diff --git a/.bazelrc b/.bazelrc deleted file mode 100644 index f9a0772778fd3..0000000000000 --- a/.bazelrc +++ /dev/null @@ -1,114 +0,0 @@ -build --cxxopt=--std=c++20 -build --copt=-I. -# Bazel does not support including its cc_library targets as system -# headers. We work around this for generated code -# (e.g. torch/headeronly/macros/cmake_macros.h) by making the generated directory a -# system include path. -build --copt=-isystem --copt bazel-out/k8-fastbuild/bin -build --copt=-isystem --copt bazel-out/darwin-fastbuild/bin -build --experimental_ui_max_stdouterr_bytes=2048576 - -# Configuration to disable tty features for environments like CI -build:no-tty --curses no -build:no-tty --progress_report_interval 10 -build:no-tty --show_progress_rate_limit 10 - -# Build with GPU support by default. -build --define=cuda=true -# rules_cuda configuration -build --@rules_cuda//cuda:enable_cuda -build --@rules_cuda//cuda:cuda_targets=sm_52 -build --@rules_cuda//cuda:compiler=nvcc -build --repo_env=CUDA_PATH=/usr/local/cuda - -# Configuration to build without GPU support -build:cpu-only --define=cuda=false -# define a separate build folder for faster switching between configs -build:cpu-only --platform_suffix=-cpu-only -# See the note on the config-less build for details about why we are -# doing this. We must also do it for the "-cpu-only" platform suffix. -build --copt=-isystem --copt=bazel-out/k8-fastbuild-cpu-only/bin -# rules_cuda configuration -build:cpu-only --@rules_cuda//cuda:enable_cuda=False - -# Definition of --config=shell -# interactive shell immediately before execution -build:shell --run_under="//tools/bazel_tools:shellwrap" - -# Disable all warnings for external repositories. We don't care about -# their warnings. -build --per_file_copt=^external/@-w - -# Set additional warnings to error level. -# -# Implementation notes: -# * we use file extensions to determine if we are using the C++ -# compiler or the cuda compiler -# * we use ^// at the start of the regex to only permit matching -# PyTorch files. This excludes external repos. -# -# Note that because this is logically a command-line flag, it is -# considered the word on what warnings are enabled. This has the -# unfortunate consequence of preventing us from disabling an error at -# the target level because those flags will come before these flags in -# the action invocation. Instead we provide per-file exceptions after -# this. -# -# On the bright side, this means we don't have to more broadly apply -# the exceptions to an entire target. -# -# Looking for CUDA flags? We have a cu_library macro that we can edit -# directly. Look in //tools/rules:cu.bzl for details. Editing the -# macro over this has the following advantages: -# * making changes does not require discarding the Bazel analysis -# cache -# * it allows for selective overrides on individual targets since the -# macro-level opts will come earlier than target level overrides - -build --per_file_copt='^//.*\.(cpp|cc)$'@-Werror=all -# The following warnings come from -Wall. We downgrade them from error -# to warnings here. -# -# We intentionally use #pragma unroll, which is compiler specific. -build --per_file_copt='^//.*\.(cpp|cc)$'@-Wno-error=unknown-pragmas - -build --per_file_copt='^//.*\.(cpp|cc)$'@-Werror=extra -# The following warnings come from -Wextra. We downgrade them from error -# to warnings here. -# -# unused-parameter-compare has a tremendous amount of violations in the -# codebase. It will be a lot of work to fix them, just disable it for -# now. -build --per_file_copt='^//.*\.(cpp|cc)$'@-Wno-unused-parameter -# missing-field-parameters has both a large number of violations in -# the codebase, but it also is used pervasively in the Python C -# API. There are a couple of catches though: -# * we use multiple versions of the Python API and hence have -# potentially multiple different versions of each relevant -# struct. They may have different numbers of fields. It will be -# unwieldy to support multiple versions in the same source file. -# * Python itself for many of these structs recommends only -# initializing a subset of the fields. We should respect the API -# usage conventions of our dependencies. -# -# Hence, we just disable this warning altogether. We may want to clean -# up some of the clear-cut cases that could be risky, but we still -# likely want to have this disabled for the most part. -build --per_file_copt='^//.*\.(cpp|cc)$'@-Wno-missing-field-initializers - -build --per_file_copt='^//.*\.(cpp|cc)$'@-Wno-unused-function -build --per_file_copt='^//.*\.(cpp|cc)$'@-Wno-unused-variable - -build --per_file_copt='//:aten/src/ATen/RegisterCompositeExplicitAutograd\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterCompositeImplicitAutograd\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterMkldnnCPU\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterNestedTensorCPU\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterQuantizedCPU\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterSparseCPU\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterSparseCsrCPU\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterNestedTensorMeta\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterSparseMeta\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterQuantizedMeta\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:aten/src/ATen/RegisterZeroTensor\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:torch/csrc/lazy/generated/RegisterAutogradLazy\.cpp$'@-Wno-error=unused-function -build --per_file_copt='//:torch/csrc/lazy/generated/RegisterLazy\.cpp$'@-Wno-error=unused-function diff --git a/.bazelversion b/.bazelversion deleted file mode 100644 index f22d756da39d4..0000000000000 --- a/.bazelversion +++ /dev/null @@ -1 +0,0 @@ -6.5.0 diff --git a/.ci/docker/README.md b/.ci/docker/README.md index 6e0dcfd6b25d7..56d6ff1f2fe27 100644 --- a/.ci/docker/README.md +++ b/.ci/docker/README.md @@ -25,7 +25,6 @@ See `build.sh` for valid build environments (it's the giant switch). * `conda` - Dockerfile and build.sh to build Docker images used in nightly conda builds * `manywheel` - Dockerfile and build.sh to build Docker images used in nightly manywheel builds -* `libtorch` - Dockerfile and build.sh to build Docker images used in nightly libtorch builds ## Usage diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index b2bfb3e1212a7..d061353187c55 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -19,15 +19,12 @@ RUN git config --global --add safe.directory '*' ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH # cmake-3.18.4 from pip +# NS: Apr 1 2026 3.18.4 is gone, reported here https://github.com/scikit-build/cmake-python-distributions/issues/693 RUN yum install -y python3-pip && \ - python3 -mpip install cmake==3.18.4 && \ + python3 -mpip install cmake==3.18.4.post1 && \ ln -s /usr/local/bin/cmake /usr/bin/cmake3 RUN rm -rf /usr/local/cuda-* -FROM base as openssl -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - FROM base as patchelf # Install patchelf ADD ./common/install_patchelf.sh install_patchelf.sh @@ -84,6 +81,8 @@ RUN yum -y update && \ yum -y install glibc-langpack-en && \ yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb RUN git config --global --add safe.directory '*' +# All rocm clang cfg files load the same rocm.cfg, make sure it points to the right toolchain. +RUN echo "--gcc-toolchain=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr" >> /opt/rocm/llvm/bin/rocm.cfg ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH FROM rocm_base as rocm @@ -109,7 +108,6 @@ COPY --from=cuda13.2 /usr/local/cuda-13.2 /usr/local/cuda-13.2 # Final step FROM ${BASE_TARGET} as final ARG DEVTOOLSET_VERSION=13 -COPY --from=openssl /opt/openssl /opt/openssl COPY --from=patchelf /patchelf /usr/local/bin/patchelf COPY --from=conda /opt/conda /opt/conda diff --git a/.ci/docker/almalinux/build.sh b/.ci/docker/almalinux/build.sh index 468f9b06418f7..668e7f71881c0 100755 --- a/.ci/docker/almalinux/build.sh +++ b/.ci/docker/almalinux/build.sh @@ -36,7 +36,7 @@ case ${DOCKER_TAG_PREFIX} in ;; rocm*) BASE_TARGET=rocm - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" + PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" ;; *) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 23e9ef021c3ab..79ddeeb328438 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -95,7 +95,6 @@ case "$tag" in CUDA_VERSION=12.4 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes ;; @@ -103,7 +102,6 @@ case "$tag" in CUDA_VERSION=12.8.1 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes INSTALL_MINGW=yes @@ -112,15 +110,14 @@ case "$tag" in CUDA_VERSION=13.0.2 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes + INSTALL_MINGW=yes ;; pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks) CUDA_VERSION=13.0.2 ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes INDUCTOR_BENCHMARKS=yes @@ -129,35 +126,32 @@ case "$tag" in CUDA_VERSION=13.0.2 ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=11 - VISION=yes KATEX=yes TRITON=yes ;; - pytorch-linux-jammy-py3-clang15-onnx) + pytorch-linux-jammy-py3.10-clang18) ANACONDA_PYTHON_VERSION=3.10 - CLANG_VERSION=15 - VISION=yes + CLANG_VERSION=18 + GCC_VERSION=11 + KATEX=yes + DOCS=yes ONNX=yes ;; - pytorch-linux-jammy-py3.10-clang15) - ANACONDA_PYTHON_VERSION=3.10 - CLANG_VERSION=15 - ;; - pytorch-linux-jammy-py3.11-clang15) + pytorch-linux-jammy-py3.11-clang18) ANACONDA_PYTHON_VERSION=3.11 - CLANG_VERSION=15 + CLANG_VERSION=18 ;; - pytorch-linux-jammy-py3.12-clang15) + pytorch-linux-jammy-py3.12-clang18) ANACONDA_PYTHON_VERSION=3.12 - CLANG_VERSION=15 + CLANG_VERSION=18 ;; - pytorch-linux-jammy-py3.13-clang15) + pytorch-linux-jammy-py3.13-clang18) ANACONDA_PYTHON_VERSION=3.13 - CLANG_VERSION=15 + CLANG_VERSION=18 ;; - pytorch-linux-jammy-py3.14-clang15) + pytorch-linux-jammy-py3.14-clang18) ANACONDA_PYTHON_VERSION=3.14 - CLANG_VERSION=15 + CLANG_VERSION=18 ;; pytorch-linux-jammy-rocm-n-py3 | pytorch-linux-jammy-rocm-n-py3-benchmarks | pytorch-linux-noble-rocm-n-py3) if [[ $tag =~ "jammy" ]]; then @@ -166,9 +160,7 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.12 fi GCC_VERSION=13 - VISION=yes ROCM_VERSION=7.2 - NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100" @@ -179,9 +171,7 @@ case "$tag" in pytorch-linux-noble-rocm-nightly-py3) ANACONDA_PYTHON_VERSION=3.12 GCC_VERSION=13 - VISION=yes ROCM_VERSION=nightly - NINJA_VERSION=1.9.0 TRITON=yes KATEX=yes PYTORCH_ROCM_ARCH="gfx942" @@ -189,23 +179,19 @@ case "$tag" in pytorch-linux-jammy-xpu-n-1-py3) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes XPU_VERSION=2025.2 XPU_DRIVER_TYPE=LTS - NINJA_VERSION=1.9.0 TRITON=yes ;; pytorch-linux-noble-xpu-n-py3 | pytorch-linux-noble-xpu-n-py3-client | pytorch-linux-noble-xpu-n-py3-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 - VISION=yes XPU_VERSION=2025.3 if [[ $tag =~ "client" ]]; then XPU_DRIVER_TYPE=CLIENT else XPU_DRIVER_TYPE=LTS fi - NINJA_VERSION=1.9.0 TRITON=yes if [[ $tag =~ "benchmarks" ]]; then INDUCTOR_BENCHMARKS=yes @@ -214,36 +200,19 @@ case "$tag" in pytorch-linux-jammy-py3-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 - VISION=yes KATEX=yes - TRITON=yes DOCS=yes INDUCTOR_BENCHMARKS=yes ;; - pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang15) + pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang18) ANACONDA_PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 - CLANG_VERSION=15 - VISION=yes - TRITON=yes - ;; - pytorch-linux-jammy-py3-clang18-asan) - ANACONDA_PYTHON_VERSION=3.10 CLANG_VERSION=18 - VISION=yes - ;; - pytorch-linux-jammy-py3.10-gcc11) - ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 - VISION=yes - KATEX=yes TRITON=yes - DOCS=yes - UNINSTALL_DILL=yes ;; - pytorch-linux-jammy-py3-clang15-executorch) + pytorch-linux-jammy-py3-clang18-executorch) ANACONDA_PYTHON_VERSION=3.10 - CLANG_VERSION=15 + CLANG_VERSION=18 EXECUTORCH=yes ;; pytorch-linux-jammy-py3.12-halide) @@ -288,21 +257,13 @@ case "$tag" in ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes - VISION=yes OPENBLAS=yes - # snadampal: skipping llvm src build install because the current version - # from pytorch/llvm:9.0.1 is x86 specific - SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=13 ACL=yes - VISION=yes OPENBLAS=yes - # snadampal: skipping llvm src build install because the current version - # from pytorch/llvm:9.0.1 is x86 specific - SKIP_LLVM_SRC_BUILD_INSTALL=yes INDUCTOR_BENCHMARKS=yes ;; pytorch-linux-noble-riscv64-py3.12-gcc14) @@ -310,7 +271,6 @@ case "$tag" in ;; *) # Catch-all for builds that are not hardcoded. - VISION=yes echo "image '$image' did not match an existing build configuration" if [[ "$image" == *py* ]]; then extract_version_from_image_name py ANACONDA_PYTHON_VERSION @@ -318,6 +278,7 @@ case "$tag" in then ANACONDA_PYTHON_VERSION=${ANACONDA_PYTHON_VERSION%?} PYTHON_FREETHREADED=1 + TSAN=yes fi fi if [[ "$image" == *cuda* ]]; then @@ -327,14 +288,10 @@ case "$tag" in if [[ -z "$ROCM_VERSION" ]]; then extract_version_from_image_name rocm ROCM_VERSION fi - NINJA_VERSION=1.9.0 TRITON=yes # To ensure that any ROCm config will build using conda cmake # and thus have LAPACK/MKL enabled fi - if [[ "$image" == *centos7* ]]; then - NINJA_VERSION=1.10.2 - fi if [[ "$image" == *gcc* ]]; then extract_version_from_image_name gcc GCC_VERSION fi @@ -366,7 +323,6 @@ docker buildx build \ ${progress_flag} \ --build-arg "BUILD_ENVIRONMENT=${image}" \ --build-arg "LLVMDEV=${LLVMDEV:-}" \ - --build-arg "VISION=${VISION:-}" \ --build-arg "UBUNTU_VERSION=${UBUNTU_VERSION}" \ --build-arg "DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" \ --build-arg "GLIBC_VERSION=${GLIBC_VERSION}" \ @@ -376,7 +332,6 @@ docker buildx build \ --build-arg "PYTHON_VERSION=${PYTHON_VERSION}" \ --build-arg "GCC_VERSION=${GCC_VERSION}" \ --build-arg "CUDA_VERSION=${CUDA_VERSION}" \ - --build-arg "NINJA_VERSION=${NINJA_VERSION:-}" \ --build-arg "KATEX=${KATEX:-}" \ --build-arg "ROCM_VERSION=${ROCM_VERSION:-}" \ --build-arg "PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}" \ @@ -390,13 +345,12 @@ docker buildx build \ --build-arg "HALIDE=${HALIDE}" \ --build-arg "PALLAS=${PALLAS}" \ --build-arg "TPU=${TPU}" \ + --build-arg "TSAN=${TSAN}" \ --build-arg "XPU_VERSION=${XPU_VERSION}" \ --build-arg "XPU_DRIVER_TYPE=${XPU_DRIVER_TYPE}" \ - --build-arg "UNINSTALL_DILL=${UNINSTALL_DILL}" \ --build-arg "ACL=${ACL:-}" \ --build-arg "OPENBLAS=${OPENBLAS:-}" \ --build-arg "SKIP_SCCACHE_INSTALL=${SKIP_SCCACHE_INSTALL:-}" \ - --build-arg "SKIP_LLVM_SRC_BUILD_INSTALL=${SKIP_LLVM_SRC_BUILD_INSTALL:-}" \ --build-arg "INSTALL_MINGW=${INSTALL_MINGW:-}" \ -f $(dirname ${DOCKERFILE})/Dockerfile \ --load \ diff --git a/.ci/docker/ci_commit_pins/huggingface-requirements.txt b/.ci/docker/ci_commit_pins/huggingface-requirements.txt index 08538ff511057..51a16f10e0632 100644 --- a/.ci/docker/ci_commit_pins/huggingface-requirements.txt +++ b/.ci/docker/ci_commit_pins/huggingface-requirements.txt @@ -1,2 +1,2 @@ -transformers==5.2.0 +transformers==5.5.3 soxr==0.5.0 diff --git a/.ci/docker/ci_commit_pins/nccl-cu126.txt b/.ci/docker/ci_commit_pins/nccl-cu126.txt new file mode 100644 index 0000000000000..1706c910183ce --- /dev/null +++ b/.ci/docker/ci_commit_pins/nccl-cu126.txt @@ -0,0 +1 @@ +v2.29.3-1 diff --git a/.ci/docker/ci_commit_pins/nccl.txt b/.ci/docker/ci_commit_pins/nccl.txt index 1706c910183ce..9ad2e5cfc6595 100644 --- a/.ci/docker/ci_commit_pins/nccl.txt +++ b/.ci/docker/ci_commit_pins/nccl.txt @@ -1 +1 @@ -v2.29.3-1 +v2.29.7-1 diff --git a/.ci/docker/ci_commit_pins/triton-xpu.txt b/.ci/docker/ci_commit_pins/triton-xpu.txt index e6c93e1b5432c..3bec45a60a706 100644 --- a/.ci/docker/ci_commit_pins/triton-xpu.txt +++ b/.ci/docker/ci_commit_pins/triton-xpu.txt @@ -1 +1 @@ -33f782efa9464adebb448ea1f1df1a64ec37ceb0 +21033c4e2be9b42c9e6ce7a39a70ead2aba279b4 diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 3d17e9c0de64b..2d65035813652 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1,5 @@ +<<<<<<< HEAD ba5c1517e6f5906761cf5783036efb587026208d +======= +88b227e23f0445f3f695bad05bbf1a363b4f50e0 +>>>>>>> upstream/main diff --git a/.ci/docker/common/install_amdsmi.sh b/.ci/docker/common/install_amdsmi.sh index 8e0ee620da679..759e2bababe25 100644 --- a/.ci/docker/common/install_amdsmi.sh +++ b/.ci/docker/common/install_amdsmi.sh @@ -7,7 +7,7 @@ source /etc/rocm_env.sh # For theRock nightly, amd_smi may already be installed or in a different location if [ -d "${ROCM_PATH}/share/amd_smi" ]; then echo "Installing amdsmi from: ${ROCM_PATH}/share/amd_smi" - cd ${ROCM_PATH}/share/amd_smi && pip install . + cd ${ROCM_PATH}/share/amd_smi && python3 -m pip install . else echo "AMD SMI not found at ${ROCM_PATH}/share/amd_smi - skipping (may already be installed via pip)" fi diff --git a/.ci/docker/common/install_base.sh b/.ci/docker/common/install_base.sh index 7d8ae247d7a0b..0f041abad6eb1 100755 --- a/.ci/docker/common/install_base.sh +++ b/.ci/docker/common/install_base.sh @@ -11,36 +11,26 @@ install_ubuntu() { # "$UBUNTU_VERSION" == "18.04" if [[ "$UBUNTU_VERSION" == "20.04"* ]]; then cmake3="cmake=3.16*" - maybe_libiomp_dev="" elif [[ "$UBUNTU_VERSION" == "22.04"* ]]; then cmake3="cmake=3.22*" - maybe_libiomp_dev="" elif [[ "$UBUNTU_VERSION" == "24.04"* ]]; then cmake3="cmake=3.28*" - maybe_libiomp_dev="" else - cmake3="cmake=3.5*" - maybe_libiomp_dev="libiomp-dev" - fi - - if [[ "$CLANG_VERSION" == 15 ]]; then - maybe_libomp_dev="libomp-15-dev" - elif [[ "$CLANG_VERSION" == 12 ]]; then - maybe_libomp_dev="libomp-12-dev" - elif [[ "$CLANG_VERSION" == 10 ]]; then - maybe_libomp_dev="libomp-10-dev" - else - maybe_libomp_dev="" + echo "Unknown Ubuntu version $UBUNTU_VERSION" + exit 1 fi # Install common dependencies apt-get update + # Install prerequisites for add-apt-repository (needs gpg-agent for PPA key import) + apt-get install -y --no-install-recommends software-properties-common gpg-agent + # Add git-core PPA for a newer version of git + add-apt-repository ppa:git-core/ppa -y + apt-get update # TODO: Some of these may not be necessary - ccache_deps="asciidoc docbook-xml docbook-xsl xsltproc" deploy_deps="libffi-dev libbz2-dev libreadline-dev libncurses5-dev libncursesw5-dev libgdbm-dev libsqlite3-dev uuid-dev tk-dev" numpy_deps="gfortran" apt-get install -y --no-install-recommends \ - $ccache_deps \ $numpy_deps \ ${deploy_deps} \ ${cmake3} \ @@ -53,14 +43,14 @@ install_ubuntu() { git \ libatlas-base-dev \ libc6-dbg \ - ${maybe_libiomp_dev} \ libyaml-dev \ libz-dev \ libjemalloc2 \ + libgl1 \ libjpeg-dev \ libasound2-dev \ libsndfile-dev \ - ${maybe_libomp_dev} \ + libssl-dev \ software-properties-common \ wget \ sudo \ @@ -71,7 +61,9 @@ install_ubuntu() { unzip \ gpg-agent \ gdb \ - bc + bc \ + zip \ + valgrind # Should resolve issues related to various apt package repository cert issues # see: https://github.com/pytorch/pytorch/issues/65931 @@ -82,70 +74,14 @@ install_ubuntu() { rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* } -install_centos() { - # Need EPEL for many packages we depend on. - # See http://fedoraproject.org/wiki/EPEL - yum --enablerepo=extras install -y epel-release - - ccache_deps="asciidoc docbook-dtds docbook-style-xsl libxslt" - numpy_deps="gcc-gfortran" - yum install -y \ - $ccache_deps \ - $numpy_deps \ - autoconf \ - automake \ - bzip2 \ - cmake \ - cmake3 \ - curl \ - gcc \ - gcc-c++ \ - gflags-devel \ - git \ - glibc-devel \ - glibc-headers \ - glog-devel \ - libstdc++-devel \ - libsndfile-devel \ - make \ - opencv-devel \ - sudo \ - wget \ - vim \ - unzip \ - gdb - - # Cleanup - yum clean all - rm -rf /var/cache/yum - rm -rf /var/lib/yum/yumdb - rm -rf /var/lib/yum/history -} - # Install base packages depending on the base OS ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') case "$ID" in ubuntu) install_ubuntu ;; - centos) - install_centos - ;; *) echo "Unable to determine OS..." exit 1 ;; esac - -# Install Valgrind separately since the apt-get version is too old. -mkdir valgrind_build && cd valgrind_build -VALGRIND_VERSION=3.20.0 -wget https://ossci-linux.s3.amazonaws.com/valgrind-${VALGRIND_VERSION}.tar.bz2 -tar -xjf valgrind-${VALGRIND_VERSION}.tar.bz2 -cd valgrind-${VALGRIND_VERSION} -./configure --prefix=/usr/local -make -j$[$(nproc) - 2] -sudo make install -cd ../../ -rm -rf valgrind_build -alias valgrind="/usr/local/bin/valgrind" diff --git a/.ci/docker/common/install_cache.sh b/.ci/docker/common/install_cache.sh index 9bb80a4e80eca..fd35f32c68f9f 100644 --- a/.ci/docker/common/install_cache.sh +++ b/.ci/docker/common/install_cache.sh @@ -18,7 +18,8 @@ install_ubuntu() { cp target/release/sccache-dist /opt/cache/bin echo "Cleaning up" cd .. - rm -rf sccache .cargo + rm -rf sccache + rustup self uninstall -y apt-get remove -y pkg-config libssl-dev apt-get autoclean && apt-get clean diff --git a/.ci/docker/common/install_clang.sh b/.ci/docker/common/install_clang.sh index 93daeee919b3d..20eea1ee157fe 100755 --- a/.ci/docker/common/install_clang.sh +++ b/.ci/docker/common/install_clang.sh @@ -15,7 +15,7 @@ if [ -n "$CLANG_VERSION" ]; then sudo apt-get update if [[ $CLANG_VERSION -ge 18 ]]; then - apt-get install -y libomp-${CLANG_VERSION}-dev libclang-rt-${CLANG_VERSION}-dev clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" + apt-get install -y --no-install-recommends libomp-${CLANG_VERSION}-dev libclang-rt-${CLANG_VERSION}-dev clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" else apt-get install -y --no-install-recommends clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" fi diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index 57c9845d76a0b..547eec0d401f5 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -67,13 +67,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then conda_install sqlite fi - # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README - if [[ $(uname -m) != "aarch64" ]]; then - pip_install mkl==2024.2.0 - pip_install mkl-static==2024.2.0 - pip_install mkl-include==2024.2.0 - fi - # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source # and libpython-static for torch deploy conda_install llvmdev=8.0.0 "libpython-static=${ANACONDA_PYTHON_VERSION}" @@ -103,5 +96,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then pip_install -r /opt/conda/requirements-docs.txt fi + # Clean conda package cache + as_jenkins conda clean -ya + popd fi diff --git a/.ci/docker/common/install_conda_docker.sh b/.ci/docker/common/install_conda_docker.sh index dc377075750ac..9665a799b6ff5 100755 --- a/.ci/docker/common/install_conda_docker.sh +++ b/.ci/docker/common/install_conda_docker.sh @@ -13,8 +13,8 @@ rm $(basename "$MINICONDA_URL") export PATH=/opt/conda/bin:$PATH # See https://github.com/pytorch/builder/issues/1473 # Pin conda to 23.5.2 as it's the last one compatible with openssl-1.1.1 -conda install -y conda=23.5.2 conda-build anaconda-client git ninja +conda install -y conda=23.5.2 conda-build anaconda-client git # The cmake version here needs to match with the minimum version of cmake -# supported by PyTorch (3.18). There is only 3.18.2 on anaconda -/opt/conda/bin/pip3 install cmake==3.18.2 +# supported by PyTorch (3.18). +/opt/conda/bin/pip3 install cmake==3.18.4.post1 ninja conda remove -y --force patchelf diff --git a/.ci/docker/common/install_cpython.sh b/.ci/docker/common/install_cpython.sh index 1c5fa03010801..4ccb380bfbb8d 100755 --- a/.ci/docker/common/install_cpython.sh +++ b/.ci/docker/common/install_cpython.sh @@ -27,10 +27,14 @@ function do_cpython_build { check_var $py_folder tar -xzf Python-$py_ver.tgz + local base_ver=${py_ver%%+*} local additional_flags="" - if [[ "$py_ver" == *"t" ]]; then + if [[ "$base_ver" == *"t" ]]; then additional_flags=" --disable-gil" fi + if [[ "$py_ver" == *"+tsan" ]]; then + additional_flags+=" --with-thread-sanitizer" + fi pushd $py_folder @@ -74,18 +78,22 @@ function do_cpython_build { # packaging is needed to create symlink since wheel no longer provides needed information retry ${prefix}/bin/pip install packaging==25.0 wheel==0.45.1 setuptools==80.9.0 local abi_tag=$(${prefix}/bin/python -c "from packaging.tags import interpreter_name, interpreter_version; import sysconfig ; from sysconfig import get_config_var; print('{0}{1}-{0}{1}{2}'.format(interpreter_name(), interpreter_version(), 't' if sysconfig.get_config_var('Py_GIL_DISABLED') else ''))") + # Append build variant suffix (e.g., "+tsan") to the abi tag + if [[ "$py_ver" == *"+"* ]]; then + abi_tag="${abi_tag}+${py_ver#*+}" + fi ln -sf ${prefix} /opt/python/${abi_tag} } function build_cpython { local py_ver=$1 check_var $py_ver - local py_suffix=$py_ver - local py_folder=$py_ver + local py_suffix=${py_ver%%+*} + local py_folder=$py_suffix # Special handling for nogil - if [[ "${py_ver}" == *"t" ]]; then - py_suffix=${py_ver::-1} + if [[ "${py_suffix}" == *"t" ]]; then + py_suffix=${py_suffix::-1} py_folder=$py_suffix fi retry wget -q $PYTHON_DOWNLOAD_URL/$py_folder/Python-$py_suffix.tgz -O Python-$py_ver.tgz diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index a6cfb8c27680e..9d7cd7ad78c05 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -82,21 +82,23 @@ function install_nvshmem { function install_124 { CUDNN_VERSION=9.1.0.70 - echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-0.6.2" + CUSPARSELT_VERSION=0.6.2.3 + echo "Installing CUDA 12.4.1 and cuDNN ${CUDNN_VERSION} and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" install_cuda 12.4.1 cuda_12.4.1_550.54.15_linux install_cudnn 12 $CUDNN_VERSION CUDA_VERSION=12.4 bash install_nccl.sh - CUDA_VERSION=12.4 bash install_cusparselt.sh + CUDA_VERSION=12.4 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } function install_126 { CUDNN_VERSION=9.10.2.21 - echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" + CUSPARSELT_VERSION=0.7.1.0 + echo "Installing CUDA 12.6.3 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" install_cuda 12.6.3 cuda_12.6.3_560.35.05_linux install_cudnn 12 $CUDNN_VERSION @@ -105,14 +107,15 @@ function install_126 { CUDA_VERSION=12.6 bash install_nccl.sh - CUDA_VERSION=12.6 bash install_cusparselt.sh + CUDA_VERSION=12.6 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } function install_129 { - CUDNN_VERSION=9.17.1.4 - echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" + CUDNN_VERSION=9.20.0.48 + CUSPARSELT_VERSION=0.8.1.1 + echo "Installing CUDA 12.9.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" # install CUDA 12.9.1 in the same container install_cuda 12.9.1 cuda_12.9.1_575.57.08_linux @@ -123,14 +126,15 @@ function install_129 { CUDA_VERSION=12.9 bash install_nccl.sh - CUDA_VERSION=12.9 bash install_cusparselt.sh + CUDA_VERSION=12.9 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } function install_128 { - CUDNN_VERSION=9.19.0.56 - echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" + CUDNN_VERSION=9.20.0.48 + CUSPARSELT_VERSION=0.7.1.0 + echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux @@ -141,14 +145,15 @@ function install_128 { CUDA_VERSION=12.8 bash install_nccl.sh - CUDA_VERSION=12.8 bash install_cusparselt.sh + CUDA_VERSION=12.8 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } function install_130 { - CUDNN_VERSION=9.19.0.56 - echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.8.0" + CUDNN_VERSION=9.20.0.48 + CUSPARSELT_VERSION=0.8.1.1 + echo "Installing CUDA 13.0 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" # install CUDA 13.0 in the same container install_cuda 13.0.2 cuda_13.0.2_580.95.05_linux @@ -159,16 +164,17 @@ function install_130 { CUDA_VERSION=13.0 bash install_nccl.sh - CUDA_VERSION=13.0 bash install_cusparselt.sh + CUDA_VERSION=13.0 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } function install_132 { - CUDNN_VERSION=9.19.0.56 - echo "Installing CUDA 13.2 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.8.0" + CUDNN_VERSION=9.20.0.48 + CUSPARSELT_VERSION=0.8.1.1 + echo "Installing CUDA 13.2 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-${CUSPARSELT_VERSION}" # install CUDA 13.2 in the same container - install_cuda 13.2.0 cuda_13.2.0_595.45.04_linux + install_cuda 13.2.1 cuda_13.2.1_595.58.03_linux # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement install_cudnn 13 $CUDNN_VERSION @@ -177,7 +183,7 @@ function install_132 { CUDA_VERSION=13.2 bash install_nccl.sh - CUDA_VERSION=13.2 bash install_cusparselt.sh + CUDA_VERSION=13.2 bash install_cusparselt.sh $CUSPARSELT_VERSION ldconfig } diff --git a/.ci/docker/common/install_cusparselt.sh b/.ci/docker/common/install_cusparselt.sh index b532c086371f1..0568dd1a18f55 100644 --- a/.ci/docker/common/install_cusparselt.sh +++ b/.ci/docker/common/install_cusparselt.sh @@ -5,34 +5,29 @@ set -ex # cuSPARSELt license: https://docs.nvidia.com/cuda/cusparselt/license.html mkdir tmp_cusparselt && cd tmp_cusparselt -if [[ ${CUDA_VERSION:0:4} =~ "13" ]]; then - arch_path='sbsa' - export TARGETARCH=${TARGETARCH:-$(uname -m)} - if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then - arch_path='x86_64' - fi - CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.8.0.4_cuda13-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz -elif [[ ${CUDA_VERSION:0:4} =~ ^12\.[5-9]$ ]]; then - arch_path='sbsa' - export TARGETARCH=${TARGETARCH:-$(uname -m)} - if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then - arch_path='x86_64' - fi - CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.7.1.0-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz -elif [[ ${CUDA_VERSION:0:4} == "12.4" ]]; then - arch_path='sbsa' - export TARGETARCH=${TARGETARCH:-$(uname -m)} - if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then - arch_path='x86_64' - fi - CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-0.6.2.3-archive" - curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz +cusparselt_version=$1 + +arch_path='sbsa' +export TARGETARCH=${TARGETARCH:-$(uname -m)} +if [ ${TARGETARCH} = 'amd64' ] || [ "${TARGETARCH}" = 'x86_64' ]; then + arch_path='x86_64' +fi + +if [[ -z "${cusparselt_version}" ]]; then + echo "Usage: install_cusparselt.sh " + exit 1 +fi + +cuda_major_version=${CUDA_VERSION%%.*} +cusparselt_minor=$(echo "${cusparselt_version}" | cut -d. -f2) +# Starting from 0.8.0, NVIDIA ships separate archives per CUDA major version +if [[ "${cusparselt_minor}" -ge 8 ]]; then + CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-${cusparselt_version}_cuda${cuda_major_version}-archive" else - echo "Not sure which libcusparselt version to install for this ${CUDA_VERSION}" + CUSPARSELT_NAME="libcusparse_lt-linux-${arch_path}-${cusparselt_version}-archive" fi +curl --retry 3 -OLs https://developer.download.nvidia.com/compute/cusparselt/redist/libcusparse_lt/linux-${arch_path}/${CUSPARSELT_NAME}.tar.xz tar xf ${CUSPARSELT_NAME}.tar.xz cp -a ${CUSPARSELT_NAME}/include/* /usr/local/cuda/include/ cp -a ${CUSPARSELT_NAME}/lib/* /usr/local/cuda/lib64/ diff --git a/.ci/docker/common/install_devtoolset.sh b/.ci/docker/common/install_devtoolset.sh deleted file mode 100755 index bdae637598138..0000000000000 --- a/.ci/docker/common/install_devtoolset.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -set -ex - -[ -n "$DEVTOOLSET_VERSION" ] - -yum install -y centos-release-scl -yum install -y devtoolset-$DEVTOOLSET_VERSION - -echo "source scl_source enable devtoolset-$DEVTOOLSET_VERSION" > "/etc/profile.d/devtoolset-$DEVTOOLSET_VERSION.sh" diff --git a/.ci/docker/common/install_docs_reqs.sh b/.ci/docker/common/install_docs_reqs.sh index c907145f2ec62..c06160373a05e 100644 --- a/.ci/docker/common/install_docs_reqs.sh +++ b/.ci/docker/common/install_docs_reqs.sh @@ -17,7 +17,7 @@ if [ -n "$KATEX" ]; then apt-get install -y --no-install-recommends yarn yarn global add katex --prefix /usr/local - sudo apt-get -y install doxygen + sudo apt-get -y install doxygen lcov apt-get autoclean && apt-get clean rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* diff --git a/.ci/docker/common/install_glibc.sh b/.ci/docker/common/install_glibc.sh deleted file mode 100755 index c98791e2bf85b..0000000000000 --- a/.ci/docker/common/install_glibc.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -set -ex - -[ -n "$GLIBC_VERSION" ] -if [[ -n "$CENTOS_VERSION" ]]; then - [ -n "$DEVTOOLSET_VERSION" ] -fi - -yum install -y wget sed - -mkdir -p /packages && cd /packages -wget -q http://ftp.gnu.org/gnu/glibc/glibc-$GLIBC_VERSION.tar.gz -tar xzf glibc-$GLIBC_VERSION.tar.gz -if [[ "$GLIBC_VERSION" == "2.26" ]]; then - cd glibc-$GLIBC_VERSION - sed -i 's/$name ne "nss_test1"/$name ne "nss_test1" \&\& $name ne "nss_test2"/' scripts/test-installation.pl - cd .. -fi -mkdir -p glibc-$GLIBC_VERSION-build && cd glibc-$GLIBC_VERSION-build - -if [[ -n "$CENTOS_VERSION" ]]; then - export PATH=/opt/rh/devtoolset-$DEVTOOLSET_VERSION/root/usr/bin:$PATH -fi - -../glibc-$GLIBC_VERSION/configure --prefix=/usr CFLAGS='-Wno-stringop-truncation -Wno-format-overflow -Wno-restrict -Wno-format-truncation -g -O2' -make -j$(nproc) -make install - -# Cleanup -rm -rf /packages -rm -rf /var/cache/yum/* -rm -rf /var/lib/rpm/__db.* -yum clean all diff --git a/.ci/docker/common/install_inductor_benchmark_deps.sh b/.ci/docker/common/install_inductor_benchmark_deps.sh index 674b141efcfb2..c54b8a44f0632 100644 --- a/.ci/docker/common/install_inductor_benchmark_deps.sh +++ b/.ci/docker/common/install_inductor_benchmark_deps.sh @@ -18,18 +18,16 @@ function install_timm() { function install_torchbench() { local commit commit=$(get_pinned_commit torchbench) - git clone https://github.com/pytorch/benchmark torchbench + mkdir torchbench && chown jenkins torchbench + as_jenkins git clone https://github.com/pytorch/benchmark torchbench pushd torchbench - git checkout "$commit" + as_jenkins git checkout "$commit" - python install.py --continue_on_fail + conda_run python install.py --continue_on_fail echo "Print all dependencies after TorchBench is installed" - python -mpip freeze + conda_run python -mpip freeze popd - - chown -R jenkins torchbench - chown -R jenkins /opt/conda } # Pango is needed for weasyprint which is needed for doctr diff --git a/.ci/docker/common/install_mingw.sh b/.ci/docker/common/install_mingw.sh index 6232a0d0245c7..e82a666ff4352 100644 --- a/.ci/docker/common/install_mingw.sh +++ b/.ci/docker/common/install_mingw.sh @@ -4,7 +4,7 @@ set -ex # Install MinGW-w64 for Windows cross-compilation apt-get update -apt-get install -y g++-mingw-w64-x86-64-posix +apt-get install -y g++-mingw-w64-x86-64-posix mingw-w64-tools echo "MinGW-w64 installed successfully" x86_64-w64-mingw32-g++ --version diff --git a/.ci/docker/common/install_miopen.sh b/.ci/docker/common/install_miopen.sh index 3dbc67b90abaf..039458add8406 100644 --- a/.ci/docker/common/install_miopen.sh +++ b/.ci/docker/common/install_miopen.sh @@ -16,7 +16,7 @@ case "$ID" in ubuntu) IS_UBUNTU=1 ;; - centos|almalinux) + almalinux) IS_UBUNTU=0 ;; *) diff --git a/.ci/docker/common/install_nccl.sh b/.ci/docker/common/install_nccl.sh index 486604140a983..f505e4c3f249a 100644 --- a/.ci/docker/common/install_nccl.sh +++ b/.ci/docker/common/install_nccl.sh @@ -18,6 +18,11 @@ NCCL_VERSION=$(cat ci_commit_pins/nccl.txt) # exit 1 # fi +# Use the NCCL version for CUDA 12.6 due to sm50 support +if [[ ${CUDA_VERSION:0:4} == "12.6" ]]; then + NCCL_VERSION=$(cat ci_commit_pins/nccl-cu126.txt) +fi + if [[ -n "${NCCL_VERSION}" ]]; then # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build diff --git a/.ci/docker/common/install_ninja.sh b/.ci/docker/common/install_ninja.sh deleted file mode 100644 index fa380722bdc2f..0000000000000 --- a/.ci/docker/common/install_ninja.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -set -ex - -[ -n "$NINJA_VERSION" ] - -arch=$(uname -m) -if [ "$arch" == "aarch64" ]; then - url="https://github.com/ninja-build/ninja/releases/download/v${NINJA_VERSION}/ninja-linux-aarch64.zip" -else - url="https://github.com/ninja-build/ninja/releases/download/v${NINJA_VERSION}/ninja-linux.zip" -fi - -pushd /tmp -wget --no-verbose --output-document=ninja-linux.zip "$url" -unzip ninja-linux.zip -d /usr/local/bin -rm -f ninja-linux.zip -popd \ No newline at end of file diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index 36ce5b11d9135..eacc7bbcad157 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -11,15 +11,11 @@ retry () { # ONNXRuntime should be installed before installing # onnx-weekly. Otherwise, onnx-weekly could be # overwritten by onnx. +# Note: parameterized, pytest-subtests, tabulate, packaging are already +# installed via requirements-ci.txt pip_install \ - parameterized==0.8.1 \ - pytest-cov==4.0.0 \ - pytest-subtests==0.10.0 \ - tabulate==0.9.0 \ - transformers==4.36.2 - -pip_install coloredlogs packaging -pip_install onnxruntime==1.23.1 + transformers==4.36.2 \ + onnxruntime==1.23.1 # Cache the transformers model to be used later by ONNX tests. We need to run the transformers # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/ @@ -34,4 +30,5 @@ conda_run python "${IMPORT_SCRIPT_FILENAME}" # Cleaning up conda_run pip uninstall -y torch +conda_run pip cache purge rm "${IMPORT_SCRIPT_FILENAME}" || true diff --git a/.ci/docker/common/install_openssl.sh b/.ci/docker/common/install_openssl.sh deleted file mode 100644 index c73c9c333c002..0000000000000 --- a/.ci/docker/common/install_openssl.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -set -ex - -OPENSSL=openssl-1.1.1k - -wget -q -O "${OPENSSL}.tar.gz" "https://ossci-linux.s3.amazonaws.com/${OPENSSL}.tar.gz" -tar xf "${OPENSSL}.tar.gz" -cd "${OPENSSL}" -./config --prefix=/opt/openssl -d '-Wl,--enable-new-dtags,-rpath,$(LIBRPATH)' -# NOTE: openssl install errors out when built with the -j option -NPROC=$[$(nproc) - 2] -make -j${NPROC}; make install_sw -# Link the ssl libraries to the /usr/lib folder. -sudo ln -s /opt/openssl/lib/lib* /usr/lib -cd .. -rm -rf "${OPENSSL}" diff --git a/.ci/docker/common/install_rocSHMEM.sh b/.ci/docker/common/install_rocSHMEM.sh new file mode 100644 index 0000000000000..ff59951f94875 --- /dev/null +++ b/.ci/docker/common/install_rocSHMEM.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Script used only in CD pipeline to build and install rocSHMEM + +set -eou pipefail + +function do_install() { + ROCSHMEM_VERSION=ea5c137103f18a9aadd570d09d72e78ec52f0a3a + rocm_dir="${ROCM_HOME:-}" + if [[ -z "${rocm_dir}" && -f /etc/rocm_env.sh ]]; then + source /etc/rocm_env.sh + rocm_dir="${ROCM_HOME:-}" + fi + rocm_dir="${rocm_dir:-/opt/rocm}" + echo "install_rocSHMEM.sh: using ROCM install prefix ${rocm_dir}" + if [[ -f "${rocm_dir}/lib/librocshmem.a" ]]; then + echo "install_rocSHMEM.sh: librocshmem.a already present in ${rocm_dir}/lib, skipping build" + return + fi + ( + set -x + curr_dir=$(pwd) + tmp_dir=$(mktemp -d) + + git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-systems.git ${tmp_dir}/rocm-systems + cd ${tmp_dir}/rocm-systems + git sparse-checkout set --cone projects/rocshmem + git checkout ${ROCSHMEM_VERSION} + + cd ${tmp_dir}/rocm-systems/projects/rocshmem + mkdir build + cd build + INSTALL_PREFIX="${rocm_dir}" ../scripts/build_configs/all_backends + cd ${curr_dir} + + ) +} + +do_install diff --git a/.ci/docker/common/install_rocm.sh b/.ci/docker/common/install_rocm.sh index 8b673a23f9de5..b4f6d70af3676 100644 --- a/.ci/docker/common/install_rocm.sh +++ b/.ci/docker/common/install_rocm.sh @@ -26,87 +26,131 @@ install_ubuntu() { apt-get install -y libc++1 apt-get install -y libc++abi1 - # When ROCM_VERSION=nightly, install ROCm from TheRock nightly wheels + # When ROCM_VERSION=nightly, install ROCm from TheRock nightly tarballs + # Mirrors: https://github.com/ROCm/TheRock/blob/main/dockerfiles/install_rocm_tarball.sh if [[ "${ROCM_VERSION}" == "nightly" ]]; then - echo "install_rocm.sh: installing ROCm from TheRock nightly wheels" + apt-get install -y --no-install-recommends pkg-config - # Clean any previous ROCm installation in the base CI image. if [[ -d /opt/rocm ]]; then - echo "Removing existing /opt/rocm from base image" rm -rf /opt/rocm fi - # Determine theRock nightly URL based on GPU architecture - # Check BUILD_ENVIRONMENT or PYTORCH_ROCM_ARCH for the target GPU - if [[ -z "${THEROCK_NIGHTLY_INDEX_URL:-}" ]]; then + # Determine GPU family based on target architecture + AMDGPU_FAMILY="${THEROCK_AMDGPU_FAMILY:-}" + if [[ -z "${AMDGPU_FAMILY}" ]]; then if [[ "${BUILD_ENVIRONMENT}" == *"gfx950"* ]] || [[ "${PYTORCH_ROCM_ARCH}" == *"gfx950"* ]]; then - # MI350 (gfx950) - THEROCK_NIGHTLY_INDEX_URL="https://rocm.nightlies.amd.com/v2/gfx950-dcgpu/" - echo "Detected gfx950 architecture - using MI350 theRock nightly repository" + AMDGPU_FAMILY="gfx950-dcgpu" else - # Default to MI300 (gfx942/gfx94X) - THEROCK_NIGHTLY_INDEX_URL="https://rocm.nightlies.amd.com/v2/gfx94X-dcgpu/" - echo "Using gfx94X (MI300) theRock nightly repository" + AMDGPU_FAMILY="gfx94X-dcgpu" fi fi - export THEROCK_NIGHTLY_INDEX_URL - echo "TheRock Index URL: ${THEROCK_NIGHTLY_INDEX_URL}" + # Auto-detect latest nightly version if not pinned + VERSION="${THEROCK_VERSION:-}" + if [[ -z "${VERSION}" ]]; then + VERSION=$(curl -fsSL "https://rocm.nightlies.amd.com/tarball/" \ + | grep -oP "therock-dist-linux-${AMDGPU_FAMILY}-\K[^\"]+(?=\.tar\.gz)" \ + | grep -v ADHOCBUILD \ + | sort -V \ + | tail -1) + if [[ -z "${VERSION}" ]]; then + echo "Error: Could not find a nightly tarball for ${AMDGPU_FAMILY}" + exit 1 + fi + fi + + # URL-encode '+' as '%2B' in VERSION (required for devreleases) + VERSION_ENCODED="${VERSION//+/%2B}" + + TARBALL_URL="https://rocm.nightlies.amd.com/tarball/therock-dist-linux-${AMDGPU_FAMILY}-${VERSION_ENCODED}.tar.gz" + + echo "==============================================" + echo "ROCm Tarball Installation" + echo "==============================================" + echo "Version: ${VERSION}" + echo "AMDGPU Family: ${AMDGPU_FAMILY}" + echo "Tarball URL: ${TARBALL_URL}" + echo "==============================================" + + # Download tarball + TARBALL_FILE="/tmp/rocm-tarball.tar.gz" + + echo "Downloading tarball..." + curl -fsSL -o "$TARBALL_FILE" "$TARBALL_URL" || { + echo "Error: Failed to download tarball from $TARBALL_URL" + exit 1 + } + + # Verify download + if [ ! -f "$TARBALL_FILE" ] || [ ! -s "$TARBALL_FILE" ]; then + echo "Error: Downloaded file is empty or does not exist" + exit 1 + fi + + # Install directory is fixed to /opt/rocm-{VERSION} + ROCM_INSTALL_DIR="/opt/rocm-${VERSION}" - python3 -m pip install \ - --index-url "${THEROCK_NIGHTLY_INDEX_URL}" \ - "rocm[libraries,devel]" + # Extract tarball to versioned directory + echo "Extracting tarball to ${ROCM_INSTALL_DIR}..." + mkdir -p "$ROCM_INSTALL_DIR" + tar -xzf "$TARBALL_FILE" -C "$ROCM_INSTALL_DIR" - # Use the rocm-sdk CLI helper to populate environment defaults - ROCM_HOME="$(rocm-sdk path --root)" - ROCM_BIN="$(rocm-sdk path --bin)" - ROCM_CMAKE_PREFIX="$(rocm-sdk path --cmake)" + # Clean up downloaded file + rm -f "$TARBALL_FILE" + echo "Tarball extracted and cleaned up" - echo "ROCM_HOME=${ROCM_HOME}" - echo "ROCM_BIN=${ROCM_BIN}" - echo "ROCM_CMAKE_PREFIX=${ROCM_CMAKE_PREFIX}" + # Create symlink /opt/rocm -> /opt/rocm-{VERSION} for compatibility + ln -sfn "$ROCM_INSTALL_DIR" /opt/rocm + echo "Created symlink: /opt/rocm -> $ROCM_INSTALL_DIR" - export ROCM_HOME - export ROCM_PATH="${ROCM_HOME}" - export PATH="${ROCM_BIN}:${PATH}" - export CMAKE_PREFIX_PATH="${ROCM_CMAKE_PREFIX}:${CMAKE_PREFIX_PATH:-}" + # Verify bin and lib folder exists after extraction + echo "Verifying installation..." + for dir in bin clients include lib libexec share; do + if [ ! -d "$ROCM_INSTALL_DIR/$dir" ]; then + echo "Error: ROCm $dir directory not found" + exit 1 + fi + echo "ROCm $dir found in $ROCM_INSTALL_DIR/$dir" + done - # theRock bundles system dependencies like libdrm, liblzma in rocm_sysdeps - ROCM_SYSDEPS="${ROCM_HOME}/lib/rocm_sysdeps" - ROCM_SYSDEPS_INCLUDE="${ROCM_SYSDEPS}/include" - ROCM_SYSDEPS_PKGCONFIG="${ROCM_SYSDEPS}/lib/pkgconfig" + echo "==============================================" + echo "ROCm installed successfully to $ROCM_INSTALL_DIR" + echo "ROCM_PATH=$ROCM_INSTALL_DIR" + echo "PATH should include: $ROCM_INSTALL_DIR/bin" + echo "==============================================" - # Write environment to file that can be sourced by CI scripts and users + # Write environment file (sourced by CI scripts and interactive shells) cat > /etc/rocm_env.sh << ROCM_ENV # ROCm paths -export ROCM_PATH="${ROCM_HOME}" -export ROCM_HOME="${ROCM_HOME}" -export ROCM_SOURCE_DIR="${ROCM_HOME}" -export ROCM_BIN="${ROCM_BIN}" -export ROCM_CMAKE="${ROCM_CMAKE_PREFIX}" -export PATH="${ROCM_BIN}:\${PATH}" -export CMAKE_PREFIX_PATH="${ROCM_CMAKE_PREFIX}:\${CMAKE_PREFIX_PATH:-}" -# Device library paths -export HIP_DEVICE_LIB_PATH="${ROCM_HOME}/lib/llvm/amdgcn/bitcode" -export ROCM_DEVICE_LIB_PATH="${ROCM_HOME}/lib/llvm/amdgcn/bitcode" -# theRock system dependencies -export ROCM_SYSDEPS_INCLUDE="${ROCM_SYSDEPS_INCLUDE}" -export CPLUS_INCLUDE_PATH="${ROCM_SYSDEPS_INCLUDE}:\${CPLUS_INCLUDE_PATH:-}" -export C_INCLUDE_PATH="${ROCM_SYSDEPS_INCLUDE}:\${C_INCLUDE_PATH:-}" -export PKG_CONFIG_PATH="${ROCM_SYSDEPS_PKGCONFIG}:\${PKG_CONFIG_PATH:-}" -export LD_LIBRARY_PATH="${ROCM_SYSDEPS}/lib:\${LD_LIBRARY_PATH:-}" -export LIBRARY_PATH="${ROCM_SYSDEPS}/lib:\${LIBRARY_PATH:-}" -export MAGMA_HOME="${ROCM_HOME}/magma" +export ROCM_PATH=/opt/rocm +export ROCM_HOME=/opt/rocm +export ROCM_SOURCE_DIR=/opt/rocm +export ROCM_BIN=/opt/rocm/bin +export ROCM_CMAKE=/opt/rocm +export PATH=/opt/rocm/bin:/opt/rocm/llvm/bin:\${PATH} +export LD_LIBRARY_PATH=/opt/rocm/lib:\${LD_LIBRARY_PATH:-} +# Sysdeps include paths (libdrm headers, etc.) +export CPLUS_INCLUDE_PATH=/opt/rocm/lib/rocm_sysdeps/include:\${CPLUS_INCLUDE_PATH:-} +export C_INCLUDE_PATH=/opt/rocm/lib/rocm_sysdeps/include:\${C_INCLUDE_PATH:-} +# Device library path +export HIP_DEVICE_LIB_PATH=/opt/rocm/amdgcn/bitcode +export MAGMA_HOME=/opt/rocm/magma +# Tarball bundles sysdeps (libdrm, liblzma, etc.); expose their libs and .pc files +if [ -d /opt/rocm/lib/rocm_sysdeps/lib ]; then + export LD_LIBRARY_PATH=/opt/rocm/lib/rocm_sysdeps/lib:\${LD_LIBRARY_PATH} + export PKG_CONFIG_PATH=/opt/rocm/lib/rocm_sysdeps/lib/pkgconfig:\${PKG_CONFIG_PATH:-} +fi # Disable MSLK for theRock nightly (not yet supported) export USE_MSLK=0 ROCM_ENV - # Append to bash.bashrc so interactive shells get the env vars echo "source /etc/rocm_env.sh" >> /etc/bash.bashrc - echo "install_rocm.sh: TheRock nightly ROCm install complete" - exit 0 - fi + # --- End of theRock nightly tarball installation --- + else + # ========================================================================= + # Non-nightly: install ROCm from repo.radeon.com apt packages + # ========================================================================= # Make sure rocm packages from repo.radeon.com have highest priority cat << EOF > /etc/apt/preferences.d/rocm-pin-600 @@ -120,6 +164,11 @@ EOF ROCM_VERSION="${ROCM_VERSION}.2" fi + # we want the patch version of 7.2 instead + if [[ $(ver $ROCM_VERSION) -eq $(ver 7.2) ]]; then + ROCM_VERSION="${ROCM_VERSION}.1" + fi + # Default url values rocm_baseurl="http://repo.radeon.com/rocm/apt/${ROCM_VERSION}" UBUNTU_VERSION_NAME=`cat /etc/os-release | grep UBUNTU_CODENAME | awk -F= '{print $2}'` @@ -154,29 +203,6 @@ EOF fi fi - # ROCm 7.2 needs a fix from procprof sdk that isn't available until 7.2.1 - if [[ $(ver $ROCM_VERSION) -eq $(ver 7.2) ]]; then - git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-systems.git - pushd rocm-systems/ - git sparse-checkout init --cone - git sparse-checkout set projects/rocprofiler-sdk shared/rocprofiler-compute - git checkout develop - git checkout rocm-7.2.0 - git config --global user.email "you@example.com" - git config --global user.name "Your Name" - git cherry-pick a71cc3cc88ed68b24c40cefec77d764053044862 - sudo apt install -y cmake libdw-dev libsqlite3-dev - cmake \ - -B rocprofiler-sdk-build \ - -DCMAKE_INSTALL_PREFIX=/opt/rocm \ - -DCMAKE_PREFIX_PATH=/opt/rocm \ - -DGPU_TARGETS="${PYTORCH_ROCM_ARCH}" \ - projects/rocprofiler-sdk - cmake --build rocprofiler-sdk-build --target all --parallel $(nproc) - cmake --build rocprofiler-sdk-build --target install - popd - fi - # ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime for kdb in /opt/rocm/share/miopen/db/*.kdb do @@ -217,7 +243,7 @@ EOF pip_install "git+https://github.com/rocm/composable_kernel@$ROCM_COMPOSABLE_KERNEL_VERSION" - # Write environment to file that can be sourced by CI scripts and users + # Write environment file (sourced by CI scripts and interactive shells) cat > /etc/rocm_env.sh << ROCM_ENV # ROCm paths export ROCM_PATH=/opt/rocm @@ -226,85 +252,18 @@ export ROCM_SOURCE_DIR=/opt/rocm export ROCM_BIN=/opt/rocm/bin export ROCM_CMAKE=/opt/rocm export PATH=/opt/rocm/bin:/opt/rocm/llvm/bin:\${PATH} -# Device library paths -export ROCM_DEVICE_LIB_PATH=/opt/rocm/amdgcn/bitcode +export LD_LIBRARY_PATH=/opt/rocm/lib:\${LD_LIBRARY_PATH:-} +# Device library path export HIP_DEVICE_LIB_PATH=/opt/rocm/amdgcn/bitcode export MAGMA_HOME=/opt/rocm/magma ROCM_ENV - # Append to bash.bashrc so interactive shells get the env vars echo "source /etc/rocm_env.sh" >> /etc/bash.bashrc # Cleanup apt-get autoclean && apt-get clean rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -} - -install_centos() { - - yum update -y - yum install -y kmod - yum install -y wget - yum install -y openblas-devel - - yum install -y epel-release - yum install -y dkms kernel-headers-`uname -r` kernel-devel-`uname -r` - - # Add amdgpu repository - local amdgpu_baseurl - if [[ $OS_VERSION == 9 ]]; then - amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/9.0/main/x86_64" - else - amdgpu_baseurl="https://repo.radeon.com/amdgpu/${ROCM_VERSION}/rhel/7.9/main/x86_64" - fi - echo "[AMDGPU]" > /etc/yum.repos.d/amdgpu.repo - echo "name=AMDGPU" >> /etc/yum.repos.d/amdgpu.repo - echo "baseurl=${amdgpu_baseurl}" >> /etc/yum.repos.d/amdgpu.repo - echo "enabled=1" >> /etc/yum.repos.d/amdgpu.repo - echo "gpgcheck=1" >> /etc/yum.repos.d/amdgpu.repo - echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/amdgpu.repo - - local rocm_baseurl="http://repo.radeon.com/rocm/yum/${ROCM_VERSION}" - echo "[ROCm]" > /etc/yum.repos.d/rocm.repo - echo "name=ROCm" >> /etc/yum.repos.d/rocm.repo - echo "baseurl=${rocm_baseurl}" >> /etc/yum.repos.d/rocm.repo - echo "enabled=1" >> /etc/yum.repos.d/rocm.repo - echo "gpgcheck=1" >> /etc/yum.repos.d/rocm.repo - echo "gpgkey=http://repo.radeon.com/rocm/rocm.gpg.key" >> /etc/yum.repos.d/rocm.repo - - yum update -y - - yum install -y \ - rocm-dev \ - rocm-utils \ - rocm-libs \ - rccl \ - rocprofiler-dev \ - roctracer-dev \ - amd-smi-lib - - # precompiled miopen kernels; search for all unversioned packages - # if search fails it will abort this script; use true to avoid case where search fails - MIOPENHIPGFX=$(yum -q search miopen-hip-gfx | grep miopen-hip-gfx | awk '{print $1}'| grep -F kdb. || true) - if [[ "x${MIOPENHIPGFX}" = x ]]; then - echo "miopen-hip-gfx package not available" && exit 1 - else - yum install -y ${MIOPENHIPGFX} - fi - - # ROCm 6.0 had a regression where journal_mode was enabled on the kdb files resulting in permission errors at runtime - for kdb in /opt/rocm/share/miopen/db/*.kdb - do - sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;" - done - - pip_install "git+https://github.com/rocm/composable_kernel@$ROCM_COMPOSABLE_KERNEL_VERSION" - - # Cleanup - yum clean all - rm -rf /var/cache/yum - rm -rf /var/lib/yum/yumdb - rm -rf /var/lib/yum/history + fi } # Install Python packages depending on the base OS @@ -313,9 +272,6 @@ case "$ID" in ubuntu) install_ubuntu ;; - centos) - install_centos - ;; *) echo "Unable to determine OS..." exit 1 diff --git a/.ci/docker/common/install_rocm_drm.sh b/.ci/docker/common/install_rocm_drm.sh index c70f5880f2c5c..a6b0fe2c03924 100644 --- a/.ci/docker/common/install_rocm_drm.sh +++ b/.ci/docker/common/install_rocm_drm.sh @@ -14,7 +14,7 @@ case "$ID" in apt-get install -y libpciaccess-dev pkg-config apt-get clean ;; - centos|almalinux) + almalinux) yum install -y libpciaccess-devel pkgconfig ;; *) diff --git a/.ci/docker/common/install_torch_tpu.sh b/.ci/docker/common/install_torch_tpu.sh index 93326e44a6161..c4e4104edfe42 100644 --- a/.ci/docker/common/install_torch_tpu.sh +++ b/.ci/docker/common/install_torch_tpu.sh @@ -60,7 +60,7 @@ fetch_secret() { set +x fi - if ! gcloud secrets versions access latest --secret="torchtpu-readonly-key" --project="ml-velocity-actions-testing" > "temp_ssh_key"; then + if ! gcloud secrets versions access latest --secret="torchtpu-read-key" --project="ml-velocity-actions-testing" > "temp_ssh_key"; then echo "Error: Failed to fetch secret. Ensure you are authenticated with gcloud." # Restore xtrace if it was enabled, before exiting @@ -82,7 +82,7 @@ clone_repo() { # Use GIT_SSH_COMMAND to specify the key and disable strict host key checking for automation export GIT_SSH_COMMAND="ssh -i temp_ssh_key -o IdentitiesOnly=yes -o StrictHostKeyChecking=no" - if git clone --recursive "git@github.com:google-ml-infra/torch_tpu.git"; then + if git clone --recursive "git@github.com:google-pytorch/torch_tpu.git"; then echo "Repository cloned successfully." else echo "Error: Failed to clone repository." @@ -110,7 +110,7 @@ pull_torch_tpu() { # sleep 28800 # Debug sleep to connect to runner to streamline debugging, do not submit # 3. Configuration -TORCH_TPU_REPO="${TORCH_TPU_REPO:-https://github.com/google-ml-infra/torch_tpu.git}" +TORCH_TPU_REPO="${TORCH_TPU_REPO:-https://github.com/google-pytorch/torch_tpu.git}" TORCH_TPU_BRANCH="${TORCH_TPU_BRANCH:-main}" # Pin File Configuration diff --git a/.ci/docker/common/install_vision.sh b/.ci/docker/common/install_vision.sh deleted file mode 100755 index 78c445568ddcd..0000000000000 --- a/.ci/docker/common/install_vision.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash - -set -ex - -install_ubuntu() { - apt-get update - apt-get install -y --no-install-recommends \ - libopencv-dev - - # Cleanup - apt-get autoclean && apt-get clean - rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* -} - -install_centos() { - # Need EPEL for many packages we depend on. - # See http://fedoraproject.org/wiki/EPEL - yum --enablerepo=extras install -y epel-release - - yum install -y \ - opencv-devel - - # Cleanup - yum clean all - rm -rf /var/cache/yum - rm -rf /var/lib/yum/yumdb - rm -rf /var/lib/yum/history -} - -# Install base packages depending on the base OS -ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') -case "$ID" in - ubuntu) - install_ubuntu - ;; - centos) - install_centos - ;; - *) - echo "Unable to determine OS..." - exit 1 - ;; -esac - -# Cache vision models used by the test -source "$(dirname "${BASH_SOURCE[0]}")/cache_vision_models.sh" diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile deleted file mode 100644 index 9c4be5abe459d..0000000000000 --- a/.ci/docker/libtorch/Dockerfile +++ /dev/null @@ -1,117 +0,0 @@ -ARG BASE_TARGET=base -ARG GPU_IMAGE=ubuntu:20.04 -FROM ${GPU_IMAGE} as base - -ENV DEBIAN_FRONTEND=noninteractive - -RUN apt-get clean && apt-get update -RUN apt-get install -y curl locales g++ git-all autoconf automake make cmake wget unzip sudo -# Just add everything as a safe.directory for git since these will be used in multiple places with git -RUN git config --global --add safe.directory '*' - -RUN locale-gen en_US.UTF-8 - -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# Install openssl -FROM base as openssl -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - -# Install python -FROM base as python -ADD common/install_cpython.sh install_cpython.sh -RUN apt-get update -y && \ - apt-get install build-essential gdb lcov libbz2-dev libffi-dev \ - libgdbm-dev liblzma-dev libncurses5-dev libreadline6-dev \ - libsqlite3-dev libssl-dev lzma lzma-dev tk-dev uuid-dev zlib1g-dev -y && \ - bash ./install_cpython.sh && \ - rm install_cpython.sh && \ - apt-get clean - -FROM base as conda -ADD ./common/install_conda_docker.sh install_conda.sh -RUN bash ./install_conda.sh && rm install_conda.sh - -FROM base as cpu -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH -# Install MKL -ADD ./common/install_mkl.sh install_mkl.sh -RUN bash ./install_mkl.sh && rm install_mkl.sh - -FROM cpu as cuda -ADD ./common/install_cuda.sh install_cuda.sh -ADD ./common/install_magma.sh install_magma.sh -COPY ./common/install_nccl.sh install_nccl.sh -COPY ./ci_commit_pins/nccl* /ci_commit_pins/ -COPY ./common/install_cusparselt.sh install_cusparselt.sh -ENV CUDA_HOME /usr/local/cuda - -FROM cuda as cuda12.6 -RUN bash ./install_cuda.sh 12.6 -RUN bash ./install_magma.sh 12.6 -RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda - -FROM cuda as cuda12.8 -RUN bash ./install_cuda.sh 12.8 -RUN bash ./install_magma.sh 12.8 -RUN ln -sf /usr/local/cuda-12.8 /usr/local/cuda - -FROM cuda as cuda12.9 -RUN bash ./install_cuda.sh 12.9 -RUN bash ./install_magma.sh 12.9 -RUN ln -sf /usr/local/cuda-12.9 /usr/local/cuda - -FROM cuda as cuda13.0 -RUN bash ./install_cuda.sh 13.0 -RUN bash ./install_magma.sh 13.0 -RUN ln -sf /usr/local/cuda-13.0 /usr/local/cuda - -# Install libibverbs for libtorch and copy to CUDA directory -RUN apt-get update -y && \ - apt-get install -y libibverbs-dev librdmacm-dev && \ - cp /usr/lib/x86_64-linux-gnu/libmlx5.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/librdmacm.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/libibverbs.so* /usr/local/cuda/lib64/ && \ - cp /usr/lib/x86_64-linux-gnu/libnl* /usr/local/cuda/lib64/ - -FROM cpu as rocm -ARG ROCM_VERSION -ARG PYTORCH_ROCM_ARCH -ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} -ENV MKLROOT /opt/intel -# Adding ROCM_PATH env var so that LoadHip.cmake (even with logic updated for ROCm6.0) -# find HIP works for ROCm5.7. Not needed for ROCm6.0 and above. -# Remove below when ROCm5.7 is not in support matrix anymore. -ENV ROCM_PATH /opt/rocm -# No need to install ROCm as base docker image should have full ROCm install -#ADD ./common/install_rocm.sh install_rocm.sh -ADD ./common/install_rocm_drm.sh install_rocm_drm.sh -ADD ./common/install_rocm_magma.sh install_rocm_magma.sh -# gfortran and python needed for building magma from source for ROCm -RUN apt-get update -y && \ - apt-get install gfortran -y && \ - apt-get install python3 python-is-python3 -y && \ - apt-get clean - -RUN bash ./install_rocm_drm.sh /opt/amdgpu && rm install_rocm_drm.sh -RUN bash ./install_rocm_magma.sh ${ROCM_VERSION} && rm install_rocm_magma.sh - -FROM ${BASE_TARGET} as final -COPY --from=openssl /opt/openssl /opt/openssl -# Install patchelf -ADD ./common/install_patchelf.sh install_patchelf.sh -RUN bash ./install_patchelf.sh && rm install_patchelf.sh -# Install Anaconda -COPY --from=conda /opt/conda /opt/conda -# Install python -COPY --from=python /opt/python /opt/python -COPY --from=python /opt/_internal /opt/_internal -ENV PATH=/opt/conda/bin:/usr/local/cuda/bin:$PATH diff --git a/.ci/docker/libtorch/build.sh b/.ci/docker/libtorch/build.sh deleted file mode 100755 index 5bfe70f34347e..0000000000000 --- a/.ci/docker/libtorch/build.sh +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env bash -# Script used only in CD pipeline - -set -eoux pipefail - -image="$1" -shift - -if [ -z "${image}" ]; then - echo "Usage: $0 IMAGENAME:ARCHTAG" - exit 1 -fi - -TOPDIR=$(git rev-parse --show-toplevel) - -DOCKER=${DOCKER:-docker} - -# Go from imagename:tag to tag -DOCKER_TAG_PREFIX=$(echo "${image}" | awk -F':' '{print $2}') - -GPU_ARCH_VERSION="" -if [[ "${DOCKER_TAG_PREFIX}" == cuda* ]]; then - # extract cuda version from image name. e.g. manylinux2_28-builder:cuda12.8 returns 12.8 - GPU_ARCH_VERSION=$(echo "${DOCKER_TAG_PREFIX}" | awk -F'cuda' '{print $2}') -elif [[ "${DOCKER_TAG_PREFIX}" == rocm* ]]; then - # extract rocm version from image name. e.g. manylinux2_28-builder:rocm6.2.4 returns 6.2.4 - GPU_ARCH_VERSION=$(echo "${DOCKER_TAG_PREFIX}" | awk -F'rocm' '{print $2}') -fi - -case ${DOCKER_TAG_PREFIX} in - cpu) - BASE_TARGET=cpu - GPU_IMAGE=ubuntu:20.04 - DOCKER_GPU_BUILD_ARG="" - ;; - cuda*) - BASE_TARGET=cuda${GPU_ARCH_VERSION} - GPU_IMAGE=ubuntu:20.04 - DOCKER_GPU_BUILD_ARG="" - ;; - rocm*) - # we want the patch version of 7.1 instead - if [[ "$GPU_ARCH_VERSION" == *"7.1"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" - fi - # we want the patch version of 7.0 instead - if [[ "$GPU_ARCH_VERSION" == *"7.0"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.2" - fi - # we want the patch version of 6.4 instead - if [[ "$GPU_ARCH_VERSION" == *"6.4"* ]]; then - GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.4" - fi - BASE_TARGET=rocm - GPU_IMAGE=rocm/dev-ubuntu-22.04:${GPU_ARCH_VERSION}-complete - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" - DOCKER_GPU_BUILD_ARG="--build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg ROCM_VERSION=${GPU_ARCH_VERSION}" - ;; - *) - echo "ERROR: Unrecognized DOCKER_TAG_PREFIX: ${DOCKER_TAG_PREFIX}" - exit 1 - ;; -esac - -tmp_tag=$(basename "$(mktemp -u)" | tr '[:upper:]' '[:lower:]') - -DOCKER_BUILDKIT=1 ${DOCKER} build \ - --target final \ - ${DOCKER_GPU_BUILD_ARG} \ - --build-arg "GPU_IMAGE=${GPU_IMAGE}" \ - --build-arg "BASE_TARGET=${BASE_TARGET}" \ - -t "${tmp_tag}" \ - $@ \ - -f "${TOPDIR}/.ci/docker/libtorch/Dockerfile" \ - "${TOPDIR}/.ci/docker/" diff --git a/.ci/docker/manywheel/Dockerfile_2_28 b/.ci/docker/manywheel/Dockerfile_2_28 index 4055e6b872539..2ad5cd6498249 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28 +++ b/.ci/docker/manywheel/Dockerfile_2_28 @@ -12,20 +12,12 @@ RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH -# cmake-3.18.4 from pip +# cmake-3.18.4.post1 from pip +# NS: Apr 1 2026 3.18.4 is gone, reported here https://github.com/scikit-build/cmake-python-distributions/issues/693 RUN yum install -y python3-pip && \ - python3 -mpip install cmake==3.18.4 && \ + python3 -mpip install cmake==3.18.4.post1 && \ ln -s /usr/local/bin/cmake /usr/bin/cmake3 -FROM base as openssl -# Install openssl (this must precede `build python` step) -# (In order to have a proper SSL module, Python is compiled -# against a recent openssl [see env vars above], which is linked -# statically. We delete openssl afterwards.) -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh - - FROM base as cuda ARG BASE_CUDA_VERSION=12.6 # Install CUDA @@ -95,7 +87,6 @@ RUN git config --global --add safe.directory "*" ENV SSL_CERT_FILE=/opt/_internal/certs.pem # Install LLVM version -COPY --from=openssl /opt/openssl /opt/openssl COPY --from=base /opt/python /opt/python COPY --from=base /usr/local/lib/ /usr/local/lib/ COPY --from=base /opt/_internal /opt/_internal @@ -127,9 +118,9 @@ RUN for cpython_version in "cp312-cp312" "cp313-cp313" "cp313-cp313t"; do \ ADD ./common/patch_libstdc.sh patch_libstdc.sh RUN bash ./patch_libstdc.sh && rm patch_libstdc.sh -# cmake-3.18.4 from pip; force in case cmake3 already exists +# cmake-3.18.4.post1 from pip; force in case cmake3 already exists RUN yum install -y python3-pip && \ - python3 -mpip install cmake==3.18.4 && \ + python3 -mpip install cmake==3.18.4.post1 && \ ln -sf /usr/local/bin/cmake /usr/bin/cmake3 FROM cpu_final as cuda_final @@ -144,7 +135,8 @@ ARG ROCM_VERSION=6.0 ARG PYTORCH_ROCM_ARCH ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ARG DEVTOOLSET_VERSION=13 -ENV LDFLAGS="-Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64 -Wl,-rpath=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib" +# All rocm clang cfg files load the same rocm.cfg, make sure it points to the right toolchain. +RUN echo "--gcc-toolchain=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr" >> /opt/rocm/llvm/bin/rocm.cfg # Somewhere in ROCm stack, we still use non-existing /opt/rocm/hip path, # below workaround helps avoid error ENV ROCM_PATH /opt/rocm @@ -160,6 +152,10 @@ RUN yum install -y libdrm-devel ENV MKLROOT /opt/intel ADD ./common/install_rocm_magma.sh install_rocm_magma.sh RUN bash ./install_rocm_magma.sh ${ROCM_VERSION} && rm install_rocm_magma.sh + +ADD ./common/install_rocSHMEM.sh install_rocSHMEM.sh +RUN bash ./install_rocSHMEM.sh ${ROCM_VERSION} && rm install_rocSHMEM.sh + ADD ./common/install_miopen.sh install_miopen.sh RUN bash ./install_miopen.sh ${ROCM_VERSION} && rm install_miopen.sh diff --git a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 index b5bf2ffc1c081..477e4221cb49f 100644 --- a/.ci/docker/manywheel/Dockerfile_2_28_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_2_28_aarch64 @@ -39,12 +39,7 @@ RUN yum install -y \ gcc-toolset-${GCCTOOLSET_VERSION}-gcc-c++ \ gcc-toolset-${GCCTOOLSET_VERSION}-gcc-gfortran \ gcc-toolset-${GCCTOOLSET_VERSION}-gdb - -# (optional) Install non-default Ninja version -ARG NINJA_VERSION -COPY ./common/install_ninja.sh install_ninja.sh -RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi -RUN rm install_ninja.sh +RUN yum install -y --enablerepo=powertools ninja-build # Ensure the expected devtoolset is used ENV PATH=/opt/rh/gcc-toolset-${GCCTOOLSET_VERSION}/root/usr/bin:$PATH diff --git a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 index 794a791b2721a..2a3f266fc413f 100644 --- a/.ci/docker/manywheel/Dockerfile_cuda_aarch64 +++ b/.ci/docker/manywheel/Dockerfile_cuda_aarch64 @@ -50,16 +50,7 @@ ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/op RUN git config --global --add safe.directory "*" -FROM base as openssl -# Install openssl (this must precede `build python` step) -# (In order to have a proper SSL module, Python is compiled -# against a recent openssl [see env vars above], which is linked -# statically. We delete openssl afterwards.) -ADD ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh && rm install_openssl.sh -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -FROM openssl as final +FROM base as final FROM base as cuda ARG BASE_CUDA_VERSION diff --git a/.ci/docker/manywheel/Dockerfile_s390x b/.ci/docker/manywheel/Dockerfile_s390x index 1cf83acb1c736..1367b004ee8a3 100644 --- a/.ci/docker/manywheel/Dockerfile_s390x +++ b/.ci/docker/manywheel/Dockerfile_s390x @@ -84,7 +84,7 @@ RUN cp $(which patchelf) /patchelf FROM patchelf as python # build python -COPY manywheel/build_scripts /build_scripts +COPY manywheel/s390_scripts /build_scripts ADD ./common/install_cpython.sh /build_scripts/install_cpython.sh ENV SSL_CERT_FILE= RUN bash build_scripts/build.sh && rm -r build_scripts diff --git a/.ci/docker/manywheel/build.sh b/.ci/docker/manywheel/build.sh index 60a7216ce5f7b..0129853c2dd18 100755 --- a/.ci/docker/manywheel/build.sh +++ b/.ci/docker/manywheel/build.sh @@ -40,7 +40,7 @@ case ${image} in manylinux2_28_aarch64-builder:cpu-aarch64) TARGET=final GPU_IMAGE=arm64v8/almalinux:8 - DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13 --build-arg NINJA_VERSION=1.12.1" + DOCKER_GPU_BUILD_ARG=" --build-arg DEVTOOLSET_VERSION=13" MANY_LINUX_VERSION="2_28_aarch64" ;; manylinuxs390x-builder:cpu-s390x) @@ -75,6 +75,10 @@ case ${image} in DOCKERFILE_SUFFIX="_cuda_aarch64" ;; manylinux2_28-builder:rocm*) + # we want the patch version of 7.2 instead + if [[ "$GPU_ARCH_VERSION" == *"7.2"* ]]; then + GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" + fi # we want the patch version of 7.1 instead if [[ "$GPU_ARCH_VERSION" == *"7.1"* ]]; then GPU_ARCH_VERSION="${GPU_ARCH_VERSION}.1" @@ -91,7 +95,7 @@ case ${image} in MANY_LINUX_VERSION="2_28" DEVTOOLSET_VERSION="13" GPU_IMAGE=rocm/dev-almalinux-8:${GPU_ARCH_VERSION}-complete - PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" + PYTORCH_ROCM_ARCH="gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1103;gfx1200;gfx1201;gfx950;gfx1150;gfx1151" DOCKER_GPU_BUILD_ARG="--build-arg ROCM_VERSION=${GPU_ARCH_VERSION} --build-arg PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} --build-arg DEVTOOLSET_VERSION=${DEVTOOLSET_VERSION}" ;; manylinux2_28-builder:xpu) diff --git a/.ci/docker/manywheel/build_scripts/manylinux1-check.py b/.ci/docker/manywheel/build_scripts/manylinux1-check.py deleted file mode 100644 index f6b9b9fc2393e..0000000000000 --- a/.ci/docker/manywheel/build_scripts/manylinux1-check.py +++ /dev/null @@ -1,63 +0,0 @@ -# Logic copied from PEP 513 - - -def is_manylinux1_compatible(): - # Only Linux, and only x86-64 / i686 - from distutils.util import get_platform - - if get_platform() not in ["linux-x86_64", "linux-i686", "linux-s390x"]: - return False - - # Check for presence of _manylinux module - try: - import _manylinux - - return bool(_manylinux.manylinux1_compatible) - except (ImportError, AttributeError): - # Fall through to heuristic check below - pass - - # Check glibc version. CentOS 5 uses glibc 2.5. - return have_compatible_glibc(2, 5) - - -def have_compatible_glibc(major, minimum_minor): - import ctypes - - process_namespace = ctypes.CDLL(None) - try: - gnu_get_libc_version = process_namespace.gnu_get_libc_version - except AttributeError: - # Symbol doesn't exist -> therefore, we are not linked to - # glibc. - return False - - # Call gnu_get_libc_version, which returns a string like "2.5". - gnu_get_libc_version.restype = ctypes.c_char_p - version_str = gnu_get_libc_version() - # py2 / py3 compatibility: - if not isinstance(version_str, str): - version_str = version_str.decode("ascii") - - # Parse string and check against requested version. - version = [int(piece) for piece in version_str.split(".")] - if len(version) != 2: - raise AssertionError( - f"Expected version to have 2 components (major.minor), got {len(version)}: {version_str}" - ) - if major != version[0]: - return False - if minimum_minor > version[1]: - return False - return True - - -import sys - - -if is_manylinux1_compatible(): - print(f"{sys.executable} is manylinux1 compatible") - sys.exit(0) -else: - print(f"{sys.executable} is NOT manylinux1 compatible") - sys.exit(1) diff --git a/.ci/docker/manywheel/build_scripts/ssl-check.py b/.ci/docker/manywheel/build_scripts/ssl-check.py deleted file mode 100644 index c4df0eacbb7fd..0000000000000 --- a/.ci/docker/manywheel/build_scripts/ssl-check.py +++ /dev/null @@ -1,26 +0,0 @@ -# cf. https://github.com/pypa/manylinux/issues/53 - -import sys -from urllib.request import urlopen - - -GOOD_SSL = "https://google.com" -BAD_SSL = "https://self-signed.badssl.com" - - -print("Testing SSL certificate checking for Python:", sys.version) - -EXC = OSError - -print(f"Connecting to {GOOD_SSL} should work") -urlopen(GOOD_SSL) -print("...it did, yay.") - -print(f"Connecting to {BAD_SSL} should fail") -try: - urlopen(BAD_SSL) - # If we get here then we failed: - print("...it DIDN'T!!!!!11!!1one!") - sys.exit(1) -except EXC: - print("...it did, yay.") diff --git a/.ci/docker/manywheel/build_scripts/build.sh b/.ci/docker/manywheel/s390_scripts/build.sh similarity index 90% rename from .ci/docker/manywheel/build_scripts/build.sh rename to .ci/docker/manywheel/s390_scripts/build.sh index b6a70f0a72787..13141dfd4ae33 100644 --- a/.ci/docker/manywheel/build_scripts/build.sh +++ b/.ci/docker/manywheel/s390_scripts/build.sh @@ -18,13 +18,7 @@ AUTOCONF_HASH=954bd69b391edc12d6a4a51a2dd1476543da5c6bbf05a95b59dc0dd6fd4c2969 # Dependencies for compiling Python that we want to remove from # the final image after compiling Python -PYTHON_COMPILE_DEPS="zlib-devel bzip2-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel libpcap-devel xz-devel libffi-devel" - -if [ "$(uname -m)" != "s390x" ] ; then - PYTHON_COMPILE_DEPS="${PYTHON_COMPILE_DEPS} db4-devel" -else - PYTHON_COMPILE_DEPS="${PYTHON_COMPILE_DEPS} libdb-devel" -fi +PYTHON_COMPILE_DEPS="zlib-devel bzip2-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel libpcap-devel xz-devel libffi-devel libdb-devel" # Libraries that are allowed as part of the manylinux1 profile MANYLINUX1_DEPS="glibc-devel libstdc++-devel glib2-devel libX11-devel libXext-devel libXrender-devel mesa-libGL-devel libICE-devel libSM-devel ncurses-devel" @@ -103,13 +97,6 @@ find /opt/_internal \ -o \( -type f -a -name '*.pyc' -o -name '*.pyo' \) \ -print0 | xargs -0 rm -f -for PYTHON in /opt/python/*/bin/python; do - # Smoke test to make sure that our Pythons work, and do indeed detect as - # being manylinux compatible: - $PYTHON $MY_DIR/manylinux1-check.py - # Make sure that SSL cert checking works - $PYTHON $MY_DIR/ssl-check.py -done # Fix libc headers to remain compatible with C99 compilers. find /usr/include/ -type f -exec sed -i 's/\bextern _*inline_*\b/extern __inline __attribute__ ((__gnu_inline__))/g' {} + diff --git a/.ci/docker/manywheel/build_scripts/build_utils.sh b/.ci/docker/manywheel/s390_scripts/build_utils.sh similarity index 100% rename from .ci/docker/manywheel/build_scripts/build_utils.sh rename to .ci/docker/manywheel/s390_scripts/build_utils.sh diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index cf79a13b4e444..61052dd73073d 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -75,16 +75,16 @@ librosa==0.10.2 ; python_version == "3.12" and platform_machine != "s390x" #test that import: test_spectral_ops.py #librosa depends on numba; disable it for s390x while numba is disabled too -#mkl #this breaks linux-bionic-rocm4.5-py3.7 +# Only mkl-static and mkl-include are needed; the mkl package contains +# dynamic libraries that are not discoverable by our build scripts. +mkl-static==2024.2.0 ; platform_machine != "aarch64" and sys_platform != "darwin" +mkl-include==2024.2.0 ; platform_machine != "aarch64" and sys_platform != "darwin" #Description: Intel oneAPI Math Kernel Library -#Pinned versions: +#Pinned versions: 2024.2.0 #test that import: test_profiler.py, test_public_bindings.py, test_testing.py, #test_nn.py, test_mkldnn.py, test_jit.py, test_fx_experimental.py, #test_autograd.py -#mkl-devel -# see mkl - #mock #Description: A testing library that allows you to replace parts of your #system under test with mock objects @@ -170,7 +170,7 @@ optree==0.17.0 ; python_version >= "3.14" #test_pointwise_ops.py, test_dtensor_ops.py, test_torchinductor.py, test_fx.py, #test_fake_tensor.py, test_mps.py -pillow==12.1.1 +pillow==12.2.0 #Description: Python Imaging Library fork #Pinned versions: 11.0.0 #test that import: @@ -225,7 +225,7 @@ xdoctest==1.3.0 #Pinned versions: 1.1.0 #test that import: -pygments==2.15.0 +pygments==2.20.0 #Description: support doctest highlighting #Pinned versions: 2.12.0 #test that import: the doctests @@ -292,6 +292,11 @@ lintrunner==0.12.11 #Pinned versions: 0.12.11 #test that import: +spin==0.17 +#Description: developer CLI for common build/lint tasks +#Pinned versions: 0.17 +#test that import: + redis>=4.0.0 #Description: redis database #test that import: anything that tests OSS caching/mocking (inductor/test_codecache.py, inductor/test_max_autotune.py) @@ -339,7 +344,7 @@ sympy==1.13.3 #Pinned versions: #test that import: -onnx==1.20.0 +onnx==1.21.0 #Description: Required by the torch.onnx exporter #Pinned versions: #test that import: @@ -369,6 +374,7 @@ pwlf==2.2.1 #test that import: test_sac_estimator.py # To build PyTorch itself +pip==26.0.1 pyyaml==6.0.3 pyzstd setuptools==78.1.1 @@ -396,7 +402,7 @@ tlparse==0.4.0 filelock==3.20.3 #Description: required for inductor testing -cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" and platform_system != "Darwin" +cuda-bindings>=12.0,<13.0 ; platform_machine != "s390x" and platform_machine != "riscv64" and platform_system != "Darwin" #Description: required for testing CUDAGraph::raw_cuda_graph(). See https://nvidia.github.io/cuda-python/cuda-bindings/latest/support.html for how this version was chosen. Note "Any fix in the latest bindings would be backported to the prior major version" means that only the newest version of cuda-bindings will get fixes. Depending on the latest version of 12.x is okay because all 12.y versions will be supported via "CUDA minor version compatibility". Pytorch builds against 13.z versions of cuda toolkit work with 12.x versions of cuda-bindings as well because newer drivers work with old toolkits. #test that import: test_cuda.py @@ -410,5 +416,5 @@ tqdm>=4.66.0 #Description: progress bar library required for dynamo benchmarks #test that import: benchmarks/dynamo/* -aiohttp==3.13.3 +aiohttp==3.13.4 #Description: required for torch.distributed.debug diff --git a/.ci/docker/requirements-docs.txt b/.ci/docker/requirements-docs.txt index 7f3e0b5cc9215..484d99ec1152e 100644 --- a/.ci/docker/requirements-docs.txt +++ b/.ci/docker/requirements-docs.txt @@ -2,17 +2,14 @@ sphinx==7.2.6 #Description: This is used to generate PyTorch docs #Pinned versions: 7.2.6 -pytorch_sphinx_theme2==0.4.3 +pytorch_sphinx_theme2==0.4.9 #Description: This is needed to generate PyTorch docs -#Pinned versions: 0.4.3 +#Pinned versions: 0.4.9 -# TODO: sphinxcontrib.katex 0.9.0 adds a local KaTeX server to speed up pre-rendering -# but it doesn't seem to work and hangs around idly. The initial thought that it is probably -# something related to Docker setup. We can investigate this later. - -sphinxcontrib.katex==0.8.6 +sphinxcontrib.katex==0.9.11 #Description: This is used to generate PyTorch docs -#Pinned versions: 0.8.6 +#Pinned versions: 0.9.11 (0.9.0+ uses a persistent KaTeX server instead of +# spawning a subprocess per math expression, ~20% faster writes) sphinxext-opengraph==0.9.1 #Description: This is used to generate PyTorch docs @@ -48,6 +45,10 @@ docutils==0.20 #Description: This is used to generate PyTorch C++ docs #Pinned versions: 0.20 +coverxygen==1.8.1 +#Description: This is used to measure C++ API doc coverage from Doxygen XML +#Pinned versions: 1.8.1 + bs4==0.0.1 #Description: This is used to generate PyTorch C++ docs #Pinned versions: 0.0.1 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 40c341bdcdbe8..7c69a55dbb185 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.6.0 +3.7.0 diff --git a/.ci/docker/triton_xpu_version.txt b/.ci/docker/triton_xpu_version.txt index 7c69a55dbb185..a76ccff2a6e0d 100644 --- a/.ci/docker/triton_xpu_version.txt +++ b/.ci/docker/triton_xpu_version.txt @@ -1 +1 @@ -3.7.0 +3.7.1 diff --git a/.ci/docker/ubuntu-rocm/Dockerfile b/.ci/docker/ubuntu-rocm/Dockerfile index ff62e4a934c74..a3d697f1b27f2 100644 --- a/.ci/docker/ubuntu-rocm/Dockerfile +++ b/.ci/docker/ubuntu-rocm/Dockerfile @@ -13,6 +13,7 @@ ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} # Install common dependencies (so that this step can be cached separately) COPY ./common/install_base.sh install_base.sh RUN bash ./install_base.sh && rm install_base.sh +RUN apt-get update && apt-get install -y --no-install-recommends libtbb-dev && rm -rf /var/lib/apt/lists/* # Install user COPY ./common/install_user.sh install_user.sh @@ -43,13 +44,6 @@ ARG CLANG_VERSION COPY ./common/install_clang.sh install_clang.sh RUN bash ./install_clang.sh && rm install_clang.sh -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - # Install rocm ARG ROCM_VERSION ENV ROCM_VERSION=${ROCM_VERSION} @@ -63,14 +57,27 @@ RUN rm -r ci_commit_pins COPY ./common/install_rocm_magma.sh install_rocm_magma.sh RUN if [ "${ROCM_VERSION}" != "nightly" ]; then bash ./install_rocm_magma.sh ${ROCM_VERSION}; fi RUN rm install_rocm_magma.sh +COPY ./common/install_rocSHMEM.sh install_rocSHMEM.sh +RUN bash ./install_rocSHMEM.sh ${ROCM_VERSION} +RUN rm install_rocSHMEM.sh ADD ./common/install_miopen.sh install_miopen.sh RUN if [ "${ROCM_VERSION}" != "nightly" ]; then bash ./install_miopen.sh ${ROCM_VERSION}; fi && rm install_miopen.sh ADD ./common/install_rocm_drm.sh install_rocm_drm.sh RUN if [ "${ROCM_VERSION}" != "nightly" ]; then bash ./install_rocm_drm.sh /usr ; fi && rm install_rocm_drm.sh -# ROCm environment variables are set in /etc/rocm_env.sh by install_rocm.sh -# and sourced via /etc/bash.bashrc for interactive shells. -# CI scripts should source /etc/rocm_env.sh directly. +# Default ROCm environment; /etc/rocm_env.sh (created by install_rocm.sh) may +# override these at runtime for different install methods (tarballs vs wheels). +ENV ROCM_PATH=/opt/rocm \ + ROCM_HOME=/opt/rocm \ + ROCM_SOURCE_DIR=/opt/rocm \ + ROCM_BIN=/opt/rocm/bin \ + ROCM_CMAKE=/opt/rocm \ + ROCM_DEVICE_LIB_PATH=/opt/rocm/amdgcn/bitcode \ + HIP_DEVICE_LIB_PATH=/opt/rocm/amdgcn/bitcode \ + MAGMA_HOME=/opt/rocm/magma +ENV PATH=/opt/rocm/bin:/opt/rocm/llvm/bin:$PATH +ENV LD_LIBRARY_PATH=/opt/rocm/lib:${LD_LIBRARY_PATH:-} + ENV LANG C.UTF-8 ENV LC_ALL C.UTF-8 @@ -79,11 +86,6 @@ COPY ./common/install_amdsmi.sh install_amdsmi.sh RUN bash ./install_amdsmi.sh RUN rm install_amdsmi.sh -COPY ./common/install_openssl.sh install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl -RUN bash ./install_openssl.sh -ENV OPENSSL_DIR /opt/openssl - ARG INDUCTOR_BENCHMARKS ARG ANACONDA_PYTHON_VERSION ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION @@ -95,12 +97,6 @@ COPY ci_commit_pins/torchbench.txt torchbench.txt RUN if [ -n "${INDUCTOR_BENCHMARKS}" ]; then bash ./install_inductor_benchmark_deps.sh; fi RUN rm install_inductor_benchmark_deps.sh common_utils.sh timm.txt huggingface-requirements.txt torchbench.txt -# (optional) Install non-default Ninja version -ARG NINJA_VERSION -COPY ./common/install_ninja.sh install_ninja.sh -RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi -RUN rm install_ninja.sh - ARG TRITON # Install triton, this needs to be done before sccache because the latter will # try to reach out to S3, which docker build runners don't have access @@ -126,8 +122,5 @@ RUN rm install_openmpi.sh ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} -# Install LLVM dev version (Defined in the pytorch/builder github repository) -COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm - USER jenkins CMD ["bash"] diff --git a/.ci/docker/ubuntu-xpu/Dockerfile b/.ci/docker/ubuntu-xpu/Dockerfile index c61612882032d..3aba712da78d7 100644 --- a/.ci/docker/ubuntu-xpu/Dockerfile +++ b/.ci/docker/ubuntu-xpu/Dockerfile @@ -47,12 +47,6 @@ RUN bash ./install_gcc.sh && rm install_gcc.sh COPY ./common/install_lcov.sh install_lcov.sh RUN bash ./install_lcov.sh && rm install_lcov.sh -COPY ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl -ENV OPENSSL_DIR /opt/openssl -RUN rm install_openssl.sh - ARG INDUCTOR_BENCHMARKS ARG ANACONDA_PYTHON_VERSION ENV ANACONDA_PYTHON_VERSION=$ANACONDA_PYTHON_VERSION @@ -80,19 +74,6 @@ COPY triton_xpu_version.txt triton_version.txt RUN if [ -n "${TRITON}" ]; then bash ./install_triton.sh; fi RUN rm install_triton.sh common_utils.sh triton-xpu.txt triton_version.txt -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - -# (optional) Install non-default Ninja version -ARG NINJA_VERSION -COPY ./common/install_ninja.sh install_ninja.sh -RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi -RUN rm install_ninja.sh - # Install ccache/sccache (do this last, so we get priority in PATH) COPY ./common/install_cache.sh install_cache.sh ENV PATH /opt/cache/bin:$PATH @@ -102,8 +83,5 @@ RUN bash ./install_cache.sh && rm install_cache.sh ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} -# Install LLVM dev version (Defined in the pytorch/builder github repository) -COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm - USER jenkins CMD ["bash"] diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index 64ba3a86ddfdd..488eb5df7f512 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -42,7 +42,6 @@ COPY ./common/install_conda.sh install_conda.sh COPY ./common/common_utils.sh common_utils.sh COPY ./common/install_magma_conda.sh install_magma_conda.sh RUN bash ./install_conda.sh && rm install_conda.sh install_magma_conda.sh common_utils.sh /opt/conda/requirements-ci.txt /opt/conda/requirements-docs.txt -RUN if [ -n "${UNINSTALL_DILL}" ]; then pip uninstall -y dill; fi # Install gcc ARG GCC_VERSION @@ -70,25 +69,6 @@ ENV NCCL_LIB_DIR="/usr/local/cuda/lib64/" ARG CUDA_VERSION -# (optional) Install vision packages like OpenCV -ARG VISION -COPY ./common/install_vision.sh ./common/cache_vision_models.sh ./common/common_utils.sh ./ -RUN if [ -n "${VISION}" ]; then bash ./install_vision.sh; fi -RUN rm install_vision.sh cache_vision_models.sh common_utils.sh -ENV INSTALLED_VISION ${VISION} - -# (optional) Install non-default Ninja version -ARG NINJA_VERSION -COPY ./common/install_ninja.sh install_ninja.sh -RUN if [ -n "${NINJA_VERSION}" ]; then bash ./install_ninja.sh; fi -RUN rm install_ninja.sh - -COPY ./common/install_openssl.sh install_openssl.sh -RUN bash ./install_openssl.sh -ENV OPENSSL_ROOT_DIR /opt/openssl -ENV OPENSSL_DIR /opt/openssl -RUN rm install_openssl.sh - ARG INDUCTOR_BENCHMARKS COPY ./common/install_inductor_benchmark_deps.sh install_inductor_benchmark_deps.sh COPY ./common/common_utils.sh common_utils.sh @@ -163,6 +143,17 @@ COPY ./common/install_onnx.sh ./common/common_utils.sh ./ RUN if [ -n "${ONNX}" ]; then bash ./install_onnx.sh; fi RUN rm install_onnx.sh common_utils.sh +# Build TSan-instrumented CPython for thread sanitizer testing +ARG TSAN +COPY ./common/install_cpython.sh install_cpython.sh +COPY requirements-ci.txt /tmp/requirements-ci.txt +RUN if [ -n "${TSAN}" ]; then \ + CC=clang-18 CXX=clang++-18 \ + CPYTHON_VERSIONS="3.14.4t+tsan" bash ./install_cpython.sh && \ + /opt/python/cp314-cp314t+tsan/bin/pip install -r /tmp/requirements-ci.txt; \ + fi +RUN rm -f install_cpython.sh /tmp/requirements-ci.txt + # (optional) Build ACL ARG ACL COPY ./common/install_acl.sh install_acl.sh @@ -197,11 +188,6 @@ RUN rm install_openmpi.sh ARG BUILD_ENVIRONMENT ENV BUILD_ENVIRONMENT ${BUILD_ENVIRONMENT} -# Install LLVM dev version (Defined in the pytorch/builder github repository) -ARG SKIP_LLVM_SRC_BUILD_INSTALL -COPY --from=pytorch/llvm:9.0.1 /opt/llvm /opt/llvm -RUN if [ -n "${SKIP_LLVM_SRC_BUILD_INSTALL}" ]; then set -eu; rm -rf /opt/llvm; fi - # AWS specific CUDA build guidance ENV TORCH_NVCC_FLAGS "-Xfatbin -compress-all" ENV CUDA_PATH /usr/local/cuda diff --git a/.ci/lumen_cli/cli/lib/common/pip_helper.py b/.ci/lumen_cli/cli/lib/common/pip_helper.py index a0cb1e17840e1..9011ac020e34d 100644 --- a/.ci/lumen_cli/cli/lib/common/pip_helper.py +++ b/.ci/lumen_cli/cli/lib/common/pip_helper.py @@ -6,7 +6,7 @@ import shutil import sys from collections.abc import Iterable # noqa: TC003 -from importlib.metadata import PackageNotFoundError, version # noqa: UP035 +from importlib.metadata import PackageNotFoundError, version from cli.lib.common.utils import run_command diff --git a/.ci/lumen_cli/cli/lib/core/torchtitan/torchtitan_test.py b/.ci/lumen_cli/cli/lib/core/torchtitan/torchtitan_test.py index 32d22556628d7..d42fc93d6f649 100644 --- a/.ci/lumen_cli/cli/lib/core/torchtitan/torchtitan_test.py +++ b/.ci/lumen_cli/cli/lib/core/torchtitan/torchtitan_test.py @@ -21,11 +21,12 @@ def __init__(self, args: Any): def prepare(self): clone_torchtitan(dst=self.work_directory) - # torchao nightly is required by torchtitan + # torchao and torchcomms nightlies are required by torchtitan pip_install_packages( packages=[ "--pre", "torchao", + "torchcomms", "--index-url", "https://download.pytorch.org/whl/nightly/cu129", ], diff --git a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py index 73ef9934d2d89..6688514f33601 100644 --- a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py +++ b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test.py @@ -161,21 +161,23 @@ def _install_wheels(self, params: VllmTestParameters): def _install_test_dependencies(self): """ This method replaces torch dependencies with local torch wheel info in - requirements/test.in file from vllm repo. then generates the test.txt + requirements/test/cuda.in file from vllm repo. then generates the test.txt in runtime """ - logger.info("generate test.txt from requirements/test.in with local torch whls") + logger.info( + "generate test.txt from requirements/test/cuda.in with local torch whls" + ) preprocess_test_in() - copy("requirements/test.txt", "snapshot_constraint.txt") + copy("requirements/test/cuda.txt", "snapshot_constraint.txt") run_command( - f"{sys.executable} -m uv pip compile requirements/test.in " - "-o test.txt " + f"{sys.executable} -m uv pip compile requirements/test/cuda.in " + "-o test/cuda.txt " "--index-strategy unsafe-best-match " "--constraint snapshot_constraint.txt " "--torch-backend cu129" ) - pip_install_packages(requirements="test.txt", prefer_uv=True) + pip_install_packages(requirements="test/cuda.txt", prefer_uv=True) logger.info("Done. installed requirements for test dependencies") def _install_dependencies(self): @@ -187,7 +189,7 @@ def _install_dependencies(self): run_python("use_existing_torch.py") # install common packages - for requirements in ["requirements/common.txt", "requirements/build.txt"]: + for requirements in ["requirements/common.txt", "requirements/build/cuda.txt"]: pip_install_packages( requirements=requirements, prefer_uv=True, @@ -222,7 +224,8 @@ def _set_envs(self, inputs: VllmTestParameters): def preprocess_test_in( - target_file: str = "requirements/test.in", additional_packages: Iterable[str] = () + target_file: str = "requirements/test/cuda.in", + additional_packages: Iterable[str] = (), ): """ This modifies the target_file file in place in vllm work directory. diff --git a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test_library.yaml b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test_library.yaml index 402f2d8bf0e69..18c2d0f80a745 100644 --- a/.ci/lumen_cli/cli/lib/core/vllm/vllm_test_library.yaml +++ b/.ci/lumen_cli/cli/lib/core/vllm/vllm_test_library.yaml @@ -75,8 +75,10 @@ vllm_multi_model_processor_test: - git+https://github.com/TIGER-AI-Lab/Mantis.git steps: - pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py --ignore models/multimodal/processing/test_common.py --ignore models/multimodal/processing/test_glm4_1v.py - - pytest -v -s models/multimodal/processing/test_common.py -k 'not mistralai' - - HF_DATASETS_OFFLINE=0 TRANSFORMERS_OFFLINE=0 pytest -v -s models/multimodal/processing/test_common.py -k mistralai + # mistralai and moonshotai ship custom tiktoken tokenizers without a tokenizer.json; transformers v5's + # TokenizersBackend needs Hub access to resolve tokenizer_file, so run them with OFFLINE=0. + - pytest -v -s models/multimodal/processing/test_common.py -k 'not mistralai and not moonshotai' + - HF_DATASETS_OFFLINE=0 TRANSFORMERS_OFFLINE=0 pytest -v -s models/multimodal/processing/test_common.py -k 'mistralai or moonshotai' - HF_DATASETS_OFFLINE=0 TRANSFORMERS_OFFLINE=0 pytest -v -s models/multimodal/processing/test_glm4_1v.py vllm_multi_model_test_28_failure_test: @@ -124,4 +126,4 @@ vllm_lora_test: id: lora_test parallelism: 4 steps: - - pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py + - pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_llm_with_multi_loras.py --ignore=lora/test_olmoe_tp.py --ignore=lora/test_deepseekv2_tp.py --ignore=lora/test_gptoss_tp.py --ignore=lora/test_qwen3moe_tp.py --ignore=lora/test_qwen35_densemodel_lora.py diff --git a/.ci/lumen_cli/pyproject.toml b/.ci/lumen_cli/pyproject.toml index b2ac379e34ab0..ce8cf59d99fda 100644 --- a/.ci/lumen_cli/pyproject.toml +++ b/.ci/lumen_cli/pyproject.toml @@ -6,7 +6,7 @@ dependencies = [ "GitPython==3.1.45", "docker==7.1.0", "pytest==7.3.2", - "uv==0.9.6" + "uv==0.11.6" ] [tool.setuptools] diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh index d50bd623dace0..d76b75dea6850 100644 --- a/.ci/manywheel/build_common.sh +++ b/.ci/manywheel/build_common.sh @@ -91,11 +91,6 @@ export PYTORCH_BUILD_NUMBER=$build_number export CMAKE_LIBRARY_PATH="/opt/intel/lib:/lib:$CMAKE_LIBRARY_PATH" export CMAKE_INCLUDE_PATH="/opt/intel/include:$CMAKE_INCLUDE_PATH" -if [[ -e /opt/openssl ]]; then - export OPENSSL_ROOT_DIR=/opt/openssl - export CMAKE_INCLUDE_PATH="/opt/openssl/include":$CMAKE_INCLUDE_PATH -fi - mkdir -p /tmp/$WHEELHOUSE_DIR export PATCHELF_BIN=/usr/local/bin/patchelf @@ -118,6 +113,9 @@ retry pip install -qUr requirements-build.txt python setup.py clean retry pip install -qr requirements.txt case ${DESIRED_PYTHON} in + cp314*) + retry pip install -q --pre numpy==2.3.4 + ;; cp31*) retry pip install -q --pre numpy==2.1.0 ;; diff --git a/.ci/manywheel/build_cuda.sh b/.ci/manywheel/build_cuda.sh index 94bf6a6b4b26c..3a34200456d6e 100644 --- a/.ci/manywheel/build_cuda.sh +++ b/.ci/manywheel/build_cuda.sh @@ -4,15 +4,21 @@ set -ex SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P ))" +if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + echo "ERROR: PYTORCH_EXTRA_INSTALL_REQUIREMENTS is not set." + echo "CUDA wheels rely on nvidia pypi packages; this variable must define the runtime dependencies." + exit 1 +fi + export TORCH_NVCC_FLAGS="-Xfatbin -compress-all" export NCCL_ROOT_DIR=/usr/local/cuda export TH_BINARY_BUILD=1 -export USE_STATIC_CUDNN=1 -export USE_STATIC_NCCL=1 -export ATEN_STATIC_CUDA=1 -export USE_CUDA_STATIC_LINK=1 +export USE_STATIC_CUDNN=0 +export USE_STATIC_NCCL=0 +export ATEN_STATIC_CUDA=0 +export USE_CUDA_STATIC_LINK=0 export INSTALL_TEST=0 # dont install test binaries into site-packages -export USE_CUPTI_SO=0 +export USE_CUPTI_SO=1 export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build export USE_CUFILE=${USE_CUFILE:-1} export USE_SYSTEM_NCCL=1 @@ -68,7 +74,7 @@ if [[ -n "$DESIRED_CUDA" ]]; then if [[ ${DESIRED_CUDA} =~ ^[0-9]+\.[0-9]+$ ]]; then CUDA_VERSION=${DESIRED_CUDA} else - # cu126, cu128 etc... + # cu126, cu130 etc... if [[ ${#DESIRED_CUDA} -eq 5 ]]; then CUDA_VERSION="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4:1}" fi @@ -108,13 +114,7 @@ TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;9.0;10.0" case ${CUDA_VERSION} in 12.6) TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;${TORCH_CUDA_ARCH_LIST//10.0/}" ;; # Only 12.6 includes legacy Maxwell/Pascal/Volta, -Hopper support - 12.8) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};12.0" ;; # +Blackwell support - 12.9) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};12.0+PTX" # +Blackwell support + PTX for forward compatibility - if [[ "$PACKAGE_TYPE" == "libtorch" ]]; then - TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST//8.6;/}" # Remove 8.6 for libtorch - fi - ;; - 13.0) + 13.0|13.2) TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};$([[ "$ARCH" == "aarch64" ]] && echo "11.0;" || echo "")12.0+PTX" export TORCH_NVCC_FLAGS="-compress-mode=size" export BUILD_BUNDLE_PTXAS=1 @@ -169,130 +169,41 @@ DEPS_SONAME=( # CUDA_VERSION 12.*, 13.* if [[ $CUDA_VERSION == 12* || $CUDA_VERSION == 13* ]]; then - export USE_STATIC_CUDNN=0 # Try parallelizing nvcc as well TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" # Compress the fatbin with -compress-mode=size for CUDA 13 if [[ $CUDA_VERSION == 13* ]]; then export TORCH_NVCC_FLAGS="$TORCH_NVCC_FLAGS -compress-mode=size" fi - if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then - echo "Bundling with cudnn and cublas." - - DEPS_LIST+=( - "/usr/local/cuda/lib64/libcudnn_adv.so.9" - "/usr/local/cuda/lib64/libcudnn_cnn.so.9" - "/usr/local/cuda/lib64/libcudnn_graph.so.9" - "/usr/local/cuda/lib64/libcudnn_ops.so.9" - "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9" - "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9" - "/usr/local/cuda/lib64/libcudnn_heuristic.so.9" - "/usr/local/cuda/lib64/libcudnn.so.9" - "/usr/local/cuda/lib64/libcusparseLt.so.0" - "/usr/local/cuda/lib64/libnvrtc-builtins.so" - "/usr/local/cuda/lib64/libcufile.so.0" - "/usr/local/cuda/lib64/libcufile_rdma.so.1" - "/usr/local/cuda/lib64/libnvshmem_host.so.3" - "/usr/local/cuda/extras/CUPTI/lib64/libnvperf_host.so" - ) - DEPS_SONAME+=( - "libcudnn_adv.so.9" - "libcudnn_cnn.so.9" - "libcudnn_graph.so.9" - "libcudnn_ops.so.9" - "libcudnn_engines_runtime_compiled.so.9" - "libcudnn_engines_precompiled.so.9" - "libcudnn_heuristic.so.9" - "libcudnn.so.9" - "libcusparseLt.so.0" - "libnvrtc-builtins.so" - "libnvshmem_host.so.3" - "libcufile.so.0" - "libcufile_rdma.so.1" - "libnvperf_host.so" - ) - # Add libnvToolsExt only if CUDA version is not 12.9 - if [[ $CUDA_VERSION == 13* ]]; then - DEPS_LIST+=( - "/usr/local/cuda/lib64/libcublas.so.13" - "/usr/local/cuda/lib64/libcublasLt.so.13" - "/usr/local/cuda/lib64/libcudart.so.13" - "/usr/local/cuda/lib64/libnvrtc.so.13" - "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13" - "/usr/local/cuda/lib64/libibverbs.so.1" - "/usr/local/cuda/lib64/librdmacm.so.1" - "/usr/local/cuda/lib64/libmlx5.so.1" - "/usr/local/cuda/lib64/libnl-3.so.200" - "/usr/local/cuda/lib64/libnl-route-3.so.200") - DEPS_SONAME+=( - "libcublas.so.13" - "libcublasLt.so.13" - "libcudart.so.13" - "libnvrtc.so.13" - "libcupti.so.13" - "libibverbs.so.1" - "librdmacm.so.1" - "libmlx5.so.1" - "libnl-3.so.200" - "libnl-route-3.so.200") - export USE_CUPTI_SO=1 - export ATEN_STATIC_CUDA=0 - export USE_CUDA_STATIC_LINK=0 - export USE_CUFILE=0 - else - DEPS_LIST+=( - "/usr/local/cuda/lib64/libcublas.so.12" - "/usr/local/cuda/lib64/libcublasLt.so.12" - "/usr/local/cuda/lib64/libcudart.so.12" - "/usr/local/cuda/lib64/libnvrtc.so.12" - "/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.12") - DEPS_SONAME+=( - "libcublas.so.12" - "libcublasLt.so.12" - "libcudart.so.12" - "libnvrtc.so.12" - "libcupti.so.12") - - if [[ $CUDA_VERSION != 12.9* ]]; then - DEPS_LIST+=("/usr/local/cuda/lib64/libnvToolsExt.so.1") - DEPS_SONAME+=("libnvToolsExt.so.1") - fi - fi + echo "Using nvidia libs from pypi." + CUDA_RPATHS=( + '$ORIGIN/../../nvidia/cudnn/lib' + '$ORIGIN/../../nvidia/nvshmem/lib' + '$ORIGIN/../../nvidia/nccl/lib' + '$ORIGIN/../../nvidia/cusparselt/lib' + ) + if [[ $CUDA_VERSION == 13* ]]; then + CUDA_RPATHS+=('$ORIGIN/../../nvidia/cu13/lib') else - echo "Using nvidia libs from pypi." - CUDA_RPATHS=( - '$ORIGIN/../../nvidia/cudnn/lib' - '$ORIGIN/../../nvidia/nvshmem/lib' - '$ORIGIN/../../nvidia/nccl/lib' - '$ORIGIN/../../nvidia/cusparselt/lib' + CUDA_RPATHS+=( + '$ORIGIN/../../nvidia/cublas/lib' + '$ORIGIN/../../nvidia/cuda_cupti/lib' + '$ORIGIN/../../nvidia/cuda_nvrtc/lib' + '$ORIGIN/../../nvidia/cuda_runtime/lib' + '$ORIGIN/../../nvidia/cufft/lib' + '$ORIGIN/../../nvidia/curand/lib' + '$ORIGIN/../../nvidia/cusolver/lib' + '$ORIGIN/../../nvidia/cusparse/lib' + '$ORIGIN/../../cusparselt/lib' + '$ORIGIN/../../nvidia/nvtx/lib' + '$ORIGIN/../../nvidia/cufile/lib' ) - if [[ $CUDA_VERSION == 13* ]]; then - CUDA_RPATHS+=('$ORIGIN/../../nvidia/cu13/lib') - else - CUDA_RPATHS+=( - '$ORIGIN/../../nvidia/cublas/lib' - '$ORIGIN/../../nvidia/cuda_cupti/lib' - '$ORIGIN/../../nvidia/cuda_nvrtc/lib' - '$ORIGIN/../../nvidia/cuda_runtime/lib' - '$ORIGIN/../../nvidia/cufft/lib' - '$ORIGIN/../../nvidia/curand/lib' - '$ORIGIN/../../nvidia/cusolver/lib' - '$ORIGIN/../../nvidia/cusparse/lib' - '$ORIGIN/../../cusparselt/lib' - '$ORIGIN/../../nvidia/nvtx/lib' - '$ORIGIN/../../nvidia/cufile/lib' - ) - fi - - CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") - export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' - export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' - export FORCE_RPATH="--force-rpath" - export USE_STATIC_NCCL=0 - export ATEN_STATIC_CUDA=0 - export USE_CUDA_STATIC_LINK=0 - export USE_CUPTI_SO=1 fi + + CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") + export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' + export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' + export FORCE_RPATH="--force-rpath" else echo "Unknown cuda version $CUDA_VERSION" exit 1 diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh index d78fbd5c3ed36..852ecf7500604 100644 --- a/.ci/manywheel/build_libtorch.sh +++ b/.ci/manywheel/build_libtorch.sh @@ -59,12 +59,6 @@ export PYTORCH_BUILD_NUMBER=$build_number export CMAKE_LIBRARY_PATH="/opt/intel/lib:/lib:$CMAKE_LIBRARY_PATH" export CMAKE_INCLUDE_PATH="/opt/intel/include:$CMAKE_INCLUDE_PATH" -# set OPENSSL_ROOT_DIR=/opt/openssl if it exists -if [[ -e /opt/openssl ]]; then - export OPENSSL_ROOT_DIR=/opt/openssl - export CMAKE_INCLUDE_PATH="/opt/openssl/include":$CMAKE_INCLUDE_PATH -fi - # If given a python version like 3.6m or 2.7mu, convert this to the format we # expect. The binary CI jobs pass in python versions like this; they also only # ever pass one python version, so we assume that DESIRED_PYTHON is not a list diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh index bac56746f4501..fa5724dca25a7 100755 --- a/.ci/manywheel/build_rocm.sh +++ b/.ci/manywheel/build_rocm.sh @@ -97,20 +97,13 @@ ROCM_SO_FILES=( "libhipblaslt.so" "libhipsparselt.so" "libhiprtc.so" + "librocprofiler-sdk.so" + "librocprofiler-register.so" + "libhsa-amd-aqlprofile64.so" + "librocm-core.so" + "librocroller.so" ) -if [[ $ROCM_INT -ge 60100 ]]; then - ROCM_SO_FILES+=("librocprofiler-register.so") -fi - -if [[ $ROCM_INT -ge 60200 ]]; then - ROCM_SO_FILES+=("librocm-core.so") -fi - -if [[ $ROCM_INT -ge 70000 ]]; then - ROCM_SO_FILES+=("librocroller.so") -fi - OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then LIBGOMP_PATH="/usr/lib64/libgomp.so.1" @@ -121,61 +114,23 @@ if [[ "$OS_NAME" == *"CentOS Linux"* || "$OS_NAME" == *"AlmaLinux"* ]]; then else LIBTINFO_PATH="/usr/lib64/libtinfo.so.6" fi + LIBDW_PATH="/usr/lib64/libdw.so.1" LIBDRM_PATH="/opt/amdgpu/lib64/libdrm.so.2" LIBDRM_AMDGPU_PATH="/opt/amdgpu/lib64/libdrm_amdgpu.so.1" - if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then - # Below libs are direct dependencies of libhipsolver - LIBSUITESPARSE_CONFIG_PATH="/lib64/libsuitesparseconfig.so.4" - if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then - LIBCHOLMOD_PATH="/lib64/libcholmod.so.2" - # Below libs are direct dependencies of libsatlas - LIBGFORTRAN_PATH="/lib64/libgfortran.so.3" - else - LIBCHOLMOD_PATH="/lib64/libcholmod.so.3" - # Below libs are direct dependencies of libsatlas - LIBGFORTRAN_PATH="/lib64/libgfortran.so.5" - fi - # Below libs are direct dependencies of libcholmod - LIBAMD_PATH="/lib64/libamd.so.2" - LIBCAMD_PATH="/lib64/libcamd.so.2" - LIBCCOLAMD_PATH="/lib64/libccolamd.so.2" - LIBCOLAMD_PATH="/lib64/libcolamd.so.2" - LIBSATLAS_PATH="/lib64/atlas/libsatlas.so.3" - # Below libs are direct dependencies of libsatlas - LIBQUADMATH_PATH="/lib64/libquadmath.so.0" - fi MAYBE_LIB64=lib64 elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" LIBNUMA_PATH="/usr/lib/x86_64-linux-gnu/libnuma.so.1" LIBELF_PATH="/usr/lib/x86_64-linux-gnu/libelf.so.1" - if [[ $ROCM_INT -ge 50300 ]]; then - LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.6" - else - LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.5" - fi + LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.6" + LIBDW_PATH="/usr/lib/x86_64-linux-gnu/libdw.so.1" LIBDRM_PATH="/usr/lib/x86_64-linux-gnu/libdrm.so.2" LIBDRM_AMDGPU_PATH="/usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1" - if [[ $ROCM_INT -ge 60100 && $ROCM_INT -lt 60300 ]]; then - # Below libs are direct dependencies of libhipsolver - LIBCHOLMOD_PATH="/lib/x86_64-linux-gnu/libcholmod.so.3" - # Below libs are direct dependencies of libcholmod - LIBSUITESPARSE_CONFIG_PATH="/lib/x86_64-linux-gnu/libsuitesparseconfig.so.5" - LIBAMD_PATH="/lib/x86_64-linux-gnu/libamd.so.2" - LIBCAMD_PATH="/lib/x86_64-linux-gnu/libcamd.so.2" - LIBCCOLAMD_PATH="/lib/x86_64-linux-gnu/libccolamd.so.2" - LIBCOLAMD_PATH="/lib/x86_64-linux-gnu/libcolamd.so.2" - LIBMETIS_PATH="/lib/x86_64-linux-gnu/libmetis.so.5" - LIBLAPACK_PATH="/lib/x86_64-linux-gnu/liblapack.so.3" - LIBBLAS_PATH="/lib/x86_64-linux-gnu/libblas.so.3" - # Below libs are direct dependencies of libblas - LIBGFORTRAN_PATH="/lib/x86_64-linux-gnu/libgfortran.so.5" - LIBQUADMATH_PATH="/lib/x86_64-linux-gnu/libquadmath.so.0" - fi MAYBE_LIB64=lib fi OS_SO_PATHS=($LIBGOMP_PATH $LIBNUMA_PATH\ $LIBELF_PATH $LIBTINFO_PATH\ + $LIBDW_PATH\ $LIBDRM_PATH $LIBDRM_AMDGPU_PATH\ $LIBSUITESPARSE_CONFIG_PATH\ $LIBCHOLMOD_PATH $LIBAMD_PATH\ diff --git a/.ci/onnx/README.md b/.ci/onnx/README.md deleted file mode 100644 index 47739136aabdf..0000000000000 --- a/.ci/onnx/README.md +++ /dev/null @@ -1,12 +0,0 @@ -# Jenkins - -The scripts in this directory are the entrypoint for testing ONNX exporter. - -The environment variable `BUILD_ENVIRONMENT` is expected to be set to -the build environment you intend to test. It is a hint for the build -and test scripts to configure Caffe2 a certain way and include/exclude -tests. Docker images, they equal the name of the image itself. For -example: `py2-cuda9.0-cudnn7-ubuntu16.04`. The Docker images that are -built on Jenkins and are used in triggered builds already have this -environment variable set in their manifest. Also see -`./docker/jenkins/*/Dockerfile` and search for `BUILD_ENVIRONMENT`. diff --git a/.ci/onnx/common.sh b/.ci/onnx/common.sh deleted file mode 100644 index b8f912fbbb4e6..0000000000000 --- a/.ci/onnx/common.sh +++ /dev/null @@ -1,107 +0,0 @@ -#!/bin/bash - -set -ex - -source "$(dirname "${BASH_SOURCE[0]}")/../pytorch/common_utils.sh" - -LOCAL_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) -ROOT_DIR=$(cd "$LOCAL_DIR"/../.. && pwd) -TEST_DIR="$ROOT_DIR/test" -pytest_reports_dir="${TEST_DIR}/test-reports/python" - -# Figure out which Python to use -PYTHON="$(which python)" -if [[ "${BUILD_ENVIRONMENT}" =~ py((2|3)\.?[0-9]?\.?[0-9]?) ]]; then - PYTHON=$(which "python${BASH_REMATCH[1]}") -fi - -if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then - # HIP_PLATFORM is auto-detected by hipcc; unset to avoid build errors - unset HIP_PLATFORM -fi - -mkdir -p "$pytest_reports_dir" || true - -########################################## -# copied from .ci/pytorch/common_utils.sh -########################################## - -function get_pinned_commit() { - cat .github/ci_commit_pins/"${1}".txt -} - -function pip_install_whl() { - # This is used to install PyTorch and other build artifacts wheel locally - # without using any network connection - - # Convert the input arguments into an array - local args=("$@") - - # Check if the first argument contains multiple paths separated by spaces - if [[ "${args[0]}" == *" "* ]]; then - # Split the string by spaces into an array - IFS=' ' read -r -a paths <<< "${args[0]}" - # Loop through each path and install individually - for path in "${paths[@]}"; do - echo "Installing $path" - python3 -mpip install --no-index --no-deps "$path" - done - else - # Loop through each argument and install individually - for path in "${args[@]}"; do - echo "Installing $path" - python3 -mpip install --no-index --no-deps "$path" - done - fi -} - -function pip_build_and_install() { - local build_target=$1 - local wheel_dir=$2 - - local found_whl=0 - for file in "${wheel_dir}"/*.whl - do - if [[ -f "${file}" ]]; then - found_whl=1 - break - fi - done - - # Build the wheel if it doesn't exist - if [ "${found_whl}" == "0" ]; then - python3 -m pip wheel \ - --no-build-isolation \ - --no-deps \ - -w "${wheel_dir}" \ - "${build_target}" - fi - - for file in "${wheel_dir}"/*.whl - do - pip_install_whl "${file}" - done -} - -function install_torchvision() { - local orig_preload - local commit - commit=$(get_pinned_commit vision) - orig_preload=${LD_PRELOAD} - if [ -n "${LD_PRELOAD}" ]; then - # Silence dlerror to work-around glibc ASAN bug, see https://sourceware.org/bugzilla/show_bug.cgi?id=27653#c9 - echo 'char* dlerror(void) { return "";}'|gcc -fpic -shared -o "${HOME}/dlerror.so" -x c - - LD_PRELOAD=${orig_preload}:${HOME}/dlerror.so - fi - - if [[ "${BUILD_ENVIRONMENT}" == *cuda* ]]; then - # Not sure if both are needed, but why not - export FORCE_CUDA=1 - export WITH_CUDA=1 - fi - pip_build_and_install "git+https://github.com/pytorch/vision.git@${commit}" dist/vision - - if [ -n "${LD_PRELOAD}" ]; then - LD_PRELOAD=${orig_preload} - fi -} diff --git a/.ci/onnx/test.sh b/.ci/onnx/test.sh deleted file mode 100755 index 1f2a23b49dc45..0000000000000 --- a/.ci/onnx/test.sh +++ /dev/null @@ -1,29 +0,0 @@ -#!/bin/bash - -# shellcheck source=./common.sh -source "$(dirname "${BASH_SOURCE[0]}")/common.sh" - -# Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96) -WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace") -cleanup_workspace() { - echo "sudo may print the following warning message that can be ignored. The chown command will still run." - echo " sudo: setrlimit(RLIMIT_STACK): Operation not permitted" - echo "For more details refer to https://github.com/sudo-project/sudo/issues/42" - sudo chown -R "$WORKSPACE_ORIGINAL_OWNER_ID" /var/lib/jenkins/workspace -} -# Disable shellcheck SC2064 as we want to parse the original owner immediately. -# shellcheck disable=SC2064 -trap_add cleanup_workspace EXIT -sudo chown -R jenkins /var/lib/jenkins/workspace -git config --global --add safe.directory /var/lib/jenkins/workspace - -if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then - # TODO: This can be removed later once vision is also part of the Docker image - install_torchvision - # JIT C++ extensions require ninja, so put it into PATH. - export PATH="/var/lib/jenkins/.local/bin:$PATH" - # NB: ONNX test is fast (~15m) so it's ok to retry it few more times to avoid any flaky issue, we - # need to bring this to the standard PyTorch run_test eventually. The issue will be tracked in - # https://github.com/pytorch/pytorch/issues/98626 - "$ROOT_DIR/scripts/onnx/test.sh" -fi diff --git a/.ci/pytorch/binary_populate_env.sh b/.ci/pytorch/binary_populate_env.sh index 9ea588d1555e2..53914914c8c93 100755 --- a/.ci/pytorch/binary_populate_env.sh +++ b/.ci/pytorch/binary_populate_env.sh @@ -112,9 +112,8 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B fi fi -USE_GLOO_WITH_OPENSSL="ON" +USE_GLOO_WITH_OPENSSL="OFF" if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then - USE_GLOO_WITH_OPENSSL="OFF" USE_GOLD_LINKER="OFF" fi diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index 5853b7be2a98a..522486d784983 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -25,7 +25,10 @@ env if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then # Use jemalloc during compilation to mitigate https://github.com/pytorch/pytorch/issues/116289 - export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libjemalloc.so.2 + JEMALLOC_LIB=$(find /usr/lib* -name "libjemalloc.so.2" 2>/dev/null | head -1) + if [[ -n "${JEMALLOC_LIB}" ]]; then + export LD_PRELOAD="${JEMALLOC_LIB}" + fi echo "NVCC version:" nvcc --version fi @@ -144,6 +147,7 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then export USE_XCCL=1 export USE_MPI=0 export TORCH_XPU_ARCH_LIST=pvc + export USE_STATIC_MKL=1 fi # sccache will fail for CUDA builds if all cores are used for compiling @@ -174,14 +178,9 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && echo "${TORCH_CUDA_ARCH_LIST}" | tr ' export BUILD_CUSTOM_STEP="ninja -C build flash_attention -j ${J}" fi -if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then - export CC=clang - export CXX=clang++ - # TODO: Removeme once all the wrappers are gone - if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then - sudo rm -f /opt/cache/bin/clang++ - fi - +# TODO: Removeme once all the wrappers are gone +if [[ "$BUILD_ENVIRONMENT" == *clang* ]] && [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then + sudo rm -f /opt/cache/bin/clang++ fi if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then @@ -193,11 +192,20 @@ if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then export UBSAN_FLAGS="-fno-sanitize-recover=all" fi +if [[ "$BUILD_ENVIRONMENT" == *-tsan* ]]; then + export USE_TSAN=1 + export USE_CUDA=0 + export USE_XNNPACK=0 + export USE_FBGEMM=0 + export USE_DISTRIBUTED=0 + export BUILD_TEST=0 +fi + if [[ "${BUILD_ENVIRONMENT}" == *no-ops* ]]; then export USE_PER_OPERATOR_HEADERS=0 fi -if [[ "${BUILD_ENVIRONMENT}" != *cuda* ]]; then +if [[ "${BUILD_ENVIRONMENT}" != *cuda* && "${BUILD_ENVIRONMENT}" != *-tsan* ]]; then export BUILD_STATIC_RUNTIME_BENCHMARK=ON fi @@ -225,118 +233,101 @@ if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && "$BUI git config --global --add safe.directory /var/lib/jenkins/workspace fi -if [[ "$BUILD_ENVIRONMENT" == *-bazel-* ]]; then - set -e -o pipefail - - get_bazel - python3 tools/optional_submodules.py checkout_eigen +# check that setup.py would fail with bad arguments +echo "The next three invocations are expected to fail with invalid command error messages." +( ! get_exit_code python setup.py bad_argument ) +( ! get_exit_code python setup.py clean] ) +( ! get_exit_code python setup.py clean bad_argument ) + +if [[ "$BUILD_ENVIRONMENT" != *libtorch* ]]; then + # rocm builds fail when WERROR=1 + # XLA test build fails when WERROR=1 + # set only when building other architectures + # or building non-XLA tests. + if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then + # TODO: Remove me and may be just focus on numpy-2.x testing + if [[ "$ANACONDA_PYTHON_VERSION" =~ ^3\.1[0-2]$ ]]; then + # Install numpy-2.0.2 for builds which are backward compatible with 1.X + # In relality it's only needed for numpy_2_x and vllm shards (where vllm depends on numpy-2) + python -mpip install numpy==2.0.2 + fi - # Leave 1 CPU free and use only up to 80% of memory to reduce the change of crashing - # the runner - BAZEL_MEM_LIMIT="--local_ram_resources=HOST_RAM*.8" - BAZEL_CPU_LIMIT="--local_cpu_resources=HOST_CPUS-1" + WERROR=1 python setup.py clean - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # Build torch, the Python module, and tests for CPU-only - tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" --config=cpu-only :torch :torch/_C.so :all_tests + WERROR=1 python -m build --wheel --no-isolation else - tools/bazel build --config=no-tty "${BAZEL_MEM_LIMIT}" "${BAZEL_CPU_LIMIT}" //... - fi -else - # check that setup.py would fail with bad arguments - echo "The next three invocations are expected to fail with invalid command error messages." - ( ! get_exit_code python setup.py bad_argument ) - ( ! get_exit_code python setup.py clean] ) - ( ! get_exit_code python setup.py clean bad_argument ) - - if [[ "$BUILD_ENVIRONMENT" != *libtorch* ]]; then - # rocm builds fail when WERROR=1 - # XLA test build fails when WERROR=1 - # set only when building other architectures - # or building non-XLA tests. - if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *xla* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then - # TODO: Remove me and may be just focus on numpy-2.x testing - if [[ "$ANACONDA_PYTHON_VERSION" =~ ^3\.1[0-2]$ ]]; then - # Install numpy-2.0.2 for builds which are backward compatible with 1.X - # In relality it's only needed for numpy_2_x and vllm shards (where vllm depends on numpy-2) - python -mpip install numpy==2.0.2 - fi - - WERROR=1 python setup.py clean - - WERROR=1 python -m build --wheel --no-isolation - else - python setup.py clean - if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then - source .ci/pytorch/install_cache_xla.sh - fi - python -m build --wheel --no-isolation - fi - pip_install_whl "$(echo dist/*.whl)" - if [[ "$BUILD_ENVIRONMENT" == *full-debug* ]]; then - # Regression test for https://github.com/pytorch/pytorch/issues/164297 - # Torch should be importable and that's about it - pushd /; python -c "import torch;print(torch.__config__.show(), torch.randn(5) + 1.7)"; popd + python setup.py clean + if [[ "$BUILD_ENVIRONMENT" == *xla* ]]; then + source .ci/pytorch/install_cache_xla.sh fi + python -m build --wheel --no-isolation + fi + pip_install_whl "$(echo dist/*.whl)" + if [[ "$BUILD_ENVIRONMENT" == *full-debug* ]]; then + # Regression test for https://github.com/pytorch/pytorch/issues/164297 + # Torch should be importable and that's about it + pushd /; python -c "import torch;print(torch.__config__.show(), torch.randn(5) + 1.7)"; popd + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then - install_torchvision - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *vision* ]]; then + install_torchvision + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *audio* ]]; then - install_torchaudio - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *audio* ]]; then + install_torchaudio + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES:-}" == *fbgemm* ]]; then - install_torchrec_and_fbgemm - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchrec* || "${BUILD_ADDITIONAL_PACKAGES:-}" == *fbgemm* ]]; then + install_torchrec_and_fbgemm + fi - if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchao* ]]; then - install_torchao - fi + if [[ "${BUILD_ADDITIONAL_PACKAGES:-}" == *torchao* ]]; then + install_torchao + fi - if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then - echo "Checking that xpu is compiled" - pushd dist/ - if python -c 'import torch; exit(0 if torch.xpu._is_compiled() else 1)'; then - echo "XPU support is compiled in." - else - echo "XPU support is NOT compiled in." - exit 1 - fi - popd + if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then + echo "Checking that xpu is compiled" + pushd dist/ + if python -c 'import torch; exit(0 if torch.xpu._is_compiled() else 1)'; then + echo "XPU support is compiled in." + else + echo "XPU support is NOT compiled in." + exit 1 fi + popd + fi - # TODO: I'm not sure why, but somehow we lose verbose commands - set -x + # TODO: I'm not sure why, but somehow we lose verbose commands + set -x - assert_git_not_dirty - # Copy ninja build logs to dist folder - mkdir -p dist - if [ -f build/.ninja_log ]; then - cp build/.ninja_log dist - fi + assert_git_not_dirty + # Copy ninja build logs to dist folder + mkdir -p dist + if [ -f build/.ninja_log ]; then + cp build/.ninja_log dist + fi - if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then - # remove sccache wrappers post-build; runtime compilation of MIOpen kernels does not yet fully support them - sudo rm -f /opt/cache/bin/cc - sudo rm -f /opt/cache/bin/c++ - sudo rm -f /opt/cache/bin/gcc - sudo rm -f /opt/cache/bin/g++ - # Restore original clang compilers that were backed up during sccache wrapping. - # Skip for theRock nightly: sccache wrapping is disabled, so no backup exists. - # theRock also uses ${ROCM_PATH}/lib/llvm/bin instead of /opt/rocm/llvm/bin. - if [[ -d /opt/rocm/llvm/bin ]]; then - pushd /opt/rocm/llvm/bin - if [[ -d original ]]; then - sudo mv original/clang . - sudo mv original/clang++ . - fi - sudo rm -rf original - popd + if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then + # remove sccache wrappers post-build; runtime compilation of MIOpen kernels does not yet fully support them + sudo rm -f /opt/cache/bin/cc + sudo rm -f /opt/cache/bin/c++ + sudo rm -f /opt/cache/bin/gcc + sudo rm -f /opt/cache/bin/g++ + # Restore original clang compilers that were backed up during sccache wrapping. + # Skip for theRock nightly: sccache wrapping is disabled, so no backup exists. + # theRock also uses ${ROCM_PATH}/lib/llvm/bin instead of /opt/rocm/llvm/bin. + if [[ -d /opt/rocm/llvm/bin ]]; then + pushd /opt/rocm/llvm/bin + if [[ -d original ]]; then + sudo mv original/clang . + sudo mv original/clang++ . fi + sudo rm -rf original + popd fi + fi + if [[ "$BUILD_ENVIRONMENT" != *-tsan* ]]; then CUSTOM_TEST_ARTIFACT_BUILD_DIR=${CUSTOM_TEST_ARTIFACT_BUILD_DIR:-"build/custom_test_artifacts"} CUSTOM_TEST_USE_ROCM=$([[ "$BUILD_ENVIRONMENT" == *rocm* ]] && echo "ON" || echo "OFF") CUSTOM_TEST_MODULE_PATH="${PWD}/cmake/public" @@ -380,32 +371,32 @@ else make VERBOSE=1 popd assert_git_not_dirty - else - # Test no-Python build - echo "Building libtorch" - - # This is an attempt to mitigate flaky libtorch build OOM error. By default, the build parallelization - # is set to be the number of CPU minus 2. So, let's try a more conservative value here. A 4xlarge has - # 16 CPUs - MAX_JOBS=$(nproc --ignore=4) - export MAX_JOBS - - # NB: Install outside of source directory (at the same level as the root - # pytorch folder) so that it doesn't get cleaned away prior to docker push. - BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py - mkdir -p ../cpp-build/caffe2 - pushd ../cpp-build/caffe2 - WERROR=1 VERBOSE=1 DEBUG=1 python "$BUILD_LIBTORCH_PY" - popd fi +else + # Test no-Python build + echo "Building libtorch" + + # This is an attempt to mitigate flaky libtorch build OOM error. By default, the build parallelization + # is set to be the number of CPU minus 2. So, let's try a more conservative value here. A 4xlarge has + # 16 CPUs + MAX_JOBS=$(nproc --ignore=4) + export MAX_JOBS + + BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py + # Build outside the source tree so the artifacts don't interfere with + # the workspace. /tmp is writable on both EC2 and OSDC runners. + mkdir -p /tmp/cpp-build/caffe2 + pushd /tmp/cpp-build/caffe2 + WERROR=1 VERBOSE=1 DEBUG=1 python "$BUILD_LIBTORCH_PY" + popd fi -if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then +if [[ "$BUILD_ENVIRONMENT" != *libtorch* ]]; then # export test times so that potential sharded tests that'll branch off this build will use consistent data # don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build PYTHONPATH=. python tools/stats/export_test_times.py fi -# don't do this for bazel or s390x or riscv64 as they don't use sccache -if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* && "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then +# don't do this for s390x or riscv64 as they don't use sccache +if [[ "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *riscv64* ]]; then print_sccache_stats fi diff --git a/.ci/pytorch/check_binary.sh b/.ci/pytorch/check_binary.sh index 3ff73084e4065..9356970394e81 100755 --- a/.ci/pytorch/check_binary.sh +++ b/.ci/pytorch/check_binary.sh @@ -292,25 +292,6 @@ if [[ "$PACKAGE_TYPE" != 'libtorch' ]]; then popd fi -############################################################################### -# Check PyTorch supports TCP_TLS gloo transport -############################################################################### - -if [[ "$(uname)" == 'Linux' && "$PACKAGE_TYPE" != 'libtorch' ]]; then - GLOO_CHECK="import torch.distributed as dist -try: - dist.init_process_group('gloo', rank=0, world_size=1) -except RuntimeError as e: - print(e) -" - RESULT=`GLOO_DEVICE_TRANSPORT=TCP_TLS MASTER_ADDR=localhost MASTER_PORT=63945 python -c "$GLOO_CHECK"` - GLOO_TRANSPORT_IS_NOT_SUPPORTED='gloo transport is not supported' - if [[ "$RESULT" =~ "$GLOO_TRANSPORT_IS_NOT_SUPPORTED" ]]; then - echo "PyTorch doesn't support TLS_TCP transport, please build with USE_GLOO_WITH_OPENSSL=1" - exit 1 - fi -fi - ############################################################################### # Restore LD_LIBRARY_PATH to its original value ############################################################################### diff --git a/.ci/pytorch/common.sh b/.ci/pytorch/common.sh index eae12816fe71e..94d9629eac519 100644 --- a/.ci/pytorch/common.sh +++ b/.ci/pytorch/common.sh @@ -5,7 +5,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh" set -ex -o pipefail -# for ROCm environment variables +# Source ROCm environment variables (paths may vary between tarball/wheel installs) if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]] && [[ -f /etc/rocm_env.sh ]]; then # shellcheck disable=SC1091 source /etc/rocm_env.sh @@ -14,6 +14,19 @@ fi # Required environment variables: # $BUILD_ENVIRONMENT (should be set by your Docker image) +# Select compiler based on build environment name. Images that have both +# GCC and Clang installed default cc/c++ to Clang (via install_clang.sh), +# so we need to override when a gcc build is requested. +if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then + export CC=clang + export CXX=clang++ +elif [[ "${BUILD_ENVIRONMENT}" == *gcc* ]]; then + export CC=gcc + export CXX=g++ + sudo update-alternatives --install /usr/bin/cc cc /usr/bin/gcc 100 + sudo update-alternatives --install /usr/bin/c++ c++ /usr/bin/g++ 100 +fi + # Figure out which Python to use for ROCm if [[ "${BUILD_ENVIRONMENT}" == *rocm* ]]; then # HIP_PLATFORM is auto-detected by hipcc; unset to avoid build errors diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh index 88e587ab5ff7d..e5ac1bab16192 100644 --- a/.ci/pytorch/common_utils.sh +++ b/.ci/pytorch/common_utils.sh @@ -127,17 +127,6 @@ function get_exit_code() { return $retcode } -function get_bazel() { - # Download and use the cross-platform, dependency-free Python - # version of Bazelisk to fetch the platform specific version of - # Bazel to use from .bazelversion. - retry curl --location --output tools/bazel \ - https://raw.githubusercontent.com/bazelbuild/bazelisk/v1.23.0/bazelisk.py - shasum --algorithm=1 --check \ - <(echo '01df9cf7f08dd80d83979ed0d0666a99349ae93c tools/bazel') - chmod u+x tools/bazel -} - function install_monkeytype { # Install MonkeyType pip_install MonkeyType @@ -307,25 +296,9 @@ function install_torchao() { } function install_flash_attn_cute() { - echo "Installing FlashAttention CuTe from GitHub..." - # Grab latest main til we have a pinned commit - local flash_attn_commit - flash_attn_commit=$(git ls-remote https://github.com/Dao-AILab/flash-attention.git HEAD | cut -f1) - - # Clone the repo to a temporary directory - rm -rf flash-attention-build - git clone --depth 1 --recursive https://github.com/Dao-AILab/flash-attention.git flash-attention-build - - pushd flash-attention-build - git checkout "${flash_attn_commit}" - - # Install only the 'cute' sub-directory - pip_install flash_attn/cute/ - popd - - # remove the local repo - rm -rf flash-attention-build - echo "FlashAttention CuTe installation complete." + echo "Installing FlashAttention 4 from PyPI..." + pip_install flash-attn-4==4.0.0b5 + echo "FlashAttention 4 installation complete." } function install_cutlass_dsl() { diff --git a/.ci/pytorch/cpp_doc_push_script.sh b/.ci/pytorch/cpp_doc_push_script.sh index f085fa78bebe9..d0b4fd38826fa 100755 --- a/.ci/pytorch/cpp_doc_push_script.sh +++ b/.ci/pytorch/cpp_doc_push_script.sh @@ -1,7 +1,7 @@ #!/bin/bash # This is where the local pytorch install in the docker image is located -pt_checkout="/var/lib/jenkins/workspace" +pt_checkout="${GITHUB_WORKSPACE:-/var/lib/jenkins/workspace}" # Since we're cat-ing this file, we need to escape all $'s echo "cpp_doc_push_script.sh: Invoked with $*" @@ -60,6 +60,34 @@ time python tools/setup_helpers/generate_code.py \ pushd docs/cpp time make VERBOSE=1 html +# Run C++ API coverage check (allowlist-based + HTML formatting) +echo "Running C++ docs coverage check..." +python check_coverage.py --coverxygen || coverage_exit=$? + +# Generate coverxygen HTML report if coverxygen produced output +if [ -f coverxygen.info ] && command -v genhtml &> /dev/null; then + genhtml --no-function-coverage coverxygen.info -o build/html/_coverage \ + --title "PyTorch C++ API Doc Coverage" \ + --legend --highlight 2>/dev/null || true +fi + +# Copy coverage reports into the build output so they get uploaded +mkdir -p build/html/_coverage +cp -f cpp_coverage.txt cpp_html_issues.txt build/html/_coverage/ 2>/dev/null || true +cp -f coverxygen.info build/html/_coverage/ 2>/dev/null || true + +if [ "${coverage_exit:-0}" -ne 0 ]; then + echo "" + echo "========================================" + echo "C++ DOCS COVERAGE: HIGH-PRIORITY GAPS" + echo "========================================" + echo "" + cat cpp_coverage.txt + echo "" + echo "See the full coverage report at: _coverage/cpp_coverage.txt" + echo "See the HTML issues report at: _coverage/cpp_html_issues.txt" +fi + popd popd @@ -76,6 +104,7 @@ cp -r "${pt_checkout}"/docs/cpp/build/html/* . # Copy back _config.yml rm -rf _config.yml mv /tmp/cppdocs-sync/* . +touch .nojekyll # Make a new commit git add . || true diff --git a/.ci/pytorch/macos-build.sh b/.ci/pytorch/macos-build.sh index f2bcb486cf95b..3259c62149b61 100755 --- a/.ci/pytorch/macos-build.sh +++ b/.ci/pytorch/macos-build.sh @@ -48,7 +48,7 @@ if [[ ${BUILD_ENVIRONMENT} == *"distributed"* ]]; then else # Explicitly set USE_DISTRIBUTED=0 to align with the default build config on mac. This also serves as the sole CI config that tests # that building with USE_DISTRIBUTED=0 works at all. See https://github.com/pytorch/pytorch/issues/86448 - USE_DISTRIBUTED=0 USE_OPENMP=1 MACOSX_DEPLOYMENT_TARGET=11.0 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python -m build --wheel --no-isolation -C--build-option=--plat-name=macosx_11_0_arm64 + USE_DISTRIBUTED=0 USE_OPENMP=1 WERROR=1 BUILD_TEST=OFF USE_PYTORCH_METAL=1 python -m build --wheel --no-isolation fi if which sccache > /dev/null; then print_sccache_stats diff --git a/.ci/pytorch/macos-common.sh b/.ci/pytorch/macos-common.sh index 6826a52577a29..d21ee458f6c99 100755 --- a/.ci/pytorch/macos-common.sh +++ b/.ci/pytorch/macos-common.sh @@ -9,6 +9,6 @@ sysctl -a | grep machdep.cpu # These are required for both the build job and the test job. # In the latter to test cpp extensions. -export MACOSX_DEPLOYMENT_TARGET=11.1 +export MACOSX_DEPLOYMENT_TARGET=14.0 export CXX=clang++ export CC=clang diff --git a/.ci/pytorch/python_doc_push_script.sh b/.ci/pytorch/python_doc_push_script.sh index 6bcd46c4815a6..44b2a26e9fa66 100755 --- a/.ci/pytorch/python_doc_push_script.sh +++ b/.ci/pytorch/python_doc_push_script.sh @@ -1,7 +1,7 @@ #!/bin/bash # This is where the local pytorch install in the docker image is located -pt_checkout="/var/lib/jenkins/workspace" +pt_checkout="${GITHUB_WORKSPACE:-/var/lib/jenkins/workspace}" source "$pt_checkout/.ci/pytorch/common_utils.sh" @@ -50,8 +50,14 @@ echo "install_path: $install_path version: $version" build_docs () { set +e - set -o pipefail - make "$1" 2>&1 | tee /tmp/docs_build.txt + # Don't pipe through tee: sphinx -j auto forks workers that inherit + # the pipe fd and hold it open after sphinx exits, causing tee to + # block forever. Write to a file and tail with --pid so it exits + # (after draining) when make finishes. + make "$1" > /tmp/docs_build.txt 2>&1 & + local make_pid=$! + tail -f --pid=$make_pid /tmp/docs_build.txt + wait $make_pid code=$? if [ $code -ne 0 ]; then set +x @@ -87,7 +93,10 @@ pushd docs if [ "$is_main_doc" = true ]; then build_docs html || exit $? - make coverage + # Run coverage check without parallel workers since it's a quick + # check that doesn't need parallelism, and avoids re-triggering the + # expensive parallel read/write machinery. + SPHINXOPTS="-WT --keep-going" make coverage # Now we have the coverage report, we need to make sure it is empty. # Sphinx 7.2.6+ format: python.txt contains a statistics table with a TOTAL row # showing the undocumented count in the third column. diff --git a/.ci/pytorch/smoke_test/check_binary_symbols.py b/.ci/pytorch/smoke_test/check_binary_symbols.py index 4a49be6db5730..0af32d4454b50 100755 --- a/.ci/pytorch/smoke_test/check_binary_symbols.py +++ b/.ci/pytorch/smoke_test/check_binary_symbols.py @@ -471,7 +471,7 @@ def check_lib_symbols_for_abi_correctness(lib: str) -> None: def main() -> None: if "install_root" in os.environ: - install_root = Path(os.getenv("install_root")) # noqa: SIM112 + install_root = Path(os.getenv("install_root")) else: if os.getenv("PACKAGE_TYPE") == "libtorch": install_root = Path(os.getcwd()) diff --git a/.ci/pytorch/smoke_test/check_wheel_tags.py b/.ci/pytorch/smoke_test/check_wheel_tags.py new file mode 100644 index 0000000000000..901304657b531 --- /dev/null +++ b/.ci/pytorch/smoke_test/check_wheel_tags.py @@ -0,0 +1,269 @@ +"""Validate wheel platform tags and macOS dylib minos. +Supports two modes: +1. Pre-install: reads .whl files from PYTORCH_FINAL_PACKAGE_DIR +2. Post-install: reads metadata from installed torch package (soft warnings) +- (macOS only) dylib minos matches the wheel platform tag +""" + +import os +import platform +import re +import subprocess +import sys +import tempfile +import zipfile +from pathlib import Path + + +EXPECTED_PLATFORM_TAGS: dict[str, str] = { + "linux": r"_x86_64$", + "linux-aarch64": r"_aarch64$", + "windows": r"^win_amd64$", + "win32": r"^win_amd64$", + "macos-arm64": r"^macosx_\d+_\d+_arm64$", + "darwin": r"^macosx_\d+_\d+_(arm64|x86_64)$", +} + + +def _extract_wheel_tags(whl_path: Path) -> list[str]: + """Extract Tag values from the WHEEL metadata file inside a .whl archive.""" + tags = [] + with zipfile.ZipFile(whl_path, "r") as zf: + wheel_files = [n for n in zf.namelist() if n.endswith("/WHEEL")] + if not wheel_files: + return tags + content = zf.read(wheel_files[0]).decode("utf-8") + for line in content.splitlines(): + if line.startswith("Tag:"): + tags.append(line.split(":", 1)[1].strip()) + return tags + + +def _extract_installed_wheel_tags(package: str = "torch") -> list[str]: + """Extract Tag values from an installed package's WHEEL metadata.""" + from importlib.metadata import distribution + + dist = distribution(package) + wheel_text = dist.read_text("WHEEL") + if not wheel_text: + return [] + tags = [] + for line in wheel_text.splitlines(): + if line.startswith("Tag:"): + tags.append(line.split(":", 1)[1].strip()) + return tags + + +def check_wheel_platform_tag() -> None: + """Validate that wheel Tags in WHEEL metadata match the expected platform. + + Mode 1: PYTORCH_FINAL_PACKAGE_DIR set → read .whl file (strict, raises on mismatch) + Mode 2: No wheel dir → read from installed torch package (soft, prints warnings) + """ + wheel_dir = os.getenv("PYTORCH_FINAL_PACKAGE_DIR", "") + + target_os = os.getenv("TARGET_OS", sys.platform) + if target_os == "linux" and platform.machine() == "aarch64": + target_os = "linux-aarch64" + expected_python = f"cp{sys.version_info.major}{sys.version_info.minor}" + import sysconfig + + abiflags = getattr(sys, "abiflags", "") + if not abiflags and ( + os.getenv("MATRIX_PYTHON_VERSION", "").endswith("t") + or bool(sysconfig.get_config_var("Py_GIL_DISABLED")) + or not getattr(sys, "_is_gil_enabled", lambda: True)() + ): + abiflags = "t" + expected_abi = f"cp{sys.version_info.major}{sys.version_info.minor}{abiflags}" + print(f"Expected ABI tag: {expected_abi}") + + platform_pattern = EXPECTED_PLATFORM_TAGS.get(target_os) + if not platform_pattern: + print( + f"No expected platform pattern for TARGET_OS={target_os}, " + "skipping wheel tag check" + ) + return + + # Mode 1: Read from .whl file + if wheel_dir and os.path.isdir(wheel_dir): + whls = list(Path(wheel_dir).glob("torch-*.whl")) + if not whls: + print(f"No torch wheel found in {wheel_dir}, skipping wheel tag check") + return + if len(whls) > 1: + raise RuntimeError( + f"Expected exactly one torch wheel in {wheel_dir}, " + f"found {len(whls)}: {[w.name for w in whls]}" + ) + whl = whls[0] + print(f"Checking wheel platform tag for: {whl.name}") + tags = _extract_wheel_tags(whl) + source = whl.name + else: + # Mode 2: Read from installed package (soft) + print("PYTORCH_FINAL_PACKAGE_DIR not set, reading from installed torch package") + try: + tags = _extract_installed_wheel_tags("torch") + source = "installed torch" + except Exception as e: + print(f"Could not read installed torch metadata: {e}, skipping") + return + + if not tags: + raise RuntimeError(f"No Tag found in WHEEL metadata of {source}") + + for tag_str in tags: + parts = tag_str.split("-") + if len(parts) != 3: + msg = ( + f"Malformed wheel tag '{tag_str}' in {source}, " + f"expected format: --" + ) + raise RuntimeError(msg) + + python_tag, abi_tag, platform_tag = parts + + print(f"Checking tag: {tag_str} (from {source})") + if python_tag != expected_python: + msg: str = ( + f"Python tag mismatch in {source}: " + f"got '{python_tag}', expected '{expected_python}'" + ) + raise RuntimeError(msg) + + if abi_tag != expected_abi: + msg = ( + f"ABI tag mismatch in {source}: " + f"got '{abi_tag}', expected '{expected_abi}'" + ) + raise RuntimeError(msg) + + if not re.search(platform_pattern, platform_tag): + msg = ( + f"Platform tag mismatch in {source}: " + f"got '{platform_tag}', expected pattern matching " + f"'{platform_pattern}' for TARGET_OS={target_os}" + ) + raise RuntimeError(msg) + + print(f"OK: Wheel tag(s) valid for {source}: {', '.join(tags)}") + + +def _check_dylibs_minos(dylibs: list, expected_minos: str, source: str) -> None: + mismatches = [] + for dylib in dylibs: + try: + result = subprocess.run( + ["otool", "-l", str(dylib)], + capture_output=True, + text=True, + timeout=30, + ) + except Exception: + continue + + minos = None + lines = result.stdout.splitlines() + for i, line in enumerate(lines): + s = line.strip() + if "LC_BUILD_VERSION" in s: + for j in range(i + 1, min(i + 6, len(lines))): + if lines[j].strip().startswith("minos"): + minos = lines[j].strip().split()[1] + break + break + if "LC_VERSION_MIN_MACOSX" in s: + for j in range(i + 1, min(i + 4, len(lines))): + if lines[j].strip().startswith("version"): + minos = lines[j].strip().split()[1] + break + break + + # A dylib with a lower minos than the wheel tag is safe (forward compatible). + # Only flag dylibs that require a *higher* macOS than the wheel claims to support. + if minos and tuple(int(x) for x in minos.split(".")) > tuple( + int(x) for x in expected_minos.split(".") + ): + mismatches.append( + f"{dylib.name}: minos={minos}, expected<={expected_minos}" + ) + + if mismatches: + raise RuntimeError( + f"minos/platform tag mismatch in {len(mismatches)} dylib(s):\n" + + "\n".join(f" {m}" for m in mismatches) + ) + print( + f"OK: All {len(dylibs)} dylib(s) have minos matching " + f"platform tag ({expected_minos}) for {source}" + ) + + +def check_mac_wheel_minos() -> None: + if sys.platform != "darwin": + return + + wheel_dir = os.getenv("PYTORCH_FINAL_PACKAGE_DIR", "") + + if wheel_dir and os.path.isdir(wheel_dir): + # Mode 1: extract dylibs from .whl file + whls = list(Path(wheel_dir).glob("*.whl")) + if not whls: + print(f"No .whl files in {wheel_dir}, skipping wheel minos check") + return + + macos_whl_re = re.compile(r"macosx_(\d+)_(\d+)_(\w+)\.whl$") + for whl in whls: + print(f"Checking wheel tag minos for: {whl.name}") + m = macos_whl_re.search(whl.name) + if not m: + print(f"No macOS platform tag in {whl.name}, skipping") + continue + expected_minos = f"{m.group(1)}.{m.group(2)}" + + with tempfile.TemporaryDirectory() as tmpdir: + with zipfile.ZipFile(whl, "r") as zf: + dylib_names = [n for n in zf.namelist() if n.endswith(".dylib")] + if not dylib_names: + print("No .dylib files in wheel, skipping minos check") + continue + for name in dylib_names: + zf.extract(name, tmpdir) + dylibs = list(Path(tmpdir).rglob("*.dylib")) + _check_dylibs_minos(dylibs, expected_minos, whl.name) + else: + # Mode 2: read from installed torch package + print("PYTORCH_FINAL_PACKAGE_DIR not set, checking installed torch dylibs") + try: + tags = _extract_installed_wheel_tags("torch") + except Exception as e: + print(f"Could not read installed torch metadata: {e}, skipping") + return + + expected_minos = None + for tag_str in tags: + m = re.search(r"macosx_(\d+)_(\d+)_\w+", tag_str) + if m: + expected_minos = f"{m.group(1)}.{m.group(2)}" + break + + if not expected_minos: + print("No macOS platform tag found in installed torch metadata, skipping") + return + + print(f"Expected minos from installed wheel tag: {expected_minos}") + + import torch + + torch_dir = Path(torch.__file__).parent + dylibs = list(torch_dir.rglob("*.dylib")) + if not dylibs: + raise RuntimeError("No .dylib files found in installed torch") + _check_dylibs_minos(dylibs, expected_minos, "installed torch") + + +if __name__ == "__main__": + check_wheel_platform_tag() + check_mac_wheel_minos() diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index ce1beb68ff49a..8041f0bef64a5 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -10,6 +10,8 @@ from pathlib import Path from tempfile import NamedTemporaryFile +from check_wheel_tags import check_mac_wheel_minos, check_wheel_platform_tag + import torch import torch._dynamo import torch.nn as nn @@ -25,6 +27,7 @@ package_type = os.getenv("MATRIX_PACKAGE_TYPE") target_os = os.getenv("TARGET_OS", sys.platform) BASE_DIR = Path(__file__).parent.parent.parent +PYTORCH_ROOT = BASE_DIR.parent is_cuda_system = gpu_arch_type == "cuda" NIGHTLY_ALLOWED_DELTA = 3 @@ -193,10 +196,15 @@ def test_cuda_gds_errors_captured() -> None: print("Testing test_cuda_gds_errors_captured") with NamedTemporaryFile() as f: torch.cuda.gds.GdsFile(f.name, os.O_CREAT | os.O_RDWR) + # cuFile >= 1.17 (CUDA 13.2+) compat mode: registration succeeds + # without nvidia-fs driver, falling back to POSIX I/O + if major_version > 13 or (major_version == 13 and minor_version >= 2): + print("GDS handle registered successfully via compatibility mode") + cuda_exception_missed = False except RuntimeError as e: expected_error = "cuFileHandleRegister failed" if re.search(expected_error, f"{e}"): - print(f"Caught CUDA exception with success: {e}") + print(f"Caught expected CUDA exception: {e}") cuda_exception_missed = False else: raise e @@ -216,6 +224,89 @@ def find_pypi_package_version(package: str) -> str | None: return None +def get_expected_cudnn_version_linux(cuda_version: str) -> str | None: + """Parse expected cuDNN version from generate_binary_build_matrix.py for Linux. + + Reads PYTORCH_EXTRA_INSTALL_REQUIREMENTS and extracts the cudnn version + for the given CUDA version (e.g. "12.6"). + """ + matrix_script = ( + PYTORCH_ROOT / ".github" / "scripts" / "generate_binary_build_matrix.py" + ) + if not matrix_script.exists(): + print(f"Warning: {matrix_script} not found, skipping cuDNN version check") + return None + + content = matrix_script.read_text() + # Match the full cudnn package version like nvidia-cudnn-cu12==9.10.2.21 + # and extract major.minor.patch (dropping the build number) + pattern = ( + rf'"{re.escape(cuda_version)}":\s*\(\s*' + r"[\s\S]*?nvidia-cudnn-cu\d+==(\d+\.\d+\.\d+)\.\d+" + ) + match = re.search(pattern, content) + if match: + return match.group(1) + return None + + +def get_expected_cudnn_version_windows(cuda_version: str) -> str | None: + """Parse expected cuDNN version from cuda_install.bat for Windows. + + Reads the batch file and extracts EXPECTED_CUDNN_VERSION for the given + CUDA version (e.g. "12.6" maps to CUDA_VER 126). + """ + bat_file = ( + PYTORCH_ROOT / ".ci" / "pytorch" / "windows" / "internal" / "cuda_install.bat" + ) + if not bat_file.exists(): + print(f"Warning: {bat_file} not found, skipping cuDNN version check") + return None + + content = bat_file.read_text() + # Convert "12.6" to "126" to match batch file's CUDA_VER format + cuda_ver_nodot = cuda_version.replace(".", "") + # Match: if %CUDA_VER% EQU 126 ( ... set EXPECTED_CUDNN_VERSION=9.10.2 ) + pattern = ( + rf"if %CUDA_VER% EQU {re.escape(cuda_ver_nodot)}\s*\(" + r"[\s\S]*?set EXPECTED_CUDNN_VERSION=(\d+\.\d+\.\d+)" + ) + match = re.search(pattern, content) + if match: + return match.group(1) + return None + + +def check_cudnn_version(cuda_version: str, actual_cudnn_version: str) -> None: + """Validate cuDNN version matches expected version from build config files.""" + if sys.platform in ["linux", "linux2"]: + expected = get_expected_cudnn_version_linux(cuda_version) + source = "generate_binary_build_matrix.py" + elif sys.platform == "win32": + expected = get_expected_cudnn_version_windows(cuda_version) + source = "cuda_install.bat" + else: + print(f"cuDNN version check not supported on platform {sys.platform}") + return + + if expected is None: + print( + f"Warning: Could not determine expected cuDNN version for CUDA {cuda_version} " + f"from {source}, skipping validation" + ) + return + + if not actual_cudnn_version.startswith(expected): + raise RuntimeError( + f"cuDNN version mismatch for CUDA {cuda_version}. " + f"Loaded: {actual_cudnn_version} Expected: {expected} (from {source})" + ) + print( + f"cuDNN version check passed: {actual_cudnn_version} matches " + f"expected {expected} from {source}" + ) + + def cudnn_to_version_str(cudnn_version: int) -> str: patch = int(cudnn_version % 10) minor = int((cudnn_version / 100) % 100) @@ -294,6 +385,8 @@ def smoke_test_cuda( f"Expected: {torch_cudnn_compile_version}" ) + check_cudnn_version(gpu_arch_ver, torch_cudnn_version) + if sys.platform in ["linux", "linux2"]: torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") @@ -551,6 +644,9 @@ def main() -> None: smoke_test_nvshmem() + check_wheel_platform_tag() + check_mac_wheel_minos() + if __name__ == "__main__": main() diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index c4cc0cc1282ab..35abebff5b3d5 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -14,9 +14,10 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" # shellcheck source=./common-build.sh source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" -# Do not change workspace permissions for ROCm and s390x CI jobs -# as it can leave workspace with bad permissions for cancelled jobs -if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *s390x* && -d /var/lib/jenkins/workspace ]]; then +# Only change workspace permissions if passwordless sudo is available +# (e.g. ROCm and s390x CI jobs lack it, and changing permissions +# can leave the workspace in a bad state for cancelled jobs) +if sudo -n true 2>/dev/null && [[ -d /var/lib/jenkins/workspace ]]; then # Workaround for dind-rootless userid mapping (https://github.com/pytorch/ci-infra/issues/96) WORKSPACE_ORIGINAL_OWNER_ID=$(stat -c '%u' "/var/lib/jenkins/workspace") cleanup_workspace() { @@ -44,6 +45,16 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then fi fi +# Remove onnxruntime if present to avoid interference with non-ONNX tests +if [[ "$TEST_CONFIG" != "onnx" ]]; then + pip uninstall -y onnxruntime 2>/dev/null || true +fi + +# Remove dill to test that serialization works without it +if [[ "$BUILD_ENVIRONMENT" == *py3.10-gcc11 ]]; then + pip uninstall -y dill 2>/dev/null || true +fi + echo "Environment variables:" env @@ -129,9 +140,7 @@ if [[ "${PYTORCH_TEST_RERUN_DISABLED_TESTS}" == "1" ]] || [[ "${CONTINUE_THROUGH fi # Get fully qualified path using realpath -if [[ "$BUILD_ENVIRONMENT" != *bazel* ]]; then - CUSTOM_TEST_ARTIFACT_BUILD_DIR=$(realpath "${CUSTOM_TEST_ARTIFACT_BUILD_DIR:-"build/custom_test_artifacts"}") -fi +CUSTOM_TEST_ARTIFACT_BUILD_DIR=$(realpath "${CUSTOM_TEST_ARTIFACT_BUILD_DIR:-"build/custom_test_artifacts"}") # Reduce set of tests to include when running run_test.py if [[ -n $TESTS_TO_INCLUDE ]]; then @@ -144,11 +153,34 @@ env echo "Testing pytorch" +# Set OMP_NUM_THREADS to nproc/4 on k8s ARC runners if not already set. +# +# We use nproc (cgroup-aware) rather than os.cpu_count() because on k8s (ARC) +# pods, os.cpu_count() returns the host's CPU count (e.g., 192) rather than +# the pod's cpuset allocation (e.g., 16). +# +# We use nproc/4 rather than nproc because OpenMP spin-waits at thread barriers. +# When thread count equals cpuset size (e.g., 16 threads on 16 CPUs), spinning +# barrier threads monopolize all CPUs and the OS must context-switch to let +# actual work complete. This causes ~5000x slowdowns on small tensor ops +# (e.g., aten::copy_ on 147KB: ~34ms instead of ~7us). Using nproc/4 leaves +# headroom for the main thread and for NUM_PROCS=3 parallel test processes. +if [[ -z "${OMP_NUM_THREADS:-}" ]] && [[ -n "${USE_ARC:-}" ]]; then + OMP_NUM_THREADS=$(( $(nproc) / 4 )) + # Floor of 4: low OMP_NUM_THREADS (1-2) changes floating-point reduction + # order, causing numerical mismatches in tests with tight tolerances + # (e.g., test_batchnorm_nhwc_cpu). + if [[ "$OMP_NUM_THREADS" -lt 4 ]]; then + OMP_NUM_THREADS=4 + fi + export OMP_NUM_THREADS +fi + export LANG=C.UTF-8 PR_NUMBER=${PR_NUMBER:-${CIRCLE_PR_NUMBER:-}} -if [[ -d "${HF_CACHE}" ]]; then +if [[ -d "${HF_CACHE}" && "$TEST_CONFIG" != "onnx" ]]; then export HF_HOME="${HF_CACHE}" fi @@ -222,12 +254,10 @@ if [[ "$BUILD_ENVIRONMENT" == *xpu* ]]; then timeout 30 xpu-smi discovery || true fi -if [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]] ; then - # JIT C++ extensions require ninja (installed from requirements-ci.txt). - # ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins - # but this script should be runnable by any user, including root - export PATH="$HOME/.local/bin:$PATH" -fi +# JIT C++ extensions require ninja (installed from requirements-ci.txt). +# ninja is installed in $HOME/.local/bin, e.g., /var/lib/jenkins/.local/bin for CI user jenkins +# but this script should be runnable by any user, including root +export PATH="$HOME/.local/bin:$PATH" if [[ "$BUILD_ENVIRONMENT" == *aarch64* ]]; then # TODO: revisit this once the CI is stabilized on aarch64 linux @@ -292,6 +322,17 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_aten_asan(3)") fi +if [[ "$BUILD_ENVIRONMENT" == *-tsan* ]]; then + # Switch to TSan-instrumented CPython so that all subsequent python + # invocations (including the pre-test sanity checks below) use an + # interpreter that has the TSan runtime. + export PATH=/opt/python/cp314-cp314t+tsan/bin:$PATH + python -m pip install "$(echo dist/*.whl)[opt-einsum]" + TSAN_OPTIONS="log_path=$(pwd)/test/test-reports/tsan_toprint.log" + export TSAN_OPTIONS + export PYTORCH_TEST_WITH_TSAN=1 +fi + # The torch._C._crash_if_debug_asserts_fail() function should only fail if both of the following are true: # 1. The build is in debug mode # 2. The value 424242 is passed in @@ -299,8 +340,7 @@ fi if [[ "$BUILD_ENVIRONMENT" == *-debug* ]]; then echo "We are in debug mode: $BUILD_ENVIRONMENT. Expect the python assertion to fail" (cd test && ! get_exit_code python -c "import torch; torch._C._crash_if_debug_asserts_fail(424242)") -elif [[ "$BUILD_ENVIRONMENT" != *-bazel-* ]]; then - # Noop when debug is disabled. Skip bazel jobs because torch isn't available there yet. +else echo "We are not in debug mode: $BUILD_ENVIRONMENT. Expect the assertion to pass" (cd test && python -c "import torch; torch._C._crash_if_debug_asserts_fail(424242)") fi @@ -311,6 +351,29 @@ elif [[ $TEST_CONFIG == 'nogpu_AVX512' ]]; then export ATEN_CPU_CAPABILITY=avx2 fi +test_tsan() { + # PATH, TSAN_OPTIONS, and wheel install are set up earlier in this + # script when BUILD_ENVIRONMENT matches *-tsan*. + local test_status=0 + python test/test_tsan.py -v || test_status=$? + + # TSan appends . to log_path. Merge all reports into a single + # file with the _toprint.log suffix so the CI "Print remaining test + # logs" step picks them up automatically. + TSAN_REPORT=$(pwd)/test/test-reports/tsan_toprint.log + if ls "${TSAN_REPORT}".* 1>/dev/null 2>&1; then + cat "${TSAN_REPORT}".* > "${TSAN_REPORT}" + rm -f "${TSAN_REPORT}".* + fi + + if [ "$test_status" -ne 0 ]; then + echo "TSan tests failed with exit code $test_status" + exit "$test_status" + fi + + assert_git_not_dirty +} + test_python_legacy_jit() { time python test/run_test.py --include test_jit_legacy test_jit_fuser_legacy --verbose assert_git_not_dirty @@ -360,14 +423,17 @@ test_python_smoke_b200() { inductor/test_torchinductor \ inductor/test_nv_universal_gemm \ inductor/test_fused_attention \ + test_varlen_attention \ $PYTHON_TEST_EXTRA_OPTION \ --upload-artifacts-while-running assert_git_not_dirty } + test_python_smoke_xpu() { # Smoke tests for XPU client time python test/run_test.py --include test_transformers $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time test_xpu_sycl_tla_backend assert_git_not_dirty } @@ -414,10 +480,20 @@ test_b200_symm_mem() { test_h100_cutlass_backend() { # cutlass backend tests for H100 + git submodule update --init --depth 1 third_party/cutlass TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_backend -k "not addmm" $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/cutlass") python test/run_test.py --include inductor/test_cutlass_evt $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running } +test_xpu_sycl_tla_backend() { + # Inductor sycl-tla backend tests for XPU + # shellcheck disable=SC1091 + source /opt/intel/oneapi/mkl/latest/env/vars.sh + sycl_tla_dir=$(realpath "./third_party/sycl-tla") + rm -rf "${sycl_tla_dir}" && git clone --depth 1 --single-branch -b v0.8 --quiet https://github.com/intel/sycl-tla.git "${sycl_tla_dir}" + TORCHINDUCTOR_CUTLASS_DIR=$(realpath "./third_party/sycl-tla") python test/run_test.py --include inductor/test_cutlass_backend -k "not addmm" $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running +} + test_lazy_tensor_meta_reference_disabled() { export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1 echo "Testing lazy tensor operations without meta reference" @@ -434,12 +510,16 @@ test_dynamo_core() { } test_dynamo_cpython() { + # Disable TD for cpython since it's pretty cheap to run the cpython tests (< 10 min) + # and if TD is enabled, only 25% of the tests will be executed + export NO_TD=1 time python test/run_test.py \ --include-cpython-tests \ --dynamo \ --verbose \ --upload-artifacts-while-running assert_git_not_dirty + unset NO_TD } test_dynamo_wrapped_shard() { @@ -546,7 +626,7 @@ test_inductor_shard() { # Do not add --inductor for the following inductor unit tests, otherwise we will fail because of nested dynamo state python test/run_test.py \ - --include inductor/test_torchinductor inductor/test_torchinductor_opinfo inductor/test_aot_inductor \ + --include inductor/test_torchinductor inductor/test_torchinductor_opinfo inductor/test_aot_inductor inductor/test_cpu_select_algorithm \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose } @@ -612,8 +692,10 @@ test_inductor_cpp_wrapper_shard() { -k 'take' \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose - TORCHINDUCTOR_AUTOTUNE_AT_COMPILE_TIME=0 python test/run_test.py \ - --include inductor/test_torchinductor inductor/test_triton_kernels\ + # Keep testing TORCHINDUCTOR_AUTOTUNE_AT_COMPILE_TIME=1 for the near future. + # Will drop this after AOTInductor also switches to lazy Triton compilation. + TORCHINDUCTOR_AUTOTUNE_AT_COMPILE_TIME=1 python test/run_test.py \ + --include inductor/test_torchinductor inductor/test_triton_kernels inductor/test_max_autotune \ --shard "$1" "$NUM_TEST_SHARDS" \ --verbose if [[ "${BUILD_ENVIRONMENT}" == *xpu* ]]; then @@ -729,6 +811,10 @@ test_perf_for_dashboard() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" ]]; then + mkdir -p "$TEST_REPORTS_DIR/profiler_traces" + fi + local suite="$1" shift @@ -788,52 +874,102 @@ test_perf_for_dashboard() { fi if [[ "$DASHBOARD_TAG" == *default-true* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_no_cudagraphs_${suite}_${dtype}_${mode}_${device}") + fi $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_no_cudagraphs_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *cudagraphs-true* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_with_cudagraphs_${suite}_${dtype}_${mode}_${device}") + fi $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *dynamic-true* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_dynamic_${suite}_${dtype}_${mode}_${device}") + fi $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --dynamic-shapes \ --dynamic-batch-only "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_dynamic_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *cppwrapper-true* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_cpp_wrapper_${suite}_${dtype}_${mode}_${device}") + fi TORCHINDUCTOR_CPP_WRAPPER=1 $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_cpp_wrapper_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *freezing_cudagraphs-true* ]] && [[ "$mode" == "inference" ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_with_cudagraphs_freezing_${suite}_${dtype}_${mode}_${device}") + fi $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" "$@" --freezing \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_freezing_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *freeze_autotune_cudagraphs-true* ]] && [[ "$mode" == "inference" ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_with_cudagraphs_freezing_autotune_${suite}_${dtype}_${mode}_${device}") + fi TORCHINDUCTOR_MAX_AUTOTUNE=1 $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" "$@" --freezing \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_with_cudagraphs_freezing_autotune_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *aotinductor-true* ]] && [[ "$mode" == "inference" ]]; then - if [[ "$target" == "accuracy" ]]; then # Also collect Export pass rate and display as a separate row + if [[ "$target" == "accuracy" ]]; then $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --export --disable-cudagraphs "$@" \ --output "$TEST_REPORTS_DIR/${backend}_export_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_aot_inductor_${suite}_${dtype}_${mode}_${device}") + fi $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --export-aot-inductor --disable-cudagraphs "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_aot_inductor_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi if [[ "$DASHBOARD_TAG" == *maxautotune-true* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" && "$target" == "performance" ]]; then + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${backend}_max_autotune_${suite}_${dtype}_${mode}_${device}") + fi TORCHINDUCTOR_MAX_AUTOTUNE=1 $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" "$@" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${backend}_max_autotune_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi + if [[ "$DASHBOARD_TAG" == *deterministic_perf-true* ]]; then + $TASKSET python "benchmarks/dynamo/$suite.py" \ + "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs --deterministic "$@" \ + --output "$TEST_REPORTS_DIR/${backend}_deterministic_perf_${suite}_${dtype}_${mode}_${device}_${target}.csv" + fi + if [[ "$DASHBOARD_TAG" == *batch_invariant_accuracy-true* ]] && [[ "$target" == "accuracy" ]]; then + $TASKSET python "benchmarks/dynamo/$suite.py" \ + "${target_flag[@]}" --"$mode" --"$dtype" --backend "$backend" --disable-cudagraphs --batch-invariant "$@" \ + --output "$TEST_REPORTS_DIR/${backend}_batch_invariant_accuracy_${suite}_${dtype}_${mode}_${device}_${target}.csv" + fi done done } @@ -861,9 +997,15 @@ test_single_dynamo_benchmark() { fi if [[ "${TEST_CONFIG}" == *perf_compare* ]]; then + local profiler_trace_flags=() + if [[ "${EXPORT_PROFILER_TRACE:-0}" == "1" ]]; then + mkdir -p "$TEST_REPORTS_DIR/profiler_traces" + profiler_trace_flags=(--export-profiler-trace --profiler-trace-name "$TEST_REPORTS_DIR/profiler_traces/${name}_${suite}") + fi python "benchmarks/dynamo/$suite.py" \ --ci --performance --disable-cudagraphs --inductor \ "${DYNAMO_BENCHMARK_FLAGS[@]}" "$@" "${partition_flags[@]}" \ + "${profiler_trace_flags[@]}" \ --output "$TEST_REPORTS_DIR/${name}_${suite}.csv" elif [[ "${TEST_CONFIG}" == *perf* ]]; then test_perf_for_dashboard "$suite" \ @@ -917,6 +1059,49 @@ test_inductor_triton_cpu() { assert_git_not_dirty } +setup_torch_trace() { + if [[ "${ENABLE_TORCH_TRACE:-0}" != "1" ]]; then + return + fi + local trace_dir="${RUNNER_TEMP:-/tmp}/torch_traces" + mkdir -p "$trace_dir" + export TORCH_TRACE="$trace_dir" + echo "TORCH_TRACE enabled: writing structured trace logs to $trace_dir" +} + +collect_tlparse_output() { + if [[ "${ENABLE_TORCH_TRACE:-0}" != "1" ]]; then + return + fi + local trace_dir="${RUNNER_TEMP:-/tmp}/torch_traces" + local test_reports_dir + test_reports_dir=$(pwd)/test/test-reports + + if [[ ! -d "$trace_dir" ]] || [[ -z "$(ls -A "$trace_dir" 2>/dev/null)" ]]; then + echo "No torch trace files found in $trace_dir, skipping tlparse" + return + fi + + echo "Collecting tlparse output from $trace_dir" + + # Install tlparse if not already available + if ! command -v tlparse &>/dev/null; then + pip install tlparse 2>/dev/null || { + echo "Warning: failed to install tlparse, skipping HTML generation" + return + } + fi + + # Run tlparse to generate HTML report + mkdir -p "$test_reports_dir/tlparse_output" + tlparse -o "$test_reports_dir/tlparse_output/" --no-browser --overwrite "$trace_dir" 2>&1 || { + echo "Warning: tlparse failed to generate HTML output" + return + } + + echo "TLParse output generated in $test_reports_dir/tlparse_output/" +} + test_dynamo_benchmark() { # Usage: test_dynamo_benchmark huggingface 0 TEST_REPORTS_DIR=$(pwd)/test/test-reports @@ -999,6 +1184,147 @@ test_inductor_torchbench_smoketest_perf() { done } +test_unbacked_parity_smoketest() { + # Check that unbacked batch-only has performance parity with backed batch-only + # Fails if any model regresses >THRESHOLD% consistently across 3 retries + TEST_REPORTS_DIR=$(pwd)/test/test-reports + mkdir -p "$TEST_REPORTS_DIR" + + local THRESHOLD=1.0 + local MAX_RETRIES=3 + local MODELS="MobileBertForMaskedLM|DistilBertForMaskedLM|DistillGPT2|T5Small" + + # Issue 6: Write per-run output files for post-failure debugging + run_comparison() { + local run_num=$1 + local output_file="$TEST_REPORTS_DIR/unbacked_parity_results_run${run_num}.txt" + python benchmarks/dynamo/huggingface.py \ + --compare-backed-unbacked \ + --performance --inference --inductor --device cuda \ + --filter "$MODELS" 2>&1 | tee "$output_file" + } + + check_regressions() { + local run_num=$1 + local output_file="$TEST_REPORTS_DIR/unbacked_parity_results_run${run_num}.txt" + # Parse the comparison table and check for regressions > threshold + # Returns 0 if regressions found, 1 if no regressions + local regressions=() + while IFS= read -r line; do + # Issue 3: Broadened regex to match model names with hyphens, slashes, dots + # Match lines like: " ModelName 10.000 10.500 +5.0%" + if [[ "$line" =~ ^[[:space:]]+([A-Za-z0-9_./-]+)[[:space:]]+([0-9.]+)[[:space:]]+([0-9.]+)[[:space:]]+\+([0-9.]+)% ]]; then + local model="${BASH_REMATCH[1]}" + local diff="${BASH_REMATCH[4]}" + # Nit: Use awk instead of bc -l to avoid dependency on bc + if awk "BEGIN{exit !($diff > $THRESHOLD)}"; then + regressions+=("$model:+${diff}%") + fi + fi + done < "$output_file" + + if [[ ${#regressions[@]} -gt 0 ]]; then + echo "Regressions found: ${regressions[*]}" + return 0 + fi + return 1 + } + + check_failures() { + local run_num=$1 + local output_file="$TEST_REPORTS_DIR/unbacked_parity_results_run${run_num}.txt" + # Issue 2: Check for any model failure — not just paired failures. + # Specifically flags when unbacked fails but backed succeeds (regression signal). + # Returns 0 if failures found, 1 if no failures + local current_model="" + local backed_failed=false + local unbacked_failed=false + local both_failures=() + local unbacked_only_failures=() + + # Append a sentinel header so the loop naturally evaluates the last real model + while IFS= read -r line; do + if [[ "$line" =~ ^---[[:space:]]+([A-Za-z0-9_./-]+)[[:space:]]+--- ]]; then + if [[ -n "$current_model" ]]; then + if $backed_failed && $unbacked_failed; then + both_failures+=("$current_model") + elif $unbacked_failed && ! $backed_failed; then + unbacked_only_failures+=("$current_model") + fi + fi + current_model="${BASH_REMATCH[1]}" + backed_failed=false + unbacked_failed=false + elif [[ "$line" =~ backed.*FAILED|backed.*TIMEOUT|backed.*ERROR ]]; then + backed_failed=true + elif [[ "$line" =~ unbacked.*FAILED|unbacked.*TIMEOUT|unbacked.*ERROR ]]; then + unbacked_failed=true + fi + done < <(cat "$output_file"; echo "--- END ---") + + local has_failures=false + if [[ ${#both_failures[@]} -gt 0 ]]; then + echo "❌ FAILURES DETECTED: Both backed and unbacked failed for: ${both_failures[*]}" + has_failures=true + fi + if [[ ${#unbacked_only_failures[@]} -gt 0 ]]; then + echo "❌ FAILURES DETECTED: Unbacked failed (but backed succeeded) for: ${unbacked_only_failures[*]}" + has_failures=true + fi + + if $has_failures; then + return 0 + fi + return 1 + } + + # Run initial comparison + echo "=== Run 1/$MAX_RETRIES ===" + run_comparison 1 + + # Check for failures first + if check_failures 1; then + echo "❌ Test failed: Models failed to run (see above for details)" + exit 1 + fi + + # Check for regressions + if ! check_regressions 1; then + echo "✅ PASSED: No regressions above ${THRESHOLD}% threshold" + exit 0 + fi + + # Regression detected - retry to confirm + local regression_count=1 + for ((retry=2; retry<=MAX_RETRIES; retry++)); do + echo "" + echo "=== Retry $retry/$MAX_RETRIES (potential regression detected) ===" + run_comparison "$retry" + + # Issue 4: Also check for failures on retries (e.g., intermittent OOM) + if check_failures "$retry"; then + echo "❌ Test failed: Models failed on retry $retry (see above for details)" + exit 1 + fi + + if check_regressions "$retry"; then + ((regression_count++)) + fi + done + + # Check if regression was consistent (majority of runs) + local required=$((MAX_RETRIES / 2 + 1)) + if [[ $regression_count -ge $required ]]; then + echo "" + echo "❌ REGRESSION CONFIRMED: Detected in $regression_count/$MAX_RETRIES runs (threshold: ${THRESHOLD}%)" + exit 1 + else + echo "" + echo "✅ PASSED: Regressions were not consistent ($regression_count/$MAX_RETRIES runs, needed $required)" + exit 0 + fi +} + test_inductor_set_cpu_affinity(){ JEMALLOC_LIB="$(find /usr/lib -name libjemalloc.so.2)" export LD_PRELOAD="$JEMALLOC_LIB":"$LD_PRELOAD" @@ -1510,7 +1836,9 @@ EOF fi # Ensure invalid item is in the test output. - echo "${test_output}" | grep -q "${invalid_item_name}" && ret=$? || ret=$? + # Use a here-string instead of a pipe to avoid SIGPIPE when grep -q + # exits early on large output (causes exit code 141 with pipefail). + grep -q "${invalid_item_name}" <<< "${test_output}" && ret=$? || ret=$? if [ $ret -ne 0 ]; then cat << EOF @@ -1654,74 +1982,6 @@ EOF assert_git_not_dirty } -test_bazel() { - set -e -o pipefail - - # bazel test needs sccache setup. - # shellcheck source=./common-build.sh - source "$(dirname "${BASH_SOURCE[0]}")/common-build.sh" - - get_bazel - - if [[ "$CUDA_VERSION" == "cpu" ]]; then - # Test //c10/... without Google flags and logging libraries. The - # :all_tests target in the subsequent Bazel invocation tests - # //c10/... with the Google libraries. - tools/bazel test --config=cpu-only --test_timeout=480 --test_output=all --test_tag_filters=-gpu-required --test_filter=-*CUDA \ - --no//c10:use_gflags --no//c10:use_glog //c10/... - - tools/bazel test --config=cpu-only --test_timeout=480 --test_output=all --test_tag_filters=-gpu-required --test_filter=-*CUDA :all_tests - else - # Increase the test timeout to 480 like CPU tests because modules_test frequently timeout - tools/bazel test --test_timeout=480 --test_output=errors \ - //:any_test \ - //:autograd_test \ - //:dataloader_test \ - //:dispatch_test \ - //:enum_test \ - //:expanding_array_test \ - //:fft_test \ - //:functional_test \ - //:grad_mode_test \ - //:inference_mode_test \ - //:init_test \ - //:jit_test \ - //:memory_test \ - //:meta_tensor_test \ - //:misc_test \ - //:moduledict_test \ - //:modulelist_test \ - //:modules_test \ - //:namespace_test \ - //:nested_test \ - //:nn_utils_test \ - //:operations_test \ - //:ordered_dict_test \ - //:parallel_benchmark_test \ - //:parameterdict_test \ - //:parameterlist_test \ - //:sequential_test \ - //:serialize_test \ - //:special_test \ - //:static_test \ - //:support_test \ - //:tensor_flatten_test \ - //:tensor_indexing_test \ - //:tensor_options_cuda_test \ - //:tensor_options_test \ - //:tensor_test \ - //:torch_dist_autograd_test \ - //:torch_include_test \ - //:transformer_test \ - //:test_bazel \ - //c10/cuda/test:test \ - //c10/test:core_tests \ - //c10/test:typeid_test \ - //c10/test:util/ssize_test \ - //c10/test:util_base_tests - fi -} - test_benchmarks() { if [[ "$BUILD_ENVIRONMENT" == *cuda* && $TEST_CONFIG != *nogpu* ]]; then pip_install "pytest-benchmark==3.2.3" @@ -1792,34 +2052,6 @@ test_executorch() { assert_git_not_dirty } -test_linux_aarch64() { - python test/run_test.py --include test_modules test_utils test_mkldnn test_mkldnn_fusion test_openmp test_torch test_dynamic_shapes \ - test_transformers test_multiprocessing test_numpy_interop test_autograd test_binary_ufuncs test_complex test_spectral_ops \ - test_foreach test_reductions test_unary_ufuncs test_tensor_creation_ops test_ops profiler/test_memory_profiler \ - distributed/elastic/timer/api_test distributed/elastic/timer/local_timer_example distributed/elastic/timer/local_timer_test \ - test_linalg \ - --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose - - # Dynamo tests - python test/run_test.py --include dynamo/test_compile dynamo/test_backends dynamo/test_comptime dynamo/test_config \ - dynamo/test_functions dynamo/test_fx_passes_pre_grad dynamo/test_interop dynamo/test_model_output dynamo/test_modules \ - dynamo/test_optimizers dynamo/test_recompile_ux dynamo/test_recompiles \ - --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose - - # Inductor tests - python test/run_test.py --include inductor/test_torchinductor inductor/test_benchmark_fusion inductor/test_codecache \ - inductor/test_config inductor/test_control_flow inductor/test_coordinate_descent_tuner inductor/test_fx_fusion \ - inductor/test_group_batch_fusion inductor/test_inductor_freezing inductor/test_inductor_utils \ - inductor/test_inplacing_pass inductor/test_kernel_benchmark inductor/test_layout_optim \ - inductor/test_max_autotune inductor/test_memory_planning inductor/test_metrics inductor/test_multi_kernel inductor/test_pad_mm \ - inductor/test_pattern_matcher inductor/test_perf inductor/test_profiler inductor/test_select_algorithm inductor/test_smoke \ - inductor/test_split_cat_fx_passes inductor/test_compile inductor/test_torchinductor \ - inductor/test_torchinductor_codegen_dynamic_shapes inductor/test_torchinductor_dynamic_shapes inductor/test_memory \ - inductor/test_triton_cpu_backend inductor/test_triton_extension_backend inductor/test_mkldnn_pattern_matcher inductor/test_cpu_cpp_wrapper \ - inductor/test_cpu_select_algorithm inductor/test_cpu_repro \ - --shard "$SHARD_NUMBER" "$NUM_TEST_SHARDS" --verbose -} - test_operator_benchmark() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" @@ -1887,11 +2119,14 @@ test_openreg() { assert_git_not_dirty } -if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then +if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then (cd test && python -c "import torch; print(torch.__config__.show())") (cd test && python -c "import torch; print(torch.__config__.parallel_info())") fi -if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then +if [[ "${TEST_CONFIG}" == "onnx" ]]; then + install_torchvision + "$(dirname "${BASH_SOURCE[0]}")/../../scripts/onnx/test.sh" +elif [[ "${TEST_CONFIG}" == *numpy_2* ]]; then # Install numpy-2.0.2 and compatible scipy & numba versions # Force re-install of pandas to avoid error where pandas checks numpy version from initial install and fails upon import TMP_PANDAS_VERSION=$(python -c "import pandas; print(pandas.__version__)" 2>/dev/null) @@ -1901,8 +2136,6 @@ if [[ "${TEST_CONFIG}" == *numpy_2* ]]; then python -m pip install --pre numpy==2.0.2 scipy==1.13.1 numba==0.60.0 fi python test/run_test.py --include dynamo/test_functions.py dynamo/test_unspec.py test_binary_ufuncs.py test_fake_tensor.py test_linalg.py test_numpy_interop.py test_tensor_creation_ops.py test_torch.py torch_np/test_basic.py -elif [[ "${BUILD_ENVIRONMENT}" == *aarch64* && "${TEST_CONFIG}" == 'default' ]]; then - test_linux_aarch64 elif [[ "${TEST_CONFIG}" == *backward* ]]; then test_forward_backward_compatibility # Do NOT add tests after bc check tests, see its comment. @@ -1953,7 +2186,9 @@ elif [[ "${TEST_CONFIG}" == *operator_microbenchmark* ]]; then elif [[ "${TEST_CONFIG}" == *attention_microbenchmark* ]]; then test_attention_microbenchmark elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then + setup_torch_trace test_inductor_distributed + collect_tlparse_output elif [[ "${TEST_CONFIG}" == *inductor-halide* ]]; then test_inductor_halide elif [[ "${TEST_CONFIG}" == *inductor-pallas* ]]; then @@ -1967,11 +2202,19 @@ elif [[ "${TEST_CONFIG}" == *aoti_cross_compile_for_windows* ]]; then elif [[ "${TEST_CONFIG}" == *huggingface* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) - test_dynamo_benchmark huggingface "$id" + setup_torch_trace + if [[ "${TEST_CONFIG}" == *unbacked_parity* ]]; then + test_unbacked_parity_smoketest + else + test_dynamo_benchmark huggingface "$id" + fi + collect_tlparse_output elif [[ "${TEST_CONFIG}" == *timm* ]]; then install_torchvision id=$((SHARD_NUMBER-1)) + setup_torch_trace test_dynamo_benchmark timm_models "$id" + collect_tlparse_output elif [[ "${TEST_CONFIG}" == cachebench ]]; then install_torchaudio install_torchvision @@ -2002,19 +2245,27 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then LIBTBB_PATH="$(find "$(dirname "$(which python)")/../lib/" -name libtbb.so.12)" export LD_PRELOAD="$LIBTBB_PATH":"$LD_PRELOAD" fi + setup_torch_trace PYTHONPATH=/torchbench test_dynamo_benchmark torchbench "$id" + collect_tlparse_output fi elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then install_torchvision + setup_torch_trace PYTHONPATH=/torchbench test_inductor_cpp_wrapper_shard "$SHARD_NUMBER" if [[ "$SHARD_NUMBER" -eq "1" ]]; then test_inductor_aoti_cpp fi + collect_tlparse_output elif [[ "${TEST_CONFIG}" == *inductor_core* ]]; then + setup_torch_trace test_inductor_core + collect_tlparse_output elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision + setup_torch_trace test_inductor_shard "${SHARD_NUMBER}" + collect_tlparse_output elif [[ "${TEST_CONFIG}" == *einops* ]]; then test_einops elif [[ "${TEST_CONFIG}" == *dynamo_core* ]]; then @@ -2056,8 +2307,6 @@ elif [[ "${SHARD_NUMBER}" -gt 2 ]]; then test_python_shard "$SHARD_NUMBER" elif [[ "${BUILD_ENVIRONMENT}" == *vulkan* ]]; then test_vulkan -elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then - test_bazel elif [[ "${BUILD_ENVIRONMENT}" == *-mobile-lightweight-dispatch* ]]; then test_libtorch elif [[ "${TEST_CONFIG}" = docs_test ]]; then @@ -2080,6 +2329,8 @@ elif [[ "${TEST_CONFIG}" == h100_cutlass_backend ]]; then test_h100_cutlass_backend elif [[ "${TEST_CONFIG}" == openreg ]]; then test_openreg +elif [[ "${TEST_CONFIG}" == "tsan" ]]; then + test_tsan else install_torchvision install_monkeytype diff --git a/.ci/pytorch/test_example_code/CMakeLists.txt b/.ci/pytorch/test_example_code/CMakeLists.txt index 0af9ba6bb655d..c89a728f27649 100644 --- a/.ci/pytorch/test_example_code/CMakeLists.txt +++ b/.ci/pytorch/test_example_code/CMakeLists.txt @@ -8,7 +8,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") add_executable(simple-torch-test simple-torch-test.cpp) target_include_directories(simple-torch-test PRIVATE ${TORCH_INCLUDE_DIRS}) target_link_libraries(simple-torch-test "${TORCH_LIBRARIES}") -set_property(TARGET simple-torch-test PROPERTY CXX_STANDARD 17) +set_property(TARGET simple-torch-test PROPERTY CXX_STANDARD 20) find_package(CUDAToolkit 11.8) diff --git a/.ci/pytorch/windows/arm64/build_libtorch.bat b/.ci/pytorch/windows/arm64/build_libtorch.bat index 1ac14ff697730..33080af122c73 100644 --- a/.ci/pytorch/windows/arm64/build_libtorch.bat +++ b/.ci/pytorch/windows/arm64/build_libtorch.bat @@ -27,8 +27,8 @@ where cl.exe :: change to source directory cd %PYTORCH_ROOT% -:: copy libuv.dll -copy %libuv_ROOT%\lib\Release\uv.dll torch\lib\uv.dll +:: copy libuv.dll (cmake installs the dll to bin/, not lib/Release/) +copy %libuv_ROOT%\bin\uv.dll torch\lib\uv.dll :: create virtual environment python -m venv .venv diff --git a/.ci/pytorch/windows/arm64/build_pytorch.bat b/.ci/pytorch/windows/arm64/build_pytorch.bat index b5c2ef65b84ad..7d10b26339d25 100644 --- a/.ci/pytorch/windows/arm64/build_pytorch.bat +++ b/.ci/pytorch/windows/arm64/build_pytorch.bat @@ -5,6 +5,7 @@ set CMAKE_BUILD_TYPE=%BUILD_TYPE% set CMAKE_C_COMPILER_LAUNCHER=sccache set CMAKE_CXX_COMPILER_LAUNCHER=sccache set libuv_ROOT=%DEPENDENCIES_DIR%\libuv\install +set INSTALL_TEST=0 set MSSdk=1 if defined PYTORCH_BUILD_VERSION ( set PYTORCH_BUILD_VERSION=%PYTORCH_BUILD_VERSION% @@ -27,8 +28,8 @@ where cl.exe :: change to source directory cd %PYTORCH_ROOT% -:: copy libuv.dll -copy %libuv_ROOT%\lib\Release\uv.dll torch\lib\uv.dll +:: copy libuv.dll (cmake installs the dll to bin/, not lib/Release/) +copy %libuv_ROOT%\bin\uv.dll torch\lib\uv.dll :: create virtual environment python -m venv .venv @@ -57,4 +58,4 @@ sccache --show-stats if %errorlevel% neq 0 ( echo "Failed on build_pytorch. (exitcode = %errorlevel%)" exit /b 1 -) \ No newline at end of file +) diff --git a/.ci/pytorch/windows/internal/cuda_config.bat b/.ci/pytorch/windows/internal/cuda_config.bat index 352eb8a3391bf..69059e2e46ee1 100644 --- a/.ci/pytorch/windows/internal/cuda_config.bat +++ b/.ci/pytorch/windows/internal/cuda_config.bat @@ -22,6 +22,10 @@ if "%CUDA_VER%"=="126" ( set "CUDA_DOTTED_VERSION=13.0" set "CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0" set "VISION_GENCODE=-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120" +) else if "%CUDA_VER%"=="132" ( + set "CUDA_DOTTED_VERSION=13.2" + set "CUDA_ARCH_LIST=7.5;8.0;8.6;9.0;10.0;12.0" + set "VISION_GENCODE=-gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_90,code=compute_90 -gencode=arch=compute_100,code=compute_100 -gencode=arch=compute_120,code=compute_120" ) else ( echo Unknown CUDA version: %CUDA_VER% exit /b 1 diff --git a/.ci/pytorch/windows/internal/cuda_install.bat b/.ci/pytorch/windows/internal/cuda_install.bat index 3538c7aa2d323..c1050edecc0b9 100644 --- a/.ci/pytorch/windows/internal/cuda_install.bat +++ b/.ci/pytorch/windows/internal/cuda_install.bat @@ -27,6 +27,7 @@ if %CUDA_VER% EQU 126 goto cuda126 if %CUDA_VER% EQU 128 goto cuda128 if %CUDA_VER% EQU 129 goto cuda129 if %CUDA_VER% EQU 130 goto cuda130 +if %CUDA_VER% EQU 132 goto cuda132 echo CUDA %CUDA_VERSION_STR% is not supported exit /b 1 @@ -42,19 +43,25 @@ goto cuda_download :cuda128 set CUDA_INSTALL_EXE=cuda_12.8.0_571.96_windows.exe set "ARGS=cuda_profiler_api_12.8 thrust_12.8 nvcc_12.8 cuobjdump_12.8 nvprune_12.8 nvprof_12.8 cupti_12.8 cublas_12.8 cublas_dev_12.8 cudart_12.8 cufft_12.8 cufft_dev_12.8 curand_12.8 curand_dev_12.8 cusolver_12.8 cusolver_dev_12.8 cusparse_12.8 cusparse_dev_12.8 npp_12.8 npp_dev_12.8 nvrtc_12.8 nvrtc_dev_12.8 nvml_dev_12.8 nvjitlink_12.8 nvtx_12.8" -set CUDNN_FOLDER=cudnn-windows-x86_64-9.19.0.56_cuda12-archive +set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda12-archive goto cuda_download :cuda129 set CUDA_INSTALL_EXE=cuda_12.9.1_576.57_windows.exe set "ARGS=cuda_profiler_api_12.9 thrust_12.9 nvcc_12.9 cuobjdump_12.9 nvprune_12.9 nvprof_12.9 cupti_12.9 cublas_12.9 cublas_dev_12.9 cudart_12.9 cufft_12.9 cufft_dev_12.9 curand_12.9 curand_dev_12.9 cusolver_12.9 cusolver_dev_12.9 cusparse_12.9 cusparse_dev_12.9 npp_12.9 npp_dev_12.9 nvrtc_12.9 nvrtc_dev_12.9 nvml_dev_12.9 nvjitlink_12.9 nvtx_12.9" -set CUDNN_FOLDER=cudnn-windows-x86_64-9.17.1.4_cuda12-archive +set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda12-archive goto cuda_download :cuda130 set CUDA_INSTALL_EXE=cuda_13.0.0_windows.exe set "ARGS=" -set CUDNN_FOLDER=cudnn-windows-x86_64-9.19.0.56_cuda13-archive +set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda13-archive +goto cuda_download + +:cuda132 +set CUDA_INSTALL_EXE=cuda_13.2.1_windows.exe +set "ARGS=" +set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda13-archive goto cuda_download :: Common download logic for CUDA toolkit, cuDNN, and ZLIB @@ -161,16 +168,20 @@ if %CUDA_VER% EQU 126 ( set EXPECTED_CUDNN_VERSION=9.10.2 ) if %CUDA_VER% EQU 128 ( - set CUDNN_FOLDER=cudnn-windows-x86_64-9.19.0.56_cuda12-archive - set EXPECTED_CUDNN_VERSION=9.19.0 + set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda12-archive + set EXPECTED_CUDNN_VERSION=9.20.0 ) if %CUDA_VER% EQU 129 ( - set CUDNN_FOLDER=cudnn-windows-x86_64-9.17.1.4_cuda12-archive - set EXPECTED_CUDNN_VERSION=9.17.1 + set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda12-archive + set EXPECTED_CUDNN_VERSION=9.20.0 ) if %CUDA_VER% EQU 130 ( - set CUDNN_FOLDER=cudnn-windows-x86_64-9.19.0.56_cuda13-archive - set EXPECTED_CUDNN_VERSION=9.19.0 + set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda13-archive + set EXPECTED_CUDNN_VERSION=9.20.0 +) +if %CUDA_VER% EQU 132 ( + set CUDNN_FOLDER=cudnn-windows-x86_64-9.20.0.48_cuda13-archive + set EXPECTED_CUDNN_VERSION=9.20.0 ) set "CUDNN_INSTALL_ZIP=%CUDNN_FOLDER%.zip" diff --git a/.ci/pytorch/windows/internal/smoke_test.bat b/.ci/pytorch/windows/internal/smoke_test.bat index 0dec6a04a8a5c..8cce63a81693d 100644 --- a/.ci/pytorch/windows/internal/smoke_test.bat +++ b/.ci/pytorch/windows/internal/smoke_test.bat @@ -1,4 +1,5 @@ set SRC_DIR=%~dp0 +set TARGET_OS=windows pushd %SRC_DIR%\.. @@ -98,6 +99,10 @@ echo Checking that basic CNN works %PYTHON_EXEC% %PYTORCH_ROOT%\.ci\pytorch\test_example_code\cnn_smoke.py if ERRORLEVEL 1 exit /b 1 +echo Running smoke_test.py +%PYTHON_EXEC% %PYTORCH_ROOT%\.ci\pytorch\smoke_test\smoke_test.py --package=torchonly --torch-compile-check disabled --runtime-error-check disabled +if ERRORLEVEL 1 exit /b 1 + goto end :libtorch diff --git a/.ci/wheel/build_wheel.sh b/.ci/wheel/build_wheel.sh index 6fb63c361f018..2563d5ba31765 100755 --- a/.ci/wheel/build_wheel.sh +++ b/.ci/wheel/build_wheel.sh @@ -97,7 +97,7 @@ fi whl_tmp_dir="${MAC_PACKAGE_WORK_DIR}/dist" mkdir -p "$whl_tmp_dir" -mac_version='macosx-11.0-arm64' +mac_version='macosx-14.0-arm64' libtorch_arch='arm64' # Create a consistent wheel package name to rename the wheel to @@ -125,27 +125,19 @@ popd export TH_BINARY_BUILD=1 export INSTALL_TEST=0 # dont install test binaries into site-packages -export MACOSX_DEPLOYMENT_TARGET=11.0 +export MACOSX_DEPLOYMENT_TARGET=14.0 EXTRA_CONDA_INSTALL_FLAGS="" CONDA_ENV_CREATE_FLAGS="" RENAME_WHEEL=false VERIFY_WHEELNAME=true case $desired_python in - 3.14t) - echo "Using 3.14 deps" - NUMPY_PINNED_VERSION="==2.1.0" - ;; - 3.14) - echo "Using 3.14t deps" - NUMPY_PINNED_VERSION="==2.1.0" - ;; - 3.13t) - echo "Using 3.13t deps" - NUMPY_PINNED_VERSION="==2.1.0" + 3.14*) + echo "Using ${desired_python} deps" + NUMPY_PINNED_VERSION="==2.3.4" ;; - 3.13) - echo "Using 3.13 deps" + 3.13*) + echo "Using ${desired_python} deps" NUMPY_PINNED_VERSION="==2.1.0" ;; 3.12) @@ -179,6 +171,7 @@ retry pip install "${PINNED_PACKAGES[@]}" -r "${pytorch_rootdir}/requirements.tx if [[ -d "/opt/llvm-openmp" ]]; then export OMP_PREFIX=/opt/llvm-openmp else + echo "libomp not found, installing via brew" retry brew install libomp fi diff --git a/.clang-tidy b/.clang-tidy index 2b8eb00bb03f5..3a26102cb753a 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -78,4 +78,5 @@ CheckOptions: cppcoreguidelines-special-member-functions.AllowSoleDefaultDtor: true cppcoreguidelines-special-member-functions.AllowImplicitlyDeletedCopyOrMove: true misc-header-include-cycle.IgnoredFilesList: 'format.h;ivalue.h;custom_class.h;Dict.h;List.h;IListRef.h' + performance-inefficient-vector-operation.VectorLikeClasses: '::std::vector;::c10::SmallVector' ... diff --git a/.claude/skills/document-public-apis/SKILL.md b/.claude/skills/document-public-apis/SKILL.md index 51cf5c2c34241..211b444fc5461 100644 --- a/.claude/skills/document-public-apis/SKILL.md +++ b/.claude/skills/document-public-apis/SKILL.md @@ -7,7 +7,9 @@ description: Document undocumented public APIs in PyTorch by removing functions This skill documents undocumented public APIs in PyTorch by removing entries from the coverage ignore lists in `docs/source/conf.py` and adding Sphinx autodoc directives (e.g., `autosummary`, `currentmodule`, `autoclass`, `automodule`) to the corresponding `.md` or `.rst` doc source files in `docs/source/`. -**"Documenting" means adding autodoc directives to doc source files — NEVER modifying Python source code.** Do not add or edit docstrings in `.py` files. Do not read or inspect Python source files. Sphinx will pull whatever docstring exists (or render an empty entry if none exists). Your only job is to add the correct directive to the correct doc file. +**"Documenting" means adding autodoc directives to doc source files — NEVER modifying Python source code.** Do not add or edit docstrings in `.py` files. Your only job is to add the correct directive to the correct doc file. + +**IMPORTANT: Before adding any function to the sphinx doctree, verify it has a real docstring.** Use a quick Python check (e.g., `python -c "from torch.module import func; print(bool(func.__doc__))"`) to confirm the function has actual documentation content — not just an empty docstring or a bare `.. warning:: This API is experimental` stub. Functions without meaningful docstrings should be left in the `coverage_ignore_functions`/`coverage_ignore_classes` lists. Adding undocumented functions to the doctree creates empty or near-empty pages that degrade documentation quality. ## Overview @@ -71,9 +73,24 @@ Work through the lists top-to-bottom. Choose enough groups to make meaningful pr If a module group has a **mix** of regular entries and entries with inline comments, still process the group — but only comment out the regular entries. Leave entries with inline comments untouched in the ignore list. +### Step 1b: Verify functions have actual docstrings + +For each function selected in Step 1, check that it has a meaningful docstring by running: + +```bash +python -c "from torch.module.path import func_name; doc = func_name.__doc__; print('HAS DOC' if doc and len(doc.strip()) > 80 else 'NO DOC'); print(repr(doc[:120]) if doc else 'None')" +``` + +A function has a **meaningful docstring** if it has real descriptive content — not just: +- `None` or empty string +- Only a `.. warning:: This API is experimental` stub with no description +- Only a one-line auto-generated signature + +**Functions without meaningful docstrings must stay in the ignore list.** Remove them from your batch. If an entire module group has no functions with docstrings, skip the whole group. + ### Step 2: Present the batch to the user -**Before making any edits**, present the selected module groups and their functions to the user. Show them organized by module: +**Before making any edits**, present the selected module groups and their functions to the user. Indicate which functions passed the docstring check and which were excluded. Show them organized by module: ``` Module: torch.ao.quantization.fx.convert @@ -315,8 +332,8 @@ Also delete any module label comments that no longer have active entries beneath ## Important notes -- **Follow the steps exactly as written.** Do not add extra investigation steps like importing Python modules to check docstrings, inspecting source code to verify function signatures, or running any commands not specified in the instructions. The `make coverage` step is the only verification needed — let it tell you what's wrong. -- **Never modify Python source files (`.py`).** This skill only edits `docs/source/conf.py` and doc source files (`.md`/`.rst`) in `docs/source/`. Do not add or edit docstrings, do not read Python source to check function signatures, do not inspect implementations. +- **Follow the steps exactly as written.** The `make coverage` step is the primary verification for correct Sphinx directives, and Step 1b's docstring check ensures you only document functions that have real content. +- **Never modify Python source files (`.py`).** This skill only edits `docs/source/conf.py` and doc source files (`.md`/`.rst`) in `docs/source/`. Do not add or edit docstrings. The only reason to inspect Python modules is in Step 1b to check whether a docstring exists — never to modify source code. - Entries are commented out in Step 3, verified in Step 6, and cleaned up in Step 8 after verification passes. Never delete uncommented entries directly. - **Read inline comments** on entries before deciding to document them. Entries marked `# deprecated`, `# documented as ...`, `# looks unintentionally public`, or `# legacy helper` should stay in the ignore list. - The `coverage_ignore_functions` list uses bare function names (not fully qualified), so the same name can appear multiple times for different modules. Use the module label comment above each entry to identify which module it belongs to. Be careful during Step 8 cleanup to only delete the correct commented-out lines — commented-out string entries have **quotes** (`# "func_name",`), module label comments do not. diff --git a/.claude/skills/metal-kernel/SKILL.md b/.claude/skills/metal-kernel/SKILL.md index 75e6684b73a40..18ba4f64570ea 100644 --- a/.claude/skills/metal-kernel/SKILL.md +++ b/.claude/skills/metal-kernel/SKILL.md @@ -325,6 +325,83 @@ python test/test_mps.py -k test_output_match_my_op python test/test_mps.py ``` +## Debugging Metal Kernels with `torch.mps.compile_shader` + +Use `torch.mps.compile_shader` to JIT-compile and test individual Metal kernels in isolation. This is invaluable for debugging multi-kernel pipelines where you need to verify each stage independently. + +### Basic Usage + +```python +import torch + +source = ''' +#include +using namespace metal; + +kernel void my_kernel( + const device float* input [[buffer(0)]], + device float* output [[buffer(1)]], + uint tid [[thread_position_in_grid]]) { + output[tid] = input[tid] * 2.0; +} +''' + +lib = torch.mps.compile_shader(source) + +inp = torch.tensor([1.0, 2.0, 3.0], device='mps') +out = torch.zeros(3, device='mps') +lib.my_kernel(inp, out, threads=[3, 1, 1], group_size=[3, 1, 1]) +torch.mps.synchronize() +print(out) # tensor([2., 4., 6.], device='mps:0') +``` + +### Dispatch Semantics + +`compile_shader` uses **`dispatchThreads`** semantics (same as `mtl_dispatch1DJob` in PyTorch): +- `threads=[N, 1, 1]` — total number of threads (NOT threadgroups) +- `group_size=[G, 1, 1]` — threads per threadgroup + +This differs from the `dispatchThreadgroups` API used by some host-side code. To match `dispatchThreadgroups:MTLSizeMake(num_tgs, num_slices, 1) threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1)`: + +```python +# Equivalent compile_shader call: +lib.kernel(args..., + threads=[num_tgs * TG_SIZE, num_slices, 1], + group_size=[TG_SIZE, 1, 1]) +``` + +### Constant Buffer Parameters + +Pass scalar constants as single-element tensors: + +```python +slice_size = torch.tensor([1024], dtype=torch.int32, device='mps') +lib.my_kernel(data, output, slice_size, threads=[1024, 1, 1], group_size=[256, 1, 1]) +``` + +### Debugging Strategy for Multi-Kernel Pipelines + +When a pipeline of kernels (e.g., histogram → prefix_sum → scatter) produces wrong results, test each kernel individually and verify its output against a Python/NumPy reference: + +```python +# 1. Run GPU kernel +lib.histogram(keys, hist, ..., threads=[N, 1, 1], group_size=[256, 1, 1]) +torch.mps.synchronize() + +# 2. Compute reference in Python +ref_hist = compute_histogram_cpu(keys.cpu().numpy(), ...) + +# 3. Compare +assert np.array_equal(hist.cpu().numpy(), ref_hist), "Histogram mismatch!" +``` + +This isolates which kernel in the pipeline is broken, rather than debugging the entire pipeline at once. + +### Common Pitfalls + +- **Wrong `threads` count** — `threads` is total threads, not threadgroups. For 5 threadgroups of 256, use `threads=[1280, 1, 1]`. +- **Threadgroup memory** — `compile_shader` doesn't support `[[threadgroup(N)]]` parameters directly. If your kernel needs threadgroup memory, restructure to use `threadgroup` arrays declared inside the kernel body instead. + ## Checklist - [ ] Added MPS dispatch to `native_functions.yaml` diff --git a/.claude/skills/pr-review/SKILL.md b/.claude/skills/pr-review/SKILL.md index ff9cefbc36ced..098c68a110d86 100644 --- a/.claude/skills/pr-review/SKILL.md +++ b/.claude/skills/pr-review/SKILL.md @@ -5,7 +5,7 @@ description: Review PyTorch pull requests for code quality, test coverage, secur # PyTorch PR Review Skill -Review PyTorch pull requests focusing on what CI cannot check: code quality, test coverage adequacy, security vulnerabilities, and backward compatibility. Linting, formatting, type checking, and import ordering are handled by CI. +Review PyTorch pull requests focusing on what CI cannot check: code quality, test coverage adequacy, security vulnerabilities, and backward compatibility. ## Usage Modes @@ -115,90 +115,40 @@ only the diff and commit log need to be fetched via git. A single line of code can have deep cross-cutting implications: a missing device guard causes silent data corruption on multi-GPU, a missing `Composite` dispatch key breaks every out-of-tree backend, a manual dtype check instead of `TensorIterator` silently skips type promotion. **Treat every line as potentially load-bearing.** -Do not skim. Do not summarize the diff and move on. Read every changed line and ask: *does this interact with existing PyTorch infrastructure that the author may not know about?* When uncertain, **investigate** — spawn a sub-agent to read the surrounding code, the infrastructure the PR should be using, or the tests that should exist. The cost of a false negative (missing a real issue) is much higher than the cost of investigation. +1. **Investigate, don't guess** — When uncertain whether a checklist item applies, spawn a sub-agent to read the relevant code. A reviewer who guesses wrong provides negative value. +2. **Review the design, not just the implementation** — A PR can have perfectly correct implementation of a bad design. Question side-channel communication, on/off private flags, and demand concrete interface documentation for new contracts between components. +3. **Focus on what CI cannot check** — Don't comment on formatting, linting, type errors, or CI failures. Focus on design quality, interface correctness, thread safety, BC implications, test adequacy, and pattern adherence. +4. **Everything is a must-fix** — There are no "nits." If it's worth mentioning, it's worth fixing. Every inconsistency degrades the codebase over time. +5. **Be specific and actionable** — Reference file paths and line numbers. Name the function/class/file the author should use. +6. **Match the immediate context** — Read how similar features are already implemented in the same file. Pattern mismatches within a file are always wrong. +7. **Assume competence** — The author knows PyTorch; explain only non-obvious context. +8. **No repetition** — Each observation appears in exactly one section of the review output. -## Review Workflow - -### Step 1: Fetch PR Information +### Using sub-agents -**Local CLI mode**: Use `gh` commands to get PR metadata, changed files, full diff, -existing comments/reviews, and associated issue information. +The review checklist is large. You cannot hold the full context of every infrastructure system in your head. **Spawn sub-agents** to investigate whether checklist items apply: read surrounding code, infrastructure the PR should be using, or tests that should exist. Spawn them in parallel for independent areas. A typical medium PR should spawn 3-8 sub-agents. -**Local Branch mode**: Use `git diff` and `git log` against `main` as shown in the -Local Branch Mode section above. - -**GitHub Actions mode**: PR metadata, comments, and reviews are already in the prompt. -Use `git diff origin/...HEAD` for the full diff and -`git log origin/..HEAD --oneline` for the commit log. +## Review Workflow -### Step 2: Understand Context +### Step 1: Understand Context Before reviewing, build understanding of what the PR touches and why: 1. Identify the purpose of the change from title/description/issue 2. Group changes by type (new code, tests, config, docs) 3. Note the scope of changes (files affected, lines changed) -4. **Spawn sub-agents to read the unchanged code surrounding each changed file.** The diff alone is not enough — you need to understand the existing patterns, base classes, and infrastructure in the files being modified. For each significantly changed file, a sub-agent should read the full file (or the relevant class/function) and report back: what patterns does this file follow? What infrastructure does it use? What invariants does it maintain? - -### Step 3: Deep Review — Line-by-Line with Investigation - -This is the core of the review. Go through **every changed line** in the diff and evaluate it against the review checklist in [review-checklist.md](review-checklist.md). - -**How to use sub-agents during review:** - -The checklist is large. You cannot hold the full context of every infrastructure system in your head. Instead, when you encounter a changed line that touches a checklist area, **spawn a sub-agent** to investigate whether the checklist item applies. For example: +4. Spawn sub-agents to read the unchanged code surrounding each significantly changed file to understand existing patterns and infrastructure -- A PR adds a new C++ kernel → spawn a sub-agent to check: Does it use TensorIterator? DispatchStub? Structured kernels? AT_DISPATCH? Does it have a meta implementation? A Composite fallback? -- A PR adds a new test → spawn a sub-agent to check: Does an OpInfo exist for this op? Is the test device-generic? Does it use make_tensor, @dtypes, TestCase? -- A PR modifies autograd code → spawn a sub-agent to check: Is derivatives.yaml the right place? Does it use setup_context? Does it have gradcheck tests? -- A PR adds a new operator → spawn a sub-agent to check: Is it in native_functions.yaml? Does it have proper tags? A Composite dispatch? Meta/fake impls? Schema annotations? +### Step 2: Deep Review -**Spawn sub-agents in parallel** for independent investigation areas. A typical review of a medium PR should spawn 3-8 sub-agents. Large PRs touching multiple subsystems may need more. +Go through **every changed line** in the diff and evaluate it against the review checklist in [review-checklist.md](review-checklist.md). -**Checklist areas** (see [review-checklist.md](review-checklist.md) for full details): -- Code quality and design -- PyTorch infrastructure — C++ kernels (TensorIterator, DispatchStub, AT_DISPATCH, device guards), CUDA/device management, operator registration and codegen (native_functions.yaml, Composite dispatch, meta/fake implementations), autograd (derivatives.yaml, autograd.Function, gradcheck), Python utilities (pytree, __torch_function__, logging), nn module patterns, Dynamo/Inductor/compile, FX/export, type promotion, serialization, distributed, tensor subclasses -- Testing adequacy (OpInfo, ModuleInfo, device-generic tests, @dtypes, @parametrize, make_tensor) -- Security considerations -- Thread safety and concurrency (Python, C++, CPython C API, NoGIL) -- Performance implications -- Any behavior change not expected by author +### Step 3: Check Backward Compatibility -### Step 4: Check Backward Compatibility +Evaluate BC implications per [bc-guidelines.md](bc-guidelines.md). For non-trivial BC questions, spawn a sub-agent to search for existing callers of the modified API. -Evaluate BC implications. See [bc-guidelines.md](bc-guidelines.md) for: -- What constitutes a BC-breaking change -- Required deprecation patterns -- Common BC pitfalls +### Step 4: Formulate Review -For non-trivial BC questions (e.g., "does changing this default break downstream users?"), spawn a sub-agent to search for existing callers of the modified API. - -### Step 5: Formulate Review - -Structure your review with actionable feedback organized by category. Every finding should be traceable to a specific line in the diff and a specific checklist item. - -## Review Areas - -| Area | Focus | Reference | -|------|-------|-----------| -| Code Quality | Abstractions, patterns, complexity | [review-checklist.md](review-checklist.md) | -| API Design | New patterns, flag-based access, broader implications | [review-checklist.md](review-checklist.md) | -| C++ Kernels | TensorIterator, DispatchStub, AT_DISPATCH, structured kernels, device guards, memory format | [review-checklist.md](review-checklist.md) | -| CUDA/Device | C10_CUDA_CHECK, stream/event guards, recordStream, CUDA graphs, AcceleratorHooks | [review-checklist.md](review-checklist.md) | -| Op Registration | native_functions.yaml, Composite fallback, meta/fake impls, tags, schema annotations | [review-checklist.md](review-checklist.md) | -| Autograd | derivatives.yaml, autograd.Function patterns, gradcheck, forward-mode AD, vmap | [review-checklist.md](review-checklist.md) | -| Python Utils | __torch_function__, pytree, logging, deprecation, backends context | [review-checklist.md](review-checklist.md) | -| nn Modules | ModuleList/Dict, nn.init, parametrize, state_dict versioning, LazyModule | [review-checklist.md](review-checklist.md) | -| Dynamo/Inductor | @register_lowering, decompositions, CustomGraphPass, config.patch, graph breaks | [review-checklist.md](review-checklist.md) | -| FX/Export | PassBase, PassManager, Interpreter, subgraph rewriter, ShapeProp, make_fx | [review-checklist.md](review-checklist.md) | -| Type Promotion | elementwise_dtypes, TensorIterator dtype handling, result_type, promoteTypes | [review-checklist.md](review-checklist.md) | -| Serialization | weights_only, safe_globals, skip_data | [review-checklist.md](review-checklist.md) | -| Distributed | DeviceMesh, distributed testing with MultiThreadedPG | [review-checklist.md](review-checklist.md) | -| Tensor Subclasses | _make_wrapper_subclass, __tensor_flatten__/__unflatten__ | [review-checklist.md](review-checklist.md) | -| Testing | OpInfo, ModuleInfo, device-generic, @dtypes, @parametrize, make_tensor | [review-checklist.md](review-checklist.md) | -| Security | Injection, credentials, input handling | [review-checklist.md](review-checklist.md) | -| Performance | Regressions, device handling, memory, profiling, benchmarking | [review-checklist.md](review-checklist.md) | -| Thread Safety | Data races, GIL assumptions, NoGIL, CPython C API | [review-checklist.md](review-checklist.md) | -| BC | Breaking changes, deprecation | [bc-guidelines.md](bc-guidelines.md) | +Structure your review with actionable feedback organized by category. Every finding should be traceable to a specific line in the diff. ## Output Format @@ -240,6 +190,8 @@ Reference the specific infrastructure the PR should be using.] ### Recommendation **Approve** / **Request Changes** / **Needs Discussion** +Missing tests (new functionality without tests, bug fixes without regression tests) always means **Request Changes**. + [Brief justification for recommendation] ``` @@ -258,20 +210,9 @@ When requested, add file-specific feedback with line references: - `torch/nn/modules/linear.py:78` - This allocation could be moved outside the loop ``` -## Key Principles - -1. **Investigate, don't guess** - When uncertain whether a checklist item applies, spawn a sub-agent to read the relevant infrastructure code. A reviewer who guesses wrong provides negative value. A reviewer who investigates and reports findings provides immense value. -2. **Every line matters** - A single missing `C10_CUDA_KERNEL_LAUNCH_CHECK()`, a single `weights_only=False`, a single missing Composite dispatch key — each of these is a real bug that affects real users. Do not skip lines. -3. **No repetition** - Each observation appears in exactly one section. Never repeat the same issue, concern, or suggestion across multiple sections. If an issue spans categories (e.g., a security issue that also affects performance), place it in the most relevant section only. -4. **Focus on what CI cannot check** - Don't comment on formatting, linting, or type errors -5. **Be specific** - Reference file paths and line numbers. Every finding should point to a concrete line in the diff. -6. **Be actionable** - Provide concrete suggestions with the right infrastructure to use, not vague concerns. If flagging a missing pattern, name the function/class/file the author should use. -7. **Be proportionate** - Minor issues shouldn't block, but note them -8. **Assume competence** - The author knows PyTorch; explain only non-obvious context. The value of this review is in catching infrastructure patterns the author may not know about, not in explaining basic programming. - ## Files to Reference -When reviewing, consult these project files for context. **Spawn sub-agents to read these** rather than relying on memory — the files change frequently: +When reviewing, consult these project files for context — read them rather than relying on memory, as they change frequently: - `CLAUDE.md` - Coding style philosophy and testing patterns - `CONTRIBUTING.md` - PR requirements and review process - `torch/testing/_internal/common_utils.py` - Test patterns and utilities diff --git a/.claude/skills/pr-review/review-checklist.md b/.claude/skills/pr-review/review-checklist.md index 2151f2977b2be..6fa73391074ed 100644 --- a/.claude/skills/pr-review/review-checklist.md +++ b/.claude/skills/pr-review/review-checklist.md @@ -7,7 +7,10 @@ This checklist covers areas that CI cannot check. Skip items related to linting, ### Abstractions and Design - [ ] **Clear abstractions** - State management is explicit; no dynamic attribute setting/getting -- [ ] **Match existing patterns** - Code follows architectural patterns already in the codebase +- [ ] **No side-channel communication** - If behavior changes based on a hidden flag or dynamically-set attribute, the interface itself should change instead (different function signature, different class, different code path). Side-channel patterns (set a private flag in one place, check it in another via `getattr`) create undocumented behavioral modes +- [ ] **Proper interface, not on/off flags** - A private boolean that switches between two fundamentally different behaviors should be two separate code paths or a proper interface change, not a flag +- [ ] **Interface documentation** - New internal calling conventions, protocols, or contracts between components must have concrete documentation: what the caller provides, what the callee receives, what invariants hold, and cleanup responsibilities. Motivational comments ("this allows X") are not interface documentation +- [ ] **Match existing patterns in the same file** - Before accepting new code in a file, read how similar features are already implemented in that same file. If the file uses class attributes for boolean flags, new boolean flags must use class attributes. If the file uses a specific setter pattern, new setters must use the same pattern - [ ] **No over-engineering** - Only requested changes are made; no speculative features - [ ] **No premature abstraction** - Helpers and utilities are only created when reused; three similar lines is better than a one-use helper - [ ] **No trivial helpers** - Avoid 1-2 LOC helper functions used only once (unless significantly improves readability) @@ -36,17 +39,9 @@ When a PR introduces new API patterns, carefully evaluate the broader implicatio - [ ] **No fragile init ordering** - If multiple imports/calls must happen in a specific undocumented order, flag the design. Dependencies should be explicit or combined into a single entry point - [ ] **Idempotent global state** - Registries and global lists that accumulate entries must handle multiple calls safely (no duplicate registration, clear cleanup story) -### Common Issues to Flag - -- Dynamic `setattr`/`getattr` for state management (prefer explicit class members) -- Unused imports, variables, or dead code paths -- Copy-pasted code that could be a shared helper -- Magic numbers without explanation -- Overly defensive error handling for impossible cases - ## PyTorch Infrastructure -When a PR touches code in the scope of any item below, **stop and investigate** whether the established infrastructure should be used. Spawn a sub-agent to read the relevant infrastructure code and determine if the PR should be using it instead of rolling its own solution. +When a PR touches code in the scope of any item below, **stop and investigate** whether the established infrastructure should be used. ### C++ Kernel Infrastructure @@ -55,8 +50,9 @@ When a PR touches code in the scope of any item below, **stop and investigate** - [ ] **Structured Kernels** — PR adds a new ATen operator with separate hand-written functional, inplace, and out= variants instead of using `structured: True` + `structured_delegate` in `native_functions.yaml` to generate boilerplate - [ ] **TORCH_CHECK variants** — PR uses generic `TORCH_CHECK` for conditions that have a more specific variant: `ValueError` → `TORCH_CHECK_VALUE`, `IndexError` → `TORCH_CHECK_INDEX`, `TypeError` → `TORCH_CHECK_TYPE`, `NotImplementedError` → `TORCH_CHECK_NOT_IMPLEMENTED` - [ ] **AT_DISPATCH macros** — PR manually switches on `dtype` with `if (dtype == kFloat) ... else if (dtype == kDouble)` instead of using `AT_DISPATCH_FLOATING_TYPES`, `AT_DISPATCH_ALL_TYPES_AND`, or the `AT_DISPATCH_SWITCH` / `AT_DISPATCH_CASE` pattern from `aten/src/ATen/Dispatch.h` -- [ ] **Device guards (RAII)** — PR manually saves/restores device context (`cudaSetDevice` + try/catch) instead of using `DeviceGuard` or `OptionalDeviceGuard` from `c10/core/DeviceGuard.h` +- [ ] **Device guards (RAII)** — PR manually saves/restores device context (`cudaSetDevice` + try/catch) instead of using `DeviceGuard` or `OptionalDeviceGuard` from `c10/core/DeviceGuard.h`. **Note:** Operators registered in `native_functions.yaml` get automatic `DeviceGuard` insertion from codegen (controlled by `device_guard: True`, the default) — do NOT flag missing device guards for these ops unless they explicitly set `device_guard: False` - [ ] **Memory format propagation** — PR allocates output tensors with `at::empty(shape, options)` (defaulting to contiguous) without calling `input.suggest_memory_format()` to preserve ChannelsLast or other input formats +- [ ] **Subclass-safe tensor allocation** — PR uses `at::empty(shape, input.options())` instead of `input.new_empty(shape)` or `at::empty_like(input)`, which don't propagate tensor subclass metadata - [ ] **TORCH_LIBRARY operator registration** — PR registers operators using manual dispatcher calls instead of `TORCH_LIBRARY` / `TORCH_LIBRARY_IMPL` macros from `torch/library.h` - [ ] **TORCH_WARN_DEPRECATION** — PR uses `TORCH_WARN` for deprecation notices instead of `TORCH_WARN_DEPRECATION` which issues a proper `DeprecationWarning` @@ -164,6 +160,7 @@ When a PR touches code in the scope of any item below, **stop and investigate** ### Test Existence - [ ] **Tests exist** - New functionality has corresponding tests +- [ ] **Regression tests for bug fixes** - Bug fixes must include a test that reproduces the bug before the fix - [ ] **Tests are in the right place** - Tests should be added to an existing test file next to other related tests - [ ] **New test file is rare** - New test file should only be added when new major features are added @@ -189,32 +186,9 @@ When a PR touches code in the scope of any item below, **stop and investigate** ### Test Quality - [ ] **Edge cases covered** - Tests include boundary conditions, empty inputs, error cases -- [ ] **Error conditions tested** - Expected exceptions are tested with `assertRaises` or `assertRaisesRegex` -- [ ] **No duplicated test logic** - Similar tests share a private helper method (e.g., `_test_foo(config)`) called from individual tests with different configs - -**Example of good test structure:** -```python -def _test_feature_with_config(self, flag, expected_shape): - """Shared test logic called by device-specific tests.""" - x = torch.randn(10) - result = my_feature(x, flag) - self.assertEqual(result.shape, expected_shape) - -def test_feature_enabled(self): - self._test_feature_with_config(True, (10, 10)) - -def test_feature_disabled(self): - self._test_feature_with_config(False, (10, 5)) -``` - -### Common Testing Issues - -- Tests that only check the happy path without error cases -- Duplicated test code that should be a parameterized helper -- Manual operator tests that duplicate existing OpInfo coverage — the fix should update the OpInfo's dtype list instead -- Tests that don't clean up resources (files, CUDA memory) -- Flaky tests (timing-dependent, order-dependent, golden value) -- Tests that skip without clear justification +- [ ] **Error conditions tested** - Expected exceptions are tested with `assertRaisesRegex`, not bare `assertRaises`. `assertRaisesRegex` verifies both the exception type and message, catching cases where the right exception is raised for the wrong reason. Bare `assertRaises` should be flagged — always require a message pattern match +- [ ] **No duplicated test logic** - Similar tests share a private helper method called from individual tests with different configs +- [ ] **Use weakref for lifetime testing** - PR uses `sys.getrefcount()` to test whether objects are kept alive. Use `weakref.ref()` instead — create a weak reference, delete the strong references, then check if the weakref is dead (`wr() is None`). `sys.getrefcount` is a CPython implementation detail that varies across versions and is fragile ## Security @@ -258,7 +232,7 @@ This is particularly important for PyTorch's autograd, which has multi-threaded - [ ] **GIL held for Python object access** - Any code that touches `PyObject*` (incref, decref, attribute access, container mutation) must hold the GIL. When releasing the GIL for long-running C++ work (`Py_BEGIN_ALLOW_THREADS`), verify no Python objects are accessed in that region - [ ] **Borrowed references across GIL release** - Borrowed references (`PyTuple_GET_ITEM`, `PyList_GET_ITEM`) become unsafe if the GIL is released and reacquired, since another thread may have mutated the container -- [ ] **Decref-before-update hazard** - When replacing an item in a container (tuple, list, dict), update the container slot first, then `Py_DECREF` the old value. Decref can trigger `__del__` finalizers that re-enter and observe the container in an inconsistent state. Without the GIL (free-threaded builds), this is also a data race. +- [ ] **Decref-before-update hazard** - When replacing an item in a container (tuple, list, dict), update the container slot first, then `Py_DECREF` the old value. Decref can trigger `__del__` finalizers that re-enter and observe the container in an inconsistent state. Without the GIL (free-threaded builds), this is also a data race. This is **always** a must-fix — even if "safe in practice" because of refcount guarantees, the pattern is wrong and breaks under NoGIL. The correct pattern costs nothing extra ### Free-Threaded Python (NoGIL, PEP 703) @@ -298,10 +272,3 @@ CPython 3.13t+ can run without the GIL. Code that was previously safe under the - [ ] **Use torch.profiler** - PR adds manual `time.time()` instrumentation instead of using `torch.profiler.profile()` context manager with `schedule()` and `tensorboard_trace_handler()` - [ ] **Use torch.utils.benchmark.Timer** - PR benchmarks with `time.time()` loops instead of `torch.utils.benchmark.Timer` which handles warmup, statistics, and proper CUDA synchronization - -### Common Performance Issues - -- Creating new tensors inside training loops instead of pre-allocating -- Synchronous CUDA operations where async would work -- Keeping computation graph alive longer than needed -- Redundant clones or copies diff --git a/.claude/skills/pt2-bug-basher/SKILL.md b/.claude/skills/pt2-bug-basher/SKILL.md index 8de7587d56f99..5c30f9002f79c 100644 --- a/.claude/skills/pt2-bug-basher/SKILL.md +++ b/.claude/skills/pt2-bug-basher/SKILL.md @@ -10,17 +10,19 @@ Debug test failures and runtime errors in the PyTorch 2 compiler stack (Dynamo, ## Workflow Summary -1. **Reproduce** -- Get a consistent reproduction of the failure -2. **Minimize** -- Reduce the repro to the smallest possible standalone case. Strip away unrelated model logic, use minimal tensor shapes, and isolate the specific op or pattern that triggers the bug. -3. **Add a unit test** -- **Do this BEFORE diving into code search or root cause investigation.** Add a failing test to the codebase that captures the bug. Place it in a specific, topic-appropriate test file (e.g., `test/dynamo/test_repros.py`, `test/inductor/test_torchinductor.py`, `test/export/test_export.py`). **Avoid `test/dynamo/test_misc.py`** — it is already oversized; find a more specific test file that matches the area of the bug. Use `torch.testing._internal.common_utils.TestCase` and `run_tests`. The test must fail before the fix and pass after. Having the test first keeps you grounded — you know exactly what "fixed" looks like before you start exploring the codebase. -4. **Gather logs** -- Run with appropriate `TORCH_LOGS` settings -5. **Classify** -- Use the [Error Triage](#error-triage) table to identify the category -6. **Inspect artifacts** -- Check FX graphs, IR, and generated code via `TORCH_COMPILE_DEBUG=1` -7. **Identify root cause** -- Trace from the error back through the compilation pipeline -8. **Fix** -- Apply the fix -9. **Verify** -- Run the new unit test AND nearby related existing tests (e.g., if you changed how `is_exporting` works, also run the existing `test_is_exporting` export test). Use `pytest -k` to quickly run related tests by name. The task is not complete until all pass. -10. **Self-review** -- Use the `/pr-review` skill to review your own changes before presenting them. Fix any issues it flags. -11. **Celebrate** -- Summarize the changes: explain the root cause, what was changed and why, and which tests were added/verified. Then tell the user the bug is squashed. Include a fun, varied motivational message or easter egg to keep spirits high (e.g., a pun, a quote, an ASCII art bug getting squashed). Keep it short and different each time. +1. **Environment check** -- Ask the user which conda environment to use. Verify it is active by checking `$CONDA_DEFAULT_ENV`. Then run `python -c "import torch; print(torch.__version__)"` to confirm torch is importable and report the version. If the environment is not active or torch cannot be imported, stop and ask the user to activate the correct environment before proceeding. +2. **Reproduce** -- Get a consistent reproduction of the failure +3. **Minimize** -- Reduce the repro to the smallest possible standalone case. Strip away unrelated model logic, use minimal tensor shapes, and isolate the specific op or pattern that triggers the bug. +4. **Add a unit test** -- **Do this BEFORE diving into code search or root cause investigation.** Add a failing test to the codebase that captures the bug. Place it in a specific, topic-appropriate test file (e.g., `test/dynamo/test_repros.py`, `test/inductor/test_torchinductor.py`, `test/export/test_export.py`). **Avoid `test/dynamo/test_misc.py`** — it is already oversized; find a more specific test file that matches the area of the bug. Use `torch.testing._internal.common_utils.TestCase` and `run_tests`. The test must fail before the fix and pass after. Having the test first keeps you grounded — you know exactly what "fixed" looks like before you start exploring the codebase. +5. **Validate on main** -- Use `EnterWorktree` to create a worktree checked out at `main`. Copy the new test file into the worktree and run the test there to confirm it **fails** on main. If the test passes on main, stop — the test may not be capturing the right bug, or the bug may already be fixed. Exit the worktree with `ExitWorktree` (action: remove) and return to the working branch before continuing. +6. **Gather logs** -- Run with appropriate `TORCH_LOGS` settings +7. **Classify** -- Use the [Error Triage](#error-triage) table to identify the category +8. **Inspect artifacts** -- Check FX graphs, IR, and generated code via `TORCH_COMPILE_DEBUG=1` +9. **Identify root cause** -- Trace from the error back through the compilation pipeline +10. **Fix** -- Apply the fix +11. **Verify** -- Run the new unit test AND nearby related existing tests (e.g., if you changed how `is_exporting` works, also run the existing `test_is_exporting` export test). Use `pytest -k` to quickly run related tests by name. The task is not complete until all pass. +12. **Self-review** -- Use the `/pr-review` skill to review your own changes before presenting them. Fix any issues it flags. +13. **Celebrate** -- Summarize the changes: explain the root cause, what was changed and why, and which tests were added/verified. Then tell the user the bug is squashed. Include a fun, varied motivational message or easter egg to keep spirits high (e.g., a pun, a quote, an ASCII art bug getting squashed). Keep it short and different each time. ## Investigation Strategy diff --git a/.claude/skills/triaging-issues/SKILL.md b/.claude/skills/triaging-issues/SKILL.md index 9cc741540a9f4..e774d6dd40cb6 100644 --- a/.claude/skills/triaging-issues/SKILL.md +++ b/.claude/skills/triaging-issues/SKILL.md @@ -156,7 +156,7 @@ If the issue belongs in another repo (vision/text/audio/RL/ExecuTorch/etc.), tra **PT2 is NOT a redirect.** `oncall: pt2` is not like the other oncall labels in Step 3. PT2 issues continue through Steps 4–7 for full triage — add `oncall: pt2`, then proceed to label with `module:` labels, mark `triaged`, etc. -See [pt2-triage-rubric.md](pt2-triage-rubric.md) for detailed labeling decisions on which `module:` labels to apply. +**Every `oncall: pt2` issue MUST have at least one `module:` label.** The PT2 oncall queue is too broad without a module label — the team needs to know which component is affected (e.g., `module: dynamo`, `module: inductor`, `module: helion`, `module: dynamic shapes`). If you cannot determine the specific module, use `module: compile ux` as a fallback, but always try to be specific first. See [pt2-triage-rubric.md](pt2-triage-rubric.md) for detailed guidance. ### 3) Redirect to Secondary Oncall diff --git a/.claude/skills/triaging-issues/labels.json b/.claude/skills/triaging-issues/labels.json index 2e4191075a92c..fa3703186472b 100644 --- a/.claude/skills/triaging-issues/labels.json +++ b/.claude/skills/triaging-issues/labels.json @@ -105,7 +105,7 @@ }, { "name": "module: aotdispatch", - "description": "umbrella label for AOTAutograd issues" + "description": "Issues in the AOTAutograd subsystem (torch/_functorch/aot_autograd.py, torch/_functorch/_aot_autograd/). Covers tracing of the forward+backward graph, partitioning, runtime wrappers, and autograd cache serialization. If a traceback includes _aot_autograd/runtime_wrappers.py or _aot_autograd/functional_utils.py, this label applies. Prefer this over 'module: inductor' when the crash is in the autograd dispatch layer, not in code generation." }, { "name": "module: aotinductor", @@ -117,7 +117,7 @@ }, { "name": "module: autograd", - "description": "Related to torch.autograd, and the autograd engine in general" + "description": "Issues in the eager-mode autograd engine (torch/autograd/, torch/csrc/autograd/). Covers gradient computation, autograd functions, hooks, and the backward pass in eager mode. Do NOT apply when the issue is about torch.compile silently dropping or mishandling an autograd operation — if the bug only reproduces under torch.compile (not in eager), the root cause is likely in Dynamo tracing ('module: dynamo') or AOTAutograd ('module: aotdispatch'), not in the eager autograd engine." }, { "name": "module: backend", @@ -177,7 +177,7 @@ }, { "name": "module: compiled autograd", - "description": "compiled_autograd" + "description": "Issues in the compiled autograd subsystem (torch/csrc/autograd/compiled_autograd.*, torch/_dynamo/compiled_autograd.py). This is the system that compiles the autograd backward graph with Dynamo+Inductor. Do NOT apply just because 'autograd' appears in the issue — use 'module: aotdispatch' for AOTAutograd tracing/partitioning issues, and 'module: autograd' for eager autograd issues." }, { "name": "module: complex", @@ -269,7 +269,7 @@ }, { "name": "module: decompositions", - "description": "Topics related to decomposition (excluding PrimTorch)" + "description": "Issues with operator decompositions (torch/_decomp/). Decompositions rewrite higher-level ops into lower-level ones for tracing/compilation. Apply when traced/compiled results diverge from eager due to a decomposition producing wrong gradients, numerically incorrect results, or missing higher-order derivative support. Key signal: 'eager vs traced results diverge' + a specific op (e.g., sqrt, layer_norm) suggests a bad decomposition." }, { "name": "module: dependency bug", @@ -345,7 +345,7 @@ }, { "name": "module: dynamo", - "description": "Related to TorchDynamo (torch.compile frontend/tracing)" + "description": "Issues in TorchDynamo, the torch.compile frontend/tracing system (torch/_dynamo/). Covers graph capture, graph breaks, symbolic tracing of Python bytecode, and variable tracking. Key signal: if torch.compile silently drops, ignores, or mishandles an operation that works correctly in eager mode (e.g., in-place mutations like detach_() being silently skipped, side-effects not captured), the root cause is almost always Dynamo's tracing." }, { "name": "module: edge cases", @@ -365,7 +365,7 @@ }, { "name": "module: error checking", - "description": "Bugs related to incorrect/lacking error checking" + "description": "Bugs where PyTorch fails to validate inputs or raises a wrong/missing error message at a validation boundary (e.g., shape checks, dtype checks, device checks). Do NOT apply just because the user complains about a bad error message — if the root cause is a bug in a specific feature, use the feature-specific label instead." }, { "name": "module: expecttest", @@ -417,7 +417,7 @@ }, { "name": "module: fx", - "description": "" + "description": "Issues in torch.fx (torch/fx/) — the graph IR, symbolic tracing, graph transforms, and make_fx. Note: make_fx with symbolic tracing is PT2 infrastructure, so issues involving make_fx should also get 'oncall: pt2'." }, { "name": "module: fx.passes", @@ -437,7 +437,7 @@ }, { "name": "module: helion", - "description": "" + "description": "Helion (pytorch/helion) is a high-level Python DSL for writing GPU kernels that compiles to Triton. Users write kernels with @helion.kernel decorator and helion.language (hl.tile(), etc.) using familiar PyTorch operations. Helion kernels integrate with torch.compile/inductor via template fusion hooks. Apply when the issue involves helion kernel authoring, helion compilation, helion-to-Triton codegen, or helion+inductor template fusion. Do NOT apply for plain Triton kernel issues (use module: inductor) or general torch.compile issues unless helion is specifically involved." }, { "name": "module: higher order operators", @@ -453,7 +453,7 @@ }, { "name": "module: inductor", - "description": "Related to TorchInductor (torch.compile codegen backend)" + "description": "Issues in the TorchInductor code generation backend (torch/_inductor/). Covers kernel codegen, scheduling, fusion, and inductor-specific caching (TORCHINDUCTOR_* env vars). Do NOT apply just because the user mentions 'inductor cache' or 'torch.compile' — if the traceback points to _aot_autograd/ or _functorch/, use 'module: aotdispatch' instead." }, { "name": "module: infallible views", @@ -769,7 +769,7 @@ }, { "name": "module: pt2-dispatcher", - "description": "PT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op," + "description": "Umbrella label for PT2 dispatcher-layer issues: aotdispatch, functionalization, FakeTensor, custom-op registration, proxy tensor tracing. Apply alongside the more specific label (e.g., 'module: aotdispatch') when the issue touches the dispatcher layer. If unsure whether a PT2 issue is inductor vs dispatcher, check the traceback: _functorch/ and _aot_autograd/ paths mean dispatcher, _inductor/ paths mean inductor." }, { "name": "module: pybind", @@ -789,7 +789,7 @@ }, { "name": "module: python version", - "description": "Issues related to specific Python versions" + "description": "Issues where behavior differs across Python versions or where a specific Python version causes breakage (e.g., 'works on 3.11 but fails on 3.12'). Do NOT apply just because the reporter mentions a Python version in their environment info, test path, or conformance test name — the bug must be specifically about Python version compatibility." }, { "name": "module: pytree", diff --git a/.claude/skills/triaging-issues/pt2-triage-rubric.md b/.claude/skills/triaging-issues/pt2-triage-rubric.md index 48eaf250fe89b..967725380f7be 100644 --- a/.claude/skills/triaging-issues/pt2-triage-rubric.md +++ b/.claude/skills/triaging-issues/pt2-triage-rubric.md @@ -2,6 +2,8 @@ This rubric guides labeling decisions for PT2 oncall triage. +**Every `oncall: pt2` issue MUST have at least one `module:` label.** The PT2 queue is too broad without one — the team needs to know which component is affected. Use the sections below to determine the right module label(s). + ## 1. Component Isolation - Be Precise, Don't Over-Tag ### Dynamo vs Dynamic Shapes @@ -33,6 +35,29 @@ When component isn't clear from the issue body: **This is critical** when you have identified an issue as inductor, and the failing device is "cpu" ONLY, then this is a CPU inductor issue, and should be redirected to `oncall: cpu inductor` +### Silently Dropped Operations = Dynamo + +If torch.compile silently drops or ignores an operation that works in eager, the bug is in Dynamo's tracing. + +| Signal | Label | +|--------|-------| +| In-place mutation skipped under compile (`detach_()`, `requires_grad_()`) | `module: dynamo` | +| Side-effect not captured (global state, tensor metadata) | `module: dynamo` | + +Don't apply `module: autograd` just because the dropped operation involves autograd. If eager works fine, the autograd engine is fine. + +### Decomposition Bugs + +If eager and traced results diverge numerically for a specific op, suspect a bad decomposition. + +| Signal | Label | +|--------|-------| +| Eager vs traced diverge for a specific op | `module: decompositions` | +| Higher-order gradients wrong under tracing | `module: decompositions` | +| `make_fx` symbolic tracing diverges from eager | `module: fx` + `module: decompositions` + `oncall: pt2` | + +--- + ## 3. Don't Over-Tag pt2-dispatcher `module: pt2-dispatcher` is for bugs **IN** the dispatcher code, not just when it appears in a stack trace. @@ -112,7 +137,34 @@ Check for existing labels before inventing categories: --- -## 7. functorch + compile +## 7. Helion Kernel Issues + +[Helion](https://github.com/pytorch/helion) is a high-level Python DSL for writing GPU kernels. Users write kernels with `@helion.kernel` and `helion.language` (`hl.tile()`, etc.) using standard PyTorch ops, and Helion compiles them to Triton. Helion kernels can also be fused into `torch.compile` graphs via inductor's template fusion hooks. + +### Identifying Helion Issues + +| Signal | Label | +|--------|-------| +| Issue mentions `helion`, `@helion.kernel`, `helion.language`, `hl.tile` | `module: helion` | +| Error traceback includes `helion/` or `helion.` frames | `module: helion` | +| Helion kernel produces wrong results (standalone, no torch.compile) | `module: helion` only | +| Helion kernel fails or miscompiles under `torch.compile` | `module: helion` + `module: inductor` | +| Inductor template fusion bug triggered by a Helion template | `module: helion` + `module: inductor` | + +### Routing + +- Helion issues get `oncall: pt2` — Helion is a PT2 component. +- If the bug is purely in Helion's own compilation (standalone kernel, not under torch.compile), apply `module: helion` without `module: inductor`. +- If the bug is in how inductor fuses or emits Helion templates, apply both `module: helion` and `module: inductor`. + +### Common Mistakes + +- **Don't confuse Helion with raw Triton**: If the user is writing Triton kernels directly (using `@triton.jit`, `tl.load`, etc.) without Helion, that's `module: inductor`, not `module: helion`. +- **Don't apply `module: helion` for general inductor codegen bugs**: Just because inductor generates Triton code doesn't make it a Helion issue. Helion is a specific DSL — look for explicit `helion` imports or mentions. + +--- + +## 8. Functorch + Compile | Situation | Labels | |-----------|--------| @@ -121,7 +173,7 @@ Check for existing labels before inventing categories: --- -## 8. High Priority Criteria +## 9. High Priority Criteria **This is critical** You should not explicitly add `high priority` - add `triage review` instead so that it is reviewed at the next triage meeting by the oncall. @@ -139,7 +191,7 @@ Mark `triage review` if ANY of these apply: --- -## 9. Fuzzer Issues +## 10. Fuzzer Issues For `topic: fuzzer` issues: @@ -151,7 +203,7 @@ For `topic: fuzzer` issues: --- -## 10. Quick Label Reference +## 11. Quick Label Reference ### Core Components - `module: dynamo` - Tracing, bytecode, graph breaks @@ -160,6 +212,7 @@ For `topic: fuzzer` issues: - `module: pt2-dispatcher` - AOT autograd, functionalization, FakeTensor - `module: cuda graphs` - CUDA graph capture/replay - `module: flex attention` - Flex attention API +- `module: helion` - Helion DSL kernel authoring, compilation, and inductor fusion ### Holistic Areas - `module: compile ux` - Error messages, APIs, programming model diff --git a/.claude/skills/triaging-issues/scripts/validate_labels.py b/.claude/skills/triaging-issues/scripts/validate_labels.py index cd2b4f22e5098..df489189a20b5 100755 --- a/.claude/skills/triaging-issues/scripts/validate_labels.py +++ b/.claude/skills/triaging-issues/scripts/validate_labels.py @@ -208,12 +208,18 @@ def main(): except json.JSONDecodeError as e: debug_log(f"JSON decode error: {e}") print(f"Hook error: Invalid JSON input: {e}", file=sys.stderr) - print("Hook was unable to validate labels; stopping triage.", file=sys.stderr) + print( + "Hook was unable to validate labels; stopping triage.", + file=sys.stderr, + ) sys.exit(2) except Exception as e: debug_log(f"Unexpected error: {type(e).__name__}: {e}") print(f"Hook error: {e}", file=sys.stderr) - print("Hook was unable to validate labels; stopping triage.", file=sys.stderr) + print( + "Hook was unable to validate labels; stopping triage.", + file=sys.stderr, + ) sys.exit(2) diff --git a/.flake8 b/.flake8 index 12d7432bbb7c6..f5507047d7f39 100644 --- a/.flake8 +++ b/.flake8 @@ -12,13 +12,21 @@ ignore = # to line this up with executable bit EXE001, # these ignores are from flake8-bugbear; please fix! - B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910 + B007,B008,B017,B019,B023,B028,B903,B905,B906,B907,B908,B910, # these ignores are from flake8-simplify. please fix or ignore with commented reason SIM105,SIM108,SIM110,SIM111,SIM113,SIM114,SIM115,SIM116,SIM117,SIM118,SIM119,SIM12, # SIM104 is already covered by pyupgrade ruff SIM104, # flake8-simplify code styles - SIM102,SIM103,SIM106,SIM112 + SIM102,SIM103,SIM106,SIM112, + # Codes where ruff is the source of truth. Ruff handles these (sometimes + # differently than flake8). As ruff promotes preview codes, move them here + # from ruff's external list. + B001,B036,B902,B904,B950, + C406,C417,C419,C901, + E121,E122,E124,E128,E131,E712,E722,E723,E731, + F401,F722,F723,F811,F812, + G001,G003,G004,G010,G200,G201,G202 per-file-ignores = __init__.py: F401 test/**: F821 diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 91c35a03948b5..1dc9f3f35c2fd 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -58,7 +58,6 @@ self-hosted-runner: - linux.rocm.gpu.2 - linux.rocm.gpu.4 - linux.rocm.mi210.docker-cache - - linux.rocm.mi250.docker-cache # gfx942 runners - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.4 @@ -68,10 +67,13 @@ self-hosted-runner: # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) - macos-m1-stable - macos-m1-14 + - macos-m2-15 + - macos-m2-26 # GitHub-hosted MacOS runners - macos-latest-xlarge - macos-13-xlarge - macos-14-xlarge + - macos-26-xlarge # Organization-wide Intel hosted XPU runners - linux.idc.xpu # Organization-wide Google Cloud TPU runners diff --git a/.github/actions/checkout-pytorch/action.yml b/.github/actions/checkout-pytorch/action.yml index 3ea7d295fb006..97303742c499c 100644 --- a/.github/actions/checkout-pytorch/action.yml +++ b/.github/actions/checkout-pytorch/action.yml @@ -60,6 +60,7 @@ runs: # --depth=1 for speed, manually fetch history and other refs as necessary fetch-depth: ${{ inputs.fetch-depth }} single-branch: true + fetch-tags: ${{ github.ref_type == 'tag' }} submodules: ${{ inputs.submodules }} show-progress: false filter: ${{ env.filter }} @@ -107,6 +108,7 @@ runs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} fetch-depth: ${{ inputs.fetch-depth }} single-branch: true + fetch-tags: ${{ github.ref_type == 'tag' }} submodules: ${{ inputs.submodules }} show-progress: false filter: ${{ env.filter }} diff --git a/.github/actions/download-build-artifacts/action.yml b/.github/actions/download-build-artifacts/action.yml index c44b6a4083448..ea78440e87686 100644 --- a/.github/actions/download-build-artifacts/action.yml +++ b/.github/actions/download-build-artifacts/action.yml @@ -18,14 +18,14 @@ runs: using: composite steps: - name: Download PyTorch Build Artifacts from S3 - if: ${{ !inputs.use-gha }} + if: inputs.use-gha == '' || inputs.use-gha == 'false' uses: seemethere/download-artifact-s3@v4 with: name: ${{ inputs.name }} s3-bucket: ${{ inputs.s3-bucket }} - name: Download PyTorch Build Artifacts from GHA - if: ${{ inputs.use-gha }} + if: inputs.use-gha != '' && inputs.use-gha != 'false' uses: actions/download-artifact@v4 with: name: ${{ inputs.name }} diff --git a/.github/actions/download-td-artifacts/action.yml b/.github/actions/download-td-artifacts/action.yml index 18766bf670f63..3b25f14ef6520 100644 --- a/.github/actions/download-td-artifacts/action.yml +++ b/.github/actions/download-td-artifacts/action.yml @@ -11,13 +11,13 @@ runs: using: composite steps: - name: Download TD Artifacts from S3 - if: ${{ !inputs.use-gha }} + if: inputs.use-gha == '' || inputs.use-gha == 'false' uses: seemethere/download-artifact-s3@v4 with: name: td_results - name: Download TD Artifacts from GHA - if: inputs.use-gha + if: inputs.use-gha != '' && inputs.use-gha != 'false' uses: actions/download-artifact@v4 with: name: td_results.json diff --git a/.github/actions/filter-test-configs/action.yml b/.github/actions/filter-test-configs/action.yml index 8afbc40cebeaa..36ccd839f26be 100644 --- a/.github/actions/filter-test-configs/action.yml +++ b/.github/actions/filter-test-configs/action.yml @@ -61,6 +61,7 @@ runs: using: composite steps: - name: Setup uv + if: ${{ !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" diff --git a/.github/actions/linux-test/action.yml b/.github/actions/linux-test/action.yml deleted file mode 100644 index 8ad74e308168a..0000000000000 --- a/.github/actions/linux-test/action.yml +++ /dev/null @@ -1,406 +0,0 @@ -name: linux-test - -inputs: - build-environment: - required: true - type: string - description: Top-level label for what's being built/tested. - test-matrix: - required: true - type: string - description: JSON description of what test configs to run. - docker-image: - required: true - type: string - description: Docker image to run in. - sync-tag: - required: false - type: string - default: "" - description: | - If this is set, our linter will use this to make sure that every other - job with the same `sync-tag` is identical. - use-gha: - required: false - type: string - default: "" - description: If set to any value, upload to GHA. Otherwise upload to S3. - dashboard-tag: - required: false - type: string - default: "" - s3-bucket: - description: S3 bucket to download artifact - required: false - type: string - default: "gha-artifacts" - aws-role-to-assume: - description: role to assume for downloading artifacts - required: false - type: string - default: "" - HUGGING_FACE_HUB_TOKEN: - description: | - HF Auth token to avoid rate limits when downloading models or datasets from hub - required: false - default: "" - GITHUB_TOKEN: - description: GitHub token - required: true - disable-monitor: - description: | - [Experimental] Disable utilization monitoring for tests. - Currently, by default we disable the monitor job and only look for specific tests, - since we are investigating the behaviour of the monitor script with different tests. - required: false - type: boolean - default: true -#env: -# GIT_DEFAULT_BRANCH: ${{ inputs.default_branch }} - -runs: - using: composite - steps: - - name: Setup Linux - uses: ./.github/actions/setup-linux - - - name: Login to ECR - uses: ./.github/actions/ecr-login - with: - aws-role-to-assume: ${{ inputs.aws-role-to-assume }} - - - name: Calculate docker image - id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@main - with: - docker-image-name: ${{ inputs.docker-image }} - - - name: Use following to pull public copy of the image - id: print-ghcr-mirror - env: - ECR_DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - shell: bash - run: | - tag=${ECR_DOCKER_IMAGE##*/} - echo "docker pull ghcr.io/pytorch/ci-image:${tag/:/-}" - - - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - - - name: Check if in a container runner - shell: bash - id: check_container_runner - run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT" - - - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - id: install-nvidia-driver - uses: pytorch/test-infra/.github/actions/setup-nvidia@main - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} - - - name: Setup GPU_FLAG for docker run - id: setup-gpu-flag - run: echo "GPU_FLAG=--gpus all -e NVIDIA_DRIVER_CAPABILITIES=all" >> "${GITHUB_ENV}" - if: ${{ contains(inputs.build-environment, 'cuda') && !contains(matrix.config, 'nogpu') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} - - - name: Setup SCCACHE_SERVER_PORT environment for docker run when on container - id: setup-sscache-port-flag - run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}" - if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' }} - - - name: Lock NVIDIA A100 40GB Frequency - shell: bash - run: | - sudo nvidia-smi -pm 1 - sudo nvidia-smi -ac 1215,1410 - nvidia-smi - if: ${{ contains(matrix.runner, 'a100') && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} - - - name: Start monitoring script - id: monitor-script - if: ${{ !inputs.disable-monitor }} - shell: bash - continue-on-error: true - run: | - python3 -m pip install psutil==5.9.8 nvidia-ml-py==11.525.84 - python3 -m tools.stats.monitor > usage_log.txt 2>&1 & - echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" - - - name: Download build artifacts - uses: ./.github/actions/download-build-artifacts - with: - name: ${{ inputs.build-environment }} - s3-bucket: ${{ inputs.s3-bucket }} - - - name: Download TD artifacts - continue-on-error: true - uses: ./.github/actions/download-td-artifacts - - - name: Parse ref - id: parse-ref - shell: bash - run: .github/scripts/parse_ref.py - - - name: Get workflow job id - id: get-job-id - uses: ./.github/actions/get-workflow-job-id - if: always() - with: - github-token: ${{ inputs.GITHUB_TOKEN }} - - - name: Check for keep-going label and re-enabled test issues - # This uses the filter-test-configs action because it conveniently - # checks for labels and re-enabled test issues. It does not actually do - # any filtering. All filtering is done in the build step. - id: keep-going - uses: ./.github/actions/filter-test-configs - with: - github-token: ${{ inputs.GITHUB_TOKEN }} - test-matrix: ${{ inputs.test-matrix }} - job-name: ${{ steps.get-job-id.outputs.job-name }} - - - name: Test - id: test - env: - BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - PR_NUMBER: ${{ github.event.pull_request.number }} - GITHUB_REPOSITORY: ${{ github.repository }} - GITHUB_WORKFLOW: ${{ github.workflow }} - GITHUB_JOB: ${{ github.job }} - GITHUB_RUN_ID: ${{ github.run_id }} - GITHUB_RUN_NUMBER: ${{ github.run_number }} - GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} - JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} - BRANCH: ${{ steps.parse-ref.outputs.branch }} - SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }} - TEST_CONFIG: ${{ matrix.config }} - SHARD_NUMBER: ${{ matrix.shard }} - NUM_TEST_SHARDS: ${{ matrix.num_shards }} - REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} - CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} - VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} - TEST_SHOWLOCALS: ${{ steps.keep-going.outputs.ci-test-showlocals }} - NO_TEST_TIMEOUT: ${{ steps.keep-going.outputs.ci-no-test-timeout }} - NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} - TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }} - SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 - SCCACHE_REGION: us-east-1 - SCCACHE_S3_KEY_PREFIX: ${{ github.workflow }} - SHM_SIZE: ${{ contains(inputs.build-environment, 'cuda') && '2g' || '1g' }} - DOCKER_IMAGE: ${{ inputs.docker-image }} - XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} - XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla - PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} - PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} - DASHBOARD_TAG: ${{ inputs.dashboard-tag }} - HUGGING_FACE_HUB_TOKEN: ${{ inputs.HUGGING_FACE_HUB_TOKEN }} - SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} - IS_A100_RUNNER: ${{ contains(matrix.runner, 'a100') && '1' || '0' }} - - shell: bash - run: | - set -x - - if [[ $TEST_CONFIG == 'multigpu' ]]; then - TEST_COMMAND=.ci/pytorch/multigpu-test.sh - elif [[ $BUILD_ENVIRONMENT == *onnx* ]]; then - TEST_COMMAND=.ci/onnx/test.sh - else - TEST_COMMAND=.ci/pytorch/test.sh - fi - - # detached container should get cleaned up by teardown_ec2_linux - # TODO: Stop building test binaries as part of the build phase - # Used for GPU_FLAG since that doesn't play nice - # shellcheck disable=SC2086,SC2090 - container_name=$(docker run \ - ${GPU_FLAG:-} \ - ${SCCACHE_SERVER_PORT_DOCKER_FLAG:-} \ - -e BUILD_ENVIRONMENT \ - -e PR_NUMBER \ - -e GITHUB_ACTIONS \ - -e GITHUB_REPOSITORY \ - -e GITHUB_WORKFLOW \ - -e GITHUB_JOB \ - -e GITHUB_RUN_ID \ - -e GITHUB_RUN_NUMBER \ - -e GITHUB_RUN_ATTEMPT \ - -e JOB_ID \ - -e JOB_NAME \ - -e BASE_SHA \ - -e BRANCH \ - -e SHA1 \ - -e AWS_DEFAULT_REGION \ - -e IN_WHEEL_TEST \ - -e SHARD_NUMBER \ - -e TEST_CONFIG \ - -e NUM_TEST_SHARDS \ - -e REENABLED_ISSUES \ - -e CONTINUE_THROUGH_ERROR \ - -e VERBOSE_TEST_LOGS \ - -e NO_TEST_TIMEOUT \ - -e NO_TD \ - -e TD_DISTRIBUTED \ - -e PR_LABELS \ - -e MAX_JOBS="$(nproc --ignore=2)" \ - -e SCCACHE_BUCKET \ - -e SCCACHE_REGION \ - -e SCCACHE_S3_KEY_PREFIX \ - -e XLA_CUDA \ - -e XLA_CLANG_CACHE_S3_BUCKET_NAME \ - -e PYTORCH_TEST_CUDA_MEM_LEAK_CHECK \ - -e PYTORCH_TEST_RERUN_DISABLED_TESTS \ - -e SKIP_SCCACHE_INITIALIZATION=1 \ - -e HUGGING_FACE_HUB_TOKEN \ - -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ - -e DASHBOARD_TAG \ - -e IS_A100_RUNNER \ - --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ - --security-opt seccomp=unconfined \ - --cap-add=SYS_PTRACE \ - --ipc=host \ - --shm-size="${SHM_SIZE}" \ - --tty \ - --detach \ - --name="${container_name}" \ - --user jenkins \ - -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ - -w /var/lib/jenkins/workspace \ - "${DOCKER_IMAGE}" - ) - echo "DOCKER_CONTAINER_ID=${container_name}" >> "${GITHUB_ENV}" - docker exec -t "${container_name}" sh -c "pip install $(echo dist/*.whl)[opt-einsum] && ${TEST_COMMAND}" - - - name: Upload pytest cache if tests failed - uses: ./.github/actions/pytest-cache-upload - continue-on-error: true - if: failure() && steps.test.conclusion && steps.test.conclusion == 'failure' - with: - cache_dir: .pytest_cache - shard: ${{ matrix.shard }} - sha: ${{ github.event.pull_request.head.sha || github.sha }} - test_config: ${{ matrix.config }} - job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} - - - name: Print remaining test logs - shell: bash - if: always() && steps.test.conclusion - run: | - cat test/**/*_toprint.log || true - - - name: Stop monitoring script - if: ${{ always() && steps.monitor-script.outputs.monitor-script-pid }} - shell: bash - continue-on-error: true - env: - MONITOR_SCRIPT_PID: ${{ steps.monitor-script.outputs.monitor-script-pid }} - run: | - kill "$MONITOR_SCRIPT_PID" - - - name: Upload test artifacts - uses: ./.github/actions/upload-test-artifacts - if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' - with: - file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} - use-gha: ${{ inputs.use-gha }} - s3-bucket: ${{ inputs.s3-bucket }} - - - name: Collect backtraces from coredumps (if any) - if: always() - shell: bash - run: | - # shellcheck disable=SC2156 - find . -iname "core.[1-9]*" -exec docker exec "${DOCKER_CONTAINER_ID}" sh -c "gdb python {} -ex 'bt' -ex 'q'" \; - - - name: Store Core dumps on S3 - uses: seemethere/upload-artifact-s3@v5 - if: failure() - with: - name: coredumps-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }} - retention-days: 14 - if-no-files-found: ignore - path: ./**/core.[1-9]* - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() && steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' - - # NB: We are currently having an intermittent GPU-related issue on G5 runners with - # A10G GPU. Once this happens, trying to reset the GPU as done in setup-nvidia does - # not seem to help. Here are some symptoms: - # * Calling nvidia-smi timeouts after 60 second - # * Fail to run nvidia-smi with an unable to determine the device handle for GPU - # unknown error - # * Test fails with a missing CUDA GPU error when initializing CUDA in PyTorch - # * Run docker --gpus all fails with error response from daemon - # - # As both the root cause and recovery path are unclear, let's take the runner out of - # service so that it doesn't get any more jobs - - name: Check NVIDIA driver installation step - if: failure() && steps.install-nvidia-driver.outcome && steps.install-nvidia-driver.outcome != 'skipped' - shell: bash - env: - RUNNER_WORKSPACE: ${{ runner.workspace }} - run: | - set +e - set -x - - nvidia-smi - # NB: Surprisingly, nvidia-smi command returns successfully with return code 0 even in - # the case where the driver has already crashed as it still can get the driver version - # and some basic information like the bus ID. However, the rest of the information - # would be missing (ERR!), for example: - # - # +-----------------------------------------------------------------------------+ - # | NVIDIA-SMI 525.89.02 Driver Version: 525.89.02 CUDA Version: 12.0 | - # |-------------------------------+----------------------+----------------------+ - # | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | - # | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | - # | | | MIG M. | - # |===============================+======================+======================| - # | 0 ERR! Off | 00000000:00:1E.0 Off | ERR! | - # |ERR! ERR! ERR! ERR! / ERR! | 4184MiB / 23028MiB | ERR! Default | - # | | | ERR! | - # +-------------------------------+----------------------+----------------------+ - # - # +-----------------------------------------------------------------------------+ - # | Processes: | - # | GPU GI CI PID Type Process name GPU Memory | - # | ID ID Usage | - # |=============================================================================| - # +-----------------------------------------------------------------------------+ - # - # This should be reported as a failure instead as it will guarantee to fail when - # Docker tries to run with --gpus all - # - # So, the correct check here is to query one of the missing piece of info like - # GPU name, so that the command can fail accordingly - nvidia-smi --query-gpu=gpu_name --format=csv,noheader --id=0 - NVIDIA_SMI_STATUS=$? - - # These are acceptable return code from nvidia-smi as copied from setup-nvidia GitHub action - if [ "$NVIDIA_SMI_STATUS" -ne 0 ] && [ "$NVIDIA_SMI_STATUS" -ne 14 ]; then - echo "NVIDIA driver installation has failed, shutting down the runner..." - .github/scripts/stop_runner_service.sh - fi - - # For runner with multiple GPUs, we also want to confirm that the number of GPUs are the - # power of 2, i.e. 1, 2, 4, or 8. This is to avoid flaky test issue when one GPU fails - # https://github.com/pytorch/test-infra/issues/4000 - GPU_COUNT=$(nvidia-smi --list-gpus | wc -l) - NVIDIA_SMI_STATUS=$? - - # These are acceptable return code from nvidia-smi as copied from setup-nvidia GitHub action - if [ "$NVIDIA_SMI_STATUS" -ne 0 ] && [ "$NVIDIA_SMI_STATUS" -ne 14 ]; then - echo "NVIDIA driver installation has failed, shutting down the runner..." - .github/scripts/stop_runner_service.sh - fi - - # Check the GPU count to be a power of 2 - if [ "$GPU_COUNT" -le 8 ] && [ "$GPU_COUNT" -ne 1 ] && [ "$GPU_COUNT" -ne 2 ] && [ "$GPU_COUNT" -ne 4 ] && [ "$GPU_COUNT" -ne 8 ]; then - echo "NVIDIA driver detects $GPU_COUNT GPUs. The runner has a broken GPU, shutting it down..." - .github/scripts/stop_runner_service.sh - fi diff --git a/.github/actions/pytest-cache-download/action.yml b/.github/actions/pytest-cache-download/action.yml index e23e0a5eaba70..bb9f3f33aef62 100644 --- a/.github/actions/pytest-cache-download/action.yml +++ b/.github/actions/pytest-cache-download/action.yml @@ -18,6 +18,7 @@ runs: using: composite steps: - name: Setup uv + if: ${{ !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" diff --git a/.github/actions/pytest-cache-upload/action.yml b/.github/actions/pytest-cache-upload/action.yml index 46e7d2db7e935..ac6c1a709c15e 100644 --- a/.github/actions/pytest-cache-upload/action.yml +++ b/.github/actions/pytest-cache-upload/action.yml @@ -25,6 +25,7 @@ runs: using: composite steps: - name: Setup uv + if: ${{ !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" diff --git a/.github/actions/reuse-old-whl/action.yml b/.github/actions/reuse-old-whl/action.yml index 46e7d0e278ba4..584c3a2ab64af 100644 --- a/.github/actions/reuse-old-whl/action.yml +++ b/.github/actions/reuse-old-whl/action.yml @@ -30,6 +30,7 @@ runs: steps: - name: Setup uv + if: ${{ !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" diff --git a/.github/actions/setup-linux/action.yml b/.github/actions/setup-linux/action.yml index 6314e0e1f7db4..e9f0b74cbb97b 100644 --- a/.github/actions/setup-linux/action.yml +++ b/.github/actions/setup-linux/action.yml @@ -1,11 +1,100 @@ name: Setup Linux -description: Set up Docker workspace on EC2 +description: Set up Linux workspace on EC2 or OSDC ARC runners + +inputs: + use-arc: + description: Whether the runner is an OSDC ARC runner (returned by runner_determinator) + required: false + default: '' + python-version: + description: Python version to install (e.g. "3.12"). Empty string keeps the default. + required: false + default: '' + compiler: + description: Compiler name and version (e.g. "gcc11", "clang15"). Empty string keeps the base image default. + required: false + default: '' + cuda-version: + description: CUDA version to activate (e.g. "13.0"). Empty string keeps the default. + required: false + default: '' + submodules: + description: Submodule checkout mode passed to checkout-pytorch (default "recursive", use "false" for test jobs). + required: false + default: 'recursive' + github-token: + description: GITHUB_TOKEN, needed to retrieve the workflow job id. + required: false + default: '' + +outputs: + branch: + description: Parsed branch name from GITHUB_REF + value: ${{ steps.parse-ref.outputs.branch }} + tag: + description: Parsed tag name from GITHUB_REF (if applicable) + value: ${{ steps.parse-ref.outputs.tag }} + job-id: + description: The workflow job id + value: ${{ steps.get-job-id.outputs.job-id }} + job-name: + description: The workflow job name + value: ${{ steps.get-job-id.outputs.job-name }} runs: using: composite steps: + # ── ARC-only steps ────────────────────────────────────────────────── + - name: Fix workspace permissions + if: ${{ inputs.use-arc == 'true' }} + shell: bash + run: | + # GH runner image has switched to uid 1001 https://github.com/actions/runner-images/issues/10936 + # while current PyTorch CI image are still using uid 1000 (ec2-user). We + # can update the uid to 1001 eventually when everything migrates to ARC. + # In the meantime, this is a quick fix to ensure that CI has the permission + # to use the GITHUB_WORKSPACE while still allowing the GH hook (uid 1001, + # gid 1001) to clean up the directory after the job + sudo chmod -R 777 "$GITHUB_WORKSPACE" + + - name: Ack Git cache ownership + if: ${{ inputs.use-arc == 'true' }} + shell: bash + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + # ── Common steps (shared) ──────────────────────────────────────── + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + no-sudo: true + checkout-mode: treeless + submodules: ${{ inputs.submodules }} + + - name: Parse ref + id: parse-ref + shell: bash + run: | + if [ -f "${{ github.action_path }}/../../scripts/parse_ref.py" ]; then + python3 "${{ github.action_path }}/../../scripts/parse_ref.py" + elif [ -f .github/scripts/parse_ref.py ]; then + python3 .github/scripts/parse_ref.py + else + echo "ERROR: parse_ref.py not found" >&2 + exit 1 + fi + + - name: Get workflow job id + id: get-job-id + if: ${{ always() && inputs.github-token != '' }} + uses: pytorch/pytorch/.github/actions/get-workflow-job-id@main + with: + github-token: ${{ inputs.github-token }} + + # ── EC2-only steps ────────────────────────────────────────────────── - name: Display EC2 information + if: ${{ inputs.use-arc != 'true' }} shell: bash run: | set -euo pipefail @@ -28,11 +117,8 @@ runs: echo "instance-type: $(get_ec2_metadata instance-type)" echo "system info $(uname -a)" - - name: Print GPU info (if present) - shell: bash - run: if [ -f /usr/bin/nvidia-smi ]; then nvidia-smi; fi - - name: Check if in a container runner + if: ${{ inputs.use-arc != 'true' }} shell: bash id: check_container_runner run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT" @@ -49,7 +135,7 @@ runs: fi fi - - name: Install uv + - name: Install uv (EC2) uses: pytorch/test-infra/.github/actions/setup-uv@main if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} with: @@ -67,12 +153,6 @@ runs: uv tool install pip fi - - name: Preserve github env variables for use in docker - shell: bash - run: | - env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" - env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" - - name: Kill any existing containers, clean up images if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'false' }} shell: bash @@ -107,3 +187,14 @@ runs: done echo "Reached maximum attempts to connect to Docker. Exiting." exit 1 + + # ── Shared steps ──────────────────────────────────────────────────── + - name: Print GPU info (if present) + shell: bash + run: if [ -f /usr/bin/nvidia-smi ]; then nvidia-smi; fi + + - name: Preserve github env variables for use in docker + shell: bash + run: | + env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" + env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" diff --git a/.github/actions/setup-rocm/action.yml b/.github/actions/setup-rocm/action.yml index 5bb982a4085b1..3c1445b85fa84 100644 --- a/.github/actions/setup-rocm/action.yml +++ b/.github/actions/setup-rocm/action.yml @@ -110,7 +110,7 @@ runs: # This is due to the device files (/dev/kfd & /dev/dri) being owned by video group on bare metal. # This video group ID maps to subgid 1 inside the docker image due to the /etc/subgid entries. # The group name corresponding to group ID 1 can change depending on the OS, so both are necessary. - echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd $DEVICE_FLAG --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --network=host" >> "${GITHUB_ENV}" + echo "GPU_FLAG=--device=/dev/mem --device=/dev/kfd $DEVICE_FLAG --group-add video --group-add $render_gid --group-add daemon --group-add bin --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" >> "${GITHUB_ENV}" - name: Login to ECR uses: pytorch/pytorch/.github/actions/ecr-login@main diff --git a/.github/actions/setup-xpu/action.yml b/.github/actions/setup-xpu/action.yml index 740492475d6e2..23b750205ab08 100644 --- a/.github/actions/setup-xpu/action.yml +++ b/.github/actions/setup-xpu/action.yml @@ -44,7 +44,7 @@ runs: fi - name: Runner diskspace health check - uses: ./.github/actions/diskspace-cleanup + uses: pytorch/pytorch/.github/actions/diskspace-cleanup@main if: always() - name: Runner health check disconnect on failure @@ -53,11 +53,30 @@ runs: run: | killall runsvc.sh + - name: Setup useful environment variables + shell: bash + run: | + RUNNER_ARTIFACT_DIR="${RUNNER_TEMP}/artifacts" + rm -rf "${RUNNER_ARTIFACT_DIR}" + mkdir -p "${RUNNER_ARTIFACT_DIR}" + echo "RUNNER_ARTIFACT_DIR=${RUNNER_ARTIFACT_DIR}" >> "${GITHUB_ENV}" + + RUNNER_TEST_RESULTS_DIR="${RUNNER_TEMP}/test-results" + rm -rf "${RUNNER_TEST_RESULTS_DIR}" + mkdir -p "${RUNNER_TEST_RESULTS_DIR}" + echo "RUNNER_TEST_RESULTS_DIR=${RUNNER_TEST_RESULTS_DIR}" >> "${GITHUB_ENV}" + + RUNNER_DOCS_DIR="${RUNNER_TEMP}/docs" + rm -rf "${RUNNER_DOCS_DIR}" + mkdir -p "${RUNNER_DOCS_DIR}" + echo "RUNNER_DOCS_DIR=${RUNNER_DOCS_DIR}" >> "${GITHUB_ENV}" + - name: Preserve github env variables for use in docker shell: bash run: | - env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" - env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" + env | grep '^GITHUB' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" + env | grep '^CI' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" + env | grep '^RUNNER' >> "${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" - name: XPU set GPU_FLAG shell: bash @@ -65,3 +84,6 @@ runs: # Add render group for container creation. render_gid=`cat /etc/group | grep render | cut -d: -f3` echo "GPU_FLAG=--device=/dev/mem --device=/dev/dri --group-add video --group-add $render_gid" >> "${GITHUB_ENV}" + + - name: Login to ECR + uses: pytorch/pytorch/.github/actions/ecr-login@main diff --git a/.github/actions/upload-build-artifacts/action.yml b/.github/actions/upload-build-artifacts/action.yml new file mode 100644 index 0000000000000..62413131581ea --- /dev/null +++ b/.github/actions/upload-build-artifacts/action.yml @@ -0,0 +1,37 @@ +name: Upload PyTorch Build Artifacts + +description: Upload build artifacts to S3 or GHA. + +inputs: + name: + description: Name of the artifact + required: true + use-gha: + description: If set to any value, use GHA to upload the artifact. Otherwise use S3. + required: false + s3-bucket: + description: S3 bucket to upload builds + required: false + default: "gha-artifacts" + +runs: + using: composite + steps: + - name: Upload PyTorch Build Artifacts to S3 + if: inputs.use-gha == '' || inputs.use-gha == 'false' + uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0 + with: + name: ${{ inputs.name }} + retention-days: 14 + if-no-files-found: error + path: artifacts.zip + s3-bucket: ${{ inputs.s3-bucket }} + + - name: Upload PyTorch Build Artifacts to GHA + if: inputs.use-gha != '' && inputs.use-gha != 'false' + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 + with: + name: ${{ inputs.name }} + retention-days: 14 + if-no-files-found: error + path: artifacts.zip diff --git a/.github/actions/upload-test-artifacts/action.yml b/.github/actions/upload-test-artifacts/action.yml index 6f4c8a89c0f45..9fcb125a42313 100644 --- a/.github/actions/upload-test-artifacts/action.yml +++ b/.github/actions/upload-test-artifacts/action.yml @@ -25,7 +25,7 @@ runs: steps: # Always set up uv and zip files first (needed for S3, reusable for GHA fallback) - name: Setup uv - if: runner.os != 'Windows' + if: ${{ runner.os != 'Windows' && !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" @@ -37,7 +37,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | set -euo pipefail - rm -f test-jsons-*.zip test-reports-*.zip logs-*.zip debug-*.zip + rm -f test-jsons-*.zip test-reports-*.zip logs-*.zip debug-*.zip profiler-traces-*.zip tlparse-*.zip ZIP_CMD=' import sys, zipfile, os from pathlib import Path @@ -60,6 +60,14 @@ runs: uv run --no-project python -c "$ZIP_CMD" \ test/debug "debug-${FILE_SUFFIX}.zip" '**/*' fi + if [ -d 'test/test-reports/profiler_traces' ]; then + uv run --no-project python -c "$ZIP_CMD" \ + test/test-reports/profiler_traces "profiler-traces-${FILE_SUFFIX}.zip" '**/*' + fi + if [ -d 'test/test-reports/tlparse_output' ]; then + uv run --no-project python -c "$ZIP_CMD" \ + test/test-reports/tlparse_output "tlparse-${FILE_SUFFIX}.zip" '**/*' + fi # Windows zip - name: Zip JSONs for upload @@ -141,6 +149,30 @@ runs: if-no-files-found: ignore path: debug-*.zip + - name: Store Profiler Traces on S3 + id: s3-upload-profiler-traces + uses: seemethere/upload-artifact-s3@v5 + continue-on-error: true + with: + s3-bucket: ${{ inputs.s3-bucket }} + s3-prefix: | + ${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact + retention-days: 14 + if-no-files-found: ignore + path: profiler-traces-*.zip + + - name: Store TLParse Output on S3 + id: s3-upload-tlparse + uses: seemethere/upload-artifact-s3@v5 + continue-on-error: true + with: + s3-bucket: ${{ inputs.s3-bucket }} + s3-prefix: | + ${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact + retention-days: 14 + if-no-files-found: ignore + path: tlparse-*.zip + # Check if S3 upload failed (test-reports is the critical one) - name: Check S3 upload status id: check-s3 @@ -193,3 +225,21 @@ runs: path: | usage_log.txt test/**/*.log + + - name: Store Profiler Traces on Github + uses: actions/upload-artifact@v4 + continue-on-error: true + with: + name: profiler-traces-runattempt${{ github.run_attempt }}-${{ inputs.file-suffix }} + retention-days: 14 + if-no-files-found: ignore + path: test/test-reports/profiler_traces/**/* + + - name: Store TLParse Output on Github + uses: actions/upload-artifact@v4 + continue-on-error: true + with: + name: tlparse-runattempt${{ github.run_attempt }}-${{ inputs.file-suffix }} + retention-days: 14 + if-no-files-found: ignore + path: test/test-reports/tlparse_output/**/* diff --git a/.github/actions/upload-utilization-stats/action.yml b/.github/actions/upload-utilization-stats/action.yml index 61332feaec688..007abb7c4bdf4 100644 --- a/.github/actions/upload-utilization-stats/action.yml +++ b/.github/actions/upload-utilization-stats/action.yml @@ -39,6 +39,7 @@ runs: using: composite steps: - name: Setup uv + if: ${{ !env.UV_PYTHON }} uses: pytorch/test-infra/.github/actions/setup-uv@main with: python-version: "3.12" diff --git a/.github/allowlist.yml b/.github/allowlist.yml new file mode 100644 index 0000000000000..a2b7da7c14ad7 --- /dev/null +++ b/.github/allowlist.yml @@ -0,0 +1,39 @@ +# Cross Repo CI Relay (CRCR) Allowlist +# +# This document defines which downstream repositories can receive PyTorch +# PR events or feedback through CRCR. Each repository is assigned a level +# (L1-L4), which determines the depth to which downstream repositories can +# participate. +# +# Trust Levels: +# +# L1 – Onboarding +# Events are forwarded to downstream, but upstream receives no feedback. +# +# L2 – Observation +# Downstream CI results are displayed on the HUD page, but not on PRs. +# +# L3 – Stable +# Adds a non-blocking check run on PRs when ciflow/oot/ label is applied. +# +# L4 – Mature +# Adds a blocking check run on every PR; reserved for critical accelerators. +# +# For more information, see the RFC: https://github.com/pytorch/rfcs/pull/90 + +# Example: +# +# L1: +# - org1/downstream-repo1 +# +# L2: +# - org2/downstream-repo2 +# +# L3: +# - org3/downstream-repo3: @oncall1,oncall2 +# +# L4: +# - org4/downstream-repo4: @oncall1,oncall2 + +L1: + - Ascend/pytorch diff --git a/.github/arc.yaml b/.github/arc.yaml new file mode 100644 index 0000000000000..09be5f49c62ab --- /dev/null +++ b/.github/arc.yaml @@ -0,0 +1,95 @@ +# ARC (Actions Runner Controller) Runner Label Mapping +# +# Maps current GitHub Actions runner labels to new ARC runner labels. +# Reference: https://github.com/pytorch/ci-infra/issues/396 +# +# New label format: +# {os}-[b]{arch}{vendor}{features}-{vcpu}-{memory}[-{gpu_type}[-{gpu_count}]] +# +# Fields: +# os - l=Linux, w=Windows, m=MacOS +# b - (optional) bare-metal instance +# arch - x86=x86_64, arm64=AArch64 +# vendor - i=Intel, a=AMD, g2/g3/g4=Graviton gen +# features - (x86 only) avx2, avx512, amx +# vcpu - vCPU count +# memory - RAM in GiB +# gpu_type - (optional) t4, a10g, l4 +# gpu_count- (optional, omitted when 1) +# +# Entries marked "# upgraded" had no exact ARC equivalent and were mapped to +# the next larger available runner. + +runner_mapping: + + # ---- x86 CPU — Intel AVX-512 (c5, c7i families) ---- + + linux.large: l-x86iavx512-2-4 # c5.large + linux.2xlarge: l-x86iavx512-8-64 # c5.2xlarge + linux.c7i.2xlarge: l-x86iavx512-8-64 # c7i.2xlarge + linux.4xlarge: l-x86iavx512-16-128 # c5.4xlarge + linux.c7i.4xlarge: l-x86iavx512-16-128 # c7i.4xlarge + linux.12xlarge: l-x86iavx512-48-384 # c5.12xlarge + linux.c7i.12xlarge: l-x86iavx512-48-384 # c7i.12xlarge + linux.24xl.spr-metal: l-bx86iamx-92-167 # c7i.metal-24xl + + # ---- x86 CPU — Intel AMX (m7i-flex family) ---- + + linux.2xlarge.amx: l-x86iamx-8-64 # m7i-flex.2xlarge + linux.8xlarge.amx: l-x86iamx-32-128 # m7i-flex.8xlarge + + # ---- x86 CPU — Intel AVX2 (m4 family) ---- + + linux.2xlarge.avx2: l-x86iavx2-8-32 # m4.2xlarge + linux.10xlarge.avx2: l-x86iavx2-40-160 # m4.10xlarge + + # ---- x86 CPU — Memory-optimized (r5, r7i families) ---- + + linux.r7i.2xlarge: l-x86iavx512-8-64 # r7i.2xlarge + linux.r7i.4xlarge: l-x86iavx512-16-128 # r7i.4xlarge + linux.4xlarge.memory: l-x86iavx512-16-128 # r5.4xlarge + linux.8xlarge.memory: l-x86iavx512-32-256 # r5.8xlarge + linux.12xlarge.memory: l-x86iavx512-48-384 # r5.12xlarge + linux.24xlarge.memory: l-x86iavx512-94-768 # r5.24xlarge + + # ---- x86 CPU — AMD (m6a, m7a families) ---- + + linux.24xlarge.amd: l-x86aavx512-125-463 # m7a.24xlarge + + # ---- x86 GPU — T4 (g4dn family) ---- + + linux.g4dn.4xlarge.nvidia.gpu: l-x86iavx512-29-115-t4 # g4dn.4xlarge + linux.g4dn.12xlarge.nvidia.gpu: l-x86iavx512-45-172-t4-4 # g4dn.12xlarge + linux.g4dn.metal.nvidia.gpu: l-bx86iavx512-94-344-t4-8 # g4dn.metal + + # ---- x86 GPU — A10G (g5 family) ---- + + linux.g5.4xlarge.nvidia.gpu: l-x86aavx2-29-113-a10g # g5.4xlarge + linux.g5.12xlarge.nvidia.gpu: l-x86aavx2-45-167-a10g-4 # g5.12xlarge + linux.g5.48xlarge.nvidia.gpu: l-x86aavx2-189-704-a10g-8 # g5.48xlarge + + # ---- x86 GPU — L4 (g6 family) ---- + + linux.g6.4xlarge.experimental.nvidia.gpu: l-x86aavx2-29-113-l4 # g6.4xlarge + linux.g6.12xlarge.nvidia.gpu: l-x86aavx2-45-172-l4-4 # g6.12xlarge + + # ---- ARM64 — Graviton ---- + + linux.arm64.2xlarge: l-arm64g2-6-32 # t4g.2xlarge + linux.arm64.2xlarge.ephemeral: l-arm64g2-6-32 # t4g.2xlarge + linux.arm64.m7g.4xlarge: l-arm64g3-16-62 # m7g.4xlarge + linux.arm64.m8g.4xlarge: l-arm64g4-16-62 # m8g.4xlarge + linux.arm64.r7g.12xlarge.memory: l-arm64g3-61-463 # r7g.12xlarge + linux.arm64.m7g.metal: l-barm64g4-62-226 # m7g.metal + + + # ---- x86 GPU — B200 (p6 family) ---- + + linux.dgx.b200: l-x86iamx-22-225-b200 # p6-b200.48xlarge (1 GPU) + + # ---- Partner hardwares ---- + + linux.idc.xpu: linux.idc.xpu + linux.rocm.gpu.2: linux.rocm.gpu.2 + linux.rocm.gpu.gfx950.1: linux.rocm.gpu.gfx950.1 + linux.rocm.gpu.gfx950.2: linux.rocm.gpu.gfx950.2 diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index ad631f37d819f..e5487b7ac5eff 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -11ed3578e70404c7329f219b02d78b7a89603ebe +c0cbdb95674556cdff7266f2d44bb855f634cfde diff --git a/.github/ci_commit_pins/torchao.txt b/.github/ci_commit_pins/torchao.txt index f33eac1d8778d..261d79443851c 100644 --- a/.github/ci_commit_pins/torchao.txt +++ b/.github/ci_commit_pins/torchao.txt @@ -1 +1 @@ -985d970b5e16b58c1e5b8bab440169d3da78cf16 +fbce9178697ca22d604314b316feb765360e0dec diff --git a/.github/ci_commit_pins/torchtitan.txt b/.github/ci_commit_pins/torchtitan.txt index 2bdccc1802837..0c66dcaabc888 100644 --- a/.github/ci_commit_pins/torchtitan.txt +++ b/.github/ci_commit_pins/torchtitan.txt @@ -1 +1 @@ -fa8e6ccc973a4a6d32a6e236156193f9eacdfa18 +7c8fc3f540edaec84d4610066ccff423141f7aa7 diff --git a/.github/ci_commit_pins/vision.txt b/.github/ci_commit_pins/vision.txt index 25b0151690966..2deb0694e30fe 100644 --- a/.github/ci_commit_pins/vision.txt +++ b/.github/ci_commit_pins/vision.txt @@ -1 +1 @@ -d63f7ed275b36e0fd6b37e52eef023ec35f337b0 +8ad7115a03bc886b68c76c8aeb412855cd0ec802 diff --git a/.github/ci_commit_pins/vllm.txt b/.github/ci_commit_pins/vllm.txt index 777d7a629ad8c..f2c9cd509ea77 100644 --- a/.github/ci_commit_pins/vllm.txt +++ b/.github/ci_commit_pins/vllm.txt @@ -1 +1 @@ -e9163b536e721c431500f6f43ace22fcb3532e7e +0e884fe638a3120a8772c3e95e71728b56db20e4 diff --git a/.github/ci_configs/vllm/Dockerfile b/.github/ci_configs/vllm/Dockerfile index 549c336a444cb..94c95a3c84c8f 100644 --- a/.github/ci_configs/vllm/Dockerfile +++ b/.github/ci_configs/vllm/Dockerfile @@ -130,7 +130,7 @@ COPY . . RUN python3 use_existing_torch.py RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/build.txt + uv pip install --system -r requirements/build/cuda.txt ARG GIT_REPO_CHECK=0 RUN --mount=type=bind,source=.git,target=.git \ @@ -237,13 +237,13 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_LINK_MODE=copy # Install build and runtime dependencies -COPY requirements/build.txt requirements/build.txt +COPY requirements/build/cuda.txt requirements/build/cuda.txt COPY use_existing_torch.py use_existing_torch.py RUN python3 use_existing_torch.py -RUN cat requirements/build.txt +RUN cat requirements/build/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/build.txt + uv pip install --system -r requirements/build/cuda.txt # Default mount file as placeholder, this just avoid the mount error ARG TORCH_WHEELS_PATH="./requirements" diff --git a/.github/ci_configs/vllm/use_existing_torch.py b/.github/ci_configs/vllm/use_existing_torch.py index 3d59fd67a398c..52b5a308219f9 100644 --- a/.github/ci_configs/vllm/use_existing_torch.py +++ b/.github/ci_configs/vllm/use_existing_torch.py @@ -1,22 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse import glob import os +import re +import sys + + +# Only strip targeted libraries when checking prefix +TORCH_LIB_PREFIXES = ( + # requirements/*.txt/in + "torch=", + "torchvision=", + "torchaudio=", + # pyproject.toml + '"torch =', + '"torchvision =', + '"torchaudio =', +) + +# Match lines where the package name is exactly torch/torchvision/torchaudio, +# not a substring of another package (e.g. terratorch, open_clip_torch). +_TORCH_PKG_RE = re.compile( + r"""^\s*['"]?\s*(?:torchvision|torchaudio|torch)\s*(?:[=<>!;\[,\]'"@~#(]|$)""", + re.IGNORECASE, +) + + +def main(argv): + parser = argparse.ArgumentParser( + description="Strip torch lib requirements to use installed version." + ) + parser.add_argument( + "--prefix", + action="store_true", + help="Strip prefix matches only (default: False)", + ) + args = parser.parse_args(argv) + + for file in ( + *glob.glob("requirements/**/*.txt", recursive=True), + *glob.glob("requirements/**/*.in", recursive=True), + "pyproject.toml", + ): + if not os.path.exists(file): + continue + with open(file) as f: + lines = f.readlines() + if "torch" in "".join(lines).lower(): + with open(file, "w") as f: + for line in lines: + if ( + args.prefix + and not line.lower().strip().startswith(TORCH_LIB_PREFIXES) + or not args.prefix + and not _TORCH_PKG_RE.match(line) + ): + f.write(line) + else: + print(f">>> removed from {file}:", line.strip()) -requires_files = glob.glob("requirements/*.txt") -requires_files += ["pyproject.toml"] - -for file in requires_files: - if not os.path.exists(file): - print(f"!!! skipping missing {file}") - continue - print(f">>> cleaning {file}") - with open(file) as f: - lines = f.readlines() - if "torch" in "".join(lines).lower(): - print("removed:") - with open(file, "w") as f: - for line in lines: - if "torch" not in line.lower(): - f.write(line) - print(f"<<< done cleaning {file}") - print() +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/.github/labeler.yml b/.github/labeler.yml index f1d10d3a8082b..c8f77e91be65d 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -62,16 +62,6 @@ - test/test_jit_llga_fuser.py - test/test_mkldnn.py -"ciflow/linux-aarch64": -- third_party/ideep -- caffe2/ideep/** -- caffe2/python/ideep/** -- cmake/Modules/FindMKLDNN.cmake -- third_party/mkl-dnn.BUILD -- torch/csrc/jit/codegen/onednn/** -- test/test_jit_llga_fuser.py -- test/test_mkldnn.py - "module: amp (automated mixed precision)": - torch/amp/** - aten/src/ATen/autocast_mode.* @@ -200,6 +190,8 @@ - test/distributed/**/*mem*/** "ciflow/torchtitan": +# torchtitan commit pin updates +- .github/ci_commit_pins/torchtitan.txt # torch.distributed (FSDP, DTensor, etc.) - torch/distributed/** # torch.compile diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index f47c3bc005cd2..c52035f7089ad 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -27,7 +27,6 @@ ciflow_push_tags: - ciflow/inductor-rocm-mi200 - ciflow/inductor-rocm-mi300 - ciflow/inductor-rocm-mi355 -- ciflow/linux-aarch64 - ciflow/mps - ciflow/nightly - ciflow/op-benchmark @@ -48,6 +47,7 @@ ciflow_push_tags: - ciflow/slow-rocm-mi200 - ciflow/torchbench - ciflow/torchtitan +- ciflow/tsan - ciflow/triton_binaries - ciflow/trunk - ciflow/unstable @@ -60,6 +60,8 @@ retryable_workflows: - linux-binary - windows-binary - inductor-A100-perf-nightly +retryable_step_names: +- Initialize containers labeler_config: labeler.yml label_to_label_config: label_to_label.yml mergebot: true diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index c274ca1e5914d..9bd898687ba67 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -11,5 +11,5 @@ lintrunner==0.12.7 ninja==1.10.0.post1 nvidia-ml-py==11.525.84 pyyaml==6.0.2 -requests==2.32.4 +requests==2.33.0 rich==14.1.0 diff --git a/.github/scripts/amd/package_triton_wheel.sh b/.github/scripts/amd/package_triton_wheel.sh index 501e50e2fe2f1..549bacc5a7573 100755 --- a/.github/scripts/amd/package_triton_wheel.sh +++ b/.github/scripts/amd/package_triton_wheel.sh @@ -28,9 +28,7 @@ if [[ -z "${TRITON_ROCM_DIR}" ]]; then export TRITON_ROCM_DIR=third_party/amd/backend fi -# Remove packaged libs and headers -rm -rf $TRITON_ROCM_DIR/include/* - +# Remove packaged libs LIBNUMA_PATH="/usr/lib64/libnuma.so.1" LIBELF_PATH="/usr/lib64/libelf.so.1" OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index 0979c6f3f436e..5928a70a42406 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -116,16 +116,27 @@ def build_triton( check_call(["git", "clone", triton_repo, "triton"], cwd=tmpdir) if release: ver, rev, patch = version.split(".") - check_call( - ["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir - ) + if device == "xpu": + # XPU uses the patch version in the release branch name + check_call( + ["git", "checkout", f"release/{ver}.{rev}.{patch}"], + cwd=triton_basedir, + ) + else: + check_call( + ["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir + ) else: check_call(["git", "fetch", "origin", commit_hash], cwd=triton_basedir) check_call(["git", "checkout", commit_hash], cwd=triton_basedir) # change built wheel name and version env["TRITON_WHEEL_NAME"] = triton_pkg_name +<<<<<<< HEAD env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix +======= + env["TRITON_EXT_ENABLED"] = "ON" +>>>>>>> upstream/main if with_clang_ldd: env["TRITON_BUILD_WITH_CLANG_LLD"] = "1" diff --git a/.github/scripts/filter_test_configs.py b/.github/scripts/filter_test_configs.py index 5393f50673ef7..9f7f7a3a7bedc 100755 --- a/.github/scripts/filter_test_configs.py +++ b/.github/scripts/filter_test_configs.py @@ -30,19 +30,31 @@ logging.basicConfig(level=logging.INFO) -def is_cuda_or_rocm_job(job_name: str | None) -> bool: - if not job_name: - return False +def is_cuda_or_rocm_job( + job_name: str | None, config: dict[str, Any] | None = None +) -> bool: + if job_name and ("cuda" in job_name or "rocm" in job_name): + return True - return "cuda" in job_name or "rocm" in job_name + # Also check the runner name in the config, since some workflows (e.g. + # inductor-unittest) use job names that don't include "cuda" even though + # they target CUDA runners. + if config: + runner = config.get("runner", "") + if "nvidia.gpu" in runner or "rocm.gpu" in runner: + return True + + return False # Supported modes when running periodically. Only applying the mode when -# its lambda condition returns true -SUPPORTED_PERIODICAL_MODES: dict[str, Callable[[str | None], bool]] = { +# its lambda condition returns true. Each callable receives (job_name, config). +SUPPORTED_PERIODICAL_MODES: dict[ + str, Callable[[str | None, dict[str, Any] | None], bool] +] = { # Memory leak check is only needed for CUDA and ROCm jobs which utilize GPU memory "mem_leak_check": is_cuda_or_rocm_job, - "rerun_disabled_tests": lambda job_name: True, + "rerun_disabled_tests": lambda job_name, config=None: True, } # The link to the published list of disabled jobs @@ -225,7 +237,7 @@ def set_periodic_modes( for config in test_matrix.get("include", []): for mode, cond in SUPPORTED_PERIODICAL_MODES.items(): - if not cond(job_name): + if not cond(job_name, config): continue cfg = config.copy() diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 4f7995fb6b907..9ce45f70e1fb8 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -23,19 +23,17 @@ REPO_ROOT = SCRIPT_DIR.parent.parent -CUDA_ARCHES = ["12.6", "12.8", "12.9", "13.0"] -CUDA_STABLE = "12.8" +CUDA_ARCHES = ["12.6", "13.0", "13.2"] +CUDA_STABLE = "13.0" CUDA_ARCHES_FULL_VERSION = { "12.6": "12.6.3", - "12.8": "12.8.1", - "12.9": "12.9.1", "13.0": "13.0.2", + "13.2": "13.2.1", } CUDA_ARCHES_CUDNN_VERSION = { "12.6": "9", - "12.8": "9", - "12.9": "9", "13.0": "9", + "13.2": "9", } ROCM_ARCHES = ["7.1", "7.2"] @@ -46,40 +44,40 @@ CPU_S390X_ARCH = ["cpu-s390x"] -CUDA_AARCH64_ARCHES = ["12.6-aarch64", "12.8-aarch64", "12.9-aarch64", "13.0-aarch64"] +CUDA_AARCH64_ARCHES = [ + "12.6-aarch64", + "13.0-aarch64", + "13.2-aarch64", +] +# WARNING: For CUDA 13.0, cublas is pinned to a version range rather +# than an exact version. A broken cublas release within that range will be +# silently pulled in. PYTORCH_EXTRA_INSTALL_REQUIREMENTS = { "12.6": ( - "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.6.3; platform_system == 'Linux' | " # noqa: B950 + "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.6.3; platform_system == 'Linux' | " "cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | " "nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | " "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " "nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | " "nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux'" ), - "12.8": ( - "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | " # noqa: B950 - "cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | " - "nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | " - "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " - "nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | " - "nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux'" - ), - "12.9": ( - "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | " # noqa: B950 - "cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | " - "nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | " - "nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | " - "nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | " - "nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux'" - ), "13.0": ( - "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | " # noqa: B950 + "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | " + "nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | " "cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | " - "nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | " - "nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | " - "nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | " + "nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | " + "nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | " + "nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | " + "nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux'" + ), + "13.2": ( + "cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | " + "cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | " + "nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | " + "nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | " + "nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | " "nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux'" ), "xpu": ( @@ -163,7 +161,7 @@ def read_nccl_pin(arch_version: str) -> str: # Single source of truth for NCCL version from optional_submodules import read_nccl_pin - return read_nccl_pin() + return read_nccl_pin(arch_version) def validate_nccl_dep_consistency(arch_version: str) -> None: @@ -272,12 +270,6 @@ def arch_type(arch_version: str) -> str: RELEASE = "release" DEBUG = "debug" -LIBTORCH_CONTAINER_IMAGES: dict[str, str] = { - **{gpu_arch: f"libtorch-cxx11-builder:cuda{gpu_arch}" for gpu_arch in CUDA_ARCHES}, - **{gpu_arch: f"libtorch-cxx11-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES}, - "cpu": "libtorch-cxx11-builder:cpu", -} - FULL_PYTHON_VERSIONS = ["3.10", "3.11", "3.12", "3.13", "3.13t", "3.14", "3.14t"] @@ -305,15 +297,8 @@ def generate_libtorch_matrix( ) -> list[dict[str, str]]: if arches is None: arches = ["cpu"] - if os == "linux": + if os == "windows": arches += CUDA_ARCHES - arches += ROCM_ARCHES - elif os == "windows": - # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up - # in 2.10 - windows_cuda_arches = CUDA_ARCHES.copy() - windows_cuda_arches.remove("12.9") - arches += windows_cuda_arches if libtorch_variants is None: libtorch_variants = [ "shared-with-deps", @@ -327,9 +312,6 @@ def generate_libtorch_matrix( for libtorch_variant in libtorch_variants: gpu_arch_type = arch_type(arch_version) gpu_arch_version = "" if arch_version == "cpu" else arch_version - # ROCm builds without-deps failed even in ROCm runners; skip for now - if gpu_arch_type == "rocm" and ("without-deps" in libtorch_variant): - continue ret.append( { "gpu_arch_type": gpu_arch_type, @@ -339,16 +321,8 @@ def generate_libtorch_matrix( ), "libtorch_config": release_type, "libtorch_variant": libtorch_variant, - "container_image": ( - LIBTORCH_CONTAINER_IMAGES[arch_version].split(":")[0] - if os not in ("windows", "windows-arm64") - else "" - ), - "container_image_tag_prefix": ( - LIBTORCH_CONTAINER_IMAGES[arch_version].split(":")[1] - if os not in ("windows", "windows-arm64") - else "" - ), + "container_image": "", + "container_image_tag_prefix": "", "package_type": "libtorch", "build_name": f"libtorch-{gpu_arch_type}{gpu_arch_version}-{libtorch_variant}-{release_type}".replace( ".", "_" @@ -377,11 +351,7 @@ def generate_wheels_matrix( if os == "linux": arches += CUDA_ARCHES + ROCM_ARCHES + XPU_ARCHES elif os == "windows": - # TODO (huydhn): Only build CUDA 12.9 for Linux. This logic is to be cleaned up - # in 2.10 - windows_cuda_arches = CUDA_ARCHES.copy() - windows_cuda_arches.remove("12.9") - arches += windows_cuda_arches + XPU_ARCHES + arches += CUDA_ARCHES + XPU_ARCHES elif os == "linux-aarch64": # Separate new if as the CPU type is different and # uses different build/test scripts @@ -417,7 +387,7 @@ def generate_wheels_matrix( # cuda linux wheels require PYTORCH_EXTRA_INSTALL_REQUIREMENTS to install if ( - arch_version in ["13.0", "12.9", "12.8", "12.6"] + arch_version in ["13.2", "13.0", "12.6"] and os == "linux" or arch_version in CUDA_AARCH64_ARCHES ): @@ -472,6 +442,9 @@ def generate_wheels_matrix( "pytorch_extra_install_requirements": ( PYTORCH_EXTRA_INSTALL_REQUIREMENTS["xpu"] if gpu_arch_type == "xpu" + else PYTORCH_EXTRA_INSTALL_REQUIREMENTS[CUDA_STABLE] + if gpu_arch_type == "cpu" + and os in ("windows", "macos-arm64") else "" ), } diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index 2dfe570dad339..dbdd768c9fd48 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -211,7 +211,7 @@ class OperatingSystem: os=OperatingSystem.MACOS_ARM64, package_type="wheel", build_configs=_MACOS_ARM64_WHEEL_CONFIGS, - macos_runner="macos-14-xlarge", + macos_runner="macos-26-xlarge", ciflow_config=CIFlowConfig( labels={ LABEL_CIFLOW_BINARIES, @@ -303,7 +303,7 @@ def main() -> None: if not isinstance(workflows, Iterable): raise Exception( # noqa: TRY002 f"How is workflows not iterable? {workflows}" - ) # noqa: TRY002 + ) for workflow in workflows: workflow.generate_workflow_file(workflow_template=template) diff --git a/.github/scripts/map_ec2_to_arc.py b/.github/scripts/map_ec2_to_arc.py new file mode 100644 index 0000000000000..1ae440c0dc97d --- /dev/null +++ b/.github/scripts/map_ec2_to_arc.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +"""Map EC2 runner labels to ARC equivalents using .github/arc.yaml. + +Takes a GitHub Actions test matrix, replaces each runner with its ARC +equivalent, and prints the updated matrix as JSON. + +Usage: + python map_ec2_to_arc.py --prefix mt- '{ include: [ + { config: "default", shard: 1, num_shards: 5, runner: "mt-linux.4xlarge" }, + ]}' +""" + +import argparse +import json +import os +import sys +from pathlib import Path + +import yaml + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Map EC2 runner labels to ARC runner labels in a test matrix" + ) + parser.add_argument( + "matrix", + help="GitHub Actions test matrix string to transform", + ) + parser.add_argument( + "--prefix", + default="", + help="Runner prefix to strip from labels (e.g. 'mt-')", + ) + return parser.parse_args() + + +def strip_prefix(label: str, prefix: str) -> str: + if prefix and label.startswith(prefix): + return label[len(prefix) :] + return label + + +def load_mapping(arc_yaml: Path) -> dict[str, str]: + with open(arc_yaml) as f: + data = yaml.safe_load(f) + return data["runner_mapping"] + + +def set_output(name: str, val: str) -> None: + print(f"Setting {name}={val}") + github_output = os.getenv("GITHUB_OUTPUT") + if github_output: + with open(github_output, "a") as f: + print(f"{name}={val}", file=f) + + +def main() -> None: + args = parse_args() + arc_yaml = Path(__file__).resolve().parent.parent / "arc.yaml" + mapping = load_mapping(arc_yaml) + + matrix = yaml.safe_load(args.matrix) + if not matrix: + set_output("test-matrix", args.matrix) + return + + entries = matrix.get("include", []) + if not entries: + set_output("test-matrix", json.dumps(matrix)) + return + + # TODO(huydo): onnxruntime uses hardware_concurrency() to size its thread + # pool, which sees all host CPUs (e.g., 192) on ARC k8s instead of the + # container's cpuset (e.g., 16). This causes pthread_setaffinity_np errors. + # Skip onnx tests on ARC until the onnxruntime session options are fixed to + # use cgroup-aware CPU counts. + excluded_configs = {"onnx"} + filtered = [] + for entry in entries: + if entry.get("config") in excluded_configs: + print(f"Excluding config '{entry['config']}' from ARC test matrix") + continue + filtered.append(entry) + matrix["include"] = filtered + + for entry in filtered: + if "runner" not in entry: + continue + clean = strip_prefix(entry["runner"].strip(), args.prefix) + if clean not in mapping: + print(f"error: no ARC runner found for '{clean}'", file=sys.stderr) + sys.exit(1) + mapped = mapping[clean] + # Passthrough runners (e.g. linux.rocm.gpu.2, linux.idc.xpu) are not + # OSDC-managed so they keep their original label without the prefix. + entry["runner"] = mapped if mapped == clean else args.prefix + mapped + + set_output("test-matrix", json.dumps(matrix)) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index baf560234549b..169f83eb17de7 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -1,9 +1,5 @@ # flake8: noqa: G004 -# Note: Copies of this script in runner_determinator.py and _runner-determinator.yml -# must be kept in sync. You can do it easily by running the following command: -# python .github/scripts/update_runner_determinator.py - """ This runner determinator is used to determine which set of runners to run a GitHub job on. It uses the first comment of a GitHub issue (by default @@ -27,7 +23,11 @@ - Users are GitHub usernames, which must start with the @ prefix - Each user is also a comma-separated list of features/experiments to enable +- Each experiment can optionally include a per-user rollout percentage + using the syntax "experiment:percentage" (e.g. "arc:10" for 10% rollout) +- Without a percentage, opted-in experiments are enabled 100% of the time - A "#" prefix opts the user out of all experiments +- A "-" prefix on an experiment opts the user out of that experiment Example config: # A list of experiments that can be opted into. @@ -46,12 +46,14 @@ # Opt-ins: # Users can opt into the LF fleet by adding their GitHub username to this list # and specifying experiments to enable in a comma-separated list. + # Optionally append :N to set a per-user rollout percentage (0-100). # To always opt out of an experiment, prefix it with a "-". # Experiments should be from the above list. @User1,-lf,split_build @User2,lf @User3,split_build + @User4,lf,arc:10 """ import json @@ -79,13 +81,18 @@ GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") GH_OUTPUT_KEY_AMI = "runner-ami" GH_OUTPUT_KEY_LABEL_TYPE = "label-type" +GH_OUTPUT_KEY_USE_ARC = "use-arc" OPT_OUT_LABEL = "no-runner-experiments" SETTING_EXPERIMENTS = "experiments" LF_FLEET_EXPERIMENT = "lf" +ARC_FLEET_EXPERIMENT = "arc" CANARY_FLEET_SUFFIX = ".c" +ARC_LABEL_PREFIX = "mt-" +ARC_CANARY_LABEL_PREFIX = "c-" + class Experiment(NamedTuple): rollout_perc: float = ( @@ -101,6 +108,11 @@ class Experiment(NamedTuple): # Add more fields as needed +class RunnerPrefixResult(NamedTuple): + prefix: str + use_arc: bool = False + + class Settings(NamedTuple): """ Settings for the experiments that can be opted into. @@ -295,9 +307,18 @@ def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str return "", rollout_state -class UserOptins(dict[str, list[str]]): +class UserExperimentConfig(NamedTuple): + """ + Per-user experiment configuration parsed from the opt-in line. + """ + + name: str + rollout_perc: float = 100 # default: always enabled when opted in + + +class UserOptins(dict[str, list[UserExperimentConfig]]): """ - Dictionary of users with a list of features they have opted into + Dictionary of users with a list of experiment configs they have opted into """ @@ -320,7 +341,32 @@ def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: if user: usr_name = user.split(",")[0].strip("@") - optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] + configs = [] + for exp_str in user.split(",")[1:]: + exp_str = exp_str.strip(" ") + if not exp_str: + continue + # Parse optional per-user rollout percentage (e.g. "arc:10") + # Opt-out entries (e.g. "-lf") never have a percentage + if ":" in exp_str and not exp_str.startswith("-"): + name, perc_str = exp_str.split(":", 1) + try: + perc = float(perc_str) + except ValueError: + log.warning( + f"Invalid rollout percentage for user {usr_name}, experiment {exp_str}. Defaulting to 100%." + ) + perc = 100 + if not (0 <= perc <= 100): + log.warning( + f"Rollout percentage {perc} for user {usr_name}, experiment {name} " + f"is out of range [0, 100]. Clamping." + ) + perc = max(0.0, min(100.0, perc)) + configs.append(UserExperimentConfig(name=name, rollout_perc=perc)) + else: + configs.append(UserExperimentConfig(name=exp_str, rollout_perc=100)) + optins[usr_name] = configs return optins @@ -352,11 +398,8 @@ def parse_settings_from_text(settings_text: str) -> Settings: """ try: if settings_text: - # Escape the backtick as well so that we can have the settings in a code block on the GH issue - # for easy reading - # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on - # the backtick character in shell commands. - backtick = chr(96) # backtick character + # Strip backticks so settings can be in a code block on the GH issue + backtick = chr(96) settings_text = settings_text.strip(f"\r\n\t{backtick} ") settings = load_yaml(settings_text) @@ -408,11 +451,24 @@ def parse_users(rollout_state: str) -> UserOptins: return parse_user_opt_in_from_text(users_text) +def get_user_experiment_config( + user: str, user_optins: UserOptins, experiment_name: str +) -> UserExperimentConfig | None: + """ + Get a user's experiment config if they are opted in. + Returns None if the user is not opted into the experiment. + """ + for config in user_optins.get(user, []): + if config.name == experiment_name: + return config + return None + + def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: """ Check if a user is opted into an experiment """ - return experiment_name in user_optins.get(user, []) + return get_user_experiment_config(user, user_optins, experiment_name) is not None def is_user_opted_out(user: str, user_optins: UserOptins, experiment_name: str) -> bool: @@ -421,7 +477,10 @@ def is_user_opted_out(user: str, user_optins: UserOptins, experiment_name: str) """ # if the experiment is prefixed with a "-", then it's an opt-out experiment_optout = "-" + experiment_name - if experiment_optout not in user_optins.get(user, []): + opted_out = any( + config.name == experiment_optout for config in user_optins.get(user, []) + ) + if not opted_out: return False if is_user_opted_in(user, user_optins, experiment_name): @@ -439,12 +498,13 @@ def get_runner_prefix( eligible_experiments: frozenset[str] = frozenset(), opt_out_experiments: frozenset[str] = frozenset(), is_canary: bool = False, -) -> str: +) -> RunnerPrefixResult: settings = parse_settings(rollout_state) user_optins = parse_users(rollout_state) fleet_prefix = "" prefixes = [] + use_arc = False for experiment_name, experiment_settings in settings.experiments.items(): if not experiment_settings.all_branches and is_exception_branch(branch): log.info( @@ -495,10 +555,37 @@ def get_runner_prefix( enabled = False if opted_in_users: - log.info( - f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." - ) - enabled = True + # Get the minimum per-user rollout percentage among opted-in requesters. + # This is conservative: if the PR author sets 10%, that intent is respected + # even if the triggering actor (e.g. pytorchmergebot) has 100%. + user_rollout_percs = [ + get_user_experiment_config(u, user_optins, experiment_name).rollout_perc + for u in opted_in_users + ] + min_perc = min(user_rollout_percs) + + if min_perc >= 100: + log.info( + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." + ) + enabled = True + elif min_perc > 0: + if random.uniform(0, 100) <= min_perc: + log.info( + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name} " + f"with {min_perc}% rollout. Enabling this run." + ) + enabled = True + else: + log.info( + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name} " + f"with {min_perc}% rollout. Not enabling this run." + ) + else: + log.info( + f"{', '.join(opted_in_users)} have opted into experiment {experiment_name} " + f"with 0% rollout. Not enabling." + ) elif experiment_settings.rollout_perc: # If no user is opted in, then we randomly enable the experiment based on the rollout percentage @@ -510,7 +597,12 @@ def get_runner_prefix( if enabled: label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: + if experiment_name == ARC_FLEET_EXPERIMENT: + use_arc = True + log.info( + f"ARC experiment enabled. Using ARC runner prefix ({'canary' if is_canary else 'production'})." + ) + elif experiment_name == LF_FLEET_EXPERIMENT: # We give some special treatment to the "lf" experiment since determines the fleet we use # - If it's enabled, then we always list it's prefix first # - If we're in the canary branch, then we append ".c" to the lf prefix @@ -520,6 +612,15 @@ def get_runner_prefix( else: prefixes.append(label) + # ARC experiment takes precedence: return a fixed label prefix + if use_arc: + arc_prefix = ( + ARC_CANARY_LABEL_PREFIX + ARC_LABEL_PREFIX + if is_canary + else ARC_LABEL_PREFIX + ) + return RunnerPrefixResult(prefix=arc_prefix, use_arc=True) + if len(prefixes) > 1: log.error( f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" @@ -530,7 +631,8 @@ def get_runner_prefix( if fleet_prefix: prefixes.insert(0, fleet_prefix) - return ".".join(prefixes) + "." if prefixes else "" + prefix = ".".join(prefixes) + "." if prefixes else "" + return RunnerPrefixResult(prefix=prefix) def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: @@ -619,7 +721,7 @@ def main() -> None: is_canary = args.github_repo == "pytorch/pytorch-canary" - runner_label_prefix = get_runner_prefix( + result = get_runner_prefix( rollout_state, (args.github_issue_owner, username), args.github_branch, @@ -627,6 +729,8 @@ def main() -> None: args.opt_out_experiments, is_canary, ) + runner_label_prefix = result.prefix + set_github_output(GH_OUTPUT_KEY_USE_ARC, str(result.use_arc).lower()) except Exception as e: log.error( diff --git a/.github/scripts/test_filter_test_configs.py b/.github/scripts/test_filter_test_configs.py index 26e38828b7865..1463f82f61eda 100755 --- a/.github/scripts/test_filter_test_configs.py +++ b/.github/scripts/test_filter_test_configs.py @@ -376,13 +376,31 @@ def test_set_periodic_modes(self) -> None: scheduled_test_matrix = set_periodic_modes(test_matrix, job_name) expected_modes = [ - m for m, c in SUPPORTED_PERIODICAL_MODES.items() if c(job_name) + m for m, c in SUPPORTED_PERIODICAL_MODES.items() if c(job_name, None) ] self.assertEqual( len(test_matrix["include"]) * len(expected_modes), len(scheduled_test_matrix["include"]), ) + def test_set_periodic_modes_gpu_runner(self) -> None: + """Job name without 'cuda' but runner indicates a CUDA job (e.g. inductor-unittest).""" + test_matrix = yaml.safe_load( + "{include: [" + '{config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu"}, ' + '{config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu"}' + "]}" + ) + scheduled = set_periodic_modes(test_matrix, "inductor-build / build") + + modes_per_config = [ + entry.get("mem_leak_check") or entry.get("rerun_disabled_tests") + for entry in scheduled["include"] + ] + self.assertIn("mem_leak_check", modes_per_config) + self.assertIn("rerun_disabled_tests", modes_per_config) + self.assertEqual(len(scheduled["include"]), 4) + @mock.patch("filter_test_configs.download_json") def test_remove_disabled_jobs(self, mock_download_json: Any) -> None: mock_download_json.return_value = MOCKED_DISABLED_UNSTABLE_JOBS diff --git a/.github/scripts/test_map_ec2_to_arc.py b/.github/scripts/test_map_ec2_to_arc.py new file mode 100644 index 0000000000000..d139c5253e456 --- /dev/null +++ b/.github/scripts/test_map_ec2_to_arc.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 + +import json +import os +import subprocess +import sys +import tempfile +from pathlib import Path + + +SCRIPT = Path(__file__).resolve().parent / "map_ec2_to_arc.py" + + +def run( + matrix: str, prefix: str = "", github_output: str | None = None +) -> subprocess.CompletedProcess: + cmd = [sys.executable, str(SCRIPT)] + if prefix: + cmd += ["--prefix", prefix] + cmd.append(matrix) + + env = os.environ.copy() + if github_output is not None: + env["GITHUB_OUTPUT"] = github_output + else: + env.pop("GITHUB_OUTPUT", None) + + return subprocess.run(cmd, capture_output=True, text=True, env=env) + + +def parse_output(stdout: str) -> dict: + """Extract the JSON matrix from the 'Setting test-matrix=...' line.""" + prefix = "Setting test-matrix=" + for line in stdout.splitlines(): + if line.startswith(prefix): + return json.loads(line[len(prefix) :]) + raise ValueError(f"no test-matrix output found in: {stdout}") + + +def check(condition: bool, msg: str = "") -> None: + if not condition: + raise AssertionError(msg) + + +def test_basic_matrix(): + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + ]}""" + result = run(matrix) + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + runners = [e["runner"] for e in output["include"]] + check(runners == ["l-x86iavx512-16-128", "l-x86iavx512-8-64"]) + + +def test_matrix_with_prefix(): + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 7, runner: "mt-linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 7, runner: "mt-linux.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "mt-linux.2xlarge" }, + ]}""" + result = run(matrix, prefix="mt-") + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + runners = [e["runner"] for e in output["include"]] + check( + runners + == [ + "mt-l-x86iavx512-16-128", + "mt-l-x86iavx512-16-128", + "mt-l-x86iavx512-8-64", + ] + ) + + +def test_matrix_without_prefix_when_none_present(): + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + ]}""" + result = run(matrix) + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + check(output["include"][0]["runner"] == "l-x86aavx2-29-113-a10g") + + +def test_unknown_runner_fails(): + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "bogus.runner" }, + ]}""" + result = run(matrix) + check(result.returncode == 1) + check("no ARC runner found for 'bogus.runner'" in result.stderr) + + +def test_prefix_not_present_on_runner(): + """When --prefix is given but a runner doesn't have it, the raw label is used.""" + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" }, + ]}""" + result = run(matrix, prefix="mt-") + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + check(output["include"][0]["runner"] == "mt-l-x86iavx512-16-128") + + +def test_preserves_non_runner_fields(): + matrix = """{ include: [ + { config: "default", shard: 3, num_shards: 7, runner: "linux.large" }, + ]}""" + result = run(matrix) + check(result.returncode == 0, result.stderr) + entry = parse_output(result.stdout)["include"][0] + check(entry["config"] == "default") + check(entry["shard"] == 3) + check(entry["num_shards"] == 7) + check(entry["runner"] == "l-x86iavx512-2-4") + + +def test_empty_include_passes_through(): + matrix = """{ include: [] }""" + result = run(matrix) + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + check(output == {"include": []}, f"expected empty include, got {output}") + + +def test_empty_string_passes_through(): + result = run("") + check(result.returncode == 0, result.stderr) + + +def test_mixed_runners(): + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "linux.4xlarge" }, + { config: "gpu", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "arm", shard: 1, num_shards: 1, runner: "linux.arm64.2xlarge" }, + ]}""" + result = run(matrix) + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + runners = [e["runner"] for e in output["include"]] + check( + runners + == [ + "l-x86iavx512-16-128", + "l-x86aavx2-29-113-a10g", + "l-arm64g2-6-32", + ] + ) + + +def test_passthrough_runner_no_prefix(): + """Passthrough runners (ROCm, XPU) should not get the OSDC prefix.""" + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.2" }, + { config: "default", shard: 1, num_shards: 1, runner: "linux.idc.xpu" }, + ]}""" + result = run(matrix, prefix="mt-") + check(result.returncode == 0, result.stderr) + output = parse_output(result.stdout) + runners = [e["runner"] for e in output["include"]] + check( + runners == ["linux.rocm.gpu.2", "linux.idc.xpu"], + f"passthrough runners should not get prefix, got {runners}", + ) + + +def test_github_output_file(): + """When GITHUB_OUTPUT is set, the script writes test-matrix to that file.""" + matrix = """{ include: [ + { config: "default", shard: 1, num_shards: 1, runner: "linux.2xlarge" }, + ]}""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f: + tmp_path = f.name + + try: + result = run(matrix, github_output=tmp_path) + check(result.returncode == 0, result.stderr) + + contents = Path(tmp_path).read_text() + check( + contents.startswith("test-matrix="), f"unexpected file contents: {contents}" + ) + written = json.loads(contents[len("test-matrix=") :].strip()) + check(written["include"][0]["runner"] == "l-x86iavx512-8-64") + finally: + os.unlink(tmp_path) + + +if __name__ == "__main__": + tests = [v for k, v in sorted(globals().items()) if k.startswith("test_")] + failed = 0 + for t in tests: + try: + t() + print(f" PASS {t.__name__}") + except AssertionError as e: + print(f" FAIL {t.__name__}: {e}") + failed += 1 + if failed: + print(f"\n{failed}/{len(tests)} tests failed") + sys.exit(1) + print(f"\nAll {len(tests)} tests passed") diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py index e8f9f1b8b4aa6..be31ad96baf00 100644 --- a/.github/scripts/test_runner_determinator.py +++ b/.github/scripts/test_runner_determinator.py @@ -146,10 +146,16 @@ def test_parse_users(self) -> None: """ users = rd.parse_users(settings_text) - self.assertDictEqual( - {"User1": ["lf"], "User2": ["lf", "otherExp"]}, - users, - "Users not parsed correctly", + self.assertEqual( + [rd.UserExperimentConfig("lf", 100)], + users["User1"], + ) + self.assertEqual( + [ + rd.UserExperimentConfig("lf", 100), + rd.UserExperimentConfig("otherExp", 100), + ], + users["User2"], ) def test_parse_users_without_settings(self) -> None: @@ -161,10 +167,97 @@ def test_parse_users_without_settings(self) -> None: """ users = rd.parse_users(settings_text) - self.assertDictEqual( - {"User1": ["lf"], "User2": ["lf", "otherExp"]}, - users, - "Users not parsed correctly", + self.assertEqual( + [rd.UserExperimentConfig("lf", 100)], + users["User1"], + ) + self.assertEqual( + [ + rd.UserExperimentConfig("lf", 100), + rd.UserExperimentConfig("otherExp", 100), + ], + users["User2"], + ) + + def test_parse_users_with_rollout_perc(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + arc: + rollout_perc: 0 + --- + + Users: + @User1,lf,arc:10 + @User2,arc:50 + @User3,lf + + """ + + users = rd.parse_users(settings_text) + self.assertEqual( + [ + rd.UserExperimentConfig("lf", 100), + rd.UserExperimentConfig("arc", 10), + ], + users["User1"], + ) + self.assertEqual( + [rd.UserExperimentConfig("arc", 50)], + users["User2"], + ) + self.assertEqual( + [rd.UserExperimentConfig("lf", 100)], + users["User3"], + ) + + def test_parse_users_invalid_percentage_defaults_to_100(self) -> None: + """Non-numeric percentage like arc:abc should default to 100%.""" + settings_text = """ + @User1,arc:abc + """ + + users = rd.parse_users(settings_text) + self.assertEqual( + [rd.UserExperimentConfig("arc", 100)], + users["User1"], + ) + + def test_parse_users_negative_percentage_clamped_to_zero(self) -> None: + """Negative percentage like arc:-5 should be clamped to 0.""" + settings_text = """ + @User1,arc:-5 + """ + + users = rd.parse_users(settings_text) + self.assertEqual( + [rd.UserExperimentConfig("arc", 0)], + users["User1"], + ) + + def test_parse_users_over_100_percentage_clamped(self) -> None: + """Percentage over 100 like arc:200 should be clamped to 100.""" + settings_text = """ + @User1,arc:200 + """ + + users = rd.parse_users(settings_text) + self.assertEqual( + [rd.UserExperimentConfig("arc", 100)], + users["User1"], + ) + + def test_parse_users_opt_out_ignores_percentage(self) -> None: + """Opt-out entries like -lf should not parse a percentage.""" + settings_text = """ + @User1,-lf + """ + + users = rd.parse_users(settings_text) + self.assertEqual( + [rd.UserExperimentConfig("-lf", 100)], + users["User1"], ) @@ -183,8 +276,8 @@ def test_opted_in_user(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for User1") def test_explicitly_opted_out_user(self) -> None: settings_text = """ @@ -200,8 +293,8 @@ def test_explicitly_opted_out_user(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for User1") def test_explicitly_opted_in_and_out_user_should_opt_out(self) -> None: settings_text = """ @@ -217,8 +310,8 @@ def test_explicitly_opted_in_and_out_user_should_opt_out(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for User1") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for User1") def test_opted_in_user_two_experiments(self) -> None: settings_text = """ @@ -234,8 +327,10 @@ def test_opted_in_user_two_experiments(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for User2" + ) def test_opted_in_user_two_experiments_default(self) -> None: settings_text = """ @@ -252,8 +347,8 @@ def test_opted_in_user_two_experiments_default(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for User2") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for User2") def test_opted_in_user_two_experiments_default_exp(self) -> None: settings_text = """ @@ -270,10 +365,12 @@ def test_opted_in_user_two_experiments_default_exp(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User2"], USER_BRANCH, frozenset(["lf", "otherExp"]) ) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for User2" + ) def test_opted_in_user_two_experiments_default_exp_2(self) -> None: settings_text = """ @@ -290,10 +387,12 @@ def test_opted_in_user_two_experiments_default_exp_2(self) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User2"], USER_BRANCH, frozenset(["otherExp"]) ) - self.assertEqual("otherExp.", prefix, "Runner prefix not correct for User2") + self.assertEqual( + "otherExp.", result.prefix, "Runner prefix not correct for User2" + ) @patch("random.uniform", return_value=50) def test_opted_out_user(self, mock_uniform: Mock) -> None: @@ -310,8 +409,8 @@ def test_opted_out_user(self, mock_uniform: Mock) -> None: @User2,lf,otherExp """ - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> None: @@ -330,8 +429,10 @@ def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> No """ # User3 is opted out, but is pulled into both experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout_excl_nondefault( @@ -353,8 +454,8 @@ def test_opted_out_user_was_pulled_in_by_rollout_excl_nondefault( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") @patch("random.uniform", return_value=10) def test_opted_out_user_was_pulled_in_by_rollout_filter_exp( @@ -376,10 +477,12 @@ def test_opted_out_user_was_pulled_in_by_rollout_filter_exp( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix( + result = rd.get_runner_prefix( settings_text, ["User3"], USER_BRANCH, frozenset(["otherExp"]) ) - self.assertEqual("otherExp.", prefix, "Runner prefix not correct for user") + self.assertEqual( + "otherExp.", result.prefix, "Runner prefix not correct for user" + ) @patch("random.uniform", return_value=25) def test_opted_out_user_was_pulled_out_by_rollout_filter_exp( @@ -401,8 +504,8 @@ def test_opted_out_user_was_pulled_out_by_rollout_filter_exp( """ # User3 is opted out, but is pulled into default experiments by the 10% rollout - prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_lf_prefix_always_comes_first(self) -> None: settings_text = """ @@ -419,8 +522,10 @@ def test_lf_prefix_always_comes_first(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) def test_ignores_commented_users(self) -> None: settings_text = """ @@ -437,8 +542,8 @@ def test_ignores_commented_users(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_ignores_extra_experiments(self) -> None: settings_text = """ @@ -456,8 +561,10 @@ def test_ignores_extra_experiments(self) -> None: """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) - self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) def test_disables_experiment_on_exception_branches_when_not_explicitly_opted_in( self, @@ -473,8 +580,8 @@ def test_disables_experiment_on_exception_branches_when_not_explicitly_opted_in( """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) - self.assertEqual("", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") def test_allows_experiment_on_exception_branches_when_explicitly_opted_in( self, @@ -491,8 +598,341 @@ def test_allows_experiment_on_exception_branches_when_explicitly_opted_in( """ - prefix = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) - self.assertEqual("lf.", prefix, "Runner prefix not correct for user") + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=5) + def test_opted_in_user_with_rollout_perc_enabled(self, mock_uniform: Mock) -> None: + """User opted in with 10% rollout, random=5 -> enabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=50) + def test_opted_in_user_with_rollout_perc_disabled(self, mock_uniform: Mock) -> None: + """User opted in with 10% rollout, random=50 -> disabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") + + def test_opted_in_user_without_rollout_perc_always_enabled(self) -> None: + """User opted in without percentage (default 100%) -> always enabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=15) + def test_multiple_requesters_uses_min_perc(self, mock_uniform: Mock) -> None: + """Two requesters with different rollout_percs, uses the minimum (10%).""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf:10 + @User2,lf:50 + + """ + + # random=15, min_perc=10 -> 15 > 10 -> disabled + result = rd.get_runner_prefix(settings_text, ["User1", "User2"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=5) + def test_multiple_requesters_uses_min_perc_enabled( + self, mock_uniform: Mock + ) -> None: + """Two requesters with different rollout_percs, min=10%, random=5 -> enabled.""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf:10 + @User2,lf:50 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1", "User2"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + def test_opt_out_overrides_rollout_perc(self) -> None: + """Opt-out (-lf) wins over opt-in with rollout_perc (lf:50).""" + settings_text = """ + experiments: + lf: + rollout_perc: 100 + --- + + Users: + @User1,-lf,lf:50 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=5) + def test_opted_in_user_with_rollout_perc_two_experiments( + self, mock_uniform: Mock + ) -> None: + """User opted into lf at 100% and otherExp at 10%, random=5 -> both enabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf,otherExp:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual( + "lf.otherExp.", result.prefix, "Runner prefix not correct for user" + ) + + @patch("random.uniform", return_value=50) + def test_opted_in_user_with_rollout_perc_partial_enable( + self, mock_uniform: Mock + ) -> None: + """User opted into lf at 100% and otherExp at 10%, random=50 -> only lf enabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + --- + + Users: + @User1,lf,otherExp:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("lf.", result.prefix, "Runner prefix not correct for user") + + def test_opted_in_user_with_zero_rollout_perc(self) -> None: + """User opted in with 0% rollout -> never enabled""" + settings_text = """ + experiments: + lf: + rollout_perc: 0 + --- + + Users: + @User1,lf:0 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=5) + def test_arc_opted_in_user_with_rollout_perc_enabled( + self, mock_uniform: Mock + ) -> None: + """User opted into arc with 10% rollout, random=5 -> arc enabled""" + settings_text = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + @patch("random.uniform", return_value=50) + def test_arc_opted_in_user_with_rollout_perc_disabled( + self, mock_uniform: Mock + ) -> None: + """User opted into arc with 10% rollout, random=50 -> arc disabled""" + settings_text = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc:10 + + """ + + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + +class TestRunnerDeterminatorArcExperiment(TestCase): + ARC_SETTINGS = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc + @User2,lf + + """ + + def test_arc_opted_in_user_returns_mt_prefix(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User1"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_opted_in_user_canary_returns_c_mt_prefix(self) -> None: + result = rd.get_runner_prefix( + self.ARC_SETTINGS, ["User1"], USER_BRANCH, is_canary=True + ) + self.assertEqual("c-mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_not_enabled_returns_use_arc_false(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User2"], USER_BRANCH) + self.assertFalse(result.use_arc) + + def test_arc_not_enabled_no_experiments_returns_use_arc_false(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 0 + --- + + Users: + @User1,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + @patch("random.uniform", return_value=10) + def test_arc_rollout_percentage(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 25 + --- + + Users: + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + @patch("random.uniform", return_value=50) + def test_arc_rollout_percentage_not_selected(self, mock_uniform: Mock) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 25 + --- + + Users: + + """ + result = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_opted_out_user(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 100 + --- + + Users: + @User1,-arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_exception_branch_not_enabled(self) -> None: + result = rd.get_runner_prefix(self.ARC_SETTINGS, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("", result.prefix) + self.assertFalse(result.use_arc) + + def test_arc_exception_branch_all_branches(self) -> None: + settings_text = """ + experiments: + arc: + rollout_perc: 0 + all_branches: true + --- + + Users: + @User1,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], EXCEPTION_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) + + def test_arc_takes_precedence_over_lf(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + arc: + rollout_perc: 0 + --- + + Users: + @User1,lf,arc + + """ + result = rd.get_runner_prefix(settings_text, ["User1"], USER_BRANCH) + self.assertEqual("mt-", result.prefix) + self.assertTrue(result.use_arc) if __name__ == "__main__": diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 761af8b691d9b..658728b096d47 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -2034,9 +2034,13 @@ def validate_revert( # For some reason, one can not be a member of private repo, only CONTRIBUTOR if pr.is_base_repo_private(): allowed_reverters.append("CONTRIBUTOR") - # Special case the pytorch-auto-revert app, whose does not have association - # But should be able to issue revert command - if comment.author_url == "https://github.com/apps/pytorch-auto-revert": + # Special case GitHub Apps that don't have a repo association + # but should be able to issue revert commands + allowed_apps = { + "https://github.com/apps/pytorch-auto-revert", + "https://github.com/apps/facebook-github-tools", + } + if comment.author_url in allowed_apps: allowed_reverters.append("NONE") if author_association not in allowed_reverters: @@ -2174,6 +2178,13 @@ def try_revert( f"Failed to fetch dependent PRs: {str(e)}, fall over to single revert" ) + if not shas_and_prs: + raise RuntimeError( + f"No revertable PRs found in ghstack for #{pr.pr_num}. " + f"This typically means the PR is still open (not merged) or " + f"its GitHub state is inconsistent. Only closed/merged PRs can be reverted." + ) + do_revert_prs( repo, pr, diff --git a/.github/scripts/update_runner_determinator.py b/.github/scripts/update_runner_determinator.py deleted file mode 100755 index 772df87c6405a..0000000000000 --- a/.github/scripts/update_runner_determinator.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 - -import re - - -# Read the contents of runner_determinator.py -with open(".github/scripts/runner_determinator.py") as script_file: - script_content = script_file.read() - -# Indent the script content by 10 spaces to match destination indentation -indented_script_content = "\n".join( - [" " * 10 + line if line else line for line in script_content.splitlines()] -) - -# Read the contents of _runner-determinator.yml -with open(".github/workflows/_runner-determinator.yml") as yml_file: - yml_content = yml_file.read() - -# Replace the content between the markers -new_yml_content = re.sub( - r"(cat < runner_determinator.py\n)(.*?)(\n\s+EOF)", - lambda match: match.group(1) + indented_script_content + match.group(3), - yml_content, - flags=re.DOTALL, -) - -# Save the modified content back to _runner-determinator.yml -with open(".github/workflows/_runner-determinator.yml", "w") as yml_file: - yml_file.write(new_yml_content) - -print("Updated _runner-determinator.yml with the contents of runner_determinator.py") diff --git a/.github/scripts/windows/build_magma.bat b/.github/scripts/windows/build_magma.bat index 75c916ecdbef7..ddb38c9ed9e6d 100644 --- a/.github/scripts/windows/build_magma.bat +++ b/.github/scripts/windows/build_magma.bat @@ -35,6 +35,9 @@ cd magma mkdir build && cd build set GPU_TARGET=All +if "%CUVER_NODOT%" == "132" ( + set CUDA_ARCH_LIST=-gencode=arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 +) if "%CUVER_NODOT%" == "130" ( set CUDA_ARCH_LIST=-gencode=arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90 -gencode arch=compute_100,code=sm_100 -gencode arch=compute_120,code=sm_120 ) diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 21a978f801116..4182ae6c33874 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -79,7 +79,7 @@ jobs: timeout-minutes: 420 {%- elif config["gpu_arch_type"] == "rocm" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 {%- elif "conda" in build_environment and config["gpu_arch_type"] == "cuda" %} runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.24xlarge.ephemeral @@ -183,7 +183,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 diff --git a/.github/workflows/_bazel-build-test.yml b/.github/workflows/_bazel-build-test.yml deleted file mode 100644 index eaebce92ba898..0000000000000 --- a/.github/workflows/_bazel-build-test.yml +++ /dev/null @@ -1,215 +0,0 @@ -name: bazel - -on: - workflow_call: - inputs: - build-environment: - required: true - type: string - description: Top-level label for what's being built/tested. - docker-image-name: - required: true - type: string - description: Name of the base docker image to build with. - cuda-version: - required: true - type: string - description: What CUDA version to build with (i.e. "11.7"), "cpu" for none. - sync-tag: - required: false - type: string - default: "" - description: | - If this is set, our linter will use this to make sure that every other - job with the same `sync-tag` is identical. - test-matrix: - required: true - type: string - description: | - A JSON description of what configs to run later on. - runner: - required: false - type: string - default: "linux.large" - description: Runner type - -env: - GIT_DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} - -jobs: - filter: - if: github.repository_owner == 'pytorch' - runs-on: ${{ inputs.runner }} - outputs: - test-matrix: ${{ steps.filter.outputs.test-matrix }} - is-test-matrix-empty: ${{ steps.filter.outputs.is-test-matrix-empty }} - keep-going: ${{ steps.filter.outputs.keep-going }} - reenabled-issues: ${{ steps.filter.outputs.reenabled-issues }} - steps: - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - with: - fetch-depth: 1 - submodules: false - - - name: Select all requested test configurations - id: filter - uses: ./.github/actions/filter-test-configs - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - test-matrix: ${{ inputs.test-matrix }} - - build-and-test: - needs: filter - # Don't run on forked repos. - if: github.repository_owner == 'pytorch' && needs.filter.outputs.is-test-matrix-empty == 'False' - strategy: - matrix: ${{ fromJSON(needs.filter.outputs.test-matrix) }} - fail-fast: false - runs-on: ${{ matrix.runner }} - steps: - - name: Setup SSH (Click me for login details) - uses: pytorch/test-infra/.github/actions/setup-ssh@main - with: - github-secret: ${{ secrets.GITHUB_TOKEN }} - - # [see note: pytorch repo ref] - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - - - name: Setup Linux - uses: ./.github/actions/setup-linux - - - name: Login to ECR - uses: ./.github/actions/ecr-login - - - name: Calculate docker image - id: calculate-docker-image - uses: pytorch/test-infra/.github/actions/calculate-docker-image@main - with: - docker-image-name: ${{ inputs.docker-image-name }} - - - name: Pull docker image - uses: pytorch/test-infra/.github/actions/pull-docker-image@main - with: - docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - - - name: Check if in a container runner - shell: bash - id: check_container_runner - run: echo "IN_CONTAINER_RUNNER=$(if [ -f /.inarc ] || [ -f /.incontainer ]; then echo true ; else echo false; fi)" >> "$GITHUB_OUTPUT" - - - name: Install nvidia driver, nvidia-docker runtime, set GPU_FLAG - uses: pytorch/test-infra/.github/actions/setup-nvidia@main - - - name: Output disk space left - run: | - sudo df -H - - - name: Preserve github env variables for use in docker - run: | - env | grep '^GITHUB' >> "/tmp/github_env_${GITHUB_RUN_ID}" - env | grep '^CI' >> "/tmp/github_env_${GITHUB_RUN_ID}" - - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - - - name: Get workflow job id - id: get-job-id - uses: ./.github/actions/get-workflow-job-id - if: always() - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - - - name: Build - env: - BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - PR_NUMBER: ${{ github.event.pull_request.number }} - BRANCH: ${{ steps.parse-ref.outputs.branch }} - GITHUB_REPOSITORY: ${{ github.repository }} - GITHUB_WORKFLOW: ${{ github.workflow }} - GITHUB_JOB: ${{ github.job }} - GITHUB_RUN_ID: ${{ github.run_id }} - GITHUB_RUN_NUMBER: ${{ github.run_number }} - GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} - JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - REENABLED_ISSUES: ${{ needs.filter.outputs.reenabled-issues }} - # TODO duplicated - AWS_DEFAULT_REGION: us-east-1 - SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 - SCCACHE_REGION: us-east-1 - TORCH_CUDA_ARCH_LIST: 5.2 - DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} - OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - CUDA_VERSION: ${{ inputs.cuda-version }} - run: | - export SHARD_NUMBER=0 - # detached container should get cleaned up by teardown_ec2_linux - # TODO: Stop building test binaries as part of the build phase - # Make sure we copy test results from bazel-testlogs symlink to - # a regular directory ./test/test-reports - # shellcheck disable=SC2086 - container_name=$(docker run \ - ${GPU_FLAG:-} \ - -e AWS_DEFAULT_REGION \ - -e BUILD_ENVIRONMENT \ - -e GITHUB_ACTIONS \ - -e GITHUB_REPOSITORY \ - -e GITHUB_WORKFLOW \ - -e GITHUB_JOB \ - -e GITHUB_RUN_NUMBER \ - -e GITHUB_RUN_ATTEMPT \ - -e JOB_ID \ - -e GIT_DEFAULT_BRANCH="$GIT_DEFAULT_BRANCH" \ - -e SHARD_NUMBER \ - -e NUM_TEST_SHARDS \ - -e MAX_JOBS="$(nproc --ignore=2)" \ - -e SCCACHE_BUCKET \ - -e SCCACHE_REGION \ - -e SKIP_SCCACHE_INITIALIZATION=1 \ - -e REENABLED_ISSUES \ - -e TORCH_CUDA_ARCH_LIST \ - -e OUR_GITHUB_JOB_ID \ - -e CUDA_VERSION \ - --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ - --security-opt seccomp=unconfined \ - --cap-add=SYS_PTRACE \ - --shm-size="1g" \ - --tty \ - --detach \ - --user jenkins \ - -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ - -w /var/lib/jenkins/workspace \ - "${DOCKER_IMAGE}" - ) - docker exec -t "${container_name}" sh -c '.ci/pytorch/build.sh' - echo "container_id=${container_name}" >> "${GITHUB_ENV}" - - - name: Test - id: test - # Time out the test phase after 3.5 hours - timeout-minutes: 120 - run: | - docker exec -t "${container_id}" sh -c '.ci/pytorch/test.sh && cp -Lr ./bazel-testlogs ./test/test-reports' - - - name: Print remaining test logs - shell: bash - if: always() && steps.test.conclusion - run: | - cat test/**/*_toprint.log || true - - - name: Chown workspace - uses: ./.github/actions/chown-workspace - if: always() - - - name: Upload test artifacts - uses: ./.github/actions/upload-test-artifacts - if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' - with: - file-suffix: bazel-${{ github.job }}_${{ steps.get-job-id.outputs.job-id }} - - - name: Teardown Linux - uses: pytorch/test-infra/.github/actions/teardown-linux@main - if: always() diff --git a/.github/workflows/_binary-build-flash-attention-wheel-linux.yml b/.github/workflows/_binary-build-flash-attention-wheel-linux.yml index 3fdc1dc4175c9..aa945af50d39c 100644 --- a/.github/workflows/_binary-build-flash-attention-wheel-linux.yml +++ b/.github/workflows/_binary-build-flash-attention-wheel-linux.yml @@ -88,14 +88,11 @@ jobs: github-secret: ${{ secrets.GITHUB_TOKEN }} fail-silently: false - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main with: submodules: true - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Pull Docker image uses: pytorch/test-infra/.github/actions/pull-docker-image@main with: diff --git a/.github/workflows/_binary-build-flash-attention-wheel-windows.yml b/.github/workflows/_binary-build-flash-attention-wheel-windows.yml index 4fc1dc8a53367..5339753a9526f 100644 --- a/.github/workflows/_binary-build-flash-attention-wheel-windows.yml +++ b/.github/workflows/_binary-build-flash-attention-wheel-windows.yml @@ -90,7 +90,7 @@ jobs: echo "CUDA_HOME=${CUDA_PATH}" >> "${GITHUB_ENV}" echo "${CUDA_PATH}/bin" >> "${GITHUB_PATH}" - name: Setup MSVC - uses: ilammy/msvc-dev-cmd@v1 + uses: ilammy/msvc-dev-cmd@0b201ec74fa43914dc39ae48a89fd1d8cb592756 # v1 - name: Remove link.exe conflict run: rm -f /usr/bin/link diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index 7f7a6e20d96d5..2db2bb75bf978 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -154,7 +154,7 @@ jobs: - name: Setup Linux if: inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: ./.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Login to ECR if: inputs.build_environment != 'linux-s390x-binary-manywheel' @@ -180,6 +180,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + fetch-depth: 2 submodules: recursive path: pytorch show-progress: false @@ -224,6 +225,7 @@ jobs: # The build.sh script in this folder is not actually the correct one, # this is just needed for sha calculation docker-build-dir: .ci/docker + docker-build-script: ${{ contains(inputs.DOCKER_IMAGE, 'manylinux') && 'manywheel/build.sh' || 'build.sh' }} working-directory: pytorch - name: Pull Docker image diff --git a/.github/workflows/_binary-test-linux.yml b/.github/workflows/_binary-test-linux.yml index ec2bd6ffe5017..4dc035055955e 100644 --- a/.github/workflows/_binary-test-linux.yml +++ b/.github/workflows/_binary-test-linux.yml @@ -138,7 +138,7 @@ jobs: - name: Setup Linux if: inputs.build_environment != 'linux-s390x-binary-manywheel' - uses: ./.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Login to ECR if: inputs.build_environment != 'linux-s390x-binary-manywheel' diff --git a/.github/workflows/_docs.yml b/.github/workflows/_docs.yml index f5cb186c6f189..616f7027f8a38 100644 --- a/.github/workflows/_docs.yml +++ b/.github/workflows/_docs.yml @@ -47,6 +47,26 @@ on: description: prefix for runner label type: string default: "" + use-arc: + required: false + type: boolean + default: false + description: If true, use ARC (OSDC) runner path instead of EC2. + python-version: + required: false + type: string + default: "" + description: Python version to use for the OSDC build. + compiler: + required: false + type: string + default: "" + description: Compiler to use for the OSDC build. + cuda-version: + required: false + type: string + default: "" + description: CUDA version to use for the OSDC build. secrets: GH_PYTORCHBOT_TOKEN: required: false @@ -55,7 +75,7 @@ on: jobs: build-docs: # Don't run on forked repos. - if: github.repository_owner == 'pytorch' + if: github.repository_owner == 'pytorch' && !inputs.use-arc runs-on: ${{ matrix.runner }} environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'pytorchbot-env' || '' }} strategy: @@ -73,8 +93,7 @@ jobs: timeout-minutes: 360 - docs_type: python runner: ${{ inputs.runner_prefix }}linux.c7i.2xlarge - # It takes less than 30m to finish python docs unless there are issues - timeout-minutes: 30 + timeout-minutes: 45 # Set a fixed name for this job instead of using the current matrix-generated name, i.e. build-docs (cpp, linux.12xlarge, 180) # The current name requires updating the database last docs push query from test-infra every time the matrix is updated name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} @@ -89,12 +108,8 @@ jobs: To start Python docs build type: cd docs && make html && make coverage - # [see note: pytorch repo ref] - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - - name: Setup Linux - uses: ./.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Login to ECR uses: ./.github/actions/ecr-login @@ -160,6 +175,7 @@ jobs: -e RUN_DOXYGEN \ -e WITH_PUSH \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + -e GITHUB_WORKSPACE=/var/lib/jenkins/workspace \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ --tty \ @@ -204,6 +220,40 @@ jobs: path: cppdocs/ s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/cppdocs + - name: Post C++ Docs Coverage Comment + if: ${{ github.event_name == 'pull_request' && matrix.docs_type == 'cpp' && steps.build-docs.outcome == 'success' }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + PR_NUM="${{ github.event.pull_request.number }}" + MARKER="" + + # Only post if the PR touches docs/cpp + if ! gh pr diff "$PR_NUM" --name-only | grep -q '^docs/cpp/'; then + echo "No changes to docs/cpp/, skipping coverage comment." + exit 0 + fi + + # Skip if we already posted + if gh pr view "$PR_NUM" --json comments --jq '.comments[].body' | grep -q "$MARKER"; then + echo "Coverage comment already posted, skipping." + exit 0 + fi + + body="${MARKER} + ## C++ Docs + + - [Doc Preview](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/index.html) + - [API Coverage Report](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/cpp_coverage.txt) + - [HTML Issues Report](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/cpp_html_issues.txt)" + + if [ -f cppdocs/_coverage/index.html ]; then + body="${body} + - [Coverxygen Report (interactive)](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/index.html)" + fi + + gh pr comment "$PR_NUM" --body "$body" + - name: Upload C++ Docs Preview (nightly dry-run) if: ${{ !inputs.push && github.event_name != 'pull_request' && matrix.docs_type == 'cpp' && steps.build-docs.outcome == 'success' }} run: | @@ -215,3 +265,157 @@ jobs: - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main if: always() + + build-docs-osdc: + if: github.repository_owner == 'pytorch' && inputs.use-arc + permissions: + id-token: write + contents: read + actions: read + runs-on: ${{ matrix.runner }} + container: + image: ${{ inputs.docker-image }} + strategy: + fail-fast: false + matrix: + include: + - docs_type: cpp + runner: ${{ inputs.runner_prefix }}l-x86iavx512-48-384 + timeout-minutes: 360 + - docs_type: python + runner: ${{ inputs.runner_prefix }}l-x86iavx512-16-128 + timeout-minutes: 45 + name: build-docs-${{ matrix.docs_type }}-${{ inputs.push }} + steps: + - name: Setup Linux + id: setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + use-arc: true + python-version: ${{ inputs.python-version }} + compiler: ${{ inputs.compiler }} + cuda-version: ${{ inputs.cuda-version }} + submodules: false + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Configure AWS credentials + id: aws-creds + continue-on-error: true + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: arn:aws:iam::308535385114:role/arc + aws-region: us-east-1 + role-duration-seconds: 18000 + + - name: Download build artifacts + uses: pytorch/pytorch/.github/actions/download-build-artifacts@main + with: + name: ${{ inputs.build-environment }} + s3-bucket: ${{ inputs.s3-bucket }} + use-gha: ${{ steps.aws-creds.outcome != 'success' }} + + - name: Generate netrc (only for docs-push) + if: inputs.push + env: + GITHUB_PYTORCHBOT_TOKEN: ${{ secrets.GH_PYTORCHBOT_TOKEN }} + run: | + rm -rf "${HOME}/.netrc" + echo "machine github.com" > "${HOME}/.netrc" + echo "login pytorchbot" >> "${HOME}/.netrc" + echo "password ${GITHUB_PYTORCHBOT_TOKEN}" >> "${HOME}/.netrc" + + - name: Build ${{ matrix.docs_type }} docs + timeout-minutes: ${{ matrix.timeout-minutes }} + id: build-docs + env: + WITH_PUSH: ${{ inputs.push }} + DOCS_TYPE: ${{ matrix.docs_type }} + RUN_DOXYGEN: ${{ inputs.run-doxygen }} + BUILD_ENVIRONMENT: ${{ inputs.build-environment }} + SHA1: ${{ github.sha }} + shell: bash + run: | + set -ex + if [[ "${GITHUB_REF}" =~ ^refs/tags/v([0-9]+\.[0-9]+)\.* ]]; then + target="${BASH_REMATCH[1]}" + else + target="main" + fi + export DOCS_VERSION="${target}" + pip install $(echo dist/*.whl)[opt-einsum] + ./.ci/pytorch/${DOCS_TYPE}_doc_push_script.sh + + - name: Upload Python Docs Preview + uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0 + if: ${{ steps.aws-creds.outcome == 'success' && github.event_name == 'pull_request' && matrix.docs_type == 'python' && steps.build-docs.outcome == 'success' }} + with: + retention-days: 14 + s3-bucket: doc-previews + if-no-files-found: error + path: pytorch_docs/main/ + s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }} + + - name: Upload C++ Docs Preview + uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5.1.0 + if: ${{ steps.aws-creds.outcome == 'success' && github.event_name == 'pull_request' && matrix.docs_type == 'cpp' && steps.build-docs.outcome == 'success' }} + with: + retention-days: 14 + if-no-files-found: error + s3-bucket: doc-previews + path: cppdocs/ + s3-prefix: pytorch/pytorch/${{ github.event.pull_request.number }}/cppdocs + + - name: Post C++ Docs Coverage Comment + if: ${{ steps.aws-creds.outcome == 'success' && github.event_name == 'pull_request' && matrix.docs_type == 'cpp' && steps.build-docs.outcome == 'success' }} + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + PR_NUM="${{ github.event.pull_request.number }}" + MARKER="" + + # Only post if the PR touches docs/cpp + if ! gh pr diff "$PR_NUM" --name-only | grep -q '^docs/cpp/'; then + echo "No changes to docs/cpp/, skipping coverage comment." + exit 0 + fi + + # Skip if we already posted + if gh pr view "$PR_NUM" --json comments --jq '.comments[].body' | grep -q "$MARKER"; then + echo "Coverage comment already posted, skipping." + exit 0 + fi + + body="${MARKER} + ## C++ Docs + + - [Doc Preview](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/index.html) + - [API Coverage Report](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/cpp_coverage.txt) + - [HTML Issues Report](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/cpp_html_issues.txt)" + + if [ -f cppdocs/_coverage/index.html ]; then + body="${body} + - [Coverxygen Report (interactive)](https://docs-preview.pytorch.org/pytorch/pytorch/${PR_NUM}/cppdocs/_coverage/index.html)" + fi + + gh pr comment "$PR_NUM" --body "$body" + + - name: Upload C++ Docs Preview (nightly dry-run) + if: ${{ steps.aws-creds.outcome == 'success' && !inputs.push && github.event_name != 'pull_request' && matrix.docs_type == 'cpp' && steps.build-docs.outcome == 'success' }} + run: | + # Unlike EC2 runners, the OSDC container doesn't have aws CLI pre-installed + pip install awscli==1.29.40 + aws s3 cp cppdocs/ s3://doc-previews/pytorch/pytorch/nightly-${{ github.sha }}/cppdocs --recursive --quiet + echo "C++ docs preview available at:" + echo "https://docs-preview.pytorch.org/pytorch/pytorch/nightly-${{ github.sha }}/cppdocs/index.html" + + memory-viz-tests: + # Tests for torch memory visualizer + # tag shangdiy for any issue with this test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-node@v4 + with: + node-version: '20' + - name: Run MemoryViz JS tests (nonretryable) + run: node test/profiler/test_memory_viz.js diff --git a/.github/workflows/_get-changed-files.yml b/.github/workflows/_get-changed-files.yml index f7138a1ddaacc..f3367916f8783 100644 --- a/.github/workflows/_get-changed-files.yml +++ b/.github/workflows/_get-changed-files.yml @@ -25,7 +25,7 @@ jobs: env: GH_TOKEN: ${{ github.token }} run: | - set -e + set -eo pipefail # Check if we're in a pull request context if [ "${{ github.event_name }}" = "pull_request" ] || [ "${{ github.event_name }}" = "pull_request_target" ]; then echo "Running in PR context" @@ -38,8 +38,15 @@ jobs: echo "all_files input is true, returning all files" echo "changed-files=*" >> "$GITHUB_OUTPUT" else - # Use gh CLI to get changed files in the PR with explicit repo - CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//') + # Use gh CLI to get changed files in the PR with explicit repo. + # pipefail ensures that if gh fails (e.g. 403 rate limit), the + # error propagates through the pipe instead of being silently + # captured as the file list and injected into downstream lint jobs. + if ! CHANGED_FILES=$(gh api repos/${{ github.repository }}/pulls/$PR_NUMBER/files --paginate --jq '.[] | select(.status != "removed") | .filename' | tr '\n' ' ' | sed 's/ $//'); then + echo "Failed to get changed files from GitHub API, falling back to all files" + echo "changed-files=*" >> "$GITHUB_OUTPUT" + exit 0 + fi # See https://github.com/pytorch/pytorch/pull/134215#issuecomment-2332128790 PYI_FILES_TO_ADD="" diff --git a/.github/workflows/_link_check.yml b/.github/workflows/_link_check.yml index 014e6106b0730..efa5b433947ef 100644 --- a/.github/workflows/_link_check.yml +++ b/.github/workflows/_link_check.yml @@ -11,15 +11,10 @@ on: jobs: lint-urls: if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }} - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + uses: ./.github/workflows/_lint.yml with: - job-name: lint-urls - timeout: 120 - runner: ${{ inputs.runner }}linux.2xlarge - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: 0 - submodules: false - ref: ${{ inputs.ref }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | ./scripts/lint_urls.sh $( if [ "${{ github.event_name }}" = "pull_request" ]; then @@ -37,15 +32,10 @@ jobs: lint-xrefs: if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }} - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + uses: ./.github/workflows/_lint.yml with: - job-name: lint-xrefs - timeout: 60 - runner: ${{ inputs.runner }}linux.2xlarge - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: 0 - submodules: false - ref: ${{ inputs.ref }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | ./scripts/lint_xrefs.sh $( if [ "${{ github.event_name }}" = "pull_request" ]; then diff --git a/.github/workflows/_lint.yml b/.github/workflows/_lint.yml new file mode 100644 index 0000000000000..8e707c748f582 --- /dev/null +++ b/.github/workflows/_lint.yml @@ -0,0 +1,82 @@ +name: Run linters + +on: + workflow_call: + inputs: + runner: + required: true + type: string + description: The runner to use + docker-image: + required: true + type: string + description: The Docker image to use + script: + required: true + type: string + description: The linter script to run + +jobs: + lint: + runs-on: ${{ inputs.runner }} + container: + image: ${{ inputs.docker-image }} + timeout-minutes: 120 + steps: + - name: Fix Git ownership + shell: bash + run: | + git config --global --add safe.directory "$GITHUB_WORKSPACE" + + - name: Checkout PyTorch + uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + with: + fetch-depth: 0 + submodules: true + no-sudo: true + checkout-mode: treeless + + - name: Setup uv + uses: pytorch/test-infra/.github/actions/setup-uv@main + with: + python-version: "3.12" + activate-environment: true + + - name: Install pip requirements + shell: bash + run: | + set -eux + uv pip install -r .ci/docker/requirements-ci.txt + + - name: Install system requirements + shell: bash + run: | + set -eux + # Update repository + dnf install -y doxygen graphviz nodejs npm + + - name: Install Node.js packages + shell: bash + run: | + set -eux + npm install -g markdown-toc + + - name: Prepare lintrunner + shell: bash + run: | + set -eux + lintrunner init + + - name: Run linter + shell: bash + env: + SCRIPT: ${{ inputs.script }} + run: | + { + echo "#!/usr/bin/env bash"; + echo "set -eou pipefail"; + echo "${SCRIPT}"; + } > "${RUNNER_TEMP}/linter_script" + + # Execute the linter script + bash "${RUNNER_TEMP}/linter_script" diff --git a/.github/workflows/_linux-build.yml b/.github/workflows/_linux-build.yml index c94dbf397be96..14ef0a775204d 100644 --- a/.github/workflows/_linux-build.yml +++ b/.github/workflows/_linux-build.yml @@ -103,34 +103,49 @@ on: required: false type: string default: "" + use-arc: + required: false + type: boolean + default: false + description: If true, use ARC (OSDC) runner path instead of EC2. + python-version: + required: false + type: string + default: "" + description: Python version to use for the OSDC build. + compiler: + required: false + type: string + default: "" + description: Compiler to use for the OSDC build. + cuda-version: + required: false + type: string + default: "" + description: CUDA version to use for the OSDC build. secrets: HUGGING_FACE_HUB_TOKEN: required: false description: | HF Auth token to avoid rate limits when downloading models or datasets from hub - SCRIBE_GRAPHQL_ACCESS_TOKEN: - required: false - description: | - FB app token to write to scribe endpoint outputs: docker-image: - value: ${{ jobs.build.outputs.docker-image }} + value: ${{ jobs.build.outputs.docker-image || jobs.build-osdc.outputs.docker-image }} description: The docker image containing the built PyTorch. test-matrix: - value: ${{ jobs.build.outputs.test-matrix }} + value: ${{ jobs.build.outputs.test-matrix || jobs.build-osdc.outputs.test-matrix }} description: An optional JSON description of what test configs to run later on. build-environment: - value: ${{ jobs.build.outputs.build-environment }} + value: ${{ jobs.build.outputs.build-environment || jobs.build-osdc.outputs.build-environment }} description: Top-level label for what's being built/tested. jobs: build: - environment: ${{ github.ref == 'refs/heads/main' && 'scribe-protected' || startsWith(github.ref, 'refs/heads/release/') && 'scribe-protected' || contains(github.event.pull_request.labels.*.name, 'ci-scribe') && 'scribe-pr' || '' }} # Don't run on forked repos - if: github.repository_owner == 'pytorch' - runs-on: ${{ inputs.runner_prefix}}${{ inputs.runner }} + if: github.repository_owner == 'pytorch' && !inputs.use-arc + runs-on: ${{ inputs.runner_prefix }}${{ inputs.runner }} timeout-minutes: 480 outputs: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} @@ -146,19 +161,14 @@ jobs: Build is done inside the container, to start an interactive session run: docker exec -it $(docker container ps --format '{{.ID}}') bash - # [pytorch repo ref] - # Use a pytorch/pytorch reference instead of a reference to the local - # checkout because when we run this action we don't *have* a local - # checkout. In other cases you should prefer a local checkout. - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - with: - no-sudo: true - checkout-mode: treeless - - name: Setup Linux - uses: ./.github/actions/setup-linux - if: inputs.build-environment != 'linux-s390x-binary-manywheel' + id: setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + python-version: ${{ inputs.python-version }} + compiler: ${{ inputs.compiler }} + cuda-version: ${{ inputs.cuda-version }} + github-token: ${{ secrets.GITHUB_TOKEN }} - name: Login to ECR if: inputs.build-environment != 'linux-s390x-binary-manywheel' @@ -166,13 +176,6 @@ jobs: with: aws-role-to-assume: ${{ inputs.aws-role-to-assume }} - - name: Get workflow job id - id: get-job-id - uses: ./.github/actions/get-workflow-job-id - if: always() - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Check if can use old whl build id: use-old-whl uses: ./.github/actions/reuse-old-whl @@ -181,8 +184,8 @@ jobs: build-environment: ${{ inputs.build-environment }} run-id: ${{ github.run_id }} github-token: ${{ secrets.GITHUB_TOKEN }} - job-id: ${{ steps.get-job-id.outputs.job-id }} - job-name: ${{ steps.get-job-id.outputs.job-name }} + job-id: ${{ steps.setup-linux.outputs.job-id }} + job-name: ${{ steps.setup-linux.outputs.job-name }} - name: Calculate docker image id: calculate-docker-image @@ -207,10 +210,6 @@ jobs: with: docker-image: ${{ steps.calculate-docker-image.outputs.docker-image }} - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py - # Apply the filter logic to the build step too if the test-config label is already there - name: Select all requested test configurations (if the test matrix is available) id: filter @@ -219,7 +218,7 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} test-matrix: ${{ inputs.test-matrix }} selected-test-configs: ${{ inputs.selected-test-configs }} - job-name: ${{ steps.get-job-id.outputs.job-name }} + job-name: ${{ steps.setup-linux.outputs.job-name }} - name: Start monitoring script id: monitor-script @@ -227,8 +226,8 @@ jobs: shell: bash continue-on-error: true env: - JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} + JOB_ID: ${{ steps.setup-linux.outputs.job-id }} + JOB_NAME: ${{ steps.setup-linux.outputs.job-name }} WORKFLOW_NAME: ${{ github.workflow }} WORKFLOW_RUN_ID: ${{github.run_id}} MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} @@ -242,21 +241,12 @@ jobs: > "../../usage_logs/usage_log_build_${JOB_ID}.txt" 2>&1 & echo "monitor-script-pid=${!}" >> "${GITHUB_OUTPUT}" - - name: Download pytest cache - uses: ./.github/actions/pytest-cache-download - continue-on-error: true - if: inputs.build-environment != 'linux-s390x-binary-manywheel' && steps.use-old-whl.outputs.reuse != 'true' - with: - cache_dir: .pytest_cache - job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} - s3_bucket: ${{ inputs.s3-bucket }} - - name: Build if: (steps.filter.outputs.is-test-matrix-empty == 'False' || inputs.test-matrix == '') && steps.use-old-whl.outputs.reuse != 'true' id: build env: BUILD_ENVIRONMENT: ${{ inputs.build-environment }} - BRANCH: ${{ steps.parse-ref.outputs.branch }} + BRANCH: ${{ steps.setup-linux.outputs.branch }} PR_NUMBER: ${{ github.event.pull_request.number }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} # Do not set SCCACHE_S3_KEY_PREFIX to share the cache between all build jobs @@ -268,9 +258,8 @@ jobs: DOCKER_IMAGE: ${{ steps.calculate-docker-image.outputs.docker-image }} DOCKER_IMAGE_S390X: ${{ inputs.docker-image-name }} XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} - OUR_GITHUB_JOB_ID: ${{ steps.get-job-id.outputs.job-id }} + OUR_GITHUB_JOB_ID: ${{ steps.setup-linux.outputs.job-id }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} BUILD_ADDITIONAL_PACKAGES: ${{ inputs.build-additional-packages }} RUNNER: ${{ inputs.runner }} run: | @@ -342,7 +331,6 @@ jobs: -e PR_LABELS \ -e OUR_GITHUB_JOB_ID \ -e HUGGING_FACE_HUB_TOKEN \ - -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e BUILD_ADDITIONAL_PACKAGES \ -e RUNNER \ --memory="${TOTAL_AVAILABLE_MEMORY_IN_GB%.*}g" \ @@ -426,7 +414,7 @@ jobs: if-no-files-found: error path: artifacts.zip - - name: copy logs + - name: Copy logs shell: bash if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}} continue-on-error: true @@ -437,7 +425,7 @@ jobs: - name: Upload raw usage log to s3 if: ${{ always() && steps.build.outcome != 'skipped' && !inputs.disable-monitor && inputs.build-environment != 'linux-s390x-binary-manywheel'}} - uses: seemethere/upload-artifact-s3@v5 + uses: seemethere/upload-artifact-s3@baba72d0712b404f646cebe0730933554ebce96a # v5 with: s3-prefix: | ${{ github.repository }}/${{ github.run_id }}/${{ github.run_attempt }}/artifact @@ -457,12 +445,12 @@ jobs: continue-on-error: true uses: ./.github/actions/upload-utilization-stats with: - job_id: ${{ steps.get-job-id.outputs.job-id }} - job_name: ${{ steps.get-job-id.outputs.job-name }} + job_id: ${{ steps.setup-linux.outputs.job-id }} + job_name: ${{ steps.setup-linux.outputs.job-name }} workflow_name: ${{ github.workflow }} workflow_run_id: ${{github.run_id}} workflow_attempt: ${{github.run_attempt}} - artifact_prefix: usage_log_build_${{ steps.get-job-id.outputs.job-id }} + artifact_prefix: usage_log_build_${{ steps.setup-linux.outputs.job-id }} - name: Teardown Linux uses: pytorch/test-infra/.github/actions/teardown-linux@main @@ -475,3 +463,136 @@ jobs: # on s390x stop the container for clean worker stop docker stop -a || true docker kill -a || true + + build-osdc: + permissions: + id-token: write + contents: read + actions: read + # Don't run on forked repos + if: github.repository_owner == 'pytorch' && inputs.use-arc + runs-on: ${{ inputs.runner_prefix }}${{ startsWith(inputs.runner, 'l-') && inputs.runner || contains(inputs.runner, 'arm64') && 'l-arm64g4-16-62' || 'l-x86iavx512-8-64' }} + container: + image: ghcr.io/pytorch/${{ inputs.docker-image-name }} + timeout-minutes: 480 + outputs: + docker-image: ghcr.io/pytorch/${{ inputs.docker-image-name }} + test-matrix: ${{ steps.map-runners.outputs.test-matrix }} + build-environment: ${{ inputs.build-environment }} + steps: + - name: Setup Linux + id: setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + use-arc: true + python-version: ${{ inputs.python-version }} + compiler: ${{ inputs.compiler }} + cuda-version: ${{ inputs.cuda-version }} + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Check if can use old whl build + id: use-old-whl + uses: pytorch/pytorch/.github/actions/reuse-old-whl@main + if: ${{ inputs.allow-reuse-old-whl }} + with: + build-environment: ${{ inputs.build-environment }} + run-id: ${{ github.run_id }} + github-token: ${{ secrets.GITHUB_TOKEN }} + job-id: ${{ steps.setup-linux.outputs.job-id }} + job-name: ${{ steps.setup-linux.outputs.job-name }} + + # Apply the filter logic to the build step too if the test-config label is already there + - name: Select all requested test configurations (if the test matrix is available) + id: filter + uses: pytorch/pytorch/.github/actions/filter-test-configs@main + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + test-matrix: ${{ inputs.test-matrix }} + selected-test-configs: ${{ inputs.selected-test-configs }} + job-name: ${{ steps.setup-linux.outputs.job-name }} + + - name: Map EC2 runners to ARC runners + id: map-runners + env: + FILTERED_TEST_MATRIX: ${{ steps.filter.outputs.test-matrix }} + RUNNER_PREFIX: ${{ inputs.runner_prefix }} + shell: bash + run: | + python3 .github/scripts/map_ec2_to_arc.py --prefix "${RUNNER_PREFIX}" "${FILTERED_TEST_MATRIX}" + + - name: Configure AWS credentials + id: aws-creds + continue-on-error: true + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: arn:aws:iam::308535385114:role/arc + aws-region: us-east-1 + # The max duration enforced by the server side + role-duration-seconds: 18000 + + - name: Build + if: (steps.filter.outputs.is-test-matrix-empty == 'False' || inputs.test-matrix == '') && steps.use-old-whl.outputs.reuse != 'true' + id: build + env: + BUILD_ENVIRONMENT: ${{ inputs.build-environment }} + BRANCH: ${{ steps.setup-linux.outputs.branch }} + PR_NUMBER: ${{ github.event.pull_request.number }} + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_REGION: us-east-1 + SCCACHE_S3_NO_CREDENTIALS: ${{ steps.aws-creds.outcome != 'success' && 'true' || 'false' }} + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + PR_LABELS: ${{ toJson(github.event.pull_request.labels.*.name) }} + TORCH_CUDA_ARCH_LIST: ${{ inputs.cuda-arch-list }} + XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} + OUR_GITHUB_JOB_ID: ${{ steps.setup-linux.outputs.job-id }} + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + BUILD_ADDITIONAL_PACKAGES: ${{ inputs.build-additional-packages }} + RUNNER: ${{ inputs.runner }} + SKIP_SCCACHE_INITIALIZATION: 1 + shell: bash + run: | + START_TIME=$(date +%s) + .ci/pytorch/build.sh + END_TIME=$(date +%s) + echo "build_time=$((END_TIME - START_TIME))" >> "$GITHUB_OUTPUT" + + - name: Build external packages + id: build-external-packages + if: inputs.build-external-packages != '' && steps.build.outcome != 'skipped' + uses: pytorch/pytorch/.github/actions/build-external-packages@main + with: + build-targets: ${{ inputs.build-external-packages }} + docker-image: ghcr.io/pytorch/${{ inputs.docker-image-name }} + cuda-arch-list: ${{ inputs.cuda-arch-list }} + output-dir: external + + - name: Move external packages to dist + if: steps.build-external-packages.outputs.output_dir != '' && steps.build-external-packages.outcome != 'skipped' + shell: bash + run: | + src="${{ steps.build-external-packages.outputs.output_dir }}" + if [ -d "$src" ]; then + mkdir -p "dist/$(dirname "$src")" + mv "$src" "dist/$(dirname "$src")/" + fi + + - name: Archive artifacts into zip + if: inputs.build-generates-artifacts && steps.build.outcome != 'skipped' && steps.use-old-whl.outputs.reuse != 'true' + run: | + zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .additional_ci_files + + - name: Store build artifacts + if: inputs.build-generates-artifacts && (steps.build.outcome != 'skipped' || steps.use-old-whl.outputs.reuse == 'true') + uses: pytorch/pytorch/.github/actions/upload-build-artifacts@main + with: + name: ${{ inputs.build-environment }} + s3-bucket: ${{ inputs.s3-bucket }} + use-gha: ${{ steps.aws-creds.outcome != 'success' }} + + - name: Upload sccache stats + if: steps.build.outcome != 'skipped' && steps.aws-creds.outcome == 'success' + uses: pytorch/pytorch/.github/actions/upload-sccache-stats@main + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + build-time: ${{ steps.build.outputs.build_time }} diff --git a/.github/workflows/_linux-test-stable-fa3.yml b/.github/workflows/_linux-test-stable-fa3.yml index f2e16712ff447..98d64b9246869 100644 --- a/.github/workflows/_linux-test-stable-fa3.yml +++ b/.github/workflows/_linux-test-stable-fa3.yml @@ -54,15 +54,13 @@ jobs: # Don't run on forked repos if: github.repository_owner == 'pytorch' runs-on: linux.aws.h100 - timeout-minutes: ${{ inputs.timeout-minutes || 30 }} + timeout-minutes: ${{ inputs.timeout-minutes || 60 }} permissions: id-token: write contents: read steps: - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - with: - no-sudo: true + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Checkout flash-attention as a secondary repository uses: actions/checkout@v4 @@ -70,9 +68,6 @@ jobs: repository: Dao-AILab/flash-attention path: flash-attention - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Login to ECR uses: ./.github/actions/ecr-login diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index 3ad1706fb3af2..2a9c0626175c8 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -72,6 +72,38 @@ on: required: false type: number default: 1 + use-arc: + required: false + type: boolean + default: false + description: If true, use ARC (OSDC) runner path instead of EC2. + python-version: + required: false + type: string + default: "" + description: Python version to use for the OSDC test. + compiler: + required: false + type: string + default: "" + description: Compiler to use for the OSDC test. + cuda-version: + required: false + type: string + default: "" + description: CUDA version to use for the OSDC test. + export-profiler-trace: + description: | + If set to "1", export Chrome profiler traces from performance benchmarks. + required: false + type: string + default: "" + enable-torch-trace: + description: | + If set to "1", enable TORCH_TRACE structured logging and collect tlparse output. + required: false + type: string + default: "" secrets: HUGGING_FACE_HUB_TOKEN: required: false @@ -92,7 +124,7 @@ env: jobs: test: # Don't run on forked repos or empty test matrix - if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' + if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' && !inputs.use-arc strategy: matrix: ${{ fromJSON(inputs.test-matrix) }} fail-fast: false @@ -112,23 +144,15 @@ jobs: All testing is done inside the container, to start an interactive session run: docker exec -it $(docker container ps --format '{{.ID}}') bash - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - with: - no-sudo: true - checkout-mode: treeless - submodules: false - - - name: Setup Python - if: contains(matrix.runner, 'b200') - uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 - with: - python-version: '3.12' - cache: pip - - name: Setup Linux - uses: ./.github/actions/setup-linux - if: inputs.build-environment != 'linux-s390x-binary-manywheel' && !contains(matrix.runner, 'b200') + id: setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + python-version: ${{ inputs.python-version }} + compiler: ${{ inputs.compiler }} + cuda-version: ${{ inputs.cuda-version }} + submodules: 'false' + github-token: ${{ secrets.GITHUB_TOKEN }} - name: Check TPU Availability id: check-tpu @@ -192,21 +216,14 @@ jobs: run: echo "SCCACHE_SERVER_PORT_DOCKER_FLAG=-e SCCACHE_SERVER_PORT=$((RUNNER_UID + 4226))" >> "${GITHUB_ENV}" if: ${{ steps.check_container_runner.outputs.IN_CONTAINER_RUNNER == 'true' && !contains(matrix.runner, 'b200') }} - - name: Get workflow job id - id: get-job-id - uses: ./.github/actions/get-workflow-job-id - if: always() - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - - name: Start monitoring script id: monitor-script if: ${{ !inputs.disable-monitor }} shell: bash continue-on-error: true env: - JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} + JOB_ID: ${{ steps.setup-linux.outputs.job-id }} + JOB_NAME: ${{ steps.setup-linux.outputs.job-name }} WORKFLOW_NAME: ${{ github.workflow }} WORKFLOW_RUN_ID: ${{github.run_id}} MONITOR_LOG_INTERVAL: ${{ inputs.monitor-log-interval }} @@ -262,13 +279,20 @@ jobs: echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" fi - # Verify CUDA libraries are present + # Setup CUDA runtime DLL (needed for MinGW import lib generation on CUDA 13.0+) + mkdir -p win-torch-wheel-extracted/bin/x64 + for dll in win-torch-wheel/cudart64_*.dll; do + if [ -f "$dll" ]; then + mv "$dll" win-torch-wheel-extracted/bin/x64/ + echo "Moved $(basename $dll) to win-torch-wheel-extracted/bin/x64/" + fi + done + + # Verify CUDA libraries and DLLs are present echo "CUDA libraries:" ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" - - - name: Parse ref - id: parse-ref - run: .github/scripts/parse_ref.py + echo "CUDA DLLs:" + ls -la win-torch-wheel-extracted/bin/x64/ || echo "No CUDA DLLs found" - name: Check for keep-going label and re-enabled test issues # This uses the filter-test-configs action because it conveniently @@ -279,7 +303,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} test-matrix: ${{ inputs.test-matrix }} - job-name: ${{ steps.get-job-id.outputs.job-name }} + job-name: ${{ steps.setup-linux.outputs.job-name }} - name: Set Test step time id: test-timeout @@ -318,9 +342,9 @@ jobs: GITHUB_RUN_ID: ${{ github.run_id }} GITHUB_RUN_NUMBER: ${{ github.run_number }} GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} - JOB_ID: ${{ steps.get-job-id.outputs.job-id }} - JOB_NAME: ${{ steps.get-job-id.outputs.job-name }} - BRANCH: ${{ steps.parse-ref.outputs.branch }} + JOB_ID: ${{ steps.setup-linux.outputs.job-id }} + JOB_NAME: ${{ steps.setup-linux.outputs.job-name }} + BRANCH: ${{ steps.setup-linux.outputs.branch }} SHA1: ${{ github.event.pull_request.head.sha || github.sha }} BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }} TEST_CONFIG: ${{ matrix.config }} @@ -347,17 +371,20 @@ jobs: PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }} DASHBOARD_TAG: ${{ inputs.dashboard-tag }} + EXPORT_PROFILER_TRACE: ${{ inputs.export-profiler-trace }} + ENABLE_TORCH_TRACE: ${{ inputs.enable-torch-trace }} VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }} HF_CACHE: /mnt/hf_cache # Use offline mode by default, only enable online mode to refresh the models - # from HF when the PR has a special ci-refresh-hf-cache label. This label is + # from HF when the PR has a special ci-refresh-hf-cache label or when the job + # is a nightly scheduled run on main. The ci-refresh-hf-cache label is # automatically added to the vLLM pinned commit hash update, which effectively # refreshes the cache daily - TRANSFORMERS_OFFLINE: ${{ contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache') && '0' || '1' }} - HF_DATASETS_OFFLINE: ${{ contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache') && '0' || '1' }} + TRANSFORMERS_OFFLINE: ${{ (github.event_name == 'schedule' || contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache')) && '0' || '1' }} + HF_DATASETS_OFFLINE: ${{ (github.event_name == 'schedule' || contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache')) && '0' || '1' }} HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} - ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} + ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.setup-linux.outputs.job-id }} TORCH_TPU: ${{ steps.check-tpu.outputs.has_tpu == 'true' && '1' || '' }} TORCH_TPU_TEXT_FILE: /var/lib/jenkins/workspace/.github/ci_commit_pins/torch_tpu.txt run: | @@ -365,8 +392,6 @@ jobs: if [[ $TEST_CONFIG == 'multigpu' ]]; then TEST_COMMAND=.ci/pytorch/multigpu-test.sh - elif [[ $BUILD_ENVIRONMENT == *onnx* ]]; then - TEST_COMMAND=.ci/onnx/test.sh else TEST_COMMAND=.ci/pytorch/test.sh fi @@ -457,6 +482,8 @@ jobs: -e HF_DATASETS_OFFLINE \ -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e DASHBOARD_TAG \ + -e EXPORT_PROFILER_TRACE \ + -e ENABLE_TORCH_TRACE \ -e ARTIFACTS_FILE_SUFFIX \ -e TORCH_TPU \ -e TORCH_TPU_TEXT_FILE \ @@ -539,7 +566,7 @@ jobs: uses: ./.github/actions/upload-test-artifacts if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' with: - file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} + file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.setup-linux.outputs.job-id }} use-gha: ${{ inputs.use-gha }} s3-bucket: ${{ inputs.s3-bucket }} @@ -563,8 +590,8 @@ jobs: continue-on-error: true uses: ./.github/actions/upload-utilization-stats with: - job_id: ${{ steps.get-job-id.outputs.job-id }} - job_name: ${{ steps.get-job-id.outputs.job-name }} + job_id: ${{ steps.setup-linux.outputs.job-id }} + job_name: ${{ steps.setup-linux.outputs.job-name }} workflow_name: ${{ github.workflow }} workflow_run_id: ${{github.run_id}} workflow_attempt: ${{github.run_attempt}} @@ -580,3 +607,237 @@ jobs: # on s390x stop the container for clean worker stop docker stop -a || true docker kill -a || true + + test-osdc: + # Don't run on forked repos or empty test matrix + if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]' && inputs.use-arc + strategy: + matrix: ${{ fromJSON(inputs.test-matrix) }} + fail-fast: false + environment: ${{ github.ref == 'refs/heads/main' && 'scribe-protected' || startsWith(github.ref, 'refs/heads/release/') && 'scribe-protected' || contains(github.event.pull_request.labels.*.name, 'ci-scribe') && 'scribe-pr' || '' }} + runs-on: ${{ matrix.runner }} + container: + image: ${{ inputs.docker-image }} + options: "--gpus all" + timeout-minutes: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} + permissions: + id-token: write + contents: read + actions: read + steps: + - name: Setup Linux + id: setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + use-arc: true + python-version: ${{ inputs.python-version }} + compiler: ${{ inputs.compiler }} + cuda-version: ${{ inputs.cuda-version }} + submodules: 'false' + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Configure AWS credentials + id: aws-creds + continue-on-error: true + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: arn:aws:iam::308535385114:role/arc + aws-region: us-east-1 + # The max duration enforced by the server side + role-duration-seconds: 18000 + + - name: Download build artifacts + uses: pytorch/pytorch/.github/actions/download-build-artifacts@main + with: + name: ${{ inputs.build-environment }} + s3-bucket: ${{ inputs.s3-bucket }} + use-gha: ${{ steps.aws-creds.outcome != 'success' || inputs.use-gha }} + + - name: Download TD artifacts + continue-on-error: true + uses: pytorch/pytorch/.github/actions/download-td-artifacts@main + + - name: Download Windows torch wheel for cross-compilation + if: matrix.win_torch_wheel_artifact != '' + uses: seemethere/download-artifact-s3@1da556a7aa0a088e3153970611f6c432d58e80e6 # v4.2.0 + with: + name: ${{ matrix.win_torch_wheel_artifact }} + path: win-torch-wheel + + - name: Extract Windows wheel and setup CUDA libraries + if: matrix.win_torch_wheel_artifact != '' + shell: bash + run: | + set -x + + # Find the wheel file + WHEEL_FILE=$(find win-torch-wheel -name "*.whl" -type f | head -n 1) + if [ -z "$WHEEL_FILE" ]; then + echo "Error: No wheel file found in win-torch-wheel directory" + exit 1 + fi + echo "Found wheel file: $WHEEL_FILE" + + # Unzip the wheel file + unzip -q "$WHEEL_FILE" -d win-torch-wheel-extracted + echo "Extracted wheel contents" + + # Setup CUDA libraries (cuda.lib and cudart.lib) directory + mkdir -p win-torch-wheel-extracted/lib/x64 + if [ -f "win-torch-wheel/cuda.lib" ]; then + mv win-torch-wheel/cuda.lib win-torch-wheel-extracted/lib/x64/ + echo "Moved cuda.lib to win-torch-wheel-extracted/lib/x64/" + fi + if [ -f "win-torch-wheel/cudart.lib" ]; then + mv win-torch-wheel/cudart.lib win-torch-wheel-extracted/lib/x64/ + echo "Moved cudart.lib to win-torch-wheel-extracted/lib/x64/" + fi + + # Setup CUDA runtime DLL (needed for MinGW import lib generation on CUDA 13.0+) + mkdir -p win-torch-wheel-extracted/bin/x64 + for dll in win-torch-wheel/cudart64_*.dll; do + if [ -f "$dll" ]; then + mv "$dll" win-torch-wheel-extracted/bin/x64/ + echo "Moved $(basename $dll) to win-torch-wheel-extracted/bin/x64/" + fi + done + + # Verify CUDA libraries and DLLs are present + echo "CUDA libraries:" + ls -la win-torch-wheel-extracted/lib/x64/ || echo "No CUDA libraries found" + echo "CUDA DLLs:" + ls -la win-torch-wheel-extracted/bin/x64/ || echo "No CUDA DLLs found" + + - name: Check for keep-going label and re-enabled test issues + # This uses the filter-test-configs action because it conveniently + # checks for labels and re-enabled test issues. It does not actually do + # any filtering. All filtering is done in the build step. + id: keep-going + uses: pytorch/pytorch/.github/actions/filter-test-configs@main + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + test-matrix: ${{ inputs.test-matrix }} + job-name: ${{ steps.setup-linux.outputs.job-name }} + + - name: Set test step time + id: test-timeout + shell: bash + env: + JOB_TIMEOUT: ${{ matrix.mem_leak_check == 'mem_leak_check' && 600 || inputs.timeout-minutes }} + run: | + echo "timeout=$((JOB_TIMEOUT-30))" >> "${GITHUB_OUTPUT}" + + - name: Test + id: test + timeout-minutes: ${{ fromJson(steps.test-timeout.outputs.timeout) }} + env: + BUILD_ENVIRONMENT: ${{ inputs.build-environment }} + PR_NUMBER: ${{ github.event.pull_request.number }} + GITHUB_REPOSITORY: ${{ github.repository }} + GITHUB_WORKFLOW: ${{ github.workflow }} + GITHUB_JOB: ${{ github.job }} + GITHUB_RUN_ID: ${{ github.run_id }} + GITHUB_RUN_NUMBER: ${{ github.run_number }} + GITHUB_RUN_ATTEMPT: ${{ github.run_attempt }} + JOB_ID: ${{ steps.setup-linux.outputs.job-id }} + JOB_NAME: ${{ steps.setup-linux.outputs.job-name }} + BRANCH: ${{ steps.setup-linux.outputs.branch }} + SHA1: ${{ github.event.pull_request.head.sha || github.sha }} + BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }} + TEST_CONFIG: ${{ matrix.config }} + SHARD_NUMBER: ${{ matrix.shard }} + NUM_TEST_SHARDS: ${{ matrix.num_shards }} + EXTRA_FLAGS: ${{ matrix.extra_flags || '' }} + OP_BENCHMARK_TESTS: ${{ matrix.op_benchmark_tests }} + REENABLED_ISSUES: ${{ steps.keep-going.outputs.reenabled-issues }} + CONTINUE_THROUGH_ERROR: ${{ steps.keep-going.outputs.keep-going }} + VERBOSE_TEST_LOGS: ${{ steps.keep-going.outputs.ci-verbose-test-logs }} + TEST_SHOWLOCALS: ${{ steps.keep-going.outputs.ci-test-showlocals }} + NO_TEST_TIMEOUT: ${{ steps.keep-going.outputs.ci-no-test-timeout }} + NO_TD: ${{ steps.keep-going.outputs.ci-no-td }} + TD_DISTRIBUTED: ${{ steps.keep-going.outputs.ci-td-distributed }} + SCCACHE_BUCKET: ossci-compiler-cache-circleci-v2 + SCCACHE_REGION: us-east-1 + SCCACHE_S3_NO_CREDENTIALS: ${{ steps.aws-creds.outcome != 'success' && 'true' || 'false' }} + XLA_CUDA: ${{ contains(inputs.build-environment, 'xla') && '0' || '' }} + XLA_CLANG_CACHE_S3_BUCKET_NAME: ossci-compiler-clang-cache-circleci-xla + PYTORCH_TEST_CUDA_MEM_LEAK_CHECK: ${{ matrix.mem_leak_check && '1' || '0' }} + PYTORCH_TEST_RERUN_DISABLED_TESTS: ${{ matrix.rerun_disabled_tests && '1' || '0' }} + TESTS_TO_INCLUDE: ${{ inputs.tests-to-include }} + DASHBOARD_TAG: ${{ inputs.dashboard-tag }} + VLLM_TEST_HUGGING_FACE_TOKEN: ${{ secrets.VLLM_TEST_HUGGING_FACE_TOKEN }} + HF_CACHE: /mnt/hf_cache + TRANSFORMERS_OFFLINE: ${{ (github.event_name == 'schedule' || contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache')) && '0' || '1' }} + HF_DATASETS_OFFLINE: ${{ (github.event_name == 'schedule' || contains(steps.keep-going.outputs.labels, 'ci-refresh-hf-cache')) && '0' || '1' }} + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} + ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.setup-linux.outputs.job-id }} + TORCH_TPU_TEXT_FILE: /var/lib/jenkins/workspace/.github/ci_commit_pins/torch_tpu.txt + USE_ARC: "1" + shell: bash + run: | + set -x + + if [[ "${TEST_CONFIG}" == 'multigpu' ]]; then + TEST_COMMAND=.ci/pytorch/multigpu-test.sh + else + TEST_COMMAND=.ci/pytorch/test.sh + fi + + # Just create an empty HF_CACHE dir if it doesn't exist. This dir is not + # used for anything besides vLLM jobs + if [[ ! -d "${HF_CACHE}" ]]; then + export HF_CACHE="${RUNNER_TEMP}/hf_cache" + mkdir -p "${HF_CACHE}" + + # When there is no cache directory, e.g. benchmark, the job has no + # way but to reach out to HF if needed + export TRANSFORMERS_OFFLINE=0 + export HF_DATASETS_OFFLINE=0 + fi + + # shellcheck disable=SC2046 + python3 -m pip install $(echo dist/*.whl)[opt-einsum] + ${TEST_COMMAND} + + - name: Configure AWS credentials + id: aws-creds-benchmark + continue-on-error: true + uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results + aws-region: us-east-1 + + - name: Upload pytest cache if tests failed + uses: pytorch/pytorch/.github/actions/pytest-cache-upload@main + continue-on-error: true + if: failure() && steps.test.conclusion && steps.test.conclusion == 'failure' && steps.aws-creds-benchmark.outcome == 'success' + with: + cache_dir: .pytest_cache + shard: ${{ matrix.shard }} + sha: ${{ github.event.pull_request.head.sha || github.sha }} + test_config: ${{ matrix.config }} + job_identifier: ${{ github.workflow }}_${{ inputs.build-environment }} + + - name: Upload the benchmark results + if: steps.aws-creds-benchmark.outcome == 'success' + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: test/test-reports + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} + + - name: Print remaining test logs + shell: bash + if: always() && steps.test.conclusion + run: | + cat test/**/*_toprint.log || true + + - name: Upload test artifacts + uses: pytorch/pytorch/.github/actions/upload-test-artifacts@main + if: always() && steps.test.conclusion && steps.test.conclusion != 'skipped' + with: + file-suffix: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.setup-linux.outputs.job-id }} + use-gha: ${{ inputs.use-gha }} + s3-bucket: ${{ inputs.s3-bucket }} diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index 0d674f044ec42..9119a6c0d16bb 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -7,7 +7,8 @@ on: required: false type: string description: | - List of experiments for this workflow. If not defined, all default experiments are included. + Comma-separated list of non-default experiments to opt into for this workflow. + These are added on top of all default experiments. If not defined, only default experiments are included. opt_out_experiments: required: false type: string @@ -36,11 +37,28 @@ on: description: | Fetch's GitHub Issue from pytorch/test-infra Example: https://github.com/pytorch/test-infra/issues/5132 + runner_config: + required: false + type: string + default: "" + description: Runner configuration used by the caller to derive runner-type. outputs: label-type: description: Type of runners to use value: ${{ jobs.runner-determinator.outputs.label-type }} + runner-config: + description: Normalized runner configuration derived from the caller input + value: ${{ jobs.runner-determinator.outputs.runner-config }} + runner-type: + description: Runner suffix to use for workflow-specific runner selection + value: ${{ jobs.runner-determinator.outputs.runner-type }} + runner-label: + description: Fully qualified runner label derived from runner_config + value: ${{ jobs.runner-determinator.outputs.runner-label }} + use-arc: + description: Whether to use ARC runners + value: ${{ jobs.runner-determinator.outputs.use-arc }} jobs: runner-determinator: @@ -49,6 +67,10 @@ jobs: runs-on: ubuntu-latest outputs: label-type: ${{ steps.set-condition.outputs.label-type }} + runner-config: ${{ steps.set-runner-info.outputs.runner-config }} + runner-type: ${{ steps.set-runner-info.outputs.runner-type }} + runner-label: ${{ steps.set-runner-info.outputs.runner-label }} + use-arc: ${{ steps.set-condition.outputs.use-arc }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ISSUE_NUMBER: ${{ inputs.issue_number }} @@ -58,658 +80,8 @@ jobs: OPT_OUT_EXPERIMENTS: ${{ inputs.opt_out_experiments }} PR_NUMBER: ${{ github.event.pull_request.number }} steps: - # - name: Checkout PyTorch - # uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - # with: - # fetch-depth: 1 - # submodules: true - - # TODO: Remove the hardcoded step below - # Hardcoding below is temporary for testing ALI runners - # This file below should match the script found in .github/scripts/runner_determinator.py - - name: Hardcode runner-determinator script - id: hardcode-script - run: | - cat < runner_determinator.py - # flake8: noqa: G004 - - # Note: Copies of this script in runner_determinator.py and _runner-determinator.yml - # must be kept in sync. You can do it easily by running the following command: - # python .github/scripts/update_runner_determinator.py - - """ - This runner determinator is used to determine which set of runners to run a - GitHub job on. It uses the first comment of a GitHub issue (by default - https://github.com/pytorch/test-infra/issues/5132) to define the configuration - of which runners should be used to run which job. - - The configuration has two parts, the settings and a list of opted-in users, - separated by a line containing "---". If the line is not present, the - settings are considered to be empty with only the second part, the user - list, defined. - - The first part is a YAML block that defines the rollout settings. This can be - used to define any settings that are needed to determine which runners to use. - It's fields are defined by the RolloutSettings class below. - - The second part is a list of users who are explicitly opted in to the LF fleet. - The user list is also a comma separated list of additional features or - experiments which the user could be opted in to. - - The user list has the following rules: - - - Users are GitHub usernames, which must start with the @ prefix - - Each user is also a comma-separated list of features/experiments to enable - - A "#" prefix opts the user out of all experiments - - Example config: - # A list of experiments that can be opted into. - # This defines the behavior they'll induce when opted into. - # Expected syntax is: - # [experiment_name]: # Name of the experiment. Also used for the label prefix. - # rollout_perc: [int] # % of workflows to run with this experiment when users are not opted in. - - experiments: - lf: - rollout_percent: 25 - all_branches: false - default: true - --- - - # Opt-ins: - # Users can opt into the LF fleet by adding their GitHub username to this list - # and specifying experiments to enable in a comma-separated list. - # To always opt out of an experiment, prefix it with a "-". - # Experiments should be from the above list. - - @User1,-lf,split_build - @User2,lf - @User3,split_build - """ - - import json - import logging - import os - import random - import re - import sys - from argparse import ArgumentParser - from collections.abc import Iterable - from functools import cache - from logging import LogRecord - from typing import Any, NamedTuple - from urllib.request import Request, urlopen - - import yaml - from github import Auth, Github - from github.Issue import Issue - - - DEFAULT_LABEL_PREFIX = "" # use meta runners - WORKFLOW_LABEL_LF = "lf." # use runners from the linux foundation - WORKFLOW_LABEL_LF_CANARY = "lf.c." # use canary runners from the linux foundation - - GITHUB_OUTPUT = os.getenv("GITHUB_OUTPUT", "") - GH_OUTPUT_KEY_AMI = "runner-ami" - GH_OUTPUT_KEY_LABEL_TYPE = "label-type" - OPT_OUT_LABEL = "no-runner-experiments" - - SETTING_EXPERIMENTS = "experiments" - - LF_FLEET_EXPERIMENT = "lf" - CANARY_FLEET_SUFFIX = ".c" - - - class Experiment(NamedTuple): - rollout_perc: float = ( - 0 # Percentage of workflows to experiment on when user is not opted-in. - ) - all_branches: bool = ( - False # If True, the experiment is also enabled on the exception branches - ) - default: bool = ( - True # If True, the experiment is enabled by default for all queries - ) - - # Add more fields as needed - - - class Settings(NamedTuple): - """ - Settings for the experiments that can be opted into. - """ - - experiments: dict[str, Experiment] = {} - - - class ColorFormatter(logging.Formatter): - """Color codes the log messages based on the log level""" - - COLORS = { - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[31m", # Red - "INFO": "\033[0m", # Reset - "DEBUG": "\033[0m", # Reset - } - - def format(self, record: LogRecord) -> str: - log_color = self.COLORS.get(record.levelname, "\033[0m") # Default to reset - record.msg = f"{log_color}{record.msg}\033[0m" - return super().format(record) - - - handler = logging.StreamHandler() - handler.setFormatter(ColorFormatter(fmt="%(levelname)-8s: %(message)s")) - - log = logging.getLogger(os.path.basename(__file__)) - log.addHandler(handler) - log.setLevel(logging.INFO) - - - def set_github_output(key: str, value: str) -> None: - """ - Defines outputs of the github action that invokes this script - """ - if not GITHUB_OUTPUT: - # See https://github.blog/changelog/2022-10-11-github-actions-deprecating-save-state-and-set-output-commands/ for deprecation notice - log.warning( - "No env var found for GITHUB_OUTPUT, you must be running this code locally. Falling back to the deprecated print method." - ) - print(f"::set-output name={key}::{value}") - return - - with open(GITHUB_OUTPUT, "a") as f: - log.info(f"Setting output: {key}='{value}'") - f.write(f"{key}={value}\n") - - - def _str_comma_separated_to_set(value: str) -> frozenset[str]: - return frozenset( - filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) - ) - - - def parse_args() -> Any: - parser = ArgumentParser("Get dynamic rollout settings") - parser.add_argument("--github-token", type=str, required=True, help="GitHub token") - parser.add_argument( - "--github-issue-repo", - type=str, - required=False, - default="pytorch/test-infra", - help="GitHub repo to get the issue", - ) - parser.add_argument( - "--github-repo", - type=str, - required=True, - help="GitHub repo where CI is running", - ) - parser.add_argument( - "--github-issue", type=int, required=True, help="GitHub issue number" - ) - parser.add_argument( - "--github-actor", type=str, required=True, help="GitHub triggering_actor" - ) - parser.add_argument( - "--github-issue-owner", type=str, required=True, help="GitHub issue owner" - ) - parser.add_argument( - "--github-branch", type=str, required=True, help="Current GitHub branch or tag" - ) - parser.add_argument( - "--github-ref-type", - type=str, - required=True, - help="Current GitHub ref type, branch or tag", - ) - parser.add_argument( - "--eligible-experiments", - type=_str_comma_separated_to_set, - required=False, - default="", - help="comma separated list of experiments to check, if omitted all experiments marked with default=True are checked", - ) - parser.add_argument( - "--opt-out-experiments", - type=_str_comma_separated_to_set, - required=False, - default="", - help=( - "comma separated list of experiments to opt-out of. If unset, no opt-outs will occur. " - "If the same experiment is listed both here and in '--eligible-experiments' opt-out will take priority." - ), - ) - parser.add_argument( - "--pr-number", - type=str, - required=False, - default="", - help="the optional PR number where this is run", - ) - - return parser.parse_args() - - - def get_gh_client(github_token: str) -> Github: # type: ignore[no-any-unimported] - auth = Auth.Token(github_token) - return Github(auth=auth) - - - def get_issue(gh: Github, repo: str, issue_num: int) -> Issue: # type: ignore[no-any-unimported] - repo = gh.get_repo(repo) - return repo.get_issue(number=issue_num) - - - def get_potential_pr_author( - github_token: str, repo: str, username: str, ref_type: str, ref_name: str - ) -> str: - # If the trigger was a new tag added by a bot, this is a ciflow case - # Fetch the actual username from the original PR. The PR number is - # embedded in the tag name: ciflow// - - gh = get_gh_client(github_token) - - if username == "pytorch-bot[bot]" and ref_type == "tag": - split_tag = ref_name.split("/") - if ( - len(split_tag) == 3 - and split_tag[0] == "ciflow" - and split_tag[2].isnumeric() - ): - pr_number = split_tag[2] - try: - repository = gh.get_repo(repo) - pull = repository.get_pull(number=int(pr_number)) - except Exception as e: - raise Exception( # noqa: TRY002 - f"issue with pull request {pr_number} from repo {repository}" - ) from e - return pull.user.login # type: ignore[no-any-return] - # In all other cases, return the original input username - return username - - - def is_exception_branch(branch: str) -> bool: - """ - Branches that get opted out of experiments by default, until they're explicitly enabled. - """ - return branch.split("/")[0] in {"main", "nightly", "release", "landchecks"} - - - def load_yaml(yaml_text: str) -> Any: - try: - data = yaml.safe_load(yaml_text) - return data - except yaml.YAMLError: - log.exception("Error loading YAML") - raise - - - def extract_settings_user_opt_in_from_text(rollout_state: str) -> tuple[str, str]: - """ - Extracts the text with settings, if any, and the opted in users from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the rest is the users. - """ - rollout_state_parts = rollout_state.split("---") - if len(rollout_state_parts) >= 2: - return rollout_state_parts[0], rollout_state_parts[1] - else: - return "", rollout_state - - - class UserOptins(dict[str, list[str]]): - """ - Dictionary of users with a list of features they have opted into - """ - - - def parse_user_opt_in_from_text(user_optin_text: str) -> UserOptins: - """ - Parse the user opt-in text into a key value pair of username and the list of features they have opted into - - Users are GitHub usernames with the @ prefix. Each user is also a comma-separated list of features/experiments to enable. - - Example line: "@User1,lf,split_build" - - A "#" prefix indicates the user is opted out of all experiments - - - """ - optins = UserOptins() - for user in user_optin_text.split("\n"): - user = user.strip("\r\n\t -") - if not user or not user.startswith("@"): - # Not a valid user. Skip - continue - - if user: - usr_name = user.split(",")[0].strip("@") - optins[usr_name] = [exp.strip(" ") for exp in user.split(",")[1:]] - - return optins - - - def is_valid_experiment_name(experiment_name: str) -> bool: - """ - Check if the experiment name is valid. - A valid name: - - Contains only alphanumeric characters and the special characters "_" & "-" - - The special characters "_" & "-" shouldn't be the first or last characters - - Cannot contain spaces - """ - - valid_char_regex = r"^[a-zA-Z0-9]([\w-]*[a-zA-Z0-9])?$" - valid = bool(re.match(valid_char_regex, experiment_name)) - - if valid: - return True - - log.error( - f"Invalid experiment name: {experiment_name}. Experiment names should only contain alphanumeric characters, '_', and '-'. They cannot contain spaces, and the special characters '_' and '-' cannot be the first or last characters." - ) - return False - - - def parse_settings_from_text(settings_text: str) -> Settings: - """ - Parse the experiments from the issue body into a list of ExperimentSettings - """ - try: - if settings_text: - # Escape the backtick as well so that we can have the settings in a code block on the GH issue - # for easy reading - # Note: Using ascii for the backtick so that the cat step in _runner-determinator.yml doesn't choke on - # the backtick character in shell commands. - backtick = chr(96) # backtick character - settings_text = settings_text.strip(f"\r\n\t{backtick} ") - settings = load_yaml(settings_text) - - # For now we just load experiments. We can expand this if/when we add more settings - experiments = {} - - for exp_name, exp_settings in settings.get(SETTING_EXPERIMENTS).items(): - if not is_valid_experiment_name(exp_name): - # Exclude invalid experiments from the list. We log an error, but don't raise an exception so that other experiments can still be processed. - continue - - valid_settings = {} - for setting in exp_settings: - if setting not in Experiment._fields: - log.warning( - f"Unexpected setting in experiment: {setting} = {exp_settings[setting]}" - ) - else: - valid_settings[setting] = exp_settings[setting] - - experiments[exp_name] = Experiment(**valid_settings) - return Settings(experiments) - - except Exception: - log.exception("Failed to parse settings") - - return Settings() - - - def parse_settings(rollout_state: str) -> Settings: - """ - Parse settings, if any, from the rollout state. - - If the issue body contains "---" then the text above that is the settings - and the text below is the list of opted in users. - - If it doesn't contain "---" then the settings are empty and the default values are used. - """ - settings_text, _ = extract_settings_user_opt_in_from_text(rollout_state) - return parse_settings_from_text(settings_text) - - - def parse_users(rollout_state: str) -> UserOptins: - """ - Parse users from the rollout state. - - """ - _, users_text = extract_settings_user_opt_in_from_text(rollout_state) - return parse_user_opt_in_from_text(users_text) - - - def is_user_opted_in(user: str, user_optins: UserOptins, experiment_name: str) -> bool: - """ - Check if a user is opted into an experiment - """ - return experiment_name in user_optins.get(user, []) - - - def is_user_opted_out(user: str, user_optins: UserOptins, experiment_name: str) -> bool: - """ - Check if a user explicitly opted out of an experiment - """ - # if the experiment is prefixed with a "-", then it's an opt-out - experiment_optout = "-" + experiment_name - if experiment_optout not in user_optins.get(user, []): - return False - - if is_user_opted_in(user, user_optins, experiment_name): - log.warning( - f"User {user} is opted into experiment {experiment_name}, but also opted out of it. Defaulting to opting out" - ) - - return True - - - def get_runner_prefix( - rollout_state: str, - workflow_requestors: Iterable[str], - branch: str, - eligible_experiments: frozenset[str] = frozenset(), - opt_out_experiments: frozenset[str] = frozenset(), - is_canary: bool = False, - ) -> str: - settings = parse_settings(rollout_state) - user_optins = parse_users(rollout_state) - - fleet_prefix = "" - prefixes = [] - for experiment_name, experiment_settings in settings.experiments.items(): - if not experiment_settings.all_branches and is_exception_branch(branch): - log.info( - f"Branch {branch} is an exception branch. Not enabling experiment {experiment_name}." - ) - continue - - if opt_out_experiments: - if experiment_name in opt_out_experiments: - opt_out_exp_list = ", ".join(opt_out_experiments) - log.info( - f"Skipping experiment '{experiment_name}', as this workflow has opted-out (opted out experiments are: {opt_out_exp_list})" - ) - continue - - if eligible_experiments: - if experiment_name not in eligible_experiments: - exp_list = ", ".join(eligible_experiments) - log.info( - f"Skipping experiment '{experiment_name}', as it is not in the eligible_experiments list: {exp_list}" - ) - continue - elif not experiment_settings.default: - log.info( - f"Skipping experiment '{experiment_name}', as it is not a default experiment" - ) - continue - - # Is any workflow_requestor opted out to this experiment? - opted_out_users = [ - requestor - for requestor in workflow_requestors - if is_user_opted_out(requestor, user_optins, experiment_name) - ] - - if opted_out_users: - log.info( - f"{', '.join(opted_out_users)} have opted out of experiment {experiment_name}." - ) - continue - - # Is any workflow_requestor opted in to this experiment? - opted_in_users = [ - requestor - for requestor in workflow_requestors - if is_user_opted_in(requestor, user_optins, experiment_name) - ] - - enabled = False - if opted_in_users: - log.info( - f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." - ) - enabled = True - - elif experiment_settings.rollout_perc: - # If no user is opted in, then we randomly enable the experiment based on the rollout percentage - if random.uniform(0, 100) <= experiment_settings.rollout_perc: - log.info( - f"Based on rollout percentage of {experiment_settings.rollout_perc}%, enabling experiment {experiment_name}." - ) - enabled = True - - if enabled: - label = experiment_name - if experiment_name == LF_FLEET_EXPERIMENT: - # We give some special treatment to the "lf" experiment since determines the fleet we use - # - If it's enabled, then we always list it's prefix first - # - If we're in the canary branch, then we append ".c" to the lf prefix - if is_canary: - label += CANARY_FLEET_SUFFIX - fleet_prefix = label - else: - prefixes.append(label) - - if len(prefixes) > 1: - log.error( - f"Only a fleet and one other experiment can be enabled for a job at any time. Enabling {prefixes[0]} and ignoring the rest, which are {', '.join(prefixes[1:])}" - ) - prefixes = prefixes[:1] - - # Fleet always comes first - if fleet_prefix: - prefixes.insert(0, fleet_prefix) - - return ".".join(prefixes) + "." if prefixes else "" - - - def get_rollout_state_from_issue(github_token: str, repo: str, issue_num: int) -> str: - """ - Gets the first comment of the issue, which contains the desired rollout state. - - The default issue we use - https://github.com/pytorch/test-infra/issues/5132 - """ - gh = get_gh_client(github_token) - issue = get_issue(gh, repo, issue_num) - return str(issue.get_comments()[0].body.strip("\n\t ")) - - - def download_json(url: str, headers: dict[str, str], num_retries: int = 3) -> Any: - for _ in range(num_retries): - try: - req = Request(url=url, headers=headers) - content = urlopen(req, timeout=5).read().decode("utf-8") - return json.loads(content) - except Exception as e: - log.warning(f"Could not download {url}: {e}") - - log.warning(f"All {num_retries} retries exhausted, downloading {url} failed") - return {} - - - @cache - def get_pr_info(github_repo: str, github_token: str, pr_number: int) -> dict[str, Any]: - """ - Dynamically get PR information - """ - github_api = f"https://api.github.com/repos/{github_repo}" - headers = { - "Accept": "application/vnd.github.v3+json", - "Authorization": f"token {github_token}", - } - json_response: dict[str, Any] = download_json( - url=f"{github_api}/issues/{pr_number}", - headers=headers, - ) - - if not json_response: - log.warning(f"Failed to get the labels for #{pr_number}") - return {} - - return json_response - - - def get_labels(github_repo: str, github_token: str, pr_number: int) -> set[str]: - """ - Dynamically get the latest list of labels from the pull request - """ - pr_info = get_pr_info(github_repo, github_token, pr_number) - return { - label.get("name") for label in pr_info.get("labels", []) if label.get("name") - } - - - def main() -> None: - args = parse_args() - - runner_label_prefix = DEFAULT_LABEL_PREFIX - - # Check if the PR is opt-out - if args.pr_number: - labels = get_labels(args.github_repo, args.github_token, int(args.pr_number)) - if OPT_OUT_LABEL in labels: - log.info( - f"Opt-out runner determinator because #{args.pr_number} has {OPT_OUT_LABEL} label" - ) - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) - sys.exit() - - try: - rollout_state = get_rollout_state_from_issue( - args.github_token, args.github_issue_repo, args.github_issue - ) - - username = get_potential_pr_author( - args.github_token, - args.github_repo, - args.github_actor, - args.github_ref_type, - args.github_branch, - ) - - is_canary = args.github_repo == "pytorch/pytorch-canary" - - runner_label_prefix = get_runner_prefix( - rollout_state, - (args.github_issue_owner, username), - args.github_branch, - args.eligible_experiments, - args.opt_out_experiments, - is_canary, - ) - - except Exception as e: - log.error( - f"Failed to get issue. Defaulting to Meta runners and no experiments. Exception: {e}" - ) - - set_github_output(GH_OUTPUT_KEY_LABEL_TYPE, runner_label_prefix) - - - if __name__ == "__main__": - main() - - EOF - - cat runner_determinator.py + - name: Checkout PyTorch + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install dependencies run: python3 -m pip install urllib3==1.26.18 PyGithub==2.3.0 @@ -721,7 +93,7 @@ jobs: curr_ref_type="${{ inputs.curr_ref_type }}" echo "Current branch is '$curr_branch'" - python3 runner_determinator.py \ + python3 .github/scripts/runner_determinator.py \ --github-token "$GITHUB_TOKEN" \ --github-issue "$ISSUE_NUMBER" \ --github-branch "$curr_branch" \ @@ -732,3 +104,30 @@ jobs: --eligible-experiments "$CHECK_EXPERIMENTS" \ --opt-out-experiments "$OPT_OUT_EXPERIMENTS" \ --pr-number "${PR_NUMBER}" + + - name: Determine runner configuration + id: set-runner-info + run: | + case "${{ inputs.runner_config }}" in + "") + runner_config="m8g" + runner_type="metal-24xl" + ;; + "m8g") + runner_config="m8g" + runner_type="metal-24xl" + ;; + "m7g") + runner_config="m7g" + runner_type="metal" + ;; + *) + echo "Unsupported runner_config: ${{ inputs.runner_config }}" + exit 1 + ;; + esac + + runner_label="linux.arm64.${runner_config}.${runner_type}" + echo "runner-config=${runner_config}" >> "$GITHUB_OUTPUT" + echo "runner-type=${runner_type}" >> "$GITHUB_OUTPUT" + echo "runner-label=${runner_label}" >> "$GITHUB_OUTPUT" diff --git a/.github/workflows/_vllm-benchmark.yml b/.github/workflows/_vllm-benchmark.yml index a97a3190c2ea6..a993d54bac4b9 100644 --- a/.github/workflows/_vllm-benchmark.yml +++ b/.github/workflows/_vllm-benchmark.yml @@ -190,6 +190,7 @@ jobs: --models "${MODELS}" \ --device "${DEVICE_NAME}" \ --include-eager-mode \ + --include-inductor-graph-partition \ --compilation-config "${COMPILATION_CONFIG}" popd diff --git a/.github/workflows/_vllm-build.yml b/.github/workflows/_vllm-build.yml index b3e8c546c66f0..1d0c7426f95c9 100644 --- a/.github/workflows/_vllm-build.yml +++ b/.github/workflows/_vllm-build.yml @@ -88,7 +88,7 @@ jobs: set -eux python use_existing_torch.py - pip install -r requirements/build.txt + pip install -r requirements/build/cuda.txt sccache --show-stats python setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38 diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index 005d68ece857d..896241e3ea860 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -193,9 +193,19 @@ jobs: cp "${CUDA_PATH}/lib/x64/cudart.lib" /c/${{ github.run_id }}/build-results/ fi + # Copy CUDA runtime DLL (needed for MinGW import lib generation on CUDA 13.0+) + echo "Searching for CUDA runtime DLLs in ${CUDA_PATH}/bin/:" + ls -la "${CUDA_PATH}"/bin/cudart*.dll 2>/dev/null || echo "No cudart*.dll found in ${CUDA_PATH}/bin/" + for dll in "${CUDA_PATH}"/bin/cudart64_*.dll "${CUDA_PATH}"/bin/cudart.dll; do + if [ -f "$dll" ]; then + cp "$dll" /c/${{ github.run_id }}/build-results/ + echo "Copied $(basename $dll)" + fi + done + # List collected files - echo "Collected CUDA libs:" - ls -lah /c/${{ github.run_id }}/build-results/*.lib + echo "Collected CUDA libs and DLLs:" + ls -lah /c/${{ github.run_id }}/build-results/*.lib /c/${{ github.run_id }}/build-results/*.dll 2>/dev/null || true # Upload to github so that people can click and download artifacts - name: Upload artifacts to s3 diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index 5724403e6de44..31d33e6bc9f6a 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -91,9 +91,6 @@ jobs: - name: Setup XPU uses: ./.github/actions/setup-xpu - - name: Login to ECR - uses: ./.github/actions/ecr-login - - name: Calculate docker image id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main @@ -251,7 +248,7 @@ jobs: -e ZE_AFFINITY_MASK \ -e HUGGING_FACE_HUB_TOKEN \ -e DASHBOARD_TAG \ - --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ + --env-file="${RUNNER_TEMP}/github_env_${GITHUB_RUN_ID}" \ --ulimit stack=10485760:83886080 \ --ulimit core=0 \ --security-opt seccomp=unconfined \ diff --git a/.github/workflows/apply-lint.yml b/.github/workflows/apply-lint.yml index dbbce9c14e292..d30c0a2b316e7 100644 --- a/.github/workflows/apply-lint.yml +++ b/.github/workflows/apply-lint.yml @@ -10,7 +10,7 @@ jobs: environment: mergebot env: GH_RUN_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} - LINTERS: CLANGFORMAT,NEWLINE,PYFMT,BAZEL_LINTER,RUFF,ATEN_CPU_GPU_AGNOSTIC + LINTERS: CLANGFORMAT,NEWLINE,PYFMT,RUFF,ATEN_CPU_GPU_AGNOSTIC steps: - name: Checkout repo id: checkout diff --git a/.github/workflows/attention_op_microbenchmark.yml b/.github/workflows/attention_op_microbenchmark.yml index cd04a48223ce1..30ae0978cb68c 100644 --- a/.github/workflows/attention_op_microbenchmark.yml +++ b/.github/workflows/attention_op_microbenchmark.yml @@ -16,8 +16,20 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + check_experiments: arc,lf + attn-microbenchmark-build: if: github.repository_owner == 'pytorch' uses: ./.github/workflows/_linux-build.yml @@ -44,13 +56,15 @@ jobs: test-matrix: ${{ needs.attn-microbenchmark-build.outputs.test-matrix }} secrets: inherit - # B200 runner + # B200 runner (OSDC), always use OSDC runner to test this workflow opmicrobenchmark-build-b200: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build-b200 uses: ./.github/workflows/_linux-build.yml + needs: get-label-type with: - runner: linux.12xlarge.memory + runner_prefix: "mt-" + runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 cuda-arch-list: '10.0' @@ -58,16 +72,25 @@ jobs: { include: [ { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, ]} + use-arc: true + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit opmicrobenchmark-test-b200: name: opmicrobenchmark-test-b200 uses: ./.github/workflows/_linux-test.yml - needs: opmicrobenchmark-build-b200 + needs: + - opmicrobenchmark-build-b200 + - get-label-type with: timeout-minutes: 500 build-environment: ${{ needs.opmicrobenchmark-build-b200.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} - aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + use-arc: true + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit diff --git a/.github/workflows/b200-distributed.yml b/.github/workflows/b200-distributed.yml index e52c7a4b5f5c5..6ce72fd11c442 100644 --- a/.github/workflows/b200-distributed.yml +++ b/.github/workflows/b200-distributed.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/b200-symm-mem.yml b/.github/workflows/b200-symm-mem.yml index 62367b61b07b9..256409ace2d6c 100644 --- a/.github/workflows/b200-symm-mem.yml +++ b/.github/workflows/b200-symm-mem.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/build-almalinux-images.yml b/.github/workflows/build-almalinux-images.yml index f459c307118fe..2a2b2c8f0caf0 100644 --- a/.github/workflows/build-almalinux-images.yml +++ b/.github/workflows/build-almalinux-images.yml @@ -36,7 +36,7 @@ jobs: runs-on: linux.9xlarge.ephemeral strategy: matrix: - tag: ["cuda12.6", "cuda12.8", "cuda12.9", "cuda13.0", "cuda13.2", "rocm7.0", "rocm7.1", "rocm7.2", "cpu"] + tag: ["cuda12.6", "cuda12.8", "cuda13.0", "cuda13.2", "rocm7.0", "rocm7.1", "rocm7.2", "cpu"] steps: - name: Build docker image uses: pytorch/pytorch/.github/actions/binary-docker-build@main diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml deleted file mode 100644 index 47bf15e1db3ab..0000000000000 --- a/.github/workflows/build-libtorch-images.yml +++ /dev/null @@ -1,68 +0,0 @@ -name: Build libtorch docker images - -on: - push: - branches: - - main - - release/* - tags: - # NOTE: Binary build pipelines should only get triggered on release candidate or nightly builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - paths: - - .ci/docker/** - - .github/workflows/build-libtorch-images.yml - - .github/actions/binary-docker-build/** - pull_request: - paths: - - .ci/docker/** - - .github/workflows/build-libtorch-images.yml - - .github/actions/binary-docker-build/** - -env: - DOCKER_REGISTRY: "docker.io" - DOCKER_BUILDKIT: 1 - WITH_PUSH: ${{ github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) }} - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - get-label-type: - if: github.repository_owner == 'pytorch' - name: get-label-type - uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main - with: - triggering_actor: ${{ github.triggering_actor }} - issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} - curr_branch: ${{ github.head_ref || github.ref_name }} - curr_ref_type: ${{ github.ref_type }} - - build: - environment: ${{ (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release') || startsWith(github.ref, 'refs/tags/v')) && 'docker-build') || '' }} - needs: get-label-type - runs-on: ${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral - name: libtorch-cxx11-builder:${{ matrix.tag }} - strategy: - fail-fast: false - matrix: - include: [ - { tag: "cuda13.0" }, - { tag: "cuda12.9" }, - { tag: "cuda12.8" }, - { tag: "cuda12.6" }, - { tag: "rocm7.0" }, - { tag: "rocm7.1" }, - { tag: "rocm7.2" }, - { tag: "cpu" }, - ] - steps: - - name: Build docker image - uses: pytorch/pytorch/.github/actions/binary-docker-build@main - with: - docker-image-name: libtorch-cxx11-builder - custom-tag-prefix: ${{ matrix.tag }} - docker-build-dir: libtorch - DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }} - DOCKER_ID: ${{ secrets.DOCKER_ID }} diff --git a/.github/workflows/build-magma-windows.yml b/.github/workflows/build-magma-windows.yml index b7d293a5cec11..39a705338f189 100644 --- a/.github/workflows/build-magma-windows.yml +++ b/.github/workflows/build-magma-windows.yml @@ -22,7 +22,7 @@ jobs: runs-on: windows-2022 strategy: matrix: - cuda_version: ["130", "129", "128", "126"] + cuda_version: ["132", "130", "129", "128", "126"] config: ["Release", "Debug"] env: CUDA_VERSION: ${{ matrix.cuda_version }} diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index f86cefd7c7a1a..d25dd4a14b1db 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -46,14 +46,14 @@ jobs: fail-fast: false matrix: include: [ + { name: "manylinux2_28-builder", tag: "cuda13.2", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda13.0", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.8", runner: "linux.9xlarge.ephemeral" }, - { name: "manylinux2_28-builder", tag: "cuda12.9", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "cuda12.6", runner: "linux.9xlarge.ephemeral" }, - { name: "manylinuxaarch64-builder", tag: "cuda13.0", runner: "linux.arm64.2xlarge.ephemeral" }, - { name: "manylinuxaarch64-builder", tag: "cuda12.9", runner: "linux.arm64.2xlarge.ephemeral" }, - { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.2xlarge.ephemeral" }, - { name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.2xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda13.2", runner: "linux.arm64.m7g.4xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda13.0", runner: "linux.arm64.m7g.4xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda12.8", runner: "linux.arm64.m7g.4xlarge.ephemeral" }, + { name: "manylinuxaarch64-builder", tag: "cuda12.6", runner: "linux.arm64.m7g.4xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm7.0", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm7.1", runner: "linux.9xlarge.ephemeral" }, { name: "manylinux2_28-builder", tag: "rocm7.2", runner: "linux.9xlarge.ephemeral" }, diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 49534b2c81bbc..17e5639b01598 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -78,14 +78,11 @@ jobs: with: github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main with: submodules: false - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Login to ECR uses: ./.github/actions/ecr-login @@ -165,7 +162,7 @@ jobs: docker exec -t "${container_name}" bash -c "${PYTHON_EXECUTABLE} /pytorch/.github/scripts/build_triton_wheel.py --device=$BUILD_DEVICE $RELEASE $WITH_CLANG_LDD" if [[ ("${{ matrix.device }}" == "cuda" || "${{ matrix.device }}" == "xpu") ]]; then - docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} //artifacts/*.whl" + docker exec -t "${container_name}" bash -c "auditwheel repair --plat ${PLATFORM} --exclude libtriton.so //artifacts/*.whl" else docker exec -t "${container_name}" bash -c "mkdir //artifacts/wheelhouse" docker exec -t "${container_name}" bash -c "mv //artifacts/*.whl //artifacts/wheelhouse/" diff --git a/.github/workflows/build-vllm-wheel.yml b/.github/workflows/build-vllm-wheel.yml index 9865ec6acec81..5008d3bb1dad2 100644 --- a/.github/workflows/build-vllm-wheel.yml +++ b/.github/workflows/build-vllm-wheel.yml @@ -28,33 +28,24 @@ jobs: matrix: python-version: [ '3.12' ] platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ] - device: [ 'cu128', 'cu129', 'cu130' ] + device: [ 'cu130', 'cu132' ] include: - - platform: manylinux_2_28_x86_64 - device: cu128 - manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.8' - runner: linux.12xlarge.memory - - platform: manylinux_2_28_x86_64 - device: cu129 - manylinux-image: 'pytorch/manylinux2_28-builder:cuda12.9' - runner: linux.12xlarge.memory - platform: manylinux_2_28_x86_64 device: cu130 manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.0' runner: linux.12xlarge.memory + - platform: manylinux_2_28_x86_64 + device: cu132 + manylinux-image: 'pytorch/manylinux2_28-builder:cuda13.2' + runner: linux.12xlarge.memory - platform: manylinux_2_28_aarch64 - device: cu128 - manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.8' + device: cu130 + manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda13.0' runner: linux.arm64.r7g.12xlarge.memory - platform: manylinux_2_28_aarch64 - device: cu129 - manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda12.9' + device: cu132 + manylinux-image: 'pytorch/manylinuxaarch64-builder:cuda13.2' runner: linux.arm64.r7g.12xlarge.memory - exclude: - # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and - # xformers is update to support 13.0 - - platform: manylinux_2_28_aarch64 - device: cu130 name: "Build ${{ matrix.device }} vLLM wheel on ${{ matrix.platform }}" runs-on: ${{ matrix.runner }} timeout-minutes: 480 @@ -69,14 +60,11 @@ jobs: with: github-secret: ${{ secrets.GITHUB_TOKEN }} - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main with: submodules: false - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Login to ECR uses: ./.github/actions/ecr-login @@ -180,12 +168,7 @@ jobs: fail-fast: false matrix: platform: [ 'manylinux_2_28_x86_64', 'manylinux_2_28_aarch64' ] - device: [ 'cu128', 'cu129', 'cu130' ] - exclude: - # TODO (huydhn): Add cu130 aarch64 once PyTorch is on 2.9+ and - # xformers is update to support 13.0 - - platform: manylinux_2_28_aarch64 - device: cu130 + device: [ 'cu130', 'cu132' ] env: PLATFORM: ${{ matrix.platform }} BUILD_DEVICE: ${{ matrix.device }} diff --git a/.github/workflows/claude-autorevert-advisor.yml b/.github/workflows/claude-autorevert-advisor.yml index 8bbc04a86bf51..6cd96349c6e8e 100644 --- a/.github/workflows/claude-autorevert-advisor.yml +++ b/.github/workflows/claude-autorevert-advisor.yml @@ -38,15 +38,16 @@ jobs: git fetch origin ${{ inputs.suspect_commit }} --depth=32 || true - name: Configure AWS credentials via OIDC - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_claude_code aws-region: us-east-1 - name: Run AI Advisor id: claude - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89 with: + allowed_bots: "pytorch-auto-revert[bot],pytorch-bot[bot]" use_bedrock: "true" claude_args: | --model global.anthropic.claude-opus-4-6-v1 @@ -127,6 +128,46 @@ jobs: path: /tmp/verdict/verdict.json retention-days: 30 + - name: Upload verdict to S3 for ClickHouse ingestion + if: always() && steps.claude.outputs.structured_output != '' + env: + VERDICT_JSON: ${{ steps.claude.outputs.structured_output }} + SIGNAL_PATTERN: ${{ inputs.signal_pattern }} + run: | + # Build enriched verdict JSON with signal metadata for CH ingestion + jq -n \ + --arg repo "${{ github.repository }}" \ + --argjson run_id "${{ github.run_id }}" \ + --argjson run_attempt "${{ github.run_attempt }}" \ + --arg timestamp "$(date -u +%Y-%m-%dT%H:%M:%S.000)" \ + --arg suspect_commit "${{ inputs.suspect_commit }}" \ + --argjson pr_number "${{ inputs.pr_number }}" \ + --arg signal_key "$(echo "$SIGNAL_PATTERN" | jq -r '.signal_key // ""')" \ + --arg signal_source "$(echo "$SIGNAL_PATTERN" | jq -r '.signal_source // ""')" \ + --arg workflow_name "$(echo "$SIGNAL_PATTERN" | jq -r '.workflow_name // ""')" \ + --arg verdict "$(echo "$VERDICT_JSON" | jq -r '.verdict // ""')" \ + --argjson confidence "$(echo "$VERDICT_JSON" | jq -r '.confidence // 0')" \ + --arg summary "$(echo "$VERDICT_JSON" | jq -r '.summary // ""')" \ + --arg causal_reasoning "$(echo "$VERDICT_JSON" | jq -r '.causal_reasoning // ""')" \ + '{ + repo: $repo, + run_id: $run_id, + run_attempt: $run_attempt, + timestamp: $timestamp, + suspect_commit: $suspect_commit, + pr_number: $pr_number, + signal_key: $signal_key, + signal_source: $signal_source, + workflow_name: $workflow_name, + verdict: $verdict, + confidence: $confidence, + summary: $summary, + causal_reasoning: $causal_reasoning + }' > /tmp/advisor_verdict.json + + aws s3 cp /tmp/advisor_verdict.json \ + "s3://ossci-raw-job-status/autorevert_advisor_verdicts/${{ github.repository }}/${{ github.run_id }}_${{ github.run_attempt }}.json" + - name: Upload usage metrics if: always() uses: pytorch/test-infra/.github/actions/upload-claude-usage@main diff --git a/.github/workflows/claude-code.yml b/.github/workflows/claude-code.yml index 114c480a17e3b..e09b63a1eac89 100644 --- a/.github/workflows/claude-code.yml +++ b/.github/workflows/claude-code.yml @@ -17,3 +17,6 @@ jobs: secrets: inherit with: additional_claude_args: '--allowedTools Skill' + append_system_prompt: | + When asked to review a PR, always use the /pr-review skill first. + It contains PyTorch-specific review guidelines, output format, and critical checks. diff --git a/.github/workflows/claude-issue-triage-run.yml b/.github/workflows/claude-issue-triage-run.yml index 8794a68314221..0173b7fa3b5ef 100644 --- a/.github/workflows/claude-issue-triage-run.yml +++ b/.github/workflows/claude-issue-triage-run.yml @@ -60,22 +60,24 @@ jobs: "ghcr.io/github/github-mcp-server:v0.30.1" ], "env": { - "GITHUB_PERSONAL_ACCESS_TOKEN": "${{ secrets.GITHUB_TOKEN }}" + "GITHUB_PERSONAL_ACCESS_TOKEN": "${GITHUB_TOKEN}" } } } } EOF + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Configure AWS credentials via OIDC - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_claude_code aws-region: us-east-1 - name: Run Issue Triage timeout-minutes: 5 - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@6e2bd52842c65e914eba5c8badd17560bd26b5de # v1.0.89 env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} TRIAGE_HOOK_DEBUG_LOG: /tmp/triage_hooks.log diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 3345c92f65faf..c1d75ee093366 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -53,18 +53,17 @@ jobs: pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11, pytorch-linux-jammy-cuda13.0-cudnn9-py3.12-gcc11-vllm, pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks, - pytorch-linux-jammy-py3.10-clang15, - pytorch-linux-jammy-py3.11-clang15, - pytorch-linux-jammy-py3.12-clang15, - pytorch-linux-jammy-py3.13-clang15, - pytorch-linux-jammy-py3.14-clang15, - pytorch-linux-jammy-py3.14t-clang15, + pytorch-linux-jammy-py3.10-clang18, + pytorch-linux-jammy-py3.11-clang18, + pytorch-linux-jammy-py3.12-clang18, + pytorch-linux-jammy-py3.13-clang18, + pytorch-linux-jammy-py3.14-clang18, + pytorch-linux-jammy-py3.14t-clang18, pytorch-linux-jammy-rocm-n-py3, pytorch-linux-noble-rocm-n-py3, pytorch-linux-noble-rocm-nightly-py3, pytorch-linux-jammy-rocm-n-py3-benchmarks, - pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang15, - pytorch-linux-jammy-py3.10-gcc11, + pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang18, pytorch-linux-jammy-py3-gcc11-inductor-benchmarks, pytorch-linux-jammy-py3.12-halide, pytorch-linux-jammy-py3.12-pallas, @@ -74,12 +73,10 @@ jobs: pytorch-linux-noble-xpu-n-py3, pytorch-linux-noble-xpu-n-py3-client, pytorch-linux-noble-xpu-n-py3-inductor-benchmarks, - pytorch-linux-jammy-py3-clang18-asan, - pytorch-linux-jammy-py3-clang15-onnx, pytorch-linux-jammy-linter, pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter, # TODO: Re-enable me when docker pin update happens - # pytorch-linux-jammy-py3-clang15-executorch, + # pytorch-linux-jammy-py3-clang18-executorch, pytorch-linux-jammy-py3.12-triton-cpu, pytorch-linux-noble-riscv64-py3.12-gcc14 ] @@ -100,13 +97,10 @@ jobs: sudo rm -rf "${GITHUB_WORKSPACE}" mkdir "${GITHUB_WORKSPACE}" - # [see note: pytorch repo ref] - # deep clone (fetch-depth 0) required for git merge-base - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - - name: Setup Linux - uses: ./.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + submodules: 'false' - name: Login to ECR uses: ./.github/actions/ecr-login @@ -159,6 +153,9 @@ jobs: echo $GHCR_PAT | docker login ghcr.io -u pytorch --password-stdin docker tag "${ECR_DOCKER_IMAGE}" "${ghcr_image}:${tag}" docker push "${ghcr_image}:${tag}" + # Also push a tag without the hash for easier reference + docker tag "${ECR_DOCKER_IMAGE}" "${ghcr_image}:${{ matrix.docker-image-name }}" + docker push "${ghcr_image}:${{ matrix.docker-image-name }}" - name: Chown workspace uses: ./.github/actions/chown-workspace diff --git a/.github/workflows/docker-cache-rocm.yml b/.github/workflows/docker-cache-rocm.yml index 9d19bda10e2a7..8b5d7e4206023 100644 --- a/.github/workflows/docker-cache-rocm.yml +++ b/.github/workflows/docker-cache-rocm.yml @@ -57,7 +57,7 @@ jobs: strategy: fail-fast: false matrix: - runner: [linux.rocm.mi250.docker-cache, linux.rocm.mi210.docker-cache] + runner: [linux.rocm.mi210.docker-cache] docker-image: [ "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-jammy-rocm-n-py3 }}", "${{ needs.download-docker-builds-artifacts.outputs.pytorch-linux-noble-rocm-n-py3 }}" @@ -66,22 +66,26 @@ jobs: runs-on: "${{ matrix.runner }}" steps: - name: debug + env: + DOWNLOAD_OUTPUTS_JSON: ${{ toJSON(needs.download-docker-builds-artifacts.outputs) }} run: | - JSON_STRINGIFIED="${{ toJSON(needs.download-docker-builds-artifacts.outputs) }}" - echo "Outputs of download-docker-builds-artifacts job: ${JSON_STRINGIFIED}" + echo "Outputs of download-docker-builds-artifacts job: ${DOWNLOAD_OUTPUTS_JSON}" - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: no-sudo: true + submodules: 'false' - name: Login to ECR uses: ./.github/actions/ecr-login - name: Generate ghrc.io tag id: ghcr-io-tag + env: + ECR_IMAGE: ${{ matrix.docker-image }} run: | - ecr_image="${{ matrix.docker-image }}" + ecr_image="${ECR_IMAGE}" ghcr_image="ghcr.io/pytorch/ci-image:${ecr_image##*:}" echo "ghcr_image=${ghcr_image}" >> "$GITHUB_OUTPUT" @@ -91,11 +95,15 @@ jobs: docker-image: ${{ steps.ghcr-io-tag.outputs.ghcr_image }} - name: Save as tarball + env: + BRANCH: ${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} + ECR_IMAGE: ${{ matrix.docker-image }} + GHCR_IMAGE: ${{ steps.ghcr-io-tag.outputs.ghcr_image }} run: | - docker_image_tag=${{ matrix.docker-image }} + docker_image_tag=${ECR_IMAGE} docker_image_tag="${docker_image_tag#*:}" # Remove everything before and including first ":" docker_image_tag="${docker_image_tag%-*}" # Remove everything after and including last "-" - ref_name=${{ github.event.workflow_run.head_branch || github.event.inputs.branch }} + ref_name=${BRANCH} if [[ $ref_name =~ "release/" ]]; then ref_suffix="release" elif [[ $ref_name == "main" ]]; then @@ -103,7 +111,7 @@ jobs: else echo "Unexpected branch in ref_name: ${ref_name}" && exit 1 fi - docker tag ${{ steps.ghcr-io-tag.outputs.ghcr_image }} ${{ matrix.docker-image }} + docker tag "${GHCR_IMAGE}" "${ECR_IMAGE}" # mv is atomic operation, so we use intermediate tar.tmp file to prevent read-write contention - docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ${{ matrix.docker-image }} + docker save -o ~/pytorch-data/docker/${docker_image_tag}.tar.tmp "${ECR_IMAGE}" mv ~/pytorch-data/docker/${docker_image_tag}.tar.tmp ~/pytorch-data/docker/${docker_image_tag}_${ref_suffix}.tar diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 577a8acb5203f..d7463bb61cf03 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -56,7 +56,7 @@ jobs: uses: pytorch/pytorch/.github/actions/checkout-pytorch@main with: fetch-depth: 1 - submodules: true + submodules: 'false' - name: Get docker release matrix id: generate-matrix run: | @@ -86,15 +86,11 @@ jobs: uses: pytorch/test-infra/.github/actions/setup-ssh@main with: github-secret: ${{ secrets.GITHUB_TOKEN }} - # [see note: pytorch repo ref] - # deep clone (fetch-depth 0) required for git merge-base - - name: Checkout PyTorch - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - submodules: 'recursive' + - name: Setup Linux - uses: ./.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main + with: + submodules: 'false' - name: Login to ECR uses: ./.github/actions/ecr-login diff --git a/.github/workflows/dtensor.yml b/.github/workflows/dtensor.yml index b56012a885a1b..b45c12b0df10b 100644 --- a/.github/workflows/dtensor.yml +++ b/.github/workflows/dtensor.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/dynamo-unittest.yml b/.github/workflows/dynamo-unittest.yml index f7eea350b5644..d8432abc4338c 100644 --- a/.github/workflows/dynamo-unittest.yml +++ b/.github/workflows/dynamo-unittest.yml @@ -17,6 +17,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -39,8 +40,8 @@ jobs: python-version: ['3.11', '3.12', '3.13'] with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py${{ matrix.python-version }}-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang15 + build-environment: linux-jammy-py${{ matrix.python-version }}-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang18 test-matrix: | { include: [ { config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, @@ -58,8 +59,8 @@ jobs: matrix: python-version: ['3.11', '3.12', '3.13'] with: - build-environment: linux-jammy-py${{ matrix.python-version }}-clang15 - docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang15 + build-environment: linux-jammy-py${{ matrix.python-version }}-clang18 + docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang18 test-matrix: | { include: [ { config: "dynamo_core", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index aa4b95446a18e..2e9c0ecc15bc9 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -190,81 +190,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_10-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_10-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda-aarch64-12_9-build: + manywheel-py3_10-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -273,26 +199,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_10-cuda-aarch64-12_9 + build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-12_9-test: # Testing + manywheel-py3_10-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda-aarch64-12_9-build + - manywheel-py3_10-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -300,37 +226,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-12_9 + build_name: manywheel-py3_10-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_10-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda-aarch64-12_9-build + needs: manywheel-py3_10-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-12_9 + build_name: manywheel-py3_10-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -338,7 +264,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda-aarch64-13_0-build: + manywheel-py3_10-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -347,26 +273,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_10-cuda-aarch64-13_0 + build_name: manywheel-py3_10-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-13_0-test: # Testing + manywheel-py3_10-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda-aarch64-13_0-build + - manywheel-py3_10-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -374,37 +300,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-13_0 + build_name: manywheel-py3_10-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_10-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda-aarch64-13_0-build + needs: manywheel-py3_10-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda-aarch64-13_0 + build_name: manywheel-py3_10-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -555,81 +481,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_11-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_11-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda-aarch64-12_9-build: + manywheel-py3_11-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -638,26 +490,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_11-cuda-aarch64-12_9 + build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-12_9-test: # Testing + manywheel-py3_11-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda-aarch64-12_9-build + - manywheel-py3_11-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -665,37 +517,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-12_9 + build_name: manywheel-py3_11-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_11-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda-aarch64-12_9-build + needs: manywheel-py3_11-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-12_9 + build_name: manywheel-py3_11-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -703,7 +555,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda-aarch64-13_0-build: + manywheel-py3_11-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -712,26 +564,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_11-cuda-aarch64-13_0 + build_name: manywheel-py3_11-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-13_0-test: # Testing + manywheel-py3_11-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda-aarch64-13_0-build + - manywheel-py3_11-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -739,37 +591,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-13_0 + build_name: manywheel-py3_11-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_11-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda-aarch64-13_0-build + needs: manywheel-py3_11-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda-aarch64-13_0 + build_name: manywheel-py3_11-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -920,81 +772,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_12-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_12-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda-aarch64-12_9-build: + manywheel-py3_12-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1003,26 +781,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_12-cuda-aarch64-12_9 + build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-12_9-test: # Testing + manywheel-py3_12-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda-aarch64-12_9-build + - manywheel-py3_12-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1030,37 +808,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-12_9 + build_name: manywheel-py3_12-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_12-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda-aarch64-12_9-build + needs: manywheel-py3_12-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-12_9 + build_name: manywheel-py3_12-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1068,7 +846,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda-aarch64-13_0-build: + manywheel-py3_12-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1077,26 +855,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_12-cuda-aarch64-13_0 + build_name: manywheel-py3_12-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-13_0-test: # Testing + manywheel-py3_12-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda-aarch64-13_0-build + - manywheel-py3_12-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1104,37 +882,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-13_0 + build_name: manywheel-py3_12-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_12-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda-aarch64-13_0-build + needs: manywheel-py3_12-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda-aarch64-13_0 + build_name: manywheel-py3_12-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1285,81 +1063,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_13-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cuda-aarch64-12_9-build: + manywheel-py3_13-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1368,26 +1072,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13-cuda-aarch64-12_9 + build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-12_9-test: # Testing + manywheel-py3_13-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda-aarch64-12_9-build + - manywheel-py3_13-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1395,37 +1099,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-12_9 + build_name: manywheel-py3_13-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_13-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda-aarch64-12_9-build + needs: manywheel-py3_13-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-12_9 + build_name: manywheel-py3_13-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1433,7 +1137,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda-aarch64-13_0-build: + manywheel-py3_13-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1442,26 +1146,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13-cuda-aarch64-13_0 + build_name: manywheel-py3_13-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-13_0-test: # Testing + manywheel-py3_13-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda-aarch64-13_0-build + - manywheel-py3_13-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1469,37 +1173,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-13_0 + build_name: manywheel-py3_13-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_13-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda-aarch64-13_0-build + needs: manywheel-py3_13-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda-aarch64-13_0 + build_name: manywheel-py3_13-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1537,112 +1241,38 @@ jobs: uses: ./.github/workflows/_binary-test-linux.yml with: PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: manylinux2_28_aarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cpu-aarch64 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cpu-aarch64-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cpu-aarch64-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu-aarch64 - DOCKER_IMAGE: manylinux2_28_aarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cpu-aarch64 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cuda-aarch64-12_6-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: "12.6-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13t-cuda-aarch64-12_6 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.6.3; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_13t-cuda-aarch64-12_6-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cuda-aarch64-12_6-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: "12.6-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + PACKAGE_TYPE: manywheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-aarch64 + DOCKER_IMAGE: manylinux2_28_aarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_6 + build_name: manywheel-py3_13t-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_6-upload: # Uploading + manywheel-py3_13t-cpu-aarch64-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda-aarch64-12_6-build + needs: manywheel-py3_13t-cpu-aarch64-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu126 - GPU_ARCH_VERSION: "12.6-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.6 + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu-aarch64 + DOCKER_IMAGE: manylinux2_28_aarch64-builder + DOCKER_IMAGE_TAG_PREFIX: cpu-aarch64 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_6 + build_name: manywheel-py3_13t-cpu-aarch64 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1650,7 +1280,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda-aarch64-12_8-build: + manywheel-py3_13t-cuda-aarch64-12_6-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1659,26 +1289,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.6.3; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.10.2.21; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_8-test: # Testing + manywheel-py3_13t-cuda-aarch64-12_6-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda-aarch64-12_8-build + - manywheel-py3_13t-cuda-aarch64-12_6-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1686,37 +1316,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_name: manywheel-py3_13t-cuda-aarch64-12_6 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_8-upload: # Uploading + manywheel-py3_13t-cuda-aarch64-12_6-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda-aarch64-12_8-build + needs: manywheel-py3_13t-cuda-aarch64-12_6-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" + DESIRED_CUDA: cu126 + GPU_ARCH_VERSION: "12.6-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 + DOCKER_IMAGE_TAG_PREFIX: cuda12.6 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_8 + build_name: manywheel-py3_13t-cuda-aarch64-12_6 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1724,7 +1354,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda-aarch64-12_9-build: + manywheel-py3_13t-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1733,26 +1363,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13t-cuda-aarch64-12_9 + build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_9-test: # Testing + manywheel-py3_13t-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda-aarch64-12_9-build + - manywheel-py3_13t-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1760,37 +1390,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_9 + build_name: manywheel-py3_13t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_13t-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda-aarch64-12_9-build + needs: manywheel-py3_13t-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-12_9 + build_name: manywheel-py3_13t-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1798,7 +1428,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda-aarch64-13_0-build: + manywheel-py3_13t-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1807,26 +1437,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_13t-cuda-aarch64-13_0 + build_name: manywheel-py3_13t-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-13_0-test: # Testing + manywheel-py3_13t-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda-aarch64-13_0-build + - manywheel-py3_13t-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1834,37 +1464,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-13_0 + build_name: manywheel-py3_13t-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_13t-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda-aarch64-13_0-build + needs: manywheel-py3_13t-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda-aarch64-13_0 + build_name: manywheel-py3_13t-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2015,81 +1645,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_14-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_14-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_14-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_14-cuda-aarch64-12_9-build: + manywheel-py3_14-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2098,26 +1654,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14-cuda-aarch64-12_9 + build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda-aarch64-12_9-test: # Testing + manywheel-py3_14-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-cuda-aarch64-12_9-build + - manywheel-py3_14-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2125,37 +1681,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-12_9 + build_name: manywheel-py3_14-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_14-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-cuda-aarch64-12_9-build + needs: manywheel-py3_14-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-12_9 + build_name: manywheel-py3_14-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2163,7 +1719,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-cuda-aarch64-13_0-build: + manywheel-py3_14-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2172,26 +1728,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14-cuda-aarch64-13_0 + build_name: manywheel-py3_14-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda-aarch64-13_0-test: # Testing + manywheel-py3_14-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-cuda-aarch64-13_0-build + - manywheel-py3_14-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2199,37 +1755,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-13_0 + build_name: manywheel-py3_14-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_14-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-cuda-aarch64-13_0-build + needs: manywheel-py3_14-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda-aarch64-13_0 + build_name: manywheel-py3_14-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2380,81 +1936,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-cuda-aarch64-12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.r7g.12xlarge.memory - ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14t-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - timeout-minutes: 420 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_14t-cuda-aarch64-12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_14t-cuda-aarch64-12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-12_8 - build_environment: linux-aarch64-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.arm64.2xlarge - ALPINE_IMAGE: "arm64v8/alpine" - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda-aarch64-12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_14t-cuda-aarch64-12_8-build - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8-aarch64" - GPU_ARCH_TYPE: cuda-aarch64 - DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_14t-cuda-aarch64-12_9-build: + manywheel-py3_14t-cuda-aarch64-13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2463,26 +1945,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14t-cuda-aarch64-12_9 + build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda-aarch64-12_9-test: # Testing + manywheel-py3_14t-cuda-aarch64-13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-cuda-aarch64-12_9-build + - manywheel-py3_14t-cuda-aarch64-13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2490,37 +1972,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-12_9 + build_name: manywheel-py3_14t-cuda-aarch64-13_0 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda-aarch64-12_9-upload: # Uploading + manywheel-py3_14t-cuda-aarch64-13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-cuda-aarch64-12_9-build + needs: manywheel-py3_14t-cuda-aarch64-13_0-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9-aarch64" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-12_9 + build_name: manywheel-py3_14t-cuda-aarch64-13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2528,7 +2010,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-cuda-aarch64-13_0-build: + manywheel-py3_14t-cuda-aarch64-13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2537,26 +2019,26 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.r7g.12xlarge.memory ALPINE_IMAGE: "arm64v8/alpine" - build_name: manywheel-py3_14t-cuda-aarch64-13_0 + build_name: manywheel-py3_14t-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' timeout-minutes: 420 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda-aarch64-13_0-test: # Testing + manywheel-py3_14t-cuda-aarch64-13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-cuda-aarch64-13_0-build + - manywheel-py3_14t-cuda-aarch64-13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2564,37 +2046,37 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-13_0 + build_name: manywheel-py3_14t-cuda-aarch64-13_2 build_environment: linux-aarch64-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.arm64.2xlarge ALPINE_IMAGE: "arm64v8/alpine" secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda-aarch64-13_0-upload: # Uploading + manywheel-py3_14t-cuda-aarch64-13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-cuda-aarch64-13_0-build + needs: manywheel-py3_14t-cuda-aarch64-13_2-build with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0-aarch64" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2-aarch64" GPU_ARCH_TYPE: cuda-aarch64 DOCKER_IMAGE: manylinuxaarch64-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda-aarch64-13_0 + build_name: manywheel-py3_14t-cuda-aarch64-13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index 622d86cb009e9..c140c7f83fdf4 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -184,77 +184,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_10-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_9-build: + manywheel-py3_10-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -263,23 +193,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_9 + build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_9-test: # Testing + manywheel-py3_10-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_9-build + - manywheel-py3_10-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -287,36 +217,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_9 + build_name: manywheel-py3_10-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_9-upload: # Uploading + manywheel-py3_10-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda12_9-test + needs: manywheel-py3_10-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_9 + build_name: manywheel-py3_10-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -324,7 +254,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_10-cuda13_0-build: + manywheel-py3_10-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -333,23 +263,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda13_0 + build_name: manywheel-py3_10-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda13_0-test: # Testing + manywheel-py3_10-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda13_0-build + - manywheel-py3_10-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -357,36 +287,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda13_0 + build_name: manywheel-py3_10-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda13_0-upload: # Uploading + manywheel-py3_10-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_10-cuda13_0-test + needs: manywheel-py3_10-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda13_0 + build_name: manywheel-py3_10-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -410,7 +340,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_10-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -462,7 +392,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -526,7 +456,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.10" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_10-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -578,7 +508,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -866,77 +796,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_11-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_9-build: + manywheel-py3_11-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -945,23 +805,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_9 + build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_9-test: # Testing + manywheel-py3_11-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda12_9-build + - manywheel-py3_11-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -969,36 +829,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_9 + build_name: manywheel-py3_11-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_9-upload: # Uploading + manywheel-py3_11-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda12_9-test + needs: manywheel-py3_11-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_9 + build_name: manywheel-py3_11-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1006,7 +866,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_11-cuda13_0-build: + manywheel-py3_11-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1015,23 +875,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda13_0 + build_name: manywheel-py3_11-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda13_0-test: # Testing + manywheel-py3_11-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_11-cuda13_0-build + - manywheel-py3_11-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1039,36 +899,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda13_0 + build_name: manywheel-py3_11-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda13_0-upload: # Uploading + manywheel-py3_11-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_11-cuda13_0-test + needs: manywheel-py3_11-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda13_0 + build_name: manywheel-py3_11-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1092,7 +952,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_11-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -1144,7 +1004,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -1208,7 +1068,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.11" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_11-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -1260,7 +1120,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -1548,77 +1408,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_12-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_9-build: + manywheel-py3_12-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1627,23 +1417,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_9 + build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_9-test: # Testing + manywheel-py3_12-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda12_9-build + - manywheel-py3_12-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1651,36 +1441,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_9 + build_name: manywheel-py3_12-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_9-upload: # Uploading + manywheel-py3_12-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda12_9-test + needs: manywheel-py3_12-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_9 + build_name: manywheel-py3_12-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1688,7 +1478,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_12-cuda13_0-build: + manywheel-py3_12-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -1697,23 +1487,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda13_0 + build_name: manywheel-py3_12-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda13_0-test: # Testing + manywheel-py3_12-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_12-cuda13_0-build + - manywheel-py3_12-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -1721,36 +1511,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda13_0 + build_name: manywheel-py3_12-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda13_0-upload: # Uploading + manywheel-py3_12-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_12-cuda13_0-test + needs: manywheel-py3_12-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda13_0 + build_name: manywheel-py3_12-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1774,7 +1564,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_12-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -1826,7 +1616,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -1890,7 +1680,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.12" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_12-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -1942,7 +1732,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -2230,77 +2020,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_13-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cuda12_9-build: + manywheel-py3_13-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2309,23 +2029,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_9 + build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_9-test: # Testing + manywheel-py3_13-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda12_9-build + - manywheel-py3_13-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2333,36 +2053,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_9 + build_name: manywheel-py3_13-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_9-upload: # Uploading + manywheel-py3_13-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda12_9-test + needs: manywheel-py3_13-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_9 + build_name: manywheel-py3_13-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2370,7 +2090,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13-cuda13_0-build: + manywheel-py3_13-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2379,23 +2099,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda13_0 + build_name: manywheel-py3_13-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda13_0-test: # Testing + manywheel-py3_13-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13-cuda13_0-build + - manywheel-py3_13-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -2403,36 +2123,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda13_0 + build_name: manywheel-py3_13-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda13_0-upload: # Uploading + manywheel-py3_13-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13-cuda13_0-test + needs: manywheel-py3_13-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda13_0 + build_name: manywheel-py3_13-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2456,7 +2176,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_13-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -2508,7 +2228,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -2572,7 +2292,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.13" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_13-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -2624,7 +2344,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -2912,77 +2632,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_13t-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cuda12_9-build: + manywheel-py3_13t-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -2991,23 +2641,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_9 + build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_9-test: # Testing + manywheel-py3_13t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda12_9-build + - manywheel-py3_13t-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3015,36 +2665,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_9 + build_name: manywheel-py3_13t-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_9-upload: # Uploading + manywheel-py3_13t-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda12_9-test + needs: manywheel-py3_13t-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_9 + build_name: manywheel-py3_13t-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -3052,7 +2702,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_13t-cuda13_0-build: + manywheel-py3_13t-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3061,23 +2711,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda13_0 + build_name: manywheel-py3_13t-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda13_0-test: # Testing + manywheel-py3_13t-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_13t-cuda13_0-build + - manywheel-py3_13t-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3085,36 +2735,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda13_0 + build_name: manywheel-py3_13t-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda13_0-upload: # Uploading + manywheel-py3_13t-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_13t-cuda13_0-test + needs: manywheel-py3_13t-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda13_0 + build_name: manywheel-py3_13t-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -3138,7 +2788,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_13t-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -3190,7 +2840,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -3254,7 +2904,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.13t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_13t-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -3306,7 +2956,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -3594,77 +3244,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_14-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_14-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_14-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_14-cuda12_9-build: + manywheel-py3_14-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3673,23 +3253,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14-cuda12_9 + build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda12_9-test: # Testing + manywheel-py3_14-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-cuda12_9-build + - manywheel-py3_14-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3697,36 +3277,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda12_9 + build_name: manywheel-py3_14-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda12_9-upload: # Uploading + manywheel-py3_14-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-cuda12_9-test + needs: manywheel-py3_14-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda12_9 + build_name: manywheel-py3_14-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -3734,7 +3314,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14-cuda13_0-build: + manywheel-py3_14-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -3743,23 +3323,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14-cuda13_0 + build_name: manywheel-py3_14-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda13_0-test: # Testing + manywheel-py3_14-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14-cuda13_0-build + - manywheel-py3_14-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -3767,36 +3347,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda13_0 + build_name: manywheel-py3_14-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14-cuda13_0-upload: # Uploading + manywheel-py3_14-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14-cuda13_0-test + needs: manywheel-py3_14-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14" - build_name: manywheel-py3_14-cuda13_0 + build_name: manywheel-py3_14-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -3820,7 +3400,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_14-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -3872,7 +3452,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -3936,7 +3516,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.14" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_14-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -3988,7 +3568,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -4276,77 +3856,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-cuda12_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14t-cuda12_8 - build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.8.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_14t-cuda12_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_14t-cuda12_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda12_8 - build_environment: linux-binary-manywheel - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.g4dn.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda12_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_14t-cuda12_8-test - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.8 - DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda12_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_14t-cuda12_9-build: + manywheel-py3_14t-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -4355,23 +3865,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14t-cuda12_9 + build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==12.9.1; platform_system == 'Linux' | cuda-bindings>=12.9.4,<13; platform_system == 'Linux' | nvidia-cudnn-cu12==9.17.1.4; platform_system == 'Linux' | nvidia-cusparselt-cu12==0.7.1; platform_system == 'Linux' | nvidia-nccl-cu12==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu12==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda12_9-test: # Testing + manywheel-py3_14t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-cuda12_9-build + - manywheel-py3_14t-cuda13_0-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -4379,36 +3889,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda12_9 + build_name: manywheel-py3_14t-cuda13_0 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda12_9-upload: # Uploading + manywheel-py3_14t-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-cuda12_9-test + needs: manywheel-py3_14t-cuda13_0-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda12.9 + DOCKER_IMAGE_TAG_PREFIX: cuda13.0 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda12_9 + build_name: manywheel-py3_14t-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -4416,7 +3926,7 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - manywheel-py3_14t-cuda13_0-build: + manywheel-py3_14t-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} uses: ./.github/workflows/_binary-build-linux.yml needs: get-label-type @@ -4425,23 +3935,23 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_14t-cuda13_0 + build_name: manywheel-py3_14t-cuda13_2 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.19.0.56; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.0; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.3; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cublas,cufile,nvjitlink,nvtx]==13.2.1; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda13_0-test: # Testing + manywheel-py3_14t-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_14t-cuda13_0-build + - manywheel-py3_14t-cuda13_2-build - get-label-type uses: ./.github/workflows/_binary-test-linux.yml with: @@ -4449,36 +3959,36 @@ jobs: PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda13_0 + build_name: manywheel-py3_14t-cuda13_2 build_environment: linux-binary-manywheel runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runs_on: linux.g4dn.4xlarge.nvidia.gpu secrets: github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_14t-cuda13_0-upload: # Uploading + manywheel-py3_14t-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: manywheel-py3_14t-cuda13_0-test + needs: manywheel-py3_14t-cuda13_2-test with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: manywheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DOCKER_IMAGE: manylinux2_28-builder - DOCKER_IMAGE_TAG_PREFIX: cuda13.0 + DOCKER_IMAGE_TAG_PREFIX: cuda13.2 DESIRED_PYTHON: "3.14t" - build_name: manywheel-py3_14t-cuda13_0 + build_name: manywheel-py3_14t-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -4502,7 +4012,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.1 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_14t-rocm7_1 build_environment: linux-binary-manywheel secrets: @@ -4554,7 +4064,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -4618,7 +4128,7 @@ jobs: DOCKER_IMAGE_TAG_PREFIX: rocm7.2 DESIRED_PYTHON: "3.14t" runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - timeout-minutes: 300 + timeout-minutes: 420 build_name: manywheel-py3_14t-rocm7_2 build_environment: linux-binary-manywheel secrets: @@ -4670,7 +4180,7 @@ jobs: - name: configure aws credentials id: aws_creds if: ${{ startsWith(github.event.ref, 'refs/tags/ciflow/') }} - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only aws-region: us-east-1 @@ -4943,76 +4453,15 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-release-extract: - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_8-build - - get-label-type - runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" - timeout-minutes: 60 - env: - DESIRED_CUDA: cu128 - LIBTORCH_VARIANT: shared-with-deps - steps: - - name: Checkout PyTorch - uses: actions/checkout@v4 - with: - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - sparse-checkout: .ci/libtorch/ - show-progress: false - - uses: actions/download-artifact@v4.1.7 - name: Download Wheel Artifact - with: - name: manywheel-py3_10-cuda12_8 - path: "${{ runner.temp }}/wheel_artifact/" - - name: Extract libtorch from wheel - run: | - set -eux - mkdir -p "${{ runner.temp }}/libtorch_output" - python3 .ci/libtorch/extract_libtorch_from_wheel.py \ - --wheel-dir "${{ runner.temp }}/wheel_artifact" \ - --output-dir "${{ runner.temp }}/libtorch_output" \ - --platform linux \ - --desired-cuda "$DESIRED_CUDA" \ - --libtorch-variant "$LIBTORCH_VARIANT" - - uses: actions/upload-artifact@v4.4.0 - if: always() - with: - name: libtorch-cuda12_8-shared-with-deps-release - retention-days: 14 - if-no-files-found: error - path: "${{ runner.temp }}/libtorch_output/" - libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: libtorch-cuda12_8-shared-with-deps-release-extract - with: - PYTORCH_ROOT: /pytorch - PACKAGE_TYPE: libtorch - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" - GPU_ARCH_TYPE: cuda - LIBTORCH_CONFIG: release - LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_8-shared-with-deps-release - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} - R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} - R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} - uses: ./.github/workflows/_binary-upload.yml - - libtorch-cuda12_9-shared-with-deps-release-extract: + libtorch-cuda13_0-shared-with-deps-release-extract: if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda12_9-build + - manywheel-py3_10-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" timeout-minutes: 60 env: - DESIRED_CUDA: cu129 + DESIRED_CUDA: cu130 LIBTORCH_VARIANT: shared-with-deps steps: - name: Checkout PyTorch @@ -5024,7 +4473,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Wheel Artifact with: - name: manywheel-py3_10-cuda12_9 + name: manywheel-py3_10-cuda13_0 path: "${{ runner.temp }}/wheel_artifact/" - name: Extract libtorch from wheel run: | @@ -5039,25 +4488,25 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_9-shared-with-deps-release + name: libtorch-cuda13_0-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ runner.temp }}/libtorch_output/" - libtorch-cuda12_9-shared-with-deps-release-upload: # Uploading + libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_9-shared-with-deps-release-extract + needs: libtorch-cuda13_0-shared-with-deps-release-extract with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch - DESIRED_CUDA: cu129 - GPU_ARCH_VERSION: "12.9" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_9-shared-with-deps-release + build_name: libtorch-cuda13_0-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -5065,15 +4514,15 @@ jobs: R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda13_0-shared-with-deps-release-extract: + libtorch-cuda13_2-shared-with-deps-release-extract: if: ${{ github.repository_owner == 'pytorch' }} needs: - - manywheel-py3_10-cuda13_0-build + - manywheel-py3_10-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" timeout-minutes: 60 env: - DESIRED_CUDA: cu130 + DESIRED_CUDA: cu132 LIBTORCH_VARIANT: shared-with-deps steps: - name: Checkout PyTorch @@ -5085,7 +4534,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Wheel Artifact with: - name: manywheel-py3_10-cuda13_0 + name: manywheel-py3_10-cuda13_2 path: "${{ runner.temp }}/wheel_artifact/" - name: Extract libtorch from wheel run: | @@ -5100,25 +4549,25 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda13_0-shared-with-deps-release + name: libtorch-cuda13_2-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ runner.temp }}/libtorch_output/" - libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading + libtorch-cuda13_2-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda13_0-shared-with-deps-release-extract + needs: libtorch-cuda13_2-shared-with-deps-release-extract with: PYTORCH_ROOT: /pytorch PACKAGE_TYPE: libtorch - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda13_0-shared-with-deps-release + build_name: libtorch-cuda13_2-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index e336e52c48434..c1d956c5bdc1f 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -33,7 +33,7 @@ concurrency: jobs: wheel-py3_10-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -44,6 +44,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -144,7 +145,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_11-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -155,6 +156,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -255,7 +257,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_12-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -266,6 +268,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -366,7 +369,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_13-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -377,6 +380,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -477,7 +481,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_13t-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -488,6 +492,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -588,7 +593,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_14-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -599,6 +604,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -699,7 +705,7 @@ jobs: uses: ./.github/workflows/_binary-upload.yml wheel-py3_14t-cpu-build: if: ${{ github.repository_owner == 'pytorch' }} - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 240 env: PYTORCH_ROOT: ${{ github.workspace }} @@ -710,6 +716,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -811,7 +818,7 @@ jobs: libtorch-cpu-shared-with-deps-release-extract: if: ${{ github.repository_owner == 'pytorch' }} needs: wheel-py3_10-cpu-build - runs-on: macos-14-xlarge + runs-on: macos-26-xlarge timeout-minutes: 60 env: DESIRED_CUDA: cpu diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index cb2de611421fe..1bb00a665713b 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -528,7 +528,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-debug-build: + libtorch-cuda13_0-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -538,8 +538,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -622,7 +622,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_8-shared-with-deps-debug + name: libtorch-cuda13_0-shared-with-deps-debug retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -638,10 +638,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-debug-test: # Testing + libtorch-cuda13_0-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda12_8-shared-with-deps-debug-build + - libtorch-cuda13_0-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -650,8 +650,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -726,7 +726,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda12_8-shared-with-deps-debug + name: libtorch-cuda13_0-shared-with-deps-debug path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -747,33 +747,33 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda12_8-shared-with-deps-debug-upload: # Uploading + libtorch-cuda13_0-shared-with-deps-debug-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_8-shared-with-deps-debug-test + needs: libtorch-cuda13_0-shared-with-deps-debug-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.10" - build_name: libtorch-cuda12_8-shared-with-deps-debug + build_name: libtorch-cuda13_0-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda13_0-shared-with-deps-debug-build: + libtorch-cuda13_2-shared-with-deps-debug-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -783,8 +783,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -867,7 +867,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda13_0-shared-with-deps-debug + name: libtorch-cuda13_2-shared-with-deps-debug retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -883,10 +883,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda13_0-shared-with-deps-debug-test: # Testing + libtorch-cuda13_2-shared-with-deps-debug-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - libtorch-cuda13_0-shared-with-deps-debug-build + - libtorch-cuda13_2-shared-with-deps-debug-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -895,8 +895,8 @@ jobs: PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 LIBTORCH_CONFIG: debug @@ -971,7 +971,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: libtorch-cuda13_0-shared-with-deps-debug + name: libtorch-cuda13_2-shared-with-deps-debug path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -992,26 +992,26 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - libtorch-cuda13_0-shared-with-deps-debug-upload: # Uploading + libtorch-cuda13_2-shared-with-deps-debug-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda13_0-shared-with-deps-debug-test + needs: libtorch-cuda13_2-shared-with-deps-debug-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: libtorch # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: debug LIBTORCH_VARIANT: shared-with-deps # This is a dummy value for libtorch to work correctly with our batch scripts # without this value pip does not get installed for some reason DESIRED_PYTHON: "3.10" - build_name: libtorch-cuda13_0-shared-with-deps-debug + build_name: libtorch-cuda13_2-shared-with-deps-debug secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index f92df48b5d86f..7da288111208d 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -56,6 +56,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -505,7 +506,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda12_8-build: + wheel-py3_10-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -515,8 +516,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -595,7 +596,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda12_8 + name: wheel-py3_10-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -611,10 +612,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_8-test: # Testing + wheel-py3_10-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda12_8-build + - wheel-py3_10-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -623,8 +624,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -695,7 +696,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda12_8 + name: wheel-py3_10-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -716,29 +717,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda12_8-upload: # Uploading + wheel-py3_10-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda12_8-test + needs: wheel-py3_10-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda12_8 + build_name: wheel-py3_10-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_10-cuda13_0-build: + wheel-py3_10-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -748,8 +749,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -828,7 +829,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_10-cuda13_0 + name: wheel-py3_10-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -844,10 +845,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda13_0-test: # Testing + wheel-py3_10-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda13_0-build + - wheel-py3_10-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -856,8 +857,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" @@ -928,7 +929,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_10-cuda13_0 + name: wheel-py3_10-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -949,22 +950,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_10-cuda13_0-upload: # Uploading + wheel-py3_10-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_10-cuda13_0-test + needs: wheel-py3_10-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.10" - build_name: wheel-py3_10-cuda13_0 + build_name: wheel-py3_10-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -1216,6 +1217,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -1665,7 +1667,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda12_8-build: + wheel-py3_11-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -1675,8 +1677,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1755,7 +1757,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda12_8 + name: wheel-py3_11-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -1771,10 +1773,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_8-test: # Testing + wheel-py3_11-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda12_8-build + - wheel-py3_11-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -1783,8 +1785,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1855,7 +1857,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda12_8 + name: wheel-py3_11-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -1876,29 +1878,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda12_8-upload: # Uploading + wheel-py3_11-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda12_8-test + needs: wheel-py3_11-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda12_8 + build_name: wheel-py3_11-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_11-cuda13_0-build: + wheel-py3_11-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -1908,8 +1910,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -1988,7 +1990,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_11-cuda13_0 + name: wheel-py3_11-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2004,10 +2006,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda13_0-test: # Testing + wheel-py3_11-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_11-cuda13_0-build + - wheel-py3_11-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -2016,8 +2018,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" @@ -2088,7 +2090,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_11-cuda13_0 + name: wheel-py3_11-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -2109,22 +2111,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_11-cuda13_0-upload: # Uploading + wheel-py3_11-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_11-cuda13_0-test + needs: wheel-py3_11-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.11" - build_name: wheel-py3_11-cuda13_0 + build_name: wheel-py3_11-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -2376,6 +2378,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -2825,7 +2828,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda12_8-build: + wheel-py3_12-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -2835,8 +2838,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -2915,7 +2918,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cuda12_8 + name: wheel-py3_12-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -2931,10 +2934,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_8-test: # Testing + wheel-py3_12-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda12_8-build + - wheel-py3_12-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -2943,8 +2946,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3015,7 +3018,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda12_8 + name: wheel-py3_12-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -3036,29 +3039,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda12_8-upload: # Uploading + wheel-py3_12-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda12_8-test + needs: wheel-py3_12-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda12_8 + build_name: wheel-py3_12-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_12-cuda13_0-build: + wheel-py3_12-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -3068,8 +3071,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3148,7 +3151,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_12-cuda13_0 + name: wheel-py3_12-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -3164,10 +3167,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda13_0-test: # Testing + wheel-py3_12-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_12-cuda13_0-build + - wheel-py3_12-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -3176,8 +3179,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" @@ -3248,7 +3251,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_12-cuda13_0 + name: wheel-py3_12-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -3269,22 +3272,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_12-cuda13_0-upload: # Uploading + wheel-py3_12-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_12-cuda13_0-test + needs: wheel-py3_12-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.12" - build_name: wheel-py3_12-cuda13_0 + build_name: wheel-py3_12-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -3536,6 +3539,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -3985,7 +3989,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13-cuda12_8-build: + wheel-py3_13-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -3995,8 +3999,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4075,7 +4079,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13-cuda12_8 + name: wheel-py3_13-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -4091,10 +4095,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_8-test: # Testing + wheel-py3_13-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13-cuda12_8-build + - wheel-py3_13-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -4103,8 +4107,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4175,7 +4179,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13-cuda12_8 + name: wheel-py3_13-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4196,29 +4200,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda12_8-upload: # Uploading + wheel-py3_13-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13-cuda12_8-test + needs: wheel-py3_13-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" - build_name: wheel-py3_13-cuda12_8 + build_name: wheel-py3_13-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13-cuda13_0-build: + wheel-py3_13-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -4228,8 +4232,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4308,7 +4312,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13-cuda13_0 + name: wheel-py3_13-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -4324,10 +4328,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda13_0-test: # Testing + wheel-py3_13-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13-cuda13_0-build + - wheel-py3_13-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -4336,8 +4340,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13" @@ -4408,7 +4412,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13-cuda13_0 + name: wheel-py3_13-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -4429,22 +4433,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13-cuda13_0-upload: # Uploading + wheel-py3_13-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13-cuda13_0-test + needs: wheel-py3_13-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13" - build_name: wheel-py3_13-cuda13_0 + build_name: wheel-py3_13-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -4696,6 +4700,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -5145,7 +5150,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13t-cuda12_8-build: + wheel-py3_13t-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -5155,8 +5160,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5235,7 +5240,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13t-cuda12_8 + name: wheel-py3_13t-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5251,10 +5256,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_8-test: # Testing + wheel-py3_13t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13t-cuda12_8-build + - wheel-py3_13t-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -5263,8 +5268,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5335,7 +5340,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13t-cuda12_8 + name: wheel-py3_13t-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5356,29 +5361,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda12_8-upload: # Uploading + wheel-py3_13t-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13t-cuda12_8-test + needs: wheel-py3_13t-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" - build_name: wheel-py3_13t-cuda12_8 + build_name: wheel-py3_13t-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_13t-cuda13_0-build: + wheel-py3_13t-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -5388,8 +5393,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5468,7 +5473,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_13t-cuda13_0 + name: wheel-py3_13t-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -5484,10 +5489,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda13_0-test: # Testing + wheel-py3_13t-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_13t-cuda13_0-build + - wheel-py3_13t-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -5496,8 +5501,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.13t" @@ -5568,7 +5573,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_13t-cuda13_0 + name: wheel-py3_13t-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -5589,22 +5594,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_13t-cuda13_0-upload: # Uploading + wheel-py3_13t-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_13t-cuda13_0-test + needs: wheel-py3_13t-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.13t" - build_name: wheel-py3_13t-cuda13_0 + build_name: wheel-py3_13t-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -5856,6 +5861,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -6305,7 +6311,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_14-cuda12_8-build: + wheel-py3_14-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -6315,8 +6321,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" @@ -6395,7 +6401,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_14-cuda12_8 + name: wheel-py3_14-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -6411,10 +6417,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14-cuda12_8-test: # Testing + wheel-py3_14-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_14-cuda12_8-build + - wheel-py3_14-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -6423,8 +6429,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" @@ -6495,7 +6501,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_14-cuda12_8 + name: wheel-py3_14-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -6516,29 +6522,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14-cuda12_8-upload: # Uploading + wheel-py3_14-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_14-cuda12_8-test + needs: wheel-py3_14-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.14" - build_name: wheel-py3_14-cuda12_8 + build_name: wheel-py3_14-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_14-cuda13_0-build: + wheel-py3_14-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -6548,8 +6554,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" @@ -6628,7 +6634,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_14-cuda13_0 + name: wheel-py3_14-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -6644,10 +6650,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14-cuda13_0-test: # Testing + wheel-py3_14-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_14-cuda13_0-build + - wheel-py3_14-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -6656,8 +6662,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14" @@ -6728,7 +6734,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_14-cuda13_0 + name: wheel-py3_14-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -6749,22 +6755,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14-cuda13_0-upload: # Uploading + wheel-py3_14-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_14-cuda13_0-test + needs: wheel-py3_14-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.14" - build_name: wheel-py3_14-cuda13_0 + build_name: wheel-py3_14-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -7016,6 +7022,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: cuda-toolkit[nvrtc,cudart,cupti,cufft,curand,cusolver,cusparse,cufile,nvjitlink,nvtx]==13.0.2; platform_system == 'Linux' | nvidia-cublas>=13.1.0.3,<=13.1.1.3; platform_system == 'Linux' | cuda-bindings>=13.0.3,<14; platform_system == 'Linux' | nvidia-cudnn-cu13==9.20.0.48; platform_system == 'Linux' | nvidia-cusparselt-cu13==0.8.1; platform_system == 'Linux' | nvidia-nccl-cu13==2.29.7; platform_system == 'Linux' | nvidia-nvshmem-cu13==3.4.5; platform_system == 'Linux' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -7465,7 +7472,7 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_14t-cuda12_8-build: + wheel-py3_14t-cuda13_0-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -7475,8 +7482,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" @@ -7555,7 +7562,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_14t-cuda12_8 + name: wheel-py3_14t-cuda13_0 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -7571,10 +7578,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14t-cuda12_8-test: # Testing + wheel-py3_14t-cuda13_0-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_14t-cuda12_8-build + - wheel-py3_14t-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -7583,8 +7590,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" @@ -7655,7 +7662,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_14t-cuda12_8 + name: wheel-py3_14t-cuda13_0 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -7676,29 +7683,29 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14t-cuda12_8-upload: # Uploading + wheel-py3_14t-cuda13_0-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_14t-cuda12_8-test + needs: wheel-py3_14t-cuda13_0-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.14t" - build_name: wheel-py3_14t-cuda12_8 + build_name: wheel-py3_14t-cuda13_0 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - wheel-py3_14t-cuda13_0-build: + wheel-py3_14t-cuda13_2-build: if: ${{ github.repository_owner == 'pytorch' }} needs: get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.12xlarge" @@ -7708,8 +7715,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" @@ -7788,7 +7795,7 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: wheel-py3_14t-cuda13_0 + name: wheel-py3_14t-cuda13_2 retention-days: 14 if-no-files-found: error path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" @@ -7804,10 +7811,10 @@ jobs: run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14t-cuda13_0-test: # Testing + wheel-py3_14t-cuda13_2-test: # Testing if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_14t-cuda13_0-build + - wheel-py3_14t-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}windows.g4dn.xlarge" timeout-minutes: 360 @@ -7816,8 +7823,8 @@ jobs: PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.14t" @@ -7888,7 +7895,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Build Artifacts with: - name: wheel-py3_14t-cuda13_0 + name: wheel-py3_14t-cuda13_2 path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" - name: Populate binary env shell: bash @@ -7909,22 +7916,22 @@ jobs: if: always() run: | .github\scripts\kill_active_ssh_sessions.ps1 - wheel-py3_14t-cuda13_0-upload: # Uploading + wheel-py3_14t-cuda13_2-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: wheel-py3_14t-cuda13_0-test + needs: wheel-py3_14t-cuda13_2-test with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: wheel # TODO: This is a legacy variable that we eventually want to get rid of in # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda DESIRED_PYTHON: "3.14t" - build_name: wheel-py3_14t-cuda13_0 + build_name: wheel-py3_14t-cuda13_2 secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} @@ -8283,15 +8290,15 @@ jobs: R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda12_8-shared-with-deps-release-extract: + libtorch-cuda13_0-shared-with-deps-release-extract: if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda12_8-build + - wheel-py3_10-cuda13_0-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" timeout-minutes: 60 env: - DESIRED_CUDA: cu128 + DESIRED_CUDA: cu130 LIBTORCH_VARIANT: shared-with-deps steps: - name: Checkout PyTorch @@ -8303,7 +8310,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Wheel Artifact with: - name: wheel-py3_10-cuda12_8 + name: wheel-py3_10-cuda13_0 path: "${{ runner.temp }}/wheel_artifact/" - name: Extract libtorch from wheel shell: bash @@ -8319,40 +8326,40 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda12_8-shared-with-deps-release + name: libtorch-cuda13_0-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ runner.temp }}/libtorch_output/" - libtorch-cuda12_8-shared-with-deps-release-upload: # Uploading + libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda12_8-shared-with-deps-release-extract + needs: libtorch-cuda13_0-shared-with-deps-release-extract with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: libtorch - DESIRED_CUDA: cu128 - GPU_ARCH_VERSION: "12.8" + DESIRED_CUDA: cu130 + GPU_ARCH_VERSION: "13.0" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda12_8-shared-with-deps-release + build_name: libtorch-cuda13_0-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} uses: ./.github/workflows/_binary-upload.yml - libtorch-cuda13_0-shared-with-deps-release-extract: + libtorch-cuda13_2-shared-with-deps-release-extract: if: ${{ github.repository_owner == 'pytorch' }} needs: - - wheel-py3_10-cuda13_0-build + - wheel-py3_10-cuda13_2-build - get-label-type runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" timeout-minutes: 60 env: - DESIRED_CUDA: cu130 + DESIRED_CUDA: cu132 LIBTORCH_VARIANT: shared-with-deps steps: - name: Checkout PyTorch @@ -8364,7 +8371,7 @@ jobs: - uses: actions/download-artifact@v4.1.7 name: Download Wheel Artifact with: - name: wheel-py3_10-cuda13_0 + name: wheel-py3_10-cuda13_2 path: "${{ runner.temp }}/wheel_artifact/" - name: Extract libtorch from wheel shell: bash @@ -8380,25 +8387,25 @@ jobs: - uses: actions/upload-artifact@v4.4.0 if: always() with: - name: libtorch-cuda13_0-shared-with-deps-release + name: libtorch-cuda13_2-shared-with-deps-release retention-days: 14 if-no-files-found: error path: "${{ runner.temp }}/libtorch_output/" - libtorch-cuda13_0-shared-with-deps-release-upload: # Uploading + libtorch-cuda13_2-shared-with-deps-release-upload: # Uploading if: ${{ github.repository_owner == 'pytorch' }} permissions: id-token: write contents: read - needs: libtorch-cuda13_0-shared-with-deps-release-extract + needs: libtorch-cuda13_2-shared-with-deps-release-extract with: PYTORCH_ROOT: ${{ github.workspace }} PACKAGE_TYPE: libtorch - DESIRED_CUDA: cu130 - GPU_ARCH_VERSION: "13.0" + DESIRED_CUDA: cu132 + GPU_ARCH_VERSION: "13.2" GPU_ARCH_TYPE: cuda LIBTORCH_CONFIG: release LIBTORCH_VARIANT: shared-with-deps - build_name: libtorch-cuda13_0-shared-with-deps-release + build_name: libtorch-cuda13_2-shared-with-deps-release secrets: github-token: ${{ secrets.GITHUB_TOKEN }} R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }} diff --git a/.github/workflows/h100-cutlass-backend.yml b/.github/workflows/h100-cutlass-backend.yml index e5406f7600133..3e02f3c80f57d 100644 --- a/.github/workflows/h100-cutlass-backend.yml +++ b/.github/workflows/h100-cutlass-backend.yml @@ -21,6 +21,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/h100-distributed.yml b/.github/workflows/h100-distributed.yml index 0e5370a51c160..0a69cce8f868c 100644 --- a/.github/workflows/h100-distributed.yml +++ b/.github/workflows/h100-distributed.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/h100-symm-mem.yml b/.github/workflows/h100-symm-mem.yml index 09c362a546024..7ffe701656cb0 100644 --- a/.github/workflows/h100-symm-mem.yml +++ b/.github/workflows/h100-symm-mem.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml index b0e9c869cdcf0..744ef7d7f75ce 100644 --- a/.github/workflows/inductor-micro-benchmark-x86.yml +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: inductor-build: diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index 2473b5a515f77..8e2e6e04acbc7 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-nightly.yml b/.github/workflows/inductor-nightly.yml index 4258e8fdb0c84..24d0ebf9ef3f0 100644 --- a/.github/workflows/inductor-nightly.yml +++ b/.github/workflows/inductor-nightly.yml @@ -19,6 +19,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: diff --git a/.github/workflows/inductor-pallas.yml b/.github/workflows/inductor-pallas.yml index 8676434d0e580..72467be2fd9c6 100644 --- a/.github/workflows/inductor-pallas.yml +++ b/.github/workflows/inductor-pallas.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -46,9 +47,6 @@ jobs: secrets: inherit linux-jammy-py3_12-inductor-pallas-gpu-test: - permissions: - id-token: write - contents: read name: pallas-gpu-py3.12-inductor uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-py3_12-inductor-pallas-gpu-build @@ -73,9 +71,6 @@ jobs: secrets: inherit linux-jammy-py3_12-inductor-pallas-tpu-test: - permissions: - id-token: write - contents: read name: pallas-tpu-py3.12-inductor uses: ./.github/workflows/_linux-test.yml needs: linux-jammy-py3_12-inductor-pallas-tpu-build diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index 1ab538f96c1f0..f4b743058ce8a 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -13,6 +13,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: @@ -59,4 +60,5 @@ jobs: disable-monitor: false monitor-log-interval: 15 monitor-data-collect-interval: 4 + export-profiler-trace: "1" secrets: inherit diff --git a/.github/workflows/inductor-perf-test-b200.yml b/.github/workflows/inductor-perf-test-b200.yml index 96608d6d1d1d0..6eacb939c941a 100644 --- a/.github/workflows/inductor-perf-test-b200.yml +++ b/.github/workflows/inductor-perf-test-b200.yml @@ -64,6 +64,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index f7b3517dccc06..64cdda3d57223 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -7,6 +7,14 @@ on: # NB: GitHub has an upper limit of 10 inputs here workflow_dispatch: inputs: + runner_config: + description: "AArch64 runner instance type" + required: true + type: choice + default: m8g + options: + - m8g + - m7g training: # CPU for training is not typical, but leave the option open here description: Run training (off by default)? @@ -51,6 +59,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -63,6 +72,7 @@ jobs: curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} opt_out_experiments: lf + runner_config: ${{ github.event.inputs.runner_config || 'm8g' }} linux-jammy-aarch64-py3_10-inductor-build: name: linux-jammy-aarch64-py3.10-inductor @@ -70,50 +80,50 @@ jobs: needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runner: linux.arm64.m7g.4xlarge + runner: linux.arm64.${{ needs.get-label-type.outputs.runner-config }}.4xlarge build-environment: linux-jammy-aarch64-py3.10 docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks test-matrix: | { include: [ - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 2, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 3, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 4, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 5, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 6, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 7, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 8, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_huggingface_perf_cpu_aarch64", shard: 9, num_shards: 9, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 1, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 2, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 3, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 4, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 5, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 6, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 7, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 8, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 9, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 10, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 11, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 12, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 13, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 14, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_timm_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 1, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 2, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 3, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 4, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 5, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 6, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 7, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 8, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 9, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 10, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 11, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 12, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 13, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 14, num_shards: 15, runner: "linux.arm64.m7g.metal" }, - { config: "inductor_torchbench_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "linux.arm64.m7g.metal" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 2, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 3, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 4, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 5, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 6, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 7, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 8, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_huggingface_perf_cpu_aarch64", shard: 9, num_shards: 9, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 1, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 2, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 3, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 4, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 5, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 6, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 7, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 8, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 9, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 10, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 11, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 12, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 13, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 14, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_timm_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 1, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 2, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 3, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 4, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 5, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 6, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 7, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 8, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 9, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 10, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 11, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 12, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 13, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 14, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, + { config: "inductor_torchbench_perf_cpu_aarch64", shard: 15, num_shards: 15, runner: "${{ needs.get-label-type.outputs.runner-label }}" }, ]} selected-test-configs: ${{ inputs.benchmark_configs }} build-additional-packages: "vision audio torchao" diff --git a/.github/workflows/inductor-perf-test-nightly-h100.yml b/.github/workflows/inductor-perf-test-nightly-h100.yml index 4d161a00825b0..2c3c094090ce0 100644 --- a/.github/workflows/inductor-perf-test-nightly-h100.yml +++ b/.github/workflows/inductor-perf-test-nightly-h100.yml @@ -51,6 +51,16 @@ on: required: false type: boolean default: false + deterministic_perf: + description: Run benchmarks with deterministic mode enabled? + required: false + type: boolean + default: false + batch_invariant_accuracy: + description: Run accuracy benchmarks with batch-invariant mode enabled? + required: false + type: boolean + default: false benchmark_configs: description: The list of configs used the benchmark required: false @@ -69,6 +79,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -131,7 +142,7 @@ jobs: if: github.event.schedule == '15 0 * * 1-6' with: build-environment: ${{ needs.build.outputs.build-environment }} - dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-deterministic_perf-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 @@ -139,6 +150,8 @@ jobs: disable-monitor: false monitor-log-interval: 15 monitor-data-collect-interval: 4 + export-profiler-trace: "1" + enable-torch-trace: "1" secrets: inherit test-weekly: @@ -148,7 +161,7 @@ jobs: if: github.event.schedule == '0 7 * * 0' with: build-environment: ${{ needs.build.outputs.build-environment }} - dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true + dashboard-tag: training-true-inference-true-default-true-dynamic-true-cudagraphs-true-cppwrapper-true-aotinductor-true-freezing_cudagraphs-true-maxautotune-true-freeze_autotune_cudagraphs-true-deterministic_perf-true docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 1440 @@ -156,6 +169,8 @@ jobs: disable-monitor: false monitor-log-interval: 15 monitor-data-collect-interval: 4 + export-profiler-trace: "1" + enable-torch-trace: "1" secrets: inherit test: @@ -167,7 +182,7 @@ jobs: if: ${{ github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' }} with: build-environment: ${{ needs.build.outputs.build-environment }} - dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }} + dashboard-tag: training-${{ inputs.training || 'true' }}-inference-${{ inputs.inference || 'true' }}-default-${{ inputs.default || 'true' }}-dynamic-${{ inputs.dynamic || 'true' }}-cudagraphs-${{ inputs.cudagraphs || 'true' }}-cppwrapper-${{ inputs.cppwrapper || 'false' }}-aotinductor-${{ inputs.aotinductor || 'false' }}-maxautotune-${{ inputs.maxautotune || 'false' }}-freezing_cudagraphs-${{ inputs.freezing_cudagraphs || 'false' }}-deterministic_perf-${{ inputs.deterministic_perf || 'false' }}-batch_invariant_accuracy-${{ inputs.batch_invariant_accuracy || 'false' }} docker-image: ${{ needs.build.outputs.docker-image }} test-matrix: ${{ needs.build.outputs.test-matrix }} timeout-minutes: 720 @@ -175,4 +190,6 @@ jobs: disable-monitor: false monitor-log-interval: 15 monitor-data-collect-interval: 4 + export-profiler-trace: "1" + enable-torch-trace: "1" secrets: inherit diff --git a/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml b/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml index c556c6b455783..9ecd66709e54a 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm-mi300.yml @@ -63,7 +63,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: get-label-type: @@ -79,10 +82,10 @@ jobs: linux-jammy-rocm-py3_10-inductor-benchmark-build: if: github.repository_owner == 'pytorch' - name: rocm-py3_10-inductor-benchmark-build + name: linux-jammy-rocm-py3.10-mi300 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: linux-jammy-rocm-py3.10-mi300 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ @@ -111,10 +114,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-inductor-benchmark-test: - permissions: - id-token: write - contents: read - name: rocm-py3_10-inductor-benchmark-test + name: linux-jammy-rocm-py3.10-mi300 uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-benchmark-build with: diff --git a/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml b/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml index e6fd83193202c..eac1f04ff5590 100644 --- a/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml +++ b/.github/workflows/inductor-perf-test-nightly-rocm-mi355.yml @@ -63,7 +63,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: get-label-type: @@ -79,10 +82,10 @@ jobs: linux-jammy-rocm-py3_10-inductor-benchmark-build: if: github.repository_owner == 'pytorch' - name: rocm-py3_10-inductor-benchmark-build + name: linux-jammy-rocm-py3.10-mi355 uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: linux-jammy-rocm-py3.10-mi355 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ @@ -111,10 +114,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-inductor-benchmark-test: - permissions: - id-token: write - contents: read - name: rocm-py3_10-inductor-benchmark-test + name: linux-jammy-rocm-py3.10-mi355 uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-benchmark-build with: diff --git a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml index eee51b7ff8889..a88d3d63ef50c 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86-zen.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86-zen.yml @@ -61,6 +61,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 87875831e2a0b..6c312586f27b5 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -61,6 +61,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/inductor-perf-test-nightly-xpu.yml b/.github/workflows/inductor-perf-test-nightly-xpu.yml index b51795c663957..70ca8f8a5b44b 100644 --- a/.github/workflows/inductor-perf-test-nightly-xpu.yml +++ b/.github/workflows/inductor-perf-test-nightly-xpu.yml @@ -63,7 +63,10 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 0d470a3bcd4cb..2e7c3d0e43f60 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -64,6 +64,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index cfa844d3b7d96..a16a4fd80a63c 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: @@ -59,6 +60,9 @@ jobs: { config: "dynamic_aot_eager_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_aot_eager_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_huggingface_unbacked_parity", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_huggingface_unbacked_parity", shard: 1, num_shards: 1, runner: "linux.aws.a100" }, + { config: "inductor_huggingface_unbacked_parity", shard: 1, num_shards: 1, runner: "linux.aws.h100" }, { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, @@ -87,7 +91,7 @@ jobs: name: rocm-periodic-dynamo-benchmarks-build uses: ./.github/workflows/_linux-build.yml with: - build-environment: linux-jammy-rocm-py3_10 + build-environment: linux-jammy-rocm-py3.10-mi355 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3-benchmarks test-matrix: | { include: [ @@ -120,9 +124,6 @@ jobs: secrets: inherit rocm-periodic-dynamo-benchmarks-test: - permissions: - id-token: write - contents: read name: rocm-periodic-dynamo-benchmarks-test uses: ./.github/workflows/_rocm-test.yml needs: rocm-periodic-dynamo-benchmarks-build diff --git a/.github/workflows/inductor-rocm-mi200.yml b/.github/workflows/inductor-rocm-mi200.yml index c67f41daf5ad3..44fda5a4ed4e7 100644 --- a/.github/workflows/inductor-rocm-mi200.yml +++ b/.github/workflows/inductor-rocm-mi200.yml @@ -17,6 +17,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -31,12 +32,12 @@ jobs: opt_out_experiments: lf linux-jammy-rocm-py3_10-inductor-build: - name: rocm-py3.10-inductor + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-mi200 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ @@ -46,10 +47,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-inductor-test: - permissions: - id-token: write - contents: read - name: rocm-py3.10-inductor + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_rocm-test.yml needs: linux-jammy-rocm-py3_10-inductor-build with: diff --git a/.github/workflows/inductor-rocm-mi300.yml b/.github/workflows/inductor-rocm-mi300.yml index 633386aba487b..0b196b4a214b9 100644 --- a/.github/workflows/inductor-rocm-mi300.yml +++ b/.github/workflows/inductor-rocm-mi300.yml @@ -17,15 +17,13 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: target-determination: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/target_determination.yml - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -39,7 +37,7 @@ jobs: opt_out_experiments: lf linux-noble-rocm-py3_12-inductor-build: - name: rocm-py3.12-inductor-mi300 + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -54,10 +52,7 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-inductor-test: - permissions: - id-token: write - contents: read - name: rocm-py3.12-inductor-mi300 + name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_rocm-test.yml needs: linux-noble-rocm-py3_12-inductor-build with: diff --git a/.github/workflows/inductor-rocm-mi355.yml b/.github/workflows/inductor-rocm-mi355.yml index 70ea41a6da698..bd75369adbd00 100644 --- a/.github/workflows/inductor-rocm-mi355.yml +++ b/.github/workflows/inductor-rocm-mi355.yml @@ -16,15 +16,13 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: target-determination: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/target_determination.yml - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -38,7 +36,7 @@ jobs: opt_out_experiments: lf linux-noble-rocm-py3_12-inductor-build: - name: rocm-py3.12-inductor-mi355 + name: linux-noble-rocm-py3.12-mi355 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: @@ -53,10 +51,7 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-inductor-test: - permissions: - id-token: write - contents: read - name: rocm-py3.12-inductor-mi355 + name: linux-noble-rocm-py3.12-mi355 uses: ./.github/workflows/_rocm-test.yml needs: linux-noble-rocm-py3_12-inductor-build with: diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 82fb485f24b2b..cb5419437d703 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -18,6 +18,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -166,8 +167,8 @@ jobs: python-version: ['3.11', '3.12', '3.13'] with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py${{ matrix.python-version }}-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang15 + build-environment: linux-jammy-py${{ matrix.python-version }}-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang18 test-matrix: | { include: [ { config: "inductor_core", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, @@ -183,8 +184,8 @@ jobs: matrix: python-version: ['3.11', '3.12', '3.13'] with: - build-environment: linux-jammy-py${{ matrix.python-version }}-clang15 - docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang15 + build-environment: linux-jammy-py${{ matrix.python-version }}-clang18 + docker-image: ci-image:pytorch-linux-jammy-py${{ matrix.python-version }}-clang18 test-matrix: | { include: [ { config: "inductor_core", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 2532279500849..d4a44968ffd93 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -25,6 +25,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: unit-test: @@ -73,6 +74,7 @@ jobs: build-environment: ${{ needs.inductor-build.outputs.build-environment }} docker-image: ${{ needs.inductor-build.outputs.docker-image }} test-matrix: ${{ needs.inductor-build.outputs.test-matrix }} + enable-torch-trace: "1" secrets: inherit inductor-cpu-build: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index aab977b11469b..cd244274039d5 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -36,7 +36,7 @@ jobs: all_files: ${{ contains(github.event.pull_request.labels.*.name, 'lint-all-files') || contains(github.event.pull_request.labels.*.name, 'Reverted') || github.event_name == 'push' }} lintrunner-clang: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + uses: ./.github/workflows/_lint.yml # Needed to prevent deduping on HUD name: lintrunner-clang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} needs: [get-label-type, get-changed-files] @@ -56,14 +56,8 @@ jobs: contains(needs.get-changed-files.outputs.changed-files, '.metal') ) with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter - # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed - fetch-depth: 0 - submodules: true - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cuda-x86_64-810d48d script: | CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" if [ "$CHANGED_FILES" = "*" ]; then @@ -78,7 +72,7 @@ jobs: # fails to find types when it should # NOTE: We should be able to disable this and consolidate with Pyrefly lintrunner-pyrefly: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + uses: ./.github/workflows/_lint.yml name: lintrunner-pyrefly-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} needs: [get-label-type, get-changed-files] # Only run if there are changed files relevant to pyrefly @@ -89,32 +83,20 @@ jobs: contains(needs.get-changed-files.outputs.changed-files, '.pyi') ) with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed - fetch-depth: 0 - submodules: true - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" echo "Running pyrefly" ADDITIONAL_LINTRUNNER_ARGS="--take PYREFLY --all-files" .github/scripts/lintrunner.sh lintrunner-noclang: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + uses: ./.github/workflows/_lint.yml name: lintrunner-noclang-${{ needs.get-changed-files.outputs.changed-files == '*' && 'all' || 'partial' }} needs: [get-label-type, get-changed-files] with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - # NB: A shallow checkout won't work here because calculate-docker-image requires a full checkout - # to run git rev-parse HEAD~:.ci/docker when a new image is needed - fetch-depth: 0 - submodules: true - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | CHANGED_FILES="${{ needs.get-changed-files.outputs.changed-files }}" echo "Running all other linters" @@ -125,14 +107,12 @@ jobs: fi quick-checks: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + if: github.repository_owner == 'pytorch' needs: get-label-type + uses: ./.github/workflows/_lint.yml with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: 0 - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | # Ensure no non-breaking spaces # NB: We use 'printf' below rather than '\u000a' since bash pre-4.2 @@ -159,10 +139,9 @@ jobs: pr-sanity-checks: name: pr-sanity-checks - needs: get-label-type - runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}linux.large"] + runs-on: linux.24_04.4x # Only run this on pull requests. This check is simple enough to be done without a Docker image - if: github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') + if: ${{ github.event_name == 'pull_request' && !contains(github.event.pull_request.labels.*.name, 'skip-pr-sanity-checks') && github.repository_owner == 'pytorch' }} steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main @@ -177,17 +156,13 @@ jobs: run: | bash .github/scripts/pr-sanity-check.sh - workflow-checks: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + if: github.repository_owner == 'pytorch' needs: get-label-type + uses: ./.github/workflows/_lint.yml with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: -1 - submodules: true - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | # Regenerate workflows .github/scripts/generate_ci_workflows.py @@ -213,14 +188,12 @@ jobs: exit $RC toc: - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + if: github.repository_owner == 'pytorch' needs: get-label-type + uses: ./.github/workflows/_lint.yml with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: 0 - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | # Regenerate ToCs and check that they didn't change set -eu @@ -249,14 +222,11 @@ jobs: test-tools: name: Test tools if: ${{ github.repository == 'pytorch/pytorch' }} - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main needs: get-label-type + uses: ./.github/workflows/_lint.yml with: - timeout: 120 - runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" - docker-image: ci-image:pytorch-linux-jammy-linter - fetch-depth: 0 - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + runner: mt-l-x86iamx-8-16 + docker-image: ghcr.io/pytorch/test-infra:cpu-x86_64-810d48d script: | # Test tools PYTHONPATH=$(pwd) pytest tools/stats @@ -351,18 +321,17 @@ jobs: python3 torch/utils/collect_env.py link-check: - name: Link checks + if: github.repository_owner == 'pytorch' needs: get-label-type + name: Link checks uses: ./.github/workflows/_link_check.yml with: runner: ${{ needs.get-label-type.outputs.label-type }} - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - secrets: inherit + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} doc-redirects-check: name: doc-redirects-check - needs: get-label-type - runs-on: [self-hosted, "${{ needs.get-label-type.outputs.label-type }}linux.large"] + runs-on: linux.24_04.4x if: github.event_name == 'pull_request' && github.repository_owner == 'pytorch' steps: - name: Checkout PyTorch diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml deleted file mode 100644 index 2c6e6b6dac39c..0000000000000 --- a/.github/workflows/linux-aarch64.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: linux-aarch64 - -on: - push: - branches: - - main - - release/* - tags: - - ciflow/linux-aarch64/* - - ciflow/trunk/* - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' && github.run_id }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - - get-label-type: - if: github.repository_owner == 'pytorch' - name: get-label-type - uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main - with: - triggering_actor: ${{ github.triggering_actor }} - issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} - curr_branch: ${{ github.head_ref || github.ref_name }} - curr_ref_type: ${{ github.ref_type }} - - linux-jammy-aarch64-py3_10-build: - name: linux-jammy-aarch64-py3.10 - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: ${{ needs.get-label-type.outputs.label-type }} - build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 - runner: linux.arm64.m7g.4xlarge - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, - { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, - { config: "default", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, - { config: "default", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, - { config: "default", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, - { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, - ]} - secrets: inherit - - linux-jammy-aarch64-py3_10-test: - name: linux-jammy-aarch64-py3.10 - uses: ./.github/workflows/_linux-test.yml - needs: linux-jammy-aarch64-py3_10-build - permissions: - id-token: write - contents: read - with: - build-environment: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-build.outputs.test-matrix }} - secrets: inherit diff --git a/.github/workflows/llm_td_retrieval.yml b/.github/workflows/llm_td_retrieval.yml index 565a9b25df50f..23416b9b0627f 100644 --- a/.github/workflows/llm_td_retrieval.yml +++ b/.github/workflows/llm_td_retrieval.yml @@ -26,15 +26,8 @@ jobs: continue-on-error: true needs: get-label-type steps: - - name: Clone PyTorch - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: pytorch/pytorch - fetch-depth: 0 - path: pytorch - - name: Setup Linux - uses: ./pytorch/.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Clone CodeLlama uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index c6282855aa88c..82f718c9b1077 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -18,6 +18,11 @@ concurrency: group: ${{ github.workflow }}--${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} cancel-in-progress: true +permissions: + id-token: write + contents: read + actions: read + jobs: get-label-type: name: get-label-type @@ -45,7 +50,7 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 secrets: inherit docs-push: diff --git a/.github/workflows/nitpicker.yml b/.github/workflows/nitpicker.yml index 40bd245ce913f..a112fb0dcec42 100644 --- a/.github/workflows/nitpicker.yml +++ b/.github/workflows/nitpicker.yml @@ -20,7 +20,7 @@ jobs: steps: - name: Checkout PyTorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main - - uses: ethanis/nitpicker@v1 + - uses: ethanis/nitpicker@c102a39683a80c7db9065f8eab7de8b58871f946 # v1 with: nitpicks: '.github/nitpicks.yml' token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index e682e1eb06c24..7d6cbe58b726b 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -28,6 +28,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: x86-opbenchmark-build: diff --git a/.github/workflows/operator_microbenchmark.yml b/.github/workflows/operator_microbenchmark.yml index 445cdcc4be04a..51a826620ce7c 100644 --- a/.github/workflows/operator_microbenchmark.yml +++ b/.github/workflows/operator_microbenchmark.yml @@ -17,6 +17,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-label-type: @@ -28,6 +29,7 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + check_experiments: arc,lf # H100 A100 runners opmicrobenchmark-build: @@ -58,14 +60,14 @@ jobs: test-matrix: ${{ needs.opmicrobenchmark-build.outputs.test-matrix }} secrets: inherit - # B200 runner + # B200 runner (OSDC), always use OSDC runner to test this workflow opmicrobenchmark-build-b200: if: github.repository_owner == 'pytorch' name: opmicrobenchmark-build-b200 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runner_prefix: "mt-" runner: linux.r7i.4xlarge build-environment: linux-jammy-cuda12.8-py3.10-gcc11-sm100 docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 @@ -74,18 +76,27 @@ jobs: { include: [ { config: "operator_microbenchmark_test", shard: 1, num_shards: 1, runner: "linux.dgx.b200" }, ]} + use-arc: true + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit opmicrobenchmark-test-b200: name: opmicrobenchmark-test-b200 uses: ./.github/workflows/_linux-test.yml - needs: opmicrobenchmark-build-b200 + needs: + - opmicrobenchmark-build-b200 + - get-label-type with: timeout-minutes: 500 build-environment: ${{ needs.opmicrobenchmark-build-b200.outputs.build-environment }} docker-image: ${{ needs.opmicrobenchmark-build-b200.outputs.docker-image }} test-matrix: ${{ needs.opmicrobenchmark-build-b200.outputs.test-matrix }} - aws-role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_s3_and_ecr_read_only + use-arc: true + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit # ROCM MI300 runner diff --git a/.github/workflows/periodic-rocm-mi200.yml b/.github/workflows/periodic-rocm-mi200.yml index 3107d6897e0d7..234c7933b0596 100644 --- a/.github/workflows/periodic-rocm-mi200.yml +++ b/.github/workflows/periodic-rocm-mi200.yml @@ -17,23 +17,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -46,12 +41,12 @@ jobs: curr_ref_type: ${{ github.ref_type }} linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-mi200 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 test-matrix: | { include: [ @@ -62,10 +57,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_rocm-test.yml needs: - linux-jammy-rocm-py3_10-build diff --git a/.github/workflows/periodic-rocm-mi300.yml b/.github/workflows/periodic-rocm-mi300.yml index 88da168926444..cfbe99812d8f6 100644 --- a/.github/workflows/periodic-rocm-mi300.yml +++ b/.github/workflows/periodic-rocm-mi300.yml @@ -14,24 +14,21 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' && github.run_id }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -60,9 +57,6 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-test: - permissions: - id-token: write - contents: read name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_rocm-test.yml needs: diff --git a/.github/workflows/periodic-rocm-mi355.yml b/.github/workflows/periodic-rocm-mi355.yml index 8c6063260861d..7fd9d36b13423 100644 --- a/.github/workflows/periodic-rocm-mi355.yml +++ b/.github/workflows/periodic-rocm-mi355.yml @@ -15,24 +15,21 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' && github.run_id }}-${{ github.event_name == 'schedule' }}-${{ github.event.schedule }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -61,9 +58,6 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-test: - permissions: - id-token: write - contents: read name: linux-noble-rocm-py3.12-mi355 uses: ./.github/workflows/_rocm-test.yml needs: diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 09018f6cf1613..a36b7eae500b5 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -23,23 +23,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 4cbc0b5077360..69af709c5bbff 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -33,6 +33,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: # See job-filter.yml for rules on adding job filter conditions @@ -47,17 +48,11 @@ jobs: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -67,6 +62,11 @@ jobs: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} + check_experiments: arc,lf + + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.10-gcc11 (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ linux-jammy-py3_10-gcc11-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-gcc11 ') || contains(needs.job-filter.outputs.jobs, ' linux-docs ') }} @@ -78,24 +78,29 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-gcc11 - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + # Use m7i instead of m7a to match the existing (Intel Xeon) numerics. DTensor crossref tests like linalg.multi_dot have tight + # float32 tolerances sensitive to different FMA/reduction order across CPU vendors. + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "numpy_2_x", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit linux-jammy-py3_10-gcc11-test: @@ -106,22 +111,74 @@ jobs: - linux-jammy-py3_10-gcc11-build - target-determination - job-filter + - get-label-type with: build-environment: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-aarch64-py3.10-gcc11 (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-jammy-aarch64-py3_10-gcc13-build: + name: linux-jammy-aarch64-py3.10 + uses: ./.github/workflows/_linux-build.yml + needs: + - get-label-type + - job-filter + with: + runner_prefix: ${{ needs.get-label-type.outputs.label-type }} + build-environment: linux-jammy-aarch64-py3.10 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 + runner: linux.arm64.m8g.4xlarge + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m8g.4xlarge" }, + ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc13 + secrets: inherit - linux-jammy-py3_14t-clang15-build: - name: linux-jammy-py3.14t-clang15 + linux-jammy-aarch64-py3_10-gcc13-test: + name: linux-jammy-aarch64-py3.10 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-aarch64-py3_10-gcc13-build + - target-determination + - job-filter + - get-label-type + with: + build-environment: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.test-matrix }} + tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc13 + secrets: inherit + + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.14t-clang18 (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-jammy-py3_14t-clang18-build: + name: linux-jammy-py3.14t-clang18 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.14t-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py3.14t-clang15 + build-environment: linux-jammy-py3.14t-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py3.14t-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -131,26 +188,39 @@ jobs: { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + # Free-threaded Python 3.14t has ~20-50% higher per-object memory overhead from biased reference counting and + # per-object locks, and test_nn (531 tests) under dynamo wrapping with compiled autograd consistently OOMs at 64GB. + { config: "dynamo_wrapped", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "dynamo_wrapped", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14t" + compiler: clang18 secrets: inherit - linux-jammy-py3_14t-clang15-test: - name: linux-jammy-py3.14t-clang15 + linux-jammy-py3_14t-clang18-test: + name: linux-jammy-py3.14t-clang18 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_14t-clang15-build + - linux-jammy-py3_14t-clang18-build - target-determination + - get-label-type with: - build-environment: linux-jammy-py3.14t-clang15 - docker-image: ${{ needs.linux-jammy-py3_14t-clang15-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_14t-clang15-build.outputs.test-matrix }} + build-environment: linux-jammy-py3.14t-clang18 + docker-image: ${{ needs.linux-jammy-py3_14t-clang18-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_14t-clang18-build.outputs.test-matrix }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14t" + compiler: clang18 secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-docs ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-docs: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-docs ') }} name: linux-docs @@ -158,11 +228,21 @@ jobs: needs: - linux-jammy-py3_10-gcc11-build - job-filter + - get-label-type with: build-environment: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-gcc11-build.outputs.docker-image }} + run-doxygen: true + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + python-version: "3.10" + compiler: gcc11 secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.10-gcc11-no-ops (build only) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-jammy-py3_10-gcc11-no-ops: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-gcc11-no-ops ') }} name: linux-jammy-py3.10-gcc11-no-ops @@ -173,13 +253,20 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-gcc11-no-ops - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.10-clang18-asan (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-jammy-py3_10-clang18-asan-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang18-asan ') }} name: linux-jammy-py3.10-clang18-asan @@ -191,7 +278,7 @@ jobs: runner: linux.c7i.4xlarge runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang18-asan - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -203,7 +290,11 @@ jobs: { config: "default", shard: 7, num_shards: 7, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} - sync-tag: asan-build + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: clang18 + # TODO (huydhn): Add this back once other workflow migrates + # sync-tag: asan-build secrets: inherit linux-jammy-py3_10-clang18-asan-test: @@ -214,58 +305,34 @@ jobs: - linux-jammy-py3_10-clang18-asan-build - target-determination - job-filter + - get-label-type with: build-environment: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.build-environment }} docker-image: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-py3_10-clang18-asan-build.outputs.test-matrix }} - sync-tag: asan-test tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: clang18 + # TODO (huydhn): Add this back once other workflow migrates + # sync-tag: asan-test secrets: inherit - linux-jammy-py3_10-clang15-onnx-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang15-onnx ') }} - name: linux-jammy-py3.10-clang15-onnx - uses: ./.github/workflows/_linux-build.yml - needs: - - get-label-type - - job-filter - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.10-clang15-onnx - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-onnx - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, - { config: "default", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, - ]} - secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.10-clang18 (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ - linux-jammy-py3_10-clang15-onnx-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang15-onnx ') }} - name: linux-jammy-py3.10-clang15-onnx - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-py3_10-clang15-onnx-build - - target-determination - - job-filter - with: - build-environment: ${{ needs.linux-jammy-py3_10-clang15-onnx-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-py3_10-clang15-onnx-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_10-clang15-onnx-build.outputs.test-matrix }} - tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} - secrets: inherit - - linux-jammy-py3_10-clang15-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang15 ') }} - name: linux-jammy-py3.10-clang15 + linux-jammy-py3_10-clang18-build: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang18 ') }} + name: linux-jammy-py3.10-clang18 uses: ./.github/workflows/_linux-build.yml needs: - get-label-type - job-filter with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.10-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang15 + build-environment: linux-jammy-py3.10-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -280,35 +347,47 @@ jobs: { config: "dynamo_wrapped", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "onnx", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: clang18 secrets: inherit - linux-jammy-py3_10-clang15-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang15 ') }} - name: linux-jammy-py3.10-clang15 + linux-jammy-py3_10-clang18-test: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-clang18 ') }} + name: linux-jammy-py3.10-clang18 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_10-clang15-build + - linux-jammy-py3_10-clang18-build - target-determination - job-filter + - get-label-type with: - build-environment: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.test-matrix }} + build-environment: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: clang18 secrets: inherit - linux-jammy-py3_14-clang15-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.14-clang15 ') }} - name: linux-jammy-py3.14-clang15 + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.14-clang18 (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + + linux-jammy-py3_14-clang18-build: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.14-clang18 ') }} + name: linux-jammy-py3.14-clang18 uses: ./.github/workflows/_linux-build.yml needs: - get-label-type - job-filter with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.14-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py3.14-clang15 + build-environment: linux-jammy-py3.14-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py3.14-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, @@ -324,56 +403,57 @@ jobs: { config: "einops", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14" + compiler: clang18 secrets: inherit - linux-jammy-py3_14-clang15-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.14-clang15 ') }} - name: linux-jammy-py3.14-clang15 + linux-jammy-py3_14-clang18-test: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.14-clang18 ') }} + name: linux-jammy-py3.14-clang18 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_14-clang15-build + - linux-jammy-py3_14-clang18-build - job-filter + - get-label-type with: - build-environment: ${{ needs.linux-jammy-py3_14-clang15-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-py3_14-clang15-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_14-clang15-build.outputs.test-matrix }} + build-environment: ${{ needs.linux-jammy-py3_14-clang18-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-py3_14-clang18-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_14-clang18-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14" + compiler: clang18 secrets: inherit - linux-jammy-cuda12_8-cudnn9-py3_10-clang15-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda12.8-cudnn9-py3.10-clang15 ') }} - name: linux-jammy-cuda12.8-cudnn9-py3.10-clang15 + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-cuda12.8-cudnn9-py3.10-clang18 (build only) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + + linux-jammy-cuda12_8-cudnn9-py3_10-clang18-build: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda12.8-cudnn9-py3.10-clang18 ') }} + name: linux-jammy-cuda12.8-cudnn9-py3.10-clang18 uses: ./.github/workflows/_linux-build.yml needs: - get-label-type - job-filter with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda12.8-cudnn9-py3.10-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang15 + build-environment: linux-jammy-cuda12.8-cudnn9-py3.10-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-clang18 test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: clang18 + cuda-version: "12.8" secrets: inherit - linux-jammy-cpu-py3_10-gcc11-bazel-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cpu-py3.10-gcc11-bazel-test ') }} - name: linux-jammy-cpu-py3.10-gcc11-bazel-test - uses: ./.github/workflows/_bazel-build-test.yml - needs: - - get-label-type - - job-filter - with: - runner: "${{ needs.get-label-type.outputs.label-type }}linux.large" - build-environment: linux-jammy-cuda12.8-py3.10-gcc11-bazel-test - docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - cuda-version: cpu - test-matrix: | - { include: [ - { config: "default", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - ]} - secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-py3.10-gcc11-mobile-lightweight-dispatch (build only) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ linux-jammy-py3_10-gcc11-mobile-lightweight-dispatch-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3.10-gcc11-mobile-lightweight-dispatch-build ') }} @@ -385,16 +465,23 @@ jobs: with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-gcc11-mobile-lightweight-dispatch-build - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 build-generates-artifacts: false test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-rocm-py3.10 (build) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + linux-jammy-rocm-py3_10-build: - if: github.event_name == 'pull_request' || (needs.job-filter.outputs.jobs != '' && contains(needs.job-filter.outputs.jobs, ' linux-jammy-rocm-py3.10 ')) + if: ${{ github.event_name == 'pull_request' || (needs.job-filter.outputs.jobs != '' && contains(needs.job-filter.outputs.jobs, ' linux-jammy-rocm-py3.10 ')) }} # don't run build twice on main name: linux-jammy-rocm-py3.10 uses: ./.github/workflows/_linux-build.yml @@ -405,45 +492,21 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-rocm-py3.10 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 - sync-tag: rocm-build + # TODO (huydhn): Add this back once other workflow migrates + # sync-tag: rocm-build test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 3, runner: "linux.rocm.gpu.2" }, { config: "default", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.2" }, { config: "default", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.2" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" secrets: inherit - linux-jammy-cuda13_0-py3_10-gcc11-inductor-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' cuda13.0-py3.10-gcc11-sm75 ') }} - name: cuda13.0-py3.10-gcc11-sm75 - uses: ./.github/workflows/_linux-build.yml - needs: - - get-label-type - - job-filter - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm75 - docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks - test-matrix: | - { include: [ - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - ]} - secrets: inherit - - linux-jammy-cuda13_0-py3_10-gcc11-inductor-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' cuda13.0-py3.10-gcc11-sm75 ') }} - name: cuda13.0-py3.10-gcc11-sm75 - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda13_0-py3_10-gcc11-inductor-build - - job-filter - with: - build-environment: linux-jammy-cuda13.0-py3.10-gcc11-sm75 - docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-inductor-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-inductor-build.outputs.test-matrix }} - tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} - secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ linux-jammy-xpu-n-py3.10 (build) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ linux-jammy-xpu-n-py3_10-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-xpu-n-py3.10 ') }} @@ -453,8 +516,6 @@ jobs: - get-label-type - job-filter with: - # This should sync with the build in xpu.yml but xpu uses a larger runner - # sync-tag: linux-xpu-n-build runner_prefix: ${{ needs.get-label-type.outputs.label-type }} runner: linux.c7i.4xlarge build-environment: linux-noble-xpu-n-py3.10 @@ -466,20 +527,29 @@ jobs: { config: "default", shard: 3, num_shards: 4, runner: "linux.idc.xpu" }, { config: "default", shard: 4, num_shards: 4, runner: "linux.idc.xpu" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" secrets: inherit + # ╠══════════════════════════════════════════════════════════════════════╣ + # ║ dynamo-cpython (build + test) ║ + # ╠══════════════════════════════════════════════════════════════════════╣ + dynamo-cpython-build: name: dynamo-cpython-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.13-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py3.13-clang15 + build-environment: linux-jammy-py3.13-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py3.13-clang18 test-matrix: | { include: [ { config: "dynamo_cpython", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.13" + compiler: clang18 secrets: inherit dynamo-cpython-test: @@ -487,10 +557,10 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: [get-label-type, dynamo-cpython-build] with: - build-environment: linux-jammy-py3.13-clang15 - docker-image: ci-image:pytorch-linux-jammy-py3.13-clang15 - test-matrix: | - { include: [ - { config: "dynamo_cpython", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.c7i.2xlarge" }, - ]} + build-environment: linux-jammy-py3.13-clang18 + docker-image: ${{ needs.dynamo-cpython-build.outputs.docker-image }} + test-matrix: ${{ needs.dynamo-cpython-build.outputs.test-matrix }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.13" + compiler: clang18 secrets: inherit diff --git a/.github/workflows/quantization-periodic.yml b/.github/workflows/quantization-periodic.yml index 8dd97ff9308db..3b7ac1fb62a67 100644 --- a/.github/workflows/quantization-periodic.yml +++ b/.github/workflows/quantization-periodic.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: diff --git a/.github/workflows/rocm-mi200.yml b/.github/workflows/rocm-mi200.yml index 7ca06317eebe5..16c9e0097df5d 100644 --- a/.github/workflows/rocm-mi200.yml +++ b/.github/workflows/rocm-mi200.yml @@ -38,12 +38,12 @@ jobs: linux-jammy-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-mi200 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -61,7 +61,7 @@ jobs: permissions: id-token: write contents: read - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_rocm-test.yml needs: - linux-jammy-rocm-py3_10-build diff --git a/.github/workflows/rocm-mi300.yml b/.github/workflows/rocm-mi300.yml index 9c2bae06f32bd..95eef38c17981 100644 --- a/.github/workflows/rocm-mi300.yml +++ b/.github/workflows/rocm-mi300.yml @@ -14,16 +14,16 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: target-determination: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/target_determination.yml - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -56,9 +56,6 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-test: - permissions: - id-token: write - contents: read name: linux-noble-rocm-py3.12-mi300 uses: ./.github/workflows/_rocm-test.yml needs: diff --git a/.github/workflows/rocm-mi355.yml b/.github/workflows/rocm-mi355.yml index 5a77695011f3c..6fac9dc83d503 100644 --- a/.github/workflows/rocm-mi355.yml +++ b/.github/workflows/rocm-mi355.yml @@ -12,16 +12,16 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: target-determination: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/target_determination.yml - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -54,9 +54,6 @@ jobs: secrets: inherit linux-noble-rocm-py3_12-test: - permissions: - id-token: write - contents: read name: linux-noble-rocm-py3.12-mi355 uses: ./.github/workflows/_rocm-test.yml needs: diff --git a/.github/workflows/rocm-navi31.yml b/.github/workflows/rocm-navi31.yml index b1dd74c3ba8d4..7452762a43bfb 100644 --- a/.github/workflows/rocm-navi31.yml +++ b/.github/workflows/rocm-navi31.yml @@ -15,16 +15,16 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: target-determination: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/target_determination.yml - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -38,12 +38,12 @@ jobs: linux-jammy-rocm-py3_10-build: if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-navi31 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-navi31 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -54,10 +54,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3_10 + name: linux-jammy-rocm-py3.10-navi31 uses: ./.github/workflows/_rocm-test.yml needs: - linux-jammy-rocm-py3_10-build diff --git a/.github/workflows/rocm-nightly.yml b/.github/workflows/rocm-nightly.yml index 3a575f9b5607d..e17a002a22de4 100644 --- a/.github/workflows/rocm-nightly.yml +++ b/.github/workflows/rocm-nightly.yml @@ -2,19 +2,30 @@ name: rocm-nightly on: push: - branches: - - nightly # Trigger when nightly branch is updated tags: - ciflow/rocm-nightly/* - workflow_dispatch: # Allow manual triggering + workflow_dispatch: + schedule: + - cron: 0 0 * * * # midnight UTC concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true -permissions: read-all +permissions: + id-token: write + contents: read + actions: read jobs: + target-determination: + if: github.repository_owner == 'pytorch' + name: before-test + uses: ./.github/workflows/target_determination.yml + permissions: + id-token: write + contents: read + get-label-type: name: get-label-type uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main @@ -34,4 +45,28 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-noble-rocm-nightly-py3.12-gfx942 docker-image-name: ci-image:pytorch-linux-noble-rocm-nightly-py3 + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 2, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 3, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 4, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 5, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + { config: "default", shard: 6, num_shards: 6, runner: "linux.rocm.gpu.gfx942.1" }, + ]} + secrets: inherit + + linux-noble-rocm-nightly-py3_12-test: + permissions: + id-token: write + contents: read + name: linux-noble-rocm-nightly-py3.12-gfx942 + uses: ./.github/workflows/_rocm-test.yml + needs: + - linux-noble-rocm-nightly-py3_12-build + - target-determination + with: + build-environment: ${{ needs.linux-noble-rocm-nightly-py3_12-build.outputs.build-environment }} + docker-image: ${{ needs.linux-noble-rocm-nightly-py3_12-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-nightly-py3_12-build.outputs.test-matrix }} secrets: inherit diff --git a/.github/workflows/runner-determinator-validator.yml b/.github/workflows/runner-determinator-validator.yml deleted file mode 100644 index 0710229a7b9ff..0000000000000 --- a/.github/workflows/runner-determinator-validator.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Validate Runner Determinator Script is in Sync - -on: - # Run on PRs when the runner-determinator script is updated to ensure it's copies are kept in sync - pull_request: - paths: - - .github/workflows/_runner-determinator.yml - - .github/workflows/runner-determinator-validator.yml - - .github/scripts/runner_determinator.py - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} - cancel-in-progress: true - -jobs: - check-runner-determinator: - if: github.repository_owner == 'pytorch' - runs-on: ubuntu-latest - - steps: - - name: Checkout repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Run Hardcode runner-determinator script - id: hardcode-script - run: | - # Extract the script content from _runner-determinator.yml and skip the first 10 spaces of each line - script_content=$(awk '/cat < runner_determinator.py/{flag=1;next}/EOF$/{flag=0}flag{print substr($0, 11)}' .github/workflows/_runner-determinator.yml) - - # Write the extracted script content to runner_determinator.py - echo "$script_content" > runner_determinator_workflow.py - - - name: Compare runner-determinator script embedded in workflow with checked in script - run: | - # Compare the extracted runner_determinator script with the existing one - # If this check fails, then make sure the contents of .github/scripts/runner_determinator.py is in sync with the - # version embedded into .github/workflows/_runner-determinator.yml - diff runner_determinator_workflow.py .github/scripts/runner_determinator.py - # Fail the job if the scripts are not identical - continue-on-error: false diff --git a/.github/workflows/runner_determinator_script_sync.yaml b/.github/workflows/runner_determinator_script_sync.yaml deleted file mode 100644 index a5f52f6980f7c..0000000000000 --- a/.github/workflows/runner_determinator_script_sync.yaml +++ /dev/null @@ -1,43 +0,0 @@ -name: runner-determinator - -on: - workflow_dispatch: - pull_request: - branches: [main] - paths: - - .github/workflows/_runner-determinator.yaml - - .github/workflows/_runner_determinator_script_sync.yaml - - .github/workflows/scripts/runner_determinator.py - -jobs: - python-script-sync-check: - if: github.repository_owner == 'pytorch' - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - sparse-checkout: | - .github - - - name: Extract the script from runner_determinator - run: | - # Runner determinator files - RUNNER_DETERMINATOR_WORKFLOW_FILE=.github/workflows/_runner-determinator.yml - RUNNER_DETERMINATOR_PYTHON_SCRIPT_FILE=.github/scripts/runner_determinator.py - - # Parse the job file, extract the script and run it, up to the final EOF, - # to generate the python file in the local folder - yq '.jobs.runner-determinator.steps[] | select(.id == "hardcode-script") | .run' \ - "${RUNNER_DETERMINATOR_WORKFLOW_FILE}" | sed '/^EOF$/q' | bash - - set +e - DIFF="$(diff "$(basename ${RUNNER_DETERMINATOR_PYTHON_SCRIPT_FILE})" ${RUNNER_DETERMINATOR_PYTHON_SCRIPT_FILE})" - IS_DIFF=$? - set -e - if [ $IS_DIFF -eq 0 ]; then - echo "Scripts are in sync! ^_^"; - else - echo -e "Scripts are *NOT* in sync:\n ${DIFF}"; - exit 1 - fi diff --git a/.github/workflows/s390x-periodic.yml b/.github/workflows/s390x-periodic.yml index bf94ee3bb00f1..2b4343fb37040 100644 --- a/.github/workflows/s390x-periodic.yml +++ b/.github/workflows/s390x-periodic.yml @@ -1,13 +1,12 @@ name: s390x-periodic +# DISABLED: this workflow has been failing for months and not providing a useful +# signal. +# Details in the PR description for https://github.com/pytorch/pytorch/pull/181005 +# Re-enable once workflow has been fully fixed and someone is willing to maintain it. on: - schedule: - # We have several schedules so jobs can check github.event.schedule to activate only for a fraction of the runs. - # Also run less frequently on weekends. - - cron: 29 8 * * * # about 1:29am PDT, for mem leak check and rerun disabled tests push: tags: - - ciflow/periodic/* - ciflow/s390/* workflow_dispatch: @@ -18,23 +17,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read linux-manylinux-2_28-py3-cpu-s390x-build: if: github.repository_owner == 'pytorch' @@ -60,9 +54,6 @@ jobs: secrets: inherit linux-manylinux-2_28-py3-cpu-s390x-test: - permissions: - id-token: write - contents: read name: linux-manylinux-2_28-py3-cpu-s390x uses: ./.github/workflows/_linux-test.yml needs: diff --git a/.github/workflows/slow-rocm-mi200.yml b/.github/workflows/slow-rocm-mi200.yml index f1e3bcc00b2b5..dfc0f70a7c598 100644 --- a/.github/workflows/slow-rocm-mi200.yml +++ b/.github/workflows/slow-rocm-mi200.yml @@ -21,23 +21,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -50,12 +45,12 @@ jobs: curr_ref_type: ${{ github.ref_type }} linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-mi200 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 sync-tag: rocm-build test-matrix: | @@ -66,10 +61,7 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi200 uses: ./.github/workflows/_rocm-test.yml needs: - linux-jammy-rocm-py3_10-build diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 8da9c9bd219d5..ff1fda43d39ba 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -21,23 +21,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -107,14 +102,14 @@ jobs: test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-sm86-build.outputs.test-matrix }} secrets: inherit - linux-jammy-py3_10-clang15-build: - name: linux-jammy-py3.10-clang15 + linux-jammy-py3_10-clang18-build: + name: linux-jammy-py3.10-clang18 uses: ./.github/workflows/_linux-build.yml needs: get-label-type with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3.10-clang15 - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang15 + build-environment: linux-jammy-py3.10-clang18 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, @@ -122,16 +117,16 @@ jobs: ]} secrets: inherit - linux-jammy-py3_10-clang15-test: - name: linux-jammy-py3.10-clang15 + linux-jammy-py3_10-clang18-test: + name: linux-jammy-py3.10-clang18 uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3_10-clang15-build + - linux-jammy-py3_10-clang18-build - target-determination with: - build-environment: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3_10-clang15-build.outputs.test-matrix }} + build-environment: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_10-clang18-build.outputs.test-matrix }} secrets: inherit linux-jammy-py3_10-clang18-asan-build: @@ -142,7 +137,7 @@ jobs: runner: linux.c7i.4xlarge runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-py3.10-clang18-asan - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-asan + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 test-matrix: | { include: [ { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 3438b1dd5ac57..2121d06a1cde3 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -25,23 +25,18 @@ jobs: runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" # 1 GPU A10G 24GB each environment: target-determinator-env steps: - - name: Clone PyTorch - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - path: pytorch - - name: Setup Linux - uses: ./pytorch/.github/actions/setup-linux + uses: pytorch/pytorch/.github/actions/setup-linux@main - name: Login to ECR - uses: ./pytorch/.github/actions/ecr-login + uses: ./.github/actions/ecr-login - name: Calculate docker image id: calculate-docker-image uses: pytorch/test-infra/.github/actions/calculate-docker-image@main with: docker-image-name: ci-image:pytorch-linux-jammy-cuda12.8-cudnn9-py3-gcc11 - working-directory: pytorch + working-directory: . - name: Use following to pull public copy of the image id: print-ghcr-mirror @@ -105,6 +100,9 @@ jobs: # detached container should get cleaned up by teardown_ec2_linux # Disable shellcheck warning for GPU_FLAG # shellcheck disable=SC2086 + # setup-linux checks out pytorch directly at GITHUB_WORKSPACE, but + # llm-target-determinator@v0.0.2 expects `pytorch/` as a sibling of + # its own checkout, so remap the layout inside the container. container_name=$(docker run \ ${GPU_FLAG:-} \ -e MAX_JOBS="$(nproc --ignore=2)" \ @@ -115,11 +113,13 @@ jobs: --tty \ --detach \ --user jenkins \ - -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace" \ + -v "${GITHUB_WORKSPACE}:/var/lib/jenkins/workspace/pytorch" \ + -v "${GITHUB_WORKSPACE}/codellama:/var/lib/jenkins/workspace/codellama" \ + -v "${GITHUB_WORKSPACE}/llm-target-determinator:/var/lib/jenkins/workspace/llm-target-determinator" \ -w /var/lib/jenkins/workspace \ "${DOCKER_IMAGE}" ) - chmod +x pytorch/.github/scripts/td_llm_indexer.sh + chmod +x .github/scripts/td_llm_indexer.sh docker exec -t "${container_name}" sh -c 'pytorch/.github/scripts/td_llm_indexer.sh' - name: Upload to s3 diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index c712b11185a76..9594ce0bf8f63 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -22,18 +22,11 @@ jobs: runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" needs: get-label-type steps: - # [pytorch repo ref] - # Use a pytorch/pytorch reference instead of a reference to the local - # checkout because when we run this action we don't *have* a local - # checkout. In other cases you should prefer a local checkout. - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main with: submodules: false - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Get workflow job id id: get-job-id uses: ./.github/actions/get-workflow-job-id diff --git a/.github/workflows/test-b200.yml b/.github/workflows/test-b200.yml index 19dcb07c29844..e753dd27ca938 100644 --- a/.github/workflows/test-b200.yml +++ b/.github/workflows/test-b200.yml @@ -35,6 +35,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/test-h100.yml b/.github/workflows/test-h100.yml index 4351b427b0b8a..c1e8c13797da7 100644 --- a/.github/workflows/test-h100.yml +++ b/.github/workflows/test-h100.yml @@ -22,6 +22,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 2b370f6083185..f14e66576fd96 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -13,6 +13,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: get-default-label-prefix: diff --git a/.github/workflows/torchtitan.yml b/.github/workflows/torchtitan.yml index 92bda75958e54..b12f980ec4202 100644 --- a/.github/workflows/torchtitan.yml +++ b/.github/workflows/torchtitan.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: build: diff --git a/.github/workflows/trunk-rocm-sandbox.yml b/.github/workflows/trunk-rocm-sandbox.yml index ddeb4e5809aec..bef62c9fbe5bc 100644 --- a/.github/workflows/trunk-rocm-sandbox.yml +++ b/.github/workflows/trunk-rocm-sandbox.yml @@ -16,23 +16,18 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: llm-td: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -70,9 +65,6 @@ jobs: secrets: inherit linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read name: linux-jammy-rocm-py3.10 uses: ./.github/workflows/_rocm-test.yml needs: diff --git a/.github/workflows/trunk-tagging.yml b/.github/workflows/trunk-tagging.yml index d96f2de8366aa..27472f786794c 100644 --- a/.github/workflows/trunk-tagging.yml +++ b/.github/workflows/trunk-tagging.yml @@ -28,8 +28,8 @@ jobs: - name: Pre-checkout validation run: | # For workflow_dispatch, validate SHA format before checkout - if [ -n "${{ github.event.inputs.commit_sha }}" ]; then - COMMIT_SHA="${{ github.event.inputs.commit_sha }}" + if [ -n "${INPUT_COMMIT_SHA}" ]; then + COMMIT_SHA="${INPUT_COMMIT_SHA}" # Verify it's a well-formed SHA (40 hex characters) if ! echo "${COMMIT_SHA}" | grep -qE '^[a-f0-9]{40}$'; then @@ -42,6 +42,8 @@ jobs: echo "✅ Using current commit SHA - no pre-checkout validation needed" fi + env: + INPUT_COMMIT_SHA: ${{ github.event.inputs.commit_sha }} - name: Checkout repository uses: actions/checkout@v4 with: @@ -53,8 +55,8 @@ jobs: - name: Set commit SHA id: commit run: | - if [ -n "${{ github.event.inputs.commit_sha }}" ]; then - COMMIT_SHA="${{ github.event.inputs.commit_sha }}" + if [ -n "${INPUT_COMMIT_SHA}" ]; then + COMMIT_SHA="${INPUT_COMMIT_SHA}" else COMMIT_SHA="${{ github.sha }}" fi @@ -63,6 +65,8 @@ jobs: echo "tag_name=trunk/${COMMIT_SHA}" } >> "${GITHUB_OUTPUT}" + env: + INPUT_COMMIT_SHA: ${{ github.event.inputs.commit_sha }} - name: Validate commit SHA run: | COMMIT_SHA="${{ steps.commit.outputs.sha }}" @@ -74,7 +78,7 @@ jobs: fi # For workflow_dispatch, verify the commit exists on main branch - if [ -n "${{ github.event.inputs.commit_sha }}" ]; then + if [ -n "${INPUT_COMMIT_SHA}" ]; then echo "Manual dispatch detected - validating commit is on main branch..." # Get all commits reachable from main branch @@ -89,6 +93,8 @@ jobs: echo "✅ Commit ${COMMIT_SHA} is valid (automatic push trigger)" fi + env: + INPUT_COMMIT_SHA: ${{ github.event.inputs.commit_sha }} - name: Create and push tag(s) with retry id: check_tag env: @@ -309,7 +315,10 @@ jobs: echo " Name: ${{ steps.commit.outputs.tag_name }}" echo " Commit: ${{ steps.commit.outputs.sha }}" echo " Trigger: ${{ github.event_name }}" - if [ -n "${{ github.event.inputs.commit_sha }}" ]; then - echo " Manual commit: ${{ github.event.inputs.commit_sha }}" + if [ -n "${INPUT_COMMIT_SHA}" ]; then + echo " Manual commit: ${INPUT_COMMIT_SHA}" fi fi + + env: + INPUT_COMMIT_SHA: ${{ github.event.inputs.commit_sha }} \ No newline at end of file diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index e3bf1d63613dc..f992c2adc8ff0 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -30,6 +30,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: # See job-filter.yml for rules on adding job filter conditions @@ -44,17 +45,11 @@ jobs: if: github.repository_owner == 'pytorch' name: before-test uses: ./.github/workflows/llm_td_retrieval.yml - permissions: - id-token: write - contents: read target-determination: name: before-test uses: ./.github/workflows/target_determination.yml needs: llm-td - permissions: - id-token: write - contents: read get-label-type: name: get-label-type @@ -65,6 +60,7 @@ jobs: issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + check_experiments: arc,lf libtorch-linux-jammy-cuda12_8-py3_10-gcc11-debug-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' libtorch-linux-jammy-cuda12.8-py3.10-gcc11-debug ') }} @@ -83,10 +79,14 @@ jobs: { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit linux-jammy-cuda12_8-py3_10-gcc11-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda12.8-py3.10-gcc11 ') || contains(needs.job-filter.outputs.jobs, ' cross-compile-linux-test ') }} + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda12.8-py3.10-gcc11 ') }} name: linux-jammy-cuda12.8-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml needs: @@ -107,9 +107,12 @@ jobs: { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, - { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + { config: "libtorch_agnostic_targetting", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.metal.nvidia.gpu" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + cuda-version: "12.8" secrets: inherit # CUDA 12.8 GPU tests moved to periodic.yml to reduce per-commit compute. @@ -118,7 +121,7 @@ jobs: # See P2188981399 for the full CI workflow analysis. linux-jammy-cuda13_0-py3_10-gcc11-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda13.0-py3.10-gcc11 ') }} + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-cuda13.0-py3.10-gcc11 ') || contains(needs.job-filter.outputs.jobs, ' cross-compile-linux-test-cuda13 ') }} name: linux-jammy-cuda13.0-py3.10-gcc11 uses: ./.github/workflows/_linux-build.yml needs: @@ -139,8 +142,14 @@ jobs: { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, - { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "linux.g4dn.metal.nvidia.gpu" }, + { config: "pr_time_benchmarks", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.metal.nvidia.gpu" }, + # Test cross-compiled models with Windows libs extracted from wheel (CUDA 13.0) + { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda13.0-py3" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + cuda-version: "13.0" secrets: inherit linux-jammy-cuda13_0-py3_10-gcc11-test: @@ -149,14 +158,20 @@ jobs: uses: ./.github/workflows/_linux-test.yml needs: - linux-jammy-cuda13_0-py3_10-gcc11-build + - win-vs2022-cuda13_0-py3-build - target-determination - job-filter + - get-label-type with: timeout-minutes: 360 build-environment: linux-jammy-cuda13.0-py3.10-gcc11 docker-image: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.docker-image }} test-matrix: ${{ needs.linux-jammy-cuda13_0-py3_10-gcc11-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + cuda-version: "13.0" secrets: inherit # no-ops builds test USE_PER_OPERATOR_HEADERS=0 where ATen/ops is not generated @@ -177,6 +192,10 @@ jobs: { include: [ { config: "default", shard: 1, num_shards: 1 }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + cuda-version: "13.0" secrets: inherit macos-py3-arm64-build: @@ -218,6 +237,8 @@ jobs: disable-monitor: false secrets: inherit + # NB: Windows runners are not available in OSDC, so we run this on Meta account + # for now until we get OSDC deployed to LF account win-vs2022-cpu-py3-build: if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' win-vs2022-cpu-py3 ') }} name: win-vs2022-cpu-py3 @@ -228,14 +249,14 @@ jobs: with: build-environment: win-vs2022-cpu-py3 cuda-version: cpu - runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + runner: windows.4xlarge.nonephemeral test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, - { config: "openreg", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" }, + { config: "default", shard: 1, num_shards: 4, runner: "windows.4xlarge.nonephemeral" }, + { config: "default", shard: 2, num_shards: 4, runner: "windows.4xlarge.nonephemeral" }, + { config: "default", shard: 3, num_shards: 4, runner: "windows.4xlarge.nonephemeral" }, + { config: "default", shard: 4, num_shards: 4, runner: "windows.4xlarge.nonephemeral" }, + { config: "openreg", shard: 1, num_shards: 1, runner: "windows.4xlarge.nonephemeral" }, ]} secrets: inherit @@ -254,30 +275,33 @@ jobs: disable-monitor: false secrets: inherit - win-vs2022-cuda12_8-py3-build: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' win-vs2022-cuda12.8-py3 ') || contains(needs.job-filter.outputs.jobs, ' cross-compile-linux-test ') }} - name: win-vs2022-cuda12.8-py3 + # NB: Windows runners are not available in OSDC, so we run this on Meta account + # for now until we get OSDC deployed to LF account + win-vs2022-cuda13_0-py3-build: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' win-vs2022-cuda13.0-py3 ') || contains(needs.job-filter.outputs.jobs, ' cross-compile-linux-test-cuda13 ') }} + name: win-vs2022-cuda13.0-py3 uses: ./.github/workflows/_win-build.yml needs: - get-label-type - job-filter with: - build-environment: win-vs2022-cuda12.8-py3 - cuda-version: "12.8" - runner: "${{ needs.get-label-type.outputs.label-type }}windows.4xlarge.nonephemeral" + build-environment: win-vs2022-cuda13.0-py3 + cuda-version: "13.0" + runner: windows.4xlarge.nonephemeral secrets: inherit linux-jammy-rocm-py3_10-build: - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi355 uses: ./.github/workflows/_linux-build.yml needs: - get-label-type - job-filter with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-rocm-py3.10 + build-environment: linux-jammy-rocm-py3.10-mi355 docker-image-name: ci-image:pytorch-linux-jammy-rocm-n-py3 - sync-tag: rocm-build + # TODO (huydhn): Add this back once other workflow migrates + # sync-tag: rocm-build test-matrix: | { include: [ { config: "default", shard: 1, num_shards: 6, runner: "linux.rocm.gpu.gfx950.1" }, @@ -290,13 +314,12 @@ jobs: { config: "distributed", shard: 2, num_shards: 3, runner: "linux.rocm.gpu.gfx950.2" }, { config: "distributed", shard: 3, num_shards: 3, runner: "linux.rocm.gpu.gfx950.2" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" secrets: inherit linux-jammy-rocm-py3_10-test: - permissions: - id-token: write - contents: read - name: linux-jammy-rocm-py3.10 + name: linux-jammy-rocm-py3.10-mi355 uses: ./.github/workflows/_rocm-test.yml needs: - linux-jammy-rocm-py3_10-build @@ -317,29 +340,14 @@ jobs: - get-label-type - job-filter with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-jammy-cuda13.0-py3.12-gcc11-sm80 docker-image-name: ci-image:pytorch-linux-jammy-cuda13.0-cudnn9-py3-gcc11-inductor-benchmarks cuda-arch-list: '8.0' - secrets: inherit - - # Test cross-compiled models with Windows libs extracted from wheel - cross-compile-linux-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' cross-compile-linux-test ') }} - name: cross-compile-linux-test - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-jammy-cuda12_8-py3_10-gcc11-build - - get-label-type - - win-vs2022-cuda12_8-py3-build - - job-filter - with: - build-environment: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-cuda12_8-py3_10-gcc11-build.outputs.docker-image }} - test-matrix: | - { include: [ - { config: "aoti_cross_compile_for_windows", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g6.4xlarge.experimental.nvidia.gpu", win_torch_wheel_artifact: "win-vs2022-cuda12.8-py3" }, - ]} - tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.12" + compiler: gcc11 + cuda-version: "13.0" secrets: inherit verify-cachebench-cpu-build: @@ -357,6 +365,9 @@ jobs: { include: [ { config: "verify_cachebench", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit verify-cachebench-cpu-test: @@ -367,16 +378,20 @@ jobs: - verify-cachebench-cpu-build - target-determination - job-filter + - get-label-type with: build-environment: ${{ needs.verify-cachebench-cpu-build.outputs.build-environment }} docker-image: ${{ needs.verify-cachebench-cpu-build.outputs.docker-image }} test-matrix: ${{ needs.verify-cachebench-cpu-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 secrets: inherit - linux-jammy-py3-clang15-executorch-build: -# if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3-clang15-executorch ') }} - name: linux-jammy-py3-clang15-executorch + linux-jammy-py3-clang18-executorch-build: +# if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3-clang18-executorch ') }} + name: linux-jammy-py3-clang18-executorch uses: ./.github/workflows/_linux-build.yml needs: - get-label-type @@ -384,25 +399,25 @@ jobs: if: false # Has been broken for a while with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build-environment: linux-jammy-py3-clang15-executorch - docker-image-name: ci-image:pytorch-linux-jammy-py3-clang15-executorch + build-environment: linux-jammy-py3-clang18-executorch + docker-image-name: ci-image:pytorch-linux-jammy-py3-clang18-executorch test-matrix: | { include: [ { config: "executorch", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, ]} secrets: inherit - linux-jammy-py3-clang15-executorch-test: - if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3-clang15-executorch ') }} - name: linux-jammy-py3-clang15-executorch + linux-jammy-py3-clang18-executorch-test: + if: ${{ needs.job-filter.outputs.jobs == '' || contains(needs.job-filter.outputs.jobs, ' linux-jammy-py3-clang18-executorch ') }} + name: linux-jammy-py3-clang18-executorch uses: ./.github/workflows/_linux-test.yml needs: - - linux-jammy-py3-clang15-executorch-build + - linux-jammy-py3-clang18-executorch-build - job-filter with: - build-environment: ${{ needs.linux-jammy-py3-clang15-executorch-build.outputs.build-environment }} - docker-image: ${{ needs.linux-jammy-py3-clang15-executorch-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-jammy-py3-clang15-executorch-build.outputs.test-matrix }} + build-environment: ${{ needs.linux-jammy-py3-clang18-executorch-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-py3-clang18-executorch-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3-clang18-executorch-build.outputs.test-matrix }} tests-to-include: ${{ github.event.inputs.tests-to-include || '' }} secrets: inherit @@ -417,5 +432,46 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.r7i.2xlarge build-environment: linux-jammy-py3.10-gcc11-full-debug-build-only - docker-image-name: ci-image:pytorch-linux-jammy-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-py3.10-clang18 + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc11 + secrets: inherit + + linux-jammy-aarch64-py3_10-gcc13-build: + name: linux-jammy-aarch64-py3.10 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: ${{ needs.get-label-type.outputs.label-type }} + build-environment: linux-jammy-aarch64-py3.10 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 + runner: linux.arm64.m7g.4xlarge + # Periodic AArch64 tests for SVE256 coverage + test-matrix: | + { include: [ + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.arm64.m7g.4xlarge" }, + ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc13 + secrets: inherit + + linux-jammy-aarch64-py3_10-gcc13-test: + name: linux-jammy-aarch64-py3.10 + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-aarch64-py3_10-gcc13-build + - get-label-type + with: + build-environment: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.build-environment }} + docker-image: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-aarch64-py3_10-gcc13-build.outputs.test-matrix }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.10" + compiler: gcc13 secrets: inherit diff --git a/.github/workflows/tsan.yml b/.github/workflows/tsan.yml new file mode 100644 index 0000000000000..e4e1d381a6235 --- /dev/null +++ b/.github/workflows/tsan.yml @@ -0,0 +1,60 @@ +name: TSan + +on: + push: + tags: + - ciflow/tsan/* + - ciflow/trunk/* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + actions: read + +jobs: + get-label-type: + if: github.repository_owner == 'pytorch' + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + + linux-jammy-py3_14t-clang18-tsan-build: + name: linux-jammy-py3.14t-clang18-tsan + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-jammy-py3.14t-clang18-tsan + docker-image-name: ci-image:pytorch-linux-jammy-py3.14t-clang18 + test-matrix: | + { include: [ + { config: "tsan", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + ]} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14t" + compiler: clang18 + secrets: inherit + + linux-jammy-py3_14t-clang18-tsan-test: + name: linux-jammy-py3.14t-clang18-tsan + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-jammy-py3_14t-clang18-tsan-build + - get-label-type + with: + build-environment: linux-jammy-py3.14t-clang18-tsan + docker-image: ${{ needs.linux-jammy-py3_14t-clang18-tsan-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-jammy-py3_14t-clang18-tsan-build.outputs.test-matrix }} + use-arc: ${{ needs.get-label-type.outputs.use-arc == 'true' }} + python-version: "3.14t" + compiler: clang18 + secrets: inherit diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index 1b4af0f274913..439a2d8c6457e 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -23,7 +23,7 @@ jobs: with: repository: pytorch/pytorch stable-branch: viable/strict - requires: '[\"pull\", \"trunk\", \"lint\", \"linux-aarch64\"]' + requires: '[\"pull\", \"trunk\", \"lint\"]' secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }} clickhouse-url: ${{ secrets.CLICKHOUSE_URL }} clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }} diff --git a/.github/workflows/upload-test-stats-while-running.yml b/.github/workflows/upload-test-stats-while-running.yml index 9aecaad0e068f..7771278448a49 100644 --- a/.github/workflows/upload-test-stats-while-running.yml +++ b/.github/workflows/upload-test-stats-while-running.yml @@ -15,15 +15,11 @@ jobs: name: Upload test stats while running runs-on: linux.2xlarge steps: - - name: Checkout PyTorch - uses: pytorch/pytorch/.github/actions/checkout-pytorch@main + - name: Setup Linux + uses: pytorch/pytorch/.github/actions/setup-linux@main with: - fetch-depth: 1 submodules: false - - name: Setup Linux - uses: ./.github/actions/setup-linux - - name: Install requirements run: | python3 -m pip install requests==2.32.2 boto3==1.35.42 diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index 5dcbfe7fd65fa..8d274071662e1 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -20,6 +20,7 @@ on: - rocm-mi300 - rocm-mi355 - rocm-navi31 + - rocm-nightly - inductor-micro-benchmark - inductor-micro-benchmark-x86 - inductor-cu124 @@ -29,7 +30,6 @@ on: - inductor-perf-nightly-rocm-mi300 - inductor-perf-nightly-rocm-mi355 - mac-mps - - linux-aarch64 types: - completed diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index eb1c78019b10e..f65e73042c245 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -16,6 +16,7 @@ concurrency: permissions: id-token: write contents: read + actions: read jobs: build: diff --git a/.github/workflows/win-arm64-build-test.yml b/.github/workflows/win-arm64-build-test.yml index 95b4e2f027f60..5c31628b96871 100644 --- a/.github/workflows/win-arm64-build-test.yml +++ b/.github/workflows/win-arm64-build-test.yml @@ -31,7 +31,7 @@ jobs: steps: - name: configure aws credentials id: aws_creds - uses: aws-actions/configure-aws-credentials@v4 + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4 with: role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_sscache aws-region: us-east-1 diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index fbfa9d62cc571..8e089db1f9cee 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -14,6 +14,11 @@ concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} cancel-in-progress: true +permissions: + id-token: write + contents: read + actions: read + jobs: get-label-type: @@ -78,9 +83,6 @@ jobs: name: linux-noble-xpu-n-py3.10 uses: ./.github/workflows/_xpu-test.yml needs: linux-noble-xpu-n-py3_10-build - permissions: - id-token: write - contents: read with: build-environment: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-xpu-n-py3_10-build.outputs.docker-image }} @@ -107,9 +109,6 @@ jobs: name: linux-noble-xpu-n-py3.10-client uses: ./.github/workflows/_xpu-test.yml needs: linux-noble-xpu-n-py3_10-client-build - permissions: - id-token: write - contents: read with: build-environment: ${{ needs.linux-noble-xpu-n-py3_10-client-build.outputs.build-environment }} docker-image: ${{ needs.linux-noble-xpu-n-py3_10-client-build.outputs.docker-image }} diff --git a/.gitignore b/.gitignore index 54e3fa32b2fdb..5105977ae52c0 100644 --- a/.gitignore +++ b/.gitignore @@ -8,8 +8,7 @@ ## PyTorch -.coverage -coverage.xml +agent_space/ .dmypy.json .gradle .hypothesis @@ -40,7 +39,6 @@ docs/build/ docs/cpp/src docs/src/**/* docs/cpp/build -docs/cpp/source/api docs/cpp/source/html/ docs/cpp/source/latex/ docs/source/compile/generated/ @@ -52,7 +50,6 @@ usage_log* test-reports/ test/*.bak test/**/*.bak -test/.coverage test/.hypothesis/ test/cpp/api/mnist test/custom_operator/model.pt @@ -367,7 +364,13 @@ xla/ pr.diff # coverage files -*/**/.coverage.* +.coverage.* +.coverage +coverage.xml +*.profdata +*.profraw +coverage-html +coverage.lcov # buck generated files .buckd/ diff --git a/.lintrunner.toml b/.lintrunner.toml index 5c42dc81c44d8..cd4f0a3dce756 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -191,6 +191,11 @@ exclude_patterns = [ 'aten/src/ATen/cpu/FlushDenormal.cpp', 'aten/src/ATen/cpu/vml.h', 'aten/src/ATen/CPUFixedAllocator.h', + # interned_strings.h includes the generated aten_interned_strings.h which + # #errors under TORCH_ASSERT_ONLY_METHOD_OPERATORS. The .cpp file #undefs + # the macro before including the header, but clang-tidy lints the .h + # standalone using the .cpp's compile flags, so the #error always fires. + 'aten/src/ATen/core/interned_strings.h', 'aten/src/ATen/Parallel*.h', 'c10/xpu/**/*.h', 'c10/xpu/**/*.cpp', @@ -532,6 +537,8 @@ exclude_patterns = [ 'torch/csrc/inductor/aoti_runtime/**', # Test files use EXPECT_THROW which is a gtest macro 'test/cpp/**/*.cpp', + # JIT frontend uses throw(ErrorReport(...)) pervasively + 'torch/csrc/jit/frontend/schema_type_parser.cpp', ] command = [ 'python3', @@ -1496,27 +1503,6 @@ command = [ '@{{PATHSFILE}}' ] -[[linter]] -code = 'BAZEL_LINTER' -include_patterns = ['WORKSPACE'] -command = [ - 'python3', - 'tools/linter/adapters/bazel_linter.py', - '--binary=.lintbin/bazel', - '--', - '@{{PATHSFILE}}' -] -init_command = [ - 'python3', - 'tools/linter/adapters/s3_init.py', - '--config-json=tools/linter/adapters/s3_init_config.json', - '--linter=bazel', - '--dry-run={{DRYRUN}}', - '--output-dir=.lintbin', - '--output-name=bazel', -] -is_formatter = true - [[linter]] code = 'LINTRUNNER_VERSION' include_patterns = ['**'] diff --git a/.spin/cmds.py b/.spin/cmds.py index f31acfa977424..f1aceb77edd25 100644 --- a/.spin/cmds.py +++ b/.spin/cmds.py @@ -8,6 +8,11 @@ import spin +CWD = Path(__file__).absolute().parent.parent +sys.path.insert(0, str(CWD)) # this only affects the current process +from tools.clean import clean as _clean + + def file_digest(file, algorithm: str): try: return hashlib.file_digest(file, algorithm) @@ -133,7 +138,6 @@ def regenerate_clangtidy_files(): #: These linters are expected to need less than 3s cpu time total VERY_FAST_LINTERS = { "ATEN_CPU_GPU_AGNOSTIC", - "BAZEL_LINTER", "C10_NODISCARD", "C10_UNUSED", "CALL_ONCE", @@ -441,6 +445,12 @@ def quickfix(ctx, *, lintrunner_args, **kwargs): ctx.invoke(quicklint, apply_patches=True) +@click.command() +def clean(): + """Clean, that is remove all files in .gitignore except in the NOT-CLEAN-FILES section.""" + _clean() + + @click.command() def regenerate_github_workflows(): """Regenerate GitHub workflows from templates.""" diff --git a/AGENTS.md b/AGENTS.md deleted file mode 100644 index 718217d3e663d..0000000000000 --- a/AGENTS.md +++ /dev/null @@ -1,18 +0,0 @@ -- This is the only AGENTS.md, there are no recursive AGENTS.md -- When you are working on a bug, first create a standalone file that - reproduces the bug and verify it fails in the expected way. Use this to - test if your changes work. Once the change is passing, find an appropriate - test file to add the test to and make sure to follow local conventions on - the test file. -- If you are running the real test suite, DO NOT run the entire test suite. - Instead run only a single test case, e.g., 'python test/test_torch.py TestTorch.test_dir' -- Do NOT run setup.py, you do not have a working build environment -- Do NOT run pre-commit, it is not setup -- To run lint, run 'lintrunner -a' (which will autoapply changes) -- Do NOT attempt to install dependencies, you do not have Internet access -- Do NOT create summary files unless explicitly asked -- When you are ready to make a PR, do exactly these steps: - - git stash -u - - git reset --hard $(cat /tmp/orig_work.txt) # NB: reset to the LOCAL branch, do NOT fetch - - git stash pop - - Resolve conflicts if necessary diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000000000..681311eb9cf45 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel deleted file mode 100644 index 4737a2a0c486c..0000000000000 --- a/BUILD.bazel +++ /dev/null @@ -1,1104 +0,0 @@ -load("@bazel_skylib//lib:paths.bzl", "paths") -load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test") -load("@rules_python//python:defs.bzl", "py_library", "py_test") -load("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule") -load("@pytorch//:tools/bazel.bzl", "rules") -load("@pytorch//tools/rules:cu.bzl", "cu_library") -load("@pytorch//tools/config:defs.bzl", "if_cuda") -load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops") -load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets") -load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources") -load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources") -load("//:tools/bazel.bzl", "rules") - -# Export files for use by torch/headeronly (where version.h generation now lives) -exports_files(["version.txt"]) - -define_targets(rules = rules) - -COMMON_COPTS = [ - "-DHAVE_MALLOC_USABLE_SIZE=1", - "-DHAVE_MMAP=1", - "-DHAVE_SHM_OPEN=1", - "-DHAVE_SHM_UNLINK=1", - "-D_FILE_OFFSET_BITS=64", - "-DUSE_FBGEMM", - "-DUSE_DISTRIBUTED", - "-DAT_PER_OPERATOR_HEADERS", - "-DATEN_THREADING=NATIVE", - "-DNO_CUDNN_DESTROY_HANDLE", -] + if_cuda([ - "-DUSE_CUDA", - "-DUSE_CUDNN", - # TODO: This should be passed only when building for CUDA-11.5 or newer - # use cub in a safe manner, see: - # https://github.com/pytorch/pytorch/pull/55292 - "-DCUB_WRAPPED_NAMESPACE=at_cuda_detail", -]) - -aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"]) - -generated_cpu_cpp = [ - "aten/src/ATen/RegisterBackendSelect.cpp", - "aten/src/ATen/RegisterCPU_0.cpp", - "aten/src/ATen/RegisterCPU_1.cpp", - "aten/src/ATen/RegisterCPU_2.cpp", - "aten/src/ATen/RegisterCPU_3.cpp", - "aten/src/ATen/RegisterFunctionalization_0.cpp", - "aten/src/ATen/RegisterFunctionalization_1.cpp", - "aten/src/ATen/RegisterFunctionalization_2.cpp", - "aten/src/ATen/RegisterFunctionalization_3.cpp", - # "aten/src/ATen/RegisterFunctionalizationEverything.cpp", - "aten/src/ATen/RegisterMkldnnCPU_0.cpp", - "aten/src/ATen/RegisterNestedTensorCPU_0.cpp", - "aten/src/ATen/RegisterQuantizedCPU_0.cpp", - "aten/src/ATen/RegisterSparseCPU_0.cpp", - "aten/src/ATen/RegisterSparseCsrCPU_0.cpp", - "aten/src/ATen/RegisterZeroTensor_0.cpp", - "aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp", - "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp", - "aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp", - "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp", - "aten/src/ATen/RegisterMeta_0.cpp", - "aten/src/ATen/RegisterSparseMeta_0.cpp", - "aten/src/ATen/RegisterQuantizedMeta_0.cpp", - "aten/src/ATen/RegisterNestedTensorMeta_0.cpp", - "aten/src/ATen/RegisterSchema.cpp", - "aten/src/ATen/CPUFunctions.h", - "aten/src/ATen/CPUFunctions_inl.h", - "aten/src/ATen/CompositeExplicitAutogradFunctions.h", - "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h", - "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h", - "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h", - "aten/src/ATen/CompositeImplicitAutogradFunctions.h", - "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h", - "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h", - "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h", - "aten/src/ATen/CompositeViewCopyKernels.cpp", - "aten/src/ATen/FunctionalInverses.h", - "aten/src/ATen/Functions.h", - "aten/src/ATen/Functions.cpp", - "aten/src/ATen/RedispatchFunctions.h", - "aten/src/ATen/Operators.h", - "aten/src/ATen/Operators_0.cpp", - "aten/src/ATen/Operators_1.cpp", - "aten/src/ATen/Operators_2.cpp", - "aten/src/ATen/Operators_3.cpp", - "aten/src/ATen/Operators_4.cpp", - "aten/src/ATen/NativeFunctions.h", - "aten/src/ATen/MetaFunctions.h", - "aten/src/ATen/MetaFunctions_inl.h", - "aten/src/ATen/MethodOperators.h", - "aten/src/ATen/NativeMetaFunctions.h", - "aten/src/ATen/RegistrationDeclarations.h", - "aten/src/ATen/VmapGeneratedPlumbing.h", - "aten/src/ATen/ViewMetaClasses.h", - "aten/src/ATen/ViewMetaClasses.cpp", - "aten/src/ATen/core/aten_interned_strings.h", - "aten/src/ATen/core/enum_tag.h", - "aten/src/ATen/core/TensorBody.h", - "aten/src/ATen/core/TensorMethods.cpp", - "aten/src/ATen/core/ATenOpList.cpp", -] - -generated_cuda_cpp = [ - "aten/src/ATen/CUDAFunctions.h", - "aten/src/ATen/CUDAFunctions_inl.h", - "aten/src/ATen/RegisterCUDA_0.cpp", - "aten/src/ATen/RegisterNestedTensorCUDA_0.cpp", - "aten/src/ATen/RegisterQuantizedCUDA_0.cpp", - "aten/src/ATen/RegisterSparseCUDA_0.cpp", - "aten/src/ATen/RegisterSparseCsrCUDA_0.cpp", -] - -generate_aten( - name = "generated_aten_cpp", - srcs = aten_generation_srcs, - outs = ( - generated_cpu_cpp + - generated_cuda_cpp + - aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") + - aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") + - aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [ - "aten/src/ATen/Declarations.yaml", - ] - ), - generator = "//torchgen:gen", -) - -filegroup( - name = "cpp_generated_code", - srcs = GENERATED_AUTOGRAD_CPP, - data = [":generate-code"], -) - -# ATen -filegroup( - name = "aten_base_cpp", - srcs = glob([ - "aten/src/ATen/*.cpp", - "aten/src/ATen/functorch/*.cpp", - "aten/src/ATen/detail/*.cpp", - "aten/src/ATen/cpu/*.cpp", - ]), -) - -filegroup( - name = "ATen_CORE_SRCS", - srcs = glob( - [ - "aten/src/ATen/core/**/*.cpp", - ], - exclude = [ - "aten/src/ATen/core/**/*_test.cpp", - ], - ), -) - -filegroup( - name = "aten_native_cpp", - srcs = glob(["aten/src/ATen/native/*.cpp"]), -) - -filegroup( - name = "aten_native_sparse_cpp", - srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]), -) - -filegroup( - name = "aten_native_nested_cpp", - srcs = glob(["aten/src/ATen/native/nested/*.cpp"]), -) - -filegroup( - name = "aten_native_quantized_cpp", - srcs = glob( - [ - "aten/src/ATen/native/quantized/*.cpp", - "aten/src/ATen/native/quantized/cpu/*.cpp", - ], - ), -) - -filegroup( - name = "aten_native_transformers_cpp", - srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]), -) - -filegroup( - name = "aten_native_mkl_cpp", - srcs = glob([ - "aten/src/ATen/native/mkl/*.cpp", - "aten/src/ATen/mkl/*.cpp", - ]), -) - -filegroup( - name = "aten_native_mkldnn_cpp", - srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]), -) - -filegroup( - name = "aten_native_xnnpack", - srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]), -) - -filegroup( - name = "aten_base_vulkan", - srcs = glob(["aten/src/ATen/vulkan/*.cpp"]), -) - -filegroup( - name = "aten_base_metal", - srcs = glob(["aten/src/ATen/metal/*.cpp"]), -) - -filegroup( - name = "ATen_QUANTIZED_SRCS", - srcs = glob( - [ - "aten/src/ATen/quantized/**/*.cpp", - ], - exclude = [ - "aten/src/ATen/quantized/**/*_test.cpp", - ], - ), -) - -filegroup( - name = "aten_cuda_cpp_srcs", - srcs = glob( - [ - "aten/src/ATen/cuda/*.cpp", - "aten/src/ATen/cuda/detail/*.cpp", - "aten/src/ATen/cuda/tunable/*.cpp", - "aten/src/ATen/cudnn/*.cpp", - "aten/src/ATen/native/cuda/*.cpp", - "aten/src/ATen/native/cuda/linalg/*.cpp", - "aten/src/ATen/native/cudnn/*.cpp", - "aten/src/ATen/native/miopen/*.cpp", - "aten/src/ATen/native/nested/cuda/*.cpp", - "aten/src/ATen/native/quantized/cuda/*.cpp", - "aten/src/ATen/native/quantized/cudnn/*.cpp", - "aten/src/ATen/native/sparse/cuda/*.cpp", - "aten/src/ATen/native/transformers/cuda/*.cpp", - ], - ), -) - -filegroup( - name = "aten_cu_srcs", - srcs = glob([ - "aten/src/ATen/cuda/*.cu", - "aten/src/ATen/cuda/detail/*.cu", - "aten/src/ATen/native/cuda/*.cu", - "aten/src/ATen/native/nested/cuda/*.cu", - "aten/src/ATen/native/quantized/cuda/*.cu", - "aten/src/ATen/native/sparse/cuda/*.cu", - "aten/src/ATen/native/transformers/cuda/*.cu", - ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"), - # It's a bit puzzling to me why it's not necessary to declare the - # target that generates these sources... -) - -# TODO: Enable support for KleidiAI bazel build -header_template_rule( - name = "aten_src_ATen_config", - src = "aten/src/ATen/Config.h.in", - out = "aten/src/ATen/Config.h", - include = "aten/src", - substitutions = { - "@AT_MKLDNN_ENABLED@": "1", - "@AT_MKLDNN_ACL_ENABLED@": "0", - "@AT_MKL_ENABLED@": "1", - "@AT_MKL_SEQUENTIAL@": "0", - "@AT_POCKETFFT_ENABLED@": "0", - "@AT_NNPACK_ENABLED@": "0", - "@CAFFE2_STATIC_LINK_CUDA_INT@": "0", - "@AT_BUILD_WITH_BLAS@": "1", - "@AT_BUILD_WITH_LAPACK@": "1", - "@AT_PARALLEL_OPENMP@": "0", - "@AT_PARALLEL_NATIVE@": "1", - "@AT_BLAS_F2C@": "0", - "@AT_BLAS_USE_CBLAS_DOT@": "1", - "@AT_KLEIDIAI_ENABLED@": "0", - "@AT_USE_EIGEN_SPARSE@": "0", - }, -) - -header_template_rule( - name = "aten_src_ATen_cuda_config", - src = "aten/src/ATen/cuda/CUDAConfig.h.in", - out = "aten/src/ATen/cuda/CUDAConfig.h", - include = "aten/src", - substitutions = { - "@AT_CUDNN_ENABLED@": "1", - "@AT_CUSPARSELT_ENABLED@": "0", - "@AT_HIPSPARSELT_ENABLED@": "0", - "@AT_ROCM_ENABLED@": "0", - "@AT_MAGMA_ENABLED@": "0", - "@NVCC_FLAGS_EXTRA@": "", - }, -) - -cc_library( - name = "aten_headers", - hdrs = [ - "torch/csrc/Export.h", - "torch/csrc/jit/frontend/function_schema_parser.h", - ] + glob( - [ - "aten/src/**/*.h", - "aten/src/**/*.hpp", - "aten/src/ATen/cuda/**/*.cuh", - "aten/src/ATen/native/**/*.cuh", - "aten/src/THC/*.cuh", - ], - ) + [ - ":aten_src_ATen_config", - ":generated_aten_cpp", - ], - includes = [ - "aten/src", - ], - deps = [ - "//c10", - ], -) - -ATEN_COPTS = COMMON_COPTS + [ - "-DCAFFE2_BUILD_MAIN_LIBS", - "-DHAVE_AVX_CPU_DEFINITION", - "-DHAVE_AVX2_CPU_DEFINITION", - "-fvisibility-inlines-hidden", - "-fno-math-errno", - "-fno-trapping-math", -] - -intern_build_aten_ops( - copts = ATEN_COPTS, - extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"), - deps = [ - ":aten_headers", - "@fbgemm", - "@mkl", - "@sleef", - "@mkl_dnn//:mkl-dnn", - ], -) - -cc_library( - name = "aten", - srcs = [ - ":ATen_CORE_SRCS", - ":ATen_QUANTIZED_SRCS", - ":aten_base_cpp", - ":aten_base_metal", - ":aten_base_vulkan", - ":aten_native_cpp", - ":aten_native_mkl_cpp", - ":aten_native_mkldnn_cpp", - ":aten_native_nested_cpp", - ":aten_native_quantized_cpp", - ":aten_native_sparse_cpp", - ":aten_native_transformers_cpp", - ":aten_native_xnnpack", - ":aten_src_ATen_config", - ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"), - copts = ATEN_COPTS, - linkopts = [ - "-ldl", - ], - data = if_cuda( - [":libcaffe2_nvrtc.so"], - [], - ), - visibility = ["//visibility:public"], - deps = [ - ":ATen_CPU", - ":aten_headers", - ":caffe2_for_aten_headers", - ":torch_headers", - "@fbgemm", - "@ideep", - "@nlohmann", - ], - alwayslink = True, -) - -cc_library( - name = "aten_nvrtc", - srcs = glob([ - "aten/src/ATen/cuda/nvrtc_stub/*.cpp", - ]), - copts = ATEN_COPTS, - linkstatic = True, - visibility = ["//visibility:public"], - deps = [ - ":aten_headers", - "//c10", - "@cuda", - "@cuda//:cuda_driver", - "@cuda//:nvrtc", - ], - alwayslink = True, -) - -cc_binary( - name = "libcaffe2_nvrtc.so", - linkshared = True, - visibility = ["//visibility:public"], - deps = [ - ":aten_nvrtc", - ], -) - -cc_library( - name = "aten_cuda_cpp", - srcs = [":aten_cuda_cpp_srcs"] + generated_cuda_cpp, - hdrs = [":aten_src_ATen_cuda_config"], - copts = ATEN_COPTS, - visibility = ["//visibility:public"], - deps = [ - ":aten", - "@cuda", - "@cuda//:cusolver", - "@cuda//:nvrtc", - "@cudnn", - "@cudnn_frontend", - ], - alwayslink = True, -) - -torch_cuda_half_options = [ - "-DCUDA_HAS_FP16=1", - "-D__CUDA_NO_HALF_OPERATORS__", - "-D__CUDA_NO_HALF_CONVERSIONS__", - "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-D__CUDA_NO_HALF2_OPERATORS__", -] - -cu_library( - name = "aten_cuda", - srcs = [":aten_cu_srcs"], - copts = ATEN_COPTS + torch_cuda_half_options, - visibility = ["//visibility:public"], - deps = [ - ":aten_cuda_cpp", - "//c10/util:bit_cast", - "@cuda//:cublas", - "@cuda//:cufft", - "@cuda//:cusparse", - "@cutlass", - ], - alwayslink = True, -) - -# caffe2 -CAFFE2_COPTS = COMMON_COPTS + [ - "-Dcaffe2_EXPORTS", - "-DCAFFE2_USE_CUDNN", - "-DCAFFE2_BUILD_MAIN_LIB", - "-fvisibility-inlines-hidden", - "-fno-math-errno", - "-fno-trapping-math", -] - -filegroup( - name = "caffe2_core_srcs", - srcs = [ - "caffe2/core/common.cc", - ], -) - -filegroup( - name = "caffe2_perfkernels_srcs", - srcs = [ - "caffe2/perfkernels/embedding_lookup_idx.cc", - ], -) - - -filegroup( - name = "caffe2_serialize_srcs", - srcs = [ - "caffe2/serialize/file_adapter.cc", - "caffe2/serialize/inline_container.cc", - "caffe2/serialize/istream_adapter.cc", - "caffe2/serialize/read_adapter_interface.cc", - ], -) - -filegroup( - name = "caffe2_utils_srcs", - srcs = [ - "caffe2/utils/proto_wrap.cc", - "caffe2/utils/string_utils.cc", - "caffe2/utils/threadpool/ThreadPool.cc", - "caffe2/utils/threadpool/pthreadpool.cc", - "caffe2/utils/threadpool/pthreadpool_impl.cc", - "caffe2/utils/threadpool/thread_pool_guard.cpp", - ], -) - -# To achieve finer granularity and make debug easier, caffe2 is split into three libraries: -# ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under -# aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the -# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is split -# out from `caffe2` to avoid dependency cycle. -cc_library( - name = "caffe2_for_aten_headers", - hdrs = [ - "caffe2/core/common.h", - "caffe2/perfkernels/common.h", - "caffe2/perfkernels/embedding_lookup_idx.h", - "caffe2/utils/fixed_divisor.h", - ] + glob([ - "caffe2/utils/threadpool/*.h", - ]), - copts = CAFFE2_COPTS, - visibility = ["//visibility:public"], - deps = [ - ":caffe2_core_macros", - "//c10", - ], -) - -cc_library( - name = "caffe2_headers", - hdrs = glob( - [ - "caffe2/perfkernels/*.h", - "caffe2/serialize/*.h", - "caffe2/utils/*.h", - "caffe2/utils/threadpool/*.h", - "modules/**/*.h", - ], - exclude = [ - "caffe2/core/macros.h", - ], - ) + if_cuda(glob([ - "caffe2/**/*.cuh", - ])), - copts = CAFFE2_COPTS, - visibility = ["//visibility:public"], - deps = [ - ":caffe2_core_macros", - ":caffe2_for_aten_headers", - ], -) - -cc_library( - name = "caffe2", - srcs = [ - ":caffe2_core_srcs", - ":caffe2_perfkernels_srcs", - ":caffe2_serialize_srcs", - ":caffe2_utils_srcs", - ], - copts = CAFFE2_COPTS + ["-mf16c"], - linkstatic = 1, - visibility = ["//visibility:public"], - deps = [ - ":caffe2_core_macros", - ":caffe2_headers", - ":caffe2_perfkernels_avx", - ":caffe2_perfkernels_avx2", - "//third_party/miniz-3.0.2:miniz", - "@com_google_protobuf//:protobuf", - "@eigen", - "@fbgemm//:fbgemm_src_headers", - "@fmt", - "@onnx", - ] + if_cuda( - [ - ":aten_cuda", - "@tensorpipe//:tensorpipe_cuda", - ], - [ - ":aten", - "@tensorpipe//:tensorpipe_cpu", - ], - ), - alwayslink = True, -) - -cu_library( - name = "torch_cuda", - srcs = [ - "torch/csrc/distributed/c10d/NanCheck.cu", - "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", - "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", - ], - copts = torch_cuda_half_options, - visibility = ["//visibility:public"], - deps = [ - ":aten", - "@cuda//:cublas", - "@cuda//:curand", - "@cudnn", - "@eigen", - "@tensorpipe//:tensorpipe_cuda", - ], - alwayslink = True, -) - -PERF_COPTS = [ - "-DHAVE_AVX_CPU_DEFINITION", - "-DHAVE_AVX2_CPU_DEFINITION", - "-DENABLE_ALIAS=1", - "-DHAVE_MALLOC_USABLE_SIZE=1", - "-DHAVE_MMAP=1", - "-DHAVE_SHM_OPEN=1", - "-DHAVE_SHM_UNLINK=1", - "-DSLEEF_STATIC_LIBS=1", - "-DTH_BALS_MKL", - "-D_FILE_OFFSET_BITS=64", - "-DUSE_FBGEMM", - "-fvisibility-inlines-hidden", - "-Wunused-parameter", - "-fno-math-errno", - "-fno-trapping-math", - "-mf16c", -] - -PERF_HEADERS = glob([ - "caffe2/perfkernels/*.h", - "caffe2/core/*.h", -]) - -cc_library( - name = "caffe2_perfkernels_avx", - srcs = glob([ - "caffe2/perfkernels/*_avx.cc", - ]), - hdrs = PERF_HEADERS, - copts = PERF_COPTS + [ - "-mavx", - ], - visibility = ["//visibility:public"], - deps = [ - ":caffe2_headers", - "//c10", - ], - alwayslink = True, -) - -cc_library( - name = "caffe2_perfkernels_avx2", - srcs = glob([ - "caffe2/perfkernels/*_avx2.cc", - ]), - hdrs = PERF_HEADERS, - copts = PERF_COPTS + [ - "-mavx2", - "-mfma", - "-mavx", - ], - visibility = ["//visibility:public"], - deps = [ - ":caffe2_headers", - "//c10", - ], - alwayslink = True, -) - -# torch -torch_cuda_headers = glob(["torch/csrc/cuda/*.h"]) - -flatbuffer_cc_library( - name = "torch_flatbuffers", - srcs = [ - "torch/csrc/jit/serialization/mobile_bytecode.fbs", - ], - flatc_args = ["--cpp", "--gen-mutable", "--scoped-enums"], - out_prefix = "torch/csrc/jit/serialization/", -) - -cc_library( - name = "torch_headers", - hdrs = if_cuda( - torch_cuda_headers, - ) + glob( - [ - "torch/*.h", - "torch/csrc/**/*.h", - "torch/nativert/**/*.h", - "torch/csrc/distributed/c10d/**/*.hpp", - "torch/lib/libshm/*.h", - ], - exclude = [ - "torch/csrc/*/generated/*.h", - "torch/csrc/jit/serialization/mobile_bytecode_generated.h", - ] + torch_cuda_headers, - ) + GENERATED_AUTOGRAD_CPP + [ - "//torch/headeronly:version_h", - ], - includes = [ - "third_party/kineto/libkineto/include", - "torch/csrc", - "torch/csrc/api/include", - "torch/csrc/distributed", - "torch/lib", - "torch/lib/libshm", - ], - visibility = ["//visibility:public"], - deps = [ - ":aten_headers", - ":caffe2_headers", - ":torch_flatbuffers", - "//c10", - "@com_github_google_flatbuffers//:flatbuffers", - "@local_config_python//:python_headers", - "@onnx", - ], - alwayslink = True, -) - -TORCH_COPTS = COMMON_COPTS + [ - "-Dtorch_EXPORTS", - "-DHAVE_AVX_CPU_DEFINITION", - "-DHAVE_AVX2_CPU_DEFINITION", - "-DCAFFE2_USE_GLOO", - "-fvisibility-inlines-hidden", - "-fno-math-errno ", - "-fno-trapping-math", - "-Wno-error=unused-function", -] - -torch_sources = { - k: "" - for k in ( - libtorch_core_sources + - libtorch_distributed_sources + - torch_cpp_srcs + - libtorch_extra_sources + - jit_core_sources + - lazy_tensor_ts_sources + - GENERATED_AUTOGRAD_CPP - ) -}.keys() - -cc_library( - name = "torch", - srcs = if_cuda(glob( - libtorch_cuda_sources, - exclude = [ - "torch/csrc/cuda/nccl.cpp", - "torch/csrc/cuda/python_nccl.cpp", - "torch/csrc/distributed/c10d/NanCheck.cu", - "torch/csrc/distributed/c10d/cuda/AsyncMM.cu", - "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu", - "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu", - "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu", - "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp", - "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", - "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", - ], - )) + torch_sources, - copts = TORCH_COPTS, - linkopts = [ - "-lrt", - ], - defines = [ - "CAFFE2_NIGHTLY_VERSION=20200115", - ], - visibility = ["//visibility:public"], - deps = [ - ":caffe2", - ":torch_headers", - "@kineto", - "@cpp-httplib", - "@nlohmann", - ] + if_cuda([ - "@cuda//:nvToolsExt", - "@cutlass", - ":torch_cuda", - ]), - alwayslink = True, -) - -cc_library( - name = "shm", - srcs = glob(["torch/lib/libshm/*.cpp"]), - linkopts = [ - "-lrt", - ], - deps = [ - ":torch", - ], -) - -cc_library( - name = "libtorch_headers", - hdrs = glob([ - "**/*.h", - "**/*.cuh", - ]) + [ - # We need the filegroup here because the raw list causes Bazel - # to see duplicate files. It knows how to deduplicate with the - # filegroup. - ":cpp_generated_code", - ], - includes = [ - "torch/csrc/api/include", - "torch/csrc/distributed", - "torch/lib", - "torch/lib/libshm", - ], - visibility = ["//visibility:public"], - deps = [ - ":torch_headers", - ], -) - -cc_library( - name = "torch_python", - srcs = libtorch_python_core_sources - + if_cuda(libtorch_python_cuda_sources) - + if_cuda(libtorch_python_distributed_sources) - + GENERATED_AUTOGRAD_PYTHON, - hdrs = glob([ - "torch/csrc/generic/*.cpp", - ]), - copts = COMMON_COPTS + if_cuda(["-DUSE_CUDA=1"]), - deps = [ - ":torch", - ":shm", - "@pybind11", - ], -) - -pybind_extension( - name = "torch/_C", - srcs = ["torch/csrc/stub.c"], - deps = [ - ":torch_python", - ":aten_nvrtc", - ], -) - -cc_binary( - name = "torch/bin/torch_shm_manager", - srcs = [ - "torch/lib/libshm/manager.cpp", - ], - deps = [ - ":shm", - ], - linkstatic = False, -) - -template_rule( - name = "gen_version_py", - src = ":torch/version.py.tpl", - out = "torch/version.py", - substitutions = if_cuda({ - # Set default to 11.2. Otherwise Torchvision complains about incompatibility. - "{{CUDA_VERSION}}": "11.2", - "{{VERSION}}": "2.0.0", - }, { - "{{CUDA_VERSION}}": "None", - "{{VERSION}}": "2.0.0", - }), -) - -py_library( - name = "pytorch_py", - visibility = ["//visibility:public"], - srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]), - deps = [ - rules.requirement("numpy"), - rules.requirement("pyyaml"), - rules.requirement("requests"), - rules.requirement("setuptools"), - rules.requirement("sympy"), - rules.requirement("typing_extensions"), - "//torchgen", - ], - data = [ - ":torch/_C.so", - ":torch/bin/torch_shm_manager", - ], -) - -# cpp api tests -cc_library( - name = "test_support", - testonly = True, - srcs = [ - "test/cpp/api/support.cpp", - ], - hdrs = [ - "test/cpp/api/init_baseline.h", - "test/cpp/api/optim_baseline.h", - "test/cpp/api/support.h", - "test/cpp/common/support.h", - ], - deps = [ - ":torch", - "@com_google_googletest//:gtest_main", - ], -) - -# Torch integration tests rely on a labeled data set from the MNIST database. -# http://yann.lecun.com/exdb/mnist/ - -cpp_api_tests = glob( - ["test/cpp/api/*.cpp"], - exclude = [ - "test/cpp/api/imethod.cpp", - "test/cpp/api/integration.cpp", - ], -) - -cc_test( - name = "integration_test", - size = "medium", - srcs = ["test/cpp/api/integration.cpp"], - data = [ - ":download_mnist", - ], - tags = [ - "gpu-required", - ], - deps = [ - ":test_support", - "@com_google_googletest//:gtest_main", - ], -) - -[ - cc_test( - name = paths.split_extension(paths.basename(filename))[0].replace("-", "_") + "_test", - size = "medium", - srcs = [filename], - deps = [ - ":test_support", - "@com_google_googletest//:gtest_main", - ], - ) - for filename in cpp_api_tests -] - -test_suite( - name = "api_tests", - tests = [ - "any_test", - "autograd_test", - "dataloader_test", - "enum_test", - "expanding_array_test", - "functional_test", - "init_test", - "integration_test", - "jit_test", - "memory_test", - "misc_test", - "module_test", - "modulelist_test", - "modules_test", - "nn_utils_test", - "optim_test", - "ordered_dict_test", - "rnn_test", - "sequential_test", - "serialize_test", - "static_test", - "tensor_options_test", - "tensor_test", - "torch_include_test", - ], -) - -# dist autograd tests -cc_test( - name = "torch_dist_autograd_test", - size = "small", - srcs = ["test/cpp/dist_autograd/test_dist_autograd.cpp"], - tags = [ - "exclusive", - "gpu-required", - ], - deps = [ - ":torch", - "@com_google_googletest//:gtest_main", - ], -) - -# jit tests -# Because these individual unit tests require custom registering, -# it is easier to mimic the cmake build by globing together a single test. -cc_test( - name = "jit_tests", - size = "small", - srcs = glob( - [ - "test/cpp/jit/*.cpp", - "test/cpp/jit/*.h", - "test/cpp/tensorexpr/*.cpp", - "test/cpp/tensorexpr/*.h", - ], - exclude = [ - # skip this since is not found in OSS build - "test/cpp/jit/test_exception.cpp", - ], - ), - linkstatic = True, - tags = [ - "exclusive", - "gpu-required", - ], - deps = [ - ":torch", - "@com_google_googletest//:gtest_main", - ], -) - -cc_test( - name = "lazy_tests", - size = "small", - srcs = glob( - [ - "test/cpp/lazy/*.cpp", - "test/cpp/lazy/*.h", - ], - exclude = [ - # skip these since they depend on generated LazyIr.h which isn't available in bazel yet - "test/cpp/lazy/test_ir.cpp", - "test/cpp/lazy/test_lazy_ops.cpp", - "test/cpp/lazy/test_lazy_ops_util.cpp", - "test/cpp/lazy/test_lazy_graph_executor.cpp", - ], - ), - linkstatic = True, - tags = [ - "exclusive", - ], - deps = [ - ":torch", - "@com_google_googletest//:gtest_main", - ], -) - -# python api tests - -py_test( - name = "test_bazel", - srcs = ["test/_test_bazel.py"], - main = "test/_test_bazel.py", - deps = [ - ":pytorch_py", - rules.requirement("networkx"), - ], -) - -# all tests -test_suite( - name = "all_tests", - tests = [ - "api_tests", - "jit_tests", - "torch_dist_autograd_test", - "//c10/test:tests", - ], -) - -# An internal genrule that we are converging with refers to these file -# as if they are from this package, so we alias them for -# compatibility. - -[ - alias( - name = paths.basename(path), - actual = path, - ) - for path in [ - "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp", - "aten/src/ATen/templates/DispatchKeyNativeFunctions.h", - "aten/src/ATen/templates/LazyIr.h", - "aten/src/ATen/templates/LazyNonNativeIr.h", - "aten/src/ATen/templates/RegisterDispatchKey.cpp", - "aten/src/ATen/templates/RegisterDispatchDefinitions.ini", - "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp", - "aten/src/ATen/native/native_functions.yaml", - "aten/src/ATen/native/tags.yaml", - "aten/src/ATen/native/ts_native_functions.yaml", - "torch/csrc/lazy/core/shape_inference.h", - "torch/csrc/lazy/ts_backend/ts_native_functions.cpp", - ] -] - -genrule( - name = "download_mnist", - srcs = ["//:tools/download_mnist.py"], - outs = [ - "mnist/train-images-idx3-ubyte", - "mnist/train-labels-idx1-ubyte", - "mnist/t10k-images-idx3-ubyte", - "mnist/t10k-labels-idx1-ubyte", - ], - cmd = "python3 tools/download_mnist.py -d $(RULEDIR)/mnist", -) diff --git a/CLAUDE.md b/CLAUDE.md index 6f3f7ab6913e2..4c148a7684e8a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,3 +1,7 @@ +# Scratch Space + +Use `agent_space/` (git-ignored, at repo root) for temporary scripts, scratch files, and throwaway experiments. Do not commit files from this directory. + # PR Review When asked to review a PR, always use the /pr-review skill. @@ -47,6 +51,45 @@ entirely. Disclose that the PR was authored with Claude. +If a commit message contains `ghstack-source-id` or `Pull-Request` trailers, +you MUST preserve them when rewriting or splitting commit messages. ghstack +will update the source id automatically when needed. + +# ghstack Workflow + +ghstack commits follow a different workflow than the conventional GitHub branch +and PR workflow. First identify whether you're on a ghstack commit: + +- If HEAD is a detached commit, you are almost certainly in a ghstack flow. +- If the commit message contains a `ghstack-source-id` trailer, it is an + existing ghstack commit. +- If the commit is associated with a remote branch like `origin/gh/USERNAME/N`, + it is likely a ghstack commit (imperfect signal: local amends without a push + can desync this). + +Rules for working with ghstack: + +- **Don't amend unless asked.** If the user asks you to work on a ghstack + commit, leave changes uncommitted so the user can review with `git diff`. + Only amend into the commit if the user explicitly asks you to amend or to + submit it directly. +- **Submitting.** Run `ghstack` to submit. When only working on a single + commit, use `ghstack --no-stack` to avoid updating the rest of the stack and + burning unnecessary CI. Use a full `ghstack` when you're intentionally + updating CI for the whole stack. +- **Preserve metadata trailers.** When editing a commit message, never delete + `Pull-Request:` or `ghstack-source-id:` trailers. If you modified the commit + message, run `ghstack -u` afterwards to push the updated PR description. +- **Never push directly.** Do not `git push` to branches, and never directly + modify the `gh/USERNAME/N` branches — ghstack manages those. +- **Finding the PR.** If the user asks to pull CI results or code review for a + ghstack commit, get the PR URL from the `Pull-Request` trailer in the commit + message. Use `gh` CLI to fetch status/comments from there. +- **Editing earlier commits / splitting.** Treat it like a normal stack of + commits (use `git rebase`, etc.). Commits that keep their metadata trailers + stay associated with their existing PRs; commits without trailers will get a + fresh PR on submit. A full `ghstack` run is usually appropriate here. + # Coding Style Guidelines Follow these rules for all code changes in this repository: diff --git a/CMakeLists.txt b/CMakeLists.txt index fdb5062824815..dd7f97211ed90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,23 @@ endif() # ---[ Project and semantic versioning. project(Torch CXX C) +# When built via scikit-build-core, Python_EXECUTABLE is already set. +# For standalone cmake invocations, find it via find_package. +find_package(Python COMPONENTS Interpreter REQUIRED) + +# Forward environment variables to CMake variables, replicating the behavior +# previously handled by setup.py / tools/setup_helpers/cmake.py. +# When built via setup.py, these are already set as -D flags; the module +# skips variables that are already defined, so it is safe to include +# unconditionally. +include(cmake/EnvVarForwarding.cmake) + +# Pre-build steps: submodule init, NCCL checkout. +# When built via setup.py, these are handled by build_deps() in setup.py; +# the module checks for existing state, so it is safe to include +# unconditionally. +include(cmake/PreBuildSteps.cmake) + if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") set(LINUX TRUE) else() @@ -232,6 +249,7 @@ option(BUILD_MOBILE_AUTOGRAD "Build autograd function in mobile build (in development)" OFF) cmake_dependent_option(INSTALL_TEST "Install test binaries if BUILD_TEST is on" ON "BUILD_TEST" OFF) +option(USE_RELATIVE_PATHS "Use relative paths in generated files for ccache friendliness" OFF) option(USE_CPP_CODE_COVERAGE "Compile C/C++ with code coverage flags" OFF) option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) @@ -246,9 +264,11 @@ else() set(_USE_CUDA_EXPLICITLY_SET FALSE) endif() -option(USE_CUDA "Use CUDA" ON) +# CUDA is incompatible with TSAN +cmake_dependent_option(USE_CUDA "Use CUDA" ON "NOT USE_TSAN" OFF) option(USE_XPU "Use XPU" ON) +option(USE_MTIA "Use MTIA" OFF) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON "USE_CUDA AND LINUX AND BUILD_PYTHON" OFF) @@ -291,7 +311,7 @@ option(USE_NNPACK "Use NNPACK" ON) cmake_dependent_option(USE_NUMA "Use NUMA. Only available on Linux." ON "LINUX" OFF) cmake_dependent_option(USE_NVRTC "Use NVRTC. Only available if USE_CUDA is on." - OFF "USE_CUDA" OFF) + ON "USE_CUDA" OFF) option(USE_NUMPY "Use NumPy" ON) option(USE_OBSERVERS "Use observers module." OFF) option(USE_OPENCL "Use OpenCL" OFF) @@ -590,7 +610,7 @@ if(MSVC) # https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170 # * https://en.cppreference.com/w/cpp/preprocessor/replace#Predefined_macros set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") - set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus -Xcompiler /Zc:preprocessor") set(CMAKE_NINJA_CMCLDEPS_RC OFF) if(MSVC_Z7_OVERRIDE) @@ -1181,6 +1201,12 @@ if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS_DEBUG " -fno-omit-frame-pointer -O0") string(APPEND CMAKE_LINKER_FLAGS_DEBUG " -fno-omit-frame-pointer -O0") endif() + # aarch64 C++ stack unwinding uses frame-pointer chain walking, so frame + # pointers must be present in all build types. The cost is negligible on + # aarch64 (31 GPRs vs x86-64's 16, so dedicating x29 rarely spills). + if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64") + append_cxx_flag_if_supported("-fno-omit-frame-pointer" CMAKE_CXX_FLAGS) + endif() append_cxx_flag_if_supported("-fno-math-errno" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-fno-trapping-math" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=format" CMAKE_CXX_FLAGS) @@ -1229,10 +1255,11 @@ if(USE_CPP_CODE_COVERAGE) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") string(APPEND CMAKE_C_FLAGS " --coverage -fprofile-abs-path") string(APPEND CMAKE_CXX_FLAGS " --coverage -fprofile-abs-path") + string(APPEND CMAKE_OBJCXX_FLAGS " --coverage -fprofile-abs-path") elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") string(APPEND CMAKE_C_FLAGS " -fprofile-instr-generate -fcoverage-mapping") - string(APPEND CMAKE_CXX_FLAGS - " -fprofile-instr-generate -fcoverage-mapping") + string(APPEND CMAKE_CXX_FLAGS " -fprofile-instr-generate -fcoverage-mapping") + string(APPEND CMAKE_OBJCXX_FLAGS " -fprofile-instr-generate -fcoverage-mapping") else() message( ERROR diff --git a/CODEOWNERS b/CODEOWNERS index 5fc084f4799f1..ddac59548f0a8 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -48,6 +48,17 @@ nn/qat/ @jerryzh168 # c10d backend APIs /torch/csrc/distributed/c10d/Backend.* @kwen2501 /torch/csrc/distributed/c10d/Ops.* @kwen2501 +# FSDP, DDP +/torch/distributed/_composable/fsdp/ @weifengpy +/torch/distributed/_composable/replicate.py @weifengpy +/torch/distributed/_composable/replicate_with_fsdp.py @weifengpy +/torch/distributed/fsdp/ @weifengpy +/torch/distributed/optim/ @weifengpy +/torch/nn/parallel/ @weifengpy +# DTensor view, matmul, _StridedShard +/torch/distributed/tensor/placement_types.py @weifengpy +/torch/distributed/tensor/_ops/_view_ops.py @weifengpy +/torch/distributed/tensor/_ops/_matrix_ops.py @weifengpy # ONNX Export /torch/_dynamo/backends/onnxrt.py @titaiwangms @xadupre @justinchuby @@ -131,10 +142,10 @@ aten/src/ATen/detail/MTIAHooksInterface.h @egienvalue torch/csrc/mtia/ @egienvalue # Profiler -torch/csrc/autograd/profiler* @scotts -torch/autograd/profiler* @scotts -torch/csrc/profiler/ @scotts -torch/profiler/ @scotts +torch/csrc/autograd/profiler* @scotts @ryanzhang22 +torch/autograd/profiler* @scotts @ryanzhang22 +torch/csrc/profiler/ @scotts @ryanzhang22 +torch/profiler/ @scotts @ryanzhang22 # AOTDispatch tests test/functorch/test_aotdispatch.py @ezyang @Chillee @@ -173,9 +184,9 @@ caffe2/utils/hip @jeffdaily @jithunnair-amd /torch/_export/serde/schema.py @SherlockNoMad @zhxchen17 # Dynamic Shapes -/torch/fx/experimental/symbolic_shapes.py @bobrenjc93 @laithsakka -/torch/fx/experimental/sym_node.py @bobrenjc93 @laithsakka -/torch/fx/experimental/recording.py @bobrenjc93 @laithsakka +/torch/fx/experimental/symbolic_shapes.py @laithsakka +/torch/fx/experimental/sym_node.py @laithsakka +/torch/fx/experimental/recording.py @laithsakka # ProxyTorchDispatchMode torch/fx/experimental/proxy_tensor.py @aorenste diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4e5850e354814..e92aec464946a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,6 +16,7 @@ aspects of contributing to PyTorch. - [Codebase structure](#codebase-structure) - [AI-Assisted Development](#ai-assisted-development) - [Spin](#spin) + - [Building](#building) - [Linting](#linting) - [default lint](#default-lint) - [Regenerating](#regenerating) @@ -300,6 +301,14 @@ helps running common tasks. To list the available tasks, run `spin --help`. Currently, we support the following tasks with Spin: +### Building + +To support building and general development, the following commands exist. + +|command|| +|-|-| +|`clean`|clean, that is remove files and directories listed in .gitignore before the NOT-CLEAN-FILES marker| + ### Linting Spin helps with linting by making sure that lintrunner is installed correctly diff --git a/Dockerfile b/Dockerfile index 7b28e9b056e62..871143c82ddff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,9 +45,12 @@ ARG INSTALL_CHANNEL=whl/nightly ARG TARGETPLATFORM # INSTALL_CHANNEL whl - release, whl/nightly - nightly, whl/test - test channels -RUN case ${TARGETPLATFORM} in \ +# TODO: revert cu132->cu130 fallback once cu132 wheels are published +RUN WHEEL_CUDA_PATH="${CUDA_PATH#.}"; \ + if [ "${WHEEL_CUDA_PATH}" = "cu132" ]; then WHEEL_CUDA_PATH=cu130; fi; \ + case ${TARGETPLATFORM} in \ "linux/arm64") pip3 install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio ;; \ - *) pip3 install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${CUDA_PATH#.}/ torch torchvision torchaudio ;; \ + *) pip3 install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${WHEEL_CUDA_PATH}/ torch torchvision torchaudio ;; \ esac RUN pip3 install torchelastic RUN IS_CUDA=$(python3 -c 'import torch ; print(torch.cuda._is_compiled())'); \ diff --git a/RELEASE.md b/RELEASE.md index 6d10e359c1324..3120931b38557 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -130,7 +130,7 @@ Following requirements need to be met prior to cutting a release branch: * Triton release branch must be created (e.g., [release/3.6.x](https://github.com/triton-lang/triton/tree/release/3.6.x)) and the Triton pin update PR must be landed (e.g., [#168096](https://github.com/pytorch/pytorch/pull/168096)) at least 1 week before the branch cut * Resolve all outstanding issues in the milestones that are feature work and release blocking (for example [release 2.10 milestone](https://github.com/pytorch/pytorch/milestone/57)). A report of outstanding cherry-picks can be produced by running the [github-analytics-daily workflow](https://github.com/pytorch/test-infra/blob/main/.github/workflows/github-analytics-daily.yml) * Validate that all new workflows have been created in the PyTorch and domain libraries included in the release. Validate it against all dimensions of release matrix, including operating systems (Linux, macOS, Windows), Python versions as well as CPU architectures (x86 and arm) and accelerator versions (CUDA, ROCm, XPU). -* All [viable/strict](.github/workflows/update-viablestrict.yml) jobs are green, which requires the following jobs to pass: `pull`, `trunk`, `lint`, `linux-aarch64` +* All [viable/strict](.github/workflows/update-viablestrict.yml) jobs are green, which requires the following jobs to pass: `pull`, `trunk`, `lint` * All the nightly jobs for pytorch and domain libraries should be green. Validate this using the following HUD links: * [PyTorch](https://hud.pytorch.org/hud/pytorch/pytorch/nightly) * [TorchVision](https://hud.pytorch.org/hud/pytorch/vision/nightly) diff --git a/WORKSPACE b/WORKSPACE deleted file mode 100644 index c7d6307a9a394..0000000000000 --- a/WORKSPACE +++ /dev/null @@ -1,369 +0,0 @@ -workspace(name = "pytorch") - -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") -load("//tools/rules:workspace.bzl", "new_patched_local_repository") - -http_archive( - name = "rules_cc", - patches = [ - "//:tools/rules_cc/cuda_support.patch", - ], - strip_prefix = "rules_cc-40548a2974f1aea06215272d9c2b47a14a24e556", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/40548a2974f1aea06215272d9c2b47a14a24e556.tar.gz", - "https://github.com/bazelbuild/rules_cc/archive/40548a2974f1aea06215272d9c2b47a14a24e556.tar.gz", - ], -) - -http_archive( - name = "rules_cuda", - strip_prefix = "runtime-b1c7cce21ba4661c17ac72421c6a0e2015e7bef3/third_party/rules_cuda", - urls = ["https://github.com/tensorflow/runtime/archive/b1c7cce21ba4661c17ac72421c6a0e2015e7bef3.tar.gz"], -) - -http_archive( - name = "platforms", - urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/platforms/releases/download/0.0.10/platforms-0.0.10.tar.gz", - # TODO Fix bazel linter to support hashes for release tarballs. - # "https://github.com/bazelbuild/platforms/releases/download/0.0.10/platforms-0.0.10.tar.gz", - ], - # sha256 = "218efe8ee736d26a3572663b374a253c012b716d8af0c07e842e82f238a0a7ee", -) - -load("@rules_cuda//cuda:dependencies.bzl", "rules_cuda_dependencies") - -rules_cuda_dependencies(with_rules_cc = False) - -load("@rules_cc//cc:repositories.bzl", "rules_cc_toolchains") - -rules_cc_toolchains() - -http_archive( - name = "bazel_skylib", - urls = [ - "https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz", - ], -) - -http_archive( - name = "pybind11_bazel", - strip_prefix = "pybind11_bazel-b162c7c88a253e3f6b673df0c621aca27596ce6b", - urls = ["https://github.com/pybind/pybind11_bazel/archive/b162c7c88a253e3f6b673df0c621aca27596ce6b.zip"], -) - -new_local_repository( - name = "pybind11", - build_file = "@pybind11_bazel//:pybind11.BUILD", - path = "third_party/pybind11", -) - -http_archive( - name = "com_github_glog", - build_file_content = """ -licenses(['notice']) - -load(':bazel/glog.bzl', 'glog_library') -# TODO: figure out why enabling gflags leads to SIGSEV on the logging init -glog_library(with_gflags=0) - """, - strip_prefix = "glog-0.4.0", - urls = [ - "https://github.com/google/glog/archive/v0.4.0.tar.gz", - ], -) - -http_archive( - name = "com_github_gflags_gflags", - strip_prefix = "gflags-2.2.2", - urls = [ - "https://github.com/gflags/gflags/archive/v2.2.2.tar.gz", - ], -) - -new_local_repository( - name = "gloo", - build_file = "//third_party:gloo.BUILD", - path = "third_party/gloo", -) - -new_local_repository( - name = "onnx", - build_file = "//third_party:onnx.BUILD", - path = "third_party/onnx", -) - -local_repository( - name = "com_google_protobuf", - path = "third_party/protobuf", -) - -new_local_repository( - name = "eigen", - build_file = "//third_party:eigen.BUILD", - path = "third_party/eigen", -) - -new_local_repository( - name = "cutlass", - build_file = "//third_party:cutlass.BUILD", - path = "third_party/cutlass", -) - -new_local_repository( - name = "fbgemm", - build_file = "//third_party:fbgemm/BUILD.bazel", - path = "third_party/fbgemm", - repo_mapping = {"@cpuinfo": "@org_pytorch_cpuinfo"}, -) - -new_local_repository( - name = "ideep", - build_file = "//third_party:ideep.BUILD", - path = "third_party/ideep", -) - -new_local_repository( - name = "mkl_dnn", - build_file = "//third_party:mkl-dnn.BUILD", - path = "third_party/ideep/mkl-dnn", -) - -new_local_repository( - name = "org_pytorch_cpuinfo", - build_file = "//third_party:cpuinfo/BUILD.bazel", - path = "third_party/cpuinfo", -) - -new_local_repository( - name = "asmjit", - build_file = "//third_party:fbgemm/external/asmjit.BUILD", - path = "third_party/fbgemm/external/asmjit", -) - -new_local_repository( - name = "sleef", - build_file = "//third_party:sleef.BUILD", - path = "third_party/sleef", -) - -new_local_repository( - name = "fmt", - build_file = "//third_party:fmt.BUILD", - path = "third_party/fmt", -) - -new_local_repository( - name = "kineto", - build_file = "//third_party:kineto.BUILD", - path = "third_party/kineto", -) - -new_local_repository( - name = "opentelemetry-cpp", - build_file = "//third_party::opentelemetry-cpp.BUILD", - path = "third_party/opentelemetry-cpp", -) - -new_local_repository( - name = "cpp-httplib", - build_file = "//third_party:cpp-httplib.BUILD", - path = "third_party/cpp-httplib", -) - -new_local_repository( - name = "nlohmann", - build_file = "//third_party:nlohmann.BUILD", - path = "third_party/nlohmann", -) - -new_local_repository( - name = "moodycamel", - build_file = "//third_party:moodycamel.BUILD", - path = "third_party/concurrentqueue", -) - -new_local_repository( - name = "tensorpipe", - build_file = "//third_party:tensorpipe.BUILD", - path = "third_party/tensorpipe", -) - -http_archive( - name = "mkl", - build_file = "//third_party:mkl.BUILD", - sha256 = "59154b30dd74561e90d547f9a3af26c75b6f4546210888f09c9d4db8f4bf9d4c", - strip_prefix = "lib", - urls = [ - "https://anaconda.org/anaconda/mkl/2020.0/download/linux-64/mkl-2020.0-166.tar.bz2", - ], -) - -http_archive( - name = "mkl_headers", - build_file = "//third_party:mkl_headers.BUILD", - sha256 = "2af3494a4bebe5ddccfdc43bacc80fcd78d14c1954b81d2c8e3d73b55527af90", - urls = [ - "https://anaconda.org/anaconda/mkl-include/2020.0/download/linux-64/mkl-include-2020.0-166.tar.bz2", - ], -) - -http_archive( - name = "rules_python", - # TODO Fix bazel linter to support hashes for release tarballs. - # - # sha256 = "94750828b18044533e98a129003b6a68001204038dc4749f40b195b24c38f49f", - strip_prefix = "rules_python-0.21.0", - url = "https://github.com/bazelbuild/rules_python/releases/download/0.21.0/rules_python-0.21.0.tar.gz", -) - -load("@rules_python//python:repositories.bzl", "py_repositories") - -py_repositories() - -load("@rules_python//python:repositories.bzl", "python_register_toolchains") - -python_register_toolchains( - name = "python3_10", - python_version = "3.10", -) - -load("@python3_10//:defs.bzl", "interpreter") -load("@rules_python//python:pip.bzl", "pip_parse") - -pip_parse( - name = "pip_deps", - python_interpreter_target = interpreter, - requirements_lock = "//:tools/build/bazel/requirements.txt", -) - -load("@pip_deps//:requirements.bzl", "install_deps") - -install_deps() - -load("@pybind11_bazel//:python_configure.bzl", "python_configure") - -python_configure( - name = "local_config_python", - python_interpreter_target = interpreter, -) - -load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") - -protobuf_deps() - -new_local_repository( - name = "cuda", - build_file = "@//third_party:cuda.BUILD", - path = "/usr/local/cuda", -) - -new_local_repository( - name = "cudnn", - build_file = "@//third_party:cudnn.BUILD", - path = "/usr/local/cuda", -) - -new_local_repository( - name = "cudnn_frontend", - build_file = "@//third_party:cudnn_frontend.BUILD", - path = "third_party/cudnn_frontend/", -) - -local_repository( - name = "com_github_google_flatbuffers", - path = "third_party/flatbuffers", -) - -local_repository( - name = "google_benchmark", - path = "third_party/benchmark", -) - -local_repository( - name = "com_google_googletest", - path = "third_party/googletest", -) - -local_repository( - name = "pthreadpool", - path = "third_party/pthreadpool", - repo_mapping = {"@com_google_benchmark": "@google_benchmark"}, -) - -local_repository( - name = "FXdiv", - path = "third_party/FXdiv", - repo_mapping = {"@com_google_benchmark": "@google_benchmark"}, -) - -local_repository( - name = "XNNPACK", - path = "third_party/XNNPACK", - repo_mapping = {"@com_google_benchmark": "@google_benchmark"}, -) - -local_repository( - name = "gemmlowp", - path = "third_party/gemmlowp/gemmlowp", -) - -local_repository( - name = "kleidiai", - path = "third_party/kleidiai", - repo_mapping = {"@com_google_googletest": "@com_google_benchmark"}, -) - -### Unused repos start - -# `unused` repos are defined to hide bazel files from submodules of submodules. -# This allows us to run `bazel build //...` and not worry about the submodules madness. -# Otherwise everything traverses recursively and a lot of submodules of submodules have -# they own bazel build files. - -local_repository( - name = "unused_tensorpipe_googletest", - path = "third_party/tensorpipe/third_party/googletest", -) - -local_repository( - name = "unused_fbgemm", - path = "third_party/fbgemm", -) - -local_repository( - name = "unused_ftm_bazel", - path = "third_party/fmt/support/bazel", -) - -local_repository( - name = "unused_kineto_fmt_bazel", - path = "third_party/kineto/libkineto/third_party/fmt/support/bazel", -) - -local_repository( - name = "unused_kineto_dynolog_googletest", - path = "third_party/kineto/libkineto/third_party/dynolog/third_party/googletest", -) - -local_repository( - name = "unused_kineto_dynolog_gflags", - path = "third_party/kineto/libkineto/third_party/dynolog/third_party/gflags", -) - -local_repository( - name = "unused_kineto_dynolog_glog", - path = "third_party/kineto/libkineto/third_party/dynolog/third_party/glog", -) - -local_repository( - name = "unused_kineto_googletest", - path = "third_party/kineto/libkineto/third_party/googletest", -) - -local_repository( - name = "unused_onnx_benchmark", - path = "third_party/onnx/third_party/benchmark", -) - -### Unused repos end diff --git a/android/pytorch_android/CMakeLists.txt b/android/pytorch_android/CMakeLists.txt index c640cbdeac9f1..4ee3011683011 100644 --- a/android/pytorch_android/CMakeLists.txt +++ b/android/pytorch_android/CMakeLists.txt @@ -14,7 +14,7 @@ endif() include(GNUInstallDirs) -set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) message(STATUS "ANDROID_STL:${ANDROID_STL}") diff --git a/android/pytorch_android_torchvision/CMakeLists.txt b/android/pytorch_android_torchvision/CMakeLists.txt index 2c8931f3fb911..b49b779a18b1d 100644 --- a/android/pytorch_android_torchvision/CMakeLists.txt +++ b/android/pytorch_android_torchvision/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.10) project(pytorch_vision_jni CXX) -set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_VERBOSE_MAKEFILE ON) set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 1e06d30774d90..06522a9298893 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -892,12 +892,23 @@ if(USE_MPS) metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1") endforeach() air_to_metallib(kernels_basic.metallib ${AIR_BASIC}) + set(METALLIB_DEPS kernels_basic.metallib) + if(CAN_COMPILE_METAL_40) + foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) + cmake_path(GET SHADER STEM TGT_STEM) + string(CONCAT TGT_40 ${TGT_STEM} "_40.air") + list(APPEND AIR_40 ${TGT_40}) + metal_to_air(${SHADER} ${TGT_40} "-std=metal4.0") + endforeach() + air_to_metallib(kernels_40.metallib ${AIR_40}) + list(APPEND METALLIB_DEPS kernels_40.metallib) + endif() add_custom_command( COMMAND echo "// $$(date)" > metallib_dummy.cpp - DEPENDS kernels_basic.metallib + DEPENDS ${METALLIB_DEPS} OUTPUT metallib_dummy.cpp COMMENT "Updating metallibs timestamp") - add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp) + add_custom_target(metallibs DEPENDS ${METALLIB_DEPS} metallib_dummy.cpp) else() file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps") foreach(SHADER ${native_mps_metal} ${native_mps_sparse_metal}) @@ -925,7 +936,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${native_mtia_h} ${cudnn_h} ${hip_h} ${mtia_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${native_mtia_h} ${cudnn_h} ${hip_h} ${mtia_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index b8cfa42295432..cb07afa1f8fdd 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -85,6 +85,29 @@ std::string precision2str(Float32Precision prec) { TORCH_CHECK(false, "Invalid enum Float32Precision(", static_cast(prec), ")"); } +CuDNNDepthwiseKernel str2cudnn_depthwise(const std::string& name) { + if (name == "auto") + return CuDNNDepthwiseKernel::AUTO; + else if (name == "cudnn") + return CuDNNDepthwiseKernel::CUDNN; + else if (name == "native") + return CuDNNDepthwiseKernel::NATIVE; + TORCH_CHECK(false, "Unknown cuDNN depthwise kernel mode: ", name, + ". Expected one of: auto, cudnn, native"); +} + +std::string cudnn_depthwise2str(CuDNNDepthwiseKernel k) { + switch (k) { + case CuDNNDepthwiseKernel::AUTO: + return "auto"; + case CuDNNDepthwiseKernel::CUDNN: + return "cudnn"; + case CuDNNDepthwiseKernel::NATIVE: + return "native"; + } + TORCH_CHECK(false, "Invalid enum CuDNNDepthwiseKernel(", static_cast(k), ")"); +} + #ifdef USE_ROCM static constexpr const auto rocm_allow_group_gemm_ck = "ROCM_ALLOW_GROUP_GEMM_CK"; #endif @@ -182,6 +205,14 @@ void Context::setUserEnabledNNPACK(bool e) { enabled_nnpack = e; } +CuDNNDepthwiseKernel Context::cudnnDepthwiseKernel() const { + return depthwise_kernel_cudnn; +} + +void Context::setCuDNNDepthwiseKernel(CuDNNDepthwiseKernel k) { + depthwise_kernel_cudnn = k; +} + bool Context::allowTF32CuDNN(std::optional op) const { if (!op.has_value()) { bool allow_tf32_rnn = float32Precision(Float32Backend::CUDA, Float32Op::RNN) == Float32Precision::TF32; @@ -452,37 +483,29 @@ void Context::setLinalgPreferredBackend(at::LinalgBackend b) { } } -at::BlasBackend Context::blasPreferredBackend() { - // Rather than put logic for interpreting what Default means at every - // call site for blasPreferredBackend(), we set it to an actual value. - if (blas_preferred_backend == at::BlasBackend::Default) { -#ifdef USE_ROCM - // ROCm - BLAS is default. May change to Lt in the code below. - blas_preferred_backend = at::BlasBackend::Cublas; -#else - // CUDA - Lt by default if available - blas_preferred_backend = hasCuBLASLt() - ? at::BlasBackend::Cublaslt - : at::BlasBackend::Cublas; -#endif - // This logic sits in the getter because it needs to validate - // values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT - // which initialize the backend without calling the setter +at::BlasBackend Context::blasDefaultBackend() { + at::BlasBackend result = at::BlasBackend::Cublas; #ifdef USE_ROCM - // AMD Instinct targets prefer hipblaslt - static const bool hipblaslt_preferred = []() { - const auto& archs = detail::getCUDAHooks().getHipblasltPreferredArchs(); - for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) { - if (!detail::getCUDAHooks().isGPUArch(archs, index)) { - return false; - } + // AMD Instinct targets prefer hipblaslt + static const bool hipblaslt_preferred = []() { + const auto& archs = detail::getCUDAHooks().getHipblasltPreferredArchs(); + for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) { + if (!detail::getCUDAHooks().isGPUArch(archs, index)) { + return false; } - return true; - }(); - if (hipblaslt_preferred) { - blas_preferred_backend = at::BlasBackend::Cublaslt; } + return true; + }(); + if (hipblaslt_preferred) { + result = at::BlasBackend::Cublaslt; + } #endif + return result; +} + +at::BlasBackend Context::blasPreferredBackend() { + if (blas_preferred_backend == at::BlasBackend::Default) { + blas_preferred_backend = blasDefaultBackend(); } #ifdef USE_ROCM @@ -699,7 +722,7 @@ at::QEngine Context::qEngine() const { qengine = at::kONEDNN; #endif -#ifdef USE_FBGEMM +#if defined(USE_FBGEMM) && (defined(__x86_64__) || defined(_M_X64)) if (fbgemm::fbgemmSupportedCPU()) { /* X86 is enabled if and only if fbgemm is available. * It combines goodness of fbgemm and onednn by dispatching. @@ -738,7 +761,7 @@ const std::vector& Context::supportedQEngines() { engines.push_back(at::kONEDNN); #endif -#ifdef USE_FBGEMM +#if defined(USE_FBGEMM) && (defined(__x86_64__) || defined(_M_X64)) if (fbgemm::fbgemmSupportedCPU()) { engines.push_back(at::kX86); // The X86 qengine is available if and only if FBGEMM is available diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index de6a7dda66d73..6a7eb4b669da9 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -49,10 +49,14 @@ enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN }; enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL }; enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 }; +enum class TORCH_API CuDNNDepthwiseKernel { AUTO, CUDNN, NATIVE }; + TORCH_API Float32Backend str2backend(const std::string& name); TORCH_API Float32Op str2op(const std::string& name); TORCH_API Float32Precision str2precision(const std::string& name); TORCH_API std::string precision2str(Float32Precision prec); +TORCH_API CuDNNDepthwiseKernel str2cudnn_depthwise(const std::string& name); +TORCH_API std::string cudnn_depthwise2str(CuDNNDepthwiseKernel k); class TORCH_API Context { public: @@ -251,6 +255,9 @@ class TORCH_API Context { bool userEnabledNNPACK() const; void setUserEnabledNNPACK(bool e); + CuDNNDepthwiseKernel cudnnDepthwiseKernel() const; + void setCuDNNDepthwiseKernel(CuDNNDepthwiseKernel k); + // Note [Disabling Fused SDP Kernels] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Flash and Memory Efficient SDP kernels are enabled by default. @@ -289,6 +296,7 @@ class TORCH_API Context { at::LinalgBackend linalgPreferredBackend() const; void setLinalgPreferredBackend(at::LinalgBackend /*b*/); + at::BlasBackend blasDefaultBackend(); at::BlasBackend blasPreferredBackend(); void setBlasPreferredBackend(at::BlasBackend /*b*/); @@ -493,6 +501,7 @@ class TORCH_API Context { bool enabled_mkldnn = true; bool allow_tf32_onednn = false; bool enabled_nnpack = true; + CuDNNDepthwiseKernel depthwise_kernel_cudnn = CuDNNDepthwiseKernel::AUTO; at::LinalgBackend linalg_preferred_backend = (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true || c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index efab9ec9c5927..fed2e6c4febe8 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -136,6 +137,11 @@ c10::DeviceCapability getDeviceCapability(c10::DeviceIndex device_index) { c10::impl::VirtualGuardImpl impl(device_type); return impl.getDeviceCapability({device_type, device_index}); } + +void emptyHostCache() { + const auto device_type = getAccelerator(true).value(); + at::getHostAllocator(device_type)->empty_cache(); +} // NOLINTEND(bugprone-unchecked-optional-access) } // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index 81678ba6efc29..90fad38a5faca 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -78,11 +78,19 @@ TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); TORCH_API c10::DeviceCapability getDeviceCapability( c10::DeviceIndex device_index); +// Releases all unused device memory currently held by the accelerator's +// device-side caching allocator. The freed memory becomes available for reuse +// by other applications or processes. TORCH_API inline void emptyCache() { const auto device_type = getAccelerator(true).value(); at::getDeviceAllocator(device_type)->emptyCache(); } +// Releases all unused host (pinned) memory currently held by the accelerator's +// host-side caching allocator. The freed memory becomes available for reuse by +// other applications or processes. +TORCH_API void emptyHostCache(); + TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( c10::DeviceIndex device_index) { const auto device_type = getAccelerator(true).value(); diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index 123d87b304148..c372ae0ad339f 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -105,12 +105,10 @@ static Tensor unsqueeze_copy_to(const Tensor & self, IntArrayRef dim, c10::SymIn Tensor FunctionalInverses::_fw_primal_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _fw_primal() during the functionalization pass. For now, this is not supported."); - return Tensor(); } Tensor FunctionalInverses::_make_dual_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode, const at::Tensor& tangent, int64_t level) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _make_dual() during the functionalization pass. For now, this is not supported."); - return Tensor(); } Tensor FunctionalInverses::view_as_real_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { @@ -301,7 +299,6 @@ Tensor FunctionalInverses::transpose_int_inverse(const Tensor& base, const Tenso Tensor FunctionalInverses::_nested_view_from_buffer_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& nested_sizes, const Tensor& nested_strides, const Tensor& storage_offsets) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _nested_view_from_buffer() during the functionalization pass. For now, nested tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_nested_view_from_jagged_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, const Tensor& offsets, const Tensor& dummy, const std::optional& lengths, int64_t ragged_idx, const std::optional& min_seqlen, const std::optional& max_seqlen) { @@ -342,47 +339,38 @@ Tensor FunctionalInverses::unsqueeze_inverse(const Tensor& base, const Tensor& m Tensor FunctionalInverses::_indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::indices_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call values() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, at::IntArrayRef size) { TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::col_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call col_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::ccol_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call ccol_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::row_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) { TORCH_INTERNAL_ASSERT(false, "Attempted to call row_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization"); - return Tensor(); } Tensor FunctionalInverses::unbind_int_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, int64_t dim) { diff --git a/aten/src/ATen/TensorIterator.h b/aten/src/ATen/TensorIterator.h index d8593a80292b3..f5616b41e7ab1 100644 --- a/aten/src/ATen/TensorIterator.h +++ b/aten/src/ATen/TensorIterator.h @@ -734,7 +734,7 @@ struct TORCH_API TensorIteratorBase : public impl::MetaBase { }; struct TORCH_API TensorIterator final : public TensorIteratorBase { - TensorIterator() : TensorIteratorBase() {} + TensorIterator() = default; // Slicing is OK, TensorIterator guaranteed NOT to have any fields TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} diff --git a/aten/src/ATen/Version.cpp b/aten/src/ATen/Version.cpp index 682883b32c187..4a1c1525157f3 100644 --- a/aten/src/ATen/Version.cpp +++ b/aten/src/ATen/Version.cpp @@ -218,8 +218,9 @@ std::string get_cxx_flags() { "Buck does not populate the `CXX_FLAGS` field of Caffe2 build options. " "As a result, `get_cxx_flags` is OSS only." ); - #endif + #else return caffe2::GetBuildOptions().at("CXX_FLAGS"); + #endif } } diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 7715ece7c6187..bb9c44dcf0beb 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -570,6 +570,10 @@ class TORCH_API TensorBase { return impl_->is_meta(); } + bool is_fake() const { + return impl_->is_fake(); + } + /// Returns if a `Tensor` is an inference tensor. bool is_inference() const { return impl_->is_inference(); diff --git a/aten/src/ATen/core/Vitals.cpp b/aten/src/ATen/core/Vitals.cpp deleted file mode 100644 index db58c03830539..0000000000000 --- a/aten/src/ATen/core/Vitals.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include -#include -#include - -namespace at::vitals { - -APIVitals VitalsAPI; - -std::ostream& operator<<(std::ostream& os, TorchVital const& tv) { - for (const auto& m : tv.attrs) { - os << "[TORCH_VITAL] " << tv.name << '.' << m.first << "\t\t " - << m.second.value << '\n'; - } - return os; -} - -TorchVital::~TorchVital() { - if (torchVitalEnabled()) { - std::cout << *this; - } -} - -TorchVitalAttr& TorchVital::create(const std::string& attr) { - return create(attr, /* force = */ false); -} - -TorchVitalAttr& TorchVital::create(const std::string& attr, bool force) { - if (!(torchVitalEnabled() || force)) { - static TorchVitalAttr disabled; - return disabled; - } - auto iter = attrs.find(attr); - if (iter == attrs.end()) { - auto r = attrs.emplace(attr, TorchVitalAttr()); - return r.first->second; - } - return iter->second; -} - -bool torchVitalEnabled() { - // If this is a performance hit, make `enabled` variable static - // and return `const bool&` instead - bool enabled = []() { - auto const e = c10::utils::get_env("TORCH_VITAL"); - if (e.has_value()) { - return !e.value().empty(); - } - return false; - }(); - if (enabled) { - VitalsAPI.vitals_enabled = true; - } - return VitalsAPI.vitals_enabled; -} - -std::string APIVitals::readVitals() { - if (!torchVitalEnabled()) { - return ""; - } - - std::stringstream buf; - for (const auto& x : name_map_) { - buf << x.second; - } - return buf.str(); -} - -bool APIVitals::setVital( - const std::string& vital_name, - const std::string& attr_name, - const std::string& value, - bool force) { - if (!(torchVitalEnabled() || force)) { - return false; - } - - auto iter = name_map_.find(vital_name); - TorchVital* vital = nullptr; - if (iter == name_map_.end()) { - auto r = name_map_.emplace(vital_name, TorchVital(vital_name)); - vital = &r.first->second; - } else { - vital = &iter->second; - } - - vital->create(attr_name, force).write(value, force); - return true; -} - -APIVitals::APIVitals() : vitals_enabled(false) { - // Set default values, force is necessary because in unit tests the env - // variable may not be set when global APIVitals are constructed. - setVital("CUDA", "used", "False", /* force = */ true); -} - -} // namespace at::vitals diff --git a/aten/src/ATen/core/Vitals.h b/aten/src/ATen/core/Vitals.h deleted file mode 100644 index 2fd7729744a10..0000000000000 --- a/aten/src/ATen/core/Vitals.h +++ /dev/null @@ -1,94 +0,0 @@ -#pragma once -#include -#include -#include - -#include - -namespace at::vitals { - -TORCH_API bool torchVitalEnabled(); - -struct TORCH_API TorchVitalAttr { - // always initialized to empty - std::string value; - template - TorchVitalAttr& operator<<(const T& t) { - if (torchVitalEnabled()) { - std::stringstream ss; - ss << t; - value += ss.str(); - } - return *this; - } - - template - void write(const T& t, bool force) { - if (force || torchVitalEnabled()) { - std::stringstream ss; - ss << t; - value = ss.str(); - } - } -}; - -struct TORCH_API TorchVital { - std::string name; - std::unordered_map attrs; - - explicit TorchVital(std::string n) : name(std::move(n)) {} - TorchVital(const TorchVital&) = default; - TorchVital(TorchVital&&) = default; - TorchVital& operator=(const TorchVital&) = default; - TorchVital& operator=(TorchVital&&) = default; - TorchVital() = delete; - - TorchVitalAttr& create(const std::string& attr); - TorchVitalAttr& create(const std::string& attr, bool force); - friend std::ostream& operator<<(std::ostream& os, const TorchVital& dt); - - ~TorchVital(); -}; - -std::ostream& operator<<(std::ostream& os, TorchVital const& tv); - -// A way to access vitals by string names instead of by global reference. -// This enables access to vitals from the PythonAPI. -class TORCH_API APIVitals { - public: - bool vitals_enabled; - - // Set any vital sign that was added to the map. - bool setVital( - const std::string& vital_name, - const std::string& attr_name, - const std::string& value, - bool force = false); - std::string readVitals(); - - APIVitals(); - - // Ensure this stays a singleton - APIVitals(APIVitals const& other) = delete; - APIVitals(APIVitals&& other) = delete; - APIVitals& operator=(const APIVitals&) = delete; - APIVitals& operator=(APIVitals&&) = delete; - ~APIVitals() = default; - - private: - std::unordered_map name_map_; -}; - -extern TORCH_API APIVitals VitalsAPI; - -} // namespace at::vitals - -#define TORCH_VITAL_DECLARE(name) \ - TORCH_API at::vitals::TorchVital TorchVital_##name; - -#define TORCH_VITAL_DEFINE(name) \ - TORCH_API at::vitals::TorchVital TorchVital_##name(#name); - -#define TORCH_VITAL_BASE(name) TorchVital_##name - -#define TORCH_VITAL(name, attr) TORCH_VITAL_BASE(name).create(#attr) diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index 6b63bd48009ee..f8998c5addc6d 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -552,10 +552,8 @@ class TORCH_API OperatorHandle { } template - PyObject* getPythonOp( - c10::impl::PyInterpreter* self_interpreter, - F slow_accessor) const { - return operatorDef_->op.getPythonOp(self_interpreter, slow_accessor); + PyObject* getPythonOp(F slow_accessor) const { + return operatorDef_->op.getPythonOp(slow_accessor); } bool operator==(const OperatorHandle& other) const { diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index cc5736ba0e77e..05e84b0a442f3 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -224,9 +224,8 @@ class TORCH_API OperatorEntry final { void setReportErrorCallback_(std::unique_ptr callback); template - PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) - const { - return py_cache_.ptr_or(self_interpreter, slow_accessor); + PyObject* getPythonOp(F slow_accessor) const { + return py_cache_.ptr_or(slow_accessor); } private: diff --git a/aten/src/ATen/core/dynamic_type.cpp b/aten/src/ATen/core/dynamic_type.cpp index 2b1a32bd0ac8a..85900a8f1d230 100644 --- a/aten/src/ATen/core/dynamic_type.cpp +++ b/aten/src/ATen/core/dynamic_type.cpp @@ -226,8 +226,7 @@ TypeKind DynamicType::dynamicKind() const { // resolve to integers #undef CASE_TYPE default: - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); - return TypeKind::AnyType; + TORCH_INTERNAL_ASSERT_FALSE_OR_RETURN(TypeKind::AnyType); } } @@ -310,8 +309,7 @@ TypePtr DynamicType::fallback() const { case Tag::Any: return AnyType::get(); } - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); - return nullptr; + TORCH_INTERNAL_ASSERT_FALSE_OR_RETURN(nullptr); } bool DynamicType::LabeledDynamicType::isSubtypeOf( @@ -372,8 +370,7 @@ DynamicTypePtr ivalue::TupleTypeFactory::create( DynamicTypePtr ivalue::TupleTypeFactory::fallback( const Type& /*unused*/) { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); - return nullptr; + TORCH_INTERNAL_ASSERT_FALSE_OR_RETURN(nullptr); } TORCH_API TupleTypePtr ivalue::TupleTypeFactory::fallback( diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 38942031befcd..532c921ca4258 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -187,7 +187,6 @@ namespace c10 { _(aten, append) \ _(aten, as_tensor) \ _(aten, adaptive_avg_pool2d_backward) \ - _(aten, dim) \ _(aten, format) \ _(aten, percentFormat) \ _(aten, __not__) \ diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index b2236dbb68828..4bfef7cfde5d9 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -371,7 +371,7 @@ inline ShapeSymbol merge_primitive( // dims, partially known and fully known shapes are all supported. struct TORCH_API SymbolicShape { // Unranked shape constructor. - SymbolicShape() : dims_(std::nullopt) {} + SymbolicShape() = default; // Known rank but unknown dimensions. SymbolicShape(std::optional rank) : dims_(std::nullopt) { diff --git a/aten/src/ATen/cpu/vec/sve/vec_float.h b/aten/src/ATen/cpu/vec/sve/vec_float.h index 743ebde5105a5..8136e6b8c6ee9 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_float.h +++ b/aten/src/ATen/cpu/vec/sve/vec_float.h @@ -285,8 +285,56 @@ class Vectorized { svfloat32_t poly = svmla_x(svptrue_b32(), r, r2, c1); return svmla_x(svptrue_b32(), scale, scale, poly); } + // Implementation from Arm Optimized Routines: + // https://github.com/ARM-software/optimized-routines/blob/v26.01/math/aarch64/experimental/sve/sv_expf_inline.h Vectorized fexp_u20() const { - return exp_u20(); + // fast exponential intended for cases where outputs will be downcasted to + // FP16 / BF16 (e.g. attention softmax). + // Accurate within 1 ULP for FP16 + // Accurate within 1 ULP for BF16 for inputs in [-87.346, max_float] & + // clamps + // inputs < -87.346 to zero. + // Implementation is similar to exp_u20, but: + // - approximates exp(r) - 1 as r instead of r + 0.5 r^2 + // - does not split natural log (ln) into high / low parts + // - avoids special case code by clamping exp(x) to 0 for x < -87.346 and + // inf for x > 88.717 + + constexpr float upper_bound = 0x1.62dea4p+6f; + constexpr float lower_bound = -0x1.5d619ap+6f; + + const svfloat32_t ln2 = svdup_n_f32(0x1.62e43p-1f); + const svfloat32_t inv_ln2 = svdup_n_f32(0x1.715476p+0f); + + constexpr float shift = 0x1.803f8p17f; + + // exp(x) = 2^n (1 + poly(r)), with 1 + poly(r) in [1/sqrt(2),sqrt(2)] + // x = ln2*n + r, with r in [-ln2/2, ln2/2] and poly(r) ~= r. + + // n = round(x/(ln2/N)) + svfloat32_t z = svmad_x(svptrue_b32(), inv_ln2, values, shift); + svfloat32_t n = svsub_x(svptrue_b32(), z, shift); + + // n = round(x/(ln2/N)) + svfloat32_t r = svmls_x(svptrue_b32(), values, n, ln2); + + // scale = 2^(n/N) + svfloat32_t scale = svexpa(svreinterpret_u32(z)); + + // poly(r) = exp(r) - 1 ~= r + svfloat32_t y = svmla_x(svptrue_b32(), scale, scale, r); + + // clamp to 0, inf + y = svsel_f32( + svcmplt_f32(svptrue_b32(), values, svdup_n_f32(lower_bound)), + svdup_n_f32(0.0f), + y); + y = svsel_f32( + svcmpgt_f32(svptrue_b32(), values, svdup_n_f32(upper_bound)), + svdup_n_f32(INFINITY), + y); + + return y; } Vectorized fmod(const Vectorized& q) const { USE_SLEEF( diff --git a/aten/src/ATen/cpu/vec/sve/vec_qint.h b/aten/src/ATen/cpu/vec/sve/vec_qint.h index c3107720dc8d5..b3c5db4c30962 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_qint.h +++ b/aten/src/ATen/cpu/vec/sve/vec_qint.h @@ -139,7 +139,7 @@ struct VectorizedQuantizedConverter { } protected: - VectorizedQuantizedConverter() {} + VectorizedQuantizedConverter() = default; }; template <> diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h index 6ab8bf6a98dc7..2f4eaeeabc9ad 100644 --- a/aten/src/ATen/cpu/vec/vec128/vec128_convert.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_convert.h @@ -373,6 +373,8 @@ struct VecConvert { } }; +// bf16/fp16 vec classes are not available for C10_MOBILE +#if !defined(C10_MOBILE) template <> struct VecConvert { static inline VectorizedN apply( @@ -392,6 +394,8 @@ struct VecConvert { } }; +#endif // !defined(C10_MOBILE) + #endif // defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256) } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h index 425fb6aa79e13..86cb950501686 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_16bit_float.h @@ -195,7 +195,7 @@ class Vectorized16 { static constexpr size_type size() { return 16; } - Vectorized16() {} + Vectorized16() = default; Vectorized16(__m256i v) : values(v) {} Vectorized16(T val) { value_type uw = val.x; diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_int.h b/aten/src/ATen/cpu/vec/vec256/vec256_int.h index eac5c710c9002..70a942a3e6677 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_int.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_int.h @@ -194,7 +194,7 @@ class Vectorized : public Vectorizedi { return 8; } using Vectorizedi::Vectorizedi; - Vectorized() {} + Vectorized() = default; Vectorized(int32_t v) { values = _mm256_set1_epi32(v); } @@ -412,7 +412,7 @@ class Vectorized : public Vectorizedi { return 16; } using Vectorizedi::Vectorizedi; - Vectorized() {} + Vectorized() = default; Vectorized(int16_t v) { values = _mm256_set1_epi16(v); } @@ -642,7 +642,7 @@ class Vectorized8 : public Vectorizedi { return 32; } using Vectorizedi::Vectorizedi; - Vectorized8() {} + Vectorized8() = default; Vectorized8(T v) { values = _mm256_set1_epi8(v); } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 9e12589b58590..087737d65c29a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -359,7 +359,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m256i vals_) { vals = vals_; @@ -566,7 +566,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m256i vals_) { vals = vals_; } @@ -778,7 +778,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m256i vals_) { vals = vals_; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h index a6a883e53b39b..34e8e3e7a6ced 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_double_vsx.h @@ -36,7 +36,7 @@ class Vectorized { static constexpr size_type size() { return 2; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h index 9acc79cdeb4c5..0f5aa527f41bf 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_complex_float_vsx.h @@ -39,7 +39,7 @@ class Vectorized { static constexpr size_type size() { return 4; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h index db574702f3ee1..24fa15cc35fc9 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_double_vsx.h @@ -38,7 +38,7 @@ class Vectorized { static constexpr size_type size() { return 4; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vfloat64 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} C10_ALWAYS_INLINE Vectorized(vfloat64 v1, vfloat64 v2) diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h index 535d3a23173d5..a50ab8a8d7410 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_float_vsx.h @@ -37,7 +37,7 @@ class Vectorized { static constexpr size_type size() { return 8; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vfloat32 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h index 7176dd15d75ed..5fd83af090fed 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int16_vsx.h @@ -34,7 +34,7 @@ class Vectorized { static constexpr size_type size() { return 16; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vint16 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool16 vmask) : _vecb0{vmask}, _vecb1{vmask} {} C10_ALWAYS_INLINE Vectorized(vint16 v1, vint16 v2) : _vec0{v1}, _vec1{v2} {} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h index 75d3ba381ad41..1c85df8dafa27 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int32_vsx.h @@ -34,7 +34,7 @@ class Vectorized { static constexpr size_type size() { return 8; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vint32 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool32 vmask) : _vecb0{vmask}, _vecb1{vmask} {} C10_ALWAYS_INLINE Vectorized(vint32 v1, vint32 v2) : _vec0{v1}, _vec1{v2} {} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h index 653c277b7d033..e37fc33a656d9 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_int64_vsx.h @@ -35,7 +35,7 @@ class Vectorized { static constexpr size_type size() { return 4; } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vint64 v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(vbool64 vmask) : _vecb0{vmask}, _vecb1{vmask} {} C10_ALWAYS_INLINE Vectorized(vint64 v1, vint64 v2) : _vec0{v1}, _vec1{v2} {} diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index ad895bf54d95a..bccc598a19d57 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -49,7 +49,7 @@ struct Vectorized { } __attribute__((__may_alias__)); public: - Vectorized() {} + Vectorized() = default; using size_type = int; static constexpr size_type size() { diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h index a707155aad787..a87ad75b36c8f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint8_vsx.h @@ -49,7 +49,7 @@ struct Vectorized { } __attribute__((__may_alias__)); public: - Vectorized() {} + Vectorized() = default; using size_type = int; static constexpr size_type size() { return 32; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h index 5863df6bd667c..f864ee945426f 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_quint8_vsx.h @@ -53,7 +53,7 @@ struct Vectorized { } __attribute__((__may_alias__)); public: - Vectorized() {} + Vectorized() = default; using size_type = int; static constexpr size_type size() { return 32; diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index efb97b3c614db..bce551e880ee1 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -386,7 +386,7 @@ struct Vectorized()>> { static constexpr size_type size() { return VECTOR_WIDTH / sizeof(ElementType); } - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(vtype v) : _vec0{v}, _vec1{v} {} C10_ALWAYS_INLINE Vectorized(const vinner_data& v) @@ -1781,7 +1781,7 @@ struct Vectorized()>> { vinner_type _vec; public: - Vectorized() {} + Vectorized() = default; explicit C10_ALWAYS_INLINE Vectorized(vinner_type v) : _vec{v} {} Vectorized(const T& val) : _vec(val.val_) {} @@ -2250,7 +2250,7 @@ struct Vectorized()>> { vinner_type _vec; public: - Vectorized() {} + Vectorized() = default; C10_ALWAYS_INLINE Vectorized(const vinner_data& v) : _vec{v.first, v.second} {} diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float8.h b/aten/src/ATen/cpu/vec/vec512/vec512_float8.h index 2a4ca9b958c55..e94e6fbf5ba9e 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float8.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float8.h @@ -125,10 +125,15 @@ static inline __m128i cvtfp32_fp8e4m3(const __m512& src) { __m512i result = _mm512_setzero_si512(); // Step 1: Handle case of overflow - // (f_bits >= fp8_max): set result = 0x7f + // NaN (f_bits > fp32_inf): set result = 0x7f + // Finite overflow or inf (fp8_max <= f_bits <= fp32_inf): saturate to 0x7e + const __m512i fp32_inf = _mm512_set1_epi32(UINT32_C(0x7F800000)); __mmask16 overflow_mask = _mm512_cmpge_epu32_mask(f_bits, fp8_max); if (overflow_mask) { - result = _mm512_mask_set1_epi32(result, overflow_mask, 0x7f); + __mmask16 nan_mask = _mm512_cmpgt_epu32_mask(f_bits, fp32_inf); + __mmask16 sat_mask = overflow_mask & ~nan_mask; + result = _mm512_mask_set1_epi32(result, nan_mask, 0x7f); + result = _mm512_mask_set1_epi32(result, sat_mask, 0x7e); } // Step 2: Handle small numbers (denormals) @@ -158,6 +163,10 @@ static inline __m128i cvtfp32_fp8e4m3(const __m512& src) { rounded = _mm512_add_epi32(rounded, mant_odd); // Shift right by 20 bits __m512i normal_result = _mm512_srli_epi32(rounded, 20); + // Rounding may carry into the NaN bit pattern (0x7f); saturate to max + __mmask16 round_overflow = + _mm512_cmpeq_epi32_mask(normal_result, _mm512_set1_epi32(0x7f)); + normal_result = _mm512_mask_set1_epi32(normal_result, round_overflow, 0x7e); result = _mm512_mask_mov_epi32(result, normal_mask, normal_result); } @@ -338,7 +347,7 @@ class Vectorizedf8 { static constexpr size_type size() { return 64; } - Vectorizedf8() {} + Vectorizedf8() = default; Vectorizedf8(__m512i v) : values(v) {} Vectorizedf8(T val) { value_type uw = val.x; @@ -377,6 +386,15 @@ class Vectorizedf8 { } } + static Vectorized blendv( + const Vectorized& a, + const Vectorized& b, + const Vectorized& mask) { + auto msb_one = _mm512_set1_epi8(0xFF); + auto mask_ = _mm512_cmp_epu8_mask((__m512i)mask, msb_one, _MM_CMPINT_EQ); + return _mm512_mask_blend_epi8(mask_, (__m512i)a, (__m512i)b); + } + Vectorized abs() const { return _mm512_andnot_si512(_mm512_set1_epi8(0x80), values); } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_int.h b/aten/src/ATen/cpu/vec/vec512/vec512_int.h index 236c31e24244d..f573b9e12492d 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_int.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_int.h @@ -23,7 +23,7 @@ struct Vectorizedi { } public: - Vectorizedi() {} + Vectorizedi() = default; Vectorizedi(__m512i v) : values(v) {} operator __m512i() const { return values; @@ -210,7 +210,7 @@ class Vectorized : public Vectorizedi { return 16; } using Vectorizedi::Vectorizedi; - Vectorized() {} + Vectorized() = default; Vectorized(int32_t v) { values = _mm512_set1_epi32(v); } @@ -459,7 +459,7 @@ class Vectorized : public Vectorizedi { return 32; } using Vectorizedi::Vectorizedi; - Vectorized() {} + Vectorized() = default; Vectorized(int16_t v) { values = _mm512_set1_epi16(v); } @@ -740,7 +740,7 @@ class Vectorized8 : public Vectorizedi { return 64; } using Vectorizedi::Vectorizedi; - Vectorized8() {} + Vectorized8() = default; Vectorized8(T v) { values = _mm512_set1_epi8(v); } diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index d5899b4b4f948..94bd60dfa8799 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -386,7 +386,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m512i vals_) { vals = vals_; @@ -602,7 +602,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m512i vals_) { vals = vals_; } @@ -838,7 +838,7 @@ struct Vectorized : public Vectorizedqi { public: using Vectorizedqi::Vectorizedqi; - Vectorized() {} + Vectorized() = default; Vectorized(__m512i vals_) { vals = vals_; @@ -1133,7 +1133,7 @@ struct VectorizedQuantizedConverter { } protected: - VectorizedQuantizedConverter() {} + VectorizedQuantizedConverter() = default; }; template <> diff --git a/aten/src/ATen/cpu/vec/vec_convert.h b/aten/src/ATen/cpu/vec/vec_convert.h index b601ab0bc07b8..3aa4df91f25bc 100644 --- a/aten/src/ATen/cpu/vec/vec_convert.h +++ b/aten/src/ATen/cpu/vec/vec_convert.h @@ -76,12 +76,24 @@ inline std::enable_if_t, Vectorized> convert return src; } +template +inline std::enable_if_t, Vectorized> +round_convert(const Vectorized& src) { + return src; +} + template inline std::enable_if_t, Vectorized> convert(const Vectorized& src) { return VecConvert::apply(src); } +template +inline std::enable_if_t, Vectorized> +round_convert(const Vectorized& src) { + return VecRoundConvert::apply(src); +} + template < typename dst_t, int dst_n, @@ -103,12 +115,6 @@ inline VectorizedN round_convert( return VecRoundConvert::apply(src); } -template -inline std::enable_if_t, Vectorized> -round_convert(const Vectorized& src) { - return VecRoundConvert::apply(src); -} - template < typename dst_t, int dst_n, @@ -121,6 +127,18 @@ convert(const VectorizedN& src) { return VecConvert::apply(src); } +template < + typename dst_t, + int dst_n, + typename src_t, + int src_n, + bool keep = false, + std::enable_if_t = 0> +inline std::conditional_t, Vectorized> +round_convert(const VectorizedN& src) { + return VecRoundConvert::apply(src); +} + } // namespace CPU_CAPABILITY template < diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index b9e66fea5ebdb..748eecbc1572a 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1956,7 +1956,6 @@ case ScalingType::TensorWise: default: TORCH_CHECK(false); - return -1; } } diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index 6251d40e6daad..66d4bce542039 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -85,13 +85,13 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); /* Handles */ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); -TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); +TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(bool setup = true); TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); TORCH_CUDA_CPP_API void clearCublasWorkspaces(); TORCH_CUDA_CPP_API void clearCublasWorkspacesForStream(cudaStream_t stream); struct WorkspaceMapWithMutex { - std::map, at::DataPtr> map; + std::map, std::pair> map; std::shared_mutex mutex; }; @@ -100,6 +100,10 @@ TORCH_CUDA_CPP_API WorkspaceMapWithMutex& cublaslt_handle_stream_to_workspace(); TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize(); TORCH_CUDA_CPP_API size_t getCUDABlasLtWorkspaceSize(); TORCH_CUDA_CPP_API void* getCUDABlasLtWorkspace(); +TORCH_CUDA_CPP_API void setChosenWorkspaceSize(size_t size); +TORCH_CUDA_CPP_API void setCUDABlasLtWorkspaceSize(size_t size); +TORCH_CUDA_CPP_API void resetChosenWorkspaceSize(); +TORCH_CUDA_CPP_API void resetCUDABlasLtWorkspaceSize(); TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle(); diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index 536020ffa1c86..9b690502bf7a5 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -82,151 +82,144 @@ Generator createCUDAGenerator(DeviceIndex device_index) { } // namespace cuda::detail /** - * Creates a clone of this CUDA Generator State. + * Allocate GPU tensors for this capture state. + * + * We allocate on the default stream so that the caching allocator routes + * these tensors to the default memory pool, not the graph's capture pool. */ +void CUDAGeneratorCaptureState::initialize(uint64_t seed) { + if (is_initialized()) { + return; + } + + auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong); + c10::InferenceMode inference_guard(false); + + // Allocate on the default stream so that the caching allocator routes + // these tensors to the default memory pool, not the graph's capture pool. + // The relaxed capture mode guard is needed because the thread-local capture + // mode may be Global (set by cudaStreamBeginCapture), which would block + // cudaMalloc even on a non-capturing stream. + c10::cuda::CUDAStreamCaptureModeGuard capture_mode_guard( + cudaStreamCaptureModeRelaxed); + c10::cuda::CUDAStreamGuard stream_guard(c10::cuda::getDefaultCUDAStream()); + + rng_state_seed_extragraph_ = at::empty({1}, options); + rng_state_offset_extragraph_ = at::empty({1}, options); + + // Synchronize the default stream so that any prior work completes before + // a different stream writes to this memory. + c10::cuda::getDefaultCUDAStream().synchronize(); + + offset_intragraph_ = 0; +} + +void CUDAGeneratorCaptureState::increase(uint64_t increment) { + // see Note [Why enforce RNG offset % 4 == 0?] + TORCH_INTERNAL_ASSERT( + offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4."); + TORCH_INTERNAL_ASSERT( + offset_intragraph_ <= std::numeric_limits::max() - increment, + "Increment causes overflow in the offset value."); + offset_intragraph_ += increment; +} + +uint64_t CUDAGeneratorCaptureState::finalize() { + uint64_t result = offset_intragraph_; + offset_intragraph_ = 0; + return result; +} + +void CUDAGeneratorCaptureState::setup_for_replay(uint64_t seed, uint64_t philox_offset) { + TORCH_INTERNAL_ASSERT(is_initialized(), + "Capture state not initialized"); + rng_state_seed_extragraph_.fill_(static_cast(seed)); + rng_state_offset_extragraph_.fill_(static_cast(philox_offset)); +} + c10::intrusive_ptr CUDAGeneratorState::clone() { - return make_intrusive( - seed_, philox_offset_per_thread_, offset_intragraph_); + return make_intrusive(seed_, philox_offset_per_thread_); } /** - * Function to increase the internal offset based on the specified increment. + * Lookup capture state for a capture ID. Returns nullptr if not found. */ -void CUDAGeneratorState::increase(uint64_t increment) { - // Rounds increment up to the nearest multiple of 4 to meet alignment - // requirements. - // see Note [Why enforce RNG offset % 4 == 0?] - increment = ((increment + 3) / 4) * 4; - // Handling different behaviors based on whether capturing is active. - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - // Ensures that the state is actually capturing. - TORCH_CHECK( - capturing_, - "Attempt to increase offset for a CUDA generator not in capture mode."); - // Ensures the offset is a multiple of 4 - // see Note [Why enforce RNG offset % 4 == 0?] - TORCH_INTERNAL_ASSERT( - offset_intragraph_ % 4 == 0, "RNG offset must be a multiple of 4."); - // Ensures the increment does not cause overflow. - TORCH_INTERNAL_ASSERT( - offset_intragraph_ <= std::numeric_limits::max() - increment, - "Increment causes overflow in the offset value."); - offset_intragraph_ += increment; - } else { - // Checks that the increment is expected outside graph capturing. - TORCH_CHECK( - !capturing_, - "Offset increment outside graph capture encountered unexpectedly."); - // Ensures the offset is a multiple of 4 - // see Note [Why enforce RNG offset % 4 == 0?] - TORCH_INTERNAL_ASSERT( - philox_offset_per_thread_ % 4 == 0, - "RNG offset must be a multiple of 4."); - philox_offset_per_thread_ += increment; +CUDAGeneratorCaptureState* CUDAGeneratorState::get_capture_state(CaptureId_t capture_id) { + std::lock_guard lock(capture_states_mutex_); + auto it = capture_states_.find(capture_id); + if (it != capture_states_.end()) { + return it->second.get(); } + return nullptr; } /** - * Registers this state to a CUDA graph to manage within the graph. + * Create and initialize capture state for a given capture ID. + * Called during capture_begin for each registered generator. */ -void CUDAGeneratorState::register_graph(cuda::CUDAGraph* graph) { - // Ensures that the RNG state is not currently being captured. - at::cuda::assertNotCapturing( - "Cannot register the state during capturing stage."); - - // If this is the first graph to be registered, allocate memory for the seed - // and offset on the GPU. - if (registered_graphs_.empty()) { - auto options = at::TensorOptions().device(at::kCUDA).dtype(at::kLong); - // Create these tensors outside of inference mode to ensure they can be - // modified in-place later. If we create them as inference tensors, - // subsequent fill_() calls outside inference mode - // will fail with "Inplace update to inference tensor outside InferenceMode". - c10::InferenceMode guard(false); - seed_extragraph_ = at::empty({1}, options); - offset_extragraph_ = at::empty({1}, options); +void CUDAGeneratorState::init_capture_state(CaptureId_t capture_id) { + { + std::lock_guard lock(capture_states_mutex_); + if (capture_states_.count(capture_id)) { + return; + } } - // Insert the graph into the set of registered graphs if it's not already - // registered. - if (registered_graphs_.find(graph) == registered_graphs_.end()) { - registered_graphs_.insert(graph); + auto capture_state = make_intrusive(); + capture_state->initialize(seed_); + + std::lock_guard lock(capture_states_mutex_); + if (!capture_states_.count(capture_id)) { + capture_states_[capture_id] = std::move(capture_state); } } /** - * Unregisters a CUDA graph from the RNG state. + * Function to increase the internal offset based on the specified increment. */ -void CUDAGeneratorState::unregister_graph(cuda::CUDAGraph* graph) { - // Verify the graph was previously registered. - TORCH_CHECK( - registered_graphs_.find(graph) != registered_graphs_.end(), - "The graph should be registered to the state"); - - // Remove the graph from the set of registered graphs. - registered_graphs_.erase(graph); - - // If no more graphs are registered, deallocate the GPU memory for the seed - // and offset. - if (registered_graphs_.empty()) { - seed_extragraph_.reset(); - offset_extragraph_.reset(); +void CUDAGeneratorState::increase(uint64_t increment) { + // see Note [Why enforce RNG offset % 4 == 0?] + increment = ((increment + 3) / 4) * 4; + + auto capture_id = at::cuda::currentStreamCaptureId(); + if (capture_id.has_value()) { + auto* capture_state = get_capture_state(capture_id.value()); + TORCH_CHECK(capture_state != nullptr, + "RNG op during graph capture but generator is not registered with " + "the capturing graph. Call graph.register_generator_state() before " + "capture_begin()."); + capture_state->increase(increment); + } else { + TORCH_INTERNAL_ASSERT( + philox_offset_per_thread_ % 4 == 0, + "RNG offset must be a multiple of 4."); + philox_offset_per_thread_ += increment; } } -/** - * Note [Explicit Registration of Generators to the CUDA Graph] - * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - * - * Ideally, it would be more user-friendly if the state could be exchanged and generators - * could be registered with the CUDA graph implicitly. However, resetting GPU tensors during - * the capture stage causes these reset operations to be recorded within the CUDA graph. - * This behavior is undesirable because we do not want these tensors to be reset during - * the replay stage of the graph. - * - * As of now, there is no available method to perform a CUDA operation during the graph's - * recording phase without having that operation be included in the CUDA graph. - * This limitation necessitates explicit user action to register generators with the graph. - * By requiring users to manually register their generators, we can ensure that state resets - * (capture_prologue) only occur before the graph capture begins, thus avoiding unintended - * resets during the replay of the graph. See https://github.com/pytorch/pytorch/pull/114068. - */ - -/** - * Performs the prologue steps for capturing a CUDA graph state. - * This method is intended to reset graph-related state variables before capturing begins. - */ -void CUDAGeneratorState::capture_prologue() { - capturing_ = true; - offset_intragraph_ = 0; - seed_extragraph_.fill_(static_cast(seed_)); - offset_extragraph_.fill_(0); +uint64_t CUDAGeneratorState::capture_epilogue(CaptureId_t capture_id) { + auto* capture_state = get_capture_state(capture_id); + if (capture_state) { + return capture_state->finalize(); + } + return 0; } -/** - * Ends the capturing phase and resets related variables, returning the whole - * graph increment. - */ -uint64_t CUDAGeneratorState::capture_epilogue() { - capturing_ = false; - return offset_intragraph_; +void CUDAGeneratorState::remove_capture_state(CaptureId_t capture_id) { + std::lock_guard lock(capture_states_mutex_); + capture_states_.erase(capture_id); } -/** - * Prepares the state for replay by setting initial state tensors and applying - * total increment. - */ -void CUDAGeneratorState::replay_prologue(uint64_t wholegraph_increment) { - // Ensures the generator is not in capturing mode. - at::cuda::assertNotCapturing( - "Cannot prepare for replay during capturing stage."); - if (wholegraph_increment) { - seed_extragraph_.fill_(static_cast(seed_)); - offset_extragraph_.fill_(static_cast(philox_offset_per_thread_)); - // Applies the total increment achieved during previous captures to update the - // offset. - increase(wholegraph_increment); +void CUDAGeneratorState::replay_prologue(CaptureId_t capture_id, uint64_t wholegraph_increment) { + if (wholegraph_increment == 0) { + return; } + + auto* capture_state = get_capture_state(capture_id); + TORCH_INTERNAL_ASSERT(capture_state != nullptr, + "replay_prologue called but no capture state found for this capture_id"); + capture_state->setup_for_replay(seed_, philox_offset_per_thread_); + philox_offset_per_thread_ += wholegraph_increment; } /** @@ -271,7 +264,8 @@ CUDAGeneratorImpl::CUDAGeneratorImpl( * See Note [Acquire lock when using random generators] */ void CUDAGeneratorImpl::set_current_seed(uint64_t seed) { - if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { + auto capture_id = at::cuda::currentStreamCaptureId(); + if (C10_LIKELY(!capture_id.has_value())) { state_->seed_ = seed; state_->philox_offset_per_thread_ = 0; no_reset_rnn_state_.clear(); @@ -410,10 +404,14 @@ void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { // set_philox_offset_per_thread instead of set_offset will cause the // cudnn RNN rng state to become stale. TORCH_CHECK(offset % 4 == 0, "offset must be a multiple of 4"); - if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { + auto capture_id = at::cuda::currentStreamCaptureId(); + if (C10_LIKELY(!capture_id.has_value())) { state_->philox_offset_per_thread_ = offset; } else { - state_->offset_intragraph_ = offset; + auto* capture_state = state_->get_capture_state(capture_id.value()); + TORCH_CHECK(capture_state != nullptr, + "Generator not registered with the capturing graph."); + capture_state->offset_intragraph_ = offset; } } @@ -421,26 +419,19 @@ void CUDAGeneratorImpl::set_philox_offset_per_thread(uint64_t offset) { * Gets the current philox_offset_per_thread_ of CUDAGeneratorImpl. */ uint64_t CUDAGeneratorImpl::philox_offset_per_thread() const { - if (C10_LIKELY(at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None)) { + auto capture_id = at::cuda::currentStreamCaptureId(); + if (C10_LIKELY(!capture_id.has_value())) { return state_->philox_offset_per_thread_; } else { - return state_->offset_intragraph_; + auto* capture_state = state_->get_capture_state(capture_id.value()); + TORCH_CHECK(capture_state != nullptr, + "Generator not registered with the capturing graph."); + return capture_state->offset_intragraph_; } } -/** - * Registers this state to a CUDA graph to manage within the graph. - */ void CUDAGeneratorImpl::register_graph(cuda::CUDAGraph* graph) { graph->register_generator_state(state_); - state_->register_graph(graph); -} - -/** - * Unregisters a CUDA graph from the RNG state. - */ -void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) { - state_->unregister_graph(graph); } /** @@ -465,12 +456,20 @@ void CUDAGeneratorImpl::unregister_graph(cuda::CUDAGraph* graph) { * See Note [Acquire lock when using random generators] */ PhiloxCudaState CUDAGeneratorImpl::philox_cuda_state(uint64_t increment) { - if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) { - uint64_t offset = state_->offset_intragraph_; + auto capture_id = at::cuda::currentStreamCaptureId(); + if (capture_id.has_value()) { + auto* capture_state = state_->get_capture_state(capture_id.value()); + TORCH_CHECK(capture_state != nullptr, + "RNG op during graph capture but generator is not registered with " + "the capturing graph. Call graph.register_generator_state() before " + "capture_begin()."); + + uint64_t offset = capture_state->offset_intragraph_; state_->increase(increment); + return PhiloxCudaState( - state_->seed_extragraph_.data_ptr(), - state_->offset_extragraph_.data_ptr(), + capture_state->rng_state_seed_extragraph_.data_ptr(), + capture_state->rng_state_offset_extragraph_.data_ptr(), offset); } else { uint64_t offset = state_->philox_offset_per_thread_; diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.h b/aten/src/ATen/cuda/CUDAGeneratorImpl.h index d4ab49382e7ff..fe62fc53eef07 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.h +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.h @@ -4,15 +4,20 @@ #include #include #include +#include +#include #include #include -#include +#include + namespace at { namespace cuda { struct CUDAGraph; } +using CaptureId_t = c10::CaptureId_t; + /** * Note [CUDA Graph-safe RNG states] * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -93,32 +98,51 @@ struct CUDAGraph; * */ +/** + * Per-capture state for a generator. + * Each (generator, capture_id) pair gets its own CUDAGeneratorCaptureState. + * This holds the GPU tensors and offset tracking for a specific graph capture. + */ +struct CUDAGeneratorCaptureState : public c10::intrusive_ptr_target { + uint64_t offset_intragraph_{0}; + at::TensorBase rng_state_seed_extragraph_; + at::TensorBase rng_state_offset_extragraph_; + + CUDAGeneratorCaptureState() = default; + + bool is_initialized() const { return rng_state_seed_extragraph_.defined(); } + void initialize(uint64_t seed); + void increase(uint64_t increment); + uint64_t finalize(); + void setup_for_replay(uint64_t seed, uint64_t philox_offset); +}; + +/** + * Generator state that supports multiple concurrent graph captures. + * Each capture gets its own CUDAGeneratorCaptureState keyed by CaptureId_t. + */ struct CUDAGeneratorState : public c10::intrusive_ptr_target { uint64_t seed_; uint64_t philox_offset_per_thread_; - uint64_t offset_intragraph_; - bool capturing_{}; - std::unordered_set registered_graphs_; - at::TensorBase seed_extragraph_; - at::TensorBase offset_extragraph_; + + // Map from capture ID to per-capture state + ska::flat_hash_map> capture_states_; + mutable std::mutex capture_states_mutex_; CUDAGeneratorState( uint64_t seed = default_rng_seed_val, - uint64_t philox_offset_per_thread = 0, - uint64_t offset_intragraph = 0) + uint64_t philox_offset_per_thread = 0) : seed_(seed), - philox_offset_per_thread_(philox_offset_per_thread), - offset_intragraph_(offset_intragraph) {} + philox_offset_per_thread_(philox_offset_per_thread) {} void increase(uint64_t increment); - void register_graph(cuda::CUDAGraph* graph); - void unregister_graph(cuda::CUDAGraph* graph); + CUDAGeneratorCaptureState* get_capture_state(CaptureId_t capture_id); + void init_capture_state(CaptureId_t capture_id); + uint64_t capture_epilogue(CaptureId_t capture_id); + void replay_prologue(CaptureId_t capture_id, uint64_t wholegraph_increment); + void remove_capture_state(CaptureId_t capture_id); - void capture_prologue(); - // capture_epilogue returns the wholegraph_increment - uint64_t capture_epilogue(); - void replay_prologue(uint64_t wholegraph_increment); c10::intrusive_ptr clone(); }; @@ -147,7 +171,6 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { uint64_t philox_offset_per_thread() const; void register_graph(cuda::CUDAGraph* graph); - void unregister_graph(cuda::CUDAGraph* graph); // Generates a PhiloxCudaState with a specified increment, and increment // current state diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index 5a0dd48cbf288..38bead02d2943 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -9,6 +10,7 @@ #include #include +#include namespace at::cuda { @@ -22,6 +24,17 @@ static bool _cuda_graphs_debug = false; static std::mutex _currently_capturing_graphs_mutex; static ska::flat_hash_map _currently_capturing_graphs; + +#if defined(USE_ROCM) +// Returns true when at least one CUDAGraph capture is currently active in this +// process. Uses the same mutex-protected capture map as capture lifecycle +// bookkeeping. +bool is_graph_capture_active() { + std::unique_lock lock(_currently_capturing_graphs_mutex); + return !_currently_capturing_graphs.empty(); +} +#endif // defined(USE_ROCM) + MempoolId_t graph_pool_handle() { // Sets just the second value, to distinguish it from MempoolId_ts created from // cudaStreamGetCaptureInfo id_s in capture_begin. @@ -67,6 +80,24 @@ void CUDAGraph::register_generator_state(const at::Generator& generator) { cuda_gen->register_graph(this); } + +template <> +std::function CUDAGraph::create_allocate_filter() const { + return [this](cudaStream_t stream) { + auto capture_id_opt = c10::cuda::captureIdMayInitCtx(stream); + return capture_id_opt.has_value() && capture_id_opt.value() == capture_id_; + }; +} + +template <> +std::function CUDAGraph::create_allocate_filter() const { + return [this](c10::Stream stream) { + cudaStream_t cuda_stream = CUDAStream(CUDAStream::UNCHECKED, stream); + auto capture_id_opt = c10::cuda::captureIdMayInitCtx(cuda_stream); + return capture_id_opt.has_value() && capture_id_opt.value() == capture_id_; + }; +} + void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_mode) { TORCH_CHECK(!has_graph_exec_, "This CUDAGraph instance already owns a captured graph. " @@ -79,11 +110,6 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode std::nullopt, cuda::detail::getDefaultCUDAGenerator()); gen->register_graph(this); - for (auto& [generator_state, wholegraph_increments] : - captured_generator_states_) { - generator_state->capture_prologue(); - } - auto stream = at::cuda::getCurrentCUDAStream(); TORCH_CHECK(stream != at::cuda::getDefaultCUDAStream(), @@ -94,6 +120,16 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode capture_stream_ = stream; capture_dev_ = c10::cuda::current_device(); +#if defined(USE_ROCM) + // hipBLASLt handles are per-(device, stream) on ROCm and lazily created. + // Ensure the handle for the intended capture stream exists before + // capture begins, because hipblasLtCreate performs internal allocations + // that are not allowed once stream capture is active. + if (at::globalContext().blasPreferredBackend() == at::BlasBackend::Cublaslt) { + (void)at::cuda::getCurrentCUDABlasLtHandle(); + } +#endif + if (pool.first != 0 || pool.second != 0) { // Either value being nonzero means the user supplied a pool to share. // But only one should be nonzero. @@ -111,27 +147,29 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*={0,0}*/, cudaStreamCaptureMode // Addendum: beginAllocateStreamToPool is now called before cudaStreamBeginCapture to prevent an // autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator // due to the capture status being updated _after_ a capture had already started. - c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter()); - - auto filter = create_allocate_filter(); + c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, create_allocate_filter()); - at::getHostAllocator(at::kCUDA)->begin_allocate_to_pool(mempool_id_, [filter](c10::Stream stream) { - return filter(CUDAStream(CUDAStream::UNCHECKED, stream)); - }); + at::getHostAllocator(at::kCUDA)->begin_allocate_to_pool(mempool_id_, create_allocate_filter()); // cudaStreamCaptureModeGlobal is the most conservative option to // prevent potentially unsafe CUDA API calls during capture. See // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode)); - cudaStreamCaptureStatus status{}; - AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_)); - TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); + auto capture_id_opt = c10::cuda::captureIdMayInitCtx(stream); + TORCH_INTERNAL_ASSERT(capture_id_opt.has_value(), + "Stream should be actively capturing after cudaStreamBeginCapture"); + capture_id_ = capture_id_opt.value(); { - std::unique_lock lock(_currently_capturing_graphs_mutex); + std::lock_guard lock(_currently_capturing_graphs_mutex); _currently_capturing_graphs.emplace(capture_id_, this); } + + for (auto& [generator_state, wholegraph_increment] : + captured_generator_states_) { + generator_state->init_capture_state(capture_id_); + } } void CUDAGraph::capture_end() { @@ -140,8 +178,10 @@ void CUDAGraph::capture_end() { TORCH_CHECK(stream.stream() == capture_stream_.stream(), "Capture must end on the same stream it began on."); - AT_CUDA_CHECK(cudaStreamEndCapture(capture_stream_, &graph_)); - + // Capture is over once cudaStreamEndCapture returns (success or failure). + // Clear bookkeeping before propagating the return status so watchdog-side + // checks cannot observe stale "capture active" state on error paths. + cudaError_t endCaptureErr = cudaStreamEndCapture(capture_stream_, &graph_); { std::unique_lock lock(_currently_capturing_graphs_mutex); TORCH_CHECK( @@ -152,12 +192,13 @@ void CUDAGraph::capture_end() { c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_); + AT_CUDA_CHECK(endCaptureErr); TORCH_CHECK(graph_ != nullptr, "Invalid capture."); - for (auto& [generator_state, wholegraph_increments] : + for (auto& [generator_state, wholegraph_increment] : captured_generator_states_) { - wholegraph_increments = generator_state->capture_epilogue(); + wholegraph_increment = generator_state->capture_epilogue(capture_id_); } size_t numCUDAGraphNodes = 0; @@ -219,9 +260,9 @@ void CUDAGraph::replay() { c10::OptionalDeviceGuard device_guard{capture_stream_.device()}; - for (auto& [generator_state, wholegraph_increments] : + for (auto& [generator_state, wholegraph_increment] : captured_generator_states_) { - generator_state->replay_prologue(wholegraph_increments); + generator_state->replay_prologue(capture_id_, wholegraph_increment); } // graph_exec_ may be replayed in any stream. AT_CUDA_CHECK(cudaGraphLaunch(graph_exec_, at::cuda::getCurrentCUDAStream())); @@ -280,6 +321,20 @@ void CUDAGraph::reset() { // and the allocator could end up in all kinds of weird states depending where failure occurred. // If the user catches the failure exception in a script, or is running in REPL or (god forbid) // a Jupyter notebook, I don't see an easy way for reset() to gracefully fix all such possible error states. + + if (capture_id_ != 0) { + for (auto& [generator_state, wholegraph_increment] : captured_generator_states_) { + generator_state->remove_capture_state(capture_id_); + } + } + captured_generator_states_.clear(); + + if (capture_id_ != 0) { + std::lock_guard lock(_currently_capturing_graphs_mutex); + _currently_capturing_graphs.erase(capture_id_); + capture_id_ = 0; + } + if (capture_ended_) { // Clean up cuBLAS workspaces allocated on the capture stream, otherwise live allocations prevent // private pool cleanup @@ -308,10 +363,6 @@ MempoolId_t CUDAGraph::pool() { } CUDAGraph::~CUDAGraph() { - for (auto& [generator_state, wholegraph_increments] : - captured_generator_states_) { - generator_state->unregister_graph(this); - } reset(); // There are recent HIP changes where hipGraphExecDestroy doesn't immediately free memory. @@ -329,17 +380,14 @@ CUDAGraph::~CUDAGraph() { CUDAGraph* CUDAGraph::get_currently_capturing_graph() { std::unique_lock lock(_currently_capturing_graphs_mutex); - cudaStreamCaptureStatus status{}; - CaptureId_t current_capture_id = 0; - auto stream = at::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, ¤t_capture_id)); + auto capture_id_opt = c10::cuda::currentStreamCaptureIdMayInitCtx(); TORCH_CHECK( - status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive, + capture_id_opt.has_value(), "The current stream is not currently capturing."); TORCH_CHECK( - _currently_capturing_graphs.count(current_capture_id), + _currently_capturing_graphs.count(capture_id_opt.value()), "get_currently_capturing_graph() can be used only between capture_begin() and capture_end(). Did you use a stream without making it depend upon the original stream used for capture?"); - return _currently_capturing_graphs.at(current_capture_id); + return _currently_capturing_graphs.at(capture_id_opt.value()); } void CUDAGraph::begin_capture_to_if_node( @@ -425,13 +473,6 @@ getCurrentCUDAStream(), &cond_node, nullptr, 1, cudaStreamSetCaptureDependencies CUDAStream child_stream = getStreamFromPool(); conditional_graph_capture_ids_.push(0); - conditional_rng_snapshots_.emplace(); - auto& conditional_rng_snapshot = conditional_rng_snapshots_.top(); - for (auto& [generator_state, wholegraph_increments] : - captured_generator_states_) { - conditional_rng_snapshot.emplace( - generator_state, generator_state->offset_intragraph_); - } c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_); @@ -445,10 +486,10 @@ getCurrentCUDAStream(), &cond_node, nullptr, 1, cudaStreamSetCaptureDependencies AT_CUDA_CHECK(cudaStreamBeginCaptureToGraph( child_stream, if_node_child_graph, nullptr, nullptr, 0, capture_mode_)); - AT_CUDA_CHECK(cudaStreamGetCaptureInfo( - child_stream, &status, &conditional_graph_capture_ids_.top())); - TORCH_INTERNAL_ASSERT( - status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); + auto child_capture_id_opt = c10::cuda::captureIdMayInitCtx(child_stream); + TORCH_INTERNAL_ASSERT(child_capture_id_opt.has_value(), + "Child stream should be actively capturing after cudaStreamBeginCaptureToGraph"); + conditional_graph_capture_ids_.top() = child_capture_id_opt.value(); conditional_node_streams_.emplace(child_stream); @@ -469,50 +510,38 @@ getCurrentCUDAStream(), &cond_node, nullptr, 1, cudaStreamSetCaptureDependencies void CUDAGraph::end_capture_to_conditional_node() { #if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040) TORCH_INTERNAL_ASSERT( - !conditional_rng_snapshots_.empty(), - "Missing RNG snapshot for conditional node capture."); + !conditional_graph_capture_ids_.empty(), + "Missing capture ID for conditional node."); + CaptureId_t child_capture_id = conditional_graph_capture_ids_.top(); bool rng_or_generators_changed = false; - auto& conditional_rng_snapshot = conditional_rng_snapshots_.top(); - if (conditional_rng_snapshot.size() != captured_generator_states_.size()) { - rng_or_generators_changed = true; - } else { - for (const auto& [generator_state, offset_intragraph_before_capture] : - conditional_rng_snapshot) { - const auto generator_it = captured_generator_states_.find(generator_state); - if (generator_it == captured_generator_states_.end() || - generator_state->offset_intragraph_ != - offset_intragraph_before_capture) { - rng_or_generators_changed = true; - break; - } + for (const auto& [generator_state, wholegraph_increment] : + captured_generator_states_) { + if (generator_state->get_capture_state(child_capture_id) != nullptr) { + rng_or_generators_changed = true; + break; } } { std::unique_lock lock(_currently_capturing_graphs_mutex); - CaptureId_t capture_id = conditional_graph_capture_ids_.top(); TORCH_CHECK( - _currently_capturing_graphs.count(capture_id), + _currently_capturing_graphs.count(child_capture_id), "capture_end() called before capture_begin()."); - _currently_capturing_graphs.erase(capture_id); + _currently_capturing_graphs.erase(child_capture_id); } CUDAStream stream = conditional_node_streams_.top().current_stream(); AT_CUDA_CHECK(cudaStreamEndCapture(stream.stream(), nullptr)); conditional_node_streams_.pop(); conditional_graph_capture_ids_.pop(); - conditional_rng_snapshots_.pop(); c10::cuda::CUDACachingAllocator::endAllocateToPool(capture_dev_, mempool_id_); at::getHostAllocator(at::kCUDA)->end_allocate_to_pool(mempool_id_); if (conditional_graph_capture_ids_.empty()) { c10::cuda::CUDACachingAllocator::beginAllocateToPool( - capture_dev_, mempool_id_, create_allocate_filter()); - auto filter = create_allocate_filter(); - at::getHostAllocator(at::kCUDA)->begin_allocate_to_pool(mempool_id_, [filter](c10::Stream stream) { - return filter(CUDAStream(CUDAStream::UNCHECKED, stream)); - }); + capture_dev_, mempool_id_, create_allocate_filter()); + at::getHostAllocator(at::kCUDA)->begin_allocate_to_pool(mempool_id_, create_allocate_filter()); } else { c10::cuda::CUDACachingAllocator::beginAllocateToPool( capture_dev_, mempool_id_, create_child_allocate_filter()); @@ -521,7 +550,6 @@ void CUDAGraph::end_capture_to_conditional_node() { return filter(CUDAStream(CUDAStream::UNCHECKED, stream)); }); } - constexpr const char* rng_with_conditional_nodes_error = "RNG within data-dependent conditional nodes is not supported yet."; TORCH_CHECK(!rng_or_generators_changed, rng_with_conditional_nodes_error); @@ -533,22 +561,11 @@ void CUDAGraph::end_capture_to_conditional_node() { #endif } -std::function CUDAGraph::create_allocate_filter() { - return [this](cudaStream_t stream) { - cudaStreamCaptureStatus status{}; - CaptureId_t stream_capture_id = 0; - AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id)); - return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_; - }; -} - std::function CUDAGraph::create_child_allocate_filter() { #if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040) return [¤t_capture_id = conditional_graph_capture_ids_.top()](cudaStream_t stream) { - cudaStreamCaptureStatus status{}; - CaptureId_t stream_capture_id{}; - AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id)); - return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == current_capture_id; + auto capture_id_opt = c10::cuda::captureIdMayInitCtx(stream); + return capture_id_opt.has_value() && capture_id_opt.value() == current_capture_id; }; #else // !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040) AT_ERROR( diff --git a/aten/src/ATen/cuda/CUDAGraph.h b/aten/src/ATen/cuda/CUDAGraph.h index 6c93be8ed8577..9a971e82bb787 100644 --- a/aten/src/ATen/cuda/CUDAGraph.h +++ b/aten/src/ATen/cuda/CUDAGraph.h @@ -9,6 +9,7 @@ #include #include +#include #include #if defined(USE_ROCM) || !(defined(CUDA_VERSION) && CUDA_VERSION >= 12040) @@ -30,6 +31,16 @@ namespace cuda { // to CUDAGraph::capture_begin TORCH_CUDA_CPP_API MempoolId_t graph_pool_handle(); +// Returns true if any CUDAGraph capture is currently active in this process. +// Used by ProcessGroupNCCL's ROCm watchdog workaround to avoid calling +// hipEventQuery during active capture on HIP runtimes without the +// event-query capture-mode fix (https://github.com/ROCm/clr/pull/3176). +// Not needed on CUDA/NVIDIA where cross-thread event query does not have this +// restriction. +#if defined(USE_ROCM) +TORCH_CUDA_CPP_API bool is_graph_capture_active(); +#endif // defined(USE_ROCM) + struct TORCH_CUDA_CPP_API CUDAGraph { CUDAGraph(bool keep_graph=false); ~CUDAGraph(); @@ -54,7 +65,6 @@ struct TORCH_CUDA_CPP_API CUDAGraph { CUDAGraph(CUDAGraph&& other) = delete; CUDAGraph& operator=(CUDAGraph&& other) = delete; - // See Note [Explicit Registration of Generators to the CUDA Graph] void register_generator_state(c10::intrusive_ptr state); void register_generator_state(const at::Generator& generator); void capture_begin( @@ -78,7 +88,8 @@ struct TORCH_CUDA_CPP_API CUDAGraph { const Tensor& scalar_cuda_pred_tensor); private: - std::function create_allocate_filter(); + template + std::function create_allocate_filter() const; std::function create_child_allocate_filter(); protected: @@ -136,11 +147,13 @@ struct TORCH_CUDA_CPP_API CUDAGraph { #if !defined(USE_ROCM) && (defined(CUDA_VERSION) && CUDA_VERSION >= 12040) std::stack conditional_node_streams_; std::stack conditional_graph_capture_ids_; - std::stack< - ska::flat_hash_map, uint64_t>> - conditional_rng_snapshots_; #endif // !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040 }; +template <> +std::function CUDAGraph::create_allocate_filter() const; +template <> +std::function CUDAGraph::create_allocate_filter() const; + } // namespace cuda } // namespace at diff --git a/aten/src/ATen/cuda/CUDAGraphsUtils.cuh b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh index d3a5b306eeea4..64ab0140793bb 100644 --- a/aten/src/ATen/cuda/CUDAGraphsUtils.cuh +++ b/aten/src/ATen/cuda/CUDAGraphsUtils.cuh @@ -19,12 +19,17 @@ using CaptureStatus = c10::cuda::CaptureStatus; // Use this version where you don't want to create a CUDA context if none exists. inline CaptureStatus currentStreamCaptureStatus() { - // don't create a context if we don't have to if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) { return c10::cuda::currentStreamCaptureStatusMayInitCtx(); - } else { - return CaptureStatus::None; } + return CaptureStatus::None; +} + +inline std::optional currentStreamCaptureId() { + if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) { + return c10::cuda::currentStreamCaptureIdMayInitCtx(); + } + return std::nullopt; } inline void assertNotCapturing(const std::string& attempt) { diff --git a/aten/src/ATen/cuda/CUDAGreenContext.cpp b/aten/src/ATen/cuda/CUDAGreenContext.cpp index 7a0eb81bfe3c2..07861d7d78a11 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.cpp +++ b/aten/src/ATen/cuda/CUDAGreenContext.cpp @@ -11,10 +11,27 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-private-field") #endif +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13010 && HAS_CUDA_GREEN_CONTEXT() +#define HAS_CUDA_WORKQUEUE_SUPPORT() 1 +#else +#define HAS_CUDA_WORKQUEUE_SUPPORT() 0 +#endif + namespace at::cuda { -GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { +GreenContext::GreenContext( + uint32_t device_id, + std::optional num_sms, + std::optional workqueue_scope, + std::optional workqueue_concurrency_limit) { #if HAS_CUDA_GREEN_CONTEXT() + TORCH_CHECK( + num_sms.has_value() || workqueue_scope.has_value(), + "At least one of num_sms or workqueue_scope must be specified"); + TORCH_CHECK( + !workqueue_concurrency_limit.has_value() || workqueue_scope.has_value(), + "workqueue_concurrency_limit requires workqueue_scope to be set"); + int driver_version; C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); TORCH_CHECK( @@ -29,46 +46,73 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { cudaFree(nullptr); } - CUdevice device; + CUdevice device; device_id_ = device_id; C10_CUDA_DRIVER_CHECK( c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id)); - // Get device resources - CUdevResource device_resource; - C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( - device, &device_resource, CU_DEV_RESOURCE_TYPE_SM)); + std::vector resources; - TORCH_CHECK( - num_sms > 0 && num_sms <= device_resource.sm.smCount, - "Invalid number of SMs requested for green context: ", - num_sms, - " (device has ", - device_resource.sm.smCount, - " SMs)"); - - // Split resources - std::vector result(1); - auto result_data = result.data(); - unsigned int nb_groups = 1; - CUdevResource remaining; + // --- SM resource --- + if (num_sms.has_value()) { + CUdevResource sm_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &sm_resource, CU_DEV_RESOURCE_TYPE_SM)); - C10_CUDA_DRIVER_CHECK( - c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( - result_data, - &nb_groups, - &device_resource, - &remaining, - 0, // default flags - num_sms)); + TORCH_CHECK( + *num_sms > 0 && *num_sms <= sm_resource.sm.smCount, + "Invalid number of SMs requested for green context: ", + *num_sms, + " (device has ", + sm_resource.sm.smCount, + " SMs)"); + + // Split resources + std::vector split_result(1); + unsigned int nb_groups = 1; + CUdevResource remaining; + + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDevSmResourceSplitByCount_( + split_result.data(), + &nb_groups, + &sm_resource, + &remaining, + 0, // default flags + *num_sms)); + TORCH_CHECK(nb_groups == 1, "Failed to create single SM resource group"); + resources.push_back(split_result[0]); + } - TORCH_CHECK(nb_groups == 1, "Failed to create single resource group"); + // --- Workqueue config resource --- + if (workqueue_scope.has_value()) { +#if HAS_CUDA_WORKQUEUE_SUPPORT() + TORCH_CHECK( + driver_version >= 13010, "cuda driver too old to use workqueue configuration!"); + CUdevResource wq_resource{}; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &wq_resource, CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG)); + + wq_resource.wqConfig.sharingScope = + static_cast(*workqueue_scope); + if (workqueue_concurrency_limit.has_value()) { + wq_resource.wqConfig.wqConcurrencyLimit = *workqueue_concurrency_limit; + } + resources.push_back(wq_resource); +#else + TORCH_CHECK( + false, + "Workqueue configuration for green contexts requires CUDA 13.1+!"); +#endif + } // Generate resource descriptor CUdevResourceDesc desc; C10_CUDA_DRIVER_CHECK( c10::cuda::DriverAPI::get()->cuDevResourceGenerateDesc_( - &desc, result_data, 1)); + &desc, + resources.data(), + static_cast(resources.size()))); // Create green context // CU_GREEN_CTX_DEFAULT_STREAM is required per docs: @@ -83,20 +127,45 @@ GreenContext::GreenContext(uint32_t device_id, uint32_t num_sms) { #else TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif - } +} - std::unique_ptr GreenContext::create( - uint32_t num_sms, - std::optional device_id) { +std::unique_ptr GreenContext::create( + std::optional device_id, + std::optional num_sms, + std::optional workqueue_scope, + std::optional workqueue_concurrency_limit) { #if HAS_CUDA_GREEN_CONTEXT() - if (!device_id.has_value()) { - device_id = at::cuda::current_device(); - } - return std::unique_ptr(new GreenContext(device_id.value(), num_sms)); + if (!device_id.has_value()) { + device_id = at::cuda::current_device(); + } + return std::unique_ptr(new GreenContext( + device_id.value(), num_sms, workqueue_scope, workqueue_concurrency_limit)); #else - TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); + TORCH_CHECK(false, "Green Context is only supported on CUDA 12.8+!"); #endif +} + +uint32_t GreenContext::max_workqueue_concurrency( + std::optional device_id) { +#if HAS_CUDA_WORKQUEUE_SUPPORT() + int driver_version; + C10_CUDA_CHECK(cudaDriverGetVersion(&driver_version)); + TORCH_CHECK( + driver_version >= 13010, "cuda driver too old to use workqueue configuration!"); + if (!device_id.has_value()) { + device_id = at::cuda::current_device(); } + CUdevice device; + C10_CUDA_DRIVER_CHECK( + c10::cuda::DriverAPI::get()->cuDeviceGet_(&device, device_id.value())); + CUdevResource wq_resource; + C10_CUDA_DRIVER_CHECK(c10::cuda::DriverAPI::get()->cuDeviceGetDevResource_( + device, &wq_resource, CU_DEV_RESOURCE_TYPE_WORKQUEUE_CONFIG)); + return wq_resource.wqConfig.wqConcurrencyLimit; +#else + TORCH_CHECK(false, "Workqueue configuration requires CUDA 13.1+!"); +#endif +} // Implement move operations #if HAS_CUDA_GREEN_CONTEXT() diff --git a/aten/src/ATen/cuda/CUDAGreenContext.h b/aten/src/ATen/cuda/CUDAGreenContext.h index b1eda3f48a45a..a93eedfaeedfa 100644 --- a/aten/src/ATen/cuda/CUDAGreenContext.h +++ b/aten/src/ATen/cuda/CUDAGreenContext.h @@ -11,12 +11,24 @@ namespace { constexpr int kStreamPerGreenContextPool = 32; } +// Workqueue sharing scope for green contexts. +// Values match the CUDA driver API's CUdevWorkqueueConfigScope enum. +enum class WorkqueueScope : int32_t { + DeviceCtx = 0, + Balanced = 1, +}; + class TORCH_CUDA_CPP_API GreenContext { public: - // Green context creation static std::unique_ptr create( - uint32_t num_sms, - std::optional device_id); + std::optional device_id, + std::optional num_sms, + std::optional workqueue_scope = std::nullopt, + std::optional workqueue_concurrency_limit = std::nullopt); + + static uint32_t max_workqueue_concurrency( + std::optional device_id = std::nullopt); + ~GreenContext() noexcept; // Delete copy constructor and assignment @@ -31,7 +43,12 @@ class TORCH_CUDA_CPP_API GreenContext { CUDAStream Stream(); private: - GreenContext(uint32_t device_id, uint32_t num_sms); + GreenContext( + uint32_t device_id, + std::optional num_sms, + std::optional workqueue_scope, + std::optional workqueue_concurrency_limit); + // Implement move operations GreenContext(GreenContext&& other) noexcept; GreenContext& operator=(GreenContext&& other) noexcept; diff --git a/aten/src/ATen/cuda/CachingHostAllocator.cpp b/aten/src/ATen/cuda/CachingHostAllocator.cpp index dcff81333cb2f..c376eb3a0a4d5 100644 --- a/aten/src/ATen/cuda/CachingHostAllocator.cpp +++ b/aten/src/ATen/cuda/CachingHostAllocator.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include #include @@ -14,6 +16,25 @@ using Block = HostBlock; struct CUDACachingHostAllocatorImpl : public CachingHostAllocatorImpl { + void free(void* ctx) override { + using Base = CachingHostAllocatorImpl; + try { + Base::free(ctx); + } catch (...) { + if (!c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig:: + pinned_free_catch_all()) { + TORCH_WARN("Exception in pinned allocator free(), rethrowing"); + throw; + } + // pinned_free_catch_all is enabled: suppress the exception to prevent + // it from escaping through ~StorageImpl() (implicitly noexcept), which + // would cause std::terminate. Allows graceful shutdown to proceed. + STATIC_GAUGE(pytorch.CUDACachingHostAllocator.free_fail_catch_all) + .record(1); + TORCH_WARN("Suppressed exception in pinned allocator free()"); + } + } + private: ska::flat_hash_map use_host_register; @@ -245,9 +266,7 @@ struct CUDACachingHostAllocatorImpl } bool stream_is_capturing(CUDAStream s) const override { - cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone}; - C10_CUDA_CHECK(cudaStreamIsCapturing(s, &status)); - return status != cudaStreamCaptureStatusNone; + return s.is_capturing(); } }; diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index f3df32749fa6b..dc9f19a670ab5 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -4,12 +4,18 @@ #include +#include #include #include #include +#include #include #include +#if defined(USE_ROCM) +#include +#endif + /** * Note [hipblaslt handles] * ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -24,12 +30,22 @@ * For CUDA builds, getCurrentCUDABlasLtHandle will alias for getCurrentCUDABlasHandle, * whereas for ROCm builds, it is a distinct function. * + * Additionally, hipblaslt cannot share a single handle across multiple streams. + * On ROCm, getCurrentCUDABlasLtHandle returns a handle unique to each (device, stream) + * pair, rather than just per-device like the cublas handle pool. + * * The workspace pools are separate for ROCm. On CUDA, the env var * TORCH_CUBLASLT_UNIFIED_WORKSPACE can be used to opt-in to unifying the workspace pools. */ namespace at::cuda { +namespace { +// -1 means no override; use env var / default +std::atomic cublas_workspace_override{-1}; +std::atomic cublaslt_workspace_override{-1}; +} // namespace + namespace { #if defined(USE_ROCM) @@ -53,8 +69,6 @@ void destroyCublasLtHandle(cublasLtHandle_t handle) { using CuBlasLtPoolType = DeviceThreadHandlePool; // ugly hack until hipblasSetWorkspace exists -#include - static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) { switch(error) { case rocblas_status_size_unchanged: @@ -135,24 +149,16 @@ void clearCublasWorkspacesForStream(cudaStream_t stream) { { auto& workspace = cublas_handle_stream_to_workspace(); std::unique_lock lock(workspace.mutex); - for (auto it = workspace.map.begin(); it != workspace.map.end(); ) { - if (std::get<1>(it->first) == stream_ptr) { - it = workspace.map.erase(it); - } else { - ++it; - } - } + std::erase_if(workspace.map, [stream_ptr](const auto& entry) { + return std::get<1>(entry.first) == stream_ptr; + }); } { auto& workspace = cublaslt_handle_stream_to_workspace(); std::unique_lock lock(workspace.mutex); - for (auto it = workspace.map.begin(); it != workspace.map.end(); ) { - if (std::get<1>(it->first) == stream_ptr) { - it = workspace.map.erase(it); - } else { - ++it; - } - } + std::erase_if(workspace.map, [stream_ptr](const auto& entry) { + return std::get<1>(entry.first) == stream_ptr; + }); } } @@ -247,11 +253,43 @@ size_t parseCUDABlasLtWorkspaceSize() { return workspace_size * 1024; } +size_t getChosenWorkspaceSize() { + int64_t ov = cublas_workspace_override.load(std::memory_order_relaxed); + if (ov >= 0) { + return static_cast(ov); + } + static size_t pool_size = parseChosenWorkspaceSize(); + return pool_size; +} + +void setChosenWorkspaceSize(size_t size) { + cublas_workspace_override.store(static_cast(size), std::memory_order_relaxed); +} + +void setCUDABlasLtWorkspaceSize(size_t size) { + cublaslt_workspace_override.store(static_cast(size), std::memory_order_relaxed); +} + +void resetChosenWorkspaceSize() { + cublas_workspace_override.store(-1, std::memory_order_relaxed); +} + +void resetCUDABlasLtWorkspaceSize() { + cublaslt_workspace_override.store(-1, std::memory_order_relaxed); +} + size_t getCUDABlasLtWorkspaceSize() { - size_t pool_size = parseCUDABlasLtWorkspaceSize(); + int64_t ov = cublaslt_workspace_override.load(std::memory_order_relaxed); + const size_t pool_size = [&] { + if (ov >= 0) { + return static_cast(ov); + } + static size_t parsed_pool_size = parseCUDABlasLtWorkspaceSize(); + return parsed_pool_size; + }(); #ifndef USE_ROCM if (unified_cublas_and_lt_workspaces()) { - auto cublasWorkspaceSize = parseChosenWorkspaceSize(); + size_t cublasWorkspaceSize = getChosenWorkspaceSize(); if (cublasWorkspaceSize < pool_size) { TORCH_WARN_ONCE("Requested unified CUBLASLT workspace size of ", pool_size, " bytes exceeds CUBLAS workspace size of ", cublasWorkspaceSize, @@ -259,7 +297,7 @@ size_t getCUDABlasLtWorkspaceSize() { " via CUBLAS_WORKSPACE_CONFIG or decrease requested" " CUBLASLT_WORKSPACE_SIZE. Otherwise CUBLASLT workspace" " size will be limited to the CUBLAS workspace size."); - pool_size = cublasWorkspaceSize; + return cublasWorkspaceSize; } } #endif @@ -267,7 +305,7 @@ size_t getCUDABlasLtWorkspaceSize() { } at::DataPtr getNewWorkspace() { - return c10::cuda::CUDACachingAllocator::get()->allocate(parseChosenWorkspaceSize()); + return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize()); } at::DataPtr getNewCUDABlasLtWorkspace() { @@ -280,15 +318,15 @@ void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) auto& workspace = cublas_handle_stream_to_workspace(); - size_t workspace_size = parseChosenWorkspaceSize(); + size_t workspace_size = getChosenWorkspaceSize(); - // Fast path: check if workspace already exists + // Fast path: check if workspace already exists and is large enough { std::shared_lock lock(workspace.mutex); auto workspace_it = workspace.map.find(key); - if (workspace_it != workspace.map.end()) { + if (workspace_it != workspace.map.end() && workspace_it->second.second >= workspace_size) { TORCH_CUDABLAS_CHECK(cublasSetWorkspace( - handle, workspace_it->second.get(), workspace_size)); + handle, workspace_it->second.first.get(), workspace_size)); return; } } @@ -296,28 +334,39 @@ void setWorkspaceForHandle(cublasHandle_t handle, c10::cuda::CUDAStream stream) // Slow path: allocate workspace outside the lock auto new_workspace = getNewWorkspace(); - // Insert with lock (double-check in case another thread inserted while we - // were allocating) + // Insert with lock, replacing any undersized entry { std::unique_lock lock(workspace.mutex); - auto workspace_it = workspace.map.try_emplace(key, std::move(new_workspace)).first; + workspace.map.insert_or_assign(key, std::make_pair(std::move(new_workspace), workspace_size)); + auto workspace_it = workspace.map.find(key); TORCH_CUDABLAS_CHECK( - cublasSetWorkspace(handle, workspace_it->second.get(), workspace_size)); + cublasSetWorkspace(handle, workspace_it->second.first.get(), workspace_size)); } } void* getCUDABlasLtWorkspace() { #ifndef USE_ROCM if (unified_cublas_and_lt_workspaces()) { - cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(/*setup=*/false); auto stream = c10::cuda::getCurrentCUDAStream(); cudaStream_t _stream = stream; auto key = std::make_tuple(static_cast(handle), static_cast(_stream)); auto& workspace = at::cuda::cublas_handle_stream_to_workspace(); - std::shared_lock lock(workspace.mutex); - auto workspace_it = workspace.map.find(key); - TORCH_INTERNAL_ASSERT(workspace_it != workspace.map.end()); - return workspace_it->second.mutable_get(); + { + std::shared_lock lock(workspace.mutex); + auto workspace_it = workspace.map.find(key); + if (workspace_it != workspace.map.end()) { + return workspace_it->second.first.mutable_get(); + } + } + // First use for this handle+stream pair — allocate and insert directly. + // No need to call cublasSetWorkspace; Lt passes workspace explicitly. + auto new_workspace = getNewWorkspace(); + { + std::unique_lock lock(workspace.mutex); + auto workspace_it = workspace.map.try_emplace(key, std::make_pair(std::move(new_workspace), getChosenWorkspaceSize())).first; + return workspace_it->second.first.mutable_get(); + } } #endif cublasLtHandle_t handle = getCurrentCUDABlasLtHandle(); @@ -327,29 +376,30 @@ void* getCUDABlasLtWorkspace() { auto& workspace = cublaslt_handle_stream_to_workspace(); - // Fast path: check if workspace already exists + size_t workspace_size = getCUDABlasLtWorkspaceSize(); + + // Fast path: check if workspace already exists and is large enough { std::shared_lock lock(workspace.mutex); auto workspace_it = workspace.map.find(key); - if (workspace_it != workspace.map.end()) { - return workspace_it->second.mutable_get(); + if (workspace_it != workspace.map.end() && workspace_it->second.second >= workspace_size) { + return workspace_it->second.first.mutable_get(); } } // Slow path: allocate workspace outside the lock auto new_workspace = getNewCUDABlasLtWorkspace(); - // Insert with lock (double-check in case another thread inserted while we - // were allocating) + // Insert with lock, replacing any undersized entry { std::unique_lock lock(workspace.mutex); - auto workspace_it = - workspace.map.try_emplace(key, std::move(new_workspace)).first; - return workspace_it->second.mutable_get(); + workspace.map.insert_or_assign(key, std::make_pair(std::move(new_workspace), workspace_size)); + auto workspace_it = workspace.map.find(key); + return workspace_it->second.first.mutable_get(); } } -cublasHandle_t getCurrentCUDABlasHandle() { +cublasHandle_t getCurrentCUDABlasHandle(bool setup) { c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); @@ -357,8 +407,6 @@ cublasHandle_t getCurrentCUDABlasHandle() { CUcontext pctx = nullptr; at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx); if (C10_UNLIKELY(!pctx)) { - // workaround for corner case where a primary context exists but is not - // the current context, seen in multithreaded use-cases TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context..."); at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device); at::globalContext().getNVRTC().cuCtxSetCurrent(pctx); @@ -381,6 +429,11 @@ cublasHandle_t getCurrentCUDABlasHandle() { pool->newPoolWindow()); auto handle = myPoolWindow->reserve(device); + + if (!setup) { + return handle; + } + auto stream = c10::cuda::getCurrentCUDAStream(); TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream)); // We explicitly set the cublas workspace even though CUDA 12.2+ fixed the @@ -434,10 +487,14 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() { thread_local std::unique_ptr myPoolWindow( pool->newPoolWindow()); - auto handle = myPoolWindow->reserve(device); + // hipblaslt cannot share a single handle across multiple streams, + // so reserve a handle unique to each (device, stream) pair. + auto stream = c10::cuda::getCurrentCUDAStream(); + cudaStream_t _stream = stream; + auto handle = myPoolWindow->reserve(device, static_cast(_stream)); return handle; #else - return reinterpret_cast(getCurrentCUDABlasHandle()); + return reinterpret_cast(getCurrentCUDABlasHandle(/*setup=*/false)); #endif } diff --git a/aten/src/ATen/cuda/StatelessPhilox4x32.cuh b/aten/src/ATen/cuda/StatelessPhilox4x32.cuh new file mode 100644 index 0000000000000..7d21af6b33480 --- /dev/null +++ b/aten/src/ATen/cuda/StatelessPhilox4x32.cuh @@ -0,0 +1,61 @@ +// Stateless Philox-4x32 PRNG implementation. +// +// Unlike PhiloxRNGEngine (PhiloxUtils.cuh), this is a pure function: given +// (seed, offset) it returns 4 pseudo-random uint32 values with no mutable +// state. This makes it suitable for use in stateless random APIs. +// +// The Philox-4x32 cipher operates on a 128-bit counter. The full counter +// is (offset_lo, offset_hi, subsequence_lo, subsequence_hi), but we fix +// subsequence=0 so that the entire 128-bit counter space is addressed by +// the 64-bit offset alone. This keeps the API simple and maintains +// cross-device consistency. For example, utilizing thread ID-based subsequence +// numbers and SM-based thread count causes different random values to +// be generated across GPU types. We avoid this situation by always setting +// subsequence=0. + +#pragma once + +#include + +namespace at::cuda { + +__device__ __forceinline__ uint2 mulhilo32(uint32_t a, uint32_t b) { + return {a * b, __umulhi(a, b)}; +} + +__device__ __forceinline__ uint4 philox_round(uint4 ctr, uint2 key) { + constexpr uint32_t kPhiloxSA = 0xD2511F53; + constexpr uint32_t kPhiloxSB = 0xCD9E8D57; + uint2 r0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 r1 = mulhilo32(kPhiloxSB, ctr.z); + return {r1.y ^ ctr.y ^ key.x, r1.x, r0.y ^ ctr.w ^ key.y, r0.x}; +} + +// Stateless Philox-4x32. Returns 4 pseudo-random uint32 values (128 bits) +// determined entirely by (seed, offset). Each unique offset produces a +// distinct 128-bit output. +template +__device__ __forceinline__ uint4 philox_4x32( + uint64_t seed, uint64_t offset) { + uint2 key = { + static_cast(seed), + static_cast(seed >> 32)}; + uint4 ctr = { + static_cast(offset), + static_cast(offset >> 32), + // restrict subsequence=0 + 0, 0}; + + constexpr uint32_t kPhilox10A = 0x9E3779B9; + constexpr uint32_t kPhilox10B = 0xBB67AE85; + + #pragma unroll + for (int i = 0; i < N_ROUNDS - 1; i++) { + ctr = philox_round(ctr, key); + key.x += kPhilox10A; + key.y += kPhilox10B; + } + return philox_round(ctr, key); +} + +} // namespace at::cuda diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 03a3a97525a43..5f81407b1ac03 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -86,9 +85,6 @@ struct _Initializer { // let's not if we don't need to!) void CUDAHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.cuda"); - // Force the update to enable unit testing. This code get executed before unit tests - // have a chance to enable vitals. - at::vitals::VitalsAPI.setVital("CUDA", "used", "true", /* force = */ true); const auto num_devices = c10::cuda::device_count_ensure_non_zero(); c10::cuda::CUDACachingAllocator::init(num_devices); diff --git a/aten/src/ATen/cuda/detail/DeviceThreadHandles.h b/aten/src/ATen/cuda/detail/DeviceThreadHandles.h index 71a344d281d2a..9ec194fc878b6 100644 --- a/aten/src/ATen/cuda/detail/DeviceThreadHandles.h +++ b/aten/src/ATen/cuda/detail/DeviceThreadHandles.h @@ -15,6 +15,7 @@ #pragma once +#include #include #include #include @@ -114,9 +115,45 @@ struct DeviceThreadHandlePool : public std::enable_shared_from_this guard(parent->mutex); + + if(parent->available_handles[device].size() > 0) + { + my_stream_handles[key] = parent->available_handles[device].back(); + parent->available_handles[device].pop_back(); + } + else + { + parent->created_handles[device].emplace_back(true /*create*/); + my_stream_handles[key] = parent->created_handles[device].back().handle; + } + + return my_stream_handles[key]; + } +#endif + private: // Stores the per-device handles currently owned by this thread std::unordered_map my_handles; +#ifdef USE_ROCM + // Stores per-(device, stream) handles for ROCm, where hipblaslt + // requires a unique handle per stream. + std::map, Handle_t> my_stream_handles; +#endif std::weak_ptr weak_parent; @@ -134,6 +171,18 @@ struct DeviceThreadHandlePool : public std::enable_shared_from_thisavailable_handles[d_h.first].push_back(d_h.second); } +#ifdef USE_ROCM + if(!my_stream_handles.empty()) { + auto parent = weak_parent.lock(); + if (!parent) { + return; + } + + std::lock_guard guard(parent->mutex); + for(auto& [key, handle] : my_stream_handles) + parent->available_handles[key.first].push_back(handle); + } +#endif } }; diff --git a/aten/src/ATen/cuda/jiterator.cu b/aten/src/ATen/cuda/jiterator.cu index 0545c8354eda3..c75da3298ebf4 100644 --- a/aten/src/ATen/cuda/jiterator.cu +++ b/aten/src/ATen/cuda/jiterator.cu @@ -356,6 +356,9 @@ c10::SmallVector CompileAndLaunchKernel( at::native::jitted_gpu_kernel_dynamic(kernel_name, iter, code_string, extra_args, return_by_ref); c10::SmallVector outputs; + if (num_outputs > 0) { + outputs.reserve(num_outputs); + } for (int i = 0; i < num_outputs; ++i) { outputs.emplace_back(iter.output(i)); } diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 29affa2d21ff1..29c15720f4a66 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -94,6 +94,16 @@ constexpr hipDataType HipDataTypeFor() { #endif } +template +constexpr hipblasComputeType_t HipBlasComputeTypeFor() { + return HIPBLAS_COMPUTE_32F; +} + +template <> +constexpr hipblasComputeType_t HipBlasComputeTypeFor() { + return HIPBLAS_COMPUTE_64F; +} + template int GetBatchFromParams(const GemmParams* params) { return 1; @@ -175,43 +185,43 @@ int GetStrideCFromParams(const ScaledGemmParams* params) { } template -float GetAlphaFromParams(const GemmParams* params) { +at::opmath_type GetAlphaFromParams(const GemmParams* params) { return params->alpha; } template -float GetAlphaFromParams(const GemmAndBiasParams* params) { +at::opmath_type GetAlphaFromParams(const GemmAndBiasParams* params) { return params->alpha; } template -float GetAlphaFromParams(const GemmStridedBatchedParams* params) { +at::opmath_type GetAlphaFromParams(const GemmStridedBatchedParams* params) { return params->alpha; } template -float GetAlphaFromParams(const ScaledGemmParams* params) { - return 1.0; +at::opmath_type GetAlphaFromParams(const ScaledGemmParams* params) { + return at::opmath_type{1.0}; } template -float GetBetaFromParams(const GemmParams* params) { +at::opmath_type GetBetaFromParams(const GemmParams* params) { return params->beta; } template -float GetBetaFromParams(const GemmAndBiasParams* params) { - return 0.0; +at::opmath_type GetBetaFromParams(const GemmAndBiasParams* params) { + return at::opmath_type{0.0}; } template -float GetBetaFromParams(const GemmStridedBatchedParams* params) { +at::opmath_type GetBetaFromParams(const GemmStridedBatchedParams* params) { return params->beta; } template -float GetBetaFromParams(const ScaledGemmParams* params) { - return 0.0; +at::opmath_type GetBetaFromParams(const ScaledGemmParams* params) { + return at::opmath_type{0.0}; } template @@ -467,8 +477,9 @@ class HipblasltGemmOp : public Callable { TORCH_CHECK(transa_outer == opa && transb_outer == opb, "trans mismatch, shouldn't happen"); - float alpha = GetAlphaFromParams(params); - float beta = GetBetaFromParams(params); + using opmath_t = at::opmath_type; + opmath_t alpha = GetAlphaFromParams(params); + opmath_t beta = GetBetaFromParams(params); hipblasLtMatrixLayout_t mat_a, mat_b, mat_c; if (opa == HIPBLAS_OP_N) { @@ -505,11 +516,14 @@ class HipblasltGemmOp : public Callable { mat_c, HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c, sizeof(stride_c))); } - hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; - if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) { - computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + hipblasComputeType_t computeType = HipBlasComputeTypeFor(); + if constexpr (std::is_same_v) { + if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } } - HipBlasLtMatmulDescriptor matmul(computeType, HIP_R_32F); + auto scale_type = HipDataTypeFor(); + HipBlasLtMatmulDescriptor matmul(computeType, scale_type); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); matmul.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); @@ -630,9 +644,11 @@ auto GetHipBlasLtTypeStringAndOps() { } #endif - hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; - if (at::globalContext().allowTF32CuBLAS()) { - computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + hipblasComputeType_t computeType = HipBlasComputeTypeFor(); + if constexpr (std::is_same_v) { + if (at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATMUL) == at::Float32Precision::TF32) { + computeType = HIPBLAS_COMPUTE_32F_FAST_TF32; + } } hipblasLtHandle_t handle; diff --git a/aten/src/ATen/functorch/BatchRulesActivation.cpp b/aten/src/ATen/functorch/BatchRulesActivation.cpp index 92b5527db77c5..47fe73e4212c3 100644 --- a/aten/src/ATen/functorch/BatchRulesActivation.cpp +++ b/aten/src/ATen/functorch/BatchRulesActivation.cpp @@ -21,8 +21,8 @@ glu_batch_rule(const Tensor& self, std::optional self_bdim, int64_t dim const auto self_ = moveBatchDimToFront(self, self_bdim); - const auto res = at::glu(self_, dim_); - return std::make_tuple(res, 0); + auto res = at::glu(self_, dim_); + return std::make_tuple(std::move(res), 0); } static std::tuple> glu_backward_batch_rule( @@ -42,8 +42,8 @@ static std::tuple> glu_backward_batch_rule( const auto grad_output_ = ensure_has_bdim(moveBatchDimToFront(grad_output, grad_output_bdim), grad_output_bdim.has_value(), batch_size); const auto self_ = ensure_has_bdim(moveBatchDimToFront(self, self_bdim), self_bdim.has_value(), batch_size); - const auto res = at::glu_backward(grad_output_, self_, dim_); - return std::make_tuple(res, 0); + auto res = at::glu_backward(grad_output_, self_, dim_); + return std::make_tuple(std::move(res), 0); } diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index 25fe96c6cda15..0acdd148e3eef 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -188,8 +188,8 @@ static std::tuple> masked_select_batch_rule( self_ = maybePadToLogicalRank(self_, 0, max_logical_rank); // masked_select returns a 1D tensor, so we have to reshape it into 2D - const auto result = at::masked_select(self_, mask).view({ batch_size, -1 }); - return std::make_tuple(result, 0); + auto result = at::masked_select(self_, mask).view({ batch_size, -1 }); + return std::make_tuple(std::move(result), 0); } static std::tuple> masked_select_backward_batch_rule( @@ -213,8 +213,8 @@ static std::tuple> masked_select_backward_batch_r self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size); grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), batch_size); - const auto result = at::masked_select_backward(grad_, self_.contiguous(), mask); - return std::make_tuple(result, 0); + auto result = at::masked_select_backward(grad_, self_.contiguous(), mask); + return std::make_tuple(std::move(result), 0); } static std::tuple> cdist_backward_batch_rule( @@ -294,7 +294,7 @@ rrelu_with_noise_batch_rule( auto ret = at::rrelu_with_noise(self_, noise_, lower, upper, training, std::move(generator)); - return std::make_tuple(ret, 0, noise_, 0); + return std::make_tuple(std::move(ret), 0, std::move(noise_), 0); } static Tensor rrelu_with_noise_batch( diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp index 748d5b1687a3c..602517256cef1 100644 --- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp +++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp @@ -329,7 +329,7 @@ convolution_backward_weight_batch_rule( dilation, transposed, output_padding, groups, mask); auto& grad_weight = std::get<1>(result); grad_weight = reshape_dim_outof_symint(1, batch_size, grad_weight); - return std::make_tuple(grad_weight, 1); + return std::make_tuple(std::move(grad_weight), 1); } else { // transposed: N(GO), BN(GI) -> N(GO), N(GBI) -> (GBI)O const auto dummy_weight = make_dummy(weight, weight_bdim, 0, batch_size); diff --git a/aten/src/ATen/functorch/BatchRulesFactory.cpp b/aten/src/ATen/functorch/BatchRulesFactory.cpp index 34a537a9edb40..d5450ed972115 100644 --- a/aten/src/ATen/functorch/BatchRulesFactory.cpp +++ b/aten/src/ATen/functorch/BatchRulesFactory.cpp @@ -94,14 +94,14 @@ static std::tuple> _new_zeros_with_same_feature_m // [K0, K1, B, 6], [B, 5], 2 -> [K0, K1, B, 5] tangent_ = tangent.movedim(*tangent_bdim, self_num_batch_dims); } - const auto result = at::_new_zeros_with_same_feature_meta(tangent_, base_, self_num_batch_dims); - return std::make_tuple(result, self_num_batch_dims); + auto result = at::_new_zeros_with_same_feature_meta(tangent_, base_, self_num_batch_dims); + return std::make_tuple(std::move(result), self_num_batch_dims); } // Case 1: auto tangent_ = moveBatchDimToFront(tangent, tangent_bdim); auto result = at::_new_zeros_with_same_feature_meta(tangent_, base, self_num_batch_dims + 1); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } static std::tuple> linspace_logspace_batch_rule_helper( @@ -139,7 +139,7 @@ static std::tuple> linspace_logspace_batch_rule_h result = result.to(*dtype); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } static std::tuple> linspace_Tensor_Tensor_batch_rule( diff --git a/aten/src/ATen/functorch/BatchRulesHelper.h b/aten/src/ATen/functorch/BatchRulesHelper.h index 0d2f075d0c540..f4583ac32a4a0 100644 --- a/aten/src/ATen/functorch/BatchRulesHelper.h +++ b/aten/src/ATen/functorch/BatchRulesHelper.h @@ -141,6 +141,8 @@ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::S auto arguments = torch::jit::pop(*stack, num_arguments); std::vector>> tensor_inputs; std::vector tensor_pos; + tensor_inputs.reserve(num_arguments); + tensor_pos.reserve(num_arguments); for (const auto idx : c10::irange(0, num_arguments)) { const auto& ivalue = arguments[idx]; if (ivalue.isTensor()) { diff --git a/aten/src/ATen/functorch/BatchRulesModules.cpp b/aten/src/ATen/functorch/BatchRulesModules.cpp index 4e0b50c4e3fe7..51cb1d4751d6f 100644 --- a/aten/src/ATen/functorch/BatchRulesModules.cpp +++ b/aten/src/ATen/functorch/BatchRulesModules.cpp @@ -168,13 +168,20 @@ static std::tuple, Tensor, std::optional grid_sample_backward_helper_out( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::tuple bw_out, - int64_t grad_input_out_bdim, - int64_t grad_grid_out_bdim, + int64_t grad_input_bdim, + int64_t grad_grid_bdim, int64_t bdim_size) { auto& [grad_input, grad_grid] = bw_out; - grad_input = reshape_dim_outof(grad_input_out_bdim, bdim_size, grad_input); - grad_grid = reshape_dim_outof(grad_grid_out_bdim, bdim_size, grad_grid); - return std::make_tuple(std::move(grad_input), grad_input_out_bdim, std::move(grad_grid), grad_grid_out_bdim); + std::optional grad_input_bdim_out, grad_grid_bdim_out; + if (grad_input.defined()) { + grad_input = reshape_dim_outof(grad_input_bdim, bdim_size, grad_input); + grad_input_bdim_out = grad_input_bdim; + } + if (grad_grid.defined()) { + grad_grid = reshape_dim_outof(grad_grid_bdim, bdim_size, grad_grid); + grad_grid_bdim_out = grad_grid_bdim; + } + return std::make_tuple(std::move(grad_input), grad_input_bdim_out, std::move(grad_grid), grad_grid_bdim_out); } diff --git a/aten/src/ATen/functorch/BatchRulesNorm.cpp b/aten/src/ATen/functorch/BatchRulesNorm.cpp index 51dae00e6b7ed..affb6ce369f2b 100644 --- a/aten/src/ATen/functorch/BatchRulesNorm.cpp +++ b/aten/src/ATen/functorch/BatchRulesNorm.cpp @@ -284,7 +284,7 @@ std::tuple batch_norm_backward_plumbing( training, eps); grad_input = makeBatched(std::move(std::get<0>(results)), std::get<1>(results), cur_level); } - return std::make_tuple(grad_input, grad_weight, grad_bias); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } static std::tuple native_group_norm_plumbing( @@ -351,7 +351,7 @@ static at::Tensor group_norm_backward_no_weight_bias_batch_rule( auto rstd_ = moveBatchDimToFront(rstd, rstd_bdim); const auto bdim_size = get_bdim_size2(grad_out, grad_out_bdim, input, input_bdim); - grad_out_ = ensure_has_bdim(grad_out, grad_out_bdim.has_value(), bdim_size); + grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size); input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size); mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size); rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size); @@ -432,7 +432,7 @@ static std::tuple native_group_norm_backward_plumbing( ); grad_input = makeBatched(std::move(tensor), 0, cur_level); } - return std::make_tuple(grad_input, grad_weight, grad_bias); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } static bool has_same_shape( @@ -524,7 +524,7 @@ native_layer_norm_batch_rule( bias_ = maybePadToLogicalRank(bias_, /*has_bdim*/bias_bdim, result_logical_rank); result0 = result0 + bias_; } - return std::make_tuple(result0, 0, mean, stats_bdim, rstd, stats_bdim); + return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim); } static std::tuple> native_layer_norm_backward_no_weight_bias_batch_rule( @@ -536,9 +536,9 @@ static std::tuple> native_layer_norm_backward if (!grad_out_bdim.has_value() && !input_bdim.has_value() && !mean_bdim.has_value() && !rstd_bdim.has_value()) { - const auto result = at::native_layer_norm_backward( + auto result = at::native_layer_norm_backward( grad_out, input, normalized_shape, mean, rstd, std::nullopt, std::nullopt, {true, false, false}); - return std::make_tuple(std::get<0>(result), std::nullopt); + return std::make_tuple(std::get<0>(std::move(result)), std::nullopt); } auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim); @@ -644,7 +644,7 @@ static std::tuple native_layer_norm_backward_p rstd_value, rstd_bdim); grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level); } - return std::make_tuple(grad_input, grad_weight, grad_bias); + return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } template diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 80034ff95ca3c..eab17f23ce2bf 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -647,7 +647,7 @@ std::tuple> index_put_batch_rule( values_ = maybe_permute_values(values_, indices, indices_bdims); auto result = at::index_put(self_, List>(indices_), values_, accumulate); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } // plumbing done since we don't support List> in codegen @@ -704,7 +704,7 @@ std::tuple> scatter_batch_rule( if (self_logical_rank == 0) { result = result.squeeze(-1); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } template @@ -742,7 +742,7 @@ inline std::tuple> scatter_batch_rule( if (self_logical_rank == 0) { result = result.squeeze(-1); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } } // namespace @@ -774,11 +774,22 @@ std::tuple> scatter_add_batch_rule( self, self_bdim, dim, index, index_bdim, src, src_bdim); } +static void check_scatter_inplace_bdim( + std::optional self_bdim, + std::optional index_bdim, + std::optional src_bdim, + const char* schema_name) { + if (!self_bdim.has_value() && (index_bdim.has_value() || src_bdim.has_value())) { + vmapIncompatibleInplaceError(schema_name); + } +} + std::tuple> scatter_add__batch_rule( const Tensor& self, std::optional self_bdim, int64_t dim, const Tensor& index, std::optional index_bdim, const Tensor& src, std::optional src_bdim) { + check_scatter_inplace_bdim(self_bdim, index_bdim, src_bdim, "scatter_add_"); return scatter_batch_rule(ATEN_FN(scatter_add_), self, self_bdim, dim, index, index_bdim, src, src_bdim); } @@ -811,6 +822,8 @@ std::tuple> scatter_reduce__two_batch_rule( const Tensor& src, std::optional src_bdim, const std::string_view reduce, bool include_self) { + check_scatter_inplace_bdim( + self_bdim, index_bdim, src_bdim, "scatter_reduce_"); return scatter_batch_rule(ATEN_FN2(scatter_reduce_, two), self, self_bdim, dim, index, index_bdim, src, src_bdim, reduce, include_self); } @@ -852,7 +865,7 @@ std::tuple> gather_batch_rule( if (index_logical_rank == 0) { result = result.squeeze(-1); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } Tensor get_expanded_index(const Tensor& index, SymIntArrayRef self_size, int64_t dim) { @@ -1000,7 +1013,7 @@ std::tuple> index_add_batch_rule_impl( if (self_logical_rank == 0) { result = result.squeeze(-1); } - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } // Index is batched. For-loop and stack is the best thing I can come up with @@ -1071,7 +1084,7 @@ std::tuple binary_pointwise_align( tensor_ = maybePadToLogicalRank(tensor_, self_bdim, max_logical_rank); other_ = maybePadToLogicalRank(other_, mask_bdim, max_logical_rank); - return std::make_tuple(tensor_, other_); + return std::make_tuple(std::move(tensor_), std::move(other_)); } std::tuple> masked_fill_scalar_batch_rule( @@ -1082,7 +1095,7 @@ std::tuple> masked_fill_scalar_batch_rule( const Scalar& source) { auto tensors = binary_pointwise_align(self, self_bdim, mask, mask_bdim); auto result = at::masked_fill(std::get<0>(tensors), std::get<1>(tensors), source); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } std::tuple> index_fill_batch_rule_helper( diff --git a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp index 3cabdd251480f..ef06451a17314 100644 --- a/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesUnaryOps.cpp @@ -39,12 +39,12 @@ clone_batch_rule( // philosophically vmap hides the batch dims and operates on a per-sample level. auto self_ = moveBatchDimToFront(self, self_bdim); auto result = at::clone(self_, memory_format); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve); auto result = at::clone(self, memory_format); - return std::make_tuple(result, self_bdim); + return std::make_tuple(std::move(result), self_bdim); } std::tuple> @@ -55,7 +55,7 @@ view_as_complex_batch_rule(const Tensor& self, std::optional self_bdim) auto self_ = moveBatchDimToFront(self, self_bdim); auto result = at::view_as_complex(self_); - return std::make_tuple(result, 0); + return std::make_tuple(std::move(result), 0); } } diff --git a/aten/src/ATen/functorch/BatchRulesViews.cpp b/aten/src/ATen/functorch/BatchRulesViews.cpp index 08724d4fc1243..925964ddb17db 100644 --- a/aten/src/ATen/functorch/BatchRulesViews.cpp +++ b/aten/src/ATen/functorch/BatchRulesViews.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -300,7 +301,7 @@ std::tuple> roll_batch_rule(const Tensor& self, s // NOTE: For scalar tensor, we don't need to unsqueeze as reshape // with `old_shape` takes care of it. output = output.reshape_symint(old_shape); - return std::make_tuple(output, 0); + return std::make_tuple(std::move(output), 0); } std::tuple> diagonal_batching_rule( @@ -425,6 +426,19 @@ std::tuple> view_batching_rule( return std::make_tuple(self_.view_symint(size_), 0); } +std::tuple> view_dtype_batch_rule( + const Tensor& self, + std::optional self_bdim, + ScalarType dtype) { + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + TORCH_CHECK( + logical_rank != 0 || self.itemsize() == c10::elementSize(dtype), + "self.dim() cannot be 0 to view ", self.scalar_type(), " as ", dtype, " (different element sizes)"); + auto self_ = moveBatchDimToFront(self, self_bdim); + return std::make_tuple(self_.view(dtype), 0); +} + std::tuple> view_copy_batch_rule( const Tensor& self, std::optional self_bdim, @@ -498,7 +512,6 @@ std::tuple> narrow_copy_batch_rule( auto logical_rank = rankWithoutBatchDim(self, self_bdim); dim = maybe_wrap_dim(dim, logical_rank) + 1; auto result = self_.narrow_copy_symint(dim, std::move(start), std::move(length)); - return std::make_tuple(std::move(result), 0); } @@ -515,6 +528,18 @@ std::tuple, std::optional> unsafe_split_batch_rule( return std::make_tuple(std::move(result), 0); } +std::tuple, std::optional> unbind_copy_batch_rule( + const Tensor& self, + std::optional self_bdim, + int64_t dim) { + TORCH_INTERNAL_ASSERT(self_bdim.has_value()); + auto self_ = moveBatchDimToFront(self, self_bdim); + auto logical_rank = rankWithoutBatchDim(self, self_bdim); + dim = maybe_wrap_dim(dim, logical_rank) + 1; + auto result = at::unbind_copy(self_, dim); + return std::make_tuple(std::move(result), 0); +} + std::tuple> diag_embed_batch_rule(const Tensor& self, std::optional self_bdim, int64_t offset, int64_t dim1, int64_t dim2) { auto logical_rank = rankWithoutBatchDim(self, self_bdim); auto self_ = moveBatchDimToFront(self, self_bdim); @@ -570,7 +595,9 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { VMAP_SUPPORT(diagonal_backward, diagonal_backward_batch_rule); VMAP_SUPPORT(select_backward, select_backward_batch_rule); VMAP_SUPPORT(slice_backward, slice_backward_batch_rule); + VMAP_SUPPORT2(unbind_copy, int, unbind_copy_batch_rule); VMAP_SUPPORT(view, view_batching_rule); + VMAP_SUPPORT2(view, dtype, view_dtype_batch_rule); VMAP_SUPPORT(view_copy, view_copy_batch_rule); VMAP_SUPPORT(expand, SINGLE_ARG(expand_batch_rule)); VMAP_SUPPORT(expand_copy, SINGLE_ARG(expand_batch_rule)); diff --git a/aten/src/ATen/functorch/BatchedFallback.cpp b/aten/src/ATen/functorch/BatchedFallback.cpp index b479639f1c1a5..ee7b9b69eafb6 100644 --- a/aten/src/ATen/functorch/BatchedFallback.cpp +++ b/aten/src/ATen/functorch/BatchedFallback.cpp @@ -412,10 +412,8 @@ void batchedNestedTensorForLoopFallback(const c10::OperatorHandle& op, torch::ji return; } - if (isInplaceOp(schema)) { - TORCH_INTERNAL_ASSERT(false, "vmap fallback not supported for in-place ops on nested tensors"); - return; - } + TORCH_INTERNAL_ASSERT(!isInplaceOp(schema), "vmap fallback not supported for in-place ops on nested tensors"); + TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(), "Nested batching rule not implemented for ", schema.operator_name(), "; ", "the fallback path doesn't work on out= or view ops."); diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.cpp b/aten/src/ATen/functorch/BatchedTensorImpl.cpp index 895770fc69921..72af353064661 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.cpp +++ b/aten/src/ATen/functorch/BatchedTensorImpl.cpp @@ -156,7 +156,6 @@ c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); - return nullptr; } c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( @@ -164,7 +163,6 @@ c10::intrusive_ptr BatchedTensorImpl::shallow_copy_and_detach( c10::VariableVersion&& version_counter, bool allow_tensor_metadata_change) const { TORCH_CHECK(false, "accessing `data` under vmap transform is not allowed"); - return nullptr; } void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr& impl) { diff --git a/aten/src/ATen/functorch/DynamicLayer.cpp b/aten/src/ATen/functorch/DynamicLayer.cpp index 1420aaf0ab943..2ef7c6d23ea2c 100644 --- a/aten/src/ATen/functorch/DynamicLayer.cpp +++ b/aten/src/ATen/functorch/DynamicLayer.cpp @@ -29,7 +29,8 @@ DynamicLayer::DynamicLayer( std::optional randomness, std::optional prev_grad_mode, std::optional prev_fwd_grad_mode, - std::optional functionalize_add_back_views) + std::optional functionalize_add_back_views, + std::optional prev_inference_mode) { if (transform_type == TransformType::Grad) { TORCH_INTERNAL_ASSERT(prev_grad_mode.has_value()); @@ -43,10 +44,10 @@ DynamicLayer::DynamicLayer( interpreter_ = Interpreter::Vmap(layerId, std::move(batchSize.value()), randomness.value()); break; case TransformType::Grad: - interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value()); + interpreter_ = Interpreter::Grad(layerId, prev_grad_mode.value(), prev_inference_mode.value_or(false)); break; case TransformType::Jvp: - interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value()); + interpreter_ = Interpreter::Jvp(layerId, prev_fwd_grad_mode.value(), prev_inference_mode.value_or(false)); break; case TransformType::Functionalize: // NOLINTNEXTLINE(bugprone-unchecked-optional-access) @@ -246,10 +247,11 @@ int64_t initAndPushDynamicLayer( std::optional randomness, std::optional prev_grad_mode, std::optional prev_fwd_grad_mode, - std::optional functionalize_add_back_views) { + std::optional functionalize_add_back_views, + std::optional prev_inference_mode) { const auto& dynamicLayerStack = dynamicLayerStackAccessor(); const int64_t layerId = static_cast(1 + dynamicLayerStack.size()); - DynamicLayer new_layer(transform_type, layerId, std::move(batch_size), randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views); + DynamicLayer new_layer(transform_type, layerId, std::move(batch_size), randomness, prev_grad_mode, prev_fwd_grad_mode, functionalize_add_back_views, prev_inference_mode); // NB: this function should be called while holding the GIL to avoid races new_layer.interpreter().set_is_alive(true); pushDynamicLayer(std::move(new_layer)); diff --git a/aten/src/ATen/functorch/DynamicLayer.h b/aten/src/ATen/functorch/DynamicLayer.h index 672a33fda0016..bb203ea8ab8ea 100644 --- a/aten/src/ATen/functorch/DynamicLayer.h +++ b/aten/src/ATen/functorch/DynamicLayer.h @@ -47,7 +47,8 @@ struct TORCH_API DynamicLayer { std::optional randomness = std::nullopt, std::optional prev_grad_mode = std::nullopt, std::optional pre_fwd_grad_mode = std::nullopt, - std::optional functionalize_add_back_views = std::nullopt); + std::optional functionalize_add_back_views = std::nullopt, + std::optional prev_inference_mode = std::nullopt); TransformType key() const; int64_t layerId() const; @@ -69,7 +70,8 @@ TORCH_API int64_t initAndPushDynamicLayer( std::optional randomness = std::nullopt, std::optional prev_grad_mode = std::nullopt, std::optional prev_fwd_grad_mode = std::nullopt, - std::optional functionalize_add_back_views = std::nullopt); + std::optional functionalize_add_back_views = std::nullopt, + std::optional prev_inference_mode = std::nullopt); TORCH_API DynamicLayer popDynamicLayerAndDeleteMetadata(); TORCH_API std::optional maybeCurrentDynamicLayer(); TORCH_API const std::vector& getDynamicLayerStack(); diff --git a/aten/src/ATen/functorch/Interpreter.h b/aten/src/ATen/functorch/Interpreter.h index 3d3b2069387d7..02681f830fb7a 100644 --- a/aten/src/ATen/functorch/Interpreter.h +++ b/aten/src/ATen/functorch/Interpreter.h @@ -123,7 +123,8 @@ struct VmapInterpreterMeta { }; struct GradInterpreterMeta { - explicit GradInterpreterMeta(bool prevGradMode): prevGradMode_(prevGradMode) {} + explicit GradInterpreterMeta(bool prevGradMode, bool prevInferenceMode = false) + : prevGradMode_(prevGradMode), prevInferenceMode_(prevInferenceMode) {} GradInterpreterMeta() = default; GradInterpreterMeta(const GradInterpreterMeta&) = default; GradInterpreterMeta(GradInterpreterMeta&&) = default; @@ -132,19 +133,23 @@ struct GradInterpreterMeta { ~GradInterpreterMeta() = default; bool prevGradMode_; + bool prevInferenceMode_; template friend void to_json(T& json_j, const GradInterpreterMeta& json_t) { json_j["prevGradMode"] = json_t.prevGradMode_; + json_j["prevInferenceMode"] = json_t.prevInferenceMode_; } template friend void from_json(const T& json_j, GradInterpreterMeta& json_t) { json_t.prevGradMode_ = json_j["prevGradMode"]; + json_t.prevInferenceMode_ = json_j.value("prevInferenceMode", false); } }; struct JvpInterpreterMeta { - explicit JvpInterpreterMeta(bool prevFwdGradMode) : prevFwdGradMode_(prevFwdGradMode) {} + explicit JvpInterpreterMeta(bool prevFwdGradMode, bool prevInferenceMode = false) + : prevFwdGradMode_(prevFwdGradMode), prevInferenceMode_(prevInferenceMode) {} JvpInterpreterMeta() = default; JvpInterpreterMeta(const JvpInterpreterMeta&) = default; JvpInterpreterMeta(JvpInterpreterMeta&&) = default; @@ -153,14 +158,17 @@ struct JvpInterpreterMeta { ~JvpInterpreterMeta() = default; bool prevFwdGradMode_; + bool prevInferenceMode_; template friend void to_json(T& json_j, const JvpInterpreterMeta& json_t) { json_j["prevFwdGradMode"] = json_t.prevFwdGradMode_; + json_j["prevInferenceMode"] = json_t.prevInferenceMode_; } template friend void from_json(const T& json_j, JvpInterpreterMeta& json_t) { json_t.prevFwdGradMode_ = json_j["prevFwdGradMode"]; + json_t.prevInferenceMode_ = json_j.value("prevInferenceMode", false); } }; @@ -200,11 +208,11 @@ struct Interpreter { static Interpreter Vmap(int64_t level, c10::SymInt batchSize, RandomnessType randomness) { return Interpreter(TransformType::Vmap, level, VmapInterpreterMeta(std::move(batchSize), randomness)); } - static Interpreter Grad(int64_t level, bool prevGradMode) { - return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode)); + static Interpreter Grad(int64_t level, bool prevGradMode, bool prevInferenceMode = false) { + return Interpreter(TransformType::Grad, level, GradInterpreterMeta(prevGradMode, prevInferenceMode)); } - static Interpreter Jvp(int64_t level, bool prevFwdGradMode) { - return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode)); + static Interpreter Jvp(int64_t level, bool prevFwdGradMode, bool prevInferenceMode = false) { + return Interpreter(TransformType::Jvp, level, JvpInterpreterMeta(prevFwdGradMode, prevInferenceMode)); } static Interpreter Functionalize(int64_t level, bool functionalizeAddBackViews) { return Interpreter(TransformType::Functionalize, level, FunctionalizeInterpreterMeta(functionalizeAddBackViews)); diff --git a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp index 1df4c8938183a..e02f20b102bc7 100644 --- a/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp +++ b/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp @@ -709,8 +709,10 @@ Tensor nested_cat_batching_rule(const ITensorListRef& tensors, int64_t dim) { // Do a cat for each set of zipped unbound components const auto num_components = unbound.front().size(); std::vector outputs; + outputs.reserve(num_components); for (auto i : c10::irange(num_components)) { std::vector arg_list; + arg_list.reserve(unbound.size()); for (auto j : c10::irange(unbound.size())) { arg_list.push_back(unbound[j][i]); } diff --git a/aten/src/ATen/mps/MPSDevice.h b/aten/src/ATen/mps/MPSDevice.h index 3c2dee71aa752..67f65b99c79e0 100644 --- a/aten/src/ATen/mps/MPSDevice.h +++ b/aten/src/ATen/mps/MPSDevice.h @@ -23,6 +23,7 @@ enum class MacOSVersion : uint32_t { MACOS_VER_15_1_PLUS, MACOS_VER_15_2_PLUS, MACOS_VER_26_0_PLUS, + MACOS_VER_26_4_PLUS, }; //----------------------------------------------------------------- diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 89b54d3d3d047..4e864aeaa73f8 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -66,6 +66,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de static bool _macos_15_1_plus = is_os_version_at_least(15, 1); static bool _macos_15_2_plus = is_os_version_at_least(15, 2); static bool _macos_26_0_plus = is_os_version_at_least(26, 0); + static bool _macos_26_4_plus = is_os_version_at_least(26, 4); switch (version) { case MacOSVersion::MACOS_VER_14_4_PLUS: @@ -78,6 +79,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de return _macos_15_2_plus; case MacOSVersion::MACOS_VER_26_0_PLUS: return _macos_26_0_plus; + case MacOSVersion::MACOS_VER_26_4_PLUS: + return _macos_26_4_plus; default: return false; } diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h index 008a8d57f3df6..a02e4dd0b4300 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.h +++ b/aten/src/ATen/mps/MPSGuardImpl.h @@ -34,7 +34,7 @@ struct TORCH_API MPSGuardImpl final static constexpr c10::DeviceType static_type = c10::DeviceType::MPS; // constructor - MPSGuardImpl() {} + MPSGuardImpl() = default; explicit MPSGuardImpl(c10::DeviceType t) { TORCH_CHECK( t == DeviceType::MPS, @@ -84,6 +84,18 @@ struct TORCH_API MPSGuardImpl final Stream exchangeStream(Stream s) const override { return Stream(Stream::DEFAULT, Device(c10::DeviceType::MPS, 0)); } + DeviceCapability getDeviceCapability(Device /* unused */) const override { + DeviceCapability cap; + cap.capability_data.capability_bits = (1ULL << kIndex_Byte) | + (1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | + (1ULL << kIndex_Long) | (1ULL << kIndex_Half) | (1ULL << kIndex_Float) | + (1ULL << kIndex_ComplexHalf) | (1ULL << kIndex_ComplexFloat) | + (1ULL << kIndex_Bool) | (1ULL << kIndex_BFloat16) | + (1ULL << kIndex_UInt32) | (1ULL << kIndex_UInt16) | + (1ULL << kIndex_UInt64); + return cap; + } + DeviceIndex deviceCount() const noexcept override { if (at::hasMPS()) { // TODO: extend it for multi-device case diff --git a/aten/src/ATen/mps/MPSStream.h b/aten/src/ATen/mps/MPSStream.h index b00890b9f5901..f83a387ee846e 100644 --- a/aten/src/ATen/mps/MPSStream.h +++ b/aten/src/ATen/mps/MPSStream.h @@ -73,7 +73,6 @@ class TORCH_API MPSStream { MTLComputeCommandEncoder_t commandEncoder(); void endKernelCoalescing(); void synchronize(SyncType syncType); - void fill(MTLBuffer_t buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE); void copy(MTLBuffer_t srcBuffer, MTLBuffer_t dstBuffer, size_t length, diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm index b02d2447ce17f..8cecc94f5926b 100644 --- a/aten/src/ATen/mps/MPSStream.mm +++ b/aten/src/ATen/mps/MPSStream.mm @@ -157,33 +157,6 @@ @interface MPSGraphExecutionDescriptor () }); } -void MPSStream::fill(id buffer, uint8_t value, size_t length, size_t offset, SyncType syncType) { - if (length == 0) { - return; - } - dispatch_sync_with_rethrow(_serialQueue, ^() { - @autoreleasepool { - endKernelCoalescing(); - id blitEncoder = [commandBuffer() blitCommandEncoder]; - - // For some reason fillBufferfor stopped working for length > 4Gb on MacOS 26 - // See https://github.com/pytorch/pytorch/issues/163962 - // Workaround by batching copy commands into 4Gb chunks - constexpr size_t max_copy_size = 0x100000000; // 4GB - size_t bytes_filled = 0; - size_t bytes_remains = length; - while (bytes_remains > 0) { - NSUInteger bytes_to_copy = std::min(max_copy_size, bytes_remains); - [blitEncoder fillBuffer:buffer range:NSMakeRange(offset + bytes_filled, bytes_to_copy) value:value]; - bytes_filled += bytes_to_copy; - bytes_remains -= bytes_to_copy; - } - [blitEncoder endEncoding]; - synchronize(syncType); - } - }); -} - void MPSStream::copy(id srcBuffer, id dstBuffer, size_t length, diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index cd24d51c62c1c..2de3e24b584b8 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -1843,7 +1843,21 @@ TORCH_IMPL_FUNC(linalg_cholesky_ex_out)(const Tensor& A, cholesky_stub(L.device().type(), L, info, upper); - if (!cpu) { + // On non-CPU devices (MAGMA) the pre-copy doesn't zero the unused triangle, + // so we must clean up after. On macOS, Accelerate's LAPACK writes into the + // unreferenced triangle for matrices larger than its internal block size + // (e.g. n > 64), violating the LAPACK spec which says "not referenced" + // elements are "never read, written to, or otherwise accessed" + // (see https://www.netlib.org/lapack/lug/node121.html). + // We work around this by applying the same cleanup on macOS. + // TODO(https://github.com/pytorch/pytorch/issues/179152): always + // clean up the unused triangle on all platforms. +#if defined(__APPLE__) + constexpr bool needs_triangle_cleanup = true; +#else + const bool needs_triangle_cleanup = !cpu; +#endif + if (needs_triangle_cleanup) { if (upper) { L.triu_(); } else { diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index c7de276f5f88f..164e709536135 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -1201,7 +1201,6 @@ struct Brgemm : public KernelCache { case ScalarType::Float8_e5m2: return f8_support; default: return false; } - return false; } }; @@ -1261,7 +1260,6 @@ struct Pack : public KernelCache { case ScalarType::Float8_e5m2: return fp8_pack; default: return false; } - return false; } }; #endif diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index 8c0771ddfdcc6..adf76c01c5a6c 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -333,12 +333,11 @@ inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input weight.scalar_type() == at::kDouble) { return at::MemoryFormat::Contiguous; } - long cudnn_version = at::detail::getCUDAHooks().versionCuDNN(); auto input_memory_format = input.suggest_memory_format(); auto weight_memory_format = weight.suggest_memory_format(); auto weight_ndim = weight.ndimension(); - bool can_use_cudnn_channels_last_2d = (cudnn_version >= 7603) && (weight_ndim == 4) && ( + bool can_use_cudnn_channels_last_2d = weight_ndim == 4 && ( (input_memory_format == at::MemoryFormat::ChannelsLast) || (weight_memory_format == at::MemoryFormat::ChannelsLast) ); @@ -346,7 +345,7 @@ inline at::MemoryFormat cudnn_conv_suggest_memory_format(const at::Tensor& input return at::MemoryFormat::ChannelsLast; } - bool can_use_cudnn_channels_last_3d = (cudnn_version >= 8005) && (weight_ndim == 5) && ( + bool can_use_cudnn_channels_last_3d = weight_ndim == 5 && ( (input_memory_format == at::MemoryFormat::ChannelsLast3d) || (weight_memory_format == at::MemoryFormat::ChannelsLast3d) ); @@ -471,8 +470,14 @@ inline bool mps_conv_use_channels_last(const at::Tensor& input, const at::Tensor return false; } + // Use exact-match so a tensor whose strides merely *look* like channels-last + // (e.g. a channel-slice of a channels-last tensor) is not misclassified -- + // MPS reads raw buffers assuming packed NHWC, which would be incorrect for + // such views. exact_match also correctly excludes degenerate 1x1 weights + // whose NCHW-contiguous strides happen to also satisfy is_contiguous(CL). + // See https://github.com/pytorch/pytorch/issues/180984 auto is_channel_last = [](const at::Tensor& t) { - auto fmt = t.suggest_memory_format(); + auto fmt = t.suggest_memory_format(/*channels_last_strides_exact_match=*/true); return fmt == at::MemoryFormat::ChannelsLast || fmt == at::MemoryFormat::ChannelsLast3d; }; return is_channel_last(input) || is_channel_last(weight); diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index 94e37647a8a5f..f3477e10acb19 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -473,6 +473,12 @@ struct ConvParams { (stride[0] == stride[1] || at::symint::size(input, 2) == 1) && // square or 1d at::symint::size(input, 1) >= 32); // min 32 channels supported) if (kernel_cond) { + auto depthwise_kernel = at::globalContext().cudnnDepthwiseKernel(); + if (depthwise_kernel == at::CuDNNDepthwiseKernel::NATIVE) { + return false; + } else if (depthwise_kernel == at::CuDNNDepthwiseKernel::CUDNN) { + return true; + } return check_cudnn_depthwise_workload_with_filter(input, stride[1], weight); } return false; @@ -1774,10 +1780,6 @@ std::tuple convolution_backward_overrideable( IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed, IntArrayRef output_padding, int64_t groups, std::array output_mask) { TORCH_CHECK_NOT_IMPLEMENTED(false, "convolution_backward_overrideable: You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function "); - return std::tuple( - at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT), - at::empty_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT), - at::empty({})); } static Tensor subvariable(const Tensor& var, int64_t dim, int64_t groups, int64_t g) { diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index 090a028786c72..cff843e0ee5c8 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -214,7 +214,6 @@ DispatchResult DispatchStubImpl::try_get_call_ptr( default: TORCH_INTERNAL_ASSERT(false, "An unexpected device type was provided ", device_type); - return ErrorType::DeviceNotSupported; } } @@ -268,7 +267,6 @@ void* DispatchStubImpl::get_call_ptr( case ErrorType::MissingDeviceKernel: TORCH_INTERNAL_ASSERT( false, "DispatchStub: missing kernel for ", device_type); - return nullptr; case ErrorType::DeviceNotSupported: TORCH_CHECK(false, "DispatchStub: unsupported device type", device_type); } diff --git a/aten/src/ATen/native/Distance.cpp b/aten/src/ATen/native/Distance.cpp index b7dfef7779ca7..af8d6185d4c26 100644 --- a/aten/src/ATen/native/Distance.cpp +++ b/aten/src/ATen/native/Distance.cpp @@ -103,8 +103,8 @@ static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, std // See Note [cdist relies on cdist_impl redispatching] // Keep this condition in sync with the condition at the Note if (!(p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25))))) { - TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X1 got: ", device1); - TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X2 got: ", device2); + TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU || device1 == kPrivateUse1, "cdist only supports CPU, XPU, CUDA and PrivateUse1 devices, X1 got: ", device1); + TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU || device2 == kPrivateUse1, "cdist only supports CPU, XPU, CUDA and PrivateUse1 devices, X2 got: ", device2); } auto dim1 = x1.dim(); @@ -229,9 +229,9 @@ Tensor _cdist_backward(const Tensor& _grad, const Tensor& _x1, const Tensor& _x2 int64_t n = x1.size(-2); int64_t m = x1.size(-1); auto device1 = x1.device().type(); - TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU, "_cdist_backward only supports CPU, XPU and CUDA devices, X1 got: ", device1); + TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU || device1 == kPrivateUse1, "_cdist_backward only supports CPU, XPU, CUDA and PrivateUse1 devices, X1 got: ", device1); auto device2 = x2.device().type(); - TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU, "_cdist_backward only supports CPU, XPU and CUDA devices, X2 got: ", device2); + TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU || device2 == kPrivateUse1, "_cdist_backward only supports CPU, XPU, CUDA and PrivateUse1 devices, X2 got: ", device2); Tensor grad_x1 = at::empty({batch_product, n, m}, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); @@ -245,7 +245,7 @@ Tensor _cdist_backward(const Tensor& _grad, const Tensor& _x1, const Tensor& _x2 Tensor _pdist_forward(const Tensor& self, const double p) { TORCH_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input"); auto device = self.device().type(); - TORCH_CHECK(device == kCPU || device == kCUDA || device == kXPU, "_pdist_forward only supports CPU, XPU and CUDA devices, got: ", device); + TORCH_CHECK(device == kCPU || device == kCUDA || device == kXPU || device == kPrivateUse1, "_pdist_forward only supports CPU, XPU, CUDA and PrivateUse1 devices, got: ", device); Tensor result = at::empty({0}, self.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); if (self.size(0) <= 1) { result.resize_({0}); @@ -266,7 +266,7 @@ Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, c TORCH_CHECK(self.is_contiguous(), "_pdist_backward requires self to be contiguous"); TORCH_CHECK(pdist.is_contiguous(), "_pdist_backward requires pdist to be contiguous"); auto device = self.device().type(); - TORCH_CHECK(device == kCPU || device == kCUDA || device == kXPU, "_pdist_backward only supports CPU, XPU and CUDA devices, got: ", device); + TORCH_CHECK(device == kCPU || device == kCUDA || device == kXPU || device == kPrivateUse1, "_pdist_backward only supports CPU, XPU, CUDA and PrivateUse1 devices, got: ", device); Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); pdist_backward_stub(device, result, grad, self, p, pdist); return result; diff --git a/aten/src/ATen/native/Distributions.cpp b/aten/src/ATen/native/Distributions.cpp index 5f34ed9d24c17..6382b1ecaace1 100644 --- a/aten/src/ATen/native/Distributions.cpp +++ b/aten/src/ATen/native/Distributions.cpp @@ -594,7 +594,8 @@ Tensor& multinomial_out(const Tensor& self, // https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503 if (!with_replacement || n_sample == 1) { // Sanity checks on `self`. - auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)); + auto [self_min, self_max] = self.aminmax(); + auto is_valid = ((self_max < INFINITY) & (self_min >= 0)); at::_assert_async(is_valid, "probability tensor contains either `inf`, `nan` or element < 0"); at::Tensor zero_prob_condition; if (self.dim() == 1){ diff --git a/aten/src/ATen/native/ForeachOpsKernels.cpp b/aten/src/ATen/native/ForeachOpsKernels.cpp index d411a68b037e0..8c57aa08338ca 100644 --- a/aten/src/ATen/native/ForeachOpsKernels.cpp +++ b/aten/src/ATen/native/ForeachOpsKernels.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -364,6 +365,21 @@ FOREACH_BINARY_OP_LIST(div) FOREACH_BINARY_OP_LIST(clamp_min) FOREACH_BINARY_OP_LIST(clamp_max) FOREACH_BINARY_OP_LIST(pow) + +// _foreach_clone +std::vector foreach_tensor_clone_slow( + TensorList self, + std::optional memory_format) { + check_foreach_api_restrictions(self); + + std::vector ret{}; + ret.reserve(self.size()); + for (const auto& t : self) { + ret.emplace_back(t.clone(memory_format)); + } + return ret; +} + // _foreach_copy_ void foreach_tensor_copy_list_kernel_slow_( TensorList self, @@ -521,6 +537,9 @@ std::vector foreach_tensor_max_slow(TensorList tensors) { std::vector result; result.reserve(tensors.size()); for (const auto& t : tensors) { + TORCH_CHECK( + t.numel() > 0, + "_foreach_max cannot compute the maximum of an empty tensor; max over zero elements is undefined."); result.emplace_back(at::max(t)); } return result; diff --git a/aten/src/ATen/native/ForeachUtils.h b/aten/src/ATen/native/ForeachUtils.h index f0dce20a6eff4..d68d770f215c0 100644 --- a/aten/src/ATen/native/ForeachUtils.h +++ b/aten/src/ATen/native/ForeachUtils.h @@ -93,22 +93,27 @@ inline void check_foreach_api_restrictions( // same device and dtype. inline bool _check_tensors_share_device_and_dtype( ArrayRef tensorLists, - const bool skip_dtype_check = false) { + const bool skip_cross_list_dtype_check = false) { const auto expected_dtype = tensorLists[0][0].dtype(); const auto expected_device = tensorLists[0][0].device(); - auto is_tensor_okay = [&](const Tensor& tensor) { - return (skip_dtype_check || tensor.dtype() == expected_dtype) && - tensor.device() == expected_device && tensor.layout() == at::kStrided && - tensor.is_non_overlapping_and_dense(); - }; - return std::all_of( tensorLists.cbegin(), tensorLists.cend(), [&](const TensorList& tensorList) { + if (tensorList.empty()) { + return true; + } + const auto list_dtype = tensorList[0].dtype(); return std::all_of( - tensorList.cbegin(), tensorList.cend(), is_tensor_okay); + tensorList.cbegin(), tensorList.cend(), [&](const Tensor& tensor) { + return tensor.device() == expected_device && + tensor.layout() == at::kStrided && + tensor.is_non_overlapping_and_dense() && + tensor.dtype() == list_dtype && + (skip_cross_list_dtype_check || + tensor.dtype() == expected_dtype); + }); }); } @@ -192,8 +197,10 @@ inline bool _check_tensors_do_type_promotion_with_scalars( inline bool check_fast_path_restrictions( ArrayRef tensorLists, ArrayRef scalarList = {}, - bool does_op_promote_integer_inputs_to_float = false) { - return _check_tensors_share_device_and_dtype(tensorLists) && + bool does_op_promote_integer_inputs_to_float = false, + bool skip_cross_list_dtype_check = false) { + return _check_tensors_share_device_and_dtype( + tensorLists, skip_cross_list_dtype_check) && _check_tensors_share_sizes_and_strides(tensorLists) && _check_tensors_do_type_promotion_with_scalars( tensorLists[0], @@ -318,9 +325,17 @@ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( const auto s = tensor->scalar_type(); const auto d = tensor->device(); // Note: `step` or `state_step` is float32 by default. + // BFloat16 is allowed here for mixed-precision optimizer + // states (e.g. fp32 params with bf16 exp_avg/exp_avg_sq). + // Currently only BFloat16 is permitted because it is the + // only low-precision state dtype validated end-to-end in + // large-scale training (e.g. DeepSeek-V3 671B). + // TBD: make the set of allowed extra dtypes configurable + // per optimizer so this function stays dtype-agnostic. if (key.first == d) { return key.second == s || s == at::ScalarType::Float || - s == at::ScalarType::Double; + s == at::ScalarType::Double || + s == at::ScalarType::BFloat16; } else if (d.is_cpu()) { // note(crcrpar): There are some test cases (e.g. // TestOptim::test_adam) where state_steps are on CPU and the @@ -333,7 +348,9 @@ inline FlatMap _group_tensors_by_first_tensors_device_and_dtype( } } }), - "Tensors of the same index must be on the same device and the same dtype except `step` tensors that can be CPU and float32/64 notwithstanding"); + "Tensors of the same index must be on the same device and the same dtype " + "except `step` tensors that can be CPU and float32/64, and optimizer " + "states that can be bfloat16 for mixed-precision training"); grouped_tensors_with_indices.try_emplace( key, TensorsAndIndicesT{ diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index b7b8424c8bbce..b5896575580f2 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -476,7 +476,7 @@ std::tuple get_atol_rtol( ? at::where(atol_opt.value() > 0, at::zeros({}, options), default_rtol) : std::move(default_rtol); } - return std::make_tuple(atol, rtol); + return std::make_tuple(std::move(atol), std::move(rtol)); } std::tuple get_atol_rtol( @@ -502,7 +502,7 @@ std::tuple get_atol_rtol( } auto atol_tensor = at::full({}, atol, options); auto rtol_tensor = at::full({}, rtol, options); - return std::make_tuple(atol_tensor, rtol_tensor); + return std::make_tuple(std::move(atol_tensor), std::move(rtol_tensor)); } } // anonymous namespace diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 257863573d3a8..7b5c0758ad9c1 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -336,7 +336,7 @@ inline std::tuple _linalg_broadcast_batch_dims(const Tensor& arg1 auto arg1_broadcasted = arg1_expand_size == arg1.sizes() ? arg1 : arg1.expand(arg1_expand_size); auto arg2_broadcasted = arg2_expand_size == arg2.sizes() ? arg2 : arg2.expand(arg2_expand_size); - return std::make_tuple(arg1_broadcasted, arg2_broadcasted); + return std::make_tuple(std::move(arg1_broadcasted), std::move(arg2_broadcasted)); } inline std::vector broadcast_batch_size(const Tensor& t1, const Tensor& t2, int64_t n_batch_dims) { @@ -351,6 +351,7 @@ inline Tensor _move_to_end(const Tensor& self, IntArrayRef axes) { const std::vector a = axes.vec(); const int64_t ndim = self.ndimension(); std::vector perm; + perm.reserve(static_cast(std::max(0, ndim))); for (const auto i : c10::irange(ndim)) { auto it = std::find(a.begin(), a.end(), i); @@ -405,7 +406,7 @@ inline std::tuple _compute_geometry_for_Q( n_columns_q = std::min(m, n); } auto q_strides = batched_matrix_contiguous_strides(q_sizes, /*f-contig*/true); - return std::make_tuple(q_sizes, q_strides, n_columns_q); + return std::make_tuple(std::move(q_sizes), std::move(q_strides), n_columns_q); } inline bool svd_uses_cusolver(const Tensor& A) { diff --git a/aten/src/ATen/native/LossCTC.cpp b/aten/src/ATen/native/LossCTC.cpp index ec4ce8d8550f4..7b777ec744413 100644 --- a/aten/src/ATen/native/LossCTC.cpp +++ b/aten/src/ATen/native/LossCTC.cpp @@ -472,8 +472,8 @@ Tensor ctc_loss_backward_tensor( Tensor ilc = input_lengths.to(Device(at::kCPU), at::kLong).contiguous(); Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); - IntArrayRef il(ilc.data_ptr(), ilc.numel()); - IntArrayRef tl(tlc.data_ptr(), tlc.numel()); + IntArrayRef il(ilc.const_data_ptr(), ilc.numel()); + IntArrayRef tl(tlc.const_data_ptr(), tlc.numel()); return at::_ctc_loss_backward(grad, log_probs, targets, il, tl, neg_log_likelihood, log_alpha, BLANK, zero_infinity); } diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index de0a40ae412d6..7d9e73779a4bd 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -2281,7 +2281,7 @@ inline C10_HOST_DEVICE T airy_ai_forward(T x) { int domain_flag = 0; - T ai; + T ai = T(0.0); if (std::isinf(x)) { return std::numeric_limits::quiet_NaN(); @@ -3267,7 +3267,7 @@ inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) { +8.04490411014108831608e-01, }; - T p; + T p = T{0}; T q = 0.0; if (std::abs(x) <= T(8.0)) { @@ -3355,7 +3355,7 @@ inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) { +7.78576235018280120474e-01, }; - T p; + T p = T{0}; T q = 0.0; if (std::abs(x) <= T(8.0)) { @@ -3440,7 +3440,7 @@ inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) { return std::numeric_limits::quiet_NaN(); } - T p; + T p = T{0}; T q = 0.0; if (x <= T(2.0)) { @@ -3518,7 +3518,7 @@ inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) { return std::numeric_limits::quiet_NaN(); } - T p; + T p = T{0}; T q = 0.0; if (x <= T(2.0)) { @@ -3595,7 +3595,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) { return std::numeric_limits::quiet_NaN(); } - T p; + T p = T{0}; T q = 0.0; if (x <= T(2.0)) { @@ -3673,7 +3673,7 @@ inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) { return std::numeric_limits::quiet_NaN(); } - T p; + T p = T{0}; T q = 0.0; if (x <= T(2.0)) { diff --git a/aten/src/ATen/native/NNPACK.cpp b/aten/src/ATen/native/NNPACK.cpp index 6869f331994e1..1179210ea1c6d 100644 --- a/aten/src/ATen/native/NNPACK.cpp +++ b/aten/src/ATen/native/NNPACK.cpp @@ -235,9 +235,9 @@ Tensor _nnpack_spatial_convolution( input_padding, kernel_size, output_subsample, - input_.data_ptr() + batch * input_size_per_batch, - weight_.data_ptr(), - bias_.data_ptr(), + input_.const_data_ptr() + batch * input_size_per_batch, + weight_.const_data_ptr(), + bias_.const_data_ptr(), output.data_ptr() + batch * output_size_per_batch, workspace.buffer, &workspace.size, @@ -262,9 +262,9 @@ Tensor _nnpack_spatial_convolution( input_size, input_padding, kernel_size, - input_.data_ptr(), - weight_.data_ptr(), - bias_.data_ptr(), + input_.const_data_ptr(), + weight_.const_data_ptr(), + bias_.const_data_ptr(), output.data_ptr(), workspace.buffer, &workspace.size, diff --git a/aten/src/ATen/native/NaiveDilatedConvolution.cpp b/aten/src/ATen/native/NaiveDilatedConvolution.cpp index bd8ada650a96b..1129afb44388b 100644 --- a/aten/src/ATen/native/NaiveDilatedConvolution.cpp +++ b/aten/src/ATen/native/NaiveDilatedConvolution.cpp @@ -397,7 +397,7 @@ void slow_conv_dilated_all_cpu_template( Tensor grad_input_n = grad_input.select(0, elt); col2hvol( - columns.data_ptr(), + columns.const_data_ptr(), nInputPlane, input_size, output_size, diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 22e2d4b75e950..71dab6aed8955 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -118,14 +118,14 @@ struct Var { } }; -static inline bool is_contiguous(const Tensor& t) { +static bool is_contiguous_in_any_format(const Tensor& t) { return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) || t.is_contiguous(at::MemoryFormat::ChannelsLast3d); } // For some ambiguous cases, it is possible a channels last contiguous Tensor has // `suggest_memory_format` of Contiguous. // See https://github.com/pytorch/pytorch/issues/63224 for details. -static inline MemoryFormat suggest_memory_format_contig(const Tensor& t) { +static MemoryFormat suggest_memory_format_contig(const Tensor& t) { return t.is_contiguous() ? at::MemoryFormat::Contiguous : (t.is_contiguous(at::MemoryFormat::ChannelsLast3d) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast); @@ -138,8 +138,8 @@ static std::tuple batch_norm_cpu_transform_input_template( const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */, bool train, double eps, Tensor& output) { - bool all_contiguous = is_contiguous(input) - && is_contiguous(output) + bool all_contiguous = is_contiguous_in_any_format(input) + && is_contiguous_in_any_format(output) && (!weight.defined() || weight.is_contiguous()) && (!bias.defined() || bias.is_contiguous()) && running_mean.is_contiguous() @@ -207,7 +207,7 @@ static std::tuple batch_norm_cpu_update_stats_template( TORCH_CHECK(input.numel() != 0, "input tensor must have at least one element, but got input_sizes = ", input.sizes()); int64_t n = input.numel() / n_input; - bool all_contiguous = is_contiguous(input); + bool all_contiguous = is_contiguous_in_any_format(input); constexpr bool mixed_type = !std::is_same_v; // Using float data type for Half _var_sum in batchnorm stats updating on CPU // to avoid _var_sum overflow since the representation range of Half is small. @@ -300,7 +300,7 @@ static std::tuple batch_norm_cpu_update_stats_template( constexpr bool mixed_type = !std::is_same_v; const auto dtype = mixed_type ? kFloat : input.scalar_type(); - Tensor save_mean = is_contiguous(input) ? at::empty({n_input}, input.options().dtype(dtype)) : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype); + Tensor save_mean = is_contiguous_in_any_format(input) ? at::empty({n_input}, input.options().dtype(dtype)) : at::mean(input, /*dim=*/reduce_dims, /*keepdim=*/false, dtype); Tensor save_var_transform = at::empty({n_input}, input.options().dtype(dtype)); return batch_norm_cpu_update_stats_template(input, running_mean, running_var, momentum, eps, save_mean, save_var_transform); } @@ -331,8 +331,8 @@ static std::tuple batch_norm_backward_cpu_template( // since we are directly manipulating pointers in contiguous path, // need to make sure input and grad_out have the same memory format. - bool all_contiguous = is_contiguous(input) - && is_contiguous(grad_out_) + bool all_contiguous = is_contiguous_in_any_format(input) + && is_contiguous_in_any_format(grad_out_) && input.suggest_memory_format() == grad_out_.suggest_memory_format(); if (all_contiguous) { @@ -863,7 +863,7 @@ std::tuple batch_norm_cpu(const Tensor& self, const std: checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); // Prepare output tensor - const bool all_contiguous = is_contiguous(self) + const bool all_contiguous = is_contiguous_in_any_format(self) && (!weight.defined() || weight.is_contiguous()) && (!bias.defined() || bias.is_contiguous()) && running_mean.is_contiguous() @@ -885,7 +885,7 @@ std::tuple batch_norm_cpu(const Tensor& self, const std: save_mean = at::empty({0}, self.options().dtype(kFloat)); save_var = at::empty({0}, self.options().dtype(kFloat)); } else { - save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options().dtype(kFloat)) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false, kFloat); + save_mean = is_contiguous_in_any_format(self) ? at::empty({self.size(1)}, self.options().dtype(kFloat)) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false, kFloat); save_var = at::empty({self.size(1)}, self.options().dtype(kFloat)); } } else { @@ -893,7 +893,7 @@ std::tuple batch_norm_cpu(const Tensor& self, const std: save_mean = at::empty({0}, self.options()); save_var = at::empty({0}, self.options()); } else { - save_mean = is_contiguous(self) ? at::empty({self.size(1)}, self.options()) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false); + save_mean = is_contiguous_in_any_format(self) ? at::empty({self.size(1)}, self.options()) : at::mean(self, /*dim=*/reduce_dims, /*keepdim=*/false); save_var = at::empty({self.size(1)}, self.options()); } } diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index d757e8ba41c8d..bb93c0e6e05fb 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -38,7 +38,7 @@ std::tuple _pack_padded_sequence(const Tensor& _input, const Ten checkLongTensor(lengths_t); int64_t batch_size = input.size(1); - int64_t * lengths = lengths_t.data_ptr(); + const int64_t * lengths = lengths_t.const_data_ptr(); TORCH_CHECK(lengths_t.size(0) == batch_size, "Expected `len(lengths)` to be equal to batch_size, but got ", lengths_t.size(0), @@ -126,7 +126,7 @@ Tensor _pack_padded_sequence_backward_symint(const Tensor& grad, c10::SymIntArra // NOTE: this op advertises as CompositeImplicitAutograd, but uses data_ptr(). // we should fix this. auto max_seq_len = batch_sizes_t.size(0); - int64_t * batch_sizes = batch_sizes_t.data_ptr(); + const int64_t * batch_sizes = batch_sizes_t.const_data_ptr(); for (const auto i : c10::irange(max_seq_len)) { grad_input[i].slice(0, 0, batch_sizes[i]).copy_(grad.slice(0, offset, offset + batch_sizes[i])); offset += batch_sizes[i]; @@ -144,7 +144,7 @@ std::tuple _pad_packed_sequence(const Tensor& data, const Tensor checkLongTensor(batch_sizes_t); TORCH_CHECK(batch_sizes_t.numel() > 0, "batch_sizes can not be empty"); - int64_t * batch_sizes = batch_sizes_t.data_ptr(); + const int64_t * batch_sizes = batch_sizes_t.const_data_ptr(); int64_t max_batch_size = batch_sizes[0]; int64_t max_real_seq_length = batch_sizes_t.size(0); int64_t max_seq_length = max_real_seq_length; diff --git a/aten/src/ATen/native/README.md b/aten/src/ATen/native/README.md index ab836593dc297..87512d6058a4b 100644 --- a/aten/src/ATen/native/README.md +++ b/aten/src/ATen/native/README.md @@ -5,7 +5,7 @@ in one of the `cpp` files in this directory. Like all ATen methods/functions, native functions are made available from both ATen's C++ and Python APIs. In C++, they are made available -either as methods on `Tensor` (`t.mymeth()`) and functions in the ATen +either as methods on `Tensor` (`t.mymeth()`) or as functions in the ATen namespace (`at::myfunc()`). In PyTorch, they are made available as methods on `Variable` or as functions on `torch._C._FunctionBase`. (It is the user's responsibility to re-export these functions in diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 1b465790d306c..80c466a6b5815 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -104,8 +104,9 @@ bool use_mkldnn(const Tensor& input, TensorList params, TensorList hx) { (input.scalar_type() == kHalf && !at::GradMode::is_enabled() && mkldnn_fp16_device_check())) && input.numel() != 0; -#endif +#else return false; +#endif } bool use_cudnn(const Tensor& t) { diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 68dadf218945f..d5dc9ba1eeb46 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -117,6 +117,7 @@ #include #include #include +#include #include #include #endif @@ -645,7 +646,12 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co // O(n) implementation. The derivative of this implementation is _not_ // the second derivative of cumprod. As such, we fallback to a less efficient // O(n^2) implementation when at::GradMode::is_enabled(). - if (!at::GradMode::is_enabled() && !are_inputs_tensors_sublcass) { + // + // NOTE: We use at::where instead of masked_scatter_/masked_select to make + // this path composite compliant for tensor subclasses (e.g., FakeTensors + // used by torch.compile). masked_select has dynamic output shape which + // causes issues with tracing. See https://github.com/pytorch/pytorch/issues/136263 + if (!at::GradMode::is_enabled()) { // n.b. This could probably be implemented much faster with a kernel // From here on we need to use some mask gymnastics to @@ -669,9 +675,16 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co // case k < z1 // select everything before the first zero [0, z1) auto mask = cumsum == 0; - // equiv to grad_input[mask] = deriv[grad] - grad_input.masked_scatter_(mask, - reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask)); + // Compute gradient for positions before the first zero + // Using at::where instead of masked_scatter_ for composite compliance + auto grad_before_first_zero = reversed_cumsum(w.masked_fill(~mask, 0.), dim); + if (!are_inputs_tensors_sublcass) { + grad_before_first_zero = grad_before_first_zero.div_(input_conj); + } else { + grad_before_first_zero = grad_before_first_zero.div(input_conj); + } + grad_input = at::where(mask, grad_before_first_zero, grad_input); + // select everything from the first zero to the second zero [z1, z2) mask = cumsum == 1; @@ -693,13 +706,21 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co // dy_j / dx_z1 = sum(cumprod(input[z1+1:z2] * grad[z1+1:z2])) * prod(output[z1-1]) // relu_() necessary as gather does not support negative indices // finally, we do grad_input[z1] = dy_j / dx_z1 - grad_input.masked_scatter_(first_zero_mask, - input_conj.masked_fill(~mask, 1.).cumprod(dim) - .mul_(grad.masked_fill(cumsum != 1, 0.)) - .sum(dim, /*keepdim*/true) - .mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_()) - .masked_fill_(first_zero_index == 0, 1.)) - .masked_select(first_zero_mask)); + // Using at::where instead of masked_scatter_ for composite compliance + auto grad_at_first_zero = input_conj.masked_fill(~mask, 1.).cumprod(dim); + const auto grad_masked = grad.masked_fill(cumsum != 1, 0.); + const auto output_before_zero = at::gather(output_conj, dim, (first_zero_index - 1).relu_()) + .masked_fill_(first_zero_index == 0, 1.); + if (!are_inputs_tensors_sublcass) { + grad_at_first_zero = grad_at_first_zero.mul_(grad_masked) + .sum(dim, /*keepdim*/true) + .mul_(output_before_zero); + } else { + grad_at_first_zero = grad_at_first_zero.mul(grad_masked) + .sum(dim, /*keepdim*/true) + .mul(output_before_zero); + } + grad_input = at::where(first_zero_mask, grad_at_first_zero, grad_input); return grad_input; } else { // GradMode::enabled() /* @@ -1864,8 +1885,8 @@ static Tensor& std_var_out( const char* fname, Tensor& result, const Tensor& self, at::OptionalIntArrayRef dim, const std::optional& correction_opt, bool keepdim, bool take_sqrt) { - TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda() || self.device().is_xpu(), - "std and var supports tensors on a CPU, CUDA, or XPU device only, but got: ", + TORCH_CHECK(self.device().is_cpu() || self.device().is_cuda() || self.device().is_xpu() || self.device().is_privateuseone(), + "std and var supports tensors on a CPU, CUDA, XPU or PrivateUse1 device only, but got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "std and var only supports strided layout, got: ", self.layout()); @@ -1937,8 +1958,8 @@ static std::tuple std_var_mean_out( at::OptionalIntArrayRef dim, const std::optional& correction_opt, bool keepdim, bool take_sqrt) { AT_ASSERT(result1.defined() && result2.defined()); - TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu(), - fname, " supports tensors on a CPU, CUDA, or XPU device only, got: ", + TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu() || self.is_privateuseone(), + fname, " supports tensors on a CPU, CUDA, XPU or PrivateUse1 device only, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, fname, " only supports strided layout, got: ", self.layout()); @@ -2351,10 +2372,14 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { Tensor value_selecting_reduction_backward_symint(const Tensor& grad, int64_t dim, const Tensor& indices, c10::SymIntArrayRef sizes, bool keepdim) { auto inplace_scatter_if_not_tensor_subclass = [&](const Tensor& grad_out, const Tensor& indices_) { - auto grad_in = at::zeros_symint(sizes, grad_out.options()); if (areAnyTensorSubclassLike({grad, indices})) { + // Use new_zeros_symint so that tensor subclasses (e.g. DTensor) + // can intercept the zeros creation through dispatch, ensuring + // the result has matching subclass type for subsequent scatter. + auto grad_in = grad_out.new_zeros_symint(sizes); return grad_in.scatter(dim, indices_, grad_out); } + auto grad_in = at::zeros_symint(sizes, grad_out.options()); return grad_in.scatter_(dim, indices_, grad_out); }; diff --git a/aten/src/ATen/native/Repeat.cpp b/aten/src/ATen/native/Repeat.cpp index 11b528b445ed4..62c66e2be1bea 100644 --- a/aten/src/ATen/native/Repeat.cpp +++ b/aten/src/ATen/native/Repeat.cpp @@ -105,7 +105,7 @@ Tensor repeat_interleave_symint( std::optional output_size) { Tensor input = dim_opt ? self : self.flatten(); int64_t dim = c10::maybe_wrap_dim(dim_opt.value_or(0), self.dim()); - TORCH_CHECK(repeats >= 0, "Repeats must be non-negative"); + TORCH_SYM_CHECK(repeats.sym_ge(0), "Repeats must be non-negative"); input = input.unsqueeze(dim + 1); auto expand_shape = input.sym_sizes().vec(); @@ -115,9 +115,13 @@ Tensor repeat_interleave_symint( // This argument doesn't really make sense for the scalar overload, but exists // for consistency with the tensor overload if (output_size) { - auto calculated_size = (repeats * expand_shape[dim]).guard_int(__FILE__, __LINE__); - TORCH_CHECK(*output_size == calculated_size, "repeat_interleave: Invalid output_size, expected ", - calculated_size, " but got ", *output_size); + auto calculated_size = repeats * expand_shape[dim]; + TORCH_SYM_CHECK( + output_size->sym_eq(calculated_size), + "repeat_interleave: Invalid output_size, expected ", + calculated_size, + " but got ", + *output_size); } return input.clone(at::MemoryFormat::Contiguous).flatten(dim, dim + 1); diff --git a/aten/src/ATen/native/Resize.h b/aten/src/ATen/native/Resize.h index 3346cd2cb220e..d4bc55e30e701 100644 --- a/aten/src/ATen/native/Resize.h +++ b/aten/src/ATen/native/Resize.h @@ -181,24 +181,76 @@ inline void checkSetStorage(Tensor& result, Storage storage, T storage_offset, * (size, stride, storage_offset) must be in bounds for self's storage. */ template -inline void setStrided( - const Tensor& self, +void checkAsStridedArgs( ArrayRef size, ArrayRef stride, T storage_offset) { - TORCH_CHECK(size.size() == stride.size(), "mismatch in length of strides and shape"); + TORCH_CHECK( + size.size() == stride.size(), "mismatch in length of strides and shape"); for (const auto& val : stride) { - TORCH_CHECK(val >= 0, - "as_strided: Negative strides are not supported at the moment, " - "got strides: ", stride); + TORCH_CHECK( + val >= 0, + "as_strided: Negative strides are not supported at the moment, " + "got strides: ", + stride); } + TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset); +} + +template +void checkAsStridedArgsAllowUnbackedSymInts( + ArrayRef size, + ArrayRef stride, + T storage_offset) { + TORCH_CHECK( + size.size() == stride.size(), "mismatch in length of strides and shape"); + if constexpr (std::is_same_v) { + // FakeTensor/Meta view replay can pass ephemeral symbolic metadata here, + // so only validate values once the SymInts become concrete. + for (const auto& val : stride) { + if (auto maybe_val = val.maybe_as_int()) { + TORCH_CHECK( + *maybe_val >= 0, + "as_strided: Negative strides are not supported at the moment, " + "got strides: ", + stride); + } + } + + if (auto maybe_storage_offset = storage_offset.maybe_as_int()) { + TORCH_CHECK( + *maybe_storage_offset >= 0, + "Tensor: invalid storage offset ", + storage_offset); + } + } else { + for (const auto& val : stride) { + TORCH_CHECK( + val >= 0, + "as_strided: Negative strides are not supported at the moment, " + "got strides: ", + stride); + } + + TORCH_CHECK( + storage_offset >= 0, + "Tensor: invalid storage offset ", + storage_offset); + } +} + +template +inline void setStrided( + const Tensor& self, + ArrayRef size, + ArrayRef stride, + T storage_offset) { + checkAsStridedArgs(size, stride, storage_offset); auto* self_ = self.unsafeGetTensorImpl(); checkInBoundsForStorage( size, stride, storage_offset, self_->dtype(), self_->storage()); - /* storage offset */ - TORCH_CHECK(storage_offset >= 0, "Tensor: invalid storage offset ", storage_offset); self_->set_sizes_and_strides(size, stride, storage_offset); } diff --git a/aten/src/ATen/native/RowwisePrune.cpp b/aten/src/ATen/native/RowwisePrune.cpp index ec4e5fff10a01..ee3854c49e755 100644 --- a/aten/src/ATen/native/RowwisePrune.cpp +++ b/aten/src/ATen/native/RowwisePrune.cpp @@ -23,7 +23,7 @@ std::tuple _rowwise_prune_helper( ScalarType compressed_indices_dtype) { int num_non_masked_rows = 0; auto mask_contig = mask.contiguous(); - auto mask_data = mask_contig.data_ptr(); + auto mask_data = mask_contig.const_data_ptr(); for (const auto i : c10::irange(mask.numel())) { num_non_masked_rows += ((mask_data[i] == true) ? 1 : 0); } @@ -39,7 +39,7 @@ std::tuple _rowwise_prune_helper( auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr(); auto compressed_indices_mapping_data = compressed_indices_mapping.data_ptr(); - auto weights_data = weights.data_ptr(); + auto weights_data = weights.const_data_ptr(); int last_row_kept = 0; for (const auto i : c10::irange(mask.numel())) { if (mask_data[i]) { diff --git a/aten/src/ATen/native/ScaledBlas.cpp b/aten/src/ATen/native/ScaledBlas.cpp index d58593bfe018f..0b384d7226b73 100644 --- a/aten/src/ATen/native/ScaledBlas.cpp +++ b/aten/src/ATen/native/ScaledBlas.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #if !defined(__s390x__) && !defined(__powerpc__) #include #endif @@ -32,12 +33,102 @@ #include #include #include +#include #include #include #endif namespace at::native { +using at::blas::ScalingType; +using at::blas::SwizzleType; + +namespace { + +void invalid_scaling_config( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& scale_a, + const std::optional& scale_b) { + std::stringstream exception_ss; + exception_ss << "Invalid scaling configuration.\n"; + exception_ss << "- For TensorWise scaling, a and b should be float8, scales should be float and singletons.\n"; + exception_ss << "- For RowWise scaling, a and b should be float8, scales should be float, scale_a should be ("; + exception_ss << mat_a.size(0) << ", 1) and scale_b should be (1, " << mat_b.size(1) << "), and both should be contiguous.\n"; + exception_ss << "Got mat_a.dtype()=" << mat_a.scalar_type(); + if (scale_a.has_value()) { + exception_ss << ", scale_a.dtype()=" << scale_a.value().scalar_type() << ", scale_a.size()=" << scale_a.value().sizes() << ", scale_a.stride()=" << scale_a.value().strides(); + } + else { + exception_ss << ", scale_a=None"; + } + exception_ss << ", mat_b.dtype()=" << mat_b.scalar_type(); + if (scale_b.has_value()) { + exception_ss << ", scale_b.dtype()=" << scale_b.value().scalar_type() << ", scale_b.size()=" << scale_b.value().sizes() << " and scale_b.stride()=" << scale_b.value().strides(); + } + else { + exception_ss << " and scale_b=None"; + } + + TORCH_CHECK_VALUE(false, exception_ss.str()); +} + +/* + * Scaling Type Determination: + * --------------------------- + * Conditions and corresponding Scaling Types: + * + * - If scale.numel() == 1: + * - Returns TensorWise. + * + * - Else if scale.dim() == 2 && scale.size(0) == outer_dim && scale.size(1) == + * 1: + * - Returns RowWise. + * + * - Otherwise: + * - Returns Error. + */ + +bool is_tensorwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return at::isFloat8Type(t.scalar_type()) && + scale.scalar_type() == at::kFloat && scale.numel() == 1; +} + +bool is_rowwise_scaling(const at::Tensor& t, const at::Tensor& scale) { + return ( + at::isFloat8Type(t.scalar_type()) && scale.scalar_type() == at::kFloat && + scale.dim() == 2 && scale.size(0) == t.size(0) && scale.size(1) == 1 && + scale.is_contiguous()); +} + +bool is_desired_scaling( + const at::Tensor& t, + const at::Tensor& scale, + ScalingType desired_scaling) { + auto result = desired_scaling == ScalingType::TensorWise + ? is_tensorwise_scaling(t, scale) + : is_rowwise_scaling(t, scale); + return result; +} + +std::pair get_joint_scaling( + std::initializer_list> options, + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const at::Tensor& scale_a, + const at::Tensor& scale_b) { + for (auto [lhs, rhs] : options) { + if (is_desired_scaling(mat_a, scale_a, lhs) && + is_desired_scaling(mat_b.t(), scale_b.t(), rhs)) { + return {lhs, rhs}; + } + } + invalid_scaling_config(mat_a, mat_b, scale_a, scale_b); + return {}; +} + +} // namespace + static Tensor& _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, const Tensor& scale_a, @@ -46,14 +137,70 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, const std::optional& scale_result, std::optional out_dtype, bool use_fast_accum, + const scaled::ScaledGemmImplementation gemm_impl, Tensor& out) { TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); TORCH_CHECK( - mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[1] == mat2.sizes()[0], "mat_a and mat_b shapes cannot be multiplied (", mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); - TORCH_INTERNAL_ASSERT((scale_a.numel() == 1 && scale_b.numel() == 1), "Now _scaled_mm only supports per-tensor scaling for CPU backend."); + TORCH_CHECK_VALUE(gemm_impl == scaled::ScaledGemmImplementation::TENSORWISE_TENSORWISE || + gemm_impl == scaled::ScaledGemmImplementation::ROWWISE_ROWWISE, "Unsupported scaling implementation"); + + if (gemm_impl == scaled::ScaledGemmImplementation::TENSORWISE_TENSORWISE) { + // Restrictions: + // A, B are FP8, scales are fp32 + TORCH_CHECK_VALUE( + isFloat8Type(mat1.scalar_type()) && isFloat8Type(mat2.scalar_type()), + "mat1 and mat2 must be fp8 types, got: ", + mat1.scalar_type(), + mat2.scalar_type()); + TORCH_CHECK_VALUE( + scale_a.numel() == 1 && scale_a.scalar_type() == kFloat, + "scale_a must have 1 Float element"); + TORCH_CHECK_VALUE( + scale_b.numel() == 1 && scale_b.scalar_type() == kFloat, + "scale_b must have 1 Float element"); + } + else { + // Restrictions: + // A, B are FP8, scales are fp32, shape M/N for A/B + TORCH_CHECK_VALUE( + isFloat8Type(mat1.scalar_type()) && isFloat8Type(mat2.scalar_type()), + "mat1 and mat_b must be fp8 types, got: ", + mat1.scalar_type(), + mat2.scalar_type()); + TORCH_CHECK_VALUE( + scale_a.size(0) == mat1.size(0) && scale_a.size(1) == 1, + "scale_a must have shape [", + mat1.size(0), + ", 1], got [", + scale_a.sizes(), + "]"); + TORCH_CHECK_VALUE( + scale_a.numel() == mat1.size(0) && scale_a.scalar_type() == kFloat, + "scale_a must have ", + mat1.size(0), + " Float elements, got ", + scale_a.numel()); + TORCH_CHECK_VALUE( + scale_b.numel() == mat2.size(1) && scale_b.scalar_type() == kFloat, + "scale_b must have ", + mat2.size(1), + " Float elements, got ", + scale_b.numel()); + + TORCH_CHECK_VALUE( + scale_a.stride(1) == 1, + "expected scale_a.stride(1) to be 1, but got ", + scale_a.stride(1)); + TORCH_CHECK_VALUE( + scale_b.stride(1) == 1, + "expected scale_b.stride(1) to be 1, but got ", + scale_b.stride(1)); + } + TORCH_CHECK( !scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), @@ -62,32 +209,41 @@ _scaled_mm_out_cpu_emulated(const Tensor& mat1, const Tensor& mat2, " but got ", bias->numel()); // Check types - TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); - auto mat1_c = mat1.contiguous(); - auto mat2_c = mat2.contiguous(); - IntArrayRef mat1_sizes = mat1_c.sizes(); - IntArrayRef mat2_sizes = mat2_c.sizes(); + auto mat1_cont = mat1.contiguous(); + auto mat2_cont = mat2.contiguous(); + IntArrayRef mat1_sizes = mat1_cont.sizes(); + IntArrayRef mat2_sizes = mat2_cont.sizes(); at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); - float input_scale = scale_a.item(); - float weight_scale = scale_b.item(); + const auto out_dtype_ = out_dtype.value_or(c10::ScalarType::BFloat16); + float output_scale = 1.0f; if (scale_result.has_value() && - (*out_dtype == ScalarType::Float8_e4m3fn || - *out_dtype == ScalarType::Float8_e5m2)) { + (out_dtype_ == ScalarType::Float8_e4m3fn || + out_dtype_ == ScalarType::Float8_e5m2)) { output_scale = scale_result.value().item(); } - auto fp32_mat1 = at::mul(mat1.to(kFloat), input_scale); - auto fp32_mat2 = at::mul(mat2_c.to(kFloat), weight_scale); - auto out_tmp = at::matmul(fp32_mat1, fp32_mat2); + + at::Tensor fp32_mat_a; + at::Tensor fp32_mat_b; + if (gemm_impl == scaled::ScaledGemmImplementation::TENSORWISE_TENSORWISE) { + fp32_mat_a = at::mul(mat1_cont.to(kFloat), scale_a.item()); + fp32_mat_b = at::mul(mat2_cont.to(kFloat), scale_b.item()); + } + else { + fp32_mat_a = at::mul(mat1_cont.to(kFloat), scale_a); + fp32_mat_b = at::mul(mat2_cont.to(kFloat), scale_b); + } + + auto out_tmp = at::matmul(fp32_mat_a, fp32_mat_b); if (bias) { out_tmp.add_(bias.value()); } - if (*out_dtype == ScalarType::Float8_e4m3fn || - *out_dtype == ScalarType::Float8_e5m2) { + if (out_dtype_ == ScalarType::Float8_e4m3fn || + out_dtype_ == ScalarType::Float8_e5m2) { out_tmp = at::mul(out_tmp, 1 / output_scale); } out_tmp = out_tmp.to(out.scalar_type()); @@ -105,7 +261,7 @@ _scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, bool use_fast_accum, Tensor& out) { #if AT_MKLDNN_ENABLED() && !defined(__powerpc__) - if (at::globalContext().userEnabledMkldnn()) { + if (at::globalContext().userEnabledMkldnn() && scale_a.numel() == 1 && scale_b.numel() == 1) { bool mixed_dtype = mat1.scalar_type() != mat2.scalar_type(); if ((!mixed_dtype && cpuinfo_has_x86_amx_int8()) || (mixed_dtype && cpuinfo_has_x86_amx_fp16())) { @@ -123,7 +279,30 @@ _scaled_mm_out_cpu(const Tensor& mat1, const Tensor& mat2, } #endif { - return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + TORCH_CHECK_VALUE( + !out_dtype || *out_dtype == out.scalar_type(), + "out_dtype must match output matrix type"); + + const auto [scaling_choice_a, scaling_choice_b] = get_joint_scaling( + { + std::make_pair(ScalingType::TensorWise, ScalingType::TensorWise), + std::make_pair(ScalingType::RowWise, ScalingType::RowWise), + }, + mat1, + mat2, + scale_a, + scale_b); + + scaled::ScaledGemmImplementation gemm_impl{scaled::ScaledGemmImplementation::NONE}; + + if (scaling_choice_a == ScalingType::TensorWise && scaling_choice_b == ScalingType::TensorWise) { + gemm_impl = scaled::ScaledGemmImplementation::TENSORWISE_TENSORWISE; + } + else if (scaling_choice_a == ScalingType::RowWise && scaling_choice_b == ScalingType::RowWise) { + gemm_impl = scaled::ScaledGemmImplementation::ROWWISE_ROWWISE; + } + + return _scaled_mm_out_cpu_emulated(mat1, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, gemm_impl, out); } } @@ -140,6 +319,201 @@ _scaled_mm_cpu(const Tensor& mat_a, const Tensor& mat_b, return _scaled_mm_out_cpu(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); } +using acceptance_fn = std::function&, + ArrayRef&, + c10::ScalarType, + std::vector&, + ArrayRef&)>; + +namespace scaled_blas = at::native::scaled; +using scaled_blas::convert_int_to_enum; +using scaled_blas::ScaledGemmImplementation; + +std::array, 2> + scale_kernel_dispatch = {{ + {"tensorwise_tensorwise", + scaled_blas::check_tensorwise_recipe, + ScaledGemmImplementation::TENSORWISE_TENSORWISE}, + {"rowwise_rowwise", + scaled_blas::check_rowwise_recipe, + ScaledGemmImplementation::ROWWISE_ROWWISE}, + + }}; + +Tensor& _scaled_mm_cpu_v2_out( + const Tensor& mat_a, + const Tensor& mat_b, + ArrayRef scale_a, + IntArrayRef scale_recipe_a, + IntArrayRef swizzle_a, + ArrayRef scale_b, + IntArrayRef scale_recipe_b, + IntArrayRef swizzle_b, + const std::optional& bias, + const std::optional out_dtype, + IntArrayRef contraction_dim, + bool use_fast_accum, + Tensor& out) { + TORCH_CHECK_VALUE(mat_a.dim() == 2, "mat_a must be a matrix"); + TORCH_CHECK_VALUE(mat_b.dim() == 2, "mat_b must be a matrix"); + + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels + // do not support this case). + if (mat_a.size(0) == 0 || mat_a.size(1) == 0 || mat_b.size(1) == 0) { + // `out` was created with `at::empty`. In the case where we are multiplying + // MxK by KxN and K is the zero dim, we need to initialize here to properly + // return a tensor of zeros. + at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)}); + if (mat_a.size(1) == 0) { + out.zero_(); + } + + return out; + } + + // Check if the input matrix sizes can be multiplied + // - if optional contraction dims are provided, use those + // -- mostly for < 1B formats (i.e. nvfp4x2) where cheap .t() is not available. + if (contraction_dim.size() > 0) { + TORCH_CHECK_VALUE(contraction_dim.size() == 2, "contraction_dim must have exactly 2 elements"); + auto mat_a_dim = contraction_dim[0]; + auto mat_b_dim = contraction_dim[1]; + TORCH_CHECK_VALUE( + mat_a.size(mat_a_dim) == mat_b.size(mat_b_dim), "mat_a and mat_b shapes cannot be multiplied (", + mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ") ", + "with contraction dims mat_a: ", mat_a_dim, ", mat_b: ", mat_b_dim); + } else { + TORCH_CHECK_VALUE( + mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied (", + mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")"); + } + + TORCH_CHECK_VALUE( + !bias || bias->numel() == mat_b.sizes()[1], + "Bias must be size ", + mat_b.sizes()[1], + " but got ", + bias->numel()); + + TORCH_CHECK_VALUE( + !out_dtype || *out_dtype == out.scalar_type(), + "out_dtype must match output matrix type"); + + if (bias) { + TORCH_CHECK_VALUE( + bias->scalar_type() == kFloat || + bias->scalar_type() == c10::ScalarType::BFloat16 || + bias->scalar_type() == c10::ScalarType::Half, + "Bias must be Float32 or BFloat16 or Half, but got ", + bias->scalar_type()); + } + + // Align with CUDA's default out to be bf16 + const auto out_dtype_ = out_dtype.value_or(c10::ScalarType::BFloat16); + + // Conversion of implicitly-defined enums to explicit + auto scale_recipe_a_enum = convert_int_to_enum(scale_recipe_a); + auto swizzle_a_enum = convert_int_to_enum(swizzle_a); + auto scale_recipe_b_enum = convert_int_to_enum(scale_recipe_b); + auto swizzle_b_enum = convert_int_to_enum(swizzle_b); + + if (!swizzle_a_enum.empty() && !swizzle_b_enum.empty()) { + TORCH_CHECK_VALUE( + swizzle_a_enum[0] == at::blas::SwizzleType::NO_SWIZZLE && + swizzle_b_enum[0] == at::blas::SwizzleType::NO_SWIZZLE, + "CPU does not support swizzle."); + } + + // at this point we can start working out what we want to be doing + // Try to do as few steps as possible. + // NOTE: support is deliberately sparse, can explicitly enumerate all + // combinations allowed. Do this via a list of defined (name, acceptance, + // concrete_impl) tuples. + bool found_impl = false; + ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE; + + for (const auto& fn_entry : scale_kernel_dispatch) { + const auto [name, accept_fn, scaled_gemm_impl] = fn_entry; + const bool ok = accept_fn( + mat_a.scalar_type(), + scale_recipe_a_enum, + scale_a, + mat_b.scalar_type(), + scale_recipe_b_enum, + scale_b); + + if (ok) { + gemm_impl = scaled_gemm_impl; + found_impl = true; + break; + } + } + + if (!found_impl) { + const std::optional scale_a_opt = scale_a.empty() ? std::optional{std::nullopt} : std::optional{scale_a[0]}; + const std::optional scale_b_opt = scale_b.empty() ? std::optional{std::nullopt} : std::optional{scale_b[0]}; + + invalid_scaling_config(mat_a, mat_b, scale_a_opt, scale_b_opt); + } + + at::native::resize_output(out, {mat_a.size(0), mat_b.size(1)}); + + auto bias_ = bias.value_or(Tensor()); + + if (gemm_impl == ScaledGemmImplementation::TENSORWISE_TENSORWISE || + gemm_impl == ScaledGemmImplementation::ROWWISE_ROWWISE) { + _scaled_mm_out_cpu_emulated( + mat_a, + mat_b, + scale_a[0], + scale_b[0], + bias, + std::nullopt, // scale-result + out_dtype_, + use_fast_accum, + gemm_impl, + out); + } else { + TORCH_CHECK_VALUE(false, "Invalid state - found an implementation, but not really"); + } + + return out; +} + +Tensor _scaled_mm_cpu_v2( + const Tensor& mat_a, + const Tensor& mat_b, + ArrayRef scale_a, + IntArrayRef scale_recipe_a, + IntArrayRef swizzle_a, + ArrayRef scale_b, + IntArrayRef scale_recipe_b, + IntArrayRef swizzle_b, + const std::optional& bias, + const std::optional out_dtype, + IntArrayRef contraction_dim, + bool use_fast_accum) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); + + return _scaled_mm_cpu_v2_out( + mat_a, + mat_b, + scale_a, + scale_recipe_a, + swizzle_a, + scale_b, + scale_recipe_b, + swizzle_b, + bias, + out_dtype, + contraction_dim, + use_fast_accum, + out); +} + // TODO(vasiliy, future PR): figure out why we need to declare this function, when // other functions that live in ATen/native/*.cpp without declarations // or headers work just fine. diff --git a/aten/src/ATen/native/SobolEngineOps.cpp b/aten/src/ATen/native/SobolEngineOps.cpp index c061adc475859..2626d16018c3c 100644 --- a/aten/src/ATen/native/SobolEngineOps.cpp +++ b/aten/src/ATen/native/SobolEngineOps.cpp @@ -41,7 +41,7 @@ std::tuple _sobol_engine_draw(const Tensor& quasi, int64_t n, co // We deal with `data` and `strides` due to performance issues. int64_t l; int64_t* wquasi_data = wquasi.data_ptr(); - int64_t* sobolstate_data = sobolstate.data_ptr(); + const int64_t* sobolstate_data = sobolstate.const_data_ptr(); scalar_t* result_data = result.data_ptr(); int64_t wquasi_stride = wquasi.stride(0); @@ -74,7 +74,7 @@ Tensor& _sobol_engine_ff_(Tensor& quasi, int64_t n, const Tensor& sobolstate, // We deal with `data` and `strides` due to performance issues. int64_t* quasi_data = quasi.data_ptr(); - int64_t* sobolstate_data = sobolstate.data_ptr(); + const int64_t* sobolstate_data = sobolstate.const_data_ptr(); int64_t quasi_stride = quasi.stride(0); int64_t sobolstate_row_stride = sobolstate.stride(0), sobolstate_col_stride = sobolstate.stride(1); diff --git a/aten/src/ATen/native/SoftMax.cpp b/aten/src/ATen/native/SoftMax.cpp index fa83f3b6122fc..e3f2013eaf63d 100644 --- a/aten/src/ATen/native/SoftMax.cpp +++ b/aten/src/ATen/native/SoftMax.cpp @@ -154,7 +154,7 @@ void host_softmax( Tensor& output, const Tensor& input, const int64_t dim, - bool* mask, + const bool* mask, const std::optional mask_type_) { TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined"); @@ -173,9 +173,9 @@ void host_softmax( } int64_t dim_stride = inner_size; int64_t outer_stride = dim_size * dim_stride; - scalar_t* input_data_base = input.data_ptr(); + const scalar_t* input_data_base = input.const_data_ptr(); scalar_t* output_data_base = output.data_ptr(); - bool* mask_data_base = mask; + const bool* mask_data_base = mask; int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast(1)); parallel_for( 0, outer_size * inner_size, grain_size, @@ -183,7 +183,7 @@ void host_softmax( for (const auto i : c10::irange(begin, end)) { int64_t outer_idx = i / inner_size; int64_t inner_idx = i % inner_size; - scalar_t* input_data = + const scalar_t* input_data = input_data_base + outer_idx * outer_stride + inner_idx; scalar_t* output_data = output_data_base + outer_idx * outer_stride + inner_idx; @@ -201,7 +201,7 @@ void host_softmax( mask_outer_idx = outer_idx / (input.size(1) * input.size(2)); } - bool* mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx; + const bool* mask_data = mask_data_base + mask_outer_idx * outer_stride + inner_idx; // Calc max in softmax dim bool is_meaningful_max = false; @@ -248,7 +248,7 @@ void host_softmax_backward( const Tensor& grad, const Tensor& output, int64_t dim, - bool* mask = nullptr) { + const bool* mask = nullptr) { int64_t outer_size = 1; int64_t dim_size = grad.size(dim); @@ -262,9 +262,9 @@ void host_softmax_backward( int64_t dim_stride = inner_size; int64_t outer_stride = dim_size * dim_stride; scalar_t* gradInput_data_base = gI.data_ptr(); - scalar_t* output_data_base = output.data_ptr(); - scalar_t* gradOutput_data_base = grad.data_ptr(); - bool* mask_data_base = mask; + const scalar_t* output_data_base = output.const_data_ptr(); + const scalar_t* gradOutput_data_base = grad.const_data_ptr(); + const bool* mask_data_base = mask; int64_t grain_size = std::min(internal::GRAIN_SIZE / dim_size, static_cast(1)); parallel_for( 0, outer_size * inner_size, grain_size, [&](int64_t begin, int64_t end) { @@ -273,11 +273,11 @@ void host_softmax_backward( int64_t inner_idx = i % inner_size; scalar_t* gradInput_data = gradInput_data_base + outer_idx * outer_stride + inner_idx; - scalar_t* output_data = + const scalar_t* output_data = output_data_base + outer_idx * outer_stride + inner_idx; const scalar_t* gradOutput_data = gradOutput_data_base + outer_idx * outer_stride + inner_idx; - bool* mask_data = mask_data_base + outer_idx * outer_stride + inner_idx; + const bool* mask_data = mask_data_base + outer_idx * outer_stride + inner_idx; acc_type sum = 0; for (const auto d : c10::irange(dim_size)) { @@ -590,7 +590,7 @@ Tensor masked_softmax_cpu(const Tensor& input_, const Tensor& mask_, const std:: AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "masked_softmax", [&] { host_softmax( - output, input, dim, mask.data_ptr(), mask_type); + output, input, dim, mask.const_data_ptr(), mask_type); }); return output; } @@ -619,7 +619,7 @@ Tensor masked_softmax_backward_cpu( Tensor grad_input = at::empty_like(grad, grad.options()); AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::BFloat16, at::ScalarType::Half, grad.scalar_type(), "masked_softmax_backward", [&] { - host_softmax_backward(grad_input, grad, output, dim, mask.data_ptr()); + host_softmax_backward(grad_input, grad, output, dim, mask.const_data_ptr()); }); return grad_input; } diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index eb3c2c93ec8a0..9c1d6c4c4e0bc 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -206,7 +206,7 @@ QUANTILE_INTERPOLATION_MODE get_quantile_interpolation_mode( } void quantile_checks(const Tensor& self, const Tensor& q) { - TORCH_CHECK(self.numel() > 0, "quantile() input tensor must be non-empty"); + TORCH_SYM_CHECK(self.sym_numel().sym_gt(0), "quantile() input tensor must be non-empty"); TORCH_CHECK(q.dim() <= 1, "quantile() q must be a scalar or 1D tensor"); TORCH_CHECK( self.scalar_type() == kFloat || self.scalar_type() == kDouble, @@ -219,26 +219,26 @@ void quantile_checks(const Tensor& self, const Tensor& q) { "quantile() q tensor must be on the same device as the input tensor"); } -std::vector quantile_output_shape( +std::vector quantile_output_shape( const std::optional original_dim, const Tensor& self, const Tensor& q, const bool keepdim, int64_t wrapped_dim) { // Compute output shape: q_size + reduced_size - std::vector out_shape; + std::vector out_shape; if (original_dim && self.dim() > 0) { - out_shape = self.sizes().vec(); + out_shape = self.sym_sizes().vec(); if (keepdim) { out_shape[wrapped_dim] = 1; } else { out_shape.erase(out_shape.begin() + wrapped_dim); } } else if (keepdim) { - out_shape = std::vector(self.dim(), 1); + out_shape = std::vector(self.dim(), 1); } if (q.dim() > 0) { - out_shape.insert(out_shape.begin(), q.numel()); + out_shape.insert(out_shape.begin(), q.sym_numel()); } return out_shape; @@ -252,11 +252,13 @@ Tensor quantile_compute( const QUANTILE_INTERPOLATION_MODE& interpolation, const bool ignore_nan, int64_t wrapped_dim, - std::vector out_shape) { + std::vector out_shape) { // Checks that all q values are between 0 and 1, inclusive // NOTE: this check is only performed when running on the CPU to avoid // synchronizing an accelerator with the CPU - if (self.device().is_cpu()) { + // The check is also skipped when the actual q values are not available yet + // e.g. with symbolic shapes or during export + if (self.device().is_cpu() && !isTensorSubclassLike(q)) { auto all_q_in_range = q.ge(0).logical_and_(q.le(1)).all(); TORCH_CHECK(at::is_scalar_tensor_true(all_q_in_range), "quantile() q values must be in the range [0, 1]"); @@ -275,18 +277,18 @@ Tensor quantile_compute( // Treat q as a 1D tensor for the following computations if (q.dim() == 0) { - out_shape.insert(out_shape.begin(), q.numel()); + out_shape.insert(out_shape.begin(), 1); } // View input as reduced_size + size of dim to reduce - std::vector in_shape(out_shape.size()); + std::vector in_shape(out_shape.size()); std::copy(out_shape.begin() + 1, out_shape.end(), in_shape.begin()); - in_shape[in_shape.size() - 1] = sorted.size(-1); - sorted = sorted.view(in_shape); + in_shape[in_shape.size() - 1] = sorted.sym_size(-1); + sorted = sorted.view_symint(in_shape); // Ensure converting from int64_t to double won't overflow - TORCH_CHECK( - sorted.size(-1) <= std::pow(2, 24), + TORCH_SYM_CHECK( + sorted.sym_size(-1).sym_le(1 << 24), "quantile() input tensor is too large"); // Convert q in [0, 1] to ranks in [0, reduction_size) @@ -308,7 +310,7 @@ Tensor quantile_compute( } else { // For quantile, compute ranks based on reduction size. If there is nan // set rank to last index so the quantile computed will be nan. - int64_t last_index = sorted.size(-1) - 1; + auto last_index = sorted.sym_size(-1) - 1; std::vector tl = at::broadcast_tensors({q * last_index, sorted.isnan().any(-1, true)}); ranks = at::masked_fill(tl[0], tl[1], last_index); @@ -388,7 +390,7 @@ void quantile_out_impl( int64_t wrapped_dim = at::maybe_wrap_dim(original_dim.value_or(0), self.dim()); auto out_shape = quantile_output_shape(original_dim, self, q, keepdim, wrapped_dim); - resize_output(out, out_shape); + resize_output_symint(out, out_shape); auto quantile = quantile_compute( self, q, original_dim, keepdim, interpolation, ignore_nan, wrapped_dim, std::move(out_shape)); diff --git a/aten/src/ATen/native/SpectralOps.cpp b/aten/src/ATen/native/SpectralOps.cpp index 91a0c3ff8cf93..6a8dd6c75769b 100644 --- a/aten/src/ATen/native/SpectralOps.cpp +++ b/aten/src/ATen/native/SpectralOps.cpp @@ -80,9 +80,9 @@ ScalarType promote_type_fft(ScalarType type, bool require_complex, Device device } const bool maybe_support_half = ( - // Only CUDA supports half precision, but since meta tensors don't have a + // CUDA and XPU support half precision, but since meta tensors don't have a // device we err on the side of accepting it - device.is_cuda() || device.is_meta() + device.is_cuda() || device.is_meta() || device.is_xpu() ); if (maybe_support_half) { TORCH_CHECK(type == kHalf || type == kFloat || type == kDouble, "Unsupported dtype ", type); diff --git a/aten/src/ATen/native/SummaryOps.cpp b/aten/src/ATen/native/SummaryOps.cpp index 870a73328a97a..ebcd9d49e2579 100644 --- a/aten/src/ATen/native/SummaryOps.cpp +++ b/aten/src/ATen/native/SummaryOps.cpp @@ -31,12 +31,12 @@ Tensor _bincount_cpu_template( if (self.dim() == 1 && self.numel() == 0) { return at::zeros({minlength}, kLong); } - if (self.dim() != 1 || *self.min().data_ptr() < 0) { + if (self.dim() != 1 || *self.min().const_data_ptr() < 0) { TORCH_CHECK(false, "bincount only supports 1-d non-negative integral inputs."); } // Ensure max_val < 2 ^ 63 - 1 (9223372036854775807) - auto max_val = *self.max().data_ptr(); + auto max_val = *self.max().const_data_ptr(); if (max_val >= std::numeric_limits::max()) { TORCH_CHECK(false, "maximum value of input overflowed, it should be < ", diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index ecb67c9ef3799..ff24f132988a1 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -988,7 +988,8 @@ Tensor& _index_put_impl_( } } if ((self.device().type() == DeviceType::CUDA || - self.device().type() == DeviceType::XPU) && + self.device().type() == DeviceType::XPU || + self.device().type() == DeviceType::PrivateUse1) && (accumulate || (globalContext().deterministicAlgorithms() && value_.numel() > 1))) { TORCH_CHECK( @@ -2585,7 +2586,7 @@ static Tensor& masked_select_out_impl_cpu( auto mask_long = at::empty(shape, self.options().dtype(at::kLong)).copy_(*_mask); auto mask_prefix_sum = at::empty(shape, self.options().dtype(at::kLong)); - auto mask_long_data = mask_long.data_ptr(); + auto mask_long_data = mask_long.const_data_ptr(); auto mask_prefix_sum_data = mask_prefix_sum.data_ptr(); // TODO: Here can only use std::partial_sum for C++14, // use std::exclusive_scan when PyTorch upgrades to C++17, which have better diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index 6f127b711d3e8..b7b8ab17e9a4e 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -66,7 +66,7 @@ inline std::tuple canDispatchToMaskedFill( c10::irange(num_ind, self.ndimension())) { mask = mask.unsqueeze(-1); } - return std::make_tuple(true, mask); + return std::make_tuple(true, std::move(mask)); } inline AdvancedIndex make_info(Tensor self, IOptTensorListRef orig) { diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 909381ce5ed30..7bda22ffb945e 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -652,7 +652,6 @@ Tensor to_dense_backward( default: TORCH_CHECK( false, "to_dense_backward: Unsupported input layout: ", input_layout); - return Tensor{}; } } @@ -1399,7 +1398,6 @@ Tensor dense_to_sparse_with_mask( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor dense_to_sparse_csr( @@ -1482,7 +1480,6 @@ Tensor dense_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor dense_to_sparse(const Tensor& self, int64_t sparse_dim) { @@ -1766,7 +1763,6 @@ Tensor sparse_compressed_to_sparse_csr( false, "sparse_compressed_to_sparse_csr: expected SparseCsr or SparseCsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_compressed_to_sparse_csc( @@ -1787,7 +1783,6 @@ Tensor sparse_compressed_to_sparse_csc( false, "sparse_compressed_to_sparse_csc: expected SparseCsr or SparseCsc layout but got ", self.layout()); - return Tensor{}; } Tensor coo_to_sparse_csr( @@ -2138,8 +2133,8 @@ static Tensor _compressed_to_block_compressed_cpu( plain_dim, compressed_blocksize, plain_blocksize, - input_compressed_indices.data_ptr(), - input_plain_indices.data_ptr()); + input_compressed_indices.const_data_ptr(), + input_plain_indices.const_data_ptr()); }); DimVector dense_shape{input_values.sizes().slice(1, input_values.dim() - 1)}; DimVector values_shape{num_blocks, blocksize[0], blocksize[1]}; @@ -2222,7 +2217,6 @@ Tensor sparse_compressed_to_sparse_bsr( false, "sparse_compressed_to_sparse_bsr: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_compressed_to_sparse_bsc( @@ -2260,7 +2254,6 @@ Tensor sparse_compressed_to_sparse_bsc( false, "sparse_compressed_to_sparse_bsc: expected SparseCsr, SparseCsc, SparseBsr or SparseBsc layout but got ", self.layout()); - return Tensor{}; } Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) { @@ -2273,7 +2266,6 @@ Tensor sparse_coo_to_sparse(const Tensor& self, const int64_t sparse_dim) { " to ", kSparse, " conversion not supported"); - return Tensor{}; } Tensor sparse_compressed_to_sparse( @@ -2377,7 +2369,6 @@ Tensor sparse_compressed_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor sparse_coo_to_sparse( @@ -2414,7 +2405,6 @@ Tensor sparse_coo_to_sparse( " to ", layout_to, " conversion not supported"); - return Tensor{}; } Tensor to_sparse(const Tensor& self, const int64_t sparse_dim) { diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index c15b082f107b2..26cafbb4b2585 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -13,10 +13,15 @@ #include #include #include +#include #include +#include +#include #include #include +#include #include +#include #include #include #include @@ -76,6 +81,30 @@ c10::SymInt sym_storage_offset(const Tensor& self) { return self.sym_storage_offset(); } +int64_t numel(const Tensor& self) { + return self.numel(); +} + +int64_t dim(const Tensor& self) { + return self.dim(); +} + +int64_t get_device(const Tensor& self) { + return self.get_device(); +} + +int64_t storage_offset(const Tensor& self) { + return self.storage_offset(); +} + +bool is_contiguous(const Tensor& self) { + return self.is_contiguous(); +} + +bool is_contiguous(const Tensor& self, at::MemoryFormat memory_format) { + return self.is_contiguous(memory_format); +} + int64_t size(const Tensor& self, Dimname dim) { size_t pos_dim = dimname_to_position(self, dim); return self.sizes()[pos_dim]; diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index d52072cec6a10..2e572ec0400a6 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1422,11 +1422,12 @@ Tensor as_strided_tensorimpl( } template -static inline void setStridedUnchecked( +static void setStridedUnchecked( const Tensor& self, ArrayRef size, ArrayRef stride, T&& storage_offset) { + checkAsStridedArgsAllowUnbackedSymInts(size, stride, storage_offset); auto* self_ = self.unsafeGetTensorImpl(); self_->set_sizes_and_strides(size, stride, std::forward(storage_offset)); } @@ -1958,7 +1959,7 @@ Tensor repeat(const Tensor& self, IntArrayRef repeats) { auto range_a = at::arange(xtensor.dim(), at::TensorOptions(at::kLong)); auto range_b = range_a + n_dims; auto stacked = stack({std::move(range_a), std::move(range_b)}, 1).flatten(); - auto permutation = IntArrayRef(stacked.data_ptr(), n_dims * 2); + auto permutation = IntArrayRef(stacked.const_data_ptr(), n_dims * 2); // Permute from [a0, ..., ad-1, b0, ..., bd-1] to [a0, b0, ..., ad-1, bd-1] urtensor = urtensor.permute(permutation); // Reshape from [a0, b0, ..., ad-1, bd-1] to [a0 * b0, ..., ad-1 * bd-1] @@ -2454,7 +2455,7 @@ Tensor index_select_sparse_cpu( const auto index_contiguous = index.contiguous(); auto nneg_index = at::empty_like(index_contiguous); // nneg_index = (index < 0) * (index + size) + (index >= 0) * index - auto* ptr_index = index_contiguous.data_ptr(); + const auto* ptr_index = index_contiguous.const_data_ptr(); auto* ptr_nneg_index = nneg_index.data_ptr(); at::parallel_for( 0, diff --git a/aten/src/ATen/native/TestOps.cpp b/aten/src/ATen/native/TestOps.cpp index 90dbf97075093..a1dbde708157b 100644 --- a/aten/src/ATen/native/TestOps.cpp +++ b/aten/src/ATen/native/TestOps.cpp @@ -137,7 +137,6 @@ Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_inverse(const a TORCH_INTERNAL_ASSERT(false, "Attempted to call _test_autograd_multiple_dispatch_view_inverse() during the functionalization pass. ", "This function is for testing only and should never be called."); - return Tensor(); } } // namespace at::functionalization diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 8b59cb87730d0..ef534208e61e8 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -912,6 +912,10 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) { } Tensor& mvlgamma_out(const Tensor& self, int64_t p, Tensor& result) { + TORCH_CHECK( + self.device() == result.device(), + "Expected tensors to be on the same device, but found ", self.device(), " and ", result.device() + ); auto out = self.mvlgamma(p); TORCH_CHECK( at::can_cast(out.scalar_type(), result.scalar_type()), diff --git a/aten/src/ATen/native/Unique.cpp b/aten/src/ATen/native/Unique.cpp index b14079e7ea19c..cd8017290ed49 100644 --- a/aten/src/ATen/native/Unique.cpp +++ b/aten/src/ATen/native/Unique.cpp @@ -124,7 +124,7 @@ struct IsUnique {}; template struct IsUnique { - bool operator() (scalar_t* data_ptr, int64_t i) { + bool operator() (const scalar_t* data_ptr, int64_t i) { if (i == 0) { return true; } return c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1]); } @@ -132,7 +132,7 @@ struct IsUnique { template struct IsUnique { - bool operator() (scalar_t* data_ptr, int64_t i) { + bool operator() (const scalar_t* data_ptr, int64_t i) { if (i == 0) { return true; } return (c10::load(&data_ptr[i]) != c10::load(&data_ptr[i - 1])) && !(_isnan(data_ptr[i]) && _isnan(data_ptr[i - 1])); @@ -184,8 +184,8 @@ std::tuple unique_cpu_sorted_template( auto [input_sorted, indices] = input_flattened.sort(); - scalar_t* input_sorted_data = input_sorted.data_ptr(); - int64_t* indices_data = indices.data_ptr(); + const scalar_t* input_sorted_data = input_sorted.const_data_ptr(); + const int64_t* indices_data = indices.const_data_ptr(); int num_threads = at::get_num_threads(); std::vector unique_count_thread(num_threads, 0); @@ -433,7 +433,7 @@ std::tuple _unique_dim_cpu_template( output = output.view(new_sizes); output = output.moveaxis(0, dim); - return std::make_tuple(output, inverse_indices, counts); + return std::make_tuple(std::move(output), std::move(inverse_indices), std::move(counts)); } } // namespace diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index cf6727c2207c7..6483ca256414e 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -78,6 +78,7 @@ using _upsampling_bilinear2d_aa = void(*)(const Tensor& output, const Tensor& in using upsampling_trilinear3d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_d, scale_t scales_h, scale_t scales_w); using upsampling_bicubic2d = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); using _upsampling_bicubic2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); +using _upsampling_lanczos2d_aa = void(*)(const Tensor& output, const Tensor& input, bool align_corners, scale_t scales_h, scale_t scales_w); DECLARE_DISPATCH(upsampling_nearest1d, upsample_nearest1d_kernel) DECLARE_DISPATCH(_upsampling_nearest_exact1d, _upsample_nearest_exact1d_kernel) DECLARE_DISPATCH(upsampling_nearest2d, upsample_nearest2d_kernel) @@ -101,6 +102,8 @@ DECLARE_DISPATCH(upsampling_trilinear3d, upsample_trilinear3d_backward_kernel) DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel) DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel) DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel) +DECLARE_DISPATCH(_upsampling_lanczos2d_aa, _upsample_lanczos2d_aa_kernel) +DECLARE_DISPATCH(_upsampling_lanczos2d_aa, _upsample_lanczos2d_aa_backward_kernel) [[maybe_unused]] inline std::array upsample_1d_common_check( IntArrayRef input_size, diff --git a/aten/src/ATen/native/UpSampleLanczos2d.cpp b/aten/src/ATen/native/UpSampleLanczos2d.cpp new file mode 100644 index 0000000000000..22ae7d972029f --- /dev/null +++ b/aten/src/ATen/native/UpSampleLanczos2d.cpp @@ -0,0 +1,102 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::meta { + +TORCH_META_FUNC(_upsample_lanczos2d_aa) ( + const Tensor& input, IntArrayRef output_size, bool align_corners, std::optional scales_h, std::optional scales_w +) { + auto full_output_size = native::upsample_2d_common_check(input.sizes(), output_size); + + TORCH_CHECK( + input.numel() != 0 || c10::multiply_integers(input.sizes().begin() + 1, input.sizes().end()), + "Non-empty 4D data tensor expected but got a tensor with sizes ", + input.sizes()); + + set_output_raw_strided(0, full_output_size, {}, input.options().memory_format(input.suggest_memory_format())); +} + +TORCH_META_FUNC(_upsample_lanczos2d_aa_backward) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w +) { + auto full_output_size = native::upsample_2d_common_check(input_size, output_size); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", grad_output.dim()); + + for (const auto i : c10::irange(4)) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", i, ") = ", full_output_size[i], + " but got grad_output.size(", i, ") = ", grad_output.size(i)); + } + + set_output_raw_strided(0, input_size, {}, grad_output.options()); +} + +} // namespace at::meta +namespace at::native { + +TORCH_IMPL_FUNC(_upsample_lanczos2d_aa_out_cpu) ( + const Tensor& input, + IntArrayRef output_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& output +) { + _upsample_lanczos2d_aa_kernel(kCPU, output, input, align_corners, scales_h, scales_w); +} + +TORCH_IMPL_FUNC(_upsample_lanczos2d_aa_backward_out_cpu) ( + const Tensor& grad_output, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + std::optional scales_h, + std::optional scales_w, + const Tensor& grad_input +) { + grad_input.zero_(); + _upsample_lanczos2d_aa_backward_kernel(kCPU, grad_input, grad_output, align_corners, scales_h, scales_w); +} + +// vec variant + +using at::native::upsample::compute_output_size; +using at::native::upsample::get_scale_value; + +Tensor _upsample_lanczos2d_aa( + const Tensor& input, + at::OptionalIntArrayRef output_size, + bool align_corners, + std::optional> scale_factors) { + auto osize = compute_output_size(input.sizes(), output_size, scale_factors); + auto scale_h = get_scale_value(scale_factors, 0); + auto scale_w = get_scale_value(scale_factors, 1); + return at::_upsample_lanczos2d_aa(input, osize, align_corners, scale_h, scale_w); +} + +DEFINE_DISPATCH(_upsample_lanczos2d_aa_kernel); +DEFINE_DISPATCH(_upsample_lanczos2d_aa_backward_kernel); + +} // namespace at::native diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp index 12841ad8e7391..8ecf4a2324074 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_dynamic.cpp @@ -29,7 +29,6 @@ at::Tensor PackedLinearWeightQnnp::apply_dynamic_impl( false, "Sparse quantized dynamic linear with fused relu is not yet " "supported on qnnpack backend."); - return at::Tensor(); } template <> diff --git a/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp b/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp index 7126c1f7b5c37..4c2d60014d994 100644 --- a/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp +++ b/aten/src/ATen/native/cpu/AmpGradScalerKernels.cpp @@ -62,7 +62,7 @@ void _amp_foreach_non_finite_check_and_unscale_cpu_kernel( "_amp_foreach_non_finite_check_and_unscale_cpu", [&iter, &found_inf, &inv_scale] { auto* found_inf_ptr = found_inf.data_ptr(); - auto* inv_scale_ptr = inv_scale.data_ptr(); + const auto* inv_scale_ptr = inv_scale.const_data_ptr(); using opmath_t = at::opmath_type; @@ -96,7 +96,7 @@ void _amp_foreach_non_finite_check_and_unscale_cpu_kernel( "_amp_foreach_non_finite_check_and_unscale_cpu", [&iter, &found_inf, &inv_scale] { auto* found_inf_ptr = found_inf.data_ptr(); - auto* inv_scale_ptr = inv_scale.data_ptr(); + const auto* inv_scale_ptr = inv_scale.const_data_ptr(); at::native::cpu_kernel_vec( iter, [found_inf_ptr, inv_scale_ptr](scalar_t val_in) -> scalar_t { @@ -166,7 +166,7 @@ at::Tensor& _amp_update_scale_cpu_kernel( float* current_scale_ptr = current_scale.data_ptr(); int* growth_tracker_ptr = growth_tracker.data_ptr(); - float* found_inf_ptr = found_inf.data_ptr(); + const float* found_inf_ptr = found_inf.const_data_ptr(); if (*found_inf_ptr) { *current_scale_ptr = (*current_scale_ptr) * backoff_factor; diff --git a/aten/src/ATen/native/cpu/AtomicAddFloat.h b/aten/src/ATen/native/cpu/AtomicAddFloat.h index 1ecfbe0357fa8..81d3884a36feb 100644 --- a/aten/src/ATen/native/cpu/AtomicAddFloat.h +++ b/aten/src/ATen/native/cpu/AtomicAddFloat.h @@ -11,26 +11,20 @@ static inline void cpu_atomic_add_float(float* dst, float fvalue) { - typedef union { - unsigned intV; - float floatV; - } uf32_t; - - uf32_t new_value, old_value; - std::atomic* dst_intV = (std::atomic*)dst; - - old_value.floatV = *dst; - new_value.floatV = old_value.floatV + fvalue; - - unsigned* old_intV = &old_value.intV; - while (!std::atomic_compare_exchange_strong(dst_intV, old_intV, new_value.intV)) { +#if defined(__cpp_lib_atomic_ref) && __cpp_lib_atomic_ref >= 201806L + std::atomic_ref atomic_dst(*dst); +#else + auto& atomic_dst = *reinterpret_cast*>(dst); +#endif + float old_value = atomic_dst.load(); + float new_value = old_value + fvalue; + while (!atomic_dst.compare_exchange_weak(old_value, new_value)) { #ifdef __aarch64__ __asm__ __volatile__("yield;" : : : "memory"); #else _mm_pause(); #endif - old_value.floatV = *dst; - new_value.floatV = old_value.floatV + fvalue; + new_value = old_value + fvalue; } } diff --git a/aten/src/ATen/native/cpu/DistributionTemplates.h b/aten/src/ATen/native/cpu/DistributionTemplates.h index 0a30fcb2aab26..a1ec35f926b38 100644 --- a/aten/src/ATen/native/cpu/DistributionTemplates.h +++ b/aten/src/ATen/native/cpu/DistributionTemplates.h @@ -85,134 +85,94 @@ struct RandomKernel { // ==================================================== Normal ======================================================== -#ifdef CPU_CAPABILITY_AVX2 -void normal_fill_16_AVX2(float *data, - const __m256* two_pi, - const __m256* one, - const __m256* minus_two, - const __m256* mean, - const __m256* std_v) { - const __m256 u1 = _mm256_sub_ps(*one, _mm256_loadu_ps(data)); - const __m256 u2 = _mm256_loadu_ps(data + 8); - // sincos256_ps and log256_ps are from avx_mathfun.h - const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(*minus_two, log256_ps(u1))); - const __m256 theta = _mm256_mul_ps(*two_pi, u2); - __m256 sintheta, costheta; - sincos256_ps(theta, &sintheta, &costheta); - const __m256 n1 = _mm256_mul_ps(radius, costheta); - const __m256 n2 = _mm256_mul_ps(radius, sintheta); - _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, *std_v, *mean)); - _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, *std_v, *mean)); -} - -template -void normal_fill_AVX2(const TensorBase &self, const float mean, const float std, RNG generator) { - float *data = self.data_ptr(); - auto size = self.numel(); - std::lock_guard lock(generator->mutex_); - for (const auto i : c10::irange(size)) { - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(generator); - } - const __m256 two_pi = _mm256_set1_ps(2.0f * c10::pi); - const __m256 one = _mm256_set1_ps(1.0f); - const __m256 minus_two = _mm256_set1_ps(-2.0f); - const __m256 mean_v = _mm256_set1_ps(mean); - const __m256 std_v = _mm256_set1_ps(std); - - for (int64_t i = 0; i < size - 15; i += 16) { - normal_fill_16_AVX2(data + i, &two_pi, &one, &minus_two, &mean_v, &std_v); - } - - if (size % 16 != 0) { - // Recompute the last 16 values. - data = data + size - 16; - for (const auto i : c10::irange(16)) { - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(generator); +// Box-Muller transform on 16 elements (8 pairs of uniforms). +// Primary template is scalar; float specialization uses SIMD when available. +// Constructed once with mean/std so constants are not rebuilt per call. +template +struct NormalFill16 { + opmath_t mean_; + opmath_t std_; + + NormalFill16(opmath_t mean, opmath_t std) + : mean_(mean), std_(std) {} + + void operator()(opmath_t* data) const { + for (const auto j : c10::irange(8)) { + const opmath_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. + const opmath_t u2 = data[j + 8]; + const opmath_t radius = std::sqrt(-2 * std::log(u1)); + const opmath_t theta = 2.0f * c10::pi * u2; + data[j] = std::fma(radius * std::cos(theta), std_, mean_); + data[j + 8] = std::fma(radius * std::sin(theta), std_, mean_); } - normal_fill_16_AVX2(data, &two_pi, &one, &minus_two, &mean_v, &std_v); } -} -#endif +}; -template -void normal_fill_16(scalar_t *data, const scalar_t mean, const scalar_t std) { - for (const auto j : c10::irange(8)) { - const scalar_t u1 = 1 - data[j]; // [0, 1) -> (0, 1] for log. - const scalar_t u2 = data[j + 8]; - const scalar_t radius = std::sqrt(-2 * std::log(u1)); - const scalar_t theta = 2.0f * c10::pi * u2; - data[j] = radius * std::cos(theta) * std + mean; - data[j + 8] = radius * std::sin(theta) * std + mean; +#if defined(CPU_CAPABILITY_AVX2) + +template <> +struct NormalFill16 { + __m256 mean_; + __m256 std_; + __m256 two_pi_ = _mm256_set1_ps(2.0f * c10::pi); + __m256 one_ = _mm256_set1_ps(1.0f); + __m256 minus_two_ = _mm256_set1_ps(-2.0f); + + NormalFill16(float mean, float std) + : mean_(_mm256_set1_ps(mean)), std_(_mm256_set1_ps(std)) {} + + void operator()(float* data) const { + const __m256 u1 = _mm256_sub_ps(one_, _mm256_loadu_ps(data)); + const __m256 u2 = _mm256_loadu_ps(data + 8); + // sincos256_ps and log256_ps are from avx_mathfun.h + const __m256 radius = _mm256_sqrt_ps(_mm256_mul_ps(minus_two_, log256_ps(u1))); + const __m256 theta = _mm256_mul_ps(two_pi_, u2); + __m256 sintheta, costheta; + sincos256_ps(theta, &sintheta, &costheta); + const __m256 n1 = _mm256_mul_ps(radius, costheta); + const __m256 n2 = _mm256_mul_ps(radius, sintheta); + _mm256_storeu_ps(data, _mm256_fmadd_ps(n1, std_, mean_)); + _mm256_storeu_ps(data + 8, _mm256_fmadd_ps(n2, std_, mean_)); } -} - -#if defined(__VSX__) || defined(CPU_CAPABILITY_VSX) -static void normal_fill_16_VSX(float *data,const Vectorized &two_pi,const Vectorized &one,const Vectorized &minus_two,const Vectorized &mean,const Vectorized &std) { - using Vec = Vectorized; - Vec u1=one-Vec::loadu(data); - Vec u2=Vec::loadu(data+8); - Vec radius=(minus_two * u1.log()); - radius=radius.sqrt(); - Vec theta=two_pi * u2; - Vec output_vec=radius * theta.cos() * std + mean; - Vec output_vec2=radius * theta.sin() * std + mean; - output_vec.store(data); - output_vec2.store(data+8); -} +}; -template -void normal_fill_VSX(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { - float *data = self.data_ptr(); - auto size = self.numel(); - std::lock_guard lock(generator->mutex_); - for (const auto i : c10::irange(size)) { - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(generator); - } +#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) +template <> +struct NormalFill16::size() == 8> { using Vec = Vectorized; - const Vec two_pi = Vec(2.0f * c10::pi); - const Vec one = Vec(1.0f); - const Vec minus_two = Vec(-2.0f); - const Vec var_vec = Vec(std); - const Vec mean_vec = Vec(mean); - - for (int64_t i = 0; i < size - 15; i += 16) { - if(Vec::size()==8) { - normal_fill_16_VSX(data + i, two_pi, one, minus_two, mean_vec, var_vec); - } - else{ - normal_fill_16(data + i, mean, std); - } + Vec mean_; + Vec std_; + Vec two_pi_ = Vec(2.0f * c10::pi); + Vec one_ = Vec(1.0f); + Vec minus_two_ = Vec(-2.0f); + + NormalFill16(float mean, float std) + : mean_(mean), std_(std) {} + + void operator()(float* data) const { + Vec u1 = one_ - Vec::loadu(data); + Vec u2 = Vec::loadu(data + 8); + Vec radius = (minus_two_ * u1.log()).sqrt(); + Vec theta = two_pi_ * u2; + Vec output1 = radius * theta.cos() * std_ + mean_; + Vec output2 = radius * theta.sin() * std_ + mean_; + output1.store(data); + output2.store(data + 8); } - if (size % 16 != 0) { - // Recompute the last 16 values. - data = data + size - 16; - for (const auto i : c10::irange(16)) { - at::uniform_real_distribution uniform(0, 1); - data[i] = uniform(generator); - } - if(Vec::size()==8){ - normal_fill_16_VSX(data, two_pi, one, minus_two, mean_vec, var_vec); - } - else{ - normal_fill_16(data, mean, std); - } - } -} -#endif //VSX +}; + +#endif template -void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std, RNG generator) { +void normal_fill(const TensorBase &self, double mean, double std, RNG generator) { using opmath_t = at::opmath_type; scalar_t *data = self.data_ptr(); auto size = self.numel(); std::lock_guard lock(generator->mutex_); - auto omean = static_cast(mean); - auto ostd = static_cast(std); at::uniform_real_distribution uniform(0, 1); + NormalFill16 normal_fill_16( + static_cast(mean), static_cast(std)); if constexpr (std::is_same_v) { // float/double: generate uniform samples directly into the output buffer, @@ -221,35 +181,37 @@ void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std data[i] = uniform(generator); } for (int64_t i = 0; i < size - 15; i += 16) { - normal_fill_16(data + i, omean, ostd); + normal_fill_16(data + i); } + // Recompute the last 16 values. if (size % 16 != 0) { data = data + size - 16; for (const auto i : c10::irange(16)) { data[i] = uniform(generator); } - normal_fill_16(data, omean, ostd); + normal_fill_16(data); } } else { // bf16/fp16: generate in opmath_t precision using a stack buffer, // apply Box-Muller, then cast down to scalar_t. opmath_t buf[16]; for (int64_t i = 0; i < size - 15; i += 16) { - for (int j = 0; j < 16; j++) { + for (const auto j : c10::irange(16)) { buf[j] = uniform(generator); } - normal_fill_16(buf, omean, ostd); - for (int j = 0; j < 16; j++) { + normal_fill_16(buf); + for (const auto j : c10::irange(16)) { data[i + j] = static_cast(buf[j]); } } + // Recompute the last 16 values. if (size % 16 != 0) { int64_t offset = size - 16; - for (int j = 0; j < 16; j++) { + for (const auto j : c10::irange(16)) { buf[j] = uniform(generator); } - normal_fill_16(buf, omean, ostd); - for (int j = 0; j < 16; j++) { + normal_fill_16(buf); + for (const auto j : c10::irange(16)) { data[offset + j] = static_cast(buf[j]); } } @@ -259,28 +221,18 @@ void normal_fill(const TensorBase &self, const scalar_t mean, const scalar_t std template void normal_kernel(const TensorBase &self, double mean, double std, RNG generator) { auto size = self.numel(); - if (self.scalar_type() == ScalarType::Float && size >= 16 && self.is_contiguous()) { -#ifdef CPU_CAPABILITY_AVX2 - normal_fill_AVX2(self, static_cast(mean), static_cast(std), generator); -#elif defined(__VSX__) || defined(CPU_CAPABILITY_VSX) - normal_fill_VSX(self, static_cast(mean), static_cast(std), generator); -#else - normal_fill(self, static_cast(mean), static_cast(std), generator); -#endif - } else { - AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { - if (size >= 16 && self.is_contiguous()) { - normal_fill(self, static_cast(mean), static_cast(std), generator); - } else { - auto iter = TensorIterator::borrowing_nullary_op(self); - std::lock_guard lock(generator->mutex_); - cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { - at::normal_distribution normal(mean, std); - return static_cast(normal(generator)); - }); - } - }); - } + AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "normal_kernel_cpu", [&] { + if (size >= 16 && self.is_contiguous()) { + normal_fill(self, mean, std, generator); + } else { + auto iter = TensorIterator::borrowing_nullary_op(self); + std::lock_guard lock(generator->mutex_); + cpu_serial_kernel(iter, [mean, std, generator]() -> scalar_t { + at::normal_distribution normal(mean, std); + return static_cast(normal(generator)); + }); + } + }); } template diff --git a/aten/src/ATen/native/cpu/HistogramKernel.cpp b/aten/src/ATen/native/cpu/HistogramKernel.cpp index 261683a187b8a..a2994721ee577 100644 --- a/aten/src/ATen/native/cpu/HistogramKernel.cpp +++ b/aten/src/ATen/native/cpu/HistogramKernel.cpp @@ -107,12 +107,12 @@ void histogramdd_cpu_contiguous(Tensor& hist, const TensorList& bin_edges, ? std::optional>(weight.value().accessor()) : std::optional>(); - std::vector bin_seq(D); + std::vector bin_seq(D); std::vector num_bin_edges(D); std::vector leftmost_edge(D), rightmost_edge(D); for (const auto dim : c10::irange(D)) { - bin_seq[dim] = bin_edges[dim].data_ptr(); + bin_seq[dim] = bin_edges[dim].const_data_ptr(); num_bin_edges[dim] = bin_edges[dim].numel(); leftmost_edge[dim] = bin_seq[dim][0]; rightmost_edge[dim] = bin_seq[dim][num_bin_edges[dim] - 1]; diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index 05e7f93d2f364..adadb636e7f9b 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -301,13 +301,15 @@ VectorizedLoop2d make_vectorized_loop2d( } template -void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { +void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE, bool check_dynamic_casting = true) { using traits = function_traits; // this could be extended to work with void return types TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); // dynamic casting not currently supported on CPU - TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + if (check_dynamic_casting) { + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + } iter.for_each([&](char** data, const int64_t* strides, int64_t n) { // basic loop can handle 1d slices with arbitrary strides, and 1d slices is all that @@ -317,6 +319,11 @@ void cpu_kernel(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at:: iter.cast_outputs(); } +template +void cpu_kernel_opaque(TensorIteratorBase& iter, func_t&& op, int64_t grain_size = at::internal::GRAIN_SIZE) { + return cpu_kernel(iter, op, grain_size, false); +} + // This function helps write elementwise kernels that requires multiple outputs. // It follows the similar structure of cpu_kernel. // Instead of `basic_loop` function, a new `multiple_outputs_loop` function is diff --git a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp index 9f535af4781c6..97973bdd089d8 100644 --- a/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp +++ b/aten/src/ATen/native/cpu/SpmmReduceKernel.cpp @@ -310,7 +310,7 @@ void spmm_reduce_backward_input_arg_kernel_impl( const scalar_t* grad_out_data = grad_out.const_data_ptr(); auto col_data = col_indices.accessor(); const scalar_t* other_data = other.const_data_ptr(); - index_t* arg_out_data = arg_out.data_ptr(); + const index_t* arg_out_data = arg_out.const_data_ptr(); int64_t M = grad_out.size(0); int64_t K = grad_out.size(1); @@ -321,7 +321,7 @@ void spmm_reduce_backward_input_arg_kernel_impl( for (const auto m : c10::irange(begin, end)) { const scalar_t* grad_out_ptr = grad_out_data + m * K; scalar_t* grad_ptr = grad_data + m * K; - index_t* arg_out_ptr = arg_out_data + m * K; + const index_t* arg_out_ptr = arg_out_data + m * K; for (const auto k : c10::irange(K)) { if (arg_out_ptr[k] == index_t(nnz)) { diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index 682a8a2b8eff6..6274f22ecad7c 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -1416,6 +1416,89 @@ struct HelperInterpCubic : public HelperInterpBase { }; +struct HelperInterpLanczos : public HelperInterpBase { + + static constexpr int interp_size = 6; + + // Taken from + // https://github.com/python-pillow/Pillow/blob/8004234d879254cc354935ad42fbb51b1700925e/src/libImaging/Resample.c#L64-L86 + template + static inline scalar_t sinc_filter(scalar_t x) { + if (x == 0.0) { + return 1.0; + } + x *= c10::pi; + return std::sin(x) / x; + } + + template + static inline scalar_t aa_filter(scalar_t x) { + // Lanczos-3 filter: sinc(x) * sinc(x/3) for |x| < 3 + x = std::abs(x); + if (x < 3.0) { + return sinc_filter(x) * sinc_filter(x / 3.0); + } + return 0.0; + } + + static inline std::vector compute_index_ranges_weights( + at::ScalarType scalar_type, + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + const std::optional& opt_scale, + bool antialias + ) { + + std::vector indices_weights; + AT_DISPATCH_FLOATING_TYPES( + scalar_type, "compute_index_ranges_weights", [&] { + + scalar_t scale = area_pixel_compute_scale( + input_size, output_size, align_corners, opt_scale); + + auto interp_size = HelperInterpLanczos::interp_size; + + indices_weights = std::get<0>(HelperInterpLanczos::_compute_index_ranges_weights( + input_size, + output_size, + stride, + ndims, + reshape_dim, + scale, + interp_size, + &HelperInterpLanczos::aa_filter, + /*antialias=*/antialias, + /*align_corners=*/align_corners)); + } + ); + return indices_weights; + } + + static inline std::tuple, int, unsigned int> compute_index_ranges_int16_weights( + int64_t input_size, + int64_t output_size, + int64_t stride, + int64_t ndims, + int64_t reshape_dim, + bool align_corners, + const std::optional& opt_scale, + bool antialias, + bool align_i32=false + ) { + + auto interp_size = HelperInterpLanczos::interp_size; + auto fn = HelperInterpLanczos::aa_filter; + return HelperInterpLanczos::_compute_index_ranges_int16_weights( + input_size, output_size, stride, ndims, reshape_dim, + align_corners, opt_scale, interp_size, fn, antialias, align_i32); + } + +}; + // Generic upsampling interpolation kernel for N-d case. // Input is assumed to be like NCHW, NCL, NCKHW - interpolated spatial dimension // are those from the end up to batch size N and number of channels C. @@ -1485,15 +1568,19 @@ void upsample_non_separable_Nd_kernel_impl( if (interp_size > 1) { // Nearest also supports uint8 tensor, so need to handle it separately + // Dispatch name should be "upsample_non_separable" but we keep the old + // name for internal BC. AT_DISPATCH_FLOATING_TYPES_AND2( - kBFloat16, kHalf, iter.dtype(), "upsample_non_separable", [&] { + kBFloat16, kHalf, iter.dtype(), "upsample_generic_Nd", [&] { // MSVC can not catch constexpr int interp_size here constexpr int mode = F::interp_size; upsample_non_separable(iter); }); } else { + // Dispatch name should be "upsample_non_separable" but we keep the old + // name for internal BC. AT_DISPATCH_FLOATING_TYPES_AND3(kByte, kBFloat16, kHalf, - iter.dtype(), "upsample_non_separable", [&] { + iter.dtype(), "upsample_generic_Nd", [&] { constexpr int mode = F::interp_size; upsample_non_separable(iter); }); @@ -1530,8 +1617,8 @@ void upsample_separable_1d( unsigned int weights_precision = 0; if (input_scalar_type == at::kByte) { - // This is a special branch to provide uint8 dtype support for bilinear and bicubic modes only - TORCH_INTERNAL_ASSERT(F::interp_size == 2 || F::interp_size == 4); + // This is a special branch to provide uint8 dtype support for bilinear, bicubic and lanczos modes only + TORCH_INTERNAL_ASSERT(F::interp_size == 2 || F::interp_size == 4 || F::interp_size == 6); int unused = 0; std::tie(indices_weights, unused, weights_precision) = F::compute_index_ranges_int16_weights( @@ -1561,8 +1648,10 @@ void upsample_separable_1d( auto iter = config.build(); + // Dispatch name should be "upsample_separable_1d" but we keep the old + // name for internal BC. AT_DISPATCH_FLOATING_TYPES_AND( - at::ScalarType::Byte, iter.dtype(), "upsample_separable_1d", [&] { + at::ScalarType::Byte, iter.dtype(), "upsample_generic_Nd_aa", [&] { auto loop = [&](char** data, const int64_t* strides, int64_t n) { if constexpr (is_horizontal) { // Strides are : X 0 | 8 8 8 0 8 (Channels first) @@ -1931,6 +2020,35 @@ void upsample_bicubic2d_aa_kernel_impl( /*antialias=*/true); } +void upsample_lanczos2d_aa_kernel_impl( + const Tensor& output, + const Tensor& input, + bool align_corners, + std::optional scales_h, + std::optional scales_w) { + + if (input.dtype() == at::kByte) { + #ifdef CPU_CAPABILITY_AVX2 + if (input.size(1) <= 4) { + return upsample_avx_bilinear_bicubic_uint8( + input, output, align_corners, {scales_h, scales_w}, + /*antialias=*/true); + } + #elif defined(__aarch64__) + if (input.size(1) == 3 + && input.is_contiguous(at::MemoryFormat::ChannelsLast) + && output.is_contiguous(at::MemoryFormat::ChannelsLast)) { + return upsample_neon_bilinear_bicubic_uint8( + input, output, align_corners, {scales_h, scales_w}, + /*antialias=*/true); + } + #endif // CPU_CAPABILITY_AVX2 + } + return upsample_separable_Nd_kernel_impl<2, scale_t, HelperInterpLanczos>( + output, input, align_corners, {scales_h, scales_w}, + /*antialias=*/true); +} + template < typename scalar_t, typename scale_type, @@ -2071,6 +2189,19 @@ void upsample_bicubic2d_aa_backward_kernel_impl( }); } +void upsample_lanczos2d_aa_backward_kernel_impl( + const Tensor& grad_input, + const Tensor& grad_output, + bool align_corners, + std::optional scales_h, + std::optional scales_w) { + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "upsample_lanczos2d_aa_backward_cpu", [&] { + upsample_separable_Nd_backward_aa( + grad_input, grad_output, align_corners, {scales_h, scales_w}); + }); +} + } // anonymous namespace REGISTER_DISPATCH(upsample_nearest1d_kernel, &upsample_nearest1d_kernel_impl) @@ -2089,4 +2220,7 @@ REGISTER_DISPATCH(upsample_trilinear3d_kernel, &upsample_trilinear3d_kernel_impl REGISTER_DISPATCH(upsample_bicubic2d_kernel, &upsample_bicubic2d_kernel_impl) REGISTER_DISPATCH(_upsample_bicubic2d_aa_kernel, &upsample_bicubic2d_aa_kernel_impl) REGISTER_DISPATCH(_upsample_bicubic2d_aa_backward_kernel, &upsample_bicubic2d_aa_backward_kernel_impl) + +REGISTER_DISPATCH(_upsample_lanczos2d_aa_kernel, &upsample_lanczos2d_aa_kernel_impl) +REGISTER_DISPATCH(_upsample_lanczos2d_aa_backward_kernel, &upsample_lanczos2d_aa_backward_kernel_impl) } // namespace at::native diff --git a/aten/src/ATen/native/cpu/UpSampleKernelNEONAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelNEONAntialias.h index f6b2ea72156ac..c8d3a1a59a508 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelNEONAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelNEONAntialias.h @@ -111,6 +111,26 @@ void NeonResampleHorizontal(const at::Tensor& unpacked_output, acc_b = vmlal_s16(acc_b, vget_high_s16(b16), vget_high_s16(weights)); } + // Block 4: handle 4 pixels that didn't fit in a block of 8 + // We also use vld3_u8 here, which still loads 8 pixels, but we only use + // the lower half - so the computation is correct. + // On all rows except the last one, reading 8 pixels is safe (the tensors + // are channels-last). But on the last row, we have to be careful not to + // read past the buffer, hence the extra boundary check. + const uint8_t* block_of_4_safe_load_end = input_p + yin * xin_stride - 24; + for (; i + 4 <= ids_size && lineIn_min + num_channels * i <= block_of_4_safe_load_end; i += 4) { + uint8x8x3_t rgb = vld3_u8(lineIn_min + num_channels * i); + int16x4_t weights4 = vld1_s16(&k[i]); + + int16x4_t r16 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(rgb.val[0]))); + int16x4_t g16 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(rgb.val[1]))); + int16x4_t b16 = vget_low_s16(vreinterpretq_s16_u16(vmovl_u8(rgb.val[2]))); + + acc_r = vmlal_s16(acc_r, r16, weights4); + acc_g = vmlal_s16(acc_g, g16, weights4); + acc_b = vmlal_s16(acc_b, b16, weights4); + } + // Horizontal reduction + rounding bias int32_t sum_r = vaddvq_s32(acc_r) + initial_val; int32_t sum_g = vaddvq_s32(acc_g) + initial_val; diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index 9a9897b88fba2..e52865a9a135d 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -40,7 +40,7 @@ inline bool is_block_start(int index, int BLOCK_SIZE) { #if (defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)) && !defined(_MSC_VER) // convert 16x int4 to int8, handle 64 bits at a time // used in avx2 and avx512 -inline __m128i conver_int4_to_int8(const uint8_t* data) { +inline __m128i convert_int4_to_int8(const uint8_t* data) { __m128i tmp = _mm_loadu_si64((const __m128i*)data); __m128i bytes = _mm_cvtepu8_epi16(tmp); const __m128i lowMask = _mm_set1_epi8(0xF); @@ -169,7 +169,7 @@ inline void tinygemm_kernel( vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]); } } else { - __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 8); + __m128i b8 = convert_int4_to_int8(B + k * ldb + col * 8); __m512i b32 = _mm512_cvtepu8_epi32(b8); vb[col] = _mm512_permutexvar_ps(b32, lut); vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]); @@ -312,7 +312,7 @@ inline void tinygemm_kernel( } else { if constexpr (col % 2 == 0) { // de-quantize per 64 bits (16x int4) - __m128i b8 = conver_int4_to_int8(B + k * ldb + col * 4); + __m128i b8 = convert_int4_to_int8(B + k * ldb + col * 4); __m128i b8_val0 = _mm_set1_epi64x(_mm_extract_epi64(b8, 0)); __m128i b8_val1 = _mm_set1_epi64x(_mm_extract_epi64(b8, 1)); if (k + PREFETCH_SIZE_K < K) { @@ -619,7 +619,7 @@ void weight_to_int4pack_kernel( const Tensor& weight) { auto weight_packed_data = reinterpret_cast(weight_packed.data_ptr()); - const auto weight_data = weight.data_ptr(); + const auto weight_data = weight.const_data_ptr(); int N = weight.size(0); int K = weight.size(1); diff --git a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu index 2030dbb904505..30745cb8f0a0a 100644 --- a/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu +++ b/aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu @@ -393,7 +393,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda) C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { // run updateGradInput kernel - atomicadaptivemaxgradinput<<< + adaptivemaxgradinput<<< blocks, threads, 0, diff --git a/aten/src/ATen/native/cuda/AveragePool3d.cu b/aten/src/ATen/native/cuda/AveragePool3d.cu index dabcf5b63be99..f290864a9320b 100644 --- a/aten/src/ATen/native/cuda/AveragePool3d.cu +++ b/aten/src/ATen/native/cuda/AveragePool3d.cu @@ -500,11 +500,6 @@ TORCH_IMPL_FUNC(avg_pool3d_backward_out_cuda) ( const int64_t oheight = gradOutput.size(-2); const int64_t owidth = gradOutput.size(-1); - /* XXX shape check behavior from TH */ - const int64_t otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode); - const int64_t oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode); - const int64_t owidth_for_chape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode); - const bool kernelsOverlap = (dT < kT) || (dH < kH) || (dW < kW); Tensor work_grad_input = gradInput; diff --git a/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu index bdfec7faffeab..1f05f2ba17161 100644 --- a/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryDivFloorKernel.cu @@ -9,7 +9,6 @@ #include #include #include -#include #include #include diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index e427ff9a50da0..33abc9003e8cb 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -34,6 +34,7 @@ #else #include #include +#include #include #include #include @@ -108,8 +109,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa case Activation::GELU: return cuda::blas::GEMMAndBiasActivationEpilogue::GELU; default: - TORCH_CHECK(false); - return cuda::blas::GEMMAndBiasActivationEpilogue::None; + TORCH_CHECK(false, "Unknown activation epologue type"); } } @@ -228,9 +228,6 @@ static bool isInputCompliesAddmmCudaLt( mat2_sizes[0] > 1 && mat2_sizes[1] > 1 ) ); - - // no compliance by default - return false; } template diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 688ee468939e0..00f0aedf4578c 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -648,7 +648,6 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); - TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); std::array data; for (int i = 0; i < ntensors; i++) { diff --git a/aten/src/ATen/native/cuda/CUDAScalar.cu b/aten/src/ATen/native/cuda/CUDAScalar.cu index 169a2ab92615f..49af9c8087cdf 100644 --- a/aten/src/ATen/native/cuda/CUDAScalar.cu +++ b/aten/src/ATen/native/cuda/CUDAScalar.cu @@ -10,28 +10,67 @@ #endif #include +#include namespace at::native { +namespace { + +bool is_cuda_caching_allocator_tensor(const Tensor& self) { + auto* cuda_allocator = c10::cuda::CUDACachingAllocator::get(); + if (cuda_allocator == nullptr) { + return false; + } + // SymmMem/NVSHMEM/rocSHMEM tensors are typically backed by custom + // from_blob-style deleters, so this check filters them out and keeps + // the direct dereference path limited to allocator-managed CUDA memory. + return self.storage().data_ptr().get_deleter() == cuda_allocator->raw_deleter(); +} + +template +void _local_scalar_dense_cuda_impl(const Tensor& self, Scalar& r) { +#if defined(USE_ROCM) && (ROCM_VERSION >= 70200) + // If this is a large BAR device, we can just read directly from VRAM + if ( + at::cuda::getCurrentDeviceProperties()->isLargeBar && + is_cuda_caching_allocator_tensor(self)) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + hipStreamCaptureStatus captureStatus; + C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr)); + if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) { + at::cuda::stream_synchronize(stream); + r = Scalar(*self.template const_data_ptr()); + } else { + C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); + } + return; + } +#endif + + // Create pinned memory for the scalar value to avoid implicit + // locking/sync in cuda library due to pageable memory + auto value = at::detail::empty_cpu( + {1}, /* size */ + c10::CppTypeToScalarType(), /* dtype */ + std::nullopt, /* layout */ + std::nullopt, /* device */ + true, /* pin_memory */ + std::nullopt /* memory format */ + ); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + at::cuda::memcpy_and_sync(value.template mutable_data_ptr(), self.template const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); + r = Scalar(*value.template const_data_ptr()); +} + +} // anonymous namespace + Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported"); AT_DISPATCH_V2( self.scalar_type(), "_local_scalar_dense_cuda", AT_WRAP([&] { - // Create pinned memory for the scalar value to avoid implicit - // locking/sync in cuda library due to pageable memory - auto value = at::detail::empty_cpu( - {1}, /* size */ - c10::CppTypeToScalarType(), /* dtype */ - std::nullopt, /* layout */ - std::nullopt, /* device */ - true, /* pin_memory */ - std::nullopt /* memory format */ - ); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - at::cuda::memcpy_and_sync(value.mutable_data_ptr(), self.const_data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); - r = Scalar(*value.const_data_ptr()); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + _local_scalar_dense_cuda_impl(self, r); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); return r; } diff --git a/aten/src/ATen/native/cuda/CompositeRandomAccessor.h b/aten/src/ATen/native/cuda/CompositeRandomAccessor.h index eb8587d1f9337..d519e333b1e9d 100644 --- a/aten/src/ATen/native/cuda/CompositeRandomAccessor.h +++ b/aten/src/ATen/native/cuda/CompositeRandomAccessor.h @@ -25,7 +25,7 @@ void swap( references_holder rh1, references_holder rh2 ) { - return thrust::swap(rh1.data(), rh2.data()); + thrust::swap(rh1.data(), rh2.data()); } template diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 9e78d2e71f043..68741f90ef85e 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -21,8 +21,7 @@ #include #include -// TODO(NS): Investigate why FP8 conversion intrinsics end up being slower -#ifdef AT_USE_NV_CVT_INTRINSICS +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000 #include #endif @@ -69,25 +68,53 @@ void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { } #endif +template +struct ConvertToFloat8E4M3fnOp { + __device__ __forceinline__ Float8_e4m3fn operator()(SrcT value) const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 890 + __nv_fp8_storage_t x; + if constexpr (std::is_same_v) { + x = __nv_cvt_float_to_fp8(value, __NV_SATFINITE, __NV_E4M3); + } else if constexpr (std::is_same_v) { + x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_SATFINITE, __NV_E4M3); + } else if constexpr (std::is_same_v) { + x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_SATFINITE, __NV_E4M3); + } else { + x = __nv_cvt_float_to_fp8(static_cast(value), __NV_SATFINITE, __NV_E4M3); + } + return Float8_e4m3fn(x, Float8_e4m3fn::from_bits()); +#else + return Float8_e4m3fn(value); +#endif + } +}; + +// e5m2 intrinsics are correct but slower; only used for float on Blackwell +// to work around the ptxas subnormal codegen bug. +struct ConvertFloatToFloat8E5M2Op { + __device__ __forceinline__ Float8_e5m2 operator()(float value) const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 13020 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 + auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); + return Float8_e5m2(x, Float8_e5m2::from_bits()); +#else + return Float8_e5m2(value); +#endif + } +}; + void float8_copy_kernel_cuda(TensorIteratorBase &iter) { ScalarType dtype = iter.dtype(0); ScalarType other_dtype = iter.dtype(1); if (dtype == kFloat8_e4m3fn) { switch (other_dtype) { case kFloat: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; case kHalf: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; case kBFloat16: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { - return Float8_e4m3fn(value); - }); + gpu_kernel_nocast(iter, ConvertToFloat8E4M3fnOp{}); break; default: gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; }); @@ -96,33 +123,16 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { } else if (dtype == kFloat8_e5m2) { switch (other_dtype) { case kFloat: - gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_float_to_fp8(value, __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else - return Float8_e5m2(value); -#endif - }); + gpu_kernel_nocast(iter, ConvertFloatToFloat8E5M2Op{}); break; case kHalf: gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_halfraw_to_fp8(static_cast<__half>(value), __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else return Float8_e5m2(value); -#endif }); break; case kBFloat16: gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { -#ifdef AT_USE_NV_CVT_INTRINSICS - const auto x = __nv_cvt_bfloat16raw_to_fp8(static_cast<__nv_bfloat16>(value), __NV_NOSAT, __NV_E5M2); - return Float8_e5m2(x, Float8_e5m2::from_bits()); -#else return Float8_e5m2(value); -#endif }); break; default: @@ -194,7 +204,7 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { break; } } else { - TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); + TORCH_CHECK(false, "This supposed to be called only for Float8 types"); } } diff --git a/aten/src/ATen/native/cuda/Distributions.cpp b/aten/src/ATen/native/cuda/Distributions.cpp index be397f4bc217f..f270ea66d3755 100644 --- a/aten/src/ATen/native/cuda/Distributions.cpp +++ b/aten/src/ATen/native/cuda/Distributions.cpp @@ -28,6 +28,14 @@ Tensor _s_poisson_cuda(const Tensor& lambda, std::optional gen_) { // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, std::optional gen_) { + TORCH_CHECK_VALUE( + at::isFloatingType(count.scalar_type()), + "binomial only supports floating-point dtypes for count, got: ", + count.scalar_type()); + TORCH_CHECK_VALUE( + at::isFloatingType(prob.scalar_type()), + "binomial only supports floating-point dtypes for prob, got: ", + prob.scalar_type()); auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); Tensor ret = at::empty(count.sizes(), count.options()); at::TensorIterator iter = at::TensorIteratorConfig() diff --git a/aten/src/ATen/native/cuda/ForeachFunctors.cuh b/aten/src/ATen/native/cuda/ForeachFunctors.cuh index 37175776097df..1e18e42c34565 100644 --- a/aten/src/ATen/native/cuda/ForeachFunctors.cuh +++ b/aten/src/ATen/native/cuda/ForeachFunctors.cuh @@ -77,6 +77,34 @@ __device__ bool init_args( return all_aligned; } +template < + int depth, + typename param_type, + typename grad_type, + typename exp_avg_type, + typename exp_avg_sq_type> +__device__ bool init_args_mixed_prec( + param_type** param_args, + grad_type** grad_args, + exp_avg_type** exp_avg_args, + exp_avg_sq_type** exp_avg_sq_args, + FusedOptimizerTensorListMetadata& tl, + const int64_t chunk_idx, + const int64_t chunk_size, + const int64_t tensor_loc) { + *param_args = + (param_type*)tl.addresses[0][tensor_loc] + chunk_idx * chunk_size; + *grad_args = (grad_type*)tl.addresses[1][tensor_loc] + chunk_idx * chunk_size; + *exp_avg_args = + (exp_avg_type*)tl.addresses[2][tensor_loc] + chunk_idx * chunk_size; + *exp_avg_sq_args = + (exp_avg_sq_type*)tl.addresses[3][tensor_loc] + chunk_idx * chunk_size; + + bool all_aligned = is_aligned(*param_args) && is_aligned(*grad_args) && + is_aligned(*exp_avg_args) && is_aligned(*exp_avg_sq_args); + return all_aligned; +} + template __device__ void load_args( T r_args[][kILP], @@ -96,6 +124,43 @@ __device__ void load_args( } } +template < + typename T, + typename param_type, + typename grad_type, + typename exp_avg_type, + typename exp_avg_sq_type> +__device__ void load_args( + T r_args[][kILP], + const param_type* param_args, + const grad_type* grad_args, + const exp_avg_type* exp_avg_args, + const exp_avg_sq_type* exp_avg_sq_args, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + r_args[0][ii] = 0; + if (i < n && i < chunk_size) { + r_args[0][ii] = static_cast(param_args[i]); + } + r_args[1][ii] = 0; + if (i < n && i < chunk_size) { + r_args[1][ii] = static_cast(grad_args[i]); + } + r_args[2][ii] = 0; + if (i < n && i < chunk_size) { + r_args[2][ii] = static_cast(exp_avg_args[i]); + } + r_args[3][ii] = 0; + if (i < n && i < chunk_size) { + r_args[3][ii] = static_cast(exp_avg_sq_args[i]); + } + } +} + template __device__ void store_args( T* dst, @@ -111,6 +176,21 @@ __device__ void store_args( } } +template +__device__ void store_args( + dT* dst, + sT* src, + const int64_t i_start, + const int64_t chunk_size, + const int64_t n) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const int64_t i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) + dst[i] = static_cast
(src[ii]); + } +} + template __device__ __forceinline__ void binary_op_scalar( T r_args[][kILP], diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 78f0434d9f7a8..e6910caa3302a 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -28,22 +28,25 @@ namespace at::native { -// _foreach_norm supports only L1, L2, and inf norm -enum class NormType { L1, L2, LInf }; +// _foreach_norm supports L0, L1, L2, and inf norm +enum class NormType { L0, L1, L2, LInf }; // NOTE: This is a simple variant of TensorListMetadata in MultiTensorApply.cuh // as we only need to track addresses for the lpnorm_cleanup function below. // Why is this struct necessary? For the same reason the TensorListMetadata // struct is necessary--which is to ferry static metadata to the CUDA kernel -// while complying with the 4kb size constraint. Since we only need to track -// addresses, we introduce this struct to be able to fit more Tensor pointers at -// a time, currently 400 empirically, compared to the much smaller values in -// depth_to_max_tensors. This way, we can launch fewer kernels for better -// performance. +// while complying with the kernel arg size constraint. Since we only need to +// track addresses, we introduce this struct to be able to fit more Tensor +// pointers at a time compared to depth_to_max_tensors. This way, we can +// launch fewer kernels for better performance. // // IF YOU USE THIS STRUCT, PLEASE ADD A ONE-OFF TEST IN test_foreach.py AS THIS // IS CURRENTLY ONLY TESTED FOR _foreach_norm. -const size_t MAX_TENSORS_PER_KERNEL = 400; +#if defined(CUDART_VERSION) && CUDART_VERSION >= 13000 && !defined(USE_ROCM) +static constexpr size_t MAX_TENSORS_PER_KERNEL = 3200; +#else +static constexpr size_t MAX_TENSORS_PER_KERNEL = 400; +#endif struct TensorListAddresses { const void* addresses[MAX_TENSORS_PER_KERNEL]; }; @@ -133,17 +136,19 @@ __global__ void lpmax_cleanup( std::vector foreach_tensor_max_cuda(TensorList tensors) { check_foreach_api_restrictions(tensors); - if (!can_use_fast_route(tensors)) { - return foreach_tensor_max_slow(tensors); - } - // for parity with max in ReduceAllOps.cpp, as max(empty) is ??? + // for parity with max in ReduceAllOps.cpp, as max(empty) is undefined + // Check this early before routing to slow path TORCH_CHECK( std::all_of( tensors.begin(), tensors.end(), [](const auto& t) { return t.numel() > 0; }), - "max(): Expected reduction dim to be specified for input.numel() == 0. Specify the reduction dim with the 'dim' argument."); + "_foreach_max cannot compute the maximum of an empty tensor; max over zero elements is undefined."); + + if (!can_use_fast_route(tensors)) { + return foreach_tensor_max_slow(tensors); + } const size_t ntensors = tensors.size(); int max_chunks_per_tensor = -1; @@ -281,7 +286,10 @@ struct LpNormFunctor { #pragma unroll for (int ii = 0; ii < kILP; ii++) { const auto next = static_cast(r_x[ii]); - if constexpr (norm_type == NormType::LInf) { + if constexpr (norm_type == NormType::L0) { + vals[ii] += + next != out_opmath_t(0) ? out_opmath_t(1) : out_opmath_t(0); + } else if constexpr (norm_type == NormType::LInf) { vals[ii] = max_propagate_nan(vals[ii], ::abs(next)); } else { vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next; @@ -296,7 +304,10 @@ struct LpNormFunctor { int i = i_start + threadIdx.x + ii * blockDim.x; if (i < n && i < chunk_size) { const auto next = static_cast(x[i]); - if constexpr (norm_type == NormType::LInf) { + if constexpr (norm_type == NormType::L0) { + vals[ii] += + next != out_opmath_t(0) ? out_opmath_t(1) : out_opmath_t(0); + } else if constexpr (norm_type == NormType::LInf) { vals[ii] = max_propagate_nan(vals[ii], ::abs(next)); } else { vals[ii] += norm_type == NormType::L1 ? ::abs(next) : next * next; @@ -314,7 +325,8 @@ struct LpNormFunctor { val += vals[i]; } } - auto final_val = norm_type == NormType::L1 || norm_type == NormType::L2 + auto final_val = norm_type == NormType::L0 || norm_type == NormType::L1 || + norm_type == NormType::L2 ? at::native::cuda_utils::BlockReduceSum(val, s_vals) : at::native::cuda_utils::BlockReduceMax(val, s_vals); @@ -348,8 +360,8 @@ __global__ void lpnorm_cleanup( val += output_this_tensor[i]; } } - out_opmath_t final_val = - norm_type == NormType::L1 || norm_type == NormType::L2 + out_opmath_t final_val = norm_type == NormType::L0 || + norm_type == NormType::L1 || norm_type == NormType::L2 ? at::native::cuda_utils::BlockReduceSum(val, vals) : at::native::cuda_utils::BlockReduceMax(val, vals); if (threadIdx.x == 0) { @@ -476,7 +488,13 @@ std::vector foreach_tensor_norm_cuda_internal( AT_DISPATCH_OUT_DTYPES( output_dtype, ForeachNormDispatchName::value, [&]() { using out_opmath_t = typename at::opmath_type; - if (p == static_cast(1)) { + if (p == static_cast(0)) { + multi_tensor_apply<1>( + tensor_lists, + LpNormFunctor(), + output_per_tensor.template mutable_data_ptr(), + max_chunks_per_tensor); + } else if (p == static_cast(1)) { multi_tensor_apply<1>( tensor_lists, LpNormFunctor(), @@ -519,7 +537,16 @@ std::vector foreach_tensor_norm_cuda_internal( .template mutable_data_ptr(); } - if (p == static_cast(1)) { + if (p == static_cast(0)) { + lpnorm_cleanup + <<>>( + output_per_tensor + .template const_data_ptr() + + i * MAX_TENSORS_PER_KERNEL * + max_chunks_per_tensor, + addr_struct, + max_chunks_per_tensor); + } else if (p == static_cast(1)) { lpnorm_cleanup <<>>( output_per_tensor @@ -613,7 +640,8 @@ std::vector foreach_tensor_norm_cuda( }); } if (!can_use_fast_route(tensors) || has_int_or_complex || - !(p == static_cast(1) || p == static_cast(2) || + !(p == static_cast(0) || p == static_cast(1) || + p == static_cast(2) || p == std::numeric_limits::infinity())) { return foreach_tensor_norm_slow(tensors, ord, dtype); } diff --git a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu index bb070f9d97616..9a36a879615a0 100644 --- a/aten/src/ATen/native/cuda/ForeachUnaryOp.cu +++ b/aten/src/ATen/native/cuda/ForeachUnaryOp.cu @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,9 @@ #include #include +#include #include +#include #endif namespace at::native { @@ -431,4 +434,48 @@ void foreach_tensor_zero_cuda_(TensorList tensors) { }); } +std::vector foreach_tensor_clone_cuda( + TensorList self, + std::optional memory_format) { + check_foreach_api_restrictions(self); + if (!_check_tensors_share_device_and_dtype({self})) { + return at::native::foreach_tensor_clone_slow(self, memory_format); + } + + std::vector ret{}; + ret.reserve(self.size()); + + auto realized_memory_format = memory_format.value_or(MemoryFormat::Preserve); + for (const auto& s : self) { + // This logic modified from at::native::clone. + if (realized_memory_format == MemoryFormat::Preserve) { + if (s.is_non_overlapping_and_dense()) { + // Copy all strides, this is marginally faster than calling empty_like + auto options = s.options(); + ret.emplace_back(at::native::empty_strided_cuda( + s.sizes(), + s.strides(), + c10::optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt())); + } else { + ret.emplace_back(at::native::empty_like(s)); + } + } else { + auto options = s.options(); + ret.emplace_back(at::native::empty_like( + s, + c10::optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt(), + realized_memory_format)); + } + } + + at::native::foreach_tensor_copy_list_kernel_cuda_(ret, self); + return ret; +} + } // namespace at::native diff --git a/aten/src/ATen/native/cuda/FusedAdamKernel.cu b/aten/src/ATen/native/cuda/FusedAdamKernel.cu index 0858f24e17c6e..498dc634fab2b 100644 --- a/aten/src/ATen/native/cuda/FusedAdamKernel.cu +++ b/aten/src/ATen/native/cuda/FusedAdamKernel.cu @@ -29,10 +29,15 @@ void _fused_adam_kernel_cuda_( const bool maximize, const std::optional& grad_scale, const std::optional& found_inf) { + const bool is_mixed_precision = + params[0].scalar_type() != exp_avgs[0].scalar_type(); if (amsgrad) { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); _fused_adam_amsgrad_cuda_impl_( params, @@ -52,7 +57,10 @@ void _fused_adam_kernel_cuda_( } else { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); _fused_adam_cuda_impl_( params, @@ -125,10 +133,15 @@ void _fused_adam_kernel_cuda_( lr.device() == param_device, "lr must be on the same GPU device as the params"); + const bool is_mixed_precision = + params[0].scalar_type() != exp_avgs[0].scalar_type(); if (amsgrad) { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); _fused_adam_amsgrad_cuda_impl_( params, @@ -148,7 +161,10 @@ void _fused_adam_kernel_cuda_( } else { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); _fused_adam_cuda_impl_( params, diff --git a/aten/src/ATen/native/cuda/FusedAdamWKernel.cu b/aten/src/ATen/native/cuda/FusedAdamWKernel.cu index 4b758aa574c66..13e11df2a24e9 100644 --- a/aten/src/ATen/native/cuda/FusedAdamWKernel.cu +++ b/aten/src/ATen/native/cuda/FusedAdamWKernel.cu @@ -29,10 +29,15 @@ void _fused_adamw_kernel_cuda_( const bool maximize, const std::optional& grad_scale, const std::optional& found_inf) { + const bool is_mixed_precision = + params[0].scalar_type() != exp_avgs[0].scalar_type(); if (amsgrad) { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); _fused_adamw_amsgrad_cuda_impl_( params, @@ -52,7 +57,10 @@ void _fused_adamw_kernel_cuda_( } else { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); _fused_adamw_cuda_impl_( params, @@ -125,10 +133,15 @@ void _fused_adamw_kernel_cuda_( lr.device() == param_device, "lr must be on the same GPU device as the params"); + const bool is_mixed_precision = + params[0].scalar_type() != exp_avgs[0].scalar_type(); if (amsgrad) { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout"); _fused_adamw_amsgrad_cuda_impl_( params, @@ -148,7 +161,10 @@ void _fused_adamw_kernel_cuda_( } else { TORCH_CHECK( at::native::check_fast_path_restrictions( - {params, grads, exp_avgs, exp_avg_sqs}), + {params, grads, exp_avgs, exp_avg_sqs}, + /*scalarList=*/{}, + /*does_op_promote_integer_inputs_to_float=*/false, + /*skip_cross_list_dtype_check=*/is_mixed_precision), "params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout"); _fused_adamw_cuda_impl_( params, diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index e617c46d5cbce..5b393401a29b4 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -37,6 +37,7 @@ #else #include #include +#include #include #include #include @@ -132,10 +133,10 @@ _mx8_mx8_bf16_grouped_mm_mslk( scale_b, offs.value(), out); + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "mxfp8_mxfp8 grouped gemm requires compile with USE_MSLK"); #endif - return out; } // 2d-2d and 2d-3d cases @@ -152,7 +153,7 @@ _f8_f8_bf16_rowwise_grouped_mm_cuda( const bool use_fast_accum, Tensor& out) { TORCH_CHECK_VALUE(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type()); - TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type()); + TORCH_CHECK_VALUE(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_b to be Float8_e4m3 matrix got ", mat_b.scalar_type()); at::cuda::detail::f8f8bf16_grouped_mm( mat_a, @@ -197,10 +198,10 @@ _f8_f8_bf16_rowwise_grouped_mm_rocm( scale_b, offs, out); + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "grouped gemm is not supported without USE_MSLK on ROCM") #endif - return out; } #endif // USE_ROCM @@ -291,11 +292,11 @@ _f4_f4_bf16_grouped_mm_mslk( out, combined_global_scale ); + + return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, "nvfp4 grouped gemm is not supported without USE_MSLK, and only for CUDA") #endif - - return out; } void _check_scales_fp8_rowwise(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) { diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index e44450c542e83..d9bf529f38db1 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -30,7 +30,7 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int constexpr int64_t max_num_threads=256; auto num_threads = at::round_up( at::ceil_div(slice_size_in_bytes, Alignment), - static_cast(C10_WARP_SIZE)); + static_cast(at::cuda::warp_size())); uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; grid_y = std::min(static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y); dim3 grid = {static_cast(num_ind), grid_y, 1}; diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index 962f71431e727..063db29d38279 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -81,7 +81,7 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) { namespace at:: native { template -void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { +void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f, bool check_cast = true) { for (int arg = 0; arg < iter.ntensors(); arg++) { TORCH_INTERNAL_ASSERT( @@ -95,14 +95,22 @@ void gpu_kernel_nocast(TensorIteratorBase& iter, const func_t& f) { if (!iter.can_use_32bit_indexing()) { for (auto& sub_iter : iter.with_32bit_indexing()) { - gpu_kernel_nocast(sub_iter, f); + gpu_kernel_nocast(sub_iter, f, check_cast); } return; } + if (check_cast) { + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + } gpu_kernel_impl_nocast(iter, f); } +template +void gpu_kernel_opaque(TensorIteratorBase& iter, const func_t& f) { + gpu_kernel_nocast(iter, f, false); +} + template void gpu_kernel(TensorIteratorBase& iter, const func_t& f) { diff --git a/aten/src/ATen/native/cuda/MultiTensorApply.cuh b/aten/src/ATen/native/cuda/MultiTensorApply.cuh index 2fe431f778b1a..d11cfc19c5856 100644 --- a/aten/src/ATen/native/cuda/MultiTensorApply.cuh +++ b/aten/src/ATen/native/cuda/MultiTensorApply.cuh @@ -15,13 +15,35 @@ static constexpr int64_t kChunkSize = 65536; static constexpr int64_t kBlockSize = 512; // TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy` -// TensorListMetadata has to be < 4KB - the limit for kernel launch argument +// TensorListMetadata has to fit within the CUDA kernel launch argument limit. +// While CUDA 12.1, driver version R530+ and Volta+ would work with 32KB, we +// decide to be safe and only swap for CUDA 13+ during compile time. This saves +// binary size and will guarantees 32KB kernel arg space; older versions are +// still limited to 4KB. We adopt naive values for 32KB from +// https://github.com/pytorch/pytorch/pull/134373. +// TODO: The values for 32KB can very much be optimized further. +#if defined(CUDART_VERSION) && CUDART_VERSION >= 13000 && !defined(USE_ROCM) + +static constexpr int depth_to_max_tensors[5] = {770, 448, 336, 252, 210}; +static constexpr int depth_to_max_blocks[5] = {2240, 2240, 2240, 2240, 2240}; +static constexpr int depth_to_max_tensors_scalarlist[5] = + {672, 448, 336, 252, 210}; +static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { + 504, + 420}; +using block_index_t = uint16_t; + +#else + static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30}; static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320}; static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30}; static constexpr int depth_to_max_tensors_scalarlist_of_complex_double[2] = { 72, 60}; +using block_index_t = unsigned char; + +#endif template __device__ __forceinline__ bool is_aligned(T* p) { @@ -42,7 +64,7 @@ template struct TensorListMetadata { const void* addresses[n][depth_to_max_tensors[n - 1]]; int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + block_index_t block_to_tensor[depth_to_max_blocks[n - 1]]; int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; @@ -52,12 +74,12 @@ struct TensorListScalarListMetadata { const void* addresses[n][depth_to_max_tensors_scalarlist[n - 1]]; int64_t numel_for_tensor[depth_to_max_tensors_scalarlist[n - 1]]; scalar_vals_t scalar_vals[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + block_index_t block_to_tensor[depth_to_max_blocks[n - 1]]; int block_to_chunk[depth_to_max_blocks[n - 1]]; }; -// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size of -// 4kb with `c10::complex` +// note(mkozuki): `n` of 1&2 violate the limit of cuda kernel argument size +// with `c10::complex` template <> struct TensorListScalarListMetadata, 1> { const void* addresses[1] @@ -66,7 +88,7 @@ struct TensorListScalarListMetadata, 1> { numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[0]]; c10::complex scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[0]]; - unsigned char block_to_tensor[depth_to_max_blocks[1 - 1]]; + block_index_t block_to_tensor[depth_to_max_blocks[1 - 1]]; int block_to_chunk[depth_to_max_blocks[1 - 1]]; }; @@ -78,19 +100,22 @@ struct TensorListScalarListMetadata, 2> { numel_for_tensor[depth_to_max_tensors_scalarlist_of_complex_double[1]]; c10::complex scalar_vals[depth_to_max_tensors_scalarlist_of_complex_double[1]]; - unsigned char block_to_tensor[depth_to_max_blocks[2 - 1]]; + block_index_t block_to_tensor[depth_to_max_blocks[2 - 1]]; int block_to_chunk[depth_to_max_blocks[2 - 1]]; }; // NOTE(crcrpar): This is a conservative resolution to handle `state_steps` // whose each element is `at::Tensor` of 1 element representing the number of // `step`s called so far. +// We're aware this struct overflows the kernel arg limit at n=1 (4244 bytes), +// but our current fused optimizers only instantiate at n>=4 so it's not a +// concern (yet). template struct FusedOptimizerTensorListMetadata { const void* addresses[n][depth_to_max_tensors[n - 1]]; int64_t numel_for_tensor[depth_to_max_tensors[n - 1]]; - const void* state_steps_addresses[depth_to_max_tensors_scalarlist[n - 1]]; - unsigned char block_to_tensor[depth_to_max_blocks[n - 1]]; + const void* state_steps_addresses[depth_to_max_tensors[n - 1]]; + block_index_t block_to_tensor[depth_to_max_blocks[n - 1]]; int block_to_chunk[depth_to_max_blocks[n - 1]]; int start_tensor_this_launch; }; diff --git a/aten/src/ATen/native/cuda/Normalization.cuh b/aten/src/ATen/native/cuda/Normalization.cuh index 8e31f8fa9a694..e4fa548a1bd62 100644 --- a/aten/src/ATen/native/cuda/Normalization.cuh +++ b/aten/src/ATen/native/cuda/Normalization.cuh @@ -54,7 +54,7 @@ __device__ __forceinline__ int getMSB(int val) { template struct Float2 { accscalar_t v1, v2; - __device__ Float2() {} + __device__ Float2() = default; __device__ Float2(scalar_t v1, scalar_t v2) : v1(static_cast(v1)), v2(static_cast(v2)) {} __device__ Float2(int v) : v1(static_cast(v)), v2(static_cast(v)) {} __device__ Float2& operator+=(const Float2& a) { @@ -134,7 +134,7 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) { } #endif } - __shared__ scalar_t shared[C10_WARP_SIZE]; + __shared__ scalar_t shared[C10_WARP_SIZE_UPPER_BOUND]; SumReduceOp reduce_op; sum = cuda_utils::BlockReduce, cuda_utils::Block2D>(sum, reduce_op, 0, shared); if (threadIdx.x == 0 && threadIdx.y == 0) { @@ -288,7 +288,7 @@ __global__ void batch_norm_collect_statistics_kernel( GenericPackedTensorAccessor save_mean, GenericPackedTensorAccessor save_transformed_var) { - __shared__ int shared_n[2 * 2 * C10_WARP_SIZE + C10_WARP_SIZE]; + __shared__ int shared_n[2 * 2 * C10_WARP_SIZE_UPPER_BOUND + C10_WARP_SIZE_UPPER_BOUND]; int plane = blockIdx.x; int N = input.size(0) * input.size(2); @@ -750,7 +750,7 @@ void batch_norm_elemt_cuda_template(const Tensor& output_, const Tensor& input_, // The input_transform kernel is pointwise, but we need to balance reading parameters (save_var/mean, // weight/bias) - which we only do once and have a for loop afterwards - with having many threads and blocks - // and good occupancy. Quiet likely, we could go with even more blocks than 1024. + // and good occupancy. Quite likely, we could go with even more blocks than 1024. // The various planes are independent, so we use blocks for them. int tf = std::max(getNumThreads(input.size(2)/4), std::min(getNumThreads(input.size(2)), 64)); diff --git a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh index f0871fa0ead6f..fee3d0a5aba1d 100644 --- a/aten/src/ATen/native/cuda/PersistentSoftmax.cuh +++ b/aten/src/ATen/native/cuda/PersistentSoftmax.cuh @@ -64,12 +64,12 @@ __device__ __forceinline__ void warp_reduce(acc_t* sum) { // input_t=half, acc_t=float, output_t=float => read half tensor, float accumulators, write float tensor. // input_t_float, acc_t=float, output_t=half => read float tensor, float accumulators, write half tensor. -template +template __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batch_size, int stride, int element_count, const bool *mask = nullptr, const int head_chunk_size = -1, bool is_transformer_mask = false) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_forward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < WARP_SIZE_PARAM) ? next_power_of_two : WARP_SIZE_PARAM; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -211,12 +211,12 @@ __global__ void softmax_warp_forward(output_t *dst, const input_t *src, int batc } } -template +template __global__ void softmax_warp_backward(output_t *gradInput, const input_t *grad, const input_t *output, int batch_size, int stride, int element_count, const bool *mask = nullptr) { // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and warp_size of method warp_softmax_backward_kernel. constexpr int next_power_of_two = 1 << log2_elements; - constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_SIZE = (next_power_of_two < WARP_SIZE_PARAM) ? next_power_of_two : WARP_SIZE_PARAM; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; @@ -325,12 +325,29 @@ void dispatch_softmax_forward(output_t *dst, const input_t *src, int softmax_ele dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { +#ifdef USE_ROCM + // To support ROCm amdgcnspirv target, we must compile both a 32 and 64 warpSize version of each kernel #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ - softmax_warp_forward \ + if (warp_size == 64) { \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ + } \ + else { \ + softmax_warp_forward \ + <<>>(dst, \ + src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ + } \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; +#else + #define LAUNCH_SOFTMAX_WARP_FORWARD(L2E) case L2E: \ + softmax_warp_forward \ <<>>(dst, \ src, batch_count, softmax_elements_stride, softmax_elements, mask, chunk_size, is_transformer_mask); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ break; +#endif LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 @@ -376,13 +393,32 @@ void dispatch_softmax_backward(output_t *grad_input, const input_t *grad, const dim3 threads(warp_size, warps_per_block, 1); // Launch code would be more elegant if C++ supported FOR CONSTEXPR switch (log2_elements) { +#ifdef USE_ROCM + // To support ROCm amdgcnspirv target, we must compile both a 32 and 64 warpSize version of each kernel + #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ + if (warp_size == 64) { \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements, mask); \ + } \ + else { \ + softmax_warp_backward \ + <<>> \ + (grad_input, grad, output, batch_count, softmax_elements_stride, \ + softmax_elements, mask); \ + } \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + break; +#else #define LAUNCH_SOFTMAX_WARP_BACKWARD(L2E) case L2E: \ - softmax_warp_backward \ + softmax_warp_backward \ <<>> \ (grad_input, grad, output, batch_count, softmax_elements_stride, \ softmax_elements, mask); \ C10_CUDA_KERNEL_LAUNCH_CHECK(); \ break; +#endif LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 diff --git a/aten/src/ATen/native/cuda/PhiloxDistribution.cu b/aten/src/ATen/native/cuda/PhiloxDistribution.cu new file mode 100644 index 0000000000000..193ea9618e748 --- /dev/null +++ b/aten/src/ATen/native/cuda/PhiloxDistribution.cu @@ -0,0 +1,322 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at::native { + +namespace { + +using at::cuda::philox_4x32; + +// Elements produced per Philox 4x32 call: 4 for float/half/bfloat16, 2 for double. +// Note that we use a full float for each generated half/bfloat16 for better numerics. +template +constexpr int elems_per_call = std::is_same_v ? 2 : 4; + +// Box-Muller: convert 4 uniform uint32 values into 4 standard normal floats. +__device__ __forceinline__ float4 box_muller_float(uint4 r) { + constexpr float M = 2.3283064365386963e-10f; // 1/2^32 + constexpr float TWO_PI = 6.2831853071795864f; + // Map to (0, 1] to avoid log(0). + float u1 = fmaf(r.x, M, M * 0.5f); + float u2 = fmaf(r.y, M, M * 0.5f); + float u3 = fmaf(r.z, M, M * 0.5f); + float u4 = fmaf(r.w, M, M * 0.5f); + + float radius1 = sqrtf(-2.0f * __logf(u1)); + float radius2 = sqrtf(-2.0f * __logf(u3)); + float s1, c1, s2, c2; + __sincosf(TWO_PI * u2, &s1, &c1); + __sincosf(TWO_PI * u4, &s2, &c2); + return {radius1 * c1, radius1 * s1, radius2 * c2, radius2 * s2}; +} + +// Box-Muller: convert 4 uint32 values (packed into 2 uint64) into 2 standard +// normal doubles. +__device__ __forceinline__ double2 box_muller_double(uint4 r) { + constexpr double M = 2.3283064365386963e-10; // 1/2^32 + constexpr double TWO_PI = 6.2831853071795864; + // Pack pairs of uint32 for ~64 bits of uniform randomness. + double u1 = fma(static_cast(r.x), M, + static_cast(r.y) * M * M + M * M * 0.5); + double u2 = fma(static_cast(r.z), M, + static_cast(r.w) * M * M + M * M * 0.5); + + double radius = ::sqrt(-2.0 * ::log(u1)); + double s, c; + ::sincos(TWO_PI * u2, &s, &c); + return {radius * c, radius * s}; +} + +// Single-key kernel: one thread per chunk of elements, where each chunk +// comes from a single Philox 4x32 call. Uses vectorized stores for full +// chunks and scalar writes for the tail. +template +__global__ void philox_single_key_kernel( + scalar_t* __restrict__ output, + const uint64_t* __restrict__ key, + int64_t num_elems, + sample_t sample_func, + param_t param_func) { + + // Use vectorized load to get (seed, offset) + auto key_vec = memory::ld_vec<16>(key); + auto* key_vals = reinterpret_cast(&key_vec); + uint64_t seed = key_vals[0]; + uint64_t offset = key_vals[1]; + + // Use vectorized stores for full chunks since they're aligned. + constexpr int epc = elems_per_call; + int64_t num_full_chunks = num_elems / epc; + int64_t chunk = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (chunk < num_full_chunks) { + auto sample = sample_func(seed, offset + static_cast(chunk)); + constexpr int vec_bytes = epc * sizeof(scalar_t); + memory::Vec v; + auto* vals = reinterpret_cast(&v); + #pragma unroll + for (int j = 0; j < epc; j++) { + vals[j] = param_func((&sample.x)[j]); + } + memory::st_vec(output + chunk * epc, v); + } + + // Scalar tail for remaining elements. + if (chunk == num_full_chunks) { + int64_t tail_start = num_full_chunks * epc; + auto sample = sample_func(seed, offset + static_cast(num_full_chunks)); + for (int j = 0; j < num_elems - tail_start; j++) { + output[tail_start + j] = param_func((&sample.x)[j]); + } + } +} + +// Multi-key kernel: one thread per (key_idx, chunk) pair, where each chunk +// comes from a single Philox 4x32 call. Uses vectorized stores for full +// chunks and scalar writes for the tail. +template +__global__ void philox_multi_key_kernel( + scalar_t* __restrict__ output, + const uint64_t* __restrict__ keys, + int64_t num_keys, + int64_t elems_per_key, + sample_t sample_func, + param_t param_func, + OffsetCalculator<1> key_offset_calc) { + constexpr int epc = elems_per_call; + int64_t chunks_per_key = (elems_per_key + epc - 1) / epc; + int64_t total_threads = num_keys * chunks_per_key; + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid >= total_threads) return; + + // Determine correct (seed, offset) to use and sample. + int64_t key_idx = tid / chunks_per_key; + int64_t chunk = tid % chunks_per_key; + auto elem_offset = key_offset_calc.get(key_idx)[0]; + uint64_t seed = keys[elem_offset]; + uint64_t offset = keys[elem_offset + 1]; + auto sample = sample_func(seed, offset + static_cast(chunk)); + + // Vectorized writes require aligned base addresses. This is guaranteed + // when elems_per_key is a multiple of epc, since + // base = key_idx * elems_per_key + chunk * epc. + int64_t full_chunks_per_key = elems_per_key / epc; + bool aligned = elems_per_key % epc == 0; + int64_t base = key_idx * elems_per_key + chunk * epc; + if (aligned && chunk < full_chunks_per_key) { + constexpr int vec_bytes = epc * sizeof(scalar_t); + memory::Vec v; + auto* vals = reinterpret_cast(&v); + #pragma unroll + for (int j = 0; j < epc; j++) { + vals[j] = param_func((&sample.x)[j]); + } + memory::st_vec(output + base, v); + } else { + for (int j = 0; j < epc && chunk * epc + j < elems_per_key; j++) { + output[base + j] = param_func((&sample.x)[j]); + } + } +} + +// Dispatches to single-key or multi-key kernels as needed. +template +void philox_distribution_kernel( + const char* op_name, + Tensor& self, const Tensor& key, + const sample_t& sample_func, const param_t& param_func) { + TORCH_CHECK(self.is_floating_point(), + op_name, ": self must be a floating point tensor, got ", + self.scalar_type()); + TORCH_CHECK(key.scalar_type() == kUInt64, + op_name, ": key must have dtype uint64, got ", + key.scalar_type()); + TORCH_CHECK(self.device() == key.device(), + op_name, ": self and key must be on the same device, got ", + self.device(), " and ", key.device()); + TORCH_CHECK(key.dim() >= 1 && key.size(-1) == 2, + op_name, ": key must have shape (2,) or (*batch, 2), got shape ", + key.sizes()); + if (key.dim() > 1) { + TORCH_CHECK(key.dim() == self.dim() + 1, + op_name, ": batched key must have ndim == output ndim + 1, " + "got key shape ", key.sizes(), " with output shape ", self.sizes()); + auto key_batch = key.sizes().slice(0, self.dim()); + TORCH_CHECK(is_expandable_to(key_batch, self.sizes()), + op_name, ": key batch shape ", key_batch, + " is not broadcastable with output shape ", self.sizes()); + } + + if (self.numel() == 0) { + return; + } + + // Ensure contiguous, aligned output for vectorized stores. Clone if needed + // to ensure alignment; the result is copied back into self afterwards. + constexpr int vec_bytes = elems_per_call * sizeof(scalar_t); + auto output = self.contiguous(); + if (reinterpret_cast(output.data_ptr()) % vec_bytes != 0) { + output = output.clone(); + } + + constexpr int block_size = 256; + + if (key.dim() == 1) { + // === Launch single key kernel === + constexpr int epc = elems_per_call; + int64_t num_chunks = (self.numel() + epc - 1) / epc; + int num_blocks = static_cast((num_chunks + block_size - 1) / block_size); + + auto key_contig = key.contiguous(); + philox_single_key_kernel + <<>>( + output.mutable_data_ptr(), + key_contig.data_ptr(), + self.numel(), sample_func, param_func); + } else { + // === Launch batched (multiple) key kernel === + // The kernel writes each key's output as a contiguous block of + // elems_per_key elements. We determine elems_per_key by counting + // trailing size-1 key dims; these are the output dimensions that a + // single key generates over. For example, with key shape (4, 1, 1, 2) + // and output shape (4, 10, 100): key_dims=1, elems_per_key=1000. + int64_t elems_per_key = 1; + int64_t key_dims = self.dim(); + for (int64_t i = self.dim() - 1; i >= 0; i--) { + if (key.size(i) != 1) break; + elems_per_key *= self.size(i); + key_dims--; + } + int64_t num_keys = self.numel() / elems_per_key; + + // Handle key, self broadcasting via OffsetCalculator. + c10::SmallVector oc_sizes(key_dims); + c10::SmallVector oc_strides(key_dims); + for (int64_t i = 0; i < key_dims; i++) { + int64_t dim = key_dims - 1 - i; + oc_sizes[i] = self.size(dim); + oc_strides[i] = key.size(dim) > 1 ? key.stride(dim) : 0; + } + const int64_t* oc_strides_ptr = oc_strides.data(); + auto key_offset_calc = OffsetCalculator<1>( + key_dims, oc_sizes.data(), &oc_strides_ptr); + + int64_t chunks_per_key = + (elems_per_key + elems_per_call - 1) / elems_per_call; + int64_t total_threads = num_keys * chunks_per_key; + int num_blocks = static_cast((total_threads + block_size - 1) / block_size); + + philox_multi_key_kernel + <<>>( + output.mutable_data_ptr(), + key.data_ptr(), + num_keys, elems_per_key, + sample_func, param_func, key_offset_calc); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + if (output.data_ptr() != self.data_ptr()) { + self.copy_(output); + } +} + +} // anonymous namespace + +Tensor& _philox_uniform_cuda_( + Tensor& self, const Tensor& key, double low, double high) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, self.scalar_type(), "_philox_uniform_", [&] { + auto sample_func = []() { + if constexpr (std::is_same_v) { + return [] __device__ (uint64_t seed, uint64_t offset) { + uint4 r = philox_4x32(seed, offset); + ulonglong2 packed; + packed.x = (static_cast(r.x) << 32) | r.y; + packed.y = (static_cast(r.z) << 32) | r.w; + return packed; + }; + } else { + return [] __device__ (uint64_t seed, uint64_t offset) { + return philox_4x32(seed, offset); + }; + } + }(); + + auto lo = static_cast(low); + auto hi = static_cast(high); + auto param_func = [lo, hi] __device__ (auto rand) { + return static_cast( + at::transformation::uniform_real(rand, lo, hi)); + }; + + philox_distribution_kernel( + "_philox_uniform_", self, key, sample_func, param_func); + }); + return self; +} + +Tensor& _philox_normal_cuda_( + Tensor& self, const Tensor& key, double mean, double stddev) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, kBFloat16, self.scalar_type(), "_philox_normal_", [&] { + using compute_t = std::conditional_t, double, float>; + auto sample_func = []() { + if constexpr (std::is_same_v) { + return [] __device__ (uint64_t seed, uint64_t offset) { + return box_muller_double(philox_4x32(seed, offset)); + }; + } else { + return [] __device__ (uint64_t seed, uint64_t offset) { + return box_muller_float(philox_4x32(seed, offset)); + }; + } + }(); + + auto mu = static_cast(mean); + auto sigma = static_cast(stddev); + auto param_func = [mu, sigma] __device__ (compute_t rand) { + return static_cast(rand * sigma + mu); + }; + + philox_distribution_kernel( + "_philox_normal_", self, key, sample_func, param_func); + }); + return self; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/cuda/PhiloxKeySplit.cu b/aten/src/ATen/native/cuda/PhiloxKeySplit.cu new file mode 100644 index 0000000000000..07893c061ee0d --- /dev/null +++ b/aten/src/ATen/native/cuda/PhiloxKeySplit.cu @@ -0,0 +1,140 @@ +#define TORCH_ASSERT_ONLY_METHOD_OPERATORS + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +namespace at::native { + +namespace { + +using at::cuda::philox_4x32; + +// Derive a new (seed, offset) key from 4 random uint32 values. +// Use 2 uint32s for the 64-bit seed and 2 for the 64-bit offset. +__device__ __forceinline__ void philox_derive_key( + uint4 r, + uint64_t* out_seed, + uint64_t* out_offset) { + *out_seed = static_cast(r.x) | (static_cast(r.y) << 32); + *out_offset = static_cast(r.z) | (static_cast(r.w) << 32); +} + +// Grid-stride loop over (split_idx, key_idx) pairs. +__global__ void philox_key_split_kernel( + const uint64_t* __restrict__ input, + uint64_t* __restrict__ output, + int64_t num_keys, + int64_t num_splits) { + int64_t total = num_keys * num_splits; + int64_t tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + for (; tid < total; tid += static_cast(gridDim.x) * blockDim.x) { + int64_t split_idx = tid / num_keys; + int64_t key_idx = tid % num_keys; + + uint64_t seed = input[key_idx * 2]; + uint64_t offset = input[key_idx * 2 + 1]; + + // Sample randomness to get the next (seed, offset pair). + uint4 r = philox_4x32(seed, offset + static_cast(split_idx)); + int64_t out = (split_idx * num_keys + key_idx) * 2; + philox_derive_key(r, &output[out], &output[out + 1]); + } +} + +__global__ void philox_key_fold_in_kernel( + const uint64_t* __restrict__ input, + uint64_t* __restrict__ output, + int64_t num_keys, + int64_t data) { + int64_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + for (; idx < num_keys; idx += static_cast(gridDim.x) * blockDim.x) { + uint64_t seed = input[idx * 2]; + uint64_t offset = input[idx * 2 + 1]; + + // Sample randomness to get the next (seed, offset pair). + uint4 r = philox_4x32(seed, offset + static_cast(data)); + philox_derive_key(r, &output[idx * 2], &output[idx * 2 + 1]); + } +} + +} // anonymous namespace + +Tensor _philox_key_split_cuda(const Tensor& key, int64_t num_splits) { + TORCH_CHECK(key.dim() >= 1 && key.size(-1) == 2, + "_philox_key_split: key must have shape (*batch, 2), got shape ", + key.sizes()); + TORCH_CHECK(key.scalar_type() == kUInt64, + "_philox_key_split: key must have dtype uint64, got ", + key.scalar_type()); + TORCH_CHECK(num_splits > 0, + "_philox_key_split: num_splits must be positive, got ", + num_splits); + + // Output shape: (num_splits, *key.shape) + auto output_sizes = key.sizes().vec(); + output_sizes.insert(output_sizes.begin(), num_splits); + Tensor output = at::empty(output_sizes, key.options()); + int64_t num_keys = key.numel() / 2; + if (num_keys == 0) { + return output; + } + + int64_t total_threads = num_keys * num_splits; + constexpr int block_size = 256; + int num_blocks = std::min( + static_cast((total_threads + block_size - 1) / block_size), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 4); + + auto key_contig = key.contiguous(); + philox_key_split_kernel<<>>( + key_contig.data_ptr(), + output.data_ptr(), + num_keys, num_splits); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +Tensor _philox_key_fold_in_cuda(const Tensor& key, int64_t data) { + TORCH_CHECK(key.dim() >= 1 && key.size(-1) == 2, + "_philox_key_fold_in: key must have shape (*batch, 2), got shape ", + key.sizes()); + TORCH_CHECK(key.scalar_type() == kUInt64, + "_philox_key_fold_in: key must have dtype uint64, got ", + key.scalar_type()); + + Tensor output = at::empty_like(key); + int64_t num_keys = key.numel() / 2; + if (num_keys == 0) { + return output; + } + + constexpr int block_size = 256; + int num_blocks = std::min( + static_cast((num_keys + block_size - 1) / block_size), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount * 4); + + auto key_contig = key.contiguous(); + philox_key_fold_in_kernel<<>>( + key_contig.data_ptr(), + output.data_ptr(), + num_keys, data); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return output; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index ee3a05c854540..5d973274928c2 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -635,10 +635,10 @@ struct ReduceOp { using args_vec_t = std::array; int dim_x = blockDim.x; args_vec_t* shared = (args_vec_t*)shared_memory; - if (dim_x > warpSize) { + if (dim_x > C10_WARP_SIZE) { int address_base = threadIdx.x + threadIdx.y*blockDim.x; shared[address_base] = value; - for (int offset = dim_x/2; offset >= warpSize; offset >>= 1) { + for (int offset = dim_x/2; offset >= C10_WARP_SIZE; offset >>= 1) { __syncthreads(); if (threadIdx.x < offset && threadIdx.x + offset < blockDim.x) { args_vec_t other = shared[address_base + offset]; @@ -649,7 +649,7 @@ struct ReduceOp { shared[address_base] = value; } } - dim_x = warpSize; + dim_x = C10_WARP_SIZE; } __syncthreads(); @@ -836,9 +836,9 @@ struct ReduceOp { } } } else { +#if defined(USE_ROCM) && ROCM_VERSION <= 71300 index_t input_offset = threadIdx.y; index_t step = blockDim.y; -#ifdef USE_ROCM // Prefetch loads to better hide their latency #define PRFCH 4 for (; input_offset < config.ctas_per_output; input_offset += step*PRFCH) { arg_vec_t next[PRFCH]; @@ -855,6 +855,14 @@ struct ReduceOp { } } #else +#if defined(USE_ROCM) + int input_offset = threadIdx.y; + int step = blockDim.y; + #pragma unroll +#else + index_t input_offset = threadIdx.y; + index_t step = blockDim.y; +#endif for (; input_offset < config.ctas_per_output; input_offset += step) { index_t idx = config.staging_memory_offset(input_offset); arg_vec_t next = reduce_buffer[idx]; @@ -969,7 +977,7 @@ inline void launch_jitted_reduce_kernel( class AccumulationBuffer { public: - AccumulationBuffer() {} + AccumulationBuffer() = default; AccumulationBuffer(size_t acc_t_size, size_t out_t_size, char* out_ptr, int64_t size) { out_ptr_ = (char*)out_ptr; @@ -1117,7 +1125,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ if (iter.ndim() == 0 || reduction_on_fastest_striding_dimension) { // Split the input across lanes if the input is contiguous in the reduced // dimension. This will require reduction between threads using warp - // shuffle instructions and shared memory (if block_width > warpSize). + // shuffle instructions and shared memory (if block_width > C10_WARP_SIZE). config.input_mult[0] = config.split_input(block_width); } else { // Otherwise split the output across lanes in a warp. diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index 967f522c2eb8f..4bafe771c7e61 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -198,7 +198,7 @@ void f8f8bf16_rowwise_impl( cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, - DtypeOutput, + void, // Indicate there is no beta scaling to save register LayoutOutput, AlignmentOutput, DtypeOutput, @@ -255,7 +255,7 @@ void f8f8bf16_rowwise_impl( : nullptr}, {{reinterpret_cast(w_scale.data_ptr())}, {{reinterpret_cast(x_scale.data_ptr())}}}}}, - reinterpret_cast(out.data_ptr()), + nullptr, stride_output, reinterpret_cast(out.data_ptr()), stride_output}}; @@ -288,7 +288,7 @@ void f8f8bf16_rowwise_impl( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.data_ptr()); + status = gemm.initialize(arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } @@ -390,7 +390,7 @@ void f8f8bf16_rowwise_impl_sm100_sm120( TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeEpilogue, - DtypeOutput, LayoutOutput, AlignmentOutput, + void, LayoutOutput, AlignmentOutput, DtypeOutput, LayoutOutput, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; @@ -448,7 +448,7 @@ void f8f8bf16_rowwise_impl_sm100_sm120( : nullptr}, {{reinterpret_cast(w_scale.data_ptr())}, {{reinterpret_cast(x_scale.data_ptr())}}}}}, - reinterpret_cast(out.data_ptr()), + nullptr, stride_output, reinterpret_cast(out.data_ptr()), stride_output}}; @@ -481,7 +481,7 @@ void f8f8bf16_rowwise_impl_sm100_sm120( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.data_ptr()); + status = gemm.initialize(arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } @@ -697,7 +697,7 @@ void f8f8bf16_rowwise_impl_sm89( } // Initialize CUTLASS kernel with arguments and workspace pointer - status = gemm.initialize(arguments, workspace.data_ptr()); + status = gemm.initialize(arguments, workspace.data_ptr(), at::cuda::getCurrentCUDAStream()); if (status != cutlass::Status::kSuccess) { throw std::runtime_error("cutlass cannot initialize"); } diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 6f33092b99eea..047aec2c8334a 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -205,8 +206,7 @@ bool is_desired_scaling(const at::Tensor& t, const at::Tensor& scale, ScalingTyp case ScalingType::BlockWise128x128: return is_blockwise_128x128_scaling(t, scale); default: - TORCH_CHECK(false); - return false; + TORCH_CHECK(false, "Unknown scaling type"); } } @@ -827,7 +827,7 @@ _scaled_block1x128_block1x128( scale_a.stride(0) == 1 && ( scale_a.stride(1) == M || - (scale_a.size(1) == 1 && scale_b.stride(1) == 1) + (scale_a.size(1) == 1 && scale_a.stride(1) == 1) ), "scale_a strides must be (", 1, ", ", M, "); got: ", scale_a.strides() ); @@ -849,7 +849,7 @@ _scaled_block1x128_block1x128( scale_b.stride(1) == 1 ) ), - "scale_b strides must be (", 1, ", ", N, "); got: ", scale_a.strides() + "scale_b strides must be (", 1, ", ", N, "); got: ", scale_b.strides() ); auto scaling_choice_a = ScalingType::BlockWise1x128; @@ -987,7 +987,7 @@ _scaled_block1x128_block128x128( scale_a.stride(1) == 1 ) ), - "scale_a must have strides (1, ", M, "); got ", scale_b.strides() + "scale_a must have strides (1, ", M, "); got ", scale_a.strides() ); // scale_b shape TORCH_CHECK_VALUE( @@ -1229,10 +1229,6 @@ _scaled_nvfp4_nvfp4( void check_swizzle_lengths(ScaledGemmImplementation impl, std::vector& swizzle_a, std::vector& swizzle_b) { -#ifdef ROCM - // ROCM doesn't swizzle their formats - we don't care what's passed. - return; -#else // Store implementations that care about swizzling, and how many swizzle arguments // they have to have // NOTE(slayton): auto here is unable to deduce the correct type.. @@ -1249,13 +1245,45 @@ void check_swizzle_lengths(ScaledGemmImplementation impl, if (impl != check_impl) { continue; } - TORCH_CHECK_VALUE(swizzle_a.size() == num_args, "swizzle_a must have ", num_args, " values, got ", swizzle_a.size()); - TORCH_CHECK_VALUE(swizzle_b.size() == num_args, "swizzle_b must have ", num_args, " values, got ", swizzle_b.size()); +#ifdef USE_ROCM + if ( + check_impl != ScaledGemmImplementation::MXFP8_MXFP8 && + check_impl != ScaledGemmImplementation::MXFP4_MXFP4) { + // ROCm currently does not support NVFP4 paths. + break; + } + TORCH_CHECK_VALUE( + swizzle_a.size() == 1 && swizzle_b.size() == 1, + "For ROCM MX gemm, swizzle_a and swizzle_b must each have 1 value, got ", + swizzle_a.size(), + " and ", + swizzle_b.size()); + TORCH_CHECK_VALUE( + swizzle_a[0] == SwizzleType::NO_SWIZZLE && + swizzle_b[0] == SwizzleType::NO_SWIZZLE, + "For ROCM MX gemm, swizzle_a and swizzle_b must both be NO_SWIZZLE"); +#else + TORCH_CHECK_VALUE( + swizzle_a.size() == num_args, + "swizzle_a must have ", + num_args, + " value", + num_args == 1 ? "" : "s", + ", got ", + swizzle_a.size()); + TORCH_CHECK_VALUE( + swizzle_b.size() == num_args, + "swizzle_b must have ", + num_args, + " value", + num_args == 1 ? "" : "s", + ", got ", + swizzle_b.size()); +#endif // No need to check anything else break; } -#endif } }; // anonymous namespace diff --git a/aten/src/ATen/native/cuda/ScaledGroupMM.cu b/aten/src/ATen/native/cuda/ScaledGroupMM.cu index a276579b15fd2..69a1a27c76bb1 100644 --- a/aten/src/ATen/native/cuda/ScaledGroupMM.cu +++ b/aten/src/ATen/native/cuda/ScaledGroupMM.cu @@ -192,7 +192,7 @@ void f8f8bf16_grouped_gemm_impl_sm90( cutlass::epilogue::collective::EpilogueTileAuto, DtypeAccum, DtypeAccum, - DtypeOutput, + void, // Indicate there is no beta scaling to save register LayoutOutput*, AlignmentOutput, DtypeOutput, @@ -390,7 +390,7 @@ void f8f8bf16_grouped_gemm_impl_sm90( (const DtypeB**)inputB_ptrs, stride_B}, {{{{inputB_scale_ptrs}, {{inputA_scale_ptrs}, {}, {}}, {}}, {}}, - (const DtypeOutput**)output_ptrs, + nullptr, stride_output, output_ptrs, stride_output}}; diff --git a/aten/src/ATen/native/cuda/Sort.cu b/aten/src/ATen/native/cuda/Sort.cu index bae17a6a01236..a05ded2e2bb0f 100644 --- a/aten/src/ATen/native/cuda/Sort.cu +++ b/aten/src/ATen/native/cuda/Sort.cu @@ -107,7 +107,7 @@ struct SmallBitonicSort { // For small sorts (n <= 128) we use warpMergeSortKVInPlace which // sorts one slice per warp and potentially multiple slices in the // same block for improved occupancy with large batch sizes. -template +template struct WarpMergeSort { template @@ -128,7 +128,7 @@ struct WarpMergeSort { const auto min_grid = minimum_grid_for_occupancy( warpMergeSortKVInPlace< A, -1, sort_size, max_block_y, - K, V, LTOp, IndexType>, + K, V, LTOp, IndexType, WARP_SIZE>, block_x * max_block_y); const auto max_batch = std::max(IndexType{1}, keySlices / min_grid); const int block_y = std::min(IndexType(max_block_y), max_batch); @@ -142,7 +142,9 @@ struct WarpMergeSort { if (descending) { const K invalid_key = at::numeric_limits::lower_bound(); - warpMergeSortKVInPlace + warpMergeSortKVInPlace< + A, -1, sort_size, max_block_y, + K, V, GTOp, IndexType, WARP_SIZE> <<>>( keyInfo, keySlices, @@ -161,7 +163,9 @@ struct WarpMergeSort { } return at::numeric_limits::upper_bound(); }(); - warpMergeSortKVInPlace + warpMergeSortKVInPlace< + A, -1, sort_size, max_block_y, + K, V, LTOp, IndexType, WARP_SIZE> <<>>( keyInfo, keySlices, @@ -375,7 +379,16 @@ void sortKeyValueInplace( sortCommon(SmallBitonicSort{}, key, value, dim, descending); #if HAS_WARP_MERGE_SORT() } else if (sort_size <= 128) { - sortCommon(WarpMergeSort<128>{}, key, value, dim, descending); +#ifdef USE_ROCM + if (at::cuda::warp_size() == 32) { + sortCommon(WarpMergeSort<128, 32>{}, key, value, dim, descending); + } + else { + sortCommon(WarpMergeSort<128, 64>{}, key, value, dim, descending); + } +#else + sortCommon(WarpMergeSort<128, C10_WARP_SIZE>{}, key, value, dim, descending); +#endif #endif } else { sortCommon(MediumRadixSort{}, key, value, dim, descending); diff --git a/aten/src/ATen/native/cuda/SortUtils.cuh b/aten/src/ATen/native/cuda/SortUtils.cuh index 8e424707c7ef6..8bcb0ae323cfc 100644 --- a/aten/src/ATen/native/cuda/SortUtils.cuh +++ b/aten/src/ATen/native/cuda/SortUtils.cuh @@ -165,15 +165,18 @@ bitonicSortKVInPlace(at::cuda::detail::TensorInfo keys, #if HAS_WARP_MERGE_SORT() +// Note [warp merge sort WARP_SIZE template param] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// warpMergeSortKVInPlace was written assuming C10_WARP_SIZE is a constexpr. +// In torch/headeronly/macros/Macros.h, C10_WARP_SIZE is 32 for CUDA, and on +// ROCm it will be 32 or 64 based on the current compile-time gfx target. +// Ideally, warpSize should be used instead of C10_WARP_SIZE in device code, but +// C10_WARP_SIZE within this kernel has been used as a template parameter for +// some device functions. Therefore, a template param for WARP_SIZE was added. template -#if !defined(USE_ROCM) -// On CUDA, use explicit launch bounds for better occupancy -C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE * max_block_dim_y) -#endif -// Note: ROCm doesn't use launch bounds here because C10_WARP_SIZE is not -// a true compile-time constant in device code (it's a constexpr function). -// The compiler infers good launch bounds from the kernel code automatically. + typename K, typename V, typename Comparator, typename IndexType, + int WARP_SIZE> +C10_LAUNCH_BOUNDS_1(WARP_SIZE * max_block_dim_y) __global__ void warpMergeSortKVInPlace( at::cuda::detail::TensorInfo keys, @@ -207,17 +210,17 @@ warpMergeSortKVInPlace( namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); - CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.x == WARP_SIZE); CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); - constexpr int items_per_thread = sort_size / C10_WARP_SIZE; + constexpr int items_per_thread = sort_size / WARP_SIZE; static_assert( - items_per_thread * C10_WARP_SIZE == sort_size, - "sort_size must be a multiple of C10_WARP_SIZE"); + items_per_thread * WARP_SIZE == sort_size, + "sort_size must be a multiple of WARP_SIZE template param"); using LoadKeys = cub::WarpLoad; using LoadValues = cub::WarpLoad; - using Sort = cub::WarpMergeSort; + using Sort = cub::WarpMergeSort; using StoreKeys = cub::WarpStore; using StoreValues = cub::WarpStore; diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh index 4ebc6de07e534..67c15e0dd3385 100644 --- a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh +++ b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh @@ -474,13 +474,13 @@ __device__ __forceinline__ void countRadixAggregateCounts( // Maximum number of warps per workgroup. HIP workgroups have at most 1024 threads. // Warp size is at least 32 (can be 64 on some architectures), so we use 32 for safety. // This sizes shared memory buffers to accommodate all possible warps: 1024/32 = 32. - constexpr uint MAX_WARPS = 1024/32; + constexpr uint MAX_WARPS = 1024/C10_WARP_SIZE_LOWER_BOUND; const int buffer_offset = buffer_index * MAX_WARPS * RadixSize; // offset of the buffer in smem. - const uint WARP_BITS = __builtin_ctz(warpSize); + const uint WARP_BITS = __builtin_ctz(C10_WARP_SIZE); const uint num_warps = blockDim.x >> WARP_BITS; // Actual number of warps in this block - const uint warp_id = threadIdx.x >> WARP_BITS; // = threadIdx.x / warpSize - const int lane_id = at::cuda::getLaneId(); // = threadIdx.x % warpSize + const uint warp_id = threadIdx.x >> WARP_BITS; // = threadIdx.x / C10_WARP_SIZE + const int lane_id = at::cuda::getLaneId(); // = threadIdx.x % C10_WARP_SIZE // Stage 1: Each warp's lane 0 stores its counts in smem. // Layout after Stage 1: [warp0: all radix bins], [warp1: all radix bins], ... @@ -521,7 +521,6 @@ __device__ __forceinline__ void countRadixAggregateCounts( for (uint32_t i = 0; i < RadixSize; ++i) { counts[i] = smem[buffer_offset + i]; } - __syncthreads(); // Wait for all threads to finish reading the final counts. } // This function counts the distribution of all input values in a @@ -692,6 +691,15 @@ __device__ scalar_t findPatternDataSmem( const scalar_t* dataSmem, // input data stored in shared memory. index_t dataSmemSize) { // input data size stored in shared memory. + // Ensure all threads have finished reading from smem before overwriting it. + // countRadixAggregateCounts Stage 3 reads from smem[buffer_offset + i]; + // when buffer_offset == 0, those locations overlap with smem[0]/smem[1] + // written below. Warp 0 (which writes smem[0]/smem[1]) may get ahead of + // lagging warps still in Stage 3. Syncing here (rather than at the end of + // Stage 3) is cheaper because findPatternDataSmem is called at most once per + // radixSelect invocation, only when a unique element is found (count == 1). + __syncthreads(); + // initialize smem to 0. // smem[0] is a flag to indicate if a value has been found. // smem[1] is the found value. diff --git a/aten/src/ATen/native/cuda/SummaryOps.cu b/aten/src/ATen/native/cuda/SummaryOps.cu index 00916903c3eb1..dd192bf91444e 100644 --- a/aten/src/ATen/native/cuda/SummaryOps.cu +++ b/aten/src/ATen/native/cuda/SummaryOps.cu @@ -326,8 +326,9 @@ Tensor _histc_cuda_template( bounds_t maxvalue = max; if (min == max && self.numel() > 0) { - minvalue = *self.min().cpu().const_data_ptr(); - maxvalue = *self.max().cpu().const_data_ptr(); + auto [min_tensor, max_tensor] = self.aminmax(); + minvalue = min_tensor.item(); + maxvalue = max_tensor.item(); } if (minvalue == maxvalue) { minvalue = minvalue - 1; diff --git a/aten/src/ATen/native/cuda/TensorTopK.cu b/aten/src/ATen/native/cuda/TensorTopK.cu index 835b718fb6d49..2d21944e92ceb 100644 --- a/aten/src/ATen/native/cuda/TensorTopK.cu +++ b/aten/src/ATen/native/cuda/TensorTopK.cu @@ -319,7 +319,7 @@ __global__ void gatherTopK(at::cuda::detail::TensorInfo inpu } __syncthreads(); // All threads within the warp need to participate in the loop, so rounding up to a multiple of the warp size. - IndexType numIterations = round_up(inputSliceSize, (IndexType) warpSize); + IndexType numIterations = round_up(inputSliceSize, (IndexType) C10_WARP_SIZE); // phase 1: write actual > `pattern` (or < `pattern`, depending on the sort direction) values to the output. // prefetching data from global memory. @@ -447,8 +447,9 @@ constexpr int MAX_WARP_TOPK_SLICE = 512; // GTOp/LTOp instead of bitwise conversion. Bitwise conversion is only needed for radix sorting. // Kernel using WarpMergeSort for small topK operations +// See Note [warp merge sort WARP_SIZE template param] template + typename scalar_t, typename IndexType, bool is_descending, int WARP_SIZE> __global__ void warpMergeSortTopK( at::cuda::detail::TensorInfo input, IndexType inputSliceSize, @@ -486,14 +487,14 @@ __global__ void warpMergeSortTopK( namespace cub = ROCM_HIPCUB(at_cuda_detail::cub); - CUDA_KERNEL_ASSERT(blockDim.x == C10_WARP_SIZE); + CUDA_KERNEL_ASSERT(blockDim.x == WARP_SIZE); CUDA_KERNEL_ASSERT(blockDim.y <= max_block_dim_y); - constexpr int items_per_thread = sort_size / C10_WARP_SIZE; - static_assert(items_per_thread * C10_WARP_SIZE == sort_size, - "sort_size must be a multiple of C10_WARP_SIZE"); + constexpr int items_per_thread = sort_size / WARP_SIZE; + static_assert(items_per_thread * WARP_SIZE == sort_size, + "sort_size must be a multiple of WARP_SIZE template param"); using LoadKeys = cub::WarpLoad; - using Sort = cub::WarpMergeSort; + using Sort = cub::WarpMergeSort; using StoreKeys = cub::WarpStore; using StoreIndices = cub::WarpStore; @@ -569,12 +570,30 @@ void launch( "Too many slices for warp topk"); // Dispatch based on sort size and sort direction + // See Note [warp merge sort WARP_SIZE template param] +#ifdef USE_ROCM #define LAUNCH_KERNEL(SORT_SIZE, IS_DESCENDING) \ - warpMergeSortTopK \ + if (at::cuda::warp_size() == 32) { \ + warpMergeSortTopK \ + <<>>( \ + input, inputSliceSize, k, numInputSlices, inputWithinSliceStride, \ + topK, topKWithinSliceStride, indices, indicesWithinSliceStride); \ + } \ + else { \ + warpMergeSortTopK \ + <<>>( \ + input, inputSliceSize, k, numInputSlices, inputWithinSliceStride, \ + topK, topKWithinSliceStride, indices, indicesWithinSliceStride); \ + } \ + C10_CUDA_KERNEL_LAUNCH_CHECK() +#else + #define LAUNCH_KERNEL(SORT_SIZE, IS_DESCENDING) \ + warpMergeSortTopK \ <<>>( \ input, inputSliceSize, k, numInputSlices, inputWithinSliceStride, \ topK, topKWithinSliceStride, indices, indicesWithinSliceStride); \ C10_CUDA_KERNEL_LAUNCH_CHECK() +#endif // We have specialized launches for different sizes, as sort_size affects // shared memory, registers per thread and occupancy. We can use 'LAUNCH_KERNEL(512, false);' @@ -849,8 +868,9 @@ __global__ void computeBlockwiseWithinKCounts( } } - constexpr int num_warps = RADIX_DIGITS / C10_WARP_SIZE; - __shared__ uint32_t warp_counts[num_warps]; + constexpr int SHMEM_SIZE = RADIX_DIGITS / C10_WARP_SIZE_LOWER_BOUND; // max shmem size on ROCm + const int num_warps = RADIX_DIGITS / C10_WARP_SIZE; + __shared__ uint32_t warp_counts[SHMEM_SIZE]; if (tidx % C10_WARP_SIZE == 0) { warp_counts[warp] = count; } @@ -1023,8 +1043,9 @@ __global__ void computeBlockwiseWithinKCounts( } } - constexpr int num_warps = RADIX_DIGITS / C10_WARP_SIZE; - __shared__ uint32_t warp_counts[num_warps]; + constexpr int SHMEM_SIZE = RADIX_DIGITS / C10_WARP_SIZE_LOWER_BOUND; // max shmem size on ROCm + const int num_warps = RADIX_DIGITS / C10_WARP_SIZE; + __shared__ uint32_t warp_counts[SHMEM_SIZE]; if (tidx % C10_WARP_SIZE == 0) { warp_counts[warp] = count; } diff --git a/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h index dc07a53543ad7..f75606bdc9eae 100644 --- a/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ b/aten/src/ATen/native/cuda/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -68,13 +68,13 @@ struct GemmFpAIntB { using ElementA = typename Mma::IteratorA::Element; using LayoutA = typename Mma::IteratorA::Layout; using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; using ElementC = typename Epilogue::OutputTileIterator::Element; using LayoutC = typename Mma::LayoutC; using ElementScale = ElementC; static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; // Type definitions about the mainloop. using Operator = typename Mma::Operator; diff --git a/aten/src/ATen/native/cuda/fused_adagrad_utils.cuh b/aten/src/ATen/native/cuda/fused_adagrad_utils.cuh index cb82c30fae198..fd70a33affdcf 100644 --- a/aten/src/ATen/native/cuda/fused_adagrad_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adagrad_utils.cuh @@ -104,6 +104,9 @@ struct FusedAdagradMathFunctor { found_inf_ptr); load_store(args[kParamIdx], r_args[kParamIdx], i_start, 0); + if (grad_scale_ptr) { + load_store(args[kGradIdx], r_args[kGradIdx], i_start, 0); + } load_store(args[kStateSumIdx], r_args[kStateSumIdx], i_start, 0); } } else { diff --git a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu index cef07de1b41f9..f4844a79b2d70 100644 --- a/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -36,26 +37,64 @@ void _fused_adam_amsgrad_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = nullptr; - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adam_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + "Mixed-precision fused Adam"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adam_amsgrad_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + scalar_t, + 5, + ADAM_MODE::ORIGINAL, + true>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adam_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } // The following overload simply has a Tensor lr @@ -87,26 +126,64 @@ void _fused_adam_amsgrad_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = lr.const_data_ptr(); - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adam_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + "Mixed-precision fused Adam"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adam_amsgrad_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + scalar_t, + 5, + ADAM_MODE::ORIGINAL, + true>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adam_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/fused_adam_impl.cu b/aten/src/ATen/native/cuda/fused_adam_impl.cu index 2c1f5ce0d6d57..e51e446ea5e74 100644 --- a/aten/src/ATen/native/cuda/fused_adam_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adam_impl.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -31,26 +32,59 @@ void _fused_adam_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = nullptr; - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adam_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, grads, exp_avgs, exp_avg_sqs, "Mixed-precision fused Adam"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adam_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + float, + 4, + ADAM_MODE::ORIGINAL, + false>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adam_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } // The following overload simply has a Tensor lr @@ -77,26 +111,59 @@ void _fused_adam_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = lr.const_data_ptr(); - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adam_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, grads, exp_avgs, exp_avg_sqs, "Mixed-precision fused Adam"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adam_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + float, + 4, + ADAM_MODE::ORIGINAL, + false>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adam_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/fused_adam_utils.cuh b/aten/src/ATen/native/cuda/fused_adam_utils.cuh index 0533db28558b7..5440577ded0ad 100644 --- a/aten/src/ATen/native/cuda/fused_adam_utils.cuh +++ b/aten/src/ATen/native/cuda/fused_adam_utils.cuh @@ -3,12 +3,73 @@ #include #include #include +#include #include namespace at::native { enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 }; +// Validates the dtype configuration for mixed-precision fused Adam/AdamW. +// +// Currently the only supported configuration is: +// params/grads: float32 +// optimizer states (exp_avg, exp_avg_sq, ...): bfloat16 +// +// This specific configuration (fp32 params + bf16 optimizer states) has been +// validated end-to-end in large-scale training runs (e.g. DeepSeek-V3 671B) +// and is the only one for which training convergence has been demonstrated. +// Additional mixed-precision configurations (e.g. float16 states) can be +// enabled here once convergence is verified for those as well. +// +// Only [0] is checked because within-list dtype homogeneity is guaranteed by +// _check_tensors_share_device_and_dtype (with skip_cross_list_dtype_check) +// and the Python-side grouping in +// _group_tensors_by_first_tensors_device_and_dtype. +inline void validate_mixed_precision_dtypes( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + const char* op_name) { + TORCH_CHECK( + params[0].scalar_type() == at::kFloat, + op_name, + " requires float32 params, got ", + params[0].scalar_type()); + TORCH_CHECK( + grads[0].scalar_type() == at::kFloat, + op_name, + " requires float32 grads, got ", + grads[0].scalar_type()); + TORCH_CHECK( + exp_avgs[0].scalar_type() == at::kBFloat16, + op_name, + " requires bfloat16 optimizer states, got ", + exp_avgs[0].scalar_type()); + TORCH_CHECK( + exp_avg_sqs[0].scalar_type() == at::kBFloat16, + op_name, + " requires bfloat16 optimizer states, got ", + exp_avg_sqs[0].scalar_type()); +} + +inline void validate_mixed_precision_dtypes( + at::TensorList params, + at::TensorList grads, + at::TensorList exp_avgs, + at::TensorList exp_avg_sqs, + at::TensorList max_exp_avg_sqs, + const char* op_name) { + validate_mixed_precision_dtypes( + params, grads, exp_avgs, exp_avg_sqs, op_name); + TORCH_CHECK( + max_exp_avg_sqs[0].scalar_type() == at::kBFloat16, + op_name, + " requires bfloat16 max_exp_avg_sqs, got ", + max_exp_avg_sqs[0].scalar_type()); +} + namespace { constexpr uint8_t kParamIdx = 0; @@ -211,6 +272,282 @@ struct FusedAdamMathFunctor { } } }; + +template < + typename scalar_type, + typename param_type, + typename grad_type, + typename exp_avg_type, + typename exp_avg_sq_type, + typename max_exp_avg_sq_type, + int depth, + ADAM_MODE adam_mode, + bool amsgrad> +struct FusedAdamMathFunctorMP { + static_assert( + depth == 4 || depth == 5, + "depth of 4 for Adam, depth of 5 for Adam with AMSGrad."); + using opmath_t = at::opmath_type; + C10_DEVICE __forceinline__ void operator()( + int64_t chunk_size, + FusedOptimizerTensorListMetadata& tl, + const float* lr_ptr, + const double& lr, + const double& beta1, + const double& beta2, + const double& weight_decay, + const double& eps, + const bool& maximize, + const float* grad_scale_ptr, + const float* found_inf_ptr) { + const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; + const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; + const double lr_double = lr_ptr ? *lr_ptr : lr; + + if (found_inf_ptr && *found_inf_ptr == 1) { + return; + } + const auto [bias_correction1, bias_correction2_sqrt] = + [&]() -> std::pair { + auto* step_count = + reinterpret_cast(tl.state_steps_addresses[tensor_loc]); + const auto bias_correction1 = 1 - at::native::pow_(beta1, *step_count); + const auto bias_correction2 = 1 - at::native::pow_(beta2, *step_count); + const auto bias_correction2_sqrt = std::sqrt(bias_correction2); + return {bias_correction1, bias_correction2_sqrt}; + }(); + + param_type* param_args; + grad_type* grad_args; + exp_avg_type* exp_avg_args; + exp_avg_sq_type* exp_avg_sq_args; + [[maybe_unused]] max_exp_avg_sq_type* max_exp_avg_sq_args; + + // r_args represents the state when everything is casted to scalar_type + // to be passed into the adam_math function. scalar_type is our operation + // math type. + scalar_type r_args[depth][kILP]; + + // n = total numel of tensor - what's already been processed + // so n = numel in current tensor not yet processed + const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; + + bool all_aligned = init_args_mixed_prec< + depth, + param_type, + grad_type, + exp_avg_type, + exp_avg_sq_type>( + ¶m_args, + &grad_args, + &exp_avg_args, + &exp_avg_sq_args, + tl, + chunk_idx, + chunk_size, + tensor_loc); + if constexpr (amsgrad) { + max_exp_avg_sq_args = + (max_exp_avg_sq_type*)tl.addresses[kMaxExpAvgSqIdx][tensor_loc] + + chunk_idx * chunk_size; + all_aligned = all_aligned && is_aligned(max_exp_avg_sq_args); + } + if ((n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned) { + for (int64_t i_start = threadIdx.x; + i_start * kILP < n && i_start * kILP < chunk_size; + i_start += blockDim.x) { + if constexpr (!std::is_same_v) { + scalar_type casted_param_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_param_args[ii] = + static_cast(param_args[ii + i_start * kILP]); + } + load_store(r_args[kParamIdx], casted_param_args, 0, 0); + } else { + load_store(r_args[kParamIdx], (scalar_type*)param_args, 0, i_start); + } + if constexpr (!std::is_same_v) { + scalar_type casted_grad_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_grad_args[ii] = + static_cast(grad_args[ii + i_start * kILP]); + } + load_store(r_args[kGradIdx], casted_grad_args, 0, 0); + } else { + load_store(r_args[kGradIdx], (scalar_type*)grad_args, 0, i_start); + } + if constexpr (!std::is_same_v) { + scalar_type casted_exp_avg_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_exp_avg_args[ii] = + static_cast(exp_avg_args[ii + i_start * kILP]); + } + load_store(r_args[kExpAvgIdx], casted_exp_avg_args, 0, 0); + } else { + load_store( + r_args[kExpAvgIdx], (scalar_type*)exp_avg_args, 0, i_start); + } + if constexpr (!std::is_same_v) { + scalar_type casted_exp_avg_sq_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_exp_avg_sq_args[ii] = + static_cast(exp_avg_sq_args[ii + i_start * kILP]); + } + load_store(r_args[kExpAvgSqIdx], casted_exp_avg_sq_args, 0, 0); + } else { + load_store( + r_args[kExpAvgSqIdx], (scalar_type*)exp_avg_sq_args, 0, i_start); + } + if constexpr (amsgrad) { + if constexpr (!std::is_same_v) { + scalar_type casted[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted[ii] = static_cast( + max_exp_avg_sq_args[ii + i_start * kILP]); + } + load_store(r_args[kMaxExpAvgSqIdx], casted, 0, 0); + } else { + load_store( + r_args[kMaxExpAvgSqIdx], + (scalar_type*)max_exp_avg_sq_args, + 0, + i_start); + } + } + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); + if constexpr (!std::is_same_v) { + param_type casted_r_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_r_args[ii] = static_cast(r_args[kParamIdx][ii]); + } + load_store(param_args, casted_r_args, i_start, 0); + } else { + load_store(param_args, (param_type*)r_args[kParamIdx], i_start, 0); + } + if constexpr (!std::is_same_v) { + exp_avg_type casted_r_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_r_args[ii] = + static_cast(r_args[kExpAvgIdx][ii]); + } + load_store(exp_avg_args, casted_r_args, i_start, 0); + } else { + load_store( + exp_avg_args, (exp_avg_type*)r_args[kExpAvgIdx], i_start, 0); + } + if constexpr (!std::is_same_v) { + exp_avg_sq_type casted_r_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_r_args[ii] = + static_cast(r_args[kExpAvgSqIdx][ii]); + } + load_store(exp_avg_sq_args, casted_r_args, i_start, 0); + } else { + load_store( + exp_avg_sq_args, + (exp_avg_sq_type*)r_args[kExpAvgSqIdx], + i_start, + 0); + } + if constexpr (amsgrad) { + if constexpr (!std::is_same_v) { + max_exp_avg_sq_type casted[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted[ii] = + static_cast(r_args[kMaxExpAvgSqIdx][ii]); + } + load_store(max_exp_avg_sq_args, casted, i_start, 0); + } else { + load_store( + max_exp_avg_sq_args, + (max_exp_avg_sq_type*)r_args[kMaxExpAvgSqIdx], + i_start, + 0); + } + } + if (grad_scale_ptr) { + if constexpr (!std::is_same_v) { + grad_type casted_r_args[kILP]; + for (int ii = 0; ii < kILP; ii++) { + casted_r_args[ii] = static_cast(r_args[kGradIdx][ii]); + } + load_store(grad_args, casted_r_args, i_start, 0); + } else { + load_store(grad_args, (grad_type*)r_args[kGradIdx], i_start, 0); + } + } + } + } else { + for (int64_t i_start = 0; i_start < n && i_start < chunk_size; + i_start += blockDim.x * kILP) { + load_args< + scalar_type, + param_type, + grad_type, + exp_avg_type, + exp_avg_sq_type>( + r_args, + param_args, + grad_args, + exp_avg_args, + exp_avg_sq_args, + i_start, + chunk_size, + n); + if constexpr (amsgrad) { +#pragma unroll + for (int ii = 0; ii < kILP; ii++) { + const auto i = i_start + threadIdx.x + ii * blockDim.x; + r_args[kMaxExpAvgSqIdx][ii] = 0; + if (i < n && i < chunk_size) { + r_args[kMaxExpAvgSqIdx][ii] = + static_cast(max_exp_avg_sq_args[i]); + } + } + } + adam_math( + r_args, + lr_double, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr, + bias_correction1, + bias_correction2_sqrt); + store_args(param_args, r_args[kParamIdx], i_start, chunk_size, n); + store_args(exp_avg_args, r_args[kExpAvgIdx], i_start, chunk_size, n); + store_args( + exp_avg_sq_args, r_args[kExpAvgSqIdx], i_start, chunk_size, n); + if constexpr (amsgrad) { + store_args( + max_exp_avg_sq_args, + r_args[kMaxExpAvgSqIdx], + i_start, + chunk_size, + n); + } + if (grad_scale_ptr) { + store_args(grad_args, r_args[kGradIdx], i_start, chunk_size, n); + } + } + } + } +}; + } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu index b2eff4839133f..dba2b2c2af830 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_amsgrad_impl.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -36,26 +37,64 @@ void _fused_adamw_amsgrad_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = nullptr; - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adamw_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + "Mixed-precision fused AdamW"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adamw_amsgrad_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + scalar_t, + 5, + ADAM_MODE::ADAMW, + true>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adamw_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } // The following overload simply has a Tensor lr @@ -87,26 +126,64 @@ void _fused_adamw_amsgrad_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = lr.const_data_ptr(); - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adamw_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<5>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + "Mixed-precision fused AdamW"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adamw_amsgrad_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + scalar_t, + 5, + ADAM_MODE::ADAMW, + true>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adamw_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<5>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/fused_adamw_impl.cu b/aten/src/ATen/native/cuda/fused_adamw_impl.cu index 90318854bec4c..0a65fee7ba84d 100644 --- a/aten/src/ATen/native/cuda/fused_adamw_impl.cu +++ b/aten/src/ATen/native/cuda/fused_adamw_impl.cu @@ -1,6 +1,7 @@ #include #include +#include #include #include #include @@ -31,26 +32,59 @@ void _fused_adamw_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = nullptr; - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adamw_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, // unused - lr, - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, grads, exp_avgs, exp_avg_sqs, "Mixed-precision fused AdamW"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adamw_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + float, + 4, + ADAM_MODE::ADAMW, + false>(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adamw_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, // unused + lr, + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } // The following overload simply has a Tensor lr @@ -77,26 +111,59 @@ void _fused_adamw_cuda_impl_( found_inf.has_value() ? found_inf->data_ptr() : nullptr; const float* lr_ptr = lr.const_data_ptr(); - AT_DISPATCH_FLOATING_TYPES_AND2( - kHalf, - kBFloat16, - params[0].scalar_type(), - "fused_adamw_kernel_cuda", - [&]() { - multi_tensor_apply_for_fused_optimizer<4>( - tensor_lists, - state_steps, - FusedAdamMathFunctor(), - lr_ptr, - 1.0, // unused - beta1, - beta2, - weight_decay, - eps, - maximize, - grad_scale_ptr, - found_inf_ptr); - }); + if (params[0].scalar_type() != exp_avgs[0].scalar_type()) { + validate_mixed_precision_dtypes( + params, grads, exp_avgs, exp_avg_sqs, "Mixed-precision fused AdamW"); + AT_DISPATCH_V2( + exp_avgs[0].scalar_type(), + "fused_adamw_mp_kernel_cuda", + AT_WRAP([&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctorMP< + float, + float, + float, + scalar_t, + scalar_t, + float, + 4, + ADAM_MODE::ADAMW, + false>(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }), + kBFloat16); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + params[0].scalar_type(), + "fused_adamw_kernel_cuda", + [&]() { + multi_tensor_apply_for_fused_optimizer<4>( + tensor_lists, + state_steps, + FusedAdamMathFunctor(), + lr_ptr, + 1.0, // unused + beta1, + beta2, + weight_decay, + eps, + maximize, + grad_scale_ptr, + found_inf_ptr); + }); + } } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/group_norm_kernel.cu b/aten/src/ATen/native/cuda/group_norm_kernel.cu index 470dd801127d5..23c25554b3a0e 100644 --- a/aten/src/ATen/native/cuda/group_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/group_norm_kernel.cu @@ -26,6 +26,19 @@ namespace { constexpr int kCUDANumThreads = 256; constexpr int kReduceTileSize = 32; +// Reduce across exactly 32 lanes (offsets 16, 8, 4, 2, 1). +// On NVIDIA (warp=32) this is identical to WarpReduceSum. +// On AMD (wavefront=64) this avoids summing across two tile columns +// when the block is (32, 16) and consecutive y-rows share a wavefront. +template +__inline__ __device__ T ReduceSum32(T val) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += WARP_SHFL_DOWN(val, offset); + } + return val; +} + template __global__ void RowwiseMomentsCUDAKernel( int64_t N, @@ -52,7 +65,7 @@ __global__ void RowwiseMomentsCUDAKernel( // https://github.com/pytorch/pytorch/pull/13967 __shared__ typename std::aligned_storage< sizeof(WelfordType), - alignof(WelfordType)>::type val_shared[C10_WARP_SIZE]; + alignof(WelfordType)>::type val_shared[C10_WARP_SIZE_UPPER_BOUND]; WelfordType* val_shared_ptr = reinterpret_cast(val_shared); val = cuda_utils::BlockReduce( val, @@ -123,8 +136,8 @@ __global__ void Compute1dBackwardFusedParamsCUDAKernel( sum1 = cuda_utils::WarpReduceSum(sum1); sum2 = cuda_utils::WarpReduceSum(sum2); } else { - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; + __shared__ T_ACC ds_shared[C10_WARP_SIZE_UPPER_BOUND]; + __shared__ T_ACC db_shared[C10_WARP_SIZE_UPPER_BOUND]; sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); } @@ -238,8 +251,11 @@ __global__ void GammaBeta1dBackwardCUDAKernel2( // Do warp reduce for the 1st 16 cols in the tile. T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; - sum1 = cuda_utils::WarpReduceSum(sum1); - sum2 = cuda_utils::WarpReduceSum(sum2); + // Use ReduceSum32 (not WarpReduceSum) to reduce exactly 32 lanes. + // On AMD wavefront-64, WarpReduceSum would incorrectly sum across two + // tile columns since consecutive y-rows share a wavefront. + sum1 = ReduceSum32(sum1); + sum2 = ReduceSum32(sum2); if (threadIdx.x == 0) { const int64_t c = blockIdx.x * blockDim.x + threadIdx.y; if (c < C) { @@ -255,8 +271,8 @@ __global__ void GammaBeta1dBackwardCUDAKernel2( // Do warp reduce for the 2nd 16 cols in the tile. sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; - sum1 = cuda_utils::WarpReduceSum(sum1); - sum2 = cuda_utils::WarpReduceSum(sum2); + sum1 = ReduceSum32(sum1); + sum2 = ReduceSum32(sum2); if (threadIdx.x == 0) { const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; if (c < C) { @@ -290,8 +306,8 @@ __global__ void ComputeInternalGradientsCUDAKernel( sum1 = cuda_utils::WarpReduceSum(sum1); sum2 = cuda_utils::WarpReduceSum(sum2); } else { - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; + __shared__ T_ACC ds_shared[C10_WARP_SIZE_UPPER_BOUND]; + __shared__ T_ACC db_shared[C10_WARP_SIZE_UPPER_BOUND]; sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); } @@ -333,8 +349,8 @@ __global__ void ComputeBackwardFusedParamsCUDAKernel( sum1 = cuda_utils::WarpReduceSum(sum1); sum2 = cuda_utils::WarpReduceSum(sum2); } else { - __shared__ T_ACC ds_shared[C10_WARP_SIZE]; - __shared__ T_ACC db_shared[C10_WARP_SIZE]; + __shared__ T_ACC ds_shared[C10_WARP_SIZE_UPPER_BOUND]; + __shared__ T_ACC db_shared[C10_WARP_SIZE_UPPER_BOUND]; sum1 = cuda_utils::BlockReduceSum(sum1, ds_shared); sum2 = cuda_utils::BlockReduceSum(sum2, db_shared); } @@ -440,10 +456,11 @@ __global__ void GammaBetaBackwardCUDAKernel2( __syncthreads(); // Do warp reduce for the 1st 16 cols in the tile. + // Use ReduceSum32 for correctness on AMD wavefront-64 (see above). T_ACC sum1 = g_shared[threadIdx.x][threadIdx.y]; T_ACC sum2 = b_shared[threadIdx.x][threadIdx.y]; - sum1 = cuda_utils::WarpReduceSum(sum1); - sum2 = cuda_utils::WarpReduceSum(sum2); + sum1 = ReduceSum32(sum1); + sum2 = ReduceSum32(sum2); if (threadIdx.x == 0) { const int64_t c = blockIdx.x * blockDim.x + threadIdx.y; if (c < C) { @@ -459,8 +476,8 @@ __global__ void GammaBetaBackwardCUDAKernel2( // Do warp reduce for the 2nd 16 cols in the tile. sum1 = g_shared[threadIdx.x][threadIdx.y + blockDim.y]; sum2 = b_shared[threadIdx.x][threadIdx.y + blockDim.y]; - sum1 = cuda_utils::WarpReduceSum(sum1); - sum2 = cuda_utils::WarpReduceSum(sum2); + sum1 = ReduceSum32(sum1); + sum2 = ReduceSum32(sum2); if (threadIdx.x == 0) { const int64_t c = blockIdx.x * blockDim.x + threadIdx.y + blockDim.y; if (c < C) { diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp index 3ad4242618d4a..74739b348db1f 100644 --- a/aten/src/ATen/native/cuda/jit_utils.cpp +++ b/aten/src/ATen/native/cuda/jit_utils.cpp @@ -981,7 +981,6 @@ int calc_thread_work_size( } else { return 4; } - return io_size; #else auto io_size = at::cuda::jit::calc_io_size(nInputs, nOutputs, inputs_type, result_type); TORCH_INTERNAL_ASSERT(io_size > 0); @@ -990,7 +989,6 @@ int calc_thread_work_size( } else { return 8; } - return io_size; #endif } diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index 04333d7450266..59f5e4a041117 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -32,12 +32,6 @@ namespace at::native { namespace { constexpr int kCUDANumThreads = 256; -#ifdef USE_ROCM -// C10_WARP_SIZE is not constexpr for host code. -#define kWarpSize C10_WARP_SIZE -#else -constexpr unsigned int kWarpSize = C10_WARP_SIZE; -#endif constexpr int vec_size = 4; //we could make it dependent on dtype, but that would lead to different results between float and low-p types // aligned vector generates vectorized load/store on CUDA (copy-pasted from MemoryAccess.cuh) @@ -67,7 +61,7 @@ __global__ void RowwiseMomentsCUDAKernel( __shared__ typename std::aligned_storage:: - type val_shared[C10_WARP_SIZE]; + type val_shared[C10_WARP_SIZE_UPPER_BOUND]; WelfordType* val_shared_ptr = reinterpret_cast(val_shared); const int64_t i = blockIdx.x; @@ -662,7 +656,7 @@ blockReduceGammaBetaBackwardsHelper( constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; int64_t thread_x = blockIdx.x * block_dim_x + threadIdx.x; - int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (kWarpSize - 1); + int lane_id = (threadIdx.y * blockDim.x + threadIdx.x) & (C10_WARP_SIZE - 1); int64_t mean_index = M_start + threadIdx.y * rows_per_thread_y; T_ACC warp_mean = 0, warp_rstd = 0; if (lane_id < rows_per_thread_y && mean_index + lane_id < M) { @@ -695,9 +689,9 @@ blockReduceGammaBetaBackwardsHelper( #pragma unroll for (int i = 0; i < rows_per_thread_y; ++i) { - T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, kWarpSize); + T_ACC rstd_reg = WARP_SHFL(warp_rstd, i, C10_WARP_SIZE); if constexpr (!rms_norm){ - T_ACC mean_reg = WARP_SHFL(warp_mean, i, kWarpSize); + T_ACC mean_reg = WARP_SHFL(warp_mean, i, C10_WARP_SIZE); dg_sum += dY_regs[i] * (X_regs[i] - mean_reg) * rstd_reg; db_sum += dY_regs[i]; } else{ @@ -775,7 +769,7 @@ __launch_bounds__(block_dim_x * block_dim_y) T* __restrict__ db) { // This assert is a compile-time check only. constexpr int rows_per_thread_y = rows_per_block_y / block_dim_y; - static_assert(rows_per_thread_y <= kWarpSize); + static_assert(rows_per_thread_y <= C10_WARP_SIZE_LOWER_BOUND); T_ACC dg_sum = 0; T_ACC db_sum = 0; @@ -818,7 +812,7 @@ __launch_bounds__(block_dim_x * block_dim_y) } else { // The caller requested a full reduction so we must reduce across // warps using shared memory and warp shuffles. - static_assert(rows_per_thread_y <= C10_WARP_SIZE); + static_assert(rows_per_thread_y <= C10_WARP_SIZE_LOWER_BOUND); alignas(sizeof(double)) extern __shared__ char s_data1[]; T_ACC* s_data_typed = reinterpret_cast(&s_data1); T_ACC* s_dg; @@ -834,8 +828,8 @@ __launch_bounds__(block_dim_x * block_dim_y) // Load transposed so that a warp holds an entire column // Because block_dim_x != block_dim_y in the general case, we need // some code to handle the general case. - static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE == 0); - constexpr int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; + static_assert(block_dim_x * block_dim_y % C10_WARP_SIZE_LOWER_BOUND == 0); + const int warps_available_to_reduce = block_dim_x * block_dim_y / C10_WARP_SIZE; int thread_id = threadIdx.y * block_dim_x + threadIdx.x; int warp_id = thread_id / C10_WARP_SIZE; int lane_id = thread_id & (C10_WARP_SIZE - 1); @@ -848,8 +842,8 @@ __launch_bounds__(block_dim_x * block_dim_y) } #pragma unroll for (unsigned delta = block_dim_y >> 1; delta >= 1; delta >>= 1) { - reg_dg += WARP_SHFL_XOR(reg_dg, delta, kWarpSize); - reg_db += WARP_SHFL_XOR(reg_db, delta, kWarpSize); + reg_dg += WARP_SHFL_XOR(reg_dg, delta, C10_WARP_SIZE); + reg_db += WARP_SHFL_XOR(reg_db, delta, C10_WARP_SIZE); } // Reduce is done. Now write it out to global memory. int64_t out_index = ((int64_t)blockIdx.x) * block_dim_x + i; @@ -955,7 +949,11 @@ void LaunchGammaBetaBackwardCUDAKernel( Tensor* dgamma, Tensor* dbeta, cudaStream_t cuda_stream) { +#ifdef USE_ROCM + constexpr int block_dim_x = 64; +#else constexpr int block_dim_x = 32; +#endif const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; if (M > 64 * 1024 && N / block_dim_x < sm_count / 2) { // We have a situation where M >> N and N is small. @@ -1010,7 +1008,13 @@ void LaunchGammaBetaBackwardCUDAKernel( } else if (M < 256) { ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } else { +#ifdef USE_ROCM + // Cap block_dim_y at 16 to keep total threads (64*16=1024) within GPU limits. + // rows_per_thread_y = 256/16 = 16, still within warp size constraint. + ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); +#else ConfigureAndLaunchGammaBetaBackwardKernel(dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); +#endif } } } @@ -1617,40 +1621,12 @@ void LayerNormBackwardKernelImplInternal( dbeta_data); C10_CUDA_KERNEL_LAUNCH_CHECK(); } else { - // For small batch size, do colwise reduce directly. - const int part_size = warp_size; - const dim3 threads2(warp_size, 4, 1); - const dim3 blocks2((N + threads2.x - 1) / threads2.x, part_size, 1); - const int nshared2_a = 2 * sizeof(T_ACC) * threads2.y * threads2.y * (threads2.x + 1); - const int nshared2_b = threads2.x * threads2.y * sizeof(T_ACC); - const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; - - const auto part_grad_dtype = at::toAccumulateType(X.scalar_type(), true); - Tensor part_grad_gamma = at::empty({part_size,N}, gamma.options().dtype(part_grad_dtype)); - Tensor part_grad_beta = at::native::empty_like(part_grad_gamma); - - cuComputePartGradGammaBeta<<>>( - dY_data, - X_data, - M,N, - mean_data, - rstd_data, - part_grad_gamma.template data_ptr(), - part_grad_beta.template data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - const dim3 threads3(warp_size, 8, 1); // Optimization for ROCm - const dim3 blocks3((N + threads3.x - 1) / threads3.x, 1, 1); - const int nshared3 = threads3.x * threads3.y * sizeof(T_ACC); - - cuComputeGradGammaBeta<<>>( - part_grad_gamma.template data_ptr(), - part_grad_beta.template data_ptr(), - part_size, - M,N, - dgamma_data, - dbeta_data); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + // Use the optimized tiled kernel adapted for wavefront-64. + // This replaces the legacy two-pass cuComputePartGradGammaBeta + + // cuComputeGradGammaBeta approach with a single-pass tiled reduction + // that has coalesced memory access and adaptive tile sizing. + LaunchGammaBetaBackwardCUDAKernel( + dY_data, X_data, mean_data, rstd_data, M, N, dgamma, dbeta, cuda_stream); } #else LaunchGammaBetaBackwardCUDAKernel( diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index da21a3fed9c00..c3c2f52fbd66b 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -51,11 +51,11 @@ struct MagmaInitializer { } initializer; } // namespace (anonymous) -#define AT_MAGMA_VERSION MAGMA_VERSION_MAJOR*100 + MAGMA_VERSION_MINOR*10 + MAGMA_VERSION_MICRO +#define AT_MAGMA_VERSION MAGMA_VERSION_MAJOR*10000 + MAGMA_VERSION_MINOR*100 + MAGMA_VERSION_MICRO -// Check that MAGMA never releases MAGMA_VERSION_MINOR >= 10 or MAGMA_VERSION_MICRO >= 10 -#if MAGMA_VERSION_MINOR >= 10 || MAGMA_VERSION_MICRO >= 10 -#error "MAGMA release minor or micro version >= 10, please correct AT_MAGMA_VERSION" +// Check that MAGMA never releases MAGMA_VERSION_MINOR >= 100 or MAGMA_VERSION_MICRO >= 100 +#if MAGMA_VERSION_MINOR >= 100 || MAGMA_VERSION_MICRO >= 100 +#error "MAGMA release minor or micro version >= 100, please correct AT_MAGMA_VERSION" #endif #endif @@ -123,7 +123,7 @@ void magmaEig( magma_int_t *info); #endif -#if AT_MAGMA_VERSION >= 254 +#if AT_MAGMA_VERSION >= 20504 template <> void magmaLdlHermitian( @@ -179,7 +179,7 @@ void magmaLdlHermitian>( AT_CUDA_CHECK(cudaGetLastError()); } -#endif // AT_MAGMA_VERSION >= 254 +#endif // AT_MAGMA_VERSION >= 20504 template<> void magmaLu( @@ -597,7 +597,7 @@ void ldl_factor_kernel( // If cusolver and magma 2.5.4+ are both available and hermitian=true, // call magma for complex inputs #ifdef USE_LINALG_SOLVER -#if AT_MAGMA_ENABLED() && (AT_MAGMA_VERSION >= 254) +#if AT_MAGMA_ENABLED() && (AT_MAGMA_VERSION >= 20504) if (LD.is_complex() && hermitian) { return ldl_factor_magma( LD, pivots, info, upper, hermitian); diff --git a/aten/src/ATen/native/cudnn/BatchNorm.h b/aten/src/ATen/native/cudnn/BatchNorm.h index 3da76c0c16e41..7bd0a18b94de2 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.h +++ b/aten/src/ATen/native/cudnn/BatchNorm.h @@ -1,6 +1,6 @@ namespace at::native { -TORCH_API size_t +TORCH_CUDA_CU_API size_t _get_cudnn_batch_norm_reserve_space_size(const Tensor& input_t, bool training); } // namespace at::native diff --git a/aten/src/ATen/native/cudnn/GridSampler.cpp b/aten/src/ATen/native/cudnn/GridSampler.cpp index 1b1e3d19699b7..fc15fa917221a 100644 --- a/aten/src/ATen/native/cudnn/GridSampler.cpp +++ b/aten/src/ATen/native/cudnn/GridSampler.cpp @@ -1,5 +1,6 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include +#include #include #include #include @@ -138,6 +139,7 @@ std::tuple cudnn_grid_sampler_backward( TORCH_CHECK( cond_cudnn_grid_sampler(input_t, grid_t), "Invalid arguments to cudnn_grid_sampler_backward"); + globalContext().alertNotDeterministic("cudnn_grid_sampler_backward"); auto input_contig = contiguousIfZeroInStrides(input_t); auto grid_contig = grid_t.contiguous(); diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index 1b617199330fb..22343822668d9 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -134,7 +134,7 @@ bool _use_cudnn_ctc_loss_tensor( if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) { Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); - IntArrayRef tl(tlc.data_ptr(), tlc.numel()); + IntArrayRef tl(tlc.const_data_ptr(), tlc.numel()); for (const auto b : c10::irange(tl.size())) { // target length < 256 is documented, but we see illegal memory accesses // when target lengths > input lengths for CuDNN @@ -142,7 +142,7 @@ bool _use_cudnn_ctc_loss_tensor( Tensor tlc = target_lengths.to(Device(at::kCPU), at::kLong).contiguous(); IntArrayRef il(ilc.const_data_ptr(), ilc.numel()); - IntArrayRef tl(tlc.data_ptr(), tlc.numel()); + IntArrayRef tl(tlc.const_data_ptr(), tlc.numel()); use_cudnn = use_cudnn && (tl[b] < 256) && (tl[b] <= il[b]); if (!use_cudnn) { break; diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 79fd4cedf45b8..a45fda2b22db4 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -1960,7 +1960,7 @@ std::tuple _cudnn_rnn_backward_input( dx = dx.transpose_(0, 1); } - return std::make_tuple(dx, dhx, dcx); + return std::make_tuple(std::move(dx), std::move(dhx), std::move(dcx)); } // NB: This MUST BE CALLED AFTER _cudnn_rnn_backward_input. @@ -2376,12 +2376,7 @@ struct DropoutState { if (event) { #if !defined(USE_ROCM) // See Note [DropoutState and CUDA graph capture] - cudaStreamCaptureStatus status; - AT_CUDA_CHECK(cudaStreamGetCaptureInfo( - cuda::getCurrentCUDAStream(), &status, &capture_id_last_lock)); - if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { - capture_id_last_lock = 0; - } + capture_id_last_lock = at::cuda::currentStreamCaptureId().value_or(0); if (capture_id_last_lock == capture_id_last_unlock) { event->block(cuda::getCurrentCUDAStream()); } @@ -2396,12 +2391,7 @@ struct DropoutState { event->record(); #if !defined(USE_ROCM) // See Note [DropoutState and CUDA graph capture] - cudaStreamCaptureStatus status; - AT_CUDA_CHECK(cudaStreamGetCaptureInfo( - cuda::getCurrentCUDAStream(), &status, &capture_id_last_unlock)); - if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) { - capture_id_last_unlock = 0; - } + capture_id_last_unlock = at::cuda::currentStreamCaptureId().value_or(0); TORCH_INTERNAL_ASSERT(capture_id_last_unlock == capture_id_last_lock); #endif } diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip index 7b6014fbddb4d..aa53441af2156 100644 --- a/aten/src/ATen/native/hip/ck_group_gemm.hip +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -31,7 +32,54 @@ namespace CkTypes { } template -using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< +using GroupedGemmXdlSplit = ck::tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, // Layout for A + BLayout, // Layout for B + ck::Tuple<>, // DsLayout + ck::tensor_layout::gemm::RowMajor, // ELayout/CLayout + DataType, // A type + DataType, // B type + CkTypes::F32, // Accumulator type + DataType, // CShuffle type + ck::Tuple<>, // Ds type + DataType, // E type + CkTypes::PassThrough, // Elementwise functors + CkTypes::PassThrough, + CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<1,4,64,1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0,2,1,3>, // ABlockTransferThreadClusterArrangeOrder + S<0,2,1,3>, // ABlockTransferSrcAccessOrder + 3, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + 1, // ABlockLdsExtraM + S<1,4,64,1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0,2,1,3>, // BBlockTransferThreadClusterArrangeOrder + S<0,2,1,3>, // BBlockTransferSrcAccessOrder + 3, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1,32,1,8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4 // CDEBlockTransferScalarPerVector_NPerBlock +>; + +template +using GroupedGemmMultipleDSplit = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, @@ -52,7 +100,6 @@ void launch_grouped_bgemm_ck_impl_dispatch( const std::optional& offs, at::Tensor& out) { - using DeviceOp = GroupedGemmKernel; using PassThrough = CkTypes::PassThrough; std::vector gemm_descs; @@ -381,27 +428,51 @@ void launch_grouped_bgemm_ck_impl_dispatch( // Initialize d_ptrs with the correct size std::vector> d_ptrs(p_a_ptrs.size()); - static DeviceOp gemm_instance; - auto argument = gemm_instance.MakeArgument( - p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, - gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} - ); - TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), - "CK Group GEMM: argument unsupported (shape/strides/type config)"); - size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); - size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); - - void* gemm_arg_buf = c10::cuda::CUDACachingAllocator::raw_alloc(arg_buf_size); - void* ws_buf = c10::cuda::CUDACachingAllocator::raw_alloc(ws_size); - - gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); - gemm_instance.SetWorkSpacePointer(&argument, ws_buf); - - auto invoker = gemm_instance.MakeInvoker(); - hipStream_t stream = c10::cuda::getCurrentCUDAStream(); - invoker.Run(argument, {stream}); - c10::cuda::CUDACachingAllocator::raw_delete(gemm_arg_buf); - c10::cuda::CUDACachingAllocator::raw_delete(ws_buf); + bool all_K_equal = true; + ck::index_t base_K = gemm_descs.empty() ? 0 : gemm_descs[0].K_; + for (const auto& desc : gemm_descs) { + if (desc.K_ != base_K) { + all_K_equal = false; + break; + } + } + + auto run_gemm = [](auto& gemm_instance, auto& argument) { + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = c10::cuda::CUDACachingAllocator::raw_alloc(arg_buf_size); + void* ws_buf = c10::cuda::CUDACachingAllocator::raw_alloc(ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::cuda::getCurrentCUDAStream(); + invoker.Run(argument, {stream}); + c10::cuda::CUDACachingAllocator::raw_delete(gemm_arg_buf); + c10::cuda::CUDACachingAllocator::raw_delete(ws_buf); + }; + + if (all_K_equal) { + using DeviceOp = GroupedGemmXdlSplit; + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + run_gemm(gemm_instance, argument); + } else { + using DeviceOp = GroupedGemmMultipleDSplit; + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + run_gemm(gemm_instance, argument); + } } void group_gemm_ck( diff --git a/aten/src/ATen/native/miopen/LossCTC_miopen.cpp b/aten/src/ATen/native/miopen/LossCTC_miopen.cpp index 21797e7537d59..cf5836b7bb296 100644 --- a/aten/src/ATen/native/miopen/LossCTC_miopen.cpp +++ b/aten/src/ATen/native/miopen/LossCTC_miopen.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #endif // TODO: Remove the condition on AT_ROCM_ENABLED entirely, diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index e9534ec9fb5ea..da4ac06ce02e5 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -743,7 +743,7 @@ std::tuple miopen_rnn_backward_input( dx = dx.transpose_(0, 1); } - return std::make_tuple(dx, dhx, dcx, workspace); + return std::make_tuple(std::move(dx), std::move(dhx), std::move(dcx), std::move(workspace)); } std::vector miopen_rnn_backward_weight( diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp index cd87850fe9eb9..6b34c0c7d8b4a 100644 --- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp @@ -23,7 +23,6 @@ Tensor& _sparse_mm_mkl_( #else TORCH_CHECK(false, "sparse_mm_mkl: ATen not compiled with MKL support"); #endif - return self; // for stopping compiler warnings. } } // namespace native diff --git a/aten/src/ATen/native/mkldnn/RNN.cpp b/aten/src/ATen/native/mkldnn/RNN.cpp index c93198aec33a2..63e1f5a9e1242 100644 --- a/aten/src/ATen/native/mkldnn/RNN.cpp +++ b/aten/src/ATen/native/mkldnn/RNN.cpp @@ -432,7 +432,7 @@ std::tuple mkldnn_rnn_la forward_hint.workspace_desc(), workspace.template data_ptr()); ideep::lstm_backward::compute(forward_hint, x, hx, cx, w1, w2, b, y, hy, cy, diff_y, diff_hy, diff_cy, mkldnn_workspace, diff_x, diff_hx, diff_cx, diff_w1, diff_w2, diff_b, reverse); auto diff_b2_ = at::clone(diff_b_); - return std::make_tuple(diff_x_, diff_w1_, diff_w2_, diff_b_, diff_b2_, diff_hx_, diff_cx_); + return std::make_tuple(std::move(diff_x_), std::move(diff_w1_), std::move(diff_w2_), std::move(diff_b_), std::move(diff_b2_), std::move(diff_hx_), std::move(diff_cx_)); } // MKLDNN RNN integration notes: @@ -530,7 +530,7 @@ static std::tuple mkldnn_rnn( if (batch_first) { output = output.transpose(0, 1); } - return std::make_tuple(output, hy, cy); + return std::make_tuple(std::move(output), std::move(hy), std::move(cy)); } //////////////////////////////////////////////////////////////////////////////// diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 7272fab02261a..c8b185510edfe 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -85,6 +85,30 @@ Tensor& addmm_out( " but got:", self.sizes()); + // Bypass OneDNN optimization path for float64 due to lack of full double + // precision support. + if (mat1.scalar_type() == at::kDouble) { + bool is_inplace = self.is_same(result); + bool is_beta_ne_zero = beta.to() != 0.0; + + Tensor self_copy; + if (is_inplace && is_beta_ne_zero) { + self_copy = self.clone(); + } + + onednn::matmul(result, mat1, mat2, Tensor(), true, onednn::Attr()); + + if (alpha.to() != 1.0) { + result.mul_(alpha); + } + + if (is_beta_ne_zero) { + result.add_(is_inplace ? self_copy : self, beta); + } + + return result; + } + // general case Tensor bias = Tensor(); onednn::Attr attr; @@ -223,6 +247,31 @@ Tensor& baddbmm_out( return result; } + // Bypass OneDNN optimization path for float64 due to lack of full double + // precision support. + if (batch1.scalar_type() == at::kDouble || + batch2.scalar_type() == at::kDouble) { + bool is_inplace = input.is_same(result); + bool is_beta_ne_zero = beta.to() != 0.0; + + Tensor input_copy; + if (is_inplace && is_beta_ne_zero) { + input_copy = input.clone(); + } + + onednn::matmul(result, batch1, batch2, Tensor(), true, onednn::Attr()); + + if (alpha.to() != 1.0) { + result.mul_(alpha); + } + + if (is_beta_ne_zero) { + result.add_(is_inplace ? input_copy : input, beta); + } + + return result; + } + // general case onednn::Attr attr; float beta_ = beta.to(); @@ -279,12 +328,6 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { return result; } -Tensor bmm(const Tensor& self, const Tensor& batch2) { - auto result = at::empty({0}, self.options()); - at::native::xpu::bmm_out(self, batch2, result); - return result; -} - Tensor& addmv_out( const Tensor& self, const Tensor& mat, @@ -326,6 +369,17 @@ Tensor& addmv_out( } Tensor vec_v = vec.view({vec.size(0), 1}); + + bool is_float64 = + mat.scalar_type() == at::kDouble || vec.scalar_type() == at::kDouble; + bool is_inplace = self.is_same(out); + if (is_float64 && is_inplace) { + Tensor self_v_copy = self_v.clone(); + at::native::xpu::addmm_out(self_v_copy, mat, vec_v, beta, alpha, out); + out.resize_({mat.size(0)}); + return out; + } + at::native::xpu::addmm_out(self_v, mat, vec_v, beta, alpha, out); out.resize_({mat.size(0)}); return out; diff --git a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp index eb8f154dedaf0..4f24c21f2b377 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Conv.cpp @@ -28,21 +28,11 @@ struct ConvParams { bool transposed{}; std::vector output_padding; int64_t groups{}; - bool benchmark{}; - bool deterministic{}; - bool is_strided() const; - bool is_dilated() const; - bool is_padded() const; bool is_output_padding_neg() const; - bool is_output_padding_big() const; bool is_padding_neg() const; bool is_stride_nonpos() const; void view1d_as_2d(); - bool use_cpu_depthwise3x3_winograd( - const at::Tensor& input, - const at::Tensor& weight) const; - bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const; }; std::ostream& operator<<(std::ostream& out, const ConvParams& params) { @@ -52,35 +42,10 @@ std::ostream& operator<<(std::ostream& out, const ConvParams& params) { << " dilation = " << IntArrayRef{params.dilation} << " transposed = " << params.transposed << " output_padding = " << IntArrayRef{params.output_padding} - << " groups = " << params.groups << " benchmark = " << params.benchmark - << " deterministic = " << params.deterministic << '}'; + << " groups = " << params.groups << '}'; return out; } -bool ConvParams::is_strided() const { - bool is_strided = false; - for (auto s : stride) { - is_strided |= (s != 1); - } - return is_strided; -} - -bool ConvParams::is_dilated() const { - bool is_dilated = false; - for (auto d : dilation) { - is_dilated |= (d != 1); - } - return is_dilated; -} - -bool ConvParams::is_padded() const { - bool is_padded = false; - for (auto p : padding) { - is_padded |= (p != 0); - } - return is_padded; -} - bool ConvParams::is_output_padding_neg() const { bool is_non_neg = false; for (auto p : output_padding) { @@ -89,15 +54,6 @@ bool ConvParams::is_output_padding_neg() const { return is_non_neg; } -bool ConvParams::is_output_padding_big() const { - bool is_big = false; - for (size_t i = 0; i < output_padding.size(); i++) { - is_big |= - (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]); - } - return is_big; -} - bool ConvParams::is_padding_neg() const { bool is_non_neg = false; for (auto p : padding) { @@ -123,30 +79,15 @@ void ConvParams::view1d_as_2d() { } } -bool ConvParams::use_cpu_depthwise3x3_winograd( - const at::Tensor& input, - const at::Tensor& weight) const { - return false; -} - -bool ConvParams::is_depthwise(const at::Tensor& input, const at::Tensor& weight) - const { - return !transposed && input.ndimension() == 4 && input.size(1) == groups && - groups > 1 && // no point if there is only a single group - weight.size(0) % input.size(1) == - 0; // output channels must be a multiple of input channels -} - static void check_shape_forward( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias, - const ConvParams& params, - bool input_is_mkldnn) { + const ConvParams& params) { int64_t k = input.ndimension(); int64_t weight_dim = weight.ndimension(); std::vector weight_sizes(weight_dim); - if ((weight_dim == k + 1) && input_is_mkldnn) { + if (weight_dim == k + 1) { weight_sizes[0] = weight.size(0) * weight.size(1); std::copy_n(weight.sizes().cbegin() + 2, k - 1, weight_sizes.begin() + 1); weight_dim = k; @@ -384,10 +325,11 @@ Tensor _convolution_out( at::MemoryFormat mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input.ndimension()) : at::MemoryFormat::Contiguous; - auto bias = bias_r.defined() ? bias_r.contiguous() : bias_r; - input = input.contiguous(mfmt); - weight = weight.contiguous(mfmt); - check_shape_forward(input, weight, bias, params, true); + + auto bias = bias_r.defined() ? make_contiguous_and_aligned(bias_r) : bias_r; + input = make_contiguous_and_aligned(input, mfmt); + weight = make_contiguous_and_aligned(weight, mfmt); + check_shape_forward(input, weight, bias, params); Tensor output; if (transposed_) { @@ -591,9 +533,9 @@ std::tuple convolution_backward_overrideable( auto mfmt = is_channels_last_suggested ? get_cl_tag_by_ndim(input_.ndimension()) : at::MemoryFormat::Contiguous; - grad_output_ = grad_output_.contiguous(mfmt); - weight_ = weight_.contiguous(mfmt); - input_ = input_.contiguous(mfmt); + grad_output_ = make_contiguous_and_aligned(grad_output_, mfmt); + weight_ = make_contiguous_and_aligned(weight_, mfmt); + input_ = make_contiguous_and_aligned(input_, mfmt); auto opt = grad_output_.options(); Tensor grad_input; diff --git a/aten/src/ATen/native/mkldnn/xpu/FusionUtils.h b/aten/src/ATen/native/mkldnn/xpu/FusionUtils.h index b8b4e1cccf0ac..1ed479cd9a297 100644 --- a/aten/src/ATen/native/mkldnn/xpu/FusionUtils.h +++ b/aten/src/ATen/native/mkldnn/xpu/FusionUtils.h @@ -9,16 +9,6 @@ // namespace at::native::xpu { -at::native::onednn::Attr& unary_attr_with_arg( - onednn::Attr& attr, - std::string_view unary, - torch::List> scalars, - std::optional algorithm); - -at::native::onednn::Attr& string_to_unary_attr( - onednn::Attr& attr, - std::string_view unary); - at::native::onednn::Attr& construct_unary_attr( onednn::Attr& attr, std::string_view unary, diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp index f79dfadd65454..953b3b5a1b2e4 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp @@ -345,8 +345,14 @@ struct ScaleSpec { int64_t outer_dim, int64_t inner_dim, const std::string& arg_type) const { - if (groups == dnnl::memory::dims{1, 1}) + // TensorWise: mask=0, groups={} -> single scale + if (groups.empty()) { + TORCH_INTERNAL_ASSERT( + mask == 0, + "Empty groups only valid for TensorWise (mask=0), got mask=", + mask); return 1; // tensorwise scaling + } TORCH_CHECK( arg_type == "src" || arg_type == "wei", @@ -399,8 +405,8 @@ inline ScaleSpec make_scale_spec( if (scaling_type == at::blas::ScalingType::TensorWise) { // Scale tensorwise. The same as `--attr-scales=common`. // mask=0 : scale whole tensor - // groups={1, 1}: indicates that there is only one group for scaling - return {0, {1, 1}, dnnl::memory::data_type::f32}; + // groups={}: indicates that there is only one group for scaling + return {0, {}, dnnl::memory::data_type::f32}; } else { // (scaling_type == at::blas::ScalingType::RowWise) // Scale RowWise. The same as `--attr-scales=per_dim_01`. @@ -479,7 +485,7 @@ sycl::event scaled_matmul( // scale_result tensor currently only supports scalar(TensorWise Scaling). bool with_dst_scale = scale_result && scale_result->defined(); if (with_dst_scale) { - op_attr.set_scales(DNNL_ARG_DST, 0, {1}, dnnl::memory::data_type::f32); + op_attr.set_scales(DNNL_ARG_DST, 0, {}, dnnl::memory::data_type::f32); } op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp index a8a6b870ff6b6..111c271e6cea5 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.cpp @@ -131,36 +131,36 @@ dnnl::memory::desc get_onednn_md(const at::Tensor& tensor) { bool onednn_strides_check(const Tensor& src) { auto adims = get_onednn_dims(src); int ndims = (int)adims.size(); - auto dims = adims.data(); auto data_type = static_cast( get_onednn_dtype_include_double(src, /*allow_undef*/ false)); auto strides_info = get_onednn_strides(src); auto strides = strides_info.empty() ? nullptr : &strides_info[0]; dnnl_memory_desc_t md; - dnnl_memory_desc_create_with_strides(&md, ndims, dims, data_type, strides); + dnnl_memory_desc_create_with_strides( + &md, ndims, adims.data(), data_type, strides); dnnl_format_kind_t md_fmt_kind; int md_ndims = 0; int md_inner_nblks = 0; dnnl_dims_t* md_padded_dims = nullptr; - dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks); dnnl_memory_desc_query(md, dnnl_query_format_kind, &md_fmt_kind); dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &md_ndims); + dnnl_memory_desc_query(md, dnnl_query_inner_nblks_s32, &md_inner_nblks); dnnl_memory_desc_query(md, dnnl_query_padded_dims, &md_padded_dims); - auto block_size = 1; - // const auto& blk = md->format_desc.blocking; - dnnl_dims_t md_inner_blks; - dnnl_dims_t md_blk_inner_idxs; - dnnl_memory_desc_query(md, dnnl_query_inner_idxs, &md_blk_inner_idxs); - dnnl_memory_desc_query(md, dnnl_query_inner_blks, &md_inner_blks); dnnl_memory_desc_destroy(md); if (strides == nullptr || md_ndims == 0 || md_fmt_kind != dnnl_format_kind_t::dnnl_blocked) return true; - dnnl_dims_t blocks = {0}; + // XPU does not support inner-block formats (e.g. nChw16c); + TORCH_INTERNAL_ASSERT( + md_inner_nblks == 0, + "XPU backend does not support block format. But found inner blocks: ", + md_inner_nblks); + + // Plain blocked format: verify strides are non-overlapping. std::array perm = {0}; for (int d = 0; d < md_ndims; ++d) { // no strides check needed for empty tensor @@ -172,12 +172,6 @@ bool onednn_strides_check(const Tensor& src) { return true; perm[d] = d; - blocks[d] = 1; - } - - for (int iblk = 0; iblk < md_inner_nblks; ++iblk) { - blocks[md_blk_inner_idxs[iblk]] *= md_inner_blks[iblk]; - block_size *= md_inner_blks[iblk]; } // A custom comparator to yield linear order on perm @@ -192,7 +186,7 @@ bool onednn_strides_check(const Tensor& src) { }; std::sort(perm.begin(), perm.begin() + md_ndims, idx_sorter); - auto min_stride = block_size; + int64_t min_stride = 1; for (int idx = 0; idx < md_ndims; ++idx) { const int d = perm[idx]; @@ -204,8 +198,7 @@ bool onednn_strides_check(const Tensor& src) { return false; // update min_stride for next iteration - const auto padded_dim = (*md_padded_dims)[d]; - min_stride = block_size * strides[d] * (padded_dim / blocks[d]); + min_stride = strides[d] * (*md_padded_dims)[d]; } return true; @@ -287,6 +280,27 @@ void undo_broadcast(at::Tensor& tensor) { return; } +bool is_64_bytes_aligned(const at::Tensor& tensor) { + constexpr uintptr_t alignment_byte = 64; + auto data_ptr = reinterpret_cast(tensor.const_data_ptr()); + return (data_ptr % alignment_byte) == 0; +} + +at::Tensor make_contiguous_and_aligned( + const at::Tensor& tensor, + std::optional memory_format) { + at::Tensor out = memory_format.has_value() ? tensor.contiguous(*memory_format) + : tensor.contiguous(); + if (!is_64_bytes_aligned(out)) { + TORCH_WARN( + "Tensor is not 64-byte aligned. Cloning to ensure alignment for oneDNN " + "operations, which incurs a device-to-device copy."); + out = out.clone(); + } + + return out; +} + bool is_onednn_matmul_strides(const at::Tensor& tensor) { // https://oneapi-src.github.io/oneDNN/dev_guide_matmul.html // oneDNN matmul only support 2-dim and 3-dim @@ -300,11 +314,8 @@ bool is_onednn_matmul_strides(const at::Tensor& tensor) { if (tensor.is_contiguous()) return true; - if (tensor.storage_offset() > 0) { - // currently onednn asks 64 byte alignment - constexpr int alignment_byte = 64; - if (reinterpret_cast(tensor.data_ptr()) % alignment_byte > 0) - return false; + if (tensor.storage_offset() > 0 && !is_64_bytes_aligned(tensor)) { + return false; } // the overlapped cases are not supported diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h index 0055fc2f296ad..e841ab8de6583 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/Utils.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -46,6 +47,12 @@ void undo_broadcast(at::Tensor& tensor); bool is_onednn_matmul_strides(const at::Tensor& tensor); +bool is_64_bytes_aligned(const at::Tensor& tensor); + +at::Tensor make_contiguous_and_aligned( + const at::Tensor& tensor, + std::optional memory_format = std::nullopt); + bool is_broadcast_from_other_to_self( const at::Tensor& self, const at::Tensor& other); diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h index d919e87bfa586..f28aaa131e51e 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNNContext.h @@ -13,7 +13,7 @@ namespace at::native::onednn { -TORCH_XPU_API dnnl::memory make_onednn_memory( +dnnl::memory make_onednn_memory( dnnl::memory::desc md, dnnl::engine& engine, void* ptr); @@ -22,7 +22,7 @@ TORCH_XPU_API dnnl::memory make_onednn_memory( bool set_onednn_verbose(int level); // GpuEngineManager singleton -struct TORCH_XPU_API GpuEngineManager { +struct GpuEngineManager { static GpuEngineManager& Instance(); // Singleton dnnl::engine& get_engine( @@ -50,7 +50,7 @@ struct TORCH_XPU_API GpuEngineManager { }; // GpuStreamManager singleton -struct TORCH_XPU_API GpuStreamManager { +struct GpuStreamManager { static GpuStreamManager& Instance(); // Singleton dnnl::stream& get_stream( diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index c39e1e5928df0..b7222dca6a27c 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -141,7 +141,8 @@ class MetalShaderLibrary { TensorIteratorBase& iter, const std::string& name, const std::optional alpha = std::nullopt, - const std::optional scalar_arg_type = std::nullopt); + const std::optional scalar_arg_type = std::nullopt, + bool supports_vec4 = false); void exec_binary_kernel( TensorIteratorBase& iter, const std::string& name, @@ -195,4 +196,11 @@ class DynamicMetalShaderLibrary : public MetalShaderLibrary { ~DynamicMetalShaderLibrary() override; }; +class PrecompiledMetalShaderLibrary : public MetalShaderLibrary { + public: + explicit PrecompiledMetalShaderLibrary(std::vector data); + explicit PrecompiledMetalShaderLibrary(const std::string& path); + ~PrecompiledMetalShaderLibrary() override; +}; + } // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 7b76da64200c9..112f22742ac13 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -28,16 +28,6 @@ #include -@interface MPSGraph (PyTorchFixups) -- (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor - secondaryTensor:(MPSGraphTensor*)secondaryTensor - name:(NSString*)name; - -- (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor - secondaryTensor:(MPSGraphTensor*)secondaryTensor - name:(NSString*)name; -@end - using namespace at::mps; namespace at::native::mps { @@ -54,6 +44,7 @@ struct MPSScalar { float f; // MPS doesn't support 'double' at::Half h; int64_t i; + uint64_t u; bool b; c10::complex cf; c10::complex ch; @@ -260,7 +251,7 @@ struct MPSKernelCache { __block MPSCachedKernel* cachedKernel = nil; MPSCacheKey hash = std::hash{}(key); dispatch_sync_with_rethrow(serialQueue_, ^() { - if (cache_.count(hash) != 0) { + if (cache_.contains(hash)) { auto& entry = cache_.at(hash); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached kernel!\n"); cachedKernel = entry.cachedKernel_; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index ec7e95361cd9a..994abf2ef59cb 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -25,34 +25,6 @@ #include #include -@implementation MPSGraph (PyTorchFixups) -- (MPSGraphTensor*)minimumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor - secondaryTensor:(MPSGraphTensor*)secondaryTensor - name:(NSString*)name { - // As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral - // arguments results in - // /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805: - // failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library' - if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) { - return [self minimumWithPrimaryTensor:primaryTensor secondaryTensor:secondaryTensor name:name]; - } - return [self minimumWithNaNPropagationWithPrimaryTensor:primaryTensor secondaryTensor:secondaryTensor name:name]; -} - -- (MPSGraphTensor*)maximumWithNaNPropagationAndIntFallbackWithPrimaryTensor:(MPSGraphTensor*)primaryTensor - secondaryTensor:(MPSGraphTensor*)secondaryTensor - name:(NSString*)name { - // As of MacOS-15.1 m..imumWithNanPropagation is only defined for floating types and calling it with integral - // arguments results in - // /AppleInternal/Library/BuildRoots/c7c74b64-74b4-11ef-aeda-9635a580fe0d/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805: - // failed assertion `Error getting visible function: (null) Function isNaN_u8_i8 was not found in the library' - if (([primaryTensor dataType] & MPSDataTypeFloatBit) == 0) { - return [self maximumWithPrimaryTensor:primaryTensor secondaryTensor:secondaryTensor name:name]; - } - return [self maximumWithNaNPropagationWithPrimaryTensor:primaryTensor secondaryTensor:secondaryTensor name:name]; -} -@end - namespace at::native::mps { /** * Computes distance from lowest to highest element offset in given tensor. @@ -470,8 +442,6 @@ bool isTooLargeForMPSGraph(const Tensor& tensor, bool useMPSStridedAPI) { auto sizes = src.sizes(); auto nStrides = strides.size(); auto nonZeroStrides = src.strides(); - int64_t crtNonZeroStride = 1; - bool hasZeroStrides = false; auto sortedStridesIndices = getSortedStrides(nonZeroStrides); NSMutableArray* sortedStridesShape = [NSMutableArray arrayWithCapacity:nStrides]; @@ -673,6 +643,8 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { case ScalarType::ComplexDouble: return {.size = sizeof(int64_t), .type = type, .value.cf = scalar.to>()}; // Unsigned types + case ScalarType::UInt64: + return {.size = sizeof(uint64_t), .type = type, .value.u = scalar.to()}; case ScalarType::UInt32: return {.size = sizeof(uint32_t), .type = type, .value.i = scalar.to()}; case ScalarType::UInt16: @@ -959,16 +931,21 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} return raw_ptr; } -class BundledShaderLibary : public MetalShaderLibrary { +class BundledShaderLibrary : public MetalShaderLibrary { public: - BundledShaderLibary() : MetalShaderLibrary("") {} + BundledShaderLibrary() : MetalShaderLibrary("") {} protected: id getLibrary() override { if (C10_UNLIKELY(!library)) { auto device = MPSDevice::getInstance()->device(); NSError* error = nil; - library = [device newLibraryWithData:getSectionData("metal_basic") error:&error]; +#ifdef CAN_BUILD_METAL_4 + const auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_0_PLUS) ? "metal_40" : "metal_basic"; +#else + const auto section_name = "metal_basic"; +#endif + library = [device newLibraryWithData:getSectionData(section_name) error:&error]; TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]); } return library; @@ -1004,11 +981,12 @@ static dispatch_data_t getSectionData(const std::string& name) { void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter, const std::string& name, std::optional alpha, - std::optional scalar_arg_type) { + std::optional scalar_arg_type, + bool supports_vec4) { // Decompose 64-bit tensor into 32-bit ones if (!iter.can_use_32bit_indexing()) { for (auto&& sub_iter : iter.with_32bit_indexing()) { - exec_unary_kernel(sub_iter, name, alpha, scalar_arg_type); + exec_unary_kernel(sub_iter, name, alpha, scalar_arg_type, supports_vec4); } return; } @@ -1020,10 +998,12 @@ static dispatch_data_t getSectionData(const std::string& name) { return; } using namespace mps; + bool use_vec4 = + supports_vec4 && iter.is_contiguous() && !alpha.has_value() && at::isFloatingType(iter.common_dtype()); const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); auto kernel_name = fmt::format("{}_{}_{}_{}{}", name, - iter.is_contiguous() ? "dense" : "strided", + use_vec4 ? "dense_vec4" : (iter.is_contiguous() ? "dense" : "strided"), scalarToMetalTypeString(outputTensor), scalarToMetalTypeString(inputTensor), alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""); @@ -1038,17 +1018,22 @@ static dispatch_data_t getSectionData(const std::string& name) { [computeEncoder setComputePipelineState:cplState]; bind_iter_tensors(computeEncoder, iter); - if (!iter.is_contiguous()) { - mtl_setArgs<2>(computeEncoder, - outputTensor.sizes(), - inputTensor.strides(), - outputTensor.strides(), - inputTensor.ndimension()); - } - if (alpha) { - mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), iter.is_contiguous() ? 2 : 6); + if (use_vec4) { + mtl_setBytes(computeEncoder, length, 2); + mtl_dispatch1DJob(computeEncoder, cplState, (length + 3) / 4); + } else { + if (!iter.is_contiguous()) { + mtl_setArgs<2>(computeEncoder, + outputTensor.sizes(), + inputTensor.strides(), + outputTensor.strides(), + inputTensor.ndimension()); + } + if (alpha) { + mtl_setBytes(computeEncoder, getMPSScalar(*alpha, alpha_type), iter.is_contiguous() ? 2 : 6); + } + mtl_dispatch1DJob(computeEncoder, cplState, length); } - mtl_dispatch1DJob(computeEncoder, cplState, length); getMPSProfiler().endProfileKernel(cplState); }); @@ -1129,7 +1114,8 @@ static dispatch_data_t getSectionData(const std::string& name) { } } - const auto alpha_type = scalar_arg_type.has_value() ? scalar_arg_type.value() : iter.common_dtype(); + const auto alpha_type = + scalar_arg_type.has_value() ? scalar_arg_type.value() : (cast_needed ? out.scalar_type() : iter.common_dtype()); const auto alpha_suffix = alpha.has_value() ? fmt::format("_{}", scalarToMetalTypeString(alpha_type)) : ""; std::string kernel_name; @@ -1342,7 +1328,7 @@ static dispatch_data_t getSectionData(const std::string& name) { } MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() { - static BundledShaderLibary l; + static BundledShaderLibrary l; return l; } @@ -1351,6 +1337,33 @@ static dispatch_data_t getSectionData(const std::string& name) { [library release]; } +// PrecompiledMetalShaderLibrary implementation +PrecompiledMetalShaderLibrary::PrecompiledMetalShaderLibrary(std::vector data) : MetalShaderLibrary("") { + auto device = MPSDevice::getInstance()->device(); + NSError* error = nil; + dispatch_data_t dd = + dispatch_data_create(data.data(), data.size(), dispatch_get_main_queue(), DISPATCH_DATA_DESTRUCTOR_DEFAULT); + library = [device newLibraryWithData:dd error:&error]; + dispatch_release(dd); + TORCH_CHECK(library, "Failed to load metallib: ", error ? [[error description] UTF8String] : "unknown error"); +} + +PrecompiledMetalShaderLibrary::PrecompiledMetalShaderLibrary(const std::string& path) : MetalShaderLibrary("") { + auto device = MPSDevice::getInstance()->device(); + NSError* error = nil; + NSURL* url = [NSURL fileURLWithPath:[NSString stringWithUTF8String:path.c_str()]]; + library = [device newLibraryWithURL:url error:&error]; + TORCH_CHECK(library, + "Failed to load metallib from '", + path, + "': ", + error ? [[error description] UTF8String] : "unknown error"); +} + +PrecompiledMetalShaderLibrary::~PrecompiledMetalShaderLibrary() { + [library release]; +} + // MetalKernelFunction implementation MetalKernelFunction::MetalKernelFunction(MTLComputePipelineState_t cps_, MTLFunction_t f_) : cps([cps_ retain]), func([f_ retain]) {} diff --git a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal index 7d1f3aa5bacf6..dc892dceea2d0 100644 --- a/aten/src/ATen/native/mps/kernels/ActivationKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ActivationKernel.metal @@ -44,6 +44,27 @@ REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half); REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat); +struct relu_functor { + template + inline T operator()(const T x) { + return x > T(0) ? x : T(0); + } +}; + +REGISTER_UNARY_OP(relu, float, float); +REGISTER_UNARY_OP(relu, half, half); +REGISTER_UNARY_OP(relu, bfloat, bfloat); +REGISTER_UNARY_OP(relu, long, long); +REGISTER_UNARY_OP(relu, int, int); +REGISTER_UNARY_OP(relu, short, short); +REGISTER_UNARY_OP(relu, char, char); +REGISTER_UNARY_OP(relu, uchar, uchar); +REGISTER_UNARY_OP(relu, bool, bool); + +REGISTER_UNARY_VEC4_OP(relu, float, float); +REGISTER_UNARY_VEC4_OP(relu, half, half); +REGISTER_UNARY_VEC4_OP(relu, bfloat, bfloat); + struct hardsigmoid_functor { template inline T operator()(const T x) { @@ -180,3 +201,37 @@ REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat); REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float); REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half); REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat); + +struct silu_functor { + template + inline T operator()(const T x) { + float xf = float(x); + return static_cast(xf / (1.0f + ::metal::precise::exp(-xf))); + } +}; + +REGISTER_UNARY_OP(silu, float, float); +REGISTER_UNARY_OP(silu, half, half); +REGISTER_UNARY_OP(silu, bfloat, bfloat); +REGISTER_UNARY_OP(silu, int, int); +REGISTER_UNARY_OP(silu, short, short); +REGISTER_UNARY_OP(silu, char, char); +REGISTER_UNARY_OP(silu, uchar, uchar); +REGISTER_UNARY_OP(silu, bool, bool); + +REGISTER_UNARY_VEC4_OP(silu, float, float); +REGISTER_UNARY_VEC4_OP(silu, half, half); +REGISTER_UNARY_VEC4_OP(silu, bfloat, bfloat); + +struct silu_backward_functor { + template + inline T operator()(const T grad_output, const T self) { + float sf = float(self); + float sig = 1.0f / (1.0f + ::metal::precise::exp(-sf)); + return static_cast(float(grad_output) * sig * (1.0f + sf - sf * sig)); + } +}; + +REGISTER_BINARY_OP(silu_backward, float, float); +REGISTER_BINARY_OP(silu_backward, half, half); +REGISTER_BINARY_OP(silu_backward, bfloat, bfloat); diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 15d259cc9b27a..1def63fff16ae 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -63,14 +63,14 @@ struct fmin_functor { struct maximum_functor { template inline T operator()(const T a, const T b) { - return max(a, b); + return c10::metal::max(a, b); } }; struct minimum_functor { template inline T operator()(const T a, const T b) { - return min(a, b); + return c10::metal::min(a, b); } }; @@ -122,6 +122,20 @@ struct logaddexp2_functor { } }; +struct xlogy_functor { + template , bool> = true> + inline T operator()(const T a, const T b) { + return static_cast(c10::metal::xlogy(a, b)); + } + template , bool> = true> + inline float operator()(const T a, const T b) { + return c10::metal::xlogy(float(a), float(b)); + } + inline float operator()(const bool a, const bool b) { + return (a && !b) ? -INFINITY : 0; + } +}; + struct xlog1py_functor { template , bool> = true> inline T operator()(const T a, const T b) { @@ -449,6 +463,8 @@ REGISTER_FLOAT_BINARY_OP(logaddexp); REGISTER_INT2FLOAT_BINARY_OP(logaddexp); REGISTER_FLOAT_BINARY_OP(logaddexp2); REGISTER_INT2FLOAT_BINARY_OP(logaddexp2); +REGISTER_FLOAT_BINARY_OP(xlogy); +REGISTER_INT2FLOAT_BINARY_OP(xlogy); REGISTER_FLOAT_BINARY_OP(xlog1py); REGISTER_INT2FLOAT_BINARY_OP(xlog1py); REGISTER_FLOAT_BINARY_OP(chebyshev_polynomial_t); @@ -544,3 +560,176 @@ REGISTER_BINARY_ALPHA_OP(sub_alpha, float2, float2, float2); REGISTER_BINARY_ALPHA_OP(sub_alpha, half2, half2, half2); REGISTER_BINARY_ALPHA_OP(lerp_alpha, float2, float2, float2); REGISTER_BINARY_ALPHA_OP(lerp_alpha, half2, half2, half2); + +// lerp with tensor weight: lerp(s, e, w) = fma(w, e - s, s) +template +inline T lerp_op(T s, T e, T w) { + return fma(w, e - s, s); +} + +inline bfloat lerp_op(bfloat s, bfloat e, bfloat w) { + return static_cast(fma(float(w), float(e) - float(s), float(s))); +} + +inline long lerp_op(long s, long e, long w) { + return s + w * (e - s); +} + +inline float2 lerp_op(float2 s, float2 e, float2 w) { + return s + mul(w, e - s); +} + +template +kernel void lerp_tensor_dense( + device T* out [[buffer(0)]], + device const T* self [[buffer(1)]], + device const T* end [[buffer(2)]], + device const T* weight [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = lerp_op(self[tid], end[tid], weight[tid]); +} + +// Scalar weight broadcast: self/end/out contiguous, weight is a single element +template +kernel void lerp_tensor_scalar_weight( + device T* out [[buffer(0)]], + device const T* self [[buffer(1)]], + device const T* end [[buffer(2)]], + device const T& weight [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = lerp_op(self[tid], end[tid], weight); +} + +// 2D strided: coordinates from 2D dispatch, no integer division +template +kernel void lerp_tensor_strided_2d( + device void* out_ptr [[buffer(0)]], + constant void* self_ptr [[buffer(1)]], + constant void* end_ptr [[buffer(2)]], + constant void* weight_ptr [[buffer(3)]], + constant long* out_strides [[buffer(4)]], + constant long* self_strides [[buffer(5)]], + constant long* end_strides [[buffer(6)]], + constant long* weight_strides [[buffer(7)]], + uint2 tid [[thread_position_in_grid]]) { + int out_off = + int(tid.x) * int(out_strides[0]) + int(tid.y) * int(out_strides[1]); + int self_off = + int(tid.x) * int(self_strides[0]) + int(tid.y) * int(self_strides[1]); + int end_off = + int(tid.x) * int(end_strides[0]) + int(tid.y) * int(end_strides[1]); + int wt_off = + int(tid.x) * int(weight_strides[0]) + int(tid.y) * int(weight_strides[1]); + ref_at_offs(out_ptr, long(out_off)) = lerp_op( + val_at_offs(self_ptr, long(self_off)), + val_at_offs(end_ptr, long(end_off)), + val_at_offs(weight_ptr, long(wt_off))); +} + +// 3D strided: coordinates from 3D dispatch, no integer division +template +kernel void lerp_tensor_strided_3d( + device void* out_ptr [[buffer(0)]], + constant void* self_ptr [[buffer(1)]], + constant void* end_ptr [[buffer(2)]], + constant void* weight_ptr [[buffer(3)]], + constant long* out_strides [[buffer(4)]], + constant long* self_strides [[buffer(5)]], + constant long* end_strides [[buffer(6)]], + constant long* weight_strides [[buffer(7)]], + uint3 tid [[thread_position_in_grid]]) { + int out_off = int(tid.x) * int(out_strides[0]) + + int(tid.y) * int(out_strides[1]) + int(tid.z) * int(out_strides[2]); + int self_off = int(tid.x) * int(self_strides[0]) + + int(tid.y) * int(self_strides[1]) + int(tid.z) * int(self_strides[2]); + int end_off = int(tid.x) * int(end_strides[0]) + + int(tid.y) * int(end_strides[1]) + int(tid.z) * int(end_strides[2]); + int wt_off = int(tid.x) * int(weight_strides[0]) + + int(tid.y) * int(weight_strides[1]) + int(tid.z) * int(weight_strides[2]); + ref_at_offs(out_ptr, long(out_off)) = lerp_op( + val_at_offs(self_ptr, long(self_off)), + val_at_offs(end_ptr, long(end_off)), + val_at_offs(weight_ptr, long(wt_off))); +} + +template +kernel void lerp_tensor_strided( + device void* out_ptr [[buffer(0)]], + constant void* self_ptr [[buffer(1)]], + constant void* end_ptr [[buffer(2)]], + constant void* weight_ptr [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* out_strides [[buffer(5)]], + constant long* self_strides [[buffer(6)]], + constant long* end_strides [[buffer(7)]], + constant long* weight_strides [[buffer(8)]], + constant uint& ndim [[buffer(9)]], + uint tid [[thread_position_in_grid]]) { + int pos[max_ndim]; + pos_from_thread_index(int(tid), pos, sizes, ndim); + auto self_off = offset_from_coord(pos, self_strides, ndim); + auto end_off = offset_from_coord(pos, end_strides, ndim); + auto weight_off = offset_from_coord(pos, weight_strides, ndim); + auto out_off = offset_from_coord(pos, out_strides, ndim); + ref_at_offs(out_ptr, out_off) = lerp_op( + val_at_offs(self_ptr, self_off), + val_at_offs(end_ptr, end_off), + val_at_offs(weight_ptr, weight_off)); +} + +#define INSTANTIATE_LERP(DTYPE) \ + template [[host_name("lerp_tensor_dense_" #DTYPE)]] kernel void \ + lerp_tensor_dense( \ + device DTYPE * out [[buffer(0)]], \ + device const DTYPE* self [[buffer(1)]], \ + device const DTYPE* end [[buffer(2)]], \ + device const DTYPE* weight [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + template [[host_name("lerp_tensor_scalar_weight_" #DTYPE)]] kernel void \ + lerp_tensor_scalar_weight( \ + device DTYPE * out [[buffer(0)]], \ + device const DTYPE* self [[buffer(1)]], \ + device const DTYPE* end [[buffer(2)]], \ + device const DTYPE& weight [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + template [[host_name("lerp_tensor_strided_2d_" #DTYPE)]] kernel void \ + lerp_tensor_strided_2d( \ + device void* out_ptr [[buffer(0)]], \ + constant void* self_ptr [[buffer(1)]], \ + constant void* end_ptr [[buffer(2)]], \ + constant void* weight_ptr [[buffer(3)]], \ + constant long* out_strides [[buffer(4)]], \ + constant long* self_strides [[buffer(5)]], \ + constant long* end_strides [[buffer(6)]], \ + constant long* weight_strides [[buffer(7)]], \ + uint2 tid [[thread_position_in_grid]]); \ + template [[host_name("lerp_tensor_strided_3d_" #DTYPE)]] kernel void \ + lerp_tensor_strided_3d( \ + device void* out_ptr [[buffer(0)]], \ + constant void* self_ptr [[buffer(1)]], \ + constant void* end_ptr [[buffer(2)]], \ + constant void* weight_ptr [[buffer(3)]], \ + constant long* out_strides [[buffer(4)]], \ + constant long* self_strides [[buffer(5)]], \ + constant long* end_strides [[buffer(6)]], \ + constant long* weight_strides [[buffer(7)]], \ + uint3 tid [[thread_position_in_grid]]); \ + template [[host_name("lerp_tensor_strided_" #DTYPE)]] kernel void \ + lerp_tensor_strided( \ + device void* out_ptr [[buffer(0)]], \ + constant void* self_ptr [[buffer(1)]], \ + constant void* end_ptr [[buffer(2)]], \ + constant void* weight_ptr [[buffer(3)]], \ + constant long* sizes [[buffer(4)]], \ + constant long* out_strides [[buffer(5)]], \ + constant long* self_strides [[buffer(6)]], \ + constant long* end_strides [[buffer(7)]], \ + constant long* weight_strides [[buffer(8)]], \ + constant uint& ndim [[buffer(9)]], \ + uint tid [[thread_position_in_grid]]); + +INSTANTIATE_LERP(float); +INSTANTIATE_LERP(half); +INSTANTIATE_LERP(bfloat); +INSTANTIATE_LERP(float2); +INSTANTIATE_LERP(long); diff --git a/aten/src/ATen/native/mps/kernels/ConstantKernel.metal b/aten/src/ATen/native/mps/kernels/ConstantKernel.metal new file mode 100644 index 0000000000000..1b4c1277b6448 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ConstantKernel.metal @@ -0,0 +1,102 @@ +#include +#include + +using namespace metal; + +template +kernel void fill_scalar_dense( + device T* out [[buffer(0)]], + constant T& fill_val [[buffer(1)]], + constant uint& numel [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + if (index < numel) + out[index] = fill_val; +} + +// Single-byte types: each thread fills 4 elements via vec writes. +template +kernel void fill_scalar_dense_vec4( + device T* out [[buffer(0)]], + constant T& fill_val [[buffer(1)]], + constant uint& numel [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + uint base = index * 4; + if (base + 4 <= numel) { + *(device vec*)(out + base) = vec(fill_val); + } else { + for (uint i = base; i < numel; i++) + out[i] = fill_val; + } +} + +#define REGISTER_FILL_OP(T) \ + template [[host_name("fill_scalar_dense_" #T)]] kernel void \ + fill_scalar_dense(device T*, constant T&, constant uint&, uint) + +REGISTER_FILL_OP(float); +REGISTER_FILL_OP(half); +REGISTER_FILL_OP(bfloat); +REGISTER_FILL_OP(long); +REGISTER_FILL_OP(ulong); +REGISTER_FILL_OP(int); +REGISTER_FILL_OP(uint); +REGISTER_FILL_OP(short); +REGISTER_FILL_OP(ushort); +REGISTER_FILL_OP(float2); +REGISTER_FILL_OP(half2); + +// Byte types use the vec4 variant under the same naming convention. +#define REGISTER_FILL_BYTE_OP(T) \ + template [[host_name("fill_scalar_dense_" #T)]] kernel void \ + fill_scalar_dense_vec4(device T*, constant T&, constant uint&, uint) + +REGISTER_FILL_BYTE_OP(char); +REGISTER_FILL_BYTE_OP(uchar); +REGISTER_FILL_BYTE_OP(bool); + +// 2D dispatch: tid.y = dim-0 index (no division), tid.x = linear index for +// dims 1..ndim-1. For an N-dim tensor this requires N-1 divisions instead of N, +// and consecutive threads in x access consecutive addresses in the innermost +// dimension (coalesced writes). +// Strides from TensorIterator are in bytes, use ref_at_offs(out+offs) +template +kernel void fill_scalar_strided( + device void* out [[buffer(0)]], + constant T& fill_val [[buffer(1)]], + constant long* sizes [[buffer(2)]], + constant long* strides [[buffer(3)]], + constant uint& ndim [[buffer(4)]], + uint2 tid [[thread_position_in_grid]]) { + long offset = long(tid.y) * strides[0]; + uint inner = tid.x; + for (uint i = 1; i < ndim; i++) { + offset += long(inner % uint(sizes[i])) * strides[i]; + inner /= uint(sizes[i]); + } + c10::metal::ref_at_offs(out, offset) = fill_val; +} + +#define REGISTER_FILL_STRIDED_OP(T) \ + template [[host_name("fill_scalar_strided_" #T)]] kernel void \ + fill_scalar_strided( \ + device void*, \ + constant T&, \ + constant long*, \ + constant long*, \ + constant uint&, \ + uint2) + +REGISTER_FILL_STRIDED_OP(float); +REGISTER_FILL_STRIDED_OP(half); +REGISTER_FILL_STRIDED_OP(bfloat); +REGISTER_FILL_STRIDED_OP(long); +REGISTER_FILL_STRIDED_OP(ulong); +REGISTER_FILL_STRIDED_OP(int); +REGISTER_FILL_STRIDED_OP(uint); +REGISTER_FILL_STRIDED_OP(short); +REGISTER_FILL_STRIDED_OP(ushort); +REGISTER_FILL_STRIDED_OP(char); +REGISTER_FILL_STRIDED_OP(uchar); +REGISTER_FILL_STRIDED_OP(bool); +REGISTER_FILL_STRIDED_OP(float2); +REGISTER_FILL_STRIDED_OP(half2); diff --git a/aten/src/ATen/native/mps/kernels/Distributions.metal b/aten/src/ATen/native/mps/kernels/Distributions.metal index 6c9a2722f378d..2c8a9404ab533 100644 --- a/aten/src/ATen/native/mps/kernels/Distributions.metal +++ b/aten/src/ATen/native/mps/kernels/Distributions.metal @@ -1,4 +1,5 @@ #include +#include #include using namespace metal; @@ -49,6 +50,26 @@ kernel void geometric( output[tid] = static_cast(result); } +template +kernel void exponential( + device T* output [[buffer(0)]], + constant float2& params [[buffer(1)]], + constant long2& seed_base_offset [[buffer(2)]], + constant uint& numel [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + uint base = tid * 4; + uint4 raw = + c10::metal::philox4::rand(seed_base_offset.x, seed_base_offset.y + tid); + float lambda = params.x; + uint count = min(4u, numel - base); + for (uint i = 0; i < count; ++i) { + float u = clamp( + c10::metal::detail::uint32_to_uniform_float(raw[i]), eps, 1.0f - eps); + output[base + i] = + static_cast(-::metal::precise::log(1.0f - u) / lambda); + } +} + #define REGISTER_OP(NAME, DTYPE) \ template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME( \ device DTYPE*, constant float2&, constant long2&, uint) @@ -69,3 +90,190 @@ REGISTER_OP(geometric, long); REGISTER_OP(geometric, short); REGISTER_OP(geometric, char); REGISTER_OP(geometric, uchar); + +#define REGISTER_EXPONENTIAL(DTYPE) \ + template [[host_name("exponential_" #DTYPE)]] kernel void \ + exponential( \ + device DTYPE*, constant float2&, constant long2&, constant uint&, uint) + +REGISTER_EXPONENTIAL(float); +REGISTER_EXPONENTIAL(half); +REGISTER_EXPONENTIAL(bfloat); + +// Marsaglia & Tsang (2000) acceptance-rejection method for Gamma distribution. +// Adapted from aten/src/ATen/native/Distributions.h sample_gamma(), +// which originates from NumPy's random module (Copyright 2005 Robert Kern). +// Each thread uses a per-thread RNG offset stride to allow variable-length +// rejection loops without colliding with other threads' random streams. +constant constexpr int GAMMA_RANDOMS_STRIDE = 32; + +template +kernel void standard_gamma( + device T* output [[buffer(0)]], + device const T* alpha_in [[buffer(1)]], + constant long2& seed_base_offset [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + float alpha = static_cast(alpha_in[tid]); + float scale = 1.0f; + long base = + seed_base_offset.y + static_cast(tid) * GAMMA_RANDOMS_STRIDE; + long seed = seed_base_offset.x; + int rng_idx = 0; + + // Boost alpha < 1 for higher acceptance probability + if (alpha < 1.0f) { + if (alpha == 0.0f) { + output[tid] = static_cast(0.0f); + return; + } + float u = c10::metal::rand(seed, base + rng_idx++); + scale = ::metal::precise::pow(1.0f - u, 1.0f / alpha); + alpha += 1.0f; + } + + // Marsaglia & Tsang acceptance-rejection + float d = alpha - 1.0f / 3.0f; + float c = 1.0f / ::metal::precise::sqrt(9.0f * d); + for (;;) { + float x, y; + do { + x = c10::metal::randn(seed, base + rng_idx++); + y = 1.0f + c * x; + } while (y <= 0.0f); + float v = y * y * y; + float u = 1.0f - c10::metal::rand(seed, base + rng_idx++); + float xx = x * x; + if (u < 1.0f - 0.0331f * xx * xx) { + float result = scale * d * v; + output[tid] = static_cast(max(result, FLT_MIN)); + return; + } + if (::metal::precise::log(u) < + 0.5f * xx + d * (1.0f - v + ::metal::precise::log(v))) { + float result = scale * d * v; + output[tid] = static_cast(max(result, FLT_MIN)); + return; + } + } +} + +#define REGISTER_GAMMA(DTYPE) \ + template [[host_name("standard_gamma_" #DTYPE)]] \ + kernel void standard_gamma( \ + device DTYPE*, device const DTYPE*, constant long2&, uint) + +REGISTER_GAMMA(float); +REGISTER_GAMMA(half); +REGISTER_GAMMA(bfloat); + +// Reparameterized gradient for Gamma distribution. +// Computes -(d/dalpha cdf(x;alpha)) / pdf(x;alpha). +// Adapted from aten/src/ATen/native/Distributions.h standard_gamma_grad_one(). + +constant constexpr float GAMMA_GRAD_COEF_UV[3][8] = { + {0.16009398f, + -0.094634809f, + 0.025146376f, + -0.0030648343f, + 1.0f, + 0.32668115f, + 0.10406089f, + 0.0014179084f}, + {0.53487893f, + 0.1298071f, + 0.065735949f, + -0.0015649758f, + 0.16639465f, + 0.020070113f, + -0.0035938915f, + -0.00058392623f}, + {0.040121004f, + -0.0065914022f, + -0.0026286047f, + -0.0013441777f, + 0.017050642f, + -0.0021309326f, + 0.00085092367f, + -1.5247877e-07f}, +}; + +template +kernel void standard_gamma_grad( + device T* output [[buffer(0)]], + device const T* self_data [[buffer(1)]], + device const T* output_data [[buffer(2)]], + uint tid [[thread_position_in_grid]]) { + float alpha = static_cast(self_data[tid]); + float x = static_cast(output_data[tid]); + + // Region 1: Small x - Taylor series expansion + if (x < 0.8f) { + float numer = 1.0f; + float denom = alpha; + float series1 = numer / denom; + float series2 = numer / (denom * denom); + for (int i = 1; i <= 5; ++i) { + numer *= -x / static_cast(i); + denom += 1.0f; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + float pow_x_alpha = ::metal::precise::pow(x, alpha); + float gamma_pdf = + ::metal::precise::pow(x, alpha - 1.0f) * ::metal::precise::exp(-x); + float gamma_cdf = pow_x_alpha * series1; + float gamma_cdf_alpha = + (::metal::precise::log(x) - c10::metal::digamma(alpha)) * gamma_cdf - + pow_x_alpha * series2; + float result = -gamma_cdf_alpha / gamma_pdf; + output[tid] = static_cast(isnan(result) ? 0.0f : result); + return; + } + + // Region 2: Large alpha - Rice saddle point expansion + if (alpha > 8.0f) { + if (0.9f * alpha <= x && x <= 1.1f * alpha) { + float numer_1 = 1.0f + 24.0f * alpha * (1.0f + 12.0f * alpha); + float numer_2 = 1440.0f * (alpha * alpha) + + 6.0f * x * (53.0f - 120.0f * x) - 65.0f * x * x / alpha + + alpha * (107.0f + 3600.0f * x); + float denom = 1244160.0f * (alpha * alpha) * (alpha * alpha); + output[tid] = static_cast(numer_1 * numer_2 / denom); + return; + } + float denom = ::metal::precise::sqrt(8.0f * alpha); + float term2 = denom / (alpha - x); + float term3 = ::metal::precise::pow( + x - alpha - alpha * ::metal::precise::log(x / alpha), -1.5f); + float term23 = (x < alpha) ? term2 - term3 : term2 + term3; + float term1 = ::metal::precise::log(x / alpha) * term23 - + ::metal::precise::sqrt(2.0f / alpha) * (alpha + x) / + ((alpha - x) * (alpha - x)); + float stirling = + 1.0f + 1.0f / (12.0f * alpha) * (1.0f + 1.0f / (24.0f * alpha)); + float numer = x * term1; + output[tid] = static_cast(-stirling * numer / denom); + return; + } + + // Region 3: Moderate alpha - bivariate rational approximation + float u = ::metal::precise::log(x / alpha); + float v = ::metal::precise::log(alpha); + float coef_v[8]; + for (int i = 0; i < 8; ++i) { + coef_v[i] = GAMMA_GRAD_COEF_UV[0][i] + + u * (GAMMA_GRAD_COEF_UV[1][i] + u * GAMMA_GRAD_COEF_UV[2][i]); + } + float p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + float q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + output[tid] = static_cast(::metal::precise::exp(p / q)); +} + +#define REGISTER_GAMMA_GRAD(DTYPE) \ + template [[host_name("standard_gamma_grad_" #DTYPE)]] \ + kernel void standard_gamma_grad( \ + device DTYPE*, device const DTYPE*, device const DTYPE*, uint) + +REGISTER_GAMMA_GRAD(float); +REGISTER_GAMMA_GRAD(half); +REGISTER_GAMMA_GRAD(bfloat); diff --git a/aten/src/ATen/native/mps/kernels/Eye.metal b/aten/src/ATen/native/mps/kernels/Eye.metal new file mode 100644 index 0000000000000..ea389866241c4 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/Eye.metal @@ -0,0 +1,59 @@ +#include +using namespace metal; + +// For complex types (float2/half2), "1" is (1, 0), not (1, 1) +template +constexpr T eye_one() { + return static_cast(1); +} +template <> +constexpr float2 eye_one() { + return float2(1.0, 0.0); +} +template <> +constexpr half2 eye_one() { + return half2(half(1.0), half(0.0)); +} + +// Single-pass: writes both 0s and 1s in one dispatch (better for small tensors) +template +kernel void eye( + device T* output [[buffer(0)]], + constant long& stride0 [[buffer(1)]], + constant long& stride1 [[buffer(2)]], + uint2 pos [[thread_position_in_grid]]) { + output[pos.y * stride0 + pos.x * stride1] = + (pos.x == pos.y) ? eye_one() : static_cast(0); +} + +// Diagonal-only: writes 1s to pre-zeroed tensor (better for large tensors) +template +kernel void eye_diag( + device T* output [[buffer(0)]], + constant long& diag_stride [[buffer(1)]], + uint index [[thread_position_in_grid]]) { + output[index * diag_stride] = eye_one(); +} + +#define REGISTER_EYE_OP(DTYPE) \ + template [[host_name("eye_" #DTYPE)]] kernel void eye( \ + device DTYPE * output [[buffer(0)]], \ + constant long& stride0 [[buffer(1)]], \ + constant long& stride1 [[buffer(2)]], \ + uint2 pos [[thread_position_in_grid]]); \ + template [[host_name("eye_diag_" #DTYPE)]] kernel void eye_diag( \ + device DTYPE * output [[buffer(0)]], \ + constant long& diag_stride [[buffer(1)]], \ + uint index [[thread_position_in_grid]]); + +REGISTER_EYE_OP(float); +REGISTER_EYE_OP(half); +REGISTER_EYE_OP(bfloat); +REGISTER_EYE_OP(float2); +REGISTER_EYE_OP(half2); +REGISTER_EYE_OP(long); +REGISTER_EYE_OP(int); +REGISTER_EYE_OP(short); +REGISTER_EYE_OP(char); +REGISTER_EYE_OP(uchar); +REGISTER_EYE_OP(bool); diff --git a/aten/src/ATen/native/mps/kernels/GridSampler.h b/aten/src/ATen/native/mps/kernels/GridSampler.h index c9d4112508613..288c736927cff 100644 --- a/aten/src/ATen/native/mps/kernels/GridSampler.h +++ b/aten/src/ATen/native/mps/kernels/GridSampler.h @@ -12,3 +12,27 @@ struct GridSamplerParams { ::c10::metal::array grid_strides; bool align_corners; }; + +template +struct GridSamplerBackwardParams { + GridSamplerParams forward; + ::c10::metal::array grad_output_strides; + ::c10::metal::array grad_input_strides; + idx_type_t grad_grid_sW; + int32_t padding_mode; +}; + +struct GridSampler3DBackwardParams { + int32_t interpolation_mode; + int32_t padding_mode; + bool align_corners; + bool compute_grad_input; + bool compute_grad_grid; + ::c10::metal::array input_sizes; + ::c10::metal::array output_sizes; + ::c10::metal::array input_strides; + ::c10::metal::array grad_input_strides; + ::c10::metal::array grad_grid_strides; + ::c10::metal::array grid_strides; + ::c10::metal::array grad_output_strides; +}; diff --git a/aten/src/ATen/native/mps/kernels/GridSampler.metal b/aten/src/ATen/native/mps/kernels/GridSampler.metal index 0d8e6aab87e9e..a18593fd0dcd8 100644 --- a/aten/src/ATen/native/mps/kernels/GridSampler.metal +++ b/aten/src/ATen/native/mps/kernels/GridSampler.metal @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -6,59 +8,6 @@ using namespace metal; using namespace c10::metal; -struct GridSamplerOffsets { - int32_t output; - int32_t input; - int32_t grid; - - GridSamplerOffsets() : output(0), input(0), grid(0) {} -}; - -// Find offsets into the tensors that this thread will operate on, -// based on the thread ID. -static GridSamplerOffsets find_grid_sampler_offsets( - constant int32_t* output_sizes, - constant int32_t* output_strides, - constant int32_t* input_strides, - constant int32_t* grid_strides, - int32_t sampler_dims, - uint tid) { - auto dims = sampler_dims + 2; - auto output_idx = static_cast(tid); - GridSamplerOffsets offsets; - - for (auto dim = dims - 1; dim >= 0; dim--) { - auto dim_idx = output_idx % output_sizes[dim]; - output_idx = output_idx / output_sizes[dim]; - - // Select the output element that this thread will calculate. - // output shape: - // 2 sampler dims: (N, C, Hout, Wout) - // 3 sampler dims: (N, C, Dout, Hout, Wout) - offsets.output += output_strides[dim] * dim_idx; - - // Select the batch and channel for the input. - // input shape: - // 2 sampler dims: (N, C, Hin, Win) - // 3 sampler dims: (N, C, Din, Hin, Win) - if (dim < 2) { - offsets.input += input_strides[dim] * dim_idx; - } - - // Select the grid coordinates for the output element. - // grid shape: - // 2 sampler dims: (N, Hout, Wout, 2) - // 3 sampler dims: (N, Dout, Hout, Wout, 3) - if (dim == 0) { - offsets.grid += grid_strides[dim] * dim_idx; - } else if (dim >= 2) { - offsets.grid += grid_strides[dim - 1] * dim_idx; - } - } - - return offsets; -} - // Mod function which gives positive output when `a` is negative static int32_t mod(int32_t a, int32_t b) { auto r = a % b; @@ -80,23 +29,22 @@ static float grid_sampler_unnormalize( } } -// Clip coordinates for border padding -static float clip_coordinates(float in, int32_t clip_limit) { - return ::metal::clamp(in, 0.0, clip_limit - 1.0); +// Clip coordinates to [0, max_val] +static float clip_coordinates(float in, float max_val) { + return ::metal::clamp(in, 0.0f, max_val); } // Reflect coordinates for reflection padding template -static T reflect_coordinates(T in, int32_t twice_low, int32_t twice_high) { - if (twice_low == twice_high) { +static T reflect_coordinates(T in, T low, T high) { + if (low == high) { return 0; } - auto min_val = static_cast(twice_low) / 2; - auto span = static_cast(twice_high - twice_low) / 2; - in = fabs(in - min_val); + auto span = high - low; + in = fabs(in - low); auto extra = fmod(in, span); int32_t flips = static_cast(floor(in / span)); - return (flips % 2 == 0) ? (extra + min_val) : (span - extra + min_val); + return (flips % 2 == 0) ? (extra + low) : (span - extra + low); } // Padding functors: each encapsulates the padding logic for integer indices @@ -108,6 +56,10 @@ struct PadZeros { return (idx < 0 || idx >= input_size) ? IDX_ZERO : idx; } + static float apply_padding(float coord, int32_t, bool) { + return coord; + } + static float compute_source(float coord, int32_t size, bool align_corners) { return grid_sampler_unnormalize(coord, size, align_corners); } @@ -120,9 +72,15 @@ struct PadBorder { return clamp(idx, 0, input_size - 1); } + static float apply_padding(float coord, int32_t size, bool) { + return clip_coordinates(coord, size - 1.0f); + } + static float compute_source(float coord, int32_t size, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); - return clip_coordinates(coord, size); + return apply_padding( + grid_sampler_unnormalize(coord, size, align_corners), + size, + align_corners); } }; @@ -137,46 +95,22 @@ struct PadReflection { return is_reverse ? idx_mod_reverse : idx_mod; } - static float compute_source(float coord, int32_t size, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); + static float apply_padding(float coord, int32_t size, bool align_corners) { if (align_corners) { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + coord = reflect_coordinates(coord, 0.0f, static_cast(size - 1)); } else { - coord = reflect_coordinates(coord, -1, 2 * size - 1); + coord = reflect_coordinates(coord, -0.5f, size - 0.5f); } - return clip_coordinates(coord, size); + return clip_coordinates(coord, size - 1.0f); } -}; - -// Cubic convolution helper 1: for |x| < 1 -template -static T cubic_convolution1(T x, T A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} -// Cubic convolution helper 2: for 1 <= |x| < 2 -template -static T cubic_convolution2(T x, T A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -// Get cubic upsampling coefficients (Catmull-Rom spline with A=-0.75) -template -static void get_cubic_coefficients(T coeffs[4], T t) { - T A = static_cast(-0.75); - coeffs[0] = cubic_convolution2(t + 1, A); - coeffs[1] = cubic_convolution1(t, A); - coeffs[2] = cubic_convolution1(1 - t, A); - coeffs[3] = cubic_convolution2(2 - t, A); -} - -// 1D cubic interpolation -template -static T cubic_interp1d(T x0, T x1, T x2, T x3, T t) { - T coeffs[4]; - get_cubic_coefficients(coeffs, t); - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; -} + static float compute_source(float coord, int32_t size, bool align_corners) { + return apply_padding( + grid_sampler_unnormalize(coord, size, align_corners), + size, + align_corners); + } +}; // 2D Bilinear interpolation template @@ -448,159 +382,534 @@ kernel void grid_sampler_2d( } } -template -T get_tensor_val( +// 3D trilinear interpolation matching the 2D bilinear pattern. +// Takes pre-read grid coordinates as values, returns the interpolated value. +template +static T interpolate_trilinear_3d( constant T* input, - constant int32_t* input_strides, - int32_t indices[dims]) { - bool found_idx_zero = false; - int32_t offset = 0; + opmath_t ix, + opmath_t iy, + opmath_t iz, + int32_t inp_D, + int32_t inp_H, + int32_t inp_W, + int32_t inp_sD, + int32_t inp_sH, + int32_t inp_sW, + bool align_corners) { + ix = grid_sampler_unnormalize(ix, inp_W, align_corners); + iy = grid_sampler_unnormalize(iy, inp_H, align_corners); + iz = grid_sampler_unnormalize(iz, inp_D, align_corners); + + int32_t ix_l = static_cast(floor(ix)); + int32_t iy_l = static_cast(floor(iy)); + int32_t iz_l = static_cast(floor(iz)); + int32_t ix_r = ix_l + 1; + int32_t iy_r = iy_l + 1; + int32_t iz_r = iz_l + 1; + + opmath_t sx = ix - ix_l; + opmath_t sy = iy - iy_l; + opmath_t sz = iz - iz_l; + + int32_t ix_l_p = Pad::pad(ix_l, inp_W, align_corners); + int32_t ix_r_p = Pad::pad(ix_r, inp_W, align_corners); + int32_t iy_l_p = Pad::pad(iy_l, inp_H, align_corners); + int32_t iy_r_p = Pad::pad(iy_r, inp_H, align_corners); + int32_t iz_l_p = Pad::pad(iz_l, inp_D, align_corners); + int32_t iz_r_p = Pad::pad(iz_r, inp_D, align_corners); - for (auto dim = 0; dim < dims; dim++) { - auto idx = indices[dim]; - found_idx_zero = found_idx_zero || (idx == IDX_ZERO); - offset += (found_idx_zero ? 0 : idx) * input_strides[dim]; + opmath_t out_acc = 0; + if (!Pad::checks_bounds || + (iz_l_p != IDX_ZERO && iy_l_p != IDX_ZERO && ix_l_p != IDX_ZERO)) { + out_acc += input[iz_l_p * inp_sD + iy_l_p * inp_sH + ix_l_p * inp_sW] * + (1 - sz) * (1 - sy) * (1 - sx); + } + if (!Pad::checks_bounds || + (iz_l_p != IDX_ZERO && iy_l_p != IDX_ZERO && ix_r_p != IDX_ZERO)) { + out_acc += input[iz_l_p * inp_sD + iy_l_p * inp_sH + ix_r_p * inp_sW] * + (1 - sz) * (1 - sy) * sx; + } + if (!Pad::checks_bounds || + (iz_l_p != IDX_ZERO && iy_r_p != IDX_ZERO && ix_l_p != IDX_ZERO)) { + out_acc += input[iz_l_p * inp_sD + iy_r_p * inp_sH + ix_l_p * inp_sW] * + (1 - sz) * sy * (1 - sx); + } + if (!Pad::checks_bounds || + (iz_l_p != IDX_ZERO && iy_r_p != IDX_ZERO && ix_r_p != IDX_ZERO)) { + out_acc += input[iz_l_p * inp_sD + iy_r_p * inp_sH + ix_r_p * inp_sW] * + (1 - sz) * sy * sx; + } + if (!Pad::checks_bounds || + (iz_r_p != IDX_ZERO && iy_l_p != IDX_ZERO && ix_l_p != IDX_ZERO)) { + out_acc += input[iz_r_p * inp_sD + iy_l_p * inp_sH + ix_l_p * inp_sW] * sz * + (1 - sy) * (1 - sx); + } + if (!Pad::checks_bounds || + (iz_r_p != IDX_ZERO && iy_l_p != IDX_ZERO && ix_r_p != IDX_ZERO)) { + out_acc += input[iz_r_p * inp_sD + iy_l_p * inp_sH + ix_r_p * inp_sW] * sz * + (1 - sy) * sx; + } + if (!Pad::checks_bounds || + (iz_r_p != IDX_ZERO && iy_r_p != IDX_ZERO && ix_l_p != IDX_ZERO)) { + out_acc += input[iz_r_p * inp_sD + iy_r_p * inp_sH + ix_l_p * inp_sW] * sz * + sy * (1 - sx); + } + if (!Pad::checks_bounds || + (iz_r_p != IDX_ZERO && iy_r_p != IDX_ZERO && ix_r_p != IDX_ZERO)) { + out_acc += input[iz_r_p * inp_sD + iy_r_p * inp_sH + ix_r_p * inp_sW] * sz * + sy * sx; } - return found_idx_zero ? 0 : input[offset]; + return static_cast(out_acc); } -// This function performs 3D linear interpolation for one value. One way to -// think of how this works is to imagine a unit cube where each corner of the -// cube has one scalar value associated with it. Inside the cube, the values -// change linearly, so the gradient is constant. The values associated with each -// corner are given by the `input`, indexed at all eight different combinations -// of the `left_indices` and `right_indices`. Given a 3D coordinate anywhere -// within the cube, specified by the `scales` argument, we must calculate the -// value associated with that position. -template -T interpolate_linear_3d( - constant T* input, - constant int32_t* input_strides, - int32_t left_indices[3], - int32_t right_indices[3], - opmath_t scales[3]) { - int32_t a_idx[3] = {left_indices[0], left_indices[1], left_indices[2]}; - int32_t b_idx[3] = {left_indices[0], left_indices[1], right_indices[2]}; - int32_t c_idx[3] = {left_indices[0], right_indices[1], left_indices[2]}; - int32_t d_idx[3] = {left_indices[0], right_indices[1], right_indices[2]}; - int32_t e_idx[3] = {right_indices[0], left_indices[1], left_indices[2]}; - int32_t f_idx[3] = {right_indices[0], left_indices[1], right_indices[2]}; - int32_t g_idx[3] = {right_indices[0], right_indices[1], left_indices[2]}; - int32_t h_idx[3] = {right_indices[0], right_indices[1], right_indices[2]}; - auto a = - static_cast>(get_tensor_val<3>(input, input_strides, a_idx)); - auto b = - static_cast>(get_tensor_val<3>(input, input_strides, b_idx)); - auto c = - static_cast>(get_tensor_val<3>(input, input_strides, c_idx)); - auto d = - static_cast>(get_tensor_val<3>(input, input_strides, d_idx)); - auto e = - static_cast>(get_tensor_val<3>(input, input_strides, e_idx)); - auto f = - static_cast>(get_tensor_val<3>(input, input_strides, f_idx)); - auto g = - static_cast>(get_tensor_val<3>(input, input_strides, g_idx)); - auto h = - static_cast>(get_tensor_val<3>(input, input_strides, h_idx)); - - auto scale0_right = scales[0]; - auto scale1_right = scales[1]; - auto scale2_right = scales[2]; - auto scale0_left = 1 - scale0_right; - auto scale1_left = 1 - scale1_right; - auto scale2_left = 1 - scale2_right; - - return static_cast( - scale0_left * scale1_left * scale2_left * a + - scale0_left * scale1_left * scale2_right * b + - scale0_left * scale1_right * scale2_left * c + - scale0_left * scale1_right * scale2_right * d + - scale0_right * scale1_left * scale2_left * e + - scale0_right * scale1_left * scale2_right * f + - scale0_right * scale1_right * scale2_left * g + - scale0_right * scale1_right * scale2_right * h); -} - -// 3D bilinear sampling for a single output element. +// 3D nearest neighbor interpolation matching the 2D nearest pattern. template -void grid_sampler_3d_single_element( - device T* output, +static T interpolate_nearest_3d( constant T* input, - constant T* coords, - int32_t dims, - constant int32_t* input_sizes, - constant int32_t* input_strides, - int32_t coord_stride, + opmath_t ix, + opmath_t iy, + opmath_t iz, + int32_t inp_D, + int32_t inp_H, + int32_t inp_W, + int32_t inp_sD, + int32_t inp_sH, + int32_t inp_sW, bool align_corners) { - int32_t left_indices[3]; - int32_t right_indices[3]; - opmath_t scales[3]; - - for (auto coord_dim = 0; coord_dim < dims; coord_dim++) { - auto input_dim = dims - coord_dim - 1; - auto input_size = input_sizes[input_dim]; - auto coord = static_cast>(coords[coord_dim * coord_stride]); - - if (!align_corners) { - auto corner_alignment_factor = static_cast>(input_size) / - static_cast>(input_size - 1); - coord = coord * corner_alignment_factor; - } + ix = Pad::compute_source(ix, inp_W, align_corners); + iy = Pad::compute_source(iy, inp_H, align_corners); + iz = Pad::compute_source(iz, inp_D, align_corners); - coord = (coord + 1) * (static_cast>(input_size - 1) / 2); + int32_t ix_nearest = static_cast(rint(ix)); + int32_t iy_nearest = static_cast(rint(iy)); + int32_t iz_nearest = static_cast(rint(iz)); - auto left_idx = static_cast(floor(coord)); - auto right_idx = static_cast(ceil(coord)); - left_indices[input_dim] = Pad::pad(left_idx, input_size, align_corners); - right_indices[input_dim] = Pad::pad(right_idx, input_size, align_corners); - scales[input_dim] = coord - left_idx; + if (Pad::checks_bounds) { + if (ix_nearest < 0 || ix_nearest >= inp_W || iy_nearest < 0 || + iy_nearest >= inp_H || iz_nearest < 0 || iz_nearest >= inp_D) { + return static_cast(0); + } } - *output = interpolate_linear_3d( - input, input_strides, left_indices, right_indices, scales); + return input[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; } -template +// Interpolation strategies for 3D kernel (matching 2D pattern). +template +struct Bilinear3D { + template + static T interpolate( + constant T* input, + opmath_t ix, + opmath_t iy, + opmath_t iz, + int32_t inp_D, + int32_t inp_H, + int32_t inp_W, + int32_t inp_sD, + int32_t inp_sH, + int32_t inp_sW, + bool align_corners) { + return interpolate_trilinear_3d( + input, + ix, + iy, + iz, + inp_D, + inp_H, + inp_W, + inp_sD, + inp_sH, + inp_sW, + align_corners); + } +}; + +template +struct Nearest3D { + template + static T interpolate( + constant T* input, + opmath_t ix, + opmath_t iy, + opmath_t iz, + int32_t inp_D, + int32_t inp_H, + int32_t inp_W, + int32_t inp_sD, + int32_t inp_sH, + int32_t inp_sW, + bool align_corners) { + return interpolate_nearest_3d( + input, + ix, + iy, + iz, + inp_D, + inp_H, + inp_W, + inp_sD, + inp_sH, + inp_sW, + align_corners); + } +}; + +// 3D grid sampler kernel: one thread per output element (n, c, d, h, w). +template kernel void grid_sampler_3d( device T* output [[buffer(0)]], constant T* input [[buffer(1)]], constant T* grid [[buffer(2)]], constant GridSamplerParams<5>& params [[buffer(3)]], uint tid [[thread_position_in_grid]]) { - auto output_sizes = params.output_sizes.data(); - auto output_strides = params.output_strides.data(); - auto input_sizes = params.input_sizes.data(); - auto input_strides = params.input_strides.data(); - auto grid_strides = params.grid_strides.data(); - auto sampler_dims = params.sampler_dims; - - auto offsets = find_grid_sampler_offsets( - output_sizes, - output_strides, - input_strides, - grid_strides, - sampler_dims, - tid); - - output += offsets.output; - input += offsets.input; - auto coords = grid + offsets.grid; - - input_sizes += 2; - input_strides += 2; - auto coord_stride = grid_strides[sampler_dims + 1]; - - grid_sampler_3d_single_element( - output, - input, - coords, - sampler_dims, - input_sizes, - input_strides, - coord_stride, - params.align_corners); + auto C = params.output_sizes[1]; + auto out_D = params.output_sizes[2]; + auto out_H = params.output_sizes[3]; + auto out_W = params.output_sizes[4]; + + auto out_sN = params.output_strides[0]; + auto out_sC = params.output_strides[1]; + auto out_sD = params.output_strides[2]; + auto out_sH = params.output_strides[3]; + auto out_sW = params.output_strides[4]; + auto inp_sN = params.input_strides[0]; + auto inp_sC = params.input_strides[1]; + auto inp_sD = params.input_strides[2]; + auto inp_sH = params.input_strides[3]; + auto inp_sW = params.input_strides[4]; + auto inp_D = params.input_sizes[2]; + auto inp_H = params.input_sizes[3]; + auto inp_W = params.input_sizes[4]; + + auto grid_sN = params.grid_strides[0]; + auto grid_sD = params.grid_strides[1]; + auto grid_sH = params.grid_strides[2]; + auto grid_sW = params.grid_strides[3]; + auto grid_sCoor = params.grid_strides[4]; + + auto align_corners = params.align_corners; + + int32_t w = tid % out_W; + int32_t h = (tid / out_W) % out_H; + int32_t d = (tid / (out_W * out_H)) % out_D; + int32_t c = (tid / (out_W * out_H * out_D)) % C; + int32_t n = tid / (out_W * out_H * out_D * C); + + auto grid_ptr = grid + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + opmath_t ix = static_cast>(grid_ptr[0]); + opmath_t iy = static_cast>(grid_ptr[grid_sCoor]); + opmath_t iz = static_cast>(grid_ptr[2 * grid_sCoor]); + + auto inp_ptr_NC = input + n * inp_sN + c * inp_sC; + auto result = Interp::template interpolate( + inp_ptr_NC, + ix, + iy, + iz, + inp_D, + inp_H, + inp_W, + inp_sD, + inp_sH, + inp_sW, + align_corners); + output[n * out_sN + c * out_sC + d * out_sD + h * out_sH + w * out_sW] = + result; +} + +// Padding mode constants (must match GridSamplerPadding enum) +constant int32_t kPaddingZeros = 0; +constant int32_t kPaddingBorder = 1; +constant int32_t kPaddingReflection = 2; + +// Uses opmath_t for intermediate computations to avoid overflow with +// half/bfloat +template +T grid_sampler_compute_source_index_set_grad( + T coord, + int32_t size, + int32_t padding_mode, + bool align_corners, + thread T* grad_in) { + using U = opmath_t; + U u_coord = static_cast(coord); + U u_grad_in = static_cast(*grad_in); + U u_size = static_cast(size); + + // Unnormalize + if (align_corners) { + u_coord = ((u_coord + U(1.0)) / U(2.0)) * (u_size - U(1.0)); + u_grad_in = (u_size - U(1.0)) / U(2.0); + } else { + u_coord = ((u_coord + U(1.0)) * u_size - U(1.0)) / U(2.0); + u_grad_in = u_size / U(2.0); + } + + if (padding_mode == kPaddingBorder) { + // Borders are considered out of bounds for gradient calculation + // (matching CUDA clip_coordinates_set_grad behavior). + U grad_clip = U(1.0); + if (u_coord <= U(0.0)) { + u_coord = U(0.0); + grad_clip = U(0.0); + } else if (u_coord >= u_size - U(1.0)) { + u_coord = u_size - U(1.0); + grad_clip = U(0.0); + } + u_grad_in = u_grad_in * grad_clip; + } else if (padding_mode == kPaddingReflection) { + U grad_refl = U(1.0); + U twice_low, twice_high; + if (align_corners) { + twice_low = U(0.0); + twice_high = U(2 * (size - 1)); + } else { + twice_low = U(-1.0); + twice_high = U(2 * size - 1); + } + + if (twice_low != twice_high) { + U min_val = twice_low / U(2.0); + U span = (twice_high - twice_low) / U(2.0); + u_coord = u_coord - min_val; + + if (u_coord < U(0.0)) { + u_coord = -u_coord; + grad_refl = -grad_refl; + } + + U extra = u_coord - span * floor(u_coord / span); + int32_t flips = static_cast(floor(u_coord / span)); + + if (flips % 2 == 0) { + u_coord = extra + min_val; + } else { + u_coord = span - extra + min_val; + grad_refl = -grad_refl; + } + } else { + u_coord = U(0.0); + } + + // Clip after reflection (borders out of bounds for gradient) + U grad_clip = U(1.0); + if (u_coord <= U(0.0)) { + u_coord = U(0.0); + grad_clip = U(0.0); + } else if (u_coord >= u_size - U(1.0)) { + u_coord = u_size - U(1.0); + grad_clip = U(0.0); + } + u_grad_in = u_grad_in * grad_refl * grad_clip; + } + + coord = static_cast(u_coord); + *grad_in = static_cast(u_grad_in); + return coord; +} + +inline bool within_bounds_3d( + int32_t z, + int32_t y, + int32_t x, + int32_t D, + int32_t H, + int32_t W) { + return z >= 0 && z < D && y >= 0 && y < H && x >= 0 && x < W; +} + +template +kernel void grid_sampler_3d_backward( + constant T* grad_output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* grid [[buffer(2)]], + device AtomicType_t* grad_input [[buffer(3)]], + device T* grad_grid [[buffer(4)]], + constant GridSampler3DBackwardParams& params [[buffer(5)]], + uint3 thread_index [[thread_position_in_grid]]) { + const auto out_w = thread_index.x; + const auto out_d_h_combined = thread_index.y; + const auto n = thread_index.z; + + const auto out_d = out_d_h_combined / params.output_sizes[3]; + const auto out_h = out_d_h_combined % params.output_sizes[3]; + + if (n >= params.input_sizes[0] || out_d >= params.output_sizes[2] || + out_h >= params.output_sizes[3] || out_w >= params.output_sizes[4]) { + return; + } + + const auto C = params.input_sizes[1]; + const auto inp_D = params.input_sizes[2]; + const auto inp_H = params.input_sizes[3]; + const auto inp_W = params.input_sizes[4]; + + const auto grid_offset = n * params.grid_strides[0] + + out_d * params.grid_strides[1] + out_h * params.grid_strides[2] + + out_w * params.grid_strides[3]; + + const opmath_t grid_x = grid[grid_offset]; + const opmath_t grid_y = grid[grid_offset + params.grid_strides[4]]; + const opmath_t grid_z = grid[grid_offset + 2 * params.grid_strides[4]]; + + opmath_t gix_mult, giy_mult, giz_mult; + opmath_t ix = grid_sampler_compute_source_index_set_grad( + grid_x, + static_cast(inp_W), + params.padding_mode, + params.align_corners, + &gix_mult); + opmath_t iy = grid_sampler_compute_source_index_set_grad( + grid_y, + static_cast(inp_H), + params.padding_mode, + params.align_corners, + &giy_mult); + opmath_t iz = grid_sampler_compute_source_index_set_grad( + grid_z, + static_cast(inp_D), + params.padding_mode, + params.align_corners, + &giz_mult); + + if (params.interpolation_mode == 0) { // trilinear + const int ix_0 = static_cast(floor(ix)); + const int iy_0 = static_cast(floor(iy)); + const int iz_0 = static_cast(floor(iz)); + const opmath_t dx = ix - ix_0; + const opmath_t dy = iy - iy_0; + const opmath_t dz = iz - iz_0; + const opmath_t wx[2] = {1 - dx, dx}; + const opmath_t wy[2] = {1 - dy, dy}; + const opmath_t wz[2] = {1 - dz, dz}; + + opmath_t gix = 0, giy = 0, giz = 0; + + for (uint32_t c = 0; c < C; c++) { + const auto grad_out_offset = n * params.grad_output_strides[0] + + c * params.grad_output_strides[1] + + out_d * params.grad_output_strides[2] + + out_h * params.grad_output_strides[3] + + out_w * params.grad_output_strides[4]; + const opmath_t gOut = grad_output[grad_out_offset]; + const auto base_grad_input_offset = + n * params.grad_input_strides[0] + c * params.grad_input_strides[1]; + const auto input_base_offset = + n * params.input_strides[0] + c * params.input_strides[1]; + + for (int i = 0; i < 8; i++) { + const int xi = i & 1; + const int yi = (i >> 1) & 1; + const int zi = (i >> 2) & 1; + const int cx = ix_0 + xi; + const int cy = iy_0 + yi; + const int cz = iz_0 + zi; + + if (within_bounds_3d( + cz, + cy, + cx, + static_cast(inp_D), + static_cast(inp_H), + static_cast(inp_W))) { + const opmath_t w = wx[xi] * wy[yi] * wz[zi]; + + if (params.compute_grad_input) { + AtomicType::atomic_add( + grad_input, + base_grad_input_offset + cz * params.grad_input_strides[2] + + cy * params.grad_input_strides[3] + + cx * params.grad_input_strides[4], + static_cast(w * gOut)); + } + + if (params.compute_grad_grid) { + const opmath_t val = input + [input_base_offset + cz * params.input_strides[2] + + cy * params.input_strides[3] + cx * params.input_strides[4]]; + const opmath_t sign_x = xi ? 1 : -1; + const opmath_t sign_y = yi ? 1 : -1; + const opmath_t sign_z = zi ? 1 : -1; + gix += sign_x * val * wy[yi] * wz[zi] * gOut; + giy += sign_y * val * wx[xi] * wz[zi] * gOut; + giz += sign_z * val * wx[xi] * wy[yi] * gOut; + } + } + } + } + + if (params.compute_grad_grid) { + const auto grad_grid_base_offset = n * params.grad_grid_strides[0] + + out_d * params.grad_grid_strides[1] + + out_h * params.grad_grid_strides[2] + + out_w * params.grad_grid_strides[3]; + grad_grid[grad_grid_base_offset] = static_cast(gix_mult * gix); + grad_grid[grad_grid_base_offset + params.grid_strides[4]] = + static_cast(giy_mult * giy); + grad_grid[grad_grid_base_offset + 2 * params.grid_strides[4]] = + static_cast(giz_mult * giz); + } + } else if (params.compute_grad_input) { // nearest + int32_t ix_n = static_cast(rint(ix)); + int32_t iy_n = static_cast(rint(iy)); + int32_t iz_n = static_cast(rint(iz)); + + if (params.padding_mode == kPaddingBorder) { + ix_n = clamp(ix_n, 0, static_cast(inp_W - 1)); + iy_n = clamp(iy_n, 0, static_cast(inp_H - 1)); + iz_n = clamp(iz_n, 0, static_cast(inp_D - 1)); + } else if (params.padding_mode == kPaddingReflection) { + if (params.align_corners) { + ix_n = static_cast(rint(reflect_coordinates( + static_cast(ix_n), 0.0f, 2.0f * (inp_W - 1)))); + iy_n = static_cast(rint(reflect_coordinates( + static_cast(iy_n), 0.0f, 2.0f * (inp_H - 1)))); + iz_n = static_cast(rint(reflect_coordinates( + static_cast(iz_n), 0.0f, 2.0f * (inp_D - 1)))); + } else { + ix_n = static_cast(rint(reflect_coordinates( + static_cast(ix_n), -1.0f, 2.0f * inp_W - 1))); + iy_n = static_cast(rint(reflect_coordinates( + static_cast(iy_n), -1.0f, 2.0f * inp_H - 1))); + iz_n = static_cast(rint(reflect_coordinates( + static_cast(iz_n), -1.0f, 2.0f * inp_D - 1))); + } + ix_n = clamp(ix_n, 0, static_cast(inp_W - 1)); + iy_n = clamp(iy_n, 0, static_cast(inp_H - 1)); + iz_n = clamp(iz_n, 0, static_cast(inp_D - 1)); + } + + bool in_bounds = params.padding_mode != kPaddingZeros || + within_bounds_3d(iz_n, + iy_n, + ix_n, + static_cast(inp_D), + static_cast(inp_H), + static_cast(inp_W)); + + if (in_bounds) { + const auto base_offset = n * params.grad_input_strides[0] + + iz_n * params.grad_input_strides[2] + + iy_n * params.grad_input_strides[3] + + ix_n * params.grad_input_strides[4]; + + for (uint32_t c = 0; c < C; c++) { + const auto grad_out_offset = n * params.grad_output_strides[0] + + c * params.grad_output_strides[1] + + out_d * params.grad_output_strides[2] + + out_h * params.grad_output_strides[3] + + out_w * params.grad_output_strides[4]; + const opmath_t gOut = grad_output[grad_out_offset]; + AtomicType::atomic_add( + grad_input, + base_offset + c * params.grad_input_strides[1], + static_cast(gOut)); + } + } + } } #define REGISTER_GRID_SAMPLER_2D(DTYPE, INTERP, INAME, PAD, PNAME) \ @@ -617,23 +926,632 @@ kernel void grid_sampler_3d( REGISTER_GRID_SAMPLER_2D(DTYPE, INTERP, INAME, PadBorder, "border") \ REGISTER_GRID_SAMPLER_2D(DTYPE, INTERP, INAME, PadReflection, "reflection") -#define REGISTER_GRID_SAMPLER_3D(DTYPE, PAD, PNAME) \ - template [[host_name("grid_sampler_3d_" PNAME "_" #DTYPE)]] \ - kernel void grid_sampler_3d( \ - device DTYPE * output [[buffer(0)]], \ - constant DTYPE * input [[buffer(1)]], \ - constant DTYPE * grid [[buffer(2)]], \ - constant GridSamplerParams<5> & params [[buffer(3)]], \ +#define REGISTER_GRID_SAMPLER_3D(DTYPE, INTERP, INAME, PAD, PNAME) \ + template [[host_name("grid_sampler_3d_" INAME "_" PNAME "_" #DTYPE)]] \ + kernel void grid_sampler_3d, DTYPE>( \ + device DTYPE * output [[buffer(0)]], \ + constant DTYPE * input [[buffer(1)]], \ + constant DTYPE * grid [[buffer(2)]], \ + constant GridSamplerParams<5> & params [[buffer(3)]], \ uint tid [[thread_position_in_grid]]); +#define REGISTER_GRID_SAMPLER_3D_INTERP(DTYPE, INTERP, INAME) \ + REGISTER_GRID_SAMPLER_3D(DTYPE, INTERP, INAME, PadZeros, "zeros") \ + REGISTER_GRID_SAMPLER_3D(DTYPE, INTERP, INAME, PadBorder, "border") \ + REGISTER_GRID_SAMPLER_3D(DTYPE, INTERP, INAME, PadReflection, "reflection") + +#define REGISTER_GRID_SAMPLER_BACKWARD(DTYPE) \ + template [[host_name("grid_sampler_3d_backward_" #DTYPE)]] \ + kernel void grid_sampler_3d_backward( \ + constant DTYPE * grad_output [[buffer(0)]], \ + constant DTYPE * input [[buffer(1)]], \ + constant DTYPE * grid [[buffer(2)]], \ + device AtomicType_t * grad_input [[buffer(3)]], \ + device DTYPE * grad_grid [[buffer(4)]], \ + constant GridSampler3DBackwardParams & params [[buffer(5)]], \ + uint3 thread_index [[thread_position_in_grid]]); + #define REGISTER_GRID_SAMPLER_OPS(DTYPE) \ REGISTER_GRID_SAMPLER_2D_INTERP(DTYPE, Bilinear2D, "bilinear") \ REGISTER_GRID_SAMPLER_2D_INTERP(DTYPE, Nearest2D, "nearest") \ REGISTER_GRID_SAMPLER_2D_INTERP(DTYPE, Bicubic2D, "bicubic") \ - REGISTER_GRID_SAMPLER_3D(DTYPE, PadZeros, "zeros") \ - REGISTER_GRID_SAMPLER_3D(DTYPE, PadBorder, "border") \ - REGISTER_GRID_SAMPLER_3D(DTYPE, PadReflection, "reflection") + REGISTER_GRID_SAMPLER_3D_INTERP(DTYPE, Bilinear3D, "bilinear") \ + REGISTER_GRID_SAMPLER_3D_INTERP(DTYPE, Nearest3D, "nearest") \ + REGISTER_GRID_SAMPLER_BACKWARD(DTYPE) REGISTER_GRID_SAMPLER_OPS(float); REGISTER_GRID_SAMPLER_OPS(half); REGISTER_GRID_SAMPLER_OPS(bfloat); + +// ========== Backward kernels ========== + +// Each _set_grad function returns float2{coord, grad} where grad is +// d(output_coord)/d(input_coord), used to chain-rule through the +// coordinate transform in the backward pass. + +static float2 grid_sampler_unnormalize_set_grad( + float coord, + int32_t size, + bool align_corners) { + float grad = align_corners ? (size - 1) / 2.0f : size / 2.0f; + return {grid_sampler_unnormalize(coord, size, align_corners), grad}; +} + +static float2 clip_coordinates_set_grad(float in, float max_val) { + if (in <= 0.0f) { + return {0.0f, 0.0f}; + } + if (in >= max_val) { + return {max_val, 0.0f}; + } + return {in, 1.0f}; +} + +static float2 reflect_coordinates_set_grad(float in, float low, float high) { + if (low == high) { + return {0.0f, 0.0f}; + } + int grad_in_mult = 1; + float span = high - low; + in = in - low; + if (in < 0.0f) { + grad_in_mult = -1; + in = -in; + } + float extra = fmod(in, span); + int flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + return {extra + low, static_cast(grad_in_mult)}; + } + return {span - extra + low, static_cast(-grad_in_mult)}; +} + +// Combines unnormalize + padding, returns {source_index, grad_multiplier}. +template +static float2 compute_source_index_set_grad( + float coord, + int32_t size, + bool align_corners); + +template <> +float2 compute_source_index_set_grad( + float coord, + int32_t size, + bool align_corners) { + return grid_sampler_unnormalize_set_grad(coord, size, align_corners); +} + +template <> +float2 compute_source_index_set_grad( + float coord, + int32_t size, + bool align_corners) { + float2 unnorm = grid_sampler_unnormalize_set_grad(coord, size, align_corners); + float2 clip = clip_coordinates_set_grad(unnorm.x, size - 1.0f); + return {clip.x, unnorm.y * clip.y}; +} + +template <> +float2 compute_source_index_set_grad( + float coord, + int32_t size, + bool align_corners) { + float2 unnorm = grid_sampler_unnormalize_set_grad(coord, size, align_corners); + float2 refl; + if (align_corners) { + refl = reflect_coordinates_set_grad( + unnorm.x, 0.0f, static_cast(size - 1)); + } else { + refl = reflect_coordinates_set_grad(unnorm.x, -0.5f, size - 0.5f); + } + float2 clip = clip_coordinates_set_grad(refl.x, size - 1.0f); + return {clip.x, unnorm.y * refl.y * clip.y}; +} + +// Runtime-dispatch versions for kernels where Pad is not in the hot loop. +static float2 compute_source_index_set_grad( + float coord, + int32_t size, + bool align_corners, + int32_t padding_mode) { + switch (padding_mode) { + case 1: + return compute_source_index_set_grad( + coord, size, align_corners); + case 2: + return compute_source_index_set_grad( + coord, size, align_corners); + default: + return compute_source_index_set_grad( + coord, size, align_corners); + } +} + +static float compute_source( + float coord, + int32_t size, + bool align_corners, + int32_t padding_mode) { + switch (padding_mode) { + case 1: + return PadBorder::compute_source(coord, size, align_corners); + case 2: + return PadReflection::compute_source(coord, size, align_corners); + default: + return PadZeros::compute_source(coord, size, align_corners); + } +} + +static bool within_bounds_2d(int2 pos, int2 size) { + return pos.x >= 0 && pos.x < size.x && pos.y >= 0 && pos.y < size.y; +} + +// Atomic safe add for grad_input +template +static void safe_add_2d_atomic( + device AtomicType_t* data, + int2 pos, + int2 stride, + int2 size, + opmath_t delta, + long NC_offset) { + if (within_bounds_2d(pos, size)) { + AtomicType::atomic_add( + data, + NC_offset + pos.y * stride.y + pos.x * stride.x, + static_cast(delta)); + } +} + +// Apply padding and convert to bounded int2 position (for bicubic backward +// where coordinates are already in pixel space). +template +static int2 apply_padding_2d(float x, float y, int2 size, bool align_corners) { + return { + static_cast(Pad::apply_padding(x, size.x, align_corners)), + static_cast(Pad::apply_padding(y, size.y, align_corners))}; +} + +// Get bounded value for bicubic backward +template +static opmath_t get_value_bounded_backward( + constant T* data, + float x, + float y, + int2 size, + int2 stride, + bool align_corners) { + int2 pos = apply_padding_2d(x, y, size, align_corners); + if (within_bounds_2d(pos, size)) { + return static_cast>(data[pos.y * stride.y + pos.x * stride.x]); + } + return 0; +} + +// Add value at bounded coordinates for bicubic backward grad_input +template +static void add_value_bounded_backward( + device AtomicType_t* data, + float x, + float y, + int2 size, + int2 stride, + opmath_t delta, + bool align_corners, + long NC_offset) { + int2 pos = apply_padding_2d(x, y, size, align_corners); + safe_add_2d_atomic(data, pos, stride, size, delta, NC_offset); +} + +// Get cubic coefficients gradient +template +static void get_cubic_coefficients_grad(T coeffs[4], T t) { + T A = static_cast(-0.75); + T x; + x = -1 - t; + coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A; + x = -t; + coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 1 - t; + coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x; + x = 2 - t; + coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A; +} + +// Common preamble for all backward kernels: decompose thread id into n,h,w +// and read grid coordinates. +template +struct BackwardPreamble { + int32_t n, h, w; + float x, y; + + BackwardPreamble( + constant T* grid, + constant GridSamplerBackwardParams<4>& params, + uint tid) { + auto out_H = params.forward.output_sizes[2]; + auto out_W = params.forward.output_sizes[3]; + w = tid % out_W; + h = (tid / out_W) % out_H; + n = tid / (out_H * out_W); + auto grid_offset = n * params.forward.grid_strides[0] + + h * params.forward.grid_strides[1] + w * params.forward.grid_strides[2]; + x = static_cast(grid[grid_offset]); + y = static_cast(grid[grid_offset + params.forward.grid_strides[3]]); + } +}; + +// Bilinear backward kernel for grad_input +template +kernel void grid_sampler_2d_backward_bilinear_input( + device AtomicType_t* grad_input [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant T* grid [[buffer(2)]], + constant GridSamplerBackwardParams<4>& params [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + BackwardPreamble p(grid, params, tid); + auto C = params.forward.input_sizes[1]; + int2 inp_size = { + params.forward.input_sizes[3], params.forward.input_sizes[2]}; + int2 gInp_stride = { + params.grad_input_strides[3], params.grad_input_strides[2]}; + auto gOut_sN = params.grad_output_strides[0]; + auto gOut_sC = params.grad_output_strides[1]; + auto gOut_sH = params.grad_output_strides[2]; + auto gOut_sW = params.grad_output_strides[3]; + auto gInp_sN = params.grad_input_strides[0]; + auto gInp_sC = params.grad_input_strides[1]; + + float ix = compute_source( + p.x, inp_size.x, params.forward.align_corners, params.padding_mode); + float iy = compute_source( + p.y, inp_size.y, params.forward.align_corners, params.padding_mode); + + int32_t ix_nw = static_cast(floor(ix)); + int32_t iy_nw = static_cast(floor(iy)); + int32_t ix_ne = ix_nw + 1; + int32_t iy_ne = iy_nw; + int32_t ix_sw = ix_nw; + int32_t iy_sw = iy_nw + 1; + int32_t ix_se = ix_nw + 1; + int32_t iy_se = iy_nw + 1; + + float nw = (ix_se - ix) * (iy_se - iy); + float ne = (ix - ix_sw) * (iy_sw - iy); + float sw = (ix_ne - ix) * (iy - iy_ne); + float se = (ix - ix_nw) * (iy - iy_nw); + + auto gOut_ptr_NCHW = + grad_output + p.n * gOut_sN + p.h * gOut_sH + p.w * gOut_sW; + long NC_offset = p.n * gInp_sN; + + for (int32_t c = 0; c < C; + ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + opmath_t gOut = static_cast>(*gOut_ptr_NCHW); + safe_add_2d_atomic( + grad_input, + {ix_nw, iy_nw}, + gInp_stride, + inp_size, + nw * gOut, + NC_offset); + safe_add_2d_atomic( + grad_input, + {ix_ne, iy_ne}, + gInp_stride, + inp_size, + ne * gOut, + NC_offset); + safe_add_2d_atomic( + grad_input, + {ix_sw, iy_sw}, + gInp_stride, + inp_size, + sw * gOut, + NC_offset); + safe_add_2d_atomic( + grad_input, + {ix_se, iy_se}, + gInp_stride, + inp_size, + se * gOut, + NC_offset); + } +} + +// Bilinear backward kernel for grad_grid +template +kernel void grid_sampler_2d_backward_bilinear_grid( + device T* grad_grid [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant T* input [[buffer(2)]], + constant T* grid [[buffer(3)]], + constant GridSamplerBackwardParams<4>& params [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + BackwardPreamble p(grid, params, tid); + auto C = params.forward.input_sizes[1]; + int2 inp_size = { + params.forward.input_sizes[3], params.forward.input_sizes[2]}; + int2 inp_stride = { + params.forward.input_strides[3], params.forward.input_strides[2]}; + auto inp_sN = params.forward.input_strides[0]; + auto inp_sC = params.forward.input_strides[1]; + auto gOut_sN = params.grad_output_strides[0]; + auto gOut_sC = params.grad_output_strides[1]; + auto gOut_sH = params.grad_output_strides[2]; + auto gOut_sW = params.grad_output_strides[3]; + auto gGrid_sW = params.grad_grid_sW; + + // .x = source index, .y = gradient multiplier from coordinate transform + float2 ix = compute_source_index_set_grad( + p.x, inp_size.x, params.forward.align_corners, params.padding_mode); + float2 iy = compute_source_index_set_grad( + p.y, inp_size.y, params.forward.align_corners, params.padding_mode); + + int32_t ix_nw = static_cast(floor(ix.x)); + int32_t iy_nw = static_cast(floor(iy.x)); + + opmath_t gix = 0, giy = 0; + auto gOut_ptr_NCHW = + grad_output + p.n * gOut_sN + p.h * gOut_sH + p.w * gOut_sW; + auto inp_ptr_NC = input + p.n * inp_sN; + + for (int32_t c = 0; c < C; + ++c, inp_ptr_NC += inp_sC, gOut_ptr_NCHW += gOut_sC) { + opmath_t gOut = static_cast>(*gOut_ptr_NCHW); + + if (within_bounds_2d({ix_nw, iy_nw}, inp_size)) { + opmath_t nw_val = + inp_ptr_NC[iy_nw * inp_stride.y + ix_nw * inp_stride.x]; + gix -= nw_val * (iy_nw + 1 - iy.x) * gOut; + giy -= nw_val * (ix_nw + 1 - ix.x) * gOut; + } + if (within_bounds_2d({ix_nw + 1, iy_nw}, inp_size)) { + opmath_t ne_val = + inp_ptr_NC[iy_nw * inp_stride.y + (ix_nw + 1) * inp_stride.x]; + gix += ne_val * (iy_nw + 1 - iy.x) * gOut; + giy -= ne_val * (ix.x - ix_nw) * gOut; + } + if (within_bounds_2d({ix_nw, iy_nw + 1}, inp_size)) { + opmath_t sw_val = + inp_ptr_NC[(iy_nw + 1) * inp_stride.y + ix_nw * inp_stride.x]; + gix -= sw_val * (iy.x - iy_nw) * gOut; + giy += sw_val * (ix_nw + 1 - ix.x) * gOut; + } + if (within_bounds_2d({ix_nw + 1, iy_nw + 1}, inp_size)) { + opmath_t se_val = + inp_ptr_NC[(iy_nw + 1) * inp_stride.y + (ix_nw + 1) * inp_stride.x]; + gix += se_val * (iy.x - iy_nw) * gOut; + giy += se_val * (ix.x - ix_nw) * gOut; + } + } + + auto gGrid_ptr_NHW = grad_grid + tid * gGrid_sW; + gGrid_ptr_NHW[0] = static_cast(ix.y * gix); + gGrid_ptr_NHW[1] = static_cast(iy.y * giy); +} + +// Nearest backward kernel for grad_input +template +kernel void grid_sampler_2d_backward_nearest_input( + device AtomicType_t* grad_input [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant T* grid [[buffer(2)]], + constant GridSamplerBackwardParams<4>& params [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + BackwardPreamble p(grid, params, tid); + int2 inp_size = { + params.forward.input_sizes[3], params.forward.input_sizes[2]}; + int2 gInp_stride = { + params.grad_input_strides[3], params.grad_input_strides[2]}; + auto gOut_sN = params.grad_output_strides[0]; + auto gOut_sC = params.grad_output_strides[1]; + auto gOut_sH = params.grad_output_strides[2]; + auto gOut_sW = params.grad_output_strides[3]; + auto gInp_sN = params.grad_input_strides[0]; + auto gInp_sC = params.grad_input_strides[1]; + + float ix = compute_source( + p.x, inp_size.x, params.forward.align_corners, params.padding_mode); + float iy = compute_source( + p.y, inp_size.y, params.forward.align_corners, params.padding_mode); + int2 nearest = { + static_cast(rint(ix)), static_cast(rint(iy))}; + + auto gOut_ptr_NCHW = + grad_output + p.n * gOut_sN + p.h * gOut_sH + p.w * gOut_sW; + long NC_offset = p.n * gInp_sN; + for (int32_t c = 0; c < params.forward.input_sizes[1]; + ++c, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) { + safe_add_2d_atomic( + grad_input, + nearest, + gInp_stride, + inp_size, + static_cast>(*gOut_ptr_NCHW), + NC_offset); + } +} + +// Bicubic backward kernel for grad_input +template +kernel void grid_sampler_2d_backward_bicubic_input( + device AtomicType_t* grad_input [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant T* grid [[buffer(2)]], + constant GridSamplerBackwardParams<4>& params [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + BackwardPreamble p(grid, params, tid); + auto C = params.forward.input_sizes[1]; + int2 inp_size = { + params.forward.input_sizes[3], params.forward.input_sizes[2]}; + int2 gInp_stride = { + params.grad_input_strides[3], params.grad_input_strides[2]}; + auto gOut_sN = params.grad_output_strides[0]; + auto gOut_sC = params.grad_output_strides[1]; + auto gOut_sH = params.grad_output_strides[2]; + auto gOut_sW = params.grad_output_strides[3]; + auto gInp_sN = params.grad_input_strides[0]; + auto gInp_sC = params.grad_input_strides[1]; + + float ix = + grid_sampler_unnormalize(p.x, inp_size.x, params.forward.align_corners); + float iy = + grid_sampler_unnormalize(p.y, inp_size.y, params.forward.align_corners); + + float ix_nw = floor(ix); + float iy_nw = floor(iy); + + float x_coeffs[4], y_coeffs[4]; + get_cubic_coefficients(x_coeffs, ix - ix_nw); + get_cubic_coefficients(y_coeffs, iy - iy_nw); + + auto gOut_ptr_NCHW = + grad_output + p.n * gOut_sN + p.h * gOut_sH + p.w * gOut_sW; + long NC_offset = p.n * gInp_sN; + int32_t ix_nw_i = static_cast(ix_nw); + int32_t iy_nw_i = static_cast(iy_nw); + + for (int32_t c = 0; c < C; + ++c, gOut_ptr_NCHW += gOut_sC, NC_offset += gInp_sC) { + opmath_t gOut = static_cast>(*gOut_ptr_NCHW); + + for (int32_t i = 0; i < 4; ++i) { + for (int32_t j = 0; j < 4; ++j) { + add_value_bounded_backward( + grad_input, + ix_nw_i - 1 + i, + iy_nw_i - 1 + j, + inp_size, + gInp_stride, + gOut * x_coeffs[i] * y_coeffs[j], + params.forward.align_corners, + NC_offset); + } + } + } +} + +// Bicubic backward kernel for grad_grid +template +kernel void grid_sampler_2d_backward_bicubic_grid( + device T* grad_grid [[buffer(0)]], + constant T* grad_output [[buffer(1)]], + constant T* input [[buffer(2)]], + constant T* grid [[buffer(3)]], + constant GridSamplerBackwardParams<4>& params [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + BackwardPreamble p(grid, params, tid); + auto C = params.forward.input_sizes[1]; + int2 inp_size = { + params.forward.input_sizes[3], params.forward.input_sizes[2]}; + int2 inp_stride = { + params.forward.input_strides[3], params.forward.input_strides[2]}; + auto inp_sN = params.forward.input_strides[0]; + auto inp_sC = params.forward.input_strides[1]; + auto gOut_sN = params.grad_output_strides[0]; + auto gOut_sC = params.grad_output_strides[1]; + auto gOut_sH = params.grad_output_strides[2]; + auto gOut_sW = params.grad_output_strides[3]; + auto gGrid_sW = params.grad_grid_sW; + auto align_corners = params.forward.align_corners; + + float2 ix = grid_sampler_unnormalize_set_grad(p.x, inp_size.x, align_corners); + float2 iy = grid_sampler_unnormalize_set_grad(p.y, inp_size.y, align_corners); + + float ix_nw = floor(ix.x); + float iy_nw = floor(iy.x); + float tx = ix.x - ix_nw; + float ty = iy.x - iy_nw; + + float x_coeffs[4], y_coeffs[4]; + float x_coeffs_grad[4], y_coeffs_grad[4]; + get_cubic_coefficients(x_coeffs, tx); + get_cubic_coefficients(y_coeffs, ty); + get_cubic_coefficients_grad(x_coeffs_grad, tx); + get_cubic_coefficients_grad(y_coeffs_grad, ty); + + opmath_t gix = 0, giy = 0; + auto gOut_ptr_NCHW = + grad_output + p.n * gOut_sN + p.h * gOut_sH + p.w * gOut_sW; + auto inp_ptr_NC = input + p.n * inp_sN; + int32_t ix_nw_i = static_cast(ix_nw); + int32_t iy_nw_i = static_cast(iy_nw); + + for (int32_t c = 0; c < C; + ++c, gOut_ptr_NCHW += gOut_sC, inp_ptr_NC += inp_sC) { + opmath_t gOut = static_cast>(*gOut_ptr_NCHW); + + for (int32_t i = 0; i < 4; ++i) { + for (int32_t j = 0; j < 4; ++j) { + opmath_t val = get_value_bounded_backward( + inp_ptr_NC, + ix_nw_i - 1 + i, + iy_nw_i - 1 + j, + inp_size, + inp_stride, + align_corners); + + gix -= val * x_coeffs_grad[i] * y_coeffs[j] * gOut; + giy -= val * y_coeffs_grad[j] * x_coeffs[i] * gOut; + } + } + } + + auto gGrid_ptr_NHW = grad_grid + tid * gGrid_sW; + gGrid_ptr_NHW[0] = static_cast(ix.y * gix); + gGrid_ptr_NHW[1] = static_cast(iy.y * giy); +} + +// Registration macros for backward kernels. +// Bilinear/nearest _input and bilinear _grid kernels use runtime padding +// dispatch (only templated on dtype). Bicubic keeps the Pad template because +// padding affects its inner loop (16 bounded lookups per channel). +#define REGISTER_GRID_SAMPLER_2D_BACKWARD(DTYPE, INTERP) \ + template [[host_name("grid_sampler_2d_backward_" #INTERP \ + "_input_" #DTYPE)]] kernel void \ + grid_sampler_2d_backward_##INTERP##_input( \ + device AtomicType_t * grad_input [[buffer(0)]], \ + constant DTYPE * grad_output [[buffer(1)]], \ + constant DTYPE * grid [[buffer(2)]], \ + constant GridSamplerBackwardParams<4> & params [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); + +#define REGISTER_GRID_SAMPLER_2D_BACKWARD_BICUBIC(DTYPE, PAD, PNAME) \ + template [[host_name("grid_sampler_2d_backward_bicubic_input_" PNAME \ + "_" #DTYPE)]] kernel void \ + grid_sampler_2d_backward_bicubic_input( \ + device AtomicType_t * grad_input [[buffer(0)]], \ + constant DTYPE * grad_output [[buffer(1)]], \ + constant DTYPE * grid [[buffer(2)]], \ + constant GridSamplerBackwardParams<4> & params [[buffer(3)]], \ + uint tid [[thread_position_in_grid]]); \ + template [[host_name("grid_sampler_2d_backward_bicubic_grid_" PNAME \ + "_" #DTYPE)]] kernel void \ + grid_sampler_2d_backward_bicubic_grid( \ + device DTYPE * grad_grid [[buffer(0)]], \ + constant DTYPE * grad_output [[buffer(1)]], \ + constant DTYPE * input [[buffer(2)]], \ + constant DTYPE * grid [[buffer(3)]], \ + constant GridSamplerBackwardParams<4> & params [[buffer(4)]], \ + uint tid [[thread_position_in_grid]]); + +#define REGISTER_GRID_SAMPLER_2D_BACKWARD_OPS(DTYPE) \ + REGISTER_GRID_SAMPLER_2D_BACKWARD(DTYPE, bilinear) \ + REGISTER_GRID_SAMPLER_2D_BACKWARD(DTYPE, nearest) \ + template [[host_name( \ + "grid_sampler_2d_backward_bilinear_grid_" #DTYPE)]] kernel void \ + grid_sampler_2d_backward_bilinear_grid( \ + device DTYPE * grad_grid [[buffer(0)]], \ + constant DTYPE * grad_output [[buffer(1)]], \ + constant DTYPE * input [[buffer(2)]], \ + constant DTYPE * grid [[buffer(3)]], \ + constant GridSamplerBackwardParams<4> & params [[buffer(4)]], \ + uint tid [[thread_position_in_grid]]); \ + REGISTER_GRID_SAMPLER_2D_BACKWARD_BICUBIC(DTYPE, PadZeros, "zeros") \ + REGISTER_GRID_SAMPLER_2D_BACKWARD_BICUBIC(DTYPE, PadBorder, "border") \ + REGISTER_GRID_SAMPLER_2D_BACKWARD_BICUBIC(DTYPE, PadReflection, "reflection") + +REGISTER_GRID_SAMPLER_2D_BACKWARD_OPS(float); +REGISTER_GRID_SAMPLER_2D_BACKWARD_OPS(half); +REGISTER_GRID_SAMPLER_2D_BACKWARD_OPS(bfloat); diff --git a/aten/src/ATen/native/mps/kernels/HistogramKernel.metal b/aten/src/ATen/native/mps/kernels/HistogramKernel.metal index 689eee6e47f11..7d6275296a4ba 100644 --- a/aten/src/ATen/native/mps/kernels/HistogramKernel.metal +++ b/aten/src/ATen/native/mps/kernels/HistogramKernel.metal @@ -114,6 +114,11 @@ kernel void histogramdd( REGISTER_HISTOGRAMDD_OP(float); REGISTER_HISTOGRAMDD_OP(half); REGISTER_HISTOGRAMDD_OP(bfloat); +REGISTER_HISTOGRAMDD_OP(int); +REGISTER_HISTOGRAMDD_OP(long); +REGISTER_HISTOGRAMDD_OP(short); +REGISTER_HISTOGRAMDD_OP(char); +REGISTER_HISTOGRAMDD_OP(uchar); kernel void kernel_index_offset( constant uint* strides [[buffer(0)]], diff --git a/aten/src/ATen/native/mps/kernels/Indexing.metal b/aten/src/ATen/native/mps/kernels/Indexing.metal index 463c1812d0302..dcc0c4ff3cb82 100644 --- a/aten/src/ATen/native/mps/kernels/Indexing.metal +++ b/aten/src/ATen/native/mps/kernels/Indexing.metal @@ -760,14 +760,6 @@ INSTANTIATE_INDEX_FILL(uchar); INSTANTIATE_INDEX_FILL(float2); INSTANTIATE_INDEX_FILL(half2); -[[host_name("index_fill_set_mask")]] -kernel void index_fill_set_mask( - device bool*, - constant long*, - constant long&, - constant long&, - uint); - #define INSTANTIATE_INDEX_FILL_FROM_MASK(T) \ template [[host_name("index_fill_dense_from_mask_" #T)]] \ kernel void index_fill_dense_from_mask( \ @@ -799,3 +791,237 @@ INSTANTIATE_INDEX_FILL_FROM_MASK(uchar) INSTANTIATE_INDEX_FILL_FROM_MASK(bool) INSTANTIATE_INDEX_FILL_FROM_MASK(float2) INSTANTIATE_INDEX_FILL_FROM_MASK(half2) + +// Nonzero kernel implementation using prefix-sum + scatter approach. +// +// Step 1 (count_nonzero_prefix_sum): Each threadgroup computes an exclusive +// prefix sum of the nonzero flags over its chunk. Per-threadgroup totals are +// written to block_sums. +// +// Step 2 (prefix_sum_blocks): A single threadgroup computes the exclusive +// prefix sum of block_sums → block_offsets and writes the total nonzero count +// to a 1-element buffer. The host reads back only that single int, then +// allocates the output tensor. +// +// Step 3 (scatter_nonzero_indices): Each thread with a nonzero element writes +// its multi-dimensional indices into the output at the position determined by +// block_offsets[tgid] + prefix[tid]. + +template , bool> = true> +inline bool is_nonzero(T val) { + return val != T(0); +} + +template , bool> = true> +inline bool is_nonzero(T val) { + return val.x != 0 || val.y != 0; +} + +template +[[max_total_threads_per_threadgroup(1024)]] +kernel void count_nonzero_prefix_sum( + const device T* input [[buffer(0)]], + device int* prefix [[buffer(1)]], + device int* block_sums [[buffer(2)]], + uint tid [[thread_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint tgsize [[threads_per_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + uint num_simds = (tgsize + simdgroup_size - 1) / simdgroup_size; + + int flag = is_nonzero(input[tid]) ? 1 : 0; + + // Inclusive prefix sum within SIMD group using shuffle + int val = flag; + for (uint offset = 1; offset < simdgroup_size; offset <<= 1) { + int other = simd_shuffle_and_fill_up(val, 0, static_cast(offset)); + val += other; + } + + // The last lane in each simd group writes its total. + // For full groups this is lane 31; for the last (partial) group we compute + // which lane is actually last. + threadgroup int simdgroup_totals[32]; + bool is_last_lane_in_simd; + if (simd_group_id < num_simds - 1) { + is_last_lane_in_simd = (simd_lane_id == simdgroup_size - 1); + } else { + uint lanes_in_last = tgsize - simd_group_id * simdgroup_size; + is_last_lane_in_simd = (simd_lane_id == lanes_in_last - 1); + } + if (is_last_lane_in_simd) { + simdgroup_totals[simd_group_id] = val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // First simd group computes exclusive prefix sum of simd group totals + threadgroup int simdgroup_offsets[32]; + if (simd_group_id == 0) { + int sg_val = + (simd_lane_id < num_simds) ? simdgroup_totals[simd_lane_id] : 0; + for (uint offset = 1; offset < simdgroup_size; offset <<= 1) { + int other = + simd_shuffle_and_fill_up(sg_val, 0, static_cast(offset)); + sg_val += other; + } + int exclusive = simd_shuffle_and_fill_up(sg_val, 0, static_cast(1)); + simdgroup_offsets[simd_lane_id] = exclusive; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + int exclusive_val = val - flag + simdgroup_offsets[simd_group_id]; + + prefix[tid] = exclusive_val; + + if (lid == tgsize - 1) { + block_sums[tgid] = + simdgroup_offsets[num_simds - 1] + simdgroup_totals[num_simds - 1]; + } +} + +// Step 2: exclusive prefix sum of block_sums → block_offsets, and write +// total nonzero count to a 1-element buffer. Runs in a single threadgroup. +// Each thread handles ceil(num_blocks / tgsize) consecutive blocks via a +// serial loop, then the per-thread totals are scanned in parallel. +[[max_total_threads_per_threadgroup(1024)]] +kernel void prefix_sum_blocks( + const device int* block_sums [[buffer(0)]], + device int* block_offsets [[buffer(1)]], + device int* total_nonzero [[buffer(2)]], + constant uint& num_blocks [[buffer(3)]], + uint lid [[thread_position_in_threadgroup]], + uint tgsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + uint num_simds = (tgsize + simdgroup_size - 1) / simdgroup_size; + + // Each thread handles a contiguous chunk of blocks + uint chunk_size = (num_blocks + tgsize - 1) / tgsize; + uint start = lid * chunk_size; + uint end = min(start + chunk_size, num_blocks); + + // Serial sum over this thread's chunk + int chunk_total = 0; + for (uint i = start; i < end; i++) { + chunk_total += block_sums[i]; + } + + // Parallel inclusive prefix sum of chunk_totals across threads + int val = chunk_total; + for (uint offset = 1; offset < simdgroup_size; offset <<= 1) { + int other = simd_shuffle_and_fill_up(val, 0, static_cast(offset)); + val += other; + } + + threadgroup int simdgroup_totals[32]; + bool is_last_lane_in_simd; + if (simd_group_id < num_simds - 1) { + is_last_lane_in_simd = (simd_lane_id == simdgroup_size - 1); + } else { + uint lanes_in_last = tgsize - simd_group_id * simdgroup_size; + is_last_lane_in_simd = (simd_lane_id == lanes_in_last - 1); + } + if (is_last_lane_in_simd) { + simdgroup_totals[simd_group_id] = val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + threadgroup int simdgroup_offsets[32]; + if (simd_group_id == 0) { + int sg_val = + (simd_lane_id < num_simds) ? simdgroup_totals[simd_lane_id] : 0; + for (uint offset = 1; offset < simdgroup_size; offset <<= 1) { + int other = + simd_shuffle_and_fill_up(sg_val, 0, static_cast(offset)); + sg_val += other; + } + int exclusive = simd_shuffle_and_fill_up(sg_val, 0, static_cast(1)); + simdgroup_offsets[simd_lane_id] = exclusive; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // This thread's exclusive offset = inclusive_scan - chunk_total + + // simdgroup_offset + int thread_offset = val - chunk_total + simdgroup_offsets[simd_group_id]; + + // Write block_offsets for this thread's chunk using a serial exclusive scan + int running = thread_offset; + for (uint i = start; i < end; i++) { + block_offsets[i] = running; + running += block_sums[i]; + } + + if (lid == tgsize - 1) { + *total_nonzero = + simdgroup_offsets[num_simds - 1] + simdgroup_totals[num_simds - 1]; + } +} + +// Scatter the multi-dimensional indices of nonzero elements. +// Output layout: out[position * ndim + d] = index along dimension d. +template +[[max_total_threads_per_threadgroup(1024)]] +kernel void scatter_nonzero_indices( + const device T* input [[buffer(0)]], + const device int* prefix [[buffer(1)]], + device int64_t* output [[buffer(2)]], + constant int& ndim [[buffer(3)]], + constant int64_t* sizes [[buffer(4)]], + constant int* block_offsets [[buffer(5)]], + constant int& max_entries [[buffer(6)]], + uint tid [[thread_position_in_grid]], + uint tgid [[threadgroup_position_in_grid]]) { + if (!is_nonzero(input[tid])) + return; + + int pos = block_offsets[tgid] + prefix[tid]; + if (pos >= max_entries) + return; + + uint flat = tid; + for (int d = ndim - 1; d >= 0; d--) { + int64_t dim_size = sizes[d]; + output[pos * ndim + d] = + static_cast(flat % static_cast(dim_size)); + flat /= static_cast(dim_size); + } +} + +#define REGISTER_NONZERO_KERNELS(DTYPE) \ + template [[host_name("count_nonzero_prefix_sum_" #DTYPE)]] [[kernel]] void \ + count_nonzero_prefix_sum( \ + const device DTYPE* input [[buffer(0)]], \ + device int* prefix [[buffer(1)]], \ + device int* block_sums [[buffer(2)]], \ + uint tid [[thread_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint tgsize [[threads_per_threadgroup]], \ + uint tgid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); \ + \ + template [[host_name("scatter_nonzero_indices_" #DTYPE)]] [[kernel]] void \ + scatter_nonzero_indices( \ + const device DTYPE* input [[buffer(0)]], \ + const device int* prefix [[buffer(1)]], \ + device int64_t* output [[buffer(2)]], \ + constant int& ndim [[buffer(3)]], \ + constant int64_t* sizes [[buffer(4)]], \ + constant int* block_offsets [[buffer(5)]], \ + constant int& max_entries [[buffer(6)]], \ + uint tid [[thread_position_in_grid]], \ + uint tgid [[threadgroup_position_in_grid]]) + +REGISTER_NONZERO_KERNELS(float); +REGISTER_NONZERO_KERNELS(half); +REGISTER_NONZERO_KERNELS(bfloat); +REGISTER_NONZERO_KERNELS(long); +REGISTER_NONZERO_KERNELS(int); +REGISTER_NONZERO_KERNELS(short); +REGISTER_NONZERO_KERNELS(char); +REGISTER_NONZERO_KERNELS(uchar); +REGISTER_NONZERO_KERNELS(bool); +REGISTER_NONZERO_KERNELS(float2); +REGISTER_NONZERO_KERNELS(half2); diff --git a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal index ad68c9549fa00..cacdee974bc75 100644 --- a/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal +++ b/aten/src/ATen/native/mps/kernels/LinearAlgebra.metal @@ -15,26 +15,27 @@ inline c10::metal::opmath_t matmul_inner( constant T* mat2Data, constant array& strides, constant uint3& sizes, - threadgroup T A_tile[TILE_DIM][TILE_DIM], - threadgroup T B_tile[TILE_DIM][TILE_DIM], + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM], + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM], uint2 tid, uint2 thread_id) { - c10::metal::opmath_t sum = 0; + using TA = c10::metal::opmath_t; + TA sum = 0; uint numTiles = (sizes.y + TILE_DIM - 1) / TILE_DIM; for (uint t = 0; t < numTiles; t++) { uint tiledCol = t * TILE_DIM + tid.x; if (thread_id.y < sizes.x && tiledCol < sizes.y) { - A_tile[tid.y][tid.x] = - mat1Data[thread_id.y * strides[0].x + tiledCol * strides[0].y]; + A_tile[tid.y][tid.x] = static_cast( + mat1Data[thread_id.y * strides[0].x + tiledCol * strides[0].y]); } else { A_tile[tid.y][tid.x] = 0; } uint tiledRow = t * TILE_DIM + tid.y; if (tiledRow < sizes.y && thread_id.x < sizes.z) { - B_tile[tid.y][tid.x] = - mat2Data[tiledRow * strides[1].x + thread_id.x * strides[1].y]; + B_tile[tid.y][tid.x] = static_cast( + mat2Data[tiledRow * strides[1].x + thread_id.x * strides[1].y]); } else { B_tile[tid.y][tid.x] = 0; } @@ -58,12 +59,13 @@ inline c10::metal::opmath_t batched_matmul_inner( uint batch, constant array& strides, constant uint4& sizes, - threadgroup T A_tile[TILE_DIM][TILE_DIM], - threadgroup T B_tile[TILE_DIM][TILE_DIM], + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM], + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM], uint3 tid, uint row, uint col) { - c10::metal::opmath_t sum = 0; + using TA = c10::metal::opmath_t; + TA sum = 0; // Compute batch offsets uint batch1Offset = batch * strides[2]; @@ -73,16 +75,16 @@ inline c10::metal::opmath_t batched_matmul_inner( for (uint t = 0; t < numTiles; t++) { uint tiledCol = t * TILE_DIM + tid.x; if (row < sizes.x && tiledCol < sizes.y) { - A_tile[tid.y][tid.x] = - mat1Data[batch1Offset + row * strides[1] + tiledCol * strides[0]]; + A_tile[tid.y][tid.x] = static_cast( + mat1Data[batch1Offset + row * strides[1] + tiledCol * strides[0]]); } else { A_tile[tid.y][tid.x] = 0; } uint tiledRow = t * TILE_DIM + tid.y; if (tiledRow < sizes.y && col < sizes.z) { - B_tile[tid.y][tid.x] = - mat2Data[batch2Offset + tiledRow * strides[4] + col * strides[3]]; + B_tile[tid.y][tid.x] = static_cast( + mat2Data[batch2Offset + tiledRow * strides[4] + col * strides[3]]); } else { B_tile[tid.y][tid.x] = 0; } @@ -108,8 +110,8 @@ kernel void matmul( constant uint3& sizes [[buffer(4)]], uint2 tid [[thread_position_in_threadgroup]], uint2 thread_id [[thread_position_in_grid]]) { - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM]; auto sum = matmul_inner( mat1Data, mat2Data, strides, sizes, A_tile, B_tile, tid, thread_id); @@ -130,8 +132,8 @@ kernel void addmm( constant uint3& sizes [[buffer(6)]], uint2 tid [[thread_position_in_threadgroup]], uint2 thread_id [[thread_position_in_grid]]) { - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM]; auto sum = matmul_inner( mat1Data, @@ -143,8 +145,9 @@ kernel void addmm( tid, thread_id); if (thread_id.y < sizes.x && thread_id.x < sizes.z) { - auto bias = - biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y]; + using TA = c10::metal::opmath_t; + auto bias = static_cast( + biasData[thread_id.y * strides[3].x + thread_id.x * strides[3].y]); outputData[thread_id.y * strides[2].x + thread_id.x * strides[2].y] = static_cast( c10::metal::mul(alpha_beta[0], sum) + @@ -165,8 +168,8 @@ kernel void naive_bmm( uint col = group_id.x * TILE_DIM + tid.x; uint row = group_id.y * TILE_DIM + tid.y; - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM]; auto sum = batched_matmul_inner( mat1Data, mat2Data, batch, strides, sizes, A_tile, B_tile, tid, row, col); @@ -192,15 +195,17 @@ kernel void naive_baddbmm( uint col = group_id.x * TILE_DIM + tid.x; uint row = group_id.y * TILE_DIM + tid.y; - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM]; auto sum = batched_matmul_inner( mat1Data, mat2Data, batch, strides, sizes, A_tile, B_tile, tid, row, col); if (row < sizes.x && col < sizes.z) { + using TA = c10::metal::opmath_t; uint biasOffset = batch * strides[11]; - auto bias = biasData[biasOffset + row * strides[10] + col * strides[9]]; + auto bias = static_cast( + biasData[biasOffset + row * strides[10] + col * strides[9]]); outputData[batch * strides[8] + col * strides[6] + row * strides[7]] = static_cast( c10::metal::mul(alpha_beta[0], sum) + @@ -224,8 +229,8 @@ kernel void naive_addbmm( c10::metal::opmath_t sum = 0; - threadgroup T A_tile[TILE_DIM][TILE_DIM]; - threadgroup T B_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t A_tile[TILE_DIM][TILE_DIM]; + threadgroup c10::metal::opmath_t B_tile[TILE_DIM][TILE_DIM]; // Iterate through all batches and accumulate for (uint batch = 0; batch < sizes.w; batch++) { @@ -243,7 +248,8 @@ kernel void naive_addbmm( } if (row < sizes.x && col < sizes.z) { - auto bias = biasData[row * strides[10] + col * strides[9]]; + using TA = c10::metal::opmath_t; + auto bias = static_cast(biasData[row * strides[10] + col * strides[9]]); outputData[row * strides[7] + col * strides[6]] = static_cast( c10::metal::mul(alpha_beta[0], sum) + c10::metal::mul(alpha_beta[1], bias)); diff --git a/aten/src/ATen/native/mps/kernels/ReduceOps.h b/aten/src/ATen/native/mps/kernels/ReduceOps.h new file mode 100644 index 0000000000000..f7e6db407dcf4 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ReduceOps.h @@ -0,0 +1,18 @@ +#pragma once +#include + +#define MAX_THREADGROUP_SIZE static_cast(1024) +C10_METAL_CONSTEXPR uint32_t SUM_NCHAINS = 8; + +template +struct NormParams { + float p; + uint32_t reduction_size; + uint32_t ndim; + + ::c10::metal::array input_sizes; + ::c10::metal::array input_strides; + + ::c10::metal::array output_sizes; + ::c10::metal::array output_strides; +}; diff --git a/aten/src/ATen/native/mps/kernels/ReduceOps.metal b/aten/src/ATen/native/mps/kernels/ReduceOps.metal new file mode 100644 index 0000000000000..76dcc03e178da --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/ReduceOps.metal @@ -0,0 +1,674 @@ +#include +#include +#include +#include +#include + +using namespace metal; +using namespace c10::metal; + +struct norm_abs_functor { + template , bool> = true> + inline T operator()(const T x) { + return static_cast(::precise::abs(x)); + } + + template , bool> = true> + inline float operator()(const T x) { + const auto abs_2 = ::precise::abs(float2(x)); + return c10::metal::hypot(abs_2.x, abs_2.y); + } +}; + +// `reduction_idx` is the index of a particular batch of input elements that all +// get reduced to one output element. `reduction_element_idx` is the index of +// just one input element within its batch. +static uint32_t get_input_offset( + uint32_t reduction_element_idx, + uint32_t reduction_idx, + constant NormParams<>& params) { + uint32_t input_offset = 0; + + for (int32_t dim = params.ndim - 1; dim >= 0; dim--) { + auto input_dim_size = params.input_sizes[dim]; + auto output_dim_size = params.output_sizes[dim]; + + // If the the input and output have the same size for this dim, then this + // dim is not being reduced, so we index by `reduction_idx` + if (input_dim_size == output_dim_size) { + auto index_in_dim = reduction_idx % input_dim_size; + reduction_idx /= input_dim_size; + input_offset += index_in_dim * params.input_strides[dim]; + + // Otherwise, this dim is being reduced, so we index by + // `reduction_element_idx` + } else { + auto index_in_dim = reduction_element_idx % input_dim_size; + reduction_element_idx /= input_dim_size; + input_offset += index_in_dim * params.input_strides[dim]; + } + } + return input_offset; +} + +// In this kernel, each threadgroup is responsible for calculating one element +// of the output. +// TI - dtype of the input tensor. +// TO - dtype of the output tensor. +template +kernel void norm( + constant TI* input [[buffer(0)]], + device TO* output [[buffer(1)]], + constant NormParams<>& params [[buffer(2)]], + uint tid [[thread_position_in_threadgroup]], + uint tptg [[threads_per_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simdgroup_id [[simdgroup_index_in_threadgroup]], + uint simdgroup_size [[threads_per_simdgroup]]) { + using TA = opmath_t; + TA output_val = 0; + const auto p = static_cast(params.p); + + if (p == INFINITY) { + output_val = -INFINITY; + } else if (p == -INFINITY) { + output_val = INFINITY; + } + + // First, all the input elements assigned to the threadgroup are divided + // between all the threads in the threadgroup, and each thread reduces those + // elements down to one partial `output_val`. + for (uint32_t reduction_element_idx = tid; + reduction_element_idx < params.reduction_size; + reduction_element_idx += tptg) { + auto input_elem = + input[get_input_offset(reduction_element_idx, tgid, params)]; + auto input_abs = static_cast(norm_abs_functor()(input_elem)); + + if (p == INFINITY) { + output_val = max(input_abs, output_val); + + } else if (p == -INFINITY) { + output_val = min(input_abs, output_val); + + } else if (p == 0) { + output_val += (input_abs == 0) ? 0 : 1; + + } else { + output_val += static_cast(::precise::pow(input_abs, p)); + } + } + + // Next, all the threads in a threadgroup reduce their `output_val`s together + // with a series of SIMD group reductions. + auto threads_remaining = tptg; + threadgroup TA shared_outputs[MAX_THREADGROUP_SIZE]; + + while (threads_remaining > 1) { + if (p == INFINITY) { + output_val = simd_max(output_val); + } else if (p == -INFINITY) { + output_val = simd_min(output_val); + } else { + output_val = simd_sum(output_val); + } + + threads_remaining = ceil_div(threads_remaining, simdgroup_size); + + if (threads_remaining > 1) { + // One thread from each SIMD group writes to a shared buffer + if (simd_lane_id == 0) { + shared_outputs[simdgroup_id] = output_val; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The remaining threads each read one of the partial outputs from the + // shared buffer + if (tid < threads_remaining) { + output_val = shared_outputs[tid]; + } else { + return; + } + } + } + + // Finally, one thread in the threadgroup writes the final output + if (tid == 0) { + uint32_t output_offset = 0; + uint32_t reduction_idx = tgid; + + for (int32_t dim = params.ndim - 1; dim >= 0; dim--) { + auto output_dim_size = params.output_sizes[dim]; + + if (output_dim_size > 1) { + auto index_in_dim = reduction_idx % output_dim_size; + reduction_idx /= output_dim_size; + output_offset += index_in_dim * params.output_strides[dim]; + } + } + + if (p != 0 && p != 1 && p != INFINITY && p != -INFINITY) { + output_val = static_cast(::precise::pow(output_val, 1 / p)); + } + output[output_offset] = static_cast(output_val); + } +} + +#define REGISTER_NORM(TI, TO) \ + template [[host_name("norm_" #TI "_" #TO)]] \ + kernel void norm( \ + constant TI * input [[buffer(0)]], \ + device TO * output [[buffer(1)]], \ + constant NormParams<> & params [[buffer(2)]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint tptg [[threads_per_threadgroup]], \ + uint tgid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simdgroup_id [[simdgroup_index_in_threadgroup]], \ + uint simdgroup_size [[threads_per_simdgroup]]); + +REGISTER_NORM(float, float); +REGISTER_NORM(half, half); +REGISTER_NORM(bfloat, bfloat); +REGISTER_NORM(float2, float); +REGISTER_NORM(half2, half); + +#include + +// Load modes for sum_reduction: identity (sum), nan-to-zero (nansum), +// or nonzero-as-one (count_nonzero). +enum LoadMode : uint { + LOAD_IDENTITY = 0, + LOAD_NAN_TO_ZERO = 1, + LOAD_NONZERO = 2 +}; + +template , bool> = true> +inline bool load_is_nonzero(T v) { + return v != T(0); +} + +template , bool> = true> +inline bool load_is_nonzero(T v) { + return v.x != 0 || v.y != 0; +} + +// Load helper: cast to opmath_t, optionally replacing NaN with zero, +// or map nonzero to 1 for count_nonzero semantics. +template < + LoadMode MODE, + typename TI, + ::metal::enable_if_t = true> +inline opmath_t load_val(TI v) { + return static_cast>(v); +} + +template < + LoadMode MODE, + typename TI, + ::metal::enable_if_t = true> +inline opmath_t load_val(TI v) { + auto r = static_cast>(v); + if (::metal::isnan(static_cast(r))) + r = 0; + return r; +} + +// LOAD_NONZERO returns uint: MPS tensor numel fits in uint32, so per-TG +// (and per-output-element) non-zero counts cannot overflow. This lets +// count_nonzero accumulate in 32-bit integer instead of 64-bit, which is a +// meaningful speedup for small inputs (especially bool) where compute +// overhead dominates. The final cast back to long happens at the output +// store in the kernel. +template < + LoadMode MODE, + typename TI, + ::metal::enable_if_t = true> +inline uint load_val(TI v) { + return load_is_nonzero(v) ? 1u : 0u; +} + +// Sum reduction kernel with multiple independent accumulation chains (ILP). +// Each thread maintains NCHAINS independent accumulators to hide ALU latency +// and keep the memory pipeline saturated. +// +// Two internal paths selected per-threadgroup (not per-element): +// - Single reduced dim (or full reduction): compute input_base + k * stride +// once per TG, then direct indexing — no per-element dim loop. +// - Multiple reduced dims: fall back to get_input_offset per element. +// MODE: LOAD_IDENTITY (sum), LOAD_NAN_TO_ZERO (nansum), +// LOAD_NONZERO (count_nonzero — contributes 1 per nonzero element). +// The compiler eliminates dead branches per instantiation. +template < + typename TI, + typename TO, + uint NCHAINS = SUM_NCHAINS, + LoadMode MODE = LOAD_IDENTITY> +kernel void sum_reduction( + constant TI* input [[buffer(0)]], + device TO* output [[buffer(1)]], + constant NormParams<>& params [[buffer(2)]], + uint tid [[thread_position_in_threadgroup]], + uint tptg [[threads_per_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simdgroup_id [[simdgroup_index_in_threadgroup]], + uint simdgroup_size [[threads_per_simdgroup]]) { + using TA = ::metal::conditional_t>; + + // Compute input_base (once per TG) and detect reduction pattern. + // For single reduced dim: input_base + k * reduction_stride gives + // the k-th reduction element — no per-element dim loop needed. + uint32_t input_base = 0; + uint32_t reduction_stride = 1; + uint32_t num_reduced_dims = 0; + { + uint32_t out_idx = tgid; + for (int32_t dim = params.ndim - 1; dim >= 0; dim--) { + if (params.input_sizes[dim] != params.output_sizes[dim]) { + num_reduced_dims++; + reduction_stride = params.input_strides[dim]; + } else { + auto idx = out_idx % params.output_sizes[dim]; + out_idx /= params.output_sizes[dim]; + input_base += idx * params.input_strides[dim]; + } + } + } + + // Load helper: cast to accumulator type, optionally replacing NaN with zero + + metal::array acc; + for (uint j = 0; j < NCHAINS; j++) { + acc[j] = 0; + } + + const uint32_t rsize = params.reduction_size; + const uint32_t stride = tptg * NCHAINS; + uint32_t base = tid * NCHAINS; + + if (num_reduced_dims <= 1) { + // Fast path: direct indexing with base + k * reduction_stride + for (; base + NCHAINS <= rsize; base += stride) { + for (uint j = 0; j < NCHAINS; j++) { + acc[j] += + load_val(input[input_base + (base + j) * reduction_stride]); + } + } + for (uint32_t idx = base; idx < rsize; idx++) { + acc[idx % NCHAINS] += + load_val(input[input_base + idx * reduction_stride]); + } + } else { + // Generic path: per-element strided offset for multi-dim reductions + for (; base + NCHAINS <= rsize; base += stride) { + for (uint j = 0; j < NCHAINS; j++) { + acc[j] += + load_val(input[get_input_offset(base + j, tgid, params)]); + } + } + for (uint32_t idx = base; idx < rsize; idx++) { + acc[idx % NCHAINS] += + load_val(input[get_input_offset(idx, tgid, params)]); + } + } + + // Collapse chains into a single value + TA output_val = acc[0]; + for (uint j = 1; j < NCHAINS; j++) { + output_val += acc[j]; + } + + // SIMD + threadgroup tree reduction + auto threads_remaining = tptg; + threadgroup TA shared_outputs[MAX_THREADGROUP_SIZE]; + + while (threads_remaining > 1) { + output_val = c10::metal::simd_sum(output_val); + threads_remaining = ceil_div(threads_remaining, simdgroup_size); + + if (threads_remaining > 1) { + if (simd_lane_id == 0) { + shared_outputs[simdgroup_id] = output_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (tid < threads_remaining) { + output_val = shared_outputs[tid]; + } else { + return; + } + } + } + + if (tid == 0) { + uint32_t output_offset = 0; + uint32_t reduction_idx = tgid; + + for (int32_t dim = params.ndim - 1; dim >= 0; dim--) { + auto output_dim_size = params.output_sizes[dim]; + if (output_dim_size > 1) { + auto index_in_dim = reduction_idx % output_dim_size; + reduction_idx /= output_dim_size; + output_offset += index_in_dim * params.output_strides[dim]; + } + } + // params.p > 0 means "divide the accumulator by p before casting" + // (used by mean to keep the division in opmath_t precision so the + // fp32 accumulation isn't lost when TO is fp16/bf16/half2). + if (params.p > 0) { + output_val /= static_cast(params.p); + } + output[output_offset] = static_cast(output_val); + } +} + +// Specialized kernel for reducing a non-innermost dim of a contiguous 2D +// tensor. Each thread handles one column, iterating over all rows with +// coalesced reads. Multiple row-workers per threadgroup reduce via shared +// memory. This avoids the strided-access penalty of the generic kernel for +// dim=0. +// +// Grid: (ceil(N/TG_X), 1) threadgroups, each (TG_X, TG_Y) threads. +// TG_X threads cover adjacent columns (coalesced), TG_Y threads split rows. +template < + typename TI, + typename TO, + uint TG_X = 32, + uint TG_Y = 32, + uint NCHAINS = SUM_NCHAINS, + LoadMode MODE = LOAD_IDENTITY> +kernel void sum_reduction_outer( + constant TI* input [[buffer(0)]], + device TO* output [[buffer(1)]], + constant uint3& sizes [[buffer(2)]], // [M, N, output_stride] + constant float& divisor [[buffer(3)]], // >0 divides accumulator before cast + uint2 tid_tg [[thread_position_in_threadgroup]], + uint2 tg_pos [[threadgroup_position_in_grid]]) { + using TA = ::metal::conditional_t>; + const uint M = sizes.x; + const uint N = sizes.y; + const uint out_stride = sizes.z; + + uint col = tg_pos.x * TG_X + tid_tg.x; + if (col >= N) + return; + + // Split rows among TG_Y workers + uint rows_per_y = ceil_div(M, TG_Y); + uint row_start = tid_tg.y * rows_per_y; + uint row_end = min(row_start + rows_per_y, M); + + // Multiple accumulation chains for ILP + metal::array acc; + for (uint j = 0; j < NCHAINS; j++) + acc[j] = 0; + + uint row = row_start; + for (; row + NCHAINS <= row_end; row += NCHAINS) { + for (uint j = 0; j < NCHAINS; j++) { + acc[j] += load_val(input[(row + j) * N + col]); + } + } + for (; row < row_end; row++) { + acc[row % NCHAINS] += load_val(input[row * N + col]); + } + + TA sum = acc[0]; + for (uint j = 1; j < NCHAINS; j++) + sum += acc[j]; + + // Reduce across TG_Y row-workers via shared memory + threadgroup TA shmem[TG_Y][TG_X]; + shmem[tid_tg.y][tid_tg.x] = sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint stride = TG_Y / 2; stride > 0; stride >>= 1) { + if (tid_tg.y < stride) + shmem[tid_tg.y][tid_tg.x] += shmem[tid_tg.y + stride][tid_tg.x]; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (tid_tg.y == 0) { + TA final_val = shmem[0][tid_tg.x]; + if (divisor > 0) { + final_val /= static_cast(divisor); + } + output[col * out_stride] = static_cast(final_val); + } +} + +#define REGISTER_SUM_OUTER_IMPL(TI, TO, PREFIX, MODE) \ + template [[host_name(PREFIX "reduction_outer_" #TI "_" #TO)]] \ + kernel void sum_reduction_outer( \ + constant TI * input [[buffer(0)]], \ + device TO * output [[buffer(1)]], \ + constant uint3 & sizes [[buffer(2)]], \ + constant float& divisor [[buffer(3)]], \ + uint2 tid_tg [[thread_position_in_threadgroup]], \ + uint2 tg_pos [[threadgroup_position_in_grid]]); + +#define REGISTER_SUM_OUTER(TI, TO) \ + REGISTER_SUM_OUTER_IMPL(TI, TO, "sum_", LOAD_IDENTITY) +#define REGISTER_NANSUM_OUTER(TI, TO) \ + REGISTER_SUM_OUTER_IMPL(TI, TO, "nansum_", LOAD_NAN_TO_ZERO) +#define REGISTER_COUNT_NONZERO_OUTER(TI) \ + REGISTER_SUM_OUTER_IMPL(TI, long, "count_nonzero_", LOAD_NONZERO) + +REGISTER_SUM_OUTER(float, float); +REGISTER_SUM_OUTER(half, half); +REGISTER_SUM_OUTER(half, float); +REGISTER_SUM_OUTER(bfloat, bfloat); +REGISTER_SUM_OUTER(bfloat, float); +REGISTER_SUM_OUTER(int, int); +REGISTER_SUM_OUTER(int, long); +REGISTER_SUM_OUTER(long, long); +REGISTER_SUM_OUTER(short, short); +REGISTER_SUM_OUTER(short, long); +REGISTER_SUM_OUTER(char, char); +REGISTER_SUM_OUTER(char, long); +REGISTER_SUM_OUTER(uchar, uchar); +REGISTER_SUM_OUTER(uchar, long); +REGISTER_SUM_OUTER(bool, long); +REGISTER_SUM_OUTER(bool, int); +REGISTER_SUM_OUTER(float2, float2); +REGISTER_SUM_OUTER(half2, half2); + +// Specialized kernel for reducing the innermost dim of a contiguous tensor. +// Input [M, N] -> output [M], each SIMD group reduces one row of N elements. +// Multiple SIMD groups per TG handle different rows for occupancy. +// No shared memory needed — simd_sum suffices for intra-row reduction. +template < + typename TI, + typename TO, + uint NCHAINS = SUM_NCHAINS, + LoadMode MODE = LOAD_IDENTITY> +kernel void sum_reduction_inner( + constant TI* input [[buffer(0)]], + device TO* output [[buffer(1)]], + constant uint2& sizes [[buffer(2)]], // [M, N] + constant float& divisor [[buffer(3)]], // >0 divides accumulator before cast + uint tptg [[threads_per_threadgroup]], + uint tgid [[threadgroup_position_in_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simdgroup_id [[simdgroup_index_in_threadgroup]]) { + using TA = ::metal::conditional_t>; + const uint M = sizes.x; + const uint N = sizes.y; + const uint num_simd_groups = tptg / 32; + + // Each SIMD group handles a different row + uint row = tgid * num_simd_groups + simdgroup_id; + if (row >= M) + return; + + constant TI* row_ptr = input + row * N; + + metal::array acc; + for (uint j = 0; j < NCHAINS; j++) + acc[j] = 0; + + // Each of 32 lanes reads elements at stride 32, NCHAINS at a time. + // Align down to full blocks of stride = 32 * NCHAINS elements. + const uint stride = 32 * NCHAINS; + const uint aligned_N = (N / stride) * stride; + uint base = simd_lane_id * NCHAINS; + for (; base < aligned_N; base += stride) { + for (uint j = 0; j < NCHAINS; j++) { + acc[j] += load_val(row_ptr[base + j]); + } + } + // Tail: remaining elements after last full block, one per lane + for (uint i = aligned_N + simd_lane_id; i < N; i += 32) { + acc[0] += load_val(row_ptr[i]); + } + + TA sum = acc[0]; + for (uint j = 1; j < NCHAINS; j++) + sum += acc[j]; + + sum = c10::metal::simd_sum(sum); + + if (simd_lane_id == 0) { + if (divisor > 0) { + sum /= static_cast(divisor); + } + output[row] = static_cast(sum); + } +} + +#define REGISTER_SUM_INNER_IMPL(TI, TO, PREFIX, MODE) \ + template [[host_name(PREFIX "reduction_inner_" #TI "_" #TO)]] \ + kernel void sum_reduction_inner( \ + constant TI * input [[buffer(0)]], \ + device TO * output [[buffer(1)]], \ + constant uint2 & sizes [[buffer(2)]], \ + constant float& divisor [[buffer(3)]], \ + uint tptg [[threads_per_threadgroup]], \ + uint tgid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simdgroup_id [[simdgroup_index_in_threadgroup]]); + +#define REGISTER_SUM_INNER(TI, TO) \ + REGISTER_SUM_INNER_IMPL(TI, TO, "sum_", LOAD_IDENTITY) +#define REGISTER_NANSUM_INNER(TI, TO) \ + REGISTER_SUM_INNER_IMPL(TI, TO, "nansum_", LOAD_NAN_TO_ZERO) +#define REGISTER_COUNT_NONZERO_INNER(TI) \ + REGISTER_SUM_INNER_IMPL(TI, long, "count_nonzero_", LOAD_NONZERO) + +REGISTER_SUM_INNER(float, float); +REGISTER_SUM_INNER(half, half); +REGISTER_SUM_INNER(half, float); +REGISTER_SUM_INNER(bfloat, bfloat); +REGISTER_SUM_INNER(bfloat, float); +REGISTER_SUM_INNER(int, int); +REGISTER_SUM_INNER(int, long); +REGISTER_SUM_INNER(long, long); +REGISTER_SUM_INNER(short, short); +REGISTER_SUM_INNER(short, long); +REGISTER_SUM_INNER(char, char); +REGISTER_SUM_INNER(char, long); +REGISTER_SUM_INNER(uchar, uchar); +REGISTER_SUM_INNER(uchar, long); +REGISTER_SUM_INNER(bool, long); +REGISTER_SUM_INNER(bool, int); +REGISTER_SUM_INNER(float2, float2); +REGISTER_SUM_INNER(half2, half2); + +#define REGISTER_SUM_IMPL(TI, TO, PREFIX, MODE) \ + template [[host_name(PREFIX "reduction_" #TI "_" #TO)]] \ + kernel void sum_reduction( \ + constant TI * input [[buffer(0)]], \ + device TO * output [[buffer(1)]], \ + constant NormParams<> & params [[buffer(2)]], \ + uint tid [[thread_position_in_threadgroup]], \ + uint tptg [[threads_per_threadgroup]], \ + uint tgid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simdgroup_id [[simdgroup_index_in_threadgroup]], \ + uint simdgroup_size [[threads_per_simdgroup]]); + +#define REGISTER_SUM(TI, TO) REGISTER_SUM_IMPL(TI, TO, "sum_", LOAD_IDENTITY) +#define REGISTER_NANSUM(TI, TO) \ + REGISTER_SUM_IMPL(TI, TO, "nansum_", LOAD_NAN_TO_ZERO) +#define REGISTER_COUNT_NONZERO(TI) \ + REGISTER_SUM_IMPL(TI, long, "count_nonzero_", LOAD_NONZERO) + +REGISTER_SUM(float, float); +REGISTER_SUM(float, half); +REGISTER_SUM(float, bfloat); +REGISTER_SUM(half, half); +REGISTER_SUM(half, float); +REGISTER_SUM(bfloat, bfloat); +REGISTER_SUM(bfloat, float); +REGISTER_SUM(int, int); +REGISTER_SUM(int, long); +REGISTER_SUM(long, long); +REGISTER_SUM(short, short); +REGISTER_SUM(short, long); +REGISTER_SUM(char, char); +REGISTER_SUM(char, long); +REGISTER_SUM(uchar, uchar); +REGISTER_SUM(uchar, long); +REGISTER_SUM(bool, long); +REGISTER_SUM(bool, int); +REGISTER_SUM(float2, float2); +REGISTER_SUM(half2, half2); + +// nansum variants (floating-point only — integers can't have NaN) +REGISTER_NANSUM(float, float); +REGISTER_NANSUM(half, half); +REGISTER_NANSUM(half, float); +REGISTER_NANSUM(bfloat, bfloat); +REGISTER_NANSUM(bfloat, float); + +REGISTER_NANSUM_OUTER(float, float); +REGISTER_NANSUM_OUTER(half, half); +REGISTER_NANSUM_OUTER(half, float); +REGISTER_NANSUM_OUTER(bfloat, bfloat); +REGISTER_NANSUM_OUTER(bfloat, float); + +REGISTER_NANSUM_INNER(float, float); +REGISTER_NANSUM_INNER(half, half); +REGISTER_NANSUM_INNER(half, float); +REGISTER_NANSUM_INNER(bfloat, bfloat); +REGISTER_NANSUM_INNER(bfloat, float); + +// count_nonzero: output is always long; reuses sum-reduction machinery +// with LOAD_NONZERO mode (1 per nonzero element, 0 otherwise). +REGISTER_COUNT_NONZERO(float); +REGISTER_COUNT_NONZERO(half); +REGISTER_COUNT_NONZERO(bfloat); +REGISTER_COUNT_NONZERO(long); +REGISTER_COUNT_NONZERO(int); +REGISTER_COUNT_NONZERO(short); +REGISTER_COUNT_NONZERO(char); +REGISTER_COUNT_NONZERO(uchar); +REGISTER_COUNT_NONZERO(bool); +REGISTER_COUNT_NONZERO(float2); +REGISTER_COUNT_NONZERO(half2); + +REGISTER_COUNT_NONZERO_OUTER(float); +REGISTER_COUNT_NONZERO_OUTER(half); +REGISTER_COUNT_NONZERO_OUTER(bfloat); +REGISTER_COUNT_NONZERO_OUTER(long); +REGISTER_COUNT_NONZERO_OUTER(int); +REGISTER_COUNT_NONZERO_OUTER(short); +REGISTER_COUNT_NONZERO_OUTER(char); +REGISTER_COUNT_NONZERO_OUTER(uchar); +REGISTER_COUNT_NONZERO_OUTER(bool); +REGISTER_COUNT_NONZERO_OUTER(float2); +REGISTER_COUNT_NONZERO_OUTER(half2); + +REGISTER_COUNT_NONZERO_INNER(float); +REGISTER_COUNT_NONZERO_INNER(half); +REGISTER_COUNT_NONZERO_INNER(bfloat); +REGISTER_COUNT_NONZERO_INNER(long); +REGISTER_COUNT_NONZERO_INNER(int); +REGISTER_COUNT_NONZERO_INNER(short); +REGISTER_COUNT_NONZERO_INNER(char); +REGISTER_COUNT_NONZERO_INNER(uchar); +REGISTER_COUNT_NONZERO_INNER(bool); +REGISTER_COUNT_NONZERO_INNER(float2); +REGISTER_COUNT_NONZERO_INNER(half2); diff --git a/aten/src/ATen/native/mps/kernels/SamplingHelpers.h b/aten/src/ATen/native/mps/kernels/SamplingHelpers.h new file mode 100644 index 0000000000000..2fd06dfc26247 --- /dev/null +++ b/aten/src/ATen/native/mps/kernels/SamplingHelpers.h @@ -0,0 +1,42 @@ +#pragma once + +// Shared cubic interpolation helpers used by both GridSampler and UpSample +// kernels. Based on the Catmull-Rom spline with A=-0.75. +// See +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + +template +accscalar_t cubic_convolution1(accscalar_t x, accscalar_t A) { + return ((A + 2) * x - (A + 3)) * x * x + 1; +} + +template +accscalar_t cubic_convolution2(accscalar_t x, accscalar_t A) { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +} + +template +void get_cubic_coefficients(accscalar_t coeffs[4], accscalar_t t) { + accscalar_t A = -0.75; + + accscalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + accscalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); +} + +template +accscalar_t cubic_interp1d( + scalar_t x0, + scalar_t x1, + scalar_t x2, + scalar_t x3, + accscalar_t t) { + accscalar_t coeffs[4]; + get_cubic_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; +} diff --git a/aten/src/ATen/native/mps/kernels/ScanKernel.metal b/aten/src/ATen/native/mps/kernels/ScanKernel.metal index de493af7aaa05..1fde656f46e69 100644 --- a/aten/src/ATen/native/mps/kernels/ScanKernel.metal +++ b/aten/src/ATen/native/mps/kernels/ScanKernel.metal @@ -3,26 +3,16 @@ using namespace metal; #include +#include #include -using c10::metal::accum_t; +using namespace c10::metal; struct LogAddExp { template T operator()(T x, T y) { - // Reference: - // https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp - T min_val = c10::metal::min(x, y); - T max_val = c10::metal::max(x, y); - - if (min_val != max_val || metal::isfinite(min_val)) { - // nan will be propagated here - return c10::metal::log1p(metal::exp(min_val - max_val)) + max_val; - } else { - // special case to correctly handle infinite cases - return x; - } - }; + return c10::metal::logaddexp(x, y); + } }; C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size; @@ -123,6 +113,29 @@ struct LogCumSumExpOp { } }; +template > +struct CumProdOp { + static constexpr constant acc_t init = + is_complex_v ? acc_t(1, 0) : static_cast(1); + + acc_t operator()(acc_t a, acc_t b) { + return c10::metal::mul(a, b); + } + + acc_t simd_scan(acc_t x) { + for (int i = 1; i <= 16; i *= 2) { + acc_t other = simd_shuffle_and_fill_up(x, init, i); + x = this->operator()(x, other); + } + return x; + } + + acc_t simd_exclusive_scan(acc_t x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + // Pair structure to hold value and index for cummin/cummax operations template > struct ValueIndexPair { @@ -765,6 +778,11 @@ kernel void scan_with_indices_outer_dim( REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float, 4); REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half, 4); REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, bfloat, 4); +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float2, 2); +REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half2, 4); + +REGISTER_SCAN_OP(cumprod, CumProdOp, float2, 2); +REGISTER_SCAN_OP(cumprod, CumProdOp, half2, 4); // Scan with indices operations for cummin/cummax REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float, 4); diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index fa9b5a1bb107d..cd4be4d0fc5bc 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -1,3 +1,4 @@ +#include #include #include #include @@ -5,45 +6,6 @@ using namespace metal; using namespace c10::metal; -// Based on -// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -template -accscalar_t cubic_convolution1(accscalar_t x, accscalar_t A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} - -template -accscalar_t cubic_convolution2(accscalar_t x, accscalar_t A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -template -void get_cubic_upsampling_coefficients(accscalar_t coeffs[4], accscalar_t t) { - accscalar_t A = -0.75; - - accscalar_t x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - accscalar_t x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); -} - -template -accscalar_t cubic_interp1d( - scalar_t x0, - scalar_t x1, - scalar_t x2, - scalar_t x3, - accscalar_t t) { - accscalar_t coeffs[4]; - get_cubic_upsampling_coefficients(coeffs, t); - - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; -} - template accscalar_t area_pixel_compute_source_index( accscalar_t scale, @@ -720,8 +682,8 @@ kernel void upsample_bicubic2d_backward( float x_coeffs[4]; float y_coeffs[4]; - get_cubic_upsampling_coefficients(x_coeffs, t_x); - get_cubic_upsampling_coefficients(y_coeffs, t_y); + get_cubic_coefficients(x_coeffs, t_x); + get_cubic_coefficients(y_coeffs, t_y); for (int n = 0; n < output_sizes.x; n++) { for (int c = 0; c < output_sizes.y; ++c) { diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm index 802c648c888d5..205b6bb3872c2 100644 --- a/aten/src/ATen/native/mps/operations/Activation.mm +++ b/aten/src/ATen/native/mps/operations/Activation.mm @@ -22,8 +22,6 @@ #include #include #include -#include -#include #include #include #include @@ -35,86 +33,6 @@ namespace at::native { -Tensor relu_mps(const Tensor& self) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - bool executeGatherOp = - !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); - Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); - - if (output.numel() == 0) { - return output; - } - - MPSStream* stream = getCurrentMPSStream(); - @autoreleasepool { - std::string key = "relu" + getTensorsStringKey({self}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - // passing selector of reLUWithTensor on the mpsGraph object - MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nil, false); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - - return output; -} - -Tensor& relu_mps_(Tensor& self) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - if (self.numel() == 0) { - return self; - } - // Inplace relu - Tensor& output = self; - bool executeGatherOp = - !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); - Tensor out; - if (executeGatherOp) { - out = at::empty_like(self, MemoryFormat::Contiguous); - } - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "relu_" + getTensorsStringKey({self}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - // passing selector of reLUWithTensor on the mpsGraph object - MPSGraphTensor* outputTensor = [mpsGraph reLUWithTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, executeGatherOp ? out : output, nil, false); - - // Create dictionary of inputs and outputs - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - if (executeGatherOp) { - output.copy_(out); - } - } - - return output; -} - TORCH_IMPL_FUNC(log_softmax_mps_out) (const Tensor& self, const int64_t dim, const bool half_to_float, const Tensor& out) { TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); @@ -1216,112 +1134,6 @@ Tensor prelu_mps(const Tensor& self, const Tensor& weight_) { return std::tuple{grad_input, weight_grad}; } -TORCH_IMPL_FUNC(silu_out_mps)(const Tensor& self, const Tensor& result) { - using namespace mps; - using CachedGraph = MPSUnaryCachedGraph; - - TORCH_CHECK(self.is_mps()); - TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "MPS doesn't know how to do exponent_i64"); - - // Empty output - if (result.numel() == 0) - return; - - MPSStream* stream = getCurrentMPSStream(); - - bool executeGatherOp = - !(self.is_contiguous(MemoryFormat::Contiguous) || self.is_contiguous(MemoryFormat::ChannelsLast) || - self.is_contiguous(MemoryFormat::ChannelsLast3d)); - Tensor result_ = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve); - - @autoreleasepool { - std::string key = "silu_out_mps:" + getTensorsStringKey({self}); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)]; - MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil]; - MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput name:nil]; - MPSGraphTensor* expPlusOneTensor = [mpsGraph additionWithPrimaryTensor:expNegativeTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph divisionWithPrimaryTensor:inputTensor - secondaryTensor:expPlusOneTensor - name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp); - Placeholder outputPlaceholder = - Placeholder(cachedGraph->outputTensor_, executeGatherOp ? result_ : result, nil, false); - - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - if (executeGatherOp) { - result.copy_(result_); - } -} - -TORCH_IMPL_FUNC(silu_backward_out_mps) -(const Tensor& grad_output, const Tensor& self, const Tensor& grad_input) { - using namespace mps; - using CachedGraph = MPSUnaryGradCachedGraph; - TORCH_CHECK(grad_output.is_mps()); - - // Empty output - if (grad_input.numel() == 0) - return; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "silu_out_backward_mps:" + getTensorsStringKey({grad_output}); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - MPSGraphTensor* gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output); - - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(grad_output)]; - MPSGraphTensor* negativeInput = [mpsGraph negativeWithTensor:inputTensor name:nil]; - MPSGraphTensor* expNegativeTensor = [mpsGraph exponentWithTensor:negativeInput name:nil]; - MPSGraphTensor* expPlusOneTensor = [mpsGraph additionWithPrimaryTensor:expNegativeTensor - secondaryTensor:unitTensor - name:nil]; - MPSGraphTensor* sigmoidTensor = [mpsGraph reciprocalWithTensor:expPlusOneTensor name:nil]; - MPSGraphTensor* oneMinusSigmoid = [mpsGraph subtractionWithPrimaryTensor:unitTensor - secondaryTensor:sigmoidTensor - name:nil]; - MPSGraphTensor* inputTimesDiff = [mpsGraph multiplicationWithPrimaryTensor:inputTensor - secondaryTensor:oneMinusSigmoid - name:nil]; - MPSGraphTensor* onePlusTensor = [mpsGraph additionWithPrimaryTensor:unitTensor - secondaryTensor:inputTimesDiff - name:nil]; - MPSGraphTensor* gradTensor = [mpsGraph multiplicationWithPrimaryTensor:sigmoidTensor - secondaryTensor:onePlusTensor - name:nil]; - MPSGraphTensor* gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradTensor - secondaryTensor:gradOutputTensor - name:nil]; - - newCachedGraph->gradOutputTensor_ = gradOutputTensor; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->gradInputTensor_ = gradInputTensor; - }); - - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output); - Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input); - - auto feeds = dictionaryFromPlaceholders(selfPlaceholder, gradOutputPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, gradInputPlaceholder); - } -} - // ------------------------------------------------- // Hardtanh backward diff --git a/aten/src/ATen/native/mps/operations/ActivationKernel.mm b/aten/src/ATen/native/mps/operations/ActivationKernel.mm index f6d3ad986ade0..2138b0ddb6a96 100644 --- a/aten/src/ATen/native/mps/operations/ActivationKernel.mm +++ b/aten/src/ATen/native/mps/operations/ActivationKernel.mm @@ -4,6 +4,18 @@ #include #include #include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#include +#include +#include +#include +#endif #include #include @@ -15,6 +27,25 @@ #include #endif +Tensor relu_mps(const Tensor& self) { + TORCH_CHECK(!self.is_complex(), "relu is not supported for complex types"); + auto output = at::empty_like(self); + if (output.numel() == 0) + return output; + auto iter = at::TensorIteratorConfig().add_output(output).add_input(self).build(); + lib.exec_unary_kernel(iter, "relu", /*alpha=*/std::nullopt, /*scalar_arg_type=*/std::nullopt, /*supports_vec4=*/true); + return output; +} + +Tensor& relu_mps_(Tensor& self) { + TORCH_CHECK(!self.is_complex(), "relu is not supported for complex types"); + if (self.numel() == 0) + return self; + auto iter = at::TensorIteratorConfig().add_output(self).add_input(self).set_check_mem_overlap(false).build(); + lib.exec_unary_kernel(iter, "relu", /*alpha=*/std::nullopt, /*scalar_arg_type=*/std::nullopt, /*supports_vec4=*/true); + return self; +} + static void hardshrink_kernel(TensorIteratorBase& iter, const Scalar& lambda = 0.5) { lib.exec_unary_kernel(iter, "hardshrink", lambda); } @@ -67,6 +98,30 @@ static void elu_backward_kernel(TensorIteratorBase& iter, }); } +static void silu_kernel(TensorIteratorBase& iter) { + if (isComplexType(iter.common_dtype())) { + auto out = iter.output(0); + auto self = iter.input(0); + at::mul_out(out, self, at::sigmoid(self)); + return; + } + lib.exec_unary_kernel(iter, "silu", /*alpha=*/std::nullopt, /*scalar_arg_type=*/std::nullopt, /*supports_vec4=*/true); +} + +static void silu_backward_kernel(TensorIteratorBase& iter) { + if (isComplexType(iter.common_dtype())) { + auto grad_input = iter.output(0); + auto grad_output = iter.input(0); + auto self = iter.input(1); + auto sig = at::sigmoid(self); + auto one_minus_sig = at::rsub(sig, 1); + auto inner = at::add(at::mul(self, one_minus_sig), 1); + grad_input.copy_(at::mul(grad_output, at::mul(sig, inner))); + return; + } + lib.exec_binary_kernel(iter, "silu_backward"); +} + static void leaky_relu_kernel(TensorIteratorBase& iter, const Scalar& negative_slope) { lib.exec_unary_kernel(iter, "leaky_relu", negative_slope); } @@ -86,5 +141,7 @@ static void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& n REGISTER_DISPATCH(elu_backward_stub, elu_backward_kernel); REGISTER_DISPATCH(leaky_relu_stub, leaky_relu_kernel); REGISTER_DISPATCH(leaky_relu_backward_stub, leaky_relu_backward_kernel); +REGISTER_DISPATCH(silu_stub, silu_kernel); +REGISTER_DISPATCH(silu_backward_stub, silu_backward_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Amp.mm b/aten/src/ATen/native/mps/operations/Amp.mm index e410d434ec7ad..2b1e11575cd50 100644 --- a/aten/src/ATen/native/mps/operations/Amp.mm +++ b/aten/src/ATen/native/mps/operations/Amp.mm @@ -29,7 +29,6 @@ static void _amp_non_finite_check_and_unscale_mps_single_impl(const Tensor& scal } TORCH_CHECK(scaled_grad.is_mps(), "Tensor is not on the MPS device."); TORCH_CHECK(scaled_grad.numel() <= std::numeric_limits::max(), "scaled_grad is too large"); - float inv_scale_val = inv_scale.item(); auto stream = getCurrentMPSStream(); auto device = MPSDevice::getInstance()->device(); auto ampPipelineState = @@ -43,7 +42,7 @@ static void _amp_non_finite_check_and_unscale_mps_single_impl(const Tensor& scal dispatch_sync_with_rethrow(stream->queue(), ^() { auto computeEncoder = stream->commandEncoder(); [computeEncoder setComputePipelineState:ampPipelineState]; - mtl_setArgs(computeEncoder, scaled_grad, found_inf, inv_scale_val); + mtl_setArgs(computeEncoder, scaled_grad, found_inf, inv_scale); [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize]; }); } diff --git a/aten/src/ATen/native/mps/operations/Attention.mm b/aten/src/ATen/native/mps/operations/Attention.mm index 90600cd8a7b85..696094d800320 100644 --- a/aten/src/ATen/native/mps/operations/Attention.mm +++ b/aten/src/ATen/native/mps/operations/Attention.mm @@ -439,13 +439,42 @@ } std::tuple _scaled_dot_product_attention_math_mps(const Tensor& query, - const Tensor& key, - const Tensor& value, + const Tensor& key_, + const Tensor& value_, const std::optional& attn_mask, double dropout_p, bool is_causal, const std::optional& dropout_mask, - std::optional scale) { + std::optional scale, + bool enable_gqa) { + TORCH_CHECK_NOT_IMPLEMENTED(c10::isFloatingType(query.scalar_type()), + "scaled_dot_product_attention for MPS does not support dtype ", + query.scalar_type()); + TORCH_CHECK_NOT_IMPLEMENTED(c10::isFloatingType(key_.scalar_type()), + "scaled_dot_product_attention for MPS does not support dtype ", + key_.scalar_type()); + TORCH_CHECK_NOT_IMPLEMENTED(c10::isFloatingType(value_.scalar_type()), + "scaled_dot_product_attention for MPS does not support dtype ", + value_.scalar_type()); + const auto any_nested = query.is_nested() || key_.is_nested() || value_.is_nested(); + const auto all_contiguous = + query.is_contiguous_or_false() && key_.is_contiguous_or_false() && value_.is_contiguous_or_false(); + auto key = key_; + auto value = value_; + if (enable_gqa) { + int64_t q_heads = query.size(-3); + int64_t k_heads = key_.size(-3); + int64_t repeat_factor = q_heads / k_heads; + + if (repeat_factor > 1) { + TORCH_CHECK(q_heads % k_heads == 0, + "For GQA, the query tensor's head dimension (" + std::to_string(q_heads) + + ") must be divisible by the key tensor's head dimension (" + std::to_string(k_heads) + ")."); + key = key_.repeat_interleave(repeat_factor, /*dim=*/-3); + value = value_.repeat_interleave(repeat_factor, /*dim=*/-3); + } + } + auto query_tuple = ensure_4d(query); Tensor q_ = std::get<0>(query_tuple); bool unsqueezed = std::get<1>(query_tuple); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index 03ecee79449c0..9f78ab11dfab1 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -109,6 +109,10 @@ static void logaddexp2_mps_kernel(TensorIteratorBase& iter) { lib.exec_binary_kernel(iter, "logaddexp2"); } +static void xlogy_mps_kernel(TensorIteratorBase& iter) { + lib.exec_binary_kernel(iter, "xlogy"); +} + static void xlog1py_mps_kernel(TensorIteratorBase& iter) { TORCH_CHECK_TYPE(isFloatingType(iter.common_dtype()), "xlog1py_mps not implemented for non-floating types"); lib.exec_binary_kernel(iter, "xlog1py"); @@ -186,6 +190,68 @@ static void lerp_scalar_mps_kernel(at::TensorIteratorBase& iter, const Scalar& w lib.exec_binary_kernel(iter, "lerp_alpha", weight); } +static void lerp_tensor_mps_kernel(at::TensorIteratorBase& iter) { + using namespace mps; + auto type_str = scalarToMetalTypeString(iter.common_dtype()); + auto numel = static_cast(iter.numel()); + auto ndim = static_cast(iter.ndim()); + + // simple elementwise kernel for dense tensors + if (iter.is_contiguous()) { + auto pso = lib.getPipelineStateForFunc("lerp_tensor_dense_" + type_str); + dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() { + auto computeEncoder = getCurrentMPSStream()->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + bind_iter_tensors(computeEncoder, iter); + mtl_dispatch1DJob(computeEncoder, pso, numel); + }); + return; + } + + // Scalar weight broadcast path + if (ndim == 1 && iter.strides(3)[0] == 0) { + auto pso = lib.getPipelineStateForFunc("lerp_tensor_scalar_weight_" + type_str); + dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() { + auto computeEncoder = getCurrentMPSStream()->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + bind_iter_tensors(computeEncoder, iter); + mtl_dispatch1DJob(computeEncoder, pso, numel); + }); + return; + } + + // 2D/3D: multi-dimensional dispatch, to avoid integer division for coordinates + if (ndim >= 2 && ndim <= 3) { + auto pso = lib.getPipelineStateForFunc(fmt::format("lerp_tensor_strided_{}d_{}", ndim, type_str)); + dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() { + auto computeEncoder = getCurrentMPSStream()->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + bind_iter_tensors(computeEncoder, iter); + mtl_setArgs<4>(computeEncoder, iter.strides(0), iter.strides(1), iter.strides(2), iter.strides(3)); + auto sizes = iter.shape(); + auto maxTg = [pso maxTotalThreadsPerThreadgroup]; + auto tg_x = std::min(static_cast(sizes[0]), maxTg); + auto tg_y = std::min(static_cast(sizes[1]), maxTg / tg_x); + auto grid_z = ndim > 2 ? static_cast(sizes[2]) : 1; + auto tg_z = std::min(grid_z, std::max(maxTg / (tg_x * tg_y), (NSUInteger)1)); + [computeEncoder dispatchThreads:MTLSizeMake(sizes[0], sizes[1], grid_z) + threadsPerThreadgroup:MTLSizeMake(tg_x, tg_y, tg_z)]; + }); + return; + } + + // General strided fallback + auto pso = lib.getPipelineStateForFunc("lerp_tensor_strided_" + type_str); + dispatch_sync_with_rethrow(getCurrentMPSStream()->queue(), ^() { + auto computeEncoder = getCurrentMPSStream()->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + bind_iter_tensors(computeEncoder, iter); + mtl_setArgs<4>( + computeEncoder, iter.shape(), iter.strides(0), iter.strides(1), iter.strides(2), iter.strides(3), ndim); + mtl_dispatch1DJob(computeEncoder, pso, numel); + }); +} + static void native_dropout_mask_and_scale_mps_kernel(at::TensorIteratorBase& iter, const Scalar& scale) { lib.exec_binary_kernel(iter, "native_dropout_mask_and_scale", scale); } @@ -242,6 +308,7 @@ static void gcd_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(zeta_stub, &zeta_mps_kernel) REGISTER_DISPATCH(logaddexp_stub, &logaddexp_mps_kernel); REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_mps_kernel); +REGISTER_DISPATCH(xlogy_stub, &xlogy_mps_kernel) REGISTER_DISPATCH(xlog1py_stub, &xlog1py_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_mps_kernel) REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_mps_kernel) @@ -256,6 +323,7 @@ static void gcd_mps_kernel(TensorIteratorBase& iter) { REGISTER_DISPATCH(polar_stub, &polar_mps_kernel); REGISTER_DISPATCH(complex_stub, &complex_mps_kernel); REGISTER_DISPATCH(lerp_kernel_scalar_weight, &lerp_scalar_mps_kernel) +REGISTER_DISPATCH(lerp_kernel_tensor_weight, &lerp_tensor_mps_kernel) REGISTER_DISPATCH(mul_stub, &mul_mps_kernel) REGISTER_DISPATCH(div_true_stub, &div_true_mps_kernel) REGISTER_DISPATCH(div_floor_stub, &div_floor_mps_kernel) diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index cfbb194916b83..e56dd92679507 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -20,15 +20,12 @@ #include #include #include -#include -#include #include #include #include #include #include #include -#include #endif namespace at::native { @@ -243,8 +240,6 @@ static void add_sub_lerp_template(const Tensor& self, CREATE_MPS_STRUCTURED_BOOLEAN_OP_FUNC(gt_tensor_out_mps, greaterThan, Tensor); // Arithmetic Binary Ops -CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(minimum_out_mps, minimumWithNaNPropagationAndIntFallback, Tensor); -CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(maximum_out_mps, maximumWithNaNPropagationAndIntFallback, Tensor); CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_tensor_out_mps, power, Tensor); CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor); CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor); @@ -266,29 +261,4 @@ static void add_sub_lerp_template(const Tensor& self, } } -TORCH_IMPL_FUNC(xlogy_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) { - mps::BinaryOpBlock xlogy_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 shape:@[ @1 ] dataType:primaryCastTensor.dataType]; - MPSGraphTensor* yIsNaNPredicateTensor = [mpsGraph isNaNWithTensor:secondaryCastTensor name:nil]; - MPSGraphTensor* logyTensor = [mpsGraph logarithmWithTensor:secondaryCastTensor name:nil]; - MPSGraphTensor* xlogyTensor = [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor - secondaryTensor:logyTensor - name:nil]; - MPSGraphTensor* xEqualZeroPredicateTensor = [mpsGraph equalWithPrimaryTensor:primaryCastTensor - secondaryTensor:zeroTensor - name:nil]; - MPSGraphTensor* outputTensor = [mpsGraph selectWithPredicateTensor:xEqualZeroPredicateTensor - truePredicateTensor:zeroTensor - falsePredicateTensor:xlogyTensor - name:nil]; - outputTensor = [mpsGraph selectWithPredicateTensor:yIsNaNPredicateTensor - truePredicateTensor:secondaryCastTensor - falsePredicateTensor:outputTensor - name:nil]; - return outputTensor; - }; - mps::binaryOpTensor(self, other, output, "xlogy_out_mps", xlogy_op_block); -} - } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm index e36ac4dc45246..c0d3b795ded9f 100644 --- a/aten/src/ATen/native/mps/operations/ConstantOps.mm +++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm @@ -1,120 +1,89 @@ // Copyright © 2022 Apple Inc. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else -#include -#include -#include +#include #endif namespace at::native { -static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) { - using namespace mps; +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = mps::MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif - if (self.numel() == 0) { - return self; - } - Tensor output = self; - bool needsCopyToOutput = false; - if (needsGather(self)) { - output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); - needsCopyToOutput = true; +static void fill_mps_kernel(TensorIterator& iter, const Scalar& value) { + using namespace mps; + if (iter.numel() == 0) { + return; } - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* inputTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - - @autoreleasepool { - std::string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble()); - - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); - MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); - - auto mpsScalar = getMPSScalar(value, self.scalar_type()); - auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar); - NSDictionary* feeds = @{cachedGraph->inputTensor_ : mpsScalarData}; - - Placeholder outputPlaceholder = - Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput); - - NSDictionary* results = - @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; - - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); - - if (needsCopyToOutput) { - self.copy_(output); + // Metal compute kernels use uint (32-bit) thread indices; decompose large + // tensors into chunks that fit in 32-bit indexing. + if (!iter.can_use_32bit_indexing()) { + for (auto&& sub_iter : iter.with_32bit_indexing()) { + fill_mps_kernel(sub_iter, value); } + return; } - return self; -} - -static Tensor& fill_mps_tensor_(Tensor& self, uint8_t value) { - TORCH_INTERNAL_ASSERT(self.is_contiguous()); + const Tensor& self = iter.tensor(0); + const auto dtype = self.scalar_type(); const auto stream = getCurrentMPSStream(); - auto storage_byte_offset = self.storage_offset() * self.itemsize(); - stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); - return self; -} - -Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { - if (isComplexType(self.scalar_type())) { - auto self_as_real = at::view_as_real(self); - auto self_as_real_real = self_as_real.select(self.dim(), 0); - auto self_as_real_imag = self_as_real.select(self.dim(), 1); - if (value.isComplex()) { - auto value_cdouble = value.to>(); - fill_scalar_mps_impl(self_as_real_real, value_cdouble.real()); - fill_scalar_mps_impl(self_as_real_imag, value_cdouble.imag()); - return self; - } - fill_scalar_mps_impl(self_as_real_real, value); - fill_scalar_mps_impl(self_as_real_imag, 0.0f); - return self; - } - // check if it's possible to use fillBuffer() to fill the Tensor's storage - if (self.is_contiguous()) { - if (value.toDouble() == 0.0) { - return fill_mps_tensor_(self, 0); - } - if (self.scalar_type() == kBool) { - return fill_mps_tensor_(self, value.toBool()); - } - if (self.scalar_type() == kByte) { - return fill_mps_tensor_(self, value.toByte()); - } - if (self.scalar_type() == kChar) { - return fill_mps_tensor_(self, value.toChar()); - } + const auto type_str = scalarToMetalTypeString(dtype); + const bool can_fill_linearly = self.is_non_overlapping_and_dense(); + + // For tensors with gaps or overlaps (e.g. stride-2 slices) use a 2D strided + // kernel: tid.y indexes dim 0 directly (no division), tid.x is the linear + // index for the remaining dims. Consecutive threads in x write consecutive + // addresses in the innermost dimension, giving coalesced writes. + if (!can_fill_linearly) { + auto fillPSO = lib.getPipelineStateForFunc(fmt::format("fill_scalar_strided_{}", type_str)); + const int64_t dim0_size = iter.ndim() > 0 ? iter.shape()[0] : 1; + const int64_t inner_numel = iter.numel() / dim0_size; + const uint32_t ndim = static_cast(iter.ndim()); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + auto mpsScalar = getMPSScalar(value, dtype); + [computeEncoder setComputePipelineState:fillPSO]; + bind_iter_tensors(computeEncoder, iter); + mtl_setArgs<1>(computeEncoder, mpsScalar, iter.shape(), iter.strides(0), ndim); + const NSUInteger maxTG = fillPSO.maxTotalThreadsPerThreadgroup; + const MTLSize tgSize = MTLSizeMake(std::min(maxTG, (NSUInteger)inner_numel), 1, 1); + const MTLSize gridSize = MTLSizeMake(inner_numel, dim0_size, 1); + [computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:tgSize]; + } + }); + return; } - return fill_scalar_mps_impl(self, value); -} - -Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) { - TORCH_CHECK(value.dim() == 0, - "fill_ only supports 0-dimension value tensor but got tensor with ", - value.dim(), - " dimensions."); - Scalar scalar_value = value.item(); - return fill_scalar_mps(self, scalar_value); + // Single-byte dtypes (bool, uint8, int8) use vec4 kernels that fill + // 4 elements per thread; all others fill 1 element per thread. + const bool is_byte_type = self.element_size() == 1; + const uint32_t numel = static_cast(iter.numel()); + const int64_t threads = is_byte_type ? (numel + 3) / 4 : numel; + + auto fillPSO = lib.getPipelineStateForFunc(fmt::format("fill_scalar_dense_{}", type_str)); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + auto mpsScalar = getMPSScalar(value, dtype); + [computeEncoder setComputePipelineState:fillPSO]; + bind_iter_tensors(computeEncoder, iter); + mtl_setArgs<1>(computeEncoder, mpsScalar, numel); + mtl_dispatch1DJob(computeEncoder, fillPSO, threads); + } + }); } -Tensor& zero_mps_(Tensor& self) { - return fill_scalar_mps(self, 0.0f); -} +REGISTER_DISPATCH(fill_stub, &fill_mps_kernel); } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index a457267a9d850..cf46afdd92b9d 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -114,7 +114,10 @@ static Tensor _mps_convolution_impl(const Tensor& input_t, const bool is_macos_15_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); const bool is3DConv = input_t.dim() == 5; - const auto memory_format = input_t.suggest_memory_format(); + // Use exact-match: a channel-slice of a channels-last tensor has CL-like + // strides but is not NHWC-packed, so the raw-buffer NHWC path would misread + // it. See https://github.com/pytorch/pytorch/issues/180984 + const auto memory_format = input_t.suggest_memory_format(/*channels_last_strides_exact_match=*/true); const auto input_suggested_layout = memory_format == kChannelsLast && is_macos_15_plus ? kChannelsLast : kContiguous; const bool is_channels_last = mps_conv_use_channels_last(input_t, weight_t) && !is3DConv; const bool bias_defined = bias_opt ? bias_opt->defined() : false; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index a3cba05c975cf..fedf6d6cfe837 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -333,6 +333,7 @@ void copy_blit_mps(void* dst, const void* src, size_t size) { } // namespace mps Tensor _copy_from_and_resize_mps(const at::Tensor& self, const at::Tensor& dst) { + const_cast(dst).resize_as_(self); return mps::mps_copy_(const_cast(dst), self, false); } diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index d1d7bd60f7a16..66747cf79af30 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -11,9 +11,13 @@ #include #include #else +#include +#include #include #include #include +#include +#include #include #include #include @@ -21,9 +25,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -419,38 +425,13 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional gen) { - TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda); - - mps::RandomOpBlock random_op_block = ^RandomOpFn(cachedGraph, randomTensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0f dataType:randomTensor.dataType]; - MPSGraphTensor* minusLambdaTensor = [mpsGraph constantWithScalar:-lambda dataType:randomTensor.dataType]; - MPSGraphTensor* subtractTensor = [mpsGraph subtractionWithPrimaryTensor:unitTensor - secondaryTensor:randomTensor - name:nil]; - MPSGraphTensor* logTensor = [mpsGraph logarithmWithTensor:subtractTensor name:nil]; - return [mpsGraph divisionWithPrimaryTensor:logTensor secondaryTensor:minusLambdaTensor name:nil]; - }; - auto eps = std::numeric_limits::epsilon(); - return mps::random_mps_impl(self, - eps, - 1.0, - std::nullopt, - std::nullopt, - MPSGraphRandomDistributionUniform, - gen, - "exponential_mps_:" + std::to_string(lambda), - random_op_block); -} - static Tensor& distribution_kernel_mps_impl(Tensor& self, double param1, double param2, const std::string& kernel_name, int64_t randoms_per_element, - std::optional gen) { + std::optional gen, + int64_t elements_per_thread = 1) { if (self.numel() == 0) { return self; } @@ -483,7 +464,13 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional{static_cast(param1), static_cast(param2)}, std::array{seed, base_offset}); - mtl_dispatch1DJob(computeEncoder, pso, output.numel()); + if (elements_per_thread > 1) { + auto numel = static_cast(output.numel()); + mtl_setBytes(computeEncoder, numel, 3); + mtl_dispatch1DJob(computeEncoder, pso, (numel + elements_per_thread - 1) / elements_per_thread); + } else { + mtl_dispatch1DJob(computeEncoder, pso, output.numel()); + } } }); } @@ -495,6 +482,11 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional gen) { + TORCH_CHECK(lambda > 0.0, "exponential_ expects lambda > 0.0, but found lambda=", lambda); + return distribution_kernel_mps_impl(self, lambda, 0.0, "exponential", 1, gen, /*elements_per_thread=*/4); +} + Tensor& cauchy_mps_(Tensor& self, double median, double sigma, std::optional gen) { TORCH_CHECK(sigma > 0.0, "cauchy_ expects sigma > 0.0, but found sigma=", sigma); return distribution_kernel_mps_impl(self, median, sigma, "cauchy", 1, gen); @@ -511,6 +503,74 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional gen) { + if (alpha.numel() == 0) { + return at::empty_like(alpha); + } + + using namespace mps; + + auto mps_gen = get_generator_or_default(gen, at::mps::detail::getDefaultMPSGenerator()); + auto stream = getCurrentMPSStream(); + Tensor ret = at::empty_like(alpha, alpha.options(), at::MemoryFormat::Contiguous); + auto alpha_contig = alpha.contiguous(); + + @autoreleasepool { + auto pso = lib.getPipelineStateForFunc("standard_gamma_" + scalarToMetalTypeString(ret)); + + int64_t seed; + int64_t base_offset; + // Each thread may consume up to GAMMA_RANDOMS_STRIDE random numbers + constexpr int64_t GAMMA_RANDOMS_STRIDE = 32; + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(mps_gen->mutex_); + seed = static_cast(mps_gen->current_seed()); + base_offset = static_cast(mps_gen->get_offset()); + mps_gen->set_offset(base_offset + GAMMA_RANDOMS_STRIDE * ret.numel()); + } + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + mtl_setArgs(computeEncoder, ret, alpha_contig, std::array{seed, base_offset}); + mtl_dispatch1DJob(computeEncoder, pso, ret.numel()); + } + }); + } + + return ret; +} + +Tensor _standard_gamma_grad_mps(const Tensor& self, const Tensor& output) { + if (self.numel() == 0) { + return at::empty_like(self); + } + + using namespace mps; + + auto stream = getCurrentMPSStream(); + Tensor ret = at::empty_like(self, self.options(), at::MemoryFormat::Contiguous); + const auto self_contig = self.contiguous(); + const auto output_contig = output.contiguous(); + + @autoreleasepool { + auto pso = lib.getPipelineStateForFunc("standard_gamma_grad_" + scalarToMetalTypeString(ret)); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:pso]; + mtl_setArgs(computeEncoder, ret, self_contig, output_contig); + mtl_dispatch1DJob(computeEncoder, pso, ret.numel()); + } + }); + } + + return ret; +} + Tensor& randperm_out_mps(int64_t n, std::optional generator, Tensor& result) { TORCH_CHECK(n >= 0, "n must be non-negative, got", n); TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), @@ -550,124 +610,15 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional generator, Tensor& result) { - using namespace mps; - - auto mps_gen = get_generator_or_default(generator, at::mps::detail::getDefaultMPSGenerator()); - int inputSize = self.dim(); - int numDist = inputSize == 1 ? 1 : self.size(0); - int numCategories = inputSize == 1 ? self.size(0) : self.size(1); - - // Restructure data for 2d - auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self; - auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result; - - MPSStream* stream = getCurrentMPSStream(); - - @autoreleasepool { - std::string key = "multinomial_with_replacement:" + getTensorsStringKey({self}) + ":" + std::to_string(n_sample); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSShape* prob_shape = getMPSShape(self_v); - newCachedGraph->stateTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @7 ]); - - auto prob_dtype = getMPSDataType(self_v); - - // This is probability weights - newCachedGraph->probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v), prob_shape); - - MPSGraphTensor* sumProbs = [mpsGraph reductionSumWithTensor:newCachedGraph->probTensor axis:-1 name:nil]; - - MPSGraphTensor* normalizedProbs = [mpsGraph divisionWithPrimaryTensor:newCachedGraph->probTensor - secondaryTensor:sumProbs - name:nil]; - - auto ns_numCategories = [NSNumber numberWithInt:numCategories]; - auto ns_numDist = [NSNumber numberWithInt:numDist]; - auto ns_n_sample = [NSNumber numberWithInt:n_sample]; - - MPSGraphTensor* ones = [mpsGraph constantWithScalar:1.0f - shape:@[ ns_numCategories, ns_numCategories ] - dataType:prob_dtype]; - auto zeroTensor = [mpsGraph constantWithScalar:0.0f dataType:MPSDataTypeInt32]; - auto minusOneTensor = [mpsGraph constantWithScalar:-1.0f dataType:MPSDataTypeInt32]; - - MPSGraphTensor* upperTriangle = [mpsGraph bandPartWithTensor:ones - numLowerTensor:zeroTensor - numUpperTensor:minusOneTensor - name:nil]; - MPSGraphTensor* upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs - secondaryTensor:upperTriangle - name:nil]; - - MPSGraphTensor* lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange - secondaryTensor:normalizedProbs - name:nil]; - - upperProbRange = [mpsGraph reshapeTensor:upperProbRange withShape:@[ ns_numDist, @1, ns_numCategories ] name:nil]; - lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange withShape:@[ ns_numDist, @1, ns_numCategories ] name:nil]; - - MPSGraphRandomOpDescriptor* descriptor = - [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform dataType:prob_dtype]; - NSArray* generatorTensors = [mpsGraph randomTensorWithShape:@[ ns_numDist, ns_n_sample, @1 ] - descriptor:descriptor - stateTensor:newCachedGraph->stateTensor - name:nil]; - MPSGraphTensor* randomTensor = generatorTensors[0]; - - auto broadcastShape = @[ ns_numDist, ns_n_sample, ns_numCategories ]; - int broadcastShapeVals[3] = {numDist, static_cast(n_sample), numCategories}; - MPSGraphTensor* broadcastShapeTensor = - [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count] - shape:@[ [NSNumber numberWithUnsignedInteger:broadcastShape.count] ] - dataType:MPSDataTypeUInt32]; - - MPSGraphTensor* samplesTensor = [mpsGraph broadcastTensor:randomTensor toShape:broadcastShape name:nil]; - MPSGraphTensor* sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor - secondaryTensor:lowerProbRange - name:nil]; - MPSGraphTensor* sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor - secondaryTensor:upperProbRange - name:nil]; - MPSGraphTensor* sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove - secondaryTensor:sampleBelow - name:nil]; - MPSGraphTensor* sampleMask = [mpsGraph castTensor:sampleWithin toType:MPSDataTypeInt32 name:@"sampleMask"]; - MPSGraphTensor* categoriesTensor = [mpsGraph coordinateAlongAxis:-1 - withShapeTensor:broadcastShapeTensor - name:nil]; - MPSGraphTensor* binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor - secondaryTensor:sampleMask - name:nil]; - MPSGraphTensor* reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor axis:-1 name:nil]; - MPSGraphTensor* reshapeTensor = [mpsGraph reshapeTensor:reducedTensor - withShape:@[ ns_numDist, ns_n_sample ] - name:nil]; - newCachedGraph->resultTensor = [mpsGraph castTensor:reshapeTensor - toType:getMPSDataType(result) - name:@"resultTensor"]; - }); - // update the Philox state values on each run of the same graph - MPSNDArrayDescriptor* stateDesc = - [MPSNDArrayDescriptor descriptorWithDataType:MPSDataTypeInt32 shape:@[ @(at::mps::detail::PHILOX_STATE_N) ]]; - MPSNDArray* stateNDArray = [[[MPSNDArray alloc] initWithDevice:stream->device() descriptor:stateDesc] autorelease]; - { - // See Note [Acquire lock when using random generators] - std::lock_guard lock(mps_gen->mutex_); - // update the Philox state values on each run - mps_gen->update_philox_counters(); - [stateNDArray writeBytes:mps_gen->state_data() strideBytes:nil]; - } - MPSGraphTensorData* stateTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray:stateNDArray] autorelease]; - - auto probPlaceholder = Placeholder(cachedGraph->probTensor, self_v); - auto outputPlaceholder = Placeholder(cachedGraph->resultTensor, result_v); - NSDictionary* feeds = @{ - cachedGraph->stateTensor : stateTensorData, - probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData() - }; - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - - return result; + auto numCategories = self.size(-1); + // CDF accumulated in float32 since bfloat16/float16 lose precision summing many small probabilities. + // Sample u from U[0, total) and search the unnormalized CDF, + // equivalent to normalizing then sampling u from U[0, 1) + auto cdf = self.cumsum(-1, /*dtype=*/kFloat); + auto uniform = at::rand(result.sizes(), generator, self.options().dtype(kFloat)) + .mul_(cdf.select(-1, numCategories - 1).unsqueeze(-1)); + at::searchsorted_out(result, cdf, uniform); + return result.clamp_(0, numCategories - 1); } /* The largest consecutive integer representable in float32 (2^24) */ diff --git a/aten/src/ATen/native/mps/operations/Eye.mm b/aten/src/ATen/native/mps/operations/Eye.mm index 592654dd57740..1b603304f565e 100644 --- a/aten/src/ATen/native/mps/operations/Eye.mm +++ b/aten/src/ATen/native/mps/operations/Eye.mm @@ -1,5 +1,7 @@ #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include + #ifndef AT_PER_OPERATOR_HEADERS #include #include @@ -7,95 +9,76 @@ #include #endif -// Steps to add op for MPS backend: -// 1. Register the op in aten/src/ATen/native/native_functions.yaml with the "MPS" dispatch key -// 2. Define the function interface for the MPS backend similar to other -// backends depending on whether its structured or non-structured -// 3. Add boiler-plate error checking code as expected for the Op -// 4. The code structure roughly follows the pattern -// a) get the MPS stream handle to encode work onto -// b) get an instance of MPSGraphCache and create a key unique to the Graph -// needed for implementing this Op. Any shape, dataType or parameter -// passed to the MPSGraph during its construction will need to be included -// here. -// c) Create the graph using make_mps_graph() and add operations to the -// instance of MPSGraph. This is if the Cache->lookup() fails. -// d) Store the MPSGraphTensors for inputs and output which are needed at -// runtime. -// e) Use the CachedGraph instance's inputs and output to create Placeholders -// You will need to pass in Tensor to create MPSGraphTensorData objects. -// f) Using MPSGraphTensor and MPSGraphTensorData instances create a feeds -// dictionary. -// g) Then call runMPSGraph() with input params and return the result. -// - namespace at::native { +using namespace mps; + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif Tensor& eye_out_mps(int64_t n, Tensor& result) { // the default value of `m` equals to `n` return eye_out_mps(n, n, result); } -using namespace mps; - Tensor& eye_out_mps(int64_t n, int64_t m, Tensor& result) { - // This is one example of boiler-plate error checking, taking after CPU/CUDA counterparts TORCH_CHECK(n >= 0, "n must be greater or equal to 0, got ", n); TORCH_CHECK(m >= 0, "m must be greater or equal to 0, got ", m); result.resize_({n, m}); - result.zero_(); - // Handle empty outputs if (result.numel() == 0) return result; - // Get MPS stream - MPSStream* stream = getCurrentMPSStream(); - - auto outputDataType = result.scalar_type(); - // Derive from MPSCachedGraph - // This structure is used to cache an MPSGraph with certain keys, so that we don't have to compile the same MPSGraph - // time and time again for the same operation The keys of this structure are based on the inputs and outputs needed - // for the operation here, we don't have any input tensors, just an output tensor. - // If the operator to be added is unary or binary, instead of creating a new CachedGraph struct yourself, please - // consider using `MPSUnaryCachedGraph` or `MPSBinaryCachedGraph` and their corresponding Grad versions in - // `OperationUtils.h`. - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* outputTensor_ = nil; - }; - - @autoreleasepool { - // A key is used to identify the MPSGraph which was created once, and can be reused if the parameters, data types - // etc match the earlier created MPSGraph - std::string key = "eye_out_mps:" + getTensorsStringKey({result}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto* mpsGraph, auto* newCachedGraph) { - MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1.0f - shape:getMPSShape(result) - dataType:getMPSDataType(outputDataType)]; - - // Here we can call the MPSGraph API needed to execute the operation. - // The API details can be found here: - // https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph - MPSGraphTensor* outputTensor = [mpsGraph bandPartWithTensor:onesTensor numLower:0 numUpper:0 name:nil]; - - if ([outputTensor dataType] != getMPSDataType(outputDataType)) { - outputTensor = castMPSTensor(mpsGraph, outputTensor, outputDataType); + MPSStream* mpsStream = getCurrentMPSStream(); + auto stride0 = result.stride(0); + auto stride1 = result.stride(1); + + // Small tensors: single-pass 2D kernel (one dispatch, no zero_() overhead). + // Large tensors: zero_() + diagonal fill (memset is faster than n*m branching writes). + constexpr int64_t kSinglePassThreshold = 1 << 22; + + if (n * m <= kSinglePassThreshold) { + auto key = "eye_" + scalarToMetalTypeString(result); + id computeEncoder = mpsStream->commandEncoder(); + id pso = lib.getPipelineStateForFunc(key); + + // Map x to the smaller stride for coalesced writes + bool swap = stride0 < stride1; + auto x_stride = swap ? stride0 : stride1; + auto y_stride = swap ? stride1 : stride0; + auto grid_x = swap ? static_cast(n) : static_cast(m); + auto grid_y = swap ? static_cast(m) : static_cast(n); + + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + [computeEncoder setComputePipelineState:pso]; + mtl_setArgs(computeEncoder, result, y_stride, x_stride); + + auto maxTg = [pso maxTotalThreadsPerThreadgroup]; + auto tg_x = std::min(grid_x, maxTg); + auto tg_y = std::min(grid_y, std::max(maxTg / tg_x, static_cast(1))); + [computeEncoder dispatchThreads:MTLSizeMake(grid_x, grid_y, 1) + threadsPerThreadgroup:MTLSizeMake(tg_x, tg_y, 1)]; + } + }); + } else { + result.zero_(); + int64_t sz = std::min(n, m); + int64_t diag_stride = stride0 + stride1; + auto key = "eye_diag_" + scalarToMetalTypeString(result); + id computeEncoder = mpsStream->commandEncoder(); + id pso = lib.getPipelineStateForFunc(key); + + dispatch_sync(mpsStream->queue(), ^() { + @autoreleasepool { + [computeEncoder setComputePipelineState:pso]; + mtl_setArgs(computeEncoder, result, diag_stride); + mtl_dispatch1DJob(computeEncoder, pso, sz); } - newCachedGraph->outputTensor_ = outputTensor; }); - - // Create placeholders which use the keys of the CachedGraph to create inputs and outputs of the operation - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); - - // Create dictionary of inputs/feeds and outputs/results - // In this case, there are no inputs, so the feeds are nil - NSDictionary* feeds = nil; - auto results = dictionaryFromPlaceholders(outputPlaceholder); - - // Run the graph - runMPSGraph(stream, cachedGraph->graph(), feeds, results); } return result; diff --git a/aten/src/ATen/native/mps/operations/GridSampler.mm b/aten/src/ATen/native/mps/operations/GridSampler.mm index 2a232decc66ed..17af52c23aa80 100644 --- a/aten/src/ATen/native/mps/operations/GridSampler.mm +++ b/aten/src/ATen/native/mps/operations/GridSampler.mm @@ -9,9 +9,13 @@ #include #include #else +#include #include +#include #include +#include #include +#include #endif namespace at::native { @@ -142,9 +146,7 @@ static void grid_sampler_3d_mps_impl(Tensor& output, switch (interpolation_mode) { case GridSamplerInterpolation::Bilinear: - break; case GridSamplerInterpolation::Nearest: - TORCH_CHECK(false, op_name, ": Unsupported Nearest interpolation"); break; case GridSamplerInterpolation::Bicubic: TORCH_CHECK(false, op_name, ": Unsupported Bicubic interpolation"); @@ -178,8 +180,10 @@ static void grid_sampler_3d_mps_impl(Tensor& output, dispatch_sync_with_rethrow(mpsStream->queue(), ^() { @autoreleasepool { id computeEncoder = mpsStream->commandEncoder(); - auto pso = lib.getPipelineStateForFunc( - fmt::format("grid_sampler_3d_{}_{}", padding_to_string(padding_mode), scalarToMetalTypeString(input))); + auto pso = lib.getPipelineStateForFunc(fmt::format("grid_sampler_3d_{}_{}_{}", + interp_to_string(interpolation_mode), + padding_to_string(padding_mode), + scalarToMetalTypeString(input))); getMPSProfiler().beginProfileKernel(pso, op_name, {input, grid}); [computeEncoder setComputePipelineState:pso]; @@ -223,4 +227,214 @@ Tensor grid_sampler_3d_mps(const Tensor& input, return output; } +std::tuple grid_sampler_2d_backward_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t _interpolation_mode, + int64_t _padding_mode, + bool align_corners, + std::array output_mask) { + check_grid_sampler_common(input, grid); + check_grid_sampler_2d(input, grid); + + TORCH_CHECK(input.scalar_type() == grid.scalar_type(), + "expected input and grid to have the same type, but got ", + input.scalar_type(), + " and ", + grid.scalar_type()); + + TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(input.scalar_type()), + "grid_sampler_2d_backward is not supported for complex on MPS"); + + auto input_requires_grad = output_mask[0]; + auto grad_input = input_requires_grad ? at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT) : Tensor(); + auto interpolation_mode = static_cast(_interpolation_mode); + auto grad_grid = interpolation_mode == GridSamplerInterpolation::Nearest + ? at::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT) + : at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto padding_mode = static_cast(_padding_mode); + + auto N = input.size(0); + auto out_H = grid.size(1); + auto out_W = grid.size(2); + auto num_threads = N * out_H * out_W; + + if (num_threads == 0) { + return std::make_tuple(grad_input, grad_grid); + } + + GridSamplerBackwardParams<4> params; + params.forward.sampler_dims = 2; + params.forward.align_corners = align_corners; + + // Forward output shape is [N, C, out_H, out_W] + params.forward.output_sizes[0] = safe_downcast(N); + params.forward.output_sizes[1] = safe_downcast(input.size(1)); + params.forward.output_sizes[2] = safe_downcast(out_H); + params.forward.output_sizes[3] = safe_downcast(out_W); + + for (const auto dim : c10::irange(input.dim())) { + params.forward.input_sizes[dim] = safe_downcast(input.size(dim)); + params.forward.input_strides[dim] = safe_downcast(input.stride(dim)); + params.forward.grid_sizes[dim] = safe_downcast(grid.size(dim)); + params.forward.grid_strides[dim] = safe_downcast(grid.stride(dim)); + params.grad_output_strides[dim] = safe_downcast(grad_output.stride(dim)); + params.grad_input_strides[dim] = input_requires_grad ? safe_downcast(grad_input.stride(dim)) : 0; + } + params.grad_grid_sW = safe_downcast(grad_grid.stride(2)); + params.padding_mode = static_cast(padding_mode); + + using namespace mps; + auto interp_str = mps::interp_to_string(interpolation_mode); + auto pad_str = mps::padding_to_string(padding_mode); + auto type_str = scalarToMetalTypeString(input); + + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + if (input_requires_grad) { + auto input_name = interpolation_mode == GridSamplerInterpolation::Bicubic + ? fmt::format("grid_sampler_2d_backward_bicubic_input_{}_{}", pad_str, type_str) + : fmt::format("grid_sampler_2d_backward_{}_input_{}", interp_str, type_str); + auto input_pso = lib.getPipelineStateForFunc(input_name); + getMPSProfiler().beginProfileKernel(input_pso, "grid_sampler_2d_backward_input", {grad_output, grid}); + [computeEncoder setComputePipelineState:input_pso]; + mtl_setArgs(computeEncoder, grad_input, grad_output, grid, params); + mtl_dispatch1DJob(computeEncoder, input_pso, num_threads); + getMPSProfiler().endProfileKernel(input_pso); + } + + if (interpolation_mode != GridSamplerInterpolation::Nearest) { + auto grid_name = interpolation_mode == GridSamplerInterpolation::Bicubic + ? fmt::format("grid_sampler_2d_backward_bicubic_grid_{}_{}", pad_str, type_str) + : fmt::format("grid_sampler_2d_backward_bilinear_grid_{}", type_str); + auto grid_pso = lib.getPipelineStateForFunc(grid_name); + getMPSProfiler().beginProfileKernel(grid_pso, "grid_sampler_2d_backward_grid", {grad_output, input, grid}); + [computeEncoder setComputePipelineState:grid_pso]; + mtl_setArgs(computeEncoder, grad_grid, grad_output, input, grid, params); + mtl_dispatch1DJob(computeEncoder, grid_pso, num_threads); + getMPSProfiler().endProfileKernel(grid_pso); + } + } + }); + + return std::make_tuple(std::move(grad_input), std::move(grad_grid)); +} + +std::tuple grid_sampler_3d_backward_mps(const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask) { + using namespace mps; + check_grid_sampler_common(input, grid); + check_grid_sampler_3d(input, grid, interpolation_mode); + + TORCH_CHECK_NOT_IMPLEMENTED(interpolation_mode == 0 || interpolation_mode == 1, + "grid_sampler_3d backward on MPS only supports bilinear and nearest interpolation"); + + TORCH_CHECK(input.scalar_type() == grid.scalar_type(), + "expected input and grid to have the same type, but got ", + input.scalar_type(), + " and ", + grid.scalar_type()); + + auto input_requires_grad = output_mask[0]; + int32_t interp_mode = static_cast(interpolation_mode); + int32_t pad_mode = static_cast(padding_mode); + + Tensor grad_input; + if (input_requires_grad) { + grad_input = at::zeros_like(input); + } + // Always allocate grad_grid, matching CPU/CUDA and the 2D MPS backward. + // Autograd requires a defined tensor for every output declared in the + // derivative, even when the corresponding input doesn't require grad. + auto grad_grid = interp_mode == 1 ? at::zeros_like(grid, MemoryFormat::Contiguous) + : at::empty_like(grid, MemoryFormat::Contiguous); + + const auto& input_contiguous = input.contiguous(); + const auto& grid_contiguous = grid.contiguous(); + const auto& grad_output_contiguous = grad_output.contiguous(); + + auto N = input_contiguous.size(0); + auto C = input_contiguous.size(1); + auto in_D = input_contiguous.size(2); + auto in_H = input_contiguous.size(3); + auto in_W = input_contiguous.size(4); + auto out_D = grid_contiguous.size(1); + auto out_H = grid_contiguous.size(2); + auto out_W = grid_contiguous.size(3); + + bool run_grad_input = input_requires_grad; + bool run_grad_grid = interp_mode != 1; + + if (!run_grad_input && !run_grad_grid) { + return std::make_tuple(std::move(grad_input), std::move(grad_grid)); + } + + // The combined kernel needs a valid buffer pointer for grad_input even when + // it is not requested, so allocate a dummy with the expected rank so stride + // queries below remain in range. + auto grad_input_buf = run_grad_input ? grad_input : at::zeros({1, 1, 1, 1, 1}, input.options()); + + GridSampler3DBackwardParams params; + params.interpolation_mode = interp_mode; + params.padding_mode = pad_mode; + params.align_corners = align_corners; + params.compute_grad_input = run_grad_input; + params.compute_grad_grid = run_grad_grid; + params.input_sizes = {safe_downcast(N), + safe_downcast(C), + safe_downcast(in_D), + safe_downcast(in_H), + safe_downcast(in_W)}; + params.output_sizes = {safe_downcast(N), + safe_downcast(C), + safe_downcast(out_D), + safe_downcast(out_H), + safe_downcast(out_W)}; + for (int i = 0; i < 5; i++) { + params.grid_strides[i] = safe_downcast(grid_contiguous.stride(i)); + params.grad_output_strides[i] = safe_downcast(grad_output_contiguous.stride(i)); + params.input_strides[i] = safe_downcast(input_contiguous.stride(i)); + params.grad_input_strides[i] = safe_downcast(grad_input_buf.stride(i)); + params.grad_grid_strides[i] = safe_downcast(grad_grid.stride(i)); + } + + MPSStream* mpsStream = getCurrentMPSStream(); + + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + id computeEncoder = mpsStream->commandEncoder(); + + auto pso = + lib.getPipelineStateForFunc(fmt::format("grid_sampler_3d_backward_{}", scalarToMetalTypeString(input))); + + getMPSProfiler().beginProfileKernel( + pso, + "grid_sampler_3d_backward", + {grad_output_contiguous, input_contiguous, grid_contiguous, grad_input_buf, grad_grid}); + + [computeEncoder setComputePipelineState:pso]; + + mtl_setArgs( + computeEncoder, grad_output_contiguous, input_contiguous, grid_contiguous, grad_input_buf, grad_grid, params); + + MTLSize threadsPerThreadgroup = MTLSizeMake(16, 16, 1); + MTLSize threadsPerGrid = MTLSizeMake(out_W, out_H * out_D, N); + [computeEncoder dispatchThreads:threadsPerGrid threadsPerThreadgroup:threadsPerThreadgroup]; + + getMPSProfiler().endProfileKernel(pso); + } + }); + + return std::make_tuple(std::move(grad_input), std::move(grad_grid)); +} + } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/HistogramKernel.mm b/aten/src/ATen/native/mps/operations/HistogramKernel.mm index 0e0dcc53b04e6..20d9a806ff410 100644 --- a/aten/src/ATen/native/mps/operations/HistogramKernel.mm +++ b/aten/src/ATen/native/mps/operations/HistogramKernel.mm @@ -172,9 +172,10 @@ static void histogramdd_out_mps_template(const Tensor& self, mps::histogramdd_kernel_impl( hist, bin_edges_contig, reshaped_input, reshaped_weight); }), - kFloat, + AT_EXPAND(AT_ALL_TYPES), + kBFloat16, kHalf, - kBFloat16); + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); /* Divides each bin's value by the total count/weight in all bins, * and by the bin's volume. diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm index dd333d2dc7ebd..1433f4ad9844c 100644 --- a/aten/src/ATen/native/mps/operations/Indexing.mm +++ b/aten/src/ATen/native/mps/operations/Indexing.mm @@ -27,8 +27,6 @@ #include #else #include -#include -#include #include #include #include @@ -42,12 +40,11 @@ #include #include #include +#include #include #include #endif -constexpr auto nonZeroMaxSize = 1UL << 24; - namespace at::native { namespace mps { @@ -292,150 +289,139 @@ static void index_put_kernel_mps(TensorIterator& iter, }); } -static Tensor nonzero_fallback(const Tensor& self) { - return at::nonzero(self.to("cpu")).to("mps"); -} - -static Tensor& nonzero_out_native_mps(const Tensor& self, Tensor& out_) { +// Metal kernel-based nonzero using prefix-sum + scatter. +// Step 1: Per-element exclusive prefix sum of nonzero flags + block totals. +// Step 2: GPU prefix sum of block totals → block offsets + total count. +// Host (optional): Read back total count, allocate output, unless max_element is provided +// Step 3: Scatter multi-dimensional indices into the output. +static void nonzero_impl_mps(const Tensor& self, Tensor& out_, std::optional max_elements) { using namespace mps; - int64_t nDim = self.dim(); + TORCH_CHECK(self.numel() < std::numeric_limits::max(), + "nonzero is not supported for tensors with more than INT_MAX elements, " + "See https://github.com/pytorch/pytorch/issues/51871"); + TORCH_CHECK(out_.dtype() == at::kLong, "Expected output type to be Long, but got ", out_.dtype()); + TORCH_CHECK(self.device() == out_.device(), + "expected self and out to be on the same device, but got out on ", + out_.device(), + " and self on ", + self.device()); + TORCH_CHECK(out_.is_mps()); + + Tensor input = self.contiguous(); + const int64_t nDim = self.dim(); + const auto numel = static_cast(input.numel()); + const auto type_str = scalarToMetalTypeString(input); MPSStream* stream = getCurrentMPSStream(); - using CachedGraph = MPSUnaryCachedGraph; + auto pso_step1 = lib.getPipelineStateForFunc(fmt::format("count_nonzero_prefix_sum_{}", type_str)); + auto pso_step2 = lib.getPipelineStateForFunc("prefix_sum_blocks"); + auto pso_step3 = lib.getPipelineStateForFunc(fmt::format("scatter_nonzero_indices_{}", type_str)); + TORCH_INTERNAL_ASSERT([pso_step1 maxTotalThreadsPerThreadgroup] == [pso_step3 maxTotalThreadsPerThreadgroup], + "nonzero: step 1 and step 3 threadgroup sizes must match"); + + uint32_t threads_per_group = static_cast([pso_step1 maxTotalThreadsPerThreadgroup]); + uint32_t num_blocks = (numel + threads_per_group - 1) / threads_per_group; + + auto tmp = at::empty({input.numel() + 2 * num_blocks + 1}, input.options().dtype(kInt)); + Tensor prefix_buf = tmp.slice(0, 0, numel); + Tensor block_sums_buf = tmp.slice(0, numel, numel + num_blocks); + Tensor block_offsets_buf = tmp.slice(0, numel + num_blocks, numel + 2 * num_blocks); + Tensor total_nonzero_buf = tmp.slice(0, numel + 2 * num_blocks, numel + 2 * num_blocks + 1); + + // Steps 1+2: compute prefix sums and block offsets entirely on GPU dispatch_sync_with_rethrow(stream->queue(), ^() { - stream->synchronize(SyncType::COMMIT_AND_WAIT); + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + + [computeEncoder setComputePipelineState:pso_step1]; + mtl_setArgs(computeEncoder, input, prefix_buf, block_sums_buf); + mtl_dispatch1DJob(computeEncoder, pso_step1, numel); + + [computeEncoder setComputePipelineState:pso_step2]; + mtl_setArgs(computeEncoder, block_sums_buf, block_offsets_buf, total_nonzero_buf, num_blocks); + uint32_t tg_size_blocks = std::min(1024u, ((num_blocks + 31) / 32) * 32); + [computeEncoder dispatchThreads:MTLSizeMake(tg_size_blocks, 1, 1) + threadsPerThreadgroup:MTLSizeMake(tg_size_blocks, 1, 1)]; + } }); - int64_t total_nonzero = at::count_nonzero(self).item(); - at::native::resize_output(out_, {total_nonzero, nDim}); - if (out_.numel() == 0) { - return out_; - } - bool contiguous_output = !needsGather(out_); - Tensor out = out_; - if (!contiguous_output) { - out = at::empty_like(out_, MemoryFormat::Contiguous); + if (!max_elements) { + // Dynamic path: sync to learn output size + const int64_t total_nonzero = total_nonzero_buf.item(); + at::native::resize_output(out_, {total_nonzero, nDim}); + max_elements = total_nonzero; } - @autoreleasepool { - std::string key = "nonzero_out_native_mps" + getTensorsStringKey(self); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); + if (out_.numel() == 0) { + return; + } - MPSGraphTensor* outputTensor = [mpsGraph nonZeroIndicesOfTensor:inputTensor name:nil]; + bool contiguous_output = out_.is_contiguous(); + Tensor out = contiguous_output ? out_ : at::empty_like(out_, MemoryFormat::Contiguous); - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); + int ndim_int = static_cast(nDim); + int max_entries = static_cast(*max_elements); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } + // Step 3: scatter indices, capped at max_entries + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:pso_step3]; + mtl_setArgs(computeEncoder, input, prefix_buf, out, ndim_int, input.sizes(), block_offsets_buf, max_entries); + mtl_dispatch1DJob(computeEncoder, pso_step3, numel); + } + }); if (!contiguous_output) { out_.copy_(out); } - - return out_; } Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) { - if (self.is_complex()) { - TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ", - "Falling back on CPU. This may have performance implications."); - Tensor out_fallback = nonzero_fallback(self); - at::native::resize_output(out_, out_fallback.sizes()); - out_.copy_(out_fallback); - return out_; - } - int64_t nDim = self.dim(); if (self.numel() == 0) { at::native::resize_output(out_, {0, nDim}); return out_; } - using namespace mps; - const uint32_t maxDimensions = 16; - - TORCH_CHECK(self.numel() < std::numeric_limits::max(), - "nonzero is not supported for tensors with more than INT_MAX elements, \ - See https://github.com/pytorch/pytorch/issues/51871"); - TORCH_CHECK( - out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype()); - TORCH_CHECK(self.device() == out_.device(), - "expected self and out to be on the same device, but got out on ", - out_.device(), - " and self on ", - self.device()); - TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions"); - TORCH_CHECK(out_.is_mps()); + nonzero_impl_mps(self, out_, std::nullopt); + return out_; +} - if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && - (self.numel() >= nonZeroMaxSize || self.is_complex())) { - TORCH_WARN_ONCE("MPS: nonzero op is not natively supported for the provided input on MacOS14", - "Falling back on CPU. This may have performance implications.", - "See github.com/pytorch/pytorch/issues/122916 for further info"); - Tensor out_fallback = nonzero_fallback(self); - at::native::resize_output(out_, out_fallback.sizes()); - out_.copy_(out_fallback); - return out_; - } +Tensor nonzero_mps(const Tensor& self) { + Tensor out = at::empty({0}, self.options().dtype(kLong)); + return nonzero_out_mps(self, out); +} - MPSStream* stream = getCurrentMPSStream(); - using CachedGraph = MPSUnaryCachedGraph; +Tensor& nonzero_static_out_mps(const Tensor& self, int64_t size, int64_t fill_value, Tensor& result) { + TORCH_CHECK(size >= 0, "nonzero_static: 'size' must be an non-negative integer"); - dispatch_sync_with_rethrow(stream->queue(), ^() { - stream->synchronize(SyncType::COMMIT_AND_WAIT); - }); - int64_t total_nonzero = at::count_nonzero(self).item(); - at::native::resize_output(out_, {total_nonzero, nDim}); - if (out_.numel() == 0) { - return out_; + int64_t nDim = self.dim(); + if (result.dim() != 2 || result.size(0) != size || result.size(1) != nDim) { + at::native::resize_output(result, {size, nDim}); } - bool contiguous_output = !needsGather(out_); - Tensor out = out_; - if (!contiguous_output) { - out = at::empty_like(out_, MemoryFormat::Contiguous); + if (result.size(0) == 0 || result.size(1) == 0) { + return result; } - @autoreleasepool { - std::string key = "nonzero_out_native_mps" + getTensorsStringKey(self); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - - MPSGraphTensor* outputTensor = [mpsGraph nonZeroIndicesOfTensor:inputTensor name:nil]; - - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = outputTensor; - }); + result.fill_(fill_value); - Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self); - Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder); - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } - - if (!contiguous_output) { - out_.copy_(out); + if (self.numel() == 0) { + return result; } - return out_; + nonzero_impl_mps(self, result, size); + return result; } -Tensor nonzero_mps(const Tensor& self) { - if (self.is_complex()) { - TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ", - "Falling back on CPU. This may have performance implications."); - return nonzero_fallback(self); - } - - Tensor out = at::empty({0}, self.options().dtype(kLong)); - return nonzero_out_mps(self, out); +Tensor nonzero_static_mps(const Tensor& self, int64_t size, int64_t fill_value) { + TORCH_CHECK(size >= 0, "nonzero_static: 'size' must be an non-negative integer"); + int64_t nDim = self.dim(); + auto result = at::empty({size, nDim}, at::TensorOptions().dtype(kLong).device(kMPS)); + nonzero_static_out_mps(self, size, fill_value, result); + return result; } Tensor masked_select_mps(const Tensor& self, const Tensor& mask) { diff --git a/aten/src/ATen/native/mps/operations/Lerp.mm b/aten/src/ATen/native/mps/operations/Lerp.mm deleted file mode 100644 index 4d13b8b2466bd..0000000000000 --- a/aten/src/ATen/native/mps/operations/Lerp.mm +++ /dev/null @@ -1,49 +0,0 @@ -#define TORCH_ASSERT_ONLY_METHOD_OPERATORS -#include -#include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#include -#else -#include -#include -#endif - -namespace at::native { -TORCH_IMPL_FUNC(lerp_Tensor_mps)(const Tensor& self, const Tensor& end, const Tensor& weight, const Tensor& out) { - TORCH_CHECK(out.is_mps()); - std::array args{{{out, "out", 0}, {self, "self", 1}, {end, "end", 2}, {weight, "weight", 3}}}; - checkAllSameGPU(__func__, args); - using namespace mps; - struct CachedGraph : public MPSCachedGraph { - CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} - MPSGraphTensor* selfTensor_ = nil; - MPSGraphTensor* endTensor_ = nil; - MPSGraphTensor* weightTensor_ = nil; - MPSGraphTensor* outputTensor_ = nil; - }; - @autoreleasepool { - std::string key = "lerp_Tensor_mps" + getTensorsStringKey({self, end, weight}); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto graph) { - auto selfTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); - auto endTensor = mpsGraphRankedPlaceHolder(mpsGraph, end); - auto weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight); - auto distance = [mpsGraph subtractionWithPrimaryTensor:endTensor secondaryTensor:selfTensor name:nil]; - auto weighedDistance = [mpsGraph multiplicationWithPrimaryTensor:weightTensor secondaryTensor:distance name:nil]; - auto output = [mpsGraph additionWithPrimaryTensor:selfTensor secondaryTensor:weighedDistance name:nil]; - graph->selfTensor_ = selfTensor; - graph->endTensor_ = endTensor; - graph->weightTensor_ = weightTensor; - graph->outputTensor_ = output; - }); - auto selfPlaceholder = Placeholder(cachedGraph->selfTensor_, self); - auto endPlaceholder = Placeholder(cachedGraph->endTensor_, end); - auto weightPlaceholder = Placeholder(cachedGraph->weightTensor_, weight); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out); - auto feeds = dictionaryFromPlaceholders(selfPlaceholder, endPlaceholder, weightPlaceholder); - runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); - } -} - -} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Linear.mm b/aten/src/ATen/native/mps/operations/Linear.mm index 219086edd8e37..2ccb89bc4d702 100644 --- a/aten/src/ATen/native/mps/operations/Linear.mm +++ b/aten/src/ATen/native/mps/operations/Linear.mm @@ -76,8 +76,6 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt TORCH_CHECK(input.is_mps(), "Tensor for argument input is on ", input.device(), " but expected on mps"); TORCH_CHECK(supportedFloatingOrComplexType(weight_arg), "MPS device does not support linear for non-float weights"); TORCH_CHECK(weight_arg.is_mps(), "Tensor for argument weight is on ", weight_arg.device(), " but expected on mps"); - TORCH_CHECK((input.scalar_type() != kComplexFloat && input.scalar_type() != kComplexHalf), - "mps linear does not support complex types"); const Tensor& bias = *(at::borrow_from_optional_tensor(bias_opt)); const bool is_bias_defined = bias.defined(); @@ -117,8 +115,9 @@ Tensor _mps_linear(const Tensor& input, const Tensor& weight_arg, const std::opt // No-graph execution causes nonsense if these are non-contiguous. const bool is_contiguous = input.is_contiguous() && weight.is_contiguous() && bias.is_contiguous(); + const bool is_complex = input.is_complex() || weight.is_complex() || (is_bias_defined && bias.is_complex()); - if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_contiguous) { + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) && is_contiguous && !is_complex) { _mps_linear_nograph(input, weight, bias, output); // Squeeze last dim of 1D linear return weight_arg.dim() != 1 ? output : output.squeeze(-1); @@ -206,8 +205,7 @@ static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& g MPSGraphTensor* outputTensor_ = nil; }; - Tensor output = at::empty( - input_size, grad_output.scalar_type(), std::nullopt, kMPS, std::nullopt, grad_output.suggest_memory_format()); + Tensor output = at::empty(input_size, grad_output.options()); TORCH_CHECK(output.is_mps()); if (grad_output.numel() == 0) { return output; @@ -275,18 +273,8 @@ static Tensor _mps_linear_backward_input(IntArrayRef input_size, const Tensor& g TORCH_CHECK(grad_output_reshaped.is_mps()); TORCH_CHECK(input_reshaped.is_mps()); - Tensor output = at::empty({grad_output_reshaped.size(1), input_reshaped.size(1)}, - grad_output.scalar_type(), - std::nullopt, - kMPS, - std::nullopt, - grad_output.suggest_memory_format()); - Tensor bias = at::empty({grad_output_reshaped.size(1)}, - grad_output.scalar_type(), - std::nullopt, - kMPS, - std::nullopt, - grad_output.suggest_memory_format()); + Tensor output = at::empty({grad_output_reshaped.size(1), input_reshaped.size(1)}, grad_output.options()); + Tensor bias = at::empty({grad_output_reshaped.size(1)}, grad_output.options()); TORCH_CHECK(output.is_mps()); TORCH_CHECK(bias.is_mps()); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index fe62a82a33b15..6bfd7d1049c4b 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -310,6 +310,20 @@ bool use_metal_mm(const Tensor& self, const Tensor& other, const Tensor& output) if (c10::isComplexType(self.scalar_type()) && self.size(1) > max_complex_inner_size) { return true; } + // Detect conditions that would trigger LORADOWN GEMV kernel with potential padding overflow + // See https://github.com/pytorch/pytorch/issues/178056 + if (self.scalar_type() == at::ScalarType::Half && (self.size(0) <= 16 || other.size(1) <= 16) && + self.stride(1) == 1 && other.stride(0) == 1) { + int64_t self_padding = self.stride(0) - self.size(1); + int64_t other_padding = other.stride(1) - other.size(0); + + if (self_padding > 15 || other_padding > 15 || self_padding % 4 != 0 || other_padding % 4 != 0) { + TORCH_WARN_ONCE( + "MPS mm implementation has a known issue with this shape, dtype and slice. Dispatching to metal implementation instead. This may impact performance."); + return true; + } + } + return !is_macos_14_4_or_newer && (self.stride(0) > max_stride_size || self.stride(1) > max_stride_size || self.size(0) > max_stride_size || self.size(1) > max_stride_size || other.stride(0) > max_stride_size || other.stride(1) > max_stride_size || @@ -698,13 +712,21 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const }); // MPS TODO: // Strided API doesn't play nice with complex data types (at least not in case of matmul). + // MPSGraph's matrixMultiplication produces incorrect results with stride-0 NDArray + // inputs on macOS < 26.4 (only every 16th row is computed). Contiguify such tensors + // by disabling the strided API so they go through the gather/clone path first. + // See https://github.com/pytorch/pytorch/issues/180201 + static const bool is_macOS_26_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_26_4_PLUS); + auto hasZeroStride = [](const Tensor& t) { + return std::ranges::any_of(t.strides(), [](auto s) { return s == 0; }); + }; + auto useStridedSelf = !isComplexType(self.scalar_type()) && (is_macOS_26_4_or_newer || !hasZeroStride(self)); + auto useStridedOther = !isComplexType(other.scalar_type()) && (is_macOS_26_4_or_newer || !hasZeroStride(other)); auto selfPlaceholder = self.numel() != 0 - ? Placeholder( - cachedGraph->inputTensor_, self, nil, true, MPSDataTypeInvalid, !isComplexType(self.scalar_type())) + ? Placeholder(cachedGraph->inputTensor_, self, nil, true, MPSDataTypeInvalid, useStridedSelf) : Placeholder(); auto otherPlaceholder = other.numel() != 0 - ? Placeholder( - cachedGraph->otherTensor_, other, nil, true, MPSDataTypeInvalid, !isComplexType(other.scalar_type())) + ? Placeholder(cachedGraph->otherTensor_, other, nil, true, MPSDataTypeInvalid, useStridedOther) : Placeholder(); auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output); @@ -899,7 +921,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const std::string key = "addmm_out_mps_impl" + getTensorsStringKey({self, other, *bias_}) + ":" + std::to_string(beta.toDouble()) + ":" + std::to_string(alpha.toDouble()); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_); + auto biasTensor = mpsGraphRankedPlaceHolder(mpsGraph, *bias_); + auto biasTensor_ = bias_->is_conj() ? [mpsGraph conjugateWithTensor:biasTensor name:nil] : biasTensor; // TODO: Use alpha and beta here with fill_.Scalar and mul auto [selfTensor, otherTensor, productTensor] = do_mm(mpsGraph, self, other); @@ -912,11 +935,11 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const secondaryTensor:alphaTensor name:@"MM/alpha*(mat1@mat2)"]; } - auto biasTimesBetaTensor = biasTensor; + auto biasTimesBetaTensor = biasTensor_; if (is_beta_non_zero && beta.toDouble() != 1.0) { auto betaTensor = [mpsGraph constantWithScalar:beta.toDouble() dataType:getMPSScalarType((*bias_).scalar_type())]; - biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:biasTensor + biasTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:biasTensor_ secondaryTensor:betaTensor name:@"MM/beta*input"]; } @@ -1128,7 +1151,8 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const // Call tiled implementation if the number of elements exceeds 2^32 uint64_t resultSize = batch1.size(0) * batch1.size(1) * batch2.size(2); if (resultSize > pow(2, 32)) { - result = tiled_bmm_out_mps_impl(batch1, batch2, result); + // Tiled path uses MPSNDArray directly, so resolve conjugate views upfront + result = tiled_bmm_out_mps_impl(batch1.resolve_conj(), batch2.resolve_conj(), result); return result; } @@ -1146,16 +1170,18 @@ static void linalg_inv_ex_out_mps_impl(const Tensor& A, bool check_errors, const std::to_string(doTranspose); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - MPSGraphTensor* batch1Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch1.scalar_type())); - MPSGraphTensor* batch2Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch2.scalar_type())); - MPSGraphTensor* batch2TensorTranspose = batch2Tensor; + auto batch1Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch1.scalar_type())); + auto batch2Tensor = mps::mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(batch2.scalar_type())); + + auto batch1TensorOp = batch1.is_conj() ? [mpsGraph conjugateWithTensor:batch1Tensor name:nil] : batch1Tensor; + auto batch2TensorOp = batch2.is_conj() ? [mpsGraph conjugateWithTensor:batch2Tensor name:nil] : batch2Tensor; if (doTranspose) { - batch2TensorTranspose = [mpsGraph transposeTensor:batch2Tensor dimension:-1 withDimension:-2 name:nil]; + batch2TensorOp = [mpsGraph transposeTensor:batch2TensorOp dimension:-1 withDimension:-2 name:nil]; } - MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1Tensor - secondaryTensor:batch2TensorTranspose + MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:batch1TensorOp + secondaryTensor:batch2TensorOp name:@"MM/(batch1@batch2)"]; newCachedGraph->batch1Tensor_ = batch1Tensor; diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 27ef097963bfb..29d6b6ae3f3f2 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -96,6 +96,24 @@ static void get_shapes(MPSShape* input_shape_readonly, Tensor& output, Tensor& save_mean, Tensor& save_var) { + // Flatten 5D to 4D: MPSGraph normalization is significantly slower for rank-5 tensors. + // Merging spatial dims is safe since BatchNorm reduces over all dims except channel. + if (self.dim() == 5) { + auto input_4d = self.contiguous().reshape({self.size(0), self.size(1), self.size(2) * self.size(3), self.size(4)}); + auto output_4d = output.reshape(input_4d.sizes()); + return batch_norm_mps_out(input_4d, + weight_opt, + bias_opt, + running_mean_opt, + running_var_opt, + train, + momentum, + epsilon, + output_4d, + save_mean, + save_var); + } + TORCH_CHECK_NOT_IMPLEMENTED(self.scalar_type() != kLong, "Long batch norm is not supported with MPS"); TORCH_CHECK_NOT_IMPLEMENTED(!c10::isComplexType(self.scalar_type()), "Batch norm for complex is not supported for MPS"); @@ -124,7 +142,10 @@ static void get_shapes(MPSShape* input_shape_readonly, const bool has_weight = (weight_opt.has_value() && weight_opt->defined()); const bool has_bias = (bias_opt.has_value() && bias_opt->defined()); - auto memory_format = self.suggest_memory_format(); + // Use exact-match: a channel-slice of a channels-last tensor has CL-like + // strides but is not NHWC-packed. + // See https://github.com/pytorch/pytorch/issues/180984 + auto memory_format = self.suggest_memory_format(/*channels_last_strides_exact_match=*/true); if (output.numel() == 0) { return std::tuple(output, save_mean, save_var); @@ -221,25 +242,29 @@ Check if running mean exists (maybe do this check before making graph) MPSGraphTensor* batchVarianceTensor = [mpsGraph varianceOfTensor:inputTensor axes:axes name:nil]; varTensor = batchVarianceTensor; if (has_running_mean) { + // Running stats may have a different dtype (e.g. float32 with float16 input) + auto running_mean_dtype = getMPSDataType(running_mean_opt.value()); // TODO: This is not the formula used in PyTorch, is this OK? Seems more robust // float besselCorrectionTerm = float(N) / std::max(N - 1.0f, 1.0f); float besselCorrectionTerm = float(N) / float(N - 1.0f); MPSGraphTensor* besselConstantTensor = [mpsGraph constantWithScalar:(double)besselCorrectionTerm shape:@[ @1 ] - dataType:input_mps_dtype]; - MPSGraphTensor* unbiasedVarianceTensor = [mpsGraph multiplicationWithPrimaryTensor:batchVarianceTensor - secondaryTensor:besselConstantTensor - name:nil]; + dataType:running_mean_dtype]; + MPSGraphTensor* unbiasedVarianceTensor = + [mpsGraph multiplicationWithPrimaryTensor:castMPSTensor(mpsGraph, batchVarianceTensor, running_mean_dtype) + secondaryTensor:besselConstantTensor + name:nil]; MPSGraphTensor* momentumTensor = [mpsGraph constantWithScalar:(double)momentum shape:@[ @1 ] - dataType:input_mps_dtype]; + dataType:running_mean_dtype]; MPSGraphTensor* oneMinusMomentum = [mpsGraph constantWithScalar:(double)(1.0 - momentum) shape:@[ @1 ] - dataType:input_mps_dtype]; + dataType:running_mean_dtype]; // Compute updated running mean - MPSGraphTensor* scaledBatchMean = [mpsGraph multiplicationWithPrimaryTensor:batchMeanTensor - secondaryTensor:momentumTensor - name:nil]; + MPSGraphTensor* scaledBatchMean = + [mpsGraph multiplicationWithPrimaryTensor:castMPSTensor(mpsGraph, batchMeanTensor, running_mean_dtype) + secondaryTensor:momentumTensor + name:nil]; MPSGraphTensor* scaledRunningMean = [mpsGraph multiplicationWithPrimaryTensor:runningMeanTensor secondaryTensor:oneMinusMomentum name:nil]; @@ -278,12 +303,16 @@ Check if running mean exists (maybe do this check before making graph) varTensor = saveVarTensor; } + // Cast weight and bias to input dtype if needed (mixed-precision support) + MPSGraphTensor* gammaTensor = has_weight ? castMPSTensor(mpsGraph, weightTensor, input_mps_dtype) : nil; + MPSGraphTensor* betaTensor = has_bias ? castMPSTensor(mpsGraph, biasTensor, input_mps_dtype) : nil; + // Compute output of batch norm MPSGraphTensor* outputTensor = [mpsGraph normalizationWithTensor:inputTensor meanTensor:saveMeanTensor varianceTensor:varTensor - gammaTensor:weightTensor - betaTensor:biasTensor + gammaTensor:gammaTensor + betaTensor:betaTensor epsilon:(float)epsilon name:nil]; @@ -385,7 +414,8 @@ Check if running mean exists (maybe do this check before making graph) bool train, double momentum, double epsilon) { - const auto memory_format = self.suggest_memory_format(); + // See https://github.com/pytorch/pytorch/issues/180984 + const auto memory_format = self.suggest_memory_format(/*channels_last_strides_exact_match=*/true); auto output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, memory_format); @@ -533,11 +563,32 @@ Check if running mean exists (maybe do this check before making graph) bool train, double epsilon, std::array grad_input_mask) { + // Flatten 5D to 4D (see batch_norm_mps_out for rationale). + if (input.dim() == 5) { + auto input_4d = + input.contiguous().reshape({input.size(0), input.size(1), input.size(2) * input.size(3), input.size(4)}); + auto grad_out_4d = grad_out.contiguous().reshape(input_4d.sizes()); + auto [gi, gw, gb] = batch_norm_backward_mps(grad_out_4d, + input_4d, + weight_opt, + running_mean_opt, + running_var_opt, + save_mean_opt, + save_var_opt, + train, + epsilon, + grad_input_mask); + if (gi.defined()) + gi = gi.reshape(input.sizes()); + return std::make_tuple(std::move(gi), std::move(gw), std::move(gb)); + } + Tensor grad_input; Tensor grad_weight; Tensor grad_bias; - const auto memory_format = input.suggest_memory_format(); + // See https://github.com/pytorch/pytorch/issues/180984 + const auto memory_format = input.suggest_memory_format(/*channels_last_strides_exact_match=*/true); if (grad_input_mask[0]) { grad_input = at::empty(input.sizes(), input.scalar_type(), std::nullopt, kMPS, std::nullopt, memory_format); @@ -610,7 +661,9 @@ Check if running mean exists (maybe do this check before making graph) NSString* ns_shape_key = [[input_shape valueForKey:@"description"] componentsJoinedByString:@","]; - std::string key = fmt::format("batch_norm_backward_mps:{}:{}:{}:{}:{}:{}:{}:{}", + auto input_mps_dtype = getMPSDataType(input); + auto weight_mps_dtype = has_weight ? getMPSDataType(weight_opt.value()) : input_mps_dtype; + std::string key = fmt::format("batch_norm_backward_mps:{}:{}:{}:{}:{}:{}:{}:{}:{}", get_mem_string(memory_format), epsilon, train, @@ -618,8 +671,8 @@ Check if running mean exists (maybe do this check before making graph) has_weight, [ns_shape_key UTF8String], c10::Join(",", grad_input_mask), - getMPSTypeString(input)); - auto input_mps_dtype = getMPSDataType(input); + getMPSTypeString(input), + has_weight ? getMPSTypeString(weight_opt.value()) : "none"); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { // NCHW - Channels dim is 1 int channelsDim = 1; @@ -628,8 +681,11 @@ Check if running mean exists (maybe do this check before making graph) // Shape is the ORIGINAL NCHW shape auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(grad_out), input_shape_readonly); MPSGraphTensor* weightTensor = nil; - if (has_weight) - weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(weight_opt.value()), new_mean_shape); + MPSGraphTensor* weightTensorCasted = nil; + if (has_weight) { + weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_mps_dtype, new_mean_shape); + weightTensorCasted = castMPSTensor(mpsGraph, weightTensor, input_mps_dtype); + } MPSGraphTensor* runningMeanTensor = nil; MPSGraphTensor* runningVarTensor = nil; if (has_running_mean) { @@ -698,7 +754,7 @@ Check if running mean exists (maybe do this check before making graph) sourceTensor:inputTensor meanTensor:saveMeanTensor varianceTensor:revertSaveVarTensor - gammaTensor:weightTensor + gammaTensor:weightTensorCasted gammaGradientTensor:gradWeightTensor betaGradientTensor:gradBiasTensor reductionAxes:axes @@ -766,7 +822,7 @@ Check if running mean exists (maybe do this check before making graph) gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor secondaryTensor:rsqrtTensor name:nil]; if (has_weight) gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradInputTensor - secondaryTensor:weightTensor + secondaryTensor:weightTensorCasted name:nil]; gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:gradInputTensor secondaryTensor:gradOutputTensor @@ -778,11 +834,13 @@ Check if running mean exists (maybe do this check before making graph) gradWeightTensor = [mpsGraph reshapeTensor:gradWeightTensor withShape:@[ input_shape_readonly[channelsDim] ] name:nil]; + gradWeightTensor = castMPSTensor(mpsGraph, gradWeightTensor, weight_mps_dtype); } if (grad_input_mask[2]) { gradBiasTensor = [mpsGraph reshapeTensor:gradBiasTensor withShape:@[ input_shape_readonly[channelsDim] ] name:nil]; + gradBiasTensor = castMPSTensor(mpsGraph, gradBiasTensor, weight_mps_dtype); } MPSGraphTensor* gradInputTensorFinal = nil; @@ -944,7 +1002,7 @@ Check if running mean exists (maybe do this check before making graph) } mean = mean.view(stat_shape); rstd = rstd.view(stat_shape); - return std::make_tuple(out, mean, rstd); + return std::make_tuple(std::move(out), std::move(mean), std::move(rstd)); } std::tuple layer_norm_backward_mps(const Tensor& grad_out, diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index 936ab93e65b7c..96fd799889473 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -69,7 +69,10 @@ static void pool2d_template(const Tensor& input, const bool is_backward_pass = grad_output.defined(); const bool has_indices = indices.defined(); const bool has_divisor = divisor_override.has_value() && divisor_override.value() != 0; - const auto suggested_memory_format = input.suggest_memory_format(); + // Use exact-match: a channel-slice of a channels-last tensor has CL-like + // strides but is not NHWC-packed, so the raw-buffer NHWC path would misread + // it. See https://github.com/pytorch/pytorch/issues/180984 + const auto suggested_memory_format = input.suggest_memory_format(/*channels_last_strides_exact_match=*/true); // for max_pool2d_with_indices() we cannot pass ChannelsLast (i.e., NHWC) to 'desc.dataLayout' in MPSGraph. // Because the returned indices will be selected based on NHWC memory layout which will // be incompatible with the PyTorch's global NCHW layout. diff --git a/aten/src/ATen/native/mps/operations/RMSNorm.mm b/aten/src/ATen/native/mps/operations/RMSNorm.mm index 7948b5acd8e93..33348c6cbff8e 100644 --- a/aten/src/ATen/native/mps/operations/RMSNorm.mm +++ b/aten/src/ATen/native/mps/operations/RMSNorm.mm @@ -55,7 +55,7 @@ if (N <= LOOPED_LIMIT) { size_t threadgroup_needed = (N + N_READS - 1) / N_READS; size_t simds_needed = (threadgroup_needed + SIMD_SIZE - 1) / SIMD_SIZE; - size_t threadgroup_size = SIMD_SIZE * simds_needed; + threadgroup_size = SIMD_SIZE * simds_needed; assert(threadgroup_size <= maxThreadsPerGroup); } size_t n_threads = M * threadgroup_size; diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index e634eefee2058..d38722cef2fe8 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -3,8 +3,10 @@ #include #include #include +#include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -20,7 +22,7 @@ #include #include #include -#include +#include #include #include #include @@ -28,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -40,10 +41,13 @@ #endif namespace at::native { -namespace mps { -typedef MPSGraphTensor* (^NormOpBlock)(MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*); -#define NormOpFn(graph, primary, secondary) \ - MPSGraphTensor*(MPSBinaryCachedGraph * graph, MPSGraphTensor * primary, MPSGraphTensor * secondary) +using namespace mps; + +#ifndef PYTORCH_JIT_COMPILE_SHADERS +static auto& lib = MetalShaderLibrary::getBundledLibrary(); +#else +#include +#endif enum StdVarType { STANDARD_VARIANCE, STANDARD_DEVIATION }; @@ -52,12 +56,9 @@ MIN, AMAX, AMIN, - SUM, PROD, MEAN, - COUNT_NONZERO, TRACE, - NANSUM, }; static void set_apparent_shapes(NSMutableArray*& apparent_out_shape, @@ -201,11 +202,6 @@ static void reduction_out_mps(const Tensor& input_t, case MPSReductionType::MEAN: output_t.fill_(std::numeric_limits::quiet_NaN()); break; - case MPSReductionType::SUM: - case MPSReductionType::NANSUM: - case MPSReductionType::COUNT_NONZERO: - output_t.zero_(); - break; case MPSReductionType::AMAX: case MPSReductionType::AMIN: case MPSReductionType::MAX: @@ -246,18 +242,10 @@ static void reduction_out_mps(const Tensor& input_t, MPSGraphTensor* castOutputTensor = nil; - if (reduction_type == MPSReductionType::SUM) { - castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; - } else if (reduction_type == MPSReductionType::PROD) { + if (reduction_type == MPSReductionType::PROD) { castOutputTensor = [mpsGraph reductionProductWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::MEAN) { castOutputTensor = [mpsGraph meanOfTensor:castInputTensor axes:wrappedAxes name:nil]; - } else if (reduction_type == MPSReductionType::COUNT_NONZERO) { - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0 dataType:castInputTensor.dataType]; - - MPSGraphTensor* nonZeros = [mpsGraph notEqualWithPrimaryTensor:castInputTensor secondaryTensor:zeros name:nil]; - - castOutputTensor = [mpsGraph reductionSumWithTensor:nonZeros axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::AMAX) { castOutputTensor = [mpsGraph reductionMaximumPropagateNaNWithTensor:castInputTensor axes:wrappedAxes name:nil]; } else if (reduction_type == MPSReductionType::AMIN) { @@ -268,23 +256,6 @@ static void reduction_out_mps(const Tensor& input_t, numUpper:0 name:nil]; castOutputTensor = [mpsGraph reductionSumWithTensor:bandPartWithTensor axes:@[ @0, @1 ] name:nil]; - } else if (reduction_type == MPSReductionType::NANSUM) { - // Integral types cannot contain NaN, so just do regular sum - if (([castInputTensor dataType] & MPSDataTypeFloatBit) == 0) { - castOutputTensor = [mpsGraph reductionSumWithTensor:castInputTensor axes:wrappedAxes name:nil]; - } else { - // Create a 0 tensor of the same shape as inputTensor - auto zeros = [mpsGraph constantWithScalar:0.0 dataType:castInputTensor.dataType]; - // Find NaNs - auto nanMask = [mpsGraph isNaNWithTensor:castInputTensor name:nil]; - // Replace NaNs with 0 - auto nanReplaced = [mpsGraph selectWithPredicateTensor:nanMask - truePredicateTensor:zeros - falsePredicateTensor:castInputTensor - name:nil]; - // Sum - castOutputTensor = [mpsGraph reductionSumWithTensor:nanReplaced axes:wrappedAxes name:nil]; - } } MPSGraphTensor* outputTensor = castOutputTensor; @@ -303,143 +274,58 @@ static void reduction_out_mps(const Tensor& input_t, } } -static void impl_func_norm_mps(const Tensor& input_tensor, - const Tensor& other_tensor, - const OptionalScalarRef& opt_p, - IntArrayRef dim, - bool keepdim, - std::optional opt_dtype, - const Tensor& output_t, - bool cdist = false, - std::optional input_broadcasted_shape = std::nullopt, - NormOpBlock normOpBlock = nullptr) { - auto p = opt_p.has_value() ? opt_p.get().to() : Scalar(2.0).to(); - if (input_tensor.numel() == 0) { - output_t.fill_((p < 0) ? INFINITY : 0); - return; - } - - auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor; - auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type()); - auto mps_input_dtype = getMPSDataType(in_dtype); - TORCH_CHECK(!input_tensor.is_complex(), "norm ops are not supported for complex yet"); - - IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes(); - - for (const auto dim_val : dim) { - auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size()); - TORCH_CHECK(wrap_dim < static_cast(input_shape.size()), - "norm_out_mps: reduction dim must be in the range of input shape") - } - - auto reciprocal_p = 1 / p; - bool pIsZero = (p == 0.0); - bool pIsPosInf = (p == std::numeric_limits::infinity()); - bool pIsNegInf = (p == -std::numeric_limits::infinity()); +static void norm_kernel_mps(TensorIterator& iter, const Scalar& p_scalar) { + const Tensor& output = iter.output(0); + const Tensor& input = iter.input(0); + auto p = p_scalar.to(); - int64_t num_input_dims = input_shape.size(); - int64_t num_reduce_dims = dim.size(); - int64_t num_output_dims; - - // For output shape calculation, assume that keepdim is true - num_output_dims = num_input_dims; - NSMutableArray* apparent_output_shape = nil; - NSMutableArray* apparent_input_shape = nil; - - // Reduction axes - NSMutableArray* axes; - set_axes(axes, num_reduce_dims, dim, input_shape.size()); - - set_apparent_shapes(apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes); - - NSArray* wrappedAxes = getTensorAxes(input_shape, dim); - if (cdist) { - apparent_input_shape = [getMPSShape(input_tensor.sizes()) mutableCopy]; - apparent_output_shape = [getMPSShape(output_t.sizes()) mutableCopy]; + if (input.numel() == 0) { + output.fill_((p < 0) ? INFINITY : 0); + return; } - if (output_t.numel() == 0) { + if (output.numel() == 0) { return; } - auto stream = getCurrentMPSStream(); - @autoreleasepool { - NSString* ns_key = [[wrappedAxes valueForKey:@"description"] componentsJoinedByString:@","]; - std::string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0"; - std::string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : getTensorsStringKey({input_t}); - std::string key = std::string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + std::to_string(p) + - ":" + keepdim_info + ":" + toString(in_dtype); + // Number of input elements that are reduced into one output element + uint32_t reduction_size = input.numel() / output.numel(); - auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor); + TORCH_INTERNAL_ASSERT(output.dim() == input.dim()); - if (cdist) { - newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor); - } - - MPSGraphTensor* inputTensor = cdist - ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) - : newCachedGraph->inputTensor_; - - if (opt_dtype.has_value()) { - inputTensor = castMPSTensor(mpsGraph, inputTensor, mps_input_dtype); - } - - MPSGraphTensor* outputTensor; - - if (pIsZero) { - MPSGraphTensor* zeros = [mpsGraph constantWithScalar:0.0 dataType:mps_input_dtype]; - MPSGraphTensor* ones = [mpsGraph constantWithScalar:1.0 dataType:mps_input_dtype]; - MPSGraphTensor* nonZeros = [mpsGraph selectWithPredicateTensor:inputTensor - truePredicateTensor:ones - falsePredicateTensor:zeros - name:nil]; - outputTensor = [mpsGraph reductionSumWithTensor:nonZeros axes:wrappedAxes name:nil]; - } else if (pIsPosInf) { - MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; - outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor axes:wrappedAxes name:nil]; - } else if (pIsNegInf) { - MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; - outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor axes:wrappedAxes name:nil]; - } else { - MPSGraphTensor* absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor name:nil]; - - MPSGraphTensor* powerValTensor = [mpsGraph constantWithScalar:p dataType:mps_input_dtype]; - - MPSGraphTensor* reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p dataType:mps_input_dtype]; + NormParams params; - MPSGraphTensor* powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor - secondaryTensor:powerValTensor - name:nil]; + params.ndim = input.dim(); + params.p = static_cast(p); + params.reduction_size = reduction_size; - MPSGraphTensor* reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor axes:wrappedAxes name:nil]; - - outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor - secondaryTensor:reciprocalPowerValTensor - name:nil]; - } + for (const auto dim_idx : c10::irange(input.dim())) { + params.input_sizes[dim_idx] = input.size(dim_idx); + params.input_strides[dim_idx] = input.stride(dim_idx); + params.output_sizes[dim_idx] = output.size(dim_idx); + params.output_strides[dim_idx] = output.stride(dim_idx); + } - if (cdist) { - outputTensor = [mpsGraph reshapeTensor:outputTensor withShape:getMPSShape(output_t) name:nil]; - } + MPSStream* stream = getCurrentMPSStream(); - newCachedGraph->outputTensor_ = outputTensor; - }); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + auto pipeline_state = lib.getPipelineStateForFunc( + fmt::format("norm_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output))); + getMPSProfiler().beginProfileKernel(pipeline_state, "norm", {input}); + [compute_encoder setComputePipelineState:pipeline_state]; + mtl_setArgs(compute_encoder, input, output, params); - auto otherPlaceholder = Placeholder(); - auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); - auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape); + auto threads_per_group = std::min(MAX_THREADGROUP_SIZE, reduction_size); + uint32_t num_threads = output.numel() * threads_per_group; - NSMutableDictionary* feeds = [NSMutableDictionary dictionary]; - feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); + [compute_encoder dispatchThreads:MTLSizeMake(num_threads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)]; - if (cdist) { - otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other_tensor); - feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData(); + getMPSProfiler().endProfileKernel(pipeline_state); } - - runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); - } + }); } static Tensor std_var_common_impl_mps(const Tensor& input_t, @@ -972,39 +858,246 @@ static void argmax_argmin_out_mps(const Tensor& input_t, } } -} // namespace mps +// Shared implementation for sum/nansum/count_nonzero/mean Metal kernels. +// `kernel_prefix` is "sum_", "nansum_" or "count_nonzero_" — selects the +// kernel variant to dispatch. `divisor` > 0 divides the accumulator (in +// opmath_t) before casting to output, enabling fused mean without losing the +// fp32 accumulation precision for fp16/bf16/half2 outputs. +static void sum_nansum_kernel_mps(TensorIterator& iter, const std::string& kernel_prefix, float divisor = 0.0f) { + const Tensor& output = iter.output(0); + const Tensor& input = iter.input(0); + + if (input.numel() == 0) { + output.zero_(); + return; + } -using namespace mps; + if (output.numel() == 0) { + return; + } + + uint32_t reduction_size = input.numel() / output.numel(); + + // TensorIterator ensures input and output have matching ndim + // (reduced dims have size 1 in output) + TORCH_INTERNAL_ASSERT(output.dim() == input.dim()); + + constexpr uint32_t NCHAINS = SUM_NCHAINS; + + auto kernel_name = + fmt::format("{}reduction_{}_{}", kernel_prefix, scalarToMetalTypeString(input), scalarToMetalTypeString(output)); + + MPSStream* stream = getCurrentMPSStream(); + + // For large full reductions (output is scalar), use multi-TG with a + // two-pass approach: first pass splits work across num_groups TGs writing + // partial sums, second pass reduces the partials to the final scalar. + if (output.numel() == 1 && reduction_size > MAX_THREADGROUP_SIZE * NCHAINS) { + auto num_groups = std::min(512u, c10::metal::ceil_div(reduction_size, MAX_THREADGROUP_SIZE * NCHAINS)); + + // elems_per_group * num_groups must equal reduction_size exactly, + // otherwise pass 1's last TG reads past the input's logical end. + // Reduce num_groups down to a divisor of reduction_size (falling back + // to 1 is always safe — the inner loop still parallelizes via threads). + while (num_groups > 1 && reduction_size % num_groups != 0) { + num_groups--; + } + + auto partials = at::empty({num_groups}, output.options()); + const auto elems_per_group = reduction_size / num_groups; + + auto out_metal = scalarToMetalTypeString(output); + auto p1_kernel = fmt::format("{}reduction_{}_{}", kernel_prefix, scalarToMetalTypeString(input), out_metal); + // Pass 2 combines partials by summing them regardless of pass-1 mode. + // For count_nonzero the partials are already per-block counts (long); + // counting them again would be wrong, so always use "sum_" here. + auto p2_kernel = fmt::format("sum_reduction_{}_{}", out_metal, out_metal); + + // Model as 2D: input is [num_groups, elems_per_group], reduce dim=1 + // Dim 0 (non-reduced): size=num_groups, input_stride=elems_per_group, output_stride=1 + // Dim 1 (reduced): size=elems_per_group, input_stride=1 + NormParams params1; + params1.ndim = 2; + params1.p = 0; + params1.reduction_size = elems_per_group; + params1.input_sizes[0] = num_groups; + params1.input_strides[0] = elems_per_group; + params1.output_sizes[0] = num_groups; + params1.output_strides[0] = 1; + params1.input_sizes[1] = elems_per_group; + params1.input_strides[1] = 1; + params1.output_sizes[1] = 1; + params1.output_strides[1] = 0; + + // Pass 2: partials[num_groups] -> output[1], reduce dim=0. + // divisor applies here (not on pass 1), so pass 2 produces + // accumulator/divisor before the final cast to output dtype. + NormParams params2; + params2.ndim = 1; + params2.p = divisor; + params2.reduction_size = num_groups; + params2.input_sizes[0] = num_groups; + params2.input_strides[0] = 1; + params2.output_sizes[0] = 1; + params2.output_strides[0] = 0; + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + + // Pass 1: input -> partials + auto ps1 = lib.getPipelineStateForFunc(p1_kernel); + getMPSProfiler().beginProfileKernel(ps1, "sum_reduction_pass1", {input}); + [compute_encoder setComputePipelineState:ps1]; + mtl_setArgs(compute_encoder, input, partials, params1); + auto tpg1 = std::min(MAX_THREADGROUP_SIZE, elems_per_group); + [compute_encoder dispatchThreads:MTLSizeMake(num_groups * tpg1, 1, 1) + threadsPerThreadgroup:MTLSizeMake(tpg1, 1, 1)]; + getMPSProfiler().endProfileKernel(ps1); + + // Pass 2: partials -> output + auto ps2 = lib.getPipelineStateForFunc(p2_kernel); + getMPSProfiler().beginProfileKernel(ps2, "sum_reduction_pass2", {partials}); + [compute_encoder setComputePipelineState:ps2]; + mtl_setArgs(compute_encoder, partials, output, params2); + auto tpg2 = std::min(MAX_THREADGROUP_SIZE, num_groups); + [compute_encoder dispatchThreads:MTLSizeMake(tpg2, 1, 1) threadsPerThreadgroup:MTLSizeMake(tpg2, 1, 1)]; + getMPSProfiler().endProfileKernel(ps2); + } + }); + return; + } + + // Detect outer-dim (non-innermost) reduction on contiguous 2D tensor. + // For this case, use a specialized kernel with coalesced column reads. + // Condition: exactly one reduced dim, it's not the last dim, input is contiguous. + { + int num_reduced = 0; + int reduced_dim = -1; + for (int64_t d = 0; d < input.dim(); d++) { + if (input.size(d) != output.size(d)) { + num_reduced++; + reduced_dim = d; + } + } + bool is_outer_reduction = (num_reduced == 1 && reduced_dim < input.dim() - 1 && input.is_contiguous()); + bool is_inner_reduction = (num_reduced == 1 && reduced_dim == input.dim() - 1 && input.is_contiguous()); + + if (is_outer_reduction && reduced_dim == 0 && output.is_contiguous()) { + uint32_t M = input.size(0); + uint32_t N = input.numel() / M; + + auto outer_kernel = fmt::format( + "{}reduction_outer_{}_{}", kernel_prefix, scalarToMetalTypeString(input), scalarToMetalTypeString(output)); + constexpr uint32_t TG_X = 32, TG_Y = 32; + const auto num_tg_x = c10::metal::ceil_div(N, TG_X); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + auto ps = lib.getPipelineStateForFunc(outer_kernel); + getMPSProfiler().beginProfileKernel(ps, "sum_reduction_outer", {input}); + struct { + uint32_t M, N, out_stride; + } sizes_s = {M, N, 1}; + [compute_encoder setComputePipelineState:ps]; + mtl_setArgs(compute_encoder, input, output, sizes_s, divisor); + [compute_encoder dispatchThreads:MTLSizeMake(num_tg_x * TG_X, TG_Y, 1) + threadsPerThreadgroup:MTLSizeMake(TG_X, TG_Y, 1)]; + getMPSProfiler().endProfileKernel(ps); + } + }); + return; + } + + if (is_inner_reduction && output.is_contiguous()) { + // M = product of all non-reduced dims, N = size of last dim + uint32_t N = input.size(input.dim() - 1); + uint32_t M = input.numel() / N; + + auto inner_kernel = fmt::format( + "{}reduction_inner_{}_{}", kernel_prefix, scalarToMetalTypeString(input), scalarToMetalTypeString(output)); + // Pack multiple rows per TG: each SIMD group (32 threads) handles one row + constexpr uint32_t TG_SIZE = 256; // 8 SIMD groups = 8 rows per TG + constexpr uint32_t rows_per_tg = TG_SIZE / 32; + const auto num_tgs = c10::metal::ceil_div(M, rows_per_tg); + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + auto ps = lib.getPipelineStateForFunc(inner_kernel); + getMPSProfiler().beginProfileKernel(ps, "sum_reduction_inner", {input}); + struct { + uint32_t M, N; + } sizes_s = {M, N}; + [compute_encoder setComputePipelineState:ps]; + mtl_setArgs(compute_encoder, input, output, sizes_s, divisor); + [compute_encoder dispatchThreads:MTLSizeMake(num_tgs * TG_SIZE, 1, 1) + threadsPerThreadgroup:MTLSizeMake(TG_SIZE, 1, 1)]; + getMPSProfiler().endProfileKernel(ps); + } + }); + return; + } + } + + NormParams params; + params.ndim = input.dim(); + params.p = divisor; + params.reduction_size = reduction_size; + + for (const auto dim_idx : c10::irange(input.dim())) { + params.input_sizes[dim_idx] = input.size(dim_idx); + params.input_strides[dim_idx] = input.stride(dim_idx); + params.output_sizes[dim_idx] = output.size(dim_idx); + params.output_strides[dim_idx] = output.stride(dim_idx); + } + + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + id compute_encoder = stream->commandEncoder(); + auto pipeline_state = lib.getPipelineStateForFunc(kernel_name); + getMPSProfiler().beginProfileKernel(pipeline_state, "sum_reduction", {input}); + [compute_encoder setComputePipelineState:pipeline_state]; + mtl_setArgs(compute_encoder, input, output, params); + + auto threads_per_group = std::min(MAX_THREADGROUP_SIZE, reduction_size); + uint32_t num_threads = output.numel() * threads_per_group; + + [compute_encoder dispatchThreads:MTLSizeMake(num_threads, 1, 1) + threadsPerThreadgroup:MTLSizeMake(threads_per_group, 1, 1)]; + + getMPSProfiler().endProfileKernel(pipeline_state); + } + }); +} -TORCH_IMPL_FUNC(sum_out_mps) -(const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional dtype, - const Tensor& output_t) { - reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::SUM, "sum_out_mps"); +static void sum_kernel_mps(TensorIterator& iter) { + sum_nansum_kernel_mps(iter, "sum_"); } -Tensor& nansum_out_mps(const Tensor& self, - OptionalIntArrayRef dim, - bool keepdim, - std::optional opt_dtype, - Tensor& result) { - TORCH_CHECK(!c10::isComplexType(self.scalar_type()), "nansum on MPS does not support complex inputs"); - if (c10::isIntegralType(self.scalar_type(), true)) { - return at::sum_out(result, self, dim, keepdim, opt_dtype); +static void nansum_kernel_mps(TensorIterator& iter) { + auto in_dtype = iter.input(0).scalar_type(); + bool is_float = c10::isFloatingType(in_dtype) || c10::isComplexType(in_dtype); + sum_nansum_kernel_mps(iter, is_float ? "nansum_" : "sum_"); +} + +static void mean_kernel_mps(TensorIterator& iter) { + auto output = iter.output(0); + auto input = iter.input(0); + if (input.numel() == 0 || output.numel() == 0) { + sum_nansum_kernel_mps(iter, "sum_"); + return; } - ScalarType dtype = get_dtype_from_result(result, opt_dtype); - const auto mask = make_dim_mask(dim, self.dim()); - resize_reduction_result(result, self, mask, keepdim, dtype); - reduction_out_mps(self, dim, keepdim, dtype, result, MPSReductionType::NANSUM, "nansum_out_mps"); - return result; + int64_t reduction_size = input.numel() / output.numel(); + // Fused divide: the sum kernel divides the accumulator (in opmath_t) + // before casting to output, so fp32 accumulation precision is preserved + // for fp16/bf16/half2 without an intermediate tensor. + sum_nansum_kernel_mps(iter, "sum_", static_cast(reduction_size)); } -Tensor nansum_mps(const Tensor& self, OptionalIntArrayRef dim, bool keepdim, std::optional opt_dtype) { - ScalarType dtype = get_dtype_from_self(self, opt_dtype, true); - Tensor result = create_reduction_result(self, dim, keepdim, dtype); - return nansum_out_mps(self, dim, keepdim, dtype, result); +static void count_nonzero_kernel_mps(TensorIterator& iter) { + sum_nansum_kernel_mps(iter, "count_nonzero_"); } Tensor trace_mps(const Tensor& self) { @@ -1076,67 +1169,11 @@ Tensor prod_mps(const Tensor& self, std::optional opt_dtype) { } Tensor count_nonzero_mps(const Tensor& self, IntArrayRef dims) { - int64_t shape_size = dims.size() == 0 ? 0 : self.sizes().size() - dims.size(); - int64_t out_shape = std::max(shape_size, 0LL); - std::vector output_shape(out_shape); - std::vector dims_vec = dims.vec(); - std::for_each(dims_vec.begin(), dims_vec.end(), [&](int64_t& n) { n = maybe_wrap_dim(n, self); }); - - if (out_shape != 0) { - int out_dim = 0; - for (const auto self_dim : c10::irange((self.sizes().size()))) { - if (std::find(dims_vec.begin(), dims_vec.end(), self_dim) == dims_vec.end()) { - output_shape[out_dim++] = (self.sizes()[self_dim]); - } - } - } - - Tensor output_t = - at::empty(IntArrayRef(output_shape), ScalarType::Long, std::nullopt, kMPS, std::nullopt, std::nullopt); - reduction_out_mps(self, - dims, - false, - self.scalar_type(), - const_cast(output_t), - MPSReductionType::COUNT_NONZERO, - "count_nonzero_mps"); - - return output_t; -} - -TORCH_IMPL_FUNC(mean_out_mps) -(const Tensor& input_t, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional dtype, - const Tensor& output_t) { - reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps"); -} - -TORCH_IMPL_FUNC(norm_out_mps) -(const Tensor& self, const OptionalScalarRef opt_p, IntArrayRef dim, bool keepdim, const Tensor& result) { - impl_func_norm_mps(self, self, opt_p, dim, keepdim, std::nullopt, result, /*cdist=*/false); -} - -TORCH_IMPL_FUNC(norm_dtype_out_mps) -(const Tensor& self, - const OptionalScalarRef opt_p, - IntArrayRef dim, - bool keepdim, - ScalarType dtype, - const Tensor& result) { - impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false); -} - -TORCH_IMPL_FUNC(linalg_vector_norm_out_mps) -(const Tensor& self, - const Scalar& scalar_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype, - const Tensor& result) { - impl_func_norm_mps( - self, self, scalar_ord, opt_dim.value_or(IntArrayRef{}), keepdim, opt_dtype, result, /*cdist=*/false); + Tensor result = create_reduction_result(self, dims, /*keepdim=*/false, ScalarType::Long); + auto iter = + make_reduction("count_nonzero_mps", result, self, dims, /*keepdim=*/false, self.scalar_type(), ScalarType::Long); + count_nonzero_kernel_mps(iter); + return result; } Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, std::optional compute_mode) { @@ -1149,96 +1186,20 @@ Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, st x2.size(-1)); TORCH_CHECK( at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type()); - auto device1 = x1.device().type(); TORCH_CHECK( at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type()); - auto device2 = x2.device().type(); TORCH_CHECK(p >= 0, "cdist only supports non-negative p values"); - TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2); - TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), - "device of X1 (", - x1.get_device(), - ") must match device of X2 (", - x2.get_device(), - ")"); - - int64_t c1 = x1.size(-1); - int64_t c2 = x2.size(-1); - - auto dim1 = x1.dim(); - auto dim2 = x2.dim(); + int64_t mode = compute_mode.value_or(0); TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode); - int64_t r1 = x1.size(-2); - int64_t r2 = x2.size(-2); - - // For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of - // them. The last two dimensions will stay the same - IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2); - IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2); - std::vector expand_batch_portion = infer_size(batch_tensor1, batch_tensor2); - std::vector tensor1_expand_size(expand_batch_portion); - tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1}); - std::vector tensor2_expand_size(expand_batch_portion); - tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2}); - - const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion); - std::vector tensor1_view{expand_batch_product, r1, c1}; - std::vector tensor2_view{expand_batch_product, r2, c2}; - - std::vector output_shape(expand_batch_portion); - output_shape.insert(output_shape.end(), {r1, r2}); + Tensor x1_ = x1.unsqueeze(-2); + Tensor x2_ = x2.unsqueeze(-3); + Tensor diff = x1_.sub(x2_); + IntArrayRef output_shape(diff.sizes().data(), diff.dim() - 1); Tensor result = at::empty(output_shape, x1.options()); + linalg_vector_norm_out(result, diff, p, makeArrayRef(-1), /*keepdim=*/false, /*dtype=*/std::nullopt); - NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) { - MPSGraph* mpsGraph = cachedGraph->graph(); - - MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor - toShape:getMPSShape(tensor1_expand_size) - name:nil]; - MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast - withShape:getMPSShape(tensor1_view) - name:nil]; - - MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor - toShape:getMPSShape(tensor2_expand_size) - name:nil]; - MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast - withShape:getMPSShape(tensor2_view) - name:nil]; - - NSMutableArray* inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]]; - NSMutableArray* otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]]; - - for (const auto i : c10::irange(tensor2_view[1])) { - inputArray[i] = inputBroadcastReshape; - } - - for (const auto i : c10::irange(tensor1_view[1])) { - otherArray[i] = otherBroadcastReshape; - } - - MPSGraphTensor* inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil]; - MPSGraphTensor* otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil]; - - MPSGraphTensor* inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor:inputTensorReshaped - secondaryTensor:otherTensorReshaped - name:nil]; - return inputTensorPNorm; - }; - - IntArrayRef inputBroadcastSize = makeArrayRef(tensor1_view.data(), tensor1_view.size()); - impl_func_norm_mps(x1, - x2, - OptionalScalarRef(p), - makeArrayRef(2), - false, - std::nullopt, - result, - /*cdist=*/true, - inputBroadcastSize, - norm_op_block); return result; } @@ -1701,4 +1662,9 @@ Tensor nanmedian_mps(const Tensor& self) { return {var, mean}; } +REGISTER_DISPATCH(norm_stub, &norm_kernel_mps) +REGISTER_DISPATCH(sum_stub, &sum_kernel_mps) +REGISTER_DISPATCH(nansum_stub, &nansum_kernel_mps) +REGISTER_DISPATCH(mean_stub, &mean_kernel_mps) + } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/Repeat.mm b/aten/src/ATen/native/mps/operations/Repeat.mm index f350b0137b05e..7abfb0a2e4500 100644 --- a/aten/src/ATen/native/mps/operations/Repeat.mm +++ b/aten/src/ATen/native/mps/operations/Repeat.mm @@ -10,6 +10,14 @@ #include #include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#endif + namespace at::native { Tensor permute_mps(const Tensor& self, IntArrayRef dims) { @@ -36,7 +44,13 @@ Tensor repeat_mps(const Tensor& self, IntArrayRef repeats) { TORCH_CHECK(repeats.size() >= (size_t)self.dim(), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); - TORCH_CHECK(!self.is_complex(), "repeat(): Not supported for complex yet!"); + + if (self.is_complex()) { + std::vector repeats_real = repeats.vec(); + repeats_real.push_back(1); + auto self_real = at::view_as_real(self); + return at::view_as_complex(repeat_mps(self_real, repeats_real)); + } // Add new leading dimensions to the tensor if the // number of target dimensions is larger than the diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm index d72ead5aa8d24..cc49c6b1d6621 100644 --- a/aten/src/ATen/native/mps/operations/RnnOps.mm +++ b/aten/src/ATen/native/mps/operations/RnnOps.mm @@ -134,7 +134,7 @@ std::string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + - std::to_string(batch_first); + std::to_string(batch_first) + "_train_" + std::to_string(train); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { NSMutableArray* kernelWeightsList = [[NSMutableArray alloc] initWithCapacity:params.size()]; NSMutableArray* recurrentKernelWeightsList = diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.h b/aten/src/ATen/native/mps/operations/ScanKernel.h new file mode 100644 index 0000000000000..06bedfbdd215e --- /dev/null +++ b/aten/src/ATen/native/mps/operations/ScanKernel.h @@ -0,0 +1,11 @@ +#pragma once + +namespace at::native::mps { + +void scan_simple_mps_impl( + const Tensor& self, + const Tensor& output, + int64_t dim, + const std::string& op_name); + +} // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/operations/ScanKernel.mm b/aten/src/ATen/native/mps/operations/ScanKernel.mm index fce8f77e84f07..9b8a8a5d485f7 100644 --- a/aten/src/ATen/native/mps/operations/ScanKernel.mm +++ b/aten/src/ATen/native/mps/operations/ScanKernel.mm @@ -3,6 +3,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -45,7 +46,7 @@ return {static_cast(grid_x), static_cast(grid_y)}; } -static void scan_simple_mps_impl(const Tensor& self, const Tensor& output, int64_t dim, const std::string& op_name) { +void scan_simple_mps_impl(const Tensor& self, const Tensor& output, int64_t dim, const std::string& op_name) { if (output.numel() == 0) { return; } diff --git a/aten/src/ATen/native/mps/operations/ScatterGather.mm b/aten/src/ATen/native/mps/operations/ScatterGather.mm index ce65421c71c9d..b9b055b4fee6f 100644 --- a/aten/src/ATen/native/mps/operations/ScatterGather.mm +++ b/aten/src/ATen/native/mps/operations/ScatterGather.mm @@ -10,10 +10,23 @@ #include #include #include +#include #endif namespace at::native { +static Tensor maybe_expand_0_dim(const Tensor& t) { + return t.dim() == 0 ? t.view({1}) : t; +} + +static Tensor expand_index_as_real(const Tensor& index) { + auto index_view = maybe_expand_0_dim(index); + std::vector index_expanded_sizes = index_view.sizes().vec(); + index_expanded_sizes.push_back(2); + auto index_expanded = index_view.unsqueeze(-1).expand(index_expanded_sizes); + return index_expanded; +} + TORCH_IMPL_FUNC(gather_out_mps) (const Tensor& self_arg, int64_t dim, const Tensor& index, bool sparse_grad, const Tensor& output) { using namespace mps; @@ -27,7 +40,14 @@ TORCH_CHECK(!sparse_grad, "sparse_grad not supported in MPS yet") TORCH_CHECK(self.scalar_type() == output.scalar_type(), "gather(): self and output must have the same scalar type"); TORCH_CHECK(dim >= 0 && dim < self.dim(), "gather(): Indexing dim ", dim, " is out of bounds of tensor"); - TORCH_CHECK(!self.is_complex(), "gather(): Yet not supported for complex"); + + if (self.is_complex()) { + auto self_real = at::view_as_real(self); + auto index_expanded = expand_index_as_real(index); + auto output_real = at::view_as_real(maybe_expand_0_dim(output)); + structured_gather_out_mps::impl(self_real, dim, index_expanded, sparse_grad, output_real); + return; + } struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} @@ -145,7 +165,15 @@ static void scatter_mps_general(const Tensor& self_arg, TORCH_CHECK(self.scalar_type() == output.scalar_type() && output.scalar_type() == src.scalar_type(), "scatter(): self, src and output must have the same scalar type"); TORCH_CHECK(dim >= 0 && dim < self.dim(), "scatter(): Indexing dim ", dim, " is out of bounds of tensor"); - TORCH_CHECK(!self.is_complex(), "scatter(): Yet not supported for complex"); + + if (self.is_complex()) { + auto self_real = at::view_as_real(self); + auto index_expanded = expand_index_as_real(index); + auto src_real = at::view_as_real(maybe_expand_0_dim(src)); + auto output_real = at::view_as_real(maybe_expand_0_dim(output)); + scatter_mps_general(self_real, dim, index_expanded, src_real, output_real, func_name, reduce); + return; + } struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index eb827d9f44c26..9dc318a393a5c 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -3,6 +3,7 @@ #include #include #include +#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -333,8 +334,21 @@ static void cumulative_op_impl(const Tensor& self, "(original dim is ", dim, ")"); - TORCH_CHECK(!self.is_complex(), "cumulative ops are not yet supported for complex"); auto input = dtype.has_value() ? self.to(dtype.value()) : self; + if (input.is_complex()) { + if (cumulativeOpType == MPSCumulativeOpType::CUMSUM) { + auto input_real = at::view_as_real(input.dim() == 0 ? input.view({1}) : input); + auto result_real = at::view_as_real(result.dim() == 0 ? result.view({1}) : result); + return cumulative_op_impl( + input_real, wrapped_dim, std::nullopt, result_real, MPSCumulativeOpType::CUMSUM, "cumsum_out_mps"); + } else if (cumulativeOpType == MPSCumulativeOpType::CUMPROD) { + auto input_view = input.dim() == 0 ? input.view({1}) : input; + auto result_view = result.dim() == 0 ? result.view({1}) : result; + return mps::scan_simple_mps_impl(input_view, result_view, wrapped_dim, "cumprod"); + } else { + TORCH_INTERNAL_ASSERT(false); + } + } // issue #103810551: cumsum / cumprod are broken for int8, int16 and as chances for overflow are pretty high, cast to // int32 fixed in macOS 13.3 diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index a5f04b6c6e1ae..953f16875c89a 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -480,6 +480,7 @@ CompositeExplicitAutograd: _conj_physical SparseCsrCPU, SparseCsrCUDA, SparseCsrMPS, SparseCsrMeta: conj_physical_sparse_csr autogen: _conj_physical.out + tags: pointwise - func: conj_physical(Tensor self) -> Tensor variants: function, method @@ -2735,8 +2736,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ - MPS: fill_scalar_mps + CPU, CUDA, MPS: fill_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: fill_sparse_csr_ @@ -2747,8 +2747,7 @@ device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA: fill_ - MPS: fill_tensor_mps_ + CPU, CUDA, MPS: fill_ QuantizedCPU, QuantizedCUDA: fill_quantized_ Meta: fill_meta_ NestedTensorCPU, NestedTensorHPU, NestedTensorCUDA: fill_nested_ @@ -2939,6 +2938,7 @@ dispatch: CPU: grid_sampler_2d_backward_cpu CUDA: grid_sampler_2d_backward_cuda + MPS: grid_sampler_2d_backward_mps autogen: grid_sampler_2d_backward.out # See NOTE [ grid_sample CPU fallback ] @@ -2948,6 +2948,8 @@ autogen: _grid_sampler_2d_cpu_fallback.out - func: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + dispatch: + CompositeExplicitAutograd: _grid_sampler_2d_cpu_fallback_backward - func: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor dispatch: @@ -2963,6 +2965,7 @@ dispatch: CPU: grid_sampler_3d_backward_cpu CUDA: grid_sampler_3d_backward_cuda + MPS: grid_sampler_3d_backward_mps autogen: grid_sampler_3d_backward.out - func: hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor @@ -3296,6 +3299,42 @@ device_guard: False manual_cpp_binding: True +- func: numel(Tensor self) -> int + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: dim(Tensor self) -> int + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: get_device(Tensor self) -> int + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: storage_offset(Tensor self) -> int + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: is_contiguous(Tensor self) -> bool + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + +- func: is_contiguous.memory_format(Tensor self, MemoryFormat memory_format) -> bool + variants: method + device_check: NoCheck + device_guard: False + manual_cpp_binding: True + - func: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor - func: kron(Tensor self, Tensor other) -> Tensor @@ -3350,11 +3389,13 @@ dispatch: CUDA: _fused_rms_norm_cuda MPS: _fused_rms_norm_mps + XPU: _fused_rms_norm_xpu CompositeImplicitAutograd: rms_norm_composite - func: _fused_rms_norm_backward(Tensor grad_out, Tensor input, int[] normalized_shape, Tensor rstd, Tensor? weight, bool[2] output_mask) -> (Tensor, Tensor) dispatch: CUDA: _fused_rms_norm_backward_cuda + XPU: _fused_rms_norm_backward_xpu - func: nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor variants: function, method @@ -3691,8 +3732,7 @@ structured_inherits: TensorIteratorBase variants: function dispatch: - CPU, CUDA: xlogy_out - MPS: xlogy_out_mps + CPU, CUDA, MPS: xlogy_out tags: pointwise - func: xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) @@ -3886,7 +3926,7 @@ device_check: NoCheck # TensorIterator structured: True dispatch: - CPU, CUDA, MTIA: aminmax_out + CPU, CUDA: aminmax_out MPS: aminmax_out_mps tags: reduction @@ -3941,7 +3981,7 @@ - func: amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA, MTIA: amax_out + CPU, CUDA: amax_out MPS: amax_out_mps tags: reduction @@ -4026,8 +4066,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: mean_out - MPS: mean_out_mps + CPU, CUDA, MPS: mean_out QuantizedCPU: mean_out_quantized_cpu tags: reduction @@ -4130,7 +4169,7 @@ - func: amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True dispatch: - CPU, CUDA, MTIA: amin_out + CPU, CUDA: amin_out MPS: amin_out_mps tags: reduction @@ -4419,7 +4458,7 @@ - func: mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: mvlgamma_out + CPU, CUDA, MPS: mvlgamma_out tags: pointwise - func: mvlgamma(Tensor self, int p) -> Tensor @@ -5381,6 +5420,7 @@ - func: selu_(Tensor(a!) self) -> Tensor(a!) device_check: NoCheck # TensorIterator + tags: pointwise - func: celu(Tensor self, Scalar alpha=1.0) -> Tensor device_check: NoCheck # TensorIterator @@ -5393,6 +5433,7 @@ dispatch: CompositeExplicitAutograd: celu_ autogen: celu.out + tags: pointwise - func: silu(Tensor self) -> Tensor structured_delegate: silu.out @@ -5413,8 +5454,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA, MTIA: silu_out - MPS: silu_out_mps + CPU, CUDA, MPS, MTIA: silu_out tags: pointwise - func: silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) @@ -5422,8 +5462,7 @@ structured_inherits: TensorIteratorBase python_module: nn dispatch: - CPU, CUDA: silu_backward_out - MPS: silu_backward_out_mps + CPU, CUDA, MPS: silu_backward_out tags: pointwise - func: silu_backward(Tensor grad_output, Tensor self) -> Tensor @@ -5980,8 +6019,7 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: sum_out - MPS: sum_out_mps + CPU, CUDA, MPS: sum_out tags: reduction - func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) @@ -5996,14 +6034,12 @@ - func: nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor variants: function, method dispatch: - CPU, CUDA: nansum - MPS: nansum_mps + CPU, CUDA, MPS: nansum tags: reduction - func: nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) dispatch: - CPU, CUDA: nansum_out - MPS: nansum_out_mps + CPU, CUDA, MPS: nansum_out tags: reduction - func: hash_tensor(Tensor self, int[1] dim=[], *, bool keepdim=False, int mode=0) -> Tensor @@ -6656,7 +6692,6 @@ dispatch: CPU, CUDA: var MPS: var_mps - MTIA: var_mtia tags: [core, reduction] - func: var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) @@ -6818,6 +6853,7 @@ dispatch: CPU: _standard_gamma_grad_cpu CUDA: _standard_gamma_grad_cuda + MPS: _standard_gamma_grad_mps autogen: _standard_gamma_grad.out - func: _standard_gamma(Tensor self, Generator? generator=None) -> Tensor @@ -6825,9 +6861,32 @@ dispatch: CPU: _s_gamma_cpu CUDA: _s_gamma_cuda + MPS: _s_gamma_mps tags: nondeterministic_seeded autogen: _standard_gamma.out +- func: _philox_key_split(Tensor key, int num_splits) -> Tensor + variants: function + dispatch: + CUDA: _philox_key_split_cuda + +- func: _philox_key_fold_in(Tensor key, int data) -> Tensor + variants: function + dispatch: + CUDA: _philox_key_fold_in_cuda + +- func: _philox_normal_(Tensor(a!) self, Tensor key, float mean=0, float std=1) -> Tensor(a!) + variants: function, method + dispatch: + CUDA: _philox_normal_cuda_ + autogen: _philox_normal, _philox_normal.out + +- func: _philox_uniform_(Tensor(a!) self, Tensor key, float low=0, float high=1) -> Tensor(a!) + variants: function, method + dispatch: + CUDA: _philox_uniform_cuda_ + autogen: _philox_uniform, _philox_uniform.out + - func: _dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor dispatch: CPU: _dirichlet_grad_cpu @@ -7016,16 +7075,14 @@ structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: norm_dtype_out - MPS: norm_dtype_out_mps + CPU, CUDA, MPS: norm_dtype_out tags: reduction - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) structured: True device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA: norm_out - MPS: norm_out_mps + CPU, CUDA, MPS: norm_out tags: reduction # These four redispatch in their implementation, so OK to be CompositeImplicitAutograd @@ -7118,8 +7175,7 @@ device_check: NoCheck # TensorIterator variants: method, function dispatch: - CPU, CUDA: zero_ - MPS: zero_mps_ + CPU, CUDA, MPS: zero_ Meta: zero_meta_ SparseCPU, SparseCUDA, SparseMPS, SparseMeta: zero_sparse_ SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: zero_sparse_csr_ @@ -7327,6 +7383,7 @@ - func: _scaled_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor variants: function dispatch: + CPU: _scaled_mm_cpu_v2 CUDA: _scaled_mm_cuda_v2 XPU: _scaled_mm_xpu_v2 tags: needs_exact_strides @@ -7334,6 +7391,7 @@ - func: _scaled_mm_v2.out(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? bias, ScalarType? out_dtype, int[] contraction_dim=[], bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) variants: function dispatch: + CPU: _scaled_mm_cpu_v2_out CUDA: _scaled_mm_cuda_v2_out XPU: _scaled_mm_xpu_v2_out tags: needs_exact_strides @@ -9568,12 +9626,14 @@ dispatch: CPU: nonzero_static_out_cpu CUDA: nonzero_static_out_cuda + MPS: nonzero_static_out_mps - func: nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor variants: method, function dispatch: CPU: nonzero_static_cpu CUDA: nonzero_static_cuda + MPS: nonzero_static_mps - func: nonzero_numpy(Tensor self) -> Tensor[] variants: method, function @@ -10010,8 +10070,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: lerp_Tensor - MPS: lerp_Tensor_mps + CPU, CUDA, MPS: lerp_Tensor tags: pointwise - func: lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor @@ -10296,8 +10355,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MTIA: maximum_out - MPS: maximum_out_mps + CPU, CUDA, MTIA, MPS: maximum_out tags: pointwise # binary max, alias of maximum @@ -10329,8 +10387,7 @@ structured_inherits: TensorIteratorBase device_check: NoCheck # TensorIterator dispatch: - CPU, CUDA, MTIA: minimum_out - MPS: minimum_out_mps + CPU, CUDA, MTIA, MPS: minimum_out tags: pointwise # binary min, alias for minimum @@ -11798,6 +11855,14 @@ CUDA: foreach_tensor_zero_cuda_ autogen: _foreach_zero, _foreach_zero.out +- func: _foreach_clone(Tensor[] self, *, MemoryFormat? memory_format=None) -> Tensor[] + device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices + variants: function + dispatch: + CompositeExplicitAutograd: foreach_tensor_clone_slow + CUDA: foreach_tensor_clone_cuda + autogen: _foreach_clone.out + - func: _foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () device_check: NoCheck # foreach kernels fall back to slow path when tensor are on different devices variants: function @@ -12131,6 +12196,7 @@ structured_delegate: elu.out device_check: NoCheck # TensorIterator python_module: nn + tags: pointwise - func: glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) structured: True @@ -12192,6 +12258,7 @@ structured_delegate: hardsigmoid.out device_check: NoCheck # TensorIterator python_module: nn + tags: pointwise - func: hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True @@ -12237,6 +12304,7 @@ dispatch: CPU, CUDA, MPS: hardtanh_ QuantizedCPU: hardtanh_quantized_cpu_ + tags: pointwise - func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -12277,7 +12345,7 @@ python_module: nn dispatch: QuantizedCPU: leaky_relu_quantized_cpu - tags: core + tags: [core, pointwise] - func: leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) structured: True @@ -12296,6 +12364,7 @@ python_module: nn dispatch: QuantizedCPU: leaky_relu_quantized_cpu_ + tags: pointwise - func: log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) device_check: NoCheck # TensorIterator @@ -12952,6 +13021,10 @@ python_module: nn autogen: _upsample_bicubic2d_aa.vec_out +- func: _upsample_lanczos2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + python_module: nn + autogen: _upsample_lanczos2d_aa.vec_out + - func: upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor python_module: nn autogen: upsample_nearest1d.vec_out @@ -13098,6 +13171,26 @@ python_module: nn structured_delegate: _upsample_bicubic2d_aa_backward.grad_input +- func: _upsample_lanczos2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_lanczos2d_aa_out_cpu + +- func: _upsample_lanczos2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_lanczos2d_aa.out + +- func: _upsample_lanczos2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + python_module: nn + structured: True + dispatch: + CPU: _upsample_lanczos2d_aa_backward_out_cpu + +- func: _upsample_lanczos2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + python_module: nn + structured_delegate: _upsample_lanczos2d_aa_backward.grad_input + - func: upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) python_module: nn structured: True @@ -14519,8 +14612,7 @@ python_module: linalg structured: True dispatch: - CPU, CUDA: linalg_vector_norm_out - MPS: linalg_vector_norm_out_mps + CPU, CUDA, MPS: linalg_vector_norm_out tags: reduction # Computes sum(|x|^ord) - the "power sum" without the final root. @@ -15182,7 +15274,7 @@ variants: function tags: nondeterministic_seeded -- func: _scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) +- func: _scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) dispatch: MPS: _scaled_dot_product_attention_math_mps tags: nondeterministic_seeded diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp index 841b67addd1bd..80c9119fa1717 100644 --- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -30,7 +30,7 @@ std::tuple matmul_backward_nested( if (grad_input_mask[1]) { grad_other = at::matmul(self.transpose(-1, -2), grad); } - return std::make_tuple(grad_self, grad_other); + return std::make_tuple(std::move(grad_self), std::move(grad_other)); } std::tuple nested_linear_backward( diff --git a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp index e10a829a60fc0..224bf69eb2f35 100644 --- a/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBinaryOps.cpp @@ -58,8 +58,8 @@ static get_elementwise_nested_tensor_impl( op_name, " requires strides to match when given NestedTensors"); const auto self_offsets = self_ptr->get_storage_offsets(); - int64_t *self_offsets_ptr = self_offsets.data_ptr(); - int64_t *other_offsets_ptr = other_ptr->get_storage_offsets().data_ptr(); + const int64_t *self_offsets_ptr = self_offsets.const_data_ptr(); + const int64_t *other_offsets_ptr = other_ptr->get_storage_offsets().const_data_ptr(); bool offsets_match = true; for (auto i = 0; i < self_offsets.size(0); i++) { offsets_match = offsets_match && (self_offsets_ptr[i] == other_offsets_ptr[i]); diff --git a/aten/src/ATen/native/nested/NestedTensorFactories.cpp b/aten/src/ATen/native/nested/NestedTensorFactories.cpp index e9ceb91d77376..9a58ff3c5484d 100644 --- a/aten/src/ATen/native/nested/NestedTensorFactories.cpp +++ b/aten/src/ATen/native/nested/NestedTensorFactories.cpp @@ -185,7 +185,7 @@ std::vector NestedTensor_unbind( auto buffer = self.values(); std::vector sizes = NestedTensor_get_sizes(self_ptr), strides = NestedTensor_get_strides(self_ptr); - int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr(); + const int64_t *offsets_ptr = self_ptr->get_storage_offsets().const_data_ptr(); for (const int64_t i: c10::irange(ntensors)){ result_tensors[i] = buffer.as_strided(sizes[i], strides[i], offsets_ptr[i]); } @@ -212,11 +212,10 @@ Tensor narrow_nested_symint(const at::Tensor& self, int64_t dim, SymInt start, S auto nested_sizes = nt_impl->get_nested_sizes(); auto nested_strides = nt_impl->get_nested_strides(); auto storage_offsets = nt_impl->get_storage_offsets(); - auto storage_offsets_ptr = storage_offsets.data_ptr(); auto start_int = start.guard_int(__FILE__, __LINE__); auto length_int = length.guard_int(__FILE__, __LINE__); - auto buffer_offset = storage_offsets_ptr[start_int]; + auto buffer_offset = storage_offsets.const_data_ptr()[start_int]; nested_sizes = nested_sizes.narrow(0, start_int, length_int); nested_strides = nested_strides.narrow(0, start_int, length_int); diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 318bbb3728a85..bf3d6da97db60 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -232,8 +232,7 @@ Tensor nested_from_padded_generic( std::vector masks; std::vector all_sizes = sizes.unbind(); for (const auto& size : all_sizes) { - IntArrayRef sizes_i( - size.data_ptr(), size.data_ptr() + size.numel()); + IntArrayRef sizes_i(size.const_data_ptr(), size.numel()); at::Tensor mask_i = padded_transformed.new_full( sizes_i, true, kBool, std::nullopt, std::nullopt, std::nullopt); masks.push_back(pad_tensor_to_shape(mask_i, target_size_arr)); @@ -273,7 +272,7 @@ Tensor NestedTensor_to_padded_tensor_generic( const auto sizes_num_rows = sizes.sizes()[0]; const auto sizes_num_columns = sizes.sizes()[1]; - const auto sizes_data_start = sizes.data_ptr(); + const auto sizes_data_start = sizes.const_data_ptr(); const auto sizes_data_end = sizes_data_start + sizes.numel(); std::vector split_sizes; split_sizes.reserve(sizes_num_rows); @@ -430,7 +429,7 @@ Tensor select_nested(const Tensor& self, int64_t dim, int64_t index) { auto self_ptr = get_nested_tensor_impl(self); std::vector sizes = NestedTensor_get_sizes(self_ptr), strides = NestedTensor_get_strides(self_ptr); - int64_t *offsets_ptr = self_ptr->get_storage_offsets().data_ptr(); + const int64_t *offsets_ptr = self_ptr->get_storage_offsets().const_data_ptr(); const at::Tensor& buffer = self_ptr->get_unsafe_storage_as_tensor(); int64_t positive_dim = at::maybe_wrap_dim(dim, self_ptr->dim()); int64_t ntensors = self_ptr->size(0); @@ -615,7 +614,6 @@ Tensor squeeze_nested(const Tensor& self) { "squeeze(): For nested tensors, squeeze without the dim argument is not supported ", "at the moment, however you can use squeeze(Tensor self, int dim) instead ", "if you need this feature, please open an issue on github describing your use case."); - return self; } Tensor squeeze_dim_nested(const Tensor& self, IntArrayRef dims) { @@ -994,8 +992,8 @@ static bool can_cat_nested_sizes(const Tensor& nested_sizes1, const Tensor& nest return false; } - auto nested_sizes1_ptr = nested_sizes1.data_ptr(); - auto nested_sizes2_ptr = nested_sizes2.data_ptr(); + auto nested_sizes1_ptr = nested_sizes1.const_data_ptr(); + auto nested_sizes2_ptr = nested_sizes2.const_data_ptr(); const auto num_components = nested_sizes1.size(0); const auto num_dims = nested_sizes1.size(1); for (auto c : c10::irange(num_components)) { @@ -1022,6 +1020,7 @@ static Tensor cat_nested_as_jagged( const auto first_item_dim = first_item.dim(); const auto first_item_batch_size = first_item.size(0); std::vector jagged_views; + jagged_views.reserve(tensors.size()); for (auto i : c10::irange(tensors.size())) { auto t = tensors[i].get(); TORCH_CHECK(t.is_nested(), @@ -1073,6 +1072,8 @@ static Tensor cat_nested_impl( // handle simple case of dim=0: concat NT components std::vector buffers; std::vector sizes; + buffers.reserve(tensors.size()); + sizes.reserve(tensors.size()); for (const auto i : c10::irange(tensors.size())) { const Tensor& t = tensors[i]; TORCH_CHECK( diff --git a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp index 60de6dd2bdaba..a0dbd7c5de97d 100644 --- a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp @@ -79,15 +79,15 @@ static Tensor matmul_with_bmm_nested(const Tensor& self, const Tensor& mat2) { // metadata for self std::vector self_sizes = NestedTensor_get_sizes(self_ptr); std::vector self_strides = NestedTensor_get_strides(self_ptr); - int64_t* self_offsets_ptr = - self_ptr->get_storage_offsets().data_ptr(); + const int64_t* self_offsets_ptr = + self_ptr->get_storage_offsets().const_data_ptr(); auto opt = self_ptr->get_nested_sizes().options(); // metadata for mat2 std::vector mat2_sizes = NestedTensor_get_sizes(mat2_ptr); std::vector mat2_strides = NestedTensor_get_strides(mat2_ptr); - int64_t* mat2_offsets_ptr = - mat2_ptr->get_storage_offsets().data_ptr(); + const int64_t* mat2_offsets_ptr = + mat2_ptr->get_storage_offsets().const_data_ptr(); auto opt2 = mat2_ptr->get_nested_sizes().options(); int64_t N = static_cast(self_sizes.size()); diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp index afed47b15c4dc..409856f19bcf6 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.cpp @@ -195,7 +195,7 @@ Tensor NestedTensor_softmax_dropout_cuda(const Tensor& self, const Tensor& query Tensor NestedTensor_batch_offsets_from_size_tensor( const Tensor& sizes, int64_t extra_elements) { - int64_t* const sizes_ptr = sizes.data_ptr(); + const int64_t* const sizes_ptr = sizes.const_data_ptr(); Tensor offsets = at::empty({1 + sizes.size(0) + extra_elements}, at::kInt); int32_t* const offsets_ptr = offsets.mutable_data_ptr(); offsets_ptr[0] = 0; @@ -236,7 +236,7 @@ Tensor NestedTensor_to_mask(const Tensor& nt, std::optional mask_dim, s auto result = at::ones({sizes.sizes()[0], result_size_1}, at::kBool); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2); auto* result_data = result.data_ptr(); - auto* sizes_ptr = sizes.data_ptr(); + const auto* sizes_ptr = sizes.const_data_ptr(); const auto sizes_size_1 = sizes.sizes()[1]; for (const auto ii : c10::irange(sizes.sizes()[0])) { auto length = sizes_ptr[ii * sizes_size_1]; diff --git a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h index 47119fdd4a1ab..8fd0c62542e07 100644 --- a/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h +++ b/aten/src/ATen/native/nested/NestedTensorTransformerFunctions.h @@ -74,7 +74,7 @@ void remove_padding_transform0213_kernelLauncher( template void add_padding_kernelLauncher( - T* input, + const T* input, T* output, T padding_value, const int* offsets, diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.cpp b/aten/src/ATen/native/nested/NestedTensorUtils.cpp index 8284796047952..5239ec4a7b671 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.cpp +++ b/aten/src/ATen/native/nested/NestedTensorUtils.cpp @@ -37,7 +37,7 @@ std::vector NestedTensor_get_max_size_from_size_tensor( if (sizes.dim() == 0) { return {}; } - const auto sizes_ptr = sizes.data_ptr(); + const auto sizes_ptr = sizes.const_data_ptr(); const auto sizes_size_0 = sizes.sizes()[0]; const auto sizes_size_1 = sizes.sizes()[1]; TORCH_INTERNAL_ASSERT(sizes_size_1 > 0); @@ -88,7 +88,7 @@ std::vector chunk_nested_tensor(const Tensor& self, int64_t chunks, int6 const auto& sizes = self_impl->get_nested_sizes(); const auto& strides = self_impl->get_nested_strides(); const auto offsets = self_impl->get_storage_offsets(); - int64_t *offsets_ptr = offsets.data_ptr(); + const int64_t *offsets_ptr = offsets.const_data_ptr(); // Account for the implicit batch dim --dim; int64_t tensor_dim = sizes.size(1); @@ -143,7 +143,7 @@ std::vector split_with_sizes_nested( const auto& sizes = self_impl->get_nested_sizes(); const auto& strides = self_impl->get_nested_strides(); const auto offsets = self_impl->get_storage_offsets(); - int64_t *offsets_ptr = offsets.data_ptr(); + const int64_t *offsets_ptr = offsets.const_data_ptr(); // Account for the implicit batch dim --dim; int64_t tensor_dim = sizes.size(1); diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 47b8e2ad8086f..c0a8ef2b9139d 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -216,15 +216,11 @@ inline IntArrayRef get_stride_for_index(const Tensor& tensor, int64_t i) { inline int64_t get_offset_for_index(const Tensor& tensor, int64_t i) { if (tensor.is_nested()) { - int64_t* offsets_ptr = get_nested_tensor_impl(tensor) - ->get_storage_offsets() - .data_ptr(); - return offsets_ptr[i]; - - } else { - int64_t offset = tensor.storage_offset(); - return offset + tensor.strides()[0] * i; + return get_nested_tensor_impl(tensor) + ->get_storage_offsets() + .const_data_ptr()[i]; } + return tensor.storage_offset() + tensor.strides()[0] * i; } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // Data structures and functions for generically applying a function on a nested @@ -433,6 +429,8 @@ inline Tensor wrap_tensor_node( } else { // Slow path std::vector flat_tensors; std::vector sizes; + flat_tensors.reserve(tensor_node.degree()); + sizes.reserve(tensor_node.degree()); for (const auto i : c10::irange(tensor_node.degree())) { flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous()); sizes.push_back( diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 96c6ab8310f80..e2b4db95b5ac3 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -64,7 +64,7 @@ Tensor nested_from_padded_cuda( at::cat({target_size_sizes, padded_sizes_tensor, target_offsets}); metadata = metadata.to(at::Device(kCUDA), kInt, true, true); - auto output_size_ptr = metadata.data_ptr(); + auto output_size_ptr = metadata.const_data_ptr(); auto input_size_ptr = output_size_ptr + target_size_sizes.numel(); auto offsets_ptr = input_size_ptr + padded_sizes_tensor.numel(); @@ -72,7 +72,7 @@ Tensor nested_from_padded_cuda( if (padded.dtype() == kFloat) { if (do_transform_0213) { remove_padding_transform0213_kernelLauncher( - padded_contiguous.data_ptr(), + padded_contiguous.const_data_ptr(), output.data_ptr(), offsets_ptr, input_size_ptr, @@ -81,7 +81,7 @@ Tensor nested_from_padded_cuda( padded_contiguous.sizes()[0]); } else { remove_padding_kernelLauncher( - padded_contiguous.data_ptr(), + padded_contiguous.const_data_ptr(), output.data_ptr(), offsets_ptr, input_size_ptr, @@ -92,7 +92,7 @@ Tensor nested_from_padded_cuda( } else if (padded.dtype() == kHalf) { if (do_transform_0213) { remove_padding_transform0213_kernelLauncher( - padded_contiguous.data_ptr(), + padded_contiguous.const_data_ptr(), output.data_ptr(), offsets_ptr, input_size_ptr, @@ -101,7 +101,7 @@ Tensor nested_from_padded_cuda( padded_contiguous.sizes()[0]); } else { remove_padding_kernelLauncher( - padded_contiguous.data_ptr(), + padded_contiguous.const_data_ptr(), output.data_ptr(), offsets_ptr, input_size_ptr, @@ -119,7 +119,7 @@ Tensor nested_from_padded_cuda( } static Tensor batch_offsets_from_efficient_size(const Tensor& ef_sizes) { - int64_t* nt_sizes_ptr = ef_sizes.data_ptr(); + const int64_t* nt_sizes_ptr = ef_sizes.const_data_ptr(); int64_t ef_sizes_size_0 = ef_sizes.sizes()[0]; Tensor offsets = at::empty({1 + ef_sizes_size_0}, at::kLong); int64_t* offsets_ptr = offsets.mutable_data_ptr(); @@ -200,11 +200,11 @@ Tensor NestedTensor_to_padded_tensor_cuda( AT_DISPATCH_FLOATING_TYPES_AND_HALF( nt_buffer.scalar_type(), "NestedTensor_to_padded_tensor_cuda", [&]() { add_padding_kernelLauncher( - nt_buffer.data_ptr(), + nt_buffer.const_data_ptr(), output.data_ptr(), (scalar_t)(padding), - offsets.data_ptr(), - nt_sizes.data_ptr(), + offsets.const_data_ptr(), + nt_sizes.const_data_ptr(), input_dim, new_size, batch_size, @@ -463,7 +463,7 @@ std::tuple _scaled_dot_product_flash_attenti grad_k = wrap_buffer(grad_k.view(-1), key.transpose(1,2)._nested_tensor_size()).transpose(1,2); grad_v = wrap_buffer(grad_v.view(-1), value.transpose(1,2)._nested_tensor_size()).transpose(1,2); - return std::make_tuple(grad_q, grad_k, grad_v); + return std::make_tuple(std::move(grad_q), std::move(grad_k), std::move(grad_v)); } } // namespace at::native diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu index 92675321fb117..2ec16c2121cc6 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cu @@ -385,7 +385,7 @@ __global__ void add_padding_3( template void add_padding_kernelLauncher( - T* input, // [batch_size x None] + const T* input, // [batch_size x None] T* output, // [batch_size x max(input.nested_size(1)) x inner_size] T padding_value, const int* offsets, @@ -437,7 +437,7 @@ void add_padding_kernelLauncher( } template void add_padding_kernelLauncher( - double* input, + const double* input, double* output, double padding_value, const int* offsets, @@ -448,7 +448,7 @@ template void add_padding_kernelLauncher( const int output_batch_size); template void add_padding_kernelLauncher( - float* input, + const float* input, float* output, float padding_value, const int* offsets, @@ -459,7 +459,7 @@ template void add_padding_kernelLauncher( const int output_batch_size); template void add_padding_kernelLauncher( - c10::Half* input, + const c10::Half* input, c10::Half* output, c10::Half padding_value, const int* offsets, diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp index c44fd27902e04..3a1261557f10d 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp @@ -18,7 +18,7 @@ int64_t get_nnz(const Tensor& nestedtensor) { const auto& sizes = nt_impl->get_nested_sizes(); auto size_tensor_stride = sizes.stride(0); const int64_t batch_size = nestedtensor.size(0); - auto* sizes_ptr = sizes.data_ptr(); + const auto* sizes_ptr = sizes.const_data_ptr(); int64_t cumulative_sequence_length = 0; for (const auto i : c10::irange(batch_size)) { // Calculate the cumulative sum of the sequence lengths @@ -49,7 +49,7 @@ int64_t get_nnz(const Tensor& nestedtensor) { auto cumulative_seqlen = at::zeros( {batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt)); - auto* sizes_ptr = sizes.data_ptr(); + const auto* sizes_ptr = sizes.const_data_ptr(); auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr(); int64_t sum = 0; @@ -84,7 +84,7 @@ int64_t get_nnz(const Tensor& nestedtensor) { */ bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) { const int64_t* tensor_offsets_ptr = - tensor->get_storage_offsets().data_ptr(); + tensor->get_storage_offsets().const_data_ptr(); const Tensor& tensor_sizes = tensor->get_nested_sizes(); const Tensor& tensor_strides = tensor->get_nested_strides(); @@ -99,7 +99,7 @@ int64_t get_nnz(const Tensor& nestedtensor) { return true; } - int64_t* previous_tensor_stride = tensor_strides.data_ptr(); + const int64_t* previous_tensor_stride = tensor_strides.const_data_ptr(); // Check initially that the first tensor's strides // are in strictly descending order @@ -133,9 +133,8 @@ int64_t get_nnz(const Tensor& nestedtensor) { // Check the offsets are a constant multiple from the previous numels const int64_t* tensor_size_ptr = tensor_sizes.const_data_ptr(); - const int64_t* tensor_stride_ptr = tensor_strides.const_data_ptr(); - int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]); + int64_t numel_0 = (tensor_size_ptr[0] * previous_tensor_stride[0]); TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!"); int64_t offset_constant = @@ -144,7 +143,7 @@ int64_t get_nnz(const Tensor& nestedtensor) { // TODO: When 0 seq_len nested tensors are allowed we need to guard // against this int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * - tensor_stride_ptr[(i - 1) * tensor_stride_0]; + previous_tensor_stride[(i - 1) * tensor_stride_0]; TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!"); int64_t current_offset_constant = (tensor_offsets_ptr[i] - tensor_offsets_ptr[i - 1]) / previous_numel; @@ -180,9 +179,9 @@ int64_t get_nnz(const Tensor& nestedtensor) { constexpr int64_t head_dim_stride = 1; const int64_t* nt_strides = - tensor_impl->get_nested_strides().data_ptr(); + tensor_impl->get_nested_strides().const_data_ptr(); const int64_t* nt_offsets_ptr = - tensor_impl->get_storage_offsets().data_ptr(); + tensor_impl->get_storage_offsets().const_data_ptr(); const int64_t nnz_stride = nt_strides[0]; const int64_t head_stride = num_heads_needs_broadcast ? 0 : nt_strides[1]; diff --git a/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp b/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp index ffe6f4c31829d..af4d543648180 100644 --- a/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp +++ b/aten/src/ATen/native/quantized/FakeQuantPerChannelAffine.cpp @@ -260,6 +260,6 @@ std::tuple _fake_quantize_learnable_per_channel_affine_b auto dScale = dScale_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements)); auto dZeroPoint = dZeroPoint_vec.sum(at::IntArrayRef(axis_for_reduction.data(), numElements)); - return std::make_tuple(dX, dScale, dZeroPoint); + return std::make_tuple(std::move(dX), std::move(dScale), std::move(dZeroPoint)); } } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp b/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp index 5df69d01b2549..36baffc36c1ed 100644 --- a/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp +++ b/aten/src/ATen/native/quantized/cpu/ChannelShuffle.cpp @@ -102,10 +102,11 @@ Tensor channel_shuffle_quantized_cpu( int64_t groups) { #ifdef USE_PYTORCH_QNNPACK return quantized_channel_shuffle_impl(self, groups); -#endif +#else // If QNNPACK is not available then fall back to the // non quantized path. return at::native::channel_shuffle(self, groups); +#endif } // Keep the registry in the anonymous namespace. diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 94ac6350aeb0e..ac51ce9ade6eb 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -1490,7 +1490,9 @@ static at::Tensor _fp8_convolution_onednn_ref( y_f32 = y_f32.to(at::kHalf); } x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes())); - return x1; + // Return a copy: custom ops must not return tensors that alias inputs. + // The accum buffer has already been mutated in-place above. + return x1.clone(); } else { TORCH_CHECK( false, @@ -1751,7 +1753,9 @@ static at::Tensor _quantized_convolution_onednn( c10::MemoryFormat::ChannelsLast3d) ); if (output.numel() == 0) { - return output; + // When has_accum_postop_sum, output aliases accum (the input). Custom ops + // must not return tensors that alias inputs, so return a copy. + return has_accum_postop_sum ? output.clone() : output; } ideep::tensor dst = at::native::itensor_view_from_dense(output); static ideep::tensor::desc dummy_accum_desc; @@ -1885,10 +1889,12 @@ static at::Tensor _quantized_convolution_onednn( if (is_1d) { output.squeeze_(quant_utils::kConv1dSqueezeDim + 2); - return output; } if (has_accum_postop_sum) { - return accum.value(); + // When has_accum_postop_sum, output aliases accum (the input) — see + // assignment above. Return a copy: custom ops must not return tensors + // that alias inputs. + return output.clone(); } else { return output; } diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp index b115b25c42784..333796e8a24ce 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp @@ -94,7 +94,6 @@ at::Tensor PackedEmbeddingBagWeight::unpack() { TORCH_INTERNAL_ASSERT( false, "We currently only support 8-bit and 4-bit quantization of embedding_bag."); - return weight_origin; } namespace at::native { diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 7fe44de11e54c..857d0c52140e9 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -65,7 +65,6 @@ Tensor adaptive_avg_pool2d_quantized_cuda( return at::quantize_per_tensor(result_fp32, input.q_scale(), input.q_zero_point(), input.scalar_type()); #else // USE_CUDA TORCH_CHECK(false, "at::native::adaptive_avg_pool2d_quantized_cuda: ATen not compiled with USE_CUDA support"); - return Tensor{}; // never reached, placates the compiler #endif } @@ -214,7 +213,6 @@ Tensor quantized_max_pool2d_cudnn( #endif // AT_CUDNN_ENABLED() #else // USE_CUDA TORCH_CHECK(false, "at::native::quantized_max_pool2d_cudnn: ATen not compiled with USE_CUDA support"); - return Tensor{}; // never reached, placates the compiler #endif } diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 9ce3619261553..012feb784e53f 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -265,8 +265,8 @@ TORCH_LIBRARY(onednn, m) { m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, str attr, Scalar?[] scalars, str? algorithm) -> Tensor")); // Conv2D with binary postop - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); - m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor(a!) qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::qconv2d_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor(a!) qaccum, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point, ScalarType? output_dtype, float accum_scale, int accum_zero_point, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor")); // Linear prepack m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor")); diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h index 7cec767d44660..e50171f48f4c6 100644 --- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h +++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h @@ -297,8 +297,8 @@ void _sparse_binary_op_intersection_kernel_impl( std::tie(sorted_hash, argsort_hash) = [&]() -> std::tuple { if (probably_coalesced.is_coalesced()) { // NOTE: argsort.dtype == nnz_arange.dtype - const auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz()); - return std::make_tuple(probably_coalesced_indices_hash, argsort); + auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz()); + return std::make_tuple(probably_coalesced_indices_hash, std::move(argsort)); } else { // NOTE: we want argsort.dtype == nnz_arange.dtype, // but sort() produces indices of type int64_t, diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 5ecfbd589ca39..79ce7fb17bfb3 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -189,8 +189,10 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres batch_ndim, " + ", block_ndim, ") but got ", values.dim()); // 3.5 - TORCH_CHECK(plain_indices.stride(-1) == 1, - "expected ", plain_indices_name, " to be a contiguous tensor per batch"); + if (plain_indices.numel() != 0) { + TORCH_CHECK(plain_indices.stride(-1) == 1, + "expected ", plain_indices_name, " to be a contiguous tensor per batch"); + } // 3.6 TORCH_CHECK(compressed_indices.stride(-1) == 1, diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp index edb95c23b98ba..cb8baf89a86e2 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp @@ -731,8 +731,8 @@ Tensor& addmm_out_sparse_compressed_cpu( " without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors."); #else sparse::impl::mkl::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result); -#endif return result; +#endif } Tensor addmm_sparse_compressed_dense( diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 5dec3746eaa88..b51bb8d1184f3 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -95,7 +95,6 @@ bool is_coalesced_sparse(const SparseTensor& self) { bool is_coalesced_default(const Tensor& self) { TORCH_CHECK(false, "is_coalesced expected sparse coordinate tensor layout but got ", self.layout()); - return false; } int64_t _nnz_sparse(const SparseTensor& self) { @@ -676,7 +675,7 @@ SparseTensor _coalesce_sparse_cpu(const SparseTensor& self) { values.scalar_type(), "coalesce", [&] { int64_t prev = -1; int64_t blockSize = values.stride(0); - scalar_t* values_ptr = values.data_ptr(); + const scalar_t* values_ptr = values.const_data_ptr(); scalar_t* newValues_ptr = newValues.data_ptr(); for (const auto j : c10::irange(nnz)) { int64_t pos = indicesPermutationAccessor[j]; diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh index 078b354659d35..7b9e61ceb4cc7 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh @@ -246,11 +246,7 @@ __global__ void coalesceValuesKernel( // `if constexpr` when CUDA codes will be compiled under C++-17, see // gh-56055 for blockers. template -#ifdef USE_ROCM -C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_STATIC*4) -#else -C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE*4) -#endif +C10_LAUNCH_BOUNDS_1(C10_WARP_SIZE_UPPER_BOUND*4) __global__ void coalesceValuesKernel( int64_t *segment_offsets, int64_t *value_indices, bool *values, bool *newValues, diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu index 599a59a18312c..cfe91ecc2e5e8 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu @@ -59,7 +59,7 @@ struct MetadataCuSparseLt { // TODO: Cast metadata to Short static_assert(kBytesPerScalar == 2, "or modify the last dim below"); metadata = metadata.view({rows / 128, cols / 32, 256}); - return std::make_tuple(storage, packed, metadata); + return std::make_tuple(std::move(storage), std::move(packed), std::move(metadata)); } MetadataCuSparseLt(at::Tensor metaN, at::Tensor metaT, int rows, int cols) { diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 9d735ac0f2c88..24893da98c6c3 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -84,7 +84,7 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { type = CUDA_R_32F; break; #endif -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 || defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: type = CUDA_R_8F_E4M3; break; @@ -203,12 +203,18 @@ std::tuple _cslt_sparse_mm_impl( compute_type = CUSPARSE_COMPUTE_32F; break; #endif -// if cuSPARSELt >= 6.2.3, we can add Float8 support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) +// cuSPARSELt >= 0.6.2 or hipSparseLt: add Float8 support +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 || defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: input_type = CUDA_R_8F_E4M3; +#ifdef USE_ROCM + // hipSparseLt 0.2.7: FP8 input only supports FP32 output + output_type = CUDA_R_32F; + C_type = CUDA_R_32F; +#else output_type = CUDA_R_8F_E4M3; C_type = CUDA_R_16F; +#endif compute_type = CUSPARSE_COMPUTE_32F; break; #endif @@ -265,10 +271,11 @@ std::tuple _cslt_sparse_mm_impl( break; } } -// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) +// cslt 0.6.2+ or hipSparseLt: fp8 output dtype support +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 || defined(USE_ROCM) else if (input_type == CUDA_R_8F_E4M3) { switch (out_dtype) { +#ifndef USE_ROCM case at::ScalarType::Float8_e4m3fn: output_type = CUDA_R_8F_E4M3; C_type = CUDA_R_16F; @@ -281,6 +288,7 @@ std::tuple _cslt_sparse_mm_impl( output_type = CUDA_R_16BF; C_type = CUDA_R_16BF; break; +#endif case at::ScalarType::Float: output_type = CUDA_R_32F; C_type = CUDA_R_32F; @@ -288,7 +296,11 @@ std::tuple _cslt_sparse_mm_impl( default: TORCH_CHECK( false, - "Unsupported out_dtype passed, must be one of {fp16, bf16, float32} for fp8 inputs"); +#ifdef USE_ROCM + "Unsupported out_dtype passed, must be float32 for fp8 inputs on ROCm"); +#else + "Unsupported out_dtype passed, must be one of {fp8, fp16, bf16, float32} for fp8 inputs"); +#endif break; } } @@ -468,6 +480,8 @@ std::tuple _cslt_sparse_mm_impl( &alg_id, sizeof(alg_id))); +#ifndef USE_ROCM + // hipSPARSELt does not support querying SPLIT_K attributes TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( &handle, &alg_sel, @@ -481,6 +495,7 @@ std::tuple _cslt_sparse_mm_impl( CUSPARSELT_MATMUL_SPLIT_K_MODE, &splitKMode, sizeof(splitKMode))); +#endif TORCH_CUDASPARSE_CHECK(cusparseLtMatmulAlgGetAttribute( &handle, diff --git a/aten/src/ATen/native/tags.yaml b/aten/src/ATen/native/tags.yaml index d76e38dae183d..b6ba3701c661d 100644 --- a/aten/src/ATen/native/tags.yaml +++ b/aten/src/ATen/native/tags.yaml @@ -97,9 +97,12 @@ desc: | This tag indicates that an operator performs a reduction operation, computing aggregate values (sum, mean, max, min, etc.) across one or more dimensions of the input tensor(s). +- tag: out + desc: | + This tag indicates that an operator has "out" semantics: its mutable + arguments are write-only output buffers that are not read from. + The operator must return exactly the mutable arguments in the order + they are declared, and nothing else. - tag: out_variant desc: | - This tag indicates that the operator is an out variant of a functional - operator. This tag only applies to custom ops. Out variant operators - write their results to pre-allocated output tensors (the out args) - rather than allocating new tensors. + Deprecated, does nothing. Use "out" instead. diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index e04b64cf4efc1..eecfe1058c2c4 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -492,6 +492,20 @@ int64_t _fused_sdp_choice_meta( return choice_int; } #endif + bool has_xpu = query_key_set.has(c10::DispatchKey::XPU); + if (has_xpu) { + auto choice_int = _fused_sdp_choice_stub( + at::kXPU, + query_, + key, + value, + attn_mask_, + dropout_p, + is_causal, + scale, + enable_gqa); + return choice_int; + } return static_cast(sdp::SDPBackend::math); } namespace { @@ -533,7 +547,7 @@ inline void validate_sdpa_input( // the math and memory efficient attn_mask implementation // Args: // attn_mask: attn_mask of shape (B, L, S) or (L, S) or (B, N_heads, L, S) -std::optional convert_boolean_attn_mask_(const std::optional& attn_mask, caffe2::TypeMeta dtype, double neg_inf) { +std::optional convert_boolean_attn_mask(const std::optional& attn_mask, caffe2::TypeMeta dtype) { // Pass through if (!attn_mask.has_value()) { return std::nullopt; @@ -541,23 +555,13 @@ std::optional convert_boolean_attn_mask_(const std::optional& at // Convert boolean mask to additive mask; need to invert mask to indicate what // to mask *out*. if (attn_mask->dtype() == at::kBool) { + constexpr double neg_inf = -std::numeric_limits::infinity(); return at::where(*attn_mask, 0.0, at::scalar_tensor(neg_inf, at::TensorOptions().dtype(dtype).device(attn_mask->device()))); } // Otherwise, attn_mask represents an additive attention tensor return attn_mask; } -std::optional convert_boolean_attn_mask(const std::optional& attn_mask, caffe2::TypeMeta dtype) { - return convert_boolean_attn_mask_(attn_mask, dtype, -std::numeric_limits::infinity()); -} - -// alternate version to workaround -inf issue with cuDNN -// TODO(eqy): delete this when cuDNN -inf issue is resolved -std::optional convert_boolean_attn_mask_cudnn(const std::optional& attn_mask, caffe2::TypeMeta dtype) { - // TODO Use the max type of the input and output - return convert_boolean_attn_mask_(attn_mask, dtype, -65504.0); -} - // Memory Efficient Attention requires a padded attn mask bias // This function pads the attn_mask bias to be a multiple of 16 // Then slices the padded bias to the original size @@ -686,11 +690,11 @@ Tensor _safe_softmax( // greater than 0.0 is specified. // // Args: -// query (Tensor): Query tensor; shape (N, ..., L, E) -// key (Tensor): Key tensor; shape (N, ..., S, E) -// value (Tensor): Value tensor; shape (N, ..., S, E) +// query (Tensor): Query tensor; shape (N, ..., Hq, L, E) +// key (Tensor): Key tensor; shape (N, ..., H, S, E) +// value (Tensor): Value tensor; shape (N, ..., H, S, Ev) // attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, -// which is (N,..., L, S). Two types of masks are supported. +// which is (N,..., Hq, L, S). Two types of masks are supported. // A boolean mask where a value of True indicates that the element *should* take part in attention. // A float mask of the same type as query, key, value that is added to the attention score. // dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied @@ -702,14 +706,17 @@ Tensor _safe_softmax( // sparse masks) via tensor subclassing, allowing for a leaner API. // // Returns a tensor: -// output (Tensor): Attention output; shape (N, ..., L, E) +// output (Tensor): Attention output; shape (N, ..., Hq, L, Ev) // // Shape legend: // N: Batch size // ...: Any number of other batch dimensions (optional) // S: Source sequence length // L: Target sequence length -// E: Embedding dimension +// E: Embedding dimension of the query and key +// Ev: Embedding dimension of the value +// Hq: Number of heads of query +// H: Number of heads of key and value Tensor scaled_dot_product_attention( const Tensor& query_, const Tensor& key, @@ -744,8 +751,7 @@ Tensor scaled_dot_product_attention( } const auto query_device_type = query_.device().type(); const auto backend = static_cast(choice_int); - const auto convert_attn_func = backend != SDPBackend::cudnn_attention ? convert_boolean_attn_mask : convert_boolean_attn_mask_cudnn; - auto attn_mask = convert_attn_func(attn_mask_, query_.dtype()); + auto attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype()); switch (backend) { case SDPBackend::cudnn_attention: { bool compute_logsumexp = should_compute_logsumexp(query_, key, value); @@ -786,48 +792,8 @@ Tensor scaled_dot_product_attention( return std::get<0>(out_lse_softmax); } case SDPBackend::math: { -#ifdef USE_MPS - TORCH_CHECK_NOT_IMPLEMENTED( - c10::isFloatingType(query_.scalar_type()), - "scaled_dot_product_attention for MPS does not support dtype ", - query_.scalar_type()); - TORCH_CHECK_NOT_IMPLEMENTED( - c10::isFloatingType(key.scalar_type()), - "scaled_dot_product_attention for MPS does not support dtype ", - key.scalar_type()); - TORCH_CHECK_NOT_IMPLEMENTED( - c10::isFloatingType(value.scalar_type()), - "scaled_dot_product_attention for MPS does not support dtype ", - value.scalar_type()); - const auto any_nested = query_.is_nested() || key.is_nested() || value.is_nested(); const bool any_inputs_require_grad = query_.requires_grad() || key.requires_grad() || value.requires_grad(); - const auto all_contiguous = query_.is_contiguous_or_false() && key.is_contiguous_or_false() && value.is_contiguous_or_false(); - if (query_device_type == DeviceType::MPS && dropout_p == 0.0 - && !(GradMode::is_enabled() && any_inputs_require_grad) - && (all_contiguous || mps::is_macos_13_or_newer(mps::MacOSVersion::MACOS_VER_15_0_PLUS)) - && !any_nested) { - if (enable_gqa) { - int64_t q_heads = query_.size(-3); - int64_t k_heads = key.size(-3); - int64_t repeat_factor = q_heads / k_heads; - - if (repeat_factor > 1) { - TORCH_CHECK(q_heads % k_heads == 0, - "For GQA, the query tensor's head dimension (" + std::to_string(q_heads) + - ") must be divisible by the key tensor's head dimension (" + std::to_string(k_heads) + ")."); - auto repeated_key = key.repeat_interleave(repeat_factor, /*dim=*/-3); - auto repeated_value = value.repeat_interleave(repeat_factor, /*dim=*/-3); - return std::get<0>(at::_scaled_dot_product_attention_math_for_mps( - query_, - repeated_key, - repeated_value, - attn_mask, - dropout_p, - is_causal, - std::nullopt, /*dropout_mask*/ - scale)); - } - } + if (query_device_type == c10::kMPS && !(at::GradMode::is_enabled() && any_inputs_require_grad)) { return std::get<0>(at::_scaled_dot_product_attention_math_for_mps( query_, key, @@ -836,9 +802,9 @@ Tensor scaled_dot_product_attention( dropout_p, is_causal, std::nullopt, /*dropout_mask*/ - scale)); + scale, + enable_gqa)); } -#endif return std::get<0>(at::_scaled_dot_product_attention_math( query_, key, @@ -854,7 +820,6 @@ Tensor scaled_dot_product_attention( TORCH_CHECK( false, "No viable backend for scaled_dot_product_attention was found."); - return Tensor(); } } @@ -1064,11 +1029,6 @@ _scaled_dot_product_fused_attention_overrideable_backward( const at::Tensor & philox_offset, std::optional scale) { TORCH_CHECK_NOT_IMPLEMENTED(false, "_scaled_dot_product_fused_attention_overrideable_backward not implemented: This is an operator for privateuse1 backends, please use TORCH_LIBRARY_IMPL to override this function "); - return std::tuple( - at::empty_like(query), - at::empty_like(key), - at::empty_like(value), - at::empty_like(attn_bias)); } Tensor triton_multi_head_attention( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 4257b61773e09..c094d0b59e358 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -86,6 +86,7 @@ #include #else // MemoryEfficient Attention Specific Imports for ROCM +#include #ifndef DISABLE_AOTRITON #include #include @@ -95,7 +96,7 @@ #endif #endif -#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)) +#if defined(USE_ROCM) && defined(USE_FLASH_ATTENTION) namespace pytorch_flash { std::tuple< @@ -488,9 +489,6 @@ _flash_attention_forward_impl( std::optional num_splits ) { #if defined(USE_FLASH_ATTENTION) - TORCH_CHECK( - !num_splits.has_value(), - "num_splits requires FA3. Register FA3 with `register_flash_attention_fa3()` to set num_splits."); const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); @@ -499,7 +497,11 @@ _flash_attention_forward_impl( std::optional alibi_slopes = _alibi_slopes; const float softcap = 0.0; -#ifdef USE_ROCM // ROCM backend accepts std::optional for window_size_left/right directly. +#ifdef USE_ROCM + TORCH_CHECK( + !num_splits.has_value(), + "num_splits is not supported on ROCm"); + // ROCM backend accepts std::optional for window_size_left/right directly. #ifdef DISABLE_AOTRITON // CK backend, Passing window_size as it is const auto window_left = window_size_left; const auto window_right = window_size_right; @@ -553,7 +555,11 @@ _flash_attention_forward_impl( window_right, softcap, return_debug_mask, - std::nullopt /*gen_*/); + std::nullopt /*gen_*/ +#ifndef USE_ROCM + , num_splits.value_or(0) +#endif + ); } else { std::tie( output, @@ -956,7 +962,7 @@ std::tuple _scaled_dot_product_flash_attention_cuda_quantized( diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 183f99e975cda..6bf6d03541f2c 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -1109,30 +1110,27 @@ std::tuple _scaled_dot_product_e if (batch_size > MAX_BATCH_SIZE) { Tensor final_grad_q, final_grad_k, final_grad_v, final_grad_bias; - auto create_strided_output = [batch_size](const Tensor& tensor) -> Tensor { + auto create_permuted_output = [batch_size](const Tensor& tensor) -> Tensor { if (!tensor.defined()) { return Tensor{}; } - int dim = tensor.dim(); - std::vector sizes; - sizes.reserve(dim); - sizes.push_back(batch_size); - for (int i = 1; i < dim; i++) { - sizes.push_back(tensor.size(i)); - } - return at::empty_strided(std::move(sizes), tensor.strides(), tensor.options()); + TORCH_INTERNAL_ASSERT(tensor.dim() == 4); + return at::empty_permuted( + {batch_size, tensor.size(1), tensor.size(2), tensor.size(3)}, + {0, 2, 1, 3}, + tensor.options()); }; if (grad_input_mask[0]) { - final_grad_q = create_strided_output(query); + final_grad_q = create_permuted_output(query); } if (grad_input_mask[1]) { - final_grad_k = create_strided_output(key); + final_grad_k = create_permuted_output(key); } if (grad_input_mask[2]) { - final_grad_v = create_strided_output(value); + final_grad_v = create_permuted_output(value); } if (grad_input_mask[3] && attn_bias.defined()) { final_grad_bias = at::zeros_like(attn_bias); diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp index 07fd359a3ec93..8aa4b98e3f1b7 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp @@ -590,7 +590,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_) { + std::optional gen_, + int num_splits) { auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm80_or_newer = (dprops->major * 10) >= 80; @@ -763,11 +764,10 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q params.page_block_size = page_block_size; // Keep references to these tensors to extend their lifetime at::Tensor softmax_lse_accum, out_accum; - if (seqlenq_ngroups_swapped) { - // Only apply split-k for decoding + if (paged_KV || seqlenq_ngroups_swapped) { std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); + head_size_rounded, p_dropout, num_splits, dprops, opts); } // [Note] BC breaking change to flash seed/offset diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h index f5ba2c117d99b..8beece5f71ee7 100644 --- a/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h +++ b/aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.h @@ -43,7 +43,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_); + std::optional gen_, + int num_splits = 0); std::tuple diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index f90d5beeb60be..073dd37db3188 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -496,6 +496,18 @@ bool check_cudnn_dropout(sdp_params const& params, bool debug) { } bool check_cudnn_tensor_shapes(sdp_params const& params, bool debug) { + constexpr int64_t max_cudnn_dim_size = 65535; + const auto b = params.query.sym_size(0); + const auto h = params.query.sym_size(1); + if (b > max_cudnn_dim_size || h > max_cudnn_dim_size) { + if (debug) { + TORCH_WARN( + "cuDNN SDPA does not support batch size or num_heads greater than ", + max_cudnn_dim_size, + ". Got batch size: ", b, ", num_heads: ", h); + } + return false; + } const auto s_q = params.query.sym_size(2); const auto s_k = params.key.sym_size(2); const auto d_qk = params.query.sym_size(3); @@ -937,8 +949,8 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { if (dprop->major >= 8) { return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug); } -#endif return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug); +#endif } SDPBackend select_sdp_backend(sdp_params const& kernel_params) { @@ -1015,7 +1027,7 @@ bool check_for_seq_len_1_nested_tensor(sdp_params const& params, bool debug) { const auto nt_q_tensor_impl = at::native::get_nested_tensor_impl(params.query); const at::Tensor& sizes = nt_q_tensor_impl->get_nested_sizes(); - auto* sizes_ptr = sizes.data_ptr(); + const auto* sizes_ptr = sizes.const_data_ptr(); const int64_t n_tensors = params.query.size(0); const int64_t size_tensor_stride = sizes.stride(0); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index b96e80d5e5a9e..f3816296a1d4f 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -318,8 +318,16 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x persistent_counter, stream); } - - return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; + // Note: These are propagated up to the return of mha_fwd(). comments + // represent the assignments at that level + return {out, // output + q_padded, // q_padded + k_padded, // k_padded + v_padded, // v_padded + M.view({batch_size, num_heads, seqlen_q}), // logsumexp + seed_t, // philox_seed + offset_t, // philox_offset + softmax_fa_t};// debug_attn_mask } std::tuple @@ -635,7 +643,7 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); } else { - dv = at::empty_like(k); + dv = at::empty_like(v); } auto [needs_swa, window_left, window_right] = calculate_swa(window_size_left, diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index 2d3692d9f98df..d7633521d8eef 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -127,7 +127,7 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, return aiter::mha_bwd_args{ // aiter args static_cast(mask.type), - true, // use_asm_v3 + hdim <= 192, // use_asm_v3: ASM v3 only supports head dim <= 192 true, // v3_atomic_fp32 1, // v3_bf16_cvt false, // v3_api_check @@ -375,11 +375,16 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x dv_expanded = dv; } - uint64_t drop_seed = 1, drop_offset = 0; - drop_seed = *philox_seed.data_ptr(); - drop_offset = *philox_offset.data_ptr(); - auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + uint64_t* drop_seed, drop_offset; + int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + std::pair drop_seed_offset = {nullptr,nullptr}; + if(is_dropout) { + drop_seed_offset.first = philox_seed.data_ptr(); + drop_seed_offset.second = philox_offset.data_ptr(); + } if (seqlen_q > 0) { ck_tile::stream_config stream_config{stream}; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip index 441589e70d763..c5902e33afd36 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -183,7 +183,6 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); const auto sizes = q.sizes(); - const int batch_size = sizes[0]; int seqlen_q = sizes[1]; int num_heads = sizes[2]; @@ -232,7 +231,6 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); - at::Tensor q_padded, k_padded, v_padded; if (head_size % 8 != 0) { q_padded = at::pad(temp_q, {0, 8 - head_size % 8}); @@ -245,7 +243,6 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x v_padded = v; } - at::Tensor out; if (out_.has_value()) { out = out_.value(); @@ -272,7 +269,6 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x auto opts = q.options(); bool has_lse = true; bool has_dropout = p_dropout > 0.0f; - at::Tensor softmax_lse; // TODO - check gradient, only training require lse softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); @@ -283,47 +279,46 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte)); } else { - p = at::empty({ 0 }, opts); + p = at::empty({ 0, 0, 0, 0 }, opts.dtype(at::kByte)); } + uint64_t drop_seed = 1, drop_offset = 0; + at::Tensor seed_t, offset_t; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); - auto rng_state = at::empty({2}, opts.dtype(at::kLong)); - auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); - - at::Tensor seed_t, offset_t; + // rng_state is used to pass philox params to CK in a type it likes i.e. uint64 + auto rng_state_options = at::TensorOptions().dtype(at::kUInt64).device(at::kCUDA); + auto rng_state = at::zeros({2}, rng_state_options.dtype(at::kUInt64)); + auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA)); if (p_dropout > 0.0) { + auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); + // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); - auto philox_args = gen->philox_cuda_state(counter_offset); + std::tie(drop_seed, drop_offset) = at::cuda::philox::unpack(philox_args); + rng_state[0] = *(reinterpret_cast(&drop_seed)); + rng_state[1] = *(reinterpret_cast(&drop_offset)); - - hipLaunchKernelGGL( - flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr); - seed_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[0])), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[1])), at::dtype(at::kLong)); - } - else - { + } else { seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); } - std::optional attn_bias; + std::pair drop_seed_offset; if( attn_bias_.has_value()) { attn_bias = attn_bias_; } - if (seqlen_k > 0) { - auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); - auto stream = at::cuda::getCurrentCUDAStream().stream(); + drop_seed_offset.first = rng_state[0].data_ptr(); + drop_seed_offset.second = rng_state[1].data_ptr(); + auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; auto traits = @@ -364,12 +359,28 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x out.zero_(); softmax_lse.fill_(std::numeric_limits::infinity()); } - if (seqlenq_ngroups_swapped) { out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } - return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p}; + // Before returning populate seed_t and offset_t if dropout is used + if(p_dropout > 0.0) + { + seed_t = at::scalar_tensor(at::Scalar(reinterpret_cast(*(drop_seed_offset.first))), + at::dtype(at::kUInt64).device(at::kCUDA)); + offset_t = at::scalar_tensor(at::Scalar(reinterpret_cast(*(drop_seed_offset.second))), + at::dtype(at::kUInt64).device(at::kCUDA)); + } + // Note: These are propagated up to the return of mha_fwd(). comments + // represent the assignments at that level + return {out, // output + q_padded, // q_padded + k_padded, // k_padded + v_padded, // v_padded + softmax_lse, // logsumexp + seed_t, // philox_seed + offset_t, // philox_offset + p}; // debug_attn_mask } } //namespace pytorch_flash diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip index 3031965833608..896afc7320e96 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_bwd_ck.hip @@ -401,8 +401,10 @@ mha_varlen_bwd_ck(const at::Tensor &dout, // total_q x num_hea uint64_t drop_seed = 1, drop_offset = 0; - drop_seed = *philox_seed.data_ptr(); - drop_offset = *philox_offset.data_ptr(); + if (is_dropout) { + drop_seed = *reinterpret_cast(philox_seed.data_ptr()); + drop_offset = *reinterpret_cast(philox_offset.data_ptr()); + } auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); if (max_seqlen_q > 0) { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index 2a9d4899e8236..39ce82ef1e700 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -73,7 +73,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, ck_tile::index_t nhead_stride_k = k.stride(1); ck_tile::index_t nhead_stride_v = v.stride(1); ck_tile::index_t nhead_stride_o = out.stride(1); - ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0; + ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0; ck_tile::index_t batch_stride_q = 0; @@ -81,7 +81,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, ck_tile::index_t batch_stride_v = 0; ck_tile::index_t batch_stride_o = 0; - ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; + ck_tile::index_t batch_stride_lse = 0; ck_tile::index_t batch_stride_randval = 0; void *attn_bias_ptr = nullptr; @@ -287,7 +287,7 @@ mha_varlen_fwd_ck(const at::Tensor &q, // total_q x num_heads at::Tensor softmax_lse; // TODO - check gradient, only training require lse - softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; if (return_dropout_randval) { diff --git a/aten/src/ATen/native/transformers/sdp_utils.h b/aten/src/ATen/native/transformers/sdp_utils.h index 6623e5332ca12..18f13781944f8 100644 --- a/aten/src/ATen/native/transformers/sdp_utils.h +++ b/aten/src/ATen/native/transformers/sdp_utils.h @@ -37,9 +37,7 @@ void alloc_with_matching_layout( ordered_strides[dim_idx] = current_stride; current_stride *= shape[dim_idx]; } - output = at::empty(at::IntArrayRef(shape), q.options()) - .as_strided( - at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), 0); + output = at::empty_strided(at::IntArrayRef(shape), at::IntArrayRef(ordered_strides), q.options()); } void permute_to_matching_layout(const Tensor& output, Tensor& grad_output) { diff --git a/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp b/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp index a69ecf3274d2a..b5abfa27e8baa 100644 --- a/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/xpu/sdp_utils.cpp @@ -29,7 +29,8 @@ bool check_flash_attention_hardware_support( c10::array_of( sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21); + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31); auto* device_prop = at::xpu::getCurrentDeviceProperties(); auto device_architecture = device_prop->architecture; @@ -39,7 +40,7 @@ bool check_flash_attention_hardware_support( device_architecture) == supported_architectures.end()) { if (debug) { TORCH_WARN( - "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); + "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); } return false; } @@ -80,9 +81,12 @@ inline bool check_flash_attention_datatype( inline bool check_flash_attention_head_dim_size( sdp_params const& params, bool debug) { - const int query_size_last = params.query.size(3); - const int key_size_last = params.key.size(3); - const int value_size_last = params.value.size(3); + // Use sym_size to preserve symbolic shapes during tracing. + // Using concrete .size() would materialize symbolic dimensions into static + // guards, preventing dynamic shape generalization across recompilations. + const auto query_size_last = params.query.sym_size(-1); + const auto key_size_last = params.key.sym_size(-1); + const auto value_size_last = params.value.sym_size(-1); const bool head_dims_equal = (query_size_last == key_size_last) && (query_size_last == value_size_last); @@ -101,7 +105,7 @@ inline bool check_flash_attention_head_dim_size( return false; } - constexpr auto max_supported_headdim = 192; + const auto max_supported_headdim = c10::SymInt(192); if (query_size_last > max_supported_headdim) { if (debug) { TORCH_WARN( diff --git a/aten/src/ATen/native/vulkan/ops/Mm.cpp b/aten/src/ATen/native/vulkan/ops/Mm.cpp index aeb25b56e60e9..a5a0d93cccc3d 100644 --- a/aten/src/ATen/native/vulkan/ops/Mm.cpp +++ b/aten/src/ATen/native/vulkan/ops/Mm.cpp @@ -631,6 +631,7 @@ Tensor run_quantized_addmm_context( return output; } else { std::vector shape; + shape.reserve(static_cast(std::max(0, input_arg.dim()))); for (const auto i : c10::irange(input_arg.dim() - 1)) { shape.emplace_back(input_arg.size(i)); } @@ -751,6 +752,7 @@ Tensor run_addmm_context( return output; } else { std::vector shape; + shape.reserve(static_cast(std::max(0, input_arg.dim()))); for (const auto i : c10::irange(input_arg.dim() - 1)) { shape.emplace_back(input_arg.size(i)); } diff --git a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp index 1ec6957162cbb..338363c49cdbe 100644 --- a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp @@ -76,6 +76,7 @@ std::tuple native_layer_norm( const Tensor bias = bias_opt->is_vulkan() ? *bias_opt : bias_opt->vulkan(); std::vector dims_to_reduce; + dims_to_reduce.reserve(normalized_shape.size()); for (const auto i : c10::irange(normalized_shape.size())) { dims_to_reduce.push_back(input_arg.dim() - i - 1); } diff --git a/aten/src/ATen/native/vulkan/ops/Repeat.cpp b/aten/src/ATen/native/vulkan/ops/Repeat.cpp index 94c29b55d0571..0c08d7f7c26dc 100644 --- a/aten/src/ATen/native/vulkan/ops/Repeat.cpp +++ b/aten/src/ATen/native/vulkan/ops/Repeat.cpp @@ -37,6 +37,8 @@ Tensor repeat(const Tensor& self, const IntArrayRef repeats) { std::vector tensor_seq_to_concat; for (const auto i : c10::irange(out_ndims)) { + tensor_seq_to_concat.reserve( + static_cast(std::max(0, repeats[i]))); for (const auto k : c10::irange(repeats[i])) { (void)k; tensor_seq_to_concat.emplace_back(tensor_to_repeat.clone()); diff --git a/aten/src/ATen/nnapi/CMakeLists.txt b/aten/src/ATen/nnapi/CMakeLists.txt index 9e367028e8a2d..4ba4048220370 100644 --- a/aten/src/ATen/nnapi/CMakeLists.txt +++ b/aten/src/ATen/nnapi/CMakeLists.txt @@ -3,7 +3,7 @@ if(PYTORCH_NNAPI_STANDALONE) cmake_minimum_required(VERSION 3.10 FATAL_ERROR) project(pytorch_nnapi) - set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard whose features are requested to build this target.") + set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") find_package(Torch REQUIRED) set(NNAPI_SRCS diff --git a/aten/src/ATen/nnapi/codegen.py b/aten/src/ATen/nnapi/codegen.py index 57b1e3a696fa8..f8bc0741fa751 100755 --- a/aten/src/ATen/nnapi/codegen.py +++ b/aten/src/ATen/nnapi/codegen.py @@ -36,156 +36,156 @@ NNAPI_FUNCTIONS = [ - ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), # noqa: B950 + ("int", "ANeuralNetworks_getDeviceCount", "uint32_t* numDevices"), ( "int", "ANeuralNetworks_getDevice", "uint32_t devIndex, ANeuralNetworksDevice** device", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksDevice_getName", "const ANeuralNetworksDevice* device, const char** name", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksDevice_getVersion", "const ANeuralNetworksDevice* device, const char** version", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksDevice_getFeatureLevel", "const ANeuralNetworksDevice* device, int64_t* featureLevel", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_getSupportedOperationsForDevices", " const ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, bool* supportedOps", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksCompilation_createForDevices", - "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", # noqa: B950 + "ANeuralNetworksModel* model, const ANeuralNetworksDevice* const* devices, uint32_t numDevices, ANeuralNetworksCompilation** compilation", ), ( "int", "ANeuralNetworksExecution_compute", "ANeuralNetworksExecution* execution", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksMemory_createFromFd", "size_t size, int protect, int fd, size_t offset, ANeuralNetworksMemory** memory", - ), # noqa: B950 + ), ( "void", "ANeuralNetworksMemory_free", "ANeuralNetworksMemory* memory", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_create", "ANeuralNetworksModel** model", - ), # noqa: B950 - ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), # noqa: B950 - ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), # noqa: B950 + ), + ("void", "ANeuralNetworksModel_free", "ANeuralNetworksModel* model"), + ("int", "ANeuralNetworksModel_finish", "ANeuralNetworksModel* model"), ( "int", "ANeuralNetworksModel_addOperand", "ANeuralNetworksModel* model, const ANeuralNetworksOperandType* type", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_setOperandValue", "ANeuralNetworksModel* model, int32_t index, const void* buffer, size_t length", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_setOperandValueFromMemory", "ANeuralNetworksModel* model, int32_t index, const ANeuralNetworksMemory* memory, size_t offset, size_t length", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_addOperation", - "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", # noqa: B950 + "ANeuralNetworksModel* model, ANeuralNetworksOperationType type, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", ), ( "int", "ANeuralNetworksModel_identifyInputsAndOutputs", "ANeuralNetworksModel* model, uint32_t inputCount, const uint32_t* inputs, uint32_t outputCount, const uint32_t* outputs", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksModel_relaxComputationFloat32toFloat16", "ANeuralNetworksModel* model, bool allow", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksCompilation_create", "ANeuralNetworksModel* model, ANeuralNetworksCompilation** compilation", - ), # noqa: B950 + ), ( "void", "ANeuralNetworksCompilation_free", "ANeuralNetworksCompilation* compilation", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksCompilation_setPreference", "ANeuralNetworksCompilation* compilation, int32_t preference", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksCompilation_finish", "ANeuralNetworksCompilation* compilation", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksExecution_create", "ANeuralNetworksCompilation* compilation, ANeuralNetworksExecution** execution", - ), # noqa: B950 + ), ( "void", "ANeuralNetworksExecution_free", "ANeuralNetworksExecution* execution", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksExecution_setInput", - "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", # noqa: B950 + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const void* buffer, size_t length", ), ( "int", "ANeuralNetworksExecution_setInputFromMemory", - "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950 + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", ), ( "int", "ANeuralNetworksExecution_setOutput", "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, void* buffer, size_t length", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksExecution_setOutputFromMemory", - "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", # noqa: B950 + "ANeuralNetworksExecution* execution, int32_t index, const ANeuralNetworksOperandType* type, const ANeuralNetworksMemory* memory, size_t offset, size_t length", ), ( "int", "ANeuralNetworksExecution_startCompute", "ANeuralNetworksExecution* execution, ANeuralNetworksEvent** event", - ), # noqa: B950 - ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), # noqa: B950 - ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), # noqa: B950 + ), + ("int", "ANeuralNetworksEvent_wait", "ANeuralNetworksEvent* event"), + ("void", "ANeuralNetworksEvent_free", "ANeuralNetworksEvent* event"), ( "int", "ANeuralNetworksExecution_getOutputOperandRank", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* rank", - ), # noqa: B950 + ), ( "int", "ANeuralNetworksExecution_getOutputOperandDimensions", "ANeuralNetworksExecution* execution, int32_t index, uint32_t* dimensions", - ), # noqa: B950 + ), ] diff --git a/aten/src/ATen/test/native_test.cpp b/aten/src/ATen/test/native_test.cpp index 5f27ce4886e47..ff4507941a648 100644 --- a/aten/src/ATen/test/native_test.cpp +++ b/aten/src/ATen/test/native_test.cpp @@ -253,6 +253,9 @@ void test(TensorOptions T, TensorOptions AccT) { } TEST(TestNative, NativeTestCPU) { +#if defined(__aarch64__) + GTEST_SKIP() << "Known failure on AArch64 (label is too far / stack test mismatch)"; +#endif manual_seed(123); test(at::device(kCPU).dtype(kFloat), diff --git a/aten/src/ATen/test/test_install/CMakeLists.txt b/aten/src/ATen/test/test_install/CMakeLists.txt index c2d54feeb37d0..10d6805a092cb 100644 --- a/aten/src/ATen/test/test_install/CMakeLists.txt +++ b/aten/src/ATen/test/test_install/CMakeLists.txt @@ -2,9 +2,6 @@ cmake_minimum_required(VERSION 3.10) find_package(ATen REQUIRED) include_directories(${ATEN_INCLUDE_DIR}) -# C++17 -if(not MSVC) - set(CMAKE_CXX_FLAGS "--std=c++17 ${CMAKE_CXX_FLAGS}") -endif() +set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") add_executable(main main.cpp) target_link_libraries(main ${ATEN_LIBRARIES}) diff --git a/aten/src/ATen/test/vitals.cpp b/aten/src/ATen/test/vitals.cpp deleted file mode 100644 index eaf1cc152bc37..0000000000000 --- a/aten/src/ATen/test/vitals.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include - -using namespace at::vitals; -using ::testing::HasSubstr; - -TEST(Vitals, Basic) { - std::stringstream buffer; - - std::streambuf* sbuf = std::cout.rdbuf(); - std::cout.rdbuf(buffer.rdbuf()); - { - c10::utils::set_env("TORCH_VITAL", "1"); - TORCH_VITAL_DEFINE(Testing); - TORCH_VITAL(Testing, Attribute0) << 1; - TORCH_VITAL(Testing, Attribute1) << '1'; - TORCH_VITAL(Testing, Attribute2) << 1.0f; - TORCH_VITAL(Testing, Attribute3) << 1.0; - auto t = at::ones({1, 1}); - TORCH_VITAL(Testing, Attribute4) << t; - } - std::cout.rdbuf(sbuf); - - auto s = buffer.str(); - ASSERT_THAT(s, HasSubstr("Testing.Attribute0\t\t 1")); - ASSERT_THAT(s, HasSubstr("Testing.Attribute1\t\t 1")); - ASSERT_THAT(s, HasSubstr("Testing.Attribute2\t\t 1")); - ASSERT_THAT(s, HasSubstr("Testing.Attribute3\t\t 1")); - ASSERT_THAT(s, HasSubstr("Testing.Attribute4\t\t 1")); -} - -TEST(Vitals, MultiString) { - std::stringstream buffer; - - std::streambuf* sbuf = std::cout.rdbuf(); - std::cout.rdbuf(buffer.rdbuf()); - { - c10::utils::set_env("TORCH_VITAL", "1"); - TORCH_VITAL_DEFINE(Testing); - TORCH_VITAL(Testing, Attribute0) << 1 << " of " << 2; - TORCH_VITAL(Testing, Attribute1) << 1; - TORCH_VITAL(Testing, Attribute1) << " of "; - TORCH_VITAL(Testing, Attribute1) << 2; - } - std::cout.rdbuf(sbuf); - - auto s = buffer.str(); - ASSERT_THAT(s, HasSubstr("Testing.Attribute0\t\t 1 of 2")); - ASSERT_THAT(s, HasSubstr("Testing.Attribute1\t\t 1 of 2")); -} - -TEST(Vitals, OnAndOff) { - for (const auto i : c10::irange(2)) { - std::stringstream buffer; - - std::streambuf* sbuf = std::cout.rdbuf(); - std::cout.rdbuf(buffer.rdbuf()); - { - c10::utils::set_env("TORCH_VITAL", i ? "1" : "0"); - TORCH_VITAL_DEFINE(Testing); - TORCH_VITAL(Testing, Attribute0) << 1; - } - std::cout.rdbuf(sbuf); - - auto s = buffer.str(); - auto f = s.find("Testing.Attribute0\t\t 1"); - if (i) { - ASSERT_TRUE(f != std::string::npos); - } else { - ASSERT_TRUE(f == std::string::npos); - } - } -} - -TEST(Vitals, APIVitals) { - std::stringstream buffer; - bool rvalue = false; - std::streambuf* sbuf = std::cout.rdbuf(); - std::cout.rdbuf(buffer.rdbuf()); - { - c10::utils::set_env("TORCH_VITAL", "1"); - APIVitals api_vitals; - rvalue = api_vitals.setVital("TestingSetVital", "TestAttr", "TestValue"); - } - std::cout.rdbuf(sbuf); - - auto s = buffer.str(); - ASSERT_TRUE(rvalue); - ASSERT_THAT(s, HasSubstr("TestingSetVital.TestAttr\t\t TestValue")); -} diff --git a/aten/src/ATen/xpu/CachingHostAllocator.cpp b/aten/src/ATen/xpu/CachingHostAllocator.cpp index 2be5eaf81bb03..c6a1a807801e2 100644 --- a/aten/src/ATen/xpu/CachingHostAllocator.cpp +++ b/aten/src/ATen/xpu/CachingHostAllocator.cpp @@ -43,8 +43,7 @@ struct XPUCachingHostAllocatorImpl } bool stream_is_capturing(XPUStream s) const override { - return c10::xpu::CaptureStatus(s.queue().ext_oneapi_get_state()) == - c10::xpu::CaptureStatus::Recording; + return s.is_capturing(); } }; diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.cpp b/aten/src/ATen/xpu/XPUGeneratorImpl.cpp index 09c84eeef9cea..133a9e4728dd0 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.cpp +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.cpp @@ -97,7 +97,7 @@ void XPUGeneratorState::increase(uint64_t increment) { } // State can be used by multiple graph -void XPUGeneratorState::register_graph(xpu::XPUGraph* graph) { +void XPUGeneratorState::register_graph(xpu::XPUGraphImpl* graph) { // Ensures that the RNG state is not currently being captured. at::xpu::assertNotCapturing( "Cannot register the state during capturing stage."); @@ -113,7 +113,7 @@ void XPUGeneratorState::register_graph(xpu::XPUGraph* graph) { } } -void XPUGeneratorState::unregister_graph(xpu::XPUGraph* graph) { +void XPUGeneratorState::unregister_graph(xpu::XPUGraphImpl* graph) { TORCH_CHECK( registered_graphs_.find(graph) != registered_graphs_.end(), "The graph should be registered to the state"); @@ -280,12 +280,12 @@ uint64_t XPUGeneratorImpl::philox_offset_per_thread() const { } } -void XPUGeneratorImpl::register_graph(xpu::XPUGraph* graph) { +void XPUGeneratorImpl::register_graph(xpu::XPUGraphImpl* graph) { graph->register_generator_state(state_); state_->register_graph(graph); } -void XPUGeneratorImpl::unregister_graph(xpu::XPUGraph* graph) { +void XPUGeneratorImpl::unregister_graph(xpu::XPUGraphImpl* graph) { state_->unregister_graph(graph); } diff --git a/aten/src/ATen/xpu/XPUGeneratorImpl.h b/aten/src/ATen/xpu/XPUGeneratorImpl.h index 8ee4967a4f13f..853d1f28f006a 100644 --- a/aten/src/ATen/xpu/XPUGeneratorImpl.h +++ b/aten/src/ATen/xpu/XPUGeneratorImpl.h @@ -8,7 +8,7 @@ namespace at { namespace xpu { -struct XPUGraph; +struct XPUGraphImpl; } struct XPUGeneratorState : public c10::intrusive_ptr_target { @@ -16,7 +16,7 @@ struct XPUGeneratorState : public c10::intrusive_ptr_target { uint64_t philox_offset_per_thread_; uint32_t offset_intragraph_; bool capturing_{}; - std::unordered_set registered_graphs_; + std::unordered_set registered_graphs_; at::TensorBase seed_extragraph_{}; at::TensorBase offset_extragraph_{}; @@ -29,8 +29,8 @@ struct XPUGeneratorState : public c10::intrusive_ptr_target { offset_intragraph_(offset_intragraph) {} void increase(uint64_t increment); - void register_graph(xpu::XPUGraph* graph); - void unregister_graph(xpu::XPUGraph* graph); + void register_graph(xpu::XPUGraphImpl* graph); + void unregister_graph(xpu::XPUGraphImpl* graph); void capture_prologue(); uint64_t capture_epilogue(); void replay_prologue(uint64_t wholegraph_increment); @@ -62,8 +62,8 @@ struct TORCH_XPU_API XPUGeneratorImpl : public GeneratorImpl { void set_philox_offset_per_thread(uint64_t offset); uint64_t philox_offset_per_thread() const; - void register_graph(xpu::XPUGraph* graph); - void unregister_graph(xpu::XPUGraph* graph); + void register_graph(xpu::XPUGraphImpl* graph); + void unregister_graph(xpu::XPUGraphImpl* graph); PhiloxXpuState philox_xpu_state(uint64_t increment); std::pair philox_engine_inputs(uint64_t increment); static c10::DeviceType device_type(); diff --git a/aten/src/ATen/xpu/XPUGraph.cpp b/aten/src/ATen/xpu/XPUGraph.cpp index 3d3a22e5b7219..d6d3cad494a04 100644 --- a/aten/src/ATen/xpu/XPUGraph.cpp +++ b/aten/src/ATen/xpu/XPUGraph.cpp @@ -1,6 +1,5 @@ #include #include -#include #include #include @@ -16,23 +15,43 @@ MempoolId_t graph_pool_handle() { return c10::xpu::MemPool::graph_pool_handle(); } -XPUGraph::XPUGraph(bool keep_graph) +XPUGraphImpl::XPUGraphImpl(const GraphImplArgs& args) : capture_stream_(at::xpu::getCurrentXPUStream()), - keep_graph_(keep_graph) {} + keep_graph_(args.keep_graph) {} -void XPUGraph::register_generator_state( +void XPUGraphImpl::register_generator_state( c10::intrusive_ptr state) { captured_generator_states_[std::move(state)] = 0; } -void XPUGraph::register_generator_state(const at::Generator& generator) { +void XPUGraphImpl::register_generator_state(const at::Generator& generator) { c10::intrusive_ptr xpu_gen = dynamic_intrusive_pointer_cast( generator.getIntrusivePtr()); xpu_gen->register_graph(this); } -void XPUGraph::capture_begin(MempoolId_t pool) { +void XPUGraphImpl::capture_begin( + MempoolId_t pool /*={0,0}*/, + GraphCaptureMode capture_mode) { + switch (capture_mode) { + case GraphCaptureMode::Default: + break; + + case GraphCaptureMode::Global: + case GraphCaptureMode::ThreadLocal: + case GraphCaptureMode::Relaxed: + TORCH_WARN( + "XPUGraph currently only support default GraphCaptureMode. " + "Falling back to default capture behavior."); + break; + + default: + TORCH_CHECK( + false, + "Invalid GraphCaptureMode value: ", + static_cast(capture_mode)); + } TORCH_CHECK( !has_graph_exec_, "This XPUGraph instance already owns a captured graph. " @@ -90,7 +109,7 @@ void XPUGraph::capture_begin(MempoolId_t pool) { capture_stream_.queue().ext_oneapi_get_state() == queue_state::recording); } -void XPUGraph::capture_end() { +void XPUGraphImpl::capture_end() { auto stream = at::xpu::getCurrentXPUStream(); TORCH_CHECK( @@ -125,7 +144,7 @@ void XPUGraph::capture_end() { } } -void XPUGraph::instantiate() { +void XPUGraphImpl::instantiate() { TORCH_CHECK( capture_ended_, "capture_end() must have been called before calling instantiate"); @@ -141,7 +160,7 @@ void XPUGraph::instantiate() { has_graph_exec_ = true; } -void XPUGraph::replay() { +void XPUGraphImpl::replay() { TORCH_CHECK( capture_ended_, "Called XPUGraph::replay without a preceding successful capture."); @@ -164,7 +183,7 @@ void XPUGraph::replay() { queue.ext_oneapi_graph(*graph_exec_); } -void XPUGraph::reset() { +void XPUGraphImpl::reset() { if (capture_ended_) { c10::xpu::XPUCachingAllocator::releasePool(capture_dev_, mempool_id_); at::getHostAllocator(at::kXPU)->release_pool(mempool_id_); @@ -180,11 +199,11 @@ void XPUGraph::reset() { } } -void XPUGraph::enable_debug_mode() { +void XPUGraphImpl::enable_debug_mode() { _xpu_graphs_debug = true; } -void XPUGraph::debug_dump(const std::string& debug_path) { +void XPUGraphImpl::debug_dump(const std::string& debug_path) { TORCH_CHECK( debug_path.size() >= 4 && debug_path.substr(debug_path.size() - 4) == ".dot", @@ -207,7 +226,7 @@ void XPUGraph::debug_dump(const std::string& debug_path) { } } -xpuGraph_t* XPUGraph::raw_xpu_graph() { +xpuGraph_t* XPUGraphImpl::raw_xpu_graph() { TORCH_CHECK( keep_graph_, "You cannot access the raw xpuGraph_t instance unless XPUGraph was initialized with keep_graph=true"); @@ -217,7 +236,7 @@ xpuGraph_t* XPUGraph::raw_xpu_graph() { return graph_.get(); } -xpuGraphExec_t* XPUGraph::raw_xpu_graph_exec() { +xpuGraphExec_t* XPUGraphImpl::raw_xpu_graph_exec() { TORCH_CHECK( has_graph_exec_, "You cannot access the raw xpuGraphExec_t instance until instantiate() has been called"); @@ -226,14 +245,14 @@ xpuGraphExec_t* XPUGraph::raw_xpu_graph_exec() { // Returns an id another graph's capture_begin can use to share the same memory // pool as this graph. -MempoolId_t XPUGraph::pool() { +MempoolId_t XPUGraphImpl::pool() const { TORCH_CHECK( capture_ended_, "Called XPUGraph::pool() without a preceding successful capture."); return mempool_id_; } -XPUGraph::~XPUGraph() { +XPUGraphImpl::~XPUGraphImpl() { for (auto& [generator_state, wholegraph_increments] : captured_generator_states_) { generator_state->unregister_graph(this); @@ -241,4 +260,6 @@ XPUGraph::~XPUGraph() { reset(); } +REGISTER_GRAPH_IMPL(XPU, XPUGraphImpl) + } // namespace at::xpu diff --git a/aten/src/ATen/xpu/XPUGraph.h b/aten/src/ATen/xpu/XPUGraph.h index 0f23a1e093ee0..fd60bb5b6ba9e 100644 --- a/aten/src/ATen/xpu/XPUGraph.h +++ b/aten/src/ATen/xpu/XPUGraph.h @@ -1,18 +1,15 @@ #pragma once #include +#include +#include #include #include #include #include #include -namespace at { - -struct Generator; -struct XPUGeneratorState; - -namespace xpu { +namespace at::xpu { TORCH_XPU_API MempoolId_t graph_pool_handle(); @@ -21,21 +18,26 @@ using xpuGraph_t = sycl::ext::oneapi::experimental::command_graph< using xpuGraphExec_t = sycl::ext::oneapi::experimental::command_graph< sycl::ext::oneapi::experimental::graph_state::executable>; -struct TORCH_XPU_API XPUGraph { - XPUGraph(bool keep_graph = false); - ~XPUGraph(); +struct TORCH_XPU_API XPUGraphImpl : public at::GraphImplInterface { + XPUGraphImpl(const GraphImplArgs& args = {}); + ~XPUGraphImpl() override; + + C10_DISABLE_COPY_AND_ASSIGN(XPUGraphImpl); void register_generator_state( c10::intrusive_ptr state); void register_generator_state(const at::Generator& generator); - void capture_begin(MempoolId_t pool = {0, 0}); - void capture_end(); - void instantiate(); - void replay(); - void reset(); - MempoolId_t pool(); - void enable_debug_mode(); - void debug_dump(const std::string& debug_path); + + void capture_begin( + MempoolId_t pool = {0, 0}, + GraphCaptureMode capture_mode = GraphCaptureMode::Default) override; + void capture_end() override; + void instantiate() override; + void replay() override; + void reset() override; + MempoolId_t pool() const override; + void enable_debug_mode() override; + void debug_dump(const std::string& debug_path) override; xpuGraph_t* raw_xpu_graph(); xpuGraphExec_t* raw_xpu_graph_exec(); @@ -59,5 +61,58 @@ struct TORCH_XPU_API XPUGraph { bool keep_graph_; }; -} // namespace xpu -} // namespace at +struct TORCH_XPU_API XPUGraph { + XPUGraph(bool keep_graph = false) { + GraphImplArgs args; + args.keep_graph = keep_graph; + impl_ = std::make_unique(args); + } + ~XPUGraph() = default; + + C10_DISABLE_COPY_AND_ASSIGN(XPUGraph); + XPUGraph(XPUGraph&& other) = delete; + XPUGraph& operator=(XPUGraph&& other) = delete; + + void register_generator_state( + c10::intrusive_ptr state) { + impl_->register_generator_state(state); + } + void register_generator_state(const at::Generator& generator) { + impl_->register_generator_state(generator); + } + void capture_begin(MempoolId_t pool = {0, 0}) { + impl_->capture_begin(pool); + } + void capture_end() { + impl_->capture_end(); + } + void instantiate() { + impl_->instantiate(); + } + void replay() { + impl_->replay(); + } + void reset() { + impl_->reset(); + } + MempoolId_t pool() const { + return impl_->pool(); + } + void enable_debug_mode() { + impl_->enable_debug_mode(); + } + void debug_dump(const std::string& debug_path) { + impl_->debug_dump(debug_path); + } + xpuGraph_t* raw_xpu_graph() { + return impl_->raw_xpu_graph(); + } + xpuGraphExec_t* raw_xpu_graph_exec() { + return impl_->raw_xpu_graph_exec(); + } + + private: + std::unique_ptr impl_; +}; + +} // namespace at::xpu diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh index 93e48aec90851..2c57d428f58ea 100755 --- a/aten/tools/run_tests.sh +++ b/aten/tools/run_tests.sh @@ -55,6 +55,7 @@ run_if_exists cuda_allocator_test if [ "$VALGRIND" == "ON" ]; then # NB: As these tests are invoked by valgrind, let's leave them for now as it's # unclear if valgrind -> python -> gtest would work + export LD_LIBRARY_PATH="${CPP_TESTS_DIR}:${LD_LIBRARY_PATH}" valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 "${CPP_TESTS_DIR}/basic" --gtest_filter='-*CUDA' if [[ -x ${CPP_TESTS_DIR}/tensor_interop_test ]]; then valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 "${CPP_TESTS_DIR}/tensor_interop_test" diff --git a/aten/tools/valgrind.sup b/aten/tools/valgrind.sup index 585487c4d2be2..c86c374bd1307 100644 --- a/aten/tools/valgrind.sup +++ b/aten/tools/valgrind.sup @@ -10,6 +10,16 @@ ... } +{ + + Memcheck:Addr8 + fun:strncmp + fun:is_dst + ... + fun:decompose_rpath + ... +} + { ignore_empty_generic_uninitialised_conditional_jump Memcheck:Cond diff --git a/benchmarks/diffusion/compile_benchmark.py b/benchmarks/diffusion/compile_benchmark.py index 934636da28840..c78ab839b00d2 100644 --- a/benchmarks/diffusion/compile_benchmark.py +++ b/benchmarks/diffusion/compile_benchmark.py @@ -107,8 +107,8 @@ def wan_benchmark(mode, backend="inductor"): prompt = ( "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in " "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." - ) # noqa: B950 - negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" # noqa: B950 + ) + negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" block_types = (diffusers.models.transformers.transformer_wan.WanTransformerBlock,) compile_model(pipe.transformer, mode, block_types, backend) @@ -139,7 +139,7 @@ def ltx_benchmark(mode, backend="inductor"): height = 512 - (512 % pipe.vae_spatial_compression_ratio) width = 704 - (704 % pipe.vae_spatial_compression_ratio) - prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." # noqa: B950 + prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region." negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" block_types = ( diff --git a/benchmarks/distributed/ddp/benchmark.py b/benchmarks/distributed/ddp/benchmark.py index c8c2112e27925..fdd77c65b9bf5 100644 --- a/benchmarks/distributed/ddp/benchmark.py +++ b/benchmarks/distributed/ddp/benchmark.py @@ -108,7 +108,7 @@ def append_benchmark(prefix, ranks, opts=None): def local_print(msg): if dist.get_rank() == 0: - print(msg, end="", flush=True) # noqa: E999 + print(msg, end="", flush=True) def print_header(): local_print("\n") diff --git a/benchmarks/distributed/ddp/diff.py b/benchmarks/distributed/ddp/diff.py index cfeb90cd6fa25..23fdb56283966 100644 --- a/benchmarks/distributed/ddp/diff.py +++ b/benchmarks/distributed/ddp/diff.py @@ -49,9 +49,9 @@ def main(): # Print header print() - print(f"{'':>10s}", end="") # noqa: E999 + print(f"{'':>10s}", end="") for _ in [75, 95]: - print(f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="") # noqa: E999 + print(f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="") print() # Print measurements @@ -66,7 +66,7 @@ def main(): ngpus = len(xa["ranks"]) ma = sorted(xa["measurements"]) mb = sorted(xb["measurements"]) - print(f"{ngpus:>4d} GPUs:", end="") # noqa: E999 + print(f"{ngpus:>4d} GPUs:", end="") for p in [75, 95]: va = np.percentile(ma, p) vb = np.percentile(mb, p) @@ -75,7 +75,7 @@ def main(): print( f" p{p:02d}: {vb:8.3f}s {int(batch_size / vb):7d}/s {delta:+8.1f}%", end="", - ) # noqa: E999 + ) print() print() diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index 0799b804bbf8c..8fd420c017d12 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -50,47 +50,47 @@ densenet121,pass,0 -detectron2_fasterrcnn_r_101_c4,fail_accuracy,42 +detectron2_fasterrcnn_r_101_c4,fail_accuracy,43 -detectron2_fasterrcnn_r_101_dc5,fail_accuracy,42 +detectron2_fasterrcnn_r_101_dc5,fail_accuracy,43 -detectron2_fasterrcnn_r_101_fpn,fail_accuracy,46 +detectron2_fasterrcnn_r_101_fpn,fail_accuracy,47 -detectron2_fasterrcnn_r_50_c4,fail_accuracy,42 +detectron2_fasterrcnn_r_50_c4,fail_accuracy,43 -detectron2_fasterrcnn_r_50_dc5,fail_accuracy,42 +detectron2_fasterrcnn_r_50_dc5,fail_accuracy,43 -detectron2_fasterrcnn_r_50_fpn,fail_accuracy,46 +detectron2_fasterrcnn_r_50_fpn,fail_accuracy,47 -detectron2_fcos_r_50_fpn,fail_accuracy,24 +detectron2_fcos_r_50_fpn,fail_accuracy,25 -detectron2_maskrcnn_r_101_c4,fail_accuracy,56 +detectron2_maskrcnn_r_101_c4,fail_accuracy,57 -detectron2_maskrcnn_r_101_fpn,fail_accuracy,62 +detectron2_maskrcnn_r_101_fpn,fail_accuracy,63 -detectron2_maskrcnn_r_50_c4,fail_accuracy,56 +detectron2_maskrcnn_r_50_c4,fail_accuracy,57 -detectron2_maskrcnn_r_50_fpn,fail_accuracy,62 +detectron2_maskrcnn_r_50_fpn,fail_accuracy,63 @@ -242,7 +242,7 @@ vgg16,pass,0 -vision_maskrcnn,fail_accuracy,30 +vision_maskrcnn,fail_accuracy,31 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv index e4aabce10466d..2712e2d1a5012 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_huggingface_inference.csv @@ -30,7 +30,7 @@ DistilBertForMaskedLM,pass,0 -DistillGPT2,pass,2 +DistillGPT2,pass,3 @@ -50,7 +50,7 @@ LayoutLMForMaskedLM,pass,0 -M2M100ForConditionalGeneration,pass,7 +M2M100ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv index e4aabce10466d..2712e2d1a5012 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_huggingface_inference.csv @@ -30,7 +30,7 @@ DistilBertForMaskedLM,pass,0 -DistillGPT2,pass,2 +DistillGPT2,pass,3 @@ -50,7 +50,7 @@ LayoutLMForMaskedLM,pass,0 -M2M100ForConditionalGeneration,pass,7 +M2M100ForConditionalGeneration,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 72fd3af5beeda..3ef5716e9306a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,22 +detectron2_fcos_r_50_fpn,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv index e4c20cfebf465..9dd6d40930389 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv @@ -190,7 +190,7 @@ vgg16,pass,6 -vision_maskrcnn,fail_to_run,37 +vision_maskrcnn,fail_to_run,38 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv index 87dd88078f222..ac90b8870040e 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_inference.csv @@ -30,7 +30,7 @@ DistilBertForMaskedLM,pass,0 -DistillGPT2,pass,2 +DistillGPT2,pass,3 @@ -50,7 +50,7 @@ LayoutLMForMaskedLM,pass,0 -M2M100ForConditionalGeneration,pass,7 +M2M100ForConditionalGeneration,pass,0 @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv index 5ca03b5ecf9fb..452bf8707fbfb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_huggingface_training.csv @@ -10,7 +10,7 @@ AllenaiLongformerBase,pass,9 -BartForCausalLM,pass,6 +BartForCausalLM,pass,7 @@ -30,7 +30,7 @@ DistilBertForMaskedLM,pass,5 -DistillGPT2,pass,7 +DistillGPT2,pass,8 @@ -50,11 +50,11 @@ LayoutLMForMaskedLM,pass,5 -M2M100ForConditionalGeneration,pass,11 +M2M100ForConditionalGeneration,pass,4 -MBartForCausalLM,pass,6 +MBartForCausalLM,pass,7 @@ -70,15 +70,15 @@ MobileBertForMaskedLM,pass,3 -OPTForCausalLM,pass,8 +OPTForCausalLM,pass,7 -PLBartForCausalLM,pass,6 +PLBartForCausalLM,pass,7 -PegasusForCausalLM,pass,6 +PegasusForCausalLM,pass,7 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index bc98e325ec784..d04a273049dbf 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,22 +detectron2_fcos_r_50_fpn,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv index e4c20cfebf465..90c1f0972a3b2 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv @@ -170,6 +170,10 @@ soft_actor_critic,pass,6 +speech_transformer,pass,16 + + + squeezenet1_1,pass,6 @@ -190,7 +194,7 @@ vgg16,pass,6 -vision_maskrcnn,fail_to_run,37 +vision_maskrcnn,fail_to_run,38 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv index d1150f849e2ee..0c22866fffccb 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,22 +detectron2_fcos_r_50_fpn,pass,23 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv index 877277a5aa192..a4d42c7d7723a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_inductor_torchbench_training.csv @@ -90,7 +90,7 @@ microbench_unbacked_tolist_sum,pass,9 -mnasnet1_0,pass,0 +mnasnet1_0,pass,7 @@ -102,7 +102,7 @@ mobilenet_v2_quantized_qat,eager_fail_to_run,0 -mobilenet_v3_large,pass,0 +mobilenet_v3_large,pass,7 @@ -141,7 +141,7 @@ pytorch_unet,pass_due_to_skip,7 -resnet152,eager_two_runs_differ,0 +resnet152,eager_two_runs_differ,7 @@ -193,7 +193,7 @@ vgg16,pass,6 -vision_maskrcnn,fail_to_run,37 +vision_maskrcnn,fail_to_run,38 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv index 87dd88078f222..24ac655e319ae 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_huggingface_inference.csv @@ -122,4 +122,4 @@ meta-llama/Llama-3.2-1B,pass,0 -openai/whisper-tiny,pass,5 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv index 6f65795e3f04e..69c864d8853d7 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/rocm/inductor_huggingface_inference.csv @@ -122,7 +122,7 @@ google/gemma-3-4b-it,pass_due_to_skip,0 -openai/whisper-tiny,pass,6 +openai/whisper-tiny,pass,0 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index bd067c7985a76..95d7b75b5b93c 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1953,10 +1953,6 @@ def disable_cudagraph_models(self): def guard_on_nn_module_models(self): return set() - @property - def inline_inbuilt_nn_modules_models(self): - return set() - def get_tolerance_and_cosine_flag(self, is_training, current_device, name): raise NotImplementedError @@ -2028,17 +2024,18 @@ def cast_based_on_args(self, model, example_inputs): return model, example_inputs - def validate_model(self, model, example_inputs): + def validate_model(self, name, model, example_inputs): """ Runs the eager model with example inputs to ensure that eager passes. """ model = self.deepcopy_model(model) example_inputs = clone_inputs(example_inputs) model, example_inputs = self.cast_based_on_args(model, example_inputs) - try: - self.model_iter_fn(model, example_inputs) - except Exception as e: - raise RuntimeError("Eager run failed") from e + with self.pick_grad(name, self.args.training): + try: + self.model_iter_fn(model, example_inputs) + except Exception as e: + raise RuntimeError("Eager run failed") from e def maybe_cast(self, model, example_inputs): model, example_inputs = self.cast_based_on_args(model, example_inputs) @@ -2169,6 +2166,38 @@ def deepcopy_and_maybe_parallelize(self, model): ) return model + def _write_accuracy_row(self, status, dynamo_start_stats, tag): + """ + Shared CSV + signpost writer for accuracy checks. + """ + headers = ["dev", "name", "batch_size", "accuracy"] + fields = [current_device, current_name, current_batch_size, status] + + if tag is not None: + headers.insert(3, "tag") + fields.insert(3, tag) + + o_headers = list(headers) + o_fields = list(fields) + + dynamo_stats = get_dynamo_stats() + dynamo_stats.subtract(dynamo_start_stats) + for k, v in dynamo_stats.items(): + headers.append(k) + fields.append(v) + + total_wall_time = output_signpost( + dict(zip(o_headers, o_fields)), + self.args, + self.suite_name, + ) + headers.append("compilation_latency") + fields.append(total_wall_time) + write_outputs(output_filename, headers, fields) + + if self.args.print_compilation_time: + print(f"Compilation time (from dynamo_timed): {total_wall_time}") + def check_accuracy( self, name, model, example_inputs, optimize_ctx, experiment, tag ): @@ -2191,34 +2220,7 @@ def record_status(accuracy_status, dynamo_start_stats): ): accuracy_status = "pass" - headers = ["dev", "name", "batch_size", "accuracy"] - fields = [current_device, current_name, current_batch_size, accuracy_status] - - if tag is not None: - headers.insert(3, "tag") - fields.insert(3, tag) - - o_headers = list(headers) - o_fields = list(fields) - - dynamo_stats = get_dynamo_stats() - dynamo_stats.subtract(dynamo_start_stats) - for k, v in dynamo_stats.items(): - headers.append(k) - fields.append(v) - - total_wall_time = output_signpost( - dict(zip(o_headers, o_fields)), - self.args, - self.suite_name, - ) - headers.append("compilation_latency") - fields.append(total_wall_time) - write_outputs(output_filename, headers, fields) - - if self.args.print_compilation_time: - print(f"Compilation time (from dynamo_timed): {total_wall_time}") - + self._write_accuracy_row(accuracy_status, dynamo_start_stats, tag) return accuracy_status if name in self.skip_accuracy_checks_large_models_dashboard: @@ -2507,6 +2509,141 @@ def record_status(accuracy_status, dynamo_start_stats): return record_status(accuracy_status, dynamo_start_stats=start_stats) + def check_batch_invariance( + self, name, model, example_inputs, optimize_ctx, experiment, tag + ): + """ + Batch invariance check: run the compiled forward at N, N/2, ..., 1 and + verify each output matches the reference sliced to that range bitwise. + + Always exercises forward-only, even under --training: batch invariance + is a property of the forward pass; backward and optimizer step + aggregate over the batch and are not batch-invariant by construction. + Models with batch-dependent forward ops (e.g. BatchNorm in train mode) + will still fail here -- that's inherent, not a harness bug. + """ + start_stats = get_dynamo_stats() + + def record_status(status, dynamo_start_stats): + self._write_accuracy_row(status, dynamo_start_stats, tag) + return status + + if name in self.skip_accuracy_checks_large_models_dashboard: + return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) + + if ( + name in self.skip_accuracy_check_as_eager_non_deterministic + or name in self.non_deterministic_models + ): + return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) + + full_batch = current_batch_size + if full_batch is None or full_batch < 2: + return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) + + # If no input tensor has batch as its first dim, the slicer below is a + # no-op and the comparison would trivially pass without actually + # exercising batch invariance. Skip rather than report a misleading pass. + if not any( + isinstance(x, torch.Tensor) and x.dim() > 0 and x.shape[0] == full_batch + for x in pytree.tree_leaves(example_inputs) + ): + return record_status("pass_due_to_skip", dynamo_start_stats=start_stats) + + def make_slicer(target): + def slicer(x): + if x.dim() > 0 and x.shape[0] == full_batch: + return x[:target].contiguous() + return x + + return slicer + + def run_fresh(inputs): + # Rebuild model for every run so parameter-mutating side effects + # (BN running stats, caches, etc.) from a prior run don't bleed + # into the next comparison. Force eval mode regardless of + # --training: dropout and train-mode BN are batch-size-dependent + # by construction. Forward-only: backward/optimizer aggregate + # over the batch and are not batch-invariant by construction. + reset_rng_state() + torch._dynamo.reset() + torch._dynamo.utils.counters.clear() + model_copy = self.deepcopy_and_maybe_parallelize(model) + model_copy.eval() + try: + optimized_iter_fn = optimize_ctx(self.forward_pass) + return self.run_n_iterations(model_copy, inputs, optimized_iter_fn) + finally: + del model_copy + empty_gpu_cache(current_device) + + with self.pick_grad(name, self.args.training): + model, example_inputs = self.maybe_cast(model, example_inputs) + + try: + reference = run_fresh(clone_inputs(example_inputs)) + except Exception as e: + log.exception("") + status = ( + "OOM" + if isinstance(e, torch.cuda.OutOfMemoryError) + else "fail_to_run" + ) + return record_status(status, dynamo_start_stats=start_stats) + + size = full_batch // 2 + while size >= 1: + slicer = make_slicer(size) + sliced_inputs = tree_map_only( + torch.Tensor, slicer, clone_inputs(example_inputs) + ) + + try: + out = run_fresh(sliced_inputs) + except Exception as e: + log.exception("") + status = ( + "OOM" + if isinstance(e, torch.cuda.OutOfMemoryError) + else f"fail_to_run_at_batch_{size}" + ) + return record_status(status, dynamo_start_stats=start_stats) + + reference_sliced = tree_map_only(torch.Tensor, slicer, reference) + + # Only compare batch-first output tensors. Aggregated outputs + # (e.g. HuggingFace's MaskedLMOutput.loss) don't have a batch + # dim and legitimately differ between batch sizes; comparing + # them would produce misleading failures. + def keep_batch_first(x): + return x if x.dim() > 0 and x.shape[0] == size else None + + ref_for_cmp = tree_map_only( + torch.Tensor, keep_batch_first, reference_sliced + ) + out_for_cmp = tree_map_only(torch.Tensor, keep_batch_first, out) + + try: + is_same = bitwise_same( + ref_for_cmp, out_for_cmp, equal_nan=self.equal_nan + ) + except Exception: + is_same = False + + if not is_same: + if self.args.skip_accuracy_check: + return record_status( + "pass_due_to_skip", dynamo_start_stats=start_stats + ) + return record_status( + f"fail_batch_invariance_at_{size}", + dynamo_start_stats=start_stats, + ) + + size //= 2 + + return record_status("pass", dynamo_start_stats=start_stats) + def check_tolerance( self, name, model, example_inputs, optimize_ctx, base_device="cpu" ): @@ -3025,9 +3162,14 @@ def run_one_model( start_stats = get_dynamo_stats() if self.args.accuracy: - status = self.check_accuracy( - name, model, example_inputs, optimize_ctx, experiment, tag - ) + if self.args.batch_invariant: + status = self.check_batch_invariance( + name, model, example_inputs, optimize_ctx, experiment, tag + ) + else: + status = self.check_accuracy( + name, model, example_inputs, optimize_ctx, experiment, tag + ) print(status) if status == "fail_accuracy" and self.args.minify: self.minify_model( @@ -3208,6 +3350,17 @@ def parse_args(args=None): parser.add_argument( "--freezing", action="store_true", help="turn on freezing", default=False ) + parser.add_argument( + "--deterministic", + action="store_true", + help="Enable deterministic mode (torch.use_deterministic_algorithms, cudnn.deterministic, etc.)", + ) + parser.add_argument( + "--batch-invariant", + action="store_true", + help="Check batch invariance: compare compiled forward outputs at full vs half batch " + "size and verify they match bitwise. Only valid with --accuracy.", + ) parser.add_argument( "--inductor-config", "-c", @@ -3277,7 +3430,7 @@ def get_example_inputs(self): parser.add_argument( "--distributed-master-port", default="6789", - help="Port to bind for for torch.distributed. Use the default unless it's conflicting with another user", + help="Port to bind for torch.distributed. Use the default unless it's conflicting with another user", ) parser.add_argument( "--dynamic-shapes", @@ -3719,7 +3872,10 @@ def get_example_inputs(self): run_mode_group.add_argument( "--inference", action="store_true", help="Performs inference" ) - return parser.parse_args(args) + parsed = parser.parse_args(args) + if parsed.batch_invariant and not parsed.accuracy: + parser.error("--batch-invariant requires --accuracy") + return parsed def process_caching_precompile(): @@ -3871,7 +4027,8 @@ def print_comparison(all_results): ) print(f"{'=' * 80}", flush=True) - # Build base command, stripping --compare-backed-unbacked and --only + value + # Build base command, stripping --compare-backed-unbacked, --only, --filter and their values + # Handles both space-separated (--filter VALUE) and equals-separated (--filter=VALUE) forms filtered = [] skip_next = False for a in sys.argv: @@ -3880,8 +4037,13 @@ def print_comparison(all_results): if skip_next: skip_next = False continue - if a == "--only": - skip_next = True + if a == "--only" or a.startswith("--only="): + if "=" not in a: + skip_next = True + continue + if a == "--filter" or a.startswith("--filter="): + if "=" not in a: + skip_next = True continue filtered.append(a) base_cmd = [sys.executable, "-B"] + filtered @@ -3988,7 +4150,7 @@ def write_csv_when_exception(args, name: str, status: str, device=None): write_outputs(output_filename, headers, row) -def setup_determinism_for_accuracy_test(args): +def setup_determinism(args): if args.only is not None and args.only not in { "alexnet", "Background_Matting", @@ -4024,6 +4186,21 @@ def setup_determinism_for_accuracy_test(args): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.mkldnn.deterministic = True + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False) + patch_torch_manual_seed() + + +def setup_batch_invariant(args): + if not torch.cuda.is_available(): + return + setup_determinism(args) + inductor_config.triton.cudagraphs = False + torch.backends.cuda.preferred_blas_library("cublaslt") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False) def run(runner, args, original_dir=None): @@ -4084,12 +4261,22 @@ def run(runner, args, original_dir=None): "DLRM+DDP is unsupported as it requires sharding the embedding layer separately from DDP" ) return sys.exit(-1) + if args.deterministic and not args.accuracy: + setup_determinism(args) + if args.accuracy: # Use small batch size. We use >1 batch size to ensure we test # batch_norm type of operators that work on batch dims. # TODO - Go through the failures for batch size = 2 if args.batch_size is None: - if runner.suite_name == "huggingface": + if args.batch_invariant: + if runner.suite_name == "huggingface": + args.batch_size = 8 + elif runner.suite_name == "torchbench": + args.batch_size = 8 + else: + args.batch_size = 16 + elif runner.suite_name == "huggingface": args.batch_size = 1 elif runner.suite_name == "torchbench": args.batch_size = 4 @@ -4107,9 +4294,10 @@ def run(runner, args, original_dir=None): args.use_eval_mode = True inductor_config.fallback_random = True - setup_determinism_for_accuracy_test(args) + setup_determinism(args) + if args.batch_invariant: + setup_batch_invariant(args) - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" if args.only is not None and args.only in { "nvidia_deeprecommender", }: @@ -4117,13 +4305,6 @@ def run(runner, args, original_dir=None): torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False - torch.backends.cudnn.allow_tf32 = False - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False) - - # Remove randomness when torch manual seed is called - patch_torch_manual_seed() - # Some models e.g. yolov3 assert batch size on n_gpus if "CUDA_VISIBLE_DEVICES" not in os.environ and not args.multiprocess: args.device_index = "0" @@ -4335,8 +4516,8 @@ def run(runner, args, original_dir=None): ) model_iter_fn = baseline_ctx(runner.model_iter_fn) - # needed to avoid error that causes inconsistent timing due to: - # Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards + # needed to avoid CUDAGraph fast-path warning / inconsistent timing when prior + # outputs still require backward (see torch._inductor.cudagraph_trees) def model_iter_fn_and_mark_step(*args, **kwargs): torch.compiler.cudagraph_mark_step_begin() model_iter_fn(*args, **kwargs) @@ -4657,22 +4838,17 @@ def detect_and_mark_batch(t, use_unbacked=False): if name in runner.guard_on_nn_module_models: guard_ctx = torch._dynamo.config.patch(guard_nn_modules=True) - inline_ctx = contextlib.nullcontext() - if name in runner.inline_inbuilt_nn_modules_models: - inline_ctx = torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) - with guard_ctx: - with inline_ctx: - runner.run_one_model( - name, - model, - example_inputs, - optimize_ctx, - experiment, - explain=args.explain, - tag=args.tag, - batch_size=batch_size if args.dynamic_batch_only else None, - ) + runner.run_one_model( + name, + model, + example_inputs, + optimize_ctx, + experiment, + explain=args.explain, + tag=args.tag, + batch_size=batch_size if args.dynamic_batch_only else None, + ) if args.generate_aot_autograd_stats: stats_file = output_filename.split(".csv")[0] + "_stats.csv" write_outputs( diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index b91777f70ed98..d62d99bfc4ffb 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -489,7 +489,7 @@ def generate(self, _, example_inputs, collect_outputs=True): else: model.eval() - self.validate_model(model, example_inputs) + self.validate_model(model_name, model, example_inputs) return device, model_name, model, example_inputs, batch_size def iter_model_names(self, args): @@ -552,7 +552,7 @@ def compute_loss(self, pred): return pred[0] def forward_pass(self, mod, inputs, collect_outputs=True): - with self.autocast(**self.autocast_arg): + with torch.no_grad(), self.autocast(**self.autocast_arg): res = mod(**inputs) return res.logits if self.hf_llm else res diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index a4af57174ea19..b22988c4ba9c9 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -74,7 +74,6 @@ batch_size: tolerance: higher_training: - MT5ForConditionalGeneration - - BertForMaskedLM higher_max_autotune_training: [] diff --git a/benchmarks/dynamo/launch_compile_op_numerics.py b/benchmarks/dynamo/launch_compile_op_numerics.py new file mode 100644 index 0000000000000..808e8414f265f --- /dev/null +++ b/benchmarks/dynamo/launch_compile_op_numerics.py @@ -0,0 +1,974 @@ +import argparse +import asyncio +import datetime +import functools +import hashlib +import itertools +import json +import logging +import os +import re +import shlex +import sys +from dataclasses import dataclass, field +from pathlib import Path + + +log = logging.getLogger(__name__) + +try: + import torch +except ImportError: + torch = None + +PYTORCH_NIGHTLY_CUDA_VERSIONS = ["12.6", "12.8", "13.0"] +PYTORCH_CUDA_VERSIONS = { + "2.9.1": ["12.6", "12.8", "13.0"], + "2.9.0": ["12.6", "12.8", "13.0"], + "2.8.0": ["12.6", "12.8", "12.9"], + "2.7.1": ["11.8", "12.6", "12.8"], + "2.7.0": ["11.8", "12.6", "12.8"], + "2.6.0": ["11.8", "12.4", "12.6"], + "2.5.1": ["11.8", "12.1", "12.4"], + "2.5.0": ["11.8", "12.1", "12.4"], + "2.4.1": ["11.8", "12.1", "12.4"], + "2.4.0": ["11.8", "12.1", "12.4"], + "2.3.1": ["11.8", "12.1"], + "2.3.0": ["11.8", "12.1"], + "2.2.2": ["11.8", "12.1"], + "2.2.1": ["11.8", "12.1"], + "2.2.0": ["11.8", "12.1"], + "2.1.2": ["11.8", "12.1"], + "2.1.1": ["11.8", "12.1"], + "2.1.0": ["11.8", "12.1"], + "2.0.1": ["11.8"], + "2.0.0": ["11.8"], +} + +ENABLED_CONFIGS = [ + # ("git:5f09e6a6c93e0b5bf75b635cddc03b85bbe85938", "12.8"), + ("nightly", "13.0"), + # ("2.9.1", "12.8"), + # ("2.7.1", "12.8"), +] + +PYTHON_VERSION = "3.11" + + +@dataclass +class Mode: + compile: bool + backend: str | None = None + mode: str | None = None + options: dict | None = None + env: dict = field(default_factory=dict) + + def __call__(self, fn): + if self.compile: + kwargs = {} + if self.backend is not None: + kwargs["backend"] = self.backend + if self.mode is not None: + kwargs["mode"] = self.mode + if self.options is not None: + options = { + k: v + for k, v in self.options.items() + if k in torch._inductor.list_options() + } + kwargs["options"] = options + return torch.compile(fn, **kwargs) + assert self.backend is None # noqa: S101 + assert self.mode is None # noqa: S101 + assert self.options is None # noqa: S101 + return fn + + +MODES = { + "eager": Mode(compile=False), + "decomp": Mode(compile=True, backend="aot_eager_decomp_partition"), + "compile_numerics": Mode( + compile=True, + options={ + "emulate_precision_casts": True, + "use_fast_math": False, + "emulate_division_rounding": True, + "eager_numerics.division_rounding": True, + "eager_numerics.disable_ftz": True, + }, + ), + "compile": Mode(compile=True), +} + +DTYPES = ["float32", "float16", "bfloat16"] + + +async def run(command, capture_output=True, capture_stderr=True, stderr=None, **kwargs): + log.info("Running command: %s", command) + proc = await asyncio.create_subprocess_exec( + *shlex.split(command), + stdout=asyncio.subprocess.PIPE if capture_output else None, + stderr=asyncio.subprocess.STDOUT if capture_stderr else stderr, + stdin=asyncio.subprocess.DEVNULL, + **kwargs, + ) + if capture_output: + result = await proc.communicate() + result = result[0].decode("utf-8") + log.info("Command result: %s", result) + log.info("Return code: %s", proc.returncode) + return result + else: + await proc.wait() + log.info("Return code: %s", proc.returncode) + return + + +async def copy_file_to_remote(hostname, source, destination): + await run(f"scp {source} {hostname}:{destination}") + + +async def execute_on_remote(hostname, command): + await run(f"ssh {hostname} {command}", capture_output=False) + + +async def copy_results_from_remote(hostname, run_id): + await run(f"mkdir -p results_{run_id}/logs") + await run(f"scp {hostname}:/workspace/result_*.jsonl results_{run_id}/") + await run( + f"scp {hostname}:/workspace/logs.tar.gz results_{run_id}/logs_{hostname}.tar.gz" + ) + await run(f"mkdir -p results_{run_id}/logs") + await run( + f"tar xzf results_{run_id}/logs_{hostname}.tar.gz -C results_{run_id}/logs --strip 2" + ) + await run( + "bash -c 'for f in *.trace; do cd $f; uvx tlparse --no-browser *.log & cd ..; done; wait'", + cwd=f"results_{run_id}/logs", + ) + await run(f"rm results_{run_id}/logs_{hostname}.tar.gz") + await run(f"bash -c 'rm results_{run_id}/logs/*.trace/*.log'") + await run(f"bash -c 'rm results_{run_id}/logs/*.trace/tl_out/raw*'") + + +async def do_numerics_test(args, hostname, gpu, run_id): + await copy_file_to_remote(hostname, __file__, "/workspace/run.py") + await execute_on_remote(hostname, f"python3 /workspace/run.py --worker --gpu {gpu}") + await copy_results_from_remote(hostname, run_id) + + +async def reserve_gpu(gpu): + stdout = await run( + f"gpu-dev reserve -g 1 -h 24 -t {gpu} --ignore-no-persist --disk none" + ) + # the output has a string like this: SSH Command: ssh gpu-dev-1db65ec7 + match = re.search(r"SSH Command: ssh (.*)", stdout) + hostname = match.group(1) + log.info( + "Hostname: %s, Reservation ID: %s", + hostname, + hostname.removeprefix("gpu-dev-"), + ) + return hostname, hostname.removeprefix("gpu-dev-") + + +async def cancel_reservation(reservation_id): + await run(f"gpu-dev cancel {reservation_id}") + + +async def run_on_gpu(gpu, command): + log.info("Reserving GPU %s", gpu) + hostname, reservation_id = await reserve_gpu(gpu) + try: + log.info("Running command on %s", hostname) + await command(hostname=hostname) + except Exception: + log.exception("Error running command on %s", gpu) + finally: + log.info("Cancelling reservation %s", reservation_id) + await cancel_reservation(reservation_id) + + +async def launcher(args): + run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + await asyncio.gather( + *[ + run_on_gpu( + gpu, + functools.partial(do_numerics_test, args=args, gpu=gpu, run_id=run_id), + ) + for gpu in args.gpu + ] + ) + await run(f"{sys.executable} {__file__} --html {run_id}") + + +def format_results_to_html(run_id): + results_dir = Path(f"results_{run_id}") + for f in results_dir.glob("*.jsonl"): + html_file = f.with_suffix(".html") + lines = f.read_text().splitlines(keepends=True) + data = [json.loads(line) for line in lines] + KEYS = data[0].keys() + print(KEYS) + FILTER_KEYS = [ + "gpu", + "pytorch_version", + "cuda_version", + "pytorch_mode", + "is_golden", + "data_type", + "function", + "pass_type", + "category", + "match_full", + "match_normal", + ] + KEY_VALUES = {k: sorted({line[k] for line in data}) for k in FILTER_KEYS} + with html_file.open("w") as html_f: + html_f.write("\n") + html_f.write(f"

Results for {f.stem}

\n") + for key, values in KEY_VALUES.items(): + html_f.write( + f"
{key} ({len(values)})\n" + ) + html_f.write( + f"Enable all\n" + ) + html_f.write( + f"Disable all\n" + ) + for value in values: + safe_value = str(value).replace(".", "_") + html_f.write( + f"\n" + ) + html_f.write( + f"\n" + ) + html_f.write("
") + # one more filter: show "first" mode only + html_f.write("
First mode that fails only\n") + html_f.write( + "\n" + ) + html_f.write("
\n") + # column filter: + html_f.write( + "
Column filter\n" + ) + html_f.write( + "Enable all\n" + ) + html_f.write( + "Disable all\n" + ) + DEFAULT_HIDDEN_KEYS = [ + "gpu", + "pytorch_version", + "cuda_version", + "is_golden", + ] + for key in KEYS: + html_f.write( + f"\n" + ) + html_f.write( + f"\n" + ) + html_f.write("
\n") + html_f.write("\n") + html_f.write("\n") + for key in KEYS: + html_f.write(f"\n") + + html_f.write("\n") + html_f.write("\n") + html_f.write("\n") + for line in lines: + data = json.loads(line) + classes = [ + f"visible-{k}-{str(data[k]).replace('.', '_')}" for k in KEY_VALUES + ] + html_f.write(f"\n") + for key in KEYS: + html_f.write(f"\n") + html_f.write( + f"" + ) + html_f.write( + f"\n" + ) + html_f.write("\n") + if data["mismatch_sample"]: + html_f.write( + f"\n" + ) + html_f.write(f"\n") + html_f.write("\n") + + html_f.write("
{key}logtlparse
\n") + if key == "mismatch_sample": + html_f.write( + f"\n" + ) + html_f.write( + f"\n" + ) + else: + html_f.write(f"{data[key]}\n") + html_f.write("logtrace
\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    html_f.write("\n")
+                    for sample in data["mismatch_sample"]:
+                        html_f.write("\n")
+                        html_f.write(f"\n")
+                        html_f.write(
+                            f"\n"
+                        )
+                        html_f.write(f"\n")
+                        html_f.write(f"\n")
+                        html_f.write(f"\n")
+                        html_f.write("\n")
+                    html_f.write("
posinputoutputgoldenrel_err
{sample['pos']}{', '.join(map(str, sample['input']))}{sample['output']}{sample['golden']}{sample['rel_err']}
\n") + html_f.write("
\n") + html_f.write("\n") + html_f.write("\n") + + +global_utils_lock = asyncio.Lock() +cuda_download_lock = asyncio.Lock() +venv_creation_lock = asyncio.Lock() +pytorch_build_lock = asyncio.Lock() +numerics_test_lock = asyncio.Lock() + + +async def create_results(args, config): + pytorch_version, cuda_version = config + async with global_utils_lock: + await ensure_global_utils() + async with cuda_download_lock: + await ensure_cuda(args, cuda_version) + async with venv_creation_lock: + venv = await create_venv(args, config) + async with pytorch_build_lock: + await maybe_build_pytorch(args, config, venv) + async with numerics_test_lock: + await run_numerics_test(args, config, venv) + + +async def ensure_global_utils(): + if os.environ.get("HAS_GLOBAL_UTILS", "0") == "1": + return + await run("mkdir global_bin") + os.environ["PATH"] = f"/workspace/global_bin:{os.environ.get('PATH', '')}" + + os.environ["UV_INSTALL_DIR"] = "/workspace/global_bin" + os.environ["UV_CACHE_DIR"] = "/workspace/.cache/uv" + await run("wget --no-verbose https://astral.sh/uv/install.sh") + await run("bash install.sh") + await run("rm install.sh") + + await run( + "wget --no-verbose https://github.com/seeraven/gitcache/releases/download/v1.0.29/gitcache_v1.0.29_Ubuntu22.04_x86_64" + ) + await run("mv gitcache_v1.0.29_Ubuntu22.04_x86_64 global_bin/gitcache") + await run("chmod +x global_bin/gitcache") + which_gitcache = await run("which gitcache") + await run(f"ln -s {which_gitcache.strip()} global_bin/git") + os.environ["GITCACHE_DIR"] = "/workspace/.cache/gitcache" + + await run("sudo apt-get update") + await run("sudo apt-get install -y ccache git-lfs") + os.environ["CMAKE_C_COMPILER_LAUNCHER"] = "ccache" + os.environ["CMAKE_CXX_COMPILER_LAUNCHER"] = "ccache" + os.environ["CMAKE_CUDA_COMPILER_LAUNCHER"] = "ccache" + os.environ["CCACHE_DIR"] = "/workspace/.cache/ccache" + os.environ["CCACHE_NOHASHDIR"] = "1" + os.environ["CCACHE_BASEDIR"] = "/workspace" + os.environ["CMAKE_BUILD_PARALLEL_LEVEL"] = "20" + os.environ["MAX_JOBS"] = "8" + os.environ["MAKEFLAGS"] = "-j20" + os.environ["NINJAFLAGS"] = "-j20" + + os.environ["HAS_GLOBAL_UTILS"] = "1" + + +async def ensure_cuda(args, cuda_version): + if os.path.exists(f"/workspace/cuda-{cuda_version}"): + return + await run("sudo apt-get update") + await run( + f"sudo apt-get install -y cuda-toolkit-{cuda_version.split('.')[0]}-{cuda_version.split('.')[1]}" + ) + + +async def create_venv(args, config): + pytorch_version, cuda_version = config + env = os.environ.copy() + # cuda_home = f"/usr/local/cuda-{cuda_version}" + cuda_home = ( + f"/usr/local/cuda-{cuda_version.split('.')[0]}.{cuda_version.split('.')[1]}" + ) + env["CUDA_HOME"] = cuda_home + env["CUDA_ROOT"] = cuda_home + env["PATH"] = f"{cuda_home}/bin:{env.get('PATH', '')}" + env["LD_LIBRARY_PATH"] = f"{cuda_home}/lib64:{env.get('LD_LIBRARY_PATH', '')}" + env["CUDA_NVCC_EXECUTABLE"] = f"{cuda_home}/bin/nvcc" + env["PYTORCH_NVCC"] = f"{cuda_home}/bin/nvcc" + venv_dir = f"/workspace/venv-pytorch_{pytorch_version}-cuda_{cuda_version}" + await run(f"mkdir -p {venv_dir}") + await run(f"uv venv -p {PYTHON_VERSION} --managed-python", cwd=venv_dir, env=env) + # update env with venv paths + env["VIRTUAL_ENV"] = f"{venv_dir}/.venv" + env["PATH"] = f"{venv_dir}/.venv/bin:{env.get('PATH', '')}" + await run("uv pip install pip numpy", env=env, cwd=venv_dir) + + index_url = f"https://download.pytorch.org/whl/cu{cuda_version.replace('.', '')}" + if pytorch_version == "nightly": + await run( + f"uv pip install --pre torch --index-url {index_url}", env=env, cwd=venv_dir + ) + elif not pytorch_version.startswith("git:"): + await run( + f"uv pip install torch=={pytorch_version} --index-url {index_url}", + env=env, + cwd=venv_dir, + ) + return {"env": env, "cwd": venv_dir} + + +async def maybe_build_pytorch(args, config, venv): + pytorch_version, cuda_version = config + if pytorch_version.startswith("git:"): + raise NotImplementedError("Building PyTorch from git is not supported yet") + return + + +async def run_numerics_test(args, config, venv): + pytorch_version, cuda_version = config + golden_flag = "" + if args.create_golden: + golden_flag = " --create-golden" + for mode in MODES: + cmd = ( + f"python {__file__} --runner --gpu {args.gpu[0]}" + f" --pytorch-version {pytorch_version}" + f" --cuda-version {cuda_version}" + f" --mode {mode}" + f" --golden /workspace/golden{golden_flag}" + ) + await run(cmd, capture_output=False, **venv) + golden_flag = "" + + +def sortable_config_key(config): + pytorch_version, cuda_version = config + cuda_version = tuple(map(int, cuda_version.split("."))) + + if pytorch_version == "nightly": + pytorch_version = (-1,) + elif pytorch_version.startswith("git:"): + pytorch_version = (-2,) + else: + pytorch_version = tuple(map(int, pytorch_version.split("."))) + + return (pytorch_version, cuda_version) + + +async def worker(args): + assert len(args.gpu) == 1 # noqa: S101 + # find golden task + # a config is a tuple (pytorch_version, cuda_version) + # where pytorch_version can either be a version string or nightly or a git hash + golden_config = sorted(ENABLED_CONFIGS, key=sortable_config_key)[-1] + await run("mkdir -p /workspace/logs") + args.create_golden = True + await create_results(args, golden_config) + args.create_golden = False + await asyncio.gather( + *[ + create_results(args, config) + for config in ENABLED_CONFIGS + if config != golden_config + ] + ) + await run("tar czf /workspace/logs.tar.gz /workspace/logs") + + +CATEGORIES = { + "reduction": ["torch.sum", "torch.mean", "torch.softmax"], + "matrix": ["torch.matmul"], + "normalization": [ + "torch.nn.functional.layer_norm", + "torch.nn.functional.rms_norm", + ], + "activation": [ + "torch.nn.functional.relu", + "torch.nn.functional.sigmoid", + "torch.nn.functional.tanh", + "torch.nn.functional.gelu", + "torch.nn.functional.silu", + ], + "elementary": [ + "torch.sin", + "torch.cos", + "torch.tan", + "torch.sigmoid", + "torch.exp", + "torch.exp2", + "torch.log", + "torch.log2", + "torch.sqrt", + "torch.erf", + "torch.reciprocal", + "torch.rsqrt", + ], + "binary": [ + "torch.add", + "torch.sub", + "torch.mul", + "torch.div", + "torch.pow", + ], +} + +PASS_TYPES = { + "reduction": ["fwd", "bwd_0"], + "matrix": ["fwd", "bwd_0", "bwd_1"], + "normalization": ["fwd", "bwd_0", "bwd_1"], + "activation": ["fwd", "bwd_0"], + "elementary": ["fwd", "bwd_0"], + "binary": ["fwd", "bwd_0", "bwd_1"], +} + +PASSES = {} + + +def register_pass(pass_type): + def decorator(func): + PASSES[pass_type] = func + return func + + return decorator + + +@register_pass("fwd") +def pass_fwd(callable): + def wrapper(*args): + return callable(*args) + + return wrapper + + +@register_pass("bwd_0") +def pass_bwd_0(callable): + def wrapper(*args): + args = [arg.detach().requires_grad_(idx == 0) for idx, arg in enumerate(args)] + output = callable(*args) + output.sum().backward() + return args[0].grad.detach() + + return wrapper + + +@register_pass("bwd_1") +def pass_bwd_1(callable): + def wrapper(*args): + args = [arg.detach().requires_grad_(idx == 1) for idx, arg in enumerate(args)] + output = callable(*args) + output.sum().backward() + return args[1].grad.detach() + + return wrapper + + +def filter_nan(a): + return torch.where(torch.isnan(a), 0, a) + + +def mantissa_bits(dtype) -> int: + return {torch.float32: 23, torch.bfloat16: 5, torch.float16: 10}[dtype] + + +def float_to_int_type(dtype): + return { + torch.float32: torch.uint32, + torch.bfloat16: torch.uint16, + torch.float16: torch.uint16, + }[dtype] + + +def generate_test_tensor(dtype, slow=False): + if dtype in [torch.float16, torch.bfloat16]: + return filter_nan( + torch.arange(0, 2**16, dtype=torch.int32, device="cuda") + .to(torch.uint16) + .view(dtype) + ) + if dtype == torch.float32: + if slow: + return filter_nan( + torch.arange(0, 2**32, dtype=torch.int64, device="cuda") + .to(torch.uint32) + .view(dtype) + ) + else: + # E8M5, E5M10, E8M23 + result = [] + for t in [ + generate_test_tensor(torch.bfloat16), + generate_test_tensor(torch.float16), + ]: + orig_dtype = t.dtype + t = t.to(dtype).view(torch.uint32) + missing_mantissa_bits = mantissa_bits(dtype) - mantissa_bits(orig_dtype) + r = torch.randint( + 0, 2**missing_mantissa_bits, t.shape, dtype=t.dtype, device="cuda" + ) + result.append(t.view(dtype)) + result.append((t.view(torch.int32) | r.view(torch.int32)).view(dtype)) + return filter_nan(torch.cat(result)) + + +def make_input(dtype, category): + dtype = getattr(torch, dtype) + from torch.testing import make_tensor + + if category in ["matrix"]: + return make_tensor(1024, 1024, dtype=dtype, device="cuda"), make_tensor( + 1024, 1024, dtype=dtype, device="cuda" + ) + if category == "normalization": + return make_tensor(1024, 1024, dtype=dtype, device="cuda"), make_tensor( + 1024, dtype=dtype, device="cuda" + ) + if category == "reduction": + return (make_tensor(1024, 1024, dtype=dtype, device="cuda"),) + if category == "binary": + test_tensor = generate_test_tensor(dtype) + # randomly permute twice: + return ( + test_tensor[torch.randperm(test_tensor.shape[0])], + test_tensor[torch.randperm(test_tensor.shape[0])], + ) + return (generate_test_tensor(dtype),) + + +def make_function(function, category): + if category == "reduction": + return eval(f"lambda x: {function}(x, dim=-1)") + elif category == "matrix": + return eval(f"lambda x, y: {function}(x, y)") + elif category == "normalization": + return eval(f"lambda x, y: {function}(x, [x.shape[-1]], y, eps=1e-5)") + elif category == "binary": + return eval(f"lambda x, y: {function}(x, y)") + elif category == "activation": + return eval(f"lambda x: {function}(x)") + elif category == "elementary": + return eval(f"lambda x: {function}(x)") + else: + raise ValueError(f"Unknown category: {category}") + + +def rel_err_ulp(a, b, dtype): + return (a - b).abs() / (b.abs() * torch.finfo(dtype).eps + torch.finfo(dtype).tiny) + + +def evaluate_output(input, output, golden_output): + # todo: how to make sure this handles denormals in the input well? + # for now, only handle it if the inputs have the same shape, otherwise assume there are none + # handle more than one input + + input = [i.flatten().float() for i in input if i.shape == output.shape] + + dtype = golden_output.dtype + + output = output.flatten().float() + golden_output = golden_output.flatten().float() + + # we are checking subnormals separate from the rest of the numbers + # we also need to check NaNs and Infs carefully + + subnormal_mask = golden_output.abs() < torch.finfo(dtype).smallest_normal + for i in input: + subnormal_mask |= i.abs() < torch.finfo(dtype).smallest_normal + nan_mask_golden = torch.isnan(golden_output) + golden_output = torch.where(nan_mask_golden, float("nan"), golden_output) + nan_mask_output = torch.isnan(output) + output = torch.where(nan_mask_output, float("nan"), output) + + equal_mask = (output == golden_output) | (nan_mask_golden & nan_mask_output) + equal_subnormal_mask = ( + torch.where(subnormal_mask, 0, output) + == torch.where(subnormal_mask, 0, golden_output) + ) | equal_mask + output_flushed = torch.where(subnormal_mask, 0.0, output) + golden_flushed = torch.where(subnormal_mask, 0.0, golden_output) + output_flushed = torch.where(~torch.isfinite(output_flushed), 0.0, output_flushed) + golden_flushed = torch.where(~torch.isfinite(golden_flushed), 0.0, golden_flushed) + + num_nonequal = (~equal_mask).sum().item() + num_nonequal_subnormal = (~equal_subnormal_mask).sum().item() + + err = rel_err_ulp(output_flushed, golden_flushed, dtype) + max_ulp_to_golden = err.max().item() + avg_ulp_to_golden = err.mean().item() + mismatch_sample = [] + pos = (~equal_mask).nonzero().squeeze(1) + log.info( + "pos.shape: %s, input: %s, golden_output.shape: %s, output.shape: %s", + pos.shape, + [i.shape for i in input], + golden_output.shape, + output.shape, + ) + ordered = torch.argsort(err[pos], descending=True) + random = torch.randperm(pos.shape[0]) + for sampling in [ordered, random]: + for i in range(min(pos.shape[0], 5)): + sample_idx = pos[sampling[i]] + mismatch_sample.append( + { + "pos": sample_idx.item(), + "input": [inp[sample_idx].item() for inp in input], + "output": output[sample_idx].item(), + "golden": golden_output[sample_idx].item(), + "rel_err": err[sample_idx].item(), + } + ) + + # use hashlib.sha256 to hash the tensors + # this i + return { + "normal_hash": hashlib.sha256( + output_flushed.cpu().numpy().tobytes() + ).hexdigest()[:8], + "full_hash": hashlib.sha256(output.cpu().numpy().tobytes()).hexdigest()[:8], + "max_ulp_to_golden": max_ulp_to_golden, + "avg_ulp_to_golden": avg_ulp_to_golden, + "num_nonequal": num_nonequal, + "num_nonequal_subnormal": num_nonequal_subnormal, + "num_total": output.shape[0], + "match_full": num_nonequal == 0, + "match_normal": num_nonequal_subnormal == 0, + "mismatch_sample": mismatch_sample, + } + + +def create_golden(args): + golden_inputs = { + dtype: {category: make_input(dtype, category) for category in CATEGORIES} + for dtype in DTYPES + } + golden_outputs = { + dtype: { + category: { + function: { + pass_type: MODES[args.mode]( + PASSES[pass_type](make_function(function, category)) + )(*golden_inputs[dtype][category]) + for pass_type in PASS_TYPES[category] + } + for function in CATEGORIES[category] + } + for category in CATEGORIES + } + for dtype in DTYPES + } + return golden_inputs, golden_outputs + + +def get_metadata(args, dtype, category, function, pass_type): + metadata = { + "gpu": args.gpu[0], + "pytorch_version": args.pytorch_version, + "cuda_version": args.cuda_version, + "pytorch_mode": args.mode, + "is_golden": args.create_golden, + "data_type": dtype, + "function": function, + "pass_type": pass_type, + "category": category, + } + identifier = hashlib.sha256(json.dumps(metadata).encode()).hexdigest()[:8] + metadata["identifier"] = identifier + return metadata + + +def run_test_case(args): + golden = torch.load(args.golden) + golden_inputs, golden_outputs = golden + metadata = get_metadata( + args, args.dtype, args.category, args.function, args.pass_type + ) + input = golden_inputs[args.dtype][args.category] + callable = make_function(args.function, args.category) + callable = PASSES[args.pass_type](callable) + callable = MODES[args.mode](callable) + output = callable(*input) + evaluation = evaluate_output( + input, + output, + golden_outputs[args.dtype][args.category][args.function][args.pass_type], + ) + data = json.dumps(metadata | evaluation) + log.info(data) + print(data) + + +test_concurrency = asyncio.Semaphore(16) + + +async def launch_test_case(args, dtype, category, function, pass_type): + async with test_concurrency: + env = os.environ | MODES[args.mode].env + metadata = get_metadata(args, dtype, category, function, pass_type) + identifier = metadata["identifier"] + env["TORCH_TRACE"] = f"/workspace/logs/{identifier}.trace" + Path(env["TORCH_TRACE"]).mkdir(parents=True, exist_ok=True) + env["TORCH_LOGS"] = "+all" + env["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" + env["TRITON_ALWAYS_COMPILE"] = "1" + cmd_args = [ + "--test", + f"--gpu {args.gpu[0]}", + f"--pytorch-version {args.pytorch_version}", + f"--cuda-version {args.cuda_version}", + f"--mode {args.mode}", + f"--golden {args.golden}", + f"--dtype {dtype}", + f"--category {category}", + f"--function {function}", + f"--pass-type {pass_type}", + ] + if args.create_golden: + cmd_args.append("--create-golden") + return await run( + f"{sys.executable} {__file__} {' '.join(cmd_args)}", + env=env, + capture_stderr=False, + stderr=open(f"/workspace/logs/{identifier}.log", "w"), + ) + + +async def runner(args): + log.info("Runner called! %s", args) + + if args.create_golden: + golden = create_golden(args) + torch.save(golden, args.golden) + + results = await asyncio.gather( + *[ + launch_test_case( + args, + dtype, + category, + function, + pass_type, + ) + for dtype, category in itertools.product(DTYPES, CATEGORIES) + for function, pass_type in itertools.product( + CATEGORIES[category], PASS_TYPES[category] + ) + ] + ) + + with open( + f"/workspace/result_{args.gpu[0]}_{args.pytorch_version}_{args.cuda_version}.jsonl", + "a+", + ) as result: + for r in results: + result.write(r) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--launcher", + default=False, + action="store_true", + help="launch workers across machines", + ) + parser.add_argument( + "--worker", + default=False, + action="store_true", + help="launches jobs on a single machine", + ) + parser.add_argument( + "--runner", + default=False, + action="store_true", + help="runs tests in a specific environment", + ) + parser.add_argument( + "--test", + action="store_true", + help="runs a single test case", + ) + parser.add_argument("--gpu", nargs="+", type=str, default=["t4"]) + parser.add_argument("--html", type=str, default=None) + parser.add_argument("--golden", type=str, default=None) + parser.add_argument("--create-golden", action="store_true") + parser.add_argument("--dtype", type=str) + parser.add_argument("--category", type=str) + parser.add_argument("--function", type=str) + parser.add_argument("--pass-type", type=str) + parser.add_argument("--pytorch-version", type=str) + parser.add_argument("--cuda-version", type=str) + parser.add_argument("--mode") + args = parser.parse_args() + if args.runner: + assert torch is not None # noqa: S101 + torch.set_default_device("cuda") + (gpu,) = args.gpu + logging.basicConfig( + level=logging.INFO, + format=f"%(asctime)s - runner:{gpu}/{args.pytorch_version}/{args.cuda_version}/{args.mode} - %(message)s", + ) + asyncio.run(runner(args)) + if args.worker: + (gpu,) = args.gpu + logging.basicConfig( + level=logging.INFO, format=f"%(asctime)s - worker:{gpu} - %(message)s" + ) + asyncio.run(worker(args)) + if args.launcher: + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - launcher - %(message)s" + ) + try: + asyncio.run(launcher(args)) + except KeyboardInterrupt: + log.error("Cancelled by user") + if args.test: + test_id = ( + f"{args.gpu[0]}/{args.pytorch_version}/{args.cuda_version}" + f"/{args.mode}/{args.dtype}/{args.category}" + f"/{args.function}/{args.pass_type}" + ) + logging.basicConfig( + level=logging.INFO, + format=f"%(asctime)s - test:{test_id} - %(message)s", + ) + run_test_case(args) + if args.html: + format_results_to_html(args.html) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py index abea0504603ac..197ec3f5c3b2a 100644 --- a/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py +++ b/benchmarks/dynamo/microbenchmarks/bench_mm_fusion.py @@ -1,5 +1,3 @@ -# flake8: noqa: B902 - from prettytable import PrettyTable import torch diff --git a/benchmarks/dynamo/perf_cli.py b/benchmarks/dynamo/perf_cli.py new file mode 100644 index 0000000000000..f4c85a5e4a2a8 --- /dev/null +++ b/benchmarks/dynamo/perf_cli.py @@ -0,0 +1,1894 @@ +#!/usr/bin/env python3 +""" +CLI tool for launching, summarizing, and reproducing inductor perf regression runs. + +Usage: + python benchmarks/dynamo/perf_cli.py launch [--device a100 h100 ...] [--ref BRANCH] [--wait] + python benchmarks/dynamo/perf_cli.py summary [--top 5] [--config PATTERN] + python benchmarks/dynamo/perf_cli.py repro [--model MODEL] [--suite SUITE] [--print-only] + +Requires: gh CLI (authenticated), internet access to S3 (gha-artifacts bucket). +""" + +from __future__ import annotations + +import argparse +import csv +import io +import json +import os +import re +import shutil +import subprocess +import sys +import time +import urllib.request +import zipfile +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from math import exp, log +from pathlib import Path + + +WORKFLOWS = { + "a100": { + "name": "inductor-A100-perf-nightly", + "id": 42513231, + }, + "a100-compare": { + "name": "inductor-A100-perf-compare", + "id": 50531883, + }, + "h100": { + "name": "inductor-perf-nightly-h100", + "id": 144201955, + }, + "b200": { + "name": "inductor-perf-b200", + "id": 173716622, + }, + "rocm-mi300": { + "name": "inductor-perf-nightly-rocm-mi300", + "id": 197925166, + }, + "rocm-mi355": { + "name": "inductor-perf-nightly-rocm-mi355", + "id": 197925165, + }, + "x86": { + "name": "inductor-perf-nightly-x86", + "id": 108782874, + }, + "x86-zen": { + "name": "inductor-perf-nightly-x86-zen", + "id": 167573808, + }, + "aarch64": { + "name": "inductor-perf-nightly-aarch64", + "id": 109196799, + }, + "macos": { + "name": "inductor-perf-nightly-macos", + "id": 117199085, + }, + "xpu": { + "name": "inductor-perf-nightly-xpu", + "id": 201149053, + }, +} + +DEVICE_CHOICES = sorted(k for k in WORKFLOWS if k != "a100-compare") + +S3_BUCKET = "gha-artifacts" +S3_URL = f"https://{S3_BUCKET}.s3.amazonaws.com" +REPO = "pytorch/pytorch" + +# Regex to parse test job names like: +# "cuda13.0-py3.10-gcc11-sm80 / test (inductor_huggingface_perf, 1, 5, linux.aws.a100)" +JOB_RE = re.compile( + r"test \((?P[^,]+),\s*(?P\d+),\s*(?P\d+),\s*(?P[^)]+)\)" +) + +PERF_CONFIGS = re.compile(r"inductor_(huggingface|timm|torchbench)_perf") + +SUITE_ALIASES = { + "hf": "huggingface", + "huggingface": "huggingface", + "timm": "timm_models", + "timm_models": "timm_models", + "tb": "torchbench", + "torchbench": "torchbench", +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def gh(*args: str, json_output: bool = False) -> str | dict | list: + cmd = ["gh"] + list(args) + result = subprocess.run(cmd, capture_output=True, text=True) + if result.returncode != 0: + print(f"gh error: {result.stderr.strip()}", file=sys.stderr) + sys.exit(1) + if json_output: + return json.loads(result.stdout) + return result.stdout.strip() + + +def git(*args: str) -> str: + result = subprocess.run(["git"] + list(args), capture_output=True, text=True) + return result.stdout.strip() + + +def gmean(values: list[float]) -> float: + if not values: + return 0.0 + return exp(sum(log(v) for v in values if v > 0) / max(len(values), 1)) + + +@dataclass +class Metric: + name: str + field: str # attribute on ModelResult + unit: str # display suffix + higher_is_better: bool + aggregate: str # "gmean" or "mean" + + +METRICS = { + "speedup": Metric("speedup", "speedup", "x", True, "gmean"), + "compilation_latency": Metric( + "compilation latency", "compilation_latency", "s", False, "mean" + ), + "compression_ratio": Metric( + "memory compression", "compression_ratio", "x", True, "gmean" + ), + "abs_latency": Metric("absolute latency", "abs_latency", "ms", False, "mean"), +} + +METRIC_CHOICES = list(METRICS.keys()) + +# HUD uses 5% relative threshold for flagging regressions +RELATIVE_THRESHOLD = 0.05 + +WORKFLOW_NAME_TO_NIGHTLY_ID = { + v["name"]: v["id"] for k, v in WORKFLOWS.items() if k != "a100-compare" +} +# compare's baseline is A100 nightly +WORKFLOW_NAME_TO_NIGHTLY_ID["inductor-A100-perf-compare"] = WORKFLOWS["a100"]["id"] + + +def _short_config(config: str, device: str = "") -> str: + """Compact label: e.g. 'a100 cudagraphs huggingface training'.""" + c = config + backend = c + for s in ("_huggingface_", "_timm_models_", "_torchbench_"): + if s in c: + backend = c.split(s)[0] + break + backend = backend.removeprefix("inductor_") + mode = "training" if "training" in c else "inference" if "inference" in c else "" + suite = "" + for s in ("huggingface", "timm_models", "torchbench"): + if s in c: + suite = s + break + parts = [p for p in (device, backend, suite, mode) if p] + return " ".join(parts) + + +@dataclass +class ModelResult: + name: str + speedup: float + abs_latency: float = 0.0 + compilation_latency: float = 0.0 + compression_ratio: float = 0.0 + eager_peak_mem: float = 0.0 + dynamo_peak_mem: float = 0.0 + config: str = "" + device: str = "" + + @property + def short_config(self) -> str: + return _short_config(self.config, self.device) + + +@dataclass +class PerfData: + config: str # e.g. "inductor_with_cudagraphs_huggingface_amp_training_cuda" + models: list[ModelResult] = field(default_factory=list) + device: str = "" + + @property + def suite(self) -> str: + for s in ("huggingface", "timm_models", "torchbench"): + if s in self.config: + return s + return "unknown" + + @property + def mode(self) -> str: + if "training" in self.config: + return "training" + if "inference" in self.config: + return "inference" + return "unknown" + + @property + def dtype(self) -> str: + # Config format: {backend}_{suite}_{dtype}_{mode}_{device} + # e.g. inductor_with_cudagraphs_huggingface_amp_training_cuda + for s in ("_huggingface_", "_timm_models_", "_torchbench_"): + if s in self.config: + tail = self.config.split(s, 1)[1] + # tail is e.g. "amp_training_cuda" + parts = tail.split("_") + if parts: + return parts[0] + return "unknown" + + @property + def short_name(self) -> str: + name = self.config + for s in ("_huggingface_", "_timm_models_", "_torchbench_"): + if s in name: + name = name.split(s)[0] + break + return name + + @property + def runtime(self) -> str: + # Last token in config: e.g. "..._training_cuda" → "cuda" + parts = self.config.rsplit("_", 1) + if len(parts) == 2: + return parts[1] + return "unknown" + + @property + def qualified_config(self) -> str: + if self.device: + return f"{self.device}/{self.config}" + return self.config + + def gmean_speedup(self) -> float: + vals = [m.speedup for m in self.models if m.speedup > 0] + return gmean(vals) + + def aggregate_metric(self, metric: Metric) -> float: + vals = [ + getattr(m, metric.field) + for m in self.models + if getattr(m, metric.field) > 0 + ] + if not vals: + return 0.0 + if metric.aggregate == "gmean": + return gmean(vals) + return sum(vals) / len(vals) + + +# --------------------------------------------------------------------------- +# Artifact downloading +# --------------------------------------------------------------------------- + + +def get_run_jobs(run_id: int) -> list[dict]: + data = gh( + "run", + "view", + str(run_id), + "--repo", + REPO, + "--json", + "jobs", + json_output=True, + ) + return data["jobs"] + + +def get_perf_jobs(jobs: list[dict]) -> list[dict]: + perf_jobs = [] + for job in jobs: + m = JOB_RE.search(job["name"]) + if not m: + continue + config = m.group("config") + if not PERF_CONFIGS.match(config): + continue + if job.get("conclusion") != "success": + continue + perf_jobs.append( + { + "config": config, + "shard": m.group("shard"), + "num_shards": m.group("num_shards"), + "runner": m.group("runner"), + "job_id": job["databaseId"], + "name": job["name"], + } + ) + return perf_jobs + + +def s3_artifact_url(run_id: int, attempt: int, job: dict) -> str: + config = job["config"] + shard = job["shard"] + num_shards = job["num_shards"] + runner = job["runner"] + job_id = job["job_id"] + filename = f"test-reports-test-{config}-{shard}-{num_shards}-{runner}_{job_id}.zip" + return f"{S3_URL}/{REPO}/{run_id}/{attempt}/artifact/{filename}" + + +CACHE_DIR = Path.home() / ".cache" / "perf_cli" + + +def get_cache_dir(run_id: int, attempt: int) -> Path: + d = CACHE_DIR / f"{run_id}" / f"{attempt}" + d.mkdir(parents=True, exist_ok=True) + return d + + +def download_and_extract_csvs( + run_id: int, + jobs: list[dict], + attempt: int = 1, + no_cache: bool = False, +) -> list[tuple[str, str]]: + """Download artifacts and return list of (csv_filename, csv_content) pairs.""" + cache = get_cache_dir(run_id, attempt) + results = [] + fetched = 0 + + for job in jobs: + # Check cache first + cache_key = ( + f"{job['config']}-{job['shard']}-{job['num_shards']}-{job['job_id']}" + ) + cache_marker = cache / f"{cache_key}.done" + + if not no_cache and cache_marker.exists(): + # Read cached CSVs + for csv_file in cache.glob(f"{cache_key}__*.csv"): + csv_name = csv_file.name.split("__", 1)[1] + results.append((csv_name, csv_file.read_text())) + continue + + url = s3_artifact_url(run_id, attempt, job) + zip_path = cache / f"{cache_key}.zip" + try: + urllib.request.urlretrieve(url, str(zip_path)) + fetched += 1 + except urllib.error.HTTPError as e: + print( + f" warning: failed to download shard {job['config']} " + f"shard {job['shard']}: {e}", + file=sys.stderr, + ) + continue + + with zipfile.ZipFile(zip_path) as zf: + for name in zf.namelist(): + if name.endswith("_performance.csv"): + csv_name = os.path.basename(name) + with zf.open(name) as f: + content = f.read().decode("utf-8") + results.append((csv_name, content)) + # Write to cache + (cache / f"{cache_key}__{csv_name}").write_text(content) + + # Mark this shard as cached and remove the zip + cache_marker.touch() + zip_path.unlink(missing_ok=True) + + if fetched > 0: + print(f" downloaded {fetched} shards (cached at {cache})") + elif results: + print(f" using cached data from {cache}") + + return results + + +def parse_csvs(csv_pairs: list[tuple[str, str]], device: str = "") -> list[PerfData]: + grouped: dict[str, list[ModelResult]] = defaultdict(list) + + for csv_name, content in csv_pairs: + # csv_name like: inductor_with_cudagraphs_huggingface_amp_training_cuda_performance.csv + config = csv_name.replace("_performance.csv", "") + reader = csv.DictReader(io.StringIO(content)) + for row in reader: + try: + speedup = float(row.get("speedup", 0)) + except (ValueError, TypeError): + continue + grouped[config].append( + ModelResult( + name=row.get("name", "?"), + speedup=speedup, + abs_latency=float(row.get("abs_latency", 0) or 0), + compilation_latency=float(row.get("compilation_latency", 0) or 0), + compression_ratio=float(row.get("compression_ratio", 0) or 0), + eager_peak_mem=float(row.get("eager_peak_mem", 0) or 0), + dynamo_peak_mem=float(row.get("dynamo_peak_mem", 0) or 0), + config=config, + device=device, + ) + ) + + return [ + PerfData(config=k, models=v, device=device) for k, v in sorted(grouped.items()) + ] + + +# --------------------------------------------------------------------------- +# S-curve rendering +# --------------------------------------------------------------------------- + + +def subsample(items: list, max_rows: int) -> list: + """Evenly subsample a sorted list, always keeping first and last.""" + n = len(items) + if n <= max_rows: + return items + # Always include first and last; evenly space the rest + indices = {0, n - 1} + for i in range(1, max_rows - 1): + indices.add(round(i * (n - 1) / (max_rows - 1))) + return [items[i] for i in sorted(indices)] + + +def render_scurve( + perf: PerfData, + metric: Metric, + top_n: int = 5, + term_width: int | None = None, + term_height: int | None = None, +): + if not perf.models: + return + + if term_width is None or term_height is None: + sz = shutil.get_terminal_size((100, 50)) + term_width = term_width or sz.columns + term_height = term_height or sz.lines + + live = [m for m in perf.models if getattr(m, metric.field) > 0] + if not live: + return + + sorted_models = sorted(live, key=lambda m: getattr(m, metric.field)) + agg = perf.aggregate_metric(metric) + n = len(sorted_models) + + # Reserve lines for header (2) + axis label (1) + padding (2) + max_rows = max(term_height - 5, 15) + display = subsample(sorted_models, max_rows) + skipped = n - len(display) + + agg_label = metric.aggregate + header = f"{perf.config} ({n} data points, {agg_label}={agg:.2f}{metric.unit})" + if skipped > 0: + header += f" [showing {len(display)}/{n}]" + print(f"\n {header}") + print(f" {'─' * min(len(header), term_width - 4)}") + + def fmt_val(v: float) -> str: + if metric.unit == "s" or metric.unit == "ms": + return f"{v:7.1f}{metric.unit}" + return f"{v:5.2f}{metric.unit}" + + # Layout: " {name:8} {dots}" + sample_val = fmt_val(display[0] and getattr(display[0], metric.field)) + max_name = min(max(len(m.name) for m in display), 30) + val_width = len(sample_val) + prefix_len = 2 + max_name + 2 + val_width + 2 + plot_width = max(term_width - prefix_len - 1, 20) + + def get_val(m): + return getattr(m, metric.field) + + min_val = get_val(sorted_models[0]) + p95_idx = max(0, int(n * 0.95) - 1) + p95_val = get_val(sorted_models[p95_idx]) + + # For ratio metrics (speedup, compression_ratio), anchor at 1.0 + # For absolute metrics (latency), anchor at 0 + if metric.unit == "x": + plot_min = min(min_val, 0.5) + plot_max = max(p95_val * 1.1, 1.5) + marker_val = 1.0 + marker_label = "1.0x" + else: + plot_min = 0 + plot_max = p95_val * 1.1 + marker_val = None + marker_label = None + + span = plot_max - plot_min + if span == 0: + span = 1 + + def val_to_col(v: float) -> int: + return max( + 0, min(plot_width - 1, int((v - plot_min) / span * (plot_width - 1))) + ) + + marker_col = val_to_col(marker_val) if marker_val is not None else None + + for m in display: + name = m.name[:max_name].ljust(max_name) + v = get_val(m) + col = val_to_col(v) + bar = [" "] * plot_width + if marker_col is not None: + bar[marker_col] = "|" + for i in range(col + 1): + if marker_col is not None and i == marker_col: + bar[i] = "|" + else: + bar[i] = "·" + print(f" {name} {fmt_val(v)} {''.join(bar)}") + + pad = " " * prefix_len + if marker_label and marker_col is not None: + print(f"{pad}{' ' * marker_col}{marker_label}") + else: + print() + + +def print_worst_offenders(perf: PerfData, metric: Metric, top_n: int = 5): + def get_val(m): + return getattr(m, metric.field) + + live = [m for m in perf.models if get_val(m) > 0] + if not live: + return + # "worst" depends on metric direction + if metric.higher_is_better: + worst = sorted(live, key=get_val)[:top_n] + else: + worst = sorted(live, key=get_val, reverse=True)[:top_n] + print(f"\n Worst offenders ({metric.name}):") + for i, m in enumerate(worst, 1): + v = get_val(m) + parts = [f"{v:.3f}{metric.unit}"] + if m.config: + parts.append(m.short_config) + detail = " ".join(parts) + print(f" {i}. {m.name:<30} {detail}") + + +# --------------------------------------------------------------------------- +# Subcommands +# --------------------------------------------------------------------------- + + +def build_dispatch_inputs(args) -> list[str]: + """Build -f flags for workflow dispatch inputs from CLI args.""" + flags = [] + bool_inputs = [ + "training", + "inference", + "default", + "dynamic", + "cppwrapper", + "cudagraphs", + "freezing_cudagraphs", + "aotinductor", + "maxautotune", + ] + for name in bool_inputs: + val = getattr(args, name, None) + if val is not None: + flags.extend(["-f", f"{name}={'true' if val else 'false'}"]) + if args.benchmark_configs: + flags.extend(["-f", f"benchmark_configs={args.benchmark_configs}"]) + return flags + + +def dispatch_one(device: str, ref: str, extra_flags: list[str]) -> int | None: + wf = WORKFLOWS[device] + print(f"\nLaunching {wf['name']} on ref: {ref}") + + dispatch_args = [ + "workflow", + "run", + str(wf["id"]), + "--repo", + REPO, + "--ref", + ref, + ] + extra_flags + + gh(*dispatch_args) + print("Dispatched. Waiting a few seconds for the run to appear...") + time.sleep(5) + + runs = gh( + "run", + "list", + "--repo", + REPO, + "--workflow", + str(wf["id"]), + "--branch", + ref, + "--limit", + "1", + "--json", + "databaseId,status,url,createdAt", + json_output=True, + ) + if not runs: + print("Could not find the dispatched run. Check the Actions tab manually.") + return None + + run = runs[0] + run_id = run["databaseId"] + url = f"https://github.com/{REPO}/actions/runs/{run_id}" + print(f"Run ID: {run_id}") + print(f"URL: {url}") + print(f"Status: {run['status']}") + return run_id + + +def wait_for_runs(pending: dict[str, int]) -> dict[str, int]: + """Poll all runs until they complete. Returns dict of successful runs.""" + print(f"\nWaiting for {len(pending)} run(s)...", flush=True) + remaining = dict(pending) + succeeded: dict[str, int] = {} + while remaining: + time.sleep(30) + done = [] + for device, run_id in remaining.items(): + data = gh( + "run", + "view", + str(run_id), + "--repo", + REPO, + "--json", + "status,conclusion", + json_output=True, + ) + status = data.get("status", "unknown") + if status == "completed": + conclusion = data.get("conclusion", "unknown") + print(f" {device} (run {run_id}): {conclusion}") + if conclusion == "success": + succeeded[device] = run_id + done.append(device) + for d in done: + del remaining[d] + if remaining: + ts = datetime.now().strftime("%H:%M:%S") + print(f" [{ts}] still waiting: {', '.join(remaining)}...", flush=True) + return succeeded + + +def cmd_launch(args): + ref = args.ref or git("rev-parse", "--abbrev-ref", "HEAD") + if ref == "HEAD": + ref = git("rev-parse", "HEAD") + + extra_flags = build_dispatch_inputs(args) + launched: dict[str, int] = {} + for device in args.device: + if device not in WORKFLOWS: + print(f"Unknown device: {device}", file=sys.stderr) + sys.exit(1) + run_id = dispatch_one(device, ref, extra_flags) + if run_id: + launched[device] = run_id + + if (args.wait or args.wait_and_summarize) and launched: + succeeded = wait_for_runs(launched) + if args.wait_and_summarize and succeeded: + # Use the first device's run ID as the positional arg; pass all + # device→run_id pairs via _run_ids so cmd_summary skips resolution. + first_run = next(iter(succeeded.values())) + summary_args = argparse.Namespace( + run_id=str(first_run), + device=list(succeeded.keys()), + _run_ids=succeeded, + baseline="latest", + metric="speedup", + top=5, + config=None, + suite=None, + mode=None, + group_by=None, + attempt=1, + no_cache=False, + ) + print(f"\n{'=' * 70}") + print("Summary") + print(f"{'=' * 70}") + cmd_summary(summary_args) + + +def filter_perf(all_perf: list[PerfData], args) -> list[PerfData]: + result = all_perf + if getattr(args, "config", None): + pattern = re.compile(args.config, re.IGNORECASE) + result = [p for p in result if pattern.search(p.config)] + if getattr(args, "suite", None): + suite = SUITE_ALIASES.get(args.suite, args.suite) + result = [p for p in result if p.suite == suite] + if getattr(args, "mode", None): + result = [p for p in result if p.mode == args.mode] + if getattr(args, "backend", None): + pattern = re.compile(args.backend, re.IGNORECASE) + result = [p for p in result if pattern.search(p.short_name)] + if getattr(args, "dtype", None): + result = [p for p in result if p.dtype == args.dtype] + if getattr(args, "runtime", None): + result = [p for p in result if p.runtime == args.runtime] + return result + + +GROUP_KEY_FNS: dict[str, callable] = { + "config": lambda p: p.qualified_config, + "suite": lambda p: p.suite, + "mode": lambda p: p.mode, + "backend": lambda p: p.short_name, + "device": lambda p: p.device or "unknown", + "dtype": lambda p: p.dtype, + "runtime": lambda p: p.runtime, +} + +GROUP_CHOICES = sorted(GROUP_KEY_FNS.keys()) + + +def group_perf(all_perf: list[PerfData], group_by: list[str] | None) -> list[PerfData]: + if not group_by: + all_models = [] + for p in all_perf: + all_models.extend(p.models) + return [PerfData(config="all", models=all_models)] + + fns = [] + for key in group_by: + fn = GROUP_KEY_FNS.get(key) + if fn is None: + print(f"Unknown group-by: {key}", file=sys.stderr) + sys.exit(1) + fns.append(fn) + + def composite_key(p: PerfData) -> str: + return " / ".join(fn(p) for fn in fns) + + groups: dict[str, list[ModelResult]] = defaultdict(list) + for p in all_perf: + groups[composite_key(p)].extend(p.models) + + return [PerfData(config=k, models=v) for k, v in sorted(groups.items())] + + +# --------------------------------------------------------------------------- +# Baseline comparison +# --------------------------------------------------------------------------- + + +NIGHTLY_WORKFLOW_IDS = {k: v["id"] for k, v in WORKFLOWS.items() if k != "a100-compare"} + + +def _find_latest_run( + branch: str, + device: str, +) -> dict | None: + """Find the latest successful perf nightly run for a branch + device. + + Returns {databaseId, createdAt, headSha} or None. + """ + wf_id = NIGHTLY_WORKFLOW_IDS.get(device) + if wf_id is None: + return None + + runs = gh( + "run", + "list", + "--repo", + REPO, + "--workflow", + str(wf_id), + "--branch", + branch, + "--status", + "success", + "--limit", + "1", + "--json", + "databaseId,createdAt,headSha", + json_output=True, + ) + if not runs: + return None + return runs[0] + + +def resolve_run(branch: str, device: str) -> int: + """Find the latest successful perf nightly run for a branch + device.""" + run = _find_latest_run(branch, device) + if not run: + print( + f"No successful perf run found for branch '{branch}' on {device}.", + file=sys.stderr, + ) + sys.exit(1) + run_id = run["databaseId"] + created = run["createdAt"][:10] + print(f"Resolved '{branch}' → run {run_id} ({device}, {created})") + return run_id + + +def discover_runs(branch: str) -> dict[str, int]: + """Auto-discover all devices with successful runs for a branch. + + Finds the latest commit SHA that has runs, then returns all runs matching + that commit. + """ + candidates: list[tuple[str, dict]] = [] + for device in DEVICE_CHOICES: + run = _find_latest_run(branch, device) + if run: + candidates.append((device, run)) + + if not candidates: + print( + f"No successful perf runs found for branch '{branch}' on any device.", + file=sys.stderr, + ) + sys.exit(1) + + # Pick the most recent commit (by createdAt) and collect all runs on it + candidates.sort(key=lambda x: x[1]["createdAt"], reverse=True) + target_sha = candidates[0][1]["headSha"] + + result = {} + for device, run in candidates: + if run["headSha"] == target_sha: + run_id = run["databaseId"] + created = run["createdAt"][:10] + print( + f"Discovered '{branch}' → run {run_id} ({device}, {created}, {target_sha[:10]})" + ) + result[device] = run_id + + return result + + +def resolve_runs(branch: str, devices: list[str]) -> dict[str, int]: + """Resolve the latest successful run for each device. Returns {device: run_id}.""" + result = {} + for device in devices: + result[device] = resolve_run(branch, device) + return result + + +def resolve_baseline(head_run_id: int) -> int: + """Find the latest successful nightly on main, skipping the head run itself.""" + run_data = gh( + "run", + "view", + str(head_run_id), + "--repo", + REPO, + "--json", + "workflowName", + json_output=True, + ) + wf_name = run_data.get("workflowName", "") + nightly_id = WORKFLOW_NAME_TO_NIGHTLY_ID.get(wf_name) + if nightly_id is None: + print( + f"Don't know which nightly corresponds to workflow '{wf_name}'", + file=sys.stderr, + ) + sys.exit(1) + + runs = gh( + "run", + "list", + "--repo", + REPO, + "--workflow", + str(nightly_id), + "--branch", + "main", + "--status", + "success", + "--limit", + "5", + "--json", + "databaseId,createdAt,headBranch", + json_output=True, + ) + # Skip the head run itself to avoid comparing a run to itself + for run in runs: + if run["databaseId"] != head_run_id: + baseline_id = run["databaseId"] + created = run["createdAt"][:10] + print(f"Baseline: run {baseline_id} (main, {created})") + return baseline_id + + print("No suitable baseline nightly found on main.", file=sys.stderr) + sys.exit(1) + + +def device_for_workflow(workflow_name: str) -> str: + """Reverse-lookup device key from workflow name.""" + for k, v in WORKFLOWS.items(): + if v["name"] == workflow_name: + return k + return "" + + +@dataclass +class RunMeta: + run_id: int + head_sha: str + head_branch: str + workflow_name: str + created_at: str + event: str + + @property + def short_sha(self) -> str: + return self.head_sha[:10] + + @property + def date(self) -> str: + return self.created_at[:10] + + +def fetch_run_meta(run_id: int) -> RunMeta: + data = gh( + "api", + f"repos/{REPO}/actions/runs/{run_id}", + "-q", + "{headSha: .head_sha, headBranch: .head_branch, workflowName: .name, createdAt: .created_at, event: .event}", + json_output=True, + ) + return RunMeta( + run_id=run_id, + head_sha=data.get("headSha", "unknown"), + head_branch=data.get("headBranch", "unknown"), + workflow_name=data.get("workflowName", "unknown"), + created_at=data.get("createdAt", "unknown"), + event=data.get("event", "unknown"), + ) + + +def print_run_header( + label: str, + metas: list[RunMeta], + configs: list[str] | None = None, +): + print(f"\n {label}") + print(f" {'─' * len(label)}") + if len(metas) == 1: + m = metas[0] + print(f" Run: {m.run_id} ({m.workflow_name})") + print(f" Commit: {m.short_sha} ({m.head_branch}, {m.date})") + else: + # Show commit from first (should all match for multi-device) + print( + f" Commit: {metas[0].short_sha} ({metas[0].head_branch}, {metas[0].date})" + ) + print(" Runs:") + for m in metas: + print(f" {m.run_id} ({m.workflow_name})") + if configs: + print(f" Configs: {len(configs)} — {', '.join(sorted(configs))}") + + +def fetch_run_perf( + run_id: int, + attempt: int, + no_cache: bool, + device: str = "", + allow_empty: bool = False, +) -> list[PerfData]: + """Fetch and parse perf data for a run.""" + jobs = get_run_jobs(run_id) + perf_jobs = get_perf_jobs(jobs) + if not perf_jobs: + if allow_empty: + print(f" no perf jobs in run {run_id}, skipping", file=sys.stderr) + return [] + print(f"No successful perf jobs in run {run_id}.", file=sys.stderr) + sys.exit(1) + csv_pairs = download_and_extract_csvs(run_id, perf_jobs, attempt, no_cache=no_cache) + if not csv_pairs: + if allow_empty: + print(f" no CSVs in run {run_id}, skipping", file=sys.stderr) + return [] + print(f"No performance CSVs in run {run_id}.", file=sys.stderr) + sys.exit(1) + return parse_csvs(csv_pairs, device=device) + + +@dataclass +class ModelDelta: + name: str + base_val: float + head_val: float + config: str = "" + device: str = "" + + @property + def delta_pct(self) -> float: + if self.base_val == 0: + return 0.0 + return (self.head_val - self.base_val) / self.base_val * 100 + + @property + def short_config(self) -> str: + return _short_config(self.config, self.device) + + +@dataclass +class ConfigAgg: + base_agg: float + base_count: int + head_agg: float + head_count: int + paired_ratio: float # gmean(head_val / base_val) over paired models + paired_count: int + + +def compute_deltas( + head_perf: list[PerfData], base_perf: list[PerfData], metric: Metric +) -> tuple[list[ModelDelta], dict[str, ConfigAgg]]: + """Join head and base on (device, config, model_name) and compute deltas. + + Returns (per_model_deltas, per_config_aggregates). + per_config_aggregates maps qualified_config -> ConfigAgg. + """ + # Build base lookup: (device, config, model_name) -> metric value + base_lookup: dict[tuple[str, str, str], float] = {} + for perf in base_perf: + for m in perf.models: + v = getattr(m, metric.field) + if v > 0: + base_lookup[(perf.device, perf.config, m.name)] = v + + deltas = [] + for perf in head_perf: + for m in perf.models: + head_val = getattr(m, metric.field) + if head_val <= 0: + continue + key = (perf.device, perf.config, m.name) + if key not in base_lookup: + continue + base_val = base_lookup[key] + deltas.append( + ModelDelta( + name=m.name, + base_val=base_val, + head_val=head_val, + config=perf.config, + device=perf.device, + ) + ) + + # Group deltas by qualified_config for paired aggregates + deltas_by_qconfig: dict[str, list[ModelDelta]] = defaultdict(list) + for d in deltas: + qc = f"{d.device}/{d.config}" if d.device else d.config + deltas_by_qconfig[qc].append(d) + + # Per-config aggregates (keyed by qualified_config for display) + config_aggs: dict[str, ConfigAgg] = {} + base_by_qconfig: dict[str, PerfData] = {p.qualified_config: p for p in base_perf} + for perf in head_perf: + qc = perf.qualified_config + if qc not in base_by_qconfig: + continue + base_perf_data = base_by_qconfig[qc] + head_agg = perf.aggregate_metric(metric) + base_agg = base_perf_data.aggregate_metric(metric) + head_count = len([m for m in perf.models if getattr(m, metric.field) > 0]) + base_count = len( + [m for m in base_perf_data.models if getattr(m, metric.field) > 0] + ) + + paired = deltas_by_qconfig.get(qc, []) + ratios = [d.head_val / d.base_val for d in paired if d.base_val > 0] + config_aggs[qc] = ConfigAgg( + base_agg=base_agg, + base_count=base_count, + head_agg=head_agg, + head_count=head_count, + paired_ratio=gmean(ratios) if ratios else 0.0, + paired_count=len(ratios), + ) + + return deltas, config_aggs + + +def print_comparison_table(config_aggs: dict[str, ConfigAgg], metric: Metric): + u = metric.unit + print(f"\n{'Config':<55} {'base':>16} {'new':>16} {'head/base':>18}") + print("─" * 108) + for config in sorted(config_aggs): + agg = config_aggs[config] + flag = "" + if agg.paired_ratio > 0: + delta_pct = (agg.paired_ratio - 1.0) * 100 + if abs(delta_pct) > RELATIVE_THRESHOLD * 100: + if metric.higher_is_better: + flag = " !!" if delta_pct < 0 else " ++" + else: + flag = " !!" if delta_pct > 0 else " ++" + print( + f" {config:<53} " + f"{agg.base_agg:>5.2f}{u} (n={agg.base_count}) " + f"{agg.head_agg:>5.2f}{u} (n={agg.head_count}) " + f"{agg.paired_ratio:>5.3f}x (n={agg.paired_count}){flag}" + ) + + +def print_regressions(deltas: list[ModelDelta], metric: Metric, top_n: int): + # For higher_is_better metrics, regression = negative delta + # For lower_is_better metrics, regression = positive delta + if metric.higher_is_better: + bad = [d for d in deltas if d.delta_pct < -RELATIVE_THRESHOLD * 100] + bad.sort(key=lambda d: d.delta_pct) + else: + bad = [d for d in deltas if d.delta_pct > RELATIVE_THRESHOLD * 100] + bad.sort(key=lambda d: d.delta_pct, reverse=True) + + if not bad: + print(f"\n No regressions (>{RELATIVE_THRESHOLD * 100:.0f}% change).") + return + + print(f"\n Regressions ({len(bad)} models, showing top {min(top_n, len(bad))}):") + for i, d in enumerate(bad[:top_n], 1): + print( + f" {i}. {d.name:<30} " + f"{d.base_val:.2f}{metric.unit} → {d.head_val:.2f}{metric.unit} " + f"{d.delta_pct:>+6.1f}% {d.short_config}" + ) + + +def render_delta_scurve( + deltas: list[ModelDelta], + metric: Metric, + term_width: int | None = None, + term_height: int | None = None, +): + if not deltas: + return + + if term_width is None or term_height is None: + sz = shutil.get_terminal_size((100, 50)) + term_width = term_width or sz.columns + term_height = term_height or sz.lines + + sorted_deltas = sorted(deltas, key=lambda d: d.delta_pct) + n = len(sorted_deltas) + max_rows = max(term_height - 5, 15) + display = subsample(sorted_deltas, max_rows) + skipped = n - len(display) + + header = f"Delta S-curve ({n} models)" + if skipped > 0: + header += f" [showing {len(display)}/{n}]" + print(f"\n {header}") + print(f" {'─' * min(len(header), term_width - 4)}") + + max_name = min(max(len(d.name) for d in display), 28) + # " name +12.3% {bar}" + prefix_len = 2 + max_name + 2 + 7 + 2 + plot_width = max(term_width - prefix_len - 1, 20) + + # Range: cap at p5/p95 to avoid outlier squishing + p5_idx = max(0, int(n * 0.05)) + p95_idx = min(n - 1, int(n * 0.95)) + range_lo = min(sorted_deltas[p5_idx].delta_pct, -10) + range_hi = max(sorted_deltas[p95_idx].delta_pct, 10) + # Ensure symmetric-ish around 0 + abs_max = max(abs(range_lo), abs(range_hi)) + range_lo = -abs_max + range_hi = abs_max + span = range_hi - range_lo + if span == 0: + span = 1 + + def pct_to_col(pct: float) -> int: + return max( + 0, min(plot_width - 1, int((pct - range_lo) / span * (plot_width - 1))) + ) + + zero_col = pct_to_col(0) + + for d in display: + name = d.name[:max_name].ljust(max_name) + col = pct_to_col(d.delta_pct) + bar = [" "] * plot_width + bar[zero_col] = "|" + if col <= zero_col: + for i in range(col, zero_col): + bar[i] = "·" + bar[zero_col] = "|" + else: + bar[zero_col] = "|" + for i in range(zero_col + 1, col + 1): + bar[i] = "·" + print(f" {name} {d.delta_pct:>+6.1f}% {''.join(bar)}") + + pad = " " * prefix_len + print(f"{pad}{' ' * zero_col}0%") + + +def print_summary_table(all_perf: list[PerfData], metric: Metric): + agg_label = metric.aggregate + print(f"\n{'Config':<65} {'Models':>6} {agg_label:>10}") + print("─" * 85) + for perf in all_perf: + n = len([m for m in perf.models if getattr(m, metric.field) > 0]) + agg = perf.aggregate_metric(metric) + print(f" {perf.qualified_config:<63} {n:>6} {agg:>8.2f}{metric.unit}") + + +def _resolve_head_runs(args) -> dict[str, int]: + """Parse run_id arg into {device: run_id} mapping. + + run_id can be: + - A single numeric run ID (device inferred from workflow) + - A branch name (resolved across all --device values, or auto-discovered) + Pre-resolved IDs can be passed via args._run_ids. + """ + # Pre-resolved (from --wait-and-summarize) + if hasattr(args, "_run_ids") and args._run_ids: + return args._run_ids + + raw = args.run_id + try: + run_id = int(raw) + # Single run ID — infer device from the workflow + meta = fetch_run_meta(run_id) + device = device_for_workflow(meta.workflow_name) + return {device: run_id} + except ValueError: + # Branch name + if args.device: + return resolve_runs(raw, args.device) + return discover_runs(raw) + + +def cmd_summary(args): + attempt = args.attempt + metric = METRICS[args.metric] + + head_run_ids = _resolve_head_runs(args) + + # Fetch head runs + auto_discovered = not hasattr(args, "_run_ids") and not args.device + head_metas: list[RunMeta] = [] + head_perf: list[PerfData] = [] + for device, run_id in list(head_run_ids.items()): + print(f"Fetching head run {run_id} ({device})...") + perf = fetch_run_perf( + run_id, + attempt, + args.no_cache, + device=device, + allow_empty=auto_discovered, + ) + if not perf: + del head_run_ids[device] + continue + head_metas.append(fetch_run_meta(run_id)) + head_perf.extend(perf) + + head_perf = filter_perf(head_perf, args) + if not head_perf: + print("No configs matched filters.") + sys.exit(1) + + head_configs = [p.qualified_config for p in head_perf] + + # Baseline comparison mode + if args.baseline and args.baseline.lower() != "none": + base_metas: list[RunMeta] = [] + base_perf: list[PerfData] = [] + + for device, head_run_id in head_run_ids.items(): + if args.baseline == "latest": + baseline_id = resolve_baseline(head_run_id) + else: + try: + baseline_id = int(args.baseline) + except ValueError: + baseline_id = resolve_run(args.baseline, device) + + print(f"Fetching baseline run {baseline_id} ({device})...") + base_data = fetch_run_perf( + baseline_id, + attempt, + args.no_cache, + device=device, + allow_empty=auto_discovered, + ) + if not base_data: + continue + base_metas.append(fetch_run_meta(baseline_id)) + base_perf.extend(base_data) + + base_perf = filter_perf(base_perf, args) + if not base_perf: + print("No baseline configs matched filters.") + sys.exit(1) + + print_run_header("HEAD", head_metas, head_configs) + print_run_header("BASE", base_metas, [p.qualified_config for p in base_perf]) + print() + + deltas, config_aggs = compute_deltas(head_perf, base_perf, metric) + if not deltas: + print("No matching models between head and baseline.") + sys.exit(1) + + print_comparison_table(config_aggs, metric) + print_regressions(deltas, metric, args.top) + render_delta_scurve(deltas, metric) + return + + # Absolute mode (no baseline) + print_run_header("RUN", head_metas, head_configs) + print() + + print_summary_table(head_perf, metric) + grouped = group_perf(head_perf, args.group_by) + for perf in grouped: + print_worst_offenders(perf, metric, args.top) + render_scurve(perf, metric, args.top) + + +CONFIG_RE = re.compile( + r"(?Pinductor_[a-z_]+?)_" + r"(?Phuggingface|timm_models|torchbench)_" + r"(?P\w+)_" + r"(?Ptraining|inference)_" + r"(?P\w+)" +) + + +def config_to_command( + config: str, + suite: str, + model: str | None = None, +) -> str | None: + """Turn a config name into a runnable benchmark command.""" + m = CONFIG_RE.match(config) + if not m: + return None + + backend_variant = m.group("backend") + dtype = m.group("dtype") + mode = m.group("mode") + runtime = m.group("device") + + # Runtime → --device flag (strip platform suffix like _x86_zen) + device_flag = runtime.split("_")[0] # "cpu_x86_zen" → "cpu" + + cmd_parts = [ + "python", + f"benchmarks/dynamo/{suite}.py", + f"--{mode}", + f"--{dtype}", + "--backend", + "inductor", + "--device", + device_flag, + ] + + if "no_cudagraphs" in backend_variant: + cmd_parts.append("--disable-cudagraphs") + if "dynamic" in backend_variant: + cmd_parts.extend(["--dynamic-shapes", "--dynamic-batch-only"]) + if "cpp_wrapper" in backend_variant: + cmd_parts.insert(0, "TORCHINDUCTOR_CPP_WRAPPER=1") + cmd_parts.append("--disable-cudagraphs") + if "freezing" in backend_variant: + cmd_parts.append("--freezing") + if "max_autotune" in backend_variant: + cmd_parts.insert(0, "TORCHINDUCTOR_MAX_AUTOTUNE=1") + if "aot_inductor" in backend_variant: + cmd_parts.append("--export-aot-inductor") + cmd_parts.append("--disable-cudagraphs") + + cmd_parts.extend(["--performance", "--cold-start-latency"]) + + if model: + cmd_parts.extend(["--only", model]) + + cmd_parts.extend(["--output", f"{config}_performance.csv"]) + + return " ".join(cmd_parts) + + +def cmd_repro(args): + run_ids = _resolve_head_runs(args) + + # Fetch perf data to discover configs + auto_discovered = not hasattr(args, "_run_ids") and not args.device + all_perf: list[PerfData] = [] + metas: list[RunMeta] = [] + for device, run_id in list(run_ids.items()): + print(f"Fetching run {run_id} ({device})...") + metas.append(fetch_run_meta(run_id)) + perf = fetch_run_perf( + run_id, + args.attempt, + no_cache=False, + device=device, + allow_empty=auto_discovered, + ) + all_perf.extend(perf) + + all_perf = filter_perf(all_perf, args) + if not all_perf: + print("No configs matched filters.") + sys.exit(1) + + print_run_header("REPRO", metas) + + configs_seen: dict[str, str] = {} # config_name → suite + for perf in all_perf: + configs_seen[perf.config] = perf.suite + + count = 0 + commands: list[str] = [] + for config in sorted(configs_seen): + suite = configs_seen[config] + cmd = config_to_command(config, suite, args.model) + if not cmd: + continue + commands.append(f"# {config}\n{cmd}") + count += 1 + + print(f"\nReproducible commands ({count} configs):\n") + for cmd in commands: + print(f"{cmd}\n") + + +PIN_DIR = Path(".ci/docker/ci_commit_pins") + + +def read_pin(name: str) -> str: + """Read a pinned commit or requirements file.""" + path = PIN_DIR / name + if not path.exists(): + return f"<{name} not found>" + return path.read_text().strip() + + +def cmd_prepare_repro(args): + suites = set() + if args.suite: + suite = SUITE_ALIASES.get(args.suite, args.suite) + suites = {suite} + else: + suites = {"huggingface", "timm_models", "torchbench"} + + torchbench_pin = read_pin("torchbench.txt") + timm_pin = read_pin("timm.txt") + hf_reqs = read_pin("huggingface-requirements.txt") + + print("# Setup commands for inductor perf benchmark suites") + print("# These mirror what CI does in the inductor-benchmarks Docker image.") + print("#") + print("# Pinned versions (commits, package versions) are read live from") + print("# .ci/docker/ci_commit_pins/") + print("# Install steps are based on:") + print("# .ci/docker/common/install_inductor_benchmark_deps.sh (build-time)") + print("# .ci/pytorch/test.sh (runtime)") + print("# If the setup process changes, check those files.") + print() + + if "huggingface" in suites: + print("# ── HuggingFace ──") + for line in hf_reqs.splitlines(): + line = line.strip() + if line and not line.startswith("#"): + print(f"pip install {line}") + print() + + if "timm_models" in suites: + print("# ── Timm ──") + print( + f"pip install git+https://github.com/huggingface/pytorch-image-models@{timm_pin}" + ) + print() + + if "torchbench" in suites: + print("# ── TorchBench ──") + print("git clone https://github.com/pytorch/benchmark torchbench") + print(f"cd torchbench && git checkout {torchbench_pin}") + print("python install.py --continue_on_fail") + print("cd ..") + print() + print("# Set PYTHONPATH so benchmark scripts find torchbench") + print("export PYTHONPATH=$(pwd)/torchbench") + print() + + print("# ── Runtime dependencies ──") + print("pip install torchvision torchaudio") + if "torchbench" in suites: + print("pip install opencv-python==4.8.0.74") + print() + + print("# ── Environment variables ──") + print("export TORCHINDUCTOR_FX_GRAPH_CACHE=True") + print("export TORCHINDUCTOR_AUTOGRAD_CACHE=True") + print() + + if not args.no_repro: + # Also print repro commands if we have a run_id + if args.run_id: + # Reuse the repro logic + repro_args = argparse.Namespace( + run_id=args.run_id, + device=args.device, + model=args.model, + suite=args.suite, + mode=args.mode, + backend=None, + dtype=None, + runtime=None, + attempt=args.attempt, + ) + print("# ── Benchmark commands ──") + cmd_repro(repro_args) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="CLI tool for inductor perf regression runs", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Launch a perf run on your current branch (A100) + python benchmarks/dynamo/perf_cli.py launch + + # Launch on both A100 and H100 + python benchmarks/dynamo/perf_cli.py launch --device a100 h100 + + # Launch on H100 and wait for completion + python benchmarks/dynamo/perf_cli.py launch --device h100 --wait + + # Launch with only inference, no training + python benchmarks/dynamo/perf_cli.py launch --no-training --inference + + # Launch with dynamic shapes enabled + python benchmarks/dynamo/perf_cli.py launch --dynamic + + # Launch on ROCm MI300 + python benchmarks/dynamo/perf_cli.py launch --device rocm-mi300 + + # Check your branch against latest main nightly + python benchmarks/dynamo/perf_cli.py summary my-feature-branch + + # Same but on H100 + python benchmarks/dynamo/perf_cli.py summary my-feature-branch --device h100 + + # Compare across A100 and H100 in one summary + python benchmarks/dynamo/perf_cli.py summary my-feature-branch --device a100 h100 + + # Use a specific run ID instead of branch name + python benchmarks/dynamo/perf_cli.py summary 22842783236 + + # Compare against a specific baseline (run ID or branch) + python benchmarks/dynamo/perf_cli.py summary 22842783236 --baseline 22816292132 + + # Absolute metrics only (no comparison) + python benchmarks/dynamo/perf_cli.py summary 22842783236 --baseline none + + # Filter to cudagraphs training + python benchmarks/dynamo/perf_cli.py summary main --config cudagraphs --mode training + + # Show commands to reproduce a run locally for a single model + python benchmarks/dynamo/perf_cli.py repro 22842783236 --model BERT_pytorch --suite tb +""", + ) + sub = parser.add_subparsers(dest="command", required=True) + + # -- launch -- + p_launch = sub.add_parser( + "launch", help="Launch a perf regression run on your branch" + ) + p_launch.add_argument( + "--device", + nargs="+", + default=["a100"], + choices=DEVICE_CHOICES, + metavar="DEVICE", + help=f"Devices to launch (default: a100). Choices: {', '.join(DEVICE_CHOICES)}", + ) + p_launch.add_argument( + "--ref", + type=str, + default=None, + help="Git ref to benchmark (default: current branch)", + ) + p_launch.add_argument( + "--wait", action="store_true", help="Wait for all runs to complete" + ) + p_launch.add_argument( + "--wait-and-summarize", + dest="wait_and_summarize", + action="store_true", + help="Wait for all runs, then print summary vs. latest main nightly", + ) + # Workflow dispatch inputs + launch_opts = p_launch.add_argument_group("workflow options") + launch_opts.add_argument( + "--training", + action="store_true", + default=None, + help="Enable training benchmarks", + ) + launch_opts.add_argument( + "--no-training", + dest="training", + action="store_false", + help="Disable training benchmarks", + ) + launch_opts.add_argument( + "--inference", + action="store_true", + default=None, + help="Enable inference benchmarks", + ) + launch_opts.add_argument( + "--no-inference", + dest="inference", + action="store_false", + help="Disable inference benchmarks", + ) + launch_opts.add_argument( + "--cudagraphs", action="store_true", default=None, help="Enable cudagraphs" + ) + launch_opts.add_argument( + "--no-cudagraphs", + dest="cudagraphs", + action="store_false", + help="Disable cudagraphs", + ) + launch_opts.add_argument( + "--dynamic", action="store_true", default=None, help="Enable dynamic shapes" + ) + launch_opts.add_argument( + "--no-dynamic", + dest="dynamic", + action="store_false", + help="Disable dynamic shapes", + ) + launch_opts.add_argument( + "--cppwrapper", action="store_true", default=None, help="Enable cpp wrapper" + ) + launch_opts.add_argument( + "--freezing-cudagraphs", + dest="freezing_cudagraphs", + action="store_true", + default=None, + ) + launch_opts.add_argument("--aotinductor", action="store_true", default=None) + launch_opts.add_argument("--maxautotune", action="store_true", default=None) + launch_opts.add_argument( + "--default", dest="default", action="store_true", default=None + ) + launch_opts.add_argument( + "--benchmark-configs", + dest="benchmark_configs", + type=str, + default=None, + help="Override benchmark_configs input", + ) + + # -- summary -- + p_summary = sub.add_parser("summary", help="Summarize results of a perf run") + p_summary.add_argument( + "run_id", type=str, help="GitHub Actions run ID or branch name" + ) + p_summary.add_argument( + "--device", + nargs="+", + default=None, + choices=DEVICE_CHOICES, + metavar="DEVICE", + help=f"Device(s) to summarize (default: auto-discover from branch). Choices: {', '.join(DEVICE_CHOICES)}", + ) + p_summary.add_argument( + "--baseline", + type=str, + default="latest", + help="Baseline run ID, 'latest' for most recent main nightly (default), or 'none' to disable", + ) + p_summary.add_argument( + "--metric", + type=str, + default="speedup", + choices=METRIC_CHOICES, + help="Metric to display (default: speedup)", + ) + p_summary.add_argument( + "--top", + type=int, + default=5, + help="Number of worst offenders to show (default: 5)", + ) + # Filters + filters = p_summary.add_argument_group("filters") + filters.add_argument( + "--config", + type=str, + default=None, + help="Regex to filter config names (e.g. 'cudagraphs', 'dynamic')", + ) + filters.add_argument( + "--suite", type=str, default=None, help="Filter to suite: hf, timm, tb" + ) + filters.add_argument( + "--mode", + type=str, + choices=["training", "inference"], + default=None, + help="Filter to training or inference", + ) + filters.add_argument( + "--backend", + type=str, + default=None, + help="Regex to filter backend (e.g. 'cudagraphs', 'dynamic')", + ) + filters.add_argument( + "--dtype", + type=str, + default=None, + help="Filter to dtype (e.g. amp, float16, bfloat16)", + ) + filters.add_argument( + "--runtime", + type=str, + default=None, + help="Filter to runtime (e.g. cuda, cpu, xpu)", + ) + p_summary.add_argument( + "--group-by", + dest="group_by", + nargs="+", + default=None, + choices=GROUP_CHOICES, + metavar="KEY", + help=f"Group S-curves by key(s) (default: single combined). Choices: {', '.join(GROUP_CHOICES)}", + ) + p_summary.add_argument( + "--attempt", type=int, default=1, help="Run attempt number (default: 1)" + ) + p_summary.add_argument( + "--no-cache", + dest="no_cache", + action="store_true", + default=False, + help="Re-download artifacts even if cached", + ) + + # -- repro -- + p_repro = sub.add_parser("repro", help="Reproduce a remote perf run locally") + p_repro.add_argument( + "run_id", type=str, help="GitHub Actions run ID or branch name" + ) + p_repro.add_argument( + "--device", + nargs="+", + default=None, + choices=DEVICE_CHOICES, + metavar="DEVICE", + help="Device(s) (default: auto-discover from branch)", + ) + p_repro.add_argument( + "--model", + type=str, + default=None, + help="Run only this model (e.g. BERT_pytorch)", + ) + p_repro.add_argument( + "--suite", type=str, default=None, help="Filter to suite (hf, timm, tb)" + ) + p_repro.add_argument( + "--mode", + type=str, + choices=["training", "inference"], + default=None, + help="Filter to mode", + ) + p_repro.add_argument( + "--backend", + type=str, + default=None, + help="Regex to filter backend (e.g. 'cudagraphs', 'dynamic')", + ) + p_repro.add_argument( + "--dtype", type=str, default=None, help="Filter to dtype (e.g. amp, bfloat16)" + ) + p_repro.add_argument( + "--runtime", type=str, default=None, help="Filter to runtime (e.g. cuda, cpu)" + ) + p_repro.add_argument( + "--attempt", type=int, default=1, help="Run attempt number (default: 1)" + ) + + # -- prepare-repro -- + p_prep = sub.add_parser( + "prepare-repro", help="Show setup commands to prepare benchmark suites locally" + ) + p_prep.add_argument( + "run_id", + nargs="?", + type=str, + default=None, + help="Optional: run ID or branch to also show benchmark commands", + ) + p_prep.add_argument( + "--device", + nargs="+", + default=None, + choices=DEVICE_CHOICES, + metavar="DEVICE", + help="Device(s) for benchmark commands", + ) + p_prep.add_argument( + "--suite", + type=str, + default=None, + help="Only show setup for this suite (hf, timm, tb)", + ) + p_prep.add_argument( + "--mode", + type=str, + choices=["training", "inference"], + default=None, + help="Filter benchmark commands to mode", + ) + p_prep.add_argument( + "--model", type=str, default=None, help="Filter benchmark commands to model" + ) + p_prep.add_argument( + "--no-repro", + dest="no_repro", + action="store_true", + help="Only show setup, skip benchmark commands", + ) + p_prep.add_argument( + "--attempt", type=int, default=1, help="Run attempt number (default: 1)" + ) + + args = parser.parse_args() + + if args.command == "launch": + cmd_launch(args) + elif args.command == "summary": + cmd_summary(args) + elif args.command == "repro": + cmd_repro(args) + elif args.command == "prepare-repro": + cmd_prepare_repro(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/benchmark_base.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/benchmark_base.py index 179a4e780ce59..55dac427ec4da 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/benchmark_base.py +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/benchmark_base.py @@ -54,7 +54,7 @@ # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER. 28: optional string github_run_number_str; } - """, # noqa: B950 + """, ) diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index ccd7d1acd7dc7..375c73eb6e917 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,8 +1,8 @@ -add_loop_eager,compile_time_instruction_count,3347000000,0.1 +add_loop_eager,compile_time_instruction_count,3479000000,0.1 -add_loop_eager_dynamic,compile_time_instruction_count,4578000000,0.1 +add_loop_eager_dynamic,compile_time_instruction_count,4779000000,0.1 @@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29660000000,0.1 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,37280000000,0.1 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,38150000000,0.1 @@ -18,11 +18,11 @@ add_loop_inductor_gpu,compile_time_instruction_count,26140000000,0.1 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,898000000,0.1 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,932400000,0.1 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15240000000,0.1 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,15640000000,0.1 @@ -34,19 +34,19 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,11090000 -update_hint_regression,compile_time_instruction_count,1548000000,0.1 +update_hint_regression,compile_time_instruction_count,1609000000,0.1 -sum_floordiv_regression,compile_time_instruction_count,3477000000,0.1 +sum_floordiv_regression,compile_time_instruction_count,3641000000,0.1 -symint_sum,compile_time_instruction_count,3112000000,0.1 +symint_sum,compile_time_instruction_count,3210000000,0.1 -symint_sum_loop,compile_time_instruction_count,4305000000,0.1 +symint_sum_loop,compile_time_instruction_count,4448000000,0.1 @@ -54,15 +54,15 @@ aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,1793000000 -aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,4988000000,0.1 +aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5109000000,0.1 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,7858000000,0.1 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,8318000000,0.1 -aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1834000000,0.1 +aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1937000000,0.1 @@ -70,23 +70,23 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3152000000, -aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8021000000,0.1 +aotdispatcher_training_subclass_cpu,compile_time_instruction_count,8317000000,0.1 -mm_loop_inductor_gpu,compile_time_instruction_count,4597000000,0.1 +mm_loop_inductor_gpu,compile_time_instruction_count,4707000000,0.1 -mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9051000000,0.1 +mm_loop_inductor_dynamic_gpu,compile_time_instruction_count,9270000000,0.1 -basic_NestedModule_eager,compile_time_instruction_count,5728000000,0.1 +basic_NestedModule_eager,compile_time_instruction_count,5866000000,0.1 -basic_InlineMod_eager,compile_time_instruction_count,7701000000,0.1 +basic_InlineMod_eager,compile_time_instruction_count,7863000000,0.1 @@ -102,11 +102,11 @@ dtensor_dispatch_collectives,instruction_count,28890000,0.1 -dtensor_dispatch_add_backward,instruction_count,346201,0.1 +dtensor_dispatch_add_backward,instruction_count,360100,0.1 -dtensor_dispatch_inplace,instruction_count,56530,0.1 +dtensor_dispatch_inplace,instruction_count,61490,0.1 @@ -118,4 +118,4 @@ dtensor_dispatch_random,instruction_count,592800,0.1 -dtensor_dispatch_custom_handler,instruction_count,35020,0.1 +dtensor_dispatch_custom_handler,instruction_count,36050,0.1 diff --git a/benchmarks/dynamo/timm_models.py b/benchmarks/dynamo/timm_models.py index ed52ffde5b618..09e7647537a4a 100755 --- a/benchmarks/dynamo/timm_models.py +++ b/benchmarks/dynamo/timm_models.py @@ -56,55 +56,6 @@ def pip_install(package): TIMM_MODELS[model_name] = int(batch_size) -# TODO - Figure out the reason of cold start memory spike - -BATCH_SIZE_DIVISORS = { - "beit_base_patch16_224": 2, - "deit_base_distilled_patch16_224": 2, - "gluon_xception65": 2, - "mobilevit_s": 2, - "swin_base_patch4_window7_224": 2, -} - -REQUIRE_HIGHER_TOLERANCE = { - "inception_v3", - "mobilenetv3_large_100", - "convnextv2_nano.fcmae_ft_in22k_in1k", -} - -REQUIRE_HIGHER_TOLERANCE_FP16_XPU = { - "botnet26t_256", -} - -REQUIRE_HIGHER_TOLERANCE_AMP = {} - -REQUIRE_EVEN_HIGHER_TOLERANCE = { - "deit_base_distilled_patch16_224", - "vit_base_patch16_siglip_256", -} - -# These models need higher tolerance in MaxAutotune mode -REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {} - -REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = { - "adv_inception_v3", -} - -SCALED_COMPUTE_LOSS = { - "mobilevit_s", -} - -FORCE_AMP_FOR_FP16_BF16_MODELS = {} - -SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {} - -REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = { - "inception_v3", - "mobilenetv3_large_100", - "vit_base_patch14_dinov2.lvd142m", -} - - def refresh_model_names(): import glob @@ -201,6 +152,22 @@ def _config(self): def _skip(self): return self._config["skip"] + @property + def _batch_size(self): + return self._config["batch_size"] + + @property + def _tolerance(self): + return self._config["tolerance"] + + @property + def _accuracy(self): + return self._config["accuracy"] + + @property + def _require_larger_multiplier_for_smaller_tensor(self): + return self._config["require_larger_multiplier_for_smaller_tensor"] + @property def skip_models_for_cpu(self): return self._skip["device"]["cpu"] @@ -215,7 +182,7 @@ def skip_models(self): @property def force_amp_for_fp16_bf16_models(self): - return FORCE_AMP_FOR_FP16_BF16_MODELS + return self._config["dtype"]["force_amp_for_fp16_bf16_models"] @property def force_fp16_for_bf16_models(self): @@ -228,17 +195,13 @@ def get_output_amp_train_process_func(self): @property def skip_accuracy_check_as_eager_non_deterministic(self): if self.args.accuracy and self.args.training: - return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS + return self._accuracy["skip"]["eager_not_deterministic"] return set() @property def guard_on_nn_module_models(self): return {} - @property - def inline_inbuilt_nn_modules_models(self): - return {} - @download_retry_decorator def _download_model(self, model_name): model = create_model( @@ -286,9 +249,10 @@ def load_model( input_size = data_config["input_size"] recorded_batch_size = TIMM_MODELS[model_name] - if model_name in BATCH_SIZE_DIVISORS: + batch_size_divisors = self._batch_size["divisors"] + if model_name in batch_size_divisors: recorded_batch_size = max( - int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1 + int(recorded_batch_size / batch_size_divisors[model_name]), 1 ) batch_size = batch_size or recorded_batch_size @@ -310,7 +274,7 @@ def load_model( self.loss = torch.nn.CrossEntropyLoss().to(device) - if model_name in SCALED_COMPUTE_LOSS: + if model_name in self._config["scaled_compute_loss"]: self.compute_loss = self.scaled_compute_loss if is_training and not use_eval_mode: @@ -318,7 +282,7 @@ def load_model( else: model.eval() - self.validate_model(model, example_inputs) + self.validate_model(model_name, model, example_inputs) return device, model_name, model, example_inputs, batch_size @@ -346,34 +310,34 @@ def pick_grad(self, name, is_training): return torch.no_grad() def use_larger_multiplier_for_smaller_tensor(self, name): - return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR + return name in self._require_larger_multiplier_for_smaller_tensor def get_tolerance_and_cosine_flag(self, is_training, current_device, name): cosine = self.args.cosine tolerance = 1e-3 - if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING: + if self.args.freezing and name in self._tolerance["freezing"]: # the conv-batchnorm fusion used under freezing may cause relatively - # large numerical difference. We need are larger tolerance. + # large numerical difference. We need a larger tolerance. # Check https://github.com/pytorch/pytorch/issues/120545 for context tolerance = 8 * 1e-2 if is_training: from torch._inductor import config as inductor_config - if name == "beit_base_patch16_224": + if name in self._tolerance["highest_training"]: tolerance = 16 * 1e-2 - elif name in REQUIRE_EVEN_HIGHER_TOLERANCE or ( + elif name in self._tolerance["even_higher"] or ( inductor_config.max_autotune - and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE + and name in self._tolerance["even_higher_max_autotune"] ): tolerance = 8 * 1e-2 - elif name in REQUIRE_HIGHER_TOLERANCE or ( - self.args.amp and name in REQUIRE_HIGHER_TOLERANCE_AMP + elif name in self._tolerance["higher_training"] or ( + self.args.amp and name in self._tolerance["higher_amp"] ): tolerance = 4 * 1e-2 elif ( - name in REQUIRE_HIGHER_TOLERANCE_FP16_XPU + name in self._tolerance["higher_fp16_xpu"] and self.args.float16 and current_device == "xpu" ): diff --git a/benchmarks/dynamo/timm_models.yaml b/benchmarks/dynamo/timm_models.yaml index 1650a87500537..3e94903791f93 100644 --- a/benchmarks/dynamo/timm_models.yaml +++ b/benchmarks/dynamo/timm_models.yaml @@ -11,3 +11,62 @@ skip: - dm_nfnet_f0 - nfnet_l0 - visformer_small + + +# TODO - Figure out the reason of cold start memory spike +batch_size: + divisors: + beit_base_patch16_224: 2 + deit_base_distilled_patch16_224: 2 + gluon_xception65: 2 + mobilevit_s: 2 + swin_base_patch4_window7_224: 2 + + +tolerance: + higher_training: + - inception_v3 + - mobilenetv3_large_100 + + higher_fp16_xpu: + - botnet26t_256 + + higher_amp: [] + + even_higher: + - deit_base_distilled_patch16_224 + - vit_base_patch16_siglip_256 + + # These models need higher tolerance in MaxAutotune mode + even_higher_max_autotune: [] + + # beit_base_patch16_224 needs 16 * 1e-2 tolerance in training + highest_training: + - beit_base_patch16_224 + + freezing: + # the conv-batchnorm fusion used under freezing may cause relatively + # large numerical difference. We need a larger tolerance. + # Check https://github.com/pytorch/pytorch/issues/120545 for context + - adv_inception_v3 + + cosine: [] + + +scaled_compute_loss: + - mobilevit_s + + +require_larger_multiplier_for_smaller_tensor: + - inception_v3 + - mobilenetv3_large_100 + - vit_base_patch14_dinov2.lvd142m + + +dtype: + force_amp_for_fp16_bf16_models: [] + + +accuracy: + skip: + eager_not_deterministic: [] diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index 77d2cb5bad808..cbe575f926d22 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -208,19 +208,6 @@ def guard_on_nn_module_models(self): "vision_maskrcnn", } - @property - def inline_inbuilt_nn_modules_models(self): - return { - "basic_gnn_edgecnn", - "drq", - "hf_Reformer", - "DALLE2_pytorch", - "detectron2_maskrcnn_r_50_fpn", - "detectron2_maskrcnn_r_101_fpn", - "vision_maskrcnn", - "doctr_reco_predictor", - } - @cached_property def _fb_models_available(self): """This property exists because importing IS_FBCODE causes some models to be @@ -394,7 +381,7 @@ def load_model( ): model.config.use_cache = False - self.validate_model(model, example_inputs) + self.validate_model(benchmark.name, model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size def iter_model_names(self, args): diff --git a/benchmarks/dynamo/torchbench.yaml b/benchmarks/dynamo/torchbench.yaml index 9a1825bc3d290..6ab39de217ba8 100644 --- a/benchmarks/dynamo/torchbench.yaml +++ b/benchmarks/dynamo/torchbench.yaml @@ -38,6 +38,7 @@ tolerance: - tacotron2 - yolov3 - squeezenet1_1 + - shufflenet_v2_x1_0 higher_fp16: - doctr_reco_predictor diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py index d8bc4c2fc9079..c748d70eeb22a 100644 --- a/benchmarks/fastrnns/bench.py +++ b/benchmarks/fastrnns/bench.py @@ -324,7 +324,7 @@ def bench_group(model_list, bench_name, bench_group, bench_args): vlrnns = ["vl_cudnn", "vl_jit", "vl_py"] if args.print_json: - print_stderr = lambda *args, **kwargs: None # noqa: E731,F811 + print_stderr = lambda *args, **kwargs: None # noqa: E731 print_stderr(args) bench_args = copy.deepcopy(vars(args)) diff --git a/benchmarks/functional_autograd_benchmark/torchvision_models.py b/benchmarks/functional_autograd_benchmark/torchvision_models.py index f749617d397e5..c4febceb1fba5 100644 --- a/benchmarks/functional_autograd_benchmark/torchvision_models.py +++ b/benchmarks/functional_autograd_benchmark/torchvision_models.py @@ -794,7 +794,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes): losses = { "loss_mask": sigmoid_focal_loss( # noqa: F821 src_masks, target_masks, num_boxes - ), # noqa: F821 + ), "loss_dice": dice_loss(src_masks, target_masks, num_boxes), # noqa: F821 } return losses diff --git a/benchmarks/gpt_fast/mixtral_moe_model.py b/benchmarks/gpt_fast/mixtral_moe_model.py index a0733588cab48..d0d186592b64a 100644 --- a/benchmarks/gpt_fast/mixtral_moe_model.py +++ b/benchmarks/gpt_fast/mixtral_moe_model.py @@ -1,4 +1,4 @@ -# flake8: noqa: E266, C417, B950 +# flake8: noqa: E266, C417 from dataclasses import dataclass import torch diff --git a/benchmarks/gpt_fast/mixtral_moe_quantize.py b/benchmarks/gpt_fast/mixtral_moe_quantize.py index 016a47e1e9734..65eac31a95ba5 100644 --- a/benchmarks/gpt_fast/mixtral_moe_quantize.py +++ b/benchmarks/gpt_fast/mixtral_moe_quantize.py @@ -1,4 +1,4 @@ -# flake8: noqa: E266, C417, B950 +# flake8: noqa: E266 from mixtral_moe_model import ConditionalFeedForward import torch diff --git a/benchmarks/gpt_fast/model.py b/benchmarks/gpt_fast/model.py index 5a1805bb1c32a..a137ba505dd82 100644 --- a/benchmarks/gpt_fast/model.py +++ b/benchmarks/gpt_fast/model.py @@ -1,4 +1,4 @@ -# flake8: noqa: E266, C417, B950 +# flake8: noqa: E266, C417 from dataclasses import dataclass import torch diff --git a/benchmarks/gpt_fast/quantize.py b/benchmarks/gpt_fast/quantize.py index 524c7072b2a4c..127f13fcecc5e 100644 --- a/benchmarks/gpt_fast/quantize.py +++ b/benchmarks/gpt_fast/quantize.py @@ -1,4 +1,4 @@ -# flake8: noqa: E266, C417, B950 +# flake8: noqa: E266 import torch import torch.nn as nn import torch.nn.functional as F diff --git a/benchmarks/inductor_backends/cutlass.py b/benchmarks/inductor_backends/cutlass.py index c231e9ddcda63..adab252b63575 100644 --- a/benchmarks/inductor_backends/cutlass.py +++ b/benchmarks/inductor_backends/cutlass.py @@ -454,7 +454,7 @@ def main(): ) ) for i, group_config in enumerate(tqdm(configs)): - group_results = run_single_experiment_group(group_config) # noqa: G004 + group_results = run_single_experiment_group(group_config) results.append( ExperimentGroup(config=group_config, results=group_results), ) diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py index e08e741029844..e7ae09b76ac09 100644 --- a/benchmarks/instruction_counts/worker/main.py +++ b/benchmarks/instruction_counts/worker/main.py @@ -177,7 +177,7 @@ def main(communication_file: str) -> None: # Runner process sent SIGINT. sys.exit() - except BaseException: # noqa: B036 + except BaseException: trace_f = io.StringIO() traceback.print_exc(file=trace_f) result = WorkerFailure(failure_trace=trace_f.getvalue()) diff --git a/benchmarks/operator_benchmark/benchmark_all_test.py b/benchmarks/operator_benchmark/benchmark_all_test.py index f7d967c2c261a..2f476c01c997d 100644 --- a/benchmarks/operator_benchmark/benchmark_all_test.py +++ b/benchmarks/operator_benchmark/benchmark_all_test.py @@ -1,5 +1,12 @@ +import platform + import benchmark_all_other_test # noqa: F401 -import benchmark_all_quantized_test # noqa: F401 + + +# Quantized benchmarks use fbgemm which only supports x86 +if platform.machine() in ("x86_64", "AMD64"): + import benchmark_all_quantized_test # noqa: F401 + from pt import unary_test # noqa: F401 import operator_benchmark as op_bench diff --git a/benchmarks/operator_benchmark/benchmark_runner.py b/benchmarks/operator_benchmark/benchmark_runner.py index 6568cf9bf3ee6..4c96b40b8db47 100644 --- a/benchmarks/operator_benchmark/benchmark_runner.py +++ b/benchmarks/operator_benchmark/benchmark_runner.py @@ -16,6 +16,7 @@ description="Run microbenchmarks.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, conflict_handler="resolve", + allow_abbrev=False, ) diff --git a/benchmarks/operator_benchmark/operator_benchmark.py b/benchmarks/operator_benchmark/operator_benchmark.py index b3c6678420f85..95a30dad117bf 100644 --- a/benchmarks/operator_benchmark/operator_benchmark.py +++ b/benchmarks/operator_benchmark/operator_benchmark.py @@ -1,6 +1,6 @@ # TODO (mingzhe09088): get rid of noqa import benchmark_runner # noqa: F401 from benchmark_pytorch import TorchBenchmarkBase # noqa: F401 -from benchmark_test_generator import * # noqa: F401,F403 +from benchmark_test_generator import * # noqa: F403 -from benchmark_utils import * # noqa: F401,F403 +from benchmark_utils import * # noqa: F403 diff --git a/benchmarks/serialization/export_save_linear_bench.py b/benchmarks/serialization/export_save_linear_bench.py new file mode 100644 index 0000000000000..6f3f9360e9896 --- /dev/null +++ b/benchmarks/serialization/export_save_linear_bench.py @@ -0,0 +1,287 @@ +import argparse +import gc +import statistics +import tempfile +import time +from dataclasses import dataclass +from pathlib import Path + +import torch + + +DEFAULT_NUM_PARAMS = ( + 1_000_000, + 10_000_000, + 100_000_000, + 1_000_000_000, + 2_000_000_000, + 3_000_000_000, + 4_000_000_000, +) +NUM_LAYERS = 5 +DEFAULT_REPEATS = 5 + + +class LinearModel(torch.nn.Module): + def __init__(self, hidden_size: int, *, dtype: torch.dtype) -> None: + super().__init__() + self.layers = torch.nn.ModuleList( + [ + torch.nn.Linear( + hidden_size, + hidden_size, + bias=True, + dtype=dtype, + ) + for _ in range(NUM_LAYERS) + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = layer(x) + return x + + +@dataclass +class BenchmarkResult: + num_params: int + median_save_ms: float | None + status: str + error_detail: str | None = None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Benchmark torch.export.save for a five-layer LinearModel across " + "a range of parameter counts." + ) + ) + parser.add_argument( + "--num-params", + type=int, + nargs="+", + default=list(DEFAULT_NUM_PARAMS), + help=( + "Target total parameter counts to benchmark. The script derives the " + "nearest hidden size for each target." + ), + ) + parser.add_argument( + "--repeats", + type=int, + default=DEFAULT_REPEATS, + help="Number of torch.export.save timings to collect per parameter-count case.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=0, + help=( + "Batch size for the example input passed to torch.export.export. " + "The default of 0 keeps the benchmark focused on serialization cost." + ), + ) + parser.add_argument( + "--dtype", + choices=("float32", "float16", "bfloat16"), + default="float32", + help="Parameter and example-input dtype.", + ) + return parser.parse_args() + + +def _dtype_from_name(dtype_name: str) -> torch.dtype: + return { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + }[dtype_name] + + +def _parameter_count(hidden_size: int) -> int: + return NUM_LAYERS * (hidden_size * hidden_size + hidden_size) + + +def _hidden_size_from_num_params(num_params: int) -> int: + if num_params <= 0: + raise ValueError("--num-params values must be positive") + + root = (-1.0 + (1.0 + 4.0 * num_params / NUM_LAYERS) ** 0.5) / 2.0 + base_hidden_size = max(1, int(root)) + candidates = [ + candidate + for candidate in ( + base_hidden_size - 1, + base_hidden_size, + base_hidden_size + 1, + base_hidden_size + 2, + ) + if candidate >= 1 + ] + return min( + candidates, + key=lambda hidden_size: abs(_parameter_count(hidden_size) - num_params), + ) + + +def _format_num_params(num_params: int) -> str: + for threshold, suffix in ( + (1_000_000_000, "B"), + (1_000_000, "M"), + (1_000, "K"), + ): + if num_params >= threshold: + value = num_params / threshold + return f"{value:.1f}".rstrip("0").rstrip(".") + suffix + return str(num_params) + + +def _measure_save_times_ms( + exported_program: torch.export.ExportedProgram, + repeats: int, +) -> list[float]: + times_ms: list[float] = [] + with tempfile.TemporaryDirectory(prefix="export_save_bench_") as temp_dir: + temp_path = Path(temp_dir) + for iteration in range(repeats): + save_path = temp_path / f"exported_program_{iteration}.pt2" + start_time = time.perf_counter() + torch.export.save(exported_program, save_path) + elapsed_ms = (time.perf_counter() - start_time) * 1000.0 + times_ms.append(elapsed_ms) + save_path.unlink(missing_ok=True) + return times_ms + + +def _status_for_exception(exc: BaseException) -> tuple[str, str]: + message = ( + str(exc).strip().splitlines()[0] if str(exc).strip() else type(exc).__name__ + ) + lowered = message.lower() + if ( + isinstance(exc, MemoryError) + or "out of memory" in lowered + or "can't allocate memory" in lowered + ): + return "oom", message + return "error", message + + +def _run_case( + target_num_params: int, + *, + repeats: int, + batch_size: int, + dtype: torch.dtype, +) -> BenchmarkResult: + exported_program = None + hidden_size = _hidden_size_from_num_params(target_num_params) + num_params = _parameter_count(hidden_size) + try: + with torch.no_grad(): + model = LinearModel(hidden_size, dtype=dtype).eval() + model.requires_grad_(False) + example_input = torch.zeros(batch_size, hidden_size, dtype=dtype) + exported_program = torch.export.export( + model, + (example_input,), + strict=True, + ) + + save_times_ms = _measure_save_times_ms(exported_program, repeats) + return BenchmarkResult( + num_params=num_params, + median_save_ms=statistics.median(save_times_ms), + status="ok", + ) + except (MemoryError, RuntimeError) as exc: + status, detail = _status_for_exception(exc) + return BenchmarkResult( + num_params=num_params, + median_save_ms=None, + status=status, + error_detail=detail, + ) + finally: + del exported_program + del example_input + del model + gc.collect() + + +def _format_table(results: list[BenchmarkResult]) -> str: + headers = ("num_params", "median_save_ms", "status") + rows = [ + ( + _format_num_params(result.num_params), + ( + f"{result.median_save_ms:.3f}" + if result.median_save_ms is not None + else "n/a" + ), + result.status, + ) + for result in results + ] + widths = [ + max(len(header), *(len(row[column]) for row in rows)) + for column, header in enumerate(headers) + ] + + def _format_row(row: tuple[str, ...]) -> str: + return " | ".join( + [ + row[0].rjust(widths[0]), + row[1].rjust(widths[1]), + row[2].ljust(widths[2]), + ] + ) + + lines = [ + _format_row(headers), + "-+-".join("-" * width for width in widths), + ] + lines.extend(_format_row(row) for row in rows) + return "\n".join(lines) + + +def main() -> None: + args = _parse_args() + dtype = _dtype_from_name(args.dtype) + + print("Benchmarking torch.export.save for a 5-layer LinearModel") + print(f"dtype={args.dtype}, batch_size={args.batch_size}, repeats={args.repeats}") + + results = [ + _run_case( + num_params, + repeats=args.repeats, + batch_size=args.batch_size, + dtype=dtype, + ) + for num_params in args.num_params + ] + + print() + print( + "Note: This benchmark is highly sensitive to disk performance and OS " + "behavior. Results for smaller parameter counts can be noisy or flaky." + ) + print() + print(_format_table(results)) + + failures = [result for result in results if result.error_detail is not None] + if failures: + print() + print("failed cases:") + for result in failures: + print( + f"- num_params={_format_num_params(result.num_params)}: " + f"{result.error_detail}" + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/static_runtime/test_static_runtime.cc b/benchmarks/static_runtime/test_static_runtime.cc index 69d925675d478..13dc7a21bf2b3 100644 --- a/benchmarks/static_runtime/test_static_runtime.cc +++ b/benchmarks/static_runtime/test_static_runtime.cc @@ -1074,6 +1074,10 @@ TEST(StaticRuntime, NanToNum) { } TEST(StaticRuntime, Stack) { +#if defined(__aarch64__) || defined(_M_ARM64) + // See https://github.com/pytorch/pytorch/issues/178522. + GTEST_SKIP() << "Skipping StaticRuntime.Stack on AArch64."; +#endif const auto stack_dim = R"JIT( def forward(self, a: Tensor, b: Tensor, dim: int): inputs = [a] @@ -2307,6 +2311,10 @@ TEST(StaticRuntime, Append) { } TEST(StaticRuntime, QuantizedLinear) { +#if defined(__aarch64__) || defined(_M_ARM64) + // See https://github.com/pytorch/pytorch/issues/178522. + GTEST_SKIP() << "Skipping QuantizedLinear on AArch64."; +#endif const std::string quantize_script = R"IR( graph(%input: Tensor, %weights: Tensor): %scale: float = prim::Constant[value=1.]() @@ -2331,6 +2339,10 @@ TEST(StaticRuntime, QuantizedLinear) { } TEST(StaticRuntime, QuantizedLinearDynamicFp16) { +#if defined(__aarch64__) || defined(_M_ARM64) + // See https://github.com/pytorch/pytorch/issues/178522. + GTEST_SKIP() << "Skipping QuantizedLinearDynamicFp16 on AArch64."; +#endif const std::string quantized_linear_dynamic_fp16_script = R"IR( graph(%input: Tensor, %weights: Tensor): %bias: None = prim::Constant() @@ -2352,6 +2364,10 @@ TEST(StaticRuntime, QuantizedLinearDynamicFp16) { } TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) { +#if defined(__aarch64__) || defined(_M_ARM64) + // See https://github.com/pytorch/pytorch/issues/178522. + GTEST_SKIP() << "Skipping QuantizedLinearReluDynamicFp16 on AArch64."; +#endif const std::string quantized_linear_relu_dynamic_fp16_script = R"IR( graph(%input: Tensor, %weights: Tensor): %bias: None = prim::Constant() diff --git a/benchmarks/tensorexpr/__main__.py b/benchmarks/tensorexpr/__main__.py index c50eb338cd706..8fecc338df63f 100644 --- a/benchmarks/tensorexpr/__main__.py +++ b/benchmarks/tensorexpr/__main__.py @@ -2,9 +2,9 @@ import itertools import os -# from . import conv # noqa: F401 -# from . import normalization # noqa: F401 -# from . import pooling # noqa: F401 +# from . import conv +# from . import normalization +# from . import pooling from . import ( # noqa: F401 attention, benchmark, diff --git a/buckbuild.bzl b/buckbuild.bzl index 1745af1b06dd4..4ad6de567c4b3 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -71,15 +71,15 @@ def read_bool(section, field, default, required = True): else: fail("`{}:{}`: no value set".format(section, field)) -def _is_build_mode_dev(): - if is_production_build_android(): - # Android Prod builds - return False +def _select_if_build_mode_dev(dev_value, default = []): if is_production_build_ios() or is_profile_build_ios(): - # iOS Prod builds - return False + return default + + return select({ + "DEFAULT": default, + "ovr_config//build_mode:optimization[dev]": dev_value, + }) - return True def _get_enable_lightweight_dispatch(): return read_bool("pt", "enable_lightweight_dispatch", False) @@ -93,10 +93,24 @@ def get_enable_mobile_dispatch_keys_trimming(): def get_disable_per_op_profiling(): return read_bool("pt", "disable_per_op_profiling", True) -def get_strip_error_messages(): +def strip_error_messages_select(value, default = []): if IS_OSS: - return True # always strip in OSS CI to expose potential issues - return read_bool("pt", "strip_error_messages", not _is_build_mode_dev()) + return value # always strip in OSS CI to expose potential issues + strip_error = read_bool("pt", "strip_error_messages", default = None, required = False) + + if strip_error == None: + + if is_production_build_ios() or is_profile_build_ios(): + return value + + return select({ + "DEFAULT": default, + "ovr_config//build_mode:optimization[opt]": value, + }) + + if strip_error: + return value + return default def get_disable_warn(): return read_bool("pt", "disable_warn", False) @@ -249,9 +263,7 @@ _COMMON_PREPROCESSOR_FLAGS = [ "-DNO_EXPORT", ] + ( ["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else [] -) + ( - ["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else [] -) + ( +) + strip_error_messages_select(["-DSTRIP_ERROR_MESSAGES"]) + ( ["-DDISABLE_WARN"] if get_disable_warn() else [] ) @@ -282,9 +294,9 @@ def get_aten_preprocessor_flags(): "-DUSE_RUY_QMATMUL", ] if get_disable_per_op_profiling(): - ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING") + ATEN_PREPROCESSOR_FLAGS += ["-DPYTORCH_DISABLE_PER_OP_PROFILING"] if _get_enable_record_kernel_dtype(): - ATEN_PREPROCESSOR_FLAGS.append("-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE") + ATEN_PREPROCESSOR_FLAGS += ["-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE"] return ATEN_PREPROCESSOR_FLAGS def get_pt_preprocessor_flags(): @@ -295,8 +307,7 @@ def get_pt_preprocessor_flags(): "-DNO_CUDNN_DESTROY_HANDLE", ] - if _is_build_mode_dev(): - PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS") + PT_PREPROCESSOR_FLAGS += _select_if_build_mode_dev(["-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS"]) return PT_PREPROCESSOR_FLAGS # This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892 @lint-ignore @@ -905,6 +916,7 @@ def define_buck_targets( ("aten/src", "ATen/ops/*.h"), # ATen Base ("aten/src", "ATen/*.h"), + ("aten/src", "ATen/accelerator/*.h"), ("aten/src", "ATen/cpu/**/*.h"), ("aten/src", "ATen/detail/*.h"), ("aten/src", "ATen/functorch/**/*.h"), @@ -1519,7 +1531,12 @@ def define_buck_targets( srcs = [ "torch/csrc/api/src/data/samplers/random.cpp", "torch/csrc/api/src/data/samplers/sequential.cpp", + "torch/csrc/api/src/optim/adagrad.cpp", + "torch/csrc/api/src/optim/adam.cpp", + "torch/csrc/api/src/optim/adamw.cpp", + "torch/csrc/api/src/optim/lbfgs.cpp", "torch/csrc/api/src/optim/optimizer.cpp", + "torch/csrc/api/src/optim/rmsprop.cpp", "torch/csrc/api/src/optim/serialize.cpp", "torch/csrc/api/src/optim/sgd.cpp", "torch/csrc/api/src/serialize/input-archive.cpp", diff --git a/build_variables.bzl b/build_variables.bzl index fab555b9c1714..91c5923b6cd76 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -691,7 +691,7 @@ libtorch_lite_cmake_sources = sorted( torch_mobile_core, ) -libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources + libtorch_nativert_sources +libtorch_cmake_sources = libtorch_core_sources + libtorch_core_jit_sources libtorch_extra_sources = libtorch_core_jit_sources + [ "torch/csrc/autograd/TraceTypeManual.cpp", @@ -742,6 +742,7 @@ libtorch_cuda_core_sources = [ "torch/csrc/cuda/shim_common.cpp", "torch/csrc/inductor/aoti_runner/model_container_runner_cuda.cpp", "torch/csrc/inductor/aoti_torch/shim_cuda.cpp", + "torch/csrc/inductor/inductor_ops_gpu.cpp", "torch/csrc/jit/codegen/fuser/cuda/fused_kernel.cpp", "torch/csrc/profiler/stubs/cuda.cpp", "torch/csrc/autograd/functions/comm.cpp", @@ -776,6 +777,7 @@ libtorch_cuda_distributed_extra_sources = [ "torch/csrc/distributed/c10d/symm_mem/CudaDMAConnectivity.cpp", "torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu", "torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu", + "torch/csrc/distributed/c10d/symm_mem/ops/nccl_reduce_scatter_offset.cu", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp", "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu", "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp", @@ -1159,6 +1161,7 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/LegacyVmapMode.cpp", "aten/src/ATen/LegacyVmapTransforms.cpp", "aten/src/ATen/core/BackendSelectFallbackKernel.cpp", + "aten/src/ATen/core/CachingHostAllocator.cpp", "aten/src/ATen/core/DeprecatedTypeProperties.cpp", "aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp", "aten/src/ATen/core/Dict.cpp", @@ -1172,7 +1175,6 @@ aten_cpu_source_non_codegen_list = [ "aten/src/ATen/core/Tensor.cpp", "aten/src/ATen/core/VariableFallbackKernel.cpp", "aten/src/ATen/core/VariableHooksInterface.cpp", - "aten/src/ATen/core/Vitals.cpp", "aten/src/ATen/core/boxing/KernelFunction.cpp", "aten/src/ATen/core/custom_class.cpp", "aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp", @@ -1503,6 +1505,7 @@ aten_native_source_non_codegen_list = [ # "aten/src/ATen/native/UpSample.cpp", "aten/src/ATen/native/UpSampleBicubic2d.cpp", "aten/src/ATen/native/UpSampleBilinear2d.cpp", + "aten/src/ATen/native/UpSampleLanczos2d.cpp", "aten/src/ATen/native/UpSampleLinear1d.cpp", "aten/src/ATen/native/UpSampleNearest1d.cpp", "aten/src/ATen/native/UpSampleNearest2d.cpp", diff --git a/c10/core/AllocatorConfig.cpp b/c10/core/AllocatorConfig.cpp index 3f9cee10ad528..212230c42a92f 100644 --- a/c10/core/AllocatorConfig.cpp +++ b/c10/core/AllocatorConfig.cpp @@ -11,6 +11,28 @@ constexpr size_t kRoundUpPowerOfTwoStart = 1 * kMB; // 1MB constexpr size_t kRoundUpPowerOfTwoEnd = 64 * 1024ul * kMB; // 64GB } // anonymous namespace +std::unordered_set& AcceleratorAllocatorConfig::getMutableKeys() { + static std::unordered_set keys{ + "large_segment_size_mb", + "max_split_size_mb", + "max_non_split_rounding_mb", + "garbage_collection_threshold", + "roundup_power2_divisions", + "expandable_segments", + "pinned_use_background_threads"}; + return keys; +} + +const std::unordered_set& AcceleratorAllocatorConfig::getKeys() { + return getMutableKeys(); +} + +std::function& AcceleratorAllocatorConfig:: + getConfigParserHook() { + static std::function hook{nullptr}; + return hook; +} + AcceleratorAllocatorConfig& AcceleratorAllocatorConfig::instance() { static AcceleratorAllocatorConfig instance; static bool env_flag [[maybe_unused]] = []() { diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h index d314c93f0494f..603b2fdc4130e 100644 --- a/c10/core/AllocatorConfig.h +++ b/c10/core/AllocatorConfig.h @@ -234,32 +234,17 @@ class C10_API AcceleratorAllocatorConfig { // Use `Construct On First Use Idiom` to avoid `Static Initialization Order` // issue. - static std::unordered_set& getMutableKeys() { - static std::unordered_set keys{ - "large_segment_size_mb", - "max_split_size_mb", - "max_non_split_rounding_mb", - "garbage_collection_threshold", - "roundup_power2_divisions", - "expandable_segments", - "pinned_use_background_threads"}; - return keys; - } + static std::unordered_set& getMutableKeys(); // Returns the set of valid keys for the allocator configuration. // This set is used to validate the presence and correctness of keys in // device-specific configuration parsers. - static const std::unordered_set& getKeys() { - return getMutableKeys(); - } + static const std::unordered_set& getKeys(); // Optional hook for parsing additional device-specific allocator settings. // This allows backends (e.g., CUDA, XPU) to register a custom parser for // their own environment configuration extensions. - static std::function& getConfigParserHook() { - static std::function hook{nullptr}; - return hook; - } + static std::function& getConfigParserHook(); // Registers a device-specific configuration parser hook and its key. This // allows backends to parse additional device-specific configuration options diff --git a/c10/core/AutogradState.h b/c10/core/AutogradState.h index d2b9cc080413d..616597dc400c0 100644 --- a/c10/core/AutogradState.h +++ b/c10/core/AutogradState.h @@ -21,8 +21,7 @@ struct C10_API AutogradState { grad_mode_(grad_mode), inference_mode_(inference_mode), fw_grad_mode_(fw_grad_mode), - multithreading_enabled_(multithreading_enabled), - view_replay_enabled_(false) {} + multithreading_enabled_(multithreading_enabled) {} void set_grad_mode(bool enabled) { grad_mode_ = enabled; @@ -78,8 +77,7 @@ struct C10_API AutogradState { bool inference_mode_ : 1; bool fw_grad_mode_ : 1; bool multithreading_enabled_ : 1; - // NOLINTNEXTLINE(cppcoreguidelines-use-default-member-init) - bool view_replay_enabled_ : 1; + bool view_replay_enabled_ : 1 = false; }; } // namespace c10 diff --git a/c10/core/CachingDeviceAllocator.h b/c10/core/CachingDeviceAllocator.h index 6800c79ab09ff..2e26e4091abe7 100644 --- a/c10/core/CachingDeviceAllocator.h +++ b/c10/core/CachingDeviceAllocator.h @@ -56,6 +56,9 @@ struct DeviceStats { // un-mapped and free memory. int64_t num_device_free = 0; + // COUNT: total number of allocations rejected by OOM preemption policy + int64_t num_oom_rejections = 0; + // SIZE: maximum block size that is allowed to be split. int64_t max_split_size = 0; }; diff --git a/c10/core/DeviceType.cpp b/c10/core/DeviceType.cpp index 907493981e117..4d5ea626a8800 100644 --- a/c10/core/DeviceType.cpp +++ b/c10/core/DeviceType.cpp @@ -62,9 +62,6 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) { ". If you have recently updated the caffe2.proto file to add a new " "device type, did you forget to update the DeviceTypeName() " "function to reflect such recent changes?"); - // The below code won't run but is needed to suppress some compiler - // warnings. - return ""; } } diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h index d0445f009af05..36e50dcbb1b9d 100644 --- a/c10/core/DispatchKey.h +++ b/c10/core/DispatchKey.h @@ -221,12 +221,12 @@ enum class DispatchKey : uint16_t { // correct backend. BackendSelect, - Python, - - // Out-of-core key for Fake Tensor in torchdistx. - // See https://pytorch.org/torchdistx/latest/fake_tensor.html - // TODO: delete this in favor of Python-implemented fake tensor + // Fake dispatch key for C++ FakeTensor mode. Must be BELOW Python so that + // TorchDispatchModes (e.g. ProxyTorchDispatchMode, FakeTensorMode) fire + // first, matching the Python FakeTensor dispatch order. Fake, + + Python, // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key // is to insert code after the "autograd subsystem" runs, so this key should // be directly after ADInplaceOrView and all of the autograd keys. diff --git a/c10/core/PyHandleCache.h b/c10/core/PyHandleCache.h index 8861f568bd972..3a312fd314ec0 100644 --- a/c10/core/PyHandleCache.h +++ b/c10/core/PyHandleCache.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -37,40 +38,27 @@ namespace c10 { // not be a way to conveniently index based on the object.) class PyHandleCache { public: - PyHandleCache() : pyinterpreter_(nullptr) {} + PyHandleCache() = default; - // Attempt to fetch the pointer from the cache, if the PyInterpreter + // Attempt to fetch the pointer from the cache, if the PyObject // matches. If it doesn't exist, or the cache entry is not valid, // use slow_accessor to get the real pointer value and return that // (possibly writing it to the cache, if the cache entry is // available.) template - PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor) - const { - // Note [Memory ordering on Python interpreter tag] - impl::PyInterpreter* interpreter = - pyinterpreter_.load(std::memory_order_acquire); - if (C10_LIKELY(interpreter == self_interpreter)) { - return data_; - } else if (interpreter == nullptr) { - auto* r = slow_accessor(); - impl::PyInterpreter* expected = nullptr; - // attempt to claim this cache entry with the specified interpreter tag - if (pyinterpreter_.compare_exchange_strong( - expected, self_interpreter, std::memory_order_acq_rel)) { - data_ = r; - } - // This shouldn't be possible, as you should be GIL protected - TORCH_INTERNAL_ASSERT(expected != self_interpreter); - return r; - } else { - return slow_accessor(); + PyObject* ptr_or(F slow_accessor) const { + PyObject* d = data_.load(std::memory_order_acquire); + if (C10_LIKELY(d != nullptr)) { + return d; } + auto* r = slow_accessor(); + PyObject* expected = nullptr; + data_.compare_exchange_strong(expected, r, std::memory_order_acq_rel); + return r; } private: - mutable std::atomic pyinterpreter_; - mutable PyObject* data_{nullptr}; + mutable std::atomic data_{nullptr}; }; } // namespace c10 diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 040c6abb7d8e2..dc9d168f053e7 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -54,6 +54,31 @@ inline size_t elementSize(ScalarType t) { #undef CASE_ELEMENTSIZE_CASE } +inline ScalarType opaqueScalarType(ScalarType t) { + auto esize = elementSize(t); + ScalarType result; + switch (esize) { + case 1: + result = kByte; + break; + case 2: + result = kUInt16; + break; + case 4: + result = kUInt32; + break; + case 8: + result = kUInt64; + break; + case 16: + result = kComplexDouble; + break; + default: + TORCH_CHECK(false, "Unknown ScalarType"); + } + return result; +} + inline bool isIntegralType(ScalarType t, bool includeBool) { bool isIntegral = (t == ScalarType::Byte || t == ScalarType::Char || t == ScalarType::Int || diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index 56bc75e01adb1..25388d37ff98b 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -3,6 +3,18 @@ namespace c10 { +void StorageImpl::incref_pyobject() const noexcept { + pyobj_slot_.incref(); +} + +void StorageImpl::decref_pyobject() const noexcept { + pyobj_slot_.decref(); +} + +bool StorageImpl::try_incref_pyobject() const noexcept { + return pyobj_slot_.try_incref(); +} + // The array to save function pointer for custom storageImpl create. static std::array StorageImplCreate; @@ -48,30 +60,6 @@ void warnDeprecatedDataPtr() { TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); } -void StorageImpl::incref_pyobject() const noexcept { - // Because intrusive_ptr incref uses relaxed memory order, we need to - // do an acquire fence to ensure that the kHasPyObject bit was - // observed before the load of the PyObject* below. - // NB: This is a no-op on x86/x86-64 - std::atomic_thread_fence(std::memory_order_acquire); - - PyObject* obj = pyobj_slot_.load_pyobj(); - (*pyobj_slot_.pyobj_interpreter())->incref(obj); -} - -void StorageImpl::decref_pyobject() const noexcept { - PyObject* obj = pyobj_slot_.load_pyobj(); - (*pyobj_slot_.pyobj_interpreter())->decref(obj); -} - -bool StorageImpl::try_incref_pyobject() const noexcept { - c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); - if (C10_UNLIKELY(!interp)) { - return false; - } - return (*interp)->try_incref(pyobj_slot_); -} - void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { // Allowlist verification. // Only if the devicetype is in the allowlist, diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index 0cd3a8fd5db0f..7a4631a88f753 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -106,9 +106,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } void incref_pyobject() const noexcept final; - void decref_pyobject() const noexcept final; - bool try_incref_pyobject() const noexcept final; size_t nbytes() const { @@ -312,6 +310,14 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { refresh_has_data_ptr_check(); } + void clear_data_ptr_access_error_msg_() { + throw_on_immutable_data_ptr_ = false; + if (extra_meta_) { + extra_meta_->custom_data_ptr_error_msg_ = std::nullopt; + } + refresh_has_data_ptr_check(); + } + void set_throw_on_mutable_data_ptr() { throw_on_mutable_data_ptr_ = true; refresh_has_data_ptr_check(); diff --git a/c10/core/Stream.cpp b/c10/core/Stream.cpp index 45339719a91b1..afb8f525dabe4 100644 --- a/c10/core/Stream.cpp +++ b/c10/core/Stream.cpp @@ -22,6 +22,12 @@ void Stream::synchronize() const { impl.synchronizeStream(*this); } +// Return whether this stream is currently under graph capturing mode. +bool Stream::is_capturing() const { + impl::VirtualGuardImpl impl{device_.type()}; + return impl.isStreamCapturing(*this); +} + // Not very parsable, but I don't know a good compact syntax for streams. // Feel free to change this into something more compact if needed. std::ostream& operator<<(std::ostream& stream, const Stream& s) { diff --git a/c10/core/Stream.h b/c10/core/Stream.h index 3890faa05d9c4..08eb0eed3c6e1 100644 --- a/c10/core/Stream.h +++ b/c10/core/Stream.h @@ -134,6 +134,10 @@ class C10_API Stream final { // on this stream has completed running on the device. void synchronize() const; + // Return the stream is currently recording work for graph capture. True while + // the stream is in capture mode, false otherwise. + bool is_capturing() const; + // The purpose of this function is to more conveniently permit binding // of Stream to and from Python. Without packing, I have to setup a whole // class with two fields (device and stream id); with packing I can just diff --git a/c10/core/SymbolicShapeMeta.h b/c10/core/SymbolicShapeMeta.h index 0820038968a8e..59c3147c878f7 100644 --- a/c10/core/SymbolicShapeMeta.h +++ b/c10/core/SymbolicShapeMeta.h @@ -141,7 +141,7 @@ class C10_API SymbolicShapeMeta { available_.fetch_or(is_contiguous_avail); } void assume_channels_last_contiguous(SymBool val = true) { - is_contiguous_ = std::move(val); + is_channels_last_contiguous_ = std::move(val); available_.fetch_or(is_channels_last_contiguous_avail); } void assume_channels_last_3d_contiguous(SymBool val = true) { diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index c890d6d084eb3..43efebea85d0f 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -195,6 +196,44 @@ void TensorImpl::_change_backend_component_keys(c10::Device device) { key_set_ = key_set | DispatchKeySet(new_backend); } +void TensorImpl::set_fake_device(c10::Device fake_device) { + TORCH_CHECK( + fake_device.type() != c10::DeviceType::Meta, + "FakeTensor does not support meta device"); + + // in python FakeTensor, it checks whether or not + // we are in in_kernel_invocation manager to determine + // which device we return + + // but since we have an extra field for fake_device_, + // we can just set it upon FakeTensor creation + // and determine in device_custom() which device to return + // (based on if DispatchKey::Fake is excluded or not) + get_extra_meta().fake_device_ = fake_device; + key_set_ = key_set_.add(DispatchKey::Fake); + + // we need this so that device() calls device_custom() + // where the fake device logic is instead of just calling device_default() + set_custom_device(true); + + // change backend key from Meta to the fake device + _change_backend_component_keys(fake_device); +} + +void TensorImpl::set_and_normalize_fake_device(c10::Device fake_device) { + // normalize device index for indexed device types (not CPU) + if (fake_device.index() == -1 && fake_device.type() != c10::DeviceType::CPU) { + const auto* guard_impl = c10::impl::getDeviceGuardImpl(fake_device.type()); + if (guard_impl) { + fake_device = guard_impl->getDevice(); + } + if (fake_device.index() == -1) { + fake_device = c10::Device(fake_device.type(), 0); + } + } + set_fake_device(fake_device); +} + void TensorImpl::HandleResize() { // If needed, we will free the data. the next mutable_data() call // will create the data storage. @@ -272,6 +311,18 @@ bool TensorImpl::compute_non_overlapping_and_dense() const { sizes_and_strides_.strides_arrayref()); } +void TensorImpl::incref_pyobject() const noexcept { + pyobj_slot_.incref(); +} + +void TensorImpl::decref_pyobject() const noexcept { + pyobj_slot_.decref(); +} + +bool TensorImpl::try_incref_pyobject() const noexcept { + return pyobj_slot_.try_incref(); +} + void TensorImpl::release_resources() { autograd_meta_.reset(); if (storage_) { @@ -375,6 +426,12 @@ c10::Device TensorImpl::device_custom() const { if (C10_UNLIKELY(python_custom_device_)) { return pyobj_slot_.load_pyobj_interpreter()->device(this); } + if (C10_UNLIKELY(extra_meta_ && extra_meta_->fake_device_.has_value())) { + if (c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Fake)) { + return device_default(); + } + return *extra_meta_->fake_device_; + } return device_default(); } @@ -988,30 +1045,6 @@ void TensorImpl::empty_tensor_restride_symint(MemoryFormat memory_format) { } } -void TensorImpl::incref_pyobject() const noexcept { - // Because intrusive_ptr incref uses relaxed memory order, we need to - // do an acquire fence to ensure that the kHasPyObject bit was - // observed before the load of the PyObject* below. - // NB: This is a no-op on x86/x86-64 - std::atomic_thread_fence(std::memory_order_acquire); - - PyObject* obj = pyobj_slot_.load_pyobj(); - (*pyobj_slot_.pyobj_interpreter())->incref(obj); -} - -void TensorImpl::decref_pyobject() const noexcept { - PyObject* obj = pyobj_slot_.load_pyobj(); - (*pyobj_slot_.pyobj_interpreter())->decref(obj); -} - -bool TensorImpl::try_incref_pyobject() const noexcept { - c10::impl::PyInterpreter* interp = pyobj_slot_.pyobj_interpreter(); - if (C10_UNLIKELY(!interp)) { - return false; - } - return (*interp)->try_incref(pyobj_slot_); -} - namespace impl { namespace { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index b680a9fd52c96..fedd9ecf5b9d6 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -237,12 +238,30 @@ struct C10_API BackendMeta : intrusive_ptr_target { } }; +// same as Python's FakeTensorMode +// storing shape env and converter from Python, we'll use these later +// to implement sym ints, real tensor conversion, etc +// this doesn't have caching because we're not implementing it +// no in_kernel_invocation_manager since that's handled by dispatch keys in C++ +struct C10_API FakeTensorMode { + std::shared_ptr shape_env_; + std::shared_ptr fake_tensor_converter_; + + FakeTensorMode( + std::shared_ptr shape_env, + std::shared_ptr converter) + : shape_env_(std::move(shape_env)), + fake_tensor_converter_(std::move(converter)) {} +}; + struct C10_API ExtraMeta { std::unique_ptr symbolic_shape_meta_ = nullptr; std::unique_ptr named_tensor_meta_ = nullptr; intrusive_ptr backend_meta_ = nullptr; std::optional custom_data_ptr_error_msg_ = std::nullopt; std::optional custom_storage_error_msg_ = std::nullopt; + std::optional fake_device_ = std::nullopt; + std::shared_ptr fake_tensor_mode_ = nullptr; ExtraMeta() = default; ~ExtraMeta() = default; @@ -263,6 +282,8 @@ struct C10_API ExtraMeta { if (other.custom_storage_error_msg_) { custom_storage_error_msg_ = other.custom_storage_error_msg_; } + fake_device_ = other.fake_device_; + fake_tensor_mode_ = other.fake_tensor_mode_; } ExtraMeta& operator=(const ExtraMeta& other) = delete; ExtraMeta(ExtraMeta&& other) = delete; @@ -1131,6 +1152,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return device_opt_.has_value() && device_opt_->type() == kMeta; } + bool is_fake() const { + return key_set_.has(DispatchKey::Fake); + } + bool is_cpu() const { // NB: This method is not virtual and avoid dispatches for performance // reasons. @@ -1429,6 +1454,34 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return key_set_.has_all(conjugate_ks); } + /** + * Transmute this meta tensor into a fake tensor + * The underlying device_opt_ stays as Meta for dispatch routing + * and fake device is stored in ExtraMeta and returned by device() + * via the device_policy_ mechanism + * also converting backend key from Meta to Fake and adding Fake key + * to DispatchKeySet + */ + + // this is the fast path: caller guarantees fake_device already has a valid + // index + void set_fake_device(c10::Device fake_device); + + // Normalizes the device index then calls set_fake_device. + // use when the device might lack an index ("cuda" vs "cuda:0"). + void set_and_normalize_fake_device(c10::Device fake_device); + + void set_fake_tensor_mode(std::shared_ptr mode) { + get_extra_meta().fake_tensor_mode_ = std::move(mode); + } + + std::shared_ptr fake_tensor_mode() const { + if (!extra_meta_) { + return nullptr; + } + return extra_meta_->fake_tensor_mode_; + } + /** * Set whether or not to take the conjugate of the tensor (flip the imaginary * bit). @@ -2185,9 +2238,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { } void incref_pyobject() const noexcept final; - void decref_pyobject() const noexcept final; - bool try_incref_pyobject() const noexcept final; private: diff --git a/c10/core/impl/DeviceGuardImplInterface.cpp b/c10/core/impl/DeviceGuardImplInterface.cpp index 428ea63c04151..6d69c31e3e1e6 100644 --- a/c10/core/impl/DeviceGuardImplInterface.cpp +++ b/c10/core/impl/DeviceGuardImplInterface.cpp @@ -21,11 +21,6 @@ DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( registerDeviceGuard(type, impl); } -namespace { -thread_local std::unique_ptr tls_fake_device_guard = - nullptr; -} // namespace - void ensureCUDADeviceGuardSet() { constexpr auto cuda_idx = static_cast(DeviceType::CUDA); @@ -38,8 +33,8 @@ void ensureCUDADeviceGuardSet() { // In following cases, we override CUDA guard interface with a no-op // device guard. When p->deviceCount() == 0, cuda build is enabled, but no // cuda devices available. - tls_fake_device_guard = std::make_unique>(); - device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get()); + static FakeGuardImpl fake_cuda_guard; + device_guard_impl_registry[cuda_idx].store(&fake_cuda_guard); } } diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 01ef32163bb5c..ef7a628835d97 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -228,6 +228,13 @@ struct C10_API DeviceGuardImplInterface { TORCH_CHECK(false, "Backend doesn't support synchronizing streams."); } + /** + * Return true if this stream is currently recording work for graph capture. + */ + virtual bool isStreamCapturing(const Stream& /*stream*/) const { + TORCH_CHECK(false, "Backend doesn't support stream capture query."); + } + /** * Wait (by blocking the calling thread) until all the work previously * recorded on the event has completed running on the device. diff --git a/c10/core/impl/LocalDispatchKeySet.cpp b/c10/core/impl/LocalDispatchKeySet.cpp index 0b3d87fce2410..c8277f19acc7c 100644 --- a/c10/core/impl/LocalDispatchKeySet.cpp +++ b/c10/core/impl/LocalDispatchKeySet.cpp @@ -46,29 +46,21 @@ void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set) { // RAII API IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKeySet include) - : tls_(&raw_local_dispatch_key_set), include_(include - tls_->included()) { - if (!include_.empty()) { - tls_->set_included(tls_->included() | include_); - } + : tls_(&raw_local_dispatch_key_set), saved_state_(tls_->included()) { + tls_->set_included(saved_state_ | include); } IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() { - if (!include_.empty()) { - tls_->set_included(tls_->included() - include_); - } + tls_->set_included(saved_state_); } ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKeySet exclude) - : tls_(&raw_local_dispatch_key_set), exclude_(exclude - tls_->excluded()) { - if (!exclude_.empty()) { - tls_->set_excluded(tls_->excluded() | exclude_); - } + : tls_(&raw_local_dispatch_key_set), saved_state_(tls_->excluded()) { + tls_->set_excluded(saved_state_ | exclude); } ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() { - if (!exclude_.empty()) { - tls_->set_excluded(tls_->excluded() - exclude_); - } + tls_->set_excluded(saved_state_); } // Non-RAII API diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index bba089bb2ad11..bd8fe811b2c06 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -94,7 +94,7 @@ class C10_API IncludeDispatchKeyGuard { // A little micro-optimization to save us from tls_get_addr call // on destruction PODLocalDispatchKeySet* tls_; - DispatchKeySet include_; + DispatchKeySet saved_state_; }; class C10_API ExcludeDispatchKeyGuard { @@ -112,7 +112,7 @@ class C10_API ExcludeDispatchKeyGuard { // A little micro-optimization to save us from tls_get_addr call // on destruction PODLocalDispatchKeySet* tls_; - DispatchKeySet exclude_; + DispatchKeySet saved_state_; }; struct C10_API ForceDispatchKeyGuard { diff --git a/c10/core/impl/PyObjectSlot.h b/c10/core/impl/PyObjectSlot.h index a0633401b3634..a576c1a17d2a4 100644 --- a/c10/core/impl/PyObjectSlot.h +++ b/c10/core/impl/PyObjectSlot.h @@ -49,6 +49,32 @@ struct C10_API PyObjectSlot { pyobj_interpreter_.store(nullptr, std::memory_order_relaxed); } + // Helper methods for incref/decref/try_incref of the stored PyObject. + // Used by intrusive_ptr_target subclasses (TensorImpl, StorageImpl, Node) + // to implement their virtual pyobject refcount overrides. + void incref() const noexcept { + // Because intrusive_ptr incref uses relaxed memory order, we need to + // do an acquire fence to ensure that the kHasPyObject bit was + // observed before the load of the PyObject* below. + // NB: This is a no-op on x86/x86-64 + std::atomic_thread_fence(std::memory_order_acquire); + PyObject* obj = load_pyobj(); + load_pyobj_interpreter()->incref(obj); + } + + void decref() const noexcept { + PyObject* obj = load_pyobj(); + load_pyobj_interpreter()->decref(obj); + } + + bool try_incref() const noexcept { + PyInterpreter* interp = pyobj_interpreter(); + if (C10_UNLIKELY(!interp)) { + return false; + } + return (*interp)->try_incref(*this); + } + private: // This is now always the global interpreter if the PyObject is set. // Maybe we can remove this field some day... diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 6c86a71676e55..2444e4e03b3ee 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -89,6 +89,9 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { void synchronizeStream(const Stream& stream) const override { impl_->synchronizeStream(stream); } + bool isStreamCapturing(const Stream& stream) const override { + return impl_->isStreamCapturing(stream); + } void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 5414d838cd8c4..f6fb7c9d6d9a8 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -109,6 +109,12 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } else if (key == "per_process_memory_fraction") { i = parsePerProcessMemoryFraction(tokenizer, i); used_native_specific_option = true; + } else if (key == "pinned_free_catch_all") { + i = parsePinnedFreeCatchAll(tokenizer, i); + used_native_specific_option = true; + } else if (key == "throw_on_cudamalloc_oom") { + i = parseThrowOnCudaMallocOom(tokenizer, i); + used_native_specific_option = true; } else { const auto& keys = c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); @@ -192,6 +198,25 @@ size_t CUDAAllocatorConfig::parsePinnedReserveSegmentSize( return i; } +size_t CUDAAllocatorConfig::parsePinnedFreeCatchAll( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { + tokenizer.checkToken(++i, ":"); + m_pinned_free_catch_all = tokenizer.toBool(++i); + return i; +} + +size_t CUDAAllocatorConfig::parseThrowOnCudaMallocOom( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { + // Format: throw_on_cudamalloc_oom:true or throw_on_cudamalloc_oom:false + // When enabled, throws OOM error before calling cudaMalloc if the allocation + // would likely fail due to insufficient memory. + tokenizer.checkToken(++i, ":"); + m_throw_on_cudamalloc_oom = tokenizer.toBool(++i); + return i; +} + REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(CUDAAllocatorConfig) } // namespace c10::cuda::CUDACachingAllocator diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index cd9c9b86285c4..928384a874394 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -34,7 +34,8 @@ class C10_CUDA_API CUDAAllocatorConfig { static bool expandable_segments() { bool enabled = c10::CachingAllocator::AcceleratorAllocatorConfig:: use_expandable_segments(); -#if !defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && !defined(USE_ROCM) +#if !defined(PYTORCH_C10_DRIVER_API_SUPPORTED) && \ + (!defined(USE_ROCM) || (ROCM_VERSION < 70000)) if (enabled) { TORCH_WARN_ONCE("expandable_segments not supported on this platform") } @@ -65,6 +66,13 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_per_process_memory_fraction; } + // When enabled, throws OOM error before calling cudaMalloc if the allocation + // would likely fail due to insufficient memory. This provides early failure + // with clear error messages instead of letting cudaMalloc fail. + static bool throw_on_cudamalloc_oom() { + return instance().m_throw_on_cudamalloc_oom; + } + /** Pinned memory allocator settings */ static bool pinned_use_cuda_host_register() { return instance().m_pinned_use_cuda_host_register; @@ -92,6 +100,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return 128; } + static bool pinned_free_catch_all() { + return instance().m_pinned_free_catch_all; + } + C10_DEPRECATED_MESSAGE( "c10::cuda::CUDACachingAllocator::CUDAAllocatorConfig::roundup_power2_divisions() is deprecated. Please use c10::CachingAllocator::AcceleratorAllocatorConfig::roundup_power2_divisions() instead.") static size_t roundup_power2_divisions(size_t size) { @@ -157,7 +169,9 @@ class C10_CUDA_API CUDAAllocatorConfig { "graph_capture_record_stream_reuse", "pinned_reserve_segment_size_mb", "pinned_num_register_threads", - "per_process_memory_fraction"}; + "per_process_memory_fraction", + "pinned_free_catch_all", + "throw_on_cudamalloc_oom"}; return keys; } @@ -185,6 +199,12 @@ class C10_CUDA_API CUDAAllocatorConfig { double parsePerProcessMemoryFraction( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); + size_t parsePinnedFreeCatchAll( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); + size_t parseThrowOnCudaMallocOom( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); std::atomic m_pinned_num_register_threads{1}; std::atomic m_pinned_reserve_segment_size_mb{0}; @@ -198,6 +218,10 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_pinned_use_cuda_host_register{false}; std::atomic m_graph_capture_record_stream_reuse{false}; std::atomic m_per_process_memory_fraction{1.0}; + std::atomic m_pinned_free_catch_all{false}; + // When true, throw OOM error before calling cudaMalloc if allocation would + // fail + std::atomic m_throw_on_cudamalloc_oom{false}; }; // Keep this for backwards compatibility diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 6e94a5f694bf8..774c79102b8d6 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -139,7 +141,7 @@ namespace Native { */ // counter to track order for Mempool Registration -thread_local int32_t registration_counter_global = -1; +std::atomic registration_counter_global{-1}; static char SHAREABLE_HANDLE_VERSION = 2; enum ShareableHandleType : char { @@ -164,12 +166,12 @@ void decrease_stat_array( struct Block; struct PrivatePool; typedef bool (*Comparison)(const Block*, const Block*); -static bool BlockComparatorSize(const Block* a, const Block* b); +static bool BlockComparatorRegistrationCounter(const Block* a, const Block* b); static bool BlockComparatorAddress(const Block* a, const Block* b); struct BlockPool { BlockPool(bool small, PrivatePool* private_pool = nullptr) - : blocks(BlockComparatorSize), + : blocks(BlockComparatorRegistrationCounter), unmapped(BlockComparatorAddress), is_small(small), owner_PrivatePool(private_pool) {} @@ -232,14 +234,16 @@ struct Block { requested_size(0), pool(pool), ptr(ptr) { - registration_counter = ++registration_counter_global; + registration_counter = + registration_counter_global.fetch_add(1, std::memory_order_relaxed) + 1; } // constructor for search key + // Use the default value for registration_counter and not modify + // registration_counter_global, because the search key is just a + // dummy placeholder. Block(c10::DeviceIndex device, cudaStream_t stream, size_t size) - : device(device), stream(stream), size(size), requested_size(0) { - registration_counter = ++registration_counter_global; - } + : device(device), stream(stream), size(size), requested_size(0) {} size_t gc_count() { TORCH_INTERNAL_ASSERT(pool); @@ -459,9 +463,15 @@ struct ExpandableSegment { if (enable_ipc_handles) { if (CUDAAllocatorConfig::expandable_segments_handle_type() != Expandable_Segments_Handle_Type::FABRIC_HANDLE) { +#ifdef USE_ROCM + prop.requestedHandleType = hipMemHandleTypePosixFileDescriptor; +#else prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; +#endif } else { +#ifndef USE_ROCM prop.requestedHandleTypes = CU_MEM_HANDLE_TYPE_FABRIC; +#endif } } int flag = 0; @@ -679,11 +689,15 @@ struct ExpandableSegment { C10_CUDA_CHECK(hipMemImportFromShareableHandle( &handle, myfd_handle, hipMemHandleTypePosixFileDescriptor)); #else - C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( - &handle, - // NOLINTNEXTLINE(performance-no-int-to-ptr) - (void*)(uintptr_t)myfd, - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + C10_CUDA_DRIVER_CHECK_MSG( + DriverAPI::get()->cuMemImportFromShareableHandle_( + &handle, + // NOLINTNEXTLINE(performance-no-int-to-ptr) + (void*)(uintptr_t)myfd, + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + " fabric_info: {", + get_nvml_fabric_info(device), + "}"); #endif LOG(INFO) << "use posix fd to import expandable segments."; close(static_cast(myfd)); @@ -702,11 +716,15 @@ struct ExpandableSegment { buf.read( reinterpret_cast(&fabric_handle), sizeof(CUmemFabricHandle)); CUmemGenericAllocationHandle handle = 0; - C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemImportFromShareableHandle_( - &handle, - // NOLINTNEXTLINE(performance-no-int-to-ptr) - (void*)&fabric_handle, - CU_MEM_HANDLE_TYPE_FABRIC)); + C10_CUDA_DRIVER_CHECK_MSG( + DriverAPI::get()->cuMemImportFromShareableHandle_( + &handle, + // NOLINTNEXTLINE(performance-no-int-to-ptr) + (void*)&fabric_handle, + CU_MEM_HANDLE_TYPE_FABRIC), + " fabric_info: {", + get_nvml_fabric_info(device), + "}"); LOG(INFO) << "use fabric handle to import expandable segments."; segment->handles_.emplace_back(Handle{handle, std::nullopt}); } @@ -756,20 +774,32 @@ struct ExpandableSegment { private: void setAccess(c10::DeviceIndex device, size_t begin, size_t end) { - CUmemAccessDesc desc; - desc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; +#if defined(USE_ROCM) && (ROCM_VERSION >= 70200) + constexpr int num_desc = 2; + CUmemAccessDesc desc[num_desc]; + desc[1].location.type = CU_MEM_LOCATION_TYPE_HOST; + desc[1].location.id = 0; // ignored + desc[1].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; +#else + constexpr int num_desc = 1; + CUmemAccessDesc desc[num_desc]; +#endif + desc[0].location.type = CU_MEM_LOCATION_TYPE_DEVICE; // NOLINTNEXTLINE(bugprone-signed-char-misuse) - desc.location.id = static_cast(device); - desc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + desc[0].location.id = static_cast(device); + desc[0].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; #ifdef USE_ROCM C10_CUDA_CHECK(hipMemSetAccess( ptr() + begin * segment_size_, (end - begin) * segment_size_, - &desc, - 1)); + &desc[0], + num_desc)); #else C10_CUDA_DRIVER_CHECK(DriverAPI::get()->cuMemSetAccess_( - ptr_ + begin * segment_size_, (end - begin) * segment_size_, &desc, 1)); + ptr_ + begin * segment_size_, + (end - begin) * segment_size_, + &desc[0], + num_desc)); #endif } @@ -978,15 +1008,16 @@ struct RestoreResult { std::vector allocations_created; }; -bool BlockComparatorSize(const Block* a, const Block* b) { +bool BlockComparatorRegistrationCounter(const Block* a, const Block* b) { if (a->stream != b->stream) { return (uintptr_t)a->stream < (uintptr_t)b->stream; } if (a->size != b->size) { return a->size < b->size; } - return (uintptr_t)a->ptr < (uintptr_t)b->ptr; + return a->registration_counter < b->registration_counter; } + bool BlockComparatorAddress(const Block* a, const Block* b) { if (a->stream != b->stream) { return (uintptr_t)a->stream < (uintptr_t)b->stream; @@ -994,6 +1025,14 @@ bool BlockComparatorAddress(const Block* a, const Block* b) { return (uintptr_t)a->ptr < (uintptr_t)b->ptr; } +// Info about OOM rejection, used to defer observer callbacks outside of lock +struct OomRejectionInfo { + bool rejected{false}; + size_t alloc_size{0}; + size_t total_allocated{0}; + size_t device_total{0}; +}; + struct AllocParams { AllocParams( c10::DeviceIndex device, @@ -1024,6 +1063,7 @@ struct AllocParams { Block* block{nullptr}; StatTypes stat_types = {false}; cudaError_t err{cudaSuccess}; + OomRejectionInfo oom_rejection_info; }; // Note: cudaEventCreate when concurrently invoked from multiple threads can be @@ -1387,6 +1427,7 @@ class DeviceCachingAllocator { // XXX - maybe we should generalize and have multiple events std::vector oom_observers_; + std::vector oom_rejection_observers_; std::vector trace_trackers_; @@ -1495,6 +1536,10 @@ class DeviceCachingAllocator { oom_observers_.emplace_back(std::move(observer)); } + void attachOomRejectionObserver(OomRejectionObserver observer) { + oom_rejection_observers_.emplace_back(std::move(observer)); + } + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { std::unique_lock lock(mutex); trace_trackers_.emplace_back(std::move(tracker)); @@ -1575,19 +1620,28 @@ class DeviceCachingAllocator { // cudaMalloc. So far this function has not modified allocator state, but // keep in mind that any observed allocator state may change across calls // to alloc_block since it may release the lock. - block_found = alloc_block(params, false, context, lock) - // Try to use memory pools that have opted in as overflow before - // expensive memory freeing operations. - || try_mempool_fallback( - params, size, stream, device_id, alloc_size, stats) - // Free enough available cached blocks to satisfy alloc and retry - // alloc. - || (release_available_cached_blocks(params, context) && - alloc_block(params, false, context, lock)) - // Free all non-split cached blocks and retry alloc. - || (C10_LIKELY(captures_underway.empty()) && - release_cached_blocks(context, {0, 0}) && - alloc_block(params, true, context, lock)); + block_found = alloc_block(params, false, context, lock); + + // If allocation was rejected by OOM policy, skip retry chain and fail + // immediately + if (!block_found && params.oom_rejection_info.rejected) { + // Skip retry chain - will be handled below in the !block_found path + } else if (!block_found) { + // Normal retry chain: try various strategies to free memory and retry + block_found = + // Try to use memory pools that have opted in as overflow before + // expensive memory freeing operations. + try_mempool_fallback( + params, size, stream, device_id, alloc_size, stats) + // Free enough available cached blocks to satisfy alloc and retry + // alloc. + || (release_available_cached_blocks(params, context) && + alloc_block(params, false, context, lock)) + // Free all non-split cached blocks and retry alloc. + || (C10_LIKELY(captures_underway.empty()) && + release_cached_blocks(context, {0, 0}) && + alloc_block(params, true, context, lock)); + } } if (!block_found) { @@ -1595,6 +1649,43 @@ class DeviceCachingAllocator { // alloc_block should have thrown an exception already. TORCH_INTERNAL_ASSERT(params.err == cudaErrorMemoryAllocation); + // Handle OOM rejection separately - return nullptr instead of throwing + // This allows callers to handle rejection gracefully without crashing + if (params.oom_rejection_info.rejected) { + if (!oom_rejection_observers_.empty()) { + auto observers_local = oom_rejection_observers_; + auto rejection_info = params.oom_rejection_info; + auto device = device_id; + + // Release lock before dispatching observers + lock.unlock(); + + // Dispatch observers asynchronously to minimize latency impact on + // rejection path + std::thread([observers_local = std::move(observers_local), + rejection_info, + device]() { + try { + for (const auto& observer : observers_local) { + observer( + device, + rejection_info.alloc_size, + rejection_info.total_allocated, + rejection_info.device_total); + } + } catch (const std::exception& e) { + LOG(ERROR) << "Exception in OOM rejection observer: " << e.what(); + } catch (...) { + LOG(ERROR) << "Unknown exception in OOM rejection observer"; + } + }).detach(); + } else { + lock.unlock(); + } + + return nullptr; + } + size_t device_free = 0; size_t device_total = 0; C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); @@ -2294,6 +2385,7 @@ class DeviceCachingAllocator { stats.num_sync_all_streams = 0; stats.num_device_alloc = 0; stats.num_device_free = 0; + stats.num_oom_rejections = 0; stats.oversize_allocations.reset_accumulated(); stats.oversize_segments.reset_accumulated(); } @@ -3452,8 +3544,37 @@ class DeviceCachingAllocator { total_allocated_memory + size > allowed_memory_maximum.value()) { p.err = cudaErrorMemoryAllocation; return false; - // Temporarily disable checkpointing & cudagraphs internally - } else if ( + } + + // When throw_on_cudamalloc_oom is enabled, check + // per_process_memory_fraction to reject allocations that would exceed the + // configured limit. This prevents fatal HSA/driver aborts by throwing + // OutOfMemoryError instead. + if (CUDAAllocatorConfig::throw_on_cudamalloc_oom()) { + size_t device_total = static_cast(device_prop.totalGlobalMem); + size_t max_allowed = static_cast( + CUDAAllocatorConfig::per_process_memory_fraction() * + static_cast(device_total)); + if (total_allocated_memory + size > max_allowed) { + stats.num_oom_rejections++; + p.oom_rejection_info = { + true, size, total_allocated_memory, device_total}; + C10_LOG_EVERY_MS(WARNING, 60000) + << "Preemptively rejecting allocation: requested " + << format_size(size) << " but total_allocated_memory (" + << format_size(total_allocated_memory) + << ") + requested size would exceed memory limit (" + << format_size(max_allowed) << " = " + << CUDAAllocatorConfig::per_process_memory_fraction() * 100 + << "% of " << format_size(device_total) + << "). Set throw_on_cudamalloc_oom:false to disable."; + p.err = cudaErrorMemoryAllocation; + return false; + } + } + + if ( + // Temporarily disable checkpointing & cudagraphs internally p.is_expandable_segments_active && !(in_fbcode && p.pool->owner_PrivatePool)) { p.block = try_allocate_expandable_block( @@ -3995,7 +4116,7 @@ class DeviceCachingAllocator { if (record_history) { // Skip if action is in the skip_actions set - bool should_skip = skip_actions_list.count(action) > 0; + bool should_skip = skip_actions_list.contains(action); if (!should_skip) { alloc_buffer.insertEntries(te); } @@ -4125,6 +4246,18 @@ class NativeCachingAllocator : public CUDAAllocator { device, ": did you call init?"); Block* block = device_allocator[device]->malloc(size, stream); + // block can be nullptr if allocation was rejected by + // throw_on_cudamalloc_oom + // + per_process_memory_fraction policy. Throw OutOfMemoryError so callers + // with OOM error handlers can catch it gracefully. + if (C10_UNLIKELY(!block)) { + TORCH_CHECK_WITH( + OutOfMemoryError, + false, + "CUDA out of memory. Allocation was preemptively rejected because " + "it would exceed per_process_memory_fraction limit. Requested size: ", + format_size(size)); + } add_allocated_block(block); *devPtr = block->ptr; const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); @@ -4276,6 +4409,12 @@ class NativeCachingAllocator : public CUDAAllocator { } } + void attachOomRejectionObserver(OomRejectionObserver observer) override { + for (auto& allocator : device_allocator) { + allocator->attachOomRejectionObserver(observer); + } + } + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override { for (auto& allocator : device_allocator) { allocator->attachAllocatorTraceTracker(tracker); diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 9cd6b4ee128b6..c2d502976dda6 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -95,14 +95,26 @@ using OutOfMemoryObserver = std::function; +// Observer called when an allocation is preemptively rejected due to +// throw_on_cudamalloc_oom policy. Parameters: +// - device: GPU device index +// - alloc_size: size of the rejected allocation request +// - total_allocated: total memory allocated before the request +// - device_total: total GPU memory +using OomRejectionObserver = std::function; + struct ShareableHandle { ptrdiff_t offset; std::string handle; }; struct StreamSegmentSize { - StreamSegmentSize(cudaStream_t s, bool small, size_t sz) - : stream(s), is_small_pool(small), total_size(sz) {} + StreamSegmentSize(cudaStream_t s, bool small_, size_t sz) + : stream(s), is_small_pool(small_), total_size(sz) {} cudaStream_t stream; bool is_small_pool; size_t total_size; @@ -216,6 +228,7 @@ class CUDAAllocator : public DeviceAllocator { return ""; } virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; + virtual void attachOomRejectionObserver(OomRejectionObserver observer) = 0; // Attached AllocatorTraceTracker callbacks will be called while the // per-device allocator lock is held. Any additional locks taken from within @@ -421,6 +434,10 @@ inline void attachOutOfMemoryObserver(OutOfMemoryObserver observer) { get()->attachOutOfMemoryObserver(std::move(observer)); } +inline void attachOomRejectionObserver(OomRejectionObserver observer) { + get()->attachOomRejectionObserver(std::move(observer)); +} + inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { get()->attachAllocatorTraceTracker(std::move(tracker)); } diff --git a/c10/cuda/CUDADeviceAssertionHost.cpp b/c10/cuda/CUDADeviceAssertionHost.cpp index df2796368bd32..1bfa150326feb 100644 --- a/c10/cuda/CUDADeviceAssertionHost.cpp +++ b/c10/cuda/CUDADeviceAssertionHost.cpp @@ -188,7 +188,7 @@ std::string c10_retrieve_device_side_assertion_info() { } return oss.str(); #else - return "Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.\n"; + return ""; #endif } diff --git a/c10/cuda/CUDADeviceAssertionHost.h b/c10/cuda/CUDADeviceAssertionHost.h index 7907376cd019e..1338959ccf419 100644 --- a/c10/cuda/CUDADeviceAssertionHost.h +++ b/c10/cuda/CUDADeviceAssertionHost.h @@ -9,7 +9,7 @@ #include #include -#if defined(USE_CUDA) || defined(USE_ROCM) +#if (defined(USE_CUDA) || defined(USE_ROCM)) && !defined(TORCH_USE_CUDA_DSA) #define TORCH_USE_CUDA_DSA #endif diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index f0dbe49d2ea6c..c5451d0e1ac9c 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -30,7 +30,7 @@ void c10_cuda_check_implementation( const char* error_string = cudaGetErrorString(cuda_error); check_message.append(error_string); check_message.append(c10::cuda::get_cuda_error_help(cuda_error)); - check_message.append(c10::cuda::get_cuda_check_suffix()); + check_message.append(c10::cuda::get_cuda_async_error_suffix(cuda_error)); check_message.append("\n"); if (include_device_assertions) { check_message.append(c10_retrieve_device_side_assertion_info()); diff --git a/c10/cuda/CUDAGraphsC10Utils.h b/c10/cuda/CUDAGraphsC10Utils.h index 936875fd71d5c..3709aa64c7845 100644 --- a/c10/cuda/CUDAGraphsC10Utils.h +++ b/c10/cuda/CUDAGraphsC10Utils.h @@ -1,8 +1,10 @@ #pragma once +#include #include + #include -#include +#include // CUDA Graphs utils used by c10 and aten. // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. @@ -67,10 +69,41 @@ inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { // Use this version where you're sure a CUDA context exists already. inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { - cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone}; + cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone}; C10_CUDA_CHECK( - cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); - return CaptureStatus(is_capturing); + cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &status)); + return CaptureStatus(status); +} + +inline CaptureStatus captureStatusMayInitCtx(cudaStream_t stream) { + cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone}; + C10_CUDA_CHECK(cudaStreamIsCapturing(stream, &status)); + return CaptureStatus(status); +} + +inline bool isStreamCapturingMayInitCtx(cudaStream_t stream) { + return captureStatusMayInitCtx(stream) == CaptureStatus::Active; +} + +inline std::optional currentStreamCaptureIdMayInitCtx() { + cudaStreamCaptureStatus status{}; + CaptureId_t capture_id = 0; + C10_CUDA_CHECK(cudaStreamGetCaptureInfo( + c10::cuda::getCurrentCUDAStream(), &status, &capture_id)); + if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) { + return capture_id; + } + return std::nullopt; +} + +inline std::optional captureIdMayInitCtx(cudaStream_t stream) { + cudaStreamCaptureStatus status{}; + CaptureId_t capture_id = 0; + C10_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id)); + if (status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) { + return capture_id; + } + return std::nullopt; } } // namespace c10::cuda diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index cc8bec049fc3f..d634de1d12718 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -668,6 +668,13 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { "If you need it, please file an issue describing your use case."); } + void attachOomRejectionObserver(OomRejectionObserver observer) override { + TORCH_CHECK( + false, + "cudaMallocAsync does not yet support attachOomRejectionObserver. " + "If you need it, please file an issue describing your use case."); + } + void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) override { TORCH_CHECK( false, @@ -790,8 +797,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { "(For backend:native, snapshot returns a detailed summary of all " "blocks tracked by the allocator, but the cudaMallocAsync backend " "does not track individual blocks.)"); - // Alternative: TORCH_WARN - return {}; } // CUDAGraph interactions diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index 70bb0f841b35c..ffb75aa46e4d8 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -42,6 +42,33 @@ const char* get_cuda_check_suffix() noexcept { "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"; } } +// NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) +const char* get_cuda_async_error_suffix(cudaError_t error) noexcept { + switch (error) { + case cudaErrorLaunchFailure: + case cudaErrorIllegalAddress: + case cudaErrorAssert: +#ifndef USE_ROCM + case cudaErrorIllegalInstruction: + case cudaErrorMisalignedAddress: +#endif + { + static auto device_blocking_flag = + c10::utils::check_env("CUDA_LAUNCH_BLOCKING"); + static bool blocking_enabled = device_blocking_flag.value_or(false); + if (!blocking_enabled) { + return "\nCUDA kernel errors might be asynchronously reported at some" + " other API call, so the stacktrace below might be incorrect." + "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1"; + } + return ""; + } + default: + return "\nFor more detailed error information, run with" + " CUDA_LOG_FILE=stderr"; + } +} + std::mutex* getFreeMutex() { static std::mutex cuda_free_mutex; return &cuda_free_mutex; diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h index bdb2f9998ecd1..32d5647fa23ff 100644 --- a/c10/cuda/CUDAMiscFunctions.h +++ b/c10/cuda/CUDAMiscFunctions.h @@ -11,5 +11,7 @@ namespace c10::cuda { C10_CUDA_API std::string get_cuda_error_help(cudaError_t /*error*/) noexcept; C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +C10_CUDA_API const char* get_cuda_async_error_suffix( + cudaError_t error) noexcept; C10_CUDA_API std::mutex* getFreeMutex(); } // namespace c10::cuda diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index f27c7c9176631..183232c4d9191 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -116,6 +116,13 @@ class C10_CUDA_API CUDAStream { void synchronize() const; + bool is_capturing() const { + DeviceGuard guard{stream_.device()}; + cudaStreamCaptureStatus status{cudaStreamCaptureStatusNone}; + C10_CUDA_CHECK(cudaStreamIsCapturing(stream(), &status)); + return status != cudaStreamCaptureStatusNone; + } + int priority() const { DeviceGuard guard{stream_.device()}; int priority = 0; diff --git a/c10/cuda/PeerToPeerAccess.cpp b/c10/cuda/PeerToPeerAccess.cpp index 526d794bc49d5..cc312bdfaa4cb 100644 --- a/c10/cuda/PeerToPeerAccess.cpp +++ b/c10/cuda/PeerToPeerAccess.cpp @@ -10,12 +10,15 @@ #include #include +#include +#include #include namespace c10::cuda { static std::vector p2pAccessEnabled_; static std::vector fabricAccessEnabled_; +static std::vector fabricCliqueId_; static int64_t num_devices_ = -1; namespace detail { @@ -35,6 +38,8 @@ void init_p2p_access_cache(int64_t num_devices) { } fabricAccessEnabled_.clear(); fabricAccessEnabled_.resize(num_devices, -1); + fabricCliqueId_.clear(); + fabricCliqueId_.resize(num_devices, kCliqueIdNotQueried); } } // namespace detail @@ -180,6 +185,7 @@ bool get_fabric_access(c10::DeviceIndex dev) { fabricInfo.version = nvmlGpuFabricInfo_v2; if (DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_ == nullptr) { cache = 0; + fabricCliqueId_[dev] = kCliqueIdUnsupported; return false; } TORCH_CHECK( @@ -188,13 +194,17 @@ bool get_fabric_access(c10::DeviceIndex dev) { nvml_device, &fabricInfo)); auto state = fabricInfo.state != NVML_GPU_FABRIC_STATE_NOT_SUPPORTED; if (state) { + fabricCliqueId_[dev] = static_cast(fabricInfo.cliqueId); // now perform the full cycle of allocating - exporting - importing memory state = isFabricSupported(); + } else { + fabricCliqueId_[dev] = kCliqueIdUnsupported; } cache = state ? 1 : 0; return cache; } else { cache = 0; + fabricCliqueId_[dev] = kCliqueIdUnsupported; return false; } #else @@ -203,4 +213,85 @@ bool get_fabric_access(c10::DeviceIndex dev) { #endif } +int get_fabric_clique_id(c10::DeviceIndex dev) { +#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040 && \ + defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + // Ensure cache is populated via get_fabric_access (which does the NVML query + // and stashes clique_id as a side effect). + get_fabric_access(dev); + return fabricCliqueId_[dev]; +#else + (void)dev; + return kCliqueIdUnsupported; +#endif +} + +std::string get_nvml_fabric_info([[maybe_unused]] c10::DeviceIndex dev) { +#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12040 && \ + defined(PYTORCH_C10_DRIVER_API_SUPPORTED) + if (DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_ == nullptr) { + return "fabric info unsupported (nvmlDeviceGetGpuFabricInfoV not available)"; + } + + auto nvml_device = get_nvml_device(dev); + if (nvml_device == nullptr) { + return "fabric info unknown (failed to get NVML device handle)"; + } + + nvmlGpuFabricInfoV_t info{}; +#ifdef nvmlGpuFabricInfo_v3 + bool has_health_summary = false; + info.version = nvmlGpuFabricInfo_v3; + if (DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_(nvml_device, &info) == + NVML_SUCCESS) { + has_health_summary = true; + } else +#endif + { + info = {}; + info.version = nvmlGpuFabricInfo_v2; + if (DriverAPI::get()->nvmlDeviceGetGpuFabricInfoV_(nvml_device, &info) != + NVML_SUCCESS) { + return "fabric info unknown (nvmlDeviceGetGpuFabricInfoV failed)"; + } + } + + char uuid_hex[33]; + for (int i = 0; i < 16; ++i) { + snprintf(uuid_hex + i * 2, 3, "%02x", info.clusterUuid[i]); + } + + const char* state_str = "unknown"; + switch (info.state) { + case NVML_GPU_FABRIC_STATE_NOT_SUPPORTED: + state_str = "not_supported"; + break; + case NVML_GPU_FABRIC_STATE_NOT_STARTED: + state_str = "not_started"; + break; + case NVML_GPU_FABRIC_STATE_IN_PROGRESS: + state_str = "in_progress"; + break; + case NVML_GPU_FABRIC_STATE_COMPLETED: + state_str = "completed"; + break; + } + + std::ostringstream oss; + oss << "clique_id=" << info.cliqueId << ", cluster_uuid=" << uuid_hex + << ", state=" << state_str << ", status=" << info.status + << ", health_mask=0x" << std::hex << std::setfill('0') << std::setw(8) + << info.healthMask; +#ifdef nvmlGpuFabricInfo_v3 + if (has_health_summary) { + oss << ", health_summary=" << std::dec + << static_cast(info.healthSummary); + } +#endif + return oss.str(); +#else + return "fabric info unsupported (requires CUDA >= 12.4)"; +#endif +} + } // namespace c10::cuda diff --git a/c10/cuda/PeerToPeerAccess.h b/c10/cuda/PeerToPeerAccess.h index d526843974264..8615f1f335d96 100644 --- a/c10/cuda/PeerToPeerAccess.h +++ b/c10/cuda/PeerToPeerAccess.h @@ -5,6 +5,7 @@ #include #include +#include namespace c10::cuda { @@ -32,4 +33,17 @@ C10_CUDA_API bool get_p2p_access( /// @return true if fabric access is available, false otherwise. C10_CUDA_API bool get_fabric_access(c10::DeviceIndex device); +constexpr int kCliqueIdNotQueried = -2; +constexpr int kCliqueIdUnsupported = -1; + +/// Query the NVLink fabric clique ID for a device. +/// Returns the clique ID (>= 0) if fabric is supported, or kCliqueIdUnsupported +/// if unsupported. +C10_CUDA_API int get_fabric_clique_id(c10::DeviceIndex device); + +/// Returns a formatted string with NVML fabric info (clique_id, cluster_uuid, +/// state, status, health_mask) for the given device. Intended for error +/// diagnostics — only call on failure paths. +C10_CUDA_API std::string get_nvml_fabric_info(c10::DeviceIndex device); + } // namespace c10::cuda diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 970a3f4584117..681ebf49a1452 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -20,6 +20,23 @@ } \ } while (0) +// clang-format off +#define C10_CUDA_DRIVER_CHECK_MSG(EXPR, ...) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + CUresult get_error_str_err [[maybe_unused]] = \ + c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: unknown error", __VA_ARGS__);\ + } else { \ + TORCH_CHECK(false, "CUDA driver error: ", err_str, __VA_ARGS__); \ + } \ + } \ + } while (0) +// clang-format on + #define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \ do { \ CUresult __err = EXPR; \ diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 14a398ee4dc37..c9617271cd498 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -213,6 +213,11 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { cuda_stream.synchronize(); } + bool isStreamCapturing(const Stream& stream) const override { + CUDAStream cuda_stream{stream}; + return cuda_stream.is_capturing(); + } + void synchronizeEvent(void* event) const override { if (!event) return; diff --git a/c10/metal/common.h b/c10/metal/common.h index 87e6400c92252..1d3b681c12a70 100644 --- a/c10/metal/common.h +++ b/c10/metal/common.h @@ -38,6 +38,20 @@ template using array = std::array; #endif +// Integer ceiling division: ceil(a / b). Usable from both host code and +// Metal shaders (where the overload is selected by ADL via `using namespace +// c10::metal;` in shader sources). +template +inline T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +// Round `a` up to the next multiple of `b`: ceil(a / b) * b. +template +inline T round_up(T a, T b) { + return ceil_div(a, b) * b; +} + enum class ScalarType { #define _DEFINE_ENUM_VAL_(_v, _n) _v = _n, C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_) diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 7d930a4665067..ffbec0a3d71d1 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -53,6 +53,25 @@ kernel void unary_dense( output[index] = f(input[index]); } +template +kernel void unary_dense_vec4( + device result_of* output [[buffer(0)]], + constant T* input [[buffer(1)]], + constant uint& numel [[buffer(2)]], + uint index [[thread_position_in_grid]]) { + F f; + uint base = index * 4; + if (base + 4 <= numel) { + using ::metal::vec; + vec val = *(constant vec*)(input + base); + *(device vec, 4>*)(output + base) = { + f(val.x), f(val.y), f(val.z), f(val.w)}; + } else { + for (uint i = base; i < numel; i++) + output[i] = f(input[i]); + } +} + template kernel void unary_strided( device result_of* output [[buffer(0)]], @@ -90,6 +109,18 @@ kernel void unary_strided( constant uint& ndim, \ uint index) +#define REGISTER_UNARY_VEC4_OP(NAME, DTYPE0, DTYPE1) \ + static_assert( \ + ::metal:: \ + is_same_v>, \ + "Output dtype mismatch for unary op " #NAME " and input " #DTYPE0); \ + template [[host_name(#NAME "_dense_vec4_" #DTYPE1 "_" #DTYPE0)]] \ + kernel void ::c10::metal::unary_dense_vec4( \ + device ::c10::metal::result_of * output, \ + constant DTYPE0 * input, \ + constant uint & numel, \ + uint index) + #define DEFINE_UNARY_FLOATING_FUNCTOR(NAME) \ struct NAME##_functor { \ template \ diff --git a/c10/metal/random.h b/c10/metal/random.h index c03d9b8a3149c..647b860b3f22c 100644 --- a/c10/metal/random.h +++ b/c10/metal/random.h @@ -33,7 +33,7 @@ uint4 single_round(uint4 ctr, uint2 key) { constexpr uint kPhiloxSB = 0xCD9E8D57; auto rc0 = mulhilo(kPhiloxSA, ctr.x); auto rc1 = mulhilo(kPhiloxSB, ctr.z); - return uint4(rc1.y ^ ctr.y ^ key.x, rc1.x, rc0.y ^ ctr.w ^ key.y, rc0.x); + return uint4(rc1.x ^ ctr.y ^ key.x, rc1.y, rc0.x ^ ctr.w ^ key.y, rc0.y); } uint4 multiple_rounds(uint4 ctr, uint2 key, uint rounds) { diff --git a/c10/metal/reduction_utils.h b/c10/metal/reduction_utils.h index 2d97820191663..0c3a81747d046 100644 --- a/c10/metal/reduction_utils.h +++ b/c10/metal/reduction_utils.h @@ -22,15 +22,33 @@ struct simd_type { } // namespace detail template -inline ::metal::enable_if_t, T> simd_sum(T val) { +inline ::metal:: + enable_if_t && !c10::metal::is_complex_v, T> + simd_sum(T val) { return T(::metal::simd_sum(detail::simd_type_t(val))); } +inline float2 simd_sum(float2 val) { + return float2(::metal::simd_sum(val.x), ::metal::simd_sum(val.y)); +} + template -inline ::metal::enable_if_t, T> simd_prod(T val) { +inline ::metal:: + enable_if_t && !c10::metal::is_complex_v, T> + simd_prod(T val) { return T(::metal::simd_product(detail::simd_type_t(val))); } +// Complex product reduction via shuffle, using c10::metal::mul for (a+bi)(c+di) +// Uses simd_shuffle_and_fill_down with identity (1+0i) for inactive lanes. +inline float2 simd_prod(float2 val) { + for (ushort i = simdgroup_size / 2; i > 0; i /= 2) { + val = c10::metal::mul( + val, ::metal::simd_shuffle_and_fill_down(val, float2(1, 0), i)); + } + return val; +} + // Extend simd_broadcast to 64-bit integral types using int2 trick template < typename T, diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8d6131d867068..ead38cbfff1ec 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -723,6 +723,19 @@ inline T logaddexp2(T a, T b) { } } +template +inline float xlogy(T x, T y) { + if (::metal::isnan(y)) { + return NAN; + } + + if (x == 0) { + return x; + } + + return x * precise::log(float(y)); +} + template inline float xlog1py(T x, T y) { if (::metal::isnan(y)) { diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 357348a26b20d..34f05acae6281 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -82,6 +82,11 @@ struct OpMathType { using type = float; }; +template <> +struct OpMathType { + using type = float2; +}; + // Type promotion structure for higher precision accumulation template struct AccumulationType { diff --git a/c10/test/CMakeLists.txt b/c10/test/CMakeLists.txt index 27385159976f4..9a576bb996f22 100644 --- a/c10/test/CMakeLists.txt +++ b/c10/test/CMakeLists.txt @@ -18,12 +18,12 @@ if(BUILD_TEST) endforeach() endif() -# ---[ C++17 header compilation test +# ---[ C++17/20 header compilation test if(BUILD_TEST) add_executable(c10_cpp17_header_build_test util/cpp17_header_build_check.cpp) target_link_libraries(c10_cpp17_header_build_test ${C10_LIB} gmock gtest gtest_main) set_target_properties(c10_cpp17_header_build_test PROPERTIES - CXX_STANDARD 17 + CXX_STANDARD 20 CXX_STANDARD_REQUIRED ON ) add_test(NAME c10_cpp17_header_build_test COMMAND $) diff --git a/c10/test/core/DeviceGuard_test.cpp b/c10/test/core/DeviceGuard_test.cpp index a6148adf8af25..6ae4fc1dd6774 100644 --- a/c10/test/core/DeviceGuard_test.cpp +++ b/c10/test/core/DeviceGuard_test.cpp @@ -1,8 +1,12 @@ #include #include +#include #include +#include +#include + using namespace c10; using namespace c10::impl; @@ -39,3 +43,75 @@ TEST(OptionalDeviceGuard, ResetDeviceDifferentDeviceType) { ASSERT_EQ(g.current_device(), Device(DeviceType::HIP, 2)); ASSERT_EQ(g.original_device(), Device(DeviceType::HIP, 0)); } + +// -- ensureCUDADeviceGuardSet ------------------------------------------- + +// Regression test: ensureCUDADeviceGuardSet() used to store a thread-local +// FakeGuardImpl* in the global device_guard_impl_registry. When the owning +// thread exited its TLS was freed, leaving a dangling pointer that the next +// thread to call deviceCount() would dereference (segfault). +// +// The fix is a function-local static, which has program lifetime. We verify +// that the pointer in the registry is still valid (and returns the expected +// deviceCount) after the threads that triggered guard installation have exited. +TEST(EnsureCUDADeviceGuard, NoUseAfterFreeWhenThreadsExit) { + // Simulate "CUDA compiled, no devices visible": a guard that is non-null but + // returns deviceCount() == 0, which is the condition that triggers fake guard + // installation in ensureCUDADeviceGuardSet(). + struct ZeroDeviceGuardImpl final : public DeviceGuardImplInterface { + DeviceType type() const override { + return DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + return d; + } + Device getDevice() const override { + return Device(DeviceType::CUDA, 0); + } + void setDevice(Device) const override {} + void uncheckedSetDevice(Device) const noexcept override {} + Stream getStream(Device d) const noexcept override { + return Stream(Stream::UNSAFE, d, 0); + } + Stream exchangeStream(Stream s) const noexcept override { + return s; + } + DeviceIndex deviceCount() const noexcept override { + return 0; + } + void record(void**, const Stream&, const DeviceIndex, const EventFlag) + const override {} + void block(void*, const Stream&) const override {} + bool queryEvent(void*) const override { + return true; + } + void destroyEvent(void*, const DeviceIndex) const noexcept override {} + }; + + constexpr auto cuda_idx = static_cast(DeviceType::CUDA); + const auto* saved = device_guard_impl_registry[cuda_idx].load(); + + static ZeroDeviceGuardImpl zero_impl; + device_guard_impl_registry[cuda_idx].store(&zero_impl); + + // Phase 1: threads call ensureCUDADeviceGuardSet(), detect deviceCount()==0, + // and install a FakeGuardImpl in the global registry. + { + std::vector threads; + for (int i = 0; i < 4; i++) { + threads.emplace_back(ensureCUDADeviceGuardSet); + } + for (auto& t : threads) { + t.join(); + } + } + // The threads' TLS is now destroyed. With the old code the registry now + // holds a dangling pointer; with the fix it holds &fake_cuda_guard (static). + + // Phase 2: the pointer must still be valid and return the expected count. + const auto* p = device_guard_impl_registry[cuda_idx].load(); + ASSERT_NE(p, nullptr); + ASSERT_EQ(p->deviceCount(), kFakeGuardImplMaxDevices); + + device_guard_impl_registry[cuda_idx].store(saved); +} diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp index cdbdc150167e0..9066d6100a3db 100644 --- a/c10/test/core/DispatchKeySet_test.cpp +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -404,7 +404,7 @@ TEST(DispatchKeySet, TestBackendComponentToString) { auto k = static_cast(i); auto res = std::string(toString(k)); ASSERT_FALSE(res == "UNKNOWN_BACKEND_BIT"); - ASSERT_FALSE(seen_strings.count(res) > 0); + ASSERT_FALSE(seen_strings.contains(res)); seen_strings.insert(res); } } @@ -439,7 +439,7 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) { } else { ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i; } - ASSERT_TRUE(seen_strings.count(res) == 0); + ASSERT_TRUE(!seen_strings.contains(res)); seen_strings.insert(res); } } diff --git a/c10/test/util/WaitCounter_test.cpp b/c10/test/util/WaitCounter_test.cpp index 1707aec75aa28..77b6b45c125b7 100644 --- a/c10/test/util/WaitCounter_test.cpp +++ b/c10/test/util/WaitCounter_test.cpp @@ -11,130 +11,134 @@ namespace { +// Struct to hold shared state +struct CounterState { + std::atomic startCount{0}; + std::atomic stopCount{0}; +}; + // Mock backend for testing WaitCounter functionality class MockWaitCounterBackend : public c10::monitor::detail::WaitCounterBackendIf { public: - explicit MockWaitCounterBackend( - std::atomic& startCount, - std::atomic& stopCount) - : startCount_(startCount), stopCount_(stopCount) {} + // Backend now holds a shared_ptr to the state + explicit MockWaitCounterBackend(std::shared_ptr state) + : state_(state) {} intptr_t start(std::chrono::steady_clock::time_point now) noexcept override { - startCount_.fetch_add(1); + state_->startCount.fetch_add(1); return reinterpret_cast(this); } void stop(std::chrono::steady_clock::time_point now, intptr_t ctx) noexcept override { - stopCount_.fetch_add(1); + state_->stopCount.fetch_add(1); EXPECT_EQ(ctx, reinterpret_cast(this)); } private: - std::atomic& startCount_; - std::atomic& stopCount_; + std::shared_ptr state_; }; class MockWaitCounterBackendFactory : public c10::monitor::detail::WaitCounterBackendFactoryIf { public: + // Factory accepts and stores a shared_ptr to the state MockWaitCounterBackendFactory( - std::atomic& startCount, - std::atomic& stopCount, + std::shared_ptr state, std::string_view keyFilter = "") - : startCount_(startCount), stopCount_(stopCount), keyFilter_(keyFilter) {} + : state_(state), keyFilter_(keyFilter) {} std::unique_ptr create( std::string_view key) noexcept override { if (!keyFilter_.empty() && key.find(keyFilter_) == std::string_view::npos) { return nullptr; } - return std::make_unique(startCount_, stopCount_); + // Pass the shared_ptr to the backend + return std::make_unique(state_); } private: - std::atomic& startCount_; - std::atomic& stopCount_; + std::shared_ptr state_; std::string keyFilter_; }; TEST(WaitCounter, BackendRegistration) { - auto backends = c10::monitor::detail::getRegisteredWaitCounterBackends(); - size_t initialCount = backends.size(); - - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "backend_registration"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique(startCount, stopCount)); + std::make_unique(state, key)); - backends = c10::monitor::detail::getRegisteredWaitCounterBackends(); - EXPECT_EQ(backends.size(), initialCount + 1); + c10::monitor::WaitCounterHandle handle(key); + { + auto guard = handle.start(); + EXPECT_EQ(state->startCount.load(), 1); + EXPECT_EQ(state->stopCount.load(), 0); + } + EXPECT_EQ(state->startCount.load(), 1); + EXPECT_EQ(state->stopCount.load(), 1); } TEST(WaitCounter, WaitGuardStartStop) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "wait_guard_start_stop"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "wait_guard_start_stop")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); - EXPECT_GE(startCount.load(), startBefore); + EXPECT_GE(state->startCount.load(), startBefore); { - c10::monitor::WaitCounterHandle handle("wait_guard_start_stop"); + c10::monitor::WaitCounterHandle handle(key); auto guard = handle.start(); - EXPECT_GE(startCount.load(), startBefore + 1); - EXPECT_EQ(stopCount.load(), stopBefore); + EXPECT_GE(state->startCount.load(), startBefore + 1); + EXPECT_EQ(state->stopCount.load(), stopBefore); } - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, WaitGuardExplicitStop) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "wait_guard_explicit_stop"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "wait_guard_explicit_stop")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); - c10::monitor::WaitCounterHandle handle("wait_guard_explicit_stop"); + c10::monitor::WaitCounterHandle handle(key); auto guard = handle.start(); - EXPECT_GE(startCount.load(), startBefore + 1); - EXPECT_EQ(stopCount.load(), stopBefore); + EXPECT_GE(state->startCount.load(), startBefore + 1); + EXPECT_EQ(state->stopCount.load(), stopBefore); guard.stop(); - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); // Calling stop() again should be a no-op (guard is already stopped) - int stopAfterFirst = stopCount.load(); + int stopAfterFirst = state->stopCount.load(); guard.stop(); - EXPECT_EQ(stopCount.load(), stopAfterFirst); + EXPECT_EQ(state->stopCount.load(), stopAfterFirst); } TEST(WaitCounter, WaitGuardMoveConstruction) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "wait_guard_move"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "wait_guard_move")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); { - c10::monitor::WaitCounterHandle handle("wait_guard_move"); + c10::monitor::WaitCounterHandle handle(key); auto guard1 = handle.start(); - EXPECT_GE(startCount.load(), startBefore + 1); + EXPECT_GE(state->startCount.load(), startBefore + 1); // Move the guard auto guard2 = std::move(guard1); @@ -142,116 +146,110 @@ TEST(WaitCounter, WaitGuardMoveConstruction) { } // Stop should be called exactly once when guard2 is destroyed - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, StaticWaitCounterMacro) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "static_macro_test"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "static_macro_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); { auto guard = STATIC_WAIT_COUNTER(static_macro_test).start(); - EXPECT_GE(startCount.load(), startBefore + 1); + EXPECT_GE(state->startCount.load(), startBefore + 1); } - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, StaticScopedWaitCounterMacro) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "static_scoped_test"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "static_scoped_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); { STATIC_SCOPED_WAIT_COUNTER(static_scoped_test); - EXPECT_GE(startCount.load(), startBefore + 1); + EXPECT_GE(state->startCount.load(), startBefore + 1); } - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, WithWaitCounterMacroAssign) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "execute_with_test_assign"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "execute_with_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); int value = 0; - WITH_WAIT_COUNTER(execute_with_test, value = 42); + WITH_WAIT_COUNTER(execute_with_test_assign, value = 42); EXPECT_EQ(value, 42); - EXPECT_GE(startCount.load(), startBefore + 1); - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->startCount.load(), startBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, WithWaitCounterMacroReturn) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "execute_with_test_return"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "execute_with_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); int value = 0; - value = WITH_WAIT_COUNTER(execute_with_test, []() { return 42; }()); + value = WITH_WAIT_COUNTER(execute_with_test_return, []() { return 42; }()); EXPECT_EQ(value, 42); - EXPECT_GE(startCount.load(), startBefore + 1); - EXPECT_GE(stopCount.load(), stopBefore + 1); + EXPECT_GE(state->startCount.load(), startBefore + 1); + EXPECT_GE(state->stopCount.load(), stopBefore + 1); } TEST(WaitCounter, SameHandleMultipleTimes) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "multiple_times_test"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "multiple_times_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); - c10::monitor::WaitCounterHandle handle("multiple_times_test"); + c10::monitor::WaitCounterHandle handle(key); for (int i = 0; i < 5; ++i) { auto guard = handle.start(); } - EXPECT_GE(startCount.load(), startBefore + 5); - EXPECT_GE(stopCount.load(), stopBefore + 5); + EXPECT_GE(state->startCount.load(), startBefore + 5); + EXPECT_GE(state->stopCount.load(), stopBefore + 5); } TEST(WaitCounter, ConcurrentUsage) { - std::atomic startCount{0}; - std::atomic stopCount{0}; + auto state = std::make_shared(); + constexpr std::string_view key = "concurrent_test"; c10::monitor::detail::registerWaitCounterBackend( - std::make_unique( - startCount, stopCount, "concurrent_test")); + std::make_unique(state, key)); - int startBefore = startCount.load(); - int stopBefore = stopCount.load(); + int startBefore = state->startCount.load(); + int stopBefore = state->stopCount.load(); constexpr int kNumThreads = 10; constexpr int kIterationsPerThread = 100; @@ -273,8 +271,10 @@ TEST(WaitCounter, ConcurrentUsage) { } EXPECT_GE( - startCount.load(), startBefore + kNumThreads * kIterationsPerThread); - EXPECT_GE(stopCount.load(), stopBefore + kNumThreads * kIterationsPerThread); + state->startCount.load(), + startBefore + kNumThreads * kIterationsPerThread); + EXPECT_GE( + state->stopCount.load(), stopBefore + kNumThreads * kIterationsPerThread); } TEST(WaitCounter, StaticHandlePerCallSite) { diff --git a/c10/util/ApproximateClock.h b/c10/util/ApproximateClock.h index 803ea404c9aa2..c1b65b33af5a7 100644 --- a/c10/util/ApproximateClock.h +++ b/c10/util/ApproximateClock.h @@ -29,6 +29,8 @@ #else #undef C10_RDTSC #endif +#elif defined(__aarch64__) && !defined(__CUDACC__) && !defined(__HIPCC__) +#define C10_ARMTSC #endif namespace c10 { @@ -70,6 +72,14 @@ inline time_t getTime(bool allow_monotonic = false) { #endif } +#if defined(C10_ARMTSC) +inline uint64_t getArmApproximateTime() { + uint64_t val; + __asm__ __volatile__("mrs %0, cntvct_el0" : "=r"(val)); + return val; +} +#endif + // We often do not need to capture true wall times. If a fast mechanism such // as TSC is available we can use that instead and convert back to epoch time // during post processing. This greatly reduce the clock's contribution to @@ -81,6 +91,8 @@ inline time_t getTime(bool allow_monotonic = false) { inline auto getApproximateTime() { #if defined(C10_RDTSC) return static_cast(__rdtsc()); +#elif defined(C10_ARMTSC) + return getArmApproximateTime(); #else return getTime(); #endif diff --git a/c10/util/Exception.cpp b/c10/util/Exception.cpp index 18d767e62d272..307e183d5c70e 100644 --- a/c10/util/Exception.cpp +++ b/c10/util/Exception.cpp @@ -3,10 +3,23 @@ #include #include +#include #include #include #include +// Google glog's api does not have an external function that allows one to check +// if glog is initialized or not. It does have an internal function - so we are +// declaring it here. This is a hack but has been used by a bunch of others too +// (e.g. Torch, common/init). See also Logging.cpp in this directory. +#ifdef C10_USE_GLOG +namespace google { +namespace glog_internal_namespace_ { +bool IsGoogleLoggingInitialized(); +} // namespace glog_internal_namespace_ +} // namespace google +#endif + namespace c10 { Error::Error(std::string msg, Backtrace backtrace, const void* caller) @@ -245,6 +258,18 @@ bool Warning::verbatim() const { } void WarningHandler::process(const Warning& warning) { +#ifdef C10_USE_GLOG + // During static initialization (before InitGoogleLogging), glog's global + // flags may not be constructed yet. Accessing them causes SIOF crashes + // (T253115013, D96553733). Fall back to stderr in that case. + if (!::google::glog_internal_namespace_::IsGoogleLoggingInitialized()) { + std::cerr << warning.source_location().file << ':' + << warning.source_location().line + << ": Warning: " << warning.msg() << " (function " + << warning.source_location().function << ')' << std::endl; + return; + } +#endif LOG_AT_FILE_LINE( WARNING, warning.source_location().file, warning.source_location().line) << "Warning: " << warning.msg() << " (function " diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 19e242a0f0980..24d14d65563e5 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -619,9 +619,19 @@ namespace c10::detail { #define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ while (false) \ C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +// In release: TORCH_INTERNAL_ASSERT_DEBUG_ONLY is a no-op, so return +// __VA_ARGS__ as a fallback. In debug: crashes via +// TORCH_INTERNAL_ASSERT(false), so no return is emitted (avoids +// -Wunreachable-code-return). +#define TORCH_INTERNAL_ASSERT_FALSE_OR_RETURN(...) \ + do { \ + return __VA_ARGS__; \ + } while (0) #else #define TORCH_INTERNAL_ASSERT_DEBUG_ONLY(...) \ C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(__VA_ARGS__)) +#define TORCH_INTERNAL_ASSERT_FALSE_OR_RETURN(...) \ + C10_EXPAND_MSVC_WORKAROUND(TORCH_INTERNAL_ASSERT(false)) #endif // TODO: We're going to get a lot of similar looking string literals diff --git a/c10/util/Semaphore.h b/c10/util/Semaphore.h index 041d9abecf515..6d2235c09da5a 100644 --- a/c10/util/Semaphore.h +++ b/c10/util/Semaphore.h @@ -8,7 +8,13 @@ // note: __cpp_lib_semaphore will not be defined in some apple platforms // even if >= C++20. -#if __has_include() && defined(__cpp_lib_semaphore) && __cpp_lib_semaphore >= 201907L +// +// libstdc++'s __atomic_semaphore has a lost-wakeup bug: _M_release skips +// the futex notify when the counter is already positive, but a concurrent +// _S_do_try_acquire can fail its CAS, see zero, and block — missing the +// wakeup. https://gcc.gnu.org/bugzilla/show_bug.cgi?id=98033 +#if __has_include() && defined(__cpp_lib_semaphore) && \ + __cpp_lib_semaphore >= 201907L && !defined(__GLIBCXX__) #define C10_SEMAPHORE_USE_STL #endif diff --git a/c10/util/generic_math.h b/c10/util/generic_math.h index 9c2a424560484..e98be901c8c53 100644 --- a/c10/util/generic_math.h +++ b/c10/util/generic_math.h @@ -59,6 +59,10 @@ inline C10_HOST_DEVICE scalar_t div_floor_floating(scalar_t a, scalar_t b) template inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) { + if (C10_UNLIKELY(b == 0)) { + return scalar_t(0); + } + if (C10_UNLIKELY( std::is_signed::value && a == std::numeric_limits::min() && b == scalar_t(-1))) { diff --git a/c10/xpu/XPUCachingAllocator.cpp b/c10/xpu/XPUCachingAllocator.cpp index 7e534c7aa403b..f01e0c5c15224 100644 --- a/c10/xpu/XPUCachingAllocator.cpp +++ b/c10/xpu/XPUCachingAllocator.cpp @@ -1547,7 +1547,7 @@ class DeviceCachingAllocator { std::shared_ptr context) { if (!record_history) return; - bool should_skip = skip_actions_list.count(action) > 0; + bool should_skip = skip_actions_list.contains(action); if (should_skip) return; TraceEntry te( diff --git a/c10/xpu/XPUDeviceProp.h b/c10/xpu/XPUDeviceProp.h index a06fa8e5cdef5..d1dd85b29d05d 100644 --- a/c10/xpu/XPUDeviceProp.h +++ b/c10/xpu/XPUDeviceProp.h @@ -129,8 +129,14 @@ namespace c10::xpu { /* the device identifier of the Intel GPU, also known as the product ID. */ \ _(device_id, device_id, 0) \ \ - /* the device descriptor for device Universal Unique ID, 16 bytes*/ \ - _(uuid, device_info_uuid, (std::array{})) + /* the device descriptor for device Universal Unique ID, 16 bytes. */ \ + _(uuid, device_info_uuid, (std::array{})) \ + \ + /* the maximum clock rate of device's global memory in MHz. */ \ + _(memory_clock_rate, memory_clock_rate, 0) \ + \ + /* the maximum bus width between device and memory in bits. */ \ + _(memory_bus_width, memory_bus_width, 0) #define AT_FORALL_XPU_DEVICE_ASPECT(_) \ /* sycl::half is supported on device. */ \ @@ -180,20 +186,24 @@ namespace c10::xpu { _DEFINE_SYCL_PROP( \ sycl::ext::oneapi::experimental::info::device, property, property) -struct C10_XPU_API DeviceProp { - AT_FORALL_XPU_DEVICE_PROPERTIES(DEFINE_DEVICE_PROP); +struct C10_XPU_API DeviceProp{ + AT_FORALL_XPU_DEVICE_PROPERTIES(DEFINE_DEVICE_PROP) - // the platform name. - DEFINE_PLATFORM_PROP(name, platform_name); + // the platform name. + DEFINE_PLATFORM_PROP(name, platform_name) - AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(DEFINE_EXT_DEVICE_PROP); + // ext properties. + AT_FORALL_XPU_EXT_DEVICE_PROPERTIES(DEFINE_EXT_DEVICE_PROP) - AT_FORALL_XPU_DEVICE_ASPECT(DEFINE_DEVICE_ASPECT); + // device aspects. + AT_FORALL_XPU_DEVICE_ASPECT(DEFINE_DEVICE_ASPECT) - AT_FORALL_XPU_EXP_CL_ASPECT(DEFINE_DEVICE_ASPECT); + // experimental device aspects. + AT_FORALL_XPU_EXP_CL_ASPECT(DEFINE_DEVICE_ASPECT) #if SYCL_COMPILER_VERSION >= 20250000 - AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(DEFINE_EXP_DEVICE_PROP); + // experimental device properties. + AT_FORALL_XPU_EXP_DEVICE_PROPERTIES(DEFINE_EXP_DEVICE_PROP) #endif }; diff --git a/c10/xpu/XPUStream.h b/c10/xpu/XPUStream.h index c922759c2c48b..396002eb211b9 100644 --- a/c10/xpu/XPUStream.h +++ b/c10/xpu/XPUStream.h @@ -109,6 +109,11 @@ class C10_XPU_API XPUStream { } } + bool is_capturing() const { + return queue().ext_oneapi_get_state() == + sycl::ext::oneapi::experimental::queue_state::recording; + } + /// Return the priority that this stream is associated with. Lower numbers /// represent higher priority. int priority() const; diff --git a/c10/xpu/impl/XPUGuardImpl.h b/c10/xpu/impl/XPUGuardImpl.h index ad788ceb88141..d9a3d83ee4488 100644 --- a/c10/xpu/impl/XPUGuardImpl.h +++ b/c10/xpu/impl/XPUGuardImpl.h @@ -50,7 +50,11 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { cap.capability_data.capability_bits = (1ULL << kIndex_Byte) | (1ULL << kIndex_Char) | (1ULL << kIndex_Short) | (1ULL << kIndex_Int) | (1ULL << kIndex_Long) | (1ULL << kIndex_Float) | - (1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool); + (1ULL << kIndex_ComplexFloat) | (1ULL << kIndex_Bool) | + (1ULL << kIndex_Float8_e5m2) | (1ULL << kIndex_Float8_e4m3fn) | + (1ULL << kIndex_Float8_e5m2fnuz) | (1ULL << kIndex_Float8_e4m3fnuz) | + (1ULL << kIndex_Float8_e8m0fnu) | (1ULL << kIndex_UInt16) | + (1ULL << kIndex_UInt32) | (1ULL << kIndex_UInt64); // BFloat16 may be emulated. We always assume BFloat16 is available; // users can call is_bf16_supported() to check for native hardware support. cap.capability_data.capability_bits |= (1ULL << kIndex_BFloat16); @@ -214,6 +218,11 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { xpu_stream.synchronize(); } + bool isStreamCapturing(const Stream& stream) const override { + const XPUStream xpu_stream{stream}; + return xpu_stream.is_capturing(); + } + void synchronizeEvent(void* event) const override { if (!event) return; diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index ea96059cedffd..92f85a0379476 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -53,6 +53,9 @@ if(INTERN_BUILD_ATEN_OPS) # Add source, includes, and libs to lists list(APPEND Caffe2_CPU_SRCS ${ATen_CPU_SRCS}) + if(USE_MTIA) + list(APPEND Caffe2_CPU_SRCS ${ATen_MTIA_SRCS}) + endif() list(APPEND Caffe2_GPU_SRCS ${ATen_CUDA_CPP_SRCS}) list(APPEND Caffe2_XPU_SRCS ${ATen_XPU_SRCS}) list(APPEND Caffe2_XPU_INCLUDE ${ATen_XPU_INCLUDE}) @@ -556,13 +559,6 @@ if(USE_CUDA OR USE_ROCM) append_filelist("libtorch_cuda_core_sources" Caffe2_GPU_HIP_JIT_FUSERS_SRCS) endif() -# NativeRT is disabled -# if(USE_CUDA) -# append_filelist("libtorch_nativert_cuda_sources" Caffe2_GPU_SRCS) -# endif() -# if(USE_ROCM) -# append_filelist("libtorch_nativert_cuda_sources" Caffe2_HIP_SRCS) -# endif() if(USE_CUDA) list(APPEND Caffe2_GPU_CU_SRCS ${Caffe2_GPU_HIP_JIT_FUSERS_SRCS}) @@ -594,6 +590,7 @@ if(USE_CUDA) ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/nccl_extension.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/ops/nccl_reduce_scatter_offset.cu ${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) @@ -604,7 +601,7 @@ if(USE_CUDA) if(CMAKE_COMPILER_IS_GNUCXX) set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-Wno-unused-but-set-variable") endif() - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0 AND CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") + if(CUDA_NVCC_FLAGS MATCHES ".*compute_90.*") set_source_files_properties(${ASYNC_MM_FILE} PROPERTIES COMPILE_FLAGS "-gencode arch=compute_90a,code=sm_90a") endif() endif() @@ -845,6 +842,10 @@ if(USE_MPS) if(CAN_COMPILE_METAL) add_dependencies(torch_cpu metallibs) target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib) + if(CAN_COMPILE_METAL_40) + target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_40,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_40.metallib) + target_compile_definitions(torch_cpu PRIVATE CAN_BUILD_METAL_4) + endif() else() target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS) endif() @@ -954,6 +955,12 @@ if(USE_ROCM) target_compile_definitions(torch_hip PRIVATE USE_NCCL) endif() + # IntraNodeComm loads libamd_smi at runtime via dlopen (avoids linking amd_smi + # which would cause bus errors if Python amdsmi loads a different copy). + if(USE_DISTRIBUTED) + target_link_libraries(torch_hip PRIVATE ${CMAKE_DL_LIBS}) + endif() + if(USE_PRECOMPILED_HEADERS) target_precompile_headers(torch_hip PRIVATE "$<$:ATen/core/ATen_pch.h>") @@ -1363,8 +1370,6 @@ if(BUILD_TEST) else() add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) add_subdirectory(${TORCH_ROOT}/test/cpp/lazy ${CMAKE_BINARY_DIR}/test_lazy) - # NativeRT is disabled - # add_subdirectory(${TORCH_ROOT}/test/cpp/nativert ${CMAKE_BINARY_DIR}/test_nativert) add_subdirectory(${TORCH_ROOT}/test/cpp/profiler ${CMAKE_BINARY_DIR}/test_profiler) add_subdirectory(${TORCH_ROOT}/test/inductor ${CMAKE_BINARY_DIR}/test_inductor) add_subdirectory(${TORCH_ROOT}/test/cpp/aoti_abi_check ${CMAKE_BINARY_DIR}/test_aoti_abi_check) @@ -1462,6 +1467,80 @@ if(USE_ROCM) if(USE_ROCM_CK_SDPA) target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA) endif() + + message(INFO "USE_NVSHMEM=${USE_NVSHMEM} USE_ROCM=${USE_ROCM} rocshmem_FOUND=${rocshmem_FOUND} rocshmem_VERSION=${rocshmem_VERSION}") + if(USE_NVSHMEM) + if(rocshmem_FOUND AND rocshmem_VERSION VERSION_GREATER_EQUAL "3.3.0") + # rocSHMEM upstream ships device code for a subset of AMD gfx targets only. + set(ROCSHMEM_SUPPORTED_ARCH gfx90a gfx942 gfx950 gfx1100 gfx1201) + set(_torch_rocshmem_build_arches) + foreach(_arch ${PYTORCH_ROCM_ARCH}) + list(FIND ROCSHMEM_SUPPORTED_ARCH "${_arch}" _rocsmem_supported_idx) + if(NOT _rocsmem_supported_idx EQUAL -1) + # rocSHMEM device archives are commonly built for gfx90a xnack variants, + # so map plain gfx90a to gfx90a:xnack-/gfx90a:xnack+ to match device + # symbols during HIP RDC link. + if(_arch STREQUAL "gfx90a") + list(APPEND _torch_rocshmem_build_arches "gfx90a:xnack-" "gfx90a:xnack+") + else() + list(APPEND _torch_rocshmem_build_arches "${_arch}") + endif() + endif() + endforeach() + if(_torch_rocshmem_build_arches) + message(STATUS "rocSHMEM found, building with rocSHMEM support for arches: ${_torch_rocshmem_build_arches}") + message(STATUS "ROCSHMEM_INCLUDE_DIR: '${ROCSHMEM_INCLUDE_DIR}'") + set(TORCH_ROCSHMEM_SRCS + "${TORCH_SRC_DIR}/csrc/distributed/c10d/cuda/utils.cpp" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/rocshmem_extension.cu" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cpp" + "${TORCH_SRC_DIR}/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp" + ) + set_source_files_properties(${TORCH_ROCSHMEM_SRCS} PROPERTIES LANGUAGE HIP) + + # FindHIP's generated compile script bakes in global HIP_CLANG_FLAGS. Override + # -fno-gpu-rdc with -fgpu-rdc so device linker can pull in rocSHMEM device APIs. + set(HIP_CLANG_FLAGS_SAVED ${HIP_CLANG_FLAGS}) + set(HIP_CLANG_FLAGS_FOR_ROCSHMEM ${HIP_CLANG_FLAGS}) + list(FILTER HIP_CLANG_FLAGS_FOR_ROCSHMEM EXCLUDE REGEX "^-fno-gpu-rdc$") + list(APPEND HIP_CLANG_FLAGS_FOR_ROCSHMEM -fgpu-rdc) + set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_FOR_ROCSHMEM}) + + hip_add_library(torch_rocshmem SHARED ${TORCH_ROCSHMEM_SRCS}) + + set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_SAVED}) + + set_target_properties(torch_rocshmem PROPERTIES LINKER_LANGUAGE HIP) + + + # torch_rocshmem is built with -fgpu-rdc, so its HIP link step performs + # device linking and must carry the full arch list explicitly. + set_target_properties( + torch_rocshmem PROPERTIES HIP_ARCHITECTURES "${_torch_rocshmem_build_arches}") + + + target_compile_definitions(torch_hip PUBLIC USE_NVSHMEM) + target_compile_definitions(torch_rocshmem PUBLIC USE_NVSHMEM USE_ROCM) + + target_link_libraries(torch_rocshmem PRIVATE roc::rocshmem) + target_link_libraries(torch_hip PRIVATE torch_rocshmem) + + install(TARGETS torch_rocshmem EXPORT Caffe2Targets DESTINATION + "${TORCH_INSTALL_LIB_DIR}") + else() + message(STATUS + "rocSHMEM found but skipped: PYTORCH_ROCM_ARCH has no entry in " + "ROCSHMEM_SUPPORTED_ARCH (${ROCSHMEM_SUPPORTED_ARCH}). " + "Building without rocSHMEM support.") + endif() + elseif(rocshmem_FOUND) + message(STATUS + "rocSHMEM found but skipped: requires rocSHMEM >= 3.3.0 " + "(found '${rocshmem_VERSION}'). Building without rocSHMEM support.") + else() + message(WARNING "rocSHMEM not found (USE_NVSHMEM=ON). Building without rocSHMEM support.") + endif() + endif() #USE_NVSHMEM endif() if(BUILD_LITE_INTERPRETER) @@ -1984,7 +2063,11 @@ if(BUILD_TEST) endif() if(USE_ROCM) + string(JOIN " " HIP_HIPCC_FLAGS_STR ${HIP_HIPCC_FLAGS}) + set(HIP_HIPCC_FLAGS ${HIP_HIPCC_FLAGS_STR}) + set(BASE_HIPCC_FLAGS ${HIP_HIPCC_FLAGS}) + foreach(test_src ${Caffe2_HIP_TEST_SRCS}) get_filename_component(test_name ${test_src} NAME_WE) if(test_src MATCHES "^.*\.hip$") diff --git a/caffe2/core/macros.h.in b/caffe2/core/macros.h.in index 8c7592b2d5315..ba0e20df271ed 100644 --- a/caffe2/core/macros.h.in +++ b/caffe2/core/macros.h.in @@ -68,4 +68,5 @@ {"USE_CUSPARSELT", "${USE_CUSPARSELT}"}, \ {"USE_XPU", "${USE_XPU}"}, \ {"USE_XCCL", "${USE_XCCL}"}, \ + {"SYCL_COMPILER_VERSION", "${SYCL_COMPILER_VERSION}"}, \ } diff --git a/caffe2/perfkernels/batch_box_cox_avx512.cc b/caffe2/perfkernels/batch_box_cox_avx512.cc deleted file mode 100644 index a97cb364a359d..0000000000000 --- a/caffe2/perfkernels/batch_box_cox_avx512.cc +++ /dev/null @@ -1,118 +0,0 @@ -#ifdef CAFFE2_PERF_USE_MKL -#include - -// Enable compiler vectorized version only if numerical consistency is not -// required between dev and opt versions - disabled for now -#ifndef FAST_VECTORIZED_KERNEL -#define CPU_CAPABILITY_AVX512 -#include - -namespace at::vec { -namespace { -// Implements the vectorized version of std::max() operation, -// which DOESNOT propagates NaN for second argument -template -Vectorized max(const Vectorized& a, const Vectorized& b); - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm512_max_pd(b, a); -} - -template <> -Vectorized max(const Vectorized& a, const Vectorized& b) { - // std::max(NaN, nonNan) -> NaN - return _mm512_max_ps(b, a); -} - -// Implements recieprocal method based on newton-rapson method -// 1. user RCP approximiation -// 2. update with RCP = RCP * (2 - X * RCP) -template -Vectorized fast_recieprocal(const Vectorized& b); -template -scalar_t fast_recieprocal(scalar_t b); - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - auto minus2 = _mm512_set1_ps(-2.f); - auto rcp = _mm512_rcp14_ps(b); - rcp = _mm512_mul_ps(rcp, _mm512_fnmsub_ps(rcp, b, minus2)); - rcp = _mm512_mul_ps(rcp, _mm512_fnmsub_ps(rcp, b, minus2)); - return rcp; -} - -template <> -float fast_recieprocal(float b) { - auto minus2 = _mm_set_ss(-2.f); - auto b_reg = _mm_set_ss(b); - auto rcp = _mm_rcp_ss(b_reg); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - rcp = _mm_mul_ss(rcp, _mm_fnmsub_ss(rcp, b_reg, minus2)); - return _mm_cvtss_f32(rcp); -} - -template<> -Vectorized fast_recieprocal(const Vectorized& b) { - auto minus2 = _mm512_set1_pd(-2.); - auto rcp = _mm512_rcp14_pd(b); - rcp = _mm512_mul_pd(rcp, _mm512_fnmsub_pd(rcp, b, minus2)); - rcp = _mm512_mul_pd(rcp, _mm512_fnmsub_pd(rcp, b, minus2)); - return rcp; -} - -template <> -double fast_recieprocal(double b) { - return 1./b; -} -} // namespace -} // namespace at::vec -#endif - -#include "caffe2/perfkernels/batch_box_cox_vec.h" - -namespace caffe2::details { - -template -void compute_batch_box_cox__avx512( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* __restrict lambda1_data, - const T* __restrict lambda2_data, - T* output_data) { - compute_batch_box_cox_vec_fma( - N, - D, - block_size, - self_data, - lambda1_data, - lambda2_data, - output_data); - } - -// Vectorized version specializations for float and double -template -void compute_batch_box_cox__avx512( - std::size_t N, - std::size_t D, - std::size_t block_size, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -template -void compute_batch_box_cox__avx512( - std::size_t N, - std::size_t D, - std::size_t block_size, - const double* self_data, - const double* __restrict lambda1_data, - const double* __restrict lambda2_data, - double* output_data); - -} // namespace caffe2::detail -#endif // CAFFE2_PERF_USE_MKL diff --git a/caffe2/perfkernels/batch_box_cox_sve128.cc b/caffe2/perfkernels/batch_box_cox_sve128.cc deleted file mode 100644 index 897e3a8ee4755..0000000000000 --- a/caffe2/perfkernels/batch_box_cox_sve128.cc +++ /dev/null @@ -1,200 +0,0 @@ -#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(CAFFE2_PERF_WITH_SVE128) -#include -#include -#include -#include -#include - -#include "c10/macros/Macros.h" - -/// Select `svlog` accuracy: -/// - 0: original. -/// - 1: more accurate, similar performance. -/// - 2: very high accuracy, a bit lower speed. -#define SVLOG_ACCURACY 2 - -/// Handle special cases in `svexp`: -/// - 0: original. -/// - 1: use clamp, better performance. -/// - 2: no special case handling. -#define SVEXP_SPECIAL_CLAMP 1 - -#if SVLOG_ACCURACY == 2 -static inline svfloat32_t svlog(svfloat32_t x) { - const svbool_t ptrue = svptrue_b8(); - - svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB; - - svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f; - svfloat32_t n = svcvt_f32_x(ptrue, u >> 23); - asm("" : "+w"(r)); // NOTE: can improve instruction scheduling. - - svfloat32_t r2 = r * r; - svfloat32_t p = -0x1.4F9934p-3f + r * 0x1.5A9AA2p-3f; - svfloat32_t q = -0x1.00187Cp-2f + r * 0x1.961348p-3f; - svfloat32_t y = -0x1.FFFFC8p-2f + r * 0x1.555D7Cp-2f; - return (r + n * 0x1.62E43p-1f) + - (y + (q + (p + -0x1.3E737Cp-3f * r2) * r2) * r2) * r2; -} -#elif SVLOG_ACCURACY == 1 -static inline svfloat32_t svlog(svfloat32_t x) { - const svbool_t ptrue = svptrue_b8(); - - svint32_t u = svreinterpret_s32(x) - 0x3F2AAAAB; - - svfloat32_t r = svreinterpret_f32((u & 0x007FFFFF) + 0x3F2AAAAB) - 1.0f; - svfloat32_t n = svcvt_f32_x(ptrue, u >> 23); - asm("" : "+w"(r)); // NOTE: can improve instruction scheduling. - - svfloat32_t r2 = r * r; - svfloat32_t A = -0x1.923814p-3f + r * 0x1.689E5Ep-3f; - svfloat32_t B = -0x1.FC0968p-3f + r * 0x1.93BF0Cp-3f; - svfloat32_t C = -0x1.000478p-1f + r * 0x1.556906p-2f; - - return (r + n * 0x1.62E43p-1f) + (C + (B + A * r2) * r2) * r2; -} -#elif SVLOG_ACCURACY == 0 -static inline svfloat32_t svlog(svfloat32_t x) { - const svbool_t ptrue = svptrue_b8(); - - svint32_t u = svsra_n_s32(svdup_n_s32(-127), svreinterpret_s32(x), 23); - - svfloat32_t n = svcvt_f32_x(ptrue, u); - svfloat32_t r = svreinterpret_f32(svreinterpret_s32(x) - (u << 23)); - - svfloat32_t D = -0.165253549814f + r * 0.0141278216615f; - svfloat32_t C = -2.47071170807f + r * 0.844007015228f; - svfloat32_t B = -5.68692588806f + r * 4.58445882797f; - svfloat32_t A = -2.29561495781f + r * 5.17591238022f; - - svfloat32_t r2 = r * r; - return (A + n * 0.6931471805f) + (B + (C + D * r2) * r2) * r2; -} -#endif - -static inline svfloat32_t svexp(svfloat32_t x) { - // Clamp interval set to prevent denormals! - const svfloat32_t max_input = svdup_n_f32(88.722839f); - const svfloat32_t min_input = svdup_n_f32(-87.33654f); - const svfloat32_t shift = svdup_n_f32(0x1.0000FEp+23f); - const svbool_t ptrue = svptrue_b8(); - -#if SVEXP_SPECIAL_CLAMP == 1 - x = svmax_x(ptrue, svmin_x(ptrue, x, max_input), min_input); -#endif - - svfloat32_t z = svmla_n_f32_x(ptrue, shift, x, 0x1.715476p+0f); - svfloat32_t n = z - shift; - svfloat32_t scale = svreinterpret_f32(svreinterpret_u32(z) << 23); - - svfloat32_t r_hi = x - n * 0x1.62E400p-1f; - svfloat32_t r = r_hi - n * 0x1.7F7D1Cp-20f; - svfloat32_t r2 = r * r; - - svfloat32_t C = 0x1.573E2Ep-5f + r * 0x1.0E4020p-7f; - svfloat32_t B = 0x1.FFFDB6p-2f + r * 0x1.555E66p-3f; - svfloat32_t A = r * 0x1.FFFFECp-1f; - - svfloat32_t poly = scale + (A + (B + C * r2) * r2) * scale; - -#if SVEXP_SPECIAL_CLAMP == 0 - const svfloat32_t inf = svdup_n_f32(std::numeric_limits::infinity()); - poly = svsel_f32(svcmplt_f32(ptrue, x, min_input), svdup_n_f32(0.0f), poly); - poly = svsel_f32(svcmpgt_f32(ptrue, x, max_input), inf, poly); -#endif - - return poly; -} - -static inline svfloat32_t compute_batch_box_cox_vec_sve128_float( - svfloat32_t lambda1_v, - svfloat32_t lambda2_v, - svfloat32_t data_v, - svfloat32_t k_eps) { - const svbool_t ptrue = svptrue_b8(); - - svfloat32_t lnData = svlog(svmax_x(ptrue, data_v + lambda2_v, k_eps)); - svbool_t predNZ = svcmpne_n_f32(ptrue, lambda1_v, 0.0f); - if (C10_LIKELY(svptest_any(predNZ, predNZ))) { - svfloat32_t lambda1_r = svdivr_f32_m(predNZ, lambda1_v, svdup_n_f32(1.0f)); - svfloat32_t pow = svexp(lnData * lambda1_v); - lnData = svsel_f32(predNZ, lambda1_r, lnData); - lnData = svnmsb_f32_m(predNZ, lnData, pow, lnData); - } - return lnData; -} - -template -void compute_batch_box_cox_vec_sve128( - std::size_t N, - std::size_t D, - const T* data_ptr, - const T* __restrict lambda1_ptr, - const T* __restrict lambda2_ptr, - T* output_ptr); - -template <> -void compute_batch_box_cox_vec_sve128( - std::size_t N, - std::size_t D, - const float *data_ptr, - const float *__restrict lambda1_ptr, - const float *__restrict lambda2_ptr, - float *output_ptr) { - const svfloat32_t k_eps = svdup_n_f32(static_cast(1e-6)); - - std::size_t remainder = D % 4; - std::size_t loopBound = D - remainder; - svbool_t remainderPred = svwhilelt_b32_u64(0, remainder); - - for (; C10_LIKELY(N > 0); --N) { - for (std::size_t j = 0; C10_LIKELY(j != loopBound); - j += 4, data_ptr += 4, output_ptr += 4) { - svfloat32_t lambda1_v = - svset_neonq(svundef_f32(), vld1q_f32(lambda1_ptr + j)); - svfloat32_t lambda2_v = - svset_neonq(svundef_f32(), vld1q_f32(lambda2_ptr + j)); - svfloat32_t data_v = svset_neonq(svundef_f32(), vld1q_f32(data_ptr)); - svfloat32_t result = compute_batch_box_cox_vec_sve128_float( - lambda1_v, lambda2_v, data_v, k_eps); - vst1q_f32(output_ptr, svget_neonq(result)); - } - if (C10_LIKELY(remainder > 0)) { - svfloat32_t lambda1_v = svld1_f32(remainderPred, lambda1_ptr + loopBound); - svfloat32_t lambda2_v = svld1_f32(remainderPred, lambda2_ptr + loopBound); - svfloat32_t data_v = svld1_f32(remainderPred, data_ptr); - svfloat32_t result = compute_batch_box_cox_vec_sve128_float( - lambda1_v, lambda2_v, data_v, k_eps); - svst1_f32(remainderPred, output_ptr, result); - data_ptr += remainder; - output_ptr += remainder; - } - } -} - -namespace caffe2::details { - -template -void compute_batch_box_cox__sve128( - std::size_t N, - std::size_t D, - const T* self_data, - const T* __restrict lambda1_data, - const T* __restrict lambda2_data, - T* output_data) { - compute_batch_box_cox_vec_sve128( - N, D, self_data, lambda1_data, lambda2_data, output_data); -} - -// Vectorized version specializations for float and double -template void compute_batch_box_cox__sve128( - std::size_t N, - std::size_t D, - const float* self_data, - const float* __restrict lambda1_data, - const float* __restrict lambda2_data, - float* output_data); - -} // namespace caffe2::details - -#endif // __aarch64__ && __ARM_FEATURE_SVE && CAFFE2_PERF_WITH_SVE128 diff --git a/caffe2/perfkernels/batch_box_cox_vec.h b/caffe2/perfkernels/batch_box_cox_vec.h deleted file mode 100644 index 08e4f84fe4327..0000000000000 --- a/caffe2/perfkernels/batch_box_cox_vec.h +++ /dev/null @@ -1,321 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include "vectorizer.h" -#include - -namespace caffe2::details { - -namespace { -void TileIndicesInPlace(std::vector& v, const std::size_t D, const std::size_t K) { - auto n = v.size(); - v.resize(K * n); - for (const auto k : c10::irange(1, K)) { - for (const auto j : c10::irange(n)) { - v[k * n + j] = v[j] + k * D; - } - } -} - -// MKL VML function templates. -template -void PackV(const int N, const T* a, const int* ia, T* y); -template -void UnpackV(const int N, const T* a, T* y, const int* iy); - -#define DELEGATE_PACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void PackV(const int N, const T* a, const int* ia, T* y) { \ - OriginalFunc(N, a, ia, y); \ - } -DELEGATE_PACKV_FUNCTION(float, vsPackV) -DELEGATE_PACKV_FUNCTION(double, vdPackV) -#undef DELEGATE_PACKV_FUNCTION - -#define DELEGATE_UNPACKV_FUNCTION(T, OriginalFunc) \ - template <> \ - void UnpackV(const int N, const T* a, T* y, const int* iy) { \ - OriginalFunc(N, a, y, iy); \ - } -DELEGATE_UNPACKV_FUNCTION(float, vsUnpackV) -DELEGATE_UNPACKV_FUNCTION(double, vdUnpackV) -#undef DELEGATE_UNPACKV_FUNCTION - -#ifndef FAST_VECTORIZED_KERNEL -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(self_data + j); - auto lambda2 = Vec::loadu(lambda2_data + j); - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps_vec); - auto res = max.log(); - res.store(output_data + j); - } - for ( ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -at::vec::Vectorized box_cox_nonzero_lambda_impl( - at::vec::Vectorized data, - at::vec::Vectorized lambda1, - at::vec::Vectorized lambda2, - at::vec::Vectorized k_eps) { - auto sum = data + lambda2; - auto max = at::vec::max(sum, k_eps); - auto lambda_over_1 = at::vec::fast_recieprocal(lambda1); - auto pow = max.pow(lambda1); - return at::vec::fmsub(pow, lambda_over_1, lambda_over_1); -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - int j = 0; - using Vec = at::vec::Vectorized; - constexpr int64_t VLEN = Vec::size(); - auto k_eps_vec = Vec(k_eps); - for(; j + VLEN < D; j += VLEN) { - auto data = Vec::loadu(data_ptr + j); - auto lambda1 = Vec::loadu(lambda1_ptr + j); - auto lambda2 = Vec::loadu(lambda2_ptr + j); - auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); - res.store(out + j); - } - if (j < D) { - auto remaining = D - j; - auto data = Vec::loadu(data_ptr + j, remaining); - auto lambda1 = Vec::loadu(lambda1_ptr + j, remaining); - auto lambda2 = Vec::loadu(lambda2_ptr + j, remaining); - auto res = box_cox_nonzero_lambda_impl(data, lambda1, lambda2, k_eps_vec); - res.store(out + j, remaining); - } -} -#else -template -void box_cox_zero_lambda( - size_t D, - const T* const self_data, - const T* const lambda2_data, - T k_eps, - T* const output_data) { - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - auto sum = self_data[j] + lambda2_data[j]; - auto max = std::max(sum, k_eps); - output_data[j] = std::log(max); - } -} - -template -void box_cox_nonzero_lambda( - int64_t D, - const T* data_ptr, - const T* lambda1_ptr, - const T* lambda2_ptr, - T k_eps, - T* out) { - - VECTOR_LOOP for (auto j=0 ;j < D; ++j) { - FAST_MATH - auto sum = data_ptr[j] + lambda2_ptr[j]; - auto max = std::max(sum, k_eps); - auto lamda1 = lambda1_ptr[j]; - auto lambda_over_1 = 1 / lamda1; - if constexpr (std::is_same::value) { - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - lambda_over_1 = lambda_over_1 * (T{2} - lambda_over_1 * lamda1); - } - auto pow = std::pow(max, lamda1); - out[j] = pow * lambda_over_1 - lambda_over_1; - } -} -#endif // FAST_VECTORIZED_KERNEL - -template -void box_cox_mixed_lambda( - const T* const self_data, - const std::vector& nonzeros, - const std::vector& zeros, - const T* const lambda1, - const T* const lambda2, - const T* const lambda2_z_, - T k_eps, - T* const buffer, - T* const output_data) { - PackV(nonzeros.size(), self_data, nonzeros.data(), buffer); - box_cox_nonzero_lambda( - nonzeros.size(), buffer, lambda1, lambda2, k_eps, buffer); - UnpackV(nonzeros.size(), buffer, output_data, nonzeros.data()); - - PackV(zeros.size(), self_data, zeros.data(), buffer); - box_cox_zero_lambda( - zeros.size(), buffer, lambda2_z_, k_eps, buffer); - UnpackV(zeros.size(), buffer, output_data, zeros.data()); -} - -template -void TileArrayIntoVector( - const T* const a, - const size_t D, - const int K, - std::vector& b) { - b.resize(K * D); - for (const auto k : c10::irange(K)) { - std::copy(a, a + D, b.begin() + k * D); - } -} - -template -void compute_batch_box_cox_vec_fma( - std::size_t N, - std::size_t D, - std::size_t block_size, - const T* self_data, - const T* __restrict lambda1_data, - const T* __restrict lambda2_data, - T* output_data) { - constexpr T k_eps = static_cast(1e-6); - - FOLLY_DECLARE_REUSED(zeros, std::vector); - FOLLY_DECLARE_REUSED(nonzeros, std::vector); - // Don't bother calling reserve; calls after the first will get a - // correctly-sized allocation anyway. - for (const auto j : c10::irange(D)) { - if (lambda1_data[j] == 0) { - zeros.push_back(j); - } else { - nonzeros.push_back(j); - } - } - - // Process K rows at a time for effective vectorization with small rows. - const auto K = std::min(N, (block_size + D - 1) / D); - - FOLLY_DECLARE_REUSED(lambda1_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_, std::vector); - FOLLY_DECLARE_REUSED(lambda2_z_, std::vector); - - if (nonzeros.size() == D) { - // ((x + lambda2)^lambda1 - 1)/lambda1, if lambda1 != 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda1_data, D, K, lambda1_); - TileArrayIntoVector(lambda2_data, D, K, lambda2_); - DCHECK_EQ(K * D, lambda1_.size()); - DCHECK_EQ(K * D, lambda2_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_nonzero_lambda( - K * D, - self_data, - lambda1_.data(), - lambda2_.data(), - k_eps, - output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_nonzero_lambda( - D, self_data, lambda1_data, lambda2_data, k_eps, output_data); - } - } else if (zeros.size() == D) { - // ln(x + lambda2), if lambda1 == 0 - size_t i = 0; - if (K > 1) { - TileArrayIntoVector(lambda2_data, D, K, lambda2_z_); - DCHECK_EQ(K * D, lambda2_z_.size()); - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_zero_lambda( - K * D, self_data, lambda2_z_.data(), k_eps, output_data); - } - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_zero_lambda( - D, self_data, lambda2_data, k_eps, output_data); - } - } else { - // mix zeros and nonzeros - const size_t n = nonzeros.size(); - if (K > 1) { - TileIndicesInPlace(nonzeros, 0, K); - TileIndicesInPlace(zeros, 0, K); - } - - FOLLY_DECLARE_REUSED(buffer, std::vector); - - buffer.resize(std::max(nonzeros.size(), zeros.size())); - lambda1_.resize(nonzeros.size()); - lambda2_.resize(nonzeros.size()); - lambda2_z_.resize(zeros.size()); - PackV(nonzeros.size(), lambda1_data, nonzeros.data(), lambda1_.data()); - PackV(nonzeros.size(), lambda2_data, nonzeros.data(), lambda2_.data()); - PackV(zeros.size(), lambda2_data, zeros.data(), lambda2_z_.data()); - - size_t i = 0; - if (K > 1) { - // Truncate to original size, and re-tile with offsets this time. - nonzeros.resize(n); - DCHECK_GT(D, n); - zeros.resize(D - n); - TileIndicesInPlace(nonzeros, D, K); - TileIndicesInPlace(zeros, D, K); - DCHECK_EQ(nonzeros.size(), lambda1_.size()); - DCHECK_EQ(nonzeros.size(), lambda2_.size()); - DCHECK_EQ(zeros.size(), lambda2_z_.size()); - - for (; i < N - K + 1; i += K, self_data += K * D, output_data += K * D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - // Truncate to original size. - nonzeros.resize(n); - zeros.resize(D - n); - } - for (; i < N; i++, self_data += D, output_data += D) { - box_cox_mixed_lambda( - self_data, - nonzeros, - zeros, - lambda1_.data(), - lambda2_.data(), - lambda2_z_.data(), - k_eps, - buffer.data(), - output_data); - } - } -} -} // namespace - -} // namespace caffe2::details diff --git a/cmake/Codegen.cmake b/cmake/Codegen.cmake index c3003e1c8af39..17456252ad16d 100644 --- a/cmake/Codegen.cmake +++ b/cmake/Codegen.cmake @@ -163,6 +163,11 @@ if(INTERN_BUILD_ATEN_OPS) set(GEN_XPU_FLAG --xpu) endif() + set(GEN_MTIA_FLAG) + if(USE_MTIA) + set(GEN_MTIA_FLAG --mtia) + endif() + set(CUSTOM_BUILD_FLAGS) if(INTERN_BUILD_MOBILE) if(USE_VULKAN) @@ -249,6 +254,7 @@ if(INTERN_BUILD_ATEN_OPS) ${GEN_ROCM_FLAG} ${GEN_MPS_FLAG} ${GEN_XPU_FLAG} + ${GEN_MTIA_FLAG} ${CUSTOM_BUILD_FLAGS} ) @@ -426,6 +432,11 @@ if(INTERN_BUILD_ATEN_OPS) function(process_vec NAME) list(GET CPU_CAPABILITY_NAMES ${i} CPU_CAPABILITY) set(NEW_IMPL ${CMAKE_BINARY_DIR}/aten/src/ATen/${NAME}.${CPU_CAPABILITY}.cpp) + # IMPL is absolute here; make it relative to NEW_IMPL's directory so the + # generated #include is worktree-independent (ccache/re-cc friendly). + if(USE_RELATIVE_PATHS) + file(RELATIVE_PATH IMPL "${CMAKE_BINARY_DIR}/aten/src/ATen" "${IMPL}") + endif() configure_file("${PROJECT_SOURCE_DIR}/cmake/IncludeSource.cpp.in" ${NEW_IMPL}) set(cpu_kernel_cpp ${NEW_IMPL} ${cpu_kernel_cpp} PARENT_SCOPE) # Create list of copies list(GET CPU_CAPABILITY_FLAGS ${i} FLAGS) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 513ad71f2f19c..69fabe0e9f44b 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -47,15 +47,21 @@ if(USE_CUDA) # torch::cudart is dealt with separately, due to CUDA_ADD_LIBRARY # design reason (it adds CUDA_LIBRARIES itself). set(Caffe2_PUBLIC_CUDA_DEPENDENCY_LIBS ) - if(NOT CAFFE2_USE_NVRTC) - caffe2_update_option(USE_NVRTC OFF) - endif() list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS caffe2::curand caffe2::cufft caffe2::cublas) if(CAFFE2_USE_CUDNN) + if(NOT CAFFE2_USE_NVRTC) + message(FATAL_ERROR + "USE_CUDNN requires USE_NVRTC (required by cudnn_frontend 1.21+). " + "Please set -DUSE_NVRTC=ON or disable cuDNN with -DUSE_CUDNN=OFF.") + endif() list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cudnn) + list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS caffe2::nvrtc_runtime) else() caffe2_update_option(USE_CUDNN OFF) endif() + if(NOT CAFFE2_USE_NVRTC) + caffe2_update_option(USE_NVRTC OFF) + endif() if(CAFFE2_USE_CUSPARSELT) list(APPEND Caffe2_CUDA_DEPENDENCY_LIBS torch::cusparselt) else() @@ -131,7 +137,10 @@ if(USE_ASAN OR USE_LSAN OR USE_TSAN) endif() if(USE_TSAN) if(TARGET Sanitizer::thread) - list(APPEND Caffe2_DEPENDENCY_LIBS Sanitizer::thread) + # Use global flags so that all targets (including executables like + # torch_shm_manager that don't link torch_cpu) get TSan instrumentation. + add_compile_options(-fsanitize=thread) + add_link_options(-fsanitize=thread) else() message(WARNING "TSAN not found. Suppress this warning with -DUSE_TSAN=OFF.") caffe2_update_option(USE_TSAN OFF) @@ -1023,7 +1032,7 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM) endif() list(APPEND HIP_HIPCC_FLAGS --offload-compress) - list(APPEND HIP_HIPCC_FLAGS -std=c++17) + list(APPEND HIP_HIPCC_FLAGS -std=c++20) # Pass device library path for theRock nightly builds if(DEFINED ENV{HIP_DEVICE_LIB_PATH}) file(TO_CMAKE_PATH "$ENV{HIP_DEVICE_LIB_PATH}" _hip_device_lib_path) @@ -1063,6 +1072,8 @@ if(USE_ROCM) list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17) set(HIP_CLANG_FLAGS ${HIP_CXX_FLAGS}) + string(JOIN " " HIP_HIPCC_FLAGS_STR ${HIP_HIPCC_FLAGS}) + set(HIP_HIPCC_FLAGS ${HIP_HIPCC_FLAGS_STR}) set(CMAKE_HIP_FLAGS ${HIP_HIPCC_FLAGS}) # Ask hcc to generate device code during compilation so we can use # host linker to link. @@ -1087,6 +1098,16 @@ if(USE_ROCM) list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS roc::hipsparselt ) + if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "7.12.0") + set(CAFFE2_USE_HIPSPARSELT ON) + endif() + endif() + + # ROCM-SMI needed to support symmetric memory + if(USE_DISTRIBUTED AND UNIX) + list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS + rocm_smi64 + ) endif() # ---[ Kernel asserts diff --git a/cmake/EnvVarForwarding.cmake b/cmake/EnvVarForwarding.cmake new file mode 100644 index 0000000000000..a31ce6db26637 --- /dev/null +++ b/cmake/EnvVarForwarding.cmake @@ -0,0 +1,138 @@ +# Forward environment variables to CMake variables. +# +# This replicates the behavior of setup.py / tools/setup_helpers/cmake.py which +# passes all BUILD_*, USE_*, and CMAKE_* environment variables as -D flags, plus +# a set of additional variables that don't follow the prefix convention. + +# Additional env vars that are forwarded with a different CMake variable name. +set(_ENV_ALIASES + "CUDNN_LIB_DIR=CUDNN_LIBRARY" + "USE_CUDA_STATIC_LINK=CAFFE2_STATIC_LINK_CUDA" +) + +# Additional env vars forwarded with the same name. +set(_ENV_PASSTHROUGH + UBSAN_FLAGS + BLAS + WITH_BLAS + CUDA_HOST_COMPILER + CUDA_NVCC_EXECUTABLE + CUDA_SEPARABLE_COMPILATION + CUDNN_LIBRARY + CUDNN_INCLUDE_DIR + CUDNN_ROOT + EXPERIMENTAL_SINGLE_THREAD_POOL + INSTALL_TEST + INTEL_MKL_DIR + INTEL_OMP_DIR + MKL_THREADING + MKLDNN_CPU_RUNTIME + MSVC_Z7_OVERRIDE + CAFFE2_USE_MSVC_STATIC_RUNTIME + Numa_INCLUDE_DIR + Numa_LIBRARIES + ONNX_ML + ONNX_NAMESPACE + ATEN_THREADING + WERROR + OPENSSL_ROOT_DIR + STATIC_DISPATCH_BACKEND + SELECTED_OP_LIST + TORCH_CUDA_ARCH_LIST + TORCH_XPU_ARCH_LIST + TRACING_BASED + PYTHON_LIB_REL_PATH +) + +# Low-priority aliases: if the canonical var is not set, use the alias. +set(_LOW_PRIORITY_ALIASES + "CUDA_HOST_COMPILER=CMAKE_CUDA_HOST_COMPILER" + "CUDAHOSTCXX=CUDA_HOST_COMPILER" + "CMAKE_CUDA_HOST_COMPILER=CUDA_HOST_COMPILER" + "CMAKE_CUDA_COMPILER=CUDA_NVCC_EXECUTABLE" + "CUDACXX=CUDA_NVCC_EXECUTABLE" +) + +# Forward aliased env vars (env name -> different cmake name) +foreach(_alias IN LISTS _ENV_ALIASES) + string(REPLACE "=" ";" _parts "${_alias}") + list(GET _parts 0 _env_name) + list(GET _parts 1 _cmake_name) + if(DEFINED ENV{${_env_name}} AND NOT DEFINED ${_cmake_name}) + set(${_cmake_name} "$ENV{${_env_name}}" CACHE STRING "From env ${_env_name}" FORCE) + endif() +endforeach() + +# Forward passthrough env vars (same name) +foreach(_var IN LISTS _ENV_PASSTHROUGH) + if(DEFINED ENV{${_var}} AND NOT DEFINED ${_var}) + set(${_var} "$ENV{${_var}}" CACHE STRING "From env ${_var}" FORCE) + endif() +endforeach() + +# Forward all BUILD_*, USE_*, CMAKE_* env vars not already set as CMake +# variables, plus vars ending in EXITCODE or EXITCODE__TRYRUN_OUTPUT. +# This matches the existing behavior where setup.py passed everything with +# these prefixes/suffixes through to CMake. +# We use execute_process + env to get the full list since CMake has no +# built-in way to enumerate environment variables. +execute_process( + COMMAND "${CMAKE_COMMAND}" -E environment + OUTPUT_VARIABLE _all_env + OUTPUT_STRIP_TRAILING_WHITESPACE +) +string(REPLACE "\n" ";" _env_lines "${_all_env}") +foreach(_line IN LISTS _env_lines) + if(_line MATCHES "^([A-Za-z_0-9]+)=(.*)") + set(_var_name "${CMAKE_MATCH_1}") + set(_var_value "${CMAKE_MATCH_2}") + # Only forward vars with BUILD_/USE_/CMAKE_ prefix or *EXITCODE* suffix. + string(REGEX MATCH "^(BUILD_|USE_|CMAKE_)" _has_prefix "${_var_name}") + string(REGEX MATCH "(EXITCODE|EXITCODE__TRYRUN_OUTPUT)$" _has_suffix "${_var_name}") + if(NOT _has_prefix AND NOT _has_suffix) + continue() + endif() + if(NOT DEFINED ${_var_name}) + set(${_var_name} "${_var_value}" CACHE STRING "From environment" FORCE) + endif() + endif() +endforeach() + +# Low-priority aliases +foreach(_alias IN LISTS _LOW_PRIORITY_ALIASES) + string(REPLACE "=" ";" _parts "${_alias}") + list(GET _parts 0 _env_name) + list(GET _parts 1 _cmake_name) + if(DEFINED ENV{${_env_name}} AND NOT DEFINED ${_cmake_name}) + set(${_cmake_name} "$ENV{${_env_name}}" CACHE STRING "From env alias ${_env_name}" FORCE) + endif() +endforeach() + +# Ensure Python's purelib is on CMAKE_PREFIX_PATH so CMake can find +# packages installed there (e.g., pybind11, numpy). +if(Python_EXECUTABLE) + execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_path('purelib'))" + OUTPUT_VARIABLE _py_purelib + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(_py_purelib AND NOT "${_py_purelib}" STREQUAL "") + list(PREPEND CMAKE_PREFIX_PATH "${_py_purelib}") + # Preserve paths from the CMAKE_PREFIX_PATH environment variable. + # Setting the cmake variable shadows the env var, so we must merge it in + # explicitly. This ensures conda's prefix (e.g. /opt/conda/envs/py_3.10) + # is present so cmake can find conda-provided libraries (libgomp, libnuma). + if(DEFINED ENV{CMAKE_PREFIX_PATH} AND NOT "$ENV{CMAKE_PREFIX_PATH}" STREQUAL "") + if(WIN32) + # On Windows the env var is already ;-separated and : appears in drive + # letters (e.g. C:\conda\envs\py310), so use it as-is. + set(_env_prefix "$ENV{CMAKE_PREFIX_PATH}") + else() + string(REPLACE ":" ";" _env_prefix "$ENV{CMAKE_PREFIX_PATH}") + endif() + list(APPEND CMAKE_PREFIX_PATH ${_env_prefix}) + endif() + list(REMOVE_DUPLICATES CMAKE_PREFIX_PATH) + endif() +endif() diff --git a/cmake/Metal.cmake b/cmake/Metal.cmake index c9565e2fc0e9e..f378dfb0e803d 100644 --- a/cmake/Metal.cmake +++ b/cmake/Metal.cmake @@ -45,6 +45,13 @@ set(BFLOAT_METAL_CODE " ptr[idx] += 1; } ") +set(LAMBDA_METAL_CODE " + kernel void test(device float* ptr, + uint idx [[thread_position_in_grid]]) { + auto fn = [](float x) { return x + 1.0; }; + ptr[idx] = fn(ptr[idx]); + } +") if(NOT CAN_COMPILE_METAL_FOUND) file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/bfloat_inc.metal" "${BFLOAT_METAL_CODE}") execute_process(COMMAND xcrun metal -std=metal3.1 bfloat_inc.metal @@ -59,6 +66,23 @@ if(NOT CAN_COMPILE_METAL_FOUND) message(WARNING "Machine can not compile metal shaders, fails with ${XCRUN_OUTPUT}") set(CAN_COMPILE_METAL NO CACHE BOOL "Host can compile metal shaders") endif() + if(CAN_COMPILE_METAL) + file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/lambda_test.metal" "${LAMBDA_METAL_CODE}") + execute_process(COMMAND xcrun metal -std=metal4.0 -c lambda_test.metal -o /dev/null + WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" + OUTPUT_VARIABLE XCRUN_OUTPUT + ERROR_VARIABLE XCRUN_OUTPUT + RESULT_VARIABLE XCRUN_RC) + if(${XCRUN_RC} EQUAL 0) + message(STATUS "Metal toolchain supports Metal 4.0") + set(CAN_COMPILE_METAL_40 YES CACHE BOOL "Host can compile Metal 4.0 shaders" FORCE) + else() + message(STATUS "Metal toolchain does not support Metal 4.0") + set(CAN_COMPILE_METAL_40 NO CACHE BOOL "Host can compile Metal 4.0 shaders" FORCE) + endif() + else() + set(CAN_COMPILE_METAL_40 NO CACHE BOOL "Host can compile Metal 4.0 shaders" FORCE) + endif() set(CAN_COMPILE_METAL_FOUND YES CACHE INTERNAL "Run check for shader compiler") endif() diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index 7f53dacadef59..10ee8546b8d9d 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -47,7 +47,7 @@ IF(NOT MKLDNN_FOUND) endif() ExternalProject_Add(xpu_mkldnn_proj GIT_REPOSITORY https://github.com/uxlfoundation/oneDNN - GIT_TAG v3.10.2 + GIT_TAG v3.11.2 PREFIX ${XPU_MKLDNN_DIR_PREFIX} BUILD_IN_SOURCE 0 CMAKE_ARGS -DCMAKE_C_COMPILER=icx diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake index cef8002f81706..447753f48c615 100644 --- a/cmake/Modules/FindNCCL.cmake +++ b/cmake/Modules/FindNCCL.cmake @@ -57,31 +57,6 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks include(CheckCXXSymbolExists) check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) - # this condition check only works for non static NCCL linking - if (NCCL_VERSION_DEFINED AND NOT USE_STATIC_NCCL) - set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") - file(WRITE ${file} " - #include - #include - int main() - { - std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; - int x; - ncclGetVersion(&x); - return x == NCCL_VERSION_CODE; - } -") - try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} - RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER - CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" - LINK_LIBRARIES ${NCCL_LIBRARIES}) - if (NOT NCCL_VERSION_MATCHED) - message(FATAL_ERROR "Found NCCL header version and library version do not match! \ -(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") - endif() - message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") - endif () - set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index cbc26c4b23ad8..642b96ba119e8 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -30,10 +30,10 @@ endif() set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell") # This list will be used for CUDA_ARCH_NAME = Common option (enabled by default) -set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0") +set(CUDA_COMMON_GPU_ARCHITECTURES "5.0") # This list is used to filter CUDA archs when autodetecting -set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0") +set(CUDA_ALL_GPU_ARCHITECTURES "5.0") if(CUDA_VERSION VERSION_GREATER "10.5") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere") @@ -65,19 +65,10 @@ if(NOT CUDA_VERSION VERSION_LESS "11.8") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0") - if(CUDA_VERSION VERSION_LESS "12.0") - set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX") - endif() endif() -if(NOT CUDA_VERSION VERSION_LESS "12.0") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a") - list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a") - list(REMOVE_ITEM CUDA_COMMON_GPU_ARCHITECTURES "3.5") - list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5") -endif() +list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0a") +list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0a") if(CUDA_VERSION VERSION_GREATER "12.6") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell") diff --git a/cmake/PreBuildSteps.cmake b/cmake/PreBuildSteps.cmake new file mode 100644 index 0000000000000..14dbbb7cf8410 --- /dev/null +++ b/cmake/PreBuildSteps.cmake @@ -0,0 +1,146 @@ +# Pre-build steps previously handled by setup.py: +# 1. Git submodule initialization and verification +# 2. NCCL checkout from pinned version +# +# This file is included early in CMakeLists.txt, before option() declarations. +# It relies on env vars (forwarded by EnvVarForwarding.cmake) and CMake -D +# cache variables, both of which are available at this point. + +find_package(Git QUIET) + +# --- Submodule initialization and verification --- +# Matches the logic in setup.py::check_submodules(). +if(NOT DEFINED USE_SYSTEM_LIBS OR NOT USE_SYSTEM_LIBS) + # Read submodule paths from .gitmodules if available, otherwise use defaults. + set(_gitmodules_file "${PROJECT_SOURCE_DIR}/.gitmodules") + if(EXISTS "${_gitmodules_file}") + file(STRINGS "${_gitmodules_file}" _gitmodule_lines REGEX "^[[:space:]]*path") + set(_submodule_folders) + foreach(_line IN LISTS _gitmodule_lines) + string(REGEX REPLACE ".*=[[:space:]]*" "" _path "${_line}") + list(APPEND _submodule_folders "${PROJECT_SOURCE_DIR}/${_path}") + endforeach() + else() + set(_submodule_folders + "${PROJECT_SOURCE_DIR}/third_party/gloo" + "${PROJECT_SOURCE_DIR}/third_party/cpuinfo" + "${PROJECT_SOURCE_DIR}/third_party/onnx" + "${PROJECT_SOURCE_DIR}/third_party/fbgemm" + "${PROJECT_SOURCE_DIR}/third_party/cutlass" + ) + endif() + + set(_all_missing TRUE) + foreach(_dir IN LISTS _submodule_folders) + if(EXISTS "${_dir}" AND IS_DIRECTORY "${_dir}") + file(GLOB _contents "${_dir}/*") + if(_contents) + set(_all_missing FALSE) + break() + endif() + endif() + endforeach() + + if(_all_missing AND GIT_FOUND) + message(STATUS "Initializing git submodules...") + execute_process( + COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive + WORKING_DIRECTORY "${PROJECT_SOURCE_DIR}" + RESULT_VARIABLE _submodule_result + ) + if(NOT _submodule_result EQUAL 0) + message(FATAL_ERROR + "Git submodule initialization failed. Please run:\n" + " git submodule update --init --recursive" + ) + endif() + endif() + + # Verify submodules contain expected files (catches corrupt/partial checkouts). + set(_expected_files CMakeLists.txt Makefile setup.py LICENSE LICENSE.md LICENSE.txt) + foreach(_dir IN LISTS _submodule_folders) + set(_found FALSE) + foreach(_file IN LISTS _expected_files) + if(EXISTS "${_dir}/${_file}") + set(_found TRUE) + break() + endif() + endforeach() + if(NOT _found) + message(FATAL_ERROR + "Submodule ${_dir} appears incomplete (none of " + "${_expected_files} found).\n" + "Please run: git submodule update --init --recursive" + ) + endif() + endforeach() + # Extra check for fbgemm's nested dependency + if(NOT EXISTS "${PROJECT_SOURCE_DIR}/third_party/fbgemm/external/asmjit/CMakeLists.txt") + message(FATAL_ERROR + "third_party/fbgemm/external/asmjit appears incomplete.\n" + "Please run: git submodule update --init --recursive" + ) + endif() +endif() + +# --- NCCL checkout --- +# Clone NCCL from the pinned tag if building with NCCL and not using +# system NCCL. Conditions match build_pytorch_libs.py::build_pytorch(). +if(NOT USE_SYSTEM_NCCL) + # Only attempt if USE_DISTRIBUTED, USE_CUDA, USE_NCCL are not explicitly OFF. + set(_skip_nccl FALSE) + foreach(_var USE_DISTRIBUTED USE_CUDA USE_NCCL) + if(DEFINED ${_var}) + string(TOUPPER "${${_var}}" _val) + if(_val MATCHES "^(OFF|0|NO|FALSE|N)$") + set(_skip_nccl TRUE) + break() + endif() + endif() + endforeach() + + if(NOT _skip_nccl) + set(_nccl_dir "${PROJECT_SOURCE_DIR}/third_party/nccl") + if(NOT EXISTS "${_nccl_dir}") + # Select pin file: try a CUDA-version-specific pin (e.g. nccl-cu126.txt) + # first, fall back to nccl.txt. Adding a new pin file is sufficient to + # support a new CUDA version — no CMake changes needed. + set(_nccl_pin_name "nccl.txt") + if(DEFINED ENV{DESIRED_CUDA}) + set(_cuda_ver "$ENV{DESIRED_CUDA}") + elseif(DEFINED ENV{CUDA_VERSION}) + set(_cuda_ver "$ENV{CUDA_VERSION}") + else() + set(_cuda_ver "") + endif() + set(_cuda_suffix "") + if(_cuda_ver MATCHES "^([0-9]+)\\.([0-9]+)") + set(_cuda_suffix "cu${CMAKE_MATCH_1}${CMAKE_MATCH_2}") + elseif(_cuda_ver MATCHES "^(cu[0-9]+)$") + set(_cuda_suffix "${CMAKE_MATCH_1}") + endif() + if(_cuda_suffix) + set(_versioned_pin "${PROJECT_SOURCE_DIR}/.ci/docker/ci_commit_pins/nccl-${_cuda_suffix}.txt") + if(EXISTS "${_versioned_pin}") + set(_nccl_pin_name "nccl-${_cuda_suffix}.txt") + endif() + endif() + + set(_nccl_pin_file "${PROJECT_SOURCE_DIR}/.ci/docker/ci_commit_pins/${_nccl_pin_name}") + if(EXISTS "${_nccl_pin_file}") + file(READ "${_nccl_pin_file}" _nccl_tag) + string(STRIP "${_nccl_tag}" _nccl_tag) + message(STATUS "Checking out NCCL release tag: ${_nccl_tag} (from ${_nccl_pin_name})") + include(FetchContent) + FetchContent_Declare( + nccl + GIT_REPOSITORY https://github.com/NVIDIA/nccl + GIT_TAG "${_nccl_tag}" + GIT_SHALLOW TRUE + SOURCE_DIR "${_nccl_dir}" + ) + FetchContent_Populate(nccl) + endif() + endif() + endif() +endif() diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index 686f5861960d3..2d68d81c6b367 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -149,6 +149,7 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_PYTORCH_METAL_EXPORT : ${USE_PYTORCH_METAL_EXPORT}") message(STATUS " USE_MPS : ${USE_MPS}") message(STATUS " CAN_COMPILE_METAL : ${CAN_COMPILE_METAL}") + message(STATUS " CAN_COMPILE_METAL_40 : ${CAN_COMPILE_METAL_40}") message(STATUS " USE_MKL : ${CAFFE2_USE_MKL}") if(${CAFFE2_USE_MKL}) message(STATUS " USE_STATIC_MKL : ${USE_STATIC_MKL}") @@ -177,7 +178,11 @@ function(caffe2_print_configuration_summary) if(${USE_NCCL}) message(STATUS " USE_SYSTEM_NCCL : ${USE_SYSTEM_NCCL}") endif() - message(STATUS " Found NVSHMEM : ${NVSHMEM_INCLUDE_DIR}") + if(${USE_ROCM}) + message(STATUS " Found ROCSHMEM : ${ROCSHMEM_INCLUDE_DIR}") + else() + message(STATUS " Found NVSHMEM : ${NVSHMEM_INCLUDE_DIR}") + endif() message(STATUS " USE_NNPACK : ${USE_NNPACK}") message(STATUS " USE_NUMPY : ${USE_NUMPY}") message(STATUS " USE_OBSERVERS : ${USE_OBSERVERS}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index abf5c8149116f..5749c67db112b 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -155,7 +155,7 @@ if(NOT @BUILD_SHARED_LIBS@) endif() set_target_properties(torch PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}" - CXX_STANDARD 17 + CXX_STANDARD 20 ) if(TORCH_CXX_FLAGS) set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1408da46f25a4..a87b16f5ba889 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -75,7 +75,7 @@ message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") # needed because the find_package call to this module uses the Module mode search # https://cmake.org/cmake/help/latest/command/find_package.html#search-modes if(UNIX) - set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) + set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip;${ROCM_PATH}/lib/${CMAKE_LIBRARY_ARCHITECTURE}/cmake/hip ${CMAKE_MODULE_PATH}) else() # Win32 set(CMAKE_MODULE_PATH ${ROCM_PATH}/cmake/ ${CMAKE_MODULE_PATH}) endif() @@ -100,6 +100,9 @@ endmacro() # MODULE argument is added for clarity that CMake is searching # for FindHIP.cmake in Module mode find_package_and_print_version(HIP 1.0 MODULE) +if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + enable_language(HIP) +endif() if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) @@ -205,6 +208,7 @@ if(HIP_FOUND) find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(hipsolver REQUIRED) find_package_and_print_version(rocsolver REQUIRED) + find_package_and_print_version(rocshmem) # workaround cmake 4 build issue if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0") message(WARNING "Work around hiprtc cmake failure for cmake >= 4") @@ -219,6 +223,7 @@ if(HIP_FOUND) if(UNIX) find_package_and_print_version(rccl) find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(rocm_smi REQUIRED) endif() # Optional components. diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index cb0e4e07c7f65..239641d352c41 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -78,8 +78,8 @@ endif() message(STATUS "PyTorch: CUDA detected: " ${CUDA_VERSION}) message(STATUS "PyTorch: CUDA nvcc is: " ${CUDA_NVCC_EXECUTABLE}) message(STATUS "PyTorch: CUDA toolkit directory: " ${CUDA_TOOLKIT_ROOT_DIR}) -if(CUDA_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "PyTorch requires CUDA 12.0 or above.") +if(CUDA_VERSION VERSION_LESS 12.1) + message(FATAL_ERROR "PyTorch requires CUDA 12.1 or above.") endif() if(CUDA_FOUND) @@ -307,6 +307,14 @@ else() endif() # nvrtc +# cuDNN frontend needs libnvrtc symbols, but linking through CUDA::nvrtc pulls +# CUDA::cuda_driver transitively. Keep a driver-free target for cuDNN users and +# reserve caffe2::nvrtc for the stub library that actually needs the driver API. +add_library(caffe2::nvrtc_runtime INTERFACE IMPORTED) +set_property( + TARGET caffe2::nvrtc_runtime PROPERTY INTERFACE_LINK_LIBRARIES + "${CUDA_NVRTC_LIB}") + add_library(caffe2::nvrtc INTERFACE IMPORTED) set_property( TARGET caffe2::nvrtc PROPERTY INTERFACE_LINK_LIBRARIES diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index 731bbe7d21862..a667ad55771b4 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -337,7 +337,7 @@ endmacro() # Usage: # torch_compile_options(lib_name) function(torch_compile_options libname) - set_property(TARGET ${libname} PROPERTY CXX_STANDARD 17) + set_property(TARGET ${libname} PROPERTY CXX_STANDARD 20) # until they can be unified, keep these lists synced with setup.py if(MSVC) @@ -391,12 +391,17 @@ function(torch_compile_options libname) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") if(NOT USE_CUDA) - # NS: One can compile CUDA code with extra-semi flag as nvcc generates code like + # NS: One can not compile CUDA code with extra-semi flag as nvcc generates code like # namespace MemoryOps_cu_d8602b38_109889 __attribute__((visibility("hidden"))) { }; list(APPEND private_compile_options -Wextra-semi) else() # NVCC + clang15 reports deprecated copies from GPU lambda instantiations list(APPEND private_compile_options -Wno-deprecated-copy) + # NVCC + clang18 reports spurious deprecated deprecated literal operator declaration when there were none + # I.e. failures look like torch/headeronly/util/complex.h:334:40: error: identifier '_if' preceded by whitespace in a literal operator declaration is deprecated + # but if one to look at the source code, there are no space there + list(APPEND private_compile_options -Wno-deprecated-literal-operator) + endif() list(APPEND private_compile_options -Wmove) else() diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index b39e31d0ade8a..2731c2842c7f4 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -37,6 +37,11 @@ torch_xpu_get_arch_list(XPU_ARCH_FLAGS) # propagate to torch-xpu-ops set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) +# Ensure SYCL device code compiles with C++20 (matching CMAKE_CXX_STANDARD). +# SYCL_FLAGS flows into SYCL_COMPILE_FLAGS in torch-xpu-ops' BuildFlags.cmake +# and is passed directly to icpx on the device compilation command line. +list(APPEND SYCL_FLAGS -std=c++20) + # Ensure USE_XPU is enabled. string(APPEND XPU_HOST_CXX_FLAGS " -DUSE_XPU") string(APPEND XPU_HOST_CXX_FLAGS " -DSYCL_COMPILER_VERSION=${SYCL_COMPILER_VERSION}") diff --git a/docs/cpp/Makefile b/docs/cpp/Makefile index e244432b9fdb3..21379d1856690 100644 --- a/docs/cpp/Makefile +++ b/docs/cpp/Makefile @@ -2,24 +2,41 @@ # You can set these variables from the command line. SPHINXOPTS = -j auto -SPHINXBUILD = sphinx-build +SPHINXBUILD = python -m sphinx SPHINXPROJ = PyTorch SOURCEDIR = source BUILDDIR = build PYCMD = python +DOXYGEN = doxygen # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile clean +.PHONY: help Makefile clean doxygen html + +# Generate Doxygen XML (required for Breathe directives) +doxygen: + @echo "Running Doxygen..." + @cd $(SOURCEDIR) && $(DOXYGEN) Doxyfile clean: - @# Clean up sphinx and doxygen build artifacts. + @# Clean up sphinx build artifacts and Doxygen output. + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + @rm -rf $(BUILDDIR)/xml + +# Build HTML with Doxygen XML generation first +html: doxygen @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - @# Clean up exhale generated api. - @echo "Removing everything under generated 'source/api'..." - @rm -rf $(SOURCEDIR)/api + +# Check documentation coverage against curated allowlist and HTML formatting +coverage: html + @echo "Running C++ docs coverage check..." + @$(PYCMD) check_coverage.py + +coverage-only: + @echo "Running C++ docs coverage check (no rebuild)..." + @$(PYCMD) check_coverage.py # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/docs/cpp/README.md b/docs/cpp/README.md new file mode 100644 index 0000000000000..0808a5cb4af0b --- /dev/null +++ b/docs/cpp/README.md @@ -0,0 +1,315 @@ +# PyTorch C++ API Documentation + +This directory contains the source for PyTorch's C++ API documentation, built with +Sphinx, Breathe, and Doxygen. + +## How it works + +The documentation pipeline has three stages: + +``` +C++ headers ──→ Doxygen ──→ XML ──→ Breathe ──→ Sphinx ──→ HTML + ↑ ↑ + Doxyfile Markdown files with + (which headers) Breathe directives (MyST) +``` + +1. **Doxygen** reads C++ headers listed in `source/Doxyfile` `INPUT` and produces + XML in `build/xml/`. +2. **Breathe** is a Sphinx extension that reads Doxygen XML and makes it available + via directives like `` ```{doxygenclass} ClassName `` (MyST Markdown syntax). +3. **Sphinx** builds the final HTML from `.md` files in `source/`, using Breathe + directives to pull in C++ API documentation. The `myst_parser` extension + enables Markdown support. + +Only headers listed in the Doxyfile's `INPUT` are processed. Source files must +explicitly reference each symbol — nothing is auto-generated. + +## Building + +```bash +make html # Build Doxygen XML + Sphinx HTML +make doxygen # Build Doxygen XML only +make clean # Clean build artifacts +``` + +The output is in `build/html/`. + +## Contributing to the C++ docs + +### Adding a new API + +1. **Ensure the header is in the Doxyfile** — check that `source/Doxyfile` `INPUT` + includes the header file or its parent directory. If not, add it: + + ``` + INPUT = ... \ + ../../../path/to/your/header.h + ``` + +2. **Add a Breathe directive to the appropriate `.md` file** under `source/api/`. + See [Which directive to use](#which-directive-to-use) below. + +3. **Run `make html`** and check the output in `build/html/`. + +4. **Run `python check_coverage.py`** to verify your API shows as documented. + +### Which directive to use + +Use Breathe directives to pull documentation from Doxygen XML. These render the +full C++ signature, doc comments, parameters, and members automatically. + +All source files use MyST Markdown syntax (fenced directives with backticks). + +**Classes and structs** — use `doxygenclass` or `doxygenstruct`: + +````markdown +```{doxygenclass} torch::nn::Linear +:members: +:undoc-members: +``` +```` + +- `:members:` shows all public member functions and variables +- `:undoc-members:` includes members without doc comments +- Omit both flags to show only the class description (useful when `:members:` + causes rendering issues) + +**Free functions** — use `doxygenfunction`: + +````markdown +```{doxygenfunction} torch::autograd::grad +``` +```` + +For overloaded functions, Breathe will document all overloads. + +**Macros** — use `doxygendefine`: + +````markdown +```{doxygendefine} TORCH_LIBRARY +``` +```` + +**Typedefs** — use `doxygentypedef`: + +````markdown +```{doxygentypedef} torch::DeviceType +``` +```` + +### When Breathe directives don't work + +Some symbols can't be documented with Breathe directives: + +- **TORCH_MODULE holder classes** (e.g., `Conv2d`, `Linear`): The `TORCH_MODULE()` + macro generates these, but Doxygen can't index them. Document the `*Impl` class + instead (e.g., `Conv2dImpl`) — it contains all the actual methods. + +- **Functions with broken `\rst`/`\endrst` blocks**: Some doc comments use Doxygen's + `\rst` alias to embed RST. When a comment has multiple such blocks, Doxygen + generates malformed XML and Breathe renders raw text. In these cases, either: + - Fix the header to use native Doxygen (`@code{.cpp}`/`@endcode`, `@note`, + `@warning`) instead of `\rst`/`\endrst` + - Use a hand-written Sphinx C++ domain directive as a fallback + +- **Functions with mismatched `\param` names**: If a header's `\param` names don't + match the actual parameter names, Doxygen may fail to index the function. Use a + hand-written `cpp:function` directive instead. + +**Hand-written Sphinx C++ domain directives** (fallback): + +````markdown +```{cpp:function} void torch::autograd::backward(const variable_list& tensors, const variable_list& grad_tensors = {}, std::optional retain_graph = std::nullopt, bool create_graph = false, const variable_list& inputs = {}) + +Computes gradients of given tensors with respect to graph leaves. + +:param tensors: Tensors of which the derivative will be computed. +:param grad_tensors: The "vector" in the Jacobian-vector product. +``` +```` + +These don't pull from Doxygen — you write the signature and docs manually. + +**`{eval-rst}` escape hatch** — if a MyST directive doesn't render correctly, +you can embed raw RST: + +````markdown +```{eval-rst} +.. doxygenclass:: X::A + :members: + :protected-members: + :private-members: +``` +```` + +### Source file structure + +Each `.md` file under `source/api/` documents one topic area using MyST Markdown. +The typical pattern: + +````markdown +# Page Title + +Brief description of this API area. + +## Section Name + +Optional prose explaining usage, with a code example: + +```cpp +#include +auto x = torch::randn({2, 3}); +``` + +```{doxygenclass} torch::nn::SomeClass +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::SomeClassOptions +:members: +:undoc-members: +``` +```` + +**Nesting directives:** When a directive contains other directives or code blocks, +the outer fence must use more backticks than the inner ones: + +`````markdown +````{cpp:class} at::Tensor + +The primary tensor class. + +```{cpp:function} int64_t dim() const + +Returns the number of dimensions. +``` +```` +````` + +### Writing doc comments in C++ headers + +Doxygen extracts documentation from comments in headers. Use `///` style: + +```cpp +/// Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. +/// +/// @code{.cpp} +/// auto linear = torch::nn::Linear(torch::nn::LinearOptions(10, 5)); +/// auto output = linear->forward(input); +/// @endcode +/// +/// @note The weight matrix is transposed compared to the Python API. +class LinearImpl : public Cloneable { +``` + +**Preferred Doxygen commands:** +- `@code{.cpp}` / `@endcode` for code examples +- `@note` for important notes +- `@warning` for warnings +- `@param name` for parameter descriptions +- `@return` for return value descriptions + +**Avoid** `\rst` / `\endrst` blocks — they cause rendering issues when a single +doc comment contains multiple blocks. Use native Doxygen commands instead. + +## Coverage checking + +```bash +make coverage # Build docs + run coverage check +make coverage-only # Run coverage check without rebuilding +``` + +`check_coverage.py` auto-discovers public APIs from Doxygen XML (`build/xml/index.xml`) +and checks which ones have Breathe or Sphinx directives in the source files. + +Reports are written to `cpp_coverage.txt` and `cpp_html_issues.txt`. + +### How coverage discovery works + +The script parses `build/xml/index.xml` to find all classes, structs, functions, +and macros that Doxygen indexed. It then: + +1. Filters out internal symbols using `EXCLUDED_PATTERNS` and `EXCLUDED_SYMBOLS` +2. Re-includes any symbols in `INCLUDED_SYMBOLS` (overrides exclusions) +3. Scans `.md` files for Breathe (`doxygenclass`, `doxygenfunction`, etc.) and + Sphinx C++ domain (`cpp:class`, `cpp:function`, etc.) directives +4. Reports the gap + +### Excluding internal APIs from coverage + +Not everything in Doxygen XML is public API. The script has three exclusion mechanisms: + +**`EXCLUDED_PATTERNS`** — regex patterns for broad categories: + +```python +EXCLUDED_PATTERNS = [ + r".*::detail::.*", # Internal namespaces + r"torch::jit::.*", # Deprecated + r".*::_\w+", # Underscore-prefixed internals +] +``` + +**`EXCLUDED_SYMBOLS`** — exact match for specific symbols: + +```python +EXCLUDED_SYMBOLS = { + "torch::autograd::deleteNode", + "at::native::dataSize", +} +``` + +**Doxyfile `EXCLUDE`** — prevents Doxygen from indexing files entirely: + +``` +EXCLUDE = ../../../torch/csrc/api/include/torch/detail +``` + +### Including "internal" APIs that are actually public + +Some APIs live in internal-looking namespaces but are widely used by external +developers. To track these for documentation coverage, add them to +`INCLUDED_SYMBOLS` in `check_coverage.py`: + +```python +INCLUDED_SYMBOLS: set[str] = { + "c10::IValue", # Used for custom op registration +} +``` + +`INCLUDED_SYMBOLS` takes priority over all exclusions. + +### HTML formatting checks + +The script also checks built HTML for: +- Unresolved Breathe directives ("Cannot find class/struct/function") +- Raw directive text in output (build failures) +- Sphinx "problematic" nodes (broken references) +- Near-empty API pages + +### coverxygen (optional) + +```bash +python check_coverage.py --coverxygen +``` + +This runs [coverxygen](https://github.com/psycofdj/coverxygen) on the Doxygen XML +to measure what percentage of C++ symbols have doc comments in the source code. +This is complementary to file coverage — it tells you which headers need more +doc comments, not which symbols need Sphinx directives. + +## Known issues + +- **`\rst`/`\endrst` rendering** — Some C++ headers use Doxygen's `\rst` alias to + embed RST in doc comments. When a single doc comment has multiple `\rst`/`\endrst` + blocks, Doxygen generates malformed XML and Breathe renders them as raw text. + Avoid using `:members:` on classes with this issue. Use hand-written examples + in the source files instead, or convert the headers to native Doxygen commands. + +- **TORCH_MODULE holder classes** — Doxygen can't index classes generated by the + `TORCH_MODULE()` macro. Use the `*Impl` class name in directives instead. + +- **"Subclassed by" plain text** — Without per-class pages (exhale), Breathe + renders "Subclassed by" lists as plain text instead of links. A `conf.py` hook + (`remove_subclassed_by`) strips these automatically. diff --git a/docs/cpp/check_coverage.py b/docs/cpp/check_coverage.py new file mode 100644 index 0000000000000..bd958671cacbf --- /dev/null +++ b/docs/cpp/check_coverage.py @@ -0,0 +1,782 @@ +#!/usr/bin/env python3 +"""C++ API documentation coverage checker. + +Auto-discovers public C++ APIs from Doxygen XML output and checks which ones +are documented in the RST source files via Breathe or Sphinx C++ domain +directives. + +Uses an exclusion list (EXCLUDED_APIS) to skip internal/detail symbols that +don't need public documentation, rather than maintaining a hardcoded allowlist. + +Additionally checks built HTML for broken formatting (empty pages, +unresolved directives, rendering errors). + +Usage: + python check_coverage.py # RST coverage + HTML checks + python check_coverage.py --coverxygen # also run coverxygen on Doxygen XML +""" + +import argparse +import re +import subprocess +import sys +import xml.etree.ElementTree as ET +from pathlib import Path + + +# ─── Paths ─────────────────────────────────────────────────────────────────── + +SCRIPT_DIR = Path(__file__).resolve().parent +SOURCE_DIR = SCRIPT_DIR / "source" +BUILD_HTML = SCRIPT_DIR / "build" / "html" +BUILD_XML = SCRIPT_DIR / "build" / "xml" +COVERAGE_OUTPUT = SCRIPT_DIR / "cpp_coverage.txt" +HTML_REPORT = SCRIPT_DIR / "cpp_html_issues.txt" + +# ─── Inclusion override ────────────────────────────────────────────────────── +# Symbols that match an exclusion pattern but should still be tracked. +# Use this for "internal" APIs that are widely used as public API. +# Add the fully-qualified symbol name here and it will bypass all exclusions. +INCLUDED_SYMBOLS: set[str] = { + # Example: "c10::IValue" would track it even though c10::IValue is excluded +} + +# ─── Exclusion list ────────────────────────────────────────────────────────── +# Symbols that should NOT be flagged as missing documentation. +# Add internal, detail, or otherwise non-public symbols here. +# Note: INCLUDED_SYMBOLS takes priority over these exclusions. + +EXCLUDED_PATTERNS = [ + # Internal/detail namespaces + r".*::detail::.*", + r".*::detail_::.*", + r"torch::python::.*", + # Underscore-prefixed internal classes + r".*::_\w+", + # Enum helper structs + r"torch::enumtype::.*", + # OptimizerCloneableOptions SFINAE helpers + r"torch::optim::OptimizerCloneableOptions::.*", + # Internal optimizer state/options cloneable helpers + r"torch::optim::OptimizerCloneable.*", + # Error classes (c10 exceptions) + r"c10::.*Error$", + r"c10::ErrorAlwaysShowCppStacktrace", + # Warning internals + r"c10::Warning.*", + r"c10::WarningHandler", + r"c10::WarningUtils::.*", + # c10 IValue internals + r"c10::IValue::.*", + r"c10::IValue", + r"c10::WeakIValue", + r"c10::ivalue::.*", + r"c10::StrongTypePtr", + r"c10::WeakTypePtr", + r"c10::WeakOrStrongTypePtr", + r"c10::WeakOrStrongCompilationUnit", + r"c10::Capsule", + r"c10::OptionalArray", + r"c10::StreamData3", + # OrderedDict::Item (internal helper) + r"torch::OrderedDict::Item", + # ExpandingArray (internal template utility) + r"torch::ExpandingArray.*", + # IMethod (internal) + r"torch::IMethod", + # CustomClassHolder (internal base) + r"torch::CustomClassHolder", + # NodeGuard (internal autograd) + r"torch::autograd::NodeGuard", + # Autograd internals + r"torch::autograd::CppNode", + r"torch::autograd::ExtractVariables", + r"torch::autograd::Node", + r"torch::autograd::Node::.*", + r"torch::autograd::TraceableFunction", + r"torch::autograd::TypeAndSize", + # Sequencer internals + r"torch::data::.*::detail::.*", + # cuDNN descriptor internals + r"at::native::ActivationDescriptor", + r"at::native::ConvolutionDescriptor", + r"at::native::SpatialTransformerDescriptor", + r"at::native::DropoutDescriptor", + r"at::native::RNNDataDescriptor", + r"at::native::DftiDescriptor", + r"at::native::DescriptorDeleter", + r"at::native::DftiDescriptorDeleter", + r"at::native::RNNDescriptor", + # ATen internals + r"at::OptionalTensorRef", + r"at::TensorRef", + # at::cuda internals (allocator, workspace, cublas) + r"at::cuda::WorkspaceMapWithMutex", + r"at::cuda::clearCublasWorkspaces.*", + r"at::cuda::cublas_handle_stream_to_workspace", + r"at::cuda::cublaslt_handle_stream_to_workspace", + r"at::cuda::getCUDABlasLt.*", + r"at::cuda::getCUDADeviceAllocator", + r"at::cuda::getChosenWorkspaceSize", + r"at::cuda::getNumGPUs", + r"at::cuda::is_available", + r"at::cuda::warp_size", + # jit namespace (deprecated) + r"torch::jit::.*", + # Operators that are just operator<< or operator>> + r".*::operator<<", + r".*::operator>>", + r".*::operator==", + r".*::operator!=", + # Internal serialize helpers + r"torch::optim::serialize", + r"torch::optim::detail::.*", + # Reduction enum helpers + r"torch::nn::reduction", + r"torch::nn::log_target", + # Internal module utils + r"torch::nn::modules::utils::.*", + # Internal c10 helpers + r"c10::detail::.*", + r"c10::detail_::.*", + r"c10::makeArrayRef", + r"c10::checkObjectSortSchema", + r"c10::getGreaterThanComparator", + r"c10::getLessThanComparator", + r"c10::value_or_else", + r"c10::warn", + r"c10::GetExceptionString", + # torch::detail + r"torch::detail::.*", + # Internal data shuttle/queue + r"torch::data::detail::.*", + # DataLoaderBase internal types + r"torch::data::DataLoaderBase::.*", + r"torch::data::WorkerException", + r"torch::data::FullDataLoaderOptions", + # Template specializations of Stack + r"torch::data::transforms::Stack< .*>", + # Example partial specialization + r"torch::data::Example< .*>", + # Doxygen internal macros + r"DEFINE_CASE", + r"DEFINE_TAG", + r"COUNT_TAG", + r"TRUTH_TABLE_ENTRY", + r"C10_EXPAND_MSVC_WORKAROUND", + r"TORCH_FORALL_TAGS", + # Non-public torch::nn functions (module stream operators, etc.) + r"torch::nn::operator.*", + # AnyModule/AnyValue internal holders + r"torch::nn::AnyModuleHolder.*", + r"torch::nn::AnyModulePlaceholder", + r"torch::nn::AnyValue.*", + r"torch::nn::NamedAnyModule", + # Internal base classes (users use the derived classes) + r"torch::nn::ConvNdImpl", + r"torch::nn::ConvTransposeNdImpl", + r"torch::nn::BatchNormImplBase", + r"torch::nn::NormImplBase", + r"torch::nn::InstanceNormImpl", + r"torch::nn::MaxPoolImpl", + r"torch::nn::AvgPoolImpl", + r"torch::nn::AdaptiveAvgPoolImpl", + r"torch::nn::AdaptiveMaxPoolImpl", + r"torch::nn::MaxUnpoolImpl", + r"torch::nn::LPPoolImpl", + r"torch::nn::ConstantPadImpl", + r"torch::nn::ReflectionPadImpl", + r"torch::nn::ReplicationPadImpl", + r"torch::nn::ZeroPadImpl", + r"torch::nn::FractionalMaxPoolImpl", + # nn::functions internal namespace + r"torch::nn::functions::.*", + # AdaptiveLogSoftmaxWithLoss (niche, rarely used in C++) + r"torch::nn::AdaptiveLogSoftmaxWithLoss.*", + r"torch::nn::ASMoutput", + # CrossMapLRN2d (niche) + r"torch::nn::CrossMapLRN2d.*", + # _out function variants (documented alongside the main function) + r"torch::special::.*_out", + r"torch::fft::.*_out", + # torch internal helpers + r"torch::InitLambda", + r"torch::dispatch", + r"torch::equal_if_defined", + r"torch::getAllCustomClassesNames", + r"torch::init", + r"torch::make_custom_class", + r"torch::selective_class_", + r"torch::pickle_load", + r"torch::pickle_save", + r"torch::schema", + r"torch::nativert::.*", + # RNNCellOptionsBase (internal base) + r".*::RNNCellOptionsBase", + # Unnamespaced Options structs (indexed without namespace by Doxygen) + r"^[A-Z]\w+Options$", + # Unnamespaced classes without namespace (Doxygen quirk) + r"^TransformerDecoderLayer$", + r"^TransformerDecoderLayerOptions$", + # functional namespace internal options structs + r"functional::.*FuncOptions", +] + +# Specific symbols to exclude (exact match) +EXCLUDED_SYMBOLS = { + # Internal / not useful to document individually + "torch::data::datasets::map", + "torch::data::datasets::make_shared_dataset", + "torch::data::datasets::operator<<", + "torch::data::datasets::operator>>", + "torch::enumtype::get_enum_name", + "torch::enumtype::reduction_get_enum", + "torch::autograd::_wrap_outputs", + "torch::autograd::check_variable_result", + "torch::autograd::CppNode_apply_functional", + "torch::autograd::CppNode_apply_functional_ivalue", + "torch::autograd::forward_ad::enter_dual_level", + "torch::autograd::forward_ad::exit_dual_level", + "torch::autograd::any_variable_requires_grad", + "torch::autograd::collect_next_edges", + "torch::autograd::create_gradient_edge", + "torch::autograd::deleteNode", + "torch::autograd::extract_vars", + "torch::autograd::get_current_node", + "torch::autograd::to_optional", + "torch::autograd::to_output_type", + "torch::nn::parallel::replicate", + "torch::nn::parallel::parallel_apply", + "torch::nn::parallel::data_parallel", + "torch::python::add_module_bindings", + "torch::python::bind_module", + "torch::python::init_bindings", + # at::native cuDNN internals + "at::native::dataSize", + "at::native::fixSizeOneDimStride", + "at::native::operator<<", + "at::native::getCudnnDataTypeFromScalarType", + # c10 cuda pool functions (internal) + "c10::cuda::getStreamFromPool", + "c10::cuda::getStreamFromExternal", + "c10::xpu::getStreamFromPool", + "c10::xpu::getStreamFromExternal", + # c10 private use backend registration (internal) + "c10::get_privateuse1_backend", + "c10::is_privateuse1_backend_registered", + "c10::register_privateuse1_backend", + "c10::isValidDeviceType", + "c10::DeviceTypeName", + # torch::stable::detail internals + "torch::stable::detail::unbox_to_tuple_impl", + "torch::stable::detail::unbox_to_tuple", + "torch::stable::detail::box_from_tuple_impl", + "torch::stable::detail::box_from_tuple", + # torch::stable::accelerator (documented in stable API page) + "torch::stable::accelerator::getCurrentStream", +} + +# Namespaces whose free functions should be checked for documentation +PUBLIC_FUNCTION_NAMESPACES = { + "torch", + "torch::autograd", + "torch::cuda", + "torch::mps", + "torch::xpu", + "torch::fft", + "torch::special", + "torch::nn::functional", + "torch::nn::init", + "torch::nn::utils", + "torch::nn::utils::rnn", + "torch::data", + "torch::stable", + "torch::stable::accelerator", + "c10", + "c10::cuda", + "c10::xpu", + "at::cuda", +} + + +# ─── XML parsing ───────────────────────────────────────────────────────────── + + +def _is_excluded(symbol: str) -> bool: + """Check if a symbol should be excluded from coverage tracking.""" + if symbol in INCLUDED_SYMBOLS: + return False + if symbol in EXCLUDED_SYMBOLS: + return True + for pattern in EXCLUDED_PATTERNS: + if re.fullmatch(pattern, symbol): + return True + return False + + +def _categorize(name: str) -> str: + """Assign a category based on the symbol's namespace.""" + if name.startswith("torch::nn::functional::"): + return "torch::nn::functional" + if name.startswith("torch::nn::init::"): + return "torch::nn::init" + if name.startswith("torch::nn::utils::"): + return "torch::nn::utils" + if name.startswith("torch::nn::"): + # Distinguish modules from other nn symbols + short = name.split("::")[-1] + if short[0].isupper(): + return "torch::nn (modules)" + return "torch::nn" + if name.startswith("torch::optim::"): + return "torch::optim" + if name.startswith("torch::data::"): + return "torch::data" + if name.startswith("torch::autograd::"): + return "torch::autograd" + if name.startswith("torch::serialize::") or name in ("torch::save", "torch::load"): + return "torch::serialize" + if name.startswith("torch::stable::"): + return "torch::stable" + if name.startswith("torch::fft::"): + return "torch::fft" + if name.startswith("torch::special::"): + return "torch::special" + if name.startswith(("torch::cuda::", "torch::mps::", "torch::xpu::")): + return "torch (device)" + if name.startswith("torch::"): + return "torch (core)" + if name.startswith("c10::cuda::"): + return "c10::cuda" + if name.startswith("c10::xpu::"): + return "c10::xpu" + if name.startswith("c10::"): + return "c10" + if name.startswith("at::cuda::"): + return "at::cuda" + if name.startswith("at::"): + return "at" + return "other" + + +def discover_apis_from_xml(xml_dir: Path) -> dict[str, list[tuple[str, str]]]: + """Parse Doxygen index.xml to discover all public APIs. + + Returns dict of category -> list of (symbol, kind). + """ + index_path = xml_dir / "index.xml" + if not index_path.exists(): + print( + f"ERROR: {index_path} not found. Run 'make doxygen' first.", + file=sys.stderr, + ) + sys.exit(1) + + tree = ET.parse(index_path) + root = tree.getroot() + + apis: dict[str, list[tuple[str, str]]] = {} + + # Collect classes and structs + for compound in root.findall("compound"): + kind = compound.get("kind") + if kind not in ("class", "struct"): + continue + name = compound.find("name").text + if _is_excluded(name): + continue + category = _categorize(name) + apis.setdefault(category, []).append((name, kind)) + + # Collect free functions from public namespaces + for compound in root.findall("compound"): + if compound.get("kind") != "namespace": + continue + ns_name = compound.find("name").text + if ns_name not in PUBLIC_FUNCTION_NAMESPACES: + continue + seen_funcs = set() + for member in compound.findall("member"): + if member.get("kind") != "function": + continue + func_name = member.find("name").text + qualified = f"{ns_name}::{func_name}" + if qualified in seen_funcs: + continue # skip overloads + seen_funcs.add(qualified) + if _is_excluded(qualified): + continue + category = _categorize(qualified) + apis.setdefault(category, []).append((qualified, "function")) + + # Collect macros (defines) from file compounds + for compound in root.findall("compound"): + if compound.get("kind") != "file": + continue + for member in compound.findall("member"): + if member.get("kind") != "define": + continue + macro_name = member.find("name").text + # Only track well-known public macros + if macro_name.startswith(("TORCH_LIBRARY", "TORCH_MODULE")): + if _is_excluded(macro_name): + continue + apis.setdefault("torch (macros)", []).append((macro_name, "define")) + + # Sort each category and deduplicate + for category in apis: + apis[category] = sorted(set(apis[category])) + + return apis + + +# ─── Source scanning ───────────────────────────────────────────────────────── + +# RST directives: .. doxygenclass:: torch::nn::ReLU +RST_DIRECTIVE_RE = re.compile( + r"^\.\.\s+doxygen(class|struct|function|typedef|define|enum|namespace)" + r"::\s*(.+?)\s*$", + re.MULTILINE, +) + +RST_CPP_DIRECTIVE_RE = re.compile( + r"^\.\.\s+cpp:(class|struct|function|enum|type)" r"::\s*(.+?)\s*$", + re.MULTILINE, +) + +# MyST directives: ```{doxygenclass} torch::nn::ReLU +MYST_DIRECTIVE_RE = re.compile( + r"^`{3,}\{doxygen(class|struct|function|typedef|define|enum|namespace)\}\s*(.+?)\s*$", + re.MULTILINE, +) + +MYST_CPP_DIRECTIVE_RE = re.compile( + r"^`{3,}\{cpp:(class|struct|function|enum|type)\}\s*(.+?)\s*$", + re.MULTILINE, +) + + +def scan_sources(source_dir: Path) -> set[str]: + """Extract all documented symbols from RST/MyST breathe and cpp domain directives.""" + documented = set() + for src_file in list(source_dir.rglob("*.rst")) + list(source_dir.rglob("*.md")): + content = src_file.read_text(errors="replace") + patterns = ( + RST_DIRECTIVE_RE, + RST_CPP_DIRECTIVE_RE, + MYST_DIRECTIVE_RE, + MYST_CPP_DIRECTIVE_RE, + ) + for pattern in patterns: + for match in pattern.finditer(content): + symbol = match.group(2) + # Strip template prefix + if symbol.startswith("template"): + gt = symbol.find(">") + if gt != -1: + symbol = symbol[gt + 1 :].lstrip() + # Strip function signature + paren = symbol.find("(") + if paren != -1: + symbol = symbol[:paren].rstrip() + documented.add(symbol) + return documented + + +# ─── Coverage report ───────────────────────────────────────────────────────── + + +def generate_coverage_report( + apis: dict[str, list[tuple[str, str]]], documented: set[str] +) -> str: + """Generate a coverage report comparing discovered APIs against RST docs.""" + lines = [] + lines.append("Undocumented C++ objects") + lines.append("=" * 50) + lines.append("") + + total = 0 + total_missing = 0 + section_stats = [] + + for category in sorted(apis.keys()): + symbols = apis[category] + section_missing = [] + for symbol, kind in symbols: + total += 1 + unqualified = symbol.rsplit("::", 1)[-1] + if symbol not in documented and unqualified not in documented: + section_missing.append((symbol, kind)) + total_missing += 1 + + covered = len(symbols) - len(section_missing) + section_stats.append((category, covered, len(symbols))) + + if section_missing: + lines.append(category) + lines.append("-" * len(category)) + for symbol, kind in section_missing: + lines.append(f" * {symbol} ({kind})") + lines.append("") + + # Summary + total_covered = total - total_missing + pct = (total_covered / total * 100) if total else 0 + + lines.append("") + lines.append("=" * 50) + lines.append("Summary") + lines.append("=" * 50) + lines.append("") + lines.append(f"Total APIs discovered: {total}") + lines.append(f"Documented: {total_covered}") + lines.append(f"Missing: {total_missing}") + lines.append(f"Coverage: {pct:.1f}%") + lines.append("") + + # Per-section table + lines.append(f"{'Category':<45} {'Covered':>8} {'Total':>6} {'%':>7}") + lines.append("-" * 70) + for category, covered, section_total in section_stats: + spct = (covered / section_total * 100) if section_total else 0 + lines.append(f"{category:<45} {covered:>8} {section_total:>6} {spct:>6.1f}%") + lines.append("") + + return "\n".join(lines) + + +# ─── HTML checks ───────────────────────────────────────────────────────────── + +BROKEN_PATTERNS = [ + ( + re.compile(r"Cannot find (?:class|struct|function|file)", re.IGNORECASE), + "unresolved breathe directive", + ), + ( + re.compile(r"Unable to resolve (?:class|struct|function)", re.IGNORECASE), + "unresolved breathe directive (ambiguous overload)", + ), + ( + re.compile(r"doxygenclass:|doxygenfunction:|doxygenstruct:", re.IGNORECASE), + "raw directive text in output", + ), + ( + re.compile(r"", re.IGNORECASE), + "Sphinx problematic node (broken reference)", + ), + ( + re.compile(r"System Message:", re.IGNORECASE), + "Sphinx system message (build error)", + ), +] + +MIN_CONTENT_LENGTH = 500 + + +def check_html_output(build_dir: Path) -> str: + """Check built HTML for broken formatting and empty pages.""" + issues = [] + + if not build_dir.exists(): + return "ERROR: build/html directory not found. Run 'make html' first.\n" + + for html_file in sorted(build_dir.rglob("*.html")): + rel = html_file.relative_to(build_dir) + if rel.name in ("search.html", "genindex.html", "objects.inv"): + continue + + try: + content = html_file.read_text(errors="replace") + except Exception as e: + issues.append((str(rel), f"cannot read: {e}")) + continue + + for pattern, description in BROKEN_PATTERNS: + matches = pattern.findall(content) + if matches: + issues.append((str(rel), f"{description} ({len(matches)}x)")) + + if str(rel).startswith("api/"): + text = re.sub(r"<[^>]+>", "", content) + text = re.sub(r"\s+", " ", text).strip() + if len(text) < MIN_CONTENT_LENGTH: + issues.append((str(rel), f"possibly empty page ({len(text)} chars)")) + + lines = [] + lines.append("HTML Formatting Check") + lines.append("=" * 50) + lines.append("") + + if not issues: + lines.append("No issues found.") + else: + lines.append(f"Found {len(issues)} issue(s):") + lines.append("") + lines.append(f"{'File':<55} Issue") + lines.append("-" * 90) + for filepath, issue in issues: + lines.append(f"{filepath:<55} {issue}") + + lines.append("") + return "\n".join(lines) + + +# ─── coverxygen integration ───────────────────────────────────────────────── + + +def run_coverxygen(xml_dir: Path) -> str: + """Run coverxygen on Doxygen XML output for doc-comment coverage.""" + lines = [] + lines.append("Coverxygen Report (Doxygen doc-comment coverage)") + lines.append("=" * 50) + lines.append("") + + if not xml_dir.exists(): + lines.append("ERROR: build/xml directory not found. Run 'make doxygen' first.") + return "\n".join(lines) + + coverxygen_cmd = None + for cmd in [ + ["coverxygen", "--version"], + [sys.executable, "-m", "coverxygen", "--version"], + ]: + try: + subprocess.run(cmd, capture_output=True, check=True) + coverxygen_cmd = cmd[:-1] + break + except (FileNotFoundError, subprocess.CalledProcessError): + continue + if coverxygen_cmd is None: + lines.append("coverxygen not installed. Install with: pip install coverxygen") + lines.append("") + lines.append("Once installed, coverxygen analyzes Doxygen XML to report what") + lines.append("percentage of C++ symbols have doc comments in the source code.") + lines.append("This is complementary to the RST coverage check above.") + lines.append("") + lines.append("Usage:") + lines.append( + f" coverxygen --xml-dir {xml_dir} --src-dir ../../ --output coverxygen.info" + ) + lines.append(" # Then use lcov/genhtml to visualize:") + lines.append( + " genhtml --no-function-coverage coverxygen.info -o coverxygen_html" + ) + return "\n".join(lines) + + try: + result = subprocess.run( + coverxygen_cmd + + [ + "--xml-dir", + str(xml_dir), + "--src-dir", + str(SCRIPT_DIR / ".." / ".."), + "--output", + "-", + "--kind", + "class,struct,function", + "--scope", + "public", + "--exclude", + ".*/build/.*", + "--exclude", + ".*/detail/.*", + "--exclude", + ".*/nativert/.*", + "--exclude", + ".*/stable/library\\.h", + ], + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode == 0: + total = 0 + documented_count = 0 + for line in result.stdout.splitlines(): + if line.startswith("DA:"): + total += 1 + parts = line.split(",") + if len(parts) >= 2 and parts[1].strip() != "0": + documented_count += 1 + pct = (documented_count / total * 100) if total else 0 + lines.append(f"Symbols scanned: {total}") + lines.append(f"With doc comments: {documented_count}") + lines.append(f"Coverage: {pct:.1f}%") + lines.append("") + lines.append("Full lcov output saved to: coverxygen.info") + (SCRIPT_DIR / "coverxygen.info").write_text(result.stdout) + else: + lines.append(f"coverxygen failed (exit {result.returncode}):") + lines.append(result.stderr[:500]) + except subprocess.TimeoutExpired: + lines.append("coverxygen timed out after 120s") + except Exception as e: + lines.append(f"coverxygen error: {e}") + + lines.append("") + return "\n".join(lines) + + +# ─── Main ──────────────────────────────────────────────────────────────────── + + +def main(): + parser = argparse.ArgumentParser(description="C++ docs coverage checker") + parser.add_argument( + "--coverxygen", + action="store_true", + help="Also run coverxygen on Doxygen XML for doc-comment coverage", + ) + parser.add_argument( + "--html-only", + action="store_true", + help="Only run HTML formatting checks", + ) + args = parser.parse_args() + + reports = [] + + if not args.html_only: + # Phase 1: Discover APIs from Doxygen XML + print("Discovering APIs from Doxygen XML...") + apis = discover_apis_from_xml(BUILD_XML) + total_apis = sum(len(v) for v in apis.values()) + print(f" Found {total_apis} public APIs across {len(apis)} categories") + + # Phase 2: Scan RST for documented symbols + print("Scanning sources for breathe directives...") + documented = scan_sources(SOURCE_DIR) + print(f" Found {len(documented)} documented symbols") + + coverage_report = generate_coverage_report(apis, documented) + reports.append(coverage_report) + + COVERAGE_OUTPUT.write_text(coverage_report) + print(f" Coverage report written to: {COVERAGE_OUTPUT}") + + # Phase 3: HTML checks + print("Checking HTML output for formatting issues...") + html_report = check_html_output(BUILD_HTML) + reports.append(html_report) + HTML_REPORT.write_text(html_report) + print(f" HTML report written to: {HTML_REPORT}") + + # Phase 4: coverxygen (optional) + if args.coverxygen: + print("Running coverxygen...") + cov_report = run_coverxygen(BUILD_XML) + reports.append(cov_report) + + # Print everything + print() + print("=" * 60) + for report in reports: + print(report) + print() + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/cpp/cpp_coverage.txt b/docs/cpp/cpp_coverage.txt new file mode 100644 index 0000000000000..4fce09f90a646 --- /dev/null +++ b/docs/cpp/cpp_coverage.txt @@ -0,0 +1,292 @@ +Undocumented C++ objects +================================================== + +torch (device) +-------------- + * torch::cuda::device_count (function) + * torch::cuda::is_available (function) + * torch::cuda::manual_seed (function) + * torch::cuda::manual_seed_all (function) + * torch::cuda::synchronize (function) + * torch::mps::commit (function) + * torch::mps::get_command_buffer (function) + * torch::mps::get_dispatch_queue (function) + * torch::mps::is_available (function) + * torch::mps::manual_seed (function) + * torch::mps::synchronize (function) + +torch (macros) +-------------- + * TORCH_MODULE (define) + * TORCH_MODULE_IMPL (define) + +torch::data +----------- + * torch::data::datasets::ChunkDatasetOptions (struct) + * torch::data::datasets::TensorDataset (struct) + * torch::data::make_data_loader (function) + * torch::data::samplers::BatchSize (struct) + * torch::data::samplers::CustomBatchRequest (struct) + +torch::fft +---------- + * torch::fft::fft (function) + * torch::fft::fft2 (function) + * torch::fft::fftfreq (function) + * torch::fft::fftn (function) + * torch::fft::fftshift (function) + * torch::fft::hfft (function) + * torch::fft::hfft2 (function) + * torch::fft::hfftn (function) + * torch::fft::ifft (function) + * torch::fft::ifft2 (function) + * torch::fft::ifftn (function) + * torch::fft::ifftshift (function) + * torch::fft::ihfft (function) + * torch::fft::ihfft2 (function) + * torch::fft::ihfftn (function) + * torch::fft::irfft (function) + * torch::fft::irfft2 (function) + * torch::fft::irfftn (function) + * torch::fft::rfft (function) + * torch::fft::rfft2 (function) + * torch::fft::rfftfreq (function) + * torch::fft::rfftn (function) + +torch::nn (modules) +------------------- + * torch::nn::AdaptiveAvgPoolOptions (struct) + * torch::nn::AdaptiveMaxPoolOptions (struct) + * torch::nn::AvgPoolOptions (struct) + * torch::nn::BatchNormOptions (struct) + * torch::nn::BilinearOptions (struct) + * torch::nn::CELUOptions (struct) + * torch::nn::ConstantPad1dImpl (class) + * torch::nn::ConstantPad2dImpl (class) + * torch::nn::ConstantPad3dImpl (class) + * torch::nn::ConstantPadOptions (struct) + * torch::nn::ConvOptions (struct) + * torch::nn::ConvTransposeOptions (struct) + * torch::nn::CosineSimilarityImpl (class) + * torch::nn::CosineSimilarityOptions (struct) + * torch::nn::DropoutOptions (struct) + * torch::nn::ELUOptions (struct) + * torch::nn::EmbeddingBagFromPretrainedOptions (struct) + * torch::nn::EmbeddingBagOptions (struct) + * torch::nn::EmbeddingFromPretrainedOptions (struct) + * torch::nn::EmbeddingOptions (struct) + * torch::nn::FlattenOptions (struct) + * torch::nn::FoldImpl (class) + * torch::nn::FractionalMaxPoolOptions (struct) + * torch::nn::Functional (class) + * torch::nn::GELUOptions (struct) + * torch::nn::GLUOptions (struct) + * torch::nn::HardshrinkOptions (struct) + * torch::nn::HardtanhOptions (struct) + * torch::nn::InstanceNormOptions (struct) + * torch::nn::KLDivLossOptions (struct) + * torch::nn::L1LossOptions (struct) + * torch::nn::LPPoolOptions (struct) + * torch::nn::LayerNormOptions (struct) + * torch::nn::LeakyReLUOptions (struct) + * torch::nn::LinearOptions (struct) + * torch::nn::LogSoftmaxOptions (struct) + * torch::nn::MaxPoolOptions (struct) + * torch::nn::MaxUnpoolOptions (struct) + * torch::nn::ModuleDictImpl (class) + * torch::nn::ModuleListImpl (class) + * torch::nn::MultiheadAttentionOptions (struct) + * torch::nn::PReLUOptions (struct) + * torch::nn::PairwiseDistanceImpl (class) + * torch::nn::PairwiseDistanceOptions (struct) + * torch::nn::PixelShuffleImpl (struct) + * torch::nn::PixelUnshuffleImpl (struct) + * torch::nn::RNNOptions (struct) + * torch::nn::RReLUOptions (struct) + * torch::nn::ReLU6Options (struct) + * torch::nn::ReLUOptions (struct) + * torch::nn::ReflectionPad1dImpl (class) + * torch::nn::ReflectionPad2dImpl (class) + * torch::nn::ReflectionPad3dImpl (class) + * torch::nn::ReflectionPadOptions (struct) + * torch::nn::ReplicationPad1dImpl (class) + * torch::nn::ReplicationPad2dImpl (class) + * torch::nn::ReplicationPad3dImpl (class) + * torch::nn::ReplicationPadOptions (struct) + * torch::nn::SELUOptions (struct) + * torch::nn::SequentialImpl (class) + * torch::nn::SoftmaxOptions (struct) + * torch::nn::SoftminOptions (struct) + * torch::nn::SoftplusOptions (struct) + * torch::nn::SoftshrinkOptions (struct) + * torch::nn::ThresholdOptions (struct) + * torch::nn::TransformerDecoderOptions (struct) + * torch::nn::TransformerEncoderLayerOptions (struct) + * torch::nn::TransformerEncoderOptions (struct) + * torch::nn::TransformerOptions (struct) + * torch::nn::UnflattenOptions (struct) + * torch::nn::UnfoldImpl (class) + * torch::nn::UpsampleImpl (class) + * torch::nn::ZeroPad1dImpl (class) + * torch::nn::ZeroPad2dImpl (class) + * torch::nn::ZeroPad3dImpl (class) + * torch::nn::ZeroPadOptions (struct) + +torch::nn::functional +--------------------- + * torch::nn::functional::AlphaDropoutFuncOptions (struct) + * torch::nn::functional::BatchNormFuncOptions (struct) + * torch::nn::functional::ConvFuncOptions (struct) + * torch::nn::functional::ConvTransposeFuncOptions (struct) + * torch::nn::functional::DropoutFuncOptions (struct) + * torch::nn::functional::EmbeddingBagFuncOptions (struct) + * torch::nn::functional::EmbeddingFuncOptions (struct) + * torch::nn::functional::FeatureAlphaDropoutFuncOptions (struct) + * torch::nn::functional::GridSampleFuncOptions (struct) + * torch::nn::functional::GumbelSoftmaxFuncOptions (struct) + * torch::nn::functional::InstanceNormFuncOptions (struct) + * torch::nn::functional::InterpolateFuncOptions (struct) + * torch::nn::functional::LogSoftmaxFuncOptions (struct) + * torch::nn::functional::MaxUnpoolFuncOptions (struct) + * torch::nn::functional::MultiheadAttentionForwardFuncOptions (struct) + * torch::nn::functional::PadFuncOptions (struct) + * torch::nn::functional::RReLUFuncOptions (struct) + * torch::nn::functional::SoftmaxFuncOptions (struct) + * torch::nn::functional::SoftminFuncOptions (struct) + * torch::nn::functional::adaptive_max_pool2d_with_indices (function) + * torch::nn::functional::adaptive_max_pool3d_with_indices (function) + * torch::nn::functional::fractional_max_pool2d_with_indices (function) + * torch::nn::functional::fractional_max_pool3d_with_indices (function) + * torch::nn::functional::multi_head_attention_forward (function) + +torch::nn::init +--------------- + * torch::nn::init::calculate_gain (function) + * torch::nn::init::constant_ (function) + * torch::nn::init::dirac_ (function) + * torch::nn::init::eye_ (function) + * torch::nn::init::kaiming_normal_ (function) + * torch::nn::init::kaiming_uniform_ (function) + * torch::nn::init::normal_ (function) + * torch::nn::init::ones_ (function) + * torch::nn::init::orthogonal_ (function) + * torch::nn::init::sparse_ (function) + * torch::nn::init::uniform_ (function) + * torch::nn::init::xavier_normal_ (function) + * torch::nn::init::xavier_uniform_ (function) + * torch::nn::init::zeros_ (function) + +torch::nn::utils +---------------- + * torch::nn::utils::clip_grad_norm_ (function) + * torch::nn::utils::clip_grad_value_ (function) + * torch::nn::utils::parameters_to_vector (function) + * torch::nn::utils::rnn::invert_permutation (function) + * torch::nn::utils::vector_to_parameters (function) + +torch::optim +------------ + * torch::optim::AdagradOptions (struct) + * torch::optim::AdagradParamState (struct) + * torch::optim::AdamOptions (struct) + * torch::optim::AdamParamState (struct) + * torch::optim::AdamWOptions (struct) + * torch::optim::AdamWParamState (struct) + * torch::optim::LBFGSOptions (struct) + * torch::optim::LBFGSParamState (struct) + * torch::optim::RMSpropOptions (struct) + * torch::optim::RMSpropParamState (struct) + * torch::optim::SGDOptions (struct) + * torch::optim::SGDParamState (struct) + +torch::special +-------------- + * torch::special::airy_ai (function) + * torch::special::bessel_j0 (function) + * torch::special::bessel_j1 (function) + * torch::special::bessel_y0 (function) + * torch::special::bessel_y1 (function) + * torch::special::chebyshev_polynomial_t (function) + * torch::special::chebyshev_polynomial_u (function) + * torch::special::chebyshev_polynomial_v (function) + * torch::special::chebyshev_polynomial_w (function) + * torch::special::digamma (function) + * torch::special::entr (function) + * torch::special::erf (function) + * torch::special::erfc (function) + * torch::special::erfcx (function) + * torch::special::erfinv (function) + * torch::special::exp2 (function) + * torch::special::expit (function) + * torch::special::expm1 (function) + * torch::special::gammainc (function) + * torch::special::gammaincc (function) + * torch::special::gammaln (function) + * torch::special::hermite_polynomial_h (function) + * torch::special::hermite_polynomial_he (function) + * torch::special::i0 (function) + * torch::special::i0e (function) + * torch::special::i1 (function) + * torch::special::i1e (function) + * torch::special::laguerre_polynomial_l (function) + * torch::special::legendre_polynomial_p (function) + * torch::special::log1p (function) + * torch::special::log_ndtr (function) + * torch::special::log_softmax (function) + * torch::special::logit (function) + * torch::special::logsumexp (function) + * torch::special::modified_bessel_i0 (function) + * torch::special::modified_bessel_i1 (function) + * torch::special::modified_bessel_k0 (function) + * torch::special::modified_bessel_k1 (function) + * torch::special::multigammaln (function) + * torch::special::ndtr (function) + * torch::special::ndtri (function) + * torch::special::polygamma (function) + * torch::special::psi (function) + * torch::special::round (function) + * torch::special::scaled_modified_bessel_k0 (function) + * torch::special::scaled_modified_bessel_k1 (function) + * torch::special::shifted_chebyshev_polynomial_t (function) + * torch::special::shifted_chebyshev_polynomial_u (function) + * torch::special::shifted_chebyshev_polynomial_v (function) + * torch::special::shifted_chebyshev_polynomial_w (function) + * torch::special::sinc (function) + * torch::special::softmax (function) + * torch::special::spherical_bessel_j0 (function) + * torch::special::xlog1py (function) + * torch::special::xlogy (function) + * torch::special::zeta (function) + + +================================================== +Summary +================================================== + +Total APIs discovered: 699 +Documented: 472 +Missing: 227 +Coverage: 67.5% + +Category Covered Total % +---------------------------------------------------------------------- +at 4 4 100.0% +at::cuda 7 7 100.0% +c10 10 10 100.0% +c10::cuda 9 9 100.0% +c10::xpu 4 4 100.0% +other 1 1 100.0% +torch (core) 4 4 100.0% +torch (device) 5 16 31.2% +torch (macros) 3 5 60.0% +torch::autograd 3 3 100.0% +torch::data 29 34 85.3% +torch::fft 0 22 0.0% +torch::nn (modules) 243 319 76.2% +torch::nn::functional 100 124 80.6% +torch::nn::init 0 14 0.0% +torch::nn::utils 0 5 0.0% +torch::optim 13 25 52.0% +torch::serialize 4 4 100.0% +torch::special 0 56 0.0% +torch::stable 33 33 100.0% diff --git a/docs/cpp/cpp_html_issues.txt b/docs/cpp/cpp_html_issues.txt new file mode 100644 index 0000000000000..754a4bd2bc4c8 --- /dev/null +++ b/docs/cpp/cpp_html_issues.txt @@ -0,0 +1,4 @@ +HTML Formatting Check +================================================== + +No issues found. diff --git a/docs/cpp/source/Doxyfile b/docs/cpp/source/Doxyfile index 72e83733b6b4c..748f7a69ea266 100644 --- a/docs/cpp/source/Doxyfile +++ b/docs/cpp/source/Doxyfile @@ -34,6 +34,7 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../aten/src/ATen/core/ivalue.h \ ../../../aten/src/ATen/core/ScalarType.h \ ../../../aten/src/ATen/cuda/CUDAContext.h \ + ../../../aten/src/ATen/cuda/CUDAContextLight.h \ ../../../aten/src/ATen/cudnn/Descriptors.h \ ../../../aten/src/ATen/cudnn/Handles.h \ ../../../aten/src/ATen/cudnn/Types.h \ @@ -48,7 +49,9 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../build/aten/src/ATen/Functions.h \ ../../../build/aten/src/ATen/core/TensorBody.h \ ../../../c10/core/Device.h \ + ../../../c10/core/DeviceGuard.h \ ../../../c10/core/DeviceType.h \ + ../../../c10/core/Stream.h \ ../../../c10/util/Half.h \ ../../../c10/util/ArrayRef.h \ ../../../c10/util/OptionalArrayRef.h \ @@ -100,7 +103,8 @@ GENERATE_XML = YES # Set to NO if you do not want the Doxygen program listing included. # NOTE: setting to NO may result in unrecovered file relationships # (which file defined which compound) -XML_PROGRAMLISTING = YES +# Setting to NO to prevent inline code from polluting documentation +XML_PROGRAMLISTING = NO ################################################################################ # Doxygen preprocessor / parser control. # ################################################################################ @@ -129,8 +133,16 @@ EXTRACT_ALL = YES EXTRACT_PACKAGE = YES EXTRACT_STATIC = YES CASE_SENSE_NAMES = NO -EXCLUDE_SYMBOLS = caffe2::* cereal* DL* TH* cudnn* std::* +EXCLUDE_SYMBOLS = caffe2::* cereal* DL* TH* cudnn* std::* torch::nested* # EXCLUDE_SYMBOLS = c10::* caffe2::* cereal* DL* TH* cudnn* std::* + +# Don't include inline source code in documentation +INLINE_SOURCES = NO +# Don't include code bodies in detailed description +SOURCE_BROWSER = NO +# Don't reference source files +REFERENCES_LINK_SOURCE = NO +VERBATIM_HEADERS = NO ################################################################################ # Docstring control / customization. # ################################################################################ @@ -148,10 +160,21 @@ JAVADOC_AUTOBRIEF = YES # * # * \endrst # */ +# Or with /// style: +# /// \rst +# /// .. code-block:: cpp +# /// +# /// int main() { +# /// return 0; +# /// } +# /// +# /// \endrst +# # NOTE: # 1. \rst and \endrst must be on their own line. -# 2. leading-asterisk required. -ALIASES = "rst=\verbatim embed:rst:leading-asterisk" +# 2. For /// comments, use leading-slashes; for /** */ use leading-asterisk +# 3. PyTorch uses /// style comments primarily +ALIASES = "rst=\verbatim embed:rst:leading-slashes" ALIASES += "endrst=\endverbatim" ################################################################################ # Warning suppression. # diff --git a/docs/cpp/source/api/aten/accessors.md b/docs/cpp/source/api/aten/accessors.md new file mode 100644 index 0000000000000..83fefce724e7e --- /dev/null +++ b/docs/cpp/source/api/aten/accessors.md @@ -0,0 +1,49 @@ +--- +myst: + html_meta: + description: Tensor accessors in PyTorch C++ — efficient element-wise access to tensor data without overhead. + keywords: PyTorch, C++, tensor accessor, packed_accessor, data access +--- + +# Tensor Accessors + +For element-wise operations in custom kernels, use *accessors* to avoid +dynamic dispatch overhead. + +## CPU Accessors + +```cpp +torch::Tensor foo = torch::rand({12, 12}); + +// Create accessor - validates type and dimensions once +auto foo_a = foo.accessor(); + +float trace = 0; +for (int i = 0; i < foo_a.size(0); i++) { + trace += foo_a[i][i]; +} +``` + +## CUDA Packed Accessors + +For CUDA kernels, use *packed accessors* which copy metadata instead of +pointing to it: + +```cpp +__global__ void kernel(torch::PackedTensorAccessor64 foo, float* trace) { + int i = threadIdx.x; + gpuAtomicAdd(trace, foo[i][i]); +} + +torch::Tensor foo = torch::rand({12, 12}).cuda(); +auto foo_a = foo.packed_accessor64(); + +float trace = 0; +kernel<<<1, 12>>>(foo_a, &trace); +``` + +```{tip} + +Use `PackedTensorAccessor32` and `packed_accessor32` for 32-bit indexing, +which is faster on CUDA but may overflow for large tensors. +``` diff --git a/docs/cpp/source/api/aten/creation.md b/docs/cpp/source/api/aten/creation.md new file mode 100644 index 0000000000000..27ddcd616db55 --- /dev/null +++ b/docs/cpp/source/api/aten/creation.md @@ -0,0 +1,235 @@ +--- +myst: + html_meta: + description: Tensor creation functions in PyTorch C++ — zeros, ones, randn, arange, from_blob, and more. + keywords: PyTorch, C++, tensor creation, zeros, ones, randn, arange, from_blob +--- + +# Tensor Creation + +Factory functions create new tensors with different initialization patterns. +All factory functions follow a general schema: + +```cpp +torch::(, , ) +``` + +## Available Factory Functions + +- `torch::zeros` - Tensor filled with zeros +- `torch::ones` - Tensor filled with ones +- `torch::empty` - Uninitialized tensor +- `torch::full` - Tensor filled with a single value +- `torch::rand` - Uniform random values on [0, 1) +- `torch::randn` - Standard normal distribution +- `torch::randint` - Random integers in a range +- `torch::arange` - Sequence of integers +- `torch::linspace` - Linearly spaced values +- `torch::logspace` - Logarithmically spaced values +- `torch::eye` - Identity matrix +- `torch::randperm` - Random permutation of integers + +## Specifying a Size + +Functions that do not require specific arguments can be invoked with just a +size. For example, the following line creates a vector with 5 components: + +```cpp +torch::Tensor tensor = torch::ones(5); +``` + +An `IntArrayRef` is constructed by specifying the size along each dimension in +curly braces. For example, `{2, 3}` for a matrix with two rows and three +columns, `{3, 4, 5}` for a three-dimensional tensor: + +```cpp +torch::Tensor tensor = torch::randn({3, 4, 5}); +assert(tensor.sizes() == std::vector{3, 4, 5}); +``` + +You can also pass an `std::vector` instead of curly braces. +Use `tensor.size(i)` to access a single dimension. + +## Passing Function-Specific Parameters + +Some factory functions accept additional parameters. For example, `randint` +takes an upper bound on the value for the integers it generates: + +```cpp +torch::Tensor tensor = torch::randint(/*high=*/10, {5, 5}); + +// With a lower bound +torch::Tensor tensor = torch::randint(/*low=*/3, /*high=*/10, {5, 5}); +``` + +```{tip} + +The size always follows the function-specific arguments. +``` + +```{attention} + +Some functions like `arange` do not need a size at all, since it is fully +determined by the function-specific arguments (the range bounds). +``` + +## Configuring Properties with TensorOptions + +`TensorOptions` configures the data type, layout, device, and +`requires_grad` of a new tensor. The construction axes are: + +- `dtype`: the data type of the elements (e.g. `kFloat32`, `kInt64`) +- `layout`: either `kStrided` (dense) or `kSparse` +- `device`: a compute device (e.g. `kCPU`, `kCUDA`) +- `requires_grad`: whether to track gradients + +Allowed values: + +- `dtype`: `kUInt8`, `kInt8`, `kInt16`, `kInt32`, `kInt64`, + `kFloat32`, `kFloat64` +- `layout`: `kStrided`, `kSparse` +- `device`: `kCPU`, or `kCUDA` (with an optional device index) +- `requires_grad`: `true` or `false` + +```{tip} + +Rust-style shorthands exist for dtypes, like `kF32` instead of +`kFloat32`. See +[torch/types.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/api/include/torch/types.h) +for the full list. +``` + +Here is an example of creating a `TensorOptions` object: + +```cpp +auto options = + torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCUDA, 1) + .requires_grad(true); + +torch::Tensor tensor = torch::full({3, 4}, /*value=*/123, options); + +assert(tensor.dtype() == torch::kFloat32); +assert(tensor.layout() == torch::kStrided); +assert(tensor.device().type() == torch::kCUDA); +assert(tensor.device().index() == 1); +assert(tensor.requires_grad()); +``` + +**Defaults:** Any axis you omit takes its default value: `kFloat32` for dtype, +`kStrided` for layout, `kCPU` for device, and `false` for +`requires_grad`. This means you can omit `TensorOptions` entirely: + +```cpp +// A 32-bit float, strided, CPU tensor that does not require a gradient. +torch::Tensor tensor = torch::randn({3, 4}); +``` + +**Shorthand syntax:** For each axis there is a free function in the `torch::` +namespace (`torch::dtype()`, `torch::device()`, `torch::layout()`, +`torch::requires_grad()`). Each returns a `TensorOptions` object that can +be further refined with builder methods: + +```cpp +// These are equivalent: +torch::ones(10, torch::TensorOptions().dtype(torch::kFloat32)) +torch::ones(10, torch::dtype(torch::kFloat32)) + +// Chaining: +torch::ones(10, torch::dtype(torch::kFloat32).layout(torch::kStrided)) +``` + +**Implicit construction:** `TensorOptions` is implicitly constructible from +individual values, so when only one axis differs from the default you can +write: + +```cpp +torch::ones(10, torch::kFloat32) +``` + +Putting it all together, a C++ tensor creation call mirrors the Python +equivalent closely: + +```python +# Python +torch.randn(3, 4, dtype=torch.float32, device=torch.device('cuda', 1), requires_grad=True) +``` + +```cpp +// C++ +torch::randn({3, 4}, torch::dtype(torch::kFloat32).device(torch::kCUDA, 1).requires_grad(true)) +``` + +## Using Externally Created Data + +If you already have tensor data allocated in memory (CPU or CUDA), use +`from_blob` to view that memory as a `Tensor`: + +```cpp +float data[] = {1, 2, 3, 4, 5, 6}; +torch::Tensor tensor = torch::from_blob(data, {2, 3}); +``` + +```{note} + +Tensors created with `from_blob` cannot be resized because ATen does not +own the memory. +``` + +## Tensor Conversion + +Use `to()` to convert tensors between dtypes and devices. The conversion +creates a new tensor and does not occur in-place: + +```cpp +torch::Tensor source = torch::randn({2, 3}, torch::kInt64); + +// Convert dtype +torch::Tensor float_tensor = source.to(torch::kFloat32); + +// Move to GPU (default CUDA device) +torch::Tensor gpu_tensor = float_tensor.to(torch::kCUDA); + +// Specific GPU device +torch::Tensor gpu1_tensor = float_tensor.to(torch::Device(torch::kCUDA, 1)); + +// Async copy +torch::Tensor async_tensor = gpu_tensor.to(torch::kCPU, /*non_blocking=*/true); +``` + +```{attention} + +The result of the conversion is a new tensor pointing to new memory, +unrelated to the source tensor. +``` + +## Scalars and Zero-Dimensional Tensors + +`Scalar` represents a single dynamically-typed number. Like a `Tensor`, +`Scalar` is dynamically typed and can hold any of ATen's number types. +Scalars can be implicitly constructed from C++ number types: + +```cpp +namespace torch { +Tensor addmm(Scalar beta, const Tensor & self, + Scalar alpha, const Tensor & mat1, + const Tensor & mat2); +Scalar sum(const Tensor & self); +} // namespace torch + +// Usage +torch::Tensor a = ...; +torch::Tensor b = ...; +torch::Tensor c = ...; +torch::Tensor r = torch::addmm(1.0, a, .5, b, c); +``` + +Zero-dimensional tensors hold a single value and can reference elements in +larger tensors: + +```cpp +torch::Tensor matrix = torch::rand({10, 20}); +matrix[1][2] = 4; // matrix[1][2] is a zero-dimensional tensor +``` diff --git a/docs/cpp/source/api/aten/index.md b/docs/cpp/source/api/aten/index.md new file mode 100644 index 0000000000000..e9980c4cca39f --- /dev/null +++ b/docs/cpp/source/api/aten/index.md @@ -0,0 +1,76 @@ +--- +myst: + html_meta: + description: ATen C++ API — the foundational tensor library powering PyTorch, including tensor operations, types, and CUDA support. + keywords: PyTorch, C++, ATen, tensor, operations, CUDA +--- + +# ATen: Tensor Library + +ATen (A Tensor Library) is the foundational tensor and mathematical operation +library on which all of PyTorch is built. It provides the core `Tensor` class +and hundreds of mathematical operations that work on tensors. + +**When to use ATen directly:** + +- When writing low-level tensor operations or custom kernels +- When you need direct access to tensor data and metadata +- When working with the PyTorch internals or extending PyTorch + +**Basic usage:** + +```cpp +#include + +// Create tensors +at::Tensor a = at::ones({2, 3}); +at::Tensor b = at::randn({2, 3}); + +// Perform operations +at::Tensor c = a + b; +at::Tensor d = at::matmul(a.t(), b); + +// Move to GPU +if (at::cuda::is_available()) { + at::Tensor gpu_tensor = c.to(at::kCUDA); +} +``` + +For most applications, prefer using the higher-level `torch::` namespace +(see {doc}`../nn/index`, {doc}`../optim/index`) which provides a more user-friendly API. + +## Header Files + +The following headers are part of the ATen public API: + +- `ATen/ATen.h` - Main ATen header +- `ATen/Backend.h` - Backend enumeration +- `ATen/core/Tensor.h` - Tensor class +- `ATen/core/ivalue.h` - IValue type (see {doc}`../c10/index`) +- `ATen/core/ScalarType.h` - Data type definitions +- `ATen/TensorOptions.h` - Tensor creation options +- `ATen/Scalar.h` - Scalar type +- `ATen/Layout.h` - Tensor layout +- `ATen/DeviceGuard.h` - Device context management +- `ATen/native/TensorShape.h` - Tensor shape operations +- `ATen/cuda/CUDAContext.h` - CUDA context (see {doc}`../cuda/index`) +- `ATen/cudnn/Descriptors.h` - cuDNN descriptors +- `ATen/mkl/Descriptors.h` - MKL descriptors + +```{note} + +The core `at::Tensor` class is defined in a generated header file +(`TensorBody.h`) that only exists after building PyTorch. The documentation +below describes the API manually. +``` + +## ATen Categories + +```{toctree} +:maxdepth: 1 + +tensor +creation +indexing +accessors +``` diff --git a/docs/cpp/source/api/aten/indexing.md b/docs/cpp/source/api/aten/indexing.md new file mode 100644 index 0000000000000..63481cb1cd5a4 --- /dev/null +++ b/docs/cpp/source/api/aten/indexing.md @@ -0,0 +1,193 @@ +--- +myst: + html_meta: + description: Tensor indexing in PyTorch C++ — Slice, None, Ellipsis, boolean masks, and advanced indexing with index() and index_put_(). + keywords: PyTorch, C++, tensor indexing, Slice, index, index_put_, boolean mask, fancy indexing +--- + +# Tensor Indexing + +The PyTorch C++ API provides tensor indexing similar to Python. Use +`torch::indexing` namespace for index types: + +```cpp +using namespace torch::indexing; +``` + +The main difference from Python is that instead of using the `[]` operator, +the C++ API uses the `index` and `index_put_` methods: + +- `torch::Tensor::index` — read elements +- `torch::Tensor::index_put_` — write elements + +## Index Types + +The `TensorIndex` class accepts six types of indices via implicit constructors: + +```{list-table} +:widths: 25 35 40 +:header-rows: 1 + +* - Type + - C++ + - Python equivalent +* - None (unsqueeze) + - `None` + - `None` +* - Ellipsis + - `Ellipsis` or `"..."` + - `...` +* - Integer + - `0`, `1`, `-1` + - `0`, `1`, `-1` +* - Boolean + - `true`, `false` + - `True`, `False` +* - Slice + - `Slice(start, stop, step)` + - `start:stop:step` +* - Tensor + - `torch::tensor({0, 2})` + - `torch.tensor([0, 2])` +``` + +## Getter Operations + +```{list-table} +:widths: 40 60 +:header-rows: 1 + +* - Python + - C++ +* - `tensor[0]` + - `tensor.index({0})` +* - `tensor[-1]` + - `tensor.index({-1})` +* - `tensor[1, 2]` + - `tensor.index({1, 2})` +* - `tensor[1, :, 3]` + - `tensor.index({1, Slice(), 3})` +* - `tensor[None]` + - `tensor.index({None})` +* - `tensor[:, None]` + - `tensor.index({Slice(), None})` +* - `tensor[...]` + - `tensor.index({Ellipsis})` or `tensor.index({"..."})` +* - `tensor[..., 0]` + - `tensor.index({Ellipsis, 0})` +* - `tensor[1::2]` + - `tensor.index({Slice(1, None, 2)})` +* - `tensor[True]` + - `tensor.index({true})` +* - `tensor[torch.tensor([1, 2])]` + - `tensor.index({torch::tensor({1, 2})})` +* - `tensor[bool_mask]` + - `tensor.index({bool_mask})` +* - `tensor[:, torch.tensor([[0,1],[4,3]])]` + - `tensor.index({Slice(), torch::tensor({{0,1},{4,3}})})` +* - `tensor[cond > 0]` + - `tensor.index({cond > 0})` +``` + +## Setter Operations + +```{list-table} +:widths: 40 60 +:header-rows: 1 + +* - Python + - C++ +* - `tensor[0] = 1` + - `tensor.index_put_({0}, 1)` +* - `tensor[1, 2] = 1` + - `tensor.index_put_({1, 2}, 1)` +* - `tensor[1] = torch.arange(5)` + - `tensor.index_put_({1}, torch::arange(5))` +* - `tensor[1::2] = 1` + - `tensor.index_put_({Slice(1, None, 2)}, 1)` +* - `tensor[0, 1::2] = torch.tensor([3., 4.])` + - `tensor.index_put_({0, Slice(1, None, 2)}, torch::tensor({3., 4.}))` +* - `tensor[...] = 0` + - `tensor.index_put_({Ellipsis}, 0)` +* - `tensor[None] = value` + - `tensor.index_put_({None}, value)` +* - `tensor[bool_mask] = 0` + - `tensor.index_put_({bool_mask}, 0)` +* - `tensor[torch.tensor([0, 2])] = value` + - `tensor.index_put_({torch::tensor({0, 2})}, value)` +* - `tensor[1:2, torch.tensor([1,2])] = 0` + - `tensor.index_put_({Slice(1, 2), torch::tensor({1, 2})}, 0)` +``` + +The `index_put_` method also accepts an optional `accumulate` parameter. +When `true`, values are added to existing values instead of replacing them: + +```cpp +tensor.index_put_({mask}, values, /*accumulate=*/true); +``` + +## Slice Syntax + +The `Slice` constructor signature is: + +```cpp +Slice( + std::optional start = std::nullopt, + std::optional stop = std::nullopt, + std::optional step = std::nullopt); +``` + +Pass `None` for open-ended bounds: + +```{list-table} +:widths: 30 70 +:header-rows: 1 + +* - Python + - C++ +* - `:` or `::` + - `Slice()` or `Slice(None, None)` +* - `1:` + - `Slice(1, None)` +* - `:3` + - `Slice(None, 3)` +* - `1:3` + - `Slice(1, 3)` +* - `1:3:2` + - `Slice(1, 3, 2)` +* - `::2` + - `Slice(None, None, 2)` +``` + +## Full Example + +```cpp +#include + +using namespace torch::indexing; + +auto tensor = torch::arange(2 * 3 * 4).reshape({2, 3, 4}); + +// Basic indexing +auto row = tensor.index({0}); // tensor[0] +auto elem = tensor.index({1, 2, 3}); // tensor[1, 2, 3] + +// Slicing +auto sliced = tensor.index({Slice(), Slice(0, 2)}); // tensor[:, 0:2] + +// None (unsqueeze) and Ellipsis +auto unsqueezed = tensor.index({None}); // tensor[None] +auto last_dim = tensor.index({Ellipsis, -1}); // tensor[..., -1] + +// Boolean mask indexing +auto mask = tensor > 10; +auto selected = tensor.index({mask}); // tensor[tensor > 10] + +// Integer tensor (fancy) indexing +auto idx = torch::tensor({0, 2}); +auto gathered = tensor.index({Slice(), idx}); // tensor[:, [0, 2]] + +// Setting values +tensor.index_put_({0, Slice(), 0}, 99); // tensor[0, :, 0] = 99 +tensor.index_put_({mask}, 0); // tensor[tensor > 10] = 0 +``` diff --git a/docs/cpp/source/api/aten/tensor.md b/docs/cpp/source/api/aten/tensor.md new file mode 100644 index 0000000000000..227c94e13d7ff --- /dev/null +++ b/docs/cpp/source/api/aten/tensor.md @@ -0,0 +1,288 @@ +--- +myst: + html_meta: + description: at::Tensor class reference — the primary tensor type in PyTorch C++ with creation, indexing, device, and dtype APIs. + keywords: PyTorch, C++, Tensor, at::Tensor, TensorOptions, ScalarType, DeviceGuard +--- + +# Tensor Class + +The `at::Tensor` class is the primary tensor class in ATen, representing +a multi-dimensional array with a specific data type and device. + +## Tensor + +```{cpp:class} at::Tensor + +The primary tensor class in ATen. Represents a multi-dimensional array +with a specific data type and device. +``` + +```{cpp:function} Tensor() + +Default constructor. Creates an undefined tensor. +``` + +```{cpp:function} int64_t dim() const + +Returns the number of dimensions of the tensor. +``` + +```{cpp:function} int64_t size(int64_t dim) const + +Returns the size of the tensor at the given dimension. +``` + +```{cpp:function} IntArrayRef sizes() const + +Returns the sizes of all dimensions. +``` + +```{cpp:function} IntArrayRef strides() const + +Returns the strides of all dimensions. +``` + +```{cpp:function} ScalarType scalar_type() const + +Returns the data type of the tensor. +``` + +```{cpp:function} Device device() const + +Returns the device where the tensor is stored. +``` + +```{cpp:function} bool is_cuda() const + +Returns true if the tensor is on a CUDA device. +``` + +```{cpp:function} bool is_cpu() const + +Returns true if the tensor is on CPU. +``` + +```{cpp:function} bool requires_grad() const + +Returns true if gradients need to be computed for this tensor. +``` + +```{cpp:function} Tensor& requires_grad_(bool requires_grad = true) + +Sets whether gradients should be computed for this tensor. +``` + +```{cpp:function} Tensor to(Device device) const + +Returns a tensor on the specified device. +``` + +```{cpp:function} Tensor to(ScalarType dtype) const + +Returns a tensor with the specified data type. +``` + +```{cpp:function} Tensor contiguous() const + +Returns a contiguous tensor with the same data. +``` + +```{cpp:function} void* data_ptr() const + +Returns a pointer to the underlying data. +``` + +**Example:** + +```cpp +#include + +at::Tensor a = at::ones({2, 2}, at::kInt); +at::Tensor b = at::randn({2, 2}); +auto c = a + b.to(at::kInt); +``` + +## TensorOptions + +```{cpp:class} at::TensorOptions + +A class to specify options for tensor creation, including dtype, device, +layout, and requires_grad. +``` + +```{cpp:function} TensorOptions() + +Default constructor. +``` + +```{cpp:function} TensorOptions dtype(ScalarType dtype) const + +Returns options with the specified data type. +``` + +```{cpp:function} TensorOptions device(Device device) const + +Returns options with the specified device. +``` + +```{cpp:function} TensorOptions layout(Layout layout) const + +Returns options with the specified layout. +``` + +```{cpp:function} TensorOptions requires_grad(bool requires_grad) const + +Returns options with the specified requires_grad setting. +``` + +**Example:** + +```cpp +auto options = at::TensorOptions() + .dtype(at::kFloat) + .device(at::kCUDA, 0) + .requires_grad(true); + +at::Tensor tensor = at::zeros({3, 4}, options); +``` + +## Scalar + +```{cpp:class} at::Scalar + +Represents a scalar value that can be converted to various numeric types. +``` + +```{cpp:function} Scalar(int64_t v) + +Construct from an integer. +``` + +```{cpp:function} Scalar(double v) + +Construct from a double. +``` + +```{cpp:function} template T to() const + +Convert to the specified type. +``` + +```{cpp:function} bool isIntegral(bool includeBool = false) const + +Returns true if the scalar is an integral type. +``` + +```{cpp:function} bool isFloatingPoint() const + +Returns true if the scalar is a floating point type. +``` + +## ScalarType + +```{cpp:enum-class} at::ScalarType + +Enumeration of data types supported by tensors. +``` + +```{cpp:enumerator} Byte + +8-bit unsigned integer (uint8_t) +``` + +```{cpp:enumerator} Char + +8-bit signed integer (int8_t) +``` + +```{cpp:enumerator} Short + +16-bit signed integer (int16_t) +``` + +```{cpp:enumerator} Int + +32-bit signed integer (int32_t) +``` + +```{cpp:enumerator} Long + +64-bit signed integer (int64_t) +``` + +```{cpp:enumerator} Half + +16-bit floating point (float16) +``` + +```{cpp:enumerator} Float + +32-bit floating point (float) +``` + +```{cpp:enumerator} Double + +64-bit floating point (double) +``` + +```{cpp:enumerator} Bool + +Boolean type +``` + +```{cpp:enumerator} BFloat16 + +Brain floating point (bfloat16) +``` + +Convenience constants: + +- `at::kByte`, `at::kChar`, `at::kShort`, `at::kInt`, `at::kLong` +- `at::kHalf`, `at::kFloat`, `at::kDouble`, `at::kBFloat16` +- `at::kBool` + +## DeviceGuard + +```{doxygenclass} c10::DeviceGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +{ + c10::DeviceGuard guard(at::Device(at::kCUDA, 1)); + // Operations here run on CUDA device 1 + auto tensor = at::zeros({2, 2}); +} +// Previous device is restored +``` + +## Layout + +```{cpp:enum-class} at::Layout + +Enumeration of tensor memory layouts. +``` + +```{cpp:enumerator} Strided + +Dense tensor with strides. +``` + +```{cpp:enumerator} Sparse + +Sparse tensor (COO format). +``` + +```{cpp:enumerator} SparseCsr + +Sparse tensor in CSR format. +``` + +```{cpp:enumerator} SparseCsc + +Sparse tensor in CSC format. +``` diff --git a/docs/cpp/source/api/autograd/custom_functions.md b/docs/cpp/source/api/autograd/custom_functions.md new file mode 100644 index 0000000000000..475c975cdd9aa --- /dev/null +++ b/docs/cpp/source/api/autograd/custom_functions.md @@ -0,0 +1,95 @@ +--- +myst: + html_meta: + description: Custom autograd functions in PyTorch C++ — defining forward and backward passes with torch::autograd::Function. + keywords: PyTorch, C++, autograd, custom function, forward, backward, Function +--- + +# Custom Autograd Functions + +PyTorch allows you to define custom autograd functions with custom forward +and backward implementations. + +## Function Base Class + +```{doxygenstruct} torch::autograd::Function +:members: +:undoc-members: +``` + +## AutogradContext + +```{doxygenstruct} torch::autograd::AutogradContext +:members: +:undoc-members: +``` + +## Creating Custom Functions + +To create a custom autograd function, inherit from `torch::autograd::Function` +and implement the static `forward` and `backward` methods: + +**Example:** + +```cpp +class MyReLU : public torch::autograd::Function { + public: + static torch::Tensor forward( + torch::autograd::AutogradContext* ctx, + torch::Tensor input) { + ctx->save_for_backward({input}); + return input.clamp_min(0); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto saved = ctx->get_saved_variables(); + auto input = saved[0]; + auto grad_output = grad_outputs[0]; + auto grad_input = grad_output * (input > 0).to(grad_output.dtype()); + return {grad_input}; + } +}; + +// Usage +auto output = MyReLU::apply(input); +``` + +## Custom Kernels and AutoDispatchBelowADInplaceOrView + +For users implementing custom kernels who want to redispatch below `Autograd` dispatch +keys, use `at::AutoDispatchBelowADInplaceOrView` instead of `InferenceMode`: + +```cpp +class ROIAlignFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, + int64_t pooled_height, + int64_t pooled_width, + int64_t sampling_ratio, + bool aligned) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->saved_data["pooled_height"] = pooled_height; + ctx->saved_data["pooled_width"] = pooled_width; + ctx->saved_data["sampling_ratio"] = sampling_ratio; + ctx->saved_data["aligned"] = aligned; + ctx->saved_data["input_shape"] = input.sizes(); + ctx->save_for_backward({rois}); + + at::AutoDispatchBelowADInplaceOrView guard; + auto result = roi_align( + input, rois, spatial_scale, pooled_height, + pooled_width, sampling_ratio, aligned); + return {result}; + } +}; +``` + +For customized inplace and view kernels, see the +[custom kernel tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html#backward-pass) +for more details. diff --git a/docs/cpp/source/api/autograd/gradient.md b/docs/cpp/source/api/autograd/gradient.md new file mode 100644 index 0000000000000..6f595496fb7a0 --- /dev/null +++ b/docs/cpp/source/api/autograd/gradient.md @@ -0,0 +1,86 @@ +--- +myst: + html_meta: + description: PyTorch C++ API for computing gradients — torch::autograd::backward and torch::autograd::grad functions for automatic differentiation. + keywords: PyTorch, C++, autograd, backward, grad, gradient, automatic differentiation +--- + +# Gradient Computation + +PyTorch provides functions for computing gradients of tensors with respect +to graph leaves. + +## Gradient Functions + +```{cpp:function} void torch::autograd::backward(const variable_list& tensors, const variable_list& grad_tensors = {}, std::optional retain_graph = std::nullopt, bool create_graph = false, const variable_list& inputs = {}) + +Computes the sum of gradients of given tensors with respect to graph leaves. + +The graph is differentiated using the chain rule. If any of `tensors` +are non-scalar (i.e. their data has more than one element) and require +gradient, then the Jacobian-vector product would be computed, in this case +the function additionally requires specifying `grad_tensors`. It should be a +sequence of matching length, that contains the "vector" in the +Jacobian-vector product, usually the gradient of the differentiated function +w.r.t. corresponding tensors (`torch::Tensor()` is an acceptable value for +all tensors that don't need gradient tensors). + +This function accumulates gradients in the leaves — you might need to zero +them before calling it. + +:param tensors: Tensors of which the derivative will be computed. +:param grad_tensors: The "vector" in the Jacobian-vector product, usually + gradients w.r.t. each element of corresponding tensors. + `torch::Tensor()` values can be specified for scalar Tensors or ones + that don't require grad. If a `torch::Tensor()` value would be + acceptable for all grad_tensors, then this argument is optional. +:param retain_graph: If `false`, the graph used to compute the grad will + be freed. Note that in nearly all cases setting this option to `true` + is not needed and often can be worked around in a much more efficient + way. Defaults to the value of `create_graph`. +:param create_graph: If `true`, graph of the derivative will be + constructed, allowing to compute higher order derivative products. + Defaults to `false`. +:param inputs: Inputs w.r.t. which the gradient will be accumulated into + `at::Tensor::grad`. All other Tensors will be ignored. If not + provided, the gradient is accumulated into all the leaf Tensors that + were used to compute `tensors`. +``` + +```{doxygenfunction} torch::autograd::grad +``` + +**Example:** + +```cpp +#include + +auto x = torch::randn({2, 2}, torch::requires_grad()); +auto y = x * x; +auto z = y.sum(); + +// Compute gradients +z.backward(); +std::cout << x.grad() << std::endl; + +// Or use grad() for specific outputs +auto grads = torch::autograd::grad({z}, {x}); +``` + +## Tensor Gradient Methods + +Tensors have built-in methods for gradient computation: + +```cpp +// Enable gradient tracking +auto x = torch::randn({2, 2}).requires_grad_(true); + +// Check if gradient is required +bool needs_grad = x.requires_grad(); + +// Access the gradient after backward +auto grad = x.grad(); + +// Detach from computation graph +auto x_detached = x.detach(); +``` diff --git a/docs/cpp/source/api/autograd/index.md b/docs/cpp/source/api/autograd/index.md new file mode 100644 index 0000000000000..37b3cb79f747e --- /dev/null +++ b/docs/cpp/source/api/autograd/index.md @@ -0,0 +1,56 @@ +--- +myst: + html_meta: + description: PyTorch C++ Autograd API — automatic differentiation for tensor computations. + keywords: PyTorch, C++, autograd, automatic differentiation, gradient +--- + +# Autograd: Automatic Differentiation + +PyTorch's autograd system provides automatic differentiation for all operations +on tensors. It records operations on tensors to build a computational graph, +then computes gradients automatically via backpropagation. + +**When to use Autograd:** + +- When training neural networks (gradients are computed automatically) +- When implementing custom backward passes for specialized operations +- When you need fine-grained control over gradient computation + +**Basic usage:** + +```cpp +#include + +// Create tensor with gradient tracking +auto x = torch::randn({2, 2}, torch::requires_grad()); +auto y = x * x; +auto z = y.sum(); + +// Compute gradients via backpropagation +z.backward(); +std::cout << x.grad() << std::endl; // dz/dx = 2x + +// Disable gradient tracking for inference +{ + torch::NoGradGuard no_grad; + auto result = model->forward(input); // No gradients computed +} +``` + +## Header Files + +- `torch/csrc/autograd/autograd.h` - High-level autograd API +- `torch/csrc/autograd/function.h` - Custom autograd functions +- `torch/csrc/autograd/grad_mode.h` - Gradient computation modes +- `torch/csrc/api/include/torch/autograd.h` - C++ Frontend autograd + +## Autograd Categories + +```{toctree} +:maxdepth: 1 + +gradient +custom_functions +modes +``` diff --git a/docs/cpp/source/api/autograd/modes.md b/docs/cpp/source/api/autograd/modes.md new file mode 100644 index 0000000000000..f1a281deae4f8 --- /dev/null +++ b/docs/cpp/source/api/autograd/modes.md @@ -0,0 +1,166 @@ +--- +myst: + html_meta: + description: Gradient mode guards in PyTorch C++ — NoGradGuard and InferenceMode for disabling gradient computation. + keywords: PyTorch, C++, NoGradGuard, InferenceMode, no_grad, inference, RAII guard +--- + +# Gradient Modes + +PyTorch provides RAII guards to control gradient computation behavior. + +## NoGradGuard + +```{cpp:class} torch::NoGradGuard + +RAII guard that disables gradient computation within its scope. +``` + +```{cpp:function} NoGradGuard() + +Disables gradient computation. +``` + +```{cpp:function} ~NoGradGuard() + +Restores previous gradient mode. +``` + +**Example:** + +```cpp +{ + torch::NoGradGuard no_grad; + // No gradients computed in this scope + auto result = model->forward(input); +} +``` + +## InferenceMode + +`c10::InferenceMode` is a RAII guard analogous to `NoGradMode` designed for use +when you are certain your operations will have no interactions with autograd +(e.g., model inference). Compared to `NoGradMode`, code run under this mode gets +better performance by disabling autograd-related work like view tracking and version +counter bumps. However, tensors created inside `InferenceMode` have more limitations +when interacting with the autograd system. + +```{cpp:class} c10::InferenceMode + +RAII guard that enables inference mode for optimized inference. +This is more efficient than NoGradGuard for inference-only workloads. +``` + +```{cpp:function} explicit InferenceMode(bool enabled = true) + +Enables or disables inference mode. +``` + +**Inference Tensors:** + +`InferenceMode` can be enabled for a given block of code. Inside `InferenceMode`, +all newly allocated (non-view) tensors are marked as inference tensors. Inference tensors: + +- Do not have a version counter, so an error will be raised if you try to read their version + (e.g., because you saved this tensor for backward). +- Are immutable outside `InferenceMode`. An error will be raised if you try to: + + - Mutate their data outside InferenceMode. + - Mutate them to `requires_grad=True` outside InferenceMode. + - To work around this, make a clone outside `InferenceMode` to get a normal tensor before mutating. + +A non-view tensor is an inference tensor if and only if it was allocated inside `InferenceMode`. +A view tensor is an inference tensor if and only if it is a view of an inference tensor. + +**Performance Guarantees:** + +Inside an `InferenceMode` block: + +- Like `NoGradMode`, all operations do not record `grad_fn` even if their inputs have + `requires_grad=True`. This applies to both inference tensors and normal tensors. +- View operations on inference tensors do not perform view tracking. View and non-view + inference tensors are indistinguishable. +- Inplace operations on inference tensors are guaranteed not to do a version bump. + +For more implementation details, see the [RFC-0011-InferenceMode](https://github.com/pytorch/rfcs/pull/17). + +**Basic Example:** + +```cpp +{ + c10::InferenceMode guard; + // Optimized inference without gradient tracking + auto result = model->forward(input); +} +``` + +**Inference Workload Example:** + +```cpp +c10::InferenceMode guard; +model.load_jit(saved_model); +auto inputs = preprocess_tensors(data); +auto out = model.forward(inputs); +auto outputs = postprocess_tensors(out); +``` + +**Nested InferenceMode:** + +Unlike some other guards, `InferenceMode` can be nested with different enabled/disabled states: + +```cpp +{ + c10::InferenceMode guard(true); + // InferenceMode is on + { + c10::InferenceMode guard(false); + // InferenceMode is off + } + // InferenceMode is on +} +// InferenceMode is off +``` + +## InferenceMode vs NoGradMode + +`InferenceMode` is preferred over `NoGradMode` for pure inference workloads because +it provides better performance. Key differences: + +- Both guards affect tensor execution to skip work not related to inference, but + `InferenceMode` also affects tensor creation while `NoGradMode` doesn't. +- Tensors created inside `InferenceMode` are marked as inference tensors with + certain limitations that apply after exiting `InferenceMode`. +- `InferenceMode` can be nested with enabled/disabled states. + +## Migrating from AutoNonVariableTypeMode + +The legacy `AutoNonVariableTypeMode` guard (now renamed to +`AutoDispatchBelowADInplaceOrView`) was commonly used for inference workloads +but is unsafe — it can silently bypass safety checks and produce wrong results. + +- **For inference-only workloads** (e.g. loading a pretrained JIT model and + running inference in C++ runtime), use `c10::InferenceMode` as a drop-in + replacement. It preserves the performance characteristics while providing + correctness guarantees. + +- **For custom autograd kernels** that need to redispatch below the Autograd + dispatch key, use `AutoDispatchBelowADInplaceOrView` instead: + + ```cpp + class ROIAlignFunction : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const torch::autograd::Variable& input, + const torch::autograd::Variable& rois, + double spatial_scale, int64_t pooled_height, + int64_t pooled_width, int64_t sampling_ratio, bool aligned) { + ctx->saved_data["spatial_scale"] = spatial_scale; + ctx->save_for_backward({rois}); + at::AutoDispatchBelowADInplaceOrView guard; + auto result = roi_align(input, rois, spatial_scale, + pooled_height, pooled_width, sampling_ratio, aligned); + return {result}; + } + }; + ``` diff --git a/docs/cpp/source/api/c10/device.md b/docs/cpp/source/api/c10/device.md new file mode 100644 index 0000000000000..229082935c11b --- /dev/null +++ b/docs/cpp/source/api/c10/device.md @@ -0,0 +1,114 @@ +--- +myst: + html_meta: + description: Device and DeviceType in PyTorch C++ — c10::Device for specifying CPU, CUDA, XPU, and other backends. + keywords: PyTorch, C++, Device, DeviceType, CPU, CUDA, XPU, MPS, c10 +--- + +# Device and DeviceType + +PyTorch provides device abstractions for writing code that works across +CPU, CUDA, and other backends. + +## Device + +```{doxygenstruct} c10::Device +:members: +:undoc-members: +``` + +**Example:** + +```cpp +c10::Device cpu_device(c10::kCPU); +c10::Device cuda_device(c10::kCUDA, 0); // CUDA device 0 + +if (cuda_device.is_cuda()) { + std::cout << "Using CUDA device " << cuda_device.index() << std::endl; +} +``` + +## DeviceType + +```{cpp:enum-class} c10::DeviceType + +Enumeration of supported device types. +``` + +```{cpp:enumerator} CPU = 0 + +CPU device. +``` + +```{cpp:enumerator} CUDA = 1 + +NVIDIA CUDA GPU. +``` + +```{cpp:enumerator} HIP = 6 + +AMD HIP GPU. +``` + +```{cpp:enumerator} XLA = 9 + +XLA / TPU. +``` + +```{cpp:enumerator} Vulkan = 10 + +Vulkan GPU. +``` + +```{cpp:enumerator} Metal = 11 + +Apple Metal GPU. +``` + +```{cpp:enumerator} XPU = 12 + +Intel XPU GPU. +``` + +```{cpp:enumerator} MPS = 13 + +Apple Metal Performance Shaders. +``` + +```{cpp:enumerator} Meta = 14 + +Meta tensors (shape only, no data). +``` + +```{cpp:enumerator} HPU = 15 + +Habana HPU. +``` + +```{cpp:enumerator} Lazy = 17 + +Lazy tensors. +``` + +```{cpp:enumerator} IPU = 18 + +Graphcore IPU. +``` + +```{cpp:enumerator} MTIA = 19 + +Meta training and inference accelerator. +``` + +```{cpp:enumerator} PrivateUse1 = 20 + +Custom backend registered via `c10::register_privateuse1_backend()`. +``` + +Convenience constants: + +- `c10::kCPU`, `c10::kCUDA`, `c10::kHIP` +- `c10::kXLA`, `c10::kVulkan`, `c10::kMetal` +- `c10::kXPU`, `c10::kMPS`, `c10::kMeta` +- `c10::kHPU`, `c10::kLazy`, `c10::kIPU`, `c10::kMTIA` +- `c10::kPrivateUse1` diff --git a/docs/cpp/source/api/c10/guards.md b/docs/cpp/source/api/c10/guards.md new file mode 100644 index 0000000000000..fba83afbcee57 --- /dev/null +++ b/docs/cpp/source/api/c10/guards.md @@ -0,0 +1,52 @@ +--- +myst: + html_meta: + description: Device and stream guards in PyTorch C++ — RAII guards for managing current device and stream context. + keywords: PyTorch, C++, DeviceGuard, StreamGuard, RAII, device management +--- + +# Device Guards + +C10 provides device-agnostic RAII guards for managing the current device +context. These guards work across all backends (CUDA, XPU, etc.) and +automatically restore the previous device when they go out of scope. + +For backend-specific guards, see {doc}`../cuda/guards` and {doc}`../xpu/index`. + +## DeviceGuard + +```{doxygenclass} c10::DeviceGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +{ + c10::DeviceGuard guard(c10::Device(c10::kCUDA, 1)); + // All operations here run on CUDA device 1 +} +// Previous device is restored +``` + +## OptionalDeviceGuard + +```{doxygenclass} c10::OptionalDeviceGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +c10::OptionalDeviceGuard guard; +if (use_gpu) { + guard.reset_device(c10::Device(c10::kCUDA, 0)); +} +// Guard only restores device if it was set +``` diff --git a/docs/cpp/source/api/c10/index.md b/docs/cpp/source/api/c10/index.md new file mode 100644 index 0000000000000..dd34df87409fa --- /dev/null +++ b/docs/cpp/source/api/c10/index.md @@ -0,0 +1,66 @@ +--- +myst: + html_meta: + description: C10 core library API — fundamental types, device management, streams, and utilities used throughout PyTorch. + keywords: PyTorch, C++, c10, core, Device, Stream, types +--- + +# C10: Core Utilities + +C10 (Caffe2 + ATen = C10) is the core library that provides fundamental +utilities and data types used throughout PyTorch. It contains device +abstractions, memory management utilities, and common data structures. + +**When to use C10:** + +- When working with device-agnostic code (CPU, CUDA, XPU, etc.) +- When you need efficient array views without copying data +- When handling optional values or type-erased containers +- When writing code that needs to work across different PyTorch backends + +**Basic usage:** + +```cpp +#include +#include + +// Device abstraction +c10::Device device(c10::kCUDA, 0); +if (device.is_cuda()) { + std::cout << "Using CUDA device " << device.index() << std::endl; +} + +// Efficient array views (no copy) +std::vector sizes = {3, 4, 5}; +c10::ArrayRef sizes_ref(sizes); + +// Optional values +c10::optional maybe_dim = 2; +int64_t dim = maybe_dim.value_or(-1); +``` + +## Header Files + +- `c10/core/Device.h` - Device abstraction +- `c10/core/DeviceType.h` - Device type enumeration +- `c10/util/ArrayRef.h` - Non-owning array reference +- `c10/util/OptionalArrayRef.h` - Optional array reference +- `c10/util/Optional.h` - Optional value wrapper +- `c10/util/Half.h` - Half-precision float +- `c10/util/Exception.h` - Error checking macros (`TORCH_CHECK`, `TORCH_INTERNAL_ASSERT`, etc.) +- `c10/cuda/CUDAGuard.h` - CUDA device guards (see {doc}`../cuda/index`) +- `c10/cuda/CUDAStream.h` - CUDA stream management (see {doc}`../cuda/index`) +- `c10/xpu/XPUStream.h` - Intel XPU stream management +- `ATen/core/ivalue.h` - IValue for TorchScript interop + +## C10 Categories + +```{toctree} +:maxdepth: 1 + +device +guards +streams +types +utilities +``` diff --git a/docs/cpp/source/api/c10/streams.md b/docs/cpp/source/api/c10/streams.md new file mode 100644 index 0000000000000..e87e3d3fdad7a --- /dev/null +++ b/docs/cpp/source/api/c10/streams.md @@ -0,0 +1,34 @@ +--- +myst: + html_meta: + description: Stream API in PyTorch C++ — c10::Stream for asynchronous execution on devices. + keywords: PyTorch, C++, Stream, c10::Stream, asynchronous, device +--- + +# Streams + +`c10::Stream` is the device-agnostic base stream class. It provides a +common interface for working with streams across different backends +(CUDA, XPU, etc.). + +For backend-specific stream APIs, see {doc}`../cuda/streams` and {doc}`../xpu/index`. + +## Stream + +```{doxygenclass} c10::Stream +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +// Streams are typically obtained from backend-specific APIs +auto cuda_stream = c10::cuda::getCurrentCUDAStream(); + +// c10::Stream provides the common interface +c10::Device device = cuda_stream.device(); +c10::DeviceType type = cuda_stream.device_type(); +``` diff --git a/docs/cpp/source/api/c10/types.md b/docs/cpp/source/api/c10/types.md new file mode 100644 index 0000000000000..60ffa1d784c45 --- /dev/null +++ b/docs/cpp/source/api/c10/types.md @@ -0,0 +1,225 @@ +--- +myst: + html_meta: + description: Core types in PyTorch C++ — ArrayRef, optional, Dict, List, IListRef, Half, and IValue. + keywords: PyTorch, C++, c10, ArrayRef, optional, Dict, List, IListRef, Half, IValue +--- + +# Core Types + +C10 provides fundamental types used throughout PyTorch. + +## ArrayRef + +```{doxygenclass} c10::ArrayRef +:members: +:undoc-members: +``` + +**Example:** + +```cpp +std::vector sizes = {3, 4, 5}; +c10::ArrayRef sizes_ref(sizes); + +// Can also use initializer list +auto tensor = at::zeros({3, 4, 5}); // implicitly converts +``` + +## OptionalArrayRef + +```{doxygenclass} c10::OptionalArrayRef +:members: +:no-link: +``` + +**Example:** + +```cpp +void my_function(c10::OptionalArrayRef sizes = c10::nullopt) { + if (sizes.has_value()) { + for (auto s : sizes.value()) { + // process sizes + } + } +} +``` + +## Optional + +```{cpp:class} c10::optional + +A wrapper type that may or may not contain a value. +Similar to `std::optional`. +``` + +```{cpp:function} bool has_value() const + +Returns true if a value is present. +``` + +```{cpp:function} T& value() + +Returns the contained value. Throws if empty. +``` + +```{cpp:function} T value_or(T default_value) const + +Returns the value if present, otherwise returns the default. +``` + +**Example:** + +```cpp +c10::optional maybe_dim = c10::nullopt; + +if (maybe_dim.has_value()) { + std::cout << "Dim: " << maybe_dim.value() << std::endl; +} + +int64_t dim = maybe_dim.value_or(-1); // Returns -1 if empty +``` + +## Half + +```{cpp:class} c10::Half + +16-bit floating point type (IEEE 754 half-precision). +``` + +```{cpp:function} Half(float value) + +Construct from a float. +``` + +```{cpp:function} operator float() const + +Convert to float. +``` + +**Example:** + +```cpp +c10::Half h = 3.14f; +float f = static_cast(h); +``` + +## Containers + +C10 provides container types that store `IValue` elements internally. These +are pointer types: copies share the same underlying storage. + +### Dict + +An ordered hash map from `Key` to `Value`. Valid key types are `int64_t`, +`double`, `bool`, `std::string`, and `at::Tensor`. + +```{doxygenclass} c10::Dict +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +c10::Dict named_tensors; +named_tensors.insert("weight", torch::randn({3, 3})); +named_tensors.insert("bias", torch::zeros({3})); + +if (named_tensors.contains("weight")) { + at::Tensor w = named_tensors.at("weight"); +} + +for (const auto& entry : named_tensors) { + std::cout << entry.key() << ": " << entry.value().sizes() << std::endl; +} +``` + +### List + +A type-safe list container backed by `IValue` elements. + +```{doxygenclass} c10::List +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +c10::List tensor_list; +tensor_list.push_back(torch::randn({2, 3})); +tensor_list.push_back(torch::zeros({2, 3})); + +at::Tensor first = tensor_list.get(0); +std::cout << "List size: " << tensor_list.size() << std::endl; + +c10::List int_list; +int_list.push_back(1); +int_list.push_back(2); +int_list.push_back(3); +``` + +### IListRef + +`c10::IListRef` is a lightweight reference type that provides a unified +interface over different list-like types (`List`, `ArrayRef`, +`std::vector`). It avoids copying when passing list arguments to operators. + +```{doxygenclass} c10::IListRef +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +// IListRef can wrap different underlying types +std::vector vec = {torch::randn({2}), torch::randn({3})}; +c10::IListRef ref(vec); + +for (const auto& t : ref) { + std::cout << t.sizes() << std::endl; +} +``` + +## IValue + +`c10::IValue` (Interpreter Value) is a type-erased container used extensively +for storing values of different types. It can hold tensors, +scalars, lists, dictionaries, and other types. + +```{note} + +The full API documentation for IValue is complex due to its many type +conversion methods. See the header file `ATen/core/ivalue.h` for complete +details. +``` + +**Common methods:** + +- `isTensor()` / `toTensor()` - Check if tensor / convert to tensor +- `isInt()` / `toInt()` - Check if int / convert to int +- `isDouble()` / `toDouble()` - Check if double / convert to double +- `isBool()` / `toBool()` - Check if bool / convert to bool +- `isString()` / `toString()` - Check if string / convert to string +- `isList()` / `toList()` - Check if list / convert to list +- `isGenericDict()` / `toGenericDict()` - Check if dict / convert to dict +- `isTuple()` / `toTuple()` - Check if tuple / convert to tuple +- `isNone()` - Check if None/null + +**Example:** + +```cpp +c10::IValue val = at::ones({2, 2}); + +if (val.isTensor()) { + at::Tensor t = val.toTensor(); +} +``` diff --git a/docs/cpp/source/api/c10/utilities.md b/docs/cpp/source/api/c10/utilities.md new file mode 100644 index 0000000000000..2e61c14c8e8dc --- /dev/null +++ b/docs/cpp/source/api/c10/utilities.md @@ -0,0 +1,153 @@ +--- +myst: + html_meta: + description: C10 utility classes in PyTorch C++ — Flags, QEngine, and Reduction enumerations. + keywords: PyTorch, C++, c10, utilities, QEngine, Reduction +--- + +# Utilities + +C10 provides utility classes for memory management and other common patterns. + +## MaybeOwned + +`MaybeOwned` is a C++ smart pointer class that dynamically +encodes whether a Tensor is *owned* or *borrowed*. It is used in +certain performance-sensitive situations to avoid unnecessarily +incrementing a Tensor's reference count (at a small cost in +overhead from the extra indirection). + +```{warning} + + MaybeOwned must be used with **extreme** care. Claims of (non-)ownership + are not statically checked, and mistakes can cause reference undercounting + and use-after-free crashes. + + Due to this lack of safety net, we discourage the use of MaybeOwned + outside code paths that are known to be highly performance sensitive. + However, if you encounter pre-existing uses of MaybeOwned in code that + you want to modify, it's critical to understand how to use it correctly. +``` + +**Use Case:** + +The primary use case for `MaybeOwned` is a function or method that +dynamically chooses between returning one of its arguments (typically +from a passthrough or "no-op" code path) and returning a freshly constructed +Tensor. Such a function would return a `MaybeOwned` in both cases: +the former in a "borrowed" state via `MaybeOwned::borrowed()`, +and the latter in an "owned" state via `MaybeOwned::owned()`. + +**Example - expect_contiguous:** + +The canonical example is `Tensor`'s `expect_contiguous` method, which shortcuts +and returns a borrowed self-reference when already contiguous: + +```cpp +inline c10::MaybeOwned Tensor::expect_contiguous( + MemoryFormat memory_format) const & { + if (is_contiguous(memory_format)) { + return c10::MaybeOwned::borrowed(*this); + } else { + return c10::MaybeOwned::owned( + __dispatch_contiguous(memory_format)); + } +} +``` + +Using the vocabulary of lifetimes, the essential safety requirement for borrowing +is that a borrowed Tensor must outlive any borrowing references to it. In the example +above, we can safely borrow `*this`, but the Tensor returned by +`__dispatch_contiguous()` is freshly created, and borrowing a reference would +effectively leave it ownerless. + +**Rules of Thumb:** + +- When in doubt, don't use `MaybeOwned` at all - in particular, prefer + avoiding using it in code that doesn't use it already. New usage should only be + introduced when critical (and demonstrable) performance gains result. + +- When modifying or calling code that already uses `MaybeOwned`, remember + that it's always safe to produce a `MaybeOwned` from a Tensor in hand + via a call to `MaybeOwned::owned()`. This may result in an unnecessary + reference count, but never in misbehavior - so it's always the safer bet, unless + the lifetime of the Tensor you're looking to wrap is crystal clear. + +More details and implementation code can be found at +[MaybeOwned.h](https://github.com/pytorch/pytorch/blob/main/c10/util/MaybeOwned.h) and +[TensorBody.h](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/templates/TensorBody.h). + +## Error Handling and Assertions + +PyTorch provides macros for error checking and assertions that produce +informative error messages with source location. These are defined in +`c10/util/Exception.h`. + +### TORCH_CHECK + +The primary macro for validating user input and runtime conditions. On failure, +raises `c10::Error` (which becomes `RuntimeError` in Python). + +```cpp +#include + +// Basic check +TORCH_CHECK(tensor.dim() == 2, "Expected 2D tensor, got ", tensor.dim(), "D"); + +// Without message (default message generated) +TORCH_CHECK(x >= 0); +``` + +Typed variants raise specific Python exception types: + +- `TORCH_CHECK_INDEX(cond, ...)` — raises `IndexError` +- `TORCH_CHECK_VALUE(cond, ...)` — raises `ValueError` +- `TORCH_CHECK_TYPE(cond, ...)` — raises `TypeError` +- `TORCH_CHECK_LINALG(cond, ...)` — raises `LinAlgError` +- `TORCH_CHECK_NOT_IMPLEMENTED(cond, ...)` — raises `NotImplementedError` + +### TORCH_INTERNAL_ASSERT + +For internal invariants that should always hold (i.e., failures indicate a bug +in PyTorch, not user error). Produces a message asking users to report the bug. + +```cpp +TORCH_INTERNAL_ASSERT(googol > 0); +TORCH_INTERNAL_ASSERT(googol > 0, "googol was ", googol); +``` + +```{note} + +Use `TORCH_CHECK` for conditions that can fail due to user input. +Use `TORCH_INTERNAL_ASSERT` only for conditions that indicate a PyTorch bug. +`TORCH_INTERNAL_ASSERT_DEBUG_ONLY` is the debug-build-only variant for +hot paths. +``` + +### TORCH_WARN + +Issues a warning (not an error) to the user. + +```cpp +TORCH_WARN("This operation is slow for sparse tensors"); +TORCH_WARN_ONCE("This warning appears only once"); +``` + +### c10::Error + +The base exception class for PyTorch C++ errors. Provides source location +and optional backtrace. + +```cpp +try { + auto result = some_operation(); +} catch (const c10::Error& e) { + std::cerr << e.what() << std::endl; + // Or without backtrace: + std::cerr << e.what_without_backtrace() << std::endl; +} +``` + +Specialized subclasses: `c10::IndexError`, `c10::ValueError`, +`c10::TypeError`, `c10::NotImplementedError`, `c10::LinAlgError`, +`c10::OutOfMemoryError`. diff --git a/docs/cpp/source/api/cuda/guards.md b/docs/cpp/source/api/cuda/guards.md new file mode 100644 index 0000000000000..d7724e390525a --- /dev/null +++ b/docs/cpp/source/api/cuda/guards.md @@ -0,0 +1,96 @@ +--- +myst: + html_meta: + description: CUDA device and stream guards in PyTorch C++ — CUDAGuard, CUDAStreamGuard, and CUDAMultiStreamGuard. + keywords: PyTorch, C++, CUDA, CUDAGuard, CUDAStreamGuard, device guard, multi-GPU +--- + +# CUDA Guards + +CUDA guards are RAII wrappers that set a CUDA device or stream as the current +context and automatically restore the previous context when the guard goes +out of scope. + +## CUDAGuard + +```{doxygenstruct} c10::cuda::CUDAGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +{ + c10::cuda::CUDAGuard guard(1); // Switch to device 1 + // All CUDA operations here run on device 1 + auto tensor = torch::zeros({2, 2}, torch::device(torch::kCUDA)); +} +// Previous device is restored +``` + +## CUDAStreamGuard + +```{doxygenstruct} c10::cuda::CUDAStreamGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +auto stream = c10::cuda::getStreamFromPool(); +{ + c10::cuda::CUDAStreamGuard guard(stream); + // Operations here use the specified stream +} +// Previous stream is restored +``` + +## OptionalCUDAGuard + +```{doxygenstruct} c10::cuda::OptionalCUDAGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +c10::cuda::OptionalCUDAGuard guard; +if (use_cuda) { + guard.set_device(0); +} +// Guard only switches device if set_device was called +``` + +## OptionalCUDAStreamGuard + +```{doxygenstruct} c10::cuda::OptionalCUDAStreamGuard +:members: +:undoc-members: +``` + +## CUDAMultiStreamGuard + +```{doxygenstruct} c10::cuda::CUDAMultiStreamGuard +:members: +:undoc-members: +``` + +**Example:** + +```cpp +at::cuda::CUDAStream stream0 = at::cuda::getStreamFromPool(false, 0); +at::cuda::CUDAStream stream1 = at::cuda::getStreamFromPool(false, 1); + +{ + at::cuda::CUDAMultiStreamGuard multi_guard({stream0, stream1}); + // stream0 is current on device 0, stream1 on device 1 +} +// Both streams restored +``` diff --git a/docs/cpp/source/api/cuda/index.md b/docs/cpp/source/api/cuda/index.md new file mode 100644 index 0000000000000..0f7f07f447b4f --- /dev/null +++ b/docs/cpp/source/api/cuda/index.md @@ -0,0 +1,58 @@ +--- +myst: + html_meta: + description: PyTorch CUDA C++ API — device management, streams, guards, and cuDNN/cuBLAS utilities. + keywords: PyTorch, C++, CUDA, GPU, streams, guards, cuDNN, cuBLAS +--- + +# CUDA Support + +PyTorch provides comprehensive CUDA support for GPU-accelerated tensor +operations and neural network training. The CUDA API allows you to manage +GPU devices, streams for asynchronous execution, and memory efficiently. + +**When to use CUDA APIs:** + +- When you need explicit control over which GPU device to use +- When implementing custom CUDA kernels or operations +- When optimizing performance with asynchronous stream execution +- When managing multi-GPU workloads + +**Basic usage:** + +```cpp +#include +#include + +// Check if CUDA is available +if (torch::cuda::is_available()) { + // Create tensor on GPU + auto tensor = torch::randn({2, 3}, torch::device(torch::kCUDA)); + + // Switch to a specific GPU + c10::cuda::CUDAGuard guard(0); // Use GPU 0 + + // Get the current CUDA stream + auto stream = c10::cuda::getCurrentCUDAStream(); + + // Move model to GPU + model->to(torch::kCUDA); +} +``` + +## Header Files + +- `c10/cuda/CUDAStream.h` - CUDA stream management +- `c10/cuda/CUDAGuard.h` - CUDA device guards +- `ATen/cuda/CUDAContext.h` - CUDA context management +- `ATen/cudnn/Descriptors.h` - cuDNN tensor descriptors + +## CUDA Categories + +```{toctree} +:maxdepth: 1 + +streams +guards +utilities +``` diff --git a/docs/cpp/source/api/cuda/streams.md b/docs/cpp/source/api/cuda/streams.md new file mode 100644 index 0000000000000..a55eb609d1bf7 --- /dev/null +++ b/docs/cpp/source/api/cuda/streams.md @@ -0,0 +1,190 @@ +--- +myst: + html_meta: + description: CUDA streams in PyTorch C++ — CUDAStream for asynchronous GPU execution and synchronization. + keywords: PyTorch, C++, CUDA, CUDAStream, stream, asynchronous, GPU, synchronization +--- + +# CUDA Streams + +CUDA streams provide a mechanism for asynchronous execution of operations +on the GPU. Operations queued to the same stream execute in order, while +operations on different streams can execute concurrently. + +## CUDAStream + +```{doxygenclass} c10::cuda::CUDAStream +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +// Get the default stream for current device +auto stream = c10::cuda::getDefaultCUDAStream(); + +// Create a new stream +auto new_stream = c10::cuda::getStreamFromPool(); + +// Get current stream +auto current = c10::cuda::getCurrentCUDAStream(); + +// Synchronize +stream.synchronize(); +``` + +## Acquiring CUDA Streams + +PyTorch provides several ways to acquire CUDA streams: + +1. **From the stream pool** (round-robin allocation): + + ```cpp + // Normal priority stream + at::cuda::CUDAStream stream = at::cuda::getStreamFromPool(); + + // High priority stream + at::cuda::CUDAStream high_prio = at::cuda::getStreamFromPool(/*isHighPriority=*/true); + + // Stream for specific device + at::cuda::CUDAStream dev1_stream = at::cuda::getStreamFromPool(false, /*device=*/1); + ``` + +2. **Default stream** (where most computation occurs): + + ```cpp + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); + ``` + +3. **Current stream** (may differ if changed with guards): + + ```cpp + at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(); + ``` + +## Setting CUDA Streams + +**Using setCurrentCUDAStream:** + +```cpp +torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); + +// Get a new stream and set it as current +at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(); +at::cuda::setCurrentCUDAStream(myStream); + +// Operations now use myStream +tensor0.sum(); + +// Restore default stream +at::cuda::setCurrentCUDAStream(at::cuda::getDefaultCUDAStream()); +``` + +**Using CUDAStreamGuard (recommended):** + +```cpp +torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); +at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(); + +{ + at::cuda::CUDAStreamGuard guard(myStream); + // Operations use myStream within this scope + tensor0.sum(); +} +// Stream automatically restored to default +``` + +## Multi-Device Stream Management + +**Streams on multiple devices:** + +```cpp +// Acquire streams for different devices +at::cuda::CUDAStream stream0 = at::cuda::getStreamFromPool(false, 0); +at::cuda::CUDAStream stream1 = at::cuda::getStreamFromPool(false, 1); + +// Set current streams on each device +at::cuda::setCurrentCUDAStream(stream0); +at::cuda::setCurrentCUDAStream(stream1); + +// Create tensors on device 0 +torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(at::kCUDA)); +tensor0.sum(); // Uses stream0 + +// Switch to device 1 +{ + at::cuda::CUDAGuard device_guard(1); + torch::Tensor tensor1 = torch::ones({2, 2}, torch::device(at::kCUDA)); + tensor1.sum(); // Uses stream1 +} +``` + +**Using CUDAMultiStreamGuard:** + +```cpp +torch::Tensor tensor0 = torch::ones({2, 2}, torch::device({torch::kCUDA, 0})); +torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); + +at::cuda::CUDAStream stream0 = at::cuda::getStreamFromPool(false, 0); +at::cuda::CUDAStream stream1 = at::cuda::getStreamFromPool(false, 1); + +{ + // Set streams on both devices simultaneously + at::cuda::CUDAMultiStreamGuard multi_guard({stream0, stream1}); + + tensor0.sum(); // Uses stream0 on device 0 + tensor1.sum(); // Uses stream1 on device 1 +} +// Both streams restored to defaults +``` + +```{attention} + +`CUDAMultiStreamGuard` does not change the current device index. It only +changes the stream on each passed-in stream's device. +``` + +## Multi-Device Stream Handling Pattern + +The following skeleton shows three common patterns for acquiring and setting +streams across multiple CUDA devices: + +```cpp +// Create stream vectors on device 0 +std::vector streams0 = + {at::cuda::getDefaultCUDAStream(), at::cuda::getStreamFromPool()}; +at::cuda::setCurrentCUDAStream(streams0[0]); + +// Create stream vector on device 1 using CUDAGuard +std::vector streams1; +{ + at::cuda::CUDAGuard device_guard(1); + streams1.push_back(at::cuda::getDefaultCUDAStream()); + streams1.push_back(at::cuda::getStreamFromPool()); +} +at::cuda::setCurrentCUDAStream(streams1[0]); + +// Pattern 1: CUDAGuard changes current device only, not streams +{ + at::cuda::CUDAGuard device_guard(1); + // current device is 1, current stream on device 1 is still streams1[0] +} + +// Pattern 2: CUDAStreamGuard changes both current device and current stream +{ + at::cuda::CUDAStreamGuard stream_guard(streams1[1]); + // current device is 1, current stream is streams1[1] +} +// restored to device 0, stream streams0[0] + +// Pattern 3: CUDAMultiStreamGuard sets streams on multiple devices at once +{ + at::cuda::CUDAMultiStreamGuard multi_guard({streams0[1], streams1[1]}); + // current device unchanged (still 0) + // stream on device 0 is streams0[1], stream on device 1 is streams1[1] +} +// streams restored to streams0[0] and streams1[0] +``` diff --git a/docs/cpp/source/api/cuda/utilities.md b/docs/cpp/source/api/cuda/utilities.md new file mode 100644 index 0000000000000..951c49e3b71de --- /dev/null +++ b/docs/cpp/source/api/cuda/utilities.md @@ -0,0 +1,226 @@ +--- +myst: + html_meta: + description: CUDA utility functions in PyTorch C++ — device properties, cuBLAS/cuDNN handles, and stream management. + keywords: PyTorch, C++, CUDA, device_count, cuBLAS, cuDNN, cuSPARSE, TensorDescriptor +--- + +# CUDA Utility Functions + +PyTorch provides utility functions for querying and managing CUDA devices, +streams, and library handles. + +## Device Management + +```{doxygenfunction} torch::cuda::device_count +``` + +```{cpp:function} int c10::cuda::current_device() + +Returns the index of the current CUDA device. +``` + +**Example:** + +```cpp +#include + +// Check available devices +int num_devices = c10::cuda::device_count(); + +// Get current device +int current = c10::cuda::current_device(); +``` + +## Device Properties + +```{doxygenfunction} at::cuda::getCurrentDeviceProperties +``` + +```{doxygenfunction} at::cuda::getDeviceProperties +``` + +```{doxygenfunction} at::cuda::canDeviceAccessPeer +``` + +```{doxygenfunction} at::cuda::warp_size +``` + +**Example:** + +```cpp +#include + +// Query properties of the current device +cudaDeviceProp* props = at::cuda::getCurrentDeviceProperties(); +std::cout << "Device: " << props->name << std::endl; +std::cout << "Compute capability: " << props->major << "." << props->minor << std::endl; + +// Query a specific device +cudaDeviceProp* dev1_props = at::cuda::getDeviceProperties(1); + +// Check peer access +bool can_access = at::cuda::canDeviceAccessPeer(0, 1); +``` + +## Library Handles + +These functions return handles for CUDA math libraries on the current device +and stream. They are primarily useful when writing custom CUDA kernels that +call cuBLAS, cuSPARSE, or cuSOLVER directly. + +```{doxygenfunction} at::cuda::getCurrentCUDABlasHandle +``` + +```{doxygenfunction} at::cuda::getCurrentCUDABlasLtHandle +``` + +```{doxygenfunction} at::cuda::getCurrentCUDASparseHandle +``` + +```{doxygenfunction} at::cuda::getCurrentCUDASolverDnHandle +``` + +**Example:** + +```cpp +#include + +// Get cuBLAS handle for current device/stream +cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + +// Get cuSPARSE handle +cusparseHandle_t sparse_handle = at::cuda::getCurrentCUDASparseHandle(); +``` + +## cuDNN Descriptors + +When writing custom kernels that call cuDNN directly, PyTorch provides RAII +wrapper classes for cuDNN descriptors. These are defined in +`ATen/cudnn/Descriptors.h`. + +### Descriptor (Base Class) + +```{doxygenclass} at::native::Descriptor +:members: +:undoc-members: +``` + +A generic RAII wrapper for cuDNN descriptor types. Descriptors default +construct to a `nullptr` and are initialized on first use via `mut_desc()`. +Use `desc()` for read-only access. + +### TensorDescriptor + +```{doxygenclass} at::native::TensorDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnTensorDescriptor_t`. Supports padding lower-dimensional tensors +to meet cuDNN broadcasting requirements (see `pad` parameter). + +**Example:** + +```cpp +#include + +at::Tensor input = torch::randn({32, 3, 224, 224}, torch::kCUDA); +at::native::TensorDescriptor desc(input); +cudnnTensorDescriptor_t raw = desc.desc(); +``` + +### FilterDescriptor + +```{doxygenclass} at::native::FilterDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnFilterDescriptor_t` for convolution filter weights. + +### ConvolutionDescriptor + +```{doxygenstruct} at::native::ConvolutionDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnConvolutionDescriptor_t`. Configures padding, stride, dilation, +groups, and math type (TF32, tensor ops) for convolution operations. + +### RNNDataDescriptor + +```{doxygenclass} at::native::RNNDataDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnRNNDataDescriptor_t` for variable-length sequence data. + +### DropoutDescriptor + +```{doxygenstruct} at::native::DropoutDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnDropoutDescriptor_t`. Manages RNG state for cuDNN dropout. + +### ActivationDescriptor + +```{doxygenstruct} at::native::ActivationDescriptor +:members: +:undoc-members: +``` + +Wraps `cudnnActivationDescriptor_t`. + +### SpatialTransformerDescriptor + +```{doxygenstruct} at::native::SpatialTransformerDescriptor +:members: +:undoc-members: +``` + +### CTCLossDescriptor + +```{doxygenstruct} at::native::CTCLossDescriptor +:members: +:undoc-members: +``` + +## Stream Management + +```{doxygenfunction} c10::cuda::getDefaultCUDAStream +``` + +```{doxygenfunction} c10::cuda::getCurrentCUDAStream +``` + +```{doxygenfunction} c10::cuda::setCurrentCUDAStream +``` + +```{doxygenfunction} c10::cuda::getStreamFromPool(const bool isHighPriority, DeviceIndex device) +``` + +```{doxygenfunction} c10::cuda::getStreamFromExternal +``` + +**Example:** + +```cpp +#include + +// Create and set custom stream +auto stream = c10::cuda::getStreamFromPool(); +c10::cuda::setCurrentCUDAStream(stream); + +// Get default stream +auto default_stream = c10::cuda::getDefaultCUDAStream(); + +// Wrap an externally created CUDA stream +cudaStream_t ext_stream; +cudaStreamCreate(&ext_stream); +auto wrapped = c10::cuda::getStreamFromExternal(ext_stream, /*device_index=*/0); +``` diff --git a/docs/cpp/source/api/data/dataloader.md b/docs/cpp/source/api/data/dataloader.md new file mode 100644 index 0000000000000..f368363368bf6 --- /dev/null +++ b/docs/cpp/source/api/data/dataloader.md @@ -0,0 +1,103 @@ +--- +myst: + html_meta: + description: DataLoader in PyTorch C++ — parallel data loading with batching, sampling, and multi-worker support. + keywords: PyTorch, C++, DataLoader, data loading, batching, workers, make_data_loader +--- + +# DataLoader + +The DataLoader batches samples from a dataset and optionally shuffles and +parallelizes the loading process. It is the main interface for iterating +over training data. + +## DataLoader Classes + +```{doxygenclass} torch::data::DataLoaderBase +:members: +:undoc-members: +``` + +## DataLoaderOptions + +```{doxygenstruct} torch::data::DataLoaderOptions +:members: +:undoc-members: +``` + +## StatefulDataLoader + +A DataLoader for `StatefulDataset` types that manage their own batching logic +internally. + +```{doxygenclass} torch::data::StatefulDataLoader +:members: +:undoc-members: +``` + +## StatelessDataLoader + +A DataLoader for `Dataset` types that use external samplers for batching. + +```{doxygenclass} torch::data::StatelessDataLoader +:members: +:undoc-members: +``` + +## Iterator + +```{doxygenclass} torch::data::Iterator +:members: +:undoc-members: +``` + +## Creating a DataLoader + +Use `make_data_loader` to create a DataLoader from a dataset: + +```cpp +auto data_loader = torch::data::make_data_loader( + std::move(dataset), + torch::data::DataLoaderOptions() + .batch_size(64) + .workers(4)); + +for (auto& batch : *data_loader) { + auto data = batch.data; + auto target = batch.target; + // Train on batch +} +``` + +## Complete Training Example + +```cpp +#include + +int main() { + // Load dataset + auto dataset = torch::data::datasets::MNIST("./data") + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + + // Create data loader + auto data_loader = torch::data::make_data_loader( + std::move(dataset), + torch::data::DataLoaderOptions().batch_size(64).workers(2)); + + // Create model and optimizer + auto model = std::make_shared(); + auto optimizer = torch::optim::Adam(model->parameters(), 0.001); + + // Training loop + for (size_t epoch = 1; epoch <= 10; ++epoch) { + for (auto& batch : *data_loader) { + optimizer.zero_grad(); + auto output = model->forward(batch.data); + auto loss = torch::nll_loss(output, batch.target); + loss.backward(); + optimizer.step(); + } + } +} +``` diff --git a/docs/cpp/source/api/data/datasets.md b/docs/cpp/source/api/data/datasets.md new file mode 100644 index 0000000000000..4bb01e255c393 --- /dev/null +++ b/docs/cpp/source/api/data/datasets.md @@ -0,0 +1,110 @@ +--- +myst: + html_meta: + description: Dataset classes in PyTorch C++ — Dataset, MapDataset, StreamDataset, and built-in datasets like MNIST. + keywords: PyTorch, C++, Dataset, MapDataset, StreamDataset, MNIST, data +--- + +# Datasets + +The dataset abstraction defines how to access individual samples in your data. +All datasets inherit from `Dataset` and must implement `get()` and `size()`. + +## Dataset Base Class + +```{doxygenclass} torch::data::datasets::Dataset +:members: +:undoc-members: +``` + +```{doxygenclass} torch::data::datasets::BatchDataset +:members: +:undoc-members: +``` + +## StatefulDataset + +A dataset that manages its own state across batches (e.g., position in a stream). +Unlike `Dataset`, it produces batches directly without external samplers. + +```{doxygenclass} torch::data::datasets::StatefulDataset +:members: +:undoc-members: +``` + +## ChunkDataReader + +Interface for reading chunks of data from a data source. Used with +`ChunkDataset` for large-scale data loading. + +```{doxygenclass} torch::data::datasets::ChunkDataReader +:members: +:undoc-members: +``` + +## Custom Dataset Example + +```cpp +class CustomDataset : public torch::data::datasets::Dataset { + public: + explicit CustomDataset(const std::string& root) { + // Load data from root directory + } + + torch::data::Example<> get(size_t index) override { + return {images_[index], labels_[index]}; + } + + torch::optional size() const override { + return images_.size(0); + } + + private: + torch::Tensor images_, labels_; +}; +``` + +## MapDataset + +```{doxygenclass} torch::data::datasets::MapDataset +:members: +:undoc-members: +``` + +## ChunkDataset + +```{doxygenclass} torch::data::datasets::ChunkDataset +:members: +:undoc-members: +``` + +## SharedBatchDataset + +```{doxygenclass} torch::data::datasets::SharedBatchDataset +:members: +:undoc-members: +``` + +## Built-in Datasets + +### MNIST + +```{doxygenclass} torch::data::datasets::MNIST +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto dataset = torch::data::datasets::MNIST("./data") + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); +``` + +## Example Struct + +```{doxygenstruct} torch::data::Example +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/data/index.md b/docs/cpp/source/api/data/index.md new file mode 100644 index 0000000000000..7e5c5517c510c --- /dev/null +++ b/docs/cpp/source/api/data/index.md @@ -0,0 +1,66 @@ +--- +myst: + html_meta: + description: PyTorch C++ data loading API — datasets, data loaders, samplers, and transforms. + keywords: PyTorch, C++, data, DataLoader, Dataset, sampler, transform +--- + +# Data Loading (torch::data) + +The `torch::data` namespace provides utilities for loading and processing +datasets during training. It includes dataset abstractions, data loaders for +batching and shuffling, samplers for controlling data access patterns, and +transforms for data augmentation. + +**When to use torch::data:** + +- When loading training data in batches +- When you need parallel data loading with multiple workers +- When implementing custom datasets or transforms + +**Components overview:** + +- **Dataset**: Defines how to access individual samples (implement `get()` and `size()`) +- **DataLoader**: Batches samples and optionally shuffles/parallelizes loading +- **Sampler**: Controls the order in which samples are accessed +- **Transform**: Applies preprocessing (normalization, augmentation) to samples + +**Basic usage:** + +```cpp +#include + +// Load built-in dataset +auto dataset = torch::data::datasets::MNIST("./data") + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); + +// Create data loader with batching and shuffling +auto data_loader = torch::data::make_data_loader( + std::move(dataset), + torch::data::DataLoaderOptions().batch_size(64).workers(4)); + +// Iterate over batches +for (auto& batch : *data_loader) { + auto images = batch.data; // Shape: [64, 1, 28, 28] + auto labels = batch.target; // Shape: [64] +} +``` + +## Header Files + +- `torch/csrc/api/include/torch/data.h` - Main data header +- `torch/csrc/api/include/torch/data/dataloader.h` - DataLoader +- `torch/csrc/api/include/torch/data/datasets.h` - Dataset classes +- `torch/csrc/api/include/torch/data/samplers.h` - Samplers + +## Module Categories + +```{toctree} +:maxdepth: 1 + +datasets +dataloader +samplers +transforms +``` diff --git a/docs/cpp/source/api/data/samplers.md b/docs/cpp/source/api/data/samplers.md new file mode 100644 index 0000000000000..f18f09c27d493 --- /dev/null +++ b/docs/cpp/source/api/data/samplers.md @@ -0,0 +1,69 @@ +--- +myst: + html_meta: + description: Data samplers in PyTorch C++ — RandomSampler, SequentialSampler, DistributedRandomSampler, and StreamSampler. + keywords: PyTorch, C++, sampler, RandomSampler, SequentialSampler, DistributedRandomSampler +--- + +# Samplers + +Samplers control the order in which samples are accessed from a dataset. +They determine the indices that the DataLoader uses to fetch data. + +## Sampler Base Class + +```{doxygenclass} torch::data::samplers::Sampler +:members: +:undoc-members: +``` + +## Sequential Sampler + +Accesses samples in order from 0 to N-1. Use this for evaluation or when +order matters. + +```{doxygenclass} torch::data::samplers::SequentialSampler +:members: +:undoc-members: +``` + +## Random Sampler + +Accesses samples in random order. Use this for training to ensure the model +sees samples in different orders each epoch. + +```{doxygenclass} torch::data::samplers::RandomSampler +:members: +:undoc-members: +``` + +## Distributed Random Sampler + +For distributed training, ensures each process gets a different subset of +the data without overlap. + +```{doxygenclass} torch::data::samplers::DistributedRandomSampler +:members: +:undoc-members: +``` + +## Distributed Sampler (Base) + +```{doxygenclass} torch::data::samplers::DistributedSampler +:members: +:undoc-members: +``` + +## Distributed Sequential Sampler + +```{doxygenclass} torch::data::samplers::DistributedSequentialSampler +:members: +:undoc-members: +``` + +## Stream Sampler + +```{doxygenclass} torch::data::samplers::StreamSampler +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/data/transforms.md b/docs/cpp/source/api/data/transforms.md new file mode 100644 index 0000000000000..5026596e8b63b --- /dev/null +++ b/docs/cpp/source/api/data/transforms.md @@ -0,0 +1,95 @@ +--- +myst: + html_meta: + description: Data transforms in PyTorch C++ — Stack, Normalize, Lambda, and Collate transforms for data pipelines. + keywords: PyTorch, C++, transforms, Stack, Normalize, Lambda, Collate, data pipeline +--- + +# Transforms + +Transforms apply preprocessing to data samples, such as normalization or +augmentation. They can be chained using the `.map()` method on datasets. + +## Transform (Base Class) + +The base class for all transforms. Subclass this to create custom transforms. + +```{doxygenclass} torch::data::transforms::Transform +:members: +:undoc-members: +``` + +## BatchTransform (Base Class) + +Base class for transforms that operate on entire batches. + +```{doxygenclass} torch::data::transforms::BatchTransform +:members: +:undoc-members: +``` + +## TensorTransform + +Base class for transforms that operate on tensors specifically. + +```{doxygenclass} torch::data::transforms::TensorTransform +:members: +:undoc-members: +``` + +## Normalize + +Normalizes tensors with a given mean and standard deviation. + +```{doxygenstruct} torch::data::transforms::Normalize +:members: +:undoc-members: +``` + +## Stack + +Stacks a batch of tensors into a single tensor. + +```{doxygenstruct} torch::data::transforms::Stack +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto dataset = torch::data::datasets::MNIST("./data") + .map(torch::data::transforms::Normalize<>(0.5, 0.5)) + .map(torch::data::transforms::Stack<>()); +``` + +## Lambda + +```{doxygenclass} torch::data::transforms::Lambda +:members: +:undoc-members: +``` + +## TensorLambda + +```{doxygenclass} torch::data::transforms::TensorLambda +:members: +:undoc-members: +``` + +## BatchLambda + +```{doxygenclass} torch::data::transforms::BatchLambda +:members: +:undoc-members: +``` + +## Chaining Transforms + +Transforms can be chained together using `.map()`: + +```cpp +auto dataset = torch::data::datasets::MNIST("./data") + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) + .map(torch::data::transforms::Stack<>()); +``` diff --git a/docs/cpp/source/api/index.md b/docs/cpp/source/api/index.md new file mode 100644 index 0000000000000..3c3ff9d855039 --- /dev/null +++ b/docs/cpp/source/api/index.md @@ -0,0 +1,40 @@ +--- +myst: + html_meta: + description: PyTorch C++ API reference — complete documentation for ATen, Autograd, nn, optim, data, CUDA, and more. + keywords: PyTorch, C++, API reference, ATen, Autograd, nn, optim, CUDA +--- + +# C++ API Reference + +This section provides reference documentation for the PyTorch C++ API, +organized by module. + +```{toctree} +:maxdepth: 2 +:caption: Core + +aten/index +c10/index +autograd/index +cuda/index +xpu/index +``` + +```{toctree} +:maxdepth: 2 +:caption: C++ Frontend + +nn/index +optim/index +data/index +serialize/index +``` + +```{toctree} +:maxdepth: 2 +:caption: Extensions + +library/index +stable/index +``` diff --git a/docs/cpp/source/api/library/custom_classes.md b/docs/cpp/source/api/library/custom_classes.md new file mode 100644 index 0000000000000..48828eb731aba --- /dev/null +++ b/docs/cpp/source/api/library/custom_classes.md @@ -0,0 +1,94 @@ +--- +myst: + html_meta: + description: Custom classes in PyTorch C++ — registering C++ classes for use in TorchScript and Python. + keywords: PyTorch, C++, custom class, TorchScript, TORCH_CLASS, registration +--- + +# Custom Classes + +PyTorch allows registering custom C++ classes that can be used from Python +and TorchScript. + +Header: `torch/custom_class.h` + +## class\_ Template + +```{doxygenclass} torch::class_ +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +struct MyClass : torch::CustomClassHolder { + int value; + + MyClass(int v) : value(v) {} + + int getValue() const { return value; } + void setValue(int v) { value = v; } +}; + +TORCH_LIBRARY(my_classes, m) { + m.class_("MyClass") + .def(torch::init()) + .def("getValue", &MyClass::getValue) + .def("setValue", &MyClass::setValue) + .def_readwrite("value", &MyClass::value); +} +``` + +## Registering Methods + +**Constructor:** + +```cpp +m.class_("MyClass") + .def(torch::init()) // Constructor taking int +``` + +**Methods:** + +```cpp +m.class_("MyClass") + .def("getValue", &MyClass::getValue) + .def("setValue", &MyClass::setValue) +``` + +**Properties:** + +```cpp +m.class_("MyClass") + .def_readwrite("value", &MyClass::value) // Read-write + .def_readonly("const_value", &MyClass::const_value) // Read-only +``` + +## Using Custom Classes + +**From C++:** + +```cpp +auto my_obj = c10::make_intrusive(42); +int val = my_obj->getValue(); +``` + +**From Python:** + +```python +import torch +torch.classes.load_library("path/to/library.so") +obj = torch.classes.my_classes.MyClass(42) +print(obj.getValue()) +``` + +**In TorchScript:** + +```python +@torch.jit.script +def use_my_class(x: torch.classes.my_classes.MyClass) -> int: + return x.getValue() +``` diff --git a/docs/cpp/source/api/library/index.md b/docs/cpp/source/api/library/index.md new file mode 100644 index 0000000000000..2d3e1e225a1be --- /dev/null +++ b/docs/cpp/source/api/library/index.md @@ -0,0 +1,62 @@ +--- +myst: + html_meta: + description: PyTorch Library API in C++ — operator registration, custom classes, and versioning. + keywords: PyTorch, C++, Library, operator registration, custom class, dispatch +--- + +# Torch Library API + +The Torch Library API provides capabilities for extending PyTorch's core library +of operators with user-defined operators and data types. This is the primary +mechanism for registering custom C++ operators that can be called from both +Python and C++. + +**When to use the Library API:** + +- When creating custom operators for PyTorch +- When implementing backend-specific kernels (CPU, CUDA, etc.) +- When registering custom classes for use in TorchScript +- When extending PyTorch with new functionality + +**Basic usage:** + +```cpp +#include + +// Define a custom operator +torch::Tensor my_add(const torch::Tensor& a, const torch::Tensor& b) { + return a + b; +} + +// Register the operator +TORCH_LIBRARY(myops, m) { + m.def("add(Tensor a, Tensor b) -> Tensor", &my_add); +} + +// Use from C++ +auto result = torch::dispatcher::call("myops::add", tensor_a, tensor_b); +``` + +For a tutorial-style introduction to the library API, check out the +[Extending TorchScript with Custom C++ Operators](https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html) +tutorial. + +## Header Files + +- `torch/library.h` - Main library API header +- `torch/custom_class.h` - Custom class registration + +## Library API Categories + +```{toctree} +:maxdepth: 1 + +registration +custom_classes +versioning +``` + +## See Also + +- {doc}`../stable/index` - For stable ABI operator registration diff --git a/docs/cpp/source/api/library/registration.md b/docs/cpp/source/api/library/registration.md new file mode 100644 index 0000000000000..ec74871915681 --- /dev/null +++ b/docs/cpp/source/api/library/registration.md @@ -0,0 +1,113 @@ +--- +myst: + html_meta: + description: Operator registration in PyTorch C++ — TORCH_LIBRARY, TORCH_LIBRARY_IMPL for custom operators. + keywords: PyTorch, C++, TORCH_LIBRARY, operator registration, dispatch, custom op +--- + +# Operator Registration + +The library API provides macros and classes for registering custom operators +with PyTorch's dispatcher. + +## Macros + +### TORCH_LIBRARY + +```{doxygendefine} TORCH_LIBRARY +``` + +**Example:** + +```cpp +TORCH_LIBRARY(myops, m) { + m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); + m.def("mul(Tensor self, Tensor other) -> Tensor"); + m.impl("mul", torch::kCPU, &mul_cpu_impl); + m.impl("mul", torch::kCUDA, &mul_cuda_impl); +} +``` + +### TORCH_LIBRARY_IMPL + +```{doxygendefine} TORCH_LIBRARY_IMPL +``` + +**Example:** + +```cpp +TORCH_LIBRARY_IMPL(myops, XLA, m) { + m.impl("mul", &mul_xla_impl); +} +``` + +### TORCH_LIBRARY_FRAGMENT + +```{doxygendefine} TORCH_LIBRARY_FRAGMENT +``` + +**Example:** + +```cpp +// In file1.cpp +TORCH_LIBRARY(myops, m) { + m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); +} + +// In file2.cpp +TORCH_LIBRARY_FRAGMENT(myops, m) { + m.def("mul(Tensor self, Tensor other) -> Tensor", &mul_impl); +} +``` + +## Classes + +### Library + +```{doxygenclass} torch::Library +``` + +**Example:** + +```cpp +TORCH_LIBRARY(myops, m) { + // Define with implementation + m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl); + + // Define schema only + m.def("mul(Tensor self, Tensor other) -> Tensor"); + + // Provide backend-specific implementations + m.impl("mul", torch::kCPU, &mul_cpu_impl); + m.impl("mul", torch::kCUDA, &mul_cuda_impl); +} +``` + +### CppFunction + +```{doxygenclass} torch::CppFunction +:members: +:no-link: +``` + +### OrderedDict + +```{doxygenclass} torch::OrderedDict +:members: +:undoc-members: +``` + +## Functions + +The library API provides builder methods on the `Library` class for registering +operators. See the `Library` class documentation above for the full API including +`def()`, `impl()`, and `fallback()` methods. + +## Dispatch Keys + +Common dispatch keys used with `torch::dispatch()`: + +- `torch::kCPU` - CPU backend +- `torch::kCUDA` - CUDA backend +- `torch::kAutograd` - Autograd backend +- `torch::kMeta` - Meta tensor backend diff --git a/docs/cpp/source/api/library/versioning.md b/docs/cpp/source/api/library/versioning.md new file mode 100644 index 0000000000000..a287f3c2a5751 --- /dev/null +++ b/docs/cpp/source/api/library/versioning.md @@ -0,0 +1,44 @@ +--- +myst: + html_meta: + description: Operator versioning in PyTorch C++ — managing backward compatibility for serialized models. + keywords: PyTorch, C++, versioning, backward compatibility, operator, serialization +--- + +# Library Versioning + +PyTorch provides version number macros for identifying the version of LibTorch in use. + +**Example:** + +```cpp +#include +#include + +int main() { + std::cout << "PyTorch version from parts: " + << TORCH_VERSION_MAJOR << "." + << TORCH_VERSION_MINOR << "." + << TORCH_VERSION_PATCH << std::endl; + std::cout << "PyTorch version: " << TORCH_VERSION << std::endl; +} +``` + +This will output something like: + +```text +PyTorch version from parts: 1.8.0 +PyTorch version: 1.8.0 +``` + +```{note} + +These macros are only available in PyTorch >= 1.8.0. +``` + +## Version Macros + +- `TORCH_VERSION_MAJOR` - Major version number +- `TORCH_VERSION_MINOR` - Minor version number +- `TORCH_VERSION_PATCH` - Patch version number +- `TORCH_VERSION` - Full version string (e.g., "1.8.0") diff --git a/docs/cpp/source/api/nn/activation.md b/docs/cpp/source/api/nn/activation.md new file mode 100644 index 0000000000000..d4551490f2f70 --- /dev/null +++ b/docs/cpp/source/api/nn/activation.md @@ -0,0 +1,356 @@ +--- +myst: + html_meta: + description: Activation functions in PyTorch C++ — ReLU, GELU, Sigmoid, Softmax, and more torch::nn activation modules. + keywords: PyTorch, C++, activation, ReLU, GELU, Sigmoid, Softmax, LeakyReLU, ELU, Mish +--- + +# Activation Functions + +Activation functions introduce non-linearity into neural networks, allowing them +to learn complex patterns. Without activations, stacked linear layers would collapse +into a single linear transformation. + +**Common choices:** + +- **ReLU family** (ReLU, LeakyReLU, PReLU, RReLU): Fast, widely used, good default choice +- **ELU family** (ELU, SELU, CELU): Smoother than ReLU, can produce negative outputs +- **GELU/SiLU/Mish**: Modern activations popular in transformers and advanced architectures +- **Sigmoid/Tanh**: Classic activations, useful for output layers (probabilities, bounded outputs) +- **Softmax**: Converts logits to probability distribution (classification output) + +## ReLU + +```{doxygenclass} torch::nn::ReLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReLUImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto relu = torch::nn::ReLU(torch::nn::ReLUOptions().inplace(true)); +``` + +## LeakyReLU + +```{doxygenclass} torch::nn::LeakyReLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LeakyReLUImpl +:members: +:undoc-members: +``` + +## PReLU + +```{doxygenclass} torch::nn::PReLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::PReLUImpl +:members: +:undoc-members: +``` + +## RReLU + +```{doxygenclass} torch::nn::RReLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::RReLUImpl +:members: +:undoc-members: +``` + +## ReLU6 + +Like ReLU but caps the output at 6: `min(max(0, x), 6)`. Commonly used in +mobile architectures (MobileNet). + +```{doxygenclass} torch::nn::ReLU6 +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReLU6Impl +:members: +:undoc-members: +``` + +## GLU + +Gated Linear Unit. Splits the input tensor in half along a dimension, +then applies `a * sigmoid(b)`. + +```{doxygenclass} torch::nn::GLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::GLUImpl +:members: +:undoc-members: +``` + +## LogSigmoid + +Applies element-wise `log(sigmoid(x))`. Numerically more stable than +computing `log` and `sigmoid` separately. + +```{doxygenclass} torch::nn::LogSigmoid +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LogSigmoidImpl +:members: +:undoc-members: +``` + +## ELU + +```{doxygenclass} torch::nn::ELU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ELUImpl +:members: +:undoc-members: +``` + +## SELU + +```{doxygenclass} torch::nn::SELU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SELUImpl +:members: +:undoc-members: +``` + +## CELU + +```{doxygenclass} torch::nn::CELU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::CELUImpl +:members: +:undoc-members: +``` + +## GELU + +```{doxygenclass} torch::nn::GELU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::GELUImpl +:members: +:undoc-members: +``` + +## SiLU (Swish) + +```{doxygenclass} torch::nn::SiLU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SiLUImpl +:members: +:undoc-members: +``` + +## Mish + +```{doxygenclass} torch::nn::Mish +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MishImpl +:members: +:undoc-members: +``` + +## Sigmoid + +```{doxygenclass} torch::nn::Sigmoid +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SigmoidImpl +:members: +:undoc-members: +``` + +## Tanh + +```{doxygenclass} torch::nn::Tanh +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::TanhImpl +:members: +:undoc-members: +``` + +## Softmax + +```{doxygenclass} torch::nn::Softmax +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SoftmaxImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto softmax = torch::nn::Softmax(torch::nn::SoftmaxOptions(/*dim=*/1)); +``` + +## Softmax2d + +Applies `Softmax` over features to each spatial location in a 4D input +tensor of shape `(N, C, H, W)`. + +```{doxygenclass} torch::nn::Softmax2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Softmax2dImpl +:members: +:undoc-members: +``` + +## LogSoftmax + +```{doxygenclass} torch::nn::LogSoftmax +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LogSoftmaxImpl +:members: +:undoc-members: +``` + +## Softmin + +```{doxygenclass} torch::nn::Softmin +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SoftminImpl +:members: +:undoc-members: +``` + +## Softplus + +```{doxygenclass} torch::nn::Softplus +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SoftplusImpl +:members: +:undoc-members: +``` + +## Softshrink + +```{doxygenclass} torch::nn::Softshrink +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SoftshrinkImpl +:members: +:undoc-members: +``` + +## Softsign + +```{doxygenclass} torch::nn::Softsign +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::SoftsignImpl +:members: +:undoc-members: +``` + +## Hardshrink + +```{doxygenclass} torch::nn::Hardshrink +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::HardshrinkImpl +:members: +:undoc-members: +``` + +## Hardtanh + +```{doxygenclass} torch::nn::Hardtanh +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::HardtanhImpl +:members: +:undoc-members: +``` + +## Tanhshrink + +```{doxygenclass} torch::nn::Tanhshrink +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::TanhshrinkImpl +:members: +:undoc-members: +``` + +## Threshold + +```{doxygenclass} torch::nn::Threshold +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ThresholdImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/containers.md b/docs/cpp/source/api/nn/containers.md new file mode 100644 index 0000000000000..c9a694a96b158 --- /dev/null +++ b/docs/cpp/source/api/nn/containers.md @@ -0,0 +1,110 @@ +--- +myst: + html_meta: + description: Module containers in PyTorch C++ — Sequential, ModuleList, ModuleDict for composing neural networks. + keywords: PyTorch, C++, Sequential, ModuleList, ModuleDict, container, module +--- + +# Containers + +Container modules hold other modules and define how they are composed together. +Use containers to build complex architectures from simpler building blocks. + +- **Sequential**: Chain modules in order, output of one feeds into the next +- **ModuleList**: Store modules in a list for iteration (not auto-forwarded) +- **ModuleDict**: Store modules in a dictionary for named access +- **ParameterList/ParameterDict**: Store parameters directly without wrapping in modules + +```{note} + +PyTorch's C++ API uses the PIMPL (Pointer to Implementation) pattern. You create +modules using the public class name (e.g., `torch::nn::Sequential`), which +internally wraps an implementation class (`SequentialImpl`). The documentation +below shows the implementation classes, which contain all the actual methods. +``` + +## Sequential + +`Sequential` is a container that chains modules together. Each module's output +becomes the next module's input. This is the simplest way to build feed-forward +networks. + +```{doxygenclass} torch::nn::Sequential +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::nn::Sequential seq( + torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 32, 3)), + torch::nn::ReLU(), + torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3)), + torch::nn::ReLU() +); + +auto output = seq->forward(input); +``` + +## ModuleList + +`ModuleList` stores modules in a list for indexed or iterated access. Unlike +`Sequential`, it does not have a built-in `forward()` method—you control how +modules are called. + +```{doxygenclass} torch::nn::ModuleList +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::nn::ModuleList layers; +layers->push_back(torch::nn::Linear(10, 20)); +layers->push_back(torch::nn::Linear(20, 30)); + +torch::Tensor x = input; +for (const auto& layer : *layers) { + x = layer->as()->forward(x); +} +``` + +## ModuleDict + +`ModuleDict` stores modules in a dictionary for named access. Useful when you +need to select modules by name at runtime. + +```{doxygenclass} torch::nn::ModuleDict +:members: +:undoc-members: +``` + +## ParameterList + +`ParameterList` stores parameters directly without wrapping them in modules. + +```{doxygenclass} torch::nn::ParameterList +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ParameterListImpl +:members: +:undoc-members: +``` + +## ParameterDict + +`ParameterDict` stores parameters in a dictionary for named access. + +```{doxygenclass} torch::nn::ParameterDict +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ParameterDictImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/convolution.md b/docs/cpp/source/api/nn/convolution.md new file mode 100644 index 0000000000000..2c2addf3ce266 --- /dev/null +++ b/docs/cpp/source/api/nn/convolution.md @@ -0,0 +1,134 @@ +--- +myst: + html_meta: + description: Convolution layers in PyTorch C++ — Conv1d, Conv2d, Conv3d, and transposed convolutions. + keywords: PyTorch, C++, convolution, Conv1d, Conv2d, Conv3d, ConvTranspose, nn +--- + +# Convolution Layers + +Convolutional layers apply learnable filters to input data, extracting local features +through sliding window operations. They are fundamental to CNNs for image, audio, and +sequential data processing. + +- **Conv1d/2d/3d**: Standard convolution for 1D sequences, 2D images, or 3D volumes +- **ConvTranspose1d/2d/3d**: Transposed convolution (deconvolution) for upsampling + +**Key parameters:** + +- `in_channels`: Number of input channels +- `out_channels`: Number of output channels (number of filters) +- `kernel_size`: Size of the convolving kernel +- `stride`: Stride of the convolution (default: 1) +- `padding`: Zero-padding added to input (default: 0) +- `dilation`: Spacing between kernel elements (default: 1) +- `groups`: Number of blocked connections (default: 1, use `in_channels` for depthwise) + +## Conv1d + +Applies 1D convolution over an input signal composed of several input planes. + +```{doxygenclass} torch::nn::Conv1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Conv1dImpl +:members: +:undoc-members: +``` + +## Conv2d + +Applies 2D convolution over an input image. The most commonly used layer for +image processing tasks. + +```{doxygenclass} torch::nn::Conv2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Conv2dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Create Conv2d: 3 input channels, 64 output channels, 3x3 kernel +auto conv = torch::nn::Conv2d( + torch::nn::Conv2dOptions(3, 64, 3) + .stride(1) + .padding(1) + .bias(true)); + +auto output = conv->forward(input); // input: [N, 3, H, W] +``` + +## Conv3d + +Applies 3D convolution over an input volume (e.g., video frames or 3D medical images). + +```{doxygenclass} torch::nn::Conv3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Conv3dImpl +:members: +:undoc-members: +``` + +## ConvTranspose1d + +Applies 1D transposed convolution (fractionally-strided convolution) for upsampling. + +```{doxygenclass} torch::nn::ConvTranspose1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ConvTranspose1dImpl +:members: +:undoc-members: +``` + +## ConvTranspose2d + +Applies 2D transposed convolution for upsampling. Commonly used in decoder +networks and generative models. + +```{doxygenclass} torch::nn::ConvTranspose2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ConvTranspose2dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Create ConvTranspose2d for upsampling +auto conv_transpose = torch::nn::ConvTranspose2d( + torch::nn::ConvTranspose2dOptions(64, 32, 4) + .stride(2) + .padding(1)); +``` + +## ConvTranspose3d + +Applies 3D transposed convolution for upsampling volumetric data. + +```{doxygenclass} torch::nn::ConvTranspose3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ConvTranspose3dImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/dropout.md b/docs/cpp/source/api/nn/dropout.md new file mode 100644 index 0000000000000..0777452ab9f34 --- /dev/null +++ b/docs/cpp/source/api/nn/dropout.md @@ -0,0 +1,86 @@ +--- +myst: + html_meta: + description: Dropout layers in PyTorch C++ — Dropout, Dropout2d, Dropout3d, AlphaDropout, and FeatureAlphaDropout. + keywords: PyTorch, C++, Dropout, Dropout2d, Dropout3d, AlphaDropout, regularization +--- + +# Dropout Layers + +Dropout randomly zeros elements during training as a regularization technique, +preventing overfitting by forcing the network to learn redundant representations. +During evaluation, dropout is disabled and outputs are scaled appropriately. + +- **Dropout**: Standard dropout for fully-connected layers +- **Dropout2d/3d**: Spatial dropout that zeros entire channels (better for CNNs) +- **AlphaDropout**: Maintains self-normalizing property (use with SELU activation) + +```{note} + +Remember to call `model->train()` during training and `model->eval()` during +inference to properly enable/disable dropout behavior. +``` + +## Dropout + +```{doxygenclass} torch::nn::Dropout +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::DropoutImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto dropout = torch::nn::Dropout(torch::nn::DropoutOptions(0.5)); +``` + +## Dropout2d / Dropout3d + +```{doxygenclass} torch::nn::Dropout2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Dropout2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Dropout3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Dropout3dImpl +:members: +:undoc-members: +``` + +## AlphaDropout + +```{doxygenclass} torch::nn::AlphaDropout +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AlphaDropoutImpl +:members: +:undoc-members: +``` + +## FeatureAlphaDropout + +```{doxygenclass} torch::nn::FeatureAlphaDropout +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::FeatureAlphaDropoutImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/embedding.md b/docs/cpp/source/api/nn/embedding.md new file mode 100644 index 0000000000000..f766a4ea17bfc --- /dev/null +++ b/docs/cpp/source/api/nn/embedding.md @@ -0,0 +1,55 @@ +--- +myst: + html_meta: + description: Embedding layers in PyTorch C++ — Embedding and EmbeddingBag for sparse and dense lookups. + keywords: PyTorch, C++, Embedding, EmbeddingBag, lookup table, sparse, NLP +--- + +# Embedding Layers + +Embedding layers map discrete tokens (words, categories, IDs) to dense vector +representations. They are the foundation of NLP models and recommendation systems. + +- **Embedding**: Standard lookup table that maps indices to dense vectors +- **EmbeddingBag**: Computes sums or means of embeddings (efficient for sparse features) + +**Key parameters:** + +- `num_embeddings`: Size of the vocabulary (number of unique tokens) +- `embedding_dim`: Dimension of each embedding vector +- `padding_idx`: Index that outputs zeros (useful for padding tokens) + +## Embedding + +```{doxygenclass} torch::nn::Embedding +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::EmbeddingImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto embedding = torch::nn::Embedding( + torch::nn::EmbeddingOptions(10000, 256) // num_embeddings, embedding_dim + .padding_idx(0)); + +auto indices = torch::tensor({1, 2, 3, 4}); +auto embedded = embedding->forward(indices); // [4, 256] +``` + +## EmbeddingBag + +```{doxygenclass} torch::nn::EmbeddingBag +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::EmbeddingBagImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/functional.md b/docs/cpp/source/api/nn/functional.md new file mode 100644 index 0000000000000..0ace995624041 --- /dev/null +++ b/docs/cpp/source/api/nn/functional.md @@ -0,0 +1,363 @@ +--- +myst: + html_meta: + description: Functional API in PyTorch C++ — torch::nn::functional stateless operations for neural networks. + keywords: PyTorch, C++, functional, torch::nn::functional, relu, conv2d, linear, softmax +--- + +# Functional API + +The `torch::nn::functional` namespace provides stateless versions of neural +network operations. Unlike module classes, functional operations do not hold +learnable parameters — you pass weights explicitly. + +**When to use functional vs modules:** + +- Use **modules** (`torch::nn::Conv2d`) when you need learnable parameters + managed automatically (training, saving, loading). +- Use **functional** (`torch::nn::functional::conv2d`) when you already have + weights as tensors, or for operations without parameters (e.g., `relu`). + +```cpp +#include +namespace F = torch::nn::functional; + +// Stateless activation — no module needed +auto output = F::relu(input); + +// Convolution with explicit weight tensor +auto output = F::conv2d(input, weight, F::Conv2dFuncOptions().stride(1).padding(1)); + +// Softmax along a dimension +auto probs = F::softmax(logits, F::SoftmaxFuncOptions(/*dim=*/1)); +``` + +## Activation Functions + +```{doxygenfunction} torch::nn::functional::elu +``` +```{doxygenfunction} torch::nn::functional::selu +``` +```{doxygenfunction} torch::nn::functional::hardshrink +``` +```{doxygenfunction} torch::nn::functional::hardtanh +``` +```{doxygenfunction} torch::nn::functional::leaky_relu +``` +```{doxygenfunction} torch::nn::functional::logsigmoid +``` +```{doxygenfunction} torch::nn::functional::glu +``` +```{doxygenfunction} torch::nn::functional::gelu +``` +```{doxygenfunction} torch::nn::functional::silu +``` +```{doxygenfunction} torch::nn::functional::mish +``` +```{doxygenfunction} torch::nn::functional::prelu +``` +```{doxygenfunction} torch::nn::functional::relu +``` +```{doxygenfunction} torch::nn::functional::relu6 +``` +```{doxygenfunction} torch::nn::functional::rrelu +``` +```{doxygenfunction} torch::nn::functional::celu +``` +```{doxygenfunction} torch::nn::functional::softplus +``` +```{doxygenfunction} torch::nn::functional::softshrink +``` +```{doxygenfunction} torch::nn::functional::softsign +``` +```{doxygenfunction} torch::nn::functional::tanhshrink +``` +```{doxygenfunction} torch::nn::functional::threshold +``` +```{doxygenfunction} torch::nn::functional::softmax +``` +```{doxygenfunction} torch::nn::functional::softmin +``` +```{doxygenfunction} torch::nn::functional::log_softmax +``` +```{doxygenfunction} torch::nn::functional::gumbel_softmax +``` + +## Convolution Functions + +```{doxygenfunction} torch::nn::functional::conv1d +``` +```{doxygenfunction} torch::nn::functional::conv2d +``` +```{doxygenfunction} torch::nn::functional::conv3d +``` +```{doxygenfunction} torch::nn::functional::conv_transpose1d +``` +```{doxygenfunction} torch::nn::functional::conv_transpose2d +``` +```{doxygenfunction} torch::nn::functional::conv_transpose3d +``` + +## Pooling Functions + +```{doxygenfunction} torch::nn::functional::avg_pool1d +``` +```{doxygenfunction} torch::nn::functional::avg_pool2d +``` +```{doxygenfunction} torch::nn::functional::avg_pool3d +``` +```{doxygenfunction} torch::nn::functional::max_pool1d +``` +```{doxygenfunction} torch::nn::functional::max_pool2d +``` +```{doxygenfunction} torch::nn::functional::max_pool3d +``` +```{doxygenfunction} torch::nn::functional::max_pool1d_with_indices +``` +```{doxygenfunction} torch::nn::functional::max_pool2d_with_indices +``` +```{doxygenfunction} torch::nn::functional::max_pool3d_with_indices +``` +```{doxygenfunction} torch::nn::functional::adaptive_max_pool1d +``` +```{doxygenfunction} torch::nn::functional::adaptive_max_pool2d +``` +```{doxygenfunction} torch::nn::functional::adaptive_max_pool3d +``` +```{doxygenfunction} torch::nn::functional::adaptive_avg_pool1d +``` +```{doxygenfunction} torch::nn::functional::adaptive_avg_pool2d +``` +```{doxygenfunction} torch::nn::functional::adaptive_avg_pool3d +``` +```{doxygenfunction} torch::nn::functional::max_unpool1d +``` +```{doxygenfunction} torch::nn::functional::max_unpool2d +``` +```{doxygenfunction} torch::nn::functional::max_unpool3d +``` +```{doxygenfunction} torch::nn::functional::fractional_max_pool2d +``` +```{doxygenfunction} torch::nn::functional::fractional_max_pool3d +``` +```{doxygenfunction} torch::nn::functional::lp_pool1d +``` +```{doxygenfunction} torch::nn::functional::lp_pool2d +``` +```{doxygenfunction} torch::nn::functional::lp_pool3d +``` + +## Linear Functions + +```{doxygenfunction} torch::nn::functional::linear +``` +```{doxygenfunction} torch::nn::functional::bilinear +``` + +## Dropout Functions + +```{doxygenfunction} torch::nn::functional::dropout +``` +```{doxygenfunction} torch::nn::functional::dropout2d +``` +```{doxygenfunction} torch::nn::functional::dropout3d +``` +```{doxygenfunction} torch::nn::functional::alpha_dropout +``` +```{doxygenfunction} torch::nn::functional::feature_alpha_dropout +``` + +## Embedding Functions + +```{doxygenfunction} torch::nn::functional::one_hot +``` +```{doxygenfunction} torch::nn::functional::embedding +``` +```{doxygenfunction} torch::nn::functional::embedding_bag +``` + +## Normalization Functions + +```{doxygenfunction} torch::nn::functional::batch_norm +``` +```{doxygenfunction} torch::nn::functional::instance_norm +``` +```{doxygenfunction} torch::nn::functional::layer_norm +``` +```{doxygenfunction} torch::nn::functional::group_norm +``` +```{doxygenfunction} torch::nn::functional::local_response_norm +``` +```{doxygenfunction} torch::nn::functional::normalize +``` + +## Loss Functions + +```{doxygenfunction} torch::nn::functional::l1_loss +``` +```{doxygenfunction} torch::nn::functional::mse_loss +``` +```{doxygenfunction} torch::nn::functional::binary_cross_entropy +``` +```{doxygenfunction} torch::nn::functional::binary_cross_entropy_with_logits +``` +```{doxygenfunction} torch::nn::functional::cross_entropy +``` +```{doxygenfunction} torch::nn::functional::nll_loss +``` +```{doxygenfunction} torch::nn::functional::kl_div +``` +```{doxygenfunction} torch::nn::functional::smooth_l1_loss(const Tensor& input, const Tensor& target, const SmoothL1LossFuncOptions& options) +``` +```{doxygenfunction} torch::nn::functional::huber_loss +``` +```{doxygenfunction} torch::nn::functional::hinge_embedding_loss +``` +```{doxygenfunction} torch::nn::functional::multi_margin_loss +``` +```{doxygenfunction} torch::nn::functional::cosine_embedding_loss +``` +```{doxygenfunction} torch::nn::functional::margin_ranking_loss +``` +```{doxygenfunction} torch::nn::functional::multilabel_margin_loss +``` +```{doxygenfunction} torch::nn::functional::soft_margin_loss +``` +```{doxygenfunction} torch::nn::functional::multilabel_soft_margin_loss +``` +```{doxygenfunction} torch::nn::functional::triplet_margin_loss +``` +```{doxygenfunction} torch::nn::functional::triplet_margin_with_distance_loss +``` +```{doxygenfunction} torch::nn::functional::ctc_loss +``` +```{doxygenfunction} torch::nn::functional::poisson_nll_loss +``` + +## Distance Functions + +```{doxygenfunction} torch::nn::functional::cosine_similarity +``` +```{doxygenfunction} torch::nn::functional::pairwise_distance +``` +```{doxygenfunction} torch::nn::functional::pdist +``` + +## Vision Functions + +```{doxygenfunction} torch::nn::functional::interpolate +``` +```{doxygenfunction} torch::nn::functional::affine_grid +``` +```{doxygenfunction} torch::nn::functional::grid_sample +``` +```{doxygenfunction} torch::nn::functional::pad +``` +```{doxygenfunction} torch::nn::functional::pixel_shuffle +``` +```{doxygenfunction} torch::nn::functional::pixel_unshuffle +``` + +## Fold/Unfold + +```{doxygenfunction} torch::nn::functional::fold +``` +```{doxygenfunction} torch::nn::functional::unfold +``` + +## Functional Options Structs + +Each functional operation that takes configuration uses a corresponding options +struct. The naming convention is `FuncOptions`. + +**Activation Options:** + +```{doxygentypedef} torch::nn::functional::ELUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::SELUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::GLUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::GELUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::HardshrinkFuncOptions +``` +```{doxygentypedef} torch::nn::functional::HardtanhFuncOptions +``` +```{doxygentypedef} torch::nn::functional::LeakyReLUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ReLUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ReLU6FuncOptions +``` +```{doxygentypedef} torch::nn::functional::CELUFuncOptions +``` +```{doxygentypedef} torch::nn::functional::SoftplusFuncOptions +``` +```{doxygentypedef} torch::nn::functional::SoftshrinkFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ThresholdFuncOptions +``` + +**Convolution Options:** + +```{doxygentypedef} torch::nn::functional::Conv1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::Conv2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::Conv3dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ConvTranspose1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ConvTranspose2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::ConvTranspose3dFuncOptions +``` + +**Pooling Options:** + +```{doxygentypedef} torch::nn::functional::AvgPool1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AvgPool2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AvgPool3dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::MaxPool1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::MaxPool2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::MaxPool3dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveMaxPool1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveMaxPool2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveMaxPool3dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveAvgPool1dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveAvgPool2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::AdaptiveAvgPool3dFuncOptions +``` + +**Other Options:** + +```{doxygentypedef} torch::nn::functional::CosineSimilarityFuncOptions +``` +```{doxygentypedef} torch::nn::functional::PairwiseDistanceFuncOptions +``` +```{doxygentypedef} torch::nn::functional::Dropout2dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::Dropout3dFuncOptions +``` +```{doxygentypedef} torch::nn::functional::L1LossFuncOptions +``` +```{doxygentypedef} torch::nn::functional::FoldFuncOptions +``` +```{doxygentypedef} torch::nn::functional::UnfoldFuncOptions +``` +```{doxygentypedef} torch::nn::functional::PixelShuffleFuncOptions +``` +```{doxygentypedef} torch::nn::functional::PixelUnshuffleFuncOptions +``` diff --git a/docs/cpp/source/api/nn/index.md b/docs/cpp/source/api/nn/index.md new file mode 100644 index 0000000000000..c5bf65125e12b --- /dev/null +++ b/docs/cpp/source/api/nn/index.md @@ -0,0 +1,94 @@ +--- +myst: + html_meta: + description: PyTorch C++ neural network modules — torch::nn API for defining and training models. + keywords: PyTorch, C++, nn, Module, neural network, torch::nn +--- + +# Neural Network Modules (torch::nn) + +The `torch::nn` namespace provides neural network building blocks that mirror +Python's `torch.nn` module. It uses a PIMPL (Pointer to Implementation) pattern +where user-facing classes like `Conv2d` wrap internal `Conv2dImpl` classes. + +**When to use torch::nn:** + +- Building neural network models in C++ +- Creating custom layers and modules +- Porting Python models to C++ for production inference +- Training models entirely in C++ + +**Basic usage:** + +```cpp +#include + +// Define a simple model +struct Net : torch::nn::Module { + torch::nn::Conv2d conv1{nullptr}; + torch::nn::Linear fc1{nullptr}; + + Net() { + conv1 = register_module("conv1", torch::nn::Conv2d( + torch::nn::Conv2dOptions(1, 32, 3).stride(1).padding(1))); + fc1 = register_module("fc1", torch::nn::Linear(32 * 28 * 28, 10)); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(conv1->forward(x)); + x = x.view({-1, 32 * 28 * 28}); + return fc1->forward(x); + } +}; + +// Create and use the model +auto model = std::make_shared(); +auto input = torch::randn({1, 1, 28, 28}); +auto output = model->forward(input); +``` + +## Header Files + +- `torch/nn.h` - Main neural network header (includes all modules) +- `torch/nn/module.h` - Base Module class +- `torch/nn/modules.h` - All module implementations +- `torch/nn/options.h` - Options structs for modules +- `torch/nn/functional.h` - Functional API + +## Module Base Class + +All neural network modules inherit from `torch::nn::Module`, which provides +parameter management, serialization, device/dtype conversion, and hooks. + +```{doxygenclass} torch::nn::Module +``` + +**Key features:** + +- `register_module()`: Register submodules for parameter tracking +- `register_parameter()`: Register learnable parameters +- `register_buffer()`: Register non-learnable state (e.g., running mean) +- `parameters()` / `named_parameters()`: Iterate over all parameters +- `to()`: Move module to a device or convert dtype +- `train()` / `eval()`: Toggle training/evaluation mode +- `save()` / `load()`: Serialize and deserialize module state + +## Module Categories + +```{toctree} +:maxdepth: 1 + +containers +convolution +pooling +linear +activation +normalization +dropout +embedding +recurrent +transformer +loss +functional +utilities +``` diff --git a/docs/cpp/source/api/nn/linear.md b/docs/cpp/source/api/nn/linear.md new file mode 100644 index 0000000000000..8aa0c55a0ba4a --- /dev/null +++ b/docs/cpp/source/api/nn/linear.md @@ -0,0 +1,84 @@ +--- +myst: + html_meta: + description: Linear layers in PyTorch C++ — Linear, Bilinear, and Flatten modules. + keywords: PyTorch, C++, Linear, Bilinear, Flatten, fully connected, dense layer +--- + +# Linear Layers + +Linear layers apply affine transformations to input data: `y = xW^T + b`. +They are the building blocks of fully-connected networks and are used for +feature transformation, classification heads, and projection layers. + +- **Linear**: Standard fully-connected layer +- **Bilinear**: Bilinear transformation of two inputs +- **Identity**: Pass-through layer (useful for skip connections) +- **Flatten/Unflatten**: Reshape tensors between convolutional and linear layers + +## Linear + +```{doxygenclass} torch::nn::Linear +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LinearImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto linear = torch::nn::Linear(torch::nn::LinearOptions(784, 256).bias(true)); +auto output = linear->forward(input); // input: [N, 784] +``` + +## Bilinear + +```{doxygenclass} torch::nn::Bilinear +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BilinearImpl +:members: +:undoc-members: +``` + +## Identity + +```{doxygenclass} torch::nn::Identity +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::IdentityImpl +:members: +:undoc-members: +``` + +## Flatten + +```{doxygenclass} torch::nn::Flatten +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::FlattenImpl +:members: +:undoc-members: +``` + +## Unflatten + +```{doxygenclass} torch::nn::Unflatten +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::UnflattenImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/loss.md b/docs/cpp/source/api/nn/loss.md new file mode 100644 index 0000000000000..8f1ab82bc234d --- /dev/null +++ b/docs/cpp/source/api/nn/loss.md @@ -0,0 +1,284 @@ +--- +myst: + html_meta: + description: Loss functions in PyTorch C++ — CrossEntropyLoss, MSELoss, NLLLoss, BCELoss, and more. + keywords: PyTorch, C++, loss, CrossEntropyLoss, MSELoss, NLLLoss, BCELoss, L1Loss +--- + +# Loss Functions + +Loss functions measure how well the model's predictions match the targets. +The choice of loss function depends on your task type and data characteristics. + +**Regression losses:** + +- **L1Loss/MSELoss**: Basic regression losses (MAE vs MSE) +- **SmoothL1Loss/HuberLoss**: Robust to outliers + +**Classification losses:** + +- **CrossEntropyLoss**: Multi-class classification (combines LogSoftmax + NLLLoss) +- **NLLLoss**: Negative log likelihood (use with LogSoftmax output) +- **BCELoss/BCEWithLogitsLoss**: Binary classification + +**Specialized losses:** + +- **CTCLoss**: Sequence-to-sequence without alignment (speech recognition) +- **TripletMarginLoss**: Metric learning (similarity/embedding tasks) +- **CosineEmbeddingLoss**: Similarity learning with cosine distance + +## L1Loss + +```{doxygenclass} torch::nn::L1Loss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::L1LossImpl +:members: +:undoc-members: +``` + +## MSELoss + +```{doxygenclass} torch::nn::MSELoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::MSELossImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto loss_fn = torch::nn::MSELoss(); +auto loss = loss_fn->forward(predictions, targets); +``` + +## CrossEntropyLoss + +```{doxygenclass} torch::nn::CrossEntropyLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::CrossEntropyLossImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto loss_fn = torch::nn::CrossEntropyLoss(); +auto logits = torch::randn({32, 10}); // [batch, num_classes] +auto targets = torch::randint(0, 10, {32}); // [batch] +auto loss = loss_fn->forward(logits, targets); +``` + +## NLLLoss + +```{doxygenclass} torch::nn::NLLLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::NLLLossImpl +:members: +:undoc-members: +``` + +## BCELoss + +```{doxygenclass} torch::nn::BCELoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::BCELossImpl +:members: +:undoc-members: +``` + +## BCEWithLogitsLoss + +```{doxygenclass} torch::nn::BCEWithLogitsLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::BCEWithLogitsLossImpl +:members: +:undoc-members: +``` + +## HuberLoss + +```{doxygenclass} torch::nn::HuberLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::HuberLossImpl +:members: +:undoc-members: +``` + +## SmoothL1Loss + +```{doxygenclass} torch::nn::SmoothL1Loss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::SmoothL1LossImpl +:members: +:undoc-members: +``` + +## KLDivLoss + +```{doxygenclass} torch::nn::KLDivLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::KLDivLossImpl +:members: +:undoc-members: +``` + +## CTCLoss + +```{doxygenclass} torch::nn::CTCLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::CTCLossImpl +:members: +:undoc-members: +``` + +## PoissonNLLLoss + +```{doxygenclass} torch::nn::PoissonNLLLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::PoissonNLLLossImpl +:members: +:undoc-members: +``` + +## MarginRankingLoss + +```{doxygenclass} torch::nn::MarginRankingLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::MarginRankingLossImpl +:members: +:undoc-members: +``` + +## HingeEmbeddingLoss + +```{doxygenclass} torch::nn::HingeEmbeddingLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::HingeEmbeddingLossImpl +:members: +:undoc-members: +``` + +## CosineEmbeddingLoss + +```{doxygenclass} torch::nn::CosineEmbeddingLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::CosineEmbeddingLossImpl +:members: +:undoc-members: +``` + +## MultiMarginLoss + +```{doxygenclass} torch::nn::MultiMarginLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::MultiMarginLossImpl +:members: +:undoc-members: +``` + +## MultiLabelMarginLoss + +```{doxygenclass} torch::nn::MultiLabelMarginLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::MultiLabelMarginLossImpl +:members: +:undoc-members: +``` + +## MultiLabelSoftMarginLoss + +```{doxygenclass} torch::nn::MultiLabelSoftMarginLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::MultiLabelSoftMarginLossImpl +:members: +:undoc-members: +``` + +## SoftMarginLoss + +```{doxygenclass} torch::nn::SoftMarginLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::SoftMarginLossImpl +:members: +:undoc-members: +``` + +## TripletMarginLoss + +```{doxygenclass} torch::nn::TripletMarginLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::TripletMarginLossImpl +:members: +:undoc-members: +``` + +## TripletMarginWithDistanceLoss + +```{doxygenclass} torch::nn::TripletMarginWithDistanceLoss +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::TripletMarginWithDistanceLossImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/normalization.md b/docs/cpp/source/api/nn/normalization.md new file mode 100644 index 0000000000000..3ccfe99957e06 --- /dev/null +++ b/docs/cpp/source/api/nn/normalization.md @@ -0,0 +1,142 @@ +--- +myst: + html_meta: + description: Normalization layers in PyTorch C++ — BatchNorm, LayerNorm, GroupNorm, InstanceNorm, and LocalResponseNorm. + keywords: PyTorch, C++, normalization, BatchNorm, LayerNorm, GroupNorm, InstanceNorm +--- + +# Normalization Layers + +Normalization layers stabilize and accelerate training by normalizing intermediate +activations. They help with gradient flow and allow higher learning rates. + +- **BatchNorm**: Normalizes across batch dimension; most common in CNNs +- **InstanceNorm**: Normalizes each sample independently; popular in style transfer +- **LayerNorm**: Normalizes across feature dimension; standard in transformers +- **GroupNorm**: Normalizes within groups of channels; works with small batches +- **LocalResponseNorm**: Lateral inhibition inspired by neuroscience (less common today) + +## BatchNorm1d / BatchNorm2d / BatchNorm3d + +```{doxygenclass} torch::nn::BatchNorm1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BatchNorm1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BatchNorm2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BatchNorm2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BatchNorm3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::BatchNorm3dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto bn = torch::nn::BatchNorm2d( + torch::nn::BatchNorm2dOptions(64) // num_features + .eps(1e-5) + .momentum(0.1) + .affine(true) + .track_running_stats(true)); +``` + +## InstanceNorm1d / InstanceNorm2d / InstanceNorm3d + +```{doxygenclass} torch::nn::InstanceNorm1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::InstanceNorm1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::InstanceNorm2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::InstanceNorm2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::InstanceNorm3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::InstanceNorm3dImpl +:members: +:undoc-members: +``` + +## LayerNorm + +```{doxygenclass} torch::nn::LayerNorm +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LayerNormImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto ln = torch::nn::LayerNorm( + torch::nn::LayerNormOptions({768})); // normalized_shape +``` + +## GroupNorm + +```{doxygenclass} torch::nn::GroupNorm +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::GroupNormImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto gn = torch::nn::GroupNorm( + torch::nn::GroupNormOptions(32, 256)); // num_groups, num_channels +``` + +## LocalResponseNorm + +```{doxygenclass} torch::nn::LocalResponseNorm +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LocalResponseNormImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/pooling.md b/docs/cpp/source/api/nn/pooling.md new file mode 100644 index 0000000000000..679b24833d7ab --- /dev/null +++ b/docs/cpp/source/api/nn/pooling.md @@ -0,0 +1,262 @@ +--- +myst: + html_meta: + description: Pooling layers in PyTorch C++ — MaxPool, AvgPool, AdaptiveMaxPool, AdaptiveAvgPool, and LPPool. + keywords: PyTorch, C++, pooling, MaxPool2d, AvgPool2d, AdaptiveAvgPool2d, max pooling +--- + +# Pooling Layers + +Pooling layers reduce spatial dimensions by aggregating values in local regions, +providing translation invariance and reducing computational cost in deeper layers. + +- **MaxPool**: Takes the maximum value in each pooling window (preserves strong features) +- **AvgPool**: Takes the average value in each pooling window (smoother downsampling) +- **AdaptivePool**: Automatically calculates kernel size to produce a target output size +- **FractionalMaxPool**: Randomized pooling with fractional output size +- **MaxUnpool**: Computes the partial inverse of MaxPool using stored indices +- **LPPool**: Power-average pooling (generalization of avg/max pooling) + +## MaxPool1d / MaxPool2d / MaxPool3d + +```{doxygenclass} torch::nn::MaxPool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxPool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxPool3dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto pool = torch::nn::MaxPool2d( + torch::nn::MaxPool2dOptions(2).stride(2)); +``` + +## AvgPool1d / AvgPool2d / AvgPool3d + +```{doxygenclass} torch::nn::AvgPool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AvgPool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AvgPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AvgPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AvgPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AvgPool3dImpl +:members: +:undoc-members: +``` + +## AdaptiveAvgPool1d / AdaptiveAvgPool2d / AdaptiveAvgPool3d + +```{doxygenclass} torch::nn::AdaptiveAvgPool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveAvgPool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveAvgPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveAvgPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveAvgPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveAvgPool3dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Output will always be 7x7 regardless of input size +auto adaptive_pool = torch::nn::AdaptiveAvgPool2d( + torch::nn::AdaptiveAvgPool2dOptions({7, 7})); +``` + +## AdaptiveMaxPool1d / AdaptiveMaxPool2d / AdaptiveMaxPool3d + +```{doxygenclass} torch::nn::AdaptiveMaxPool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveMaxPool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveMaxPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveMaxPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveMaxPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::AdaptiveMaxPool3dImpl +:members: +:undoc-members: +``` + +## FractionalMaxPool2d / FractionalMaxPool3d + +```{doxygenclass} torch::nn::FractionalMaxPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::FractionalMaxPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::FractionalMaxPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::FractionalMaxPool3dImpl +:members: +:undoc-members: +``` + +## MaxUnpool1d / MaxUnpool2d / MaxUnpool3d + +Computes a partial inverse of `MaxPool`, using the indices of the maximum +values computed during pooling to place values back into unpooled positions. + +```{doxygenclass} torch::nn::MaxUnpool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxUnpool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxUnpool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxUnpool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxUnpool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MaxUnpool3dImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto pool = torch::nn::MaxPool2d( + torch::nn::MaxPool2dOptions(2).stride(2).return_indices(true)); +auto unpool = torch::nn::MaxUnpool2d( + torch::nn::MaxUnpoolOptions<2>(2).stride(2)); + +auto [output, indices] = pool->forward_with_indices(input); +auto reconstructed = unpool->forward(output, indices); +``` + +## LPPool1d / LPPool2d / LPPool3d + +```{doxygenclass} torch::nn::LPPool1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LPPool1dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LPPool2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LPPool2dImpl +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LPPool3d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LPPool3dImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/recurrent.md b/docs/cpp/source/api/nn/recurrent.md new file mode 100644 index 0000000000000..f286ebb6b52b5 --- /dev/null +++ b/docs/cpp/source/api/nn/recurrent.md @@ -0,0 +1,125 @@ +--- +myst: + html_meta: + description: Recurrent layers in PyTorch C++ — RNN, LSTM, and GRU modules for sequence modeling. + keywords: PyTorch, C++, RNN, LSTM, GRU, recurrent, sequence, hidden state +--- + +# Recurrent Layers + +Recurrent layers process sequential data by maintaining hidden state across time steps. +They are essential for tasks involving sequences: language modeling, speech recognition, +time series prediction, and more. + +- **RNN**: Basic recurrent layer (simple but prone to vanishing gradients) +- **LSTM**: Long Short-Term Memory (gated architecture, handles long-range dependencies) +- **GRU**: Gated Recurrent Unit (simpler than LSTM, often similar performance) +- **Cell variants**: Single-step versions for custom loop implementations + +**Key parameters:** + +- `input_size`: Number of features in input +- `hidden_size`: Number of features in hidden state +- `num_layers`: Number of stacked recurrent layers +- `batch_first`: If true, input shape is `[batch, seq, features]` +- `bidirectional`: Process sequence in both directions + +## RNN + +```{doxygenclass} torch::nn::RNN +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::RNNImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto rnn = torch::nn::RNN( + torch::nn::RNNOptions(128, 256) // input_size, hidden_size + .num_layers(2) + .batch_first(true) + .bidirectional(false)); + +auto input = torch::randn({32, 10, 128}); // [batch, seq_len, input_size] +auto [output, hidden] = rnn->forward(input); +``` + +## LSTM + +```{doxygenclass} torch::nn::LSTM +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LSTMImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto lstm = torch::nn::LSTM( + torch::nn::LSTMOptions(128, 256) + .num_layers(2) + .batch_first(true) + .dropout(0.1) + .bidirectional(true)); + +auto input = torch::randn({32, 10, 128}); +auto [output, state] = lstm->forward(input); +auto [h_n, c_n] = state; // hidden state, cell state +``` + +## GRU + +```{doxygenclass} torch::nn::GRU +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::GRUImpl +:members: +:undoc-members: +``` + +## RNNCell + +```{doxygenclass} torch::nn::RNNCell +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::RNNCellImpl +:members: +:undoc-members: +``` + +## LSTMCell + +```{doxygenclass} torch::nn::LSTMCell +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::LSTMCellImpl +:members: +:undoc-members: +``` + +## GRUCell + +```{doxygenclass} torch::nn::GRUCell +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::GRUCellImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/transformer.md b/docs/cpp/source/api/nn/transformer.md new file mode 100644 index 0000000000000..664564d4b3519 --- /dev/null +++ b/docs/cpp/source/api/nn/transformer.md @@ -0,0 +1,112 @@ +--- +myst: + html_meta: + description: Transformer layers in PyTorch C++ — Transformer, TransformerEncoder, TransformerDecoder, and MultiheadAttention. + keywords: PyTorch, C++, Transformer, TransformerEncoder, TransformerDecoder, MultiheadAttention, attention +--- + +# Transformer Layers + +Transformer layers use self-attention mechanisms to process sequences in parallel, +enabling efficient training on long sequences. They are the foundation of modern +NLP models (BERT, GPT) and increasingly used in vision and other domains. + +- **Transformer**: Complete encoder-decoder architecture +- **TransformerEncoder/Decoder**: Standalone encoder or decoder stacks +- **TransformerEncoderLayer/DecoderLayer**: Individual transformer blocks +- **MultiheadAttention**: Core attention mechanism used throughout + +**Key parameters:** + +- `d_model`: Dimension of the model (embedding dimension) +- `nhead`: Number of attention heads +- `num_encoder_layers/num_decoder_layers`: Number of stacked layers +- `dim_feedforward`: Dimension of feedforward network +- `dropout`: Dropout rate for regularization + +## Transformer + +Complete encoder-decoder transformer architecture. + +```{doxygenclass} torch::nn::Transformer +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::TransformerImpl +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto transformer = torch::nn::Transformer( + torch::nn::TransformerOptions() + .d_model(512) + .nhead(8) + .num_encoder_layers(6) + .num_decoder_layers(6) + .dim_feedforward(2048) + .dropout(0.1)); +``` + +## TransformerEncoder + +Stack of encoder layers for processing source sequences. + +```{doxygenclass} torch::nn::TransformerEncoder +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::TransformerEncoderImpl +:members: +:undoc-members: +``` + +## TransformerDecoder + +Stack of decoder layers for generating target sequences. + +```{doxygenclass} torch::nn::TransformerDecoder +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::TransformerDecoderImpl +:members: +:undoc-members: +``` + +## TransformerEncoderLayer + +Single encoder layer with self-attention and feedforward network. + +```{doxygenclass} torch::nn::TransformerEncoderLayerImpl +:members: +:undoc-members: +``` + +## TransformerDecoderLayer + +Single decoder layer with self-attention, cross-attention, and feedforward network. + +```{doxygenclass} TransformerDecoderLayerImpl +:members: +:undoc-members: +``` + +## MultiheadAttention + +Scaled dot-product attention with multiple parallel heads. + +```{doxygenclass} torch::nn::MultiheadAttention +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::MultiheadAttentionImpl +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/nn/utilities.md b/docs/cpp/source/api/nn/utilities.md new file mode 100644 index 0000000000000..6bde6de870a0d --- /dev/null +++ b/docs/cpp/source/api/nn/utilities.md @@ -0,0 +1,261 @@ +--- +myst: + html_meta: + description: Neural network utilities in PyTorch C++ — parameter initialization, module cloning, padding layers, and vision utilities. + keywords: PyTorch, C++, nn utilities, init, Cloneable, AnyModule, padding, PixelShuffle, Upsample +--- + +# Utilities + +Additional utilities for building neural networks: parameter initialization, +module cloning, type-erased containers, padding layers, and vision utilities. + +## Parameter Initialization + +The `torch::nn::init` namespace provides functions for initializing module parameters: + +```cpp +#include + +// Xavier/Glorot initialization +torch::nn::init::xavier_uniform_(linear->weight); +torch::nn::init::xavier_normal_(linear->weight); + +// Kaiming/He initialization +torch::nn::init::kaiming_uniform_(conv->weight, /*a=*/0, torch::kFanIn, torch::kReLU); +torch::nn::init::kaiming_normal_(conv->weight); + +// Other initializations +torch::nn::init::zeros_(linear->bias); +torch::nn::init::ones_(bn->weight); +torch::nn::init::constant_(linear->bias, 0.1); +torch::nn::init::normal_(linear->weight, /*mean=*/0, /*std=*/0.01); +torch::nn::init::uniform_(linear->weight, /*a=*/-0.1, /*b=*/0.1); +torch::nn::init::orthogonal_(rnn->weight_hh); +``` + +## Cloneable + +```{doxygenclass} torch::nn::Cloneable +``` + +All `torch::nn` modules inherit from `Cloneable`, enabling deep copies: + +```cpp +auto model = torch::nn::Linear(10, 5); +auto model_copy = std::dynamic_pointer_cast(model->clone()); +``` + +## AnyModule + +`AnyModule` provides type-erased storage for any module, allowing you to +store heterogeneous modules in a single container. + +```{doxygenclass} torch::nn::AnyModule +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::nn::AnyModule any_module(torch::nn::Linear(10, 5)); +auto output = any_module.forward(input); +``` + +## Functional + +Wraps a function or callable as a module, useful for inserting arbitrary +functions into a `Sequential` container. + +```{doxygenclass} torch::nn::FunctionalImpl +:members: +:undoc-members: +``` + +## ModuleHolder + +```{doxygenclass} torch::nn::ModuleHolder +:members: +:undoc-members: +``` + +## CosineSimilarity + +```{doxygenclass} torch::nn::CosineSimilarity +:members: +:undoc-members: +``` + +## PairwiseDistance + +```{doxygenclass} torch::nn::PairwiseDistance +:members: +:undoc-members: +``` + +## PackedSequence + +```{cpp:class} torch::nn::utils::rnn::PackedSequence + +Holds the data and list of `batch_sizes` of a packed sequence. +All RNN modules accept packed sequences as inputs. +``` + +```{cpp:function} const Tensor& data() const + +Returns the packed tensor containing all sequence elements. +``` + +```{cpp:function} const Tensor& batch_sizes() const + +Returns a 1D tensor of batch sizes at each time step. +``` + +```{cpp:function} const Tensor& sorted_indices() const + +Returns indices used to sort sequences by length (descending). +``` + +```{cpp:function} const Tensor& unsorted_indices() const + +Returns indices to restore the original sequence order. +``` + +```{cpp:function} PackedSequence to(torch::Device device) const + +Moves the packed sequence to the specified device. +``` + +See also: `torch::nn::utils::rnn::pack_padded_sequence` and +`torch::nn::utils::rnn::pad_packed_sequence`. + +## Padding Layers + +### ReflectionPad1d / ReflectionPad2d / ReflectionPad3d + +```{doxygenclass} torch::nn::ReflectionPad1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReflectionPad2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReflectionPad3d +:members: +:undoc-members: +``` + +### ReplicationPad1d / ReplicationPad2d / ReplicationPad3d + +```{doxygenclass} torch::nn::ReplicationPad1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReplicationPad2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ReplicationPad3d +:members: +:undoc-members: +``` + +### ZeroPad1d / ZeroPad2d / ZeroPad3d + +```{doxygenclass} torch::nn::ZeroPad1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ZeroPad2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ZeroPad3d +:members: +:undoc-members: +``` + +### ConstantPad1d / ConstantPad2d / ConstantPad3d + +```{doxygenclass} torch::nn::ConstantPad1d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ConstantPad2d +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::ConstantPad3d +:members: +:undoc-members: +``` + +## Vision Layers + +### PixelShuffle + +```{doxygenclass} torch::nn::PixelShuffle +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::PixelShuffleOptions +:members: +:undoc-members: +``` + +### PixelUnshuffle + +```{doxygenclass} torch::nn::PixelUnshuffle +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::PixelUnshuffleOptions +:members: +:undoc-members: +``` + +### Upsample + +```{doxygenclass} torch::nn::Upsample +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::UpsampleOptions +:members: +:undoc-members: +``` + +### Fold / Unfold + +```{doxygenclass} torch::nn::Fold +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::FoldOptions +:members: +:undoc-members: +``` + +```{doxygenclass} torch::nn::Unfold +:members: +:undoc-members: +``` + +```{doxygenstruct} torch::nn::UnfoldOptions +:members: +:undoc-members: +``` diff --git a/docs/cpp/source/api/optim/adaptive.md b/docs/cpp/source/api/optim/adaptive.md new file mode 100644 index 0000000000000..3f48534567627 --- /dev/null +++ b/docs/cpp/source/api/optim/adaptive.md @@ -0,0 +1,158 @@ +--- +myst: + html_meta: + description: Adaptive optimizers in PyTorch C++ — Adam, AdamW, Adagrad, and RMSprop. + keywords: PyTorch, C++, Adam, AdamW, Adagrad, RMSprop, adaptive optimizer +--- + +# Adaptive Learning Rate Optimizers + +These optimizers automatically adapt the learning rate for each parameter based +on historical gradient information. They typically require less hyperparameter +tuning and work well across a wide range of problems. + +## Adam (Adaptive Moment Estimation) + +Adam combines the benefits of RMSprop and momentum, maintaining per-parameter +adaptive learning rates. It's an excellent default choice for most deep learning +tasks, especially when you want fast convergence with minimal tuning. + +**When to use:** + +- Transformers and attention-based models +- Quick prototyping and experimentation +- When you don't have time for extensive hyperparameter search +- General-purpose deep learning + +**Key parameters:** + +- `lr`: Learning rate (typical: 1e-3 to 1e-4) +- `betas`: Coefficients for running averages (default: {0.9, 0.999}) +- `eps`: Numerical stability term (default: 1e-8) +- `weight_decay`: L2 regularization (note: applied differently than in SGD) + +```{doxygenclass} torch::optim::Adam +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Standard Adam configuration +auto optimizer = torch::optim::Adam( + model->parameters(), + torch::optim::AdamOptions(1e-3) // learning rate + .betas({0.9, 0.999}) // momentum terms + .eps(1e-8) // numerical stability + .weight_decay(0)); // L2 penalty + +// For transformers, lower learning rate with warmup +auto optimizer = torch::optim::Adam( + model->parameters(), + torch::optim::AdamOptions(1e-4) + .betas({0.9, 0.98})); // β2=0.98 for transformers +``` + +## AdamW (Adam with Decoupled Weight Decay) + +AdamW fixes a subtle issue with Adam's weight decay implementation. In Adam, +weight decay is coupled with the gradient update, which can lead to suboptimal +regularization. AdamW decouples weight decay, applying it directly to the +weights as in SGD. + +**When to use:** + +- Always prefer AdamW over Adam when using weight decay +- Training transformers (BERT, GPT, etc.) +- When you want proper L2 regularization behavior + +**Key difference from Adam:** + +- In Adam: `weight = weight - lr * (grad + weight_decay * weight)` +- In AdamW: `weight = weight - lr * grad - lr * weight_decay * weight` + +```{doxygenclass} torch::optim::AdamW +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// AdamW with decoupled weight decay - preferred for transformers +auto optimizer = torch::optim::AdamW( + model->parameters(), + torch::optim::AdamWOptions(1e-4) + .betas({0.9, 0.999}) + .weight_decay(0.01)); // Decoupled regularization +``` + +## RMSprop (Root Mean Square Propagation) + +RMSprop adapts the learning rate by dividing by a running average of recent +gradient magnitudes. It's particularly effective for recurrent neural networks +and problems with non-stationary objectives. + +**When to use:** + +- Training RNNs and LSTMs +- Non-stationary problems where gradient scale varies significantly +- Online learning scenarios + +**Key parameters:** + +- `lr`: Learning rate (typical: 1e-3 to 1e-2) +- `alpha`: Smoothing constant (default: 0.99) +- `momentum`: Optional momentum term +- `centered`: Use centered RMSprop (normalizes by variance) + +```{doxygenclass} torch::optim::RMSprop +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// RMSprop for RNN training +auto optimizer = torch::optim::RMSprop( + model->parameters(), + torch::optim::RMSpropOptions(1e-3) + .alpha(0.99) // smoothing constant + .momentum(0.9) // optional momentum + .centered(true)); // normalize by variance +``` + +## Adagrad (Adaptive Gradient) + +Adagrad adapts the learning rate based on the accumulated sum of squared +gradients. Parameters with frequent updates get smaller learning rates, while +parameters with infrequent updates get larger rates. This makes it ideal for +sparse data. + +**When to use:** + +- NLP tasks with sparse features +- Embedding layers with infrequent updates +- Recommendation systems with sparse user/item features + +**Limitation:** Learning rate monotonically decreases, which can cause training +to stop prematurely. For long training runs, consider Adam or RMSprop. + +```{doxygenclass} torch::optim::Adagrad +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Adagrad for sparse NLP features +auto optimizer = torch::optim::Adagrad( + model->parameters(), + torch::optim::AdagradOptions(0.01) + .lr_decay(0) // learning rate decay + .weight_decay(0) + .initial_accumulator_value(0)); +``` diff --git a/docs/cpp/source/api/optim/gradient_descent.md b/docs/cpp/source/api/optim/gradient_descent.md new file mode 100644 index 0000000000000..c72b30171f717 --- /dev/null +++ b/docs/cpp/source/api/optim/gradient_descent.md @@ -0,0 +1,49 @@ +--- +myst: + html_meta: + description: SGD optimizer in PyTorch C++ — stochastic gradient descent with momentum and weight decay. + keywords: PyTorch, C++, SGD, gradient descent, momentum, weight decay, optimizer +--- + +# Gradient Descent Optimizers + +These optimizers use gradient descent with optional enhancements like momentum. +They are the foundation of neural network training and work well when you can +afford careful hyperparameter tuning. + +## SGD (Stochastic Gradient Descent) + +The classic optimization algorithm. SGD with momentum is often the best choice +for convolutional neural networks when properly tuned. While requiring more +careful learning rate selection than adaptive methods, it frequently achieves +the best final accuracy. + +**When to use:** + +- Training CNNs (ResNet, VGG, etc.) where you want maximum accuracy +- When you have time for hyperparameter tuning +- When combined with learning rate schedules (warmup, cosine annealing) + +**Key parameters:** + +- `lr`: Learning rate (typical: 0.01-0.1 for CNNs) +- `momentum`: Accelerates convergence (typical: 0.9) +- `weight_decay`: L2 regularization coefficient +- `nesterov`: Use Nesterov momentum (often improves convergence) + +```{doxygenclass} torch::optim::SGD +:members: +:undoc-members: +``` + +**Example:** + +```cpp +// Standard SGD with momentum - good for CNNs +auto optimizer = torch::optim::SGD( + model->parameters(), + torch::optim::SGDOptions(0.01) // learning rate + .momentum(0.9) // momentum factor + .weight_decay(1e-4) // L2 regularization + .nesterov(true)); // Nesterov momentum +``` diff --git a/docs/cpp/source/api/optim/index.md b/docs/cpp/source/api/optim/index.md new file mode 100644 index 0000000000000..e46c75be09f55 --- /dev/null +++ b/docs/cpp/source/api/optim/index.md @@ -0,0 +1,116 @@ +--- +myst: + html_meta: + description: PyTorch C++ optimizer API — SGD, Adam, and other optimizers for training neural networks. + keywords: PyTorch, C++, optimizer, optim, SGD, Adam, training +--- + +# Optimizers (torch::optim) + +The `torch::optim` namespace provides optimization algorithms for +training neural networks. These optimizers update model parameters based +on computed gradients to minimize the loss function. + +**When to use torch::optim:** + +- When training neural networks with gradient descent +- When you need different optimization strategies (SGD, Adam, etc.) +- When implementing learning rate schedules + +**Basic usage:** + +```cpp +#include + +// Create model and optimizer +auto model = std::make_shared(); +auto optimizer = torch::optim::Adam( + model->parameters(), + torch::optim::AdamOptions(1e-3)); + +// Training loop +for (auto& batch : *data_loader) { + optimizer.zero_grad(); // Clear gradients + auto loss = loss_fn(model->forward(batch.data), batch.target); + loss.backward(); // Compute gradients + optimizer.step(); // Update parameters +} +``` + +## Header Files + +- `torch/csrc/api/include/torch/optim.h` - Main optim header +- `torch/csrc/api/include/torch/optim/optimizer.h` - Optimizer base class +- `torch/csrc/api/include/torch/optim/sgd.h` - SGD optimizer +- `torch/csrc/api/include/torch/optim/adam.h` - Adam optimizer + +## Optimizer Base Class + +All optimizers inherit from the `Optimizer` base class, which provides common +functionality for parameter updates, gradient zeroing, and state management. + +```{doxygenclass} torch::optim::Optimizer +:members: +:undoc-members: +``` + +### OptimizerOptions + +```{doxygenclass} torch::optim::OptimizerOptions +:members: +:undoc-members: +``` + +### OptimizerParamGroup + +```{doxygenclass} torch::optim::OptimizerParamGroup +:members: +:undoc-members: +``` + +### OptimizerParamState + +```{doxygenclass} torch::optim::OptimizerParamState +:members: +:undoc-members: +``` + +## Choosing an Optimizer + +Selecting the right optimizer depends on your model architecture, dataset, and +training requirements: + +```{list-table} +:widths: 20 40 40 +:header-rows: 1 + +* - Optimizer + - Best For + - Trade-offs +* - **SGD + Momentum** + - CNNs, well-understood problems, when you can tune hyperparameters + - Requires careful learning rate tuning; often achieves best final accuracy +* - **Adam/AdamW** + - General-purpose, transformers, quick prototyping + - Works well out-of-the-box; AdamW preferred with weight decay +* - **RMSprop** + - RNNs, non-stationary objectives + - Good for recurrent architectures; handles varying gradient scales +* - **Adagrad** + - Sparse data (NLP, embeddings) + - Learning rate decreases over time; good for infrequent features +* - **LBFGS** + - Small models, fine-tuning, convex problems + - Memory-intensive; requires closure function +``` + +## Optimizer Categories + +```{toctree} +:maxdepth: 1 + +gradient_descent +adaptive +second_order +schedulers +``` diff --git a/docs/cpp/source/api/optim/schedulers.md b/docs/cpp/source/api/optim/schedulers.md new file mode 100644 index 0000000000000..814ff45e2fbaf --- /dev/null +++ b/docs/cpp/source/api/optim/schedulers.md @@ -0,0 +1,152 @@ +--- +myst: + html_meta: + description: Learning rate schedulers in PyTorch C++ — StepLR, ExponentialLR, and other LR scheduling policies. + keywords: PyTorch, C++, learning rate, scheduler, StepLR, ExponentialLR, LRScheduler +--- + +# Learning Rate Schedulers + +Learning rate schedulers adjust the learning rate during training, which often +improves convergence and final accuracy. Common strategies include: + +- **Step decay**: Reduce LR by a factor every N epochs +- **Exponential decay**: Multiply LR by gamma each epoch +- **Cosine annealing**: Smoothly decrease LR following a cosine curve +- **Warmup**: Gradually increase LR at the start of training + +## LRScheduler Base Class + +```{doxygenclass} torch::optim::LRScheduler +:members: +:undoc-members: +``` + +## StepLR + +Decays the learning rate by `gamma` every `step_size` epochs. This is the +simplest and most commonly used scheduler. + +```{doxygenclass} torch::optim::StepLR +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto optimizer = torch::optim::SGD( + model->parameters(), + torch::optim::SGDOptions(0.1)); + +// Reduce LR by 10x every 30 epochs +auto scheduler = torch::optim::StepLR( + optimizer, + /*step_size=*/30, + /*gamma=*/0.1); + +for (int epoch = 0; epoch < 90; ++epoch) { + train_one_epoch(model, optimizer, data_loader); + scheduler.step(); + // LR: 0.1 (epochs 0-29), 0.01 (30-59), 0.001 (60-89) +} +``` + +## ReduceLROnPlateau + +Reduces the learning rate when a metric has stopped improving. Useful when +you want the scheduler to respond to validation loss rather than follow a +fixed schedule. + +```{doxygenclass} torch::optim::ReduceLROnPlateauScheduler +:members: +:undoc-members: +``` + +## ExponentialLR + +Decays the learning rate by `gamma` every epoch. Provides smoother decay than +StepLR but may be slower to reduce the learning rate. + +**Example:** + +```cpp +auto optimizer = torch::optim::Adam( + model->parameters(), + torch::optim::AdamOptions(1e-3)); + +// Reduce LR by 5% each epoch +auto scheduler = torch::optim::ExponentialLR( + optimizer, + /*gamma=*/0.95); + +for (int epoch = 0; epoch < num_epochs; ++epoch) { + train_one_epoch(model, optimizer, data_loader); + scheduler.step(); +} +``` + +## Complete Training Example + +Here's a complete example showing optimizer usage with learning rate scheduling: + +```cpp +#include + +struct Net : torch::nn::Module { + Net() { + fc1 = register_module("fc1", torch::nn::Linear(784, 256)); + fc2 = register_module("fc2", torch::nn::Linear(256, 10)); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(fc1->forward(x.view({-1, 784}))); + return fc2->forward(x); + } + + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; +}; + +int main() { + // Create model + auto model = std::make_shared(); + + // Create optimizer with weight decay + auto optimizer = torch::optim::AdamW( + model->parameters(), + torch::optim::AdamWOptions(1e-3) + .weight_decay(0.01)); + + // Learning rate scheduler + auto scheduler = torch::optim::StepLR(optimizer, 10, 0.5); + + // Loss function + auto loss_fn = torch::nn::CrossEntropyLoss(); + + // Training loop + for (int epoch = 0; epoch < 30; ++epoch) { + model->train(); + double epoch_loss = 0.0; + + for (auto& batch : *train_loader) { + optimizer.zero_grad(); + + auto output = model->forward(batch.data); + auto loss = loss_fn(output, batch.target); + + loss.backward(); + optimizer.step(); + + epoch_loss += loss.item(); + } + + scheduler.step(); + std::cout << "Epoch " << epoch + << " Loss: " << epoch_loss + << " LR: " << scheduler.get_last_lr()[0] + << std::endl; + } + + return 0; +} +``` diff --git a/docs/cpp/source/api/optim/second_order.md b/docs/cpp/source/api/optim/second_order.md new file mode 100644 index 0000000000000..0f93621aecf41 --- /dev/null +++ b/docs/cpp/source/api/optim/second_order.md @@ -0,0 +1,60 @@ +--- +myst: + html_meta: + description: Second-order optimizers in PyTorch C++ — LBFGS optimizer for full-batch optimization. + keywords: PyTorch, C++, LBFGS, second-order, optimizer, full-batch +--- + +# Second-Order Optimizers + +Second-order methods use curvature information (Hessian or its approximations) +to make better optimization steps. They can converge faster but are more +computationally expensive and memory-intensive. + +## LBFGS (Limited-memory Broyden-Fletcher-Goldfarb-Shanno) + +LBFGS is a quasi-Newton method that approximates the inverse Hessian using +gradient history. It can converge much faster than first-order methods for +smooth, convex-like loss surfaces. + +**When to use:** + +- Small models where memory isn't a concern +- Fine-tuning pre-trained models +- Convex or near-convex optimization problems +- Full-batch training (not mini-batch) + +**Key parameters:** + +- `lr`: Learning rate (often 1.0 for LBFGS) +- `max_iter`: Maximum iterations per step +- `history_size`: Number of past gradients to store + +**Important:** LBFGS requires a closure function that recomputes the loss. + +```{doxygenclass} torch::optim::LBFGS +:members: +:undoc-members: +``` + +**Example:** + +```cpp +auto optimizer = torch::optim::LBFGS( + model->parameters(), + torch::optim::LBFGSOptions(1.0) + .max_iter(20) + .history_size(10)); + +// LBFGS requires a closure that recomputes the model +for (int epoch = 0; epoch < num_epochs; ++epoch) { + auto closure = [&]() { + optimizer.zero_grad(); + auto output = model->forward(data); + auto loss = loss_fn(output, target); + loss.backward(); + return loss; + }; + optimizer.step(closure); +} +``` diff --git a/docs/cpp/source/api/serialize/archives.md b/docs/cpp/source/api/serialize/archives.md new file mode 100644 index 0000000000000..2d50e3db08c9f --- /dev/null +++ b/docs/cpp/source/api/serialize/archives.md @@ -0,0 +1,72 @@ +--- +myst: + html_meta: + description: Serialization archives in PyTorch C++ — InputArchive and OutputArchive for model serialization. + keywords: PyTorch, C++, InputArchive, OutputArchive, serialization, archive +--- + +# Archives + +Archives provide a lower-level interface for serialization, allowing you to +save multiple values to a single file with named keys. + +## OutputArchive + +```{doxygenclass} torch::serialize::OutputArchive +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::serialize::OutputArchive archive; +archive.write("tensor1", tensor1); +archive.write("tensor2", tensor2); +archive.save_to("multi_tensor.pt"); +``` + +## InputArchive + +```{doxygenclass} torch::serialize::InputArchive +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::serialize::InputArchive archive; +archive.load_from("multi_tensor.pt"); + +torch::Tensor tensor1, tensor2; +archive.read("tensor1", tensor1); +archive.read("tensor2", tensor2); +``` + +## Saving Multiple Values + +Archives are useful when you need to save multiple related values together: + +```cpp +// Save multiple tensors and metadata +torch::serialize::OutputArchive out_archive; +out_archive.write("weights", model_weights); +out_archive.write("biases", model_biases); +out_archive.write("epoch", torch::tensor(current_epoch)); +out_archive.write("loss", torch::tensor(best_loss)); +out_archive.save_to("training_state.pt"); + +// Load them back +torch::serialize::InputArchive in_archive; +in_archive.load_from("training_state.pt"); + +torch::Tensor weights, biases, epoch_tensor, loss_tensor; +in_archive.read("weights", weights); +in_archive.read("biases", biases); +in_archive.read("epoch", epoch_tensor); +in_archive.read("loss", loss_tensor); + +int epoch = epoch_tensor.item(); +float loss = loss_tensor.item(); +``` diff --git a/docs/cpp/source/api/serialize/checkpoints.md b/docs/cpp/source/api/serialize/checkpoints.md new file mode 100644 index 0000000000000..27cae975b47c5 --- /dev/null +++ b/docs/cpp/source/api/serialize/checkpoints.md @@ -0,0 +1,113 @@ +--- +myst: + html_meta: + description: Checkpointing in PyTorch C++ — saving and resuming training state. + keywords: PyTorch, C++, checkpoint, save, resume, training state +--- + +# Checkpoints + +Checkpoints save the complete training state so you can resume training +after interruption. A checkpoint typically includes: + +- Model parameters +- Optimizer state (momentum buffers, learning rates) +- Current epoch number +- Best validation loss/accuracy + +## Creating Checkpoints + +```cpp +void save_checkpoint( + std::shared_ptr model, + torch::optim::Adam& optimizer, + int epoch, + const std::string& path) { + torch::serialize::OutputArchive archive; + model->save(archive); + archive.write("epoch", torch::tensor(epoch)); + optimizer.save(archive); + archive.save_to(path); +} +``` + +## Loading Checkpoints + +```cpp +int load_checkpoint( + std::shared_ptr model, + torch::optim::Adam& optimizer, + const std::string& path) { + torch::serialize::InputArchive archive; + archive.load_from(path); + model->load(archive); + torch::Tensor epoch_tensor; + archive.read("epoch", epoch_tensor); + optimizer.load(archive); + return epoch_tensor.item(); +} +``` + +## Complete Checkpoint Example + +```cpp +#include +#include +#include + +struct Net : torch::nn::Module { + Net() { + fc1 = register_module("fc1", torch::nn::Linear(784, 256)); + fc2 = register_module("fc2", torch::nn::Linear(256, 10)); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(fc1->forward(x.view({-1, 784}))); + return fc2->forward(x); + } + + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; +}; + +int main() { + auto model = std::make_shared(); + auto optimizer = torch::optim::Adam(model->parameters(), 1e-3); + + int start_epoch = 0; + const std::string checkpoint_path = "checkpoint.pt"; + + // Resume from checkpoint if it exists + if (std::filesystem::exists(checkpoint_path)) { + std::cout << "Loading checkpoint..." << std::endl; + start_epoch = load_checkpoint(model, optimizer, checkpoint_path); + std::cout << "Resuming from epoch " << start_epoch << std::endl; + } + + // Training loop + for (int epoch = start_epoch; epoch < 100; ++epoch) { + // ... training code ... + + // Save checkpoint every 10 epochs + if ((epoch + 1) % 10 == 0) { + save_checkpoint(model, optimizer, epoch + 1, checkpoint_path); + std::cout << "Saved checkpoint at epoch " << epoch + 1 << std::endl; + } + } + + return 0; +} +``` + +## Best Practices + +1. **Save periodically**: Save checkpoints at regular intervals (e.g., every epoch + or every N batches) to minimize lost work. + +2. **Keep multiple checkpoints**: Maintain the last few checkpoints in case the + most recent one is corrupted or represents a poor model state. + +3. **Include all state**: Save everything needed to resume training, including + learning rate scheduler state if using one. + +4. **Verify checkpoints**: Occasionally verify that checkpoints can be loaded + correctly. diff --git a/docs/cpp/source/api/serialize/index.md b/docs/cpp/source/api/serialize/index.md new file mode 100644 index 0000000000000..002431d9ef632 --- /dev/null +++ b/docs/cpp/source/api/serialize/index.md @@ -0,0 +1,58 @@ +--- +myst: + html_meta: + description: PyTorch C++ serialization API — saving and loading models and tensors. + keywords: PyTorch, C++, serialization, save, load, checkpoint, model +--- + +# Serialization (torch::serialize) + +The `torch::serialize` namespace provides utilities for saving and loading +model weights, tensors, and optimizer state. This enables checkpointing during +training and deployment of trained models. + +**When to use torch::serialize:** + +- When saving trained models to disk +- When implementing training checkpoints +- When loading pre-trained weights +- When transferring models between C++ and Python + +**Basic usage:** + +```cpp +#include + +// Save a model +auto model = std::make_shared(); +// ... train the model ... +torch::save(model, "model.pt"); + +// Load a model +auto loaded_model = std::make_shared(); +torch::load(loaded_model, "model.pt"); + +// Save and load tensors +torch::Tensor tensor = torch::randn({2, 3}); +torch::save(tensor, "tensor.pt"); + +torch::Tensor loaded_tensor; +torch::load(loaded_tensor, "tensor.pt"); +``` + +## Header Files + +- `torch/csrc/api/include/torch/serialize.h` - Main serialization header +- `torch/csrc/api/include/torch/serialize/archive.h` - Archive classes +- `torch/csrc/api/include/torch/serialize/input-archive.h` - Input archive +- `torch/csrc/api/include/torch/serialize/output-archive.h` - Output archive + +## Serialization Categories + +```{toctree} +:maxdepth: 1 + +save_load +archives +checkpoints +``` diff --git a/docs/cpp/source/api/serialize/save_load.md b/docs/cpp/source/api/serialize/save_load.md new file mode 100644 index 0000000000000..e08cd21f905cd --- /dev/null +++ b/docs/cpp/source/api/serialize/save_load.md @@ -0,0 +1,84 @@ +--- +myst: + html_meta: + description: Save and load functions in PyTorch C++ — torch::save and torch::load for tensors and modules. + keywords: PyTorch, C++, save, load, torch::save, torch::load, tensor, module +--- + +# Saving and Loading + +The primary interface for serialization uses the `torch::save` and +`torch::load` functions, which can save and load tensors, modules, +and optimizers. + +## Save Functions + +```{doxygenfunction} torch::save(const Value &value, SaveToArgs&&... args) +``` + +```{doxygenfunction} torch::save(const std::vector &tensor_vec, SaveToArgs&&... args) +``` + +## Load Functions + +```{doxygenfunction} torch::load(Value &value, LoadFromArgs&&... args) +``` + +```{doxygenfunction} torch::load(std::vector &tensor_vec, LoadFromArgs&&... args) +``` + +## Saving and Loading Tensors + +```cpp +// Save a tensor +torch::Tensor tensor = torch::randn({2, 3}); +torch::save(tensor, "tensor.pt"); + +// Load a tensor +torch::Tensor loaded; +torch::load(loaded, "tensor.pt"); +``` + +## Saving and Loading Modules + +```cpp +// Define a model +struct Net : torch::nn::Module { + Net() { + fc1 = register_module("fc1", torch::nn::Linear(784, 64)); + fc2 = register_module("fc2", torch::nn::Linear(64, 10)); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(fc1->forward(x)); + return fc2->forward(x); + } + + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; +}; + +// Save model +auto model = std::make_shared(); +torch::save(model, "model.pt"); + +// Load model +auto loaded_model = std::make_shared(); +torch::load(loaded_model, "model.pt"); +``` + +## Saving Optimizer State + +```cpp +auto model = std::make_shared(); +auto optimizer = torch::optim::Adam(model->parameters(), 0.001); + +// Train... + +// Save both model and optimizer +torch::save(model, "model.pt"); +torch::save(optimizer, "optimizer.pt"); + +// Load both +torch::load(model, "model.pt"); +torch::load(optimizer, "optimizer.pt"); +``` diff --git a/docs/cpp/source/api/stable/index.md b/docs/cpp/source/api/stable/index.md new file mode 100644 index 0000000000000..adb3d477d18b6 --- /dev/null +++ b/docs/cpp/source/api/stable/index.md @@ -0,0 +1,72 @@ +--- +myst: + html_meta: + description: PyTorch Stable ABI C++ API — binary-compatible operator registration across PyTorch versions. + keywords: PyTorch, C++, stable ABI, binary compatibility, operator registration +--- + +# Torch Stable API + +The PyTorch Stable C++ API provides a binary-compatible interface for calling +tensor operations and utilities that is guaranteed to remain stable across +PyTorch versions. This enables ahead-of-time compiled extensions that don't +need recompilation when PyTorch is updated. + +**When to use the Stable API:** + +- When building extensions that must work across multiple PyTorch versions +- When distributing pre-compiled binaries +- When binary compatibility is more important than access to the latest features +- When writing custom operators for production deployment + +**Basic usage:** + +```cpp +#include +#include + +// Create a tensor using stable API +auto tensor = torch::stable::empty( + {3, 4}, + torch::headeronly::ScalarType::Float, + torch::headeronly::Layout::Strided, + torch::stable::Device(torch::headeronly::DeviceType::CPU), + false, + torch::headeronly::MemoryFormat::Contiguous); + +// Register operators with stable ABI +STABLE_TORCH_LIBRARY(myops, m) { + m.def("my_op(Tensor input) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(myops, CPU, m) { + m.impl("my_op", TORCH_BOX(&my_cpu_kernel)); +} +``` + +For more information on the stable ABI, see the +[Stable ABI notes](https://docs.pytorch.org/docs/stable/notes/libtorch_stable_abi.html). + +## Header Files + +- `torch/csrc/stable/library.h` - Stable library registration +- `torch/csrc/stable/ops.h` - Stable operator definitions +- `torch/csrc/stable/tensor_struct.h` - Stable tensor structures +- `torch/csrc/stable/device_struct.h` - Stable device structures +- `torch/csrc/stable/accelerator.h` - Accelerator support +- `torch/csrc/stable/macros.h` - Stable API macros + +## Stable API Categories + +```{toctree} +:maxdepth: 1 + +registration +operators +utilities +``` + +## See Also + +- {doc}`../library/index` - For standard (non-stable) operator registration +- [Stable ABI documentation](https://pytorch.org/docs/stable/cpp_extension.html) diff --git a/docs/cpp/source/api/stable/operators.md b/docs/cpp/source/api/stable/operators.md new file mode 100644 index 0000000000000..8c3bfac176ea6 --- /dev/null +++ b/docs/cpp/source/api/stable/operators.md @@ -0,0 +1,146 @@ +--- +myst: + html_meta: + description: Stable ABI operator API in PyTorch C++ — StableLibrary and boxed kernel registration. + keywords: PyTorch, C++, stable ABI, StableLibrary, operator, boxed kernel +--- + +# Stable Operators + +The stable API provides tensor operations that maintain binary compatibility +across PyTorch versions. + +## Tensor Class + +```{doxygenclass} torch::stable::Tensor +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::stable::Tensor tensor = torch::stable::empty({3, 4}, ...); +float* data = tensor.data_ptr(); +auto shape = tensor.sizes(); +``` + +## Device Class + +```{doxygenclass} torch::stable::Device +:members: +:undoc-members: +``` + +**Example:** + +```cpp +torch::stable::Device cpu_device(torch::headeronly::DeviceType::CPU); +torch::stable::Device cuda_device(torch::headeronly::DeviceType::CUDA, 0); +``` + +## Tensor Creation + +```{doxygenfunction} torch::stable::empty +``` + +```{doxygenfunction} torch::stable::empty_like +``` + +```{doxygenfunction} torch::stable::new_empty(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) +``` + +```{doxygenfunction} torch::stable::new_zeros(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) +``` + +```{doxygenfunction} torch::stable::full +``` + +```{doxygenfunction} torch::stable::from_blob(void *data, torch::headeronly::IntHeaderOnlyArrayRef sizes, torch::headeronly::IntHeaderOnlyArrayRef strides, torch::stable::Device device, torch::headeronly::ScalarType dtype, int64_t storage_offset, torch::headeronly::Layout layout) +``` + +**Example:** + +```cpp +auto tensor = torch::stable::empty( + {3, 4}, + torch::headeronly::ScalarType::Float, + torch::headeronly::Layout::Strided, + torch::stable::Device(torch::headeronly::DeviceType::CUDA, 0), + false, + torch::headeronly::MemoryFormat::Contiguous); +``` + +## Tensor Manipulation + +```{doxygenfunction} torch::stable::clone +``` + +```{doxygenfunction} torch::stable::contiguous +``` + +```{doxygenfunction} torch::stable::reshape +``` + +```{doxygenfunction} torch::stable::view +``` + +```{doxygenfunction} torch::stable::flatten +``` + +```{doxygenfunction} torch::stable::squeeze +``` + +```{doxygenfunction} torch::stable::unsqueeze +``` + +```{doxygenfunction} torch::stable::transpose +``` + +```{doxygenfunction} torch::stable::select +``` + +```{doxygenfunction} torch::stable::narrow +``` + +```{doxygenfunction} torch::stable::pad +``` + +## Device and Type Conversion + +```{doxygenfunction} torch::stable::to(const torch::stable::Tensor &self, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, bool non_blocking, bool copy, std::optional memory_format) +``` + +```{doxygenfunction} torch::stable::to(const torch::stable::Tensor &self, torch::stable::Device device, bool non_blocking, bool copy) +``` + +## In-place Operations + +```{doxygenfunction} torch::stable::fill_ +``` + +```{doxygenfunction} torch::stable::zero_ +``` + +```{doxygenfunction} torch::stable::copy_ +``` + +## Mathematical Operations + +```{doxygenfunction} torch::stable::matmul +``` + +```{doxygenfunction} torch::stable::amax(const torch::stable::Tensor &self, int64_t dim, bool keepdim) +``` + +```{doxygenfunction} torch::stable::amax(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim) +``` + +```{doxygenfunction} torch::stable::sum +``` + +```{doxygenfunction} torch::stable::sum_out +``` + +```{doxygenfunction} torch::stable::subtract +``` diff --git a/docs/cpp/source/api/stable/registration.md b/docs/cpp/source/api/stable/registration.md new file mode 100644 index 0000000000000..ad970a66d72ae --- /dev/null +++ b/docs/cpp/source/api/stable/registration.md @@ -0,0 +1,116 @@ +--- +myst: + html_meta: + description: Stable ABI registration macros in PyTorch C++ — STABLE_TORCH_LIBRARY, STABLE_TORCH_LIBRARY_IMPL, and TORCH_BOX. + keywords: PyTorch, C++, stable ABI, STABLE_TORCH_LIBRARY, TORCH_BOX, macro, registration +--- + +# Library Registration Macros + +These macros provide stable ABI equivalents of the standard PyTorch operator +registration macros (`TORCH_LIBRARY`, `TORCH_LIBRARY_IMPL`, etc.). +Use these when building custom operators that need to maintain binary +compatibility across PyTorch versions. + +## STABLE_TORCH_LIBRARY + +```{c:macro} STABLE_TORCH_LIBRARY(ns, m) + +Defines a library of operators in a namespace using the stable ABI. + +This is the stable ABI equivalent of {c:macro}`TORCH_LIBRARY`. +Use this macro to define operator schemas that will maintain +binary compatibility across PyTorch versions. Only one `STABLE_TORCH_LIBRARY` +block can exist per namespace; use `STABLE_TORCH_LIBRARY_FRAGMENT` for +additional definitions in the same namespace from different translation units. + +:param ns: The namespace in which to define operators (e.g., `mylib`). +:param m: The name of the StableLibrary variable available in the block. + +**Example:** + +```cpp +STABLE_TORCH_LIBRARY(mylib, m) { + m.def("my_op(Tensor input, int size) -> Tensor"); + m.def("another_op(Tensor a, Tensor b) -> Tensor"); +} +``` + +Minimum compatible version: PyTorch 2.9. +``` +## STABLE_TORCH_LIBRARY_IMPL + +```{c:macro} STABLE_TORCH_LIBRARY_IMPL(ns, k, m) + +Registers operator implementations for a specific dispatch key using the stable ABI. + +This is the stable ABI equivalent of `TORCH_LIBRARY_IMPL`. Use this macro +to provide implementations of operators for a specific dispatch key (e.g., +CPU, CUDA) while maintaining binary compatibility across PyTorch versions. +``` + +```{note} + +All kernel functions registered with this macro must be boxed using +the `TORCH_BOX` macro. +``` + +:param ns: The namespace in which the operators are defined. +:param k: The dispatch key (e.g., `CPU`, `CUDA`). +:param m: The name of the StableLibrary variable available in the block. + +**Example:** + +```cpp +STABLE_TORCH_LIBRARY_IMPL(mylib, CPU, m) { + m.impl("my_op", TORCH_BOX(&my_cpu_kernel)); +} + +STABLE_TORCH_LIBRARY_IMPL(mylib, CUDA, m) { + m.impl("my_op", TORCH_BOX(&my_cuda_kernel)); +} +``` + +Minimum compatible version: PyTorch 2.9. + +## STABLE_TORCH_LIBRARY_FRAGMENT + +```{c:macro} STABLE_TORCH_LIBRARY_FRAGMENT(ns, m) + +Extends operator definitions in an existing namespace using the stable ABI. + +This is the stable ABI equivalent of `TORCH_LIBRARY_FRAGMENT`. Use this macro +to add additional operator definitions to a namespace that was already +created with `STABLE_TORCH_LIBRARY`. + +:param ns: The namespace to extend. +:param m: The name of the StableLibrary variable available in the block. + +Minimum compatible version: PyTorch 2.9. +``` + +## TORCH_BOX + +```{c:macro} TORCH_BOX(func) + +Wraps a function to conform to the stable boxed kernel calling convention. + +This macro takes an unboxed kernel function pointer and generates a boxed wrapper +that can be registered with the stable library API. + +:param func: The unboxed kernel function to wrap. + +**Example:** + +```cpp +Tensor my_kernel(const Tensor& input, int64_t size) { + return input.reshape({size}); +} + +STABLE_TORCH_LIBRARY_IMPL(my_namespace, CPU, m) { + m.impl("my_op", TORCH_BOX(&my_kernel)); +} +``` + +Minimum compatible version: PyTorch 2.9. +``` \ No newline at end of file diff --git a/docs/cpp/source/api/stable/utilities.md b/docs/cpp/source/api/stable/utilities.md new file mode 100644 index 0000000000000..c9d30e96ddef9 --- /dev/null +++ b/docs/cpp/source/api/stable/utilities.md @@ -0,0 +1,267 @@ +--- +myst: + html_meta: + description: Stable ABI utilities in PyTorch C++ — version checking and compatibility helpers. + keywords: PyTorch, C++, stable ABI, utilities, version, compatibility +--- + +# Utilities + +The stable API provides various utility functions and types for working with +tensors and CUDA operations. + +## DeviceGuard Class + +```{doxygenclass} torch::stable::accelerator::DeviceGuard +:members: +:undoc-members: +``` + +```{doxygenfunction} torch::stable::accelerator::getCurrentDeviceIndex +``` + +**Example:** + +```cpp +{ + torch::stable::accelerator::DeviceGuard guard(1); + // Operations here run on device 1 +} +// Previous device is restored +``` + +## Stream + +```{doxygenclass} torch::stable::accelerator::Stream +:members: +:undoc-members: +``` + +## Stream Utilities + +For CUDA stream access, we currently recommend the ABI stable C shim API. This +will be improved in a future release with a more ergonomic wrapper. + +### Getting the Current CUDA Stream + +To obtain the current `cudaStream_t` for use in CUDA kernels: + +```cpp +#include +#include + +// For now, we rely on the ABI stable C shim API to get the current CUDA stream. +void* stream_ptr = nullptr; +TORCH_ERROR_CODE_CHECK( + aoti_torch_get_current_cuda_stream(tensor.get_device_index(), &stream_ptr)); +cudaStream_t stream = static_cast(stream_ptr); + +// Now you can use 'stream' in your CUDA kernel launches +my_kernel<<>>(args...); +``` + +```{note} + +The `TORCH_ERROR_CODE_CHECK` macro is required when using C shim APIs +to properly check error codes and throw appropriate exceptions. +``` + +## CUDA Error Checking Macros + +These macros provide stable ABI equivalents for CUDA error checking. +They wrap CUDA API calls and kernel launches, providing detailed error +messages using PyTorch's error formatting. + +### STD_CUDA_CHECK + +```{c:macro} STD_CUDA_CHECK(EXPR) + +Checks the result of a CUDA API call and throws an exception on error. +Users of this macro are expected to include `cuda_runtime.h`. + +**Example:** + +```cpp +STD_CUDA_CHECK(cudaMalloc(&ptr, size)); +STD_CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); +``` + +Minimum compatible version: PyTorch 2.10. +``` +### STD_CUDA_KERNEL_LAUNCH_CHECK + +```{c:macro} STD_CUDA_KERNEL_LAUNCH_CHECK() + +Checks for errors from the most recent CUDA kernel launch. Equivalent to +`STD_CUDA_CHECK(cudaGetLastError())`. + +**Example:** + +```cpp +my_kernel<<>>(args...); +STD_CUDA_KERNEL_LAUNCH_CHECK(); +``` + +Minimum compatible version: PyTorch 2.10. +``` +## Header-Only Utilities + +The `torch::headeronly` namespace provides header-only versions of common +PyTorch types and utilities. These can be used without linking against libtorch +at all! This portability makes them ideal for maintaining binary compatibility +across PyTorch versions. + +### Error Checking + +`STD_TORCH_CHECK` is a header-only macro for runtime assertions: + +```cpp +#include + +STD_TORCH_CHECK(condition, "Error message with ", variable, " interpolation"); +``` + +Wherever you used `TORCH_CHECK` before, you can replace usage with `STD_TORCH_CHECK` +to remove the need to link against libtorch. The only difference is that when the +condition check fails, `TORCH_CHECK` throws a fancier `c10::Error` while +`STD_TORCH_CHECK` throws a `std::runtime_error`. + +### Core Types + +The following `c10::` types are available as header-only versions under +`torch::headeronly::`: + +- `torch::headeronly::ScalarType` - Tensor data types (Float, Double, Int, etc.) +- `torch::headeronly::DeviceType` - Device types (CPU, CUDA, etc.) +- `torch::headeronly::MemoryFormat` - Memory layout formats (Contiguous, ChannelsLast, etc.) +- `torch::headeronly::Layout` - Tensor layouts (Strided, Sparse, etc.) + +```cpp +#include +#include +#include +#include + +auto dtype = torch::headeronly::ScalarType::Float; +auto device_type = torch::headeronly::DeviceType::CUDA; +auto memory_format = torch::headeronly::MemoryFormat::Contiguous; +auto layout = torch::headeronly::Layout::Strided; +``` + +### TensorAccessor + +`TensorAccessor` provides efficient, bounds-checked access to tensor data. +You can construct one from a stable tensor's data pointer, sizes, and strides: + +```cpp +#include + +// Create a TensorAccessor for a 2D float tensor +auto sizes = tensor.sizes(); +auto strides = tensor.strides(); +torch::headeronly::TensorAccessor accessor( + static_cast(tensor.mutable_data_ptr()), + sizes.data(), + strides.data()); + +// Access elements +float value = accessor[i][j]; +``` + +### Dispatch Macros + +Header-only dispatch macros (THO = Torch Header Only) are available for +dtype dispatching: + +```cpp +#include + +THO_DISPATCH_V2( + tensor.scalar_type(), // will be resolved as scalar_t + "my_kernel", + AT_WRAP(([&]() { + // code to specialize with scalar_t + // scalar_t is the resolved C++ type (e.g. float, double) + auto* data = static_cast(tensor.mutable_data_ptr()); + Scalar s(*data); + })), + AT_EXPAND(AT_ALL_TYPES), + AT_EXPAND(AT_COMPLEX_TYPES), + torch::headeronly::ScalarType::Half, + // as many type arguments as needed +); +``` + +`THO_DISPATCH_V2` works the same way as `AT_DISPATCH_V2` (see +`ATen/Dispatch_v2.h`) but does not require linking against libtorch. +As a result, whereas `AT_DISPATCH_V2` would have thrown `c10::NotImplementedError` +for unimplemented paths, `THO_DISPATCH_V2` will throw `std::runtime_error`. + +For ease of use, we've also migrated the below AT_* macros representing +collections of types to be header-only and thus have no dependency on libtorch: + +- `AT_FLOATING_TYPES` +- `AT_INTEGRAL_TYPES` +- `AT_INTEGRAL_TYPES_V2` +- `AT_ALL_TYPES` +- `AT_COMPLEX_TYPES` +- `AT_ALL_TYPES_AND_COMPLEX` +- `AT_FLOAT8_TYPES` +- `AT_BAREBONES_UNSIGNED_TYPES` +- `AT_QINT_TYPES` + +If your extension uses our older AT_DISPATCH version 1 infrastructure, +you can also migrate to a header-only libtorch-free world without upgrading +everything to version 2. + +`THO_DISPATCH_SWITCH` and `THO_DISPATCH_CASE` are the header-only +equivalents of `AT_DISPATCH_SWITCH` and `AT_DISPATCH_CASE`. Similarly, +the only user-visible difference is the exception type on an unhandled dtype, +where the `AT_` version throws a `c10::NotImplementedError` and the `THO_` +version throws a `std::runtime_error`. + +The migration is pretty mechanical: + +- `AT_DISPATCH_SWITCH` → `THO_DISPATCH_SWITCH` +- `AT_DISPATCH_CASE` → `THO_DISPATCH_CASE` +- `AT_PRIVATE_CASE_TYPE_USING_HINT` → `THO_PRIVATE_CASE_TYPE_USING_HINT` +- `at::ScalarType::X` → `torch::headeronly::ScalarType::X` + +```cpp +// ---- Before (requires linking against libtorch) ---- +#include + +#define MY_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define MY_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + MY_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +``` + +```cpp +// ---- After (header-only, no libtorch dependency) ---- +#include + +#define MY_DISPATCH_CASE_FLOATING_TYPES(...) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Float, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::Half, __VA_ARGS__) \ + THO_DISPATCH_CASE(torch::headeronly::ScalarType::BFloat16, __VA_ARGS__) + +#define MY_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + THO_DISPATCH_SWITCH(TYPE, NAME, \ + MY_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +``` + +For the complete list of header-only APIs, see `torch/header_only_apis.txt` +in the PyTorch source tree. + +## Parallelization Utilities + +```{doxygenfunction} torch::stable::parallel_for +``` + +```{doxygenfunction} torch::stable::get_num_threads +``` diff --git a/docs/cpp/source/api/xpu/index.md b/docs/cpp/source/api/xpu/index.md new file mode 100644 index 0000000000000..c88eff6a69165 --- /dev/null +++ b/docs/cpp/source/api/xpu/index.md @@ -0,0 +1,48 @@ +--- +myst: + html_meta: + description: PyTorch XPU C++ API — Intel GPU support with device management, streams, and guards. + keywords: PyTorch, C++, XPU, Intel GPU, device, stream +--- + +# XPU Support + +PyTorch provides XPU support for Intel GPU-accelerated tensor operations. +The XPU API allows you to manage Intel GPU devices, streams for asynchronous +execution, and synchronization. + +**When to use XPU APIs:** + +- When running on Intel GPUs (Data Center GPU Max, Arc, etc.) +- When implementing custom XPU kernels or operations +- When managing asynchronous execution with XPU streams +- When writing device-portable code alongside CUDA + +**Basic usage:** + +```cpp +#include + +// Check if XPU is available +if (torch::xpu::is_available()) { + // Create tensor on XPU + auto tensor = torch::randn({2, 3}, torch::device(torch::kXPU)); + + // Move model to XPU + model->to(torch::kXPU); +} +``` + +## Header Files + +- `torch/xpu.h` - High-level XPU utilities (device count, availability, seeding) +- `c10/xpu/XPUStream.h` - XPU stream management + +## XPU Categories + +```{toctree} +:maxdepth: 1 + +streams +utilities +``` diff --git a/docs/cpp/source/api/xpu/streams.md b/docs/cpp/source/api/xpu/streams.md new file mode 100644 index 0000000000000..a9b6911471d4f --- /dev/null +++ b/docs/cpp/source/api/xpu/streams.md @@ -0,0 +1,50 @@ +--- +myst: + html_meta: + description: XPU streams in PyTorch C++ — XPUStream for asynchronous Intel GPU execution. + keywords: PyTorch, C++, XPU, XPUStream, Intel GPU, stream, asynchronous +--- + +# XPU Streams + +XPU streams provide a mechanism for asynchronous execution of operations +on Intel GPUs. Like CUDA streams, operations queued to the same stream execute +in order, while operations on different streams can execute concurrently. + +## XPUStream + +```{doxygenclass} c10::xpu::XPUStream +:members: +:undoc-members: +``` + +**Example:** + +```cpp +#include + +// Get the current XPU stream +auto stream = c10::xpu::getCurrentXPUStream(); + +// Create a new stream from the pool +auto new_stream = c10::xpu::getStreamFromPool(); + +// Synchronize +stream.synchronize(); +``` + +## Acquiring XPU Streams + +```{doxygenfunction} c10::xpu::getCurrentXPUStream +``` + +```{doxygenfunction} c10::xpu::setCurrentXPUStream +``` + +```{doxygenfunction} c10::xpu::getStreamFromPool(const bool isHighPriority, DeviceIndex device) +``` + +## Stream Synchronization + +```{doxygenfunction} c10::xpu::syncStreamsOnDevice +``` diff --git a/docs/cpp/source/api/xpu/utilities.md b/docs/cpp/source/api/xpu/utilities.md new file mode 100644 index 0000000000000..9857d5e291a2b --- /dev/null +++ b/docs/cpp/source/api/xpu/utilities.md @@ -0,0 +1,53 @@ +--- +myst: + html_meta: + description: XPU utility functions in PyTorch C++ — device count, properties, and stream management for Intel GPUs. + keywords: PyTorch, C++, XPU, Intel GPU, device_count, utilities +--- + +# XPU Utility Functions + +High-level utility functions for querying and managing XPU devices. + +## Device Management + +```{doxygenfunction} torch::xpu::device_count +``` + +```{doxygenfunction} torch::xpu::is_available +``` + +```{doxygenfunction} torch::xpu::synchronize +``` + +**Example:** + +```cpp +#include + +if (torch::xpu::is_available()) { + size_t num_devices = torch::xpu::device_count(); + std::cout << "Found " << num_devices << " XPU device(s)" << std::endl; + + // Synchronize all streams on device 0 + torch::xpu::synchronize(0); +} +``` + +## Random Number Generation + +```{doxygenfunction} torch::xpu::manual_seed +``` + +```{doxygenfunction} torch::xpu::manual_seed_all +``` + +**Example:** + +```cpp +// Set seed for reproducibility on current XPU device +torch::xpu::manual_seed(42); + +// Set seed for all XPU devices +torch::xpu::manual_seed_all(42); +``` diff --git a/docs/cpp/source/conf.py b/docs/cpp/source/conf.py index 10d854c21db4f..ab3c722bb63ed 100644 --- a/docs/cpp/source/conf.py +++ b/docs/cpp/source/conf.py @@ -20,7 +20,6 @@ # See https://github.com/pytorch/pytorch/issues/79992. import os -import textwrap # sys.path.insert(0, os.path.abspath('.')) import pytorch_sphinx_theme2 @@ -31,14 +30,55 @@ # If your documentation needs a minimal Sphinx version, state it here. # needs_sphinx = "3.1.2" -run_doxygen = os.environ.get("RUN_DOXYGEN", "false") == "true" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.intersphinx", -] + (["breathe", "exhale"] if run_doxygen else []) + "breathe", + "myst_parser", +] + +# -- Breathe Configuration ------------------------------------------------ +# Breathe connects Sphinx to Doxygen XML output +# Use doxygenclass, doxygenfunction, etc. directives in RST files +# to pull documentation from C++ source code + +breathe_projects = {"PyTorch": "../build/xml"} +breathe_default_project = "PyTorch" + +# Default members to show when using doxygenclass/doxygenstruct directives +breathe_default_members = ("members", "undoc-members") + +# Map file extensions to language domains for proper syntax highlighting +breathe_domain_by_extension = { + "h": "cpp", + "hpp": "cpp", + "cpp": "cpp", + "c": "c", +} + +# Implementation detail filters - skip internal/private content +breathe_implementation_filename_extensions = [".c", ".cc", ".cpp"] + +# Show the file where items are defined +breathe_show_define_initializer = True +breathe_show_enumvalue_initializer = True + +# Control what gets shown in documentation +breathe_show_include = False # Don't show #include directives + +# Order of member documentation +breathe_order_parameters_first = False + +# Use Sphinx's C++ domain for cross-references +breathe_use_project_refids = True + +# Suppress specific Breathe warnings for cleaner builds +breathe_debug_trace_directives = False +breathe_debug_trace_doxygen_ids = False +breathe_debug_trace_qualification = False intersphinx_mapping = {"pytorch": ("https://docs.pytorch.org/docs/main", None)} @@ -52,99 +92,6 @@ "misc.highlighting_failure", ] -# Configure Breathe -breathe_show_define_initializer = True -breathe_show_enumvalue_initializer = True -breathe_default_members = ("members", "undoc-members") - - -# Fix for Python 3.10+ compatibility with exhale 2.3.0 -# MutableMapping was moved from collections to collections.abc in Python 3.10 -try: - import collections - from collections.abc import MutableMapping - - if not hasattr(collections, "MutableMapping"): - collections.MutableMapping = MutableMapping -except ImportError: - pass - -# Setup absolute paths for communicating with breathe / exhale where -# items are expected / should be trimmed by. -# This file is {repo_root}/docs/cpp/source/conf.py -this_file_dir = os.path.abspath(os.path.dirname(__file__)) -doxygen_xml_dir = os.path.join( - os.path.dirname(this_file_dir), # {repo_root}/docs/cpp - "build", # {repo_root}/docs/cpp/build - "xml", # {repo_root}/docs/cpp/build/xml -) -repo_root = os.path.dirname( # {repo_root} - os.path.dirname( # {repo_root}/docs - os.path.dirname( # {repo_root}/docs/cpp - this_file_dir # {repo_root}/docs/cpp/source - ) - ) -) - -breathe_projects = {"PyTorch": doxygen_xml_dir} -breathe_default_project = "PyTorch" - -# Setup the exhale extension -exhale_args = { - ############################################################################ - # These arguments are required. # - ############################################################################ - "containmentFolder": "./api", - "rootFileName": "library_root.rst", - "rootFileTitle": "Library API", - "doxygenStripFromPath": repo_root, - ############################################################################ - # Suggested optional arguments. # - ############################################################################ - "createTreeView": True, - "exhaleExecutesDoxygen": True, - "exhaleUseDoxyfile": True, - "verboseBuild": True, - ############################################################################ - # HTML Theme specific configurations. # - ############################################################################ - # Fix broken Sphinx RTD Theme 'Edit on GitHub' links - # Search for 'Edit on GitHub' on the FAQ: - # http://exhale.readthedocs.io/en/latest/faq.html - "pageLevelConfigMeta": ":github_url: https://github.com/pytorch/pytorch", - ############################################################################ - # Individual page layout example configuration. # - ############################################################################ - # Example of adding contents directives on custom kinds with custom title - "contentsTitle": "Page Contents", - "kindsWithContentsDirectives": ["class", "file", "namespace", "struct"], - # Exclude PIMPL files from class hierarchy tree and namespace pages. - "listingExclude": [r".*Impl$"], - ############################################################################ - # Main library page layout example configuration. # - ############################################################################ - "afterTitleDescription": textwrap.dedent( - """ - Welcome to the developer reference for the PyTorch C++ API. - """ - ), - ############################################################################ - # Duplicate handling and error management. # - ############################################################################ - # Note: Using Doxyfile instead of stdin configuration - # "exhaleDoxygenStdin" is not compatible with "exhaleUseDoxyfile" - # Handle unresolved references more gracefully - "unabridgedOrphanKinds": { - "function", - "define", - "enum", - "enumvalue", - "typedef", - "variable", - }, - "fullToctreeMaxDepth": 2, -} - # Tell sphinx what the primary language being documented is. primary_domain = "cpp" @@ -164,7 +111,10 @@ # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = ".rst" +source_suffix = { + ".rst": "restructuredtext", + ".md": "markdown", +} # The master toctree document. master_doc = "index" @@ -264,8 +214,24 @@ # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # NOTE: sharing python docs resources -html_static_path = [os.path.join(repo_root, "docs", "cpp", "source", "_static")] -html_css_files = ["cpp_theme.css"] + + +# Remove "Subclassed by" paragraphs that Breathe renders as plain text +# (not links) because per-class pages don't exist without exhale. +def remove_subclassed_by(app, doctree, docname): + from docutils import nodes + + for node in doctree.traverse(nodes.paragraph): + text = node.astext() + if text.startswith("Subclassed by "): + # Keep if it contains actual reference links + if not node.traverse(nodes.reference): + node.parent.remove(node) + + +def setup(app): + app.connect("doctree-resolved", remove_subclassed_by) + # Called automatically by Sphinx, making this `conf.py` an "extension". diff --git a/docs/cpp/source/faq.md b/docs/cpp/source/faq.md new file mode 100644 index 0000000000000..3ee305ccdf3ad --- /dev/null +++ b/docs/cpp/source/faq.md @@ -0,0 +1,244 @@ +--- +myst: + html_meta: + description: Frequently asked questions about the PyTorch C++ API and libtorch. + keywords: PyTorch, C++, FAQ, libtorch, troubleshooting +--- + +# FAQ + +Listed below are a number of common issues users face with the various parts of +the C++ API. + +## C++ Extensions + +### Undefined symbol errors from PyTorch/ATen + +**Problem**: You import your extension and get an `ImportError` stating that +some C++ symbol from PyTorch or ATen is undefined. For example: + +```cpp +>>> import extension +Traceback (most recent call last): + File "", line 1, in +ImportError: /home/user/.pyenv/versions/3.7.1/lib/python3.7/site-packages/extension.cpython-37m-x86_64-linux-gnu.so: undefined symbol: _ZN2at19UndefinedTensorImpl10_singletonE +``` + +**Fix**: The fix is to `import torch` before you import your extension. This will make +the symbols from the PyTorch dynamic (shared) library that your extension +depends on available, allowing them to be resolved once you import your extension. + +### I created a tensor using a function from `at::` and get errors + +**Problem**: You created a tensor using e.g. `at::ones` or `at::randn` or +any other tensor factory from the `at::` namespace and are getting errors. + +**Fix**: Replace `at::` with `torch::` for factory function calls. You +should never use factory functions from the `at::` namespace, as they will +create tensors. The corresponding `torch::` functions will create variables, +and you should only ever deal with variables in your code. + +## LibTorch + +### How do I move a model to GPU? + +**Problem**: You want to run your model on GPU but are unsure how to move both +the model and tensors to the correct device. + +**Fix**: Use the `to()` method to move your model and tensors to a CUDA device: + +```cpp +torch::Device device(torch::kCUDA); +model->to(device); +auto input = torch::randn({1, 3, 224, 224}).to(device); +auto output = model->forward(input); +``` + +You can also check for CUDA availability before moving: + +```cpp +torch::Device device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); +``` + +Make sure to compile with the TorchScript headers by including ``. + +### My model runs slower in C++ than in Python + +**Problem**: Your model inference is slower in C++ compared to Python. + +**Fix**: There are several common causes: + +1. **Enable inference mode**: Wrap your inference code with `torch::NoGradGuard` + to disable gradient computation: + +```cpp +torch::NoGradGuard no_grad; +auto output = model->forward(input); +``` + +2. **Enable optimizations**: For TorchScript models, use `optimize_for_inference`: + +```cpp +module = torch::jit::optimize_for_inference(module); +``` + +3. **Warm up the model**: Run a few inference passes before benchmarking to allow + JIT compilation and memory allocation to complete. + +4. **Check thread settings**: Ensure proper thread configuration: + +```cpp +at::set_num_threads(4); // Adjust based on your hardware +``` + +## Neural Network Modules + +### How do I register submodules in a custom module? + +**Problem**: You created a custom module but the submodules are not being +recognized during `forward()` or when saving/loading the model. + +**Fix**: You must register submodules in the constructor using +`register_module()`: + +```cpp +struct MyModel : torch::nn::Module { + MyModel() { + fc1 = register_module("fc1", torch::nn::Linear(784, 128)); + fc2 = register_module("fc2", torch::nn::Linear(128, 10)); + } + + torch::Tensor forward(torch::Tensor x) { + x = torch::relu(fc1->forward(x)); + return fc2->forward(x); + } + + torch::nn::Linear fc1{nullptr}, fc2{nullptr}; +}; +``` + +### How do I set a module to evaluation mode? + +**Problem**: Layers like Dropout and BatchNorm behave differently during training +and evaluation, and you need to switch between modes. + +**Fix**: Use the `eval()` and `train()` methods: + +```cpp +model->eval(); // Set to evaluation mode +// ... run inference ... +model->train(); // Set back to training mode +``` + +## Data Loading + +### How do I create a custom dataset? + +**Problem**: You want to load your own data instead of using built-in datasets. + +**Fix**: Create a class that inherits from `torch::data::datasets::Dataset` and +implement the `get()` and `size()` methods: + +```cpp +class CustomDataset : public torch::data::datasets::Dataset { + public: + explicit CustomDataset(const std::string& data_path) { + // Load your data here + } + + torch::data::Example<> get(size_t index) override { + // Return a single data sample + torch::Tensor data = /* load data at index */; + torch::Tensor label = /* load label at index */; + return {data, label}; + } + + torch::optional size() const override { + return dataset_size_; + } + + private: + size_t dataset_size_; +}; +``` + +Then use it with a DataLoader: + +```cpp +auto dataset = CustomDataset("path/to/data") + .map(torch::data::transforms::Stack<>()); +auto dataloader = torch::data::make_data_loader( + std::move(dataset), + torch::data::DataLoaderOptions().batch_size(32).workers(4)); +``` + +## Serialization + +### How do I save and load model weights? + +**Problem**: You want to save trained model weights and load them later. + +**Fix**: Use `torch::save()` and `torch::load()`: + +```cpp +// Saving +torch::save(model, "model.pt"); + +// Loading +torch::load(model, "model.pt"); +``` + +For saving only specific tensors or state: + +```cpp +torch::serialize::OutputArchive archive; +model->save(archive); +archive.save_to("model_weights.pt"); + +// Loading +torch::serialize::InputArchive archive; +archive.load_from("model_weights.pt"); +model->load(archive); +``` + +## Build and Compilation + +### CMake cannot find Torch + +**Problem**: When building your project with CMake, you get an error that +`Torch` package cannot be found. + +**Fix**: You need to specify the path to the LibTorch installation using +`CMAKE_PREFIX_PATH`: + +```cpp +cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +``` + +Alternatively, set `Torch_DIR` to point to the directory containing +`TorchConfig.cmake`: + +```cpp +cmake -DTorch_DIR=/path/to/libtorch/share/cmake/Torch .. +``` + +### Linker errors with undefined references + +**Problem**: Your project compiles but you get linker errors with undefined +references to PyTorch symbols. + +**Fix**: Ensure you're linking against all required libraries in your +`CMakeLists.txt`: + +```cpp +find_package(Torch REQUIRED) +add_executable(my_app main.cpp) +target_link_libraries(my_app "${TORCH_LIBRARIES}") +set_property(TARGET my_app PROPERTY CXX_STANDARD 17) +``` + +Also ensure that the compiler flags are set correctly: + +```cpp +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +``` diff --git a/docs/cpp/source/frontend.md b/docs/cpp/source/frontend.md new file mode 100644 index 0000000000000..244e2c836bf98 --- /dev/null +++ b/docs/cpp/source/frontend.md @@ -0,0 +1,146 @@ +--- +myst: + html_meta: + description: PyTorch C++ Frontend guide — defining models, training loops, and using torch::nn modules in C++. + keywords: PyTorch, C++, frontend, torch::nn, Module, training, inference +--- + +# The C++ Frontend + +The PyTorch C++ frontend is a C++17 library for CPU and GPU +tensor computation, with automatic differentiation and high level building +blocks for state of the art machine learning applications. + +## Description + +The PyTorch C++ frontend can be thought of as a C++ version of the +PyTorch Python frontend, providing automatic differentiation and various higher +level abstractions for machine learning and neural networks. Specifically, +it consists of the following components: + +| Component | Description | +| --- | --- | +| `torch::Tensor` | Automatically differentiable, efficient CPU and GPU enabled tensors | +| `torch::nn` | A collection of composable modules for neural network modeling | +| `torch::optim` | Optimization algorithms like SGD, Adam or RMSprop to train your models | +| `torch::data` | Datasets, data pipelines and multi-threaded, asynchronous data loader | +| `torch::serialize` | A serialization API for storing and loading model checkpoints | +| `torch::python` | Glue to bind your C++ models into Python | +| `torch::jit` | Pure C++ access to the TorchScript JIT compiler | + +## End-to-end example + +Here is a simple, end-to-end example of defining and training a simple +neural network on the MNIST dataset: + +```cpp +#include + +// Define a new Module. +struct Net : torch::nn::Module { + Net() { + // Construct and register two Linear submodules. + fc1 = register_module("fc1", torch::nn::Linear(784, 64)); + fc2 = register_module("fc2", torch::nn::Linear(64, 32)); + fc3 = register_module("fc3", torch::nn::Linear(32, 10)); + } + + // Implement the Net's algorithm. + torch::Tensor forward(torch::Tensor x) { + // Use one of many tensor manipulation functions. + x = torch::relu(fc1->forward(x.reshape({x.size(0), 784}))); + x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training()); + x = torch::relu(fc2->forward(x)); + x = torch::log_softmax(fc3->forward(x), /*dim=*/1); + return x; + } + + // Use one of many "standard library" modules. + torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; +}; + +int main() { + // Create a new Net. + auto net = std::make_shared(); + + // Create a multi-threaded data loader for the MNIST dataset. + auto data_loader = torch::data::make_data_loader( + torch::data::datasets::MNIST("./data").map( + torch::data::transforms::Stack<>()), + /*batch_size=*/64); + + // Instantiate an SGD optimization algorithm to update our Net's parameters. + torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01); + + for (size_t epoch = 1; epoch <= 10; ++epoch) { + size_t batch_index = 0; + // Iterate the data loader to yield batches from the dataset. + for (auto& batch : *data_loader) { + // Reset gradients. + optimizer.zero_grad(); + // Execute the model on the input data. + torch::Tensor prediction = net->forward(batch.data); + // Compute a loss value to judge the prediction of our model. + torch::Tensor loss = torch::nll_loss(prediction, batch.target); + // Compute gradients of the loss w.r.t. the parameters of our model. + loss.backward(); + // Update the parameters based on the calculated gradients. + optimizer.step(); + // Output the loss and checkpoint every 100 batches. + if (++batch_index % 100 == 0) { + std::cout << "Epoch: " << epoch << " | Batch: " << batch_index + << " | Loss: " << loss.item() << std::endl; + // Serialize your model periodically as a checkpoint. + torch::save(net, "net.pt"); + } + } + } +} +``` + +To see more complete examples of using the PyTorch C++ frontend, see [the example repository](https://github.com/pytorch/examples/tree/master/cpp). + +## Philosophy + +PyTorch's C++ frontend was designed with the idea that the Python frontend is +great, and should be used when possible; but in some settings, performance and +portability requirements make the use of the Python interpreter infeasible. For +example, Python is a poor choice for low latency, high performance or +multithreaded environments, such as video games or production servers. The +goal of the C++ frontend is to address these use cases, while not sacrificing +the user experience of the Python frontend. + +As such, the C++ frontend has been written with a few philosophical goals in mind: + +* **Closely model the Python frontend in its design**, naming, conventions and + functionality. While there may be occasional differences between the two + frontends (e.g., where we have dropped deprecated features or fixed "warts" + in the Python frontend), we guarantee that the effort in porting a Python model + to C++ should lie exclusively in **translating language features**, + not modifying functionality or behavior. + +* **Prioritize flexibility and user-friendliness over micro-optimization.** + In C++, you can often get optimal code, but at the cost of an extremely + unfriendly user experience. Flexibility and dynamism is at the heart of + PyTorch, and the C++ frontend seeks to preserve this experience, in some + cases sacrificing performance (or "hiding" performance knobs) to keep APIs + simple and explicable. We want researchers who don't write C++ for a living + to be able to use our APIs. + +A word of warning: Python is not necessarily slower than +C++! The Python frontend calls into C++ for almost anything computationally expensive +(especially any kind of numeric operation), and these operations will take up +the bulk of time spent in a program. If you would prefer to write Python, +and can afford to write Python, we recommend using the Python interface to +PyTorch. However, if you would prefer to write C++, or need to write C++ +(because of multithreading, latency or deployment requirements), the +C++ frontend to PyTorch provides an API that is approximately as convenient, +flexible, friendly and intuitive as its Python counterpart. The two frontends +serve different use cases, work hand in hand, and neither is meant to +unconditionally replace the other. + +## Installation + +Instructions on how to install the C++ frontend library distribution, including +an example for how to build a minimal application depending on LibTorch, may be +found by following [this](https://pytorch.org/cppdocs/installing.html) link. diff --git a/docs/cpp/source/frontend.rst b/docs/cpp/source/frontend.rst deleted file mode 100644 index 7a1776f7bd4a6..0000000000000 --- a/docs/cpp/source/frontend.rst +++ /dev/null @@ -1,153 +0,0 @@ -The C++ Frontend -================ - -The PyTorch C++ frontend is a C++17 library for CPU and GPU -tensor computation, with automatic differentiation and high level building -blocks for state of the art machine learning applications. - -Description ------------ - -The PyTorch C++ frontend can be thought of as a C++ version of the -PyTorch Python frontend, providing automatic differentiation and various higher -level abstractions for machine learning and neural networks. Specifically, -it consists of the following components: - -+----------------------+------------------------------------------------------------------------+ -| Component | Description | -+======================+========================================================================+ -| ``torch::Tensor`` | Automatically differentiable, efficient CPU and GPU enabled tensors | -+----------------------+------------------------------------------------------------------------+ -| ``torch::nn`` | A collection of composable modules for neural network modeling | -+----------------------+------------------------------------------------------------------------+ -| ``torch::optim`` | Optimization algorithms like SGD, Adam or RMSprop to train your models | -+----------------------+------------------------------------------------------------------------+ -| ``torch::data`` | Datasets, data pipelines and multi-threaded, asynchronous data loader | -+----------------------+------------------------------------------------------------------------+ -| ``torch::serialize`` | A serialization API for storing and loading model checkpoints | -+----------------------+------------------------------------------------------------------------+ -| ``torch::python`` | Glue to bind your C++ models into Python | -+----------------------+------------------------------------------------------------------------+ -| ``torch::jit`` | Pure C++ access to the TorchScript JIT compiler | -+----------------------+------------------------------------------------------------------------+ - -End-to-end example ------------------- - -Here is a simple, end-to-end example of defining and training a simple -neural network on the MNIST dataset: - -.. code-block:: cpp - - #include - - // Define a new Module. - struct Net : torch::nn::Module { - Net() { - // Construct and register two Linear submodules. - fc1 = register_module("fc1", torch::nn::Linear(784, 64)); - fc2 = register_module("fc2", torch::nn::Linear(64, 32)); - fc3 = register_module("fc3", torch::nn::Linear(32, 10)); - } - - // Implement the Net's algorithm. - torch::Tensor forward(torch::Tensor x) { - // Use one of many tensor manipulation functions. - x = torch::relu(fc1->forward(x.reshape({x.size(0), 784}))); - x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training()); - x = torch::relu(fc2->forward(x)); - x = torch::log_softmax(fc3->forward(x), /*dim=*/1); - return x; - } - - // Use one of many "standard library" modules. - torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr}; - }; - - int main() { - // Create a new Net. - auto net = std::make_shared(); - - // Create a multi-threaded data loader for the MNIST dataset. - auto data_loader = torch::data::make_data_loader( - torch::data::datasets::MNIST("./data").map( - torch::data::transforms::Stack<>()), - /*batch_size=*/64); - - // Instantiate an SGD optimization algorithm to update our Net's parameters. - torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01); - - for (size_t epoch = 1; epoch <= 10; ++epoch) { - size_t batch_index = 0; - // Iterate the data loader to yield batches from the dataset. - for (auto& batch : *data_loader) { - // Reset gradients. - optimizer.zero_grad(); - // Execute the model on the input data. - torch::Tensor prediction = net->forward(batch.data); - // Compute a loss value to judge the prediction of our model. - torch::Tensor loss = torch::nll_loss(prediction, batch.target); - // Compute gradients of the loss w.r.t. the parameters of our model. - loss.backward(); - // Update the parameters based on the calculated gradients. - optimizer.step(); - // Output the loss and checkpoint every 100 batches. - if (++batch_index % 100 == 0) { - std::cout << "Epoch: " << epoch << " | Batch: " << batch_index - << " | Loss: " << loss.item() << std::endl; - // Serialize your model periodically as a checkpoint. - torch::save(net, "net.pt"); - } - } - } - } - -To see more complete examples of using the PyTorch C++ frontend, see `the example repository -`_. - -Philosophy ----------- - -PyTorch's C++ frontend was designed with the idea that the Python frontend is -great, and should be used when possible; but in some settings, performance and -portability requirements make the use of the Python interpreter infeasible. For -example, Python is a poor choice for low latency, high performance or -multithreaded environments, such as video games or production servers. The -goal of the C++ frontend is to address these use cases, while not sacrificing -the user experience of the Python frontend. - -As such, the C++ frontend has been written with a few philosophical goals in mind: - -* **Closely model the Python frontend in its design**, naming, conventions and - functionality. While there may be occasional differences between the two - frontends (e.g., where we have dropped deprecated features or fixed "warts" - in the Python frontend), we guarantee that the effort in porting a Python model - to C++ should lie exclusively in **translating language features**, - not modifying functionality or behavior. - -* **Prioritize flexibility and user-friendliness over micro-optimization.** - In C++, you can often get optimal code, but at the cost of an extremely - unfriendly user experience. Flexibility and dynamism is at the heart of - PyTorch, and the C++ frontend seeks to preserve this experience, in some - cases sacrificing performance (or "hiding" performance knobs) to keep APIs - simple and explicable. We want researchers who don't write C++ for a living - to be able to use our APIs. - -A word of warning: Python is not necessarily slower than -C++! The Python frontend calls into C++ for almost anything computationally expensive -(especially any kind of numeric operation), and these operations will take up -the bulk of time spent in a program. If you would prefer to write Python, -and can afford to write Python, we recommend using the Python interface to -PyTorch. However, if you would prefer to write C++, or need to write C++ -(because of multithreading, latency or deployment requirements), the -C++ frontend to PyTorch provides an API that is approximately as convenient, -flexible, friendly and intuitive as its Python counterpart. The two frontends -serve different use cases, work hand in hand, and neither is meant to -unconditionally replace the other. - -Installation ------------- - -Instructions on how to install the C++ frontend library distribution, including -an example for how to build a minimal application depending on LibTorch, may be -found by following `this `_ link. diff --git a/docs/cpp/source/index.rst b/docs/cpp/source/index.md similarity index 54% rename from docs/cpp/source/index.rst rename to docs/cpp/source/index.md index fb7e5986e869a..1979721ad7a94 100644 --- a/docs/cpp/source/index.rst +++ b/docs/cpp/source/index.md @@ -1,5 +1,11 @@ -PyTorch C++ API -=============== +--- +myst: + html_meta: + description: PyTorch C++ API documentation — ATen tensors, Autograd, C++ Frontend, TorchScript, and C++ Extensions. + keywords: PyTorch, C++, API, libtorch, ATen, Autograd, TorchScript, C++ Frontend +--- + +# PyTorch C++ API These pages provide the documentation for the public portions of the PyTorch C++ API. This API can roughly be divided into five parts: @@ -16,108 +22,99 @@ networks with strong emphasis on GPU acceleration as well as fast CPU performance. It is currently in use at Facebook in research and production; we are looking forward to welcoming more users of the PyTorch C++ API. -.. warning:: +```{warning} - At the moment, the C++ API should be considered "beta" stability; we may - make major breaking changes to the backend in order to improve the API, - or in service of providing the Python interface to PyTorch, which is our - most stable and best supported interface. +At the moment, the C++ API should be considered "beta" stability; we may +make major breaking changes to the backend in order to improve the API, +or in service of providing the Python interface to PyTorch, which is our +most stable and best supported interface. +``` -ATen ----- +## ATen ATen is fundamentally a tensor library, on top of which almost all other Python -and C++ interfaces in PyTorch are built. It provides a core ``Tensor`` class, +and C++ interfaces in PyTorch are built. It provides a core `Tensor` class, on which many hundreds of operations are defined. Most of these operations have -both CPU and GPU implementations, to which the ``Tensor`` class will +both CPU and GPU implementations, to which the `Tensor` class will dynamically dispatch based on its type. A small example of using ATen could look as follows: -.. code-block:: cpp - - #include +```cpp +#include - at::Tensor a = at::ones({2, 2}, at::kInt); - at::Tensor b = at::randn({2, 2}); - auto c = a + b.to(at::kInt); +at::Tensor a = at::ones({2, 2}, at::kInt); +at::Tensor b = at::randn({2, 2}); +auto c = a + b.to(at::kInt); +``` -This ``Tensor`` class and all other symbols in ATen are found in the ``at::`` +This `Tensor` class and all other symbols in ATen are found in the `at::` namespace, documented -`here `_. +[here](https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at). -Autograd --------- +## Autograd What we term *autograd* are the portions of PyTorch's C++ API that augment the -ATen ``Tensor`` class with capabilities concerning automatic differentiation. +ATen `Tensor` class with capabilities concerning automatic differentiation. The autograd system records operations on tensors to form an *autograd graph*. -Calling ``backwards()`` on a leaf variable in this graph performs reverse mode +Calling `backwards()` on a leaf variable in this graph performs reverse mode differentiation through the network of functions and tensors spanning the autograd graph, ultimately yielding gradients. The following example provides a taste of this interface: -.. code-block:: cpp - - #include - #include +```cpp +#include +#include - torch::Tensor a = torch::ones({2, 2}, torch::requires_grad()); - torch::Tensor b = torch::randn({2, 2}); - auto c = a + b; - c.backward(); // a.grad() will now hold the gradient of c w.r.t. a. +torch::Tensor a = torch::ones({2, 2}, torch::requires_grad()); +torch::Tensor b = torch::randn({2, 2}); +auto c = a + b; +c.backward(); // a.grad() will now hold the gradient of c w.r.t. a. +``` -The ``at::Tensor`` class in ATen is not differentiable by default. To add the +The `at::Tensor` class in ATen is not differentiable by default. To add the differentiability of tensors the autograd API provides, you must use tensor factory functions from the `torch::` namespace instead of the `at::` namespace. For example, while a tensor created with `at::ones` will not be differentiable, a tensor created with `torch::ones` will be. -C++ Frontend ------------- +## C++ Frontend The PyTorch C++ frontend provides a high level, pure C++ modeling interface for neural networks and general ML (Machine Learning) research and production use cases, largely following the Python API in design and provided functionality. The C++ frontend includes the following: -- An interface for defining machine learning models through a hierarchical module system (like ``torch.nn.Module``); +- An interface for defining machine learning models through a hierarchical module system (like `torch.nn.Module`); - A "standard library" of pre-existing modules for the most common modeling purposes (e.g. convolutions, RNNs, batch normalization etc.); - An optimization API, including implementations of popular optimizers such as SGD, Adam, RMSprop and others; - A means of representing datasets and data pipelines, including functionality to load data in parallel over many CPU cores; -- A serialization format for storing and loading checkpoints of a training session (like ``torch.utils.data.DataLoader``); -- Automatic parallelization of models onto multiple GPUs (like ``torch.nn.parallel.DataParallel``); +- A serialization format for storing and loading checkpoints of a training session (like `torch.utils.data.DataLoader`); +- Automatic parallelization of models onto multiple GPUs (like `torch.nn.parallel.DataParallel`); - Support code to easily bind C++ models into Python using pybind11; - Entry points to the TorchScript JIT compiler; - Helpful utilities to facilitate interfacing with the ATen and Autograd APIs. -See `this document `_ for a more +See [this document](https://pytorch.org/cppdocs/frontend.html) for a more detailed description of the C++ frontend. Relevant sections of the `torch::` -namespace related to the C++ Frontend include `torch::nn -`_, -`torch::optim -`_, -`torch::data -`_, -`torch::serialize -`_, -`torch::jit -`_ -and `torch::python -`_. -Examples of the C++ frontend can be found in `this repository -`_ which is being +namespace related to the C++ Frontend include [torch::nn](https://pytorch.org/cppdocs/api/namespace_torch__nn.html#namespace-torch-nn), +[torch::optim](https://pytorch.org/cppdocs/api/namespace_torch__optim.html#namespace-torch-optim), +[torch::data](https://pytorch.org/cppdocs/api/namespace_torch__data.html#namespace-torch-data), +[torch::serialize](https://pytorch.org/cppdocs/api/namespace_torch__serialize.html#namespace-torch-serialize), +[torch::jit](https://pytorch.org/cppdocs/api/namespace_torch__jit.html#namespace-torch-jit) +and [torch::python](https://pytorch.org/cppdocs/api/namespace_torch__python.html#namespace-torch-python). +Examples of the C++ frontend can be found in [this repository](https://github.com/pytorch/examples/tree/master/cpp) which is being expanded on a continuous and active basis. -.. note:: +```{note} - Unless you have a particular reason to constrain yourself exclusively to ATen - or the Autograd API, the C++ frontend is the recommended entry point to the - PyTorch C++ ecosystem. While it is still in beta as we collect user feedback - (from you!), it provides both more functionality and better stability - guarantees than the ATen and Autograd APIs. +Unless you have a particular reason to constrain yourself exclusively to ATen +or the Autograd API, the C++ frontend is the recommended entry point to the +PyTorch C++ ecosystem. While it is still in beta as we collect user feedback +(from you!), it provides both more functionality and better stability +guarantees than the ATen and Autograd APIs. +``` -TorchScript ------------ +## TorchScript TorchScript is a representation of a PyTorch model that can be understood, compiled and serialized by the TorchScript compiler. Fundamentally, TorchScript @@ -132,16 +129,13 @@ functionality: The first mechanism may be of great interest to you if you would like to define your models in Python as much as possible, but subsequently export them to C++ for production environments and no-Python inference. You can find out more -about this by following `this -`_ link. The second +about this by following [this](https://pytorch.org/tutorials/advanced/cpp_export.html) link. The second API concerns itself with scenarios in which you would like to extend TorchScript with custom operators, which can similarly be serialized and -invoked from C++ during inference. Lastly, the `torch::jit::compile -`_ +invoked from C++ during inference. Lastly, the [torch::jit::compile](https://pytorch.org/cppdocs/api/function_namespacetorch_1_1jit_1a8660dc13a6b82336aadac667e6dccba1.html) function may be used to access the TorchScript compiler directly from C++. -C++ Extensions --------------- +## C++ Extensions *C++ Extensions* offer a simple yet powerful way of accessing all of the above interfaces for the purpose of extending regular Python use-cases of PyTorch. @@ -151,37 +145,26 @@ does not add any new functionality to the PyTorch C++ API. Instead, it provides integration with Python setuptools as well as JIT compilation mechanisms that allow access to ATen, the autograd and other C++ APIs from Python. To learn more about the C++ extension API, go through -`this tutorial `_. - -Contents --------- - -.. toctree:: - :maxdepth: 2 +[this tutorial](https://pytorch.org/tutorials/advanced/cpp_extension.html). - installing - frontend - stable - api/library_root +## Contents -.. toctree:: - :glob: - :maxdepth: 1 - :caption: Notes +```{toctree} +:maxdepth: 2 - notes/* +installing +frontend +api/index +faq +``` -Indices and tables -================== +# Indices and tables -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` +* {ref}`genindex` +* {ref}`modindex` +* {ref}`search` -Acknowledgements ----------------- +## Acknowledgements -This documentation website for the PyTorch C++ universe has been enabled by the -`Exhale `_ project and generous investment -of time and effort by its maintainer, `svenevs `_. -We thank Stephen for his work and his efforts providing help with the PyTorch C++ documentation. +This documentation website for the PyTorch C++ universe uses the Sphinx +C++ domain for API documentation. diff --git a/docs/cpp/source/installing.md b/docs/cpp/source/installing.md new file mode 100644 index 0000000000000..7fb9067b9611f --- /dev/null +++ b/docs/cpp/source/installing.md @@ -0,0 +1,187 @@ +--- +myst: + html_meta: + description: How to install and configure the PyTorch C++ API (libtorch) for development. + keywords: PyTorch, C++, install, libtorch, setup, CMake, build +--- + +# Installing C++ Distributions of PyTorch + +We provide binary distributions of all headers, libraries and CMake +configuration files required to depend on PyTorch. We call this distribution +*LibTorch*, and you can download ZIP archives containing the latest LibTorch +distribution on [our website](https://pytorch.org/get-started/locally/). Below +is a small example of writing a minimal application that depends on LibTorch +and uses the `torch::Tensor` class which comes with the PyTorch C++ API. + +## Minimal Example + +The first step is to download the LibTorch ZIP archive via the link above. For +example: + +```sh +wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip +unzip libtorch-shared-with-deps-latest.zip +``` + +Note that the above link has CPU-only libtorch. If you would like to download a GPU-enabled +libtorch, find the right link in the link selector on https://pytorch.org + +If you're a Windows developer and wouldn't like to use CMake, you could jump to the Visual Studio +Extension section. + +Next, we can write a minimal CMake build configuration to develop a small +application that depends on LibTorch. CMake is not a hard requirement for using +LibTorch, but it is the recommended and blessed build system and will be well +supported into the future. A most basic `CMakeLists.txt` file could look like +this: + +```cmake +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) +project(example-app) + +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +add_executable(example-app example-app.cpp) +target_link_libraries(example-app "${TORCH_LIBRARIES}") +set_property(TARGET example-app PROPERTY CXX_STANDARD 20) + +# The following code block is suggested to be used on Windows. +# According to https://github.com/pytorch/pytorch/issues/25457, +# the DLLs need to be copied to avoid memory errors. +if (MSVC) + file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") + add_custom_command(TARGET example-app + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${TORCH_DLLS} + $) +endif (MSVC) +``` + +The implementation of our example will simply create a new `torch::Tensor` and +print it: + +```cpp +#include +#include + +int main() { + torch::Tensor tensor = torch::rand({2, 3}); + std::cout << tensor << std::endl; +} +``` + +While there are more fine-grained headers you can include to access only parts +of the PyTorch C++ API, including `torch/torch.h` is the most sure-proof way of +including most of its functionality. + +The last step is to build the application. For this, assume our example +directory is laid out like this: + +```sh +example-app/ + CMakeLists.txt + example-app.cpp +``` + +We can now run the following commands to build the application from within the +`example-app/` folder: + +```sh +mkdir build +cd build +cmake -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch .. +cmake --build . --config Release +``` + +where `/absolute/path/to/libtorch` should be the absolute (!) path to the unzipped LibTorch +distribution. If PyTorch was installed via pip, `CMAKE_PREFIX_PATH` can be queried +using `torch.utils.cmake_prefix_path` variable. In that case CMake configuration step would look something like follows: + +```sh +cmake -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` .. +``` + +If all goes well, it will look something like this: + +```sh +root@4b5a67132e81:/example-app# mkdir build +root@4b5a67132e81:/example-app# cd build +root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. +-- The C compiler identification is GNU 5.4.0 +-- The CXX compiler identification is GNU 5.4.0 +-- Check for working C compiler: /usr/bin/cc +-- Check for working C compiler: /usr/bin/cc -- works +-- Detecting C compiler ABI info +-- Detecting C compiler ABI info - done +-- Detecting C compile features +-- Detecting C compile features - done +-- Check for working CXX compiler: /usr/bin/c++ +-- Check for working CXX compiler: /usr/bin/c++ -- works +-- Detecting CXX compiler ABI info +-- Detecting CXX compiler ABI info - done +-- Detecting CXX compile features +-- Detecting CXX compile features - done +-- Looking for pthread.h +-- Looking for pthread.h - found +-- Looking for pthread_create +-- Looking for pthread_create - not found +-- Looking for pthread_create in pthreads +-- Looking for pthread_create in pthreads - not found +-- Looking for pthread_create in pthread +-- Looking for pthread_create in pthread - found +-- Found Threads: TRUE +-- Configuring done +-- Generating done +-- Build files have been written to: /example-app/build +root@4b5a67132e81:/example-app/build# cmake --build . --config Release +Scanning dependencies of target example-app +[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o +[100%] Linking CXX executable example-app +[100%] Built target example-app +``` + +Executing the resulting `example-app` binary found in the `build` folder +should now merrily print the tensor (exact output subject to randomness): + +```sh +root@4b5a67132e81:/example-app/build# ./example-app +0.2063 0.6593 0.0866 +0.0796 0.5841 0.1569 +[ Variable[CPUFloatType]{2,3} ] +``` + +```{tip} + +On Windows, debug and release builds are not ABI-compatible. If you plan to +build your project in debug mode, please try the debug version of LibTorch. +Also, make sure you specify the correct configuration in the `cmake --build .` +line above. +``` + +## System Requirements + +To ensure smooth installation and usage of LibTorch, please ensure your system +meets the following requirements: + +1. **GLIBC Version**: + + - GLIBC 2.29 or newer for cxx11 ABI version + +2. **GCC Version**: + + - GCC 9 or newer for cxx11 + +## Visual Studio Extension + +[LibTorch Project Template](https://marketplace.visualstudio.com/items?itemName=YiZhang.LibTorch001) can help Windows developers +set all libtorch project settings and link options for debug and release. +It's easy to use and you could check out the [demo video](https://ossci-windows.s3.us-east-1.amazonaws.com/vsextension/demo.mp4). +The only prerequisite is to download the libtorch on https://pytorch.org + +## Support + +If you run into any troubles with this installation and minimal usage guide, +please use our [forum](https://discuss.pytorch.org/) or [GitHub issues](https://github.com/pytorch/pytorch/issues) to get in touch. diff --git a/docs/cpp/source/installing.rst b/docs/cpp/source/installing.rst deleted file mode 100644 index 0612d2592f21a..0000000000000 --- a/docs/cpp/source/installing.rst +++ /dev/null @@ -1,182 +0,0 @@ -Installing C++ Distributions of PyTorch -======================================= - -We provide binary distributions of all headers, libraries and CMake -configuration files required to depend on PyTorch. We call this distribution -*LibTorch*, and you can download ZIP archives containing the latest LibTorch -distribution on `our website `_. Below -is a small example of writing a minimal application that depends on LibTorch -and uses the ``torch::Tensor`` class which comes with the PyTorch C++ API. - -Minimal Example ---------------- - -The first step is to download the LibTorch ZIP archive via the link above. For -example: - -.. code-block:: sh - - wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip - unzip libtorch-shared-with-deps-latest.zip - -Note that the above link has CPU-only libtorch. If you would like to download a GPU-enabled -libtorch, find the right link in the link selector on https://pytorch.org - -If you're a Windows developer and wouldn't like to use CMake, you could jump to the Visual Studio -Extension section. - -Next, we can write a minimal CMake build configuration to develop a small -application that depends on LibTorch. CMake is not a hard requirement for using -LibTorch, but it is the recommended and blessed build system and will be well -supported into the future. A most basic `CMakeLists.txt` file could look like -this: - -.. code-block:: cmake - - cmake_minimum_required(VERSION 3.18 FATAL_ERROR) - project(example-app) - - find_package(Torch REQUIRED) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") - - add_executable(example-app example-app.cpp) - target_link_libraries(example-app "${TORCH_LIBRARIES}") - set_property(TARGET example-app PROPERTY CXX_STANDARD 17) - - # The following code block is suggested to be used on Windows. - # According to https://github.com/pytorch/pytorch/issues/25457, - # the DLLs need to be copied to avoid memory errors. - if (MSVC) - file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") - add_custom_command(TARGET example-app - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different - ${TORCH_DLLS} - $) - endif (MSVC) - -The implementation of our example will simply create a new `torch::Tensor` and -print it: - -.. code-block:: cpp - - #include - #include - - int main() { - torch::Tensor tensor = torch::rand({2, 3}); - std::cout << tensor << std::endl; - } - -While there are more fine-grained headers you can include to access only parts -of the PyTorch C++ API, including `torch/torch.h` is the most sure-proof way of -including most of its functionality. - -The last step is to build the application. For this, assume our example -directory is laid out like this: - -.. code-block:: sh - - example-app/ - CMakeLists.txt - example-app.cpp - -We can now run the following commands to build the application from within the -``example-app/`` folder: - -.. code-block:: sh - - mkdir build - cd build - cmake -DCMAKE_PREFIX_PATH=/absolute/path/to/libtorch .. - cmake --build . --config Release - -where ``/absolute/path/to/libtorch`` should be the absolute (!) path to the unzipped LibTorch -distribution. If PyTorch was installed via pip, `CMAKE_PREFIX_PATH` can be queried -using `torch.utils.cmake_prefix_path` variable. In that case CMake configuration step would look something like follows: - -.. code-block:: sh - - cmake -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` .. - -If all goes well, it will look something like this: - -.. code-block:: sh - - root@4b5a67132e81:/example-app# mkdir build - root@4b5a67132e81:/example-app# cd build - root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. - -- The C compiler identification is GNU 5.4.0 - -- The CXX compiler identification is GNU 5.4.0 - -- Check for working C compiler: /usr/bin/cc - -- Check for working C compiler: /usr/bin/cc -- works - -- Detecting C compiler ABI info - -- Detecting C compiler ABI info - done - -- Detecting C compile features - -- Detecting C compile features - done - -- Check for working CXX compiler: /usr/bin/c++ - -- Check for working CXX compiler: /usr/bin/c++ -- works - -- Detecting CXX compiler ABI info - -- Detecting CXX compiler ABI info - done - -- Detecting CXX compile features - -- Detecting CXX compile features - done - -- Looking for pthread.h - -- Looking for pthread.h - found - -- Looking for pthread_create - -- Looking for pthread_create - not found - -- Looking for pthread_create in pthreads - -- Looking for pthread_create in pthreads - not found - -- Looking for pthread_create in pthread - -- Looking for pthread_create in pthread - found - -- Found Threads: TRUE - -- Configuring done - -- Generating done - -- Build files have been written to: /example-app/build - root@4b5a67132e81:/example-app/build# cmake --build . --config Release - Scanning dependencies of target example-app - [ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o - [100%] Linking CXX executable example-app - [100%] Built target example-app - -Executing the resulting ``example-app`` binary found in the ``build`` folder -should now merrily print the tensor (exact output subject to randomness): - -.. code-block:: sh - - root@4b5a67132e81:/example-app/build# ./example-app - 0.2063 0.6593 0.0866 - 0.0796 0.5841 0.1569 - [ Variable[CPUFloatType]{2,3} ] - -.. tip:: - On Windows, debug and release builds are not ABI-compatible. If you plan to - build your project in debug mode, please try the debug version of LibTorch. - Also, make sure you specify the correct configuration in the ``cmake --build .`` - line above. - -System Requirements -------------------- - -To ensure smooth installation and usage of LibTorch, please ensure your system -meets the following requirements: - -1. **GLIBC Version**: - - GLIBC 2.29 or newer for cxx11 ABI version - -2. **GCC Version**: - - GCC 9 or newer for cxx11 - -Visual Studio Extension ------------------------ - -`LibTorch Project Template `_ can help Windows developers -set all libtorch project settings and link options for debug and release. -It's easy to use and you could check out the `demo video `_. -The only prerequisite is to download the libtorch on https://pytorch.org - -Support -------- - -If you run into any troubles with this installation and minimal usage guide, -please use our `forum `_ or `GitHub issues -`_ to get in touch. diff --git a/docs/cpp/source/library.rst b/docs/cpp/source/library.rst deleted file mode 100644 index 5cbaaf959910b..0000000000000 --- a/docs/cpp/source/library.rst +++ /dev/null @@ -1,37 +0,0 @@ -Torch Library API -================= - -The PyTorch C++ API provides capabilities for extending PyTorch's core library -of operators with user defined operators and data types. Extensions implemented -using the Torch Library API are made available for use in both the PyTorch eager -API as well as in TorchScript. - -For a tutorial style introduction to the library API, check out the -`Extending TorchScript with Custom C++ Operators -`_ -tutorial. - -Macros ------- - -.. doxygendefine:: TORCH_LIBRARY - -.. doxygendefine:: TORCH_LIBRARY_IMPL - -Classes -------- - -.. doxygenclass:: torch::Library - :members: - -.. doxygenclass:: torch::CppFunction - :members: - -Functions ---------- - -.. doxygengroup:: torch-dispatch-overloads - :content-only: - -.. doxygengroup:: torch-schema-overloads - :content-only: diff --git a/docs/cpp/source/notes/faq.rst b/docs/cpp/source/notes/faq.rst deleted file mode 100644 index 37a1a609f3a5f..0000000000000 --- a/docs/cpp/source/notes/faq.rst +++ /dev/null @@ -1,34 +0,0 @@ -FAQ ---- - -Listed below are a number of common issues users face with the various parts of -the C++ API. - -C++ Extensions -============== - -Undefined symbol errors from PyTorch/ATen -***************************************** - -**Problem**: You import your extension and get an ``ImportError`` stating that -some C++ symbol from PyTorch or ATen is undefined. For example:: - - >>> import extension - Traceback (most recent call last): - File "", line 1, in - ImportError: /home/user/.pyenv/versions/3.7.1/lib/python3.7/site-packages/extension.cpython-37m-x86_64-linux-gnu.so: undefined symbol: _ZN2at19UndefinedTensorImpl10_singletonE - -**Fix**: The fix is to ``import torch`` before you import your extension. This will make -the symbols from the PyTorch dynamic (shared) library that your extension -depends on available, allowing them to be resolved once you import your extension. - -I created a tensor using a function from ``at::`` and get errors -**************************************************************** - -**Problem**: You created a tensor using e.g. ``at::ones`` or ``at::randn`` or -any other tensor factory from the ``at::`` namespace and are getting errors. - -**Fix**: Replace ``at::`` with ``torch::`` for factory function calls. You -should never use factory functions from the ``at::`` namespace, as they will -create tensors. The corresponding ``torch::`` functions will create variables, -and you should only ever deal with variables in your code. diff --git a/docs/cpp/source/notes/inference_mode.rst b/docs/cpp/source/notes/inference_mode.rst deleted file mode 100644 index 60cc48ca93d5b..0000000000000 --- a/docs/cpp/source/notes/inference_mode.rst +++ /dev/null @@ -1,121 +0,0 @@ -Inference Mode -============== - -``c10::InferenceMode`` is a new RAII guard analogous to ``NoGradMode`` -to be used when you are certain your operations will have no interactions -with autograd (e.g. model training). Compared to ``NoGradMode``, code run -under this mode gets better performance by disabling autograd related work like -view tracking and version counter bumps. However, tensors created inside -``c10::InferenceMode`` have more limitations when interacting with autograd system as well. - -``InferenceMode`` can be enabled for a given block of code. Inside ``InferenceMode`` -all newly allocated (non-view) tensors are marked as inference tensors. Inference tensors: - -- do not have a version counter so an error will be raised if you try to read their version - (e.g., because you saved this tensor for backward). -- are immutable outside ``InferenceMode``. So an error will be raised if you try to: - - mutate their data outside InferenceMode. - - mutate them into ``requires_grad=True`` outside InferenceMode. - To work around you can make a clone outside ``InferenceMode`` to get a normal tensor before mutating. - -A non-view tensor is an inference tensor if and only if it was allocated inside ``InferenceMode``. -A view tensor is an inference tensor if and only if it is a view of an inference tensor. - -Inside an ``InferenceMode`` block, we make the following performance guarantees: - -- Like ``NoGradMode``, all operations do not record ``grad_fn`` even if their inputs have ``requires_grad=True``. - This applies to both inference tensors and normal tensors. -- View operations on inference tensors do not do view tracking. View and non-view inference tensors are - indistinguishable. -- Inplace operations on inference tensors are guaranteed not to do a version bump. - -For more implementation details of ``InferenceMode`` please see the `RFC-0011-InferenceMode `_. - -Migration guide from ``AutoNonVariableTypeMode`` ------------------------------------------------- - -In production use of PyTorch for inference workload, we have seen a proliferation -of uses of the C++ guard ``AutoNonVariableTypeMode`` (now ``AutoDispatchBelowADInplaceOrView``), -which disables autograd, view tracking and version counter bumps. Unfortunately, -current colloquial of this guard for inference workload is unsafe: it's possible to -use ``AutoNonVariableTypeMode`` to bypass PyTorch's safety checks and result in -silently wrong results, e.g. PyTorch throws an error when tensors saved for backwards -are subsequently mutated, but mutation happens inside ``AutoNonVariableTypeMode`` will -silently bypass the check and returns wrong gradient to users. - -When current users of ``AutoNonVariableTypeMode`` think about migrating, the following -steps might help you decide the best alternatives: - -1. Users trying to run workload in inference only mode (like loading a pretrained JIT model and - run inference in C++ runtime) should add ``c10::InferenceMode guard`` to guard all operations - on tensors (including model loading). See an inference workload example below: - -.. code-block:: cpp - - c10::InferenceMode guard; - model.load_jit(saved_model); - auto inputs = preprocess_tensors(data); - auto out = model.forward(inputs); - auto outputs = postprocess_tensors(out); - -Note ``c10::InferenceMode`` offers a drop in replacement for ``AutoNonVariableTypeMode`` which preserves -the performance characteristics of ``AutoNonVariableTypeMode``. But they also have some differences that -users should pay additional attention to: - - - Both guards affects tensor execution process to skip work not related to inference, but ``InferenceMode`` - also affects tensor creation while ``AutoNonVariableTypeMode`` doesn't. In other words, tensors created - inside ``InferenceMode`` are marked as inference tensors so that certain limitations can be applied after - exiting ``InferenceMode``. - - Enabled/disabled ``InferenceMode`` states can be nested while ``AutoNonVariableTypeMode`` only allows enabled state. - -.. code-block:: cpp - - { - InferenceMode guard(true); - // InferenceMode is on - { - InferenceMode guard(false); - // InferenceMode is off - } - // InferenceMode is on - } - // InferenceMode is off - - -2. Users trying to implement a customized kernel who want to redispatch under ``Autograd`` dispatch - keys should use ``AutoDispatchBelowADInplaceOrView`` instead. Note ``AutoDispatchBelowADInplaceOrView`` is just a new name - of ``AutoNonVariableTypeMode`` since it explains the guard's functionality better. We're deprecating - ``AutoNonVariableTypeMode`` and it'll be removed in 1.10 release. See customized kernel - ``ROIAlignFunction`` in ``pytorch/vision`` for an example: - -.. code-block:: cpp - - class ROIAlignFunction : public torch::autograd::Function { - public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, - const torch::autograd::Variable& input, - const torch::autograd::Variable& rois, - double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t sampling_ratio, - bool aligned) { - ctx->saved_data["spatial_scale"] = spatial_scale; - ctx->saved_data["pooled_height"] = pooled_height; - ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["sampling_ratio"] = sampling_ratio; - ctx->saved_data["aligned"] = aligned; - ctx->saved_data["input_shape"] = input.sizes(); - ctx->save_for_backward({rois}); - // Used to be at::AutoNonVariableTypeMode g; - at::AutoDispatchBelowADInplaceOrView guard; - auto result = roi_align( - input, rois, spatial_scale, pooled_height, - pooled_width, sampling_ratio, aligned); - return {result}; - } - -Customized inplace & view kernels need some special handling in addition to the guard above, see -`custom kernel tutorial `_ -for more details. diff --git a/docs/cpp/source/notes/maybe_owned.rst b/docs/cpp/source/notes/maybe_owned.rst deleted file mode 100644 index a1bc2d02b9882..0000000000000 --- a/docs/cpp/source/notes/maybe_owned.rst +++ /dev/null @@ -1,59 +0,0 @@ -MaybeOwned -================== - -``MaybeOwned`` is a C++ smart pointer class that dynamically -encodes whether a Tensor is *owned* or *borrowed*. It is used in -certain performance-sensitive situations to avoid unnecessarily -incrementing a Tensor’s reference count (at a small cost in -overhead from the extra indirection). - -.. warning:: - MaybeOwned must be used with **extreme** care. Claims of (non-)ownership - are not statically checked, and mistakes can cause reference undercounting - and use-after-free crashes. - - Due to this lack of safety net, we discourage the use of MaybeOwned - outside code paths that are known to be highly performance sensitive. - However, if you encounter pre-existing uses of MaybeOwned in code that - you want to modify, it’s critical to understand how to use it correctly. - -The primary use case for ``MaybeOwned`` is a function or method that -dynamically chooses between returning one of its arguments (typically -from a passthrough or “no-op” code path) and returning a freshly constructed -Tensor. Such a function would return a ``MaybeOwned`` in both cases, -the former in a "borrowed" state via a call to ``MaybeOwned::borrowed()``, -and the latter in an "owned" state via a call to ``MaybeOwned::owned()``. - -The canonical example is ``Tensor``'s ``expect_contiguous`` method, which shortcuts -and returns a borrowed self-reference when already contiguous: - -.. code-block:: cpp - - inline c10::MaybeOwned Tensor::expect_contiguous(MemoryFormat memory_format) const & { - if (is_contiguous(memory_format)) { - return c10::MaybeOwned::borrowed(*this); - } else { - return c10::MaybeOwned::owned(__dispatch_contiguous(memory_format)); - } - } - -Using the vocabulary of lifetimes, the essential safety requirement for borrowing -is that a borrowed Tensor must outlive any borrowing references to it. Here, for -example, we can safely borrow ``*this``, but the Tensor returned by -``__dispatch_contiguous()`` is freshly created, and borrowing a reference would -effectively leave it ownerless. - -So, general rules of thumb: - -- When in doubt, don’t use ``MaybeOwned`` at all - in particular, prefer - avoiding using it in code that doesn’t use it already. New usage should only be - introduced when critical (and demonstrable) performance gains result. - -- When modifying or calling code that already uses ``MaybeOwned``, remember - that it's always safe to produce a ``MaybeOwned`` from a Tensor in hand - via a call to ``MaybeOwned::owned()``. This may result in an unnecessary - reference count, but never in misbehavior - so it's always the safer bet, unless - the lifetime of the Tensor you're looking to wrap is crystal clear. - -More details and implementation code can be found at and -. diff --git a/docs/cpp/source/notes/tensor_basics.rst b/docs/cpp/source/notes/tensor_basics.rst deleted file mode 100644 index cf8f68a24eec7..0000000000000 --- a/docs/cpp/source/notes/tensor_basics.rst +++ /dev/null @@ -1,154 +0,0 @@ -Tensor Basics -============= - -The ATen tensor library backing PyTorch is a simple tensor library that exposes -the Tensor operations in Torch directly in C++17. ATen's API is auto-generated -from the same declarations PyTorch uses so the two APIs will track each other -over time. - -Tensor types are resolved dynamically, such that the API is generic and does not -include templates. That is, there is one ``Tensor`` type. It can hold a CPU or -CUDA Tensor, and the tensor may have Doubles, Float, Ints, etc. This design -makes it easy to write generic code without templating everything. - -See https://pytorch.org/cppdocs/api/namespace_at.html#functions for the provided -API. Excerpt: - -.. code-block:: cpp - - Tensor atan2(const Tensor & other) const; - Tensor & atan2_(const Tensor & other); - Tensor pow(Scalar exponent) const; - Tensor pow(const Tensor & exponent) const; - Tensor & pow_(Scalar exponent); - Tensor & pow_(const Tensor & exponent); - Tensor lerp(const Tensor & end, Scalar weight) const; - Tensor & lerp_(const Tensor & end, Scalar weight); - Tensor histc() const; - Tensor histc(int64_t bins) const; - Tensor histc(int64_t bins, Scalar min) const; - Tensor histc(int64_t bins, Scalar min, Scalar max) const; - -In place operations are also provided, and always suffixed by `_` to indicate -they will modify the Tensor. - -Efficient Access to Tensor Elements ------------------------------------ - -When using Tensor-wide operations, the relative cost of dynamic dispatch is very -small. However, there are cases, especially in your own kernels, where efficient -element-wise access is needed, and the cost of dynamic dispatch inside the -element-wise loop is very high. ATen provides *accessors* that are created with -a single dynamic check that a Tensor is the type and number of dimensions. -Accessors then expose an API for accessing the Tensor elements efficiently. - -Accessors are temporary views of a Tensor. They are only valid for the lifetime -of the tensor that they view and hence should only be used locally in a -function, like iterators. - -Note that accessors are not compatible with CUDA tensors inside kernel functions. -Instead, you will have to use a *packed accessor* which behaves the same way but -copies tensor metadata instead of pointing to it. - -It is thus recommended to use *accessors* for CPU tensors and *packed accessors* -for CUDA tensors. - -CPU accessors -************* - -.. code-block:: cpp - - torch::Tensor foo = torch::rand({12, 12}); - - // assert foo is 2-dimensional and holds floats. - auto foo_a = foo.accessor(); - float trace = 0; - - for(int i = 0; i < foo_a.size(0); i++) { - // use the accessor foo_a to get tensor data. - trace += foo_a[i][i]; - } - -CUDA accessors -************** - - -.. code-block:: cpp - - __global__ void packed_accessor_kernel( - torch::PackedTensorAccessor64 foo, - float* trace) { - int i = threadIdx.x; - gpuAtomicAdd(trace, foo[i][i]); - } - - torch::Tensor foo = torch::rand({12, 12}); - - // assert foo is 2-dimensional and holds floats. - auto foo_a = foo.packed_accessor64(); - float trace = 0; - - packed_accessor_kernel<<<1, 12>>>(foo_a, &trace); - -In addition to ``PackedTensorAccessor64`` and ``packed_accessor64`` there are -also the corresponding ``PackedTensorAccessor32`` and ``packed_accessor32`` -which use 32-bit integers for indexing. This can be quite a bit faster on CUDA -but may lead to overflows in the indexing calculations. - -Note that the template can hold other parameters such as the pointer restriction -and the integer type for indexing. See documentation for a thorough template -description of *accessors* and *packed accessors*. - - -Using Externally Created Data ------------------------------ - -If you already have your tensor data allocated in memory (CPU or CUDA), -you can view that memory as a ``Tensor`` in ATen: - -.. code-block:: cpp - - float data[] = { 1, 2, 3, - 4, 5, 6 }; - torch::Tensor f = torch::from_blob(data, {2, 3}); - -These tensors cannot be resized because ATen does not own the memory, but -otherwise behave as normal tensors. - -Scalars and zero-dimensional tensors ------------------------------------- - -In addition to the ``Tensor`` objects, ATen also includes ``Scalar``\s that -represent a single number. Like a Tensor, Scalars are dynamically typed and can -hold any one of ATen's number types. Scalars can be implicitly constructed from -C++ number types. Scalars are needed because some functions like ``addmm`` take -numbers along with Tensors and expect these numbers to be the same dynamic type -as the tensor. They are also used in the API to indicate places where a function -will *always* return a Scalar value, like ``sum``. - -.. code-block:: cpp - - namespace torch { - Tensor addmm(Scalar beta, const Tensor & self, - Scalar alpha, const Tensor & mat1, - const Tensor & mat2); - Scalar sum(const Tensor & self); - } // namespace torch - - // Usage. - torch::Tensor a = ... - torch::Tensor b = ... - torch::Tensor c = ... - torch::Tensor r = torch::addmm(1.0, a, .5, b, c); - -In addition to ``Scalar``\s, ATen also allows ``Tensor`` objects to be -zero-dimensional. These Tensors hold a single value and they can be references -to a single element in a larger ``Tensor``. They can be used anywhere a -``Tensor`` is expected. They are normally created by operators like `select` -which reduce the dimensions of a ``Tensor``. - -.. code-block:: cpp - - torch::Tensor two = torch::rand({10, 20}); - two[1][2] = 4; - // ^^^^^^ <- zero-dimensional Tensor diff --git a/docs/cpp/source/notes/tensor_creation.rst b/docs/cpp/source/notes/tensor_creation.rst deleted file mode 100644 index 25c9a8ca0cd46..0000000000000 --- a/docs/cpp/source/notes/tensor_creation.rst +++ /dev/null @@ -1,353 +0,0 @@ -Tensor Creation API -=================== - -This note describes how to create tensors in the PyTorch C++ API. It highlights -the available factory functions, which populate new tensors according to some -algorithm, and lists the options available to configure the shape, data type, -device and other properties of a new tensor. - -Factory Functions ------------------ - -A *factory function* is a function that produces a new tensor. There are many -factory functions available in PyTorch (both in Python and C++), which differ -in the way they initialize a new tensor before returning it. All factory -functions adhere to the following general "schema": - -.. code-block:: cpp - - torch::(, , ) - -Let's bisect the various parts of this "schema": - -1. ```` is the name of the function you would like to invoke, -2. ```` are any required or optional parameters a particular factory function accepts, -3. ```` is an object of type ``IntArrayRef`` and specifies the shape of the resulting tensor, -4. ```` is an instance of ``TensorOptions`` and configures the data type, device, layout and other properties of the resulting tensor. - -Picking a Factory Function -************************** - -The following factory functions are available at the time of this writing (the -hyperlinks lead to the corresponding Python functions, since they often have -more eloquent documentation -- the options are the same in C++): - -- `arange `_: Returns a tensor with a sequence of integers, -- `empty `_: Returns a tensor with uninitialized values, -- `eye `_: Returns an identity matrix, -- `full `_: Returns a tensor filled with a single value, -- `linspace `_: Returns a tensor with values linearly spaced in some interval, -- `logspace `_: Returns a tensor with values logarithmically spaced in some interval, -- `ones `_: Returns a tensor filled with all ones, -- `rand `_: Returns a tensor filled with values drawn from a uniform distribution on ``[0, 1)``. -- `randint `_: Returns a tensor with integers randomly drawn from an interval, -- `randn `_: Returns a tensor filled with values drawn from a unit normal distribution, -- `randperm `_: Returns a tensor filled with a random permutation of integers in some interval, -- `zeros `_: Returns a tensor filled with all zeros. - -Specifying a Size -***************** - -Functions that do not require specific arguments by nature of how they fill the -tensor can be invoked with just a size. For example, the following line creates -a vector with 5 components, initially all set to 1: - -.. code-block:: cpp - - torch::Tensor tensor = torch::ones(5); - - -What if we wanted to instead create a ``3 x 5`` matrix, or a ``2 x 3 x 4`` -tensor? In general, an ``IntArrayRef`` -- the type of the size parameter of factory -functions -- is constructed by specifying the size along each dimension in -curly braces. For example, ``{2, 3}`` for a tensor (in this case matrix) with -two rows and three columns, ``{3, 4, 5}`` for a three-dimensional tensor, and -``{2}`` for a one-dimensional tensor with two components. In the one -dimensional case, you can omit the curly braces and just pass the single -integer like we did above. Note that the squiggly braces are just one way of -constructing an ``IntArrayRef``. You can also pass an ``std::vector`` and -a few other types. Either way, this means we can construct a three-dimensional -tensor filled with values from a unit normal distribution by writing: - -.. code-block:: cpp - - torch::Tensor tensor = torch::randn({3, 4, 5}); - assert(tensor.sizes() == std::vector{3, 4, 5}); - -``tensor.sizes()`` returns an ``IntArrayRef`` which can be compared against an -``std::vector``, and we can see that it contains the sizes we passed -to the tensor. You can also write ``tensor.size(i)`` to access a single dimension, -which is equivalent to but preferred over ``tensor.sizes()[i]``. - -Passing Function-Specific Parameters -************************************ - -Neither ``ones`` nor ``randn`` accept any additional parameters to change their -behavior. One function which does require further configuration is ``randint``, -which takes an upper bound on the value for the integers it generates, as well -as an optional lower bound, which defaults to zero. Here we create a ``5 x 5`` -square matrix with integers between 0 and 10: - -.. code-block:: cpp - - torch::Tensor tensor = torch::randint(/*high=*/10, {5, 5}); - -And here we raise the lower bound to 3: - -.. code-block:: cpp - - torch::Tensor tensor = torch::randint(/*low=*/3, /*high=*/10, {5, 5}); - -The inline comments ``/*low=*/`` and ``/*high=*/`` are not required of course, -but aid readability just like keyword arguments in Python. - -.. tip:: - - The main take-away is that the size always follows the function specific - arguments. - -.. attention:: - - Sometimes a function does not need a size at all. For example, the size of - the tensor returned by ``arange`` is fully specified by its function-specific - arguments -- the lower and upper bound of a range of integers. In that case - the function does not take a ``size`` parameter. - -Configuring Properties of the Tensor -************************************ - -The previous section discussed function-specific arguments. Function-specific -arguments can only change the values with which tensors are filled, and -sometimes the size of the tensor. They never change things like the data type -(e.g. ``float32`` or ``int64``) of the tensor being created, or whether it -lives in CPU or GPU memory. The specification of these properties is left to -the very last argument to every factory function: a ``TensorOptions`` object, -discussed below. - -``TensorOptions`` is a class that encapsulates the construction axes of a -Tensor. With *construction axis* we mean a particular property of a Tensor that -can be configured before its construction (and sometimes changed afterwards). -These construction axes are: - -- The ``dtype`` (previously "scalar type"), which controls the data type of the - elements stored in the tensor, -- The ``layout``, which is either strided (dense) or sparse, -- The ``device``, which represents a compute device on which a tensor is stored (like a CPU or CUDA GPU), -- The ``requires_grad`` boolean to enable or disable gradient recording for a tensor, - -If you are used to PyTorch in Python, these axes will sound very familiar. The -allowed values for these axes at the moment are: - -- For ``dtype``: ``kUInt8``, ``kInt8``, ``kInt16``, ``kInt32``, ``kInt64``, ``kFloat32`` and ``kFloat64``, -- For ``layout``: ``kStrided`` and ``kSparse``, -- For ``device``: Either ``kCPU``, or ``kCUDA`` (which accepts an optional device index), -- For ``requires_grad``: either ``true`` or ``false``. - -.. tip:: - - There exist "Rust-style" shorthands for dtypes, like ``kF32`` instead of - ``kFloat32``. See `here - `_ - for the full list. - - -An instance of ``TensorOptions`` stores a concrete value for each of these -axes. Here is an example of creating a ``TensorOptions`` object that represents -a 64-bit float, strided tensor that requires a gradient, and lives on CUDA -device 1: - -.. code-block:: cpp - - auto options = - torch::TensorOptions() - .dtype(torch::kFloat32) - .layout(torch::kStrided) - .device(torch::kCUDA, 1) - .requires_grad(true); - - -Notice how we use the '"builder"-style methods of ``TensorOptions`` to -construct the object piece by piece. If we pass this object as the last -argument to a factory function, the newly created tensor will have these -properties: - -.. code-block:: cpp - - torch::Tensor tensor = torch::full({3, 4}, /*value=*/123, options); - - assert(tensor.dtype() == torch::kFloat32); - assert(tensor.layout() == torch::kStrided); - assert(tensor.device().type() == torch::kCUDA); // or device().is_cuda() - assert(tensor.device().index() == 1); - assert(tensor.requires_grad()); - -Now, you may be thinking: do I really need to specify each axis for every new -tensor I create? Fortunately, the answer is "no", as **every axis has a default -value**. These defaults are: - -- ``kFloat32`` for the dtype, -- ``kStrided`` for the layout, -- ``kCPU`` for the device, -- ``false`` for ``requires_grad``. - -What this means is that any axis you omit during the construction of a -``TensorOptions`` object will take on its default value. For example, this is -our previous ``TensorOptions`` object, but with the ``dtype`` and ``layout`` -defaulted: - -.. code-block:: cpp - - auto options = torch::TensorOptions().device(torch::kCUDA, 1).requires_grad(true); - -In fact, we can even omit all axes to get an entirely defaulted -``TensorOptions`` object: - -.. code-block:: cpp - - auto options = torch::TensorOptions(); // or `torch::TensorOptions options;` - -A nice consequence of this is that the ``TensorOptions`` object we just spoke -so much about can be entirely omitted from any tensor factory call: - -.. code-block:: cpp - - // A 32-bit float, strided, CPU tensor that does not require a gradient. - torch::Tensor tensor = torch::randn({3, 4}); - torch::Tensor range = torch::arange(5, 10); - -But the sugar gets sweeter: In the API presented here so far, you may have -noticed that the initial ``torch::TensorOptions()`` is quite a mouthful to -write. The good news is that for every construction axis (dtype, layout, device -and ``requires_grad``), there is one *free function* in the ``torch::`` -namespace which you can pass a value for that axis. Each function then returns -a ``TensorOptions`` object preconfigured with that axis, but allowing even -further modification via the builder-style methods shown above. For example, - -.. code-block:: cpp - - torch::ones(10, torch::TensorOptions().dtype(torch::kFloat32)) - -is equivalent to - -.. code-block:: cpp - - torch::ones(10, torch::dtype(torch::kFloat32)) - -and further instead of - -.. code-block:: cpp - - torch::ones(10, torch::TensorOptions().dtype(torch::kFloat32).layout(torch::kStrided)) - -we can just write - -.. code-block:: cpp - - torch::ones(10, torch::dtype(torch::kFloat32).layout(torch::kStrided)) - -which saves us quite a bit of typing. What this means is that in practice, you -should barely, if ever, have to write out ``torch::TensorOptions``. Instead use -the ``torch::dtype()``, ``torch::device()``, ``torch::layout()`` and -``torch::requires_grad()`` functions. - -A final bit of convenience is that ``TensorOptions`` is implicitly -constructible from individual values. This means that whenever a function has a -parameter of type ``TensorOptions``, like all factory functions do, we can -directly pass a value like ``torch::kFloat32`` or ``torch::kStrided`` in place -of the full object. Therefore, when there is only a single axis we would like -to change compared to its default value, we can pass only that value. As such, -what was - -.. code-block:: cpp - - torch::ones(10, torch::TensorOptions().dtype(torch::kFloat32)) - -became - -.. code-block:: cpp - - torch::ones(10, torch::dtype(torch::kFloat32)) - -and can finally be shortened to - -.. code-block:: cpp - - torch::ones(10, torch::kFloat32) - -Of course, it is not possible to modify further properties of the -``TensorOptions`` instance with this short syntax, but if all we needed was to -change one property, this is quite practical. - -In conclusion, we can now compare how ``TensorOptions`` defaults, together with -the abbreviated API for creating ``TensorOptions`` using free functions, allow -tensor creation in C++ with the same convenience as in Python. Compare this -call in Python: - -.. code-block:: python - - torch.randn(3, 4, dtype=torch.float32, device=torch.device('cuda', 1), requires_grad=True) - -with the equivalent call in C++: - -.. code-block:: cpp - - torch::randn({3, 4}, torch::dtype(torch::kFloat32).device(torch::kCUDA, 1).requires_grad(true)) - -Pretty close! - -Conversion ----------- - -Just as we can use ``TensorOptions`` to configure how new tensors should be -created, we can also use ``TensorOptions`` to convert a tensor from one set of -properties to a new set of properties. Such a conversion usually creates a new -tensor and does not occur in-place. For example, if we have a ``source_tensor`` -created with - -.. code-block:: cpp - - torch::Tensor source_tensor = torch::randn({2, 3}, torch::kInt64); - -we can convert it from ``int64`` to ``float32``: - -.. code-block:: cpp - - torch::Tensor float_tensor = source_tensor.to(torch::kFloat32); - -.. attention:: - - The result of the conversion, ``float_tensor``, is a new tensor pointing to - new memory, unrelated to the source ``source_tensor``. - -We can then move it from CPU memory to GPU memory: - -.. code-block:: cpp - - torch::Tensor gpu_tensor = float_tensor.to(torch::kCUDA); - -If you have multiple CUDA devices available, the above code will copy the -tensor to the *default* CUDA device, which you can configure with a -``torch::DeviceGuard``. If no ``DeviceGuard`` is in place, this will be GPU -1. If you would like to specify a different GPU index, you can pass it to -the ``Device`` constructor: - -.. code-block:: cpp - - torch::Tensor gpu_two_tensor = float_tensor.to(torch::Device(torch::kCUDA, 1)); - -In the case of CPU to GPU copy and reverse, we can also configure the memory -copy to be *asynchronous* by passing ``/*non_blocking=*/false`` as the last -argument to ``to()``: - -.. code-block:: cpp - - torch::Tensor async_cpu_tensor = gpu_tensor.to(torch::kCPU, /*non_blocking=*/true); - -Conclusion ----------- - -This note hopefully gave you a good understanding of how to create and convert -tensors in an idiomatic fashion using the PyTorch C++ API. If you have any -further questions or suggestions, please use our `forum -`_ or `GitHub issues -`_ to get in touch. diff --git a/docs/cpp/source/notes/tensor_cuda_stream.rst b/docs/cpp/source/notes/tensor_cuda_stream.rst deleted file mode 100644 index 4940317713635..0000000000000 --- a/docs/cpp/source/notes/tensor_cuda_stream.rst +++ /dev/null @@ -1,276 +0,0 @@ -Tensor CUDA Stream API -====================== - -A `CUDA Stream`_ is a linear sequence of execution that belongs to a specific CUDA device. -The PyTorch C++ API supports CUDA streams with the CUDAStream class and useful helper functions to make streaming operations easy. -You can find them in `CUDAStream.h`_. This note provides more details on how to use Pytorch C++ CUDA Stream APIs. - -.. _CUDA Stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams -.. _CUDAStream.h: https://pytorch.org/cppdocs/api/file_c10_cuda_CUDAStream.h.html#file-c10-cuda-cudastream-h -.. _CUDAStreamGuard.h: https://pytorch.org/cppdocs/api/structc10_1_1cuda_1_1_c_u_d_a_stream_guard.html - -Acquiring CUDA stream -********************* - -Pytorch's C++ API provides the following ways to acquire CUDA stream: - -1. Acquire a new stream from the CUDA stream pool, streams are preallocated from the pool and returned in a round-robin fashion. - -.. code-block:: cpp - - CUDAStream getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); - -.. tip:: - - You can request a stream from the high priority pool by setting isHighPriority to true, or a stream for a specific device - by setting device index (defaulting to the current CUDA stream's device index). - -2. Acquire the default CUDA stream for the passed CUDA device, or for the current device if no device index is passed. - -.. code-block:: cpp - - CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); - -.. tip:: - - The default stream is where most computation occurs when you aren't explicitly using streams. - -3. Acquire the current CUDA stream, for the CUDA device with index ``device_index``, or for the current device if no device index is passed. - -.. code-block:: cpp - - CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); - -.. tip:: - - The current CUDA stream will usually be the default CUDA stream for the device, but it may be different if someone - called ``setCurrentCUDAStream`` or used ``StreamGuard`` or ``CUDAStreamGuard``. - - - -Set CUDA stream -*************** - -Pytorch's C++ API provides the following ways to set CUDA stream: - -1. Set the current stream on the device of the passed in stream to be the passed in stream. - -.. code-block:: cpp - - void setCurrentCUDAStream(CUDAStream stream); - -.. attention:: - - This function may have nothing to do with the current device. It only changes the current stream on the stream's device. - We recommend using ``CUDAStreamGuard``, instead, since it switches to the stream's device and makes it the current stream on that device. - ``CUDAStreamGuard`` will also restore the current device and stream when it's destroyed - -2. Use ``CUDAStreamGuard`` to switch to a CUDA stream within a scope, it is defined in `CUDAStreamGuard.h`_ - -.. tip:: - - Use ``CUDAMultiStreamGuard`` if you need to set streams on multiple CUDA devices. - -CUDA Stream Usage Examples -************************** - -1. Acquiring and setting CUDA stream on the same device - -.. code-block:: cpp - - // This example shows how to acquire and set CUDA stream on the same device. - // `at::cuda::setCurrentCUDAStream` is used to set current CUDA stream - - // create a tensor on device 0 - torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); - // get a new CUDA stream from CUDA stream pool on device 0 - at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(); - // set current CUDA stream from default stream to `myStream` on device 0 - at::cuda::setCurrentCUDAStream(myStream); - // sum() on tensor0 uses `myStream` as current CUDA stream - tensor0.sum(); - - // get the default CUDA stream on device 0 - at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(); - // set current CUDA stream back to default CUDA stream on device 0 - at::cuda::setCurrentCUDAStream(defaultStream); - // sum() on tensor0 uses `defaultStream` as current CUDA stream - tensor0.sum(); - -.. code-block:: cpp - - // This example is the same as previous example, but explicitly specify device - // index and use CUDA stream guard to set current CUDA stream - - // create a tensor on device 0 - torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(torch::kCUDA)); - // get a new stream from CUDA stream pool on device 0 - at::cuda::CUDAStream myStream = at::cuda::getStreamFromPool(false, 0); - // set the current CUDA stream to `myStream` within the scope using CUDA stream guard - { - at::cuda::CUDAStreamGuard guard(myStream); - // current CUDA stream is `myStream` from here till the end of bracket. - // sum() on tensor0 uses `myStream` as current CUDA stream - tensor0.sum(); - } - // current CUDA stream is reset to default CUDA stream after CUDA stream guard is destroyed - // sum() on tensor0 uses default CUDA stream on device 0 as current CUDA stream - tensor0.sum(); - -.. attention:: - - Above code is running on the same CUDA device. `setCurrentCUDAStream` will always set current CUDA stream on current device, - but note that `setCurrentCUDAStream` actually set current stream on the device of passed in CUDA stream. - - -2. Acquiring and setting CUDA streams on multiple devices. - -.. code-block:: cpp - - // This example shows how to acquire and set CUDA stream on two devices. - - // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 - at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); - at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); - - // set current CUDA stream to `myStream0` on device 0 - at::cuda::setCurrentCUDAStream(myStream0); - // set current CUDA stream to `myStream1` on device 1 - at::cuda::setCurrentCUDAStream(myStream1); - - // create a tensor on device 0, no need to specify device index since - // current device index is 0 - torch::Tensor tensor0 = torch::ones({2, 2}, torch::device(at::kCUDA)); - // sum() on tensor0 use `myStream0` as current CUDA stream on device 0 - tensor0.sum(); - - // change the current device index to 1 by using CUDA device guard within a bracket scope - { - at::cuda::CUDAGuard device_guard{1}; - // create a tensor on device 1 - torch::Tensor tensor1 = torch::ones({2, 2}, torch::device(at::kCUDA)); - // sum() on tensor 1 uses `myStream1` as current CUDA stream on device 1 - tensor1.sum(); - } - - // current device is reset to device 0 after device_guard is destroyed - - // acquire a new CUDA stream on device 1 - at::cuda::CUDAStream myStream1_1 = at::cuda::getStreamFromPool(false, 1); - // create a new tensor on device 1 - torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); - - // change the current device index to 1 and current CUDA stream on device 1 - // to `myStream1_1` using CUDA stream guard within a scope - { - at::cuda::CUDAStreamGuard stream_guard(myStream1_1); - // sum() on tensor1 use `myStream1_1` as current CUDA stream on device 1 - tensor1.sum(); - } - - // current device is reset to device 0 and current CUDA stream on device 1 is - // reset to `myStream1` - - // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 - tensor1.sum(); - - -3. Working with CUDA multistream guard - -.. code-block:: cpp - - // This example shows how to use CUDA multistream guard to set - // two streams on two devices at the same time. - - // create two tensor, one on device 0, one on device 1 - torch::Tensor tensor0 = torch::ones({2, 2}, torch::device({torch::kCUDA, 0})); - torch::Tensor tensor1 = torch::ones({2, 2}, torch::device({torch::kCUDA, 1})); - - // acquire new CUDA streams from CUDA stream pool on device 0 and device 1 - at::cuda::CUDAStream myStream0 = at::cuda::getStreamFromPool(false, 0); - at::cuda::CUDAStream myStream1 = at::cuda::getStreamFromPool(false, 1); - - // set current CUDA stream on device 0 to `myStream0` and - // set current CUDA stream on device 1 to `myStream1` CUDA using multistream guard - { - at::cuda::CUDAMultiStreamGuard multi_guard({myStream0, myStream1}); - - // sum() on tensor0 uses `myStream0` as current CUDA stream on device 0 - tensor0.sum(); - // sum() on tensor1 uses `myStream1` as current CUDA stream on device 1 - tensor1.sum(); - } - - // current CUDA stream on device 0 is reset to default CUDA stream on device 0 - // current CUDA stream on device 1 is reset to default CUDA stream on device 1 - - // sum() on tensor0 uses default CUDA stream as current CUDA stream on device 0 - tensor0.sum(); - // sum() on tensor1 uses default CUDA stream as current CUDA stream on device 1 - tensor1.sum(); - -.. attention:: - ``CUDAMultiStreamGuard`` does not change current device index, it only changes the stream on - each passed in stream's device. Other than scope controlling, this guard is equivalent to - calling ``setCurrentCUDAStream`` on each passed in stream. - -4. A skeleton example for handling CUDA streams on multiple devices - -.. code-block:: cpp - - // This is a skeleton example that shows how to handle CUDA streams on multiple devices - // Suppose you want to do work on the non-default stream on two devices simultaneously, and we - // already have streams on both devices in two vectors. The following code shows three ways - // of acquiring and setting the streams. - - // Usage 0: acquire CUDA stream and set current CUDA stream with `setCurrentCUDAStream` - // Create a CUDA stream vector `streams0` on device 0 - std::vector streams0 = - {at::cuda::getDefaultCUDAStream(), at::cuda::getStreamFromPool()}; - // set current stream as `streams0[0]` on device 0 - at::cuda::setCurrentCUDAStream(streams0[0]); - - // create a CUDA stream vector `streams1` on device using CUDA device guard - std::vector streams1; - { - // device index is set to 1 within this scope - at::cuda::CUDAGuard device_guard(1); - streams1.push_back(at::cuda::getDefaultCUDAStream()); - streams1.push_back(at::cuda::getStreamFromPool()); - } - // device index is reset to 0 after device_guard is destroyed - - // set current stream as `streams1[0]` on device 1 - at::cuda::setCurrentCUDAStream(streams1[0]); - - - // Usage 1: use CUDA device guard to change the current device index only - { - at::cuda::CUDAGuard device_guard(1); - - // current device index is changed to 1 within scope - // current CUDA stream is still `streams1[0]` on device 1, no change - } - // current device index is reset to 0 after `device_guard` is destroyed - - - // Usage 2: use CUDA stream guard to change both current device index and current CUDA stream. - { - at::cuda::CUDAStreamGuard stream_guard(streams1[1]); - - // current device index and current CUDA stream are set to 1 and `streams1[1]` within scope - } - // current device index and current CUDA stream are reset to 0 and `streams0[0]` after - // stream_guard is destroyed - - - // Usage 3: use CUDA multi-stream guard to change multiple streams on multiple devices - { - // This is the same as calling `torch::cuda::setCurrentCUDAStream` on both streams - at::cuda::CUDAMultiStreamGuard multi_guard({streams0[1], streams1[1]}); - - // current device index is not change, still 0 - // current CUDA stream on device 0 and device 1 are set to `streams0[1]` and `streams1[1]` - } - // current CUDA stream on device 0 and device 1 are reset to `streams0[0]` and `streams1[0]` - // after `multi_guard` is destroyed. diff --git a/docs/cpp/source/notes/tensor_indexing.rst b/docs/cpp/source/notes/tensor_indexing.rst deleted file mode 100644 index 013ab7a737c75..0000000000000 --- a/docs/cpp/source/notes/tensor_indexing.rst +++ /dev/null @@ -1,99 +0,0 @@ -Tensor Indexing API -=================== - -Indexing a tensor in the PyTorch C++ API works very similar to the Python API. -All index types such as ``None`` / ``...`` / integer / boolean / slice / tensor -are available in the C++ API, making translation from Python indexing code to C++ -very simple. The main difference is that, instead of using the ``[]``-operator -similar to the Python API syntax, in the C++ API the indexing methods are: - -- ``torch::Tensor::index`` (`link `_) -- ``torch::Tensor::index_put_`` (`link `_) - -It's also important to note that index types such as ``None`` / ``Ellipsis`` / ``Slice`` -live in the ``torch::indexing`` namespace, and it's recommended to put ``using namespace torch::indexing`` -before any indexing code for convenient use of those index types. - -Here are some examples of translating Python indexing code to C++: - -Getter ------- - -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| Python | C++ (assuming ``using namespace torch::indexing``) | -+==========================================================+======================================================================================+ -| ``tensor[None]`` | ``tensor.index({None})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[Ellipsis, ...]`` | ``tensor.index({Ellipsis, "..."})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[1, 2]`` | ``tensor.index({1, 2})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[True, False]`` | ``tensor.index({true, false})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[1::2]`` | ``tensor.index({Slice(1, None, 2)})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[torch.tensor([1, 2])]`` | ``tensor.index({torch::tensor({1, 2})})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])]`` | ``tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ - -Setter ------- - -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| Python | C++ (assuming ``using namespace torch::indexing``) | -+==========================================================+======================================================================================+ -| ``tensor[None] = 1`` | ``tensor.index_put_({None}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[Ellipsis, ...] = 1`` | ``tensor.index_put_({Ellipsis, "..."}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[1, 2] = 1`` | ``tensor.index_put_({1, 2}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[True, False] = 1`` | ``tensor.index_put_({true, false}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[1::2] = 1`` | ``tensor.index_put_({Slice(1, None, 2)}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({torch::tensor({1, 2})}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ -| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)`` | -+----------------------------------------------------------+--------------------------------------------------------------------------------------+ - - -Translating between Python/C++ index types ------------------------------------------- - -The one-to-one translation between Python and C++ index types is as follows: - -+-------------------------+------------------------------------------------------------------------+ -| Python | C++ (assuming ``using namespace torch::indexing``) | -+=========================+========================================================================+ -| ``None`` | ``None`` | -+-------------------------+------------------------------------------------------------------------+ -| ``Ellipsis`` | ``Ellipsis`` | -+-------------------------+------------------------------------------------------------------------+ -| ``...`` | ``"..."`` | -+-------------------------+------------------------------------------------------------------------+ -| ``123`` | ``123`` | -+-------------------------+------------------------------------------------------------------------+ -| ``True`` | ``true`` | -+-------------------------+------------------------------------------------------------------------+ -| ``False`` | ``false`` | -+-------------------------+------------------------------------------------------------------------+ -| ``:`` or ``::`` | ``Slice()`` or ``Slice(None, None)`` or ``Slice(None, None, None)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``1:`` or ``1::`` | ``Slice(1, None)`` or ``Slice(1, None, None)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``:3`` or ``:3:`` | ``Slice(None, 3)`` or ``Slice(None, 3, None)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``::2`` | ``Slice(None, None, 2)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``1:3`` | ``Slice(1, 3)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``1::2`` | ``Slice(1, None, 2)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``:3:2`` | ``Slice(None, 3, 2)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``1:3:2`` | ``Slice(1, 3, 2)`` | -+-------------------------+------------------------------------------------------------------------+ -| ``torch.tensor([1, 2])``| ``torch::tensor({1, 2})`` | -+-------------------------+------------------------------------------------------------------------+ diff --git a/docs/cpp/source/notes/versioning.rst b/docs/cpp/source/notes/versioning.rst deleted file mode 100644 index b7c79427bb031..0000000000000 --- a/docs/cpp/source/notes/versioning.rst +++ /dev/null @@ -1,29 +0,0 @@ -Library Versioning -================== - -We provide version number macros for identifying the version of LibTorch in use. -Example usage: - -.. code-block:: cpp - - #include - #include - - int main() { - std::cout << "PyTorch version from parts: " - << TORCH_VERSION_MAJOR << "." - << TORCH_VERSION_MINOR << "." - << TORCH_VERSION_PATCH << std::endl; - std::cout << "PyTorch version: " << TORCH_VERSION << std::endl; - } - -This will output something like: - -.. code-block:: text - - PyTorch version from parts: 1.8.0 - PyTorch version: 1.8.0 - -.. note:: - - These macros are only available in PyTorch >= 1.8.0. diff --git a/docs/cpp/source/stable.rst b/docs/cpp/source/stable.rst deleted file mode 100644 index c7c60995419da..0000000000000 --- a/docs/cpp/source/stable.rst +++ /dev/null @@ -1,388 +0,0 @@ -Torch Stable API -================ - -The PyTorch Stable C++ API provides a convenient high level interface to call -ABI-stable tensor operations and other utilities commonly used in custom operators. -These functions are designed to maintain binary compatibility across PyTorch versions, -making them suitable for use in ahead-of-time compiled code. - -For more information on the stable ABI, see the -`Stable ABI notes `_. - -Library Registration Macros ---------------------------- - -These macros provide stable ABI equivalents of the standard PyTorch operator -registration macros (``TORCH_LIBRARY``, ``TORCH_LIBRARY_IMPL``, etc.). -Use these when building custom operators that need to maintain binary -compatibility across PyTorch versions. - -``STABLE_TORCH_LIBRARY(ns, m)`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Defines a library of operators in a namespace using the stable ABI. - -This is the stable ABI equivalent of :c:macro:`TORCH_LIBRARY`. -Use this macro to define operator schemas that will maintain -binary compatibility across PyTorch versions. Only one ``STABLE_TORCH_LIBRARY`` -block can exist per namespace; use ``STABLE_TORCH_LIBRARY_FRAGMENT`` for -additional definitions in the same namespace from different translation units. - -**Parameters:** - -- ``ns`` - The namespace in which to define operators (e.g., ``mylib``). -- ``m`` - The name of the StableLibrary variable available in the block. - -**Example:** - -.. code-block:: cpp - - STABLE_TORCH_LIBRARY(mylib, m) { - m.def("my_op(Tensor input, int size) -> Tensor"); - m.def("another_op(Tensor a, Tensor b) -> Tensor"); - } - -Minimum compatible version: PyTorch 2.9. - -``STABLE_TORCH_LIBRARY_IMPL(ns, k, m)`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Registers operator implementations for a specific dispatch key using the stable ABI. - -This is the stable ABI equivalent of ``TORCH_LIBRARY_IMPL``. Use this macro -to provide implementations of operators for a specific dispatch key (e.g., -CPU, CUDA) while maintaining binary compatibility across PyTorch versions. - -.. note:: - - All kernel functions registered with this macro must be boxed using - the ``TORCH_BOX`` macro. - -**Parameters:** - -- ``ns`` - The namespace in which the operators are defined. -- ``k`` - The dispatch key (e.g., ``CPU``, ``CUDA``). -- ``m`` - The name of the StableLibrary variable available in the block. - -**Example:** - -.. code-block:: cpp - - STABLE_TORCH_LIBRARY_IMPL(mylib, CPU, m) { - m.impl("my_op", TORCH_BOX(&my_cpu_kernel)); - } - - STABLE_TORCH_LIBRARY_IMPL(mylib, CUDA, m) { - m.impl("my_op", TORCH_BOX(&my_cuda_kernel)); - } - -Minimum compatible version: PyTorch 2.9. - -``STABLE_TORCH_LIBRARY_FRAGMENT(ns, m)`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Extends operator definitions in an existing namespace using the stable ABI. - -This is the stable ABI equivalent of ``TORCH_LIBRARY_FRAGMENT``. Use this macro -to add additional operator definitions to a namespace that was already -created with ``STABLE_TORCH_LIBRARY``. - -**Parameters:** - -- ``ns`` - The namespace to extend. -- ``m`` - The name of the StableLibrary variable available in the block. - -Minimum compatible version: PyTorch 2.9. - -``TORCH_BOX(&func)`` -^^^^^^^^^^^^^^^^^^^ - -Wraps a function to conform to the stable boxed kernel calling convention. - -This macro takes an unboxed kernel function pointer and generates a boxed wrapper -that can be registered with the stable library API. - -**Parameters:** - -- ``func`` - The unboxed kernel function to wrap. - -**Example:** - -.. code-block:: cpp - - Tensor my_kernel(const Tensor& input, int64_t size) { - return input.reshape({size}); - } - - STABLE_TORCH_LIBRARY_IMPL(my_namespace, CPU, m) { - m.impl("my_op", TORCH_BOX(&my_kernel)); - } - -Minimum compatible version: PyTorch 2.9. - -Tensor Class ------------- - -The ``torch::stable::Tensor`` class offers a user-friendly C++ interface similar -to ``torch::Tensor`` while maintaining binary compatibility across PyTorch versions. - -.. doxygenclass:: torch::stable::Tensor - :members: - - -Device Class ------------- - -The ``torch::stable::Device`` class provides a user-friendly C++ interface similar -to ``c10::Device`` while maintaining binary compatibility across PyTorch versions. -It represents a compute device (CPU, CUDA, etc.) with an optional device index. - -.. doxygenclass:: torch::stable::Device - :members: - -DeviceGuard Class ------------------ - -The ``torch::stable::accelerator::DeviceGuard`` provides a user-friendly C++ -interface similar to ``c10::DeviceGuard`` while maintaining binary compatibility -across PyTorch versions. - -.. doxygenclass:: torch::stable::accelerator::DeviceGuard - :members: - -.. doxygenfunction:: torch::stable::accelerator::getCurrentDeviceIndex - - -Stream Utilities ----------------- - -For CUDA stream access, we currently recommend the ABI stable C shim API. This -will be improved in a future release with a more ergonomic wrapper. - -Getting the Current CUDA Stream -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To obtain the current ``cudaStream_t`` for use in CUDA kernels: - -.. code-block:: cpp - - #include - #include - - // For now, we rely on the ABI stable C shim API to get the current CUDA stream. - // This will be improved in a future release. - // When using a C shim API, we need to use TORCH_ERROR_CODE_CHECK to - // check the error code and throw an appropriate runtime_error otherwise. - void* stream_ptr = nullptr; - TORCH_ERROR_CODE_CHECK( - aoti_torch_get_current_cuda_stream(tensor.get_device_index(), &stream_ptr)); - cudaStream_t stream = static_cast(stream_ptr); - - // Now you can use 'stream' in your CUDA kernel launches - my_kernel<<>>(args...); - -.. note:: - - The ``TORCH_ERROR_CODE_CHECK`` macro is required when using C shim APIs - to properly check error codes and throw appropriate exceptions. - -CUDA Error Checking Macros --------------------------- - -These macros provide stable ABI equivalents for CUDA error checking. -They wrap CUDA API calls and kernel launches, providing detailed error -messages using PyTorch's error formatting. - -``STD_CUDA_CHECK(EXPR)`` -^^^^^^^^^^^^^^^^^^^^^^^^ - -Checks the result of a CUDA API call and throws an exception on error. -Users of this macro are expected to include ``cuda_runtime.h``. - -**Example:** - -.. code-block:: cpp - - STD_CUDA_CHECK(cudaMalloc(&ptr, size)); - STD_CUDA_CHECK(cudaMemcpy(dst, src, size, cudaMemcpyDeviceToHost)); - -Minimum compatible version: PyTorch 2.10. - -``STD_CUDA_KERNEL_LAUNCH_CHECK()`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Checks for errors from the most recent CUDA kernel launch. Equivalent to -``STD_CUDA_CHECK(cudaGetLastError())``. - -**Example:** - -.. code-block:: cpp - - my_kernel<<>>(args...); - STD_CUDA_KERNEL_LAUNCH_CHECK(); - -Minimum compatible version: PyTorch 2.10. - -Header-Only Utilities ---------------------- - -The ``torch::headeronly`` namespace provides header-only versions of common -PyTorch types and utilities. These can be used without linking against libtorch, -making them ideal for maintaining binary compatibility across PyTorch versions. - -Error Checking -^^^^^^^^^^^^^^ - -``STD_TORCH_CHECK`` is a header-only macro for runtime assertions: - -.. code-block:: cpp - - #include - - STD_TORCH_CHECK(condition, "Error message with ", variable, " interpolation"); - -Core Types -^^^^^^^^^^ - -The following ``c10::`` types are available as header-only versions under -``torch::headeronly::``: - -- ``torch::headeronly::ScalarType`` - Tensor data types (Float, Double, Int, etc.) -- ``torch::headeronly::DeviceType`` - Device types (CPU, CUDA, etc.) -- ``torch::headeronly::MemoryFormat`` - Memory layout formats (Contiguous, ChannelsLast, etc.) -- ``torch::headeronly::Layout`` - Tensor layouts (Strided, Sparse, etc.) - -.. code-block:: cpp - - #include - #include - #include - #include - - auto dtype = torch::headeronly::ScalarType::Float; - auto device_type = torch::headeronly::DeviceType::CUDA; - auto memory_format = torch::headeronly::MemoryFormat::Contiguous; - auto layout = torch::headeronly::Layout::Strided; - -TensorAccessor -^^^^^^^^^^^^^^ - -``TensorAccessor`` provides efficient, bounds-checked access to tensor data. -You can construct one from a stable tensor's data pointer, sizes, and strides: - -.. code-block:: cpp - - #include - - // Create a TensorAccessor for a 2D float tensor - auto sizes = tensor.sizes(); - auto strides = tensor.strides(); - torch::headeronly::TensorAccessor accessor( - static_cast(tensor.mutable_data_ptr()), - sizes.data(), - strides.data()); - - // Access elements - float value = accessor[i][j]; - -Dispatch Macros -^^^^^^^^^^^^^^^ - -Header-only dispatch macros (THO = Torch Header Only) are available for -dtype and device dispatching: - -.. code-block:: cpp - - #include - - THO_DISPATCH_FLOATING_TYPES(tensor.scalar_type(), "my_kernel", [&] { - // scalar_t is the resolved type - auto* data = tensor.data_ptr(); - }); - -Full API List -^^^^^^^^^^^^^ - -For the complete list of header-only APIs, see ``torch/header_only_apis.txt`` -in the PyTorch source tree. - -Stable Operators ----------------- - -Tensor Creation -^^^^^^^^^^^^^^^ - -.. doxygenfunction:: torch::stable::empty - -.. doxygenfunction:: torch::stable::empty_like - -.. doxygenfunction:: torch::stable::new_empty(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) - -.. doxygenfunction:: torch::stable::new_zeros(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory) - -.. doxygenfunction:: torch::stable::full - -.. doxygenfunction:: torch::stable::from_blob - -Tensor Manipulation -^^^^^^^^^^^^^^^^^^^ - -.. doxygenfunction:: torch::stable::clone - -.. doxygenfunction:: torch::stable::contiguous - -.. doxygenfunction:: torch::stable::reshape - -.. doxygenfunction:: torch::stable::view - -.. doxygenfunction:: torch::stable::flatten - -.. doxygenfunction:: torch::stable::squeeze - -.. doxygenfunction:: torch::stable::unsqueeze - -.. doxygenfunction:: torch::stable::transpose - -.. doxygenfunction:: torch::stable::select - -.. doxygenfunction:: torch::stable::narrow - -.. doxygenfunction:: torch::stable::pad - - -Device and Type Conversion -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. doxygenfunction:: torch::stable::to(const torch::stable::Tensor &self, std::optional dtype, std::optional layout, std::optional device, std::optional pin_memory, bool non_blocking, bool copy, std::optional memory_format) - -.. doxygenfunction:: torch::stable::to(const torch::stable::Tensor &self, torch::stable::Device device, bool non_blocking, bool copy) - -.. doxygenfunction:: torch::stable::fill_ - -.. doxygenfunction:: torch::stable::zero_ - -.. doxygenfunction:: torch::stable::copy_ - -.. doxygenfunction:: torch::stable::matmul - -.. doxygenfunction:: torch::stable::amax(const torch::stable::Tensor &self, int64_t dim, bool keepdim) - -.. doxygenfunction:: torch::stable::amax(const torch::stable::Tensor &self, torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim) - -.. doxygenfunction:: torch::stable::sum - -.. doxygenfunction:: torch::stable::sum_out - -.. doxygenfunction:: torch::stable::subtract - -.. doxygenfunction:: torch::stable::parallel_for - -.. doxygenfunction:: torch::stable::get_num_threads - - -Parallelization Utilities -^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. doxygenfunction:: torch::stable::parallel_for - -.. doxygenfunction:: torch::stable::get_num_threads diff --git a/docs/source/accelerator.md b/docs/source/accelerator.md index c5904563ee711..1af297d315775 100644 --- a/docs/source/accelerator.md +++ b/docs/source/accelerator.md @@ -2,6 +2,7 @@ ```{eval-rst} .. automodule:: torch.accelerator + :no-members: ``` ```{eval-rst} @@ -20,12 +21,23 @@ set_device_idx current_device_index current_device_idx + get_device_capability set_stream current_stream synchronize device_index ``` +## Graphs + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + Graph +``` + ```{eval-rst} .. automodule:: torch.accelerator.memory ``` @@ -40,6 +52,7 @@ :nosignatures: empty_cache + empty_host_cache get_memory_info max_memory_allocated max_memory_reserved diff --git a/docs/source/accelerator/ci.md b/docs/source/accelerator/ci.md new file mode 100644 index 0000000000000..12ec06d41f461 --- /dev/null +++ b/docs/source/accelerator/ci.md @@ -0,0 +1,180 @@ +# CI Integration + +## Background + +Out-of-tree (OOT) accelerator backends need to maintain compatibility with PyTorch's evolving codebase. As PyTorch continues to develop rapidly, changes in the upstream repository can potentially break downstream accelerator integrations. To address this challenge, PyTorch provides a Cross-Repository CI Relay (CRCR) mechanism that enables automatic CI coordination between the PyTorch repository and downstream accelerator repositories. + +This chapter guides third-party accelerator vendors through the process of integrating their repositories with PyTorch's CI ecosystem, ensuring continuous compatibility validation. + +## Why CI/CD Integration Matters + +Integrating with PyTorch's CI ecosystem provides several key benefits: + +* **Early Detection**: Catch compatibility issues before they reach production, reducing debugging effort and user impact. +* **Automated Validation**: Automatically test your accelerator against PyTorch PRs without manual intervention. +* **Reduced Maintenance Burden**: Proactive testing reduces the need for reactive fixes when compatibility breaks. + +## How It Works + +The CRCR system consists of four components: a **GitHub App** that bridges authentication and events, the **PyTorch repository** as the upstream event source, a **Relay Server** that dispatches events to eligible downstream repos, and **downstream repositories** that receive events and optionally report results back. + +When a PR is opened or updated in PyTorch, GitHub notifies the Relay Server via the GitHub App. The Relay Server verifies the event, reads the allowlist, and dispatches a ``repository_dispatch`` event to each registered downstream repository. Downstream repos can optionally report CI results back to the Relay Server, which surfaces them in the PyTorch HUD or as PR check runs. + +```{mermaid} +flowchart TD + PyTorch["PyTorch\n(PR Event)"] -->|webhook| RS["Relay Server\n(Allowlist/Dispatch/Callback)"] + GH["GitHub APP\n(Auth&Bridge)"] <--> RS + RS <--> HUD["HUD\n(Dashboard)"] + RS -->|repo_dispatch| DA["Downstream A\n(e.g. Ascend)"] + RS -->|repo_dispatch| DB[Downstream B] + RS -->|repo_dispatch| DC[Downstream C] + DA -->|callback| RS + DB -->|callback| RS + DC -->|callback| RS +``` + +Participation is governed by a four-tier allowlist: + +* **L1**: Events are forwarded to the downstream repo; no results are reported upstream. +* **L2**: Results are displayed on dedicated HUD pages for the downstream repo. +* **L3**: Non-blocking check runs appear on PyTorch PRs, triggered by maintainer labels. +* **L4**: Blocking check runs run on all PyTorch PRs (reserved for critical accelerators). + +Downstream repos advance through levels by meeting documented requirements around hardware verification, CI reliability, and success rates. + +For a deeper dive into the architecture and design decisions, see the [RFC-0050: Cross-Repository CI Relay for PyTorch Out-of-Tree Backends](https://github.com/fffrog/rfcs/blob/5e138470e962b0f9c5092e564f35bd7fb13b0b2f/RFC-0050-Cross-Repository-CI-Relay-for-PyTorch-Out-of-Tree-Backends.md). + +```{note} +The CRCR currently supports **L1 (Silent)** integration only. +``` + +## Integration Steps + +### Step 1: Install the GitHub App + +Install the [PyTorch Cross-Repo CI Relay](https://github.com/apps/pytorch-fdn-cross-repo-ci-relay) GitHub App on your repository by clicking the **Configure** button and selecting your repository. + +### Step 2: Add Your Repository to the Allowlist + +Submit a pull request to ``pytorch/pytorch`` adding your repository to ``.github/allowlist.yml`` under the ``L1`` key: + +```{eval-rst} +.. code-block:: yaml + + L1: + - your-org/your-accelerator +``` + +See [#180352](https://github.com/pytorch/pytorch/pull/180352) for a reference example. The PyTorch team will review and merge the PR to complete the onboarding. + +### Step 3: Create the Workflow File + +Create a GitHub Actions workflow in your repository to receive ``repository_dispatch`` events: + +```{eval-rst} +.. code-block:: yaml + :caption: .github/workflows/pytorch_ci.yml + + name: PyTorch CI + + run-name: >- + PyTorch CI - + ${{ + github.event.client_payload.event_type == 'pull_request' && + format('PR #{0} ({1})', + github.event.client_payload.payload.pull_request.number, + github.event.client_payload.payload.action) || + format('Push {0}', github.event.client_payload.payload.after) + }} + + on: + repository_dispatch: + types: [pull_request, push] + + concurrency: + group: >- + pytorch-ci-${{ github.event.client_payload.payload.repository.full_name }}-${{ + github.event.client_payload.payload.pull_request.number || github.run_id }} + cancel-in-progress: true + + permissions: + contents: read + + jobs: + cancel-pr: + if: ${{ github.event.client_payload.payload.action == 'closed' }} + runs-on: ubuntu-latest + steps: + - run: echo "PR closed, canceling in-progress runs" + + ci: + if: ${{ github.event.client_payload.payload.action != 'closed' }} + runs-on: ubuntu-latest + steps: + - name: Checkout downstream repo + uses: actions/checkout@v4 + + - name: Checkout PyTorch at triggered commit + uses: actions/checkout@v4 + with: + repository: pytorch/pytorch + ref: >- + ${{ github.event.client_payload.event_type == 'pull_request' && + github.event.client_payload.payload.pull_request.head.sha || + github.event.client_payload.payload.after }} + path: pytorch + + - name: Build and test + run: | + # Your build and test commands + echo "Running tests against PyTorch..." +``` + +### Step 4: Test the Integration + +Verify your integration works correctly: + +1. Create a test PR in PyTorch (or ask maintainers to trigger a test dispatch) +2. Confirm your workflow triggers correctly + +## Event Payload + +The CRCR relay is a stateless pass-through: it forwards the complete GitHub webhook payload as ``client_payload`` in the ``repository_dispatch`` event. There is no simplified intermediary schema. + +The ``client_payload`` has two top-level fields: + +* ``event_type``: either ``pull_request`` or ``push`` +* ``payload``: the raw GitHub webhook payload for that event type + +Commonly used fields: + +```{eval-rst} +.. code-block:: yaml + + github.event.client_payload.event_type # "pull_request" or "push" + github.event.client_payload.payload.action # "opened", "synchronize", "reopened" or "closed" only + github.event.client_payload.payload.pull_request.number # PR number (pull_request events only) + github.event.client_payload.payload.pull_request.head.sha # Head commit SHA to checkout + github.event.client_payload.payload.after # Commit SHA (push events only) +``` + +Supported ``action`` values for ``pull_request`` events: + +| Action | Description | +| ------ | ----------- | +| ``opened`` | New PR created | +| ``synchronized`` | New commits pushed to an existing PR | +| ``reopened`` | Previously closed PR reopened | +| ``closed`` | PR closed or merged; triggers the ``cancel-pr`` job to stop in-progress runs | + +## Troubleshooting + +### Workflow Not Triggering + +1. Confirm your onboarding with the PyTorch team is complete +2. Check that your workflow file is on the default branch +3. Ensure the ``repository_dispatch`` event type in your workflow matches what the relay sends + +## Resources + +* [RFC-0050: Cross-Repository CI Relay for PyTorch Out-of-Tree Backends](https://github.com/pytorch/rfcs/blob/master/RFC-0050-Cross-Repository-CI-Relay-for-PyTorch-Out-of-Tree-Backends.md) diff --git a/docs/source/accelerator/index.md b/docs/source/accelerator/index.md index 6c21b99ec7765..813183484f557 100644 --- a/docs/source/accelerator/index.md +++ b/docs/source/accelerator/index.md @@ -49,6 +49,7 @@ autoload operators amp profiler +ci ``` [OpenReg URL]: https://github.com/pytorch/pytorch/tree/main/test/cpp_extensions/open_registration_extension/torch_openreg "OpenReg URL" diff --git a/docs/source/backends.md b/docs/source/backends.md index d85c53f200054..8d4b60cbd9fce 100644 --- a/docs/source/backends.md +++ b/docs/source/backends.md @@ -24,6 +24,7 @@ These backends include: - `torch.backends.nnpack` - `torch.backends.openmp` - `torch.backends.opt_einsum` +- `torch.backends.python_native` - `torch.backends.xeon` ## torch.backends.cpu @@ -101,10 +102,26 @@ These backends include: .. autofunction:: torch.backends.cuda.preferred_blas_library ``` +```{eval-rst} +.. autofunction:: torch.backends.cuda.cublas_workspace_size +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.cublaslt_workspace_size +``` + +```{eval-rst} +.. autofunction:: torch.backends.cuda.blas_workspace_size +``` + ```{eval-rst} .. autofunction:: torch.backends.cuda.preferred_rocm_fa_library ``` +```{eval-rst} +.. autofunction:: torch.backends.cuda.is_ck_sdpa_available +``` + ```{eval-rst} .. autofunction:: torch.backends.cuda.preferred_linalg_library ``` @@ -396,6 +413,122 @@ These backends include: ``` +## torch.backends.python_native + +```{eval-rst} +.. automodule:: torch.backends.python_native +``` + +The `torch.backends.python_native` module provides user control over native operators implemented in python +via. DSLs (Domain Specific Languages) that are defined in `torch._native`. This allows users to selectively +enable or disable high-performance implementations from various DSLs like Triton and CuteDSL. + +### Module-level Functions + +```{eval-rst} +.. autofunction:: torch.backends.python_native.get_dsl_operations +``` + +```{eval-rst} +.. autofunction:: torch.backends.python_native.disable_operations +``` + +```{eval-rst} +.. autofunction:: torch.backends.python_native.enable_operations +``` + +```{eval-rst} +.. autofunction:: torch.backends.python_native.disable_dispatch_keys +``` + +```{eval-rst} +.. autofunction:: torch.backends.python_native.enable_dispatch_keys +``` + +```{eval-rst} +.. autofunction:: torch.backends.python_native.operations_disabled +``` + +### Module-level Properties + +```{eval-rst} +.. attribute:: available_dsls + + A :class:`list` of :class:`str` containing the names of DSLs that are available at runtime. + This is a subset of :attr:`all_dsls` that have their runtime dependencies satisfied. +``` + +```{eval-rst} +.. attribute:: all_dsls + + A :class:`list` of :class:`str` containing the names of all registered DSLs, whether + available at runtime or not. +``` + +### DSL Controllers + +For each registered DSL (e.g., `triton`, `cutedsl`), auto-populated controller modules are available: + +```{eval-rst} +.. currentmodule:: torch.backends.python_native +``` + +#### DSL Properties + +Each DSL controller (e.g., `torch.backends.python_native.triton`) provides the following properties: + +| Property | Type | Description | +|----------|------|-------------| +| `name` | `str` | The name of the DSL | +| `available` | `bool` | Whether the DSL's runtime dependencies are available | +| `enabled` | `bool` | Controls whether all operations from this DSL are enabled. Setting to `False` disables all operations from the DSL, while `True` re-enables them | +| `version` | `Version` or `None` | The version of the DSL runtime, if available. Returns `None` if the DSL is not available | + +#### DSL Methods + +Each DSL controller provides the following methods: + +**disable()** + Disable all operations from this DSL. + +**enable()** + Re-enable all operations from this DSL. + +**disabled()** + Context manager that temporarily disables all operations from this DSL. + Operations are automatically re-enabled when exiting the context. + + Example:: + + with torch.backends.python_native.triton.disabled(): + # Triton operations are disabled here + result = model(input) + # Triton operations restored here + +### Usage Examples + +```{eval-rst} +.. code-block:: python + + import torch.backends.python_native as pn + + # Query available DSLs + print(pn.available_dsls) # ['triton', 'cutedsl'] + + # Disable all Triton operations + pn.triton.enabled = False + + # Temporarily disable CuteDSL operations + with pn.cutedsl.disabled(): + result = model(input) # CuteDSL ops disabled + + # Disable specific operations across all DSLs + pn.disable_operations('scaled_mm', '_flash_attention_forward') + + # Query operations for a specific DSL + triton_ops = pn.get_dsl_operations('triton') +``` + ## torch.backends.xeon ```{eval-rst} diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index 4b6c553a5dcd1..8fbf8bfa202eb 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -346,9 +346,12 @@ ExecuTorch (Edge, Mobile) - Mergen Nachin (`mergennachin `__) - Jacob Szwejbka (`JacobSzwejbka `__) +- Digant Desai (`digantdesai `__) - (emeritus) Kimish Patel (`kimishpatel `__) - (emeritus) Dave Bort (`dbort `__) - (emeritus) Martin Yuan (`iseeyuan `__) +- (emeritus) Mengwei Liu (`larryliu0820 `__) +- (emeritus) Chen Lai (`cccclai `__) TorchCodec ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/conf.py b/docs/source/conf.py index eae6ac07dfea2..a87c6f91cb80d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -78,6 +78,12 @@ "html_image", ] +# Don't execute notebooks during the docs build. Notebook correctness is +# verified by the separate docs_test CI job; re-executing them here just +# adds ~3 minutes to the build for no benefit. +nb_execution_mode = "off" +suppress_warnings = ["myst-nb.lexer"] + html_baseurl = "https://docs.pytorch.org/docs/stable/" # needed for sphinx-sitemap sitemap_locales = [None] sitemap_excludes = [ @@ -150,7 +156,7 @@ }, ], "show_version_warning_banner": True, - "llm_disabled": False, + "llm_disabled": os.environ.get("CI") and os.environ.get("WITH_PUSH") != "true", "icon_links": [ { "name": "X", @@ -253,31 +259,18 @@ coverage_ignore_functions = [ "main", # utility script "run", # utility script - # torch.distributions.constraints - "is_dependent", # torch.hub "import_module", - # torch.jit - "export_opnames", # torch.jit.unsupported_tensor_ops "execWrapper", - # torch.onnx - "unregister_custom_op_symbolic", - # torch.ao.quantization - "default_eval_fn", # torch.backends "disable_global_flags", "flags_frozen", - # torch.distributed.algorithms.ddp_comm_hooks - "register_ddp_comm_hook", # torch.nn.parallel "DistributedDataParallelCPU", # torch.utils - "set_module", "burn_in_info", "get_info_and_burn_skeleton", - "get_inline_skeleton", - "get_model_info", "get_storage_info", "hierarchical_pickle", # torch.amp.autocast_mode @@ -612,70 +605,33 @@ "make_sharded_output_tensor", # torch.fx.passes.annotate_getitem_nodes "annotate_getitem_nodes", - # torch.fx.passes.backends.cudagraphs - "partition_cudagraphs", # torch.fx.passes.dialect.common.cse_pass "get_CSE_banned_ops", # torch.fx.passes.graph_manipulation "get_size_of_all_nodes", "get_size_of_node", "get_tensor_meta", - "replace_target_nodes_with", - # torch.fx.passes.infra.pass_manager - "pass_result_wrapper", - "this_before_that_pass_constraint", - # torch.fx.passes.operator_support - "any_chain", - "chain", - "create_op_support", - # torch.fx.passes.param_fetch - "default_matching", - "extract_attrs_for_lowering", - "lift_lowering_attrs_to_nodes", - # torch.fx.passes.pass_manager - "inplace_wrapper", - "log_hook", - "loop_pass", - "these_before_those_pass_constraint", - "this_before_that_pass_constraint", - # torch.fx.passes.regional_inductor - "regional_inductor", - # torch.fx.passes.reinplace - "reinplace", # torch.fx.passes.split_module "split_module", # torch.fx.passes.split_utils "getattr_recursive", - "setattr_recursive", "split_by_tags", # torch.fx.passes.splitter_base "generate_inputs_for_submodules", # torch.fx.passes.tools_common "get_acc_ops_name", "get_node_target", - "is_node_output_tensor", "legalize_graph", # torch.fx.passes.utils.common - "compare_graphs", "lift_subgraph_as_module", # torch.fx.passes.utils.fuser_utils - "erase_nodes", "fuse_as_graphmodule", "fuse_by_partitions", "insert_subgm", - "topo_sort", - "validate_partition", # torch.fx.passes.utils.source_matcher_utils - "check_subgraphs_connected", "get_source_partitions", # torch.fx.proxy "assert_fn", - # torch.fx.subgraph_rewriter - "replace_pattern", - "replace_pattern_with_filters", - # torch.fx.tensor_type - "is_consistent", - "is_more_precise", # torch.fx.traceback "format_stack", "get_current_meta", @@ -683,9 +639,9 @@ "preserve_node_meta", "reset_grad_fn_seq_nr", "set_current_meta", + "set_current_replay_node", "set_grad_fn_seq_nr", "set_stack_trace", - "set_current_replay_node", "get_current_replay_node", # torch.jit.annotations "ann_to_type", @@ -726,13 +682,6 @@ "quantize_linear_modules", "quantize_rnn_cell_modules", "quantize_rnn_modules", - # torch.library - "define", - "get_ctx", - "impl", - "impl_abstract", - # torch.masked.maskedtensor.core - "is_masked_tensor", # torch.masked.maskedtensor.creation "as_masked_tensor", "masked_tensor", @@ -1018,19 +967,8 @@ "check_export_model_diff", "verify", "verify_aten_graph", - # torch.optim.optimizer - "register_optimizer_step_post_hook", - "register_optimizer_step_pre_hook", # torch.overrides "enable_reentrant_dispatch", - # torch.package.analyze.find_first_use_of_broken_modules - "find_first_use_of_broken_modules", - # torch.package.analyze.is_from_package - "is_from_package", - # torch.package.analyze.trace_dependencies - "trace_dependencies", - # torch.profiler.itt - "range", # torch.profiler.profiler "schedule", "supported_activities", @@ -1038,8 +976,6 @@ # torch.return_types "pytree_register_structseq", # torch.serialization - "check_module_version_greater_or_equal", - "default_restore_location", "load", "location_tag", "mkdtemp", @@ -1060,8 +996,6 @@ "hann", "kaiser", "nuttall", - # torch.sparse.semi_structured - "to_sparse_semi_structured", # torch.utils.backend_registration "generate_methods_for_privateuse1_backend", "rename_privateuse1_backend", @@ -2502,9 +2436,123 @@ def setup(app): app.connect("autodoc-process-docstring", process_docstring) app.connect("html-page-context", hide_edit_button_for_pages) app.config.add_last_updated = True + + # Force serial reads to avoid pipe congestion from large env pickles. + # Sphinx's parallel read sends the entire environment (100s of MB for + # PyTorch) through a 64KB OS pipe per worker, which causes extreme + # slowdowns with many workers. Serial reads avoid this overhead while + # parallel writes (which send trivial payloads) remain enabled. + from sphinx.builders import Builder + + _orig_read_serial = Builder._read_serial + + def _serial_read_ignoring_nproc(self, docnames, nproc=1): + return _orig_read_serial(self, docnames) + + Builder._read_parallel = _serial_read_ignoring_nproc + + # Skip pickling doctrees to disk for the HTML builder. This is only used + # for incremental rebuilds which don't apply in CI clean builds. Saves + # ~2 minutes by avoiding serializing large autodoc-generated doctrees. + # Other builders (doctest, coverage) may need doctrees on disk. + from sphinx.builders.html import StandaloneHTMLBuilder + + def _write_doctree_no_disk(self, docname, doctree, *, _cache=True): + # Still do the cleanup and in-memory caching, just skip the disk I/O + doctree.reporter = None + doctree.transformer = None + doctree.settings = doctree.settings.copy() + doctree.settings.warning_stream = None + doctree.settings.env = None + from docutils.utils import DependencyList + + doctree.settings.record_dependencies = DependencyList() + if _cache: + self.env._write_doc_doctree_cache[docname] = doctree + + StandaloneHTMLBuilder.write_doctree = _write_doctree_no_disk + + _skip_git_dates_on_ci(app) + _fix_katex_server_race(app) + return {"version": "0.1", "parallel_read_safe": True} +def _fix_katex_server_race(app): + """Retry on ConnectionRefusedError when connecting to the KaTeX server. + + sphinxcontrib-katex 0.9.x starts a Node.js server and connects via Unix + socket. There's a race: Python sees the socket file (created by bind()) + and calls connect() before Node.js has called listen(). On slow CI + machines this causes ConnectionRefusedError. Fix by retrying connect(). + """ + try: + from sphinxcontrib.katex import KaTeXServer + except ImportError: + return + + import socket as _socket + import time as _time + + @classmethod + def _start_with_retry(cls, rundir, timeout): + from subprocess import PIPE, Popen + + from sphinxcontrib.katex import ONE_MILLISECOND + + socket_path = rundir / "katex.sock" + cmd = cls.build_command(socket=socket_path) + process = Popen(cmd, stdin=PIPE, stdout=PIPE, cwd=rundir) + + startup_start = _time.monotonic() + while not socket_path.is_socket(): + _time.sleep(ONE_MILLISECOND) + if _time.monotonic() - startup_start > timeout: + raise cls.timeout_error(timeout) + + # Retry connect() — bind() creates the socket file but listen() + # is async in Node.js and may not be ready yet. + while True: + remaining = startup_start + timeout - _time.monotonic() + if remaining <= 0: + raise cls.timeout_error(timeout) + try: + sock = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM) + sock.settimeout(remaining) + sock.connect(str(socket_path)) + break + except ConnectionRefusedError: + sock.close() + _time.sleep(ONE_MILLISECOND) + except TimeoutError: + sock.close() + raise cls.timeout_error(timeout) # noqa: B904 + + return process, sock + + KaTeXServer.start_server_process = _start_with_retry + + +def _skip_git_dates_on_ci(app): + """Skip git date lookups on non-release CI builds. + + The pytorch theme runs 2 git subprocess calls per page to display + "Created/Updated" dates. On CI preview builds this is wasteful (dates + aren't needed) and problematic (treeless clones make git log slow). + Release builds (WITH_PUSH=true) keep the original behavior so dates + appear in published docs. + """ + if not os.environ.get("CI") or os.environ.get("WITH_PUSH") == "true": + return + + try: + import pytorch_sphinx_theme2 + except ImportError: + return + + pytorch_sphinx_theme2.get_git_dates = lambda path: ("Unknown", "Unknown") + + def hide_edit_button_for_pages(app, pagename, templatename, context, doctree): if pagename.startswith("generated/"): context["theme_use_edit_page_button"] = False diff --git a/docs/source/cuda.md b/docs/source/cuda.md index fc30c4f851d92..929f90a823bc7 100644 --- a/docs/source/cuda.md +++ b/docs/source/cuda.md @@ -165,6 +165,7 @@ :toctree: generated :nosignatures: + caching_allocator_disabled caching_allocator_enable ``` @@ -307,7 +308,19 @@ See the docs for {class}`~torch.cuda.green_contexts.GreenContext` for an example ``` ```{eval-rst} -.. py:module:: torch.cuda.nvtx +.. automodule:: torch.cuda.nvtx +``` + +```{eval-rst} +.. currentmodule:: torch.cuda.nvtx +``` + +```{eval-rst} +.. autofunction:: range_start +``` + +```{eval-rst} +.. autofunction:: range_end ``` ```{eval-rst} diff --git a/docs/source/ddp_comm_hooks.md b/docs/source/ddp_comm_hooks.md index 059c388cd003a..5ead08d6d8564 100644 --- a/docs/source/ddp_comm_hooks.md +++ b/docs/source/ddp_comm_hooks.md @@ -33,6 +33,13 @@ Particularly, {class}`torch.distributed.GradBucket` represents a bucket of gradi .. autofunction:: torch.distributed.GradBucket.parameters ``` +## Register Communication Hook + +```{eval-rst} +.. currentmodule:: torch.distributed.algorithms.ddp_comm_hooks +.. autofunction:: register_ddp_comm_hook +``` + ## Default Communication Hooks Default communication hooks are simple **stateless** hooks, so the input state diff --git a/docs/source/distributed.md b/docs/source/distributed.md index b2699b8530624..6be381685f783 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -663,6 +663,10 @@ with torch.profiler(): Please refer to the [profiler documentation](https://pytorch.org/docs/main/profiler.html) for a full overview of profiler features. +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.record_comm +``` + ## Optimization with Symmetric Memory ### Copy Engine Collectives diff --git a/docs/source/distributed.tensor.md b/docs/source/distributed.tensor.md index 7361c04924675..d768de0e096eb 100644 --- a/docs/source/distributed.tensor.md +++ b/docs/source/distributed.tensor.md @@ -264,7 +264,10 @@ these features. ```{eval-rst} .. autofunction:: register_sharding +``` +```{eval-rst} +.. autofunction:: implicit_replication ``` % modules that are missing docs, add the doc later when necessary diff --git a/docs/source/fx.experimental.md b/docs/source/fx.experimental.md index 6dac091c44e72..d842fa6a4edf3 100644 --- a/docs/source/fx.experimental.md +++ b/docs/source/fx.experimental.md @@ -77,8 +77,6 @@ These APIs are experimental and subject to change without notice. InnerTensorKey Specialization - hint_int - size_hint is_concrete_int is_concrete_bool is_concrete_float @@ -116,11 +114,13 @@ These APIs are experimental and subject to change without notice. guard_float guard_int guard_scalar - has_hint + guarding_hint_or_throw + has_guarding_hint has_symbolic_sizes_strides is_nested_int is_symbol_binding_fx_node is_symbolic + optimization_hint expect_true log_lru_cache_stats ``` @@ -362,6 +362,7 @@ These APIs are experimental and subject to change without notice. view_inference_rule register_inference_rule transpose_inference_rule + range_check ``` ## torch.fx.experimental.migrate_gradual_types.constraint_transformation diff --git a/docs/source/fx.md b/docs/source/fx.md index 426f19a5f6689..50181b3c189e8 100644 --- a/docs/source/fx.md +++ b/docs/source/fx.md @@ -8,6 +8,7 @@ ## Overview ```{eval-rst} .. automodule:: torch.fx + :no-members: ``` @@ -1176,6 +1177,236 @@ The set of leaf modules can be customized by overriding annotate_fn ``` +## torch.fx.subgraph_rewriter + +```{eval-rst} +.. currentmodule:: torch.fx.subgraph_rewriter +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + replace_pattern + replace_pattern_with_filters +``` + +## torch.fx.tensor_type + +```{eval-rst} +.. currentmodule:: torch.fx.tensor_type +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + is_consistent + is_more_precise +``` + +## torch.fx.passes.backends.cudagraphs + +```{eval-rst} +.. currentmodule:: torch.fx.passes.backends.cudagraphs +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + partition_cudagraphs +``` + +## torch.fx.passes.graph_manipulation + +```{eval-rst} +.. currentmodule:: torch.fx.passes.graph_manipulation +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + replace_target_nodes_with +``` + +## torch.fx.passes.infra.pass_manager + +```{eval-rst} +.. currentmodule:: torch.fx.passes.infra.pass_manager +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + pass_result_wrapper + this_before_that_pass_constraint +``` + +## torch.fx.passes.operator_support + +```{eval-rst} +.. currentmodule:: torch.fx.passes.operator_support +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + any_chain + chain + create_op_support +``` + +## torch.fx.passes.param_fetch + +```{eval-rst} +.. automodule:: torch.fx.passes.param_fetch + :no-members: +``` + +```{eval-rst} +.. currentmodule:: torch.fx.passes.param_fetch +``` + +```{eval-rst} +.. autofunction:: default_matching +``` + +```{eval-rst} +.. autofunction:: extract_attrs_for_lowering +``` + +```{eval-rst} +.. autofunction:: lift_lowering_attrs_to_nodes +``` + +## torch.fx.passes.pass_manager + +```{eval-rst} +.. currentmodule:: torch.fx.passes.pass_manager +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + inplace_wrapper + log_hook + loop_pass + these_before_those_pass_constraint + this_before_that_pass_constraint +``` + +## torch.fx.passes.regional_inductor + +```{eval-rst} +.. currentmodule:: torch.fx.passes.regional_inductor +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + regional_inductor +``` + +## torch.fx.passes.reinplace + +```{eval-rst} +.. currentmodule:: torch.fx.passes.reinplace +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + reinplace +``` + +## torch.fx.passes.split_utils + +```{eval-rst} +.. currentmodule:: torch.fx.passes.split_utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + setattr_recursive +``` + +## torch.fx.passes.tools_common + +```{eval-rst} +.. currentmodule:: torch.fx.passes.tools_common +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + is_node_output_tensor +``` + +## torch.fx.passes.utils.common + +```{eval-rst} +.. currentmodule:: torch.fx.passes.utils.common +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + compare_graphs +``` + +## torch.fx.passes.utils.fuser_utils + +```{eval-rst} +.. currentmodule:: torch.fx.passes.utils.fuser_utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + erase_nodes + topo_sort + validate_partition +``` + +## torch.fx.passes.utils.source_matcher_utils + +```{eval-rst} +.. currentmodule:: torch.fx.passes.utils.source_matcher_utils +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + check_subgraphs_connected +``` + @@ -1218,7 +1449,7 @@ The set of leaf modules can be customized by overriding .. py:module:: torch.fx.passes.infra.pass_manager .. py:module:: torch.fx.passes.net_min_base .. py:module:: torch.fx.passes.operator_support -.. py:module:: torch.fx.passes.param_fetch + .. py:module:: torch.fx.passes.pass_manager .. py:module:: torch.fx.passes.regional_inductor .. py:module:: torch.fx.passes.reinplace diff --git a/docs/source/masked.md b/docs/source/masked.md index 6a79a6be2ea95..af9bc9911720f 100644 --- a/docs/source/masked.md +++ b/docs/source/masked.md @@ -304,6 +304,20 @@ The following ops are currently supported: Tensor.view ``` +## torch.masked.maskedtensor.core + +```{eval-rst} +.. currentmodule:: torch.masked.maskedtensor.core +``` + +```{eval-rst} +.. autosummary:: + :toctree: generated + :nosignatures: + + is_masked_tensor +``` + ```{eval-rst} .. This module needs to be documented. Adding here in the meantime .. for tracking purposes diff --git a/docs/source/multiprocessing.md b/docs/source/multiprocessing.md index 6669fcaa24b30..74944ca58ed20 100644 --- a/docs/source/multiprocessing.md +++ b/docs/source/multiprocessing.md @@ -206,6 +206,12 @@ terminate processes upon detecting an error in one of them. % for tracking purposes +## torch.multiprocessing.pool + +```{eval-rst} +.. currentmodule:: torch.multiprocessing.pool +``` + ```{eval-rst} .. py:module:: torch.multiprocessing.pool ``` diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index 15c67f4ab67d8..399bf6a19bdb9 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -627,6 +627,24 @@ Available options: on all the CUDA devices to a specified fraction of the available memory. This is a value between 0 and 1. Attempting to allocate more memory will raise an out of memory error. +* ``throw_on_cudamalloc_oom`` (default: ``False``) + When set to ``True``, the allocator will preemptively reject allocations that would + exceed the ``per_process_memory_fraction`` limit by throwing an ``OutOfMemoryError`` + instead of attempting the allocation through the driver, which could trigger a fatal + GPU runtime abort (e.g., ``HSA_STATUS_ERROR_OUT_OF_RESOURCES`` on ROCm). The rejected + allocation skips the retry chain and throws immediately, allowing the caller to catch + the exception and handle it gracefully. + + Example configuration: + + .. code-block:: bash + + PYTORCH_CUDA_ALLOC_CONF=per_process_memory_fraction:0.95,throw_on_cudamalloc_oom:True + + This is particularly useful for inference serving, where a fatal GPU OOM would crash the + server process. With this option, the serving framework can catch the ``OutOfMemoryError`` + and reject the individual request while continuing to serve subsequent requests. + .. note:: Some stats reported by the diff --git a/docs/source/notes/ddp.rst b/docs/source/notes/ddp.rst index 07804f8bae166..4f74114fabce8 100644 --- a/docs/source/notes/ddp.rst +++ b/docs/source/notes/ddp.rst @@ -199,7 +199,7 @@ DistributedDataParallel parameters when ``find_unused_parameters`` is set to ``True`` in DDP constructor. -.. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png +.. image:: https://github.com/user-attachments/assets/b6fe5258-f1f1-4c73-8e1d-6fd96406faa2 :alt: ddp_code.png :width: 400 px diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index c3e1b664a2197..41d59f32eb80f 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -142,7 +142,7 @@ the autograd engine. tensors filled with zeros. The default value of this setting is True. In addition to ``ctx`` methods, the :class:`~Function` class supports the following -class attribute: +class attributes: - :attr:`~Function.clear_saved_tensors_on_access`: When set to ``True`` on the :class:`~Function` subclass, accessing ``ctx.saved_tensors`` in the backward pass @@ -153,6 +153,50 @@ class attribute: needed once. The default is ``False``. Note that ``saved_tensors`` can only be accessed once when this is enabled; a second access will raise an error. +- :attr:`~Function.boxed_grads_call`: When set to ``True`` on the + :class:`~Function` subclass, backward receives grads as a single mutable list + argument instead of individual args in an immutable tuple. This allows backward + to free individual grads mid-execution by removing them from the list, reducing + peak memory. When enabled, the backward calling convention changes from + ``backward(ctx, *grads) -> Tuple[Tensor, ...]`` to ``backward(ctx, grads) -> Tuple[Tensor, ...]`` where ``grads`` is a list. + The return convention is unchanged. + The default is ``False``. + +Here is an example using ``boxed_grads_call`` to reduce peak memory in the backward +pass of a QKV projection — a common building block of transformer models. The forward +pass projects an input into three separate tensors (Q, K, V), so backward receives +three gradients. Without ``boxed_grads_call``, all three grad tensors are held alive +for the entire backward call because they are unpacked from an immutable tuple. +With ``boxed_grads_call``, each grad can be freed as soon as it is consumed, so peak +memory is reduced by up to 2/3 of the total grad memory:: + + class QKVProjection(Function): + """Projects input x into Q, K, V: q = x @ w_q, k = x @ w_k, v = x @ w_v.""" + boxed_grads_call = True + + @staticmethod + def forward(ctx, x, w_q, w_k, w_v): + ctx.save_for_backward(x, w_q, w_k, w_v) + return x.mm(w_q), x.mm(w_k), x.mm(w_v) + + @staticmethod + def backward(ctx, grads): + x, w_q, w_k, w_v = ctx.saved_tensors + grad_x = torch.zeros_like(x) + grad_weights = [] + + # Process each grad independently and free it immediately. + # Without boxed_grads_call, all three grads would stay alive + # until backward returns, tripling peak grad memory. + for i, w in enumerate((w_q, w_k, w_v)): + grad_out = grads[i] + grads[i] = None # Release reference in the caller's list + grad_x += grad_out.mm(w.t()) + grad_weights.append(x.t().mm(grad_out)) + del grad_out # grad_out can now be freed by the runtime + + return grad_x, *grad_weights + **Step 3:** If your :class:`~Function` does not support double backward you should explicitly declare this by decorating backward with the :func:`~function.once_differentiable`. With this decorator, attempts to diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index dc0cb984c47fc..e1c098872d968 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -26,7 +26,7 @@ Intel Client GPU +---------------------------------------+-----------------------------------------------------------------------------------------------------+ | Supported OS | Validated Hardware | +=======================================+=====================================================================================================+ -| Windows 11 & Ubuntu 24.04/25.04 | | Intel® Arc A-Series Graphics (CodeName: Alchemist) | +| Windows 11 & Ubuntu 24.04/25.10 | | Intel® Arc A-Series Graphics (CodeName: Alchemist) | | | | Intel® Arc B-Series Graphics (CodeName: Battlemage) | | | | Intel® Core™ Ultra Processors with Intel® Arc™ Graphics (CodeName: Meteor Lake-H) | | | | Intel® Core™ Ultra Processors (Series 2) with Intel® Arc™ Graphics (CodeName: Arrow Lake-H) | @@ -42,7 +42,7 @@ Software Prerequisite To use PyTorch on Intel GPUs, you need to install the Intel GPUs driver first. For installation guide, visit `Intel GPUs Driver Installation `_. -Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. +Please skip the Intel® Deep Learning Essentials installation section if you install from binaries. For building from source, please refer to `PyTorch Installation Prerequisites for Intel GPUs `_ for both Intel GPU Driver and Intel® Deep Learning Essentials Installation. Installation @@ -74,17 +74,17 @@ To install the latest preview/nightly wheels: Previous Versions ~~~~~~~~~~~~~~~~~ -**v2.10.0** +**v2.11.0** .. code-block:: bash - pip install torch==2.10.0 torchvision==0.25.0 torchaudio==2.10.0 --index-url https://download.pytorch.org/whl/xpu + pip install torch==2.11.0 torchvision==0.26.0 torchaudio==2.11.0 --index-url https://download.pytorch.org/whl/xpu -**v2.9.1** +**v2.10.0** .. code-block:: bash - pip install torch==2.9.1 torchvision==0.24.1 torchaudio==2.9.1 --index-url https://download.pytorch.org/whl/xpu + pip install torch==2.10.0 torchvision==0.25.0 torchaudio==2.10.0 --index-url https://download.pytorch.org/whl/xpu .. note:: @@ -93,7 +93,7 @@ Previous Versions From Source ^^^^^^^^^^^ -Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed. Follow guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. +Now that we have `Intel GPU Driver and Intel® Deep Learning Essentials `_ installed, follow the guides to build ``pytorch``, ``torchvision``, ``torchaudio`` from source. Build from source for ``torch`` refer to `PyTorch Installation Build from source `_. @@ -129,7 +129,7 @@ If you are migrating code from ``cuda``, you would change references from ``cuda The following points outline the support and limitations for PyTorch with Intel GPU: #. Both training and inference workflows are supported. -#. Both eager mode and ``torch.compile`` is supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to use torch.compile on Windows CPU/XPU `_. +#. Both eager mode and ``torch.compile`` are supported. The feature ``torch.compile`` is also supported on Windows from PyTorch* 2.7 with Intel GPU, refer to `How to use torch.compile on Windows CPU/XPU `_. #. Data types such as FP32, BF16, FP16, and Automatic Mixed Precision (AMP) are all supported. Examples @@ -226,7 +226,7 @@ Inference with ``torch.compile`` Training Examples ^^^^^^^^^^^^^^^^^ -Here is a few training workflow examples. +Here are a few training workflow examples. Train with FP32 ~~~~~~~~~~~~~~~ diff --git a/docs/source/optim.md b/docs/source/optim.md index c389fa6c37d2b..ffbc172ede073 100644 --- a/docs/source/optim.md +++ b/docs/source/optim.md @@ -149,6 +149,18 @@ for input, target in dataset: Optimizer.zero_grad ``` +## Module-level hooks + +```{eval-rst} +.. currentmodule:: torch.optim.optimizer + +.. autofunction:: register_optimizer_step_post_hook + +.. autofunction:: register_optimizer_step_pre_hook + +.. currentmodule:: torch.optim +``` + ## Algorithms ```{eval-rst} diff --git a/docs/source/package.md b/docs/source/package.md index 1b50f743d5793..dcc699cf545f6 100644 --- a/docs/source/package.md +++ b/docs/source/package.md @@ -726,12 +726,23 @@ statements more clearly show whether they are referring to packaged code or not. :members: ``` - +## Analysis Utilities ```{eval-rst} .. py:module:: torch.package.analyze.find_first_use_of_broken_modules .. py:module:: torch.package.analyze.is_from_package .. py:module:: torch.package.analyze.trace_dependencies + +.. currentmodule:: torch.package.analyze.find_first_use_of_broken_modules + +.. autofunction:: find_first_use_of_broken_modules + +.. currentmodule:: torch.package.analyze.is_from_package + +.. autofunction:: is_from_package + +.. currentmodule:: torch.package.analyze.trace_dependencies + +.. autofunction:: trace_dependencies .. py:module:: torch.package.file_structure_representation .. py:module:: torch.package.find_file_dependencies .. py:module:: torch.package.glob_group diff --git a/docs/source/profiler.md b/docs/source/profiler.md index 1578b7334d849..ca65df06da223 100644 --- a/docs/source/profiler.md +++ b/docs/source/profiler.md @@ -11,11 +11,10 @@ ## API Reference ```{eval-rst} -.. autoclass:: torch.profiler._KinetoProfile - :members: .. autoclass:: torch.profiler.profile :members: + :inherited-members: .. autoclass:: torch.profiler.ProfilerAction :members: @@ -26,6 +25,8 @@ .. autofunction:: torch.profiler.schedule .. autofunction:: torch.profiler.tensorboard_trace_handler + +.. autofunction:: torch.profiler.supported_activities ``` ## Intel Instrumentation and Tracing Technology APIs @@ -38,6 +39,8 @@ .. autofunction:: torch.profiler.itt.range_push .. autofunction:: torch.profiler.itt.range_pop + +.. autofunction:: torch.profiler.itt.range ``` )) \+\d+ ([A-Za-z0-9_]+)", "\1 \3", @@ -285,7 +285,7 @@ def f(x): y = g(y) return y + 3 - def munge_filenames(s): # noqa: F841 + def munge_filenames(s): return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s) f(torch.randn(2)) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index f65d860fab64a..db7314ef5f84f 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -2,12 +2,18 @@ import contextlib import sys import unittest +from collections import defaultdict from contextlib import contextmanager import torch import torch._dynamo.test_case import torch._dynamo.testing -from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same +from torch._dynamo.testing import ( + check_dynamic_shape_capture, + EagerAndRecordGraphs, + normalize_gm, + same, +) from torch._dynamo.utils import counters from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -17,6 +23,8 @@ ) +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + try: from . import test_functions except ImportError: @@ -76,7 +84,7 @@ def customized_ctx_manager_with_graph_break(mode): torch._C._set_grad_enabled(prev) -class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class CtxManagerTests(torch._dynamo.test_case.TestCase): def test_no_grad(self): def fn1(a, b): x = a + 1 @@ -568,7 +576,7 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(cnts.frame_count, 1) - self.assertExpectedInline(str(cnts.op_count), """16""") + self.assertExpectedInline(str(cnts.op_count), """17""") @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_cuda_device(self): @@ -583,6 +591,17 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_cuda__exchange_device_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(args, kwargs): + torch.cuda._exchange_device(*args, **kwargs) + + initial_dev = torch.cuda.current_device() + for args, kwargs in (((), ()), ((0, 0), ()), ((), ("kwarg",))): + self.assertRaises(torch._dynamo.exc.Unsupported, fn, args, kwargs) + self.assertEqual(torch.cuda.current_device(), initial_dev) + def test_autograd_profiler_enabled(self): def fn(x): if torch.autograd._profiler_enabled(): @@ -687,7 +706,7 @@ def test_autocast_sdpa(self): class MyModule(torch.nn.Module): def forward(self, query, key, value): with torch.autocast("cpu"): - with torch.autocast("cuda", dtype=torch.float32): + with torch.autocast(device_type, dtype=torch.float32): out = F.scaled_dot_product_attention( query, key, value, None, 0.0, True ) @@ -698,13 +717,31 @@ def forward(self, query, key, value): seq_len_k = 1 head_dim = 8 query = torch.ones( - 1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True + 1, + 8, + seq_len_q, + head_dim, + device=device_type, + dtype=dtype, + requires_grad=True, ) key = torch.ones( - 1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True + 1, + 8, + seq_len_k, + head_dim, + device=device_type, + dtype=dtype, + requires_grad=True, ) value = torch.ones( - 1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True + 1, + 8, + seq_len_k, + head_dim, + device=device_type, + dtype=dtype, + requires_grad=True, ) module = MyModule() @@ -718,7 +755,7 @@ def forward(self, query, key, value): self.assertEqual(compiled.device, real_device) self.assertEqual(compiled.dtype, real_dtype) - self.assertEqual(compiled.device.type, "cuda") + self.assertEqual(compiled.device.type, device_type) self.assertEqual(compiled.device.index, 0) self.assertEqual(compiled.dtype, torch.float32) @@ -1012,6 +1049,348 @@ def fn(a, b): self.assertTrue(res[0].dtype == torch.float16) self.assertTrue(res[1].dtype == torch.float16) + def test__enter__exit_autocast(self): + def f(x, y): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch.amp.autocast_mode._exit_autocast(m) + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + z = f(x, y) + opt_z = opt_f(x, y) + self.assertEqual(z, opt_z) + self.assertEqual(z.dtype, opt_z.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]", L_y_: "f32[s77, s77]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[s77, s77]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", + ) + + def test__enter__exit_autocast_graph_break(self): + def f(x, y, z): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch._dynamo.graph_break() + x = x @ z + # At this point m is wrapped as an AutocastModeVariable, which will graph break on the __exit__ call + torch.amp.autocast_mode._exit_autocast(m) + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=False) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + z = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y, z) + opt_out = opt_f(x, y, z) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "f32[s77, s77]", L_y_: "f32[s77, s77]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[s77, s77]" = l_x_ @ l_y_; l_x_ = l_y_ = None + return (x,) +""", + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + return (x,) +""", + ) + + # Doesn't include autocast functions, see comment above + graph = eager.graphs[1] + actual = normalize_gm(graph.print_readable(False)) + + if check_dynamic_shape_capture(): + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, s77: "Sym(s77)", L_x_: "bf16[s77, s77]", L_z_: "f32[s77, s77]"): + l_x_ = L_x_ + l_z_ = L_z_ + + x: "bf16[s77, s77]" = l_x_ @ l_z_; l_x_ = l_z_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', False); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + return (x,) +""", + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "bf16[3, 3]", L_z_: "f32[3, 3]"): + l_x_ = L_x_ + l_z_ = L_z_ + + x: "bf16[3, 3]" = l_x_ @ l_z_; l_x_ = l_z_ = None + + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', False); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + return (x,) +""", + ) + + def test_autocast_low_level_api(self): + def f(x, y): + torch.set_autocast_enabled("cpu", True) + torch.set_autocast_dtype("cpu", torch.bfloat16) + torch.set_autocast_cache_enabled(True) + x = x @ y + torch.autocast_decrement_nesting() + torch.clear_autocast_cache() + torch.set_autocast_enabled("cpu", False) + return x + + prev_enabled = torch.is_autocast_enabled("cpu") + prev_dtype = torch.get_autocast_dtype("cpu") + prev_cache = torch.is_autocast_cache_enabled() + + try: + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y) + opt_out = opt_f(x, y) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + finally: + torch.set_autocast_enabled("cpu", prev_enabled) + torch.set_autocast_dtype("cpu", prev_dtype) + torch.set_autocast_cache_enabled(prev_cache) + + def test__enter__exit_autocast_function_mode(self): + class FunctionCount(torch.overrides.TorchFunctionMode): + def __init__(self): + self.counts = defaultdict(int) + + def __torch_function__(self, func, types, args, kwargs=None): + self.counts[func] += 1 + return func(*args, **(kwargs or {})) + + def f(x, y): + m = torch.amp.autocast_mode._enter_autocast("cpu") + x = x @ y + torch.amp.autocast_mode._exit_autocast(m) + return x + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + with FunctionCount() as fc: + z = f(x, y) + self.assertEqual(fc.counts[torch.amp.autocast_mode._enter_autocast], 1) + self.assertEqual(fc.counts[torch.amp.autocast_mode._exit_autocast], 1) + with FunctionCount() as fc: + opt_z = opt_f(x, y) + self.assertEqual(fc.counts[torch.amp.autocast_mode._enter_autocast], 1) + self.assertEqual(fc.counts[torch.amp.autocast_mode._exit_autocast], 1) + self.assertEqual(z, opt_z) + self.assertEqual(z.dtype, opt_z.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + + def test__enter__exit_autocast_non_idempotent(self): + # Recompile trick doesn't work with dynamic shapes + if check_dynamic_shape_capture(): + return + + def f(x, y): + with torch.amp.autocast("cpu"): + x = x @ y + return x + + eager = EagerAndRecordGraphs() + opt_f = torch.compile(f, backend=eager, fullgraph=False) + x = torch.randn(3, 3, dtype=torch.float32) + y = torch.randn(3, 3, dtype=torch.float32) + out = f(x, y) + opt_out = opt_f(x, y) + self.assertEqual(out, opt_out) + self.assertEqual(out.dtype, opt_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]"): + l_x_ = L_x_ + l_y_ = L_y_ + + _enter_autocast = torch.amp.autocast_mode._enter_autocast('cpu', None, True, None) + + x: "bf16[3, 3]" = l_x_ @ l_y_; l_x_ = l_y_ = None + + _exit_autocast = torch.amp.autocast_mode._exit_autocast(_enter_autocast); _enter_autocast = _exit_autocast = None + return (x,) +""", + ) + + # Recompiling will decompose the _enter_autocast and _exit_autocast calls to lower level autocast functions + eager = EagerAndRecordGraphs() + d = {} + exec(actual, globals(), d) + retraced = torch.compile(d["GraphModule"], backend=eager, fullgraph=True) + retraced_out = retraced()(x, y)[0] + self.assertEqual(out, retraced_out) + self.assertEqual(out.dtype, retraced_out.dtype) + self.assertFalse(torch.is_autocast_enabled("cpu")) + + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_L_x_: "f32[3, 3]", L_L_y_: "f32[3, 3]"): + l_l_x_ = L_L_x_ + l_l_y_ = L_L_y_ + + _is_autocast_available = torch._C._is_autocast_available('cpu'); _is_autocast_available = None + + set_autocast_enabled = torch.set_autocast_enabled('cpu', True); set_autocast_enabled = None + + set_autocast_dtype = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype = None + + autocast_increment_nesting = torch.autocast_increment_nesting(); autocast_increment_nesting = None + set_autocast_cache_enabled = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled = None + x: "bf16[3, 3]" = l_l_x_ @ l_l_y_; l_l_x_ = l_l_y_ = None + autocast_decrement_nesting = torch.autocast_decrement_nesting(); autocast_decrement_nesting = None + + clear_autocast_cache = torch.clear_autocast_cache(); clear_autocast_cache = None + + set_autocast_enabled_1 = torch.set_autocast_enabled('cpu', False); set_autocast_enabled_1 = None + + set_autocast_dtype_1 = torch.set_autocast_dtype('cpu', torch.bfloat16); set_autocast_dtype_1 = None + + set_autocast_cache_enabled_1 = torch.set_autocast_cache_enabled(True); set_autocast_cache_enabled_1 = None + return (x,) +""", + ) + @parametrize( "Ctx", [CustomizedCtxManagerWithGraphBreak, customized_ctx_manager_with_graph_break], @@ -1296,7 +1675,7 @@ def forward(self): _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (add,) -""", # NOQA: B950 +""", ) def test_disable_saved_tensors_hooks_prev_disabled(self): @@ -1339,7 +1718,7 @@ def forward(self): _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_1 = None return (add,) -""", # NOQA: B950 +""", ) def test_disable_saved_tensors_hooks_prev_disabled_nested(self): @@ -1394,7 +1773,7 @@ def forward(self): _saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_3 = None return (add_1,) -""", # NOQA: B950 +""", ) def test_disable_saved_tensors_hooks_graph_break(self): @@ -1409,7 +1788,7 @@ def fn(x): eager = EagerAndRecordGraphs() torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(())) - def check_graph(actual, expected): # noqa: F841 + def check_graph(actual, expected): self.assertExpectedInline(actual, expected) graph = eager.graphs[0] @@ -1427,7 +1806,7 @@ def forward(self, L_x_: "f32[]"): _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (y,) -""", # NOQA: B950 +""", ) graph = eager.graphs[1] @@ -1445,7 +1824,37 @@ def forward(self, L_y_: "f32[]"): _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None return (mul,) -""", # NOQA: B950 +""", + ) + + def test__saved_tensors_hooks_disable(self): + def fn(x): + y = x + 1 + torch._C._autograd._saved_tensors_hooks_disable("This is not supported") + y *= 2 + torch._C._autograd._saved_tensors_hooks_enable() + return y + + eager = EagerAndRecordGraphs() + torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(())) + graph = eager.graphs[0] + actual = normalize_gm(graph.print_readable(False)) + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[]"): + l_x_ = L_x_ + + y: "f32[]" = l_x_ + 1; l_x_ = None + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None + + y *= 2; y_1: "f32[]" = y; y = None + + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None + return (y_1,) +""", ) def test_context_wrapping_grad_mode_decorator(self): @@ -1902,9 +2311,7 @@ def fn(x): self.assertGreater(len(counters["graph_break"]), 0) -class ContextlibContextManagerTests( - torch._dynamo.test_case.TestCaseWithNestedGraphBreaks -): +class ContextlibContextManagerTests(torch._dynamo.test_case.TestCase): def setUp(self): super().setUp() self._prev = torch._dynamo.config.enable_trace_contextlib @@ -2863,6 +3270,90 @@ def fn(t): y = fn(t) self.assertEqual(y, t.sin()) + @parametrize("gb", (True, False)) + def test_functorch_low_level(self, gb): + def f(x, gb): + level = torch._C._functorch._grad_increment_nesting() + torch._C._functorch.set_inplace_requires_grad_allowed(True) + torch._functorch.eager_transforms._set_tensor_requires_grad(x) + if gb: + torch._dynamo.graph_break() + torch._C._functorch.set_inplace_requires_grad_allowed(False) + torch._C._functorch._grad_decrement_nesting() + return x + level + + prev_inplace = torch._C._functorch.get_inplace_requires_grad_allowed() + prev_level = torch._C._functorch.maybe_current_level() + opt_f = torch.compile(f, fullgraph=not gb, backend="eager") + x = torch.randn(3, 3, requires_grad=False) + opt_y = opt_f(x, gb) + self.assertTrue(x.requires_grad) + y = f(x, gb) + self.assertEqual(y, opt_y) + self.assertEqual(torch._C._functorch.maybe_current_level(), prev_level) + self.assertEqual( + torch._C._functorch.get_inplace_requires_grad_allowed(), prev_inplace + ) + + def test_retrace_grad(self): + # Recompile trick doesn't work with dynamic shapes + if check_dynamic_shape_capture(): + return + + def fn(x): + return x.sin().sum() + + def wrapper_fn(x): + return torch.func.grad(fn)(x) + + x = torch.randn(3, 3) + eager = EagerAndRecordGraphs() + opt_f = torch.compile(wrapper_fn, backend=eager, fullgraph=True) + y = wrapper_fn(x) + opt_y = opt_f(x) + self.assertEqual(y, opt_y) + first_graph = normalize_gm(eager.graphs[0].print_readable(False)) + self.assertExpectedInline( + first_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]"): + l_x_ = L_x_ + + _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable = None + _grad_increment_nesting = torch._C._functorch._grad_increment_nesting(); _grad_increment_nesting = None + + diff_args: "f32[3, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 1); l_x_ = None + + set_inplace_requires_grad_allowed = torch._C._functorch.set_inplace_requires_grad_allowed(True); set_inplace_requires_grad_allowed = None + + _set_tensor_requires_grad: "f32[3, 3]" = torch._functorch.eager_transforms._set_tensor_requires_grad(diff_args); _set_tensor_requires_grad = None + + set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None + + sin: "f32[3, 3]" = diff_args.sin() + output: "f32[]" = sin.sum(); sin = None + + _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None + grad_input: "f32[3, 3]" = _autograd_grad[0]; _autograd_grad = None + + grad_input_1: "f32[3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None + + _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None + _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None + return (grad_input_1,) +""", + ) + + d = {} + exec(first_graph, globals(), d) + retraced = torch.compile(d["GraphModule"], backend=eager, fullgraph=True) + retraced_out = retraced()(x)[0] + self.assertEqual(y, retraced_out) + retraced_graph = normalize_gm(eager.graphs[0].print_readable(False)) + self.assertEqual(first_graph, retraced_graph) + instantiate_parametrized_tests(CtxManagerTests) instantiate_parametrized_tests(ContextlibContextManagerTests) diff --git a/test/dynamo/test_debug_utils.py b/test/dynamo/test_debug_utils.py index f560460a320d7..3dc06c350d7f7 100644 --- a/test/dynamo/test_debug_utils.py +++ b/test/dynamo/test_debug_utils.py @@ -43,7 +43,7 @@ def forward(self, x_1): full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) return (convert_element_type, _to_copy, full, empty) - """, # NOQA: B950 + """, ) _, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) @@ -58,7 +58,7 @@ def forward(self, x_1): full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) return (convert_element_type, _to_copy, full, empty) - """, # NOQA: B950 + """, ) @patch.dict( @@ -190,7 +190,7 @@ def forward( def test_sym_aot_graph_parser(self, device): def forward( self, - primals_1: "f32[1001, 6]", # noqa: F821 + primals_1: "f32[1001, 6]", primals_2: "f32[s0]", # noqa: F821 primals_3: "Sym(s0)", # noqa: F821, primals_4: "f32[s1]", # noqa: F821, @@ -325,9 +325,12 @@ def test_multiple_rules(self, device): result = self._run_with_override(device, "0:aot_eager;1:inductor;3:eager") self.assertEqual(result, ["aot_eager", "inductor", "eager"]) - def test_first_rule_wins(self, device): - result = self._run_with_override(device, ">=0:aot_eager;>=1:inductor") - self.assertEqual(result, ["aot_eager", "aot_eager", "aot_eager", "aot_eager"]) + def test_conflicting_rules_raise(self, device): + with self.assertRaisesRegex( + torch._dynamo.exc.InternalTorchDynamoError, + "Conflicting backend override", + ): + self._run_with_override(device, ">=0:aot_eager;>=1:inductor") def test_complex_config(self, device): result = self._run_with_override(device, "0:aot_eager;>=2:inductor") @@ -468,6 +471,18 @@ def test_config_router_aggregation_multiple_rules(self, device): self.assertEqual(router.get_value_for_graph(1), {"b": 2, "c": 3}) self.assertEqual(router.get_value_for_graph(2), {"c": 3}) + def test_backend_router_conflict_raises(self, device): + from torch._dynamo.graph_id_filter import GraphBackendRouter + + with self.assertRaisesRegex(ValueError, "Conflicting backend override"): + GraphBackendRouter("0-5:eager;3-10:inductor") + + def test_backend_router_same_backend_no_conflict(self, device): + from torch._dynamo.graph_id_filter import GraphBackendRouter + + router = GraphBackendRouter("0:eager;>=0:eager") + self.assertIsNotNone(router.get_value_for_graph(0)) + def test_get_inductor_config_override_empty(self, device): from torch._dynamo.graph_id_filter import ( get_inductor_config_override_for_compile_id, diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 871582ebd0d71..7240a5fc8a108 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import copy import functools import operator import os @@ -8,15 +7,11 @@ from unittest.mock import patch import torch -import torch._dynamo.config as config import torch._dynamo.testing -from torch._dynamo.decorators import leaf_function from torch._dynamo.exc import Unsupported -from torch._dynamo.testing import normalize_gm from torch._dynamo.utils import counters from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - parametrize, skipIfWindows, ) from torch.testing._internal.dynamo_pytree_test_utils import PytreeRegisteringTestCase @@ -221,11 +216,6 @@ def fn(a): self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 5) - def test_allow_in_graph_deprecation_warning(self): - with self.assertWarnsRegex(FutureWarning, "nonstrict_trace"): - torch._dynamo.allow_in_graph(my_custom_function) - torch._dynamo.disallow_in_graph(my_custom_function) - def test_allow_in_graph_no_id_reuse(self): cnts = torch._dynamo.testing.CompileCounter() @@ -254,7 +244,7 @@ def fn(a): fn(torch.randn(10)) # Check for graph break - self.assertEqual(cnts.frame_count, 3) + self.assertEqual(cnts.frame_count, 2) def test_incorrect_usage_disallow_in_graph(self): with self.assertRaisesRegex(RuntimeError, "disallow_in_graph is expected"): @@ -661,15 +651,16 @@ def fn(x, y): def test_nonstrict_trace_no_action_at_a_distance(self): def trace_me(x): + x = x + 4 torch._dynamo.graph_break() - return x + 42 + return x + 8 # No effect on traceability of `trace_me` torch._dynamo.nonstrict_trace(trace_me) def fn(x): - res = trace_me(x) - return res + 1 + res = trace_me(x + 1) + return res + 2 x = torch.randn(10) cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") @@ -695,7 +686,7 @@ def trace_me(x, y): fn(torch.ones(10), torch.ones(1)) self.assertFalse(True) # must raise error before this except torch._dynamo.exc.Unsupported as e: - msg = "Applying `nonstrict_trace` to function ; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # NOQA: B950 + msg = "Applying `nonstrict_trace` to function ; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." self.assertIn(msg, str(e)) def test_nonstrict_trace_custom_class_error(self): @@ -1328,8 +1319,7 @@ def forward(self, a, *args): def _test_mark_static_address(self, guarded): # This test verifies that dynamo properly marks inputs as static # when using the mark_static_address API. - # For both inline_inbuilt_nn_modules True and False, we expect the - # tensor to be present in the buffers attribute of the graph. + # We expect the tensor to be present in the buffers attribute of the graph. compiles_with_buffers = 0 compiles = 0 @@ -1366,15 +1356,9 @@ def fn(x): self.assertEqual(compiles, 2 if guarded else 1) def test_mark_static_address_guarded(self): - with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True): - self._test_mark_static_address(guarded=True) - self._test_mark_static_address(guarded=True) def test_mark_static_address_unguarded(self): - with torch._dynamo.config.patch("inline_inbuilt_nn_modules", True): - self._test_mark_static_address(guarded=False) - self._test_mark_static_address(guarded=False) def test_class_methods(self): @@ -1530,7 +1514,6 @@ def fn(x, y, z): # Would have been 4 without stance self.assertEqual(cnts.op_count, 2) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_mark_static_nn_module(self): @torch._dynamo.mark_static class Mock(torch.nn.Module): @@ -1672,7 +1655,7 @@ def post_munge(s): Detected recompile when torch.compile stance is 'fail_on_recompile'. filename: 'test_decorators.py', function name: 'f', line number: N triggered by the following guard failure(s): - 0/0: tensor 'x' size mismatch at index 0. expected 4, actual 7 - - 0/1: tensor 'x' size mismatch at index 0. expected 5, actual 7""", # noqa: B950 + - 0/1: tensor 'x' size mismatch at index 0. expected 5, actual 7""", post_munge=post_munge, ) @@ -2021,7 +2004,7 @@ def f5(x): inp = torch.ones(3) self.assertEqual(f5(inp), inp + 7) - self.assertEqual(cnts.frame_count, 4) + self.assertEqual(cnts.frame_count, 2) def inner_f6(x): x = x + 2 @@ -2037,7 +2020,7 @@ def f6(x): cnts.clear() self.assertEqual(f6(inp), inp + 7) - self.assertEqual(cnts.frame_count, 3) + self.assertEqual(cnts.frame_count, 2) def inner_f7(x): x = x + 2 @@ -2077,7 +2060,7 @@ def f8(x): inp = torch.ones(3) self.assertEqual(f8(inp), inp + 7) - self.assertEqual(cnts.frame_count, 4) + self.assertEqual(cnts.frame_count, 3) def inner2_f9(x): x = x + 2 @@ -2140,7 +2123,7 @@ def f1(x): inp = torch.ones(3) self.assertEqual(f1(inp), inp + 7) - self.assertEqual(cnts.frame_count, 4) + self.assertEqual(cnts.frame_count, 2) def inner1_f2(x): x = x + 1 @@ -2260,6 +2243,7 @@ def fn(x): self.assertEqual(cnts.frame_count, 0) + @torch._dynamo.config.patch(nested_graph_breaks=False) def test_nested_compile_fullgraph(self): # Test that fullgraph=True cannot be toggled back by fullgraph=False inp = torch.ones(3) @@ -2455,1647 +2439,20 @@ def forward(self, x) -> torch.Tensor: with self.assertRaises(RuntimeError): exported_model = torch.export.export(model, (inp,)) - def _assert_models_equal( - self, - model_expected, - model_test, - x_expected, - x_test, - ): - out_expected = model_expected(x_expected) - out_test = model_test(x_test) - self.assertEqual(out_expected, out_test) - - loss_expected = out_expected.sum() - loss_test = out_test.sum() - loss_expected.backward() - loss_test.backward() - self.assertEqual(x_expected.grad, x_test.grad) - - expected_grads = { - name: param.grad for name, param in model_expected.named_parameters() - } - test_grads = {name: param.grad for name, param in model_test.named_parameters()} - - self.assertEqual(set(expected_grads.keys()), set(test_grads.keys())) - for name in expected_grads: - if expected_grads[name] is not None: - self.assertEqual( - expected_grads[name], - test_grads[name], - msg=f"Gradient mismatch for parameter {name}", - ) - - def _test_leaf_function_helper(self, mod_class, args_fn, loss_fn): - import torch.utils._pytree as pytree - from torch._dynamo.testing import AotEagerAndRecordGraphs, EagerAndRecordGraphs - - mod_eager = mod_class() - mod_compile_eager = mod_class() - mod_compile_eager.load_state_dict(dict(mod_eager.state_dict())) - mod_compile_aot = mod_class() - mod_compile_aot.load_state_dict(dict(mod_eager.state_dict())) - - eager_backend = EagerAndRecordGraphs() - compiled_eager = torch.compile( - mod_compile_eager, backend=eager_backend, fullgraph=True - ) - - backend = AotEagerAndRecordGraphs() - compiled_aot = torch.compile(mod_compile_aot, backend=backend, fullgraph=True) - - for _ in range(2): - mod_eager.zero_grad() - mod_compile_eager.zero_grad() - mod_compile_aot.zero_grad() - - args = args_fn() - args_clone = pytree.tree_map( - lambda x: x.clone().detach().requires_grad_(x.requires_grad), args - ) - args_clone2 = pytree.tree_map( - lambda x: x.clone().detach().requires_grad_(x.requires_grad), args - ) - - out_eager = mod_eager(*args) - loss_fn(out_eager).backward() - - out_compile_eager = compiled_eager(*args_clone) - loss_fn(out_compile_eager).backward() - - out_compile_aot = compiled_aot(*args_clone2) - loss_fn(out_compile_aot).backward() - - self.assertEqual(out_eager, out_compile_eager) - self.assertEqual(out_eager, out_compile_aot) - - for (name_eager, param_eager), (_, param_compile_eager), ( - _, - param_compile_aot, - ) in zip( - mod_eager.named_parameters(), - mod_compile_eager.named_parameters(), - mod_compile_aot.named_parameters(), - ): - self.assertEqual( - param_eager.grad, - param_compile_eager.grad, - msg=f"Gradient mismatch for {name_eager} between eager and compile_eager", - ) - self.assertEqual( - param_eager.grad, - param_compile_aot.grad, - msg=f"Gradient mismatch for {name_eager} between eager and compile_aot", - ) - - pytree.tree_map( - lambda x, compile_x: self.assertEqual(x.grad, compile_x.grad) - if isinstance(x, torch.Tensor) and x.requires_grad - else None, - args, - args_clone, - ) - pytree.tree_map( - lambda x, compile_x: self.assertEqual(x.grad, compile_x.grad) - if isinstance(x, torch.Tensor) and x.requires_grad - else None, - args, - args_clone2, - ) - - return ( - normalize_gm(eager_backend.graphs[0].print_readable(print_output=False)), - normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), - normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), - ) - - def test_leaf_function_simple(self): - @leaf_function - def non_tracable_forward(mod, x): - if x.sum() > 0: - return (mod.linear(x),) - else: - return (mod.linear(x) + x,) - - @non_tracable_forward.register_fake - def non_tracable_forward_fake(mod, x): - return (mod.linear(x),) - - class NonTracable(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return non_tracable_forward(self, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( - NonTracable, args_fn, loss_fn - ) - self.assertExpectedInline( - dynamo_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[3, 3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): - l_x_ = L_x_ - l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ - l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ - - real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn - fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn - input_spec : torch.utils._pytree.TreeSpec = self.input_spec - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None - getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None - return (getitem,) -""", # noqa: B950 - ) - self.assertExpectedInline( - fw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]"): - _opaque_obj0 = self._opaque_obj0 - _opaque_obj1 = self._opaque_obj1 - _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_2, requires_grad_indices = (1, 2, 3)); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_2 = None - - getitem: "f32[0]" = with_effects[0] - getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None - return (getitem, getitem_1) -""", # noqa: B950 - ) - self.assertExpectedInline( - bw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): - _opaque_obj2 = self._opaque_obj2 - _opaque_obj3 = self._opaque_obj3 - _tree_spec_constant1 = self._tree_spec_constant1 - with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ()); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None - getitem_2: "f32[0]" = with_effects_1[0] - getitem_4: "f32[3, 3]" = with_effects_1[2] - getitem_5: "f32[3]" = with_effects_1[3] - getitem_6: "f32[3, 3]" = with_effects_1[4]; with_effects_1 = None - return (getitem_6, getitem_4, getitem_5, getitem_2) -""", # noqa: B950 - ) - - def test_leaf_function_with_logging(self): - @leaf_function - def logging_forward(mod, x): - print("Processing input") - return (mod.linear(x),) - - @logging_forward.register_fake - def logging_forward_fake(mod, x): - return (mod.linear(x),) - - class LoggingModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return logging_forward(self, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - with patch("builtins.print") as mock_print: - self._test_leaf_function_helper(LoggingModule, args_fn, loss_fn) - mock_print.assert_any_call("Processing input") - # Called 6 times: eager, compile_eager, and compile_aot, 2 iterations each - self.assertEqual(mock_print.call_count, 6) - - def test_leaf_function_dynamic_autograd_module_config(self): - from torch._dynamo.testing import CompileCounterWithBackend - - @leaf_function - def configurable_scale(mod, x): - # Branch based on module config, not input - if mod.use_double_scale: - return (mod.linear(x) * 2,) - else: - return (mod.linear(x) * 3,) - - @configurable_scale.register_fake - def configurable_scale_fake(mod, x): - return (mod.linear(x),) - - class ConfigurableModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - self.use_double_scale = True # Config attribute - - def forward(self, x): - return configurable_scale(self, x) - - mod_eager = ConfigurableModule() - mod_compiled = ConfigurableModule() - mod_compiled.load_state_dict(dict(mod_eager.state_dict())) - - counter = CompileCounterWithBackend("aot_eager") - compiled_fn = torch.compile(mod_compiled, backend=counter, fullgraph=True) - - x_value = torch.randn(3, 3) - - mod_eager.use_double_scale = True - mod_compiled.use_double_scale = True - - x1 = x_value.clone().requires_grad_(True) - x1_clone = x_value.clone().requires_grad_(True) - - out_eager_1 = mod_eager(x1) - out_eager_1[0].sum().backward() - - out_compiled_1 = compiled_fn(x1_clone) - out_compiled_1[0].sum().backward() - - self.assertEqual(out_eager_1, out_compiled_1) - self.assertEqual(x1.grad, x1_clone.grad) - - mod_eager.zero_grad() - mod_compiled.zero_grad() - - mod_eager.use_double_scale = False - mod_compiled.use_double_scale = False - - x2 = x_value.clone().requires_grad_(True) - x2_clone = x_value.clone().requires_grad_(True) - - out_eager_2 = mod_eager(x2) - out_eager_2[0].sum().backward() - - out_compiled_2 = compiled_fn(x2_clone) - out_compiled_2[0].sum().backward() - - self.assertEqual(out_eager_2, out_compiled_2) - self.assertEqual(x2.grad, x2_clone.grad) - - # Same inputs but different config -> different gradients - # This proves leaf_function builds autograd dynamically (not burned in at trace time) - self.assertNotEqual(x1.grad, x2.grad) - - # Verify only ONE compilation happened (no recompilation when changing config) - self.assertEqual(counter.frame_count, 1) - - def test_leaf_function_dynamic_autograd_closure(self): - from torch._dynamo.testing import CompileCounterWithBackend - - config = {"use_double_scale": True} - - @leaf_function - def configurable_scale(x, y): - # Branch based on closure variable, not input - if config["use_double_scale"]: - return (x @ y * 2,) - else: - return (x @ y * 3,) - - @configurable_scale.register_fake - def configurable_scale_fake(x, y): - return (x @ y,) - - def fn(x, y): - return configurable_scale(x, y) - - counter = CompileCounterWithBackend("aot_eager") - compiled_fn = torch.compile(fn, backend=counter, fullgraph=True) - - x_value = torch.randn(3, 3) - y_value = torch.randn(3, 3) - - config["use_double_scale"] = True - - x1 = x_value.clone().requires_grad_(True) - y1 = y_value.clone().requires_grad_(True) - x1_clone = x_value.clone().requires_grad_(True) - y1_clone = y_value.clone().requires_grad_(True) - - out_eager_1 = fn(x1, y1) - out_eager_1[0].sum().backward() - - out_compiled_1 = compiled_fn(x1_clone, y1_clone) - out_compiled_1[0].sum().backward() - - self.assertEqual(out_eager_1, out_compiled_1) - self.assertEqual(x1.grad, x1_clone.grad) - self.assertEqual(y1.grad, y1_clone.grad) - - config["use_double_scale"] = False - - x2 = x_value.clone().requires_grad_(True) - y2 = y_value.clone().requires_grad_(True) - x2_clone = x_value.clone().requires_grad_(True) - y2_clone = y_value.clone().requires_grad_(True) - - out_eager_2 = fn(x2, y2) - out_eager_2[0].sum().backward() - - out_compiled_2 = compiled_fn(x2_clone, y2_clone) - out_compiled_2[0].sum().backward() - - self.assertEqual(out_eager_2, out_compiled_2) - self.assertEqual(x2.grad, x2_clone.grad) - self.assertEqual(y2.grad, y2_clone.grad) - - # Same inputs but different closure -> different gradients - # This proves leaf_function builds autograd dynamically (not burned in at trace time) - self.assertNotEqual(x1.grad, x2.grad) - self.assertNotEqual(y1.grad, y2.grad) - - # Verify only ONE compilation happened (no recompilation when changing closure) - self.assertEqual(counter.frame_count, 1) - - def test_leaf_function_closure_constants_without_grad(self): - closure_scale = 2.0 - closure_tensor = torch.tensor([1.0, 2.0, 3.0]) - - @leaf_function - def closure_forward(mod, x): - out = mod.linear(x) * closure_scale * mod.scale - out = out + closure_tensor + mod.offset - return (out,) - - @closure_forward.register_fake - def closure_forward_fake(mod, x): - return (mod.linear(x) + mod.offset,) - - class ClosureModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - self.scale = 3.0 - self.offset = torch.nn.Parameter(torch.ones(3)) - - def forward(self, x): - return closure_forward(self, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( - ClosureModule, args_fn, loss_fn - ) - self.assertExpectedInline( - dynamo_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[3, 3]", L_self_parameters_offset_: "f32[3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): - l_x_ = L_x_ - l_self_parameters_offset_ = L_self_parameters_offset_ - l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ - l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ - - real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn - fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn - input_spec : torch.utils._pytree.TreeSpec = self.input_spec - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_parameters_offset_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_parameters_offset_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None - getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None - return (getitem,) -""", # noqa: B950 - ) - self.assertExpectedInline( - fw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3]", primals_4: "f32[3, 3]", primals_5: "f32[3]"): - _opaque_obj0 = self._opaque_obj0 - _opaque_obj1 = self._opaque_obj1 - _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_5, primals_2, requires_grad_indices = (1, 2, 3, 4)); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_5 = primals_2 = None - - getitem: "f32[0]" = with_effects[0] - getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None - return (getitem, getitem_1) -""", # noqa: B950 - ) - self.assertExpectedInline( - bw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): - _opaque_obj2 = self._opaque_obj2 - _opaque_obj3 = self._opaque_obj3 - _tree_spec_constant1 = self._tree_spec_constant1 - with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ()); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None - getitem_2: "f32[0]" = with_effects_1[0] - getitem_4: "f32[3]" = with_effects_1[2] - getitem_5: "f32[3, 3]" = with_effects_1[3] - getitem_6: "f32[3]" = with_effects_1[4] - getitem_7: "f32[3, 3]" = with_effects_1[5]; with_effects_1 = None - return (getitem_7, getitem_4, getitem_5, getitem_6, getitem_2) -""", # noqa: B950 - ) - - def test_leaf_function_pytree_inputs(self): - @leaf_function - def pytree_forward(mod, inputs): - if inputs["x"].sum() > 0: - return (mod.linear(inputs["x"]), inputs["y"] + 1) - return (mod.linear(inputs["x"]) + inputs["y"], inputs["y"] - 1) - - @pytree_forward.register_fake - def pytree_forward_fake(mod, inputs): - return (mod.linear(inputs["x"]), inputs["y"]) - - class PytreeModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, inputs): - return pytree_forward(self, inputs) - - def args_fn(): - return ( - { - "x": torch.randn(3, 3, requires_grad=True), - "y": torch.randn(3, 3, requires_grad=True), - }, - ) - - def loss_fn(out): - return out[0].sum() + out[1].sum() - - self._test_leaf_function_helper(PytreeModule, args_fn, loss_fn) - - def test_leaf_function_nested_annotations(self): - @leaf_function - def inner_leaf_forward(mod, x): - y = mod.linear(x) - return (y + x,) - - @inner_leaf_forward.register_fake - def inner_leaf_forward_fake(mod, x): - return (mod.linear(x),) - - class InnerLeaf(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return inner_leaf_forward(self, x) - - @leaf_function - def outer_leaf_forward(mod, x): - z = mod.linear(x) - return mod.inner(z + x) - - @outer_leaf_forward.register_fake - def outer_leaf_forward_fake(mod, x): - return mod.inner(mod.linear(x)) - - class OuterLeaf(torch.nn.Module): - def __init__(self): - super().__init__() - self.inner = InnerLeaf() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return outer_leaf_forward(self, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( - OuterLeaf, args_fn, loss_fn - ) - self.assertExpectedInline( - dynamo_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[3, 3]", L_self_modules_inner_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_inner_modules_linear_parameters_bias_: "f32[3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): - l_x_ = L_x_ - l_self_modules_inner_modules_linear_parameters_weight_ = L_self_modules_inner_modules_linear_parameters_weight_ - l_self_modules_inner_modules_linear_parameters_bias_ = L_self_modules_inner_modules_linear_parameters_bias_ - l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ - l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ - - real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn - fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn - input_spec : torch.utils._pytree.TreeSpec = self.input_spec - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_modules_inner_modules_linear_parameters_weight_, l_self_modules_inner_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_modules_inner_modules_linear_parameters_weight_ = l_self_modules_inner_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None - getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None - return (getitem,) -""", # noqa: B950 - ) - self.assertExpectedInline( - fw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]", primals_5: "f32[3, 3]", primals_6: "f32[3]"): - _opaque_obj0 = self._opaque_obj0 - _opaque_obj1 = self._opaque_obj1 - _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_5, primals_6, primals_2, requires_grad_indices = (1, 2, 3, 4, 5)); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_5 = primals_6 = primals_2 = None - - getitem: "f32[0]" = with_effects[0] - getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None - return (getitem, getitem_1) -""", # noqa: B950 - ) - self.assertExpectedInline( - bw_graph_str, - """\ -class GraphModule(torch.nn.Module): - def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): - _opaque_obj2 = self._opaque_obj2 - _opaque_obj3 = self._opaque_obj3 - _tree_spec_constant1 = self._tree_spec_constant1 - with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ()); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None - getitem_2: "f32[0]" = with_effects_1[0] - getitem_4: "f32[3, 3]" = with_effects_1[2] - getitem_5: "f32[3]" = with_effects_1[3] - getitem_6: "f32[3, 3]" = with_effects_1[4] - getitem_7: "f32[3]" = with_effects_1[5] - getitem_8: "f32[3, 3]" = with_effects_1[6]; with_effects_1 = None - return (getitem_8, getitem_4, getitem_5, getitem_6, getitem_7, getitem_2) -""", # noqa: B950 - ) - - def test_leaf_function_data_dependent_nonzero(self): - @leaf_function - def nonzero_forward(mod, x): - out = mod.linear(x) - nonzero_indices = (out > 0).nonzero() - return (out, nonzero_indices) - - @nonzero_forward.register_fake - def nonzero_forward_fake(mod, x): - out = mod.linear(x) - return out, (out > 0).nonzero() - - class NonzeroModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return nonzero_forward(self, x) - - class OuterModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.pre_linear = torch.nn.Linear(3, 3) - self.nonzero_module = NonzeroModule() - self.scale = torch.nn.Parameter(torch.tensor(2.0)) - - def forward(self, x): - x = self.pre_linear(x) - x = torch.relu(x) - out, nonzero_indices = self.nonzero_module(x) - num_nonzero = nonzero_indices.shape[0] - scaled_out = out * self.scale + num_nonzero - return scaled_out, nonzero_indices - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - self._test_leaf_function_helper(OuterModule, args_fn, loss_fn) - - def test_leaf_function_data_dependent_item(self): - @leaf_function - def item_forward(mod, x): - out = mod.linear(x) - scalar_value = out.sum().item() - return (out, scalar_value) - - @item_forward.register_fake - def item_forward_fake(mod, x): - out = mod.linear(x) - return (out, out.sum().item()) - - class ItemModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return item_forward(self, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - self._test_leaf_function_helper(ItemModule, args_fn, loss_fn) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_multiple_compiled_submodules(self, backend): - @leaf_function - def leaf_forward(mod, x): - if x.sum() > 0: - return (mod.linear(x),) - else: - return (mod.linear(x) + x,) - - @leaf_forward.register_fake - def leaf_forward_fake(mod, x): - return (mod.linear(x),) - - class LeafModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - return leaf_forward(self, x) - - class CompiledSubmodule1(torch.nn.Module): - def __init__(self): - super().__init__() - self.pre_linear = torch.nn.Linear(4, 4) - self.leaf = LeafModule(4, 4) - - def forward(self, x): - x = self.pre_linear(x) - x = torch.relu(x) - out = self.leaf(x)[0] - return out - - class CompiledSubmodule2(torch.nn.Module): - def __init__(self): - super().__init__() - self.leaf = LeafModule(4, 4) - self.post_linear = torch.nn.Linear(4, 4) - - def forward(self, x): - out = self.leaf(x)[0] - out = self.post_linear(out) - return torch.sigmoid(out) - - class CompiledSubmodule3(torch.nn.Module): - def __init__(self): - super().__init__() - self.leaf1 = LeafModule(4, 4) - self.leaf2 = LeafModule(4, 4) - - def forward(self, x): - out1 = self.leaf1(x)[0] - out2 = self.leaf2(x)[0] - return out1 + out2 - - class TopLevelModule(torch.nn.Module): - def __init__(self, compile_submodules=False): - super().__init__() - self.submodule1 = CompiledSubmodule1() - self.submodule2 = CompiledSubmodule2() - self.submodule3 = CompiledSubmodule3() - self.final_linear = torch.nn.Linear(4, 4) - self.compile_submodules = compile_submodules - - def forward(self, x): - if self.compile_submodules: - out1 = torch.compile(self.submodule1, backend=backend)(x) - out2 = torch.compile(self.submodule2, backend=backend)(out1) - out3 = torch.compile(self.submodule3, backend=backend)(out2) - else: - out1 = self.submodule1(x) - out2 = self.submodule2(out1) - out3 = self.submodule3(out2) - final = self.final_linear(out3) - return final - - model_eager = TopLevelModule(compile_submodules=False) - model_compiled = TopLevelModule(compile_submodules=True) - model_compiled.load_state_dict(model_eager.state_dict()) - - x = torch.randn(2, 4, requires_grad=True) - x_compiled = x.clone().detach().requires_grad_(True) - - self._assert_models_equal( - model_eager, - model_compiled, - x, - x_compiled, - ) - - @parametrize("backend", ["eager", "aot_eager"]) - @parametrize("do_compile", [False, True]) - def test_leaf_function_with_graph_breaks(self, backend, do_compile): - @leaf_function - def leaf_forward(mod, x): - if x.sum() > 0: - return (mod.linear(x),) - else: - return (mod.linear(x) + 1,) - - @leaf_forward.register_fake - def leaf_forward_fake(mod, x): - return (mod.linear(x),) - - class LeafModule(torch.nn.Module): - def __init__(self, in_features, out_features): - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - - def forward(self, x): - return leaf_forward(self, x) - - class TopLevelModule(torch.nn.Module): - def __init__(self, do_compile=False, backend="eager"): - super().__init__() - self.leaf1 = LeafModule(4, 4) - self.leaf2 = LeafModule(4, 4) - self.leaf3 = LeafModule(4, 4) - self.final_linear = torch.nn.Linear(4, 4) - self.do_compile = do_compile - self.backend = backend - - def _forward(self, x): - out1 = self.leaf1(x)[0] - torch._dynamo.graph_break() - out2 = self.leaf2(out1)[0] - torch._dynamo.graph_break() - out3 = self.leaf3(out2)[0] - result = self.final_linear(out3) - return result - - def forward(self, x): - if self.do_compile: - return torch.compile( - self._forward, backend=self.backend, fullgraph=False - )(x) - else: - return self._forward(x) - - model_eager = TopLevelModule(do_compile=False) - model_test = TopLevelModule(do_compile=do_compile, backend=backend) - model_test.load_state_dict(model_eager.state_dict()) - - x = torch.randn(2, 4, requires_grad=True) - x_test = x.clone().detach().requires_grad_(True) - - self._assert_models_equal(model_eager, model_test, x, x_test) - - def test_leaf_function_with_module_in_pytree(self): - @leaf_function - def main_forward(modules_dict, x): - if x.sum() > 0: - return (modules_dict["first"](x) + modules_dict["second"](x),) - else: - return (modules_dict["first"](x) - modules_dict["second"](x),) - - @main_forward.register_fake - def main_forward_fake(modules_dict, x): - return (modules_dict["first"](x) + modules_dict["second"](x),) - - class HelperModule(torch.nn.Module): - def __init__(self, scale=1.0): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - self.scale = scale - - def forward(self, x): - return self.linear(x) * self.scale - - class WrapperModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.helper1 = HelperModule(scale=1.0) - self.helper2 = HelperModule(scale=0.5) - - def forward(self, x): - modules_dict = {"first": self.helper1, "second": self.helper2} - return main_forward(modules_dict, x) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - self._test_leaf_function_helper(WrapperModule, args_fn, loss_fn) - - def test_leaf_function_with_module_as_kwarg(self): - @leaf_function - def main_forward(x, helper_mod=None): - if x.sum() > 0: - return (helper_mod(x),) - else: - return (helper_mod(x) + x,) - - @main_forward.register_fake - def main_forward_fake(x, helper_mod=None): - return (helper_mod(x),) - - class HelperModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return self.linear(x) - - class WrapperModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.helper = HelperModule() - - def forward(self, x): - return main_forward(x, helper_mod=self.helper) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - self._test_leaf_function_helper(WrapperModule, args_fn, loss_fn) - - def test_leaf_function_missing_fake_impl_error(self): - @leaf_function - def no_fake_impl_forward(mod, x): - return (mod.linear(x),) - - class SimpleModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return no_fake_impl_forward(self, x) - - mod = SimpleModule() - x = torch.randn(3, 3) - - with self.assertRaisesRegex(Exception, "requires a fake implementation"): - mod(x) - - compiled_mod = torch.compile(mod, backend="eager", fullgraph=True) - with self.assertRaisesRegex(Exception, "requires a fake implementation"): - compiled_mod(x) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_constant_tensor_closure_error(self, backend): - constant_weight = torch.randn(3, 3) - - @leaf_function - def constant_closure_forward(x): - return (x @ constant_weight,) - - @constant_closure_forward.register_fake - def constant_closure_forward_fake(x): - return (x @ constant_weight,) - - class ConstantClosureModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return constant_closure_forward(x) - - mod = ConstantClosureModule() - x = torch.randn(3, 3, requires_grad=True) - - result = mod(x) - expected = x @ constant_weight - self.assertEqual(result[0], expected) - - compiled_mod = torch.compile(mod, backend=backend, fullgraph=True) - with self.assertRaisesRegex( - Exception, "Please convert all Tensors to FakeTensors" - ): - compiled_mod(x) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_error(self, backend): - @leaf_function - def mutate_input(x): - x.add_(1) - return (x,) - - @mutate_input.register_fake - def mutate_input_fake(x): - x.add_(1) - return (x,) - - def fn(x): - return mutate_input(x) - - x = torch.randn(3, 3) - - x_eager = x.clone() - with self.assertRaisesRegex(RuntimeError, "Undeclared in-place mutation"): - fn(x_eager) - - x = torch.randn(3, 3, requires_grad=True) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - with self.assertRaisesRegex(RuntimeError, "leaf Variable that requires grad"): - compiled_fn(x.clone().requires_grad_(True)) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_validation_dtype_mismatch(self, backend): - @leaf_function - def dtype_mismatch_forward(mod, x): - return (mod.linear(x),) - - @dtype_mismatch_forward.register_fake - def dtype_mismatch_forward_fake(mod, x): - return (mod.linear(x).double(),) - - class DtypeMismatchModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return dtype_mismatch_forward(self, x) - - mod = DtypeMismatchModule() - x = torch.randn(3, 3) - - with config.patch(leaf_function_validate_outputs=True): - compiled_mod = torch.compile(mod, backend=backend) - with self.assertRaisesRegex(RuntimeError, "Dtype mismatch"): - compiled_mod(x) - - @parametrize("backend", ["eager", "aot_eager"]) - @parametrize("validate_outputs", [True, False]) - def test_leaf_function_validation_shape_mismatch(self, backend, validate_outputs): - @leaf_function - def mismatched_forward(mod, x): - return (mod.linear(x),) - - @mismatched_forward.register_fake - def mismatched_forward_fake(mod, x): - return (torch.zeros(x.shape[0], 6),) - - class MismatchedModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return mismatched_forward(self, x) - - mod = MismatchedModule() - x = torch.randn(3, 3) - - with config.patch(leaf_function_validate_outputs=validate_outputs): - compiled_mod = torch.compile(mod, backend=backend) - if validate_outputs: - with self.assertRaises((RuntimeError, AssertionError)): - compiled_mod(x) - else: - result = compiled_mod(x) - self.assertEqual(result[0].shape, (3, 3)) - - def test_leaf_function_no_module_inputs(self): - @leaf_function - def my_custom_fn(inputs: dict[str, torch.Tensor], scale: float, offset: int): - x = inputs["x"] - y = inputs["y"] - if x.sum() > 0: - return (x * scale + y + offset, x.sum() + y.sum()) - return (x * scale - y + offset, x.sum() - y.sum()) - - @my_custom_fn.register_fake - def my_custom_fn_fake( - inputs: dict[str, torch.Tensor], scale: float, offset: int - ): - x = inputs["x"] - y = inputs["y"] - return (x * scale + y + offset, x.sum() + y.sum()) - - class NoModuleInputsModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.scale = 2.0 - self.offset = 1 - - def forward(self, x, y): - inputs = {"x": x, "y": y} - return my_custom_fn(inputs, self.scale, self.offset) - - def args_fn(): - return ( - torch.randn(3, 3, requires_grad=True), - torch.randn(3, 3, requires_grad=True), - ) - - def loss_fn(out): - return out[0].sum() + out[1].sum() - - self._test_leaf_function_helper(NoModuleInputsModule, args_fn, loss_fn) - - @parametrize("backend", ["eager", "aot_eager"]) - @parametrize("check_escaped_gradients", [True, False]) - def test_leaf_function_escaped_gradient_multiple_tensors( - self, backend, check_escaped_gradients - ): - weight1 = torch.randn(3, 3, requires_grad=True) - weight2 = torch.randn(3, 3, requires_grad=True) - - @leaf_function - def uses_multiple_closures(x): - return (x @ weight1 + x @ weight2,) - - @uses_multiple_closures.register_fake - def uses_multiple_closures_fake(x): - return (torch.empty(x.shape[0], 3),) - - def fn(x): - return uses_multiple_closures(x) - - x = torch.randn(2, 3, requires_grad=True) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - with config.patch( - leaf_function_check_escaped_gradients=check_escaped_gradients - ): - if check_escaped_gradients: - with self.assertRaisesRegex(RuntimeError, "2 tensor"): - compiled_fn(x) - else: - result = compiled_fn(x) - self.assertEqual(result[0].shape, (2, 3)) - - @parametrize("backend", ["eager", "aot_eager"]) - @parametrize("check_escaped_gradients", [True, False]) - def test_leaf_function_escaped_gradient_input_no_grad( - self, backend, check_escaped_gradients - ): - closure_weight = torch.randn(3, 3, requires_grad=True) - - @leaf_function - def uses_closure(x): - return (x @ closure_weight,) - - @uses_closure.register_fake - def uses_closure_fake(x): - return (torch.empty(x.shape[0], 3),) - - def fn(x): - return uses_closure(x) - - x = torch.randn(2, 3, requires_grad=False) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - with config.patch( - leaf_function_check_escaped_gradients=check_escaped_gradients - ): - result = compiled_fn(x) - self.assertEqual(result[0].shape, (2, 3)) - - @parametrize("backend", ["eager", "aot_eager"]) - @parametrize("check_escaped_gradients", [True, False]) - def test_leaf_function_escaped_gradient_mixed_inputs( - self, backend, check_escaped_gradients - ): - base1 = torch.randn(3, 3, requires_grad=True) - base2 = torch.randn(3, 4, requires_grad=True) - closure_weight1 = base1 * 2 - closure_weight2 = base2 * 3 - - @leaf_function - def mixed_inputs(x, y): - out1 = x @ closure_weight1 + y - out2 = x @ closure_weight2 - return (out1, out2) - - @mixed_inputs.register_fake - def mixed_inputs_fake(x, y): - return (torch.empty(x.shape[0], 3), torch.empty(x.shape[0], 4)) - - def fn(x, y): - return mixed_inputs(x, y) - - x = torch.randn(2, 3, requires_grad=True) - y = torch.randn(2, 3, requires_grad=False) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - with config.patch( - leaf_function_check_escaped_gradients=check_escaped_gradients - ): - if check_escaped_gradients: - with self.assertRaisesRegex(RuntimeError, "2 tensor"): - compiled_fn(x, y) - else: - result = compiled_fn(x, y) - self.assertEqual(result[0].shape, (2, 3)) - self.assertEqual(result[1].shape, (2, 4)) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_escaped_gradient_error_message_contains_tensor_info( - self, backend - ): - closure_weight = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) - - @leaf_function - def uses_closure(x): - return (x @ closure_weight,) - - @uses_closure.register_fake - def uses_closure_fake(x): - return (torch.empty(x.shape[0], 5),) - - def fn(x): - return uses_closure(x) - - x = torch.randn(2, 4, requires_grad=True) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - with config.patch(leaf_function_check_escaped_gradients=True): - with self.assertRaisesRegex(RuntimeError, r"shape=\[4, 5\].*dtype="): - compiled_fn(x) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_escaped_gradient_actually_lost(self, backend): - closure_weight = torch.randn(3, 3, requires_grad=True) - - @leaf_function - def uses_closure(x): - return (x @ closure_weight,) - - @uses_closure.register_fake - def uses_closure_fake(x): - return (torch.empty(x.shape[0], 3),) - - def fn(x): - return uses_closure(x) - - x = torch.randn(2, 3, requires_grad=True) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - result = compiled_fn(x) - loss = result[0].sum() - loss.backward() - - self.assertIsNotNone(x.grad) - self.assertIsNone(closure_weight.grad) - - def test_leaf_function_and_nonstrict_trace_mutually_exclusive(self): - from torch._dynamo.decorators import leaf_function, nonstrict_trace - - with self.assertRaisesRegex( - ValueError, - "cannot be both marked as @leaf_function and @nonstrict_trace", - ): - - @leaf_function - @nonstrict_trace - def bad_fn1(x): - return (x,) - - with self.assertRaisesRegex( - ValueError, - "cannot be both marked as @leaf_function and @nonstrict_trace", - ): - - @nonstrict_trace - @leaf_function - def bad_fn2(x): - return (x,) - - def test_leaf_function_no_return_value(self): - printed = [] - - @leaf_function - def fn_no_return(x): - print("processing") - - @fn_no_return.register_fake - def fn_no_return_fake(x): - pass - - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - fn_no_return(x) - return (self.linear(x),) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() - - with patch("builtins.print", lambda *args, **kwargs: printed.append(args)): - eager_graph, fw_graph, bw_graph = self._test_leaf_function_helper( - Mod, args_fn, loss_fn - ) - self.assertTrue(any("processing" in p for p in printed)) - self.assertExpectedInline( - eager_graph, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[3, 3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): - l_x_ = L_x_ - l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ - l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ - - real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn - fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn - input_spec : torch.utils._pytree.TreeSpec = self.input_spec - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', l_x_); real_fn = fake_fn = input_spec = invoke_leaf_function = None - - linear: "f32[3, 3]" = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); l_x_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None - return (linear,) -""", # noqa: B950 - ) - self.assertExpectedInline( - fw_graph, - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]"): - _opaque_obj0 = self._opaque_obj0 - _opaque_obj1 = self._opaque_obj1 - _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', primals_2, requires_grad_indices = (0,)); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = None - - getitem: "f32[0]" = with_effects[0]; with_effects = None - - t: "f32[3, 3]" = torch.ops.aten.t.default(primals_3) - addmm: "f32[3, 3]" = torch.ops.aten.addmm.default(primals_4, primals_2, t); primals_4 = t = None - return (getitem, addmm, primals_2, primals_3) -""", # noqa: B950 - ) - self.assertExpectedInline( - bw_graph, - """\ -class GraphModule(torch.nn.Module): - def forward(self, primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", tangents_1: "f32[3, 3]"): - t: "f32[3, 3]" = torch.ops.aten.t.default(primals_3); primals_3 = None - t_1: "f32[3, 3]" = torch.ops.aten.t.default(t); t = None - mm: "f32[3, 3]" = torch.ops.aten.mm.default(tangents_1, t_1); t_1 = None - t_2: "f32[3, 3]" = torch.ops.aten.t.default(tangents_1) - mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t_2, primals_2); t_2 = primals_2 = None - t_3: "f32[3, 3]" = torch.ops.aten.t.default(mm_1); mm_1 = None - sum_1: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None - view: "f32[3]" = torch.ops.aten.view.default(sum_1, [3]); sum_1 = None - t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None - return (mm, t_4, view) -""", # noqa: B950 - ) - - def test_leaf_function_output_structure_mismatch(self): - @leaf_function - def mismatched_fn(x): - return {"a": x, "b": x * 2} - - @mismatched_fn.register_fake - def mismatched_fn_fake(x): - return (x, x * 2) - - def fn(x): - return mismatched_fn(x) - - x = torch.randn(3, 3) - with self.assertRaisesRegex(AssertionError, "output structure mismatch"): - torch.compile(fn, backend="eager")(x) - - def test_leaf_function_nested_output(self): - @leaf_function - def nested_output_fn(linear1, linear2, linear3, x): - if x.sum() > 0: - return { - "out": (linear1(x), linear2(x)), - "extra": linear3(x), - "count": 42, - } - else: - return { - "out": (linear1(x) + 1, linear2(x) + 1), - "extra": linear3(x) + 1, - "count": 42, - } - - @nested_output_fn.register_fake - def nested_output_fn_fake(linear1, linear2, linear3, x): - return { - "out": (linear1(x), linear2(x)), - "extra": linear3(x), - "count": 42, - } - - class NestedOutputModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(3, 3) - self.linear2 = torch.nn.Linear(3, 3) - self.linear3 = torch.nn.Linear(3, 3) - - def forward(self, x): - result = nested_output_fn(self.linear1, self.linear2, self.linear3, x) - return ( - result["out"][0] * result["count"] - + result["out"][1] - + result["extra"] - ) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out.sum() - - self._test_leaf_function_helper(NestedOutputModule, args_fn, loss_fn) - - def test_leaf_function_custom_pytree_output(self): - class Point: - x: torch.Tensor - y: torch.Tensor - - def __init__(self, x, y): - self.x = x - self.y = y - - self.register_pytree_node( - Point, - lambda p: ((p.x, p.y), ()), - lambda xy, _: Point(xy[0], xy[1]), - serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", - ) - - @leaf_function - def point_fn(linear1, linear2, x): - return (Point(linear1(x), linear2(x)), 0.5) - - @point_fn.register_fake - def point_fn_fake(linear1, linear2, x): - return (Point(linear1(x), linear2(x)), 0.5) - - class PointModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(3, 3) - self.linear2 = torch.nn.Linear(3, 3) - - def forward(self, x): - p, scale = point_fn(self.linear1, self.linear2, x) - return (p.x * scale, p.y * scale) - - def args_fn(): - return (torch.randn(3, 3, requires_grad=True),) - - def loss_fn(out): - return out[0].sum() + out[1].sum() - - self._test_leaf_function_helper(PointModule, args_fn, loss_fn) - - def test_leaf_function_fake_requires_grad_ignored(self): - @leaf_function - def my_fn(x): - return (x * 2,) - - @my_fn.register_fake - def my_fn_fake(x): - return (torch.empty_like(x).requires_grad_(False),) - - from torch._dynamo.testing import EagerAndRecordGraphs - - backend = EagerAndRecordGraphs() - - @torch.compile(backend=backend, fullgraph=True) - def fn(x): - return my_fn(x) - - x = torch.randn(3, 3, requires_grad=True) - out = fn(x) - - self.assertTrue(out[0].requires_grad) - out[0].sum().backward() - self.assertIsNotNone(x.grad) - - graph = backend.graphs[0] - for node in graph.graph.nodes: - if node.op == "call_function" and "invoke_leaf_function" in str( - node.target - ): - example_value = node.meta.get("example_value") - self.assertIsNotNone(example_value) - self.assertTrue(example_value[0].requires_grad) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_non_grad(self, backend): - @leaf_function(mutates_args={"buf"}) - def mutate_buffer(x, buf): - buf.add_(1) - return (x + buf,) - - @mutate_buffer.register_fake - def mutate_buffer_fake(x, buf): - buf.add_(1) - return (x + buf,) - - def fn(x, buf): - return mutate_buffer(x, buf) - - x = torch.randn(3, 3) - buf = torch.randn(3, 3) - - buf_eager = buf.clone() - result_eager = fn(x, buf_eager) - expected = x + buf + 1 - self.assertEqual(result_eager[0], expected) - self.assertEqual(buf_eager, buf + 1) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - buf_compiled = buf.clone() - result_compiled = compiled_fn(x, buf_compiled) - self.assertEqual(result_compiled[0], expected) - self.assertEqual(buf_compiled, buf + 1) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_mixed(self, backend): - @leaf_function(mutates_args={"buf"}) - def mixed_fn(x, buf): - buf.mul_(2) - return (x * buf,) - - @mixed_fn.register_fake - def mixed_fn_fake(x, buf): - buf.mul_(2) - return (x * buf,) - - def fn(x, buf): - return mixed_fn(x, buf) - - x = torch.randn(3, 3, requires_grad=True) - buf = torch.randn(3, 3) - - buf_eager = buf.clone() - result_eager = fn(x, buf_eager) - expected = x * (buf * 2) - self.assertEqual(result_eager[0], expected) - self.assertEqual(buf_eager, buf * 2) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - buf_compiled = buf.clone() - result_compiled = compiled_fn(x, buf_compiled) - self.assertEqual(result_compiled[0], expected) - self.assertEqual(buf_compiled, buf * 2) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_module_buffer(self, backend): - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("running_mean", torch.zeros(3)) - self.linear = torch.nn.Linear(3, 3) - - def forward(self, x): - return update_stats(self, x) - - @leaf_function(mutates_args={"model.running_mean"}) - def update_stats(model, x): - model.running_mean.add_(x.mean(dim=0)) - return (model.linear(x),) - - @update_stats.register_fake - def update_stats_fake(model, x): - model.running_mean.add_(x.mean(dim=0)) - return (model.linear(x),) - - mod = MyModule() - x = torch.randn(4, 3) - - mod_eager = copy.deepcopy(mod) - result_eager = mod_eager(x) - expected_mean = torch.zeros(3) + x.mean(dim=0) - self.assertEqual(mod_eager.running_mean, expected_mean) - - mod_compiled = copy.deepcopy(mod) - compiled_mod = torch.compile(mod_compiled, backend=backend, fullgraph=True) - result_compiled = compiled_mod(x) - self.assertEqual(result_compiled, result_eager) - self.assertEqual(mod_compiled.running_mean, expected_mean) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_pytree(self, backend): - @leaf_function(mutates_args={"buffers"}) - def update_buffers(x, buffers): - for buf in buffers: - buf.add_(1) - return (x + sum(buffers),) - - @update_buffers.register_fake - def update_buffers_fake(x, buffers): - for buf in buffers: - buf.add_(1) - return (x + sum(buffers),) - - def fn(x, buffers): - return update_buffers(x, buffers) - - x = torch.randn(3, 3) - bufs = [torch.randn(3, 3), torch.randn(3, 3)] - - bufs_eager = [b.clone() for b in bufs] - result_eager = fn(x, bufs_eager) - expected = x + (bufs[0] + 1) + (bufs[1] + 1) - self.assertEqual(result_eager[0], expected) - self.assertEqual(bufs_eager[0], bufs[0] + 1) - self.assertEqual(bufs_eager[1], bufs[1] + 1) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - bufs_compiled = [b.clone() for b in bufs] - result_compiled = compiled_fn(x, bufs_compiled) - self.assertEqual(result_compiled[0], expected) - self.assertEqual(bufs_compiled[0], bufs[0] + 1) - self.assertEqual(bufs_compiled[1], bufs[1] + 1) - - @parametrize("backend", ["eager", "aot_eager"]) - def test_leaf_function_input_mutation_pytree_fine_grained(self, backend): - @leaf_function(mutates_args={"buffers[0]"}) - def update_first(x, buffers): - buffers[0].add_(1) - return (x + buffers[0] + buffers[1],) - - @update_first.register_fake - def update_first_fake(x, buffers): - buffers[0].add_(1) - return (x + buffers[0] + buffers[1],) - - def fn(x, buffers): - return update_first(x, buffers) - - x = torch.randn(3, 3) - bufs = [torch.randn(3, 3), torch.randn(3, 3)] - - bufs_eager = [b.clone() for b in bufs] - result_eager = fn(x, bufs_eager) - expected = x + (bufs[0] + 1) + bufs[1] - self.assertEqual(result_eager[0], expected) - self.assertEqual(bufs_eager[0], bufs[0] + 1) - self.assertEqual(bufs_eager[1], bufs[1]) - - compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) - bufs_compiled = [b.clone() for b in bufs] - result_compiled = compiled_fn(x, bufs_compiled) - self.assertEqual(result_compiled[0], expected) - self.assertEqual(bufs_compiled[0], bufs[0] + 1) - self.assertEqual(bufs_compiled[1], bufs[1]) - - def test_leaf_function_mutates_args_invalid_parameter(self): - with self.assertRaisesRegex(ValueError, "refers to parameter 'buf'"): - - @leaf_function(mutates_args={"buf"}) - def bad_fn(x, buffers): - buffers.add_(1) - return (x + buffers,) - - with self.assertRaisesRegex(ValueError, "refers to parameter 'mdl'"): - - @leaf_function(mutates_args={"mdl.running_mean"}) - def bad_fn2(x, model): - model.running_mean.add_(1) - return (x,) - - def test_leaf_function_mutates_args_non_leaf_expression(self): - @leaf_function(mutates_args={"model"}) - def bad_fn(x, model): - model.running_mean.add_(1) - return (x,) - - @bad_fn.register_fake - def bad_fn_fake(x, model): - model.running_mean.add_(1) - return (x,) - - class MyModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer("running_mean", torch.zeros(3)) - - def forward(self, x): - return bad_fn(x, self) + def test_allow_in_graph_inside_compile_gives_clear_error(self): + # Regression test for https://github.com/pytorch/pytorch/issues/178511 + # Calling allow_in_graph inside a compiled region is not supported. + # Verify the error message guides users to annotate before compilation. + def forward(x): + wrapped_fn = torch.compiler.allow_in_graph(my_custom_function) + return wrapped_fn(x) - mod = MyModule() - x = torch.randn(3) - compiled_fn = torch.compile(mod, backend="eager", fullgraph=True) + compiled = torch.compile(forward, fullgraph=True) with self.assertRaisesRegex( - torch._dynamo.exc.UserError, "resolved to a non-leaf value" + torch._dynamo.exc.Unsupported, + "allow_in_graph", ): - compiled_fn(x) + compiled(torch.randn(4)) instantiate_parametrized_tests(DecoratorTests) diff --git a/test/dynamo/test_deviceguard.py b/test/dynamo/test_deviceguard.py index 2d1d267c9379e..de2c9f1b76bf7 100644 --- a/test/dynamo/test_deviceguard.py +++ b/test/dynamo/test_deviceguard.py @@ -6,7 +6,7 @@ import torch._dynamo.test_case import torch._dynamo.testing from torch._dynamo.device_interface import CudaInterface, DeviceGuard -from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU +from torch.testing._internal.common_cuda import TEST_CUDA class TestDeviceGuard(torch._dynamo.test_case.TestCase): @@ -56,21 +56,6 @@ def setUp(self): super().setUp() self.device_interface = CudaInterface - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def test_device_guard(self): - current_device = torch.cuda.current_device() - - device_guard = DeviceGuard(self.device_interface, 1) - - with device_guard as _: - self.assertEqual(torch.cuda.current_device(), 1) - self.assertEqual(device_guard.prev_idx, 0) - self.assertEqual(device_guard.idx, 1) - - self.assertEqual(torch.cuda.current_device(), current_device) - self.assertEqual(device_guard.prev_idx, 0) - self.assertEqual(device_guard.idx, 1) - def test_device_guard_no_index(self): current_device = torch.cuda.current_device() diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index e23fe4bb2c666..7cba9e7351abc 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -1,6 +1,5 @@ # Owner(s): ["module: dynamo"] -# ruff: noqa: TRY002 import enum import itertools @@ -134,6 +133,46 @@ def fn(x): with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): self.assertEqual(fn(x), opt_fn(x)) + def test_dict_torch_size_dynamic_key(self): + class DynamicShapeModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_configs = { + torch.Size([32, 64]): torch.nn.Linear(32, 64), + torch.Size([64, 128]): torch.nn.Linear(64, 128), + torch.Size([128, 256]): torch.nn.Linear(128, 256), + } + self.activation_functions = { + torch.Size([32]): torch.nn.ReLU(), + torch.Size([64]): torch.nn.Tanh(), + torch.Size([128]): torch.nn.Sigmoid(), + } + + def forward(self, x): + current_shape = torch.tensor(x.shape[1:]) + shape_key = torch.Size([current_shape[0], 64]) + if shape_key in self.layer_configs: + x = self.layer_configs[shape_key](x) + + activation_shape = torch.Size([x.shape[1]]) + if activation_shape in self.activation_functions: + x = self.activation_functions[activation_shape](x) + return x + + model = DynamicShapeModel().eval() + + for features in (32, 64, 128, 16): + x = torch.randn(4, features) + with torch.no_grad(): + eager_out = model(x) + + torch._dynamo.reset() + compiled_model = torch.compile(model, backend="eager") + with torch.no_grad(): + compiled_out = compiled_model(x) + + self.assertTrue(same(eager_out, compiled_out)) + def test_dict_subclass_methods_fallback_readonly(self): sd = SimpleDict() sd[2] = 5 @@ -609,6 +648,33 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) + def test_len_dunder_dict_after_delete(self): + # After deleting an attribute during tracing, len(obj.__dict__) must + # reflect the deletion. The bug: SideEffectsProxyDict.__len__ returned + # len(item_dict) + len(side_effects_table), double-counting deleted + # entries and keys mutated from item_dict. + class UserDefined: + def __init__(self) -> None: + self.a = 3 + self.b = 5 + + def run(self, x): + del self.a + # len should be 1 (only b remains), not 3 (item_dict=2 + mutations=1) + return x * len(self.__dict__) + + obj1 = UserDefined() + obj2 = UserDefined() + + def fn(x, obj): + return obj.run(x) + + x = torch.ones(2) + ref = fn(x, obj1) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x, obj2) + self.assertEqual(ref, res) + def test_contains_module_dunder_dict(self): class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -1152,7 +1218,7 @@ def f3(x): def test_newly_constructed_default_dict_with_dict(self): def f(x): - d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406 + d = dict([("a", 1), ("b", 2)], c=3) dd = defaultdict(list, d, d=4, e=5) dd["x"].append(42) return x + 1, d, dd @@ -1177,6 +1243,85 @@ def f(x): self.assertEqual(ref, res) + def test_dict_new_ignores_extra_args(self): + """dict.__new__ ignores extra args (CPython behavior). + + This matters for instantiate_user_defined_class_object which calls + cls.__new__(cls, *args, **kwargs) — dict.__new__ should not fail + on the extra args. + """ + + def f(x): + od = OrderedDict(a=1, b=2) + return x + od["a"] + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def test_ordered_dict_repr(self): + """repr() on OrderedDictVariable should not graph break.""" + + def f(x): + od = OrderedDict(a=1, b=2) + r = repr(od) + # Just verify repr doesn't graph break and returns a string + return x + (1 if isinstance(r, str) else 0) + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def test_c_new_init_args_ignored_for_dict(self): + """C-level __new__ for dict/set passes init_args=[] since they + ignore extra args. This ensures generators passed to OrderedDict() + don't end up in reconstruction.""" + + def whoo(t): + yield 1, t + 1 + yield 2, t + 2 + + def f(t): + return OrderedDict(whoo(t)) + + t = torch.randn(2) + ref = f(t) + res = torch.compile(f, backend="eager", fullgraph=True)(t) + self.assertEqual(ref, res) + + def test_ordered_dict_as_python_constant_preserves_type(self): + """as_python_constant should return OrderedDict, not plain dict.""" + + def f(x): + od = OrderedDict(a=1, b=2) + # Enum functional API calls as_python_constant on the OrderedDict + import enum + + E = enum.Enum("E", od) + return x + E.a.value + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def test_default_dict_as_python_constant_preserves_type(self): + """as_python_constant should return defaultdict, not plain dict.""" + + def f(x): + dd = defaultdict(int, a=1, b=2) + # isinstance triggers as_python_constant internally for + # constant folding the type check + assert isinstance(dd, defaultdict) # noqa: S101 + return x + dd["a"] + + x = torch.ones(2) + ref = f(x) + res = torch.compile(f, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + @parametrize("op", ["or_", "and_", "xor", "sub"]) def test_dict_keys_binop(self, op): op = getattr(operator, op) @@ -1204,6 +1349,222 @@ def f(): opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) + @parametrize("op", ["or_", "and_", "xor", "sub"]) + def test_dict_items_binop(self, op): + op = getattr(operator, op) + + def f(): + a = {"one": 1, "two": 2} + b = {"one": 1, "three": 3} + return op(a.items(), b.items()), op(b.items(), a.items()) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["ior", "iand", "ixor", "isub"]) + def test_dict_items_inplace_binop(self, op): + op = getattr(operator, op) + + def f(): + a = {"one": 1, "two": 2}.items() + b = {"one": 1, "three": 3}.items() + c = {"one": 1, "two": 2}.items() + a = op(a, b) + b = op(b, c) + return a, b + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_dict_keys_vs_set(self, op): + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2, "three": 3}.keys() + s = {"one", "four"} + return op(keys, s), op(s, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_dict_items_vs_set(self, op): + op = getattr(operator, op) + + def f(): + items = {"one": 1, "two": 2}.items() + s = {("one", 1), ("three", 3)} + return op(items, s), op(s, items) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_dict_keys_vs_dict_items(self, op): + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2}.keys() + items = {"three": 3, "four": 4}.items() + return op(keys, items), op(items, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_dict_keys_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2, "three": 3}.keys() + s = MySet({"one", "four"}) + return op(keys, s), op(s, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_dict_items_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + items = {"one": 1, "two": 2, "three": 3}.items() + s = MySet({"one", "four"}) + return op(items, s), op(s, items) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["or_", "and_", "xor", "sub", "iand", "ior", "ixor", "isub"]) + def test_cross_type_set_binop_set_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + s = {"one", "two", "three"} + u = MySet({"one", "four"}) + return op(s, u), op(u, s) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_dict_keys_vs_set(self, op): + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2}.keys() + s = {"one", "two", "three"} + return op(keys, s), op(s, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_dict_items_vs_set(self, op): + op = getattr(operator, op) + + def f(): + items = {"one": 1, "two": 2}.items() + s = {("one", 1), ("two", 2), ("three", 3)} + return op(items, s), op(s, items) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_dict_keys_vs_dict_items(self, op): + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2}.keys() + items = {"one": 1, "two": 2}.items() + return op(keys, items), op(items, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_set_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + s = {"one", "two"} + u = MySet({"one", "two", "three"}) + return op(s, u), op(u, s) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_dict_keys_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + keys = {"one": 1, "two": 2}.keys() + u = MySet({"one", "two", "three"}) + return op(keys, u), op(u, keys) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + @parametrize("op", ["lt", "le", "gt", "ge", "eq", "ne"]) + def test_cross_type_cmp_dict_items_vs_user_defined_set(self, op): + class MySet(set): + pass + + op = getattr(operator, op) + + def f(): + items = {"one": 1, "two": 2}.items() + u = MySet({("one", 1), ("two", 2), ("three", 3)}) + return op(items, u), op(u, items) + + opt_f = torch.compile(f, backend="eager", fullgraph=True) + self.assertEqual(f(), opt_f()) + + def test_dict_view_iand_rebinds_variable(self): + def f_keys(): + d = {"one": 1, "two": 2, "three": 3} + tmp = d.keys() + tmp &= {"one", "four"} + return tmp, d + + opt_f_keys = torch.compile(f_keys, backend="eager", fullgraph=True) + eager_tmp, eager_d = f_keys() + compiled_tmp, compiled_d = opt_f_keys() + self.assertIsInstance(eager_tmp, set) + self.assertEqual(eager_tmp, compiled_tmp) + self.assertEqual(eager_d, compiled_d) + + def f_items(): + d = {"one": 1, "two": 2} + tmp = d.items() + tmp &= {("one", 1), ("three", 3)} + return tmp, d + + opt_f_items = torch.compile(f_items, backend="eager", fullgraph=True) + eager_tmp, eager_d = f_items() + compiled_tmp, compiled_d = opt_f_items() + self.assertIsInstance(eager_tmp, set) + self.assertEqual(eager_tmp, compiled_tmp) + self.assertEqual(eager_d, compiled_d) + def test_dict_union_result_type(self): def_dict = defaultdict(int, {1: 1, 2: 2}) ord_dict = OrderedDict({3: 3, 4: 4}) @@ -1488,6 +1849,44 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) + def test_custom_bool(self): + class CustomBoolDict: + def __init__(self, bool_result: bool): + super().__init__() + self._bool_result = bool_result + + def __bool__(self): + return self._bool_result + + def fn(t: torch.Tensor, d, apply_not: bool): + if apply_not: + return t.sin(), bool(not d) + else: + return t.sin(), bool(d) + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + t = CustomBoolDict(True) + self.assertTrue( + opt_fn(x, t, False)[1], + "CustomBoolDict(True) should evaluate to True in boolean context", + ) + self.assertFalse( + opt_fn(x, t, True)[1], + "not CustomBoolDict(True) should evaluate to False in boolean context", + ) + + f = CustomBoolDict(False) + self.assertFalse( + opt_fn(x, f, False)[1], + "CustomBoolDict(False) should evaluate to False in boolean context", + ) + self.assertTrue( + opt_fn(x, f, True)[1], + "not CustomBoolDict(False) should evaluate to True in boolean context", + ) + instantiate_parametrized_tests(DictTests) @@ -1632,6 +2031,7 @@ class DictMethodsTests(torch._dynamo.test_case.TestCase): thetype = dict # Methods: + # + bool # + clear # + copy # + fromkeys @@ -2080,10 +2480,26 @@ def f(a, x): x = torch.randn(4) self.assertTrue(same(f(A(), x), opt_f(A(), x))) + @make_dynamo_test + def test_bool(self): + p = self.thetype() + q = self.thetype({"a": 1, "b": 2}) + if p: + self.fail("empty mapping must compare to False") + if not q: + self.fail("non-empty mapping must compare to True") + if bool(p): + self.fail("empty mapping must compare to False") + if not bool(q): + self.fail("non-empty mapping must compare to True") + class DictSubclassMethodsTests(DictMethodsTests): thetype = SimpleDict + def test_binop_or(self): + super().test_binop_or() + class OrderedDictMethodsTests(DictMethodsTests): thetype = OrderedDict @@ -2310,6 +2726,27 @@ def fn(obj): result = fn(obj) self.assertEqual(result, [("a", 1), ("b", 2), ("c", 3)]) + def test_method_dict_direct_fullgraph(self): + """ + Regression test: Accessing __dict__ directly on UserMethodVariable. + This should fail with: unsupported variable type for __dict__ access + """ + + class Foo: + def bar(self): + return 42 + + @torch.compile(backend="eager", fullgraph=True) + def fn(): + obj = Foo() + # Access __dict__ on bound method - creates UserMethodVariable + method = obj.bar + d = method.__dict__ + return d + + # This should not raise Unsupported + fn() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 3bebcfc345b70..a585c65c82157 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import unittest import warnings from torch._dynamo import config @@ -85,13 +84,6 @@ def make_dynamic_cls(cls): make_dynamic_cls(test) del test -if TEST_Z3: - if not config.inline_inbuilt_nn_modules: - # TODO model is somehow not being freed when z3 is available - unittest.expectedFailure( - DynamicShapesMiscTests.test_parameter_free_dynamic_shapes # noqa: F821 - ) - # Test takes too long ~700s as of 414a1fd29f04d06e41b7f895368dd1f83a4be29d DynamicShapesExportTests.test_retracibility_dynamic_shapes = slowTest( # noqa: F821 DynamicShapesExportTests.test_retracibility_dynamic_shapes # noqa: F821 diff --git a/test/dynamo/test_dynamo_decompositions.py b/test/dynamo/test_dynamo_decompositions.py index 16f40c7a3bcbf..e388ea4c344a3 100644 --- a/test/dynamo/test_dynamo_decompositions.py +++ b/test/dynamo/test_dynamo_decompositions.py @@ -6,7 +6,8 @@ import torch._dynamo.config import torch._dynamo.test_case from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm -from torch.testing._internal.common_utils import run_tests, skipIfCrossRef +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_utils import run_tests, skipIfCrossRef, TestCase class TestDynamoDecompositions(torch._dynamo.test_case.TestCase): @@ -379,7 +380,7 @@ def forward(self, L_weight_: "f32[]", L_end_tensors_0_: "f32[4]", L_end_tensors_ copy_: "f32[4]" = l_tensors_0_.copy_(b); l_tensors_0_ = b = copy_ = None copy__1: "f32[4]" = l_tensors_1_.copy_(b_1); l_tensors_1_ = b_1 = copy__1 = None return () -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -423,7 +424,7 @@ def forward(self, L_end_tensors_0_: "f32[4]", L_end_tensors_1_: "f32[4]", L_tens copy__1: "f32[4]" = l_tensors_1_.copy_(l_end_tensors_1_); l_end_tensors_1_ = copy__1 = None _foreach_addcmul_ = torch._foreach_addcmul_([l_tensors_0_, l_tensors_1_], [neg_omw, neg_omw], (getitem, getitem_1)); l_tensors_0_ = l_tensors_1_ = neg_omw = getitem = getitem_1 = _foreach_addcmul_ = None return () -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -467,7 +468,7 @@ def forward(self, L_end_tensors_0_: "f32[4]", L_end_tensors_1_: "f32[4]", L_tens copy__1: "f32[4]" = l_tensors_1_.copy_(l_end_tensors_1_); l_end_tensors_1_ = copy__1 = None _foreach_addcmul_ = torch._foreach_addcmul_([l_tensors_0_, l_tensors_1_], [neg_omw, neg_omw], (getitem, getitem_1)); l_tensors_0_ = l_tensors_1_ = neg_omw = getitem = getitem_1 = _foreach_addcmul_ = None return () -""", # noqa: B950 +""", ) def test_foreach_pow_scalar_decomposition_enabled(self): @@ -568,9 +569,13 @@ def forward(self, L_exps_0_: "f32[4]", L_exps_1_: "f32[4]"): """, ) + +class TestDynamoDecompositionsNumerics(TestCase): + """Numerics tests for dynamo decompositions across devices.""" + @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - def test_addcmul_tensor_value_numerics(self): + def test_addcmul_tensor_value_numerics(self, device): """Compiled addcmul_ with tensor value matches eager. Not bitwise on CPU: inductor may decompose fma to mul+add rather @@ -580,10 +585,10 @@ def test_addcmul_tensor_value_numerics(self): def fn(x, tensor1, tensor2, value): return x.addcmul_(tensor1, tensor2, value=value) - x = torch.randn(4) - tensor1 = torch.randn(4) - tensor2 = torch.randn(4) - value = torch.tensor(0.5) + x = torch.randn(4, device=device) + tensor1 = torch.randn(4, device=device) + tensor2 = torch.randn(4, device=device) + value = torch.tensor(0.5, device=device) expected = fn(x.clone(), tensor1, tensor2, value) actual = torch.compile(fn, fullgraph=True)(x.clone(), tensor1, tensor2, value) @@ -591,32 +596,32 @@ def fn(x, tensor1, tensor2, value): @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - def test_addcdiv_tensor_value_numerics(self): + def test_addcdiv_tensor_value_numerics(self, device): """Compiled addcdiv_ with tensor value matches eager.""" def fn(x, tensor1, tensor2, value): return x.addcdiv_(tensor1, tensor2, value=value) - x = torch.randn(4) - tensor1 = torch.randn(4) - tensor2 = torch.randn(4) + 0.1 - value = torch.tensor(0.5) + x = torch.randn(4, device=device) + tensor1 = torch.randn(4, device=device) + tensor2 = torch.randn(4, device=device) + 0.1 + value = torch.tensor(0.5, device=device) expected = fn(x.clone(), tensor1, tensor2, value) actual = torch.compile(fn, fullgraph=True)(x.clone(), tensor1, tensor2, value) - self.assertEqual(expected, actual, atol=0, rtol=0) + self.assertEqual(expected, actual) @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - def test_add_tensor_alpha_numerics(self): + def test_add_tensor_alpha_numerics(self, device): """Compiled add_ with tensor alpha matches eager.""" def fn(x, other, alpha): return x.add_(other, alpha=alpha) - x = torch.randn(4) - other = torch.randn(4) - alpha = torch.tensor(2.0) + x = torch.randn(4, device=device) + other = torch.randn(4, device=device) + alpha = torch.tensor(2.0, device=device) expected = fn(x.clone(), other, alpha) actual = torch.compile(fn, fullgraph=True)(x.clone(), other, alpha) @@ -624,14 +629,13 @@ def fn(x, other, alpha): @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_add_tensor_alpha_fma_matches_aten_cuda(self): - """On CUDA, ATen add_ with tensor alpha extracts the scalar and uses + def test_add_tensor_alpha_fma_matches_aten(self, device): + """ATen add_ with tensor alpha extracts the scalar and uses fma(other, alpha, self). Our decomposition must use fma to match.""" torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - other = torch.randn(64, 64, device="cuda") - alpha = torch.tensor(2.3, device="cuda") + x = torch.randn(64, 64, device=device) + other = torch.randn(64, 64, device=device) + alpha = torch.tensor(2.3, device=device) def fn(x, other, alpha): return x.add_(other, alpha=alpha) @@ -642,17 +646,16 @@ def fn(x, other, alpha): @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_addcmul_value_1_fma_matches_aten_cuda(self): - """On CUDA, ATen addcmul_ with value=1 uses hardware fma(t1, t2, self). + def test_addcmul_value_1_fma_matches_aten(self, device): + """ATen addcmul_ with value=1 uses hardware fma(t1, t2, self). Our decomposition uses inductor_prims.fma for this case. Without fma, mul(t1, t2) + self rounds the product first, causing ~7% element mismatches on typical inputs (e.g. Adagrad's addcmul_(grad, grad, value=1)). """ torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - t1 = torch.randn(64, 64, device="cuda") + x = torch.randn(64, 64, device=device) + t1 = torch.randn(64, 64, device=device) def fn(x, t1): # value=1 is a constant, triggers fma path in decomposition @@ -660,56 +663,54 @@ def fn(x, t1): expected = fn(x.clone(), t1) actual = torch.compile(fn, fullgraph=True)(x.clone(), t1) - self.assertEqual(expected, actual, atol=0, rtol=0) + self.assertEqual(expected, actual) @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_addcmul_scalar_value_cuda(self): - """Compiled addcmul_ with scalar value matches eager on CUDA.""" + def test_addcmul_scalar_value(self, device): + """Compiled addcmul_ with scalar value matches eager.""" torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - t1 = torch.randn(64, 64, device="cuda") - t2 = torch.randn(64, 64, device="cuda") + x = torch.randn(64, 64, device=device) + t1 = torch.randn(64, 64, device=device) + t2 = torch.randn(64, 64, device=device) def fn(x, t1, t2): return x.addcmul_(t1, t2, value=0.5) expected = fn(x.clone(), t1, t2) actual = torch.compile(fn, fullgraph=True)(x.clone(), t1, t2) - self.assertEqual(expected, actual, atol=0, rtol=0) + self.assertEqual(expected, actual) @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_addcmul_tensor_value_cuda(self): - """Compiled addcmul_ with tensor value matches eager on CUDA.""" + def test_addcmul_tensor_value(self, device): + """Compiled addcmul_ with tensor value matches eager.""" torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - t1 = torch.randn(64, 64, device="cuda") - t2 = torch.randn(64, 64, device="cuda") - value = torch.tensor(0.5, device="cuda") + x = torch.randn(64, 64, device=device) + t1 = torch.randn(64, 64, device=device) + t2 = torch.randn(64, 64, device=device) + value = torch.tensor(0.5, device=device) def fn(x, t1, t2, value): return x.addcmul_(t1, t2, value=value) expected = fn(x.clone(), t1, t2, value) actual = torch.compile(fn, fullgraph=True)(x.clone(), t1, t2, value) - self.assertEqual(expected, actual, atol=0, rtol=0) + self.assertEqual(expected, actual) @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_addcdiv_scalar_value_cuda(self): + def test_addcdiv_scalar_value_cuda(self, device): """Compiled addcdiv_ with scalar value matches eager on CUDA. Not bitwise: ATen inlines the division into fma(alpha, t1/t2, input) which nvcc can optimize differently than separate div + fma kernels. """ torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - t1 = torch.randn(64, 64, device="cuda") - t2 = torch.randn(64, 64, device="cuda") + 0.1 + x = torch.randn(64, 64, device=device) + t1 = torch.randn(64, 64, device=device) + t2 = torch.randn(64, 64, device=device) + 0.1 def fn(x, t1, t2): return x.addcdiv_(t1, t2, value=-0.01) @@ -721,17 +722,17 @@ def fn(x, t1, t2): @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_addcdiv_tensor_value_cuda(self): + def test_addcdiv_tensor_value_cuda(self, device): """Compiled addcdiv_ with tensor value matches eager on CUDA. Not bitwise: ATen inlines the division into fma(alpha, t1/t2, input) which nvcc can optimize differently than separate div + fma kernels. """ torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - t1 = torch.randn(64, 64, device="cuda") - t2 = torch.randn(64, 64, device="cuda") + 0.1 - value = torch.tensor(-0.01, device="cuda") + x = torch.randn(64, 64, device=device) + t1 = torch.randn(64, 64, device=device) + t2 = torch.randn(64, 64, device=device) + 0.1 + value = torch.tensor(-0.01, device=device) def fn(x, t1, t2, value): return x.addcdiv_(t1, t2, value=value) @@ -742,20 +743,21 @@ def fn(x, t1, t2, value): @skipIfCrossRef @torch._dynamo.config.patch(enable_dynamo_decompositions=True) - @unittest.skipUnless(torch.cuda.is_available(), "requires CUDA") - def test_add_scalar_alpha_cuda(self): - """Compiled add_ with scalar alpha matches eager on CUDA.""" + def test_add_scalar_alpha(self, device): + """Compiled add_ with scalar alpha matches eager.""" torch.manual_seed(42) - x = torch.randn(64, 64, device="cuda") - other = torch.randn(64, 64, device="cuda") + x = torch.randn(64, 64, device=device) + other = torch.randn(64, 64, device=device) def fn(x, other): return x.add_(other, alpha=2.3) expected = fn(x.clone(), other) actual = torch.compile(fn, fullgraph=True)(x.clone(), other) - self.assertEqual(expected, actual, atol=0, rtol=0) + self.assertEqual(expected, actual) + +instantiate_device_type_tests(TestDynamoDecompositionsNumerics, globals()) if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_dynamo_runtime_assert.py b/test/dynamo/test_dynamo_runtime_assert.py deleted file mode 100644 index d2a144d792650..0000000000000 --- a/test/dynamo/test_dynamo_runtime_assert.py +++ /dev/null @@ -1,98 +0,0 @@ -# Owner(s): ["module: dynamo"] -import time - -import torch -import torch._dynamo.test_case -import torch._dynamo.testing -import torch._dynamo.utils - - -class DenseBlock(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.linear = torch.nn.Linear(dim, dim) - self.norm = torch.nn.LayerNorm(dim) - self.gate = torch.nn.Linear(dim, dim) - - def forward(self, x): - return self.norm(self.linear(x)) * torch.sigmoid(self.gate(x)) - - -class DenseArch(torch.nn.Module): - def __init__(self, dim, num_layers): - super().__init__() - self.blocks = torch.nn.ModuleList([DenseBlock(dim) for _ in range(num_layers)]) - - def forward(self, x): - for block in self.blocks: - x = block(x) - return x - - -class RecModel(torch.nn.Module): - """Simplified recommendation model with many nested submodules.""" - - def __init__(self, num_events=10, num_layers=6, num_embeddings=12, dim=128): - super().__init__() - self.shared_arch = DenseArch(dim, num_layers) - self.event_submodels = torch.nn.ModuleDict( - { - f"event_{i}": torch.nn.Sequential( - DenseArch(dim, num_layers), - torch.nn.Linear(dim, 1), - ) - for i in range(num_events) - } - ) - self.embeddings = torch.nn.ModuleList( - [torch.nn.Embedding(1000, dim) for _ in range(num_embeddings)] - ) - - def forward(self, x): - x = self.shared_arch(x) - outputs = [] - for submodel in self.event_submodels.values(): - outputs.append(submodel(x)) - return torch.cat(outputs, dim=-1) - - -class RuntimeAssertCompileTimeTests(torch._dynamo.test_case.TestCase): - """Regression test for _set_node_metadata_hook overhead in - insert_deferred_runtime_asserts (S603290). - - With inline_inbuilt_nn_modules=False, call_module nodes cause _copy_attr - to install full module subtrees into the GraphModule, making gm.modules() - large. If the pass wraps every loop iteration with _set_node_metadata_hook - (which iterates gm.modules() on enter AND exit), the cost becomes - O(nodes * modules) — catastrophic for large models. - """ - - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_insert_runtime_assert_pass_time(self): - # Large model so overhead from O(nodes*modules) would be measurable. - model = RecModel(num_events=30, num_layers=10, num_embeddings=20, dim=128) - x = torch.randn(8, 128) - - torch._dynamo.reset() - compiled = torch.compile(model, backend="eager") - t0 = time.perf_counter() - compiled(x) - total_s = time.perf_counter() - t0 - - pass_times = torch._dynamo.utils.compilation_time_metrics.get( - "insert_deferred_runtime_asserts", [] - ) - pass_s = sum(pass_times) - - self.assertLess( - pass_s / total_s, - 0.05, - f"insert_deferred_runtime_asserts took {pass_s * 1000:.1f}ms out of " - f"{total_s * 1000:.1f}ms ({pass_s / total_s * 100:.1f}% of compile time)", - ) - - -if __name__ == "__main__": - from torch._dynamo.test_case import run_tests - - run_tests() diff --git a/test/dynamo/test_einops.py b/test/dynamo/test_einops.py index c24b4a76cbaaf..5e8a77098c127 100644 --- a/test/dynamo/test_einops.py +++ b/test/dynamo/test_einops.py @@ -110,7 +110,7 @@ def test_layers(self, version): for size in [16, 32, 64]: x = torch.rand([size, size]) result1 = original(x) - result2 = compiled(x.double()).float() + result2 = compiled(x) self.assertEqual(result1, result2) @parametrize("version", [einops_version_sanitized]) diff --git a/test/dynamo/test_enum.py b/test/dynamo/test_enum.py index 32a658794f21d..d94c1ab97b84e 100644 --- a/test/dynamo/test_enum.py +++ b/test/dynamo/test_enum.py @@ -428,7 +428,6 @@ def fn(x, priority): res = opt_fn(x, Priority.HIGH) self.assertEqual(ref, res) - @unittest.expectedFailure # TODO: Support Flag enum membership check def test_flag_enum(self): """Test Flag enum operations.""" @@ -452,6 +451,25 @@ def fn(x, perm): res = opt_fn(x, combined) self.assertEqual(ref, res) + def test_flag_enum_membership_combined(self): + """Test Flag enum membership check with combined flags.""" + + class LocalReduction(enum.Flag): + MEAN = enum.auto() + SUM = enum.auto() + MAX = enum.auto() + + @torch.compile(backend="eager", fullgraph=True) + def fn(x, local_reduction): + if LocalReduction.MEAN in local_reduction: + return x.mean() + return x.sum() + + x = torch.tensor([1.0, 2.0, 3.0]) + self.assertEqual(fn(x, LocalReduction.MEAN | LocalReduction.SUM), x.mean()) + torch._dynamo.reset() + self.assertEqual(fn(x, LocalReduction.SUM | LocalReduction.MAX), x.sum()) + def test_enum_comparison(self): """Test enum comparison operations.""" @@ -515,7 +533,6 @@ def fn(x, val): res = compiled_fn(x, 3) self.assertEqual(ref, res) - @unittest.expectedFailure def test_enum_construction_no_extra_init(self): # Real-world instance of the metaclass __call__ issue above. # EnumMeta.__call__ only calls __new__ (value lookup), NOT __init__. diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 8bacd4711817c..42039323db98c 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -1,7 +1,9 @@ # Owner(s): ["module: dynamo"] +import dis import logging import re +import sys import traceback import unittest import unittest.mock @@ -16,11 +18,7 @@ from torch._dynamo.exc import ResumePrologueTracingError, TorchRuntimeError, Unsupported from torch._dynamo.testing import skipIfNotPy312, skipIfOnlyNotPy312 from torch._dynamo.utils import counters -from torch.testing._internal.common_utils import ( - IS_FBCODE, - munge_exc, - scoped_load_inline, -) +from torch.testing._internal.common_utils import IS_FBCODE, munge_exc from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test @@ -47,6 +45,157 @@ def __exit__(self, exc_type, exc_value, traceback): pass +def _get_iter_has_positions() -> bool: + """Whether GET_ITER bytecodes have position info on this Python build. + + This varies across Python 3.12 point releases / platforms — some builds + include positions for GET_ITER and some don't, which affects whether + RangeIteratorVariable gets source attribution. + """ + code = compile("for x in range(1): pass", "", "exec") + for inst in dis.get_instructions(code): + if inst.opname == "GET_ITER": + return inst.positions is not None and inst.positions.lineno is not None + return False + + +def _generic_ctx_mgr_stack_source_attribution() -> str: + if sys.version_info >= (3, 11): + caret = "^^^^^^^^^^^^^^^\n" + else: + caret = "" + return ( + "Stack variable source attribution:\n" + " WithExitFunctionVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " with GenericCtxMgr():\n" + f"{caret}" + " WithExitFunctionVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " with GenericCtxMgr():\n" + f"{caret}" + ) + + +def _assert_failure_stack_source_attribution() -> str: + if sys.version_info >= (3, 11): + caret = "^^^^^^^^^^^^^^^\n" + else: + caret = "" + return ( + "Stack variable source attribution:\n" + " WithExitFunctionVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " with GenericCtxMgr():\n" + f"{caret}" + ) + + +def _load_global_has_positions() -> bool: + """Whether LOAD_GLOBAL bytecodes have position info on this Python build. + + This varies across Python 3.12 point releases / platforms — some builds + include positions for LOAD_GLOBAL and some don't, which affects whether + NullVariable (pushed as part of LOAD_GLOBAL's call convention) gets + source attribution. + """ + code = compile("def f(): x()", "", "exec") + for const in code.co_consts: + if hasattr(const, "co_code"): + for inst in dis.get_instructions(const): + if inst.opname == "LOAD_GLOBAL": + return ( + inst.positions is not None and inst.positions.lineno is not None + ) + return False + + +def _reconstruction_failure_gb_stack_source_attribution() -> str: + if sys.version_info >= (3, 14): + return ( + "Stack variable source attribution:\n" + " LazyVariableTracker(realized: SkipFunctionVariable()) originated from:\n" + ' File "test_error_messages.py", line N\n' + " torch._dynamo.graph_break()\n" + "^^^^^^^^^^^^^^^^^^^^^^^^^\n" + "\n" + ) + + if sys.version_info >= (3, 11) and _load_global_has_positions(): + return ( + "Stack variable source attribution:\n" + " NullVariable originated from:\n" + ' File "test_error_messages.py", line N\n' + " torch._dynamo.graph_break()\n" + "^^^^^^^^^^^^^^^^^^^^^^^^^\n" + "\n" + ) + + if sys.version_info >= (3, 11): + return "" + + return ( + "Stack variable source attribution:\n" + " LazyVariableTracker(realized: SkipFunctionVariable()) originated from:\n" + ' File "test_error_messages.py", line N\n' + " torch._dynamo.graph_break()\n" + "\n" + ) + + +def _graph_break_in_loop_stack_source_attribution() -> str: + if sys.version_info >= (3, 11) and _get_iter_has_positions(): + return ( + "Stack variable source attribution:\n" + " RangeIteratorVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " for i in range(2):\n" + "^^^^^^^^\n" + "\n" + ) + + if sys.version_info >= (3, 11): + return "" + + return ( + "Stack variable source attribution:\n" + " RangeIteratorVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " for i in range(2):\n" + "\n" + ) + + +def _skip_frame_in_loop_message_stack_source_attribution() -> str: + if sys.version_info >= (3, 11) and _get_iter_has_positions(): + return ( + "Stack variable source attribution:\n" + " RangeIteratorVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " for i in range(2):\n" + "^^^^^^^^\n" + " WithExitFunctionVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " with GenericCtxMgr():\n" + "^^^^^^^^^^^^^^^\n" + "\n" + ) + + if sys.version_info >= (3, 11): + return "" + + return ( + "Stack variable source attribution:\n" + " RangeIteratorVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " for i in range(2):\n" + " WithExitFunctionVariable() originated from:\n" + ' File "test_error_messages.py", line N\n' + " with GenericCtxMgr():\n" + "\n" + ) + + class ErrorMessagesTest(LoggingTestCase): def test_dynamic_shape_operator_no_meta_kernel(self): def fn(): @@ -469,92 +618,6 @@ def post_munge(string): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", ) - @scoped_load_inline - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") - def test_cpp_extension_recommends_custom_ops(self, load_inline): - cpp_source = """ - #include - at::Tensor foobar(const at::Tensor& x) { - return x.clone(); - } - """ - module = load_inline( - name="mylib", - cpp_sources=cpp_source, - functions="foobar", - verbose=True, - ) - - x = torch.ones(2, 2, requires_grad=True) - counters.clear() - - @torch.compile(backend="eager") - def f(x): - return module.foobar(x) - - with self.assertWarnsOnceRegex( - UserWarning, - "(?s).*https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html.*", - ): - f(x) - self.assertEqual(len(counters["graph_break"]), 1) - first_graph_break = next(iter(counters["graph_break"].keys())) - - first_graph_break = re.sub(r"mylib(_v\d+)?", "mylib", first_graph_break) - # HACK: this patches around the fact that PyBind11 improperly sets the - # __qualname__ attribute on functions and methods; see - # https://github.com/pybind/pybind11/issues/5774. This should be removed if - # that issue is fixed. - first_graph_break = re.sub( - r"pybind11_detail_function_record_v[^ .]+", "PyCapsule", first_graph_break - ) - - self.assertExpectedInline( - first_graph_break, - """\ -Attempted to call function marked as skipped - Explanation: Dynamo does not know how to trace the builtin `mylib.PyCapsule.foobar.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). - Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. - Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`. - - Developer debug context: module: mylib, qualname: PyCapsule.foobar, skip reason: cannot determine source file for mylib (likely a C extension or builtin) - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0007.html""", - ) - - cpp_source = """ - #include - at::Tensor baz(const at::Tensor& x) { - return x.clone(); - } - """ - module2 = load_inline( - name="mylib2", - cpp_sources=cpp_source, - functions="baz", - verbose=True, - ) - - torch._dynamo.reset() - - # Test that each warning only happens once - @torch.compile(backend="eager") - def f(x): - module2.baz(x) - module.foobar(x) - module.foobar(x) - module2.baz(x) - module.foobar(x) - module2.baz(x) - return x.clone() - - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - f(x) - f(x) - self.assertEqual(len(ws), 2) - def test_observed_exception(self): def fn(): raise RuntimeError("test") @@ -568,7 +631,7 @@ def fn(): Hint: Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance("force_eager")`. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: raised exception RuntimeError([ConstantVariable(str: 'test')]) + Developer debug context: raised exception RuntimeError('test') For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0088.html @@ -603,35 +666,6 @@ def fn(mod): return mod(1)""", ) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_class_property(self): - class Foo(torch.nn.Module): - attr = unittest - - def fn(mod, x): - return mod.attr - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)( - Foo(), torch.randn(3) - ), - """\ -Unsupported nn.Module attribute type - Explanation: Dynamo does not support tracing nn.Module attributes of type `module` - Hint: Refactor your code so that `attr` (type `module`) is not an attribute of `Foo` - Hint: Currently supported attribute types are methods, classmethods, staticmethods, properties, constants, and tensors. - Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - - Developer debug context: nn.Module subclass: Foo, name: attr, attribute type: module - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0161.html - -from user code: - File "test_error_messages.py", line N, in fn - return mod.attr""", - ) - def test_generic_ctx_mgr_graph_break_fullgraph_true(self): def fn(): with GenericCtxMgr(): @@ -667,8 +701,7 @@ def fn(): torch.compile(fn, backend="eager")() self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: @@ -692,12 +725,19 @@ def fn(): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html +""" + + _generic_ctx_mgr_stack_source_attribution() + + """ User code traceback: File "test_error_messages.py", line N, in test_generic_ctx_mgr_graph_break_fullgraph_false torch.compile(fn, backend="eager")() File "test_error_messages.py", line N, in fn torch._dynamo.graph_break() -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) def test_load_build_class(self): @@ -830,10 +870,7 @@ def post_munge(s): """, ) - self.assertExpectedInline( - post_munge( - munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0) - ), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: @@ -848,12 +885,22 @@ def post_munge(s): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0092.html +""" + + _reconstruction_failure_gb_stack_source_attribution() + + """\ User code traceback: File "test_error_messages.py", line N, in test_reconstruction_failure_gb torch.compile(fn, backend="eager")() File "test_error_messages.py", line N, in fn torch._dynamo.graph_break() -""", +""" + ) + + self.assertExpectedInline( + post_munge( + munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0) + ), + expected, ) def test_faketensor_nyi(self): @@ -924,6 +971,12 @@ def fn(x): """\ Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + + The branch condition involves a tensor computed as follows: + # File "test_error_messages.py", line N, in fn, code: if x.sum() > 0: + gt = gt(sum_1, 0) + + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. @@ -936,6 +989,35 @@ def fn(x): if x.sum() > 0:""", ) + def test_data_dependent_branching_bool_tensor_hints(self): + def cast_overflow_tensors(tensors, offset=1000): + if tensors.isinf().any() or tensors.isnan().any(): + clamp_value = torch.finfo(tensors.dtype).max - offset + tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value) + return tensors + + self.assertExpectedInlineMunged( + Unsupported, + lambda: torch.compile( + cast_overflow_tensors, backend="eager", fullgraph=True + )(torch.randn(3)), + """\ +Data-dependent branching + Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. + Hint: Note: Python `or`/`and` between tensor expressions (e.g. `tensor.any() or other_tensor.any()`) triggers implicit bool conversion. Use `torch.logical_or`/`torch.logical_and` or the `|`/`&` operators instead. + Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. + Hint: Use `torch.cond` to express dynamic control flow. + + Developer debug context: attempted to jump with TensorVariable() + + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html + +from user code: + File "test_error_messages.py", line N, in cast_overflow_tensors + if tensors.isinf().any() or tensors.isnan().any():""", + ) + # Test that the bytecode source attribution is correct with VariableTracker @make_logging_test(trace_bytecode=True) def test_variable_tracker_source_attribution(self, records): @@ -996,6 +1078,12 @@ def fn(x): Graph break in user code at test_error_messages.py:N Graph Break Reason: Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + + The branch condition involves a tensor computed as follows: + # File "test_error_messages.py", line N, in fn, code: if x.sum() > 0: + gt = gt(sum_1, 0) + + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. @@ -1023,8 +1111,7 @@ def fn(x): # only 1 graph break message self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: @@ -1040,12 +1127,19 @@ def fn(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0034.html +""" + + _assert_failure_stack_source_attribution() + + """ User code traceback: File "test_error_messages.py", line N, in test_assert_failure_in_generic_ctx_mgr torch.compile(fn, backend="eager")(torch.randn(3)) File "test_error_messages.py", line N, in fn assert x is None # noqa: S101 -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) def test_no_internal_compiler_stacktrace(self): @@ -1160,8 +1254,7 @@ def fn(x): fn(torch.ones(3)) self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: @@ -1184,12 +1277,19 @@ def fn(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb7000.html +""" + + _graph_break_in_loop_stack_source_attribution() + + """\ User code traceback: File "test_error_messages.py", line N, in test_graph_break_in_loop fn(torch.ones(3)) File "test_error_messages.py", line N, in fn torch._dynamo.graph_break() -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) @torch.compile(backend="eager") @@ -1203,14 +1303,19 @@ def gn(x): gn(torch.ones(3)) self.assertEqual(len(records), 2) - self.assertExpectedInline( - munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + + The branch condition involves a tensor computed as follows: + # File "test_error_messages.py", line N, in gn, code: if x.sum() > 0: + gt = gt(sum_1, 0) + + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. @@ -1228,12 +1333,19 @@ def gn(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb7000.html +""" + + _graph_break_in_loop_stack_source_attribution() + + """\ User code traceback: File "test_error_messages.py", line N, in test_graph_break_in_loop gn(torch.ones(3)) File "test_error_messages.py", line N, in gn if x.sum() > 0: -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0), + expected, ) @make_logging_test(graph_breaks=True) @@ -1247,14 +1359,19 @@ def fn(x): torch.compile(fn, backend="eager")(torch.randn(3)) self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + + The branch condition involves a tensor computed as follows: + # File "test_error_messages.py", line N, in fn, code: if x.sum() > 0: + gt = gt(sum_1, 0) + + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. @@ -1262,12 +1379,19 @@ def fn(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html +""" + + _skip_frame_in_loop_message_stack_source_attribution() + + """\ User code traceback: File "test_error_messages.py", line N, in test_skip_frame_in_loop_message torch.compile(fn, backend="eager")(torch.randn(3)) File "test_error_messages.py", line N, in fn if x.sum() > 0: -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) @make_logging_test(dynamo=logging.DEBUG) @@ -1433,6 +1557,9 @@ def fn(x): self.assertEqual((len(matches) <= 20), True) self.assertIn("Most recent bytecode instructions traced (max 20):", s) + # TODO this test is broken with nested_graph_breaks because we need to update + # the resume collapse function for nested graph breaks + @torch._dynamo.config.patch(nested_graph_breaks=False) @torch._dynamo.config.patch(verbose=True) @make_logging_test(graph_breaks=True) def test_graph_break_traceback_above_dynamo_shows_user_code(self, records): @@ -1819,8 +1946,9 @@ def fn(x): @make_logging_test(graph_breaks=True) def test_store_attr_graph_break(self, records): class Foo: + @torch.compiler.disable def __setattr__(self, name, value): - torch._dynamo.graph_break() + super().__setattr__(name, value) @torch.compile(backend="eager") def fn(x): @@ -1828,27 +1956,30 @@ def fn(x): fn(torch.ones(3)) + def post_munge(s): + return re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s) + self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + post_munge( + munge_exc(records[-1].getMessage(), suppress_suffix=True, skip=0) + ), """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Encountered graph break when attempting to trace STORE_ATTR: storing an object's attribute, e.g. x.attr = y: -Call to `torch._dynamo.graph_break()` - Explanation: User-inserted graph break. Message: None - Hint: Remove the `torch._dynamo.graph_break()` call. +Skip inlining `torch.compiler.disable()`d function + Explanation: Skip inlining function .Foo.__setattr__ at 0xmem_addr> since it was wrapped with `torch.compiler.disable` (reason: None) + Hint: Remove the `torch.compiler.disable` call - Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + Developer debug context: .Foo.__setattr__ at 0xmem_addr> - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0099.html User code traceback: File "test_error_messages.py", line N, in test_store_attr_graph_break fn(torch.ones(3)) File "test_error_messages.py", line N, in fn Foo().attr = x - File "test_error_messages.py", line N, in __setattr__ - torch._dynamo.graph_break() """, ) @@ -1859,19 +1990,18 @@ def fn(x): Unsupported, lambda: fn(torch.ones(3)), """\ -Call to `torch._dynamo.graph_break()` - Explanation: User-inserted graph break. Message: None - Hint: Remove the `torch._dynamo.graph_break()` call. +Skip inlining `torch.compiler.disable()`d function + Explanation: Skip inlining function .Foo.__setattr__ at 0xmem_addr> since it was wrapped with `torch.compiler.disable` (reason: None) + Hint: Remove the `torch.compiler.disable` call - Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}` + Developer debug context: .Foo.__setattr__ at 0xmem_addr> - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0099.html from user code: File "test_error_messages.py", line N, in fn - Foo().attr = x - File "test_error_messages.py", line N, in __setattr__ - torch._dynamo.graph_break()""", + Foo().attr = x""", + post_munge=post_munge, ) def test_runtime_error_readable_shape_mismatch(self): @@ -2199,39 +2329,6 @@ def post_munge(s): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0039.html""", ) - @unittest.skipIf( - not torch.utils._triton.has_triton() - or not hasattr(__import__("triton"), "set_allocator"), - "requires triton with set_allocator support", - ) - def test_triton_set_allocator(self): - import triton - - def fn(x): - triton.set_allocator(lambda size, align, stream: None) - return x * 2 + 1 - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(10)), - """\ -triton.set_allocator not supported - Explanation: triton.set_allocator is not supported inside torch.compile. It modifies global Triton allocator state and cannot be traced. - Hint: Move triton.set_allocator() outside of the torch.compile region (call it before the compiled function). - - Developer debug context: triton.set_allocator called inside compiled region - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb4026.html - -from user code: - File "test_error_messages.py", line N, in fn - triton.set_allocator(lambda size, align, stream: None)""", - ) - - -class NestedGraphBreakLoggingTests( - LoggingTestCase, torch._dynamo.test_case.TestCaseWithNestedGraphBreaks -): @make_logging_test(graph_breaks=True) def test_nested_generic_ctx_mgr(self, records): def inner(): @@ -2247,8 +2344,7 @@ def fn(): torch.compile(fn, backend="eager")() self.assertEqual(len(records), 2) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: @@ -2283,6 +2379,10 @@ def fn(): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html +""" + + _generic_ctx_mgr_stack_source_attribution() + + """\ + User code traceback: File "test_error_messages.py", line N, in test_nested_generic_ctx_mgr torch.compile(fn, backend="eager")() @@ -2290,7 +2390,11 @@ def fn(): inner() File "test_error_messages.py", line N, in inner torch._dynamo.graph_break() -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) self.assertExpectedInline( munge_exc(records[1].getMessage(), suppress_suffix=True, skip=0), @@ -2322,13 +2426,6 @@ def fn(): def test_skipped_frame_with_verbose_traceback_nested(self, records): global f1, f2, f3 - class GenericCtxMgr: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - def f1(x): with GenericCtxMgr(): torch._dynamo.graph_break() @@ -2342,8 +2439,7 @@ def f3(x): torch.compile(f3, backend="eager")(torch.randn(3)) self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered: @@ -2367,6 +2463,10 @@ def f3(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0066.html +""" + + _assert_failure_stack_source_attribution() + + """\ + User code traceback: File "test_error_messages.py", line N, in test_skipped_frame_with_verbose_traceback_nested torch.compile(f3, backend="eager")(torch.randn(3)) @@ -2376,20 +2476,17 @@ def f3(x): return f1(x + 2) File "test_error_messages.py", line N, in f1 torch._dynamo.graph_break() -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) @make_logging_test(graph_breaks=True) def test_skip_frame_in_loop_message_nested(self, records): global f1, f2, f3 - class GenericCtxMgr: - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - pass - def f1(x): for i in range(2): with GenericCtxMgr(): @@ -2405,14 +2502,19 @@ def f3(x): result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841 self.assertEqual(len(records), 1) - self.assertExpectedInline( - munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected = ( """\ Graph break in user code at test_error_messages.py:N Graph Break Reason: Encountered graph break that we cannot resume from. Compiling up to the previous resumable state, then skipping the rest of the function. Graph break encountered: Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. + + The branch condition involves a tensor computed as follows: + # File "test_error_messages.py", line N, in f1, code: if x.sum() > 0: + gt = gt(sum_1, 0) + + Hint: For the common pattern `if tensor_cond: x = transform(x)` (e.g. clamping inf/nan values), consider making the code branchless by always applying the transform. Operations like torch.clamp, torch.nan_to_num, and torch.where are typically no-ops on well-behaved inputs and compile without graph breaks. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. @@ -2420,6 +2522,9 @@ def f3(x): For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html +""" + + _skip_frame_in_loop_message_stack_source_attribution() + + """\ User code traceback: File "test_error_messages.py", line N, in test_skip_frame_in_loop_message_nested result = torch.compile(f3, backend="eager")(torch.randn(3)) # noqa: F841 @@ -2429,7 +2534,11 @@ def f3(x): return f1(x + 4) File "test_error_messages.py", line N, in f1 if x.sum() > 0: -""", +""" + ) + self.assertExpectedInline( + munge_exc(records[0].getMessage(), suppress_suffix=True, skip=0), + expected, ) @make_logging_test(graph_breaks=True) diff --git a/test/dynamo/test_exc.py b/test/dynamo/test_exc.py index 3f934dc1f5ebd..2e7128e028765 100644 --- a/test/dynamo/test_exc.py +++ b/test/dynamo/test_exc.py @@ -1,6 +1,8 @@ # Owner(s): ["module: dynamo"] +import sys import unittest +from typing import cast import torch import torch._dynamo @@ -13,6 +15,7 @@ UserError, UserErrorType, ) +from torch._dynamo.variables.base import SourceLocation from torch.testing._internal.common_device_type import skipIf from torch.testing._internal.common_utils import ( IS_FBCODE, @@ -23,6 +26,39 @@ from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test +# Module-level storage avoids free-variable issues when capturing comptime state. +_source_location_capture: dict[str, SourceLocation] = {} + + +def _capture_y_source_location(ctx) -> None: + tx = ctx._i_will_not_complain_if_bc_breaks_InstructionTranslator() + y_vt = tx.symbolic_locals.get("y") + if y_vt is not None and y_vt.source_location is not None: + _source_location_capture["source_location"] = y_vt.source_location + + +def _unsupported_error_source_attribution() -> str: + if sys.version_info < (3, 11): + return """\ +Stack variable source attribution: + ConstantVariable(int: 1) originated from: + File "test_exc.py", line N + return {1, 2} +""" + + return """\ +Stack variable source attribution: + ConstantVariable(int: 1) originated from: + File "test_exc.py", line N + return {1, 2} +^ + ConstantVariable(int: 2) originated from: + File "test_exc.py", line N + return {1, 2} +^ +""" + + class ExcTests(LoggingTestCase): maxDiff = None @@ -154,27 +190,31 @@ def fn001(x): torch.compile(fn001, backend="eager")(torch.randn(1)) record = self.getRecord(records, "missing BUILD_SET handler") + expected = ( + "Graph break in user code at test_exc.py:N\n" + "Graph Break Reason: Failed to handle graph break gracefully. " + "Skipping the function and falling back to eager. Graph break " + "encountered:\n" + "\n" + "missing BUILD_SET handler\n" + " Explanation: Missing BUILD_SET bytecode handler (for testing purposes).\n" + "\n" + "\n" + " Developer debug context:\n" + "\n" + " For more details about this graph break, please visit: " + "https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0200.html\n" + "\n" + _unsupported_error_source_attribution() + "\n" + "User code traceback:\n" + ' File "test_exc.py", line N, in test_unsupported_error\n' + ' torch.compile(fn001, backend="eager")(torch.randn(1))\n' + ' File "test_exc.py", line N, in fn001\n' + " return {1, 2}\n" + ) self.assertExpectedInline( munge_exc(record.getMessage()), - """\ -Graph break in user code at test_exc.py:N -Graph Break Reason: Failed to handle graph break gracefully. Skipping the function and falling back to eager. Graph break encountered: - -missing BUILD_SET handler - Explanation: Missing BUILD_SET bytecode handler (for testing purposes). - - - Developer debug context: - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0200.html - -User code traceback: - File "test_exc.py", line N, in test_unsupported_error - torch.compile(fn001, backend="eager")(torch.randn(1)) - File "test_exc.py", line N, in fn001 - return {1, 2} -""", # noqa: B950 + expected, ) @torch._dynamo.config.patch(suppress_errors=False) @@ -238,7 +278,7 @@ def fn001(x): return fn002(x) File "test_exc.py", line N, in fn002 torch._dynamo.graph_break() -""", # noqa: B950 +""", ) @make_logging_test(graph_breaks=True) @@ -402,6 +442,61 @@ def fn(x, shape): ==> (== (+ L['shape'][0] L['shape'][1] L['shape'][2]) L['x'].size()[0])""", ) + def test_source_location_format_no_col_info(self): + source_location = SourceLocation(filename=__file__, lineno=1) + result = source_location.format() + self.assertIn(f'File "{__file__}", line 1', result) + self.assertNotIn("^", result) + + def test_source_location_format_with_col_info(self): + source_location = SourceLocation( + filename=__file__, + lineno=1, + end_lineno=1, + col_offset=0, + end_col_offset=10, + ) + result = source_location.format() + self.assertIn(f'File "{__file__}", line 1', result) + self.assertIn("^" * 10, result) + + def test_source_location_format_without_source_line(self): + source_location = SourceLocation( + filename="", + lineno=1, + end_lineno=1, + col_offset=0, + end_col_offset=10, + ) + result = source_location.format() + self.assertEqual(result, ' File "", line 1\n') + + def test_vt_source_location_set_during_tracing(self): + _source_location_capture.clear() + + def fn(x): + y = x + 1 + comptime(_capture_y_source_location) + return y + + torch.compile(fn, backend="eager")(torch.ones(3)) + + source_location = _source_location_capture.get("source_location") + self.assertIsNotNone(source_location) + source_location = cast(SourceLocation, source_location) + self.assertEqual(source_location.filename, __file__.replace(".pyc", ".py")) + self.assertIsNotNone(source_location.lineno) + + @make_logging_test(graph_breaks=True) + def test_graph_break_source_attribution_on_stack(self, records): + def fn(x): + return (x + 1, torch._dynamo.graph_break())[0] # noqa: GB_REGISTRY + + torch.compile(fn, backend="eager")(torch.ones(3)) + + record = self.getRecord(records, "Graph break in user code") + self.assertIn("Stack variable source attribution", record.getMessage()) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 031117dc97c66..c1981980fa153 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import contextlib +import dataclasses import sys import torch @@ -127,7 +128,7 @@ def fn(x): x = torch.sigmoid(x) try: x = torch.cos(x) - raise AssertionError # noqa: B904 + raise AssertionError except AssertionError: x = torch.cos(x) @@ -187,7 +188,7 @@ def test_propagate_exception_inside_ctx_manager(self): def cm(): try: yield - except BaseException: # noqa: B036 + except BaseException: raise ValueError # noqa: B904 @contextlib.contextmanager @@ -265,7 +266,7 @@ def ctx(): for x, y in args: try: fn(x, y) - except BaseException: # noqa: B036 + except BaseException: new_exc = sys.exc_info() fix_exc_context(frame_exc[1], new_exc[1], prev_exc[1]) prev_exc = new_exc @@ -273,7 +274,7 @@ def ctx(): try: fixed_ctx = prev_exc[1].__context__ raise prev_exc[1] - except BaseException: # noqa: B036 + except BaseException: prev_exc[1].__context__ = fixed_ctx raise @@ -438,7 +439,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_custom_getattr_on_module_exception(self): class Foo(torch.nn.Module): def __init__(self, a=3): @@ -527,6 +527,30 @@ def forward(self, x): metrics = torch._dynamo.utils.get_compilation_metrics() self.assertIn("Observed exception", metrics[0].fail_reason) + def test_observed_exception_formats_fstring_message(self): + from torch.utils._pytree import tree_map_with_path + + def check_tensor(path, x): + if not isinstance(x, torch.Tensor): + raise ValueError(f"Expected Tensor at {path=}") + return x * 2 + + def fn(tree): + return tree_map_with_path(check_tensor, tree) + + tree = {"a": torch.randn(10), "b": 5} + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + with self.assertRaises(Unsupported) as compiled_ctx: + compiled_fn(tree) + + exc_str = str(compiled_ctx.exception) + self.assertIn("Observed exception", exc_str) + self.assertIn("Expected Tensor at path=(MappingKey(key='b'),)", exc_str) + self.assertNotIn("Failed to trace builtin operator", exc_str) + self.assertNotIn("StringFormatVariable", exc_str) + self.assertNotIn("ConstantVariable(", exc_str) + def test_key_error(self): def fn(x, d): try: @@ -650,7 +674,7 @@ def fn(): raise ZeroDivisionError except ZeroDivisionError: try: - raise ValueError # noqa: B904 + raise ValueError except ValueError: pass raise @@ -700,7 +724,7 @@ def cm(): yield 1 except ValueError: try: - raise TypeError # noqa: B904 + raise TypeError finally: pass @@ -730,7 +754,7 @@ def fn(): raise ValueError except ValueError: try: - raise TypeError # noqa: B904 + raise TypeError finally: pass @@ -783,7 +807,7 @@ def fn(t): raise GeneratorExit except Exception: return t.sin() - except BaseException: # noqa: B036 + except BaseException: return t.cos() t = torch.randn(2) @@ -1027,6 +1051,309 @@ def f(x): self.assertIn("in g", str(ctx.exception)) self.assertIn('raise Exception("Invalid")', str(ctx.exception)) + def test_str_repr_exception_no_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + try: + raise ValueError + except ValueError as e: + return t.sin(), str(e), repr(e) + + t = torch.randn(2) + y, s, r = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(s, "") + self.assertEqual(r, "ValueError()") + + def test_str_repr_exception_single_arg(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + try: + raise ValueError("test error") + except ValueError as e: + return t.sin(), str(e), repr(e) + + t = torch.randn(2) + y, s, r = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(s, "test error") + self.assertEqual(r, "ValueError('test error')") + + def test_str_repr_exception_multi_args(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + try: + raise ValueError("hello", 42) + except ValueError as e: + return t.sin(), str(e), repr(e) + + t = torch.randn(2) + y, s, r = fn(t) + self.assertEqual(y, t.sin()) + self.assertEqual(s, str(("hello", 42))) + self.assertEqual(r, "ValueError('hello', 42)") + + def test_frozen_dataclass_setattr_raises(self): + @dataclasses.dataclass(frozen=True) + class TestDataClass: + x: int + + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + dc = TestDataClass(1) + try: + dc.x = 2 + except dataclasses.FrozenInstanceError: + return t + 1 + except Exception: + return t + 2 + return t + dc.x + + self.assertEqual(fn(torch.zeros(1)), 1) + + def test_exception_traceback_access(self): + # Test that __traceback__ is accessible after raising/catching an exception + def fn(x): + try: + raise ValueError("oops") + except ValueError as e: + tb = e.__traceback__ + if tb is not None: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_traceback_tb_next(self): + # Test that tb_next can be accessed on a traceback + def fn(x): + try: + raise ValueError("oops") + except ValueError as e: + tb = e.__traceback__ + if tb is not None: + # tb_next is None for a single-frame traceback + if tb.tb_next is None: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_traceback_tb_lineno(self): + # Test that tb_lineno is accessible on a traceback + def fn(x): + try: + raise ValueError("oops") + except ValueError as e: + tb = e.__traceback__ + if tb is not None and tb.tb_lineno > 0: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_with_traceback_method(self): + # Test the with_traceback() method + def fn(x): + try: + raise ValueError("first") + except ValueError as e: + tb = e.__traceback__ + try: + raise RuntimeError("second").with_traceback(tb) from None + except RuntimeError as e2: + if e2.__traceback__ is not None: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_set_traceback(self): + # Test assigning __traceback__ on an exception + def fn(x): + try: + raise ValueError("first") + except ValueError as e: + tb = e.__traceback__ + try: + raise RuntimeError("second") from None + except RuntimeError as e2: + e2.__traceback__ = tb + if e2.__traceback__ is not None: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_set_traceback_none(self): + # Test assigning None to __traceback__ + def fn(x): + try: + raise ValueError("oops") + except ValueError as e: + e.__traceback__ = None + if e.__traceback__ is None: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_traceback_tb_lasti_graph_break(self): + # Accessing tb_lasti should cause a graph break + def fn(x): + try: + raise ValueError("oops") + except ValueError as e: + tb = e.__traceback__ + if tb is not None: + _ = tb.tb_lasti + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + # Should graph break but still produce correct results + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_set_tb_next(self): + # Test setting tb_next on a traceback + def fn(x): + try: + raise ValueError("first") + except ValueError as e: + tb1 = e.__traceback__ + try: + raise RuntimeError("second") from None + except RuntimeError as e2: + tb2 = e2.__traceback__ + if tb2 is not None and tb1 is not None: + tb2.tb_next = None + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_exception_traceback_chained(self): + # Test traceback chaining through multiple frames + def inner(): + raise ValueError("inner") + + def fn(x): + try: + inner() + except ValueError as e: + tb = e.__traceback__ + if tb is not None: + x = x + 1 + # Walk the traceback chain + depth = 0 + curr = tb + while curr is not None: + depth += 1 + curr = curr.tb_next + if depth > 0: + x = x + 1 + return x + + x = torch.randn(4) + ref = fn(x) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + + @parametrize( + "exc_type_1,exc_type_2", + [ + (ValueError, TypeError), + (CustomException, ValueError), + ], + name_fn=lambda exc1, exc2: f"{exc1.__name__}_to_{exc2.__name__}", + ) + def test_exception_set_context(self, exc_type_1, exc_type_2): + # Test explicitly assigning to __context__ attribute (reaches ExceptionVariable.__context__ assignment) + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + exc1 = exc_type_1("first") + exc2 = exc_type_2("second") + + # This explicitly sets __context__ via call_setattr + exc2.__context__ = exc1 + + # Verify it was set correctly + if exc2.__context__ is exc1: + return t.sin() + else: + return t.cos() + + t = torch.randn(2) + ref_result = t.sin() + result = fn(t) + self.assertEqual(result, ref_result) + + @parametrize( + "exc_type_1,exc_type_2,exc_type_3", + [ + (ValueError, TypeError, RuntimeError), + (CustomException, ValueError, TypeError), + ], + name_fn=lambda exc1, exc2, exc3: ( + f"{exc1.__name__}_chain_{exc2.__name__}_{exc3.__name__}" + ), + ) + def test_exception_context_chain(self, exc_type_1, exc_type_2, exc_type_3): + # Test chaining contexts through multiple exceptions + @torch.compile(backend="eager", fullgraph=True) + def fn(t): + exc1 = exc_type_1("first") + exc2 = exc_type_2("second") + exc3 = exc_type_3("third") + + exc2.__context__ = exc1 + exc3.__context__ = exc2 + + # Verify the chain + if isinstance(exc3.__context__, exc_type_2) and isinstance( + exc3.__context__.__context__, exc_type_1 + ): + return t.sin() + else: + return t.cos() + + t = torch.randn(2) + ref_result = t.sin() + result = fn(t) + self.assertEqual(result, ref_result) + instantiate_parametrized_tests(ExceptionTests) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 4e9b09fcaddee..aaf6a99c70609 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -130,6 +130,15 @@ def forward(self, x, y): return pytree.tree_unflatten([x], self._out_spec)""", ) + def test_export_empty_graph_no_error(self): + def func(x): + return len(x) + + exported = torch._dynamo.export(func)(torch.randn(5)) + out_graph = exported[0] + result = out_graph(torch.randn(5)) + self.assertEqual(result, 5) + def test_no_tensor_computation_2(self): inp = torch.randn(3) inp2 = 2 @@ -200,7 +209,7 @@ def func(x): hit = True self.assertExpectedInline( guard.code_list, - """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] and L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", # noqa: B950 + """["L['x'].stride()[0] == L['x'].size()[1]", "L['x'].stride()[1] == 1", "L['x'].storage_offset() == 0", "2 <= L['x'].size()[0] and L['x'].size()[0] <= 10", "2 <= L['x'].size()[1]"]""", ) break @@ -1931,7 +1940,7 @@ def forward(self, x): ge = sym_size_int_1 >= 2; sym_size_int_1 = None _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 2 on node 'ge'"); ge = _assert_scalar_default = None getitem_2 = cond[0]; cond = None - return pytree.tree_unflatten([getitem_2], self._out_spec)""", # noqa: B950 + return pytree.tree_unflatten([getitem_2], self._out_spec)""", ) self.assertExpectedInline( out_graph.cond_true_0.code.strip(), @@ -3839,7 +3848,7 @@ def forward(self, pred, x): cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (a, b, l_x_, d, c)); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None getitem = cond[0]; cond = None - return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122 + return pytree.tree_unflatten([getitem], self._out_spec)""", ) self.assertExpectedInline( @@ -4369,7 +4378,7 @@ def bad_fn(x): expected = [ """x = torch.sin(l_x_)""", - """cos = torch.cos(l_stack0_)""", + """cos = torch.cos(l_nested_frame_values_0_1_)""", ] def test_backend(gm: torch.fx.GraphModule, example_inputs): @@ -4404,7 +4413,7 @@ def forward(self, x): _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_args_0_ + 1; l_args_0_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 + return pytree.tree_unflatten([add], self._out_spec)""", ) self.assertEqual(out.requires_grad, False) with self.assertRaisesRegex( @@ -4427,7 +4436,7 @@ def forward(self, x): _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False) add = l_args_0_ + 1; l_args_0_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 + return pytree.tree_unflatten([add], self._out_spec)""", ) inp = torch.randn(2, 2) @@ -4449,7 +4458,7 @@ def forward(self, x): _enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True) add = l_x_ + 1; l_x_ = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None - return pytree.tree_unflatten([add], self._out_spec)""", # NOQA: B950 + return pytree.tree_unflatten([add], self._out_spec)""", ) inp = torch.randn(2, 2, requires_grad=True) out = gm(inp) @@ -4502,7 +4511,7 @@ def forward(self, x, b, y): x = l_x_.clone(); l_x_ = None x[l_b_] = l_y_; setitem = x; l_b_ = l_y_ = setitem = None _exit_inference_mode = torch.autograd.grad_mode._exit_inference_mode(_enter_inference_mode); _enter_inference_mode = _exit_inference_mode = None - return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950 + return pytree.tree_unflatten([x], self._out_spec)""", ) gm, _ = torch._dynamo.export(fn)(x, b, y) diff --git a/test/dynamo/test_fake_distributed.py b/test/dynamo/test_fake_distributed.py index bb534e1a4f80d..fd05e6bc67940 100644 --- a/test/dynamo/test_fake_distributed.py +++ b/test/dynamo/test_fake_distributed.py @@ -27,12 +27,14 @@ def normalize_graph(gm): @skipIf(not dist.is_available(), "requires distributed") class TestFakeDistributed(DynamoTestCase): def setUp(self): + super().setUp() # Use FakeProcessGroup to run tests on a single process dist.init_process_group(backend="fake", rank=0, world_size=2) self.local_rank = 0 self.world_size = 2 def tearDown(self): + super().tearDown() dist.destroy_process_group() def test_all_to_all_single_autograd(self): @@ -65,7 +67,7 @@ def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", primals_3: "f32[ wait_tensor: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None return (wait_tensor, primals_1, primals_2, floordiv) -""", # noqa: B950 +""", ) self.assertExpectedInline( normalize_graph(backend.bw_graphs[0]), @@ -75,7 +77,7 @@ def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", floordiv: "Sym(( all_to_all_single_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None wait_tensor_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None return (None, None, wait_tensor_1) -""", # noqa: B950 +""", ) backend.fw_graphs.clear() @@ -106,7 +108,7 @@ def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2 wait_tensor: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None return (wait_tensor, primals_1, primals_2, primals_3, floordiv) -""", # noqa: B950 +""", ) self.assertExpectedInline( normalize_graph(backend.bw_graphs[0]), @@ -116,7 +118,7 @@ def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2 all_to_all_single_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None wait_tensor_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None return (None, None, None, wait_tensor_1) -""", # noqa: B950 +""", ) def test_device_mesh_get_local_rank(self): @@ -158,6 +160,7 @@ def fn(x): res = fn(x) self.assertEqual(res, (x + 1, [0])) + @torch._dynamo.config.patch(nested_graph_breaks=False) def test_device_mesh_init_skip_after_graph_break(self): device_mesh = init_device_mesh( device_type="cpu", diff --git a/test/dynamo/test_flat_apply.py b/test/dynamo/test_flat_apply.py index 6583efb033a05..10891dc3c94cf 100644 --- a/test/dynamo/test_flat_apply.py +++ b/test/dynamo/test_flat_apply.py @@ -55,7 +55,7 @@ class InputData: values: Tensor -torch.utils._pytree.register_dataclass(InputData) +pytree.register_dataclass(InputData) @dataclass @@ -73,7 +73,7 @@ class OutputData: result2: Tensor -torch.utils._pytree.register_dataclass(OutputData) +pytree.register_dataclass(OutputData) @dataclass @@ -198,7 +198,7 @@ def forward(self, L_x_: "f32[10]", L_y_: "f32[10]"): flat_apply_capture = torch__dynamo_variables_torch_flat_apply_capture(trace_point_tensor_callable, trace_point_tensor_input_spec, l_x_, l_y_, t); trace_point_tensor_callable = trace_point_tensor_input_spec = l_x_ = l_y_ = t = None res: "f32[10]" = flat_apply_capture[0]; flat_apply_capture = None return (res,) -""", # NOQA: B950 +""", ) def test_nonstrict_trace_captured_tensor_post_aot_graph(self): @@ -225,7 +225,7 @@ def forward(self, arg0_1: "f32[10]", arg1_1: "f32[10]"): _tensor_constant0: "f32[1]" = self._tensor_constant0 add: "f32[10]" = torch.ops.aten.add.Tensor(mul, _tensor_constant0); mul = _tensor_constant0 = None return (add,) -""", # NOQA: B950 +""", ) @@ -281,7 +281,7 @@ def forward(self, L_i_values: "f32[4, 4]"): add_1: "f32[4, 4]" = add + z_result1; add = z_result1 = None add_2: "f32[4, 4]" = add_1 + z_result2; add_1 = z_result2 = None return (add_2,) -""", # NOQA: B950 +""", ) def test_dataclass_input(self): @@ -334,7 +334,7 @@ def forward(self, L_i_values: "f32[4, 4]"): add_1: "f32[4, 4]" = add + z_result1; add = z_result1 = None add_2: "f32[4, 4]" = add_1 + z_result2; add_1 = z_result2 = None return (add_2,) -""", # NOQA: B950 +""", ) def test_invalid_input(self): @@ -416,7 +416,7 @@ def forward(self, L_i_values: "f32[4, 4]"): add_1: "f32[4, 4]" = add + value; add = value = None add_2: "f32[4, 4]" = add_1 + value_1; add_1 = value_1 = None return (add_2,) -""", # NOQA: B950 +""", ) def test_invalid_output(self): diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 35e46e0e4824e..c62cc88c6d2b6 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -31,11 +31,10 @@ EagerAndRecordGraphs, normalize_gm, ) -from torch._dynamo.utils import ifdynstaticdefault, range_iterator, same +from torch._dynamo.utils import counters, ifdynstaticdefault, range_iterator, same from torch._dynamo.variables import ConstantVariable, SkipFunctionVariable from torch._dynamo.variables.lists import RangeVariable from torch.nn import functional as F -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -153,7 +152,7 @@ def inline_script_if_tracing_fn_with_default_args(x, y, c=1.2): return torch.cos(x * y) + c -class FunctionTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class FunctionTests(torch._dynamo.test_case.TestCase): @make_test def test_inline_jit_annotations(x): x = inline_script_if_tracing(x) @@ -492,9 +491,9 @@ def test_cls_eq(a, b): @make_test def test_obj_is(a, b): v = a + b - if MyCls() is None: # noqa: E711 + if MyCls() is None: return -1 - if MyCls() is not None: # noqa: E711 + if MyCls() is not None: v = v.sin() if MyCls() is MyCls(): return -2 @@ -505,9 +504,9 @@ def test_obj_is(a, b): @make_test def test_cls_is(a, b): v = a + b - if MyCls is None: # noqa: E711 + if MyCls is None: return -1 - if MyCls is not None: # noqa: E711 + if MyCls is not None: v = v.sin() if MyCls is not MyCls: return -2 @@ -1754,11 +1753,11 @@ def test_default_dict_set(self): self._test_default_dict_helper(set) def test_default_dict_lambda(self): - self._test_default_dict_helper(lambda: dict()) # noqa: C408 + self._test_default_dict_helper(lambda: dict()) def test_default_dict_closure(self): def factory(): - return dict() # noqa: C408 + return dict() self._test_default_dict_helper(factory) @@ -1785,7 +1784,7 @@ def test_default_dict_constr(self): param = torch.nn.Parameter(torch.ones([2, 2])) def fn(x): - dd = collections.defaultdict(lambda: dict()) # noqa: C408 + dd = collections.defaultdict(lambda: dict()) dd["a"] = x + 1 dd[param] = 123 dd["c"] = x * 2 @@ -1824,7 +1823,7 @@ def fn(x, y): @make_test def test_call_dict1(x): - d1 = dict() # noqa: C408 + d1 = dict() d1["x"] = x + 1 d2 = collections.OrderedDict() d2["x"] = x + 2 @@ -1832,7 +1831,7 @@ def test_call_dict1(x): @make_test def test_call_dict2(x): - d1 = dict() # noqa: C408 + d1 = dict() d1["x"] = x d2 = collections.OrderedDict(d1) if isinstance(d2, collections.OrderedDict): @@ -3008,7 +3007,6 @@ def fn(y): else: return x.cos() - @unittest.expectedFailure def test_getattr_metaclass(self): class Meta(type): def __getattr__(cls, name): @@ -3701,6 +3699,25 @@ def fn(): opt_fn = torch.compile(fn, fullgraph=True, backend="eager") self.assertEqual(opt_fn(), fn()) + def test_operator_concat(self): + for seq_type in (list, tuple): + with self.subTest(seq_type=seq_type): + + def fn(a, b): + return operator.concat(a, b) + + opt_fn = torch.compile(fn, fullgraph=True) + a = seq_type([1, 2, 3]) + b = seq_type([4, 5, 6]) + self.assertEqual(opt_fn(a, b), fn(a, b)) + + def test_operator_iconcat(self): + def fn(a, b): + return operator.iconcat(a, b) + + opt_fn = torch.compile(fn, fullgraph=True) + self.assertEqual(opt_fn([1, 2, 3], [4, 5, 6]), [1, 2, 3, 4, 5, 6]) + def test_attrgetter(self): for attrs in ( ("shape",), @@ -4038,6 +4055,20 @@ def fn(t): g = torch.compile(fn, backend="eager", fullgraph=True)(t) self.assertEqual(e, g) + @unittest.skipIf(sys.platform == "darwin", "No mkldnn on MacOS") + def test_quantize_per_tensor(self): + def fn(t, scale, zero_point): + return torch.quantize_per_tensor(t, scale, zero_point, torch.quint8) + + scale = torch.tensor(2.0) + zero_point = torch.tensor(10.0) + t = torch.rand((2, 2)) * scale + zero_point + + result = fn(t, scale, zero_point) + compiled_fn = torch.compile(fn, fullgraph=True) + compiled_result = compiled_fn(t, scale, zero_point) + self.assertEqual(compiled_result, result) + def test_map_return(self): def fn(a, b): return map(lambda x: x + 1, [a, b]) @@ -4628,7 +4659,7 @@ def forward(self): return self.m() -class DefaultsTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class DefaultsTests(torch._dynamo.test_case.TestCase): def test_func_default_tensor_args(self): """ Tests that we indeed reference (and mutate) "the one" default tensor arg @@ -4713,6 +4744,38 @@ def test_meth_default_tensor_args(self): self.assertEqual(cnts.frame_count, 3) self.assertEqual(cnts.op_count, 6) + def test_guard_on_constant_func_defaults(self): + """ + When a compiled function is re-invoked with a closure whose + __code__ is the same but __defaults__ differ (e.g. a different + constant default arg), Dynamo must recompile instead of reusing + the stale graph. + """ + cnts = torch._dynamo.testing.CompileCounter() + + def make_adder(offset): + def adder(x, _offset=offset): + return x + _offset + + return adder + + @torch.compile(backend=cnts) + def call_adder(x, fn): + return fn(x) + + x = torch.ones(4) + + adder0 = make_adder(0) + result0 = call_adder(x, adder0) + self.assertEqual(result0, x + 0) + self.assertEqual(cnts.frame_count, 1) + + # Same __code__, different __defaults__ → must recompile + adder5 = make_adder(5) + result5 = call_adder(x, adder5) + self.assertEqual(result5, x + 5) + self.assertEqual(cnts.frame_count, 2) + def test_func_default_torch_args(self): """ Tests other types of torch types as function default (size, dtype, device) @@ -4802,6 +4865,47 @@ def fn(x): ref = opt_fn(x) self.assertEqual(ref, res) + def test_pydantic_dataclass_construction(self): + @torch._dynamo.disable + def populate(self, x, y): + self.x = x + self.y = y + + @dataclass(init=False) + class Point: + x: torch.Tensor + y: torch.Tensor + # Pydantic uses this sentinel on decorated dataclasses. + __is_pydantic_dataclass__ = True + + def __init__(self, x, y): + populate(self, x, y) + + def fn(x, y): + p = Point(x=x, y=y) + return p.x + p.y + + torch._dynamo.reset() + counters.clear() + cnts = torch._dynamo.testing.CompileCounter() + compiled_fn = torch.compile(fn, backend=cnts) + x = torch.randn(4) + y = torch.randn(4) + + self.assertTrue(same(fn(x, y), compiled_fn(x, y))) + self.assertEqual(cnts.frame_count, 0) + self.assertEqual(cnts.op_count, 0) + # Skipping the whole frame records a second follow-on graph break, so + # assert on the specific pydantic entry rather than the raw count. + self.assertEqual( + [ + count + for msg, count in counters["graph_break"].items() + if "Pydantic dataclass constructor" in msg + ], + [1], + ) + def test_listlike_of_tensors_contains_constant(self): for listlike in [set, list]: @@ -5268,28 +5372,6 @@ def g(x, y): self.assertEqual(f(torch.ones(3, 3)), opt_f(torch.ones(3, 3))) self.assertEqual(cnts.frame_count, 3) - @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") - def test_gpu_current_device(self): - def fn(x): - y = torch.empty( - (2, 3), - dtype=torch.float32, - device=torch.accelerator.current_device_index(), - ) - y.copy_(x) - return torch.sin(y + y.device.index) - - counter = torch._dynamo.testing.CompileCounter() - opt_fn = torch.compile(backend=counter, fullgraph=True)(fn) - - with torch.accelerator.device_index(0): - x = torch.randn(2, 3) - self.assertEqual(opt_fn(x), fn(x)) - self.assertEqual(counter.frame_count, 1) - with torch.accelerator.device_index(1): - self.assertEqual(opt_fn(x), fn(x)) - self.assertEqual(counter.frame_count, 2) - def test_fn_with_attr(self): def fn(x): if fn.pred: @@ -5560,7 +5642,7 @@ def fn(x, klass): self.assertTrue(isinstance(res, tuple)) def test_udf_list(self): - class MyList(list): # noqa: SLOT001 + class MyList(list): def len_mulitply_2(self): return len(self) * 2 @@ -5594,7 +5676,7 @@ def fn(x, lst): self.assertTrue(res_lst.checked) def test_udf_list_slice(self): - class MyList(list): # noqa: SLOT001 + class MyList(list): def len_mulitply_2(self): return len(self) * 2 @@ -5612,7 +5694,7 @@ def fn(x, lst): self.assertEqual(len(ref_lst), len(res_lst)) def test_udf_list_reconstruction(self): - class MyList(list): # noqa: SLOT001 + class MyList(list): # def __new__(cls, *args, **kwargs): # return super().__new__(cls, *args, **kwargs) pass @@ -5706,7 +5788,7 @@ def get_torch_functional_functions(): self.assertTrue(callable(compiled_func)) def test_skip_function_call_very_weird_value(self): - class weird: # noqa: UP004 + class weird: def __getattribute__(self, name): if name == "__qualname__": raise AttributeError("test") diff --git a/test/dynamo/test_fwd_loss_bwd.py b/test/dynamo/test_fwd_loss_bwd.py index d33ff2559581c..603617c830c34 100644 --- a/test/dynamo/test_fwd_loss_bwd.py +++ b/test/dynamo/test_fwd_loss_bwd.py @@ -94,7 +94,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = loss.detach(); loss = None return (detach, getitem, getitem_1) -""", # noqa: B950 +""", ) self.assertEqual(len(backend.fw_graphs), 1) @@ -122,7 +122,7 @@ def forward(self, arg0_1: "f32[4, 4]", arg1_1: "f32[4]", arg2_1: "f32[2, 4]"): detach: "f32[]" = torch.ops.aten.detach.default(sum_1); sum_1 = None return (detach, t_3, view) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -169,7 +169,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = loss.detach(); loss = None return (detach,) -""", # noqa: B950 +""", ) self.assertEqual(len(backend.fw_graphs), 1) @@ -188,7 +188,7 @@ def forward(self, arg0_1: "f32[4, 4]", arg1_1: "f32[4]", arg2_1: "f32[2, 4]"): detach: "f32[]" = torch.ops.aten.detach.default(sum_1); sum_1 = None return (detach,) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -231,7 +231,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = loss.detach(); loss = None return (detach, getitem) -""", # noqa: B950 +""", ) self.assertEqual(len(backend.fw_graphs), 1) @@ -257,7 +257,7 @@ def forward(self, arg0_1: "f32[4, 4]", arg1_1: "f32[4]", arg2_1: "f32[2, 4]"): detach: "f32[]" = torch.ops.aten.detach.default(sum_1); sum_1 = None return (detach, t_3) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -345,7 +345,7 @@ def fn(external_input): Hint: Otherwise, move the autograd.grad() call outside the compiled region. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: inputs with external grad_fn: ["L['external_input']"]""" # noqa: B950 + Developer debug context: inputs with external grad_fn: ["L['external_input']"]""" ), ): fn(external_computation) @@ -402,7 +402,7 @@ def fn(ext): Hint: Otherwise, move the autograd.grad() call outside the compiled region. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: inputs with external grad_fn: ["L['ext']"]""" # noqa: B950 + Developer debug context: inputs with external grad_fn: ["L['ext']"]""" ), ): fn(external) @@ -455,7 +455,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = grad_norm.detach(); grad_norm = None sin: "f32[2, 4]" = l_x_.sin(); l_x_ = None return (detach, sin) -""", # noqa: B950 +""", ) self.assertEqual(len(backend.fw_graphs), 1) @@ -488,7 +488,7 @@ def forward(self, primals_1: "f32[4, 4]", primals_2: "f32[4]", primals_3: "f32[2 detach: "f32[]" = torch.ops.aten.detach.default(add); add = None sin: "f32[2, 4]" = torch.ops.aten.sin.default(primals_3) return (detach, sin, primals_3) -""", # noqa: B950 +""", ) # Trigger backward to compile the backward graph @@ -507,7 +507,7 @@ def forward(self, primals_3: "f32[2, 4]", tangents_1: "f32[2, 4]"): cos: "f32[2, 4]" = torch.ops.aten.cos.default(primals_3); primals_3 = None mul: "f32[2, 4]" = torch.ops.aten.mul.Tensor(tangents_1, cos); tangents_1 = cos = None return (None, None, mul) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -541,8 +541,9 @@ def step(mod, x): """\ autograd.grad consumed returned tensor's grad_fn Explanation: torch.autograd.grad() consumes grad_fns that are needed by tensors returned from this compiled function. This would cause 'backward through graph a second time' errors. - Hint: If you don't need to backward through the returned tensor, call .detach() before returning: `return loss.detach()` - Hint: If you need to backward through the returned tensor, use retain_graph=True in autograd.grad().""" # noqa: B950 + The following returned tensors have consumed grad_fns: loss + Hint: Detach the problematic tensor(s) before returning: e.g. `loss.detach()` + Hint: If you need to backward through the returned tensor, use retain_graph=True in autograd.grad().""" ), ): step_compiled_fullgraph(torch.nn.Linear(4, 4), torch.randn(2, 4)) @@ -590,17 +591,29 @@ def fn(x): torch._dynamo.reset() compiled_fn = torch.compile(fn, fullgraph=True, backend="aot_eager") - msg = textwrap.dedent( - """\ -autograd.grad consumed returned tensor's grad_fn - Explanation: torch.autograd.grad() consumes grad_fns that are needed by tensors returned from this compiled function. This would cause 'backward through graph a second time' errors. - Hint: If you don't need to backward through the returned tensor, call .detach() before returning: `return loss.detach()` - Hint: If you need to backward through the returned tensor, use retain_graph=True in autograd.grad().""" # noqa: B950 - ) + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + r"autograd\.grad consumed returned tensor's grad_fn", + ): + compiled_fn(torch.randn(4, requires_grad=True)) + + def test_autograd_grad_leaked_tensor_names_in_error(self): + """Test that the error message includes the names of all leaked tensors.""" + torch._dynamo.reset() + + def fn(x): + a = x * 2 + b = x * 3 + z = (a + b).sum() + torch.autograd.grad(z, x) + # Both a and b have consumed grad_fns + return a, b + + compiled_fn = torch.compile(fn, fullgraph=True, backend="aot_eager") with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, - re.escape(msg) + r"[\s\S]*", + r"Leaked output tensors:", ): compiled_fn(torch.randn(4, requires_grad=True)) @@ -645,7 +658,7 @@ def fn(edge, x): Hint: Or use tensor inputs directly instead of GradientEdge objects. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: GradientEdge in outputs: L['edge']""" # noqa: B950 + Developer debug context: GradientEdge in outputs: L['edge']""" ) with self.assertRaisesRegex( @@ -679,7 +692,7 @@ def fn(edges, x): Hint: Or use tensor inputs directly instead of GradientEdge objects. Hint: It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues. - Developer debug context: GradientEdge in outputs[0]: L['edges'][0]""" # noqa: B950 + Developer debug context: GradientEdge in outputs[0]: L['edges'][0]""" ) with self.assertRaisesRegex( @@ -830,7 +843,7 @@ def fn(x): """\ autograd.grad with already consumed grad_fn Explanation: torch.autograd.grad() is trying to consume grad_fns that were already consumed by a previous autograd.grad() call. This would cause 'backward through graph a second time' errors at runtime. - Hint: Use retain_graph=True in the first autograd.grad() call if you need to compute gradients through the same graph multiple times.""" # noqa: B950 + Hint: Use retain_graph=True in the first autograd.grad() call if you need to compute gradients through the same graph multiple times.""" ) with self.assertRaisesRegex( @@ -881,7 +894,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = loss.detach(); loss = None return (detach, new_grad_strided, new_grad_strided_1) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -970,7 +983,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[2, 4]" = res.detach(); res = None sum_1: "f32[]" = detach.sum(); detach = None return (sum_1, new_grad_strided, new_grad_strided_1) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -1031,7 +1044,7 @@ def forward(self, L_mod_parameters_weight_: "f32[4, 4]", L_mod_parameters_bias_: detach: "f32[]" = loss.detach(); loss = None return (detach, new_grad_strided, new_grad_strided_1) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -1107,7 +1120,7 @@ def forward(self, SYNTHETIC_LOCAL_tmp_0_: "f32[4, 4]", L_x_: "f32[2, 4]"): grad = torch.autograd.grad(loss, [w]); loss = w = None grad_1: "f32[4, 4]" = grad[0]; grad = None return (grad_1,) -""", # noqa: B950 +""", ) @skipIfCrossRef @@ -1131,7 +1144,7 @@ def fn(x): """\ backward() with in-graph created tensor Explanation: backward(inputs=[...]) with tensors created inside the compiled function is not yet supported. - Hint: Only pass tensors that are inputs to the compiled function or captured from outside""" # noqa: B950 + Hint: Only pass tensors that are inputs to the compiled function or captured from outside""" ), ): compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1184,7 +1197,7 @@ def fn(x): """\ backward() with non-leaf tensor Explanation: backward(inputs=[...]) with non-leaf tensors is not yet supported. - Hint: Only pass leaf tensors (parameters, graph inputs) to backward(inputs=...)""" # noqa: B950 + Hint: Only pass leaf tensors (parameters, graph inputs) to backward(inputs=...)""" ), ): compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) @@ -1210,6 +1223,81 @@ def fn(a, b): self.assertTrue(ref is act) + def test_autograd_grad_lost_grad_fn_in_closure(self): + def f(x): + return (x**2).sum() + + x = torch.randn(4, requires_grad=True) + _, vjp_fn = torch.func.vjp(f, x) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "_autograd_grad with lost grad_fn linkage", + ): + torch.compile(vjp_fn, backend="eager", fullgraph=True)(torch.ones(())) + + def test_autograd_grad_transform_closure_compiled_separately(self): + def f(x): + return (x**2).sum() + + x = torch.randn(4, requires_grad=True) + _, vjp_fn = torch.func.vjp(f, x) + v = torch.ones(()) + + eager_grad = vjp_fn(v) + + cnt = torch._dynamo.testing.CompileCounter() + compiled_grad = torch.compile(vjp_fn, backend=cnt)(v) + + self.assertEqual(compiled_grad, eager_grad) + self.assertEqual(cnt.frame_count, 0) + + @skipIfCrossRef + def test_autograd_grad_transform_compiled_end_to_end(self): + def f(x): + return (x**2).sum() + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + _, vjp_fn = torch.func.vjp(f, x) + return vjp_fn(torch.ones(())) + + x = torch.randn(4, requires_grad=True) + result = fn(x) + expected = torch.func.vjp(f, x)[1](torch.ones(())) + self.assertEqual(result, expected) + + def test_autograd_grad_multi_output_transform_closure(self): + def f(x): + return x * 2, x * 3 + + x = torch.randn(4, requires_grad=True) + _, vjp_fn = torch.func.vjp(f, x) + v = (torch.ones(4), torch.ones(4)) + + eager_grad = vjp_fn(v) + + cnt = torch._dynamo.testing.CompileCounter() + compiled_grad = torch.compile(vjp_fn, backend=cnt)(v) + self.assertEqual(compiled_grad, eager_grad) + + @skipIfCrossRef + def test_autograd_grad_inline_computation_no_graph_break(self): + def f(x): + return (x * 2).sum() + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + y, vjp_fn = torch.func.vjp(f, x) + return y, vjp_fn(torch.ones(())) + + x = torch.randn(4, requires_grad=True) + compiled_y, compiled_grad = fn(x) + eager_y, eager_vjp_fn = torch.func.vjp(f, x) + eager_grad = eager_vjp_fn(torch.ones(())) + self.assertEqual(compiled_y, eager_y) + self.assertEqual(compiled_grad, eager_grad) + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_fx_annotate.py b/test/dynamo/test_fx_annotate.py index 5228a5978aef0..e21f564b52a60 100644 --- a/test/dynamo/test_fx_annotate.py +++ b/test/dynamo/test_fx_annotate.py @@ -1,5 +1,7 @@ # Owner(s): ["module: dynamo"] +import warnings + import torch import torch._dynamo.test_case import torch.fx.traceback as fx_traceback @@ -48,21 +50,21 @@ def forward(self, x): ('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sub', {'pp_stage': 0}) -('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950 +('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", ) self.assertExpectedInline( str(fw_metadata), """\ ('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sub', {'pp_stage': 0}) -('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950 +('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", ) self.assertExpectedInline( str(bw_metadata), """\ ('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}) ('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0}) -('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950 +('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", ) def test_activation_checkpointing(self): @@ -93,17 +95,17 @@ def fn(x): ('get_attr', 'wrap_body_0', {'ac_sin': 0}) [('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})] ('call_function', 'tag_activation_checkpoint', {'ac_sin': 0}) -('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950 +('call_function', 'ac', {'ac_sin': 0})""", ) self.assertExpectedInline( str(fw_metadata), - """('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950 + """('call_function', 'sin', {'ac_sin': 0})""", ) self.assertExpectedInline( str(bw_metadata), """\ ('call_function', 'cos', {'ac_sin': 0}) -('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950 +('call_function', 'mul', {'ac_sin': 0})""", ) def test_activation_checkpointing_annotation_inside(self): @@ -131,17 +133,17 @@ def fn(x): bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0]) self.assertExpectedInline( str(dynamo_metadata), - """[('call_function', 'p', {'stage': 0})]""", # noqa: B950 + """[('call_function', 'p', {'stage': 0})]""", ) self.assertExpectedInline( str(fw_metadata), - """('call_function', 'sin', {'stage': 0})""", # noqa: B950 + """('call_function', 'sin', {'stage': 0})""", ) self.assertExpectedInline( str(bw_metadata), """\ ('call_function', 'cos', {'stage': 0}) -('call_function', 'mul', {'stage': 0})""", # noqa: B950 +('call_function', 'mul', {'stage': 0})""", ) @requires_cuda_and_triton @@ -203,7 +205,7 @@ def fn(x): ('get_attr', 'mask_fn_0', {'compile_inductor': 0}) [('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})] ('call_function', 'flex_attention', {'compile_inductor': 0}) -('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950 +('call_function', 'out', {'compile_inductor': 0})""", ) self.assertExpectedInline( str(fw_metadata), @@ -216,7 +218,7 @@ def fn(x): ('call_function', 'getitem', {'compile_inductor': 0}) ('call_function', 'getitem_1', {'compile_inductor': 0}) ('call_function', 'detach_1', {'compile_inductor': 0}) -('call_function', 'detach_3', {'compile_inductor': 0})""", # noqa: B950 +('call_function', 'detach_3', {'compile_inductor': 0})""", ) self.assertExpectedInline( str(bw_metadata), @@ -234,8 +236,58 @@ def fn(x): ('call_function', 'flex_attention_backward', {'compile_inductor': 0}) ('call_function', 'getitem_3', {'compile_inductor': 0}) ('call_function', 'getitem_4', {'compile_inductor': 0}) -('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950 +('call_function', 'getitem_5', {'compile_inductor': 0})""", + ) + + @requires_cuda_and_triton + def test_flex_attention_backward_tag_does_not_leak(self): + from torch.fx.experimental.proxy_tensor import make_fx + from torch.fx.traceback import preserve_node_meta + + def causal_mask(batch, head, q_idx, kv_idx): + del batch, head + return q_idx >= kv_idx + + q = torch.randn( + 1, 2, 128, 32, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + k = torch.randn( + 1, 2, 128, 32, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + v = torch.randn( + 1, 2, 128, 32, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + block_mask = create_block_mask(causal_mask, 1, 2, 128, 128, device="cuda") + + def fn(q, k, v, block_mask): + with fx_traceback.annotate({"ac_region_id": 0}): + y = flex_attention(q, k, v, block_mask=block_mask) + torch.autograd.grad(y, (q, k, v), torch.ones_like(y)) + return y.cos() + + warnings.filterwarnings( + "ignore", + message="flex_attention called without torch.compile", + ) + with ( + torch._dynamo.config.patch(error_on_nested_fx_trace=False), + torch.compiler._non_strict_tracing_context(), + torch.compiler._patch_autograd_grad(), + preserve_node_meta(), + ): + gm = make_fx(fn, record_stack_traces=True)(q, k, v, block_mask) + + backward_nodes = [ + node for node in gm.graph.nodes if node.meta.get("autograd_backward", False) + ] + self.assertTrue(backward_nodes) + + flex_nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.flex_attention ) + self.assertEqual(len(flex_nodes), 1) + self.assertNotIn("autograd_backward", flex_nodes[0].meta) + self.assertEqual(flex_nodes[0].meta.get("custom", {}), {"ac_region_id": 0}) def test_as_decorator(self): class Mod(torch.nn.Module): @@ -270,21 +322,21 @@ def forward(self, x): ('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sub', {'pp_stage': 0}) -('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950 +('call_function', 'mul', {'pp_stage': 0})""", ) self.assertExpectedInline( str(fw_metadata), """\ ('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) ('call_function', 'sub', {'pp_stage': 0}) -('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950 +('call_function', 'mul', {'pp_stage': 0})""", ) self.assertExpectedInline( str(bw_metadata), """\ ('call_function', 'mul_1', {'pp_stage': 0}) ('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0}) -('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950 +('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", ) def test_graph_break(self): @@ -326,7 +378,7 @@ def forward(self, x, y): ('call_function', 'mul_1', {'moo': 0}) ('call_function', 'ge', {'moo': 0}) ('call_function', '_check', {'moo': 0}) -('call_function', 'mul', {'moo': 0})""", # noqa: B950 +('call_function', 'mul', {'moo': 0})""", ) diff --git a/test/dynamo/test_fx_graph_runnable.py b/test/dynamo/test_fx_graph_runnable.py index 9ff32cbc2e60a..5da16806e1d1a 100644 --- a/test/dynamo/test_fx_graph_runnable.py +++ b/test/dynamo/test_fx_graph_runnable.py @@ -203,6 +203,7 @@ def f(x): self._exec_and_verify_payload() @unittest.skipUnless(has_triton(), "Triton not available") + @torch._dynamo.config.patch(nested_graph_breaks=False) def test_user_defined_triton_kernel_autotune(self): def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = torch.ones(x.shape, device=x.device, dtype=x.dtype) @@ -549,6 +550,44 @@ def f(view, weights): torch.compile(f)(view, weights) self._exec_and_verify_payload() + @torch._dynamo.config.patch(assume_static_by_default=False) + def test_repeat_interleave_with_output_size(self): + def f(data, repeats, output_size): + indices = torch.repeat_interleave(repeats, output_size=output_size.item()) + return data[indices] + + num_segments = 128 + data = torch.randn(1000, 16) + repeats = torch.randint(5, 15, (num_segments,), dtype=torch.int64) + output_size = repeats.sum() + + torch.compile(f, dynamic=True)(data, repeats, output_size) + + self._exec_and_verify_payload() + + # Verify the payload contains the repeat_interleave fixup + payload = self.buffer.getvalue().strip() + self.assertIn("def forward", payload) + self.assertIn("repeat_interleave", payload) + # Verify the fixup code is present + self.assertIn("# Fixup: ensure sum(repeats) == output_size", payload) + self.assertIn("_repeats.fill_", payload) + + def test_repeat_interleave_with_constant_output_size(self): + def f(data, repeats): + # output_size is a constant, not a dynamic input + indices = torch.repeat_interleave(repeats, output_size=1280) + return data[indices] + + num_segments = 128 + data = torch.randn(1000, 16) + repeats = torch.full((num_segments,), 10, dtype=torch.int64) + + torch.compile(f)(data, repeats) + self._exec_and_verify_payload() + payload = self.buffer.getvalue().strip() + self.assertNotIn("# Fixup: ensure sum(repeats) == output_size", payload) + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Skip in fbcode/sandcastle") class TestFxGraphRunnableMultiProcessGroup(TestCase): diff --git a/test/dynamo/test_generator.py b/test/dynamo/test_generator.py index 1660c1d045d99..d10bf6314a091 100644 --- a/test/dynamo/test_generator.py +++ b/test/dynamo/test_generator.py @@ -17,7 +17,7 @@ ) -class GeneratorTestsBase(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class GeneratorTestsBase(torch._dynamo.test_case.TestCase): def setUp(self): super().setUp() self._old = torch._dynamo.config.enable_faithful_generator_behavior diff --git a/test/dynamo/test_getitem.py b/test/dynamo/test_getitem.py new file mode 100644 index 0000000000000..96de001815576 --- /dev/null +++ b/test/dynamo/test_getitem.py @@ -0,0 +1,795 @@ +# Owner(s): ["module: dynamo"] +"""Tests for mp_subscript_impl: unified __getitem__ dispatch via vt_getitem in Dynamo. + +Tests exercise the vt_getitem → mp_subscript_impl path via operator.getitem(), +and the call_method("__getitem__") → mp_subscript_impl path via obj.__getitem__(). + +See TODO(follow-up) comments on each mp_subscript_impl override for remaining +CPython behavioral gaps. +""" + +import collections +import operator +import types +import unittest + +import torch +import torch._dynamo.test_case +import torch._dynamo.testing +from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON, HAS_GPU + + +class GetItemTests(torch._dynamo.test_case.TestCase): + def _compile(self, fn, *args): + return torch.compile(fn, backend="eager", fullgraph=True)(*args) + + # --- BaseListVariable (ListVariable) --- + + def test_list_int_index(self): + def fn(x): + items = [x, x + 1, x + 2] + return ( + operator.getitem(items, 0), + operator.getitem(items, 1), + operator.getitem(items, 2), + ) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_list_getitem_compiled_directly(self): + compiled = torch.compile(operator.getitem, backend="eager", fullgraph=True) + items = [10, 20, 30] + self.assertEqual(compiled(items, 0), 10) + self.assertEqual(compiled(items, 2), 30) + + def test_list_slice(self): + def fn(x): + items = [x, x + 1, x + 2] + full = operator.getitem(items, slice(None)) + partial = operator.getitem(items, slice(0, 2)) + single = operator.getitem(items, slice(1, 2)) + return full, partial, single + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_list_negative_index(self): + def fn(x): + items = [x, x + 1, x + 2] + return operator.getitem(items, -1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_list_bool_index(self): + def fn(x): + items = [x, x + 1, x + 2] + return ( + operator.getitem(items, False), + operator.getitem(items, True), + ) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_list_invalid_index_type(self): + def fn(x): + items = [x, x + 1, x + 2] + return operator.getitem(items, "a") + + x = torch.randn(4) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + def test_list_index_via_index_dunder(self): + """Custom __index__ object used as list index — _PyIndex_Check + nb_index_impl.""" + + class Idx: + def __index__(self): + return 2 + + def fn(x): + items = [x, x + 1, x + 2] + return operator.getitem(items, Idx()) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- BaseListVariable (TupleVariable) --- + + def test_tuple_int_index(self): + def fn(x): + items = (x, x + 1, x + 2) + return operator.getitem(items, 0) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tuple_negative_index(self): + def fn(x): + items = (x, x + 1, x + 2) + return operator.getitem(items, -1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tuple_slice(self): + def fn(x): + items = (x, x + 1, x + 2) + full = operator.getitem(items, slice(None)) + partial = operator.getitem(items, slice(0, 2)) + single = operator.getitem(items, slice(1, 2)) + return full, partial, single + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tuple_bool_index(self): + def fn(x): + items = (x, x + 1, x + 2) + return operator.getitem(items, False) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tuple_invalid_index_type(self): + def fn(x): + items = (x, x + 1, x + 2) + return operator.getitem(items, "a") + + x = torch.randn(4) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + def test_tuple_index_via_index_dunder(self): + class Idx: + def __index__(self): + return 2 + + def fn(x): + items = (x, x + 1, x + 2) + return operator.getitem(items, Idx()) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- RangeVariable --- + + def test_range_int_index(self): + def fn(x): + r = range(0, 10, 2) + return x + operator.getitem(r, 3) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_range_negative_index(self): + def fn(x): + r = range(0, 10, 2) + return x + operator.getitem(r, -1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_range_slice(self): + def fn(x): + r = range(0, 10, 2) + result = operator.getitem(r, slice(1, 3)) + return x + result[0] + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_range_bool_index(self): + def fn(x): + r = range(0, 10, 2) + return x + operator.getitem(r, True) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_range_index_via_index_dunder(self): + class Idx: + def __index__(self): + return 2 + + def fn(x): + r = range(0, 10, 2) + return x + operator.getitem(r, Idx()) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_range_invalid_index_type(self): + def fn(x): + r = range(0, 10, 2) + return x + operator.getitem(r, "a") + + x = torch.randn(4) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + # --- SizeVariable --- + + def test_size_int_index(self): + def fn(x): + s = x.size() + return x + operator.getitem(s, 0) + + x = torch.randn(4, 8) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_size_negative_index(self): + def fn(x): + s = x.size() + return x + operator.getitem(s, -1) + + x = torch.randn(4, 8) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_size_slice(self): + def fn(x): + s = x.size() + result = operator.getitem(s, slice(0, 1)) + return x + result[0] + + x = torch.randn(4, 8) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_size_bool_index(self): + def fn(x): + s = x.size() + return x + operator.getitem(s, False) + + x = torch.randn(4, 8) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_size_index_via_index_dunder(self): + class Idx: + def __index__(self): + return 1 + + def fn(x): + s = x.size() + return x + operator.getitem(s, Idx()) + + x = torch.randn(4, 8) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_size_invalid_index_type(self): + def fn(x): + s = x.size() + return x + operator.getitem(s, "a") + + x = torch.randn(4, 8) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + # --- ConstDictVariable --- + + def test_dict_str_key(self): + def fn(x): + d = {"a": x, "b": x + 1} + return operator.getitem(d, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_dict_int_key(self): + def fn(x): + d = {0: x, 1: x + 1} + return operator.getitem(d, 1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_dict_missing_key(self): + def fn(x): + d = {"a": x} + return operator.getitem(d, "missing") + + x = torch.randn(4) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + # --- DefaultDictVariable --- + + def test_defaultdict_existing_key(self): + def fn(x): + d = collections.defaultdict(lambda: x + 99) + d["a"] = x + return operator.getitem(d, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_defaultdict_missing_key(self): + def fn(x): + d = collections.defaultdict(list) + operator.getitem(d, "new") + return x + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- TensorVariable --- + + def test_tensor_int_index(self): + def fn(x): + return operator.getitem(x, 0) + + x = torch.randn(4, 4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tensor_slice(self): + def fn(x): + return operator.getitem(x, slice(0, 2)) + + x = torch.randn(4, 4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tensor_tuple_index(self): + def fn(x): + return operator.getitem(x, (0, slice(1, 3))) + + x = torch.randn(4, 4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_tensor_getitem_torch_function_mode(self): + """TorchFunctionMode intercepts tensor __getitem__ and can modify behavior.""" + + class AddOneMode(torch.overrides.TorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + result = func(*args, **(kwargs or {})) + if func is torch.Tensor.__getitem__: + return result + 1 + return result + + def fn(x): + with AddOneMode(): + return operator.getitem(x, 0) + + x = torch.randn(4, 4) + expected = fn(x) + compiled = torch.compile(fn, backend="eager")(x) + self.assertEqual(expected, compiled) + + def test_tensor_getitem_torch_function_subclass(self): + """Tensor subclass with __torch_function__ intercepts __getitem__.""" + + class ScaledTensor(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + result = super().__torch_function__(func, types, args, kwargs or {}) + if func is torch.Tensor.__getitem__: + return result * 2 + return result + + def fn(x): + return operator.getitem(x, 0) + + x = ScaledTensor(torch.randn(4, 4)) + expected = fn(x) + compiled = torch.compile(fn, backend="eager")(x) + self.assertEqual(expected, compiled) + + # --- NamedTupleVariable (via UserDefinedTupleVariable) --- + + def test_namedtuple_int_index(self): + def fn(x): + result = torch.topk(x, 2) + return operator.getitem(result, 1) + + x = torch.randn(10) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_namedtuple_values_index(self): + def fn(x): + result = torch.topk(x, 2) + return operator.getitem(result, 0) + + x = torch.randn(10) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- TypingVariable --- + + def test_typing_subscript(self): + def fn(x): + t = list[int] # noqa: F841 + return x + 1 + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- MappingProxyVariable --- + + def test_mappingproxy_getitem(self): + def fn(x): + d = {"a": 1, "b": 2} + proxy = types.MappingProxyType(d) + return x + operator.getitem(proxy, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- NNModuleVariable (ModuleList) --- + + def test_nn_module_list_int_index(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList( + [torch.nn.Linear(4, 4) for _ in range(3)] + ) + + def forward(self, x): + return operator.getitem(self.layers, 1)(x) + + model = Model() + x = torch.randn(4) + compiled = torch.compile(model, backend="eager", fullgraph=True) + self.assertEqual(model(x), compiled(x)) + + # --- NNModuleVariable (ModuleDict) --- + + def test_nn_module_dict_str_key(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleDict({"fc": torch.nn.Linear(4, 4)}) + + def forward(self, x): + return operator.getitem(self.layers, "fc")(x) + + model = Model() + x = torch.randn(4) + compiled = torch.compile(model, backend="eager", fullgraph=True) + self.assertEqual(model(x), compiled(x)) + + # --- NNModuleVariable (Sequential) --- + + def test_nn_sequential_int_index(self): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.seq = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + ) + + def forward(self, x): + return operator.getitem(self.seq, 0)(x) + + model = Model() + x = torch.randn(4) + compiled = torch.compile(model, backend="eager", fullgraph=True) + self.assertEqual(model(x), compiled(x)) + + # --- UserDefinedObjectVariable --- + + def test_user_defined_object_getitem(self): + class Container: + def __init__(self, items): + self.items = items + + def __getitem__(self, key): + return self.items[key] + + def fn(x): + c = Container([x, x + 1]) + return operator.getitem(c, 0) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_object_without_getitem(self): + class NoGetItem: + pass + + def fn(x): + obj = NoGetItem() + return operator.getitem(obj, 0) + + x = torch.randn(4) + with self.assertRaises(torch._dynamo.exc.Unsupported): + self._compile(fn, x) + + # --- UserDefinedListVariable --- + + def test_user_defined_list_getitem(self): + class MyList(list): + pass + + def fn(x): + items = MyList([x, x + 1, x + 2]) + return operator.getitem(items, 1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- UserDefinedTupleVariable --- + + def test_user_defined_tuple_getitem(self): + class MyTuple(tuple): # noqa: SLOT001 + pass + + def fn(x): + items = MyTuple((x, x + 1, x + 2)) + return operator.getitem(items, 1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- UserDefinedDictVariable --- + + def test_user_defined_dict_getitem(self): + class MyDict(dict): + pass + + def fn(x): + d = MyDict(a=x, b=x + 1) + return operator.getitem(d, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_user_defined_dict_missing(self): + class MyDict(dict): + def __missing__(self, key): + return 42 + + def fn(x): + d = MyDict(a=1) + return x + operator.getitem(d, "b") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_user_defined_dict_custom_missing(self): + class DefaultDict(dict): + def __missing__(self, key): + self[key] = len(self) + return self[key] + + def fn(x): + d = DefaultDict() + d["a"] = 1 + val = operator.getitem(d, "b") + return x + val + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_collections_counter_getitem(self): + def fn(x): + c = collections.Counter({"a": 1, "b": 2}) + return x + operator.getitem(c, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- UserDefined* with overridden __getitem__ --- + + def test_user_defined_dict_overridden_getitem(self): + """Dict subclass with custom __getitem__ should NOT delegate to _base_vt.""" + + class MyDict(dict): + def __getitem__(self, key): + return super().__getitem__(key) + 100 + + def fn(x): + d = MyDict(a=1, b=2) + return x + operator.getitem(d, "a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_user_defined_list_overridden_getitem(self): + """List subclass with custom __getitem__ should NOT delegate to _base_vt.""" + + class MyList(list): + def __getitem__(self, key): + return super().__getitem__(key) * 2 + + def fn(x): + items = MyList([x, x + 1, x + 2]) + return operator.getitem(items, 1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_counter_missing_key(self): + """Counter.__missing__ returns 0 for missing keys.""" + + def fn(x): + c = collections.Counter({"a": 1, "b": 2}) + return x + operator.getitem(c, "missing") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # --- TorchScriptObjectVariable --- + + def test_opaque_object_getitem(self): + from torch._library.opaque_object import ( + MemberType, + OpaqueBase, + register_opaque_type, + ) + + class OpaqueScaler(OpaqueBase): + def __init__(self, scale): + self.scale = scale + + def apply(self, x): + return x * self.scale + + class OpaqueContainer(OpaqueBase): + def __init__(self, items): + self.items = items + + def __getitem__(self, idx): + return self.items[idx] + + register_opaque_type( + OpaqueScaler, + typ="reference", + members={ + "scale": MemberType.USE_REAL, + "apply": MemberType.INLINED, + }, + ) + register_opaque_type( + OpaqueContainer, + typ="reference", + members={ + "items": MemberType.USE_REAL, + "__getitem__": MemberType.INLINED, + }, + ) + + def fn(x, c): + scaler = operator.getitem(c, 0) + return scaler.apply(x) + + x = torch.randn(4) + c = OpaqueContainer([OpaqueScaler(2.0), OpaqueScaler(3.0)]) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x, c), compiled(x, c)) + + # --- TritonKernelVariable --- + + @unittest.skipUnless(HAS_GPU and HAS_CUDA_AND_TRITON, "requires gpu and triton") + def test_triton_kernel_getitem_grid(self): + from torch.testing._internal.triton_utils import add_kernel + + def fn(x, y): + output = torch.zeros_like(x) + n_elements = output.numel() + grid = (n_elements // 256,) + bound = operator.getitem(add_kernel, grid) + bound(x, y, output, n_elements, BLOCK_SIZE=256) + return output + + x = torch.randn(256, device="cuda") + y = torch.randn(256, device="cuda") + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x, y), compiled(x, y)) + + # =================================================================== + # CPython behavioral gaps — expectedFailure until implemented + # =================================================================== + + # GAP 1: deque has only sq_item (int index), no mp_subscript. + # CPython: deque[slice] → TypeError "sequence index must be integer, not 'slice'" + # Dynamo: DequeVariable inherits BaseListVariable.mp_subscript_impl which accepts slices. + # TODO: DequeVariable should override mp_subscript_impl to reject slices, matching + # CPython's deque which only has sq_item (Modules/_collectionsmodule.c:1888). + @unittest.expectedFailure + def test_deque_slice_should_reject(self): + """deque does not support slicing in CPython — only sq_item (int index).""" + + def fn(x): + d = collections.deque([x, x + 1, x + 2]) + return operator.getitem(d, slice(0, 2)) + + x = torch.randn(4) + with self.assertRaises(TypeError): + self._compile(fn, x) + + # TODO: deque int index works but through the wrong dispatch path. + # CPython: PyObject_GetItem Branch 2 → _PyIndex_Check(key) → PyNumber_AsSsize_t → sq_item. + # Dynamo: inherited BaseListVariable.mp_subscript_impl (Branch 1 path). + # Result is correct, dispatch path diverges. Fix when sq_item branch is implemented. + def test_deque_int_index(self): + def fn(x): + d = collections.deque([x, x + 1, x + 2]) + return operator.getitem(d, 1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # GAP 2: dict_subscript calls _PyObject_HashFast → TypeError for unhashable keys. + # TODO: ConstDictVariable.mp_subscript_impl should check tp_hash and raise TypeError + # for unhashable keys, matching CPython's dict_subscript (Objects/dictobject.c:3680). + @unittest.expectedFailure + def test_dict_unhashable_key(self): + """dict[unhashable] should raise TypeError, not KeyError or silent failure.""" + + def fn(x): + d = {0: x, 1: x + 1} + return operator.getitem(d, [0]) + + x = torch.randn(4) + with self.assertRaises(TypeError): + self._compile(fn, x) + + # TODO: str/bytes subscript works via constant fold fallback (base mp_subscript_impl + # raises Unsupported → _make_handler → operator.getitem("hello", 0) evaluates at + # Python level), not via a proper mp_subscript_impl override mirroring CPython's + # unicode_subscript / bytes_subscript. Should add dedicated overrides on + # ConstantVariable to match CPython's dispatch path. + # CPython: Objects/unicodeobject.c:13809 (unicode_subscript) + # CPython: Objects/bytesobject.c (bytes_subscript) + + def test_str_subscript(self): + def fn(x): + s = "hello" + c = operator.getitem(s, 0) + return x + len(c) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_bytes_subscript(self): + def fn(x): + b = b"hello" + val = operator.getitem(b, 0) + return x + val + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + # =================================================================== + # Explicit __getitem__ dunder call path tests + # Exercises: obj.__getitem__(key) → LOAD_ATTR + CALL, which may + # route through call_method → mp_subscript_impl rather than + # vt_getitem → mp_subscript_impl. + # =================================================================== + + def test_list_dunder_getitem(self): + def fn(x): + items = [x, x + 1, x + 2] + return items.__getitem__(1) + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_dict_dunder_getitem(self): + def fn(x): + d = {"a": x, "b": x + 1} + return d.__getitem__("a") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + def test_user_defined_dict_missing_dunder_getitem(self): + """__missing__ fallback must work via __getitem__ method call, not just operator.getitem.""" + + class MyDict(dict): + def __missing__(self, key): + return 42 + + def fn(x): + d = MyDict(a=1) + return x + d.__getitem__("b") + + x = torch.randn(4) + self.assertEqual(fn(x), self._compile(fn, x)) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_global.py b/test/dynamo/test_global.py index 119d56d674e24..7ac80a056b6f6 100644 --- a/test/dynamo/test_global.py +++ b/test/dynamo/test_global.py @@ -12,7 +12,7 @@ import utils -class Pair: # noqa: B903 +class Pair: def __init__(self, x, y): self.x = x self.y = y diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 09ef1a7edd9e7..80c279abc038c 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -# flake8: noqa: B950 import contextlib import torch @@ -620,7 +619,7 @@ def fn(x, y): add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}), - """cycle detected in path: deque([output, add_2, add_2])""", + """cycle detected in path: [output, add_2, sum_1, add, add_2]""", ) def test_cycle_detection_two_node(self): @@ -644,7 +643,7 @@ def fn(x, y): mod.graph, {add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])}, ), - """cycle detected in path: deque([output, add_2, add, add_2])""", + """cycle detected in path: [output, add_2, sum_1, add, add_2]""", ) def test_cycle_detection_arg_and_additional_deps(self): @@ -665,7 +664,7 @@ def fn(x, y): add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}), - """cycle detected in path: deque([output, add_2, add, add_2])""", + """cycle detected in path: [output, add_2, sum_1, add, add_2]""", ) def test_cycle_detection_simple(self): @@ -676,7 +675,7 @@ def test_cycle_detection_simple(self): add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {}), - """cycle detected in path: deque([output, add_2, sum_1, add, add_2])""", + """cycle detected in path: [output, add_2, sum_1, add, add_2]""", ) def test_cycle_detection_complex(self): @@ -711,7 +710,7 @@ def fn(x, y): invoke_subgraph_node.args = (add_2, args[1]) self.assertExpectedInline( _detect_cycles(mod.graph, {}), - """cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""", + """cycle detected in path: [output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2]""", ) def test_autocast_ordering(self): diff --git a/test/dynamo/test_guard_exclusion.py b/test/dynamo/test_guard_exclusion.py new file mode 100644 index 0000000000000..cfa84c7e302d2 --- /dev/null +++ b/test/dynamo/test_guard_exclusion.py @@ -0,0 +1,566 @@ +# Owner(s): ["module: dynamo"] +import torch +import torch._dynamo +import torch._dynamo.testing +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +class GraphTracker: + """Backend that tracks which compiled graph (by compilation order) handles each call.""" + + def __init__(self): + self.graphs = [] + self.call_log = [] + + def __call__(self, gm, example_inputs): + graph_id = len(self.graphs) + self.graphs.append(gm) + + def wrapper(*args, **kwargs): + self.call_log.append(graph_id) + return gm.forward(*args, **kwargs) + + return wrapper + + @property + def frame_count(self): + return len(self.graphs) + + def reset(self): + self.graphs.clear() + self.call_log.clear() + + +@skipIfTorchDynamo("uses custom backend incompatible with PYTORCH_TEST_WITH_DYNAMO") +@torch._dynamo.config.patch(automatic_dynamic_exclusion_guard=True) +class TestGuardExclusion(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + def tearDown(self): + super().tearDown() + torch._dynamo.reset() + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_automatic_dynamic_exclusive_guard_basic(self): + """ + 1. [3, 4] -> Graph 0 (static) + 2. [5, 4] -> Graph 1 (dim 0 dynamic), exclusion rejects dim0==3 + 3. [7, 4] -> Graph 1 (reuse dynamic graph) + 4. [3, 4] -> Graph 0 (exclusion triggers, reverts to static) + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x1 = torch.randn(3, 4) + result1 = opt(x1) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + x2 = torch.randn(5, 4) + result2 = opt(x2) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # dynamic graph reuse + opt(torch.randn(7, 4)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # original shape reverts to Graph 0 + x3 = torch.randn(3, 4) + result3 = opt(x3) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 0) + + self.assertEqual(result1, x1 * 2) + self.assertEqual(result2, x2 * 2) + self.assertEqual(result3, x3 * 2) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_accumulated_exclusion_does_not_shadow_intermediate_graph(self): + """ + Tensor accumulation: dims become dynamic one at a time. + 1. func(3, 4) -> Graph 0: static (3, 4) + 2. func(5, 4) -> Graph 1: (s0, 4), excluded dim0=3 + 3. func(3, 19) -> Graph 2: (s0, s1), excluded dim1=4 + (dim0's exclusion is cleared since no dim transitioned) + 4. func(5, 4) -> should use Graph 1, not Graph 2 + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Call 1: shape [3, 4] -> compiles Graph 0 (static) + opt(torch.randn(3, 4)) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + # Call 2: shape [5, 4] -> compiles Graph 1 (s0, 4) + opt(torch.randn(5, 4)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Call 3: shape [3, 19] -> Graph 1 exclusion rejects (size(0)==3), + # Graph 0 rejects (19!=4), recompiles Graph 2 (s0, s1) + opt(torch.randn(3, 19)) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + # Call 4: shape [5, 4] -> should still use Graph 1 (s0, 4). + # Graph 2's exclusion is dim1=4, so (5, 4) is rejected and + # falls through to Graph 1. + opt(torch.randn(5, 4)) + + self.assertEqual( + tracker.call_log[-1], + 1, + "Input [5,4] should use Graph 1 (s0, 4), not Graph 2 (s0, s1). " + "Graph 2's exclusion must reject size(1)==4 independently, not " + "require size(0)==3 AND size(1)==4.", + ) + + # Call 5: shape [3, 4] -> should still use Graph 0 (static) + opt(torch.randn(3, 4)) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input [3,4] should use Graph 0 (static)", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_4d_progressive_dynamism_cascading(self): + """ + 4D tensor where dims become dynamic one at a time across recompilations. + Each new graph is more general, and exclusion guards ensure inputs cascade + to the most specialized graph. + + Graph 0: (2, 3, 4, 5) static + Graph 1: (dyn, 3, 4, 5) excluded dim0=2 + Graph 2: (dyn, dyn, 4, 5) excluded dim1=3 + Graph 3: (dyn, dyn, dyn, 5) excluded dim2=4 + """ + + def foo(x): + return x.sum() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Graph 0: (2, 3, 4, 5) static + opt(torch.randn(2, 3, 4, 5)) + self.assertEqual(tracker.frame_count, 1) + + # Graph 1: dim0 changes -> (dyn, 3, 4, 5) + opt(torch.randn(7, 3, 4, 5)) + self.assertEqual(tracker.frame_count, 2) + + # Graph 2: dim1 also changes -> (dyn, dyn, 4, 5) + # Input (7, 8, 4, 5): Graph 1 rejects dim1=8≠3, Graph 0 rejects dim0=7≠2 + opt(torch.randn(7, 8, 4, 5)) + self.assertEqual(tracker.frame_count, 3) + + # Graph 3: dim2 also changes -> (dyn, dyn, dyn, 5) + # Input (7, 8, 9, 5): Graph 2 rejects dim2=9≠4 + opt(torch.randn(7, 8, 9, 5)) + self.assertEqual(tracker.frame_count, 4) + + # Now verify cascading: each input routes to the most specialized graph. + # (2, 3, 4, 5) -> Graph 0 (static, most specialized) + opt(torch.randn(2, 3, 4, 5)) + self.assertEqual(tracker.call_log[-1], 0, "(2,3,4,5) -> Graph 0 (static)") + + # (7, 3, 4, 5) -> Graph 1 (dyn, 3, 4, 5) + opt(torch.randn(7, 3, 4, 5)) + self.assertEqual(tracker.call_log[-1], 1, "(7,3,4,5) -> Graph 1 (dyn,3,4,5)") + + # (7, 8, 4, 5) -> Graph 2 (dyn, dyn, 4, 5) + opt(torch.randn(7, 8, 4, 5)) + self.assertEqual(tracker.call_log[-1], 2, "(7,8,4,5) -> Graph 2 (dyn,dyn,4,5)") + + # (7, 8, 9, 5) -> Graph 3 (dyn, dyn, dyn, 5) + opt(torch.randn(7, 8, 9, 5)) + self.assertEqual( + tracker.call_log[-1], 3, "(7,8,9,5) -> Graph 3 (dyn,dyn,dyn,5)" + ) + + # (20, 30, 40, 5) -> Graph 3 (most general, no exclusion hit) + opt(torch.randn(20, 30, 40, 5)) + self.assertEqual( + tracker.call_log[-1], 3, "(20,30,40,5) -> Graph 3 (most general)" + ) + + self.assertEqual(tracker.frame_count, 4, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_5d_two_rounds_of_dynamism(self): + """ + 5D tensor with two rounds of automatic_dynamic. Verify inputs route + to the most specialized graph after each round. + + Graph 0: (2, 3, 4, 5, 6) static + Graph 1: (dyn, 3, 4, 5, 6) excluded dim0=2 + Graph 2: (dyn, 3, dyn, 5, 6) excluded dim2=4 + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Graph 0: static + opt(torch.randn(2, 3, 4, 5, 6)) + self.assertEqual(tracker.frame_count, 1) + + # Graph 1: dim0 becomes dynamic + opt(torch.randn(8, 3, 4, 5, 6)) + self.assertEqual(tracker.frame_count, 2) + + # Graph 2: dim2 also becomes dynamic + opt(torch.randn(8, 3, 10, 5, 6)) + self.assertEqual(tracker.frame_count, 3) + + # Verify routing: + # Original static shape -> Graph 0 + opt(torch.randn(2, 3, 4, 5, 6)) + self.assertEqual(tracker.call_log[-1], 0, "Original -> Graph 0") + + # dim0 differs, dim2 matches static -> Graph 1 + opt(torch.randn(9, 3, 4, 5, 6)) + self.assertEqual(tracker.call_log[-1], 1, "dim0 changed -> Graph 1") + + # dim0 differs, dim2 differs -> Graph 2 + opt(torch.randn(9, 3, 11, 5, 6)) + self.assertEqual(tracker.call_log[-1], 2, "dim0+dim2 changed -> Graph 2") + + # dim0 is original excluded value, dim2 differs -> Graph 2 should still + # accept because dim0's exclusion is None (already dynamic when snapshot taken) + opt(torch.randn(2, 3, 11, 5, 6)) + self.assertEqual( + tracker.call_log[-1], + 2, + "dim0=2 with dim2≠4 -> Graph 2 (dim0 exclusion is None)", + ) + + self.assertEqual(tracker.frame_count, 3, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_many_entries_wrong_graph_selection(self): + """ + Convoluted scenario: 4D tensor, three rounds of dynamism creating 4 graphs. + Without exclusion guards, the most general graph would shadow all others. + Test that each input gets the best (most specialized) match. + + Graph 0: (2, 3, 4, 5) static + Graph 1: (dyn, 3, 4, 5) excluded dim0=2 + After Graph 1, (2, 8, 4, 5) triggers dim1 dynamic: + Graph 2: (dyn, dyn, 4, 5) excluded dim1=3 + After Graph 2, (2, 8, 9, 5) triggers dim2 dynamic: + Graph 3: (dyn, dyn, dyn, 5) excluded dim2=4 + + Key: Graph 3 should NOT steal inputs that belong to Graph 0, 1, or 2. + """ + + def foo(x): + return x.relu() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # Build up 4 graphs progressively + opt(torch.randn(2, 3, 4, 5)) # Graph 0 + opt(torch.randn(7, 3, 4, 5)) # Graph 1: dim0 dynamic + opt(torch.randn(7, 8, 4, 5)) # Graph 2: dim1 also dynamic + opt(torch.randn(7, 8, 9, 5)) # Graph 3: dim2 also dynamic + self.assertEqual(tracker.frame_count, 4) + + # Now stress-test routing with various inputs: + test_cases = [ + # (shape, expected_graph, description) + ((2, 3, 4, 5), 0, "exact original -> static Graph 0"), + ((7, 3, 4, 5), 1, "dim0 differs -> Graph 1"), + ((99, 3, 4, 5), 1, "dim0 differs (large) -> Graph 1"), + ((7, 8, 4, 5), 2, "dim0+dim1 differ -> Graph 2"), + ((7, 99, 4, 5), 2, "dim0+dim1 differ (large) -> Graph 2"), + ((7, 8, 9, 5), 3, "dim0+dim1+dim2 differ -> Graph 3"), + ((99, 99, 99, 5), 3, "all non-static dims differ -> Graph 3"), + ] + + for shape, expected_graph, desc in test_cases: + opt(torch.randn(*shape)) + self.assertEqual(tracker.call_log[-1], expected_graph, desc) + + self.assertEqual(tracker.frame_count, 4, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_multi_dim_dynamic_and_semantics(self): + """ + When multiple dims become dynamic at once, AND semantics is critical. + Graph 0: (4, 3, 234, 5) static + Graph 1: (s0, 3, s2, s3) dynamic on dims 0,2,3. excluded=(4, _, 234, 5) + + OR semantics (wrong): rejects (4, 3, 100, 20) because dim0==4 matches. + AND semantics (correct): accepts (4, 3, 100, 20) because not ALL excluded + dims match (dim2=100≠234 and dim3=20≠5). + """ + + def foo(x): + return x * 2 + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + opt(torch.randn(4, 3, 234, 5)) # Graph 0: static + opt(torch.randn(10, 3, 100, 20)) # Graph 1: dims 0,2,3 dynamic + self.assertEqual(tracker.frame_count, 2) + + # Only the exact original shape should be excluded from Graph 1 + opt(torch.randn(4, 3, 234, 5)) + self.assertEqual(tracker.call_log[-1], 0, "Exact original -> Graph 0") + + # Partial matches should NOT be excluded (AND semantics) + opt(torch.randn(4, 3, 100, 20)) # dim0=4 matches, dims 2,3 don't + self.assertEqual(tracker.call_log[-1], 1, "dim0=4 partial match -> Graph 1") + + opt(torch.randn(10, 3, 234, 20)) # dim2=234 matches, dims 0,3 don't + self.assertEqual(tracker.call_log[-1], 1, "dim2=234 partial match -> Graph 1") + + opt(torch.randn(10, 3, 234, 5)) # dim2=234, dim3=5 match, dim0 doesn't + self.assertEqual(tracker.call_log[-1], 1, "dim2+dim3 partial match -> Graph 1") + + opt(torch.randn(4, 3, 234, 20)) # dim0=4, dim2=234 match, dim3 doesn't + self.assertEqual(tracker.call_log[-1], 1, "dim0+dim2 partial match -> Graph 1") + + # Totally new shape, no exclusion hit + opt(torch.randn(99, 3, 88, 77)) + self.assertEqual(tracker.call_log[-1], 1, "New shape -> Graph 1") + + self.assertEqual(tracker.frame_count, 2, "No additional recompilations") + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_integer_input_exclusion_basic(self): + """ + Integer inputs that become dynamic should also get exclusion guards. + 1. foo(x, 3) -> Graph 0: static n=3 + 2. foo(x, 5) -> Graph 1: dynamic n, excluded should reject n==3 + 3. foo(x, 3) -> should use Graph 0 (static), not Graph 1 + """ + + def foo(x, n): + return x * n + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x = torch.randn(4) + + opt(x, 3) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(x, 5) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + opt(x, 3) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input n=3 should use Graph 0 (static), not Graph 1 (dynamic n).", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_integer_input_exclusion_accumulation(self): + """ + Same accumulation scenario as the tensor test but with integer inputs. + 1. foo(x, 3, 4) -> Graph 0: static (3, 4) + 2. foo(x, 5, 4) -> Graph 1: dynamic (s0, 4), exclusion rejects n0==3 + 3. foo(x, 3, 19) -> Graph 2: dynamic (s0, s1), exclusion should reject + n1==4 independently, not require n0==3 AND n1==4 + 4. foo(x, 5, 4) -> should use Graph 1, not Graph 2 + """ + + def foo(x, n, m): + return x * n + m + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + x = torch.randn(4) + + opt(x, 3, 4) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(x, 5, 4) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + opt(x, 3, 19) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + opt(x, 5, 4) + self.assertEqual( + tracker.call_log[-1], + 1, + "Input (5, 4) should use Graph 1 (s0, 4), not Graph 2 (s0, s1).", + ) + + opt(x, 3, 4) + self.assertEqual( + tracker.call_log[-1], + 0, + "Input (3, 4) should use Graph 0 (static)", + ) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_two_tensor_inputs_exclusion(self): + """ + Multi-tensor exclusion: all pairs are flattened and guarded with + not-all semantics. The guard rejects only when ALL excluded values + across ALL tensors match simultaneously. + + 1. foo(x=[3,4], y=[5,6]) -> Graph 0: all static + 2. foo(x=[3,10], y=[5,11]) -> Graph 1: x.dim1, y.dim1 dynamic + exclusion: Or(x.dim1!=4, y.dim1!=6) + 3. foo(x=[3,10], y=[5,6]) -> Graph 1 (not all match -> passes) + 4. foo(x=[3,4], y=[5,21]) -> Graph 1 (not all match -> passes) + 5. foo(x=[3,4], y=[5,6]) -> Graph 0 (all match -> rejected) + 6. foo(x=[3,10], y=[5,11]) -> Graph 1 + """ + + def foo(x, y): + return x.sum() + y.sum() + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + opt(torch.randn(3, 4), torch.randn(5, 6)) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + opt(torch.randn(3, 10), torch.randn(5, 11)) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Only y.dim1 matches excluded value; combined Or guard passes. + opt(torch.randn(3, 10), torch.randn(5, 6)) + self.assertEqual(tracker.frame_count, 2, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 1) + + # Only x.dim1 matches excluded value; combined Or guard passes. + opt(torch.randn(3, 4), torch.randn(5, 21)) + self.assertEqual(tracker.frame_count, 2, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 1) + + # Both match excluded values; Or guard fails -> falls to Graph 0. + opt(torch.randn(3, 4), torch.randn(5, 6)) + self.assertEqual(tracker.call_log[-1], 0) + + # Neither matches; Or guard passes -> Graph 1. + opt(torch.randn(3, 10), torch.randn(5, 11)) + self.assertEqual(tracker.call_log[-1], 1) + + @torch._dynamo.config.patch( + automatic_dynamic_shapes=True, assume_static_by_default=True + ) + def test_multi_tensor_and_scalar_accumulation(self): + """ + 3-dim tensors + 3 scalar inputs with cascading accumulation. + Each step transitions one input while the rest stay the same, + verifying that only the current transition's exclusion is emitted. + + Graph 0: all static + Graph 1: x.dim2, y.dim2 dynamic excl: Or(Ne(x.dim2,4), Ne(y.dim2,7)) + Graph 2: + n dynamic excl: Ne(n, 2) + Graph 3: + m dynamic excl: Ne(m, 3) + Graph 4: + k dynamic excl: Ne(k, 4) + """ + + def foo(x, y, n, m, k): + return x.sum() * n + y.sum() * m + k + + tracker = GraphTracker() + opt = torch.compile(foo, backend=tracker) + + # -- Compilation steps -- + + # Graph 0: all static + opt(torch.randn(2, 3, 4), torch.randn(5, 6, 7), 2, 3, 4) + self.assertEqual(tracker.frame_count, 1) + self.assertEqual(tracker.call_log[-1], 0) + + # Graph 1: tensor dim2 changes + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 2, 3, 4) + self.assertEqual(tracker.frame_count, 2) + self.assertEqual(tracker.call_log[-1], 1) + + # Graph 2: scalar n changes (tensor exclusions cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 3, 4) + self.assertEqual(tracker.frame_count, 3) + self.assertEqual(tracker.call_log[-1], 2) + + # Graph 3: scalar m changes (n exclusion cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 4) + self.assertEqual(tracker.frame_count, 4) + self.assertEqual(tracker.call_log[-1], 3) + + # Graph 4: scalar k changes (m exclusion cleared) + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 15) + self.assertEqual(tracker.frame_count, 5) + self.assertEqual(tracker.call_log[-1], 4) + + # -- Verification: each input routes to the correct graph -- + + # k=4 triggers Graph 4 exclusion -> Graph 3 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 9, 4) + self.assertEqual(tracker.call_log[-1], 3, "k=4 should fall to Graph 3") + + # m=3 also triggers Graph 3 exclusion -> Graph 2 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 8, 3, 4) + self.assertEqual(tracker.call_log[-1], 2, "m=3 should fall to Graph 2") + + # n=2 also triggers Graph 2 exclusion -> Graph 1 + opt(torch.randn(2, 3, 10), torch.randn(5, 6, 11), 2, 3, 4) + self.assertEqual(tracker.call_log[-1], 1, "n=2 should fall to Graph 1") + + # tensor dims match original -> Graph 1 exclusion triggers -> Graph 0 + opt(torch.randn(2, 3, 4), torch.randn(5, 6, 7), 2, 3, 4) + self.assertEqual(tracker.call_log[-1], 0, "Original sizes should use Graph 0") + + # mixed: new tensor dims + new scalars -> Graph 4 + opt(torch.randn(2, 3, 20), torch.randn(5, 6, 21), 50, 60, 70) + self.assertEqual(tracker.frame_count, 5, "Should not recompile") + self.assertEqual(tracker.call_log[-1], 4, "All-new values should use Graph 4") + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_guard_manager.py b/test/dynamo/test_guard_manager.py index 508f03d241c52..5e53c49dccdbc 100644 --- a/test/dynamo/test_guard_manager.py +++ b/test/dynamo/test_guard_manager.py @@ -2,7 +2,6 @@ import abc import functools import inspect -import unittest import weakref import torch @@ -254,8 +253,12 @@ def test_default_device_guard(self): guard = guards.DEFAULT_DEVICE(root, ["cpu device"], None) self.assertTrue(guard(foo)) + if not torch.accelerator.is_available(): + self.skipTest("Accelerator is not available") + try: - torch.set_default_device("cuda") + device = torch.accelerator.current_accelerator() + torch.set_default_device(device) self.assertFalse(guard(foo)) finally: torch.set_default_device(None) @@ -445,10 +448,14 @@ def test_weakref_alive_guard(self): del x self.assertFalse(guard(weakref_x())) - @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") def test_call_function_no_args_guard(self): + if not torch.accelerator.is_available(): + self.skipTest("Accelerator is not available") + root = RootGuardManager() - x = torch.cuda.current_device() + device = torch.accelerator.current_accelerator() + # Use device.index which is device-agnostic (works on all accelerators) + x = device.index if device.index is not None else 0 guard = guards.EQUALS_MATCH(root, x, [0], None) self.assertTrue(guard(0)) self.assertFalse(guard(1)) @@ -1066,19 +1073,23 @@ def hook(guard_wrapper, f_locals, builder): class RecursiveDictTagTests(torch._dynamo.test_case.TestCase): def setUp(self): + super().setUp() self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards torch._dynamo.config.use_recursive_dict_tags_for_guards = True def tearDown(self): + super().tearDown() torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev class TagSafetyChecks(RecursiveDictTagTests): def setUp(self): + super().setUp() self._prev = torch._dynamo.config.use_recursive_dict_tags_for_guards torch._dynamo.config.use_recursive_dict_tags_for_guards = True def tearDown(self): + super().tearDown() torch._dynamo.config.use_recursive_dict_tags_for_guards = self._prev def test_immutable_tag_safe(self): diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 35f78f57afbfc..55d83764f99af 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -1,6 +1,7 @@ # Owner(s): ["module: dynamo"] import dataclasses +import itertools import pickle import sys import tempfile @@ -9,7 +10,6 @@ import weakref from collections.abc import Iterator from typing import NamedTuple -from unittest.mock import patch import torch import torch._dynamo.testing @@ -790,22 +790,6 @@ def fn(x, y): # guard should fail for different y value self._test_check_fn(ref, loaded, {"x": torch.randn(3), "y": 6}, False) - def test_nn_module(self): - def fn(m, x): - return m(x) - - m = GlobalModule() - x = torch.randn(3) - - # config setting controls whether the NN_MODULE guard is installed - with patch("torch._dynamo.config.inline_inbuilt_nn_modules", False): - # we don't support NN_MODULE because it adds an ID_MATCH guard, and we don't - # support that in serialization - with self.assertRaisesRegex( - PackageError, "NN_MODULE guard cannot be serialized." - ): - self._test_serialization("NN_MODULE", fn, m, x) - def test_class_match(self): def fn(x): # usage of this context manager installs a FUNCTION_MATCH guard @@ -946,6 +930,35 @@ def _gen_kwargs(x=x): ref, loaded, {"x": torch.randn(4), "r": iter(range(2, 15, 4))}, False ) + def test_count_iterator_match(self): + def fn(x, counter): + return x + next(counter) + + x = torch.randn(3) + + def _gen_kwargs(x=x): + return {"x": x, "counter": itertools.count(2, 3)} + + ref, loaded = self._test_serialization( + "COUNT_ITERATOR_MATCH", fn, _gen_fn=_gen_kwargs + ) + + self._test_check_fn( + ref, loaded, {"x": x, "counter": itertools.count(2, 3)}, True + ) + self._test_check_fn( + ref, + loaded, + {"x": torch.randn(4), "counter": itertools.count(2, 3)}, + True, + ) + self._test_check_fn( + ref, loaded, {"x": x, "counter": itertools.count(5, 3)}, False + ) + self._test_check_fn( + ref, loaded, {"x": x, "counter": itertools.count(2, 4)}, False + ) + def test_dict_version(self): def fn(x): return pytree.tree_leaves(x)[0] + 1 @@ -1507,6 +1520,7 @@ def foo(x, y): True, ) + @torch._dynamo.config.patch(nested_graph_breaks=False) def test_ddp_module(self): import torch.distributed as dist @@ -1662,19 +1676,17 @@ def foo(inputs): self._test_check_fn(ref, loaded, {"inputs": Inputs(x, weakref.ref(x))}, True) def test_unused_stream(self): - if not torch.cuda.is_available(): - self.skipTest("CUDA is not available") + if not torch.accelerator.is_available(): + self.skipTest("Accelerator is not available") def foo(inputs): return inputs.x + 1 x = torch.randn(3, 2) ref, loaded = self._test_serialization( - "TENSOR_MATCH", foo, Inputs(x, torch.cuda.Stream()) - ) - self._test_check_fn( - ref, loaded, {"inputs": Inputs(x, torch.cuda.Stream())}, True + "TENSOR_MATCH", foo, Inputs(x, torch.Stream()) ) + self._test_check_fn(ref, loaded, {"inputs": Inputs(x, torch.Stream())}, True) def test_unused_process_group(self): import torch.distributed as dist @@ -1857,6 +1869,17 @@ def test_source_serialization(self): self.assertEqual(pickle.dumps(src1), pickle.dumps(src2)) + def test_source_serialization_init_false_fields(self): + # Test that source serialization handles fields that are not initialized + from torch._dynamo.source import DefaultsSource, LocalSource + + base = LocalSource("x") + source = DefaultsSource(base=base, idx_key=0, is_kw=False) + + # Round-trip through pickle should work even with init=False fields + restored = pickle.loads(pickle.dumps(source)) + self.assertEqual(source, restored) + class SimpleModule(torch.nn.Module): def __init__(self, c): diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 0ba5d34144024..da1a74eb1360e 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -38,6 +38,7 @@ xfailIfTorchDynamo, ) from torch.testing._internal.hop_db import hop_db +from torch.testing._internal.inductor_utils import GPU_TYPE from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.testing._internal.triton_utils import ( requires_cuda_and_triton, @@ -135,7 +136,7 @@ def default_args_generator(seed_value): yield new_args -class HigherOrderOpTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class HigherOrderOpTests(torch._dynamo.test_case.TestCase): def _assert_wrap_fallback(self, func, args, setup=lambda: None): counters.clear() backend = EagerAndRecordGraphs() @@ -422,7 +423,7 @@ def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"): sin_1: "f32[]" = l_d_y_1_2_.sin(); l_d_y_1_2_ = None sub: "f32[]" = add - sin_1; add = sin_1 = None return (sub,) -""", # NOQA: B950 +""", ) def test_wrap_pytree_args_with_symint_constant(self): @@ -2582,10 +2583,6 @@ def f(x): # 3 args - 1 for input, and other 2 for the weight and bias self.assertTrue(len(wrap_node.args), 3) - # Check that the linear bias and weight are getattr in the outer graph - if not torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2) - # Check that the inner function has one op and its a linear op body_function = getattr(backend.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 1) @@ -2714,10 +2711,6 @@ def f(x): wrap_node = find_first_node(backend.graphs[0], wrap) self.assertTrue(len(wrap_node.args), 3) - # Check that the linear bias and weight are getattr in the outer graph - if not torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertTrue(len(dict(backend.graphs[0].named_parameters())) == 2) - # Check that the inner function has one op and its a linear op body_function = getattr(backend.graphs[0], wrap_node.args[0].name) self.assertEqual(op_count(body_function), 1) @@ -3097,7 +3090,7 @@ def forward(self, L_a_ : torch.SymInt, L_b_ : torch.SymInt, L_c_ : torch.SymInt, _vmap_decrement_nesting_2 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_2 = None _remove_batch_dim_3 = torch._functorch.predispatch._remove_batch_dim(batched_outputs_3, 1, l_d_, 0); batched_outputs_3 = l_d_ = None _vmap_decrement_nesting_3 = torch._functorch.predispatch._vmap_decrement_nesting(); _vmap_decrement_nesting_3 = None - return (_remove_batch_dim_3,)""", # noqa: B950 + return (_remove_batch_dim_3,)""", ) def test_cond_pytree_operands(self): @@ -3163,7 +3156,7 @@ def forward(self, L_pred_ : torch.Tensor, L_pytree_in_0_ : torch.Tensor, L_pytre cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, (l_pytree_in_0_, l_pytree_in_1_0_0_0_, l_pytree_in_2_, l_pytree_in_3_0_, l_pytree_in_3_1_0_, l_pytree_in_3_2_, l_pytree_in_4_g_)); l_pred_ = cond_true_0 = cond_false_0 = l_pytree_in_0_ = l_pytree_in_1_0_0_0_ = l_pytree_in_2_ = l_pytree_in_3_0_ = l_pytree_in_3_1_0_ = l_pytree_in_3_2_ = l_pytree_in_4_g_ = None getitem = cond[0]; cond = None - return (getitem,)""", # noqa: B950 + return (getitem,)""", ) def test_cond_pytree_operands_with_non_tensor_leaves(self): @@ -3392,7 +3385,7 @@ def outer_body_fn(x): with self.assertRaisesRegex(RuntimeError, msg): fn_with_hints(x, y) - @requires_cuda_and_triton + @requires_gpu_and_triton def test_wrap_inductor_compiled_regions_option(self): """ Test that wrap_inductor_compiled_regions option wraps compiled regions @@ -3414,8 +3407,8 @@ def fn_wrapped(x, y): def fn_not_wrapped(x, y): return torch.matmul(x, y) - x = torch.randn(4, 4, device="cuda") - y = torch.randn(4, 4, device="cuda") + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) # Test wrapped version - HOP should be visible in DebugMode with DebugMode() as debug_mode_wrapped: @@ -3436,7 +3429,7 @@ def fn_not_wrapped(x, y): self.assertEqual(result_wrapped, expected) self.assertEqual(result_not_wrapped, expected) - @requires_cuda_and_triton + @requires_gpu_and_triton def test_wrap_inductor_compiled_regions_with_backward(self): """ Test that wrap_inductor_compiled_regions works correctly with autograd. @@ -3451,8 +3444,8 @@ def test_wrap_inductor_compiled_regions_with_backward(self): def fn(x, y): return torch.matmul(x, y) - x = torch.randn(4, 4, device="cuda", requires_grad=True) - y = torch.randn(4, 4, device="cuda", requires_grad=True) + x = torch.randn(4, 4, device=GPU_TYPE, requires_grad=True) + y = torch.randn(4, 4, device=GPU_TYPE, requires_grad=True) # Clone for eager comparison x_eager = x.detach().clone().requires_grad_(True) @@ -3478,9 +3471,7 @@ def fn(x, y): self.assertEqual(y.grad, y_eager.grad) -class HigherOrderOpVmapGuardTests( - torch._dynamo.test_case.TestCaseWithNestedGraphBreaks, LoggingTestCase -): +class HigherOrderOpVmapGuardTests(LoggingTestCase): @make_logging_test(recompiles=True) def test_vmap_grad_guard_ok(self, records): vmap = torch.vmap @@ -3749,9 +3740,7 @@ def fn(x): self.assertGreater(len(records), 0) -class FuncTorchHigherOrderOpTests( - torch._dynamo.test_case.TestCaseWithNestedGraphBreaks -): +class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase): def tearDown(self): # Ensure that in the case of a test failure, the next test won't fail # because of a previous call to _vmap_increment_nesting that wasn't undone @@ -3778,7 +3767,7 @@ def tearDown(self): def test_teardown_resets_nested_graph_breaks(self): expected_nested_state = getattr( - self, "prev_nested_graph_breaks", torch._dynamo.config.nested_graph_breaks + self, "_prior_nested_graph_breaks", torch._dynamo.config.nested_graph_breaks ) def _check_flag(): @@ -3839,13 +3828,13 @@ def forward(self, L_x_: "f32[4, 3]"): child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - child_2: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None + child_2: "f32[4, 3]" = torch._functorch.predispatch._make_dual(l_x_, child_1, level = 0); child_1 = None _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None @@ -3862,7 +3851,7 @@ def forward(self, L_x_: "f32[4, 3]"): primals_out: "f32[4, 3]" = torch.sin(diff_primals) - results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 3) + results: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 3) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -3898,16 +3887,16 @@ def forward(self, L_x_: "f32[4, 3]"): output_input: "f32[4, 3, 4, 3]" = split_1.view((4, 3, 4, 3)); split_1 = None - _unpack_dual = torch._unpack_dual(output_input, level = 0); output_input = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(output_input, level = 0); output_input = None primal: "f32[4, 3, 4, 3]" = _unpack_dual[0] dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None - tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results_1: "f32[12, 4, 3, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None @@ -3966,13 +3955,13 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - child_3: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None + child_3: "f32[3, 4]" = torch._functorch.predispatch._make_dual(l_y_, child_1, level = 0); child_1 = None child_2: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None @@ -3991,7 +3980,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): primals_out: "f32[4, 3]" = _wrap_for_grad_2.sin(); _wrap_for_grad_2 = None - results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 3) + results: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 3) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4027,17 +4016,17 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child_7: "f32[4, 3, 3, 4]" = split_1.view((4, 3, 3, 4)); split_1 = None - _unpack_dual = torch._unpack_dual(child_7, level = 0); child_7 = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(child_7, level = 0); child_7 = None primal: "f32[4, 3, 3, 4]" = _unpack_dual[0]; _unpack_dual = None tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal) - child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None - child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None + child_8: "f32[4, 3, 3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = child_8 = None + child_9: "f32[4, 3, 3, 4]" = torch._functorch.predispatch._unwrap_for_grad(tangent, 2); tangent = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None child_10: "f32[12, 4, 3, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_9, 1, 12, 0); child_9 = None @@ -4088,7 +4077,7 @@ def forward(self, L_x_: "f32[4, 3]"): primals_out: "f32[4, 3]" = torch.sin(diff_primals) - results: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primals_out, 1) + results: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4166,7 +4155,7 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): primals_out: "f32[3, 4]" = diff_primals.sin() - results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1) + results: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4244,8 +4233,8 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): primals_out: "f32[3, 4]" = diff_primals.sin() - aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None - results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1) + aux_1: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None + results: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4324,7 +4313,7 @@ def forward(self, L_x_: "f32[5]"): sin: "f32[5]" = _wrap_for_grad.sin(); _wrap_for_grad = None primals_out: "f32[]" = sin.sum(); sin = None - results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None + results: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 1); primals_out = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4372,8 +4361,8 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): child: "f32[5]" = _wrap_for_grad.sin() child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) - _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) + _unwrap_for_grad: "f32[5]" = torch._functorch.predispatch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._functorch.predispatch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4424,8 +4413,8 @@ def forward(self, L_x_: "f32[5]", L_v_: "f32[5]"): child: "f32[5]" = _wrap_for_grad.sin() child_1: "f32[5]" = _wrap_for_grad.cos(); _wrap_for_grad = None - _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1) - _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) + _unwrap_for_grad: "f32[5]" = torch._functorch.predispatch._unwrap_for_grad(child, 1) + _unwrap_for_grad_1: "f32[5]" = torch._functorch.predispatch._unwrap_for_grad(child_1, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4478,8 +4467,8 @@ def forward(self, L_x_: "f32[5]"): sin: "f32[5]" = aux.sin() primals_out: "f32[]" = sin.sum(); sin = None - aux_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = aux_1 = None - results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None + aux_1: "f32[5]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = aux_1 = None + results: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primals_out, 1); primals_out = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4487,7 +4476,6 @@ def forward(self, L_x_: "f32[5]"): """, ) - @config.patch(inline_inbuilt_nn_modules=True) def test_functional_call(self): def wrapper_fn(model, params, inputs, targets): prediction = torch.func.functional_call(model, params, (inputs,)) @@ -4504,10 +4492,9 @@ def wrapper_fn(model, params, inputs, targets): return actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertExpectedInline( - actual, - """\ + self.assertExpectedInline( + actual, + """\ class GraphModule(torch.nn.Module): def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): l_model_parameters_weight_ = L_model_parameters_weight_ @@ -4520,24 +4507,8 @@ def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bi mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None return (mse_loss,) """, - ) - else: - self.assertExpectedInline( - actual, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): - l_inputs_ = L_inputs_ - l_targets_ = L_targets_ - - prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None - - mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None - return (mse_loss,) -""", - ) + ) - @config.patch(inline_inbuilt_nn_modules=True) def test_functional_call_sequential_params_and_buffers(self): # copied from test/test_stateless.py class MockModule(torch.nn.Module): @@ -4567,8 +4538,7 @@ def wrapper_fn(model, params, buffers, inputs): return actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) - if torch._dynamo.config.inline_inbuilt_nn_modules: - expected = """\ + expected = """\ class GraphModule(torch.nn.Module): def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: "f32[1, 1]", L_model_modules_l1_parameters_bias_: "f32[1]", L_model_buffers_buffer_: "f32[1]"): l_inputs_ = L_inputs_ @@ -4579,49 +4549,11 @@ def forward(self, L_inputs_: "f32[1, 1]", L_model_modules_l1_parameters_weight_: add: "f32[1, 1]" = linear + l_model_buffers_buffer_; linear = l_model_buffers_buffer_ = None return (add,) """ - # We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it. - self.assertExpectedInline( - empty_line_normalizer(actual), - empty_line_normalizer(normalize_gm(expected)), - ) - else: - self.assertExpectedInline( - actual, - """\ -class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[1, 1]"): - l_x_ = L_x_ - - l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None - l__self___buffer: "f32[1]" = self.L__self___buffer - add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None - return (add,) -""", - ) - - @config.patch(inline_inbuilt_nn_modules=False) - def test_functional_call_disable_inline_nn_module(self): - counters.clear() - - def wrapper_fn(model, params, inputs, targets): - prediction = torch.func.functional_call(model, params, (inputs,)) - return torch.nn.functional.mse_loss(prediction, targets) - - model = torch.nn.Linear(3, 3) - params = dict(model.named_parameters()) - inputs = torch.randn(64, 3) - targets = torch.randn(64, 3) - - actual = wrapper_fn(model, params, inputs, targets) - expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( - model, params, inputs, targets - ) - self.assertEqual(len(counters["graph_break"]), 1) - self.assertIn( - "torch.func.functional_call capture is disabled", - next(iter(counters["graph_break"].keys())), + # We found Windows/Linux have some empty line difference, empty_line_normalizer will help fix it. + self.assertExpectedInline( + empty_line_normalizer(actual), + empty_line_normalizer(normalize_gm(expected)), ) - self.assertEqual(actual, expected) def test_grad(self): counters.clear() @@ -4664,8 +4596,8 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4731,8 +4663,8 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4787,8 +4719,8 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4843,8 +4775,8 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4897,9 +4829,9 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4953,9 +4885,9 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -5026,10 +4958,10 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): getitem: "f32[3, 3, 3]" = _autograd_grad[0] getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(getitem_1, 1); getitem_1 = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -5070,10 +5002,10 @@ def forward(self, L_x_: "f32[3, 3, 3]", L_y_: "f32[3, 3, 3]"): getitem: "f32[3, 3, 3]" = _autograd_grad[0] getitem_1: "f32[3, 3, 3]" = _autograd_grad[1]; _autograd_grad = None - _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem, 1); getitem = None - _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(getitem_1, 1); getitem_1 = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None - aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + _unwrap_for_grad: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(getitem, 1); getitem = None + _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(getitem_1, 1); getitem_1 = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -5131,8 +5063,8 @@ def forward(self, L_x_: "f32[]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args_1], create_graph = True); diff_args_1 = None grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None + grad_input_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 2); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 2); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable("torch.func.{grad, vjp, jacrev, hessian} don't yet support saved tensor hooks. Please open an issue with your use case."); _saved_tensors_hooks_disable_2 = None @@ -5140,8 +5072,8 @@ def forward(self, L_x_: "f32[]"): _autograd_grad_1 = torch._functorch.eager_transforms._autograd_grad((grad_input_1,), [diff_args], create_graph = True); diff_args = None grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None - grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None - output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None + grad_input_3: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None + output_2: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -5245,8 +5177,8 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None - grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None - output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + grad_input_1: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -5307,28 +5239,28 @@ def forward(self, L_x_: "f32[4, 3]"): child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None + _make_dual: "f32[4, 3]" = torch._functorch.predispatch._make_dual(l_x_, child_1, level = 0); child_1 = None _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None result_duals: "f32[4, 3]" = torch.sin(_make_dual); _make_dual = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[4, 3]" = _unpack_dual[0] dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None - tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + primals_out_unflatten: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None @@ -5387,29 +5319,29 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None + _make_dual: "f32[3, 4]" = torch._functorch.predispatch._make_dual(l_y_, child_1, level = 0); child_1 = None _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[3, 4]" = _unpack_dual[0] dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None - tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + primals_out_unflatten: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None @@ -5468,31 +5400,31 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child_1: "f32[3, 4]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[3, 4]" = torch._make_dual(l_y_, child_1, level = 0); child_1 = None + _make_dual: "f32[3, 4]" = torch._functorch.predispatch._make_dual(l_y_, child_1, level = 0); child_1 = None aux: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = None _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = _wrap_for_grad_1 = None result_duals: "f32[3, 4]" = _make_dual.sin(); _make_dual = None - aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 2); aux = None + aux_1: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 2); aux = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[3, 4]" = _unpack_dual[0] dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None - tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + primals_out_unflatten: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None results: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(tangents_out_unflatten, 1, 12, 0); tangents_out_unflatten = None aux_2: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(aux_1, 1, 12, 0); aux_1 = None @@ -5554,36 +5486,36 @@ def forward(self, L_x_: "f32[4, 3]", L_y_: "f32[3, 4]"): child_1: "f32[4, 3]" = torch._functorch.predispatch._add_batch_dim(child, 0, 1); child = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - child_3: "f32[4, 3]" = torch._make_dual(l_x_, child_1, level = 0); child_1 = None + child_3: "f32[4, 3]" = torch._functorch.predispatch._make_dual(l_x_, child_1, level = 0); child_1 = None _wrap_for_grad: "f32[4, 3]" = torch._C._functorch._wrap_for_grad(l_x_, 2); l_x_ = _wrap_for_grad = None _wrap_for_grad_1: "f32[3, 4]" = torch._C._functorch._wrap_for_grad(l_y_, 2); l_y_ = None child_2: "f32[3, 4]" = _wrap_for_grad_1.sin(); _wrap_for_grad_1 = None - _unpack_dual = torch._unpack_dual(child_2, level = 0); child_2 = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(child_2, level = 0); child_2 = None primal: "f32[3, 4]" = _unpack_dual[0]; _unpack_dual = None tangent: "f32[3, 4]" = torch.zeros_like(primal) - _unpack_dual_1 = torch._unpack_dual(child_3, level = 0); child_3 = None + _unpack_dual_1 = torch._functorch.predispatch._unpack_dual(child_3, level = 0); child_3 = None primal_1: "f32[4, 3]" = _unpack_dual_1[0] dual: "f32[4, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None - child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None - child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None - child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None - child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + child_4: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = child_4 = None + child_5: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None + child_6: "f32[3, 4]" = torch._functorch.predispatch._unwrap_for_grad(tangent, 2); tangent = None + child_7: "f32[4, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None child_8: "f32[12, 3, 4]" = torch._functorch.predispatch._remove_batch_dim(child_6, 1, 12, 0); child_6 = None child_9: "f32[12, 4, 3]" = torch._functorch.predispatch._remove_batch_dim(child_7, 1, 12, 0); child_7 = None @@ -5631,27 +5563,27 @@ def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): l_x_ = L_x_ l_v_ = L_v_ - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None + _make_dual: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None result_duals: "f32[]" = sin.sum(); sin = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[]" = _unpack_dual[0] dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None - tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None + primals_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten) """, ) @@ -5682,29 +5614,29 @@ def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): l_x_ = L_x_ l_v_ = L_v_ - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None + aux: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None sin: "f32[3, 3]" = aux.sin() result_duals: "f32[]" = sin.sum(); sin = None - aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + aux_1: "f32[3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[]" = _unpack_dual[0] dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None - tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None + primals_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten, aux_1) """, ) @@ -5737,35 +5669,35 @@ def forward(self, L_x_: "f32[3, 3]", L_y_: "f32[3, 3]", L_v_: "f32[3, 3]"): l_y_ = L_y_ l_v_ = L_v_ - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - aux: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = None + aux: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_v_, level = 0); l_x_ = None _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None - _make_dual_1: "f32[3, 3]" = torch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None + _make_dual_1: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_y_, l_v_, level = 0); l_y_ = l_v_ = None sin: "f32[3, 3]" = aux.sin() sum_1: "f32[]" = sin.sum(); sin = None cos: "f32[3, 3]" = _make_dual_1.cos(); _make_dual_1 = None result_duals: "f32[3, 3]" = sum_1 + cos; sum_1 = cos = None - aux_1: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + aux_1: "f32[3, 3]" = torch._functorch.predispatch._unwrap_for_grad(aux, 1); aux = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[3, 3]" = _unpack_dual[0] dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None - tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None + primals_out_unflatten: "f32[3, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[3, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None return (primals_out_unflatten, tangents_out_unflatten, aux_1) """, ) @@ -5798,27 +5730,27 @@ def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): l_v_ = L_v_ _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None + _make_dual: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None result_duals: "f32[]" = sin.sum(); sin = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[]" = _unpack_dual[0] dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None - tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None + primals_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None return (primals_out_unflatten, tangents_out_unflatten) """, @@ -5865,27 +5797,27 @@ def forward(self, L_x_: "f32[3, 3]", L_v_: "f32[3, 3]"): _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_2 = None - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - _make_dual: "f32[3, 3]" = torch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None + _make_dual: "f32[3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_v_, level = 0); l_x_ = l_v_ = None sin: "f32[3, 3]" = _make_dual.sin(); _make_dual = None result_duals: "f32[]" = sin.sum(); sin = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[]" = _unpack_dual[0] dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None - tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None + primals_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._functorch.predispatch._unwrap_for_grad(dual, 1); dual = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_4 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_4 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None _set_fwd_grad_enabled_5 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_5 = None _set_fwd_grad_enabled_6 = torch._C._set_fwd_grad_enabled(False); _set_fwd_grad_enabled_6 = None _set_fwd_grad_enabled_7 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_7 = None @@ -5935,48 +5867,48 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3, 3]"): l_x_ = L_x_ - _jvp_increment_nesting = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting = None + _jvp_increment_nesting = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting = None _set_fwd_grad_enabled = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled = None - _enter_dual_level = torch._C._enter_dual_level(); _enter_dual_level = None + _enter_dual_level = torch._functorch.predispatch._enter_dual_level(); _enter_dual_level = None _maybe_load_decompositions = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions = None - child: "f32[3, 3, 3]" = torch._make_dual(l_x_, l_x_, level = 0); l_x_ = None + child: "f32[3, 3, 3]" = torch._functorch.predispatch._make_dual(l_x_, l_x_, level = 0); l_x_ = None - _jvp_increment_nesting_1 = torch._C._functorch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None + _jvp_increment_nesting_1 = torch._functorch.predispatch._jvp_increment_nesting(); _jvp_increment_nesting_1 = None _set_fwd_grad_enabled_1 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_1 = None _maybe_load_decompositions_1 = torch.autograd.forward_ad._maybe_load_decompositions(); _maybe_load_decompositions_1 = None - _make_dual_1: "f32[3, 3, 3]" = torch._make_dual(child, child, level = 0); child = None + _make_dual_1: "f32[3, 3, 3]" = torch._functorch.predispatch._make_dual(child, child, level = 0); child = None result_duals: "f32[3, 3, 3]" = torch.sin(_make_dual_1); _make_dual_1 = None - _unpack_dual = torch._unpack_dual(result_duals, level = 0); result_duals = None + _unpack_dual = torch._functorch.predispatch._unpack_dual(result_duals, level = 0); result_duals = None primal: "f32[3, 3, 3]" = _unpack_dual[0] dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None - primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None - tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None + primals_out_unflatten: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal, 2); primal = None + tangents_out_unflatten: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual, 2); dual = None _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None - _jvp_decrement_nesting = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting = None + _jvp_decrement_nesting = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting = None - _unpack_dual_1 = torch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None + _unpack_dual_1 = torch._functorch.predispatch._unpack_dual(primals_out_unflatten, level = 0); primals_out_unflatten = None primal_1: "f32[3, 3, 3]" = _unpack_dual_1[0] dual_1: "f32[3, 3, 3]" = _unpack_dual_1[1]; _unpack_dual_1 = None - _unpack_dual_2 = torch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None + _unpack_dual_2 = torch._functorch.predispatch._unpack_dual(tangents_out_unflatten, level = 0); tangents_out_unflatten = None primal_2: "f32[3, 3, 3]" = _unpack_dual_2[0] dual_2: "f32[3, 3, 3]" = _unpack_dual_2[1]; _unpack_dual_2 = None - _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None - _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None - _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None - _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None + _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal_1, 1); primal_1 = None + _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(primal_2, 1); primal_2 = None + _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual_1, 1); dual_1 = None + _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._functorch.predispatch._unwrap_for_grad(dual_2, 1); dual_2 = None - _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None + _exit_dual_level = torch._functorch.predispatch._exit_dual_level(level = 0); _exit_dual_level = None _set_fwd_grad_enabled_3 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_3 = None - _jvp_decrement_nesting_1 = torch._C._functorch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None + _jvp_decrement_nesting_1 = torch._functorch.predispatch._jvp_decrement_nesting(); _jvp_decrement_nesting_1 = None return (_unwrap_for_grad_2, _unwrap_for_grad_3, _unwrap_for_grad_4, _unwrap_for_grad_5) """, ) @@ -6902,9 +6834,7 @@ def wrapper_fn(x): self.assertEqual(expected, actual) -class ActivationCheckpointingTests( - torch._dynamo.test_case.TestCaseWithNestedGraphBreaks -): +class ActivationCheckpointingTests(torch._dynamo.test_case.TestCase): def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): cloned_args = [] for arg in args: @@ -7291,11 +7221,10 @@ def false_branch(x): # inductor "while_loop", # LoweringException: AssertionError "flex_attention", # LoweringException: AssertionError - "flex_attention_backward", # AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 16 } -class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class TestHigherOrderOpsOpInfo(torch._dynamo.test_case.TestCase): @requires_cuda_and_triton @parametrize("backend", ("aot_eager", "inductor")) @ops( diff --git a/test/dynamo/test_hooks.py b/test/dynamo/test_hooks.py index 5e0c90729d3f5..14935efe10636 100644 --- a/test/dynamo/test_hooks.py +++ b/test/dynamo/test_hooks.py @@ -370,7 +370,7 @@ def forward(self, L_x_ : torch.Tensor): a = getitem * 3 add = a + getitem; a = None sum_1 = add.sum(); add = None - return (sum_1, getitem)""", # noqa: B950 + return (sum_1, getitem)""", ) def test_hook_on_intermediate_used_before_and_after(self): @@ -522,7 +522,7 @@ def forward(self, L_x_ : torch.Tensor): sum_1 = result.sum(); result = None sum_2 = getitem_3.sum(); getitem_3 = None add = sum_1 + sum_2; sum_1 = sum_2 = None - return (add,)""", # noqa: B950 + return (add,)""", ) def test_intermediary_hooks(self): @@ -915,7 +915,7 @@ def forward(self, x): comp_out[0].backward(torch.ones(4)) self.assertEqual(cnts.frame_count, 1) - my_hook = my_hook2 # noqa: F811 + my_hook = my_hook2 self.assertEqual(x0.grad, x1.grad) eager_out = mod(x0) @@ -1017,7 +1017,6 @@ def run(input): x.grad = None run(i).sum().backward() - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) def test_no_recompile_on_same_hook(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo/test_inline_and_install.py b/test/dynamo/test_inline_and_install.py index e484ebaf9de51..157bea9b9c973 100644 --- a/test/dynamo/test_inline_and_install.py +++ b/test/dynamo/test_inline_and_install.py @@ -23,7 +23,6 @@ def make_dynamic_cls(cls): cls_prefix, suffix, (config, "install_free_tensors", True), - (config, "inline_inbuilt_nn_modules", True), xfail_prop="_expected_failure_inline_and_install", ) diff --git a/test/dynamo/test_input_attr_tracking.py b/test/dynamo/test_input_attr_tracking.py index 57734086729d2..2d7c19feec39e 100644 --- a/test/dynamo/test_input_attr_tracking.py +++ b/test/dynamo/test_input_attr_tracking.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -# flake8: noqa: B950 import torch import torch._dynamo import torch._dynamo.test_case diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index 43160f8cdb7a2..438ad5c58e2e1 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -80,7 +80,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class InstallParamsAsGraphAttrTests(torch._dynamo.test_case.TestCase): - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) @torch._dynamo.config.patch(install_free_tensors=False) def check_num_inputs_and_equality_no_install( self, @@ -99,7 +98,6 @@ def check_num_inputs_and_equality_no_install( self.assertEqual(actual_num_inputs, expected_num_inline_inputs) self.assertEqual(opt_fn(*example_inputs), fn_to_compile(*example_inputs)) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) @torch._dynamo.config.patch(install_free_tensors=True) def check_num_inputs_and_equality_install( self, @@ -333,7 +331,6 @@ def test_linear_explicit( class InstallParamsWhenExport(torch._dynamo.test_case.TestCase): - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) @torch._dynamo.config.patch(install_free_tensors=True) def check_export_matches_expectation( self, @@ -476,7 +473,6 @@ def fn(a: torch.Tensor) -> torch.Tensor: inp = torch.randn((5, 5)) self.check_export_matches_expectation(fn, 1, (inp,)) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) @torch._dynamo.config.patch(install_free_tensors=True) def test_modify_net_state(self) -> None: class Mod(torch.nn.Module): diff --git a/test/dynamo/test_len_protocol.py b/test/dynamo/test_len_protocol.py new file mode 100644 index 0000000000000..6b35b37473f45 --- /dev/null +++ b/test/dynamo/test_len_protocol.py @@ -0,0 +1,1297 @@ +# Owner(s): ["module: dynamo"] + +""" +Comprehensive tests for len() builtin and __len__() method protocol in PyTorch Dynamo. + +Tests cover: +- len(obj) builtin calls +- obj.__len__() method calls +- type(obj).__len__(obj) unbound method calls (marked as expectedFailure - Dynamo limitation) +- Various container types: list, tuple, dict, set, frozenset, range, str, Tensor, nn.Module +- Dict views: keys(), values(), items() +- User-defined classes with __len__ +""" + +import collections +import dataclasses +import types + +import torch +import torch._dynamo.test_case +from torch.testing._internal.common_utils import make_dynamo_test + + +class _BaseSequenceLen: + """Base class for testing len() on sequence types (list, tuple)""" + + thetype = None # Override in subclass + + def setUp(self): + if self.thetype is None: + self.skipTest("Base class - not meant to be run directly") + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + seq = self.thetype([1, 2, 3]) + self.assertEqual(len(seq), 3) + self.assertEqual(seq.__len__(), 3) + + @make_dynamo_test + def test_len_empty(self): + seq = self.thetype([]) + self.assertEqual(len(seq), 0) + self.assertEqual(seq.__len__(), 0) + + @make_dynamo_test + def test_len_single_element(self): + seq = self.thetype([42]) + self.assertEqual(len(seq), 1) + self.assertEqual(seq.__len__(), 1) + + @make_dynamo_test + def test_len_nested(self): + inner = self.thetype([1, 2]) + seq = self.thetype([inner, inner, inner]) + self.assertEqual(len(seq), 3) + self.assertEqual(seq.__len__(), 3) + + @make_dynamo_test + def test_len_with_tensors(self): + seq = self.thetype([torch.tensor(1), torch.tensor(2), torch.tensor(3)]) + self.assertEqual(len(seq), 3) + self.assertEqual(seq.__len__(), 3) + + @make_dynamo_test + def test_len_with_mixed_types(self): + seq = self.thetype([1, "hello", 3.14, torch.tensor(4)]) + self.assertEqual(len(seq), 4) + self.assertEqual(seq.__len__(), 4) + + @make_dynamo_test + def test_len_large(self): + seq = self.thetype(range(100)) + self.assertEqual(len(seq), 100) + self.assertEqual(seq.__len__(), 100) + + +class TestListLen(_BaseSequenceLen, torch._dynamo.test_case.TestCase): + """Tests for len() on list objects""" + + thetype = list + + +class TestTupleLen(_BaseSequenceLen, torch._dynamo.test_case.TestCase): + """Tests for len() on tuple objects""" + + thetype = tuple + + +class _BaseMappingLen: + """Base class for testing len() on mapping types (dict, OrderedDict)""" + + def get_mapping(self, items): + """Override in subclass to return appropriate mapping type""" + return dict(items) + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + d = self.get_mapping({1: "a", 2: "b", 3: "c"}.items()) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_empty(self): + d = self.get_mapping({}.items()) + self.assertEqual(len(d), 0) + self.assertEqual(d.__len__(), 0) + + @make_dynamo_test + def test_len_single_entry(self): + d = self.get_mapping({"key": "value"}.items()) + self.assertEqual(len(d), 1) + self.assertEqual(d.__len__(), 1) + + @make_dynamo_test + def test_len_string_keys(self): + d = self.get_mapping({"one": 1, "two": 2, "three": 3}.items()) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_int_keys(self): + d = self.get_mapping({1: "one", 2: "two", 3: "three", 4: "four"}.items()) + self.assertEqual(len(d), 4) + self.assertEqual(d.__len__(), 4) + + @make_dynamo_test + def test_len_with_tensor_values(self): + d = self.get_mapping( + {"a": torch.tensor(1), "b": torch.tensor(2), "c": torch.tensor(3)}.items() + ) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_large(self): + d = self.get_mapping({i: i * 2 for i in range(20)}.items()) + self.assertEqual(len(d), 20) + self.assertEqual(d.__len__(), 20) + + +class TestDictLen(_BaseMappingLen, torch._dynamo.test_case.TestCase): + """Tests for len() on dict objects""" + + +class TestOrderedDictLen(_BaseMappingLen, torch._dynamo.test_case.TestCase): + """Tests for len() on OrderedDict objects""" + + def get_mapping(self, items): + return collections.OrderedDict(items) + + +class TestDefaultDictLen(_BaseMappingLen, torch._dynamo.test_case.TestCase): + """Tests for len() on defaultdict objects""" + + def get_mapping(self, items): + d = collections.defaultdict(int) + for k, v in items: + d[k] = v + return d + + +class _BaseSetLen: + """Base class for testing len() on set types""" + + __test__ = False # Prevent pytest from collecting this as a test class + + def get_set(self, items): + """Override in subclass to return appropriate set type""" + return set(items) + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + s = self.get_set([1, 2, 3]) + self.assertEqual(len(s), 3) + self.assertEqual(s.__len__(), 3) + + @make_dynamo_test + def test_len_empty(self): + s = self.get_set([]) + self.assertEqual(len(s), 0) + self.assertEqual(s.__len__(), 0) + + @make_dynamo_test + def test_len_single_element(self): + s = self.get_set([42]) + self.assertEqual(len(s), 1) + self.assertEqual(s.__len__(), 1) + + @make_dynamo_test + def test_len_with_strings(self): + s = self.get_set(["a", "b", "c", "d"]) + self.assertEqual(len(s), 4) + self.assertEqual(s.__len__(), 4) + + @make_dynamo_test + def test_len_with_duplicates(self): + # Set constructor deduplicates + s = self.get_set([1, 2, 2, 3, 3, 3]) + self.assertEqual(len(s), 3) + self.assertEqual(s.__len__(), 3) + + @make_dynamo_test + def test_len_large(self): + s = self.get_set(range(50)) + self.assertEqual(len(s), 50) + self.assertEqual(s.__len__(), 50) + + +class TestSetLen(_BaseSetLen, torch._dynamo.test_case.TestCase): + """Tests for len() on set objects""" + + def get_set(self, items): + return set(items) + + +class TestFrozenSetLen(_BaseSetLen, torch._dynamo.test_case.TestCase): + """Tests for len() on frozenset objects""" + + def get_set(self, items): + return frozenset(items) + + +class TestRangeLen(torch._dynamo.test_case.TestCase): + """Tests for len() on range objects""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + r = range(5) + self.assertEqual(len(r), 5) + self.assertEqual(r.__len__(), 5) + self.assertEqual(range.__len__(r), 5) + + @make_dynamo_test + def test_len_with_start_stop(self): + r = range(5, 15) + self.assertEqual(len(r), 10) + self.assertEqual(r.__len__(), 10) + self.assertEqual(range.__len__(r), 10) + + @make_dynamo_test + def test_len_with_step(self): + r = range(0, 10, 2) + self.assertEqual(len(r), 5) + self.assertEqual(r.__len__(), 5) + self.assertEqual(range.__len__(r), 5) + + @make_dynamo_test + def test_len_negative_step(self): + r = range(10, 0, -1) + self.assertEqual(len(r), 10) + self.assertEqual(r.__len__(), 10) + self.assertEqual(range.__len__(r), 10) + + @make_dynamo_test + def test_len_empty(self): + r = range(5, 5) + self.assertEqual(len(r), 0) + self.assertEqual(r.__len__(), 0) + self.assertEqual(range.__len__(r), 0) + + +class TestStringLen(torch._dynamo.test_case.TestCase): + """Tests for len() on string objects""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_empty(self): + s = "" + self.assertEqual(len(s), 0) + self.assertEqual(s.__len__(), 0) + self.assertEqual(str.__len__(s), 0) + + @make_dynamo_test + def test_len_single_char(self): + s = "a" + self.assertEqual(len(s), 1) + self.assertEqual(s.__len__(), 1) + self.assertEqual(str.__len__(s), 1) + + @make_dynamo_test + def test_len_multiple_chars(self): + s = "hello" + self.assertEqual(len(s), 5) + self.assertEqual(s.__len__(), 5) + self.assertEqual(str.__len__(s), 5) + + @make_dynamo_test + def test_len_with_spaces(self): + s = "hello world" + self.assertEqual(len(s), 11) + self.assertEqual(s.__len__(), 11) + self.assertEqual(str.__len__(s), 11) + + +class TestTensorLen(torch._dynamo.test_case.TestCase): + """Tests for len() on torch.Tensor objects""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_1d(self): + t = torch.tensor([1, 2, 3, 4, 5]) + self.assertEqual(len(t), 5) + self.assertEqual(t.__len__(), 5) + self.assertEqual(torch.Tensor.__len__(t), 5) + + @make_dynamo_test + def test_len_2d(self): + t = torch.tensor([[1, 2], [3, 4], [5, 6]]) + self.assertEqual(len(t), 3) + self.assertEqual(t.__len__(), 3) + self.assertEqual(torch.Tensor.__len__(t), 3) + + @make_dynamo_test + def test_len_3d(self): + t = torch.randn(4, 5, 6) + self.assertEqual(len(t), 4) + self.assertEqual(t.__len__(), 4) + self.assertEqual(torch.Tensor.__len__(t), 4) + + @make_dynamo_test + def test_len_empty(self): + t = torch.tensor([]) + self.assertEqual(len(t), 0) + self.assertEqual(t.__len__(), 0) + self.assertEqual(torch.Tensor.__len__(t), 0) + + @make_dynamo_test + def test_len_large_batch(self): + t = torch.randn(100, 5, 5) + self.assertEqual(len(t), 100) + self.assertEqual(t.__len__(), 100) + self.assertEqual(torch.Tensor.__len__(t), 100) + + @make_dynamo_test + def test_len_different_dtypes(self): + t = torch.tensor([1, 2, 3], dtype=torch.float32) + self.assertEqual(len(t), 3) + self.assertEqual(t.__len__(), 3) + self.assertEqual(torch.Tensor.__len__(t), 3) + + +class TestNNModuleLen(torch._dynamo.test_case.TestCase): + """Tests for len() on torch.nn module containers""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + # Pre-construct nn.Module instances outside compiled regions + self.seq_3layers = torch.nn.Sequential( + torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 5) + ) + self.empty_seq = torch.nn.Sequential() + self.ml_2modules = torch.nn.ModuleList( + [torch.nn.Linear(10, 20), torch.nn.Linear(20, 30)] + ) + self.empty_ml = torch.nn.ModuleList() + self.md_2modules = torch.nn.ModuleDict( + {"layer1": torch.nn.Linear(10, 20), "layer2": torch.nn.Linear(20, 5)} + ) + self.empty_md = torch.nn.ModuleDict() + self.seq_5layers = torch.nn.Sequential( + *[torch.nn.Linear(10, 10) for _ in range(5)] + ) + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_sequential(self): + seq = self.seq_3layers + self.assertEqual(len(seq), 3) + self.assertEqual(seq.__len__(), 3) + + @make_dynamo_test + def test_len_empty_sequential(self): + seq = self.empty_seq + self.assertEqual(len(seq), 0) + self.assertEqual(seq.__len__(), 0) + + @make_dynamo_test + def test_len_module_list(self): + ml = self.ml_2modules + self.assertEqual(len(ml), 2) + self.assertEqual(ml.__len__(), 2) + + @make_dynamo_test + def test_len_empty_module_list(self): + ml = self.empty_ml + self.assertEqual(len(ml), 0) + self.assertEqual(ml.__len__(), 0) + + @make_dynamo_test + def test_len_module_dict(self): + md = self.md_2modules + self.assertEqual(len(md), 2) + self.assertEqual(md.__len__(), 2) + + @make_dynamo_test + def test_len_empty_module_dict(self): + md = self.empty_md + self.assertEqual(len(md), 0) + self.assertEqual(md.__len__(), 0) + + @make_dynamo_test + def test_len_sequential_with_multiple_layers(self): + seq = self.seq_5layers + self.assertEqual(len(seq), 5) + self.assertEqual(seq.__len__(), 5) + + +class TestDictViewLen(torch._dynamo.test_case.TestCase): + """Tests for len() on dict view objects (keys, values, items)""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_dict_keys_view(self): + d = {"a": 1, "b": 2, "c": 3} + keys = d.keys() + self.assertEqual(len(keys), 3) + self.assertEqual(keys.__len__(), 3) + + @make_dynamo_test + def test_len_dict_values_view(self): + d = {"a": 1, "b": 2, "c": 3} + values = d.values() + self.assertEqual(len(values), 3) + self.assertEqual(values.__len__(), 3) + + @make_dynamo_test + def test_len_dict_items_view(self): + d = {"x": 10, "y": 20, "z": 30} + items = d.items() + self.assertEqual(len(items), 3) + self.assertEqual(items.__len__(), 3) + + @make_dynamo_test + def test_len_dict_keys_empty(self): + d = {} + keys = d.keys() + self.assertEqual(len(keys), 0) + self.assertEqual(keys.__len__(), 0) + + @make_dynamo_test + def test_len_dict_values_empty(self): + d = {} + values = d.values() + self.assertEqual(len(values), 0) + self.assertEqual(values.__len__(), 0) + + @make_dynamo_test + def test_len_dict_items_empty(self): + d = {} + items = d.items() + self.assertEqual(len(items), 0) + self.assertEqual(items.__len__(), 0) + + @make_dynamo_test + def test_len_dict_keys_single_entry(self): + """Test len() on dict.keys() with single entry""" + d = {"key": "value"} + keys = d.keys() + self.assertEqual(len(keys), 1) + self.assertEqual(keys.__len__(), 1) + + @make_dynamo_test + def test_len_dict_values_single_entry(self): + """Test len() on dict.values() with single entry""" + d = {"key": "value"} + values = d.values() + self.assertEqual(len(values), 1) + self.assertEqual(values.__len__(), 1) + + @make_dynamo_test + def test_len_dict_items_single_entry(self): + """Test len() on dict.items() with single entry""" + d = {"key": "value"} + items = d.items() + self.assertEqual(len(items), 1) + self.assertEqual(items.__len__(), 1) + + @make_dynamo_test + def test_len_dict_keys_int_keys(self): + """Test len() on dict.keys() with integer keys""" + d = {1: "one", 2: "two", 3: "three", 4: "four"} + keys = d.keys() + self.assertEqual(len(keys), 4) + self.assertEqual(keys.__len__(), 4) + + @make_dynamo_test + def test_len_dict_values_tensor_values(self): + """Test len() on dict.values() with tensor values""" + d = {"a": torch.tensor(1), "b": torch.tensor(2), "c": torch.tensor(3)} + values = d.values() + self.assertEqual(len(values), 3) + self.assertEqual(values.__len__(), 3) + + @make_dynamo_test + def test_len_dict_items_mixed_types(self): + """Test len() on dict.items() with mixed key/value types""" + d = {"str": "value", 42: torch.tensor(1), (1, 2): "tuple_key"} + items = d.items() + self.assertEqual(len(items), 3) + self.assertEqual(items.__len__(), 3) + + @make_dynamo_test + def test_len_dict_keys_large(self): + """Test len() on dict.keys() with large number of entries""" + d = {i: i * 2 for i in range(50)} + keys = d.keys() + self.assertEqual(len(keys), 50) + self.assertEqual(keys.__len__(), 50) + + @make_dynamo_test + def test_len_dict_values_large(self): + """Test len() on dict.values() with large number of entries""" + d = {i: i * 2 for i in range(50)} + values = d.values() + self.assertEqual(len(values), 50) + self.assertEqual(values.__len__(), 50) + + @make_dynamo_test + def test_len_dict_items_large(self): + """Test len() on dict.items() with large number of entries""" + d = {i: i * 2 for i in range(50)} + items = d.items() + self.assertEqual(len(items), 50) + self.assertEqual(items.__len__(), 50) + + +# User-defined classes for TestUserDefinedLen +class CustomList: + def __init__(self, items): + self.items = items + + def __len__(self): + return len(self.items) + + +class CustomContainer: + def __len__(self): + return 0 + + +class Container: + def __init__(self, size): + self.size = size + + def __len__(self): + return self.size + + +class FixedSize: + def __len__(self): + return 10 + + +class ListWrapper: + def __init__(self): + self.data = [1, 2, 3, 4, 5] + + def __len__(self): + return len(self.data) + + +class ListSubclassCustomLen(list): + def __len__(self): + return super().__len__() * 2 + + +class CustomMapping: + """A user-defined mapping (dict-like) class""" + + def __init__(self, data): + self._data = dict(data) if not isinstance(data, dict) else data + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + return self._data[key] + + def keys(self): + return self._data.keys() + + def values(self): + return self._data.values() + + def items(self): + return self._data.items() + + +class CustomMappingSubclass(CustomMapping): + """A subclass of CustomMapping""" + + def __len__(self): + # Custom len implementation (e.g., filtered length) + return super().__len__() + 1 + + +class SimpleDictLike: + """Minimal dict-like class with just __len__""" + + def __init__(self, size): + self.size = size + + def __len__(self): + return self.size + + +class TupleSubclassCustomLen(tuple): + __slots__ = () + + def __len__(self): + return super().__len__() + 1 + + +class DictSubclassCustomLen(dict): + def __len__(self): + return super().__len__() - 1 + + +class SetSubclassCustomLen(set): + def __len__(self): + return 0 + + +class TestUserDefinedLen(torch._dynamo.test_case.TestCase): + """Tests for len() on user-defined classes with __len__""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_simple_custom_class(self): + obj = CustomList([1, 2, 3]) + self.assertEqual(len(obj), 3) + self.assertEqual(obj.__len__(), 3) + self.assertEqual(CustomList.__len__(obj), 3) + + @make_dynamo_test + def test_len_custom_class_empty(self): + obj = CustomContainer() + self.assertEqual(len(obj), 0) + self.assertEqual(obj.__len__(), 0) + self.assertEqual(CustomContainer.__len__(obj), 0) + + @make_dynamo_test + def test_len_custom_class_with_properties(self): + obj = Container(42) + self.assertEqual(len(obj), 42) + self.assertEqual(obj.__len__(), 42) + self.assertEqual(Container.__len__(obj), 42) + + @make_dynamo_test + def test_len_custom_class_constant_return(self): + obj = FixedSize() + self.assertEqual(len(obj), 10) + self.assertEqual(obj.__len__(), 10) + self.assertEqual(FixedSize.__len__(obj), 10) + + @make_dynamo_test + def test_len_custom_class_with_list_attr(self): + obj = ListWrapper() + self.assertEqual(len(obj), 5) + self.assertEqual(obj.__len__(), 5) + self.assertEqual(ListWrapper.__len__(obj), 5) + + +class TestSubclassOverloadedLen(torch._dynamo.test_case.TestCase): + """Tests for custom classes that inherit from builtins and overload __len__""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_list_subclass_custom_len(self): + obj = ListSubclassCustomLen([1, 2, 3]) + self.assertEqual(len(obj), 6) + self.assertEqual(obj.__len__(), 6) + self.assertEqual(ListSubclassCustomLen.__len__(obj), 6) + + @make_dynamo_test + def test_tuple_subclass_custom_len(self): + obj = TupleSubclassCustomLen([1, 2, 3]) + self.assertEqual(len(obj), 4) + self.assertEqual(obj.__len__(), 4) + self.assertEqual(TupleSubclassCustomLen.__len__(obj), 4) + + @make_dynamo_test + def test_dict_subclass_custom_len(self): + obj = DictSubclassCustomLen({"a": 1, "b": 2, "c": 3}) + self.assertEqual(len(obj), 2) + self.assertEqual(obj.__len__(), 2) + self.assertEqual(DictSubclassCustomLen.__len__(obj), 2) + + @make_dynamo_test + def test_set_subclass_custom_len(self): + obj = SetSubclassCustomLen([1, 2, 3]) + self.assertEqual(len(obj), 0) + self.assertEqual(obj.__len__(), 0) + self.assertEqual(SetSubclassCustomLen.__len__(obj), 0) + + +class DescriptorLenClass: + """Test class with __len__ as a regular instance method""" + + def __len__(self): + """Regular instance method __len__ - should return 40""" + return 40 + + +class PartialLenClass: + """Class where __len__ is a lambda/callable object""" + + def __init__(self): + self._items = [1, 2, 3, 4, 5] + + def __len__(self): + return len(self._items) + + +class StaticMethodLenClass: + """Test class where __len__ is a staticmethod (unusual, likely to fail)""" + + @staticmethod + def __len__(): + """Staticmethod __len__ - unusual pattern""" + return 10 + + +class ClassMethodLenClass: + """Test class where __len__ is a classmethod (unusual, likely to fail)""" + + @classmethod + def __len__(cls): + """Classmethod __len__ - unusual pattern""" + return 20 + + +class CustomDescriptorLenClass: + """Test class where __len__ is a custom descriptor""" + + class CustomDescriptorLen: + """Custom descriptor that implements __get__ method""" + + def __get__(self, obj, objtype=None): + """Descriptor protocol: return a callable that returns len""" + if obj is None: + return self + return lambda: 50 + + __len__ = CustomDescriptorLen() + + +class TestUserDefinedMappingLen(torch._dynamo.test_case.TestCase): + """Tests for len() on user-defined mapping (dict-like) classes""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_custom_mapping_basic(self): + """Test len() on a basic custom mapping class""" + m = CustomMapping({"a": 1, "b": 2, "c": 3}) + self.assertEqual(len(m), 3) + self.assertEqual(m.__len__(), 3) + + @make_dynamo_test + def test_len_custom_mapping_empty(self): + """Test len() on an empty custom mapping""" + m = CustomMapping({}) + self.assertEqual(len(m), 0) + self.assertEqual(m.__len__(), 0) + + @make_dynamo_test + def test_len_custom_mapping_single_entry(self): + """Test len() on a custom mapping with single entry""" + m = CustomMapping({"key": "value"}) + self.assertEqual(len(m), 1) + self.assertEqual(m.__len__(), 1) + + @make_dynamo_test + def test_len_custom_mapping_string_keys(self): + """Test len() on a custom mapping with string keys""" + m = CustomMapping({"one": 1, "two": 2, "three": 3}) + self.assertEqual(len(m), 3) + self.assertEqual(m.__len__(), 3) + + @make_dynamo_test + def test_len_custom_mapping_int_keys(self): + """Test len() on a custom mapping with int keys""" + m = CustomMapping({1: "one", 2: "two", 3: "three", 4: "four"}) + self.assertEqual(len(m), 4) + self.assertEqual(m.__len__(), 4) + + @make_dynamo_test + def test_len_custom_mapping_with_tensor_values(self): + """Test len() on a custom mapping with tensor values""" + m = CustomMapping( + {"a": torch.tensor(1), "b": torch.tensor(2), "c": torch.tensor(3)} + ) + self.assertEqual(len(m), 3) + self.assertEqual(m.__len__(), 3) + + @make_dynamo_test + def test_len_custom_mapping_large(self): + """Test len() on a large custom mapping""" + m = CustomMapping({i: i * 2 for i in range(20)}) + self.assertEqual(len(m), 20) + self.assertEqual(m.__len__(), 20) + + @make_dynamo_test + def test_len_custom_mapping_subclass(self): + """Test len() on a subclass of CustomMapping""" + m = CustomMappingSubclass({"a": 1, "b": 2}) + # CustomMappingSubclass.__len__ returns len + 1 + self.assertEqual(len(m), 3) + self.assertEqual(m.__len__(), 3) + + @make_dynamo_test + def test_len_simple_dict_like(self): + """Test len() on a minimal dict-like class""" + m = SimpleDictLike(42) + self.assertEqual(len(m), 42) + self.assertEqual(m.__len__(), 42) + + +class TestDescriptorLenImpl(torch._dynamo.test_case.TestCase): + """Test that len_impl handles descriptor-based __len__ correctly""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_regular_instance_len(self): + """Test regular instance method __len__""" + obj = DescriptorLenClass() + + # Regular instance methods should work fine + self.assertEqual(len(obj), 40) + self.assertEqual(obj.__len__(), 40) + + @make_dynamo_test + def test_partial_callable_len(self): + """Test lambda/callable as __len__""" + obj = PartialLenClass() + + # Callable __len__ should work + self.assertEqual(len(obj), 5) + self.assertEqual(obj.__len__(), 5) + + @make_dynamo_test + def test_staticmethod_len_works(self): + """Test that staticmethod as __len__ actually works (unusual but supported) + + Staticmethods are descriptors that don't bind. CPython resolves the + descriptor and calls the underlying function without passing self. + """ + obj = StaticMethodLenClass() + + # Surprisingly, CPython's descriptor protocol makes this work + # staticmethod.__get__ returns the unwrapped function + self.assertEqual(len(obj), 10) + + @make_dynamo_test + def test_classmethod_len_works(self): + """Test that classmethod as __len__ actually works (unusual but supported) + + Classmethods are descriptors that bind to the class. CPython's + descriptor protocol handles this and passes the class instead of instance. + """ + obj = ClassMethodLenClass() + + # Surprisingly, CPython's descriptor protocol makes this work + # classmethod.__get__ returns a bound method with the class + self.assertEqual(len(obj), 20) + + @make_dynamo_test + def test_custom_descriptor_len(self): + """Test custom descriptor with __get__ method as __len__ + + Custom descriptors implement the descriptor protocol via __get__. + The descriptor is resolved and the returned callable is used as __len__. + """ + obj = CustomDescriptorLenClass() + + # Custom descriptor's __get__ returns a callable that returns 50 + self.assertEqual(len(obj), 50) + + +class TestRaisesTypeError(torch._dynamo.test_case.TestCase): + """Tests for types that don't support len() - should raise TypeError like Python""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_slice_raises_type_error(self): + """slice objects do not support len() - should raise TypeError""" + s = slice(1, 5) + with self.assertRaises(TypeError): + len(s) + + @make_dynamo_test + def test_len_slice_with_step_raises_type_error(self): + """slice with step also raises TypeError""" + s = slice(0, 10, 2) + with self.assertRaises(TypeError): + len(s) + + @make_dynamo_test + def test_len_list_iterator_raises_type_error(self): + """list iterator does not support len() - should raise TypeError""" + it = iter([1, 2, 3]) + with self.assertRaises(TypeError): + len(it) + + @make_dynamo_test + def test_len_empty_list_iterator_raises_type_error(self): + """empty list iterator also raises TypeError""" + it = iter([]) + with self.assertRaises(TypeError): + len(it) + + @make_dynamo_test + def test_len_tuple_iterator_raises_type_error(self): + """tuple iterator does not support len() - should raise TypeError""" + it = iter((1, 2, 3)) + with self.assertRaises(TypeError): + len(it) + + @make_dynamo_test + def test_len_range_iterator_raises_type_error(self): + """range iterator does not support len() - should raise TypeError""" + it = iter(range(5)) + with self.assertRaises(TypeError): + len(it) + + @make_dynamo_test + def test_len_dict_iterator_raises_type_error(self): + """dict iterator (keys) does not support len() - should raise TypeError""" + d = {"a": 1, "b": 2} + it = iter(d) + with self.assertRaises(TypeError): + len(it) + + +class TestDequeLen(torch._dynamo.test_case.TestCase): + """Tests for len() on collections.deque objects""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + d = collections.deque([1, 2, 3]) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_empty(self): + d = collections.deque([]) + self.assertEqual(len(d), 0) + self.assertEqual(d.__len__(), 0) + + @make_dynamo_test + def test_len_single_element(self): + d = collections.deque([42]) + self.assertEqual(len(d), 1) + self.assertEqual(d.__len__(), 1) + + @make_dynamo_test + def test_len_with_maxlen(self): + d = collections.deque([1, 2, 3, 4, 5], maxlen=3) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_with_strings(self): + d = collections.deque(["a", "b", "c", "d"]) + self.assertEqual(len(d), 4) + self.assertEqual(d.__len__(), 4) + + @make_dynamo_test + def test_len_large(self): + d = collections.deque(range(50)) + self.assertEqual(len(d), 50) + self.assertEqual(d.__len__(), 50) + + +class TestMappingProxyLen(torch._dynamo.test_case.TestCase): + """Tests for len() on types.MappingProxyType objects""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_basic(self): + """Test len() on a basic MappingProxyType""" + d = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_empty(self): + """Test len() on an empty MappingProxyType""" + d = types.MappingProxyType({}) + self.assertEqual(len(d), 0) + self.assertEqual(d.__len__(), 0) + + @make_dynamo_test + def test_len_single_entry(self): + """Test len() on a MappingProxyType with single entry""" + d = types.MappingProxyType({"key": "value"}) + self.assertEqual(len(d), 1) + self.assertEqual(d.__len__(), 1) + + @make_dynamo_test + def test_len_string_keys(self): + """Test len() on a MappingProxyType with string keys""" + d = types.MappingProxyType({"one": 1, "two": 2, "three": 3}) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_int_keys(self): + """Test len() on a MappingProxyType with int keys""" + d = types.MappingProxyType({1: "one", 2: "two", 3: "three", 4: "four"}) + self.assertEqual(len(d), 4) + self.assertEqual(d.__len__(), 4) + + @make_dynamo_test + def test_len_with_tensor_values(self): + """Test len() on a MappingProxyType with tensor values""" + d = types.MappingProxyType( + {"a": torch.tensor(1), "b": torch.tensor(2), "c": torch.tensor(3)} + ) + self.assertEqual(len(d), 3) + self.assertEqual(d.__len__(), 3) + + @make_dynamo_test + def test_len_large(self): + """Test len() on a large MappingProxyType""" + d = types.MappingProxyType({i: i * 2 for i in range(20)}) + self.assertEqual(len(d), 20) + self.assertEqual(d.__len__(), 20) + + +class MetaclassWithLen(type): + """A metaclass that defines __len__ on the class itself""" + + def __len__(cls): + """Return the number of items defined in the metaclass""" + return 5 + + +class SimpleMetaclassClass(metaclass=MetaclassWithLen): + """A class using the MetaclassWithLen metaclass""" + + +class TestMetaclassLen(torch._dynamo.test_case.TestCase): + """Tests for len() on metaclasses, classmethods, staticmethods, and properties""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_metaclass_len_basic(self): + """Test len() on a class with __len__ defined in metaclass""" + self.assertEqual(len(SimpleMetaclassClass), 5) + + @make_dynamo_test + def test_metaclass_len_direct_call(self): + """Test direct call to __len__() on a class with metaclass-defined __len__""" + self.assertEqual(SimpleMetaclassClass.__len__(), 5) + + +class CustomMutableMapping(collections.abc.MutableMapping): + """Custom mutable mapping implementation with __len__.""" + + def __init__(self, data=None): + self._data = data if data is not None else {} + + def __len__(self): + return len(self._data) + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __iter__(self): + return iter(self._data) + + +@dataclasses.dataclass(frozen=True) +class FrozenPoint: + """Frozen dataclass with __len__ method.""" + + x: float + y: float + z: float + + def __len__(self): + return 3 + + +@dataclasses.dataclass(frozen=True) +class FrozenData: + """Frozen dataclass with __len__ based on items.""" + + items: tuple + + def __len__(self): + return len(self.items) + + +class TestMutableMappingLen(torch._dynamo.test_case.TestCase): + """Tests for len() on mutable mapping types.""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_custom_mutable_mapping(self): + """Test len() on custom mutable mapping.""" + m = CustomMutableMapping({"x": 10, "y": 20, "z": 30}) + self.assertEqual(len(m), 3) + self.assertEqual(m.__len__(), 3) + + @make_dynamo_test + def test_len_custom_mutable_mapping_empty(self): + """Test len() on empty custom mutable mapping.""" + m = CustomMutableMapping() + self.assertEqual(len(m), 0) + + +class TestFrozenDataclassLen(torch._dynamo.test_case.TestCase): + """Tests for len() on frozen dataclasses.""" + + def setUp(self): + self.old = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + super().setUp() + + def tearDown(self): + torch._dynamo.config.enable_trace_unittest = self.old + return super().tearDown() + + @make_dynamo_test + def test_len_frozen_dataclass_with_len(self): + """Test len() on frozen dataclass with custom __len__.""" + p = FrozenPoint(1.0, 2.0, 3.0) + self.assertEqual(len(p), 3) + self.assertEqual(p.__len__(), 3) + + @make_dynamo_test + def test_len_frozen_dataclass_via_tuple(self): + """Test len() on frozen dataclass with __len__ based on contained data.""" + obj = FrozenData((1, 2, 3, 4)) + self.assertEqual(len(obj), 4) + self.assertEqual(obj.__len__(), 4) + + @make_dynamo_test + def test_len_frozen_dataclass_consistency(self): + """Test that len() on frozen dataclass is consistent across multiple calls.""" + obj = FrozenData(("a", "b", "c")) + # Call len() twice to ensure consistency + len1 = len(obj) + len2 = len(obj) + self.assertEqual(len1, 3) + self.assertEqual(len2, 3) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_list.py b/test/dynamo/test_list.py index 85415244db69c..379bcb037b75f 100644 --- a/test/dynamo/test_list.py +++ b/test/dynamo/test_list.py @@ -261,7 +261,8 @@ def test_pop(self): p = self.thetype("abcd") self.assertEqual(p.pop(), "d") self.assertEqual(p.pop(1), "b") - self.assertRaises(IndexError, p.pop, 10) + # The length of p is now 2, valid indices are 0, 1 + self.assertRaises(IndexError, p.pop, 2) # Wrong number of arguments self.assertRaises(TypeError, p.pop, 2, 3) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 0a59c81e23a61..81ebb02d2d5f6 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -375,7 +375,7 @@ def test_dynamo_error(self, records): from user code: File "test_logging.py", line N, in dynamo_error_fn - output = output.add(torch.ones(10, 10))""", # noqa: B950 + output = output.add(torch.ones(10, 10))""", ) test_aot = within_range_record_test(2, 6, aot=logging.INFO) @@ -817,6 +817,48 @@ def fn(x, lst): self.assertEqual(len(records), 0) + @make_logging_test(side_effects=True) + def test_side_effects_logs_fullgraph_graph_break(self, records): + """Test that side effects are logged even when fullgraph=True causes an error.""" + my_list = [1, 2, 3] + + @torch.compile(backend="eager", fullgraph=True) + def fn(x, lst): + lst.append(4) + # Force a graph break after the side effect + torch._dynamo.graph_break() + return x + len(lst) + + with self.assertRaises(torch._dynamo.exc.Unsupported): + fn(torch.ones(1), my_list) + + # Side effects should still be logged even though codegen never ran + self.assertGreater(len(records), 0) + self.assertIn("Mutating object of type list", records[0].getMessage()) + + @make_logging_test(side_effects=True) + def test_side_effects_logged_on_fullgraph_side_effect_error(self, records): + tracked_list = [] + hop_list = [] + + def fn(x): + hop_list.append(1) + return x.sin() + + @torch.compile(backend="eager", fullgraph=True) + def model(x, lst): + lst.append(1) + return torch.utils.checkpoint.checkpoint(fn, x, use_reentrant=False) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "HOP: Unsafe side effect", + ): + model(torch.ones(1), tracked_list) + + self.assertGreaterEqual(len(records), 1) + self.assertIn("Mutating object of type list", records[0].getMessage()) + @make_settings_test("torch._dynamo.utils") def test_dump_compile_times(self, records): fn_opt = torch.compile(example_fn, backend="inductor") @@ -1233,7 +1275,7 @@ def f(x, y, z): +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # +- __SHAPE_GUARD__: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) +- __SHAPE_GUARD__: ((2*L['y'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # -+- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim))""", # noqa: B950 ++- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim))""", ) @make_logging_test(guards=True) @@ -1249,7 +1291,7 @@ def f(x, y): munge_shape_guards(record.getMessage()), """\ +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # #:# in # -+- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim))""", # noqa: B950 ++- __SHAPE_GUARD__: 2 <= L['y'].size()[0] # return any([x.size(0) == y.size(0) * 2]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.decorators.mark_unbacked(tensor, dim))""", ) @make_logging_test(guards=True) @@ -1268,7 +1310,7 @@ def f(x, y): munge_shape_guards(record.getMessage()), """\ +- __SHAPE_GUARD__: L['x'].size()[0] == 2*L['y'].size()[0] # torch._check(x.size(0) == y.size(0) * 2) # #:# in # #:# in # -+- __SHAPE_GUARD__: 3 <= L['y'].size()[0] <= 14 # torch._check(x.size(0) > 5) # #:# in # #:# in # and torch._check(x.size(0) < 30) # #:# in # #:# in #""", # noqa: B950 ++- __SHAPE_GUARD__: 3 <= L['y'].size()[0] <= 14 # torch._check(x.size(0) > 5) # #:# in # #:# in # and torch._check(x.size(0) < 30) # #:# in # #:# in #""", ) @make_logging_test(guards=True) @@ -1282,7 +1324,7 @@ def f(x): record = self.getRecord(records, "TREE_GUARD_MANAGER") self.assertExpectedInline( munge_global_state_json(record.getMessage()), - """+- GLOBAL_STATE: ___check_global_state() against {"allow_bf16_reduce": "#","allow_fp16_reduce": "#","allow_tf32": "#","autocast_state":{"cached_enabled": "#","dtype": "#","enabled": "#"},"default_dtype": "#","deterministic_algorithms": "#","deterministic_algorithms_warn_only": "#","grad_mode": "#","num_threads": "#","torch_function": "#","torch_function_all_disabled": "#"}""", # noqa: B950 + """+- GLOBAL_STATE: ___check_global_state() against {"allow_bf16_reduce": "#","allow_fp16_reduce": "#","allow_tf32": "#","autocast_state":{"cached_enabled": "#","dtype": "#","enabled": "#"},"default_dtype": "#","deterministic_algorithms": "#","deterministic_algorithms_warn_only": "#","grad_mode": "#","num_threads": "#","torch_function": "#","torch_function_all_disabled": "#"}""", ) @make_logging_test(cudagraph_static_inputs=True) @@ -1431,10 +1473,20 @@ def bar(): z = y * x return z - return bar(), bar + # force top-level trace of bar + try: + return bar(), bar + finally: + pass foo() + @torch.compile + def baz(x): + return x + 1 + + baz(torch.ones(3)) + # `_log_traced_frames` is registered as an atexit callback, so we invoke # it explicitly for testing. torch._dynamo.eval_frame._log_traced_frames() @@ -1449,6 +1501,7 @@ def bar(): TorchDynamo attempted to trace the following frames: [ * foo test_logging.py:N * bar test_logging.py:N + * baz test_logging.py:N ]""", ) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dcba65f24d90a..8bf16dd67aad7 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -44,6 +44,7 @@ from torch import Tensor from torch._C import FileCheck from torch._dynamo import allow_in_graph +from torch._dynamo.comptime import comptime from torch._dynamo.eval_frame import _debug_get_cache_entry_list from torch._dynamo.exc import Unsupported from torch._dynamo.source import ConstantSource, GetItemSource, LocalSource @@ -80,7 +81,6 @@ PLATFORM_SUPPORTS_FLASH_ATTENTION, SM80OrLater, TEST_CUDA, - TEST_MULTIGPU, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_methods_invocations import ( @@ -407,11 +407,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): x = torch.ones(5) with YoloMode(): - out = torch.compile(torch.add, backend=backend, fullgraph=True)(x, x) - + out = torch.compile(torch.add, backend=backend, fullgraph=False)(x, x) self.assertEqual(out.sum().item(), 5.0) self.assertEqual(len(backend.graphs), 0) + with YoloMode(): + with self.assertRaisesRegex(RuntimeError, "found no compiled frames"): + torch.compile(torch.add, backend=backend, fullgraph=True)(x, x) + def test_compile_non_infra_empty_with_disalloed_dispatch_mode(self): from torch.utils._python_dispatch import TorchDispatchMode @@ -570,7 +573,7 @@ def _should_skip_dynamo(cls): return False def __torch_dispatch__(self, func, types, args=(), kwargs=None): - out = torch.compile(func, backend=backend, fullgraph=True)( + out = torch.compile(func, backend=backend, fullgraph=False)( *args, **kwargs ) return out @@ -1029,6 +1032,102 @@ def f(x, flag): self.assertEqual(res, torch.ones(5) + 1) self.assertTrue(res.offloading_activation) + @unittest.skipIf( + not torch.cuda.is_available() or torch.cuda.get_device_capability() < (9, 0), + "requires Hopper+ (SM >= 9.0) for TMA", + ) + @unittest.skipIf( + not torch.utils._triton.has_triton() + or not hasattr(__import__("triton"), "set_allocator"), + "requires triton with set_allocator support", + ) + def test_triton_set_allocator_no_graph_break(self): + """set_allocator inside torch.compile does not graph break and + replays correctly at runtime (including cache hits).""" + import triton + import triton.language as tl + from triton.runtime._allocation import NullAllocator + + @triton.jit + def tma_copy_kernel( + x_ptr, + out_ptr, + M, + N, + stride_m, + stride_n, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid = tl.program_id(0) + desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_M, BLOCK_N], + ) + block = tl.load_tensor_descriptor(desc, [pid * BLOCK_M, 0]) + out_desc = tl.make_tensor_descriptor( + out_ptr, + shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_M, BLOCK_N], + ) + tl.store_tensor_descriptor(out_desc, [pid * BLOCK_M, 0], block) + + M, N, BLOCK_M, BLOCK_N = 128, 64, 64, 64 + + def run_kernel(x): + out = torch.empty_like(x) + tma_copy_kernel[(M // BLOCK_M,)]( + x, + out, + M, + N, + x.stride(0), + x.stride(1), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return out + + x = torch.randn(M, N, device="cuda") + + from contextlib import contextmanager + + from triton.runtime._allocation import _allocator + + @contextmanager + def triton_allocator(allocator): + prev = _allocator.get() + triton.set_allocator(allocator) + try: + yield + finally: + triton.set_allocator(prev) + + def fn_with_set_allocator(x): + triton.set_allocator( + lambda size, alignment, stream: torch.empty( + size, device="cuda", dtype=torch.int8 + ) + ) + return run_kernel(x) + + opt_fn = torch.compile( + fn_with_set_allocator, backend="aot_eager", fullgraph=True + ) + + # set_allocator inside compiled region does NOT graph break + with triton_allocator(NullAllocator()): + out = opt_fn(x) + self.assertEqual(out, x) + + # Verify set_allocator replays on cache hit (not just tracing) + triton.set_allocator(NullAllocator()) + out2 = opt_fn(x) + self.assertEqual(out2, x) + def test_closure_recompiles(self): cnt = CompileCounter() @@ -1578,6 +1677,59 @@ def fn(x): result = fn(torch.ones(1)) self.assertEqual(torch.ones(1) + 2, result) + def test_known_tensor_methods_traced(self): + # Verify that known tensor methods (in all_tensor_attrs) are still + # traced into the graph via the generic proxy path. + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt, fullgraph=True) + def fn(x): + return x.abs().cos() + + result = fn(torch.randn(4)) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(cnt.op_count, 2) + + def test_tensor_subclass_method_traced(self): + # Methods defined on the actual tensor class (including dynamically + # added ones) should be proxied through the generic call_method path, + # not graph-broken. This validates that the guard uses the concrete + # class_type rather than the static all_tensor_attrs dict. + def _dynamo_test_method(self): + return self + 1 + + with unittest.mock.patch.object( + torch.Tensor, "_dynamo_test_method", _dynamo_test_method, create=True + ): + cnt = CompileCounterWithBackend("eager") + + @torch.compile(backend=cnt) + def fn(x): + y = x._dynamo_test_method() + return y + 1 + + result = fn(torch.randn(4)) + self.assertEqual(cnt.frame_count, 1) + # Verify _dynamo_test_method appears as a call_method in the FX graph + call_method_targets = [ + n.target for n in cnt.graphs[0].graph.nodes if n.op == "call_method" + ] + self.assertIn("_dynamo_test_method", call_method_targets) + + def test_unknown_tensor_method_graph_break(self): + # Truly unknown methods raise AttributeError during tracing at + # LOAD_ATTR time (dynamic_getattr), ensuring dynamo does not + # silently proxy them into the compiled graph. + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt) + def fn(x): + y = x._nonexistent_test_method_xyz() + return y + 1 + + with self.assertRaises(AttributeError): + fn(torch.randn(4)) + def test_shape_unpack(self): def fn(x): a, b = x.size() @@ -4237,6 +4389,7 @@ def double_nested_call(x, y): o = torch.compile(foo, fullgraph=True, backend="eager")(x, y) self.assertEqual(o, x * y) + @torch._dynamo.config.patch(nested_graph_breaks=False) def test_module_deepcopy(self): m1 = torch.nn.Sequential( torch.nn.Linear(10, 10), @@ -4335,6 +4488,149 @@ def fn(x): result = torch.compile(fn, fullgraph=True)(x) self.assertEqual(result, correct) + def test_deepcopy_user_defined_object(self): + class MyConfig: + def __init__(self, hidden_size=64): + self.hidden_size = hidden_size + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = MyConfig() + self.linear = torch.nn.Linear(64, 64) + + def forward(self, x): + cfg = copy.deepcopy(self.config) + return self.linear(x) * cfg.hidden_size + + m = MyModule() + x = torch.randn(2, 64) + result = torch.compile(m, fullgraph=True, backend="eager")(x) + self.assertEqual(m(x), result) + + def test_deepcopy_user_defined_object_with_containers(self): + class Config: + def __init__(self): + self.sizes = [1, 2, 3] + self.mapping = {"a": 10, "b": 20} + self.flags = (True, False) + + def fn(x, cfg): + c = copy.deepcopy(cfg) + c.sizes[0] = 99 + c.mapping["a"] = 77 + return x + c.sizes[0] + c.mapping["a"] + + cfg = Config() + x = torch.randn(4) + correct = fn(x, cfg) + result = torch.compile(fn, fullgraph=True, backend="eager")(x, cfg) + self.assertEqual(result, correct) + # Verify deepcopy didn't mutate original + self.assertEqual(cfg.sizes[0], 1) + self.assertEqual(cfg.mapping["a"], 10) + + def test_deepcopy_set(self): + MY_SET = {1, 2, 3} + + def fn(x): + s = copy.deepcopy(MY_SET) + return x + len(s) + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_deepcopy_frozenset(self): + MY_FROZENSET = frozenset([1, 2, 3]) + + def fn(x): + s = copy.deepcopy(MY_FROZENSET) + return x + len(s) + + x = torch.randn(4) + correct = fn(x) + result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(result, correct) + + def test_deepcopy_user_defined_object_with_method(self): + class MyConfig: + def __init__(self, hidden_size=64): + self.hidden_size = hidden_size + + def get_size(self): + return self.hidden_size + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = MyConfig() + self.linear = torch.nn.Linear(64, 64) + + def forward(self, x): + cfg = copy.deepcopy(self.config) + return self.linear(x) * cfg.get_size() + + m = MyModule() + x = torch.randn(2, 64) + correct = m(x) + result = torch.compile(m, fullgraph=True, backend="eager")(x) + self.assertEqual(result, correct) + + def test_deepcopy_nested_user_defined_object(self): + class Inner: + def __init__(self, scale): + self.scale = scale + + class Outer: + def __init__(self): + self.inner = Inner(2.0) + self.bias = 1.0 + + def fn(x, cfg): + c = copy.deepcopy(cfg) + c.inner.scale = 3.0 + return x * c.inner.scale + c.bias + + cfg = Outer() + x = torch.randn(4) + correct = fn(x, cfg) + result = torch.compile(fn, fullgraph=True, backend="eager")(x, cfg) + self.assertEqual(result, correct) + # Verify deepcopy didn't mutate original + self.assertEqual(cfg.inner.scale, 2.0) + + def test_deepcopy_with_getattribute_override(self): + # Regression test: classes that override __getattribute__ (like + # HuggingFace PretrainedConfig) caused a graph break on + # __reduce_ex__ because SuperVariable.call_method for + # object.__getattribute__ bypassed the polyfill detection in + # resolve_type_attr. + class Config: + attribute_map = {} + + def __init__(self, hidden_size=768, num_layers=6): + self.hidden_size = hidden_size + self.num_layers = num_layers + + def __getattribute__(self, key): + if key != "attribute_map" and key in super().__getattribute__( + "attribute_map" + ): + key = super().__getattribute__("attribute_map")[key] + return super().__getattribute__(key) + + def fn(x, config): + c = copy.deepcopy(config) + return x * c.hidden_size + c.num_layers + + x = torch.randn(3) + config = Config() + correct = fn(x, config) + result = torch.compile(fn, backend="eager", fullgraph=True)(x, config) + self.assertEqual(result, correct) + def test_global_state_guard_serialization(self): GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard guards = GlobalStateGuard() @@ -5497,6 +5793,54 @@ def fn(x): # Graph breaks at manual_seed. self.assertEqual(len(counters["graph_break"]), 1) + def test_torch_generator_manual_seed(self): + from torch._dynamo.utils import counters + + cnts = torch._dynamo.testing.CompileCounter() + counters.clear() + + def fn(x, gen): + gen.manual_seed(3) + return x + 1 + + x = torch.randn(10) + ref = fn(x, torch.Generator()) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=False) + res = opt_fn(x, torch.Generator()) + + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.op_count, 1) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(len(counters["graph_break"]), 1) + + def test_torch_generator_initial_seed(self): + from torch._dynamo.utils import counters + + cnts = torch._dynamo.testing.CompileCounter() + counters.clear() + + def fn(x): + return x + 1, torch.default_generator.initial_seed() + + x = torch.randn(10) + ref = fn(x) + + opt_fn = torch.compile(fn, backend=cnts, fullgraph=False) + res = opt_fn(x) + + self.assertTrue(same(ref, res)) + self.assertEqual(cnts.op_count, 1) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(len(counters["graph_break"]), 1) + + def test_torch_generator_get_state_fullgraph(self): + def fn(): + return torch.default_generator.get_state() + + with self.assertRaises(torch._dynamo.exc.Unsupported): + torch.compile(fn, backend="eager", fullgraph=True)() + def test_is_tensor_like(self): cnts = torch._dynamo.testing.CompileCounter() @@ -6972,6 +7316,49 @@ def test_recompile(foo, *, exp_frame_count): torch._dynamo.reset() test_recompile(foo_graph_break, exp_frame_count=2) + def test_multithread_compile_dynamic(self): + def f(x): + comptime.assert_static(x.shape[0]) + return x * x + + def _do_test(func): + success = True + + def run(offset): + for i in range(20): + print(func(torch.randn(i * 2 + offset))) + + t1 = threading.Thread(target=run, args=[0]) + t2 = threading.Thread(target=run, args=[1]) + + def exc_hook(x): + nonlocal success + success = False + + try: + threading.excepthook = exc_hook + t1.start() + t2.start() + + t1.join() + t2.join() + finally: + threading.excepthook = threading.__excepthook__ + self.assertTrue(success) + + _do_test(torch.compile(f, backend="eager", dynamic=False)) + torch._dynamo.reset() + + f_opt = torch.compile(f, backend="eager") + + def g(x): + with torch._dynamo.config.patch( + automatic_dynamic_shapes=False, assume_static_by_default=True + ): + f_opt(x) + + _do_test(g) + def test_backend_match_guard_multi_threads(self): x = torch.randn([3, 4]) @@ -7152,11 +7539,7 @@ def body(x): mod = Module() - error_message = "" - if torch._dynamo.config.inline_inbuilt_nn_modules: - error_message = r"Higher Order Operator: torch\.ops\.higher_order\.map_impl" - else: - error_message = "Can't inplace modify module params/buffers" + error_message = r"Higher Order Operator: torch\.ops\.higher_order\.map_impl" with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, error_message @@ -7282,7 +7665,10 @@ def test_duplicate_graph_break_log(self): @torch.compile(backend="eager") def f1(a, b): - f2(a, b) + try: + f2(a, b) + finally: + pass def f2(a, b): c = a + b @@ -7291,7 +7677,10 @@ def f2(a, b): @torch.compile(backend="eager") def g1(a, b): - g2(a, b) + try: + g2(a, b) + finally: + pass def g2(a, b): c = a + b @@ -7788,6 +8177,60 @@ def fn(x, obj): res = opt_fn(x, obj) self.assertTrue(same(ref, res)) + def test_custom_instancecheck_does_not_cause_extra_init(self): + # When __new__ returns an object whose type is not a subclass of cls, + # CPython's type.__call__ skips __init__. The polyfill + # instantiate_user_defined_class_object must match this behavior even + # when the metaclass defines a custom __instancecheck__ that would + # return True for isinstance(). + class Meta(type): + def __instancecheck__(cls, instance): + return isinstance(instance, Base) and instance.tag == cls._tag + + class Base: + def __init__(self, tag="default"): + self.tag = tag + + class Child(Base, metaclass=Meta): + _tag = "child" + + def __new__(cls): + # Returns a Base (not a Child), like ByteStorage.__new__ + return Base(tag="child") + + def fn(): + obj = Child() + return obj.tag + + ref = fn() + self.assertEqual(ref, "child") + + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn() + self.assertEqual(res, "child") + + def test_custom_instancecheck_init_not_called(self): + class AlwaysTrueMeta(type): + def __instancecheck__(cls, instance): + return True + + class Child(metaclass=AlwaysTrueMeta): + def __new__(cls): + return object() + + def __init__(self): + raise AssertionError("should NOT be called") + + def fn(): + return Child() + + ref = fn() + self.assertIsInstance(ref, object) + + opt_fn = torch.compile(fn, backend="eager") + res = opt_fn() + self.assertIsInstance(res, object) + def test_variable_tracker_recursively_contains(self): # VariableTracker.recursively_contains should be updated correctly when mutation happens def fn(x): @@ -8109,31 +8552,6 @@ def guard_export_print(guards): # This guard was created self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents") - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def test_symint_as_device_kwarg_multi_gpu(self): - def fn(rank): - # -2 to make device id smaller for easier testing on CI - return torch.ones(10, device=rank.size(0) - 2) - - x = torch.randn(2) - out = fn(torch.randn(2)) - - guard_failure = None - - def guard_failures(failure): - nonlocal guard_failure - guard_failure = failure - - opt_fn = torch._dynamo.optimize( - "eager", guard_fail_fn=guard_failures, dynamic=True - )(fn) - self.assertEqual(out, opt_fn(x)) - - x = torch.randn(3) - self.assertEqual(fn(x), opt_fn(x)) - self.assertTrue(guard_failure is not None) - self.assertIn("""tensor 'rank' size mismatch at index 0""", guard_failure[0]) - @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "Test requires CUDA or XPU.") def test_symint_as_device_kwarg_non_strict_export(self): class Mod(torch.nn.Module): @@ -8580,6 +8998,99 @@ def fn(x): self.assertTrue(same(ref, res)) + def test_cast_with_different_module_types(self): + # typing.cast works correctly when used in a mixin pattern with + # different module types, producing correct results without + # graph breaks. + from typing import cast + + class Mixin: + def get_self_as_module(self): + return cast(torch.nn.Module, self) + + class ModuleA(Mixin, torch.nn.Module): + def forward(self, x): + self.get_self_as_module() + return x + 1 + + class ModuleB(Mixin, torch.nn.Module): + def forward(self, x): + self.get_self_as_module() + return x + 2 + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt) + def fn(mod, x): + mod.get_self_as_module() + return x + 1 + + x = torch.randn(4) + ref_a = fn.__wrapped__(ModuleA(), x) + ref_b = fn.__wrapped__(ModuleB(), x) + res_a = fn(ModuleA(), x) + res_b = fn(ModuleB(), x) + + self.assertEqual(ref_a, res_a) + self.assertEqual(ref_b, res_b) + self.assertEqual(cnt.frame_count, 2) + + def test_cast_fullgraph_with_non_tensor(self): + # Verify typing.cast works with non-tensor values under fullgraph=True + from typing import cast + + def fn(x): + val = cast(int, x.shape[0]) + return x + val + + opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) + + ref = fn(torch.ones(3, 4)) + res = opt_fn(torch.ones(3, 4)) + + self.assertTrue(same(ref, res)) + + @torch._dynamo.config.patch(nested_graph_breaks=False) + def test_cast_no_recompile_after_graph_break(self): + # In FSDP, cast(nn.Module, self) can be called after a + # graph break. Without the polyfill + skip_code fix, PEP 523 compiles + # typing.cast as a standalone frame with TYPE_MATCH guards on val, + # causing recompilation when different module types pass through. + # https://github.com/pytorch/pytorch/blob/0feb90404fbeb9b1594ae194f8fd47bbe7f5f245/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L376 + from typing import cast + + from torch._dynamo.utils import counters + + counters.clear() + + class Base(torch.nn.Module): + def get_state(self): + torch._dynamo.decorators.graph_break() + return cast(torch.nn.Module, self) + + class ModuleA(Base): + pass + + class ModuleB(Base): + pass + + cnt = torch._dynamo.testing.CompileCounter() + a, b = ModuleA(), ModuleB() + + @torch.compile(backend=cnt) + def fn(mod, x): + mod.get_state() + return x + 1 + + x = torch.randn(4) + fn(a, x) + fn(b, x) + self.assertEqual(cnt.frame_count, 1) + # 5 frames: fn (x2), get_state before graph_break (x2), + # get_state resume after graph_break (x1, no recompile). + # Without skip_code, typing.cast would add 2 more frames (7 total). + self.assertEqual(counters["frames"]["total"], 5) + def test_T_tensor_attribute(self): def fn(x, y): a = x.T @@ -10388,6 +10899,30 @@ def fn(): self.assertEqual(fn_out, compiled_out) self.assertEqual(fn_out, (True, False, True, False, True, False)) + def test_class_hasattr_sourceless_descriptor(self): + """Test that hasattr on sourceless UserDefinedClassVariable does not graph break.""" + + class FlagDescriptor: + def __get__(self, instance, owner): + if hasattr(owner, "flag"): + return 1 + return 0 + + class WithFlag: + flag = True + prop = FlagDescriptor() + + class WithoutFlag: + prop = FlagDescriptor() + + def fn(x, obj): + return x + obj.prop + + compiled_fn = torch.compile(backend="eager", fullgraph=True)(fn) + x = torch.randn(3) + self.assertEqual(fn(x, WithFlag()), compiled_fn(x, WithFlag())) + self.assertEqual(fn(x, WithoutFlag()), compiled_fn(x, WithoutFlag())) + def test_torch_objects_as_keys(self): remap = {torch.float16: torch.float32} @@ -11496,6 +12031,39 @@ def fn(x): self.assertEqual(list(eager), list(compiled)) self.assertEqual(len(counters["graph_break"]), 0) + def test_itertools_count_from_uncompiled_region(self): + counters.clear() + counter = itertools.count() + + def fn(x): + return x * (next(counter) + 1) + + x = torch.randn([2, 5]) + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + + self.assertEqual(compiled_fn(x), x) + self.assertEqual(compiled_fn(x), x * 2) + self.assertEqual(next(counter), 2) + self.assertEqual(len(counters["graph_break"]), 0) + + def test_itertools_count_already_advanced(self): + counters.clear() + counter = itertools.count() + # Advance the counter before entering the compiled region + next(counter) # 0 + next(counter) # 1 + + def fn(x): + return x * (next(counter) + 1) + + x = torch.randn([2, 5]) + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + + self.assertEqual(compiled_fn(x), x * 3) # next(counter) = 2 + self.assertEqual(compiled_fn(x), x * 4) # next(counter) = 3 + self.assertEqual(next(counter), 4) + self.assertEqual(len(counters["graph_break"]), 0) + def test_itertools_infinite_cycle(self): counters.clear() @@ -11994,6 +12562,35 @@ def fn(x, y): expected = fn(*inps) self.assertEqual(actual, expected) + def test_frozen_dataclass_in_compile(self): + from torch.utils._pytree import MappingKey, SequenceKey + + def fn(x): + path = (MappingKey("a"), SequenceKey(0)) + msg = f"path={path}" + return x * 2, msg + + x = torch.randn(4, 4) + eager_result = fn(x) + compiled_result = torch.compile(fn, fullgraph=True)(x) + self.assertEqual(eager_result[0], compiled_result[0]) + self.assertEqual(eager_result[1], compiled_result[1]) + + def test_frozen_dataclass_treespec_method_and_fields(self): + from torch.utils._pytree import tree_flatten + + def fn(x): + d = {"a": x, "b": [x * 2, x * 3]} + flat, spec = tree_flatten(d) + is_leaf = spec.is_leaf() + return sum(flat), spec.num_leaves, spec.num_nodes, is_leaf + + x = torch.randn(4) + eager_result = fn(x) + compiled_result = torch.compile(fn, fullgraph=True)(x) + for i in range(4): + self.assertEqual(eager_result[i], compiled_result[i]) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -12883,62 +13480,6 @@ def fn(x, const): c2 = _debug_get_cache_entry_list(fn.__code__) self.assertIs(c1[1], c2[0]) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - @skipIfWindows(msg="TODO: (xuhancn) conform, AssertionError: False is not true") - def test_dynamo_cache_invalidate(self): - DeletedGuardManagerWrapper = torch._dynamo.guards.DeletedGuardManagerWrapper - - class Mod(torch.nn.Module): - def __init__(self) -> None: - super(Mod, self).__init__() - self.fc = torch.nn.Linear(3, 3) - - def forward(self, out): - return self.fc(out) - - def fn(x, mod): - return mod(x) - - opt_fn = torch.compile(fn, backend="eager") - - m1 = Mod() - m2 = Mod() - m3 = Mod() - inp = torch.randn(3, 3) - - # NOTE: assumes that each cache entry is guarded - # on unique Mod instance - opt_fn(inp, m1) - opt_fn(inp, m2) - opt_fn(inp, m3) - - c1 = _debug_get_cache_entry_list(fn.__code__) - self.assertEqual(len(c1), 3) - - # move cache entry to front - opt_fn(inp, m2) - c2 = _debug_get_cache_entry_list(fn.__code__) - self.assertIs(c1[1], c2[0]) - - # delete center of cache - del m3 - c3 = _debug_get_cache_entry_list(fn.__code__) - self.assertEqual(len(c3), 3) - self.assertTrue(isinstance(c3[2].guard_manager, DeletedGuardManagerWrapper)) - - # delete end of cache - del m1 - c4 = _debug_get_cache_entry_list(fn.__code__) - self.assertEqual(len(c4), 3) - self.assertTrue(isinstance(c4[1].guard_manager, DeletedGuardManagerWrapper)) - self.assertTrue(isinstance(c4[2].guard_manager, DeletedGuardManagerWrapper)) - - del m2 - c5 = _debug_get_cache_entry_list(fn.__code__) - self.assertTrue(isinstance(c5[0].guard_manager, DeletedGuardManagerWrapper)) - self.assertTrue(isinstance(c5[1].guard_manager, DeletedGuardManagerWrapper)) - self.assertTrue(isinstance(c5[2].guard_manager, DeletedGuardManagerWrapper)) - def test_inspect_signature_bind(self): import inspect @@ -14376,6 +14917,47 @@ def f(*args, **kwargs): self.assertRaises(Unsupported, f, []) self.assertRaises(Unsupported, f, "1 + j") + def test_builtin_class_method_constant_fold(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(): + return ( + bool.__new__(bool), + bool.__new__(bool, 1), + bool.__new__(bool, 0), + bool.from_bytes(b"\x00" * 8, "big"), + bool.from_bytes(b"abcd", "little"), + int.__new__(int), + int.__new__(int, 42), + int.from_bytes(b"\x00\x03", "big"), + int.from_bytes(b"\xff", byteorder="big", signed=True), + float.fromhex("0x1.ffffp10"), + float.hex(1.5), + ) + + res = fn() + self.assertIs(res[0], False) + self.assertIs(res[1], True) + self.assertIs(res[2], False) + self.assertIs(res[3], False) + self.assertIs(res[4], True) + self.assertEqual(res[5], 0) + self.assertEqual(res[6], 42) + self.assertEqual(res[7], 3) + self.assertEqual(res[8], -1) + self.assertEqual(res[9], float.fromhex("0x1.ffffp10")) + self.assertEqual(res[10], "0x1.8000000000000p+0") + + def test_builtin_constant_fold_str_conversions(self): + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + s = hex(255) + oct(8) + bin(3) + ascii("hello") + format(42, "x") + return x + len(s) + + x = torch.randn(4) + res = fn(x) + expected = hex(255) + oct(8) + bin(3) + ascii("hello") + format(42, "x") + self.assertEqual(res, x + len(expected)) + def test_guard_string_escaped(self): d = {frozenset({0}): {frozenset({0}): 1}} @@ -14498,6 +15080,334 @@ def fn(x): result2 = compiled_fn(x2) self.assertEqual(result2, x2) + def test_requires_grad_changes_dynamo_graph(self): + # requires_grad_() on a graph input graph-breaks, so no fullgraph + def fn(x): + x.requires_grad_() + if x.requires_grad: + return x * 2 + return x + 1 + + x = torch.randn(3, 3) + opt_fn = torch.compile(fn) + result = opt_fn(x) + self.assertEqual(result, x * 2) + + def test_requires_grad_backward_outside_compile(self): + # requires_grad_() on a graph input graph-breaks, but eager fallback + # produces correct results. + def fn(x): + x.requires_grad_() + return (x * 2).sum() + + x_ref = torch.randn(3, 3) + x_test = x_ref.clone() + + fn(x_ref).backward() + torch.compile(fn)(x_test).backward() + + self.assertEqual(x_ref.grad, x_test.grad) + + def test_detach_inplace_on_intermediate_updates_metadata(self): + def fn(x): + y = x * 2 + y.detach_() + return y + 1, y.requires_grad, y.grad_fn is None + + x = torch.randn(3, 3, requires_grad=True) + ref = fn(x.clone()) + result = torch.compile(fn, backend="eager", fullgraph=True)(x.clone()) + + self.assertEqual(ref, result) + self.assertFalse(result[1]) + self.assertTrue(result[2]) + + def test_requires_grad_on_intermediate(self): + def fn(x): + y = x * 2 + y.requires_grad_() + return y + + x = torch.randn(3, 3) + + # fullgraph=True should error with actionable message + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + r"requires_grad_\(\)(.|\n)*\.detach\(\)", + ): + torch.compile(fn, fullgraph=True)(x) + + # Without fullgraph, falls back to eager and is correct + result = torch.compile(fn)(x) + self.assertTrue(result.requires_grad) + self.assertEqual(fn(x), result) + + def test_requires_grad_on_intermediate_derived_returned(self): + def fn(x): + y = x * 2 + y.requires_grad_() + return y * 3 + + x = torch.randn(3, 3) + + # Derived tensor also loses requires_grad — should error with message + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + r"requires_grad_\(\)(.|\n)*\.detach\(\)", + ): + torch.compile(fn, fullgraph=True)(x) + + # Without fullgraph, falls back to eager and is correct + result = torch.compile(fn)(x) + ref = fn(x) + self.assertTrue(result.requires_grad) + self.assertEqual(ref, result) + + def test_requires_grad_on_intermediate_partial_graph(self): + # When requires_grad_() on a source-less intermediate leaks as output, + # Dynamo should restart and graph break at requires_grad_(), capturing + # ops before it in a compiled graph (partial acceleration). + def fn(x): + a = x.sin() + b = a.cos() + b.requires_grad_() + return b + + backend = torch._dynamo.testing.EagerAndRecordGraphs() + x = torch.randn(3, 3) + result = torch.compile(fn, backend=backend)(x) + self.assertEqual(result, fn(x)) + self.assertTrue(result.requires_grad) + # The graph should capture the ops before requires_grad_() + self.assertEqual(len(backend.graphs), 1) + # Dynamic shapes adds shape guards to the graph, skip the exact check + if torch._dynamo.config.assume_static_by_default: + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ +def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + a = l_x_.sin(); l_x_ = None + b = a.cos(); a = None + return (b,)""", + ) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_on_intermediate_not_returned(self): + def fn(x): + y = x * 2 + y.requires_grad_() + loss = (y * 3).sum() + loss.backward() + return y.grad + + x = torch.randn(3, 3) + + ref = fn(x.clone()) + result = torch.compile(fn, fullgraph=True)(x.clone()) + self.assertEqual(ref, result) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_intermediate_backward_grad_used_in_compute(self): + # Use the grad result in further computation within compile + def fn(x): + y = x * 2 + y.requires_grad_() + loss = (y**2).sum() + loss.backward() + return y.grad * 2 + 1 + + x = torch.randn(3, 3) + + ref = fn(x.clone()) + result = torch.compile(fn, fullgraph=True)(x.clone()) + self.assertEqual(ref, result) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_intermediate_chunked_loss_backward(self): + # Mirrors the TxtUnembedding pattern: forward compute, detach, make + # new leaf, chunked loss with per-chunk backward, then propagate + # accumulated grad back to the original input via h.backward(). + def fn(x, targets): + # Forward computation before detach (e.g. transformer layers) + h = x * 2 + 1 + x_detached = h.detach().requires_grad_() + chunksz = x_detached.shape[0] // 2 + total_loss = torch.tensor(0.0) + for start in range(0, x_detached.shape[0], chunksz): + chunk = x_detached[start : start + chunksz] + chunk_targets = targets[start : start + chunksz] + logits = chunk @ torch.eye(chunk.shape[-1]) + loss = torch.nn.functional.cross_entropy(logits, chunk_targets) + loss.backward() + total_loss = total_loss + loss.detach() + # Propagate chunked grad back through the forward computation + h.backward(x_detached.grad) + return x.grad, total_loss + + x_ref = torch.randn(4, 8, requires_grad=True) + targets = torch.randint(0, 8, (4,)) + + x_test = x_ref.clone().detach().requires_grad_(True) + ref_grad, ref_loss = fn(x_ref, targets) + compiled_grad, compiled_loss = torch.compile(fn, fullgraph=True)( + x_test, targets + ) + self.assertEqual(ref_grad, compiled_grad) + self.assertEqual(ref_loss, compiled_loss) + # Verify grad propagated to the input + self.assertEqual(x_ref.grad, x_test.grad) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_intermediate_backward_and_return_detached(self): + # Returning a detached version of the tainted tensor is safe — detach() + # strips requires_grad so AOTAutograd functionalization can't lose anything. + def fn(x): + y = x * 2 + y.requires_grad_() + out = y * 3 + loss = out.sum() + loss.backward() + return y.grad, out.detach() + + x = torch.randn(3, 3) + + ref_grad, ref_out = fn(x.clone()) + compiled_grad, compiled_out = torch.compile(fn, fullgraph=True)(x.clone()) + self.assertEqual(ref_grad, compiled_grad) + self.assertEqual(ref_out, compiled_out) + self.assertFalse(compiled_out.requires_grad) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_intermediate_metadata_checks(self): + # After requires_grad_() on an intermediate, requires_grad and is_leaf + # should report correctly and be usable in control flow. + def fn(x): + y = x * 2 + y.requires_grad_() + if y.requires_grad and y.is_leaf: + loss = (y * 3).sum() + loss.backward() + return y.grad + return y + + x = torch.randn(3, 3) + ref = fn(x.clone()) + result = torch.compile(fn, fullgraph=True)(x.clone()) + self.assertEqual(ref, result) + + @torch._dynamo.config.patch(trace_autograd_ops=True) + def test_requires_grad_intermediate_side_effect_global(self): + # requires_grad_() on intermediate, then store grad in a global + saved = {} + + def fn(x): + y = x * 2 + y.requires_grad_() + loss = (y**2).sum() + loss.backward() + saved["grad"] = y.grad + return y.grad.clone() + + x = torch.randn(3, 3) + ref = fn(x.clone()) + saved_ref = saved["grad"].clone() + saved.clear() + + result = torch.compile(fn, fullgraph=True)(x.clone()) + self.assertEqual(ref, result) + self.assertEqual(saved_ref, saved["grad"]) + + def test_import_user_defined_module(self): + # testcase for https://github.com/pytorch/pytorch/issues/177682 + # Bad import result for types.ModuleType subclass in sys.modules + class _ConfigModule(types.ModuleType): + x = 1 + + _ConfigModule.__module__ = __name__ + sys.modules["my_config"] = _ConfigModule("my_config") + + def fn(): + import my_config # noqa: F401 + + return torch.tensor(1) + + compilefn = torch.compile(fn, fullgraph=True, backend="eager") + + ret1 = fn() + ret2 = compilefn() + self.assertEqual(ret1, ret2) + + def test_constant_subclass_guard_recompiles(self): + class MyInt(int): + def __eq__(self, other): + raise RuntimeError("should not be called during guard check") + + class MyFloat(float): + def __eq__(self, other): + raise RuntimeError("should not be called during guard check") + + class MyStr(str): + def __eq__(self, other): + raise RuntimeError("should not be called during guard check") + + cnt = torch._dynamo.testing.CompileCounter() + + # int subclass + @torch.compile(backend=cnt) + def f(x, y): + return x + y + + r1 = f(torch.tensor(1), MyInt(5)) + self.assertEqual(r1.item(), 6) + self.assertEqual(cnt.frame_count, 1) + + r2 = f(torch.tensor(1), MyInt(10)) + self.assertEqual(r2.item(), 11) + self.assertEqual(cnt.frame_count, 2) + + r3 = f(torch.tensor(1), MyInt(5)) + self.assertEqual(r3.item(), 6) + self.assertEqual(cnt.frame_count, 2) + + # float subclass + cnt.clear() + + @torch.compile(backend=cnt) + def g(x, y): + return x + y + + r4 = g(torch.tensor(1.0), MyFloat(2.5)) + self.assertEqual(r4.item(), 3.5) + self.assertEqual(cnt.frame_count, 1) + + r5 = g(torch.tensor(1.0), MyFloat(3.5)) + self.assertEqual(r5.item(), 4.5) + self.assertEqual(cnt.frame_count, 2) + + r6 = g(torch.tensor(1.0), MyFloat(2.5)) + self.assertEqual(r6.item(), 3.5) + self.assertEqual(cnt.frame_count, 2) + + # str subclass + cnt.clear() + + @torch.compile(backend=cnt, fullgraph=True) + def h(x, s): + return x + len(s) + + r7 = h(torch.tensor(1), MyStr("abc")) + self.assertEqual(r7.item(), 4) + self.assertEqual(cnt.frame_count, 1) + + r8 = h(torch.tensor(1), MyStr("abcde")) + self.assertEqual(r8.item(), 6) + self.assertEqual(cnt.frame_count, 2) + + r9 = h(torch.tensor(1), MyStr("abc")) + self.assertEqual(r9.item(), 4) + self.assertEqual(cnt.frame_count, 2) + class MiscTestsPyTree(torch._inductor.test_case.TestCase): @parametrize_pytree_module @@ -15086,21 +15996,6 @@ def f(rank): opt_out = torch.compile(backend="eager", dynamic=True, fullgraph=True)(f)(x) self.assertEqual(out, opt_out) - @unittest.skipIf(not TEST_MULTIGPU, "need multiple GPU") - def test_gpu_set_device(self, device): - def fn(): - a = torch.ones(2, device=device) - torch.get_device_module(device).set_device(1) - return a + 1 - - with torch.get_device_module(device).device(0): - counter = CompileCounter() - opt_fn = torch.compile(fn, backend=counter) - res = opt_fn() - self.assertTrue(res.device.type in device) - self.assertEqual(res.device.index, 0) - self.assertEqual(counter.frame_count, 2) - def test_torch_device_python_type(self, device): device_type = torch.device(device).type for device, device_type, index in [ diff --git a/test/dynamo/test_model_output.py b/test/dynamo/test_model_output.py index 08d9c11b09906..36b8a7789c4f4 100644 --- a/test/dynamo/test_model_output.py +++ b/test/dynamo/test_model_output.py @@ -13,13 +13,13 @@ try: from transformers import modeling_outputs from transformers.configuration_utils import PretrainedConfig - from transformers.file_utils import ModelOutput from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithPast, ) + from transformers.utils import ModelOutput except ImportError: modeling_outputs = None @@ -37,11 +37,11 @@ def fn(a, tmp): if hasattr(tmp, "somekey"): a = a + 1 if tmp.return_dict: - return a + torch.ones(2) * tmp.max_length + return a + torch.ones(2) * tmp.chunk_size_feed_forward return a x = torch.randn(2) - tmp = PretrainedConfig(return_dict=True, max_length=20) + tmp = PretrainedConfig(return_dict=True, chunk_size_feed_forward=20) ref = fn(x, tmp) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x, tmp) @@ -50,7 +50,7 @@ def fn(a, tmp): @maybe_skip def test_pretrained_non_const_attr(self): def fn(a, tmp): - if tmp.pruned_heads: + if tmp.attribute_map: return a + 1 else: return a - 1 @@ -359,11 +359,7 @@ def forward( ) -devices = ["cpu", "cuda", "xpu", "hpu"] - -instantiate_device_type_tests( - TestModelOutputBert, globals(), only_for=devices, allow_xpu=True -) +instantiate_device_type_tests(TestModelOutputBert, globals(), allow_xpu=True) if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 4a380640a0c8d..90f63f9686034 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -160,10 +160,12 @@ def tearDownClass(cls): super().tearDownClass() def setUp(self): + super().setUp() torch.set_default_device(None) torch._dynamo.reset() def tearDown(self): + super().tearDown() torch.set_default_device(None) torch._dynamo.reset() @@ -799,7 +801,7 @@ def test_hop(self): with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, - "raised exception HopDetectionError([ConstantVariable(str: 'test')])", + "raised exception HopDetectionError\\('test'\\)", ): # This runs in fullgraph already with TestModeRaises(): @@ -821,7 +823,7 @@ def test_hop_eager(self): with torch.device(GPU_TYPE): with self.assertRaisesRegex( torch._dynamo.exc.Unsupported, - "raised exception HopDetectionError([ConstantVariable(str: 'test')])", + "raised exception HopDetectionError\\('test'\\)", ): with TestModeRaises(): flex_attention_eager( @@ -927,7 +929,7 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None add: "f32[3, 3]" = torch.ops.aten.add.Tensor(mul, arg1_1); mul = arg1_1 = None return (add,) -""", # noqa: B950 +""", ) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) @@ -981,9 +983,10 @@ class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]"): mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None return (mul,) -""", # noqa: B950 +""", ) + @torch._functorch.config.patch(guess_tangent_strides_as_outputs=True) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) def test_invoke_subgraph_seq_nr(self): """ @@ -1048,7 +1051,8 @@ def bw_compiler(gm, example_inputs): [ ["add"], # seq_nr 21 [ - "clone", + "copy", + "empty_strided", "getitem", "getitem_1", "getitem_2", @@ -1099,37 +1103,30 @@ def forward(self, arg0_1: "f32[3, 3]"): # Annotation: {'seq_nr': 10} File: test_modes.py:921 in inner_fn, code: return y / 2 div: "f32[3, 3]" = torch.ops.aten.div.Tensor(cos, 2); cos = None return (div, arg0_1) - """, # noqa: B950 + """, ignore_comments=True, ignore_empty_lines=True, ) self.assertExpectedInline( normalize_gm(bw_graph.print_readable(print_output=False)), - """ + """\ class GraphModule(torch.nn.Module): def forward(self, getitem_1: "f32[3, 3]", tangents_1: "f32[]"): - # Annotation: {'seq_nr': 15} No stacktrace found for following nodes expand: "f32[3, 3]" = torch.ops.aten.expand.default(tangents_1, [3, 3]); tangents_1 = None - - # Annotation: {'seq_nr': 14} No stacktrace found for following nodes - clone: "f32[3, 3]" = torch.ops.aten.clone.default(expand, memory_format = torch.contiguous_format); expand = None + empty_strided: "f32[3, 3]" = torch.ops.aten.empty_strided.default([3, 3], [3, 1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + copy: "f32[3, 3]" = torch.ops.aten.copy.default(empty_strided, expand); empty_strided = expand = None repeated_subgraph1 = self.repeated_subgraph1 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, 'invoke_subgraph_1', getitem_1, clone); repeated_subgraph1 = getitem_1 = clone = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, 'invoke_subgraph_1', getitem_1, copy); repeated_subgraph1 = getitem_1 = copy = None getitem_2: "f32[3, 3]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None return (getitem_2,) - class repeated_subgraph1(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): - # Annotation: {'seq_nr': 10} File: test_modes.py:921 in inner_fn, code: return y / 2 div: "f32[3, 3]" = torch.ops.aten.div.Tensor(arg1_1, 2); arg1_1 = None - - # Annotation: {'test': 'test', 'seq_nr': 9} File: test_modes.py:920 in inner_fn, code: y = x.cos() sin: "f32[3, 3]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None neg: "f32[3, 3]" = torch.ops.aten.neg.default(sin); sin = None mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(div, neg); div = neg = None - return (mul,) - """, # noqa: B950 + return (mul,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1189,7 +1186,7 @@ class repeated_subgraph1(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]"): mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, 3); arg0_1 = None return (mul,) -""", # noqa: B950 +""", ) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) @@ -1236,7 +1233,7 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]", mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg2_1, arg3_1); arg2_1 = arg3_1 = None add: "f32[3, 3]" = torch.ops.aten.add.Tensor(mul, mul_1); mul = mul_1 = None return (add,) -""", # noqa: B950 +""", ) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) @@ -1286,7 +1283,7 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]"): sub: "f32[3, 3]" = torch.ops.aten.sub.Tensor(arg0_1, arg1_1) mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None return (add, sub, mul) -""", # noqa: B950 +""", ) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) @@ -1337,7 +1334,7 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]") mul: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None mul_1: "f32[3, 3]" = torch.ops.aten.mul.Tensor(mul, arg2_1); mul = arg2_1 = None return (add_1, mul_1) -""", # noqa: B950 +""", ) @torch._dynamo.config.patch(force_compile_during_fx_trace=True) @@ -2011,6 +2008,142 @@ def wrapped_fn(*args, **kwargs): dist.destroy_process_group() + @unittest.skipUnless( + IS_FLEX_ATTENTION_CUDA_PLATFORM_SUPPORTED and not torch.version.hip, + "Requires CUDA with SM >= 8.0, Triton, and not ROCm", + ) + def test_2tier_blockmask_tensor_closure_nested_compile_aot_export(self): + """2-tier AOT export with BlockMask whose mask_mod captures tensors. + + Reproduces the sixlib/mango pattern where: + - BlockMask is pytree-registered (as in sixlib/attention_mask.py) + - BlockMask with tensor closure is created inside forward + - aot_export flattens intermediates via pytree, exposing the + mask_mod callable (with captured tensors) in the context + + visibility is a model input -> FunctionalTensor during outer trace. + mask_mod closure captures derived slices (lower, upper). + """ + import contextlib + + import torch.fx.traceback as fx_traceback + from torch._functorch.aot_autograd import aot_export_joint_with_descriptors + from torch._subclasses import FakeTensorMode + from torch.nn.attention.flex_attention import ( + BlockMask, + create_block_mask, + flex_attention, + ) + from torch.utils._pytree import register_pytree_node, SUPPORTED_NODES + + # Register BlockMask as pytree node (same as sixlib/attention_mask.py) + if BlockMask not in SUPPORTED_NODES: + register_pytree_node( + BlockMask, + BlockMask._flatten, + BlockMask._unflatten, + flatten_with_keys_fn=BlockMask._flatten_with_keys, + serialized_type_name="torch.nn.attention.flex_attention.BlockMask", + ) + + d_model, n_heads = 64, 4 + batch_size, seq_len = 2, 32 + + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(d_model, d_model) + self.k_proj = torch.nn.Linear(d_model, d_model) + self.v_proj = torch.nn.Linear(d_model, d_model) + + def forward(self, x, visibility): + bs, sl, _ = x.shape + lower = visibility[:, 0, :] + upper = visibility[:, 1, :] + + def mask_mod(b_idx, h_idx, q_idx, k_idx): + return (lower[b_idx, q_idx] <= k_idx) & ( + k_idx <= upper[b_idx, q_idx] + ) + + bm = create_block_mask( + mask_mod, + B=bs, + H=None, + Q_LEN=sl, + KV_LEN=sl, + device=x.device, + ) + + q = ( + self.q_proj(x) + .view(bs, sl, n_heads, d_model // n_heads) + .transpose(1, 2) + ) + k = ( + self.k_proj(x) + .view(bs, sl, n_heads, d_model // n_heads) + .transpose(1, 2) + ) + v = ( + self.v_proj(x) + .view(bs, sl, n_heads, d_model // n_heads) + .transpose(1, 2) + ) + + with fx_traceback.annotate({"compile_with_inductor": 1}): + attn_out = flex_attention(q, k, v, block_mask=bm) + + out = attn_out.transpose(1, 2).contiguous().view(bs, sl, -1) + # Return BlockMask alongside output so it appears in the + # output pytree spec — this matches how sixlib's decoder + # returns context alongside outputs. + return out, bm + + with ( + torch._dynamo.config.patch(force_compile_during_fx_trace=True), + torch._inductor.config.patch(wrap_inductor_compiled_regions=True), + torch._functorch.config.patch(force_non_lazy_backward_lowering=True), + ): + torch._dynamo.reset() + + model = SimpleModel().to(GPU_TYPE) + x = torch.randn( + batch_size, + seq_len, + d_model, + device=GPU_TYPE, + dtype=torch.float32, + requires_grad=True, + ) + visibility = torch.zeros( + batch_size, + 2, + seq_len, + device=GPU_TYPE, + dtype=torch.int64, + ) + visibility[:, 1, :] = seq_len - 1 + + fake_mode = FakeTensorMode(allow_non_fake_inputs=True) + with contextlib.ExitStack() as stack: + stack.enter_context(fake_mode) + result = aot_export_joint_with_descriptors( + stack, + model, + args=(x, visibility), + kwargs={}, + keep_inference_input_mutations=True, + _disable_torch_fn_metadata_mode=True, + ) + + print("=== Outer (aot_export) graph ===") + print(result.graph_module.graph) + for name, submod in result.graph_module.named_modules(): + if name and hasattr(submod, "graph"): + print(f"\n=== Submodule: {name} ===") + print(submod.graph) + class TorchFunctionModeLifecycleTests(torch._dynamo.test_case.TestCase): def test_default_device_restored_after_mode_tests(self): diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 679f160590485..13e8b5ef61044 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -2140,35 +2140,8 @@ def forward(self, x): ): x = torch.randn(*size, requires_grad=True) mod(x) - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertEqual(cnts.frame_count, 1) - else: - self.assertEqual(cnts.frame_count, num_submodules) - - @patch.object(torch._dynamo.config, "accumulated_recompile_limit", 2) - @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) - def test_recompile_limit_on_freed_module(self): - class Mod(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.lin = torch.nn.Linear(5, 5) - - def forward(self, x): - return self.lin(x) - - def fn(x, mod): - return mod(x) - - cnts = torch._dynamo.testing.CompileCounterWithBackend("eager") - opt_mod = torch.compile(fn, backend=cnts) - for _ in range(8): - mod = Mod() - opt_mod(torch.randn(5, 5), mod) - - # fn compiles twice - self.assertEqual(cnts.frame_count, 2) + self.assertEqual(cnts.frame_count, 1) - @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) def test_inline_inbuilt_nn_modules(self): size = (10, 10) recompile_limit = 1 @@ -2249,10 +2222,7 @@ def forward(self, x): ]: x = torch.randn(size) mod(x) - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertEqual(cnts.frame_count, 2) - else: - self.assertEqual(cnts.frame_count, 2 * num_submodules) + self.assertEqual(cnts.frame_count, 2) def test_recursion(self): mod = MockModule() @@ -2449,56 +2419,6 @@ def new_forward_hook( self.assertEqual(compiled_func(inp), outer_func(inp)) self.assertEqual(compiled_func(inp).item(), 16) - @patch.object(torch._dynamo.config, "guard_nn_modules", False) - @patch.object(torch._dynamo.config, "skip_nnmodule_hook_guards", True) - @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", False) - def test_hooks_skip_guards(self): - class TestModule(torch.nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return 2 * x + 1 - - m = TestModule() - - def forward_hook( - module: torch.nn.Module, inputs: tuple[torch.Tensor], output: torch.Tensor - ) -> torch.Tensor: - return 2 * output + 1 - - handle = m.register_forward_hook(forward_hook) - - def outer_func(tensor): - x = tensor * 2 + 1 - y = m(x) - return y - - inp = torch.tensor(1.0, requires_grad=True) - - failure_reason = None - - def guard_fail_fn(failure): - nonlocal failure_reason - failure_reason = failure[0] - - cc = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") - compiled_func = torch._dynamo.optimize( - guard_fail_fn=guard_fail_fn, - backend=cc, - )(outer_func) - - m = TestModule() - handle = m.register_forward_hook(forward_hook) - failure_reason = None - self.assertEqual(compiled_func(inp), outer_func(inp)) - self.assertEqual(compiled_func(inp).item(), 15) - self.assertEqual(cc.frame_count, 1) - self.assertEqual(cc.op_count, 6) - - # if we remove the hook, dynamo shouldn't notice - handle.remove() - self.assertNotEqual(compiled_func(inp), outer_func(inp)) - self.assertEqual(compiled_func(inp).item(), 15) - self.assertEqual(cc.frame_count, 1) - def _forward_hook_test_helper(self, model): forward_handles = {} compiled_activations = {} @@ -2814,18 +2734,6 @@ def run(): run() self.assertTrue(models[0].abc) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_assign_does_not_exist(self): - class MyModule(torch.nn.Module): - def forward(self, x): - self.text_encoding = x + 1 - return self.text_encoding - - mod = MyModule() - out = torch.compile(mod, fullgraph=True, backend="eager")(torch.randn(10)) - if mod.text_encoding is not out: - raise AssertionError("Expected mod.text_encoding to be out") - def test_module_dict_iter_values(self): class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -2886,16 +2794,10 @@ def foo(mod, x): mod = Mod() foo(mod, torch.rand([4])) - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertEqual(compiles_without_buffers, 1) - else: - self.assertEqual(compiles_without_buffers, 0) + self.assertEqual(compiles_without_buffers, 1) foo(mod, torch.rand([4], dtype=torch.half)) - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertEqual(compiles_without_buffers, 2) - else: - self.assertEqual(compiles_without_buffers, 1) + self.assertEqual(compiles_without_buffers, 2) class Mod2(Mod): def __setattr__(self, name, value): @@ -2913,7 +2815,6 @@ def test_unspec_non_inlinable_module(self): expected = mod(x) self.assertEqual(actual, expected) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_mark_static_previously_seen_tensor(self): # This test verifies that dynamo will mark # the buffers/params of a module as static @@ -2959,7 +2860,6 @@ def fn(x, b, mod): fn(inp, buf, mod) self.assertEqual(num_compiles, 1) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_mark_static_nn_module_tensor(self): # This test verifies that dynamo will mark # the nn module tensor attributes as static @@ -3002,7 +2902,6 @@ def fn(x): fn(inp) self.assertEqual(num_compiles, 1) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @torch._inductor.config.patch("freezing", True) @torch.no_grad() def test_mark_static_with_freezing(self): @@ -3242,8 +3141,26 @@ def fn(x): self.assertEqual(cnt.frame_count, 2) helper() - with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): - helper() + + def test_monkeypatching_forward_inside_compiled_region(self): + class Mod(torch.nn.Module): + def forward(self, x): + return x - 1 + + @torch.compile(backend="eager", fullgraph=True) + def fn(mod, x, y): + def patch(x): + return x + y + + mod.forward = patch + return mod(x) + + inp0 = torch.ones(3) + inp1 = torch.ones(3) + mod = Mod() + + self.assertEqual(fn(mod, inp0, inp1), inp0 + inp1) + self.assertEqual(mod(inp0), inp0 + inp1) def test_user_defined_nn_module_dynamic(self): class Conv2d(torch.nn.Conv2d): @@ -3279,7 +3196,6 @@ def forward(self, x): # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. self.assertEqual(cnts.frame_count, 3) - @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) def test_overridden_call(self): class OverRiddenCallModule(torch.nn.Module): def __init__(self): @@ -3304,7 +3220,6 @@ def fn(x): res = opt_fn(x) self.assertEqual(ref, res) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_param_requires_grad(self): def adjust_model(model): to_freeze = model.num_iter % 2 == 0 @@ -3351,7 +3266,6 @@ def forward(self, x): self.assertEqual(cnt.frame_count, 3) @torch._dynamo.config.patch("use_recursive_dict_tags_for_guards", False) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_param_requires_grad_no_recursive_dict_tags(self): class MyModule(torch.nn.Module): def __init__(self): @@ -3377,7 +3291,35 @@ def forward(self, x): self.assertEqual(cnt.frame_count, 2) self.assertIsNotNone(model.linear.weight.grad) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) + @torch._dynamo.config.patch(skip_tensor_guards_with_matching_dict_tags=True) + @torch._dynamo.config.patch("use_recursive_dict_tags_for_guards", True) + def test_param_dtype_change_recompiles_with_recursive_dict_tags(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4)) + + def forward(self, x): + return x * self.scale + + model = MyModule() + x = torch.randn(4) + + cnt = torch._dynamo.testing.CompileCounter() + compiled = torch.compile(model, backend=cnt, fullgraph=True) + + self.assertTrue(torch._dynamo.testing.same(model(x), compiled(x))) + self.assertEqual(cnt.frame_count, 1) + + model.to(dtype=torch.float64) + + recompiled = torch.compile(model, backend=cnt, fullgraph=True) + result = recompiled(x) + + self.assertEqual(result.dtype, torch.float64) + self.assertTrue(torch._dynamo.testing.same(model(x), result)) + self.assertEqual(cnt.frame_count, 2) + def test_param_requires_grad_submodule(self): class Inner(torch.nn.Module): def __init__(self): @@ -3682,6 +3624,67 @@ def scale_output(module, args, kwargs, output): self.assertEqual(compiled_call_count, eager_call_count) self.assertTrue(torch.allclose(output_eager, output)) + @patch.object(torch._dynamo.config, "guard_nn_modules", True) + def test_dict_insertion_guard_method_func(self): + class SimpleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + def hook_function(module, args): + return (args[0] + 1.0,) + + class HookHelper: + def hook_method(self, module, args): + return (args[0] + 2.0,) + + model = SimpleModel() + helper = HookHelper() + + model.register_forward_pre_hook(hook_function) + model.register_forward_pre_hook(helper.hook_method, prepend=True) + + @torch.compile(fullgraph=True, backend="eager") + def runner_func(mod, x): + return mod(x) + + input_tensor = torch.randn(1, 10) + # This would error before fixing guard orering on nn.Modules (https://github.com/pytorch/pytorch/issues/170429) + _ = runner_func(model, input_tensor) + + def test_prepend_hook_ordering(self): + class HookedLinear(torch.nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_forward_pre_hook(self._hook_add) + self.register_forward_pre_hook(self._hook_mul, prepend=True) + + @staticmethod + def _hook_add(module, args): + return (args[0] + 1,) + + @staticmethod + def _hook_mul(module, args): + return (args[0] * 2,) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer = HookedLinear(4, 4, bias=False) + + def forward(self, x): + return self.layer(x) + + model = Model() + x = torch.ones(1, 4) + + eager = model(x) + compiled = torch.compile(model, backend="eager", fullgraph=True)(x) + self.assertEqual(eager, compiled) + devices = ["cuda", "hpu", "xpu"] instantiate_device_type_tests( diff --git a/test/dynamo/test_nb_bool.py b/test/dynamo/test_nb_bool.py new file mode 100644 index 0000000000000..b22e764719fe4 --- /dev/null +++ b/test/dynamo/test_nb_bool.py @@ -0,0 +1,278 @@ +# Owner(s): ["module: dynamo"] +"""Tests for nb_bool / generic_bool: bool() via PyObject_IsTrue in Dynamo.""" + +import collections +import enum + +import torch + + +class _Color(enum.Enum): + RED = 1 + BLUE = 2 + + +import torch.nn +from torch.testing._internal.common_utils import make_dynamo_test, run_tests, TestCase + + +class NbBoolTests(TestCase): + # --- Scalar constants (ConstantVariable path) --- + + @make_dynamo_test + def test_int(self): + self.assertEqual(bool(0), False) + self.assertEqual(bool(1), True) + self.assertEqual(bool(-1), True) + + @make_dynamo_test + def test_float(self): + self.assertEqual(bool(0.0), False) + self.assertEqual(bool(-0.0), False) + self.assertEqual(bool(3.14), True) + + @make_dynamo_test + def test_none(self): + self.assertEqual(bool(None), False) + + @make_dynamo_test + def test_str(self): + self.assertEqual(bool(""), False) + self.assertEqual(bool("nonempty"), True) + + @make_dynamo_test + def test_bytes(self): + self.assertEqual(bool(b""), False) + self.assertEqual(bool(b"hello"), True) + + @make_dynamo_test + def test_bool(self): + self.assertEqual(False, False) + self.assertEqual(True, True) + + @make_dynamo_test + def test_complex_zero(self): + self.assertEqual(bool(0j), False) + + @make_dynamo_test + def test_complex_nonzero(self): + self.assertEqual(bool(1 + 2j), True) + + @make_dynamo_test + def test_complex_real_nonzero_imag_zero(self): + self.assertEqual(bool(1 + 0j), True) + + @make_dynamo_test + def test_complex_real_zero_imag_nonzero(self): + self.assertEqual(bool(0 + 1j), True) + + # --- Containers (length fallback / _bool_from_length path) --- + + @make_dynamo_test + def test_empty_list(self): + self.assertEqual(bool([]), False) + + @make_dynamo_test + def test_nonempty_list(self): + self.assertEqual(bool([1, 2, 3]), True) + + @make_dynamo_test + def test_empty_dict(self): + self.assertEqual(bool({}), False) + + @make_dynamo_test + def test_nonempty_dict(self): + self.assertEqual(bool({"a": 1}), True) + + @make_dynamo_test + def test_empty_tuple(self): + self.assertEqual(bool(()), False) + + @make_dynamo_test + def test_nonempty_tuple(self): + self.assertEqual(bool((1,)), True) + + @make_dynamo_test + def test_empty_set(self): + self.assertEqual(bool(set()), False) + + @make_dynamo_test + def test_nonempty_set(self): + self.assertEqual(bool({1, 2}), True) + + @make_dynamo_test + def test_empty_frozenset(self): + self.assertEqual(bool(frozenset()), False) + + @make_dynamo_test + def test_nonempty_frozenset(self): + self.assertEqual(bool(frozenset({1})), True) + + @make_dynamo_test + def test_empty_range(self): + self.assertEqual(bool(range(0)), False) + + @make_dynamo_test + def test_nonempty_range(self): + self.assertEqual(bool(range(5)), True) + + # --- dict subclasses --- + + @make_dynamo_test + def test_empty_defaultdict(self): + d = collections.defaultdict(int) + self.assertEqual(bool(d), False) + + @make_dynamo_test + def test_nonempty_defaultdict(self): + d = collections.defaultdict(int, {"x": 1}) + self.assertEqual(bool(d), True) + + @make_dynamo_test + def test_empty_counter(self): + c = collections.Counter() + self.assertEqual(bool(c), False) + + @make_dynamo_test + def test_nonempty_counter(self): + c = collections.Counter("abc") + self.assertEqual(bool(c), True) + + # --- Enum (UserDefinedClassVariable / ConstantVariable path) --- + + @make_dynamo_test + def test_enum_member(self): + self.assertEqual(bool(_Color.RED), True) + self.assertEqual(bool(_Color.BLUE), True) + + # --- UserDefinedObjectVariable tests (torch.compile path) --- + + def test_user_defined_with_bool(self): + class MyObj: + def __init__(self, val): + self.val = val + + def __bool__(self): + return self.val > 0 + + def fn(x, obj): + return x + 1 if bool(obj) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, MyObj(5)), compiled(x, MyObj(5))) + torch._dynamo.reset() + compiled = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, MyObj(-1)), compiled(x, MyObj(-1))) + + def test_user_defined_with_len(self): + class Container: + def __init__(self, items): + self.items = items + + def __len__(self): + return len(self.items) + + def fn(x, c): + return x + 1 if bool(c) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, Container([1, 2])), compiled(x, Container([1, 2]))) + torch._dynamo.reset() + compiled = torch.compile(fn, backend="eager") + self.assertEqual(fn(x, Container([])), compiled(x, Container([]))) + + def test_user_defined_no_bool_no_len(self): + class Plain: + pass + + def fn(x, obj): + return x + 1 if bool(obj) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x, Plain()), compiled(x, Plain())) + + def test_user_defined_bool_returns_non_bool_raises(self): + class BadBool: + def __bool__(self): + return 1 + + def fn(x, obj): + return x + 1 if bool(obj) else x - 1 + + with self.assertRaises(TypeError): + bool(BadBool()) + with self.assertRaises(TypeError): + torch.compile(fn, backend="eager")(torch.randn(4), BadBool()) + + # --- Metaclass with __bool__ (UserDefinedClassVariable path) --- + + def test_metaclass_bool(self): + class Foo(type): + def __bool__(cls): + return False + + class A(metaclass=Foo): + pass + + class Bar(type): + pass + + class B(metaclass=Bar): + pass + + def fn(x): + # A's metaclass defines __bool__ returning False; B's does not (truthy). + return x + bool(A) + bool(B) + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compiled(x)) + + # --- nn.Module (NNModuleVariable path) --- + + def test_nn_module_nonempty(self): + mod = torch.nn.ModuleList([torch.nn.Linear(4, 4)]) + + def fn(x): + return x + 1 if bool(mod) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compiled(x)) + + def test_nn_module_empty(self): + mod = torch.nn.ModuleList() + + def fn(x): + return x + 1 if bool(mod) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compiled(x)) + + # --- Tensor (TensorVariable path) --- + + def test_tensor_nonzero(self): + def fn(x): + t = torch.tensor(1) + return x + 1 if bool(t) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compiled(x)) + + def test_tensor_zero(self): + def fn(x): + t = torch.tensor(0) + return x + 1 if bool(t) else x - 1 + + x = torch.randn(4) + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), compiled(x)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_nb_float.py b/test/dynamo/test_nb_float.py new file mode 100644 index 0000000000000..e4c11e5395940 --- /dev/null +++ b/test/dynamo/test_nb_float.py @@ -0,0 +1,321 @@ +# Owner(s): ["module: dynamo"] +"""Tests for nb_float_impl: unified __float__ / float() protocol in Dynamo.""" + +import torch +import torch._dynamo.testing +from torch.testing._internal.common_utils import ( + make_dynamo_test, + run_tests, + skipIfCrossRef, + TestCase, +) + + +class NbFloatTests(TestCase): + # --- float / int / bool (ConstantVariable) --- + + @make_dynamo_test + def test_float_float(self): + self.assertEqual(float(3.14), 3.14) # noqa: UP018 + + @make_dynamo_test + def test_int_float(self): + self.assertEqual(float(5), 5.0) + + @make_dynamo_test + def test_bool_float(self): + self.assertEqual(float(True), 1.0) + self.assertEqual(float(False), 0.0) + + @make_dynamo_test + def test_float_dunder_float(self): + self.assertEqual((3.14).__float__(), 3.14) + + @make_dynamo_test + def test_int_dunder_float(self): + self.assertEqual((5).__float__(), 5.0) + + @make_dynamo_test + def test_bool_dunder_float(self): + self.assertEqual(True.__float__(), 1.0) + + # --- TypeError for non-float types --- + + def test_complex_float_raises(self): + def fn(x): + try: + return float(1 + 2j) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_none_float_raises(self): + def fn(x): + try: + return float(None) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_list_float_raises(self): + def fn(x): + try: + return float([1, 2]) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_dict_float_raises(self): + def fn(x): + try: + return float({}) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_set_float_raises(self): + def fn(x): + try: + return float(set()) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_tuple_float_raises(self): + def fn(x): + try: + return float(()) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + # --- str parsing (constructor, not nb_float) --- + + @make_dynamo_test + def test_str_float_parsing(self): + self.assertEqual(float("3.14"), 3.14) + + @make_dynamo_test + def test_str_float_int(self): + self.assertEqual(float("123"), 123.0) + + # --- UserDefinedObjectVariable with __float__ --- + + def test_user_defined_float(self): + class MyFloat: + def __init__(self, v): + self.v = v + + def __float__(self): + return self.v + + obj = MyFloat(3.14) + + def fn(x): + return float(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 3.14 + ) + + def test_user_defined_dunder_float(self): + class MyFloat: + def __init__(self, v): + self.v = v + + def __float__(self): + return self.v + + obj = MyFloat(2.71) + + def fn(x): + return obj.__float__() + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 2.71 + ) + + def test_user_defined_no_float_raises(self): + class NoFloat: + pass + + obj = NoFloat() + + def fn(x): + try: + return float(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("float() argument must be a string", result) + + def test_float_returning_non_float_raises(self): + class Bad: + def __float__(self): + return "not a float" + + obj = Bad() + + def fn(x): + try: + return float(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("__float__ returned non-float", result) + + def test_float_raising_exception_propagates(self): + class RaisingFloat: + def __float__(self): + raise ValueError("custom error") + + obj = RaisingFloat() + + def fn(x): + try: + return float(obj) + except ValueError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertEqual(result, "custom error") + + def test_user_defined_staticmethod_float(self): + class StaticFloat: + @staticmethod + def __float__(): + return 3.0 + + obj = StaticFloat() + + def fn(x): + return float(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 3.0 + ) + + def test_user_defined_classmethod_float(self): + class ClassFloat: + @classmethod + def __float__(cls): + return 4.0 + + obj = ClassFloat() + + def fn(x): + return float(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 4.0 + ) + + # --- nb_index fallback (PyNumber_Float step 3) --- + + def test_index_fallback_for_float(self): + class HasIndex: + def __index__(self): + return 42 + + obj = HasIndex() + + def fn(x): + return float(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 42.0 + ) + + def test_index_fallback_no_float_no_index_raises(self): + class NoSlots: + pass + + obj = NoSlots() + + def fn(x): + try: + return float(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertIn("float() argument must be a string", result) + self.assertEqual(result, eager_result) + + # --- Tensor --- + + @skipIfCrossRef + def test_tensor_int_dtype(self): + def fn(x): + return float(x) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(5)) + self.assertEqual(result, 5.0) + + @skipIfCrossRef + def test_tensor_float_dtype(self): + def fn(x): + return float(x) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(3.14)) + self.assertAlmostEqual(result, 3.14, places=2) + + @skipIfCrossRef + def test_tensor_complex_raises(self): + def fn(x): + try: + return float(x) + except RuntimeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)( + torch.tensor(1 + 2j) + ) + eager_result = fn(torch.tensor(1 + 2j)) + self.assertIn( + "value cannot be converted to type double without overflow", result + ) + self.assertEqual(result, eager_result) + + @skipIfCrossRef + def test_tensor_dunder_float(self): + def fn(x): + return x.__float__() + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(5)) + self.assertEqual(result, 5.0) + + # --- SymNodeVariable --- + + @skipIfCrossRef + def test_symnode_float(self): + def fn(x): + return float(x.size(0)) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(10, 20)) + self.assertEqual(result, 10.0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_nb_index.py b/test/dynamo/test_nb_index.py new file mode 100644 index 0000000000000..e06dc4addbdca --- /dev/null +++ b/test/dynamo/test_nb_index.py @@ -0,0 +1,353 @@ +# Owner(s): ["module: dynamo"] +"""Tests for nb_index_impl: unified __index__ / operator.index protocol in Dynamo.""" + +import operator + +import torch +import torch._dynamo.testing +from torch.testing._internal.common_utils import make_dynamo_test, run_tests, TestCase + + +class NbIndexTests(TestCase): + # --- int / bool (ConstantVariable) --- + + @make_dynamo_test + def test_int_index(self): + self.assertEqual(operator.index(5), 5) + + @make_dynamo_test + def test_bool_index(self): + self.assertEqual(operator.index(True), 1) + + @make_dynamo_test + def test_int_dunder_index(self): + self.assertEqual((5).__index__(), 5) + + # --- TypeError for non-index types --- + + def test_str_index_raises(self): + def fn(x): + try: + return operator.index("hello") + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("cannot be interpreted as an integer", result) + + def test_float_index_raises(self): + def fn(x): + try: + return operator.index(3.14) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("cannot be interpreted as an integer", result) + + def test_list_index_raises(self): + def fn(x): + try: + return operator.index([1, 2]) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("cannot be interpreted as an integer", result) + + def test_none_index_raises(self): + def fn(x): + try: + return operator.index(None) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("cannot be interpreted as an integer", result) + + # --- UserDefinedObjectVariable with __index__ --- + + def test_user_defined_index(self): + class MyInt: + def __init__(self, v): + self.v = v + + def __index__(self): + return self.v + + obj = MyInt(42) + + def fn(x): + return operator.index(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 42 + ) + + def test_user_defined_dunder_index(self): + class MyInt: + def __init__(self, v): + self.v = v + + def __index__(self): + return self.v + + obj = MyInt(7) + + def fn(x): + return obj.__index__() + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 7 + ) + + def test_user_defined_no_index_raises(self): + class NoIndex: + pass + + obj = NoIndex() + + def fn(x): + try: + return operator.index(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("cannot be interpreted as an integer", result) + + # --- nb_index used in list/tuple subscript (PyNumber_AsSsize_t) --- + + @make_dynamo_test + def test_list_subscript_with_bool(self): + lst = [10, 20, 30] + self.assertEqual(lst[True], 20) + + @make_dynamo_test + def test_tuple_subscript_with_bool(self): + t = (10, 20, 30) + self.assertEqual(t[True], 20) + + def test_list_subscript_with_user_defined_index(self): + class MyIdx: + def __index__(self): + return 1 + + idx = MyIdx() + + def fn(x): + lst = [10, 20, 30] + return lst[idx] + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 20 + ) + + def test_tuple_subscript_with_user_defined_index(self): + class MyIdx: + def __index__(self): + return 2 + + idx = MyIdx() + + def fn(x): + t = (10, 20, 30) + return t[idx] + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 30 + ) + + def test_list_subscript_with_no_index_raises(self): + class NoIndex: + pass + + obj = NoIndex() + + def fn(x): + lst = [10, 20, 30] + try: + return lst[obj] + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("list indices must be integers or slices", result) + + def test_user_defined_staticmethod_index(self): + class StaticIdx: + @staticmethod + def __index__(): + return 3 + + obj = StaticIdx() + + def fn(x): + return operator.index(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 3 + ) + + def test_user_defined_classmethod_index(self): + class ClassIdx: + @classmethod + def __index__(cls): + return 4 + + obj = ClassIdx() + + def fn(x): + return operator.index(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 4 + ) + + def test_custom_getitem_with_user_defined_index(self): + class MyIdx: + def __index__(self): + return 1 + + class MyContainer: + def __init__(self, data): + self.data = data + + def __getitem__(self, idx): + return self.data[idx] + + idx = MyIdx() + container = MyContainer([10, 20, 30]) + + def fn(x): + return container[idx] + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 20 + ) + + def test_dict_getitem_with_non_indexable(self): + class NoIndex: + pass + + obj = NoIndex() + + def fn(x): + d = {obj: 42} + return d[obj] + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 42 + ) + + def test_list_getitem_non_indexable_matches_cpython(self): + class NoIndex: + pass + + obj = NoIndex() + + def fn(x): + try: + return [10, 20][obj] + except TypeError: + return "caught" + + # Both CPython and Dynamo should raise TypeError and catch it + eager_result = fn(torch.tensor(0)) + compiled_result = torch.compile(fn, backend="eager", fullgraph=True)( + torch.tensor(0) + ) + self.assertEqual(eager_result, "caught") + self.assertEqual(compiled_result, "caught") + + def test_list_subscript_error_message_matches_cpython(self): + def fn(x): + try: + return [10, 20]["hello"] # noqa: RUF016 + except TypeError as e: + return str(e) + + eager_result = fn(torch.tensor(0)) + compiled_result = torch.compile(fn, backend="eager", fullgraph=True)( + torch.tensor(0) + ) + self.assertIn("list indices must be integers or slices", eager_result) + self.assertIn("list indices must be integers or slices", compiled_result) + + def test_index_returning_non_int_raises(self): + class Bad: + def __index__(self): + return "not an int" # noqa: PLE0305 + + obj = Bad() + + def fn(x): + try: + return operator.index(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("__index__ returned non-int", result) + + def test_negative_index_via_user_defined(self): + class NegIdx: + def __index__(self): + return -1 + + idx = NegIdx() + + def fn(x): + lst = [10, 20, 30] + return lst[idx] + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 30 + ) + + def test_index_raising_exception_propagates(self): + class RaisingIdx: + def __index__(self): + raise ValueError("custom error") + + obj = RaisingIdx() + + def fn(x): + try: + return operator.index(obj) + except ValueError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertEqual(result, "custom error") + + # --- Tensor __index__ --- + + def test_tensor_int_index(self): + def fn(x): + return operator.index(x) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(5)) + self.assertEqual(result, 5) + + def test_tensor_float_index_raises(self): + def fn(x): + try: + return operator.index(x) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(1.5)) + self.assertIn("only integer tensors", result) + + def test_list_subscript_with_tensor(self): + def fn(x): + lst = [10, 20, 30] + return lst[x] + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(2)) + self.assertEqual(result, 30) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_nb_int.py b/test/dynamo/test_nb_int.py new file mode 100644 index 0000000000000..60d538b986fc9 --- /dev/null +++ b/test/dynamo/test_nb_int.py @@ -0,0 +1,317 @@ +# Owner(s): ["module: dynamo"] +"""Tests for nb_int_impl: unified __int__ / int() protocol in Dynamo.""" + +import torch +import torch._dynamo.testing +from torch.testing._internal.common_utils import make_dynamo_test, run_tests, TestCase + + +class NbIntTests(TestCase): + # --- int / bool (ConstantVariable) --- + + @make_dynamo_test + def test_int_int(self): + self.assertEqual(5, 5) + + @make_dynamo_test + def test_bool_int(self): + self.assertEqual(int(True), 1) + self.assertEqual(int(False), 0) + + @make_dynamo_test + def test_int_dunder_int(self): + self.assertEqual((5).__int__(), 5) + + @make_dynamo_test + def test_bool_dunder_int(self): + self.assertEqual(True.__int__(), 1) + + # --- float (ConstantVariable) --- + + @make_dynamo_test + def test_float_int(self): + self.assertEqual(int(3.14), 3) + + @make_dynamo_test + def test_float_negative_int(self): + self.assertEqual(int(-2.9), -2) + + @make_dynamo_test + def test_float_dunder_int(self): + self.assertEqual((3.14).__int__(), 3) + + # --- TypeError for non-int types --- + + def test_complex_int_raises(self): + def fn(x): + try: + return int(1 + 2j) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_none_int_raises(self): + def fn(x): + try: + return int(None) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_list_int_raises(self): + def fn(x): + try: + return int([1, 2]) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_dict_int_raises(self): + def fn(x): + try: + return int({}) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_set_int_raises(self): + def fn(x): + try: + return int(set()) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + def test_tuple_int_raises(self): + def fn(x): + try: + return int(()) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertEqual(result, eager_result) + + # --- str parsing (constructor, not nb_int) --- + + @make_dynamo_test + def test_str_int_parsing(self): + self.assertEqual(int("123"), 123) + + @make_dynamo_test + def test_str_int_base(self): + self.assertEqual(int("ff", 16), 255) + + # --- UserDefinedObjectVariable with __int__ --- + + def test_user_defined_int(self): + class MyInt: + def __init__(self, v): + self.v = v + + def __int__(self): + return self.v + + obj = MyInt(42) + + def fn(x): + return int(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 42 + ) + + def test_user_defined_dunder_int(self): + class MyInt: + def __init__(self, v): + self.v = v + + def __int__(self): + return self.v + + obj = MyInt(7) + + def fn(x): + return obj.__int__() + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 7 + ) + + def test_user_defined_no_int_raises(self): + class NoInt: + pass + + obj = NoInt() + + def fn(x): + try: + return int(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("int() argument must be a string", result) + + def test_int_returning_non_int_raises(self): + class Bad: + def __int__(self): + return "not an int" + + obj = Bad() + + def fn(x): + try: + return int(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertIn("__int__ returned non-int", result) + + def test_int_raising_exception_propagates(self): + class RaisingInt: + def __int__(self): + raise ValueError("custom error") + + obj = RaisingInt() + + def fn(x): + try: + return int(obj) + except ValueError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + self.assertEqual(result, "custom error") + + def test_user_defined_staticmethod_int(self): + class StaticInt: + @staticmethod + def __int__(): + return 3 + + obj = StaticInt() + + def fn(x): + return int(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 3 + ) + + def test_user_defined_classmethod_int(self): + class ClassInt: + @classmethod + def __int__(cls): + return 4 + + obj = ClassInt() + + def fn(x): + return int(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 4 + ) + + # --- nb_index fallback (PyNumber_Long step 3) --- + + def test_index_fallback_for_int(self): + class HasIndex: + def __index__(self): + return 42 + + obj = HasIndex() + + def fn(x): + return int(obj) + + self.assertEqual( + torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)), 42 + ) + + def test_index_fallback_no_int_no_index_raises(self): + class NoSlots: + pass + + obj = NoSlots() + + def fn(x): + try: + return int(obj) + except TypeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(0)) + eager_result = fn(torch.tensor(0)) + self.assertIn("int() argument must be a string", result) + self.assertEqual(result, eager_result) + + # --- Tensor --- + + def test_tensor_int_dtype(self): + def fn(x): + return int(x) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(5)) + self.assertEqual(result, 5) + + def test_tensor_float_dtype(self): + def fn(x): + return int(x) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(3.7)) + self.assertEqual(result, 3) + + def test_tensor_complex_raises(self): + def fn(x): + try: + return int(x) + except RuntimeError as e: + return str(e) + + result = torch.compile(fn, backend="eager", fullgraph=True)( + torch.tensor(1 + 2j) + ) + eager_result = fn(torch.tensor(1 + 2j)) + self.assertIn( + "value cannot be converted to type int64_t without overflow", result + ) + self.assertEqual(result, eager_result) + + def test_tensor_dunder_int(self): + def fn(x): + return x.__int__() + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.tensor(5)) + self.assertEqual(result, 5) + + # --- SymNodeVariable --- + + def test_symnode_int(self): + def fn(x): + return int(x.size(0)) + + result = torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(10, 20)) + self.assertEqual(result, 10) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_nested_graph_breaks.py b/test/dynamo/test_nested_graph_breaks.py index c77cd95f70c3a..54fdfd14426cd 100644 --- a/test/dynamo/test_nested_graph_breaks.py +++ b/test/dynamo/test_nested_graph_breaks.py @@ -73,7 +73,7 @@ def make_nested_cls(cls): global1, global2, global3, global4 = (torch.zeros(3),) * 4 -class NestedGraphBreakTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks): +class NestedGraphBreakTests(torch._dynamo.test_case.TestCase): def test_single_graph_break(self): # NOTE marking f1, f2, f3 as global # prevents them from being freevars @@ -1038,6 +1038,27 @@ def f4(x): # multiplication by 32, 64, 128, 256 self.assertEqual(cnts.op_count, 4) + def test_resume_closure_different_module_globals(self): + # Tests that resume functions with freevars (closures) from inlined + # frames get the correct f_globals. Without the factory fix, + # MAKE_FUNCTION inherits the root frame's globals, so the resume + # function for `inner` would not find HELPER_CONSTANT. + try: + from . import _test_nested_graph_breaks_helper + except ImportError: + import _test_nested_graph_breaks_helper + + def outer(x): + return _test_nested_graph_breaks_helper.closure_with_graph_break(x) + 2 + + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch._dynamo.optimize(backend=cnts)(outer) + x = torch.zeros(3) + res = outer(x) + ref = opt_fn(x) + self.assertEqual(ref, res) + self.assertEqual(cnts.frame_count, 2) + def test_error_on_graph_break_nested(self): # error_on_graph_break in a nested frame cnts = torch._dynamo.testing.CompileCounter() @@ -1104,6 +1125,256 @@ def f8(x): with self.assertRaises(torch._dynamo.exc.Unsupported): f8(inp) + def test_graph_break_in_wrapped_user_function(self): + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 2 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + x = torch.no_grad()(fn)(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + def test_graph_break_in_nested_wrapped_user_function(self): + # no_grad wraps enable_grad wraps no_grad wraps fn + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 2 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + wrapped = torch.no_grad()(torch.enable_grad()(torch.no_grad()(fn))) + x = wrapped(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + def test_graph_break_in_wrapped_user_function_with_multiple_breaks(self): + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + x = x + 2 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 3 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + x = torch.no_grad()(fn)(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 6) + self.assertEqual(cnts.frame_count, 3) + + def test_graph_break_in_sequential_wrapped_user_functions(self): + def fn1(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 2 + + def fn2(x): + x = x + 3 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 4 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + x = torch.no_grad()(fn1)(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + x = torch.no_grad()(fn2)(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 10) + # 3 frames: before fn1's break, after fn1's break through fn2's break, + # after fn2's break to end + self.assertEqual(cnts.frame_count, 3) + + def test_graph_break_in_wrapped_enable_grad(self): + @torch.no_grad() + def fn(x): + x = x + 1 + # enable_grad inside no_grad + x = torch.enable_grad()(lambda y: y + 2)(x) + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 3 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + return fn(x) + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 6) + + def test_graph_break_in_wrapped_user_method(self): + class Foo: + def __init__(self): + self.a = 1 + self.b = 2 + + def fn(self, x): + x = x + self.a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + self.b + + obj = Foo() + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + obj.fn = torch.no_grad()(obj.fn) + x = obj.fn(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + def test_graph_break_in_wrapped_nested_function(self): + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + a = 1 + b = 2 + + @torch.no_grad() + def fn(x): + x = x + a + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + b + + x = fn(x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + def test_graph_break_in_wrapped_skipped_function(self): + from torch._dynamo.testing import _skipped_function_for_test_reconstruct + + def fn(x): + x = x + 1 + torch._dynamo.graph_break() + assert torch.compiler.is_compiling() # noqa: S101 + assert not torch.is_grad_enabled() # noqa: S101 + return x + 2 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def gn(x): + x = torch.no_grad()(_skipped_function_for_test_reconstruct)(fn, x) + assert torch.compiler.is_compiling() # noqa: S101 + assert torch.is_grad_enabled() # noqa: S101 + return x + + inp = torch.randn(3) + self.assertEqual(gn(inp), inp + 3) + self.assertEqual(cnts.frame_count, 2) + + def test_step_graph_break_frame_values_not_corrupted(self): + """Bytecode generation bug in step_graph_break corrupted parent frame + locals when the parent had a non-empty operand stack (num_stack > 0). + """ + + def inner(x): + x = x + 1 + x = x + 1 + torch._dynamo.step_unsupported() + return x + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def fn(x): + x = x + 1 + y = (x, inner(x)) + return x, y + + x = torch.tensor([1.0, 2.0]) + result = fn(x) + self.assertEqual(result[0], torch.tensor([2.0, 3.0])) + self.assertEqual( + result[1], (torch.tensor([2.0, 3.0]), torch.tensor([4.0, 5.0])) + ) + self.assertEqual(cnts.frame_count, 1) + self.assertEqual(cnts.op_count, 3) + + def test_contextmanager_graph_break_in_init(self): + """Graph break in _GeneratorContextManager.__init__ when the generator + function is @torch._disable_dynamo (the DDP pattern).""" + from contextlib import contextmanager + + @contextmanager + @torch._disable_dynamo(recursive=False) + def my_ctx(): + yield + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnts) + def fn(x): + x = x + 1 + with my_ctx(): + x = x + 2 + return x + 3 + + inp = torch.randn(3) + self.assertEqual(fn(inp), inp + 6) + self.assertEqual(cnts.frame_count, 1) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 5f47a3745bfd5..87b4c088a1a1a 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -492,7 +492,7 @@ def guard_filter_fn(guards): compiled_fn(*args) total_frames = torch._dynamo.convert_frame.FRAME_COUNTER - self._save_and_reload(expected_backends=8, expected_dynamo=1) + self._save_and_reload(expected_backends=9, expected_dynamo=1) compiled_fn = torch._dynamo.optimize( backend="inductor", guard_filter_fn=guard_filter_fn diff --git a/test/dynamo/test_profiler.py b/test/dynamo/test_profiler.py index 0145b9f79bb83..79ae8f775d832 100644 --- a/test/dynamo/test_profiler.py +++ b/test/dynamo/test_profiler.py @@ -312,7 +312,7 @@ def forward(self, L_x_ : torch.Tensor): _record_function_enter_new_2 = torch.ops.profiler._record_function_enter_new('my_net2', None) c = b + 2; b = None _record_function_exit__record_function_2 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_2); _record_function_enter_new_2 = _record_function_exit__record_function_2 = None - return (c,)""", # noqa: B950 + return (c,)""", ) self.assertExpectedInline( backend.fw_graphs[0].code.strip(), @@ -327,7 +327,7 @@ def forward(self, arg0_1): _record_function_enter_new_2 = torch.ops.profiler._record_function_enter_new.default('my_net2') add = torch.ops.aten.add.Tensor(cos, 2); cos = None _record_function_exit_2 = torch.ops.profiler._record_function_exit._RecordFunction(_record_function_enter_new_2); _record_function_enter_new_2 = _record_function_exit_2 = None - return (add,)""", # noqa: B950 + return (add,)""", ) with torch.profiler.profile() as prof: fn_c( diff --git a/test/dynamo/test_python_dispatcher.py b/test/dynamo/test_python_dispatcher.py index d74077a5be4ce..42363cec6a2be 100644 --- a/test/dynamo/test_python_dispatcher.py +++ b/test/dynamo/test_python_dispatcher.py @@ -77,7 +77,7 @@ def forward(self, L_x_: "f32[2, 3]"): sub: "f32[2, 3]" = l_x_ - 1; l_x_ = None sin: "f32[2, 3]" = torch.sin(sub); sub = None return (sin,) -""", # NOQA: B950 +""", ) @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu") diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index a853020ad8213..b90635f119090 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -321,6 +321,120 @@ def h(x): self.assertEqual(res, inp + 3) +class RecompileLimitKwargTests(torch._dynamo.test_case.TestCase): + @staticmethod + def _num_cache_entries(code): + return len(torch._dynamo.eval_frame._debug_get_cache_entry_list(code)) + + def test_recompile_limit_basic(self): + cnt = torch._dynamo.testing.CompileCounter() + + def f(x, y): + return x + y + + opt_f = torch.compile(f, backend=cnt, recompile_limit=2) + + opt_f(torch.randn(3), torch.randn(3)) + self.assertEqual(self._num_cache_entries(f), 1) + + opt_f(torch.randn(3, dtype=torch.float64), torch.randn(3, dtype=torch.float64)) + self.assertEqual(self._num_cache_entries(f), 2) + + # Third dtype should NOT trigger recompilation (recompile_limit=2) + opt_f(torch.randn(3, dtype=torch.float16), torch.randn(3, dtype=torch.float16)) + self.assertEqual(self._num_cache_entries(f), 2) + + def test_recompile_limit_none_uses_global(self): + cnt = torch._dynamo.testing.CompileCounter() + + def f(x, y): + return x + y + + # Without recompile_limit kwarg, uses global config (default 8) + opt_f = torch.compile(f, backend=cnt) + + for i in range(10): + dtype = [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + torch.int16, + torch.int8, + torch.uint8, + torch.complex64, + ][i] + opt_f(torch.ones(3, dtype=dtype), torch.ones(3, dtype=dtype)) + + self.assertEqual( + self._num_cache_entries(f), torch._dynamo.config.recompile_limit + ) + + def test_recompile_limit_fullgraph_raises(self): + """With fullgraph=True, hitting the recompile_limit kwarg raises + FailOnRecompileLimitHit, consistent with the fullgraph contract.""" + cnt = torch._dynamo.testing.CompileCounter() + + def f(x): + return x.sin() + + opt_f = torch.compile(f, backend=cnt, fullgraph=True, recompile_limit=1) + + opt_f(torch.randn(3)) + self.assertEqual(cnt.frame_count, 1) + + with self.assertRaises(FailOnRecompileLimitHit): + opt_f(torch.randn(3, dtype=torch.float64)) + + @torch._dynamo.config.patch(automatic_dynamic_shapes=True) + def test_recompile_limit_resume_function_auto_dynamic(self): + """With automatic dynamic shapes and recompile_limit=2, the resume + function recompiles via dimension changes on a global tensor while + the main function gets cache hits. The resume function should stop + at 2 entries and fall back to eager.""" + cnt = torch._dynamo.testing.CompileCounter() + + y_holder = {"tensor": torch.randn(4, 8, 2)} + + def f(x): + x.sin() + print("graph break") + return y_holder["tensor"].cos() + + opt_f = torch.compile(f, backend=cnt, recompile_limit=2) + + # Call 1: static compile + y_holder["tensor"] = torch.randn(4, 8, 2) + opt_f(torch.randn(4, 8, 2)) + + # Call 2: y dim0 changes -> f cache hit, resume recompiles + y_holder["tensor"] = torch.randn(5, 8, 2) + opt_f(torch.randn(4, 8, 2)) + frame_count_after_2 = cnt.frame_count + + # Call 3: y dim1 changes -> resume should NOT recompile + # (resume already has 2 entries = recompile_limit) + y_holder["tensor"] = torch.randn(5, 9, 2) + opt_f(torch.randn(4, 8, 2)) + self.assertEqual(cnt.frame_count, frame_count_after_2) + + # Verify f has 1 entry, resume has 2 + num_f_entries = len(torch._dynamo.eval_frame._debug_get_cache_entry_list(f)) + self.assertEqual(num_f_entries, 1) + + from torch._dynamo.resume_execution import ContinueExecutionCache + + resume_codes = list(ContinueExecutionCache.cache[f.__code__].values()) + self.assertTrue(len(resume_codes) > 0, "No resume functions found") + for resume_code in resume_codes: + num_resume_entries = len( + torch._dynamo.eval_frame._debug_get_cache_entry_list(resume_code) + ) + self.assertEqual(num_resume_entries, 2) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_recompiles.py b/test/dynamo/test_recompiles.py index a8579cc4c6178..827062f17985e 100644 --- a/test/dynamo/test_recompiles.py +++ b/test/dynamo/test_recompiles.py @@ -8,54 +8,6 @@ class RecompileTests(torch._dynamo.test_case.TestCase): - def test_inline_inbuilt_nn_modules_candidate(self): - def hook_flag_on(guard_manager, f_locals, builder): - self.assertTrue( - "[inline-inbuilt-nn-modules-candidate]" not in str(guard_manager) - ) - - def hook_flag_off(guard_manager, f_locals, builder): - self.assertTrue( - "[inline-inbuilt-nn-modules-candidate]" in str(guard_manager) - ) - - class SubMod(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - - @torch.compile(backend="eager") - def forward(self, x): - return self.linear(x) - - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.sm1 = SubMod() - self.sm2 = SubMod() - - def forward(self, x): - return self.sm1(x) + self.sm2(x) - - try: - from .utils import install_guard_manager_testing_hook - except ImportError: - from utils import install_guard_manager_testing_hook - - with ( - install_guard_manager_testing_hook(hook_flag_on), - dc.patch(inline_inbuilt_nn_modules=True), - ): - mod = Mod() - mod(torch.randn(2, 2)) - - with ( - install_guard_manager_testing_hook(hook_flag_off), - dc.patch(inline_inbuilt_nn_modules=False), - ): - mod = Mod() - mod(torch.randn(2, 2)) - def test_automatic_dynamic_reduce_recompiles(self): # Test the counterfactual, lots of recompiles without this config def foo(x, y): diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index 1662e63d9ffdd..49c3ca2895889 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -57,6 +57,89 @@ def f(d, t): opt_f(d_opt, t) self.assertEqual(d, d_opt) + def _compile_and_capture_side_effects(self, fn, *args): + """Compile fn and return side-effect metadata from bytecode hooks.""" + captured = {} + + def rewrite_hook(code, out_code): + return out_code.replace(co_name=f"{out_code.co_name}_hooked") + + def inspect_hook(code, out_code): + captured["refs"] = ( + torch._dynamo.convert_frame.get_compiled_code_side_effects(out_code) + ) + captured["has_side_effects"] = ( + torch._dynamo.convert_frame.compiled_code_has_side_effects(out_code) + ) + + torch._dynamo.reset() + rewrite_handle = torch._dynamo.convert_frame.register_bytecode_hook( + rewrite_hook + ) + inspect_handle = torch._dynamo.convert_frame.register_bytecode_hook( + inspect_hook + ) + try: + torch.compile(fn, backend="eager", fullgraph=True)(*args) + finally: + inspect_handle.remove() + rewrite_handle.remove() + + return captured + + def test_bytecode_hook_exposes_side_effect_refs(self): + def mutating_fn(x, lst): + lst.append(x + 1) + return x * 2 + + def pure_fn(x): + return x * 2 + + x = torch.randn(3) + + mutated = self._compile_and_capture_side_effects(mutating_fn, x, []) + self.assertEqual(mutated["refs"], ("L['lst']",)) + self.assertTrue(mutated["has_side_effects"]) + + pure = self._compile_and_capture_side_effects(pure_fn, x) + self.assertEqual(pure["refs"], ()) + self.assertFalse(pure["has_side_effects"]) + + def test_side_effect_refs_dict_mutation(self): + def fn(x, d): + d["result"] = x + 1 + return x * 2 + + result = self._compile_and_capture_side_effects(fn, torch.randn(3), {}) + self.assertEqual(result["refs"], ("L['d']",)) + self.assertTrue(result["has_side_effects"]) + + def test_side_effect_refs_tensor_in_container(self): + # Relevant to cudagraphs: a compiled function computes tensors and + # stores them into an external container as a side effect. + def fn(x, outputs): + y = x * 2 + z = x + 3 + outputs.append(y) + outputs.append(z) + return x + + result = self._compile_and_capture_side_effects(fn, torch.randn(4), []) + self.assertEqual(result["refs"], ("L['outputs']",)) + self.assertTrue(result["has_side_effects"]) + + def test_side_effect_refs_multiple_containers(self): + def fn(x, lst, d): + lst.append(x + 1) + d["out"] = x * 2 + return x + + result = self._compile_and_capture_side_effects(fn, torch.randn(3), [], {}) + self.assertEqual(len(result["refs"]), 2) + self.assertIn("L['lst']", result["refs"]) + self.assertIn("L['d']", result["refs"]) + self.assertTrue(result["has_side_effects"]) + def test_ConstDict_pop_reconstruct(self): """ If something is pop'ed from the dict, we reconstruct everything @@ -305,6 +388,22 @@ def fn(model, states, x): got = opt_fn(model, states, x) self.assertEqual(expected, got) + def test_ordered_dict_no_reconstruct_without_mutation(self): + """Sourced OrderedDict should not emit BUILD_MAP when not mutated.""" + + def hook(instructions: list[dis.Instruction]): + build_map = _filter_instructions(instructions, "BUILD_MAP") + self.assertEqual(len(build_map), 0) + + def fn(od, x): + return x + od["a"] + + od = collections.OrderedDict(a=1, b=2) + with self.register_bytecode_hook(hook): + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + got = opt_fn(od, torch.tensor(1.0)) + self.assertEqual(got, torch.tensor(2.0)) + def test_graph_break_in_wrapped_user_function(self): def fn(x): x = x + 1 @@ -559,54 +658,60 @@ def hook(instructions: list[dis.Instruction]): opt_gn = torch.compile(gn, backend="eager", fullgraph=True) opt_gn(torch.ones(3)) - def test_as_python_constant_super_delegation_no_false_positive(self): - """Passing a reference-type opaque object to a value-type constructor - should NOT produce a false 'self-referential' error. - - When OpaqueObjectClassVariable.call_function evaluates constructor args - via as_python_constant(), a reference-type TorchScriptObjectVariable's - as_python_constant delegates to UserDefinedObjectVariable's via super(). - The _add_call_once_guard must key on id(original_method) rather than the - method name string, so that this super() delegation is not mistaken for - a self-referential call. + def test_opaque_reference_as_python_constant(self): + """TSOV.as_python_constant must succeed for reference-type opaque + objects. Without this, __eq__ between two opaque objects graph breaks. """ - from torch._library.opaque_object import OpaqueBase, register_opaque_type + import torch._library.opaque_object + import torch._opaque_base - class _Counter(OpaqueBase): - def __init__(self, start, end): - self.start = start - self.end = end + class Config(torch._opaque_base.OpaqueBase): + def __init__(self, v): + self.v = v - register_opaque_type(_Counter, typ="reference") - - class _ValueWrapper(OpaqueBase): - def __init__(self, inner): - self.inner = inner + def __bool__(self): + return True def __eq__(self, other): - return isinstance(other, _ValueWrapper) and self.inner == other.inner + return isinstance(other, Config) and self.v == other.v def __hash__(self): - return hash(id(self.inner)) + return hash(self.v) - def __fx_repr__(self): - return "_ValueWrapper(inner=None)", {"_ValueWrapper": _ValueWrapper} + torch._library.opaque_object.register_opaque_type(Config, typ="reference") - register_opaque_type(_ValueWrapper, typ="value") + cfg = Config(42) - counter = _Counter(0, 10) + def fn(x, cfg): + if cfg: + return x + 1 + return x - @torch.compile(backend="eager", fullgraph=False) - def f(c): - return _ValueWrapper(c) + opt = torch.compile(fn, backend="eager", fullgraph=True) + result = opt(torch.ones(4), cfg) + self.assertEqual(result, torch.ones(4) + 1) - # Counter is a reference-type opaque object. Passing it to a value-type - # constructor triggers as_python_constant() on the TSOV, which delegates - # via super() to UDOV. This should fail because reference types are not - # constants, but the error must NOT be "self-referential". - with self.assertRaises(RuntimeError) as ctx: - f(counter) - self.assertNotIn("self-referential", str(ctx.exception)) + def test_call_once_guard_allows_super_delegation(self): + """_add_call_once_guard must key on (id(self), id(original_method)) + so that super().as_python_constant() between VT subclasses is not + mistaken for a self-referential call. + """ + from torch._dynamo.variables.base import VariableTracker + + class _Parent(VariableTracker): + def as_python_constant(self): + return 42 + + class _Child(_Parent): + def as_python_constant(self): + return super().as_python_constant() + + child = _Child() + # With name-based keying, _Child and _Parent share the same key + # (id(self), "as_python_constant"), causing a false + # AsPythonConstantNotImplementedError("self-referential"). + self.assertEqual(child.as_python_constant(), 42) + self.assertTrue(child.is_python_constant()) if __name__ == "__main__": diff --git a/test/dynamo/test_regional_inductor.py b/test/dynamo/test_regional_inductor.py index 4306529736cda..d16b35cd5f156 100644 --- a/test/dynamo/test_regional_inductor.py +++ b/test/dynamo/test_regional_inductor.py @@ -525,8 +525,9 @@ def flex_attn_fn(x): ) _, codes = run_fw_bw_and_get_code(lambda: compiled_module(x)) - # flex in forward and flex_backward in backward - self.assertEqual(len(codes), 2) + # flex in forward and flex_backward in backward; contiguous partitioning + # may split non-contiguous annotated nodes into separate regions + self.assertGreaterEqual(len(codes), 2) def test_refcounts(self): """Tests that activations can be cleared before the end of graph""" @@ -834,7 +835,7 @@ def forward(self, primals_1, primals_2): getitem_10 = invoke_subgraph_6[1] getitem_1 = invoke_subgraph_6[0]; invoke_subgraph_6 = None sin_3 = torch.ops.aten.sin.default(getitem_1) - return (sin_3, primals_1, getitem_9, getitem_8, getitem, sin_1, getitem_11, getitem_10, getitem_1)""", # noqa: B950 + return (sin_3, primals_1, getitem_9, getitem_8, getitem, sin_1, getitem_11, getitem_10, getitem_1)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -859,7 +860,7 @@ def forward(self, primals_1, getitem_9, getitem_8, getitem, sin_1, getitem_11, g add = torch.ops.aten.add.Tensor(getitem_3, getitem_6); getitem_3 = getitem_6 = None cos_3 = torch.ops.aten.cos.default(primals_1); primals_1 = None mul_3 = torch.ops.aten.mul.Tensor(getitem_5, cos_3); getitem_5 = cos_3 = None - return (mul_3, add)""", # noqa: B950 + return (mul_3, add)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -990,7 +991,7 @@ def forward(self, primals_0, primals_1, primals_2, primals_3, primals_4, primals alias_1 = torch.ops.aten.alias.default(getitem_1); getitem_1 = None alias_2 = torch.ops.aten.alias.default(alias); alias = None alias_3 = torch.ops.aten.alias.default(alias_1); alias_1 = None - return (getitem, primals_0, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, alias_2, alias_3)""", # noqa: B950 + return (getitem, primals_0, primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, alias_2, alias_3)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1008,7 +1009,7 @@ def forward(self, primals_0, primals_1, primals_2, primals_3, primals_4, primals getitem_5 = flex_attention_backward[2]; flex_attention_backward = None add = torch.ops.aten.add.Tensor(getitem_3, getitem_4); getitem_3 = getitem_4 = None add_1 = torch.ops.aten.add.Tensor(add, getitem_5); add = getitem_5 = None - return (add_1, None, None, None, None, None, None, None, None)""", # noqa: B950 + return (add_1, None, None, None, None, None, None, None, None)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1524,5 +1525,152 @@ def fn(x): self.assertIsNotNone(post_compiled._graph_module) # Now deserialized +class RegionalInductorPartitionTests(torch._inductor.test_case.TestCase): + """Tests for _RegionScooper partitioning behavior. + + Uses CapabilityBasedPartitioner per region ID. Nodes with the same region ID + are merged as aggressively as possible (only cycles prevent merging). Nodes + with different region IDs are never merged. + """ + + def _make_tag_node(self, g, inp, scalar, tagged): + node = g.call_function(torch.ops.aten.mul.Scalar, (inp, scalar)) + if tagged: + node.meta["custom"] = {"compile_with_inductor": True} + node.meta["val"] = torch.empty(10) + return node + + def _scoop_and_count(self, gm): + from torch.fx.passes.regional_inductor import _RegionScooper + + with torch.fx.traceback.preserve_node_meta(enable=False): + partitioned = _RegionScooper.scoop_regions(gm) + return sum( + 1 + for node in partitioned.graph.nodes + if node.op == "call_module" + and node.target.startswith("__marked_inductor_submod") + ) + + def _build_chain(self, tags): + """Build a linear chain: x -> op0 -> op1 -> ... -> output""" + g = torch.fx.Graph() + current = g.placeholder("x") + for tagged in tags: + current = self._make_tag_node(g, current, 1.0, tagged) + g.output(current) + return torch.fx.GraphModule(torch.nn.Module(), g) + + def test_linear_chain_partitioning(self): + """Contiguous runs of tagged nodes form separate partitions.""" + cases = [ + # (tags, expected_partitions) + ([False, True, True, True, False], 1), + ([True, True, False, True, True], 2), + ([True, False, True, False, True], 3), + ([True, False, True, False, True, False, True], 4), + ([False, False, False], 0), + ] + for tags, expected in cases: + with self.subTest(tags=tags): + gm = self._build_chain(tags) + self.assertEqual(self._scoop_and_count(gm), expected) + + def test_parallel_branches_not_fused(self): + """Two adjacent independent tagged branches form 1 partition.""" + g = torch.fx.Graph() + x = g.placeholder("x") + mul_a = self._make_tag_node(g, x, 2.0, tagged=True) + mul_b = self._make_tag_node(g, x, 3.0, tagged=True) + out = g.call_function(torch.ops.aten.add.Tensor, (mul_a, mul_b)) + out.meta["val"] = torch.empty(10) + g.output(out) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 1) + + def test_parallel_branches_with_gap_same_region(self): + """Two independent tagged nodes separated by an untagged node but + sharing the same region ID are fused into 1 partition (no cycle). + """ + g = torch.fx.Graph() + x = g.placeholder("x") + mul_a = self._make_tag_node(g, x, 2.0, tagged=True) + sin = g.call_function(torch.ops.aten.sin.default, (x,)) + sin.meta["val"] = torch.empty(10) + mul_b = self._make_tag_node(g, x, 3.0, tagged=True) + out = g.call_function(torch.ops.aten.add.Tensor, (mul_a, mul_b)) + out.meta["val"] = torch.empty(10) + g.output(out) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 1) + + def test_parallel_branches_with_gap_different_regions(self): + """Two independent tagged nodes with different region IDs produce + 2 separate partitions regardless of graph topology. + """ + g = torch.fx.Graph() + x = g.placeholder("x") + mul_a = self._make_tag_node(g, x, 2.0, tagged=True) + mul_a.meta["custom"] = {"compile_with_inductor": {"inductor_region": 0}} + sin = g.call_function(torch.ops.aten.sin.default, (x,)) + sin.meta["val"] = torch.empty(10) + mul_b = self._make_tag_node(g, x, 3.0, tagged=True) + mul_b.meta["custom"] = {"compile_with_inductor": {"inductor_region": 1}} + out = g.call_function(torch.ops.aten.add.Tensor, (mul_a, mul_b)) + out.meta["val"] = torch.empty(10) + g.output(out) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 2) + + def test_dependent_partitions_merged_across_gap(self): + """Two tagged nodes (same region) separated by an untagged node and + connected by a data dependency are merged into 1 partition. + """ + g = torch.fx.Graph() + x = g.placeholder("x") + mul_a = self._make_tag_node(g, x, 2.0, tagged=True) + sin = g.call_function(torch.ops.aten.sin.default, (x,)) + sin.meta["val"] = torch.empty(10) + # mul_b consumes mul_a, creating a data dependency across the gap + mul_b = self._make_tag_node(g, mul_a, 3.0, tagged=True) + g.output(mul_b) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 1) + + def test_different_annotations_not_merged(self): + """Two tagged nodes with different region IDs are NOT merged, + even if connected by a data dependency. + """ + g = torch.fx.Graph() + x = g.placeholder("x") + mul_a = self._make_tag_node(g, x, 2.0, tagged=True) + mul_a.meta["custom"] = {"compile_with_inductor": {"inductor_region": 0}} + sin = g.call_function(torch.ops.aten.sin.default, (x,)) + sin.meta["val"] = torch.empty(10) + mul_b = self._make_tag_node(g, mul_a, 3.0, tagged=True) + mul_b.meta["custom"] = {"compile_with_inductor": {"inductor_region": 1}} + g.output(mul_b) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 2) + + def test_chained_merges_across_multiple_gaps(self): + """Multiple tagged nodes (same region) with chained data dependencies + across untagged gaps are all merged into 1 partition. + """ + g = torch.fx.Graph() + x = g.placeholder("x") + tagged_a = self._make_tag_node(g, x, 2.0, tagged=True) + # Untagged node consuming x (not tagged_a), just occupying a slot + filler_1 = g.call_function(torch.ops.aten.sin.default, (x,)) + filler_1.meta["val"] = torch.empty(10) + tagged_b = self._make_tag_node(g, tagged_a, 3.0, tagged=True) + filler_2 = g.call_function(torch.ops.aten.sin.default, (x,)) + filler_2.meta["val"] = torch.empty(10) + tagged_c = self._make_tag_node(g, tagged_b, 4.0, tagged=True) + g.output(tagged_c) + gm = torch.fx.GraphModule(torch.nn.Module(), g) + self.assertEqual(self._scoop_and_count(gm), 1) + + if __name__ == "__main__": run_tests() diff --git a/test/dynamo/test_reorder_logs.py b/test/dynamo/test_reorder_logs.py index 128211ac48a48..5825543b90642 100644 --- a/test/dynamo/test_reorder_logs.py +++ b/test/dynamo/test_reorder_logs.py @@ -299,9 +299,38 @@ def f(x): Developer debug context: call_method TensorVariable() item () {} - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", # noqa: B950 + For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html""", ) + def test_reorder_print_data_dependent_fstring(self): + """Print with data-dependent bool in f-string should graph break on print, + but work when print is reorderable.""" + + def f(x, mask): + make_causal = bool((mask == 0).all()) + print(f"make_causal={make_causal}") + return x + 1 + + x = torch.randn(2, 3) + mask = torch.zeros(2, 3) + + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, + "Dynamo does not know how to trace builtin operator `print`", + ): + torch.compile(backend="eager", fullgraph=True)(f)(x, mask) + + with torch._dynamo.config.patch( + reorderable_logging_functions={print}, capture_scalar_outputs=True + ): + opt_f = torch.compile(backend="eager", fullgraph=True)(f) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + opt_out = opt_f(x, mask) + printed_output = mock_stdout.getvalue().strip() + + self.assertTrue(same(opt_out, x + 1)) + self.assertEqual(printed_output, "make_causal=True") + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 40100dbf6bb26..79c61273ec47f 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -56,7 +56,11 @@ ) from torch._inductor.utils import fresh_cache from torch.nn import functional as F -from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.nn.attention.flex_attention import ( + AuxRequest, + create_block_mask, + flex_attention, +) from torch.profiler import profile, ProfilerActivity from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -1218,6 +1222,21 @@ def fn(x): compiled_out = compiled_fn(x) self.assertEqual(eager_out, compiled_out) + # https://github.com/pytorch/pytorch/issues/166626 + def test_inplace_add_from_meta_tensor_factory(self): + def fn(x): + log_det = torch.zeros(x.size(0), device=x.device) + log_det += torch.zeros(x.size(0), device="meta") + return log_det + + x = torch.randn(2, 4) + eager_out = fn(x) + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + compiled_out = compiled_fn(x) + + self.assertEqual(eager_out.device, compiled_out.device) + self.assertEqual(eager_out, compiled_out) + # https://github.com/pytorch/pytorch/issues/109053 def test_view_dtype_overload(self): def f(x): @@ -1384,12 +1403,8 @@ def test_reformer_eval(self): def test_reformer_train(self): with torch.enable_grad(): cnt = self._reformer(nopython=False) - expected_op_count = ( - """10""" if torch._dynamo.config.inline_inbuilt_nn_modules else """4""" - ) - self.assertExpectedInline(cnt.frame_count, """1""") - self.assertExpectedInline(cnt.op_count, expected_op_count) + self.assertExpectedInline(cnt.op_count, """10""") def test_longformer_chunk(self): input1 = torch.randn([1, 4096, 1]) @@ -1478,8 +1493,6 @@ def fn(input_lengths: torch.Tensor, new_ones_1): @torch._dynamo.config.patch(error_on_recompile=True) @torch.fx.experimental._config.patch(use_duck_shape=False) def test_dynamic_shape_disable_duck_size(self): - # noqa: F841 - class TestModel(nn.Module): def __init__( self, @@ -1995,6 +2008,38 @@ def fn(x): self.assertFalse(y.requires_grad) self.assertFalse(z.requires_grad) + def test_locals_traced_correctly_under_compile(self): + def fn(x): + if x.dim() > 2: + batch_size, seq_len, hidden_dim = x.shape + x = x.view(-1, hidden_dim) + + x = x + 1 + + if "batch_size" in locals() and "seq_len" in locals(): + x = x.view(batch_size, seq_len, -1) + return x + + x = torch.randn(2, 3, 4) + opt_fn = torch.compile(fn, backend="eager") + self.assertTrue(same(fn(x), opt_fn(x))) + + def test_vars_traced_correctly_under_compile(self): + def fn(x): + if x.dim() > 2: + batch_size, seq_len, hidden_dim = x.shape + x = x.view(-1, hidden_dim) + + x = x + 1 + + if "batch_size" in vars() and "seq_len" in vars(): + x = x.view(batch_size, seq_len, -1) + return x + + x = torch.randn(2, 3, 4) + opt_fn = torch.compile(fn, backend="eager") + self.assertTrue(same(fn(x), opt_fn(x))) + def test_abc_setattr(self): # tests that we correctly bail out of __setattr__ calls @@ -2123,7 +2168,7 @@ def b(x): self.assertTrue(same(b(y), y.sin().cos())) @skipIfWindows( - msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):" # noqa: B950 + msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function (*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):" ) def test_longtensor_list(self): for partition in [0, 5, 10]: @@ -4154,7 +4199,7 @@ def fn(inp1, inp2, inp3, inp4, c): cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnt) opt_fn(inp1, inp2, inp3, inp4, c) - self.assertEqual(cnt.frame_count, 3) + self.assertEqual(cnt.frame_count, 2) def test_torch_variable_type(self): # from torchvision @@ -4327,16 +4372,18 @@ def fn(x): self.assertTrue(same(fn(x), opt_fn(x))) def test_add_sub_alpha_out(self): - inp = torch.randn(2, 3, 4) - other = 1 - alpha = 2 + test_cases = ( + (torch.randn(2, 3, 4), 1, 2, torch.zeros(2, 3, 4)), + (2, 1.1, 0.4, torch.tensor(0.0)), + ) for op in [torch.add, torch.sub]: - out = torch.zeros(2, 3, 4) - compile_out = torch.zeros(2, 3, 4) - op(inp, other, alpha=alpha, out=out) - compiled_fn = torch.compile(op, dynamic=True, backend="eager") - compiled_fn(inp, other, alpha=alpha, out=compile_out) - self.assertTrue(same(out, compile_out)) + for inp, other, alpha, out in test_cases: + compiled_fn = torch.compile(op, dynamic=True, backend="eager") + eager_out = out.clone() + compiled_out = out.clone() + op(inp, other, alpha=alpha, out=eager_out) + compiled_fn(inp, other, alpha=alpha, out=compiled_out) + self.assertTrue(same(eager_out, compiled_out)) def test_negative_shape_guard(self): def fn(x): @@ -4832,16 +4879,18 @@ def forward(self, x): model = SimpleModel().eval() input_tensor = torch.randn(1, 10, dtype=torch.float32) opt = torch.compile(model.eval(), backend="eager", fullgraph=True) - actual = opt(input_tensor) try: expected = model(input_tensor) except Exception as e: - raise unittest.SkipTest("eager failed, requires Python>=3.12") from e + raise unittest.SkipTest( + "eager failed, requires Python between 3.9 and 3.12" + ) from e + actual = opt(input_tensor) self.assertEqual(actual, expected) def test_invalid_seq_unpack(self): def myfn(arg): - (a, b) = arg # noqa: F841 + (a, b) = arg def fn(): return myfn((1, 2, 3)) @@ -5378,17 +5427,11 @@ def fn(obj): obj = A() - try: + with self.assertRaisesRegex(RuntimeError, r"super\(\)"): fn(obj) - except Exception as e: - orig_str = str(e) - self.assertIn("no arguments", orig_str) - try: + with self.assertRaisesRegex(RuntimeError, r"super\(\)"): torch.compile(backend="eager")(fn)(obj) - except Exception as e: - compiled_str = str(e) - self.assertEqual(orig_str, compiled_str) def test_super_staticmethod(self): class Parent: @@ -5888,10 +5931,9 @@ def forward(self, x): opt_mod = torch.compile(mod, backend=compiler) opt_mod(torch.randn(2, 2)) - with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True): - mod = Mod() - opt_mod = torch.compile(mod, backend=compiler) - opt_mod(torch.randn(2, 2)) + mod = Mod() + opt_mod = torch.compile(mod, backend=compiler) + opt_mod(torch.randn(2, 2)) # an example similar to Pippy usecase mod = Mod() @@ -5934,6 +5976,34 @@ def test_export_vs_dynamo_for_multiheadattention(self): self.assertEqual(len(compile_nodes), 0) self.assertEqual(len(export_nodes), 0) + def test_multiheadattention_tracing_slowpath_matches_fastpath_layout(self): + class MHAWithView(nn.Module): + def __init__(self) -> None: + super().__init__() + self.hidden_dim = 64 + self.attention = nn.MultiheadAttention( + self.hidden_dim, 8, batch_first=True + ) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x) + return attn_output.view(-1, self.hidden_dim) + + with torch.no_grad(): + model = MHAWithView().eval() + x = torch.randn(4, 32, model.hidden_dim) + eager = model(x) + + backend = EagerAndRecordGraphs() + compiled_model = torch.compile(model, backend=backend, fullgraph=True) + compiled = compiled_model(x) + + compile_nodes = backend.graphs[0].graph.find_nodes( + op="call_function", target=torch._native_multi_head_attention + ) + self.assertEqual(compiled, eager) + self.assertEqual(len(compile_nodes), 0) + def test_negative_floor_div_solve(self): class CompiledClass(nn.Module): def __init__(self) -> None: @@ -6257,96 +6327,6 @@ def fn(aot6_sub_58, aot6_mul_170): # No assert necessary since this used to crash. fn(aot6_sub_58, aot6_mul_170) - @torch._dynamo.config.patch(guard_nn_modules=False) - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_inlining_cornercase(self): - """ - nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the - tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs. - - But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is - UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of - guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also - cudagraphs which make it harder to fix the corner case right away. The long term solution is - inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term. - - We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked - as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this - changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior. - - - To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on. - - In this test, we purposely turn the flag off, testing that the tagging is disabled. - """ - - class SubMod(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(1, 1) - self.a = torch.randn(1, 1) - self.counter = 0 - self.multipliers = [2.2, 3.3] - - def forward(self, x): - self.counter += 1 - return ( - self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1] - ) - - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.submod = SubMod() - - def forward(self, x): - return self.submod(x) - - mod = Mod() - opt_mod = torch.compile(mod, backend="eager") - - x = torch.randn(1, 1) - ref = mod(x) # noqa: F841 - res = opt_mod(x) # noqa: F841 - - mod.submod.multipliers = [3.3, 4.4] - # Since guard_nn_modules is False, this will not recompile - with torch._dynamo.config.patch(error_on_recompile=True): - ref = mod(x) # noqa: F841 - res = opt_mod(x) # noqa: F841 - - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_nnmodule_variable_children_wrap_value(self): - """ - tests wrap_values() in nn_module.py by calling children() on a submodule, - which triggers NNModuleVariable.call_method only when - inline_inbuilt_nn_modules=False. This path was previously untested - causing #173924 - """ - - class Parent(torch.nn.Module): - def __init__(self): - super().__init__() - self.container = torch.nn.Sequential( - torch.nn.Linear(10, 10), - torch.nn.ReLU(), - torch.nn.Linear(10, 5), - ) - - def forward(self, x): - for child in self.container.children(): - x = child(x) - return x - - model = Parent() - x = torch.randn(2, 10) - - eager_result = model(x) - compiled_model = torch.compile(model, backend="eager", fullgraph=True) - compiled_result = compiled_model(x) - - self.assertEqual(eager_result, compiled_result) - def test_optimized_module_training(self): mod = torch.nn.Linear(3, 3) mod.eval() @@ -8024,6 +8004,91 @@ def f(x): result = f(torch.randn(5)) self.assertEqual(result, 5) + def test_one_hot_bounds_check_compiled(self): + # https://github.com/pytorch/pytorch/issues/144211 + # torch.compile(one_hot) should raise on out-of-bounds indices, + # not silently produce wrong results. + one_hot = torch.compile(torch.nn.functional.one_hot, fullgraph=True) + + a = torch.arange(0, 5) % 3 # [0, 1, 2, 0, 1] + with self.assertRaises(RuntimeError): + one_hot(a, 1) + + torch._dynamo.reset() + with self.assertRaises(RuntimeError): + one_hot(torch.tensor([-1, 0, 1]), 3) + + torch._dynamo.reset() + expected = torch.nn.functional.one_hot(a, 3) + self.assertEqual(one_hot(a, 3), expected) + + @unittest.expectedFailure + def test_method_dunder_dict_setitem(self): + # Reproducer for: getattr(obj, method_name).__dict__['key'] = value + # method.__dict__ is handled specially by CPython at C level (no + # tp_dictoffset, no Python-visible descriptor), which caused + # object.__getattribute__(method, "__dict__") to raise AttributeError. + + @torch.compile(backend="eager", fullgraph=True) + def fn(x): + getattr(self, self._testMethodName).__dict__["slow_test"] = True + return x.sin() + + x = torch.randn(2) + _ = fn(x) + self.assertTrue(getattr(self, self._testMethodName).__dict__.get("slow_test")) + + def test_elementwise_dtypes_constant_fold(self): + from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + ) + + @torch.compile(fullgraph=True, backend="eager") + def fn(x): + dt, _ = elementwise_dtypes( + x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + return x.to(dt) + + result = fn(torch.randn(3)) + self.assertEqual(result.dtype, torch.float32) + + def test_elementwise_dtypes_int_to_float(self): + from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + ) + + @torch.compile(fullgraph=True, backend="eager") + def fn(x): + dt, _ = elementwise_dtypes( + x, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT + ) + return x.to(dt) + + result = fn(torch.randint(0, 10, (3,))) + self.assertEqual(result.dtype, torch.float32) + + def test_elementwise_dtypes_multi_args(self): + from torch._prims_common import ( + elementwise_dtypes, + ELEMENTWISE_TYPE_PROMOTION_KIND, + ) + + @torch.compile(fullgraph=True, backend="eager") + def fn(x, y): + dt, _ = elementwise_dtypes( + x, y, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + return x.to(dt) + + result = fn( + torch.randn(3, dtype=torch.float16), + torch.randn(3, dtype=torch.float32), + ) + self.assertEqual(result.dtype, torch.float32) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): @@ -8080,7 +8145,7 @@ def f(): @skipIfHpu @unittest.skipIf( - TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION, + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported", ) def test_flash_attn_backward_mixed_strides(self, device): @@ -8569,7 +8634,7 @@ def fn(x): cnt = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnt) self.assertEqual(fn(x), opt_fn(x)) - self.assertEqual(cnt.frame_count, 2) + self.assertEqual(cnt.frame_count, 1) def test_filter_warnings(self): x = torch.ones(2, 2, requires_grad=True) @@ -8941,8 +9006,6 @@ def f(x): self.assertEqual(result.item(), 4.0) def test_enum_with_class_values(self): - # Enum whose members are user-defined classes; calling .value() - # instantiates the class, which Dynamo can't trace. from enum import Enum class AvgMetric: @@ -8977,6 +9040,128 @@ def fn(logger, x): logger = ScalarLogger() fn(logger, torch.tensor(1.0)) + def test_class_attr_mutation_recompiles(self): + class GlobalState: + factor = 1.0 + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt) + def fn(x): + return x * GlobalState.factor + + x = torch.tensor([4.0]) + + GlobalState.factor = 1.0 + result1 = fn(x) + self.assertEqual(result1, torch.tensor([4.0])) + self.assertEqual(cnt.frame_count, 1) + + GlobalState.factor = 10.0 + result2 = fn(x) + self.assertEqual(result2, torch.tensor([40.0])) + self.assertEqual(cnt.frame_count, 2) + + @skipIfHpu + @requires_cuda + def test_deterministic_pad_replicate_compile(self, device): + from torch.testing._internal.common_utils import DeterministicGuard + + pad = torch.nn.ReplicationPad1d(2).to(device) + compiled_pad = torch.compile(pad, backend="aot_eager", fullgraph=True) + x = torch.randn(3, 3, device=device, requires_grad=True) + with DeterministicGuard(True): + ref = pad(x) + res = compiled_pad(x) + self.assertEqual(ref, res) + grad = torch.autograd.grad(res.sum(), x) + ref_grad = torch.autograd.grad(ref.sum(), x) + self.assertEqual(grad, ref_grad) + + @requires_cuda + @unittest.skipIf( + TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "flash attention not supported", + ) + def test_flex_attention_guard_on_constant_func_defaults(self): + """ + Dynamo must guard on mask_mod.__defaults__ so that when a + compiled function is re-invoked with a new BlockMask whose + mask_mod has the same __code__ but different __defaults__, + Dynamo recompiles instead of reusing the stale first graph. + """ + from torch.utils._triton import has_triton + + if not has_triton(): + self.skipTest("requires triton") + + @torch.compile(fullgraph=True) + def flex_chunk(q, k, v, block_mask, scale): + out, aux = flex_attention( + q, + k, + v, + block_mask=block_mask, + scale=scale, + return_aux=AuxRequest(lse=True), + ) + return out, aux.lse + + def merge(out, lse, new_out, new_lse): + lse, new_lse = lse.unsqueeze(-1), new_lse.unsqueeze(-1) + mx = torch.maximum(lse, new_lse) + e0, e1 = torch.exp(lse - mx), torch.exp(new_lse - mx) + d = e0 + e1 + return (out * e0 + new_out * e1) / d, (mx + torch.log(d)).squeeze(-1) + + @torch.compile(fullgraph=True) + def ref_attn(q, k, v, block_mask, scale): + return flex_attention(q, k, v, block_mask=block_mask, scale=scale) + + torch.manual_seed(42) + B, H, S, D = 1, 1, 512, 16 + device = "cuda" + NUM_CHUNKS = 4 + chunk_size = S // NUM_CHUNKS + + q = torch.randn(B, H, S, D, device=device) + k = torch.randn(B, H, S, D, device=device) + v = torch.randn(B, H, S, D, device=device) + scale = D**-0.5 + + merged_out = merged_lse = None + for step in range(NUM_CHUNKS): + kv_offset = step * chunk_size + + def mask_mod(b, h, q_idx, kv_idx, _offset=kv_offset): + return q_idx >= kv_idx + _offset + + bm = create_block_mask( + mask_mod, B=B, H=H, Q_LEN=S, KV_LEN=chunk_size, device=device + ) + out, lse = flex_chunk( + q, + k[:, :, kv_offset : kv_offset + chunk_size], + v[:, :, kv_offset : kv_offset + chunk_size], + bm, + scale, + ) + if merged_out is None: + merged_out, merged_lse = out, lse + else: + merged_out, merged_lse = merge(merged_out, merged_lse, out, lse) + + def causal(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + ref_bm = create_block_mask(causal, B=B, H=H, Q_LEN=S, KV_LEN=S, device=device) + ref_out = ref_attn(q, k, v, ref_bm, scale) + + self.assertTrue( + (merged_out - ref_out).abs().max().item() < 1e-3, + "flex_attention mask_mod __defaults__ not properly guarded", + ) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_resume.py b/test/dynamo/test_resume.py index 42103a7878e73..5f26b8a74fd43 100644 --- a/test/dynamo/test_resume.py +++ b/test/dynamo/test_resume.py @@ -13,7 +13,7 @@ def fn(x): torch._dynamo.graph_break() x = x + var1 - def inner_fn(): # noqa: F841 + def inner_fn(): return var2 return x @@ -26,10 +26,13 @@ def test_freevars(self): fn = fn_creator() opt_fn = torch.compile(fn, backend="eager") opt_fn(torch.randn(10)) - codes = [v for k, v in list(globals().items()) if k.startswith("__resume_at")] - self.assertEqual(len(codes), 1) + entries = [v for k, v in list(globals().items()) if k.startswith("__resume_at")] + self.assertEqual(len(entries), 1) + # When freevars are present, install_resume_function_global stores a + # factory that closes over the code object (first closure cell). + code = entries[0].__closure__[0].cell_contents # co_freevars of resume functions, are sorted concatenation of the original function's co_freevars and co_cellvars - self.assertEqual(codes[0].co_freevars, ("var1", "var2")) + self.assertEqual(code.co_freevars, ("var1", "var2")) if __name__ == "__main__": diff --git a/test/dynamo/test_sets.py b/test/dynamo/test_sets.py index dab0bdea8ea7f..d7267dd717c37 100644 --- a/test/dynamo/test_sets.py +++ b/test/dynamo/test_sets.py @@ -178,7 +178,7 @@ def fn(x, s): from user code: File "test_sets.py", line N, in fn - for i in s:""", # noqa: B950 + for i in s:""", ) def test_set_multiple_types(self): @@ -447,7 +447,7 @@ def test_equality(self): @make_dynamo_test def test_in_frozenset(self): item = self.thetype("abc") - container = self.thetype([frozenset("abc")]) # noqa: C405 + container = self.thetype([frozenset("abc")]) self.assertIn(item, container) @make_dynamo_test diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 0acdda103eff3..e7791e4a1c3c9 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -1,5 +1,4 @@ # Owner(s): ["module: dynamo"] -import functools import re import unittest import weakref @@ -13,15 +12,9 @@ store_user_object_weakrefs, ) from torch._dynamo.testing import extract_graph, remove_trailing_space -from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda -requires_multigpu = functools.partial( - unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" -) - - def remove_file_comment(gm_str: str) -> str: return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str)) @@ -75,13 +68,13 @@ def fn(x, y, s1, s2): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None return (add_3,) @@ -164,38 +157,30 @@ def fn(x, s): self.assertEqual(s0, s1) @requires_cuda - @requires_multigpu() - def test_get_current_stream_return_different_device(self): - def fn(x, s0, s1): - with s1: - with s0: - s = torch.accelerator.current_stream(torch.device("cuda:1")) - return s + def test_cuda_current_stream_attrs(self): + """Verify that torch.cuda.current_stream() attributes are accessible + under torch.compile and match eager behavior.""" - s0 = torch.Stream(device="cuda:0") - s1 = torch.Stream(device="cuda:1") - inp = (torch.ones(2, 2) + 1, s0, s1) - fn_opt = torch.compile(fn, fullgraph=True) - s_act = fn_opt(*inp) - s_exp = fn(*inp) - self.assertEqual(s_act, s_exp) + def fn_cuda_stream(x): + return torch.cuda.current_stream().cuda_stream + + x = torch.zeros(1, device="cuda") + compiled = torch.compile(fn_cuda_stream, backend="eager", fullgraph=True) + self.assertEqual(compiled(x), fn_cuda_stream(x)) @requires_cuda - @requires_multigpu() - def test_get_current_stream_return_no_index(self): - def fn(x, s0, s1): - with s1: - with s0: - s = torch.accelerator.current_stream(torch.device("cuda")) - return s + def test_cuda_current_stream_with_entered_stream(self): + """Verify that torch.cuda.current_stream().cuda_stream returns the + correct value when inside a stream context for a user-created stream.""" - s0 = torch.Stream(device="cuda:0") - s1 = torch.Stream(device="cuda:1") - inp = (torch.ones(2, 2) + 1, s0, s1) - fn_opt = torch.compile(fn, fullgraph=True) - s_act = fn_opt(*inp) - s_exp = fn(*inp) - self.assertEqual(s_act, s_exp) + def fn(x, s): + with s: + return torch.cuda.current_stream().cuda_stream + + s = torch.cuda.Stream() + x = torch.zeros(1, device="cuda") + compiled = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(compiled(x, s), fn(x, s)) @requires_cuda def test_nested_stream_enter_exit(self): @@ -231,13 +216,13 @@ def fn(x, y, s0, s1, s2): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1, add_2) """, @@ -279,13 +264,13 @@ def fn(x, y): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None return (add_3,) @@ -323,46 +308,18 @@ def fn(x, y): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1, add_2) """, ) - @requires_cuda - @requires_multigpu() - def test_new_event_api(self) -> None: - from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index - from torch._dynamo.variables.streams import new_event - - def event_generation_backend(gm, *args, **kwargs): # type: ignore[no-untyped-def] - e0_ind = new_event() - with torch.Stream(device="cuda:1"): - get_external_object_by_index(e0_ind).record() - e1_ind = new_event() - self.assertNotEqual(e0_ind, e1_ind) - self.assertNotEqual( - get_external_object_by_index(e0_ind), - get_external_object_by_index(e1_ind), - ) - with gm.graph.inserting_after(next(iter(gm.graph.nodes))): - gm.graph.call_function( - get_external_object_by_index, args=(1,), kwargs={} - ) - return gm - - @torch.compile(backend=event_generation_backend) - def fn(x): - return x + 1 - - fn(torch.ones(2, 2, device="cuda:0")) - @requires_cuda def test_new_stream_api(self) -> None: from torch._dynamo.graph_bytecode_inputs import get_external_object_by_index @@ -451,19 +408,19 @@ def fn(x, y): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1) - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None return (add_2, add_3) """, @@ -499,13 +456,13 @@ def fn(x, y): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2) - # Annotation: {'stream': 0} - add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None - return (add, add_1) + # Annotation: {'stream': 1} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None + return (add, add_1, mul, add_1) """, ) @@ -514,22 +471,38 @@ def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): print_graph(bw_graphs[0]), """\ class GraphModule(torch.nn.Module): - def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): - # Annotation: {'stream': 0} + def forward(self, mul: "f32[2, 2]", add_1: "f32[2, 2]", tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): + # Annotation: {'stream': 1} mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2) # add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None - # Annotation: {'stream': 0} - add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = None + # Annotation: {'stream': 1} + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3) + + # No stacktrace found for following nodes + subgraph_record_event_default = self.subgraph_record_event_default + control_deps = torch.ops.higher_order.control_deps((mul, add_1, mul_2, add_3, add_2), subgraph_record_event_default, add_1, add_3, add_2); mul = add_1 = mul_2 = add_3 = add_2 = subgraph_record_event_default = None + + # + getitem_2: "f32[2, 2]" = control_deps[3] + + # Annotation: {'stream': 1} + getitem_1: "f32[2, 2]" = control_deps[2]; control_deps = None # No stacktrace found for following nodes - sync_dealloc_default = torch.ops.streams.sync_dealloc.default(2, 1, mul_3); mul_3 = sync_dealloc_default = None - return (add_3, add_2) + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(3, 2, mul_3); mul_3 = sync_dealloc_default = None + return (getitem_1, getitem_2) + + class subgraph_record_event_default(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]", dep_2: "f32[2, 2]"): + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(3, 1) + return (record_event_default, dep_0, dep_1, dep_2) """, ) @@ -563,11 +536,11 @@ def fn(x, y): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2) - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None return (add, add_1, mul, add, add_1) """, @@ -579,48 +552,64 @@ def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): """\ class GraphModule(torch.nn.Module): def forward(self, mul: "f32[2, 2]", add: "f32[2, 2]", add_1: "f32[2, 2]", tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2) # add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None # No stacktrace found for following nodes subgraph_record_event_default = self.subgraph_record_event_default - control_deps = torch.ops.higher_order.control_deps((mul, add, mul_3), subgraph_record_event_default, add, mul_3); add = mul_3 = subgraph_record_event_default = None + control_deps = torch.ops.higher_order.control_deps((mul, add, mul_3, add_2), subgraph_record_event_default, add, mul_3, add_2); add = mul_3 = add_2 = subgraph_record_event_default = None - # Annotation: {'stream': 1} + # + getitem_2: "f32[2, 2]" = control_deps[3] + + # Annotation: {'stream': 2} getitem_1: "f32[2, 2]" = control_deps[2] # No stacktrace found for following nodes subgraph_wait_event_default = self.subgraph_wait_event_default control_deps_1 = torch.ops.higher_order.control_deps((control_deps, mul, add_1, mul_2), subgraph_wait_event_default, add_1, mul_2); control_deps = mul = add_1 = mul_2 = subgraph_wait_event_default = None - # Annotation: {'stream': 0} - getitem_3: "f32[2, 2]" = control_deps_1[2]; control_deps_1 = None + # Annotation: {'stream': 1} + getitem_4: "f32[2, 2]" = control_deps_1[2]; control_deps_1 = None - # Annotation: {'stream': 0} - add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem_3, getitem_1); getitem_3 = None + # Annotation: {'stream': 1} + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem_4, getitem_1); getitem_4 = None + + # No stacktrace found for following nodes + subgraph_record_event_default_1 = self.subgraph_record_event_default_1 + control_deps_2 = torch.ops.higher_order.control_deps((add_3,), subgraph_record_event_default_1, add_3); add_3 = subgraph_record_event_default_1 = None + + # Annotation: {'stream': 1} + getitem_5: "f32[2, 2]" = control_deps_2[1]; control_deps_2 = None # No stacktrace found for following nodes - sync_dealloc_default = torch.ops.streams.sync_dealloc.default(3, 1, getitem_1); getitem_1 = sync_dealloc_default = None - return (add_3, add_2) + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(4, 2, getitem_1); getitem_1 = sync_dealloc_default = None + return (getitem_5, getitem_2) class subgraph_record_event_default(torch.nn.Module): - def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]"): + def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]", dep_2: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(2, 1) - return (record_event_default, dep_0, dep_1) + record_event_default = torch.ops.streams.record_event.default(3, 2) + return (record_event_default, dep_0, dep_1, dep_2) class subgraph_wait_event_default(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]"): # No stacktrace found for following nodes - wait_event_default = torch.ops.streams.wait_event.default(2, 0) + wait_event_default = torch.ops.streams.wait_event.default(3, 1) return (wait_event_default, dep_0, dep_1) -""", # noqa: B950 + + class subgraph_record_event_default_1(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]"): + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(4, 1) + return (record_event_default, dep_0) +""", ) @requires_cuda @@ -645,7 +634,7 @@ def fn(x) -> None: class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]"): # - record_event = torch.ops.streams.record_event.default(0, 1); record_event = None + record_event = torch.ops.streams.record_event.default(1, 0); record_event = None # add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1) @@ -754,45 +743,58 @@ def fn(x): """\ class GraphModule(torch.nn.Module): def forward(self, getitem: "f32[2, 2]", mul: "f32[2, 2]", mul_1: "f32[2, 2]", mul_2: "f32[2, 2]", tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, mul_1); tangents_1 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul_4: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, getitem); tangents_2 = getitem = None # No stacktrace found for following nodes subgraph_record_event_default = self.subgraph_record_event_default control_deps_2 = torch.ops.higher_order.control_deps((mul, mul_4), subgraph_record_event_default, mul, mul_4); mul = mul_4 = subgraph_record_event_default = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} getitem_2: "f32[2, 2]" = control_deps_2[2] # No stacktrace found for following nodes subgraph_wait_event_default = self.subgraph_wait_event_default control_deps_3 = torch.ops.higher_order.control_deps((control_deps_2, mul_1, mul_2, mul_3), subgraph_wait_event_default, mul_2, mul_3); control_deps_2 = mul_1 = mul_2 = mul_3 = subgraph_wait_event_default = None - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} getitem_4: "f32[2, 2]" = control_deps_3[2]; control_deps_3 = None - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem_4, getitem_2); getitem_4 = None # No stacktrace found for following nodes - sync_dealloc_default = torch.ops.streams.sync_dealloc.default(4, 1, getitem_2); getitem_2 = sync_dealloc_default = None - return (add,) + subgraph_record_event_default_1 = self.subgraph_record_event_default_1 + control_deps_4 = torch.ops.higher_order.control_deps((add,), subgraph_record_event_default_1, add); add = subgraph_record_event_default_1 = None + + # Annotation: {'stream': 3} + getitem_5: "f32[2, 2]" = control_deps_4[1]; control_deps_4 = None + + # No stacktrace found for following nodes + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(5, 2, getitem_2); getitem_2 = sync_dealloc_default = None + return (getitem_5,) class subgraph_record_event_default(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(3, 1) + record_event_default = torch.ops.streams.record_event.default(4, 2) return (record_event_default, dep_0, dep_1) class subgraph_wait_event_default(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]", dep_1: "f32[2, 2]"): # No stacktrace found for following nodes - wait_event_default = torch.ops.streams.wait_event.default(3, 2) + wait_event_default = torch.ops.streams.wait_event.default(4, 3) return (wait_event_default, dep_0, dep_1) -""", # noqa: B950 + + class subgraph_record_event_default_1(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]"): + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(5, 3) + return (record_event_default, dep_0) +""", ) @requires_cuda @@ -842,135 +844,168 @@ def fn(x): """\ class GraphModule(torch.nn.Module): def forward(self, getitem: "f32[2, 2]", getitem_3: "f32[2, 2]", getitem_2: "f32[2, 2]", getitem_4: "f32[2, 2]", getitem_6: "f32[2, 2]", tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]", tangents_3: "f32[2, 2]"): - # Annotation: {'stream': 4} + # Annotation: {'stream': 5} mul_7: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_3, getitem_6); tangents_3 = getitem_6 = None # No stacktrace found for following nodes subgraph_record_event_default = self.subgraph_record_event_default control_deps_8 = torch.ops.higher_order.control_deps((mul_7,), subgraph_record_event_default, mul_7); mul_7 = subgraph_record_event_default = None - # Annotation: {'stream': 4} - getitem_8: "f32[2, 2]" = control_deps_8[1]; control_deps_8 = None + # Annotation: {'stream': 5} + getitem_8: "f32[2, 2]" = control_deps_8[1] - # Annotation: {'stream': 3} + # No stacktrace found for following nodes + subgraph_wait_event_default = self.subgraph_wait_event_default + control_deps_9 = torch.ops.higher_order.control_deps((control_deps_8,), subgraph_wait_event_default); control_deps_8 = subgraph_wait_event_default = control_deps_9 = None + + # Annotation: {'stream': 4} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, getitem_8); tangents_2 = None # No stacktrace found for following nodes subgraph_record_event_default_4 = self.subgraph_record_event_default_4 control_deps_10 = torch.ops.higher_order.control_deps((add,), subgraph_record_event_default_4, add); add = subgraph_record_event_default_4 = None - # Annotation: {'stream': 3} + # Annotation: {'stream': 4} getitem_9: "f32[2, 2]" = control_deps_10[1]; control_deps_10 = None # No stacktrace found for following nodes - sync_dealloc_default = torch.ops.streams.sync_dealloc.default(10, 4, getitem_8); getitem_8 = sync_dealloc_default = None + sync_dealloc_default = torch.ops.streams.sync_dealloc.default(10, 5, getitem_8); getitem_8 = sync_dealloc_default = None - # Annotation: {'stream': 3} + # Annotation: {'stream': 4} mul_8: "f32[2, 2]" = torch.ops.aten.mul.Tensor(getitem_9, getitem_4); getitem_4 = None # No stacktrace found for following nodes subgraph_record_event_default_1 = self.subgraph_record_event_default_1 control_deps_11 = torch.ops.higher_order.control_deps((mul_8,), subgraph_record_event_default_1, mul_8); mul_8 = subgraph_record_event_default_1 = None - # Annotation: {'stream': 3} - getitem_10: "f32[2, 2]" = control_deps_11[1]; control_deps_11 = None + # Annotation: {'stream': 4} + getitem_10: "f32[2, 2]" = control_deps_11[1] mul_9: "f32[2, 2]" = torch.ops.aten.mul.Tensor(getitem_9, getitem_3); getitem_9 = getitem_3 = None mul_10: "f32[2, 2]" = torch.ops.aten.mul.Tensor(mul_9, getitem) - # Annotation: {'stream': 2} + # No stacktrace found for following nodes + subgraph_wait_event_default_1 = self.subgraph_wait_event_default_1 + control_deps_12 = torch.ops.higher_order.control_deps((control_deps_11,), subgraph_wait_event_default_1); control_deps_11 = subgraph_wait_event_default_1 = control_deps_12 = None + + # Annotation: {'stream': 3} mul_11: "f32[2, 2]" = torch.ops.aten.mul.Tensor(getitem_10, getitem_2); getitem_2 = None # No stacktrace found for following nodes subgraph_record_event_default_5 = self.subgraph_record_event_default_5 control_deps_13 = torch.ops.higher_order.control_deps((mul_11,), subgraph_record_event_default_5, mul_11); mul_11 = subgraph_record_event_default_5 = None - # Annotation: {'stream': 2} + # Annotation: {'stream': 3} getitem_11: "f32[2, 2]" = control_deps_13[1]; control_deps_13 = None # No stacktrace found for following nodes - sync_dealloc_default_1 = torch.ops.streams.sync_dealloc.default(11, 3, getitem_10); getitem_10 = sync_dealloc_default_1 = None - record_event_default_2 = torch.ops.streams.record_event.default(8, 2); record_event_default_2 = None - wait_event_default_2 = torch.ops.streams.wait_event.default(8, 1); wait_event_default_2 = None + sync_dealloc_default_1 = torch.ops.streams.sync_dealloc.default(11, 4, getitem_10); getitem_10 = sync_dealloc_default_1 = None + record_event_default_2 = torch.ops.streams.record_event.default(8, 3); record_event_default_2 = None + wait_event_default_2 = torch.ops.streams.wait_event.default(8, 2); wait_event_default_2 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_1, getitem_11); tangents_1 = None # No stacktrace found for following nodes subgraph_record_event_default_6 = self.subgraph_record_event_default_6 control_deps_14 = torch.ops.higher_order.control_deps((add_1,), subgraph_record_event_default_6, add_1); add_1 = subgraph_record_event_default_6 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} getitem_12: "f32[2, 2]" = control_deps_14[1]; control_deps_14 = None # No stacktrace found for following nodes - sync_dealloc_default_2 = torch.ops.streams.sync_dealloc.default(12, 2, getitem_11); getitem_11 = sync_dealloc_default_2 = None + sync_dealloc_default_2 = torch.ops.streams.sync_dealloc.default(12, 3, getitem_11); getitem_11 = sync_dealloc_default_2 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul_12: "f32[2, 2]" = torch.ops.aten.mul.Tensor(getitem_12, getitem); getitem_12 = getitem = None # No stacktrace found for following nodes subgraph_record_event_default_3 = self.subgraph_record_event_default_3 control_deps_15 = torch.ops.higher_order.control_deps((mul_12,), subgraph_record_event_default_3, mul_12); mul_12 = subgraph_record_event_default_3 = None - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} getitem_13: "f32[2, 2]" = control_deps_15[1] # No stacktrace found for following nodes subgraph_wait_event_default_3 = self.subgraph_wait_event_default_3 control_deps_16 = torch.ops.higher_order.control_deps((control_deps_15, mul_9, mul_10), subgraph_wait_event_default_3, mul_10); control_deps_15 = mul_9 = mul_10 = subgraph_wait_event_default_3 = None - # Annotation: {'stream': 3} + # Annotation: {'stream': 4} getitem_14: "f32[2, 2]" = control_deps_16[1]; control_deps_16 = None - # Annotation: {'stream': 3} + # Annotation: {'stream': 4} add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem_14, getitem_13); getitem_14 = None # No stacktrace found for following nodes - sync_dealloc_default_3 = torch.ops.streams.sync_dealloc.default(13, 1, getitem_13); getitem_13 = sync_dealloc_default_3 = None - return (add_2,) + subgraph_record_event_default_7 = self.subgraph_record_event_default_7 + control_deps_17 = torch.ops.higher_order.control_deps((add_2,), subgraph_record_event_default_7, add_2); add_2 = subgraph_record_event_default_7 = None + + # Annotation: {'stream': 4} + getitem_15: "f32[2, 2]" = control_deps_17[1]; control_deps_17 = None + + # No stacktrace found for following nodes + sync_dealloc_default_3 = torch.ops.streams.sync_dealloc.default(13, 2, getitem_13); getitem_13 = sync_dealloc_default_3 = None + return (getitem_15,) class subgraph_record_event_default(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(6, 4) + record_event_default = torch.ops.streams.record_event.default(6, 5) return (record_event_default, dep_0) + class subgraph_wait_event_default(torch.nn.Module): + def forward(self): + # No stacktrace found for following nodes + wait_event_default = torch.ops.streams.wait_event.default(6, 4) + return wait_event_default + class subgraph_record_event_default_4(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(10, 3) + record_event_default = torch.ops.streams.record_event.default(10, 4) return (record_event_default, dep_0) class subgraph_record_event_default_1(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(7, 3) + record_event_default = torch.ops.streams.record_event.default(7, 4) return (record_event_default, dep_0) + class subgraph_wait_event_default_1(torch.nn.Module): + def forward(self): + # No stacktrace found for following nodes + wait_event_default = torch.ops.streams.wait_event.default(7, 3) + return wait_event_default + class subgraph_record_event_default_5(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(11, 2) + record_event_default = torch.ops.streams.record_event.default(11, 3) return (record_event_default, dep_0) class subgraph_record_event_default_6(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(12, 1) + record_event_default = torch.ops.streams.record_event.default(12, 2) return (record_event_default, dep_0) class subgraph_record_event_default_3(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - record_event_default = torch.ops.streams.record_event.default(9, 1) + record_event_default = torch.ops.streams.record_event.default(9, 2) return (record_event_default, dep_0) class subgraph_wait_event_default_3(torch.nn.Module): def forward(self, dep_0: "f32[2, 2]"): # No stacktrace found for following nodes - wait_event_default = torch.ops.streams.wait_event.default(9, 3) + wait_event_default = torch.ops.streams.wait_event.default(9, 4) return (wait_event_default, dep_0) -""", # noqa: B950 + + class subgraph_record_event_default_7(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]"): + # No stacktrace found for following nodes + record_event_default = torch.ops.streams.record_event.default(13, 4) + return (record_event_default, dep_0) +""", ) @requires_cuda @@ -998,7 +1033,7 @@ def fn(x): """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[2, 2]"): - # Annotation: {'stream': 0} + # Annotation: {'stream': 1} add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 2) copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = None return (copy_,) @@ -1251,12 +1286,16 @@ def fn(x) -> torch.Tensor: return a + b + y + z inp = (torch.ones(2, 2, device="cuda"),) - ( - _, - _, - fw_graphs, - _, - ) = extract_graph(fn, *inp) + # Patch out wrapping so we get the raw graph to manually wrap below. + with patch( + "torch._functorch._aot_autograd.graph_capture.wrap_all_sync_nodes_with_control_deps" + ): + ( + _, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) gm = fw_graphs[0] graph = gm.graph @@ -1417,7 +1456,7 @@ def fn(x, y): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2) add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None return (add, primals_1, mul) @@ -1434,13 +1473,13 @@ def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[2, 2]", mul: "f32[2, 2]", tangents_1: "f32[2, 2]"): - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2) - # Annotation: {'stream': 1} + # Annotation: {'stream': 2} clone: "f32[2, 2]" = torch.ops.aten.clone.default(tangents_1); tangents_1 = None - # Annotation: {'stream': 0} No stacktrace found for following nodes + # Annotation: {'stream': 1} No stacktrace found for following nodes copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = copy_ = None return (mul_2, clone) """, @@ -1475,6 +1514,10 @@ def test_is_marked_side_effectful(self): torch.ops.streams.record_event.default, torch.fx.node._side_effectful_functions, ) + self.assertIn( + torch.ops.streams.synchronize_event.default, + torch.fx.node._side_effectful_functions, + ) @requires_cuda def test_backward_sync_control_deps_e2e(self) -> None: @@ -1527,6 +1570,679 @@ def fn(x, y): "Expected control_deps nodes in backward graph for stream synchronization", ) + def test_sync_dealloc_has_fake_impl(self): + """Test that sync_dealloc has a registered fake impl. + + Without a fake impl, Inductor's backward compilation crashes when the + backward graph contains cross-stream sync_dealloc ops. + """ + from torch._subclasses.fake_tensor import FakeTensorMode + + with FakeTensorMode(): + t = torch.randn(4) + # Should not raise "no fake impl registered" + torch.ops.streams.sync_dealloc.default(0, 1, t) + + def test_record_stream_has_fake_impl(self): + """Test that record_stream's fake impl has the correct signature.""" + from torch._subclasses.fake_tensor import FakeTensorMode + + with FakeTensorMode(): + t = torch.randn(4) + # Should not raise due to signature mismatch + torch.ops.streams.record_stream.default(t, 0) + + @requires_cuda + def test_record_stream(self): + backend = torch._dynamo.testing.EagerAndRecordGraphs() + + def fn(x): + s = torch.Stream() + x.record_stream(s) + return x + + compiled = torch.compile(fn, backend=backend, fullgraph=True) + compiled(torch.randn(4, device="cuda")) + + self.assertEqual(len(backend.graphs), 1) + found = any( + node.target is torch.ops.streams.record_stream + for node in backend.graphs[0].graph.nodes + ) + self.assertTrue(found, "record_stream op not found in graph") + + @requires_cuda + def test_event_record_after_input_mutation_errors(self): + def fn(x): + s = torch.Stream() + e = torch.Event() + with s: + x.add_(1) + e.record() + return e + + with self.assertRaisesRegex(RuntimeError, "An event was recorded on a stream"): + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + def test_event_record_after_input_mutation_stack_traces(self): + def fn(x): + s = torch.Stream() + e = torch.Event() + with s: + x.add_(1) + e.record() + return e + + try: + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + self.fail("Expected RuntimeError") + except RuntimeError as e: + msg = str(e) + self.assertIn("Input mutation occurred here:", msg) + self.assertIn("x.add_(1)", msg) + self.assertIn("Event record occurred here:", msg) + self.assertIn("e.record()", msg) + + @requires_cuda + def test_event_record_after_input_mutation_record_event(self): + def fn(x): + s = torch.Stream() + with s: + x.add_(1) + e = s.record_event() + return e + + with self.assertRaisesRegex(RuntimeError, "An event was recorded on a stream"): + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + def test_event_record_after_input_mutation_through_view(self): + def fn(x): + s = torch.Stream() + e = torch.Event() + v = x.view(-1) + with s: + v.add_(1) + e.record() + return e + + with self.assertRaisesRegex(RuntimeError, "An event was recorded on a stream"): + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + def test_event_record_after_input_mutation_input_event(self): + def fn(x, e): + s = torch.Stream() + with s: + x.add_(1) + e.record() + return x + + with self.assertRaisesRegex(RuntimeError, "An event was recorded on a stream"): + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda"), + torch.Event(), + ) + + @requires_cuda + def test_event_record_before_input_mutation_no_error(self): + def fn(x): + s = torch.Stream() + e = torch.Event() + with s: + e.record() + x.add_(1) + return e + + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + def test_event_record_on_different_stream_no_error(self): + def fn(x): + s0 = torch.Stream() + s1 = torch.Stream() + e = torch.Event() + with s0: + x.add_(1) + with s1: + e.record() + return e + + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + def test_event_not_returned_no_error(self): + def fn(x): + s = torch.Stream() + e = torch.Event() + with s: + x.add_(1) + e.record() + return x + + with self.assertRaisesRegex(RuntimeError, "An event was recorded on a stream"): + torch.compile(fn, backend="eager", fullgraph=True)( + torch.ones(2, 2, device="cuda") + ) + + @requires_cuda + @unittest.skip("https://github.com/pytorch/pytorch/issues/177771") + def test_cuda_event_record_on_stream(self): + """torch.cuda.Event should be accepted by torch.Stream.record_event (C++ type check).""" + s = torch.Stream(device="cuda") + e = torch.cuda.Event() + # This hits THPStream_record_event in Stream.cpp which does a type check + s.record_event(e) + + @requires_cuda + def test_event_synchronize_tracing(self): + def fn(x): + e = torch.Event() + e.record() + x = x + 1 + e.synchronize() + return x + + inp = (torch.ones(2, 2, device="cuda"),) + ( + _, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]"): + # + record_event = torch.ops.streams.record_event.default(1, 0); record_event = None + + # + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None + + # No stacktrace found for following nodes + subgraph_synchronize_event = self.subgraph_synchronize_event + control_deps = torch.ops.higher_order.control_deps((add,), subgraph_synchronize_event, add); subgraph_synchronize_event = control_deps = None + return (add,) + + class subgraph_synchronize_event(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]"): + # + synchronize_event_default = torch.ops.streams.synchronize_event.default(1) + return (synchronize_event_default, dep_0) +""", + ) + + @requires_cuda + def test_event_synchronize_inductor_lowering(self): + with patch("torch._inductor.config.implicit_fallbacks", False): + + @torch.compile() + def fn(x): + e = torch.Event() + x = x + 1 + e.record() + e.synchronize() + return x + + inp = (torch.ones(2, 2, device="cuda"),) + fn(*inp) + + @requires_cuda + def test_control_deps_wrapping_synchronize_event(self) -> None: + """Test that synchronize_event threads recorded ops' values through. + + After record_event wraps ops in control_deps and produces getitem + pass-throughs, synchronize_event must also thread those through so + that subsequent consumers depend on the synchronize. + """ + + def fn(x) -> torch.Tensor: + e = torch.Event() + y = x + 1 + e.record() + e.synchronize() + # z uses y which was produced before the record — its value must + # be threaded through both record and synchronize control_deps. + z = y * 2 + return z + + inp = (torch.ones(2, 2, device="cuda"),) + # Patch out wrapping so we get the raw graph to manually wrap below. + with patch( + "torch._functorch._aot_autograd.graph_capture.wrap_all_sync_nodes_with_control_deps" + ): + ( + _, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + + gm = fw_graphs[0] + graph = gm.graph + + import operator + + from torch._functorch._aot_autograd.streams import ( + set_stream, + wrap_all_sync_nodes_with_control_deps, + ) + from torch._inductor.fx_passes.control_dependencies import control_deps + + # extract_graph doesn't annotate streams, so set stream metadata on + # compute nodes to match the record_event's stream index. + record_node = next( + n + for n in graph.nodes + if n.op == "call_function" + and n.target is torch.ops.streams.record_event.default + ) + stream_idx = record_node.args[1] + for n in graph.nodes: + if ( + n.op == "call_function" + and "val" in n.meta + and n.target + not in ( + torch.ops.streams.record_event.default, + torch.ops.streams.synchronize_event.default, + ) + ): + set_stream(n, stream_idx) + + wrap_all_sync_nodes_with_control_deps(gm) + + ctrl_nodes = list(graph.find_nodes(op="call_function", target=control_deps)) + # record_event + synchronize_event = 2 control_deps nodes + self.assertEqual(len(ctrl_nodes), 2) + record_ctrl = ctrl_nodes[0] + sync_ctrl = ctrl_nodes[1] + + # synchronize_event's control_deps should depend on record's ctrl + self.assertIn(record_ctrl, sync_ctrl.args[0]) + + # The record should thread through the add (y = x + 1) + record_getitems = [ + n + for n in graph.nodes + if n.op == "call_function" + and n.target == operator.getitem + and n.args[0] is record_ctrl + ] + self.assertGreaterEqual(len(record_getitems), 1) + + # Those getitems should be passed through synchronize's control_deps + # as additional args (the passthrough deps) + sync_passthrough_args = sync_ctrl.args[2:] # skip (deps_tuple, subgraph) + for getitem in record_getitems: + self.assertIn( + getitem, + sync_passthrough_args, + "record_event's getitem should be threaded through synchronize_event", + ) + + # The mul (z = y * 2) should consume a getitem from synchronize's + # control_deps, not directly from record's. + sync_getitems = [ + n + for n in graph.nodes + if n.op == "call_function" + and n.target == operator.getitem + and n.args[0] is sync_ctrl + ] + self.assertGreaterEqual(len(sync_getitems), 1) + + # Find the mul node and verify it uses a sync getitem + mul_nodes = [ + n + for n in graph.nodes + if n.op == "call_function" and n.target == torch.ops.aten.mul.Tensor + ] + self.assertEqual(len(mul_nodes), 1) + mul_args = set(mul_nodes[0].args) + self.assertTrue( + mul_args & set(sync_getitems), + "mul should depend on synchronize_event's getitem, not record_event's", + ) + + @requires_cuda + def test_external_event_synchronize_threads_inputs(self) -> None: + """When the event was recorded externally, synchronize threads graph inputs through.""" + + def fn(x): + e = torch.Event() + y = x + 1 + e.record() + e.synchronize() + z = y * 2 + return z + + inp = (torch.ones(2, 2, device="cuda"),) + # Patch out wrapping so we get the raw graph to manually wrap below. + with patch( + "torch._functorch._aot_autograd.graph_capture.wrap_all_sync_nodes_with_control_deps" + ): + ( + _, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + + gm = fw_graphs[0] + graph = gm.graph + + from torch._functorch._aot_autograd.streams import ( + set_stream, + wrap_all_sync_nodes_with_control_deps, + ) + + # Remove the record_event to simulate an externally-recorded event. + record_node = next( + n + for n in graph.nodes + if n.op == "call_function" + and n.target is torch.ops.streams.record_event.default + ) + stream_idx = record_node.args[1] + graph.erase_node(record_node) + + # Set stream metadata on compute nodes. + for n in graph.nodes: + if ( + n.op == "call_function" + and "val" in n.meta + and n.target is not torch.ops.streams.synchronize_event.default + ): + set_stream(n, stream_idx) + + wrap_all_sync_nodes_with_control_deps(gm) + gm.recompile() + + self.assertExpectedInline( + print_graph(gm), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, 1) + + # No stacktrace found for following nodes + subgraph_synchronize_event = self.subgraph_synchronize_event + control_deps = torch.ops.higher_order.control_deps((arg0_1, add), subgraph_synchronize_event, add); arg0_1 = add = subgraph_synchronize_event = None + + # Annotation: {'stream': 0} + getitem: "f32[2, 2]" = control_deps[1]; control_deps = None + + # Annotation: {'stream': 0} + mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None + return (mul,) + + class subgraph_synchronize_event(torch.nn.Module): + def forward(self, dep_0: "f32[2, 2]"): + # + synchronize_event_default = torch.ops.streams.synchronize_event.default(1) + return (synchronize_event_default, dep_0) +""", + ) + + @requires_cuda + def test_event_synchronize_control_deps_e2e(self): + """E2E: compute → record → synchronize → use result through torch.compile.""" + + def f(x): + e = torch.Event() + y = x + 1 + e.record() + e.synchronize() + z = y * 2 + return z + + inp = torch.ones(2, 2, device="cuda") + eager_result = f(inp) + compiled_result = torch.compile(f)(inp) + self.assertEqual(eager_result, compiled_result) + + @requires_cuda + def test_event_synchronize_e2e(self): + def f(a_list): + a_cpu_list = [] + a_to_cpu_event_list = [] + for a in a_list: + a_cpu = a.to(device="cpu", non_blocking=True) + e = torch.Event() + e.record() + a_cpu_list.append(a_cpu) + a_to_cpu_event_list.append(e) + + for e in a_to_cpu_event_list: + e.synchronize() + + return torch.cat(a_cpu_list) + + f_compiled = torch.compile(f) + inputs = [ + torch.rand(100, dtype=torch.float16, device="cuda") for _ in range(10) + ] + eager_result = f(inputs) + compiled_result = f_compiled(inputs) + self.assertEqual(eager_result, compiled_result) + + @requires_cuda + def test_event_record_wait_on_default_stream(self): + e = torch.cuda.Event() + + def f(x): + y = x + 1 + e.record() + e.wait() + return y + 1 + + f_compiled = torch.compile(f) + x = torch.randn(10, device="cuda") + eager_result = f(x) + compiled_result = f_compiled(x) + self.assertEqual(eager_result, compiled_result) + + @requires_cuda + def test_record_stream_inductor_output_code(self) -> None: + """Verify record_stream is ordered between the producing kernel and the + consuming kernel in inductor-generated wrapper code.""" + from torch._inductor.utils import run_and_get_code + from torch.testing import FileCheck + + def fn(x): + s = torch.Stream(device="cuda") + y = x + 1 + y.record_stream(s) + z = y * 2 + return z + + compiled = torch.compile(fn, backend="inductor", fullgraph=True) + x = torch.randn(1024, device="cuda") + result, (code,) = run_and_get_code(compiled, x) + self.assertEqual(result, (x + 1) * 2) + + # record_stream must appear after the kernel that produces the tensor + # and before the return. + FileCheck().check(".run(").check( + "torch.ops.streams.record_stream.default(" + ).check("return").run(code) + + @requires_cuda + def test_del_multi_stream_sync_dealloc(self): + def fn(x, y): + s = torch.Stream() + e = torch.Event() + z0 = x + 1 + with s: + z = torch.add(x, y) + e.record() + e.wait() + del x + return z0, z + + inp = (torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + graph_str = print_graph(fw_graphs[0]) + self.assertIn("sync_dealloc", graph_str) + self.assertIn("record_event", graph_str) + + @requires_cuda + def test_del_same_stream_no_sync_dealloc(self): + def fn(x, y): + s = torch.Stream() + e = torch.Event() + with s: + z = torch.add(x, y) + del x + e.record() + e.wait() + return z + + inp = (torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + graph_str = print_graph(fw_graphs[0]) + self.assertNotIn("sync_dealloc", graph_str) + + @requires_cuda + def test_del_single_stream_no_sync_dealloc(self): + def fn(x, y): + z = torch.add(x, y) + del x + return z + + inp = (torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + graph_str = print_graph(fw_graphs[0]) + self.assertNotIn("sync_dealloc", graph_str) + + @requires_cuda + def test_del_attr_multi_stream_sync_dealloc(self): + class Holder: + pass + + def fn(x, y): + s = torch.Stream() + e = torch.Event() + h = Holder() + h.tensor = x + z0 = x + 1 + with s: + z = torch.add(h.tensor, y) + e.record() + e.wait() + del h.tensor + return z0, z + + inp = (torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + graph_str = print_graph(fw_graphs[0]) + self.assertIn("sync_dealloc", graph_str) + self.assertIn("record_event", graph_str) + + @requires_cuda + def test_del_subscr_multi_stream_sync_dealloc(self): + def fn(x, y): + s = torch.Stream() + e = torch.Event() + d = {"t": x} + z0 = x + 1 + with s: + z = torch.add(d["t"], y) + e.record() + e.wait() + del d["t"] + return z0, z + + inp = (torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + graph_str = print_graph(fw_graphs[0]) + self.assertIn("sync_dealloc", graph_str) + self.assertIn("record_event", graph_str) + + @requires_cuda + def test_stream_pointer_extraction_edge_cases(self): + def get_ptrs(stream_a, stream_b, default_stream): + return ( + stream_a.cuda_stream, + stream_b.cuda_stream, + default_stream.cuda_stream, + ) + + s1, s2 = torch.cuda.Stream(), torch.cuda.Stream() + default_s = torch.cuda.default_stream() + expected_s1, expected_s2 = s1.cuda_stream, s2.cuda_stream + + self.assertNotEqual(expected_s1, expected_s2) + self.assertGreater(expected_s1, 1000) + + opt_get_ptrs = torch.compile(get_ptrs, backend="inductor") + + s3 = torch.cuda.Stream() + with torch.cuda.stream(s3): + actual_s1, actual_s2, actual_default = opt_get_ptrs(s1, s2, default_s) + + self.assertEqual(actual_s1, expected_s1) + self.assertEqual(actual_s2, expected_s2) + self.assertEqual(actual_default, default_s.cuda_stream) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index 526eb88cbda8f..379dad9cb9fe5 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -348,7 +348,7 @@ def test_schedule(self): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -382,7 +382,7 @@ def test_cudagraphs(self): {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"compilation_metrics_runtime": "METRICS", "frame_id": 0, "frame_compile_id": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -444,7 +444,7 @@ def fn(x, y): {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -476,7 +476,7 @@ def test_example_fn(self): {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -534,24 +534,23 @@ def test_example_training_fn(self): {"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "compiled_fn_wrapper", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_45_MODIFIED_BYTECODE", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} -{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"bwd_compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} -{"dynamo_start": {"stack": "STACK"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_47_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -{"compilation_metrics": "METRICS", "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "torch_dynamo_resume_in_example_training_fn_at_47_ORIGINAL_BYTECODE", "encoding": "string"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "dynamo_hint_overrides": {}, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +""", ) self.assertParses() @@ -572,7 +571,7 @@ def test_dynamo_error(self): {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -618,7 +617,7 @@ def throw(x): {"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "dynamo_error", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -655,67 +654,9 @@ def forward(self, x): dist.destroy_process_group() - if not torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertExpectedInline( - self.buffer.getvalue(), - """\ -{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} -{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} -{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} -{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} -{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} -{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} -{"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"dynamo_output_graph": {"sizes": {"l_x_": [1024, 1024], "l__self___layers_0": [1024, 1024], "l__self___layers_1": [1024, 1024]}}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"optimize_ddp_split_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"optimize_ddp_split_child": {"name": "submod_0"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"optimize_ddp_split_child": {"name": "submod_1"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024, 1024], "is_leaf": true, "stride": [1024, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_joint_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_backward_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "before_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "after_pre_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_joint_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "torch._functorch.config", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_backward_graph": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "before_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "inductor_post_grad_graph", "encoding": "string"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"inductor_output_code": {"filename": "FILENAME", "file_path": "FILENAME"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"artifact": {"name": "aotautograd_cache_bypass", "encoding": "json"}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"compilation_metrics": "METRICS", "rank": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 - ) - else: - self.assertExpectedInline( - self.buffer.getvalue(), - """\ + self.assertExpectedInline( + self.buffer.getvalue(), + """\ {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"describe_storage": {"id": 0, "describer_id": "ID", "size": 4194304}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} @@ -797,8 +738,8 @@ def forward(self, x): {"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 - ) +""", + ) self.assertParses() @@ -841,7 +782,7 @@ def fn(x): {"artifact": {"name": "torch_dynamo_resume_in_fn_at_808_MODIFIED_BYTECODE", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -888,7 +829,7 @@ def fn(a, b): {"dynamo_output_graph": {"sizes": {"l_a_": ["s97", "s52"], "l_b_": ["s52", "s20"], "matmul": ["s97", "s20"]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -929,7 +870,7 @@ def inner(x, ys, zs): {"dynamo_output_graph": {"sizes": {"l_x_": [1], "x": [1]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 1, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -962,7 +903,7 @@ def forward(self, x, y): {"dynamo_output_graph": {"sizes": {"l_x_": [3], "l_y_": [3], "add": [3]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) @requires_tlparse @@ -1018,7 +959,7 @@ def fn(a): {"artifact": {"name": "aotautograd_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 0} -""", # noqa: B950 +""", ) self.assertParses() @@ -1104,7 +1045,7 @@ def fn(a): '"attempt": 0, "has_payload": "HASH"}' ), ( - '{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 2, "frame_compile_id": 0, ' + '{"chromium_event": {}, "compiled_autograd_id": 0, "frame_id": 1, "frame_compile_id": 0, ' '"attempt": 0, "has_payload": "HASH"}' ), ] @@ -1146,13 +1087,10 @@ def backward(ctx, gO): '{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 3, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 4, "frame_compile_id": 0, "attempt": 0}', '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 5, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 6, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 7, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 8, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 0, "frame_id": 9, "frame_compile_id": 0, "attempt": 0}', - '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 12, "frame_compile_id": 1, "attempt": 0}', + '{"dynamo_start": {"stack": "STACK"}, "compiled_autograd_id": 1, "frame_id": 6, "frame_compile_id": 0, "attempt": 0}', ] logs = self.buffer.getvalue() self.assertTrue(all(event in logs for event in expected)) diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index ccb20644b0861..cff4ec776a291 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -728,7 +728,7 @@ def __torch_function__( ): if kwargs is None: kwargs = {} - if func not in HANDLED_FUNCTIONS or not all( # noqa: C419 + if func not in HANDLED_FUNCTIONS or not all( [ # noqa: C419 issubclass(t, (torch.Tensor, MyClass)) for t in types ] @@ -2285,7 +2285,6 @@ def f(x): out_test = compiled_f(view) self.assertEqual(out_ref, out_test) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @parametrize("dynamic", [True, False]) def test_mark_static_with_subclass_desugaring(self, dynamic): from collections.abc import Callable @@ -2313,6 +2312,7 @@ def inner_compile( boxed_forward_device_index: BoxedDeviceIndex | None = None, layout_opt: bool | None = None, extern_node_serializer: Callable[[list[Any]], Any] | None = None, + **kwargs: Any, ): if dynamic: self.assertEqual(static_input_idxs, [2, 3, 4]) @@ -2328,7 +2328,6 @@ def fn(t0, t1, t2): fn(torch.ones(4), x, torch.ones(4)) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_subclass_parameters_are_static_under_training(self): from collections.abc import Callable from typing import Any @@ -2350,6 +2349,7 @@ def inner_compile( boxed_forward_device_index: BoxedDeviceIndex | None = None, layout_opt: bool | None = None, extern_node_serializer: Callable[[list[Any]], Any] | None = None, + **kwargs: Any, ): # Important bit: there are 3 params: linear.weight.a, linear.weight.b, linear.bias, # which are the first 3 args of the graph. @@ -2469,7 +2469,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=1) primals_7, # SavedForBackwardsAOTOutput(idx=2) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2495,7 +2495,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_clone_view(self): @@ -2537,7 +2537,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2562,7 +2562,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_mul(self): @@ -2611,7 +2611,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=2) primals_7, # SavedForBackwardsAOTOutput(idx=3) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2644,7 +2644,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_view(self): @@ -2686,7 +2686,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2711,7 +2711,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_view_mul(self): @@ -2752,7 +2752,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2777,7 +2777,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_return_tensor_and_subclass(self): @@ -2819,7 +2819,7 @@ def forward( primals_5, # SavedForBackwardsAOTOutput(idx=0) primals_7, # SavedForBackwardsAOTOutput(idx=1) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2844,7 +2844,7 @@ def forward( primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1) primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0) ) -""", # noqa: B950 +""", ) @unittest.expectedFailure @@ -2872,7 +2872,7 @@ def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]", primals_3: "Sy view: "f32[12]" = torch.ops.aten.view.default(clone, [mul]) view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [mul]); clone_1 = None return [clone, view, view_1, mul, primals_5, primals_6] -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2883,7 +2883,7 @@ def forward(self, primals_5: "Sym(3)", primals_6: "Sym(4)", tangents_1: "f32[12] view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_6]); tangents_1 = None view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_6]); tangents_2 = primals_5 = primals_6 = None return [view_2, view_3, None, None] -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_automatic_dynamic_shapes(self): @@ -2923,7 +2923,7 @@ def forward( view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b') clone_1, # PlainAOTOutput(idx=2) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2952,7 +2952,7 @@ def forward( clone_1, # PlainAOTOutput(idx=2) primals_5, # SavedForBackwardsAOTOutput(idx=0) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2970,7 +2970,7 @@ def forward( view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -2992,7 +2992,7 @@ def forward( primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_mark_dynamic_shapes(self): @@ -3040,7 +3040,7 @@ def forward( clone_1, # PlainAOTOutput(idx=2) primals_5, # SavedForBackwardsAOTOutput(idx=0) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -3062,7 +3062,7 @@ def forward( primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1) primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0) ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_different_shape(self): @@ -3094,7 +3094,7 @@ def forward( view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a') view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b') ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -3112,7 +3112,7 @@ def forward( view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a') view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b') ) -""", # noqa: B950 +""", ) def test_tensor_subclass_TwoTensor_return_shape(self): @@ -3269,7 +3269,7 @@ def forward( primals_8, # SavedForBackwardsAOTOutput(idx=1) primals_10, # SavedForBackwardsAOTOutput(idx=2) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -3299,7 +3299,7 @@ def forward( primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) ) -""", # noqa: B950 +""", ) def test_njt_subclass_from_cat(self): @@ -3346,7 +3346,7 @@ def forward( primals_10, # SavedForBackwardsAOTOutput(idx=1) add_2, # SavedForBackwardsAOTOutput(idx=2) ) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -3378,7 +3378,7 @@ def forward( primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2) primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1) ) -""", # noqa: B950 +""", ) def test_njt_subclass_from_buffer(self): @@ -3446,7 +3446,7 @@ def forward( sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2) sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1) ) -""", # noqa: B950 +""", ) def test_deferred_init_subclass_init_not_traced(self): @@ -3633,7 +3633,7 @@ def forward(self, s71: "Sym(s71)", L_nt_: "NestedTensor(f64[3, s71, 5])"): add: "NestedTensor(f64[3, s71, 5])" = l_nt_ + 2; l_nt_ = None return (add,) -""", # noqa: B950 +""", ) # Note: [What kind of guards are involved in nested tensor compilation] diff --git a/test/dynamo/test_subgraphs.py b/test/dynamo/test_subgraphs.py index 16c765bfb1409..7095b4f80db1f 100644 --- a/test/dynamo/test_subgraphs.py +++ b/test/dynamo/test_subgraphs.py @@ -119,7 +119,7 @@ def fn(a, b): c2 = b - a return indirectly_unsupported(c1, c2) - self._common(fn, 2, 3) + self._common(fn, 1, 3) def test_indirect_unsupported2(self): def fn(a, b): @@ -129,14 +129,14 @@ def fn(a, b): c2 = b - a return local_const1 / (local_const2 - indirectly_unsupported(c1, c2)) - self._common(fn, 3, 5) + self._common(fn, 2, 5) def test_indirect_unsupported3(self): def fn(a, b): args = [a - b, b - a] return indirectly_unsupported(*args) - self._common(fn, 2, 3) + self._common(fn, 1, 3) def test_stack_state1(self): def fn(a, b): @@ -156,7 +156,7 @@ def fn(a, b): c2 = b - a return t1 / (t2 - indirectly_unsupported(c1, c2)) - self._common(fn, 3, 7) + self._common(fn, 2, 7) def test_multigraph(self): def fn(a, b): @@ -199,7 +199,7 @@ def fn(a, b): x = x + 2.0 return x - self._common(fn, 3, 7) + self._common(fn, 2, 7) def test_resume3(self): def fn(a, b): @@ -212,7 +212,7 @@ def fn(a, b): x = x + 2.0 return x - self._common(fn, 3, 7) + self._common(fn, 2, 7) def test_resume4(self): def fn(a, b): @@ -225,7 +225,7 @@ def fn(a, b): x = x + 2.0 return x - self._common(fn, 3, 7) + self._common(fn, 2, 7) def test_resume5(self): def fn(a, b): diff --git a/test/dynamo/test_torchrec.py b/test/dynamo/test_torchrec.py index 311270a8f652a..9ad8f54d3de08 100644 --- a/test/dynamo/test_torchrec.py +++ b/test/dynamo/test_torchrec.py @@ -62,7 +62,7 @@ def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor": if not HAS_TORCHREC: print("torchrec not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest @unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec") diff --git a/test/dynamo/test_tp_slots.py b/test/dynamo/test_tp_slots.py new file mode 100644 index 0000000000000..6952b92058a27 --- /dev/null +++ b/test/dynamo/test_tp_slots.py @@ -0,0 +1,280 @@ +# Owner(s): ["module: dynamo"] + +"""Tests for CPython type slot detection in Dynamo. + +Tests that get_type_slots correctly identifies which protocol methods +(sequence, mapping, number, type) are implemented by various types. +""" + +import collections.abc +import dataclasses +import enum + +from torch._C._dynamo import ( + get_type_slots, + has_slot, + PyMappingSlots, + PyNumberSlots, + PySequenceSlots, + PyTypeSlots, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestTypeSlots(TestCase): + """Test suite for type slot detection.""" + + def _get_slot_info(self, obj_type): + """Helper to get and unpack slot information.""" + seq_slots, map_slots, num_slots, type_slots = get_type_slots(obj_type) + return seq_slots, map_slots, num_slots, type_slots + + def test_dict_slots(self): + """Test that dict has mapping protocol but not sequence protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(dict) + + # dict should NOT have sq_length (sequence protocol) + self.assertFalse(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + + # dict SHOULD have mapping protocol + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_SUBSCRIPT)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_ASS_SUBSCRIPT)) + + def test_list_slots(self): + """Test that list has sequence protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(list) + + # list SHOULD have sequence protocol + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_ITEM)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + # list also has mapping protocol for compatibility (mp_length, mp_subscript, mp_ass_subscript) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_SUBSCRIPT)) + + def test_tuple_slots(self): + """Test that tuple has sequence protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(tuple) + + # tuple SHOULD have sequence protocol + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_ITEM)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + def test_set_slots(self): + """Test that set has sequence protocol for contains.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(set) + + # set SHOULD have sq_contains + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + # set should NOT have mapping protocol + self.assertFalse(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + + def test_dict_subclass_slots(self): + """Test that dict subclasses have both mapping and sequence protocol.""" + + class DictSubclass(dict): + pass + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(DictSubclass) + + # Dict subclasses expose both protocols for compatibility + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + + def test_list_subclass_slots(self): + """Test that list subclasses have sequence protocol.""" + + class ListSubclass(list): + pass + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(ListSubclass) + + # List subclasses should have sequence protocol + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_ITEM)) + + def test_int_slots(self): + """Test that int has number protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(int) + + # int SHOULD have number protocol operations + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_ADD)) + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_SUBTRACT)) + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_MULTIPLY)) + + def test_str_slots(self): + """Test that str has sequence protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(str) + + # str SHOULD have sequence protocol + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_ITEM)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + def test_type_has_call_slot(self): + """Test that type objects have tp_call slot.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(type) + + # type SHOULD have tp_call (for calling classes/types) + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_CALL)) + + def test_type_has_hash_slot(self): + """Test that most types have tp_hash slot.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(int) + + # int SHOULD have tp_hash + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_HASH)) + + def test_custom_class_slots(self): + """Test that custom user-defined classes have minimal slots.""" + + class CustomClass: + pass + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(CustomClass) + + # Custom classes should not have sequence/mapping/number protocol by default + self.assertFalse(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertFalse(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + + def test_set_subclass_slots(self): + """Test that set subclasses inherit set protocol.""" + + class SetSubclass(set): + pass + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(SetSubclass) + + # Set subclasses should have sq_contains + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + def test_tuple_subclass_slots(self): + """Test that tuple subclasses inherit sequence protocol.""" + + class TupleSubclass(tuple): + __slots__ = () + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(TupleSubclass) + + # Tuple subclasses should have sequence protocol + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_ITEM)) + + def test_mutable_mapping_slots(self): + """Test that MutableMapping ABC has mapping protocol.""" + + class MyMapping(collections.abc.MutableMapping): + def __init__(self): + self._data = {} + + def __getitem__(self, key): + return self._data[key] + + def __setitem__(self, key, value): + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(MyMapping) + + # MutableMapping subclasses should have mapping protocol + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_SUBSCRIPT)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_ASS_SUBSCRIPT)) + + def test_frozen_dataclass_slots(self): + """Test that frozen dataclasses have standard object slots.""" + + @dataclasses.dataclass(frozen=True) + class FrozenData: + x: int + y: str + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(FrozenData) + + # Frozen dataclasses should have basic type slots like hash + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_HASH)) + + def test_enum_slots(self): + """Test that Enum types have expected slots.""" + + class Color(enum.Enum): + RED = 1 + GREEN = 2 + BLUE = 3 + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(Color) + + # Enums should have type protocol slots + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_HASH)) + + def test_enum_member_slots(self): + """Test that individual enum members work correctly.""" + + class Status(enum.Enum): + PENDING = "pending" + ACTIVE = "active" + + # Enum members are instances, test the enum class + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(Status) + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_HASH)) + + def test_metaclass_slots(self): + """Test that metaclass types have tp_call for instantiation.""" + + class MyMeta(type): + pass + + class MyClass(metaclass=MyMeta): + pass + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(MyMeta) + + # Metaclasses should have tp_call for creating instances + self.assertTrue(has_slot(type_slots, PyTypeSlots.TP_CALL)) + + def test_dict_subclass_with_custom_len(self): + """Test dict subclass with custom __len__ (the original bug case).""" + + class DictWithCustomLen(dict): + def __len__(self): + return super().__len__() - 1 + + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info( + DictWithCustomLen + ) + + # Should have both protocols (CPython compatibility) + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_LENGTH)) + self.assertTrue(has_slot(map_slots, PyMappingSlots.MP_LENGTH)) + + def test_frozenset_slots(self): + """Test that frozenset has sequence protocol for contains.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(frozenset) + + # frozenset SHOULD have sq_contains + self.assertTrue(has_slot(seq_slots, PySequenceSlots.SQ_CONTAINS)) + + def test_float_slots(self): + """Test that float has number protocol.""" + seq_slots, map_slots, num_slots, type_slots = self._get_slot_info(float) + + # float SHOULD have number protocol operations + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_ADD)) + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_SUBTRACT)) + self.assertTrue(has_slot(num_slots, PyNumberSlots.NB_MULTIPLY)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 8833445bb495f..e8c44be703d66 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -474,6 +474,8 @@ def test_no_special_handlers_for_torch_non_c_bindings(self): "handle_assert", # No global state (constant) "handle_nested_tensor", # No global state "handle_current_stream", # Safely implemented + "handle_synchronize", # Device type from function identity or arg + "handle_functorch_autograd_grad", # Only inspects placeholder metadata ) for fn in handlers: if isinstance(fn, staticmethod) or inspect.ismethod(fn): @@ -494,7 +496,7 @@ def test_no_special_handlers_for_torch_non_c_bindings(self): ) def test_almost_impossible_missing_name(self): - class weird: # noqa: UP004 + class weird: def __getattribute__(self, name): if name == "__name__": raise AttributeError("test") diff --git a/test/dynamo/test_tree_map.py b/test/dynamo/test_tree_map.py index 3d75e3a2fdaf9..dffa408a4b223 100644 --- a/test/dynamo/test_tree_map.py +++ b/test/dynamo/test_tree_map.py @@ -1,28 +1,18 @@ # Owner(s): ["module: dynamo"] -try: - import optree -except ImportError: # pragma: no cover - optree = None - from collections import namedtuple +from dataclasses import dataclass from typing import NamedTuple import torch import torch._dynamo +import torch.utils._pytree as python_pytree +from torch._dynamo.test_case import run_tests, TestCase from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - run_tests, - TestCase, + subtest, ) -from torch.utils import _pytree as pytree - - -try: - import torch.utils._cxx_pytree as cxx_pytree -except ImportError: # pragma: no cover - cxx_pytree = None def _tensor_leaf(*values): @@ -52,38 +42,56 @@ def _tuple_is_leaf(node): return isinstance(node, tuple) +pytree_modules = { + "python": python_pytree, +} +if python_pytree._cxx_pytree_dynamo_traceable: + import torch.utils._cxx_pytree as cxx_pytree + + pytree_modules["cxx"] = cxx_pytree + pytree_modules["native_optree"] = cxx_pytree.optree +else: + cxx_pytree = None + +optree = cxx_pytree.optree if cxx_pytree is not None else None + + def _require_optree(test_case): if optree is None: test_case.skipTest("optree is unavailable") -TREE_MAP_IMPLEMENTATIONS = [] -if optree is not None: - TREE_MAP_IMPLEMENTATIONS.append(("optree", optree.tree_map)) -TREE_MAP_IMPLEMENTATIONS.append(("pytree_python", pytree.tree_map)) -if cxx_pytree is not None: - TREE_MAP_IMPLEMENTATIONS.append(("pytree_cxx", cxx_pytree.tree_map)) +parametrize_pytree_module = parametrize( + "pytree_name,pytree", + [subtest((name, module), name=name) for name, module in pytree_modules.items()], +) +_PYTREE_MODULES_WITH_PATH = {"python"} -TREE_MAP_WITH_PATH_IMPLEMENTATIONS = [ - ("pytree_python", pytree.tree_map_with_path), -] +parametrize_pytree_module_with_path = parametrize( + "pytree_name,pytree", + [ + subtest((name, module), name=name) + for name, module in pytree_modules.items() + if name in _PYTREE_MODULES_WITH_PATH + ], +) KWARG_CASES = [ ("default", {}, None), - ("none_is_leaf", {"none_is_leaf": True}, {"optree"}), + ("none_is_leaf", {"none_is_leaf": True}, {"native_optree"}), ("is_leaf", {"is_leaf": _tuple_is_leaf}, None), - ("namespace", {"namespace": "torch"}, {"optree"}), + ("namespace", {"namespace": "torch"}, {"native_optree"}), ( "namespace_and_none_is_leaf", {"namespace": "torch", "none_is_leaf": True}, - {"optree"}, + {"native_optree"}, ), ( "namespace_none_is_leaf_predicate", {"namespace": "torch", "none_is_leaf": True, "is_leaf": _tuple_is_leaf}, - {"optree"}, + {"native_optree"}, ), ] @@ -109,8 +117,8 @@ def _build_tree(offset: int) -> dict[str, object]: def _assert_trees_allclose(test_case: TestCase, ref, res) -> None: - ref_flat, ref_spec = pytree.tree_flatten(ref) - res_flat, res_spec = pytree.tree_flatten(res) + ref_flat, ref_spec = python_pytree.tree_flatten(ref) + res_flat, res_spec = python_pytree.tree_flatten(res) test_case.assertEqual(ref_spec, res_spec) for expected, actual in zip(ref_flat, res_flat): if isinstance(expected, torch.Tensor): @@ -125,37 +133,35 @@ def setUp(self): super().setUp() torch._dynamo.reset() - def _run_tree_map(self, tree_map_impl, kwargs): + def _run_tree_map(self, pytree, kwargs): lhs = _build_tree(0) rhs = _build_tree(7) def fn(a, b): - return tree_map_impl(_combine_leaves, a, b, **kwargs) + return pytree.tree_map(_combine_leaves, a, b, **kwargs) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(lhs, rhs) result = compiled(lhs, rhs) _assert_trees_allclose(self, expected, result) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module @parametrize("kwargs_name,kwargs,allowed_impls", KWARG_CASES) def test_tree_map_variants( self, - tree_map_name: str, - tree_map_impl, + pytree_name: str, + pytree, kwargs_name: str, kwargs: dict, allowed_impls, ) -> None: - if tree_map_name == "pytree_cxx" and cxx_pytree is None: - self.skipTest("torch.utils._cxx_pytree is unavailable") - if allowed_impls is not None and tree_map_name not in allowed_impls: + if allowed_impls is not None and pytree_name not in allowed_impls: self.skipTest("kwargs unsupported for implementation") - self._run_tree_map(tree_map_impl, kwargs) + self._run_tree_map(pytree, kwargs) def test_tree_map_rejects_mismatched_container_types(self) -> None: def fn(a, b): - return pytree.tree_map(lambda u, v: u + v, a, b) + return python_pytree.tree_map(lambda u, v: u + v, a, b) lhs = [torch.ones(2), torch.ones(2)] rhs = (torch.ones(2), torch.ones(2)) @@ -172,7 +178,7 @@ def fn(a, b): def test_tree_map_is_leaf_handles_tensor_nodes(self) -> None: def fn(tree): - return pytree.tree_map( + return python_pytree.tree_map( lambda pair: torch.stack(pair).sum(dim=0), tree, is_leaf=lambda node: isinstance(node, tuple), @@ -193,7 +199,7 @@ def mapper(node): return node + 2 def fn(arg): - return pytree.tree_map_only(torch.Tensor, mapper, arg) + return python_pytree.tree_map_only(torch.Tensor, mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -205,7 +211,7 @@ def test_tree_map_only_multiple_trees_falls_back(self) -> None: rhs = {"a": torch.ones(2) * 3, "b": torch.ones(2) * 4} def fn(a, b): - return pytree.tree_map_only(torch.Tensor, lambda x, y: x + y, a, b) + return python_pytree.tree_map_only(torch.Tensor, lambda x, y: x + y, a, b) with self.assertRaisesRegex(TypeError, "callable"): fn(lhs, rhs) @@ -228,7 +234,7 @@ def mapper(node): raise AssertionError("unexpected node passed to mapper") def fn(arg): - return pytree.tree_map_only((int, tuple), mapper, arg) + return python_pytree.tree_map_only((int, tuple), mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -248,7 +254,7 @@ def mapper(node): return node * 2 if isinstance(node, torch.Tensor) else node def fn(arg): - return pytree.tree_map(mapper, arg, is_leaf=is_leaf) + return python_pytree.tree_map(mapper, arg, is_leaf=is_leaf) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -265,7 +271,7 @@ def mapper(node): return node + 5 if isinstance(node, torch.Tensor) else node def fn(arg): - return pytree.tree_map_only(selector, mapper, arg) + return python_pytree.tree_map_only(selector, mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -291,15 +297,15 @@ def fn(a, b): ): compiled(lhs, rhs) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module def test_tree_map_none_nodes_default_behavior( - self, tree_map_name: str, tree_map_impl + self, pytree_name: str, pytree ) -> None: - if tree_map_name == "optree": + if pytree_name == "native_optree": self.skipTest("optree treats None as an internal node by default") def fn(a, b): - return tree_map_impl(lambda u, v: (u, v), a, b) + return pytree.tree_map(lambda u, v: (u, v), a, b) tree = {"k": None} compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -364,9 +370,9 @@ def mapper(node): self.assertIs(result["nested"]["dtype"], torch.float64) self.assertEqual(result, expected) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module def test_user_defined_object_treated_as_leaf( - self, tree_map_name: str, tree_map_impl + self, pytree_name: str, pytree ) -> None: """User-defined objects (not registered in pytree) should be treated as leaves.""" @@ -388,7 +394,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -409,10 +415,8 @@ def fn(arg): self.assertTrue(torch.allclose(result["custom"].value, obj1.value * 3)) self.assertTrue(torch.allclose(result["tensor"], torch.ones(2))) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_user_defined_object_multiple_trees( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_user_defined_object_multiple_trees(self, pytree_name: str, pytree) -> None: """User-defined objects should work correctly with multiple input trees.""" class Point: @@ -429,7 +433,7 @@ def mapper(a, b): return a + b def fn(t1, t2): - return tree_map_impl(mapper, t1, t2) + return pytree.tree_map(mapper, t1, t2) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree1, tree2) @@ -439,10 +443,8 @@ def fn(t1, t2): self.assertEqual(result["point"].y, 6) self.assertEqual(result["val"], 30) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_dict_subclass_treated_as_leaf( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_dict_subclass_treated_as_leaf(self, pytree_name: str, pytree) -> None: """Dict subclasses (not registered in pytree) should be treated as leaves.""" class MyDict(dict): @@ -461,7 +463,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -473,10 +475,8 @@ def fn(arg): # Regular dict should still be traversed self.assertEqual(result["regular"]["x"], 3) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_list_subclass_treated_as_leaf( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_list_subclass_treated_as_leaf(self, pytree_name: str, pytree) -> None: """List subclasses (not registered in pytree) should be treated as leaves.""" class MyList(list): @@ -494,7 +494,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -505,10 +505,8 @@ def fn(arg): # Regular list should be traversed self.assertEqual(result["regular"], [14, 15, 16]) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_tuple_subclass_treated_as_leaf( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_tuple_subclass_treated_as_leaf(self, pytree_name: str, pytree) -> None: """Tuple subclasses (not registered in pytree) should be treated as leaves.""" class MyTuple(tuple): # noqa: SLOT001 @@ -525,7 +523,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -536,9 +534,9 @@ def fn(arg): # Regular tuple should be traversed self.assertEqual(result["regular"], (14, 15, 16)) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module def test_user_defined_object_nested_in_containers( - self, tree_map_name: str, tree_map_impl + self, pytree_name: str, pytree ) -> None: """User-defined objects nested inside containers should be leaves.""" @@ -557,7 +555,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -566,9 +564,9 @@ def fn(arg): self.assertEqual(result["list_of_wrappers"][1].value, 20) self.assertEqual(result["nested"]["wrapper"].value, 30) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module def test_user_defined_object_with_is_leaf_predicate( - self, tree_map_name: str, tree_map_impl + self, pytree_name: str, pytree ) -> None: """Test that is_leaf predicate interacts correctly with user-defined objects.""" @@ -591,7 +589,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg, is_leaf=is_leaf_fn) + return pytree.tree_map(mapper, arg, is_leaf=is_leaf_fn) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -607,7 +605,7 @@ def __init__(self, items): self.items = list(items) # Register with pytree - pytree.register_pytree_node( + python_pytree.register_pytree_node( RegisteredContainer, lambda x: (x.items, None), lambda items, _: RegisteredContainer(items), @@ -622,7 +620,7 @@ def mapper(node): return node def fn(arg): - return pytree.tree_map(mapper, arg) + return python_pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -633,7 +631,38 @@ def fn(arg): self.assertTrue(torch.allclose(result.items[1], torch.zeros(2) + 1)) finally: # Clean up registration - pytree._deregister_pytree_node(RegisteredContainer) + python_pytree._deregister_pytree_node(RegisteredContainer) + + def test_registered_custom_type_multiple_trees_with_eq(self) -> None: + """Dataclass registrations compare class objects through the metaclass.""" + + @dataclass + class RegisteredContainer: + x: torch.Tensor + y: torch.Tensor + + python_pytree.register_pytree_node( + RegisteredContainer, + lambda c: ([c.x, c.y], None), + lambda children, _: RegisteredContainer(*children), + ) + + try: + lhs = RegisteredContainer(torch.zeros(2, 3), torch.ones(2, 3)) + rhs = RegisteredContainer(torch.ones(2, 3), torch.zeros(2, 3)) + + def fn(a, b): + return python_pytree.tree_map(lambda x, y: x + y, a, b) + + compiled = torch.compile(fn, backend="eager", fullgraph=True) + expected = fn(lhs, rhs) + result = compiled(lhs, rhs) + + self.assertIsInstance(result, RegisteredContainer) + self.assertEqual(result.x, expected.x) + self.assertEqual(result.y, expected.y) + finally: + python_pytree._deregister_pytree_node(RegisteredContainer) def test_registered_custom_type_falls_back_optree(self) -> None: """Custom types registered with optree should fall back to tracing.""" @@ -829,8 +858,8 @@ def fn(arg): finally: optree.unregister_pytree_node(NamespacedContainer, namespace="namespace_a") - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_dataclass_treated_as_leaf(self, tree_map_name: str, tree_map_impl) -> None: + @parametrize_pytree_module + def test_dataclass_treated_as_leaf(self, pytree_name: str, pytree) -> None: """Dataclasses should be treated as leaves (not registered by default).""" import dataclasses @@ -847,7 +876,7 @@ def mapper(node): return node + 10 def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -857,9 +886,9 @@ def fn(arg): self.assertEqual(result["point"].y, 4) self.assertEqual(result["val"], 13) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) + @parametrize_pytree_module def test_user_defined_object_with_tensor_attribute( - self, tree_map_name: str, tree_map_impl + self, pytree_name: str, pytree ) -> None: """User-defined objects containing tensors should be treated as leaves.""" @@ -878,7 +907,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) @@ -887,10 +916,8 @@ def fn(arg): self.assertTrue(torch.allclose(result["wrapper"].tensor, torch.ones(2, 2) * 2)) self.assertTrue(torch.allclose(result["direct_tensor"], torch.ones(2))) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_user_defined_object_no_fallback( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_user_defined_object_no_fallback(self, pytree_name: str, pytree) -> None: """Verify user-defined objects use fastpath without triggering fallback.""" import logging @@ -906,7 +933,7 @@ def mapper(node): return node + 1 def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) # Capture debug logs to ensure no fallback is triggered log_records = [] @@ -942,8 +969,8 @@ def fn(arg): logger.removeHandler(handler) logger.setLevel(old_level) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_namedtuple_tree_map(self, tree_map_name: str, tree_map_impl) -> None: + @parametrize_pytree_module + def test_namedtuple_tree_map(self, pytree_name: str, pytree) -> None: """Test tree_map with namedtuple uses fast path.""" Point = namedtuple("Point", ["x", "y"]) @@ -958,7 +985,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg) + return pytree.tree_map(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -972,10 +999,8 @@ def fn(arg): self.assertTrue(torch.allclose(result["nested"][0].y, torch.ones(3) * 2 + 1)) _assert_trees_allclose(self, expected, result) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_namedtuple_tree_map_multiple_trees( - self, tree_map_name: str, tree_map_impl - ) -> None: + @parametrize_pytree_module + def test_namedtuple_tree_map_multiple_trees(self, pytree_name: str, pytree) -> None: """Test tree_map with multiple namedtuple trees.""" Point = namedtuple("Point", ["x", "y"]) @@ -988,7 +1013,7 @@ def mapper(a, b): return a def fn(t1, t2): - return tree_map_impl(mapper, t1, t2) + return pytree.tree_map(mapper, t1, t2) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree1, tree2) @@ -999,8 +1024,8 @@ def fn(t1, t2): self.assertTrue(torch.allclose(result["point"].y, torch.ones(2) * 3)) _assert_trees_allclose(self, expected, result) - @parametrize("tree_map_name,tree_map_impl", TREE_MAP_IMPLEMENTATIONS) - def test_namedtuple_with_is_leaf(self, tree_map_name: str, tree_map_impl) -> None: + @parametrize_pytree_module + def test_namedtuple_with_is_leaf(self, pytree_name: str, pytree) -> None: """Test tree_map with namedtuple and is_leaf predicate.""" class Point(NamedTuple): @@ -1021,7 +1046,7 @@ def mapper(node): return node def fn(arg): - return tree_map_impl(mapper, arg, is_leaf=is_leaf_fn) + return pytree.tree_map(mapper, arg, is_leaf=is_leaf_fn) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) @@ -1040,12 +1065,8 @@ def setUp(self): super().setUp() torch._dynamo.reset() - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_basic_nested_tree( - self, tree_map_name: str, tree_map_with_path_impl - ) -> None: + @parametrize_pytree_module_with_path + def test_basic_nested_tree(self, pytree_name: str, pytree) -> None: """Keypaths are correctly constructed for nested dicts, lists, and tuples.""" tree = { "tensor": torch.ones(2), @@ -1060,7 +1081,7 @@ def mapper(kp, x): return x + 1 if isinstance(x, torch.Tensor) else x def fn(arg): - return tree_map_with_path_impl(mapper, arg) + return pytree.tree_map_with_path(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -1077,10 +1098,8 @@ def fn(arg): _assert_trees_allclose(self, expected, result) self.assertEqual(eager_keypaths, compiled_keypaths) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_multiple_trees(self, tree_map_name: str, tree_map_with_path_impl) -> None: + @parametrize_pytree_module_with_path + def test_multiple_trees(self, pytree_name: str, pytree) -> None: """tree_map_with_path with multiple input trees.""" tree1 = {"a": torch.ones(2), "b": [torch.zeros(3)]} tree2 = {"a": torch.ones(2) * 2, "b": [torch.ones(3) * 3]} @@ -1089,19 +1108,15 @@ def mapper(kp, x, y): return x + y def fn(t1, t2): - return tree_map_with_path_impl(mapper, t1, t2) + return pytree.tree_map_with_path(mapper, t1, t2) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree1, tree2) result = compiled(tree1, tree2) _assert_trees_allclose(self, expected, result) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_is_leaf_predicate( - self, tree_map_name: str, tree_map_with_path_impl - ) -> None: + @parametrize_pytree_module_with_path + def test_is_leaf_predicate(self, pytree_name: str, pytree) -> None: """is_leaf stops traversal and passes the subtree as a leaf.""" tree = {"a": [torch.ones(2), torch.zeros(2)]} @@ -1117,7 +1132,7 @@ def mapper(kp, x): return x def fn(arg): - return tree_map_with_path_impl(mapper, arg, is_leaf=is_leaf_fn) + return pytree.tree_map_with_path(mapper, arg, is_leaf=is_leaf_fn) compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -1133,12 +1148,8 @@ def fn(arg): self.assertEqual(len(eager_keypaths), 1) self.assertEqual(eager_keypaths, compiled_keypaths) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_namedtuple_uses_getattr_key( - self, tree_map_name: str, tree_map_with_path_impl - ) -> None: + @parametrize_pytree_module_with_path + def test_namedtuple_uses_getattr_key(self, pytree_name: str, pytree) -> None: """Namedtuple fields produce GetAttrKey in keypaths.""" Point = namedtuple("Point", ["x", "y"]) tree = {"point": Point(torch.ones(2), torch.zeros(2))} @@ -1152,7 +1163,7 @@ def mapper(kp, x): return x def fn(arg): - return tree_map_with_path_impl(mapper, arg) + return pytree.tree_map_with_path(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -1167,12 +1178,8 @@ def fn(arg): _assert_trees_allclose(self, expected, result) self.assertEqual(eager_keypaths, compiled_keypaths) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_deeply_nested_keypaths( - self, tree_map_name: str, tree_map_with_path_impl - ) -> None: + @parametrize_pytree_module_with_path + def test_deeply_nested_keypaths(self, pytree_name: str, pytree) -> None: """Deeply nested structures produce correct multi-level keypaths.""" tree = {"outer": {"inner": [torch.ones(2)]}} @@ -1183,7 +1190,7 @@ def mapper(kp, x): return x * 2 if isinstance(x, torch.Tensor) else x def fn(arg): - return tree_map_with_path_impl(mapper, arg) + return pytree.tree_map_with_path(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) @@ -1200,12 +1207,8 @@ def fn(arg): self.assertEqual(eager_keypaths, compiled_keypaths) _assert_trees_allclose(self, expected, result) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) - def test_keypath_values_used_in_computation( - self, tree_map_name: str, tree_map_with_path_impl - ) -> None: + @parametrize_pytree_module_with_path + def test_keypath_values_used_in_computation(self, pytree_name: str, pytree) -> None: """The map function can use keypath values to influence the result.""" from torch.utils._pytree import MappingKey @@ -1219,18 +1222,16 @@ def mapper(kp, x): return x * 2 def fn(arg): - return tree_map_with_path_impl(mapper, arg) + return pytree.tree_map_with_path(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) expected = fn(tree) result = compiled(tree) _assert_trees_allclose(self, expected, result) - @parametrize( - "tree_map_name,tree_map_with_path_impl", TREE_MAP_WITH_PATH_IMPLEMENTATIONS - ) + @parametrize_pytree_module_with_path def test_user_defined_object_treated_as_leaf( - self, tree_map_name: str, tree_map_with_path_impl + self, pytree_name: str, pytree ) -> None: """Unregistered user-defined objects are leaves in tree_map_with_path.""" @@ -1246,7 +1247,7 @@ def mapper(kp, x): return x + 1 def fn(arg): - return tree_map_with_path_impl(mapper, arg) + return pytree.tree_map_with_path(mapper, arg) compiled = torch.compile(fn, backend="eager", fullgraph=True) result = compiled(tree) diff --git a/test/dynamo/test_unittest.py b/test/dynamo/test_unittest.py index de1364d424d7f..df1d1d7419b77 100644 --- a/test/dynamo/test_unittest.py +++ b/test/dynamo/test_unittest.py @@ -8,10 +8,12 @@ class TestUnittest(torch._dynamo.test_case.TestCase): def setUp(self): + super().setUp() self._prev = torch._dynamo.config.enable_trace_unittest torch._dynamo.config.enable_trace_unittest = True def tearDown(self): + super().tearDown() torch._dynamo.config.enable_trace_unittest = self._prev @make_dynamo_test diff --git a/test/dynamo/test_user_defined_object.py b/test/dynamo/test_user_defined_object.py new file mode 100644 index 0000000000000..eb920a800c4f7 --- /dev/null +++ b/test/dynamo/test_user_defined_object.py @@ -0,0 +1,695 @@ +# Owner(s): ["module: dynamo"] + +import dataclasses +import types +import unittest + +import torch +import torch._dynamo.testing as dynamo_testing +from torch._dynamo.test_case import run_tests, TestCase + + +class SlotsOnly: + __slots__ = ("x", "y") + + def __init__(self, x, y): + self.x = x + self.y = y + + +class SlotsAndDict: + __slots__ = ("x", "__dict__") + + def __init__(self, x): + self.x = x + + +@dataclasses.dataclass(frozen=True, slots=True) +class FrozenSlots: + x: int + y: int + + +class SlotsAndSetattr: + __slots__ = ("x",) + + def __init__(self, x): + self.x = x + + def __setattr__(self, name, value): + object.__setattr__(self, name, value * 2) + + +class SlotsAndDictAndSetattr: + __slots__ = ("x", "__dict__") + + def __init__(self, x): + self.x = x + + def __setattr__(self, name, value): + object.__setattr__(self, name, value * 2) + + +class SlotsBase: + __slots__ = ("x",) + + def __init__(self): + self.x = 0 + + +class SlotsDerived(SlotsBase): + __slots__ = ("y",) + + def __init__(self): + super().__init__() + self.y = 0 + + +class Plain: + pass + + +class SlotsChildOfPlain(Plain): + __slots__ = ("z",) + + def __init__(self): + self.z = 0 + + +class Slots: + __slots__ = ("x",) + + +class SlotsShadowed(SlotsBase): + x = 42 # class attribute shadows parent's slot descriptor + + +class SlotsAndProperty: + __slots__ = ("_x",) + + def __init__(self, x): + self._x = x + + @property + def x(self): + return self._x + + @x.setter + def x(self, value): + self._x = value * 2 + + +class TestSlotsAttrAssignment(TestCase): + """Tests for attribute assignment on objects with __slots__.""" + + def test_valid_slot_assignment(self): + # Case 1: assign to a declared slot — should succeed + def fn(t): + obj = SlotsOnly(1, 2) + obj.x = 99 + return t + obj.x + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_invalid_slot_assignment_raises(self): + # Case 2: assign to an undeclared attr on a slotted object (no __dict__) + # should raise AttributeError in eager; compiled raises an exception too + def fn(t): + obj = SlotsOnly(1, 2) + obj.z = 99 + return t + obj.x + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + t = torch.ones(1) + self.assertRaises(AttributeError, fn, t) + self.assertRaises(Exception, compiled_fn, t) + + def test_slots_with_dict_allows_arbitrary_attrs(self): + # Case 3: __slots__ includes __dict__ — arbitrary attr assignment should work + def fn(t): + obj = SlotsAndDict(1) + obj.extra = 42 + return t + obj.x + obj.extra + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_frozen_dataclass_with_slots_construction(self): + # Case 4: frozen dataclass with slots uses object.__setattr__ in __init__ + # to bypass the frozen __setattr__. Dynamo must allow this for slot descriptors. + def fn(t): + obj = FrozenSlots(3, 4) + return t + obj.x + obj.y + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_custom_setattr_with_slots(self): + # Case 5: __slots__ + custom __setattr__ — the custom __setattr__ is traced + def fn(t): + obj = SlotsAndSetattr(1) + obj.x = 10 + return t + obj.x + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_with_dict_valid_slot_assignment(self): + # Case 6: __slots__ + __dict__: assigning to a declared slot still works + def fn(t): + obj = SlotsAndDict(1) + obj.x = 99 + return t + obj.x + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_with_dict_undeclared_attr_goes_to_dict(self): + # Case 7: __slots__ + __dict__: assigning to an undeclared attr goes to + # __dict__ instead of raising AttributeError + def fn(t): + obj = SlotsAndDict(1) + obj.z = 42 + return t + obj.z + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_custom_setattr_with_slots_and_dict(self): + # Case 8: __slots__ + __dict__ + custom __setattr__ — custom __setattr__ + # is traced for both slot and non-slot attrs + def fn(t): + obj = SlotsAndDictAndSetattr(1) + obj.x = 10 + obj.extra = 3 + return t + obj.x + obj.extra + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_inheritance_parent_and_child_slots(self): + # Subclass adds its own slot on top of parent's slot — both accessible + def fn(t): + obj = SlotsDerived() + obj.x = 1 + obj.y = 2 + return t + obj.x + obj.y + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_child_inherits_dict_from_no_slots_parent(self): + # Subclass with __slots__ inheriting from a parent without __slots__ + # gets __dict__ from the parent, so arbitrary attrs are allowed + def fn(t): + obj = SlotsChildOfPlain() + obj.z = 1 + obj.extra = 42 + return t + obj.z + obj.extra + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_object_setattr_bypasses_custom_setattr(self): + # object.__setattr__ skips the custom __setattr__ and writes directly to slot + def fn(t): + obj = SlotsAndSetattr(1) + object.__setattr__(obj, "x", 5) + return t + obj.x + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_getattr_default_on_unset_slot(self): + # getattr with a default on an unset slot returns the default + def fn(t): + obj = Slots() + val = getattr(obj, "x", 99) + return t + val + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slot_read_after_delete_raises(self): + # Reading a slot after deletion raises AttributeError in both eager and compiled + def fn(t): + obj = Slots() + obj.x = 1 + del obj.x + return t + obj.x + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + t = torch.ones(1) + self.assertRaises(AttributeError, fn, t) + self.assertRaises(Exception, compiled_fn, t) + + def test_slot_shadowed_by_class_attribute(self): + # Class attribute in subclass shadows parent slot descriptor: + # reads return the class attribute, writes raise AttributeError + def fn(t): + obj = SlotsShadowed() + return t + obj.x # returns class attr 42 + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slot_assignment_with_object_as_argument(self): + # Slotted object passed as argument (not created inside fn) + def fn(t, obj): + obj.x = 10 + return t + obj.x + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + t = torch.ones(1) + obj = Slots() + self.assertEqual(fn(t.clone(), obj), compiled_fn(t.clone(), obj)) + + def test_slot_mutation_materialized_on_argument(self): + # Slot mutation on an object passed as argument must be visible after + # the compiled function returns (side effect materialization) + def fn(t, obj): + obj.x = 10 + return t.sin() + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + obj = Slots() + compiled_fn(torch.ones(1), obj) + self.assertEqual(obj.x, 10) + + def test_slot_delete_materialized(self): + # del on a slot inside a compiled fn must be visible after the call returns + def fn(t, obj): + del obj.x + return t.sin() + + compiled_fn = torch.compile(fn, backend="eager", fullgraph=True) + obj = Slots() + obj.x = 1 + compiled_fn(torch.ones(1), obj) + self.assertFalse(hasattr(obj, "x")) + + def test_hasattr_on_slotted_object(self): + # hasattr inside compiled code reflects actual slot state + def fn(t): + obj = Slots() + before = hasattr(obj, "x") # False — slot not set + obj.x = 5 + after = hasattr(obj, "x") # True — slot is now set + return t + before + after + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_with_property_setter(self): + # property setter is called instead of writing directly to the slot + def fn(t): + obj = SlotsAndProperty(1) + obj.x = 5 # calls setter: _x = 5 * 2 = 10 + return t + obj.x # calls getter: returns _x = 10 + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slot_assignment_no_recompile_same_type(self): + # Calling compiled fn repeatedly with the same slotted object type + # must not trigger recompilation + cnts = dynamo_testing.CompileCounter() + + def fn(t, obj): + obj.x = 10 + return t + obj.x + + compiled_fn = torch.compile(fn, backend=cnts) + t = torch.ones(1) + compiled_fn(t, Slots()) + compiled_fn(t, Slots()) + compiled_fn(t, Slots()) + self.assertEqual(cnts.frame_count, 1) + + def test_slot_assignment_recompiles_on_type_change(self): + # Compiled fn sees slot assigned to int first, then float — guards recompile + cnts = dynamo_testing.CompileCounter() + + def fn(t, a, obj): + obj.x = a + return t + obj.x + + compiled_fn = torch.compile(fn, backend=cnts) + t = torch.ones(1) + + compiled_fn(t, 1, Slots()) + compiled_fn(t, 1, Slots()) + self.assertEqual(cnts.frame_count, 1) # same type, no recompile + + x = t.clone() + res = compiled_fn(x, 1.0, Slots()) + self.assertEqual(cnts.frame_count, 2) # float instead of int — recompile + self.assertEqual(res, fn(x, 1.0, Slots())) + + +class WithGetattribute: + # __slots__ = ("x", "_side_effects") + + def __init__(self, x): + object.__setattr__(self, "x", x) + object.__setattr__(self, "_side_effects", set()) + + def __getattribute__(self, name): + effects = object.__getattribute__(self, "_side_effects") + effects.add(name) + return object.__getattribute__(self, name) + + +class TestSlotsFromCPython(TestCase): + """Slot tests extracted from CPython's test_descr.py::test_slots.""" + + def setUp(self): + super().setUp() + self._u_prev = torch._dynamo.config.enable_trace_unittest + torch._dynamo.config.enable_trace_unittest = True + + def tearDown(self): + super().tearDown() + torch._dynamo.config.enable_trace_unittest = self._u_prev + + def test_slots_empty(self): + class C: + __slots__ = [] + + def fn(t): + x = C() + self.assertFalse(hasattr(x, "__dict__")) + self.assertFalse(hasattr(x, "foo")) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_single(self): + class C: + __slots__ = ["a"] + + def fn(t): + x = C() + self.assertFalse(hasattr(x, "__dict__")) + self.assertFalse(hasattr(x, "a")) + x.a = 1 + self.assertEqual(x.a, 1) + x.a = None + self.assertEqual(x.a, None) + del x.a + self.assertFalse(hasattr(x, "a")) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_multiple(self): + class C: + __slots__ = ["a", "b", "c"] + + def fn(t): + x = C() + self.assertFalse(hasattr(x, "__dict__")) + self.assertFalse(hasattr(x, "a")) + self.assertFalse(hasattr(x, "b")) + self.assertFalse(hasattr(x, "c")) + x.a = 1 + x.b = 2 + x.c = 3 + self.assertEqual(x.a, 1) + self.assertEqual(x.b, 2) + self.assertEqual(x.c, 3) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_name_mangling(self): + class C: + __slots__ = ["__a"] + + def __init__(self, value): + self.__a = value + + def get(self): + return self.__a + + def fn(t): + x = C(5) + self.assertFalse(hasattr(x, "__dict__")) + self.assertFalse(hasattr(x, "__a")) + self.assertEqual(x.get(), 5) + try: + x.__a = 6 + except AttributeError: + pass + else: + self.fail("Double underscored names not mangled") + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_string_not_expanded(self): + # A single string is not expanded as a sequence + class C: + __slots__ = "abc" # noqa: PLC0205 + + def fn(t): + c = C() + c.abc = 5 + self.assertEqual(c.abc, 5) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_tuple(self): + slots = ("foo", "bar") + + class C: + __slots__ = slots + + def fn(t): + x = C() + x.foo = 5 + self.assertEqual(x.foo, 5) + self.assertIs(type(slots[0]), str) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_get_unset_raises(self): + class X: + __slots__ = "a" # noqa: PLC0205 + + def fn(t): + with self.assertRaises(AttributeError): + X().a + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_str_subclass(self): + # gh-98783: string subclass in __slots__ + class SubStr(str): # noqa: SLOT000 + pass + + class X: + __slots__ = (SubStr("x"),) + + def fn(t): + X().x = 1 + with self.assertRaises(AttributeError): + X().a + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_special_dict(self): + # __dict__ in __slots__ enables arbitrary attr assignment + class D: + __slots__ = ["__dict__"] + + def fn(t): + a = D() + self.assertTrue(hasattr(a, "__dict__")) + self.assertFalse(hasattr(a, "__weakref__")) + a.foo = 42 + self.assertEqual(a.__dict__, {"foo": 42}) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_special_weakref(self): + # __weakref__ in __slots__ — no __dict__, arbitrary attr raises + class W: + __slots__ = ["__weakref__"] + + def fn(t): + a = W() + self.assertTrue(hasattr(a, "__weakref__")) + self.assertFalse(hasattr(a, "__dict__")) + with self.assertRaises(AttributeError): + a.foo = 42 + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_special_inherit_dict_weakref(self): + # Inheriting from both __dict__ and __weakref__ slot classes + class D: + __slots__ = ["__dict__"] + + class W: + __slots__ = ["__weakref__"] + + class C1(W, D): + __slots__ = [] + + def fn(t): + a = C1() + self.assertTrue(hasattr(a, "__dict__")) + self.assertTrue(hasattr(a, "__weakref__")) + a.foo = 42 + self.assertEqual(a.__dict__, {"foo": 42}) + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + @unittest.expectedFailure + def test_slots_special2_classcell(self): + # Testing __classcell__ in __slots__ + class Meta(type): + def __new__(metacls, name, bases, namespace, attr): + self.assertIn(attr, namespace) + return super().__new__(metacls, name, bases, namespace) + + class C1: + def __init__(self): + self.b = 42 + + class C2(C1, metaclass=Meta, attr="__classcell__"): + __slots__ = ["__classcell__"] + + def __init__(self): + super().__init__() + + def fn(t): + self.assertIsInstance( + C2.__dict__["__classcell__"], types.MemberDescriptorType + ) + c = C2() + self.assertEqual(c.b, 42) + self.assertFalse(hasattr(c, "__classcell__")) + c.__classcell__ = 42 + self.assertEqual(c.__classcell__, 42) + with self.assertRaises(TypeError): + + class C3: + __classcell__ = 42 + __slots__ = ["__classcell__"] + + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_slots_multiple_inheritance(self): + # SF bug 575229: multiple inheritance w/ slots dumps core + class A: + __slots__ = () + + class B: + pass + + class C(A, B): + __slots__ = () + + def fn(t): + self.assertTrue(hasattr(C, "__dict__")) + self.assertTrue(hasattr(C, "__weakref__")) + C().x = 2 + return t.sin() + + dynamo_testing.standard_test(self, fn, nargs=1) + + +class TestUserDefinedClassDict(TestCase): + def test_class_dict_read(self): + class MyClass: + x = 3 + + def fn(t): + t = t + MyClass.__dict__["x"] + t = t + MyClass.__dict__.get("x", 0) + t = t + MyClass.__dict__.get("z", 99) + t = t + (1 if "x" in MyClass.__dict__ else 0) + t = t + (1 if "z" in MyClass.__dict__ else 0) + return t + + dynamo_testing.standard_test(self, fn, nargs=1) + + def test_class_dict_via_arg(self): + class MyClass: + x = 7 + + def fn(t, cls): + return t + cls.__dict__.get("x", 0) + + cnt = dynamo_testing.CompileCounter() + compiled = torch.compile(fn, backend=cnt) + result = compiled(torch.tensor([0.0]), MyClass) + self.assertEqual(result, torch.tensor([7.0])) + + def test_class_dict_mutation_recompiles(self): + # Mutating a class attribute between calls should trigger recompilation, + # and the compiled function should see the updated value. + class MyClass: + x = 1 + + def fn(t): + return t + MyClass.__dict__["x"] + + cnt = dynamo_testing.CompileCounter() + compiled = torch.compile(fn, backend=cnt) + + result1 = compiled(torch.tensor([0.0])) + self.assertEqual(result1, torch.tensor([1.0])) + self.assertEqual(cnt.frame_count, 1) + + MyClass.x = 10 + result2 = compiled(torch.tensor([0.0])) + self.assertEqual(result2, torch.tensor([10.0])) + # Should have recompiled due to guard failure + self.assertEqual(cnt.frame_count, 2) + + def test_class_dict_add_key_recompiles(self): + # Adding a new attribute to the class should trigger recompilation + # when the compiled code checks for key presence. + class MyClass: + x = 1 + + def fn(t): + return t + (1 if "y" in MyClass.__dict__ else 0) + + cnt = dynamo_testing.CompileCounter() + compiled = torch.compile(fn, backend=cnt) + + result1 = compiled(torch.tensor([0.0])) + self.assertEqual(result1, torch.tensor([0.0])) + self.assertEqual(cnt.frame_count, 1) + + MyClass.y = 99 + result2 = compiled(torch.tensor([0.0])) + self.assertEqual(result2, torch.tensor([1.0])) + # Should have recompiled + self.assertEqual(cnt.frame_count, 2) + + def test_class_dict_delete_key_recompiles(self): + # Deleting a class attribute should trigger recompilation. + class MyClass: + x = 5 + y = 10 + + def fn(t): + return t + MyClass.__dict__.get("y", 0) + + cnt = dynamo_testing.CompileCounter() + compiled = torch.compile(fn, backend=cnt) + + result1 = compiled(torch.tensor([0.0])) + self.assertEqual(result1, torch.tensor([10.0])) + self.assertEqual(cnt.frame_count, 1) + + del MyClass.y + result2 = compiled(torch.tensor([0.0])) + self.assertEqual(result2, torch.tensor([0.0])) + # Should have recompiled + self.assertEqual(cnt.frame_count, 2) + + +if __name__ == "__main__": + run_tests() diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 2cd19fb3334d6..ca4f967f9a9f6 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -1,5 +1,6 @@ # Owner(s): ["module: dynamo"] import dataclasses +import json import os import pprint import sys @@ -84,7 +85,6 @@ def test_larger_multiplier_for_even_smaller_tensor(self): @dynamo_config.patch( { "log_compilation_metrics": True, - "inline_inbuilt_nn_modules": False, } ) def test_graph_break_counting(self): @@ -194,7 +194,8 @@ def fn(x): @torch.compile(backend=my_backend) def fn(x): z = x + 1 - y = break_it(z) + with torch._dynamo.disable_nested_graph_breaks(): + y = break_it(z) return y * 2 x = torch.randn(3) @@ -298,6 +299,14 @@ def test_reinplace_counters_use_trigger_name_not_enum_value(self): "Should not use enum value (integer) in key, should use trigger.name instead", ) + def test_get_dynamo_config_for_logging_ignores_logging_functions(self): + with dynamo_config.patch(ignore_logging_functions={print}): + result = utils._get_dynamo_config_for_logging() + parsed = json.loads(result) + + self.assertIsInstance(parsed, dict) + self.assertNotIn("ignore_logging_functions", parsed) + class TestModel(torch.nn.Module): def __init__(self): @@ -445,7 +454,6 @@ def backward(grad_output): @dynamo_config.patch( { "log_compilation_metrics": True, - "inline_inbuilt_nn_modules": False, } ) @inductor_config.patch( @@ -598,7 +606,7 @@ def filter_expected(s: str) -> str: 'pass.pre_grad_passes.apply_gumbel_max_trick_pass': [0.0], 'pass.pre_grad_passes.efficient_conv_bn_eval_pass': [0.0], 'pass.pre_grad_passes.group_batch_fusion_passes': [0.0]}""" - ), # noqa: B950 + ), ) # Now validate utils.calculate_time_spent(). Formatting the return @@ -631,7 +639,7 @@ def filter_expected(s: str) -> str: 'gc': 0.0, 'inductor_compile': 0.0, 'total_wall_time': 0.0}""" - ), # noqa: B950 + ), ) # Now validate the CompilationMetrics logs. We expect a log for the @@ -688,7 +696,7 @@ def filter_expected(s: str) -> str: 'compile_time_autotune_time_us': None, 'compiler_config': None, 'compliant_custom_ops': set(), - 'config_inline_inbuilt_nn_modules': False, + 'config_inline_inbuilt_nn_modules': True, 'config_suppress_errors': False, 'cuda_version': None, 'cudagraph_skip_reason': None, @@ -707,11 +715,11 @@ def filter_expected(s: str) -> str: 'fail_user_frame_lineno': None, 'frame_key': '1', 'gc_time_us': 0, - 'graph_input_count': 1, - 'graph_node_count': 3, + 'graph_input_count': 3, + 'graph_node_count': 5, 'graph_node_shapes': None, 'graph_op_count': 1, - 'guard_count': 10, + 'guard_count': 31, 'has_guarded_code': True, 'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_compile_time_s': 0.0, @@ -781,7 +789,7 @@ def filter_expected(s: str) -> str: 'compile_time_autotune_time_us': None, 'compiler_config': None, 'compliant_custom_ops': set(), - 'config_inline_inbuilt_nn_modules': False, + 'config_inline_inbuilt_nn_modules': True, 'config_suppress_errors': False, 'cuda_version': None, 'cudagraph_skip_reason': None, @@ -800,11 +808,11 @@ def filter_expected(s: str) -> str: 'fail_user_frame_lineno': None, 'frame_key': '1', 'gc_time_us': 0, - 'graph_input_count': 1, - 'graph_node_count': 3, + 'graph_input_count': 3, + 'graph_node_count': 5, 'graph_node_shapes': None, 'graph_op_count': 1, - 'guard_count': 10, + 'guard_count': 31, 'has_guarded_code': True, 'inductor_code_gen_cumulative_compile_time_us': 0, 'inductor_compile_time_s': 0.0, @@ -855,7 +863,7 @@ def filter_expected(s: str) -> str: 'triton_compile_time_us': 0, 'triton_kernel_compile_times_us': None, 'triton_version': None}""" - ), # noqa: B950 + ), ) # Second event is for the backward @@ -888,7 +896,7 @@ def filter_expected(s: str) -> str: 'compile_time_autotune_time_us': None, 'compiler_config': None, 'compliant_custom_ops': None, - 'config_inline_inbuilt_nn_modules': False, + 'config_inline_inbuilt_nn_modules': True, 'config_suppress_errors': False, 'cuda_version': None, 'cudagraph_skip_reason': None, @@ -981,7 +989,7 @@ def filter_expected(s: str) -> str: 'compile_time_autotune_time_us': None, 'compiler_config': None, 'compliant_custom_ops': None, - 'config_inline_inbuilt_nn_modules': False, + 'config_inline_inbuilt_nn_modules': True, 'config_suppress_errors': False, 'cuda_version': None, 'cudagraph_skip_reason': None, @@ -1055,7 +1063,7 @@ def filter_expected(s: str) -> str: 'triton_compile_time_us': 0, 'triton_kernel_compile_times_us': None, 'triton_version': None}""" - ), # noqa: B950 + ), ) @dynamo_config.patch( diff --git a/test/dynamo/test_wrap_inductor_compiled_regions.py b/test/dynamo/test_wrap_inductor_compiled_regions.py index ffe364ded4671..0f2d335adfe30 100644 --- a/test/dynamo/test_wrap_inductor_compiled_regions.py +++ b/test/dynamo/test_wrap_inductor_compiled_regions.py @@ -99,6 +99,33 @@ def fn(x, y): expected = torch.matmul(x, y) self.assertEqual(result, expected) + @requires_cuda_and_triton + def test_wrap_name_visible_in_debug_mode(self): + """Test that named compiled regions surface their name in DebugMode""" + + @torch.compile( + backend="inductor", + options={"wrap_inductor_compiled_regions": True}, + fullgraph=True, + name="flex_attention", + ) + def fn(x, y): + return torch.matmul(x, y) + + x = torch.randn(4, 4, device="cuda") + y = torch.randn(4, 4, device="cuda") + + with DebugMode() as debug_mode: + result = fn(x, y) + + debug_string = debug_mode.debug_string() + + self.assertIn("inductor_compiled_code", debug_string) + self.assertIn("name=flex_attention", debug_string) + + expected = torch.matmul(x, y) + self.assertEqual(result, expected) + @requires_cuda_and_triton def test_wrap_disabled_not_visible_in_debug_mode(self): """Test that compiled regions are not wrapped when option is disabled""" @@ -922,6 +949,60 @@ def checkpointed_fn(x, y): self.assertEqual(x.grad, x_eager.grad) self.assertEqual(y.grad, y_eager.grad) + @requires_cuda_and_triton + def test_sac_outer_compile_inner_name_visible_to_policy(self): + """Test that SAC policies can inspect torch.compile region names""" + + @torch.compile( + backend="inductor", + options={"wrap_inductor_compiled_regions": True}, + fullgraph=True, + name="flex_attention", + ) + def inner_compiled_matmul(x, y): + return torch.matmul(x, y) + + seen_region_names = [] + + def policy_fn(ctx, op, *args, **kwargs): + from torch._higher_order_ops.wrap import inductor_compiled_code + + if op == inductor_compiled_code: + seen_region_names.append(kwargs.get("name")) + return CheckpointPolicy.PREFER_RECOMPUTE + + def checkpointed_fn(x, y): + a = inner_compiled_matmul(x, y) + return torch.relu(a) + + x = torch.randn(4, 4, device="cuda", requires_grad=True) + y = torch.randn(4, 4, device="cuda", requires_grad=True) + + x_eager = x.detach().clone().requires_grad_(True) + y_eager = y.detach().clone().requires_grad_(True) + + context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + + output = checkpoint( + checkpointed_fn, + x, + y, + use_reentrant=False, + context_fn=context_fn, + ) + loss = output.sum() + loss.backward() + + a_eager = torch.matmul(x_eager, y_eager) + b_eager = torch.relu(a_eager) + loss_eager = b_eager.sum() + loss_eager.backward() + + self.assertIn("flex_attention", seen_region_names) + self.assertEqual(output, b_eager) + self.assertEqual(x.grad, x_eager.grad) + self.assertEqual(y.grad, y_eager.grad) + @requires_cuda_and_triton def test_wrap_no_dispatch_mode_no_hop_invoked(self): """ diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_multi_arg b/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_multiset_operations_equivalent_to_set_operations similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_multi_arg rename to test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_multiset_operations_equivalent_to_set_operations diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_no_arg b/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_union similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_no_arg rename to test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_union diff --git a/test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_single_arg b/test/dynamo_expected_failures/CPython313-test_descr-AAAPTypesLongInitTest.test_pytype_long_ready similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_baseexception-ExceptionClassTests.test_interface_single_arg rename to test/dynamo_expected_failures/CPython313-test_descr-AAAPTypesLongInitTest.test_pytype_long_ready diff --git a/test/dynamo_expected_failures/CPython313-test_bool-BoolTest.test_bool_new b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_abstractmethods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_bool-BoolTest.test_bool_new rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_abstractmethods diff --git a/test/dynamo_expected_failures/CPython313-test_bool-BoolTest.test_from_bytes b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_altmro similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_bool-BoolTest.test_from_bytes rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_altmro diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_copy_subclass b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_attr_raise_through_property similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_copy_subclass rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_attr_raise_through_property diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_repr_nonsortable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_basic_inheritance similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_collections-TestCounter.test_repr_nonsortable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_basic_inheritance diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestNamedTuple.test_name_conflicts b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_binary_operator_override similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_collections-TestNamedTuple.test_name_conflicts rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_binary_operator_override diff --git a/test/dynamo_expected_failures/CPython313-test_collections-TestUserObjects.test_list_copy b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_bound_method_repr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_collections-TestUserObjects.test_list_copy rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_bound_method_repr diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_instance_docstring_given_cm_docstring b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_bpo25750 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_contextlib-ContextManagerTestCase.test_instance_docstring_given_cm_docstring rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_bpo25750 diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_slots b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_buffer_inheritance similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_contextlib-TestAbstractContextManager.test_slots rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_buffer_inheritance diff --git a/test/dynamo_expected_failures/CPython313-test_contextlib-TextExitStack.test_exit_exception_with_existing_context b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_builtin_bases similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_contextlib-TextExitStack.test_exit_exception_with_existing_context rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_builtin_bases diff --git a/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_missing b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_builtin_function_or_method similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_missing rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_builtin_function_or_method diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_mixed_set_operations b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_mixed_set_operations rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_items b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre_multi_inherit_invalid similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dictview_set_operations_on_items rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre_multi_inherit_invalid diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre_multi_inherit_valid similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_instance_dict_getattr_str_subclass rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_carloverre_multi_inherit_valid diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys_contained b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classic similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_keys_contained rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classic diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classmethods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_missing rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classmethods diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_tuple_keyerror b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classmethods_in_c similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_tuple_keyerror rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_classmethods_in_c diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_compattr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_bool rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_compattr diff --git a/test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_copy_setstate similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_dict-SubclassMappingTests.test_read rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_copy_setstate diff --git a/test/dynamo_expected_failures/CPython313-test_enum-MiscTestCase.test__all__ b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_cycle_through_dict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-MiscTestCase.test__all__ rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_cycle_through_dict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_and b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_delete_hook similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_and rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_delete_hook diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_member_contains b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dict_constructors similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_member_contains rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dict_constructors diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_member_iter b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dir similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_member_iter rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dir diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dynamics similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_dynamics diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_evil_type_name similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_evil_type_name diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_file_fault similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_file_fault diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_getattr_hooks similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_getattr_hooks diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string_with_start b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_gh55664 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestFlag.test_programatic_function_string_with_start rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_gh55664 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_and b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_hash_inheritance similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_and rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_hash_inheritance diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_global_repr_conform1 b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_init similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_global_repr_conform1 rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_init diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_global_repr_keep b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_instance_method_get_behavior similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_global_repr_keep rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_instance_method_get_behavior diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_invert b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_ipow_returns_not_implemented similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_invert rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_ipow_returns_not_implemented diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_member_iter b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_keyword_arguments similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_member_iter rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_keyword_arguments diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_keywords similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_keywords diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_empty_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_load_attr_extended_arg similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_empty_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_load_attr_extended_arg diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_empty_tuple b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_metaclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_from_empty_tuple rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_metaclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_meth_class_get similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_meth_class_get diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_method_wrapper similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_method_wrapper diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_methods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_methods diff --git a/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string_with_start b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_methods_in_c similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_programatic_function_string_with_start rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_methods_in_c diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestConvert.test_convert_raise b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mixing_slot_wrappers similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestConvert.test_convert_raise rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mixing_slot_wrappers diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestEmptyAndNonLatinStrings.test_non_latin_character_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_module_subclasses similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestEmptyAndNonLatinStrings.test_non_latin_character_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_module_subclasses diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestEmptyAndNonLatinStrings.test_non_latin_number_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestEmptyAndNonLatinStrings.test_non_latin_number_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases_catch_mro_conflict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases_catch_mro_conflict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases_with_failing_mro similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_bases_with_failing_mro diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_names similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_mutable_names diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_no_ipow similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_no_ipow diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_not_implemented similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_not_implemented diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_class similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_class diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_class_assignment_between_heaptypes_and_nonheaptypes similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_class_assignment_between_heaptypes_and_nonheaptypes diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_new similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_new diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_new_and_init_with_parameters similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_object_new_and_init_with_parameters diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_overloading similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_overloading diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_properties similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_properties diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_properties_plus similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_properties_plus diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_shadowed_attr b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_proxy_call similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntEnumFunction.test_shadowed_attr rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_proxy_call diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_proxy_super similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_proxy_super diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_python_lists similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_python_lists diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_qualname similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_qualname diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_recursive_call similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_recursive_call diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_refleaks_in_classmethod___init__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_refleaks_in_classmethod___init__ diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_refleaks_in_staticmethod___init__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_refleaks_in_staticmethod___init__ diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_repr b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_remove_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagClass.test_repr rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_remove_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_repr_as_str similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_repr_as_str diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_repr_with_module_str_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_repr_with_module_str_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_rich_comparisons similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_rich_comparisons diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_and_no_get similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_and_no_get diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_class similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_class diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_dict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_dict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_repr b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_doc similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestIntFlagFunction.test_repr rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_set_doc diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slices similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slices diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slot_shadows_class_variable similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slot_shadows_class_variable diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_descriptor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_descriptor diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_special2 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalDateFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_special2 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_trash similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_slots_trash diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_special_method_lookup similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_special_method_lookup diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_special_unbound_method_types similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_special_unbound_method_types diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_specialized_method_calls_check_types similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_specialized_method_calls_check_types diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_specials similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_specials diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_staticmethods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_staticmethods diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_staticmethods_in_c similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_staticmethods_in_c diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_of_str_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_of_str_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_operations similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_operations diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_subclass_as_dict_key similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_str_subclass_as_dict_key diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subclass_propagation similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subclass_propagation diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subclass_right_op similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMinimalFloatFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subclass_right_op diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subtype_resurrection similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_subtype_resurrection diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_supers similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_supers diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_type___getattribute__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_type___getattribute__ diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_uninitialized_modules similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedDateFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_uninitialized_modules diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_vicious_descriptor_nonsense similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_vicious_descriptor_nonsense diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_weakref_segfault similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_weakref_segfault diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_weakrefs similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_weakrefs diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_wrapper_segfault similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_wrapper_segfault diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_wrong_class_slot_wrapper similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-ClassPropertiesAndMethods.test_wrong_class_slot_wrapper diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-DictProxyTests.test_repr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-DictProxyTests.test_repr diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-MiscTests.test_type_lookup_mro_reference similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-MiscTests.test_type_lookup_mro_reference diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_disappearing_custom_mro similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_disappearing_custom_mro diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_extend similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_extend diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_set_bases_on_self similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_set_bases_on_self diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_super similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_incomplete_super diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_on_base similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedFloatFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_on_base diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_on_direct_base similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_on_direct_base diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_tp_base_cycle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_reent_set_bases_tp_base_cycle diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_tp_subclasses_cycle_error_return_path similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_tp_subclasses_cycle_error_return_path diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_tp_subclasses_cycle_in_update_slots similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-MroTest.test_tp_subclasses_cycle_in_update_slots diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_complexes similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_complexes diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_dicts similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_dicts diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_explicit_reverse_methods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_explicit_reverse_methods diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_floats similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_floats diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_ints similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_ints diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_lists similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_lists diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_spam_dicts similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_spam_dicts diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_spam_lists similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-OperatorsTest.test_spam_lists diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_repr b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_issue24097 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagClass.test_repr rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_issue24097 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_object_reduce similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_object_reduce diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_pickle_slots similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_pickle_slots diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_reduce similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_reduce diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_reduce_copying similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_reduce_copying diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_special_method_lookup similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_descr-PicklingTests.test_special_method_lookup diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_descr-SharedKeyTests.test_subclasses similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_descr-SharedKeyTests.test_subclasses diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_repr b/test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_contains_tf similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFlagFunction.test_repr rename to test/dynamo_expected_failures/CPython313-test_enum-OldTestIntFlag.test_contains_tf diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_211 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_211 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_212 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_212 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_213 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_213 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_214 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_214 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_215 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_215 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_216 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_216 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_217 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_217 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_218 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_218 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_219 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_219 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_220 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_220 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_221 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_221 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_246 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_246 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_247 b/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_247 deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestCopyingEmpty.test_deep_copy b/test/dynamo_expected_failures/CPython313-test_set-TestCopyingEmpty.test_deep_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestCopyingNested.test_deep_copy b/test/dynamo_expected_failures/CPython313-test_set-TestCopyingNested.test_deep_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestCopyingSingleton.test_deep_copy b/test/dynamo_expected_failures/CPython313-test_set-TestCopyingSingleton.test_deep_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestCopyingTriple.test_deep_copy b/test/dynamo_expected_failures/CPython313-test_set-TestCopyingTriple.test_deep_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestCopyingTuple.test_deep_copy b/test/dynamo_expected_failures/CPython313-test_set-TestCopyingTuple.test_deep_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestFrozenSetSubclass.test_intersection deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection b/test/dynamo_expected_failures/CPython313-test_set-TestSetSubclass.test_intersection deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_bad___prepare__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_bad___prepare__ diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_get_original_bases similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_get_original_bases diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_derivation similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_derivation diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_override_callable similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_override_callable diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_override_function similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_metaclass_override_function diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_basics similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedIntFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_basics diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_defaults similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_defaults diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_exec_body similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_exec_body diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_meta similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_meta diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_meta_with_base similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_meta_with_base diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_repr b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_genericalias similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrClass.test_repr rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_genericalias diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_multiple similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_multiple diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_multiple_2 similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_multiple_2 diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_none similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_new_class_with_mro_entry_none diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_one_argument_type similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_one_argument_type diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_prepare_class similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_prepare_class diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_resolve_bases_with_mro_entry similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_resolve_bases_with_mro_entry diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_repr b/test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_tuple_subclass_as_bases similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestMixedStrFunction.test_repr rename to test/dynamo_expected_failures/CPython313-test_types-ClassCreationTests.test_tuple_subclass_as_bases diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_async_def similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_async_def diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_coro similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_coro diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_corogen similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_corogen diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_functional_gen similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_functional_gen diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_gen similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_duck_gen diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_gen similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_gen diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_genfunc similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_genfunc diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_non_gen_values similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_non_gen_values diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_returning_itercoro similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_returning_itercoro diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_wrapper_object similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainEnumFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_wrapper_object diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_wrong_args similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-CoroutineTests.test_wrong_args diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-FunctionTests.test_function_type_defaults similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-FunctionTests.test_function_type_defaults diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-FunctionTests.test_function_type_wrong_defaults similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-FunctionTests.test_function_type_wrong_defaults diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_chainmap similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_chainmap diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_constructor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_constructor diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_customdict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_customdict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_hash similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_hash diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_methods similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_methods diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_missing similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_missing diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_reversed similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestPlainFlagFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_reversed diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_string_list_with_start b/test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_union similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_string_list_with_start rename to test/dynamo_expected_failures/CPython313-test_types-MappingProxyTests.test_union diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_string_with_start b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_as_dict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_string_with_start rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_as_dict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrdel similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrdel diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_from_subclass b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrget similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_from_subclass rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrget diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_from_subclass_with_start b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrset similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_from_subclass_with_start rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_attrset diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_with_start b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_constructor similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestSpecial.test_programmatic_function_type_with_start rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_constructor diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_equal similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_equal diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_fake_namespace_compare similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_fake_namespace_compare diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_nested similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_nested diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_pickle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_pickle diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_recursive similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_recursive diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_recursive_repr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_recursive_repr diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_repr b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_replace similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumClass.test_repr rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_replace diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_dir_on_class b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_replace_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_dir_on_class rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_replace_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_gnv_is_static b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_repr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_gnv_is_static rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_repr diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_from_dict b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_subclass similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_from_dict rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_subclass diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_iterable b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_unbound similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_iterable rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_unbound diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_string b/test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_underlying_dict similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_string rename to test/dynamo_expected_failures/CPython313-test_types-SimpleNamespaceTests.test_underlying_dict diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_string_list b/test/dynamo_expected_failures/CPython313-test_types-SubinterpreterTests.test_static_types_inherited_slots similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_programmatic_function_string_list rename to test/dynamo_expected_failures/CPython313-test_types-SubinterpreterTests.test_static_types_inherited_slots diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_repr b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_call_unbound_crash similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_repr rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_call_unbound_crash diff --git a/test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_shadowed_attr b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_dunder_get_signature similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_enum-TestStrEnumFunction.test_shadowed_attr rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_dunder_get_signature diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float__format__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float__format__ diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float__format__locale similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float__format__locale diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else_mixed2 b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float_to_string similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_else_mixed2 rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_float_to_string diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_mixed1 b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_format_spec_errors similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_nested_mixed1 rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_format_spec_errors diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_int__format__ similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_int__format__ diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_int__format__locale similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_int__format__locale diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else_finally b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_descriptor_crash similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_else_finally rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_descriptor_crash diff --git a/test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_finally b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_descriptor_types similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exception_variations-ExceptStarTestCases.test_try_except_finally rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_descriptor_types diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_raise_does_not_create_context_chain_cycle b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_wrapper_types similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_raise_does_not_create_context_chain_cycle rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_method_wrapper_types diff --git a/test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_error_str_does_not_crash b/test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_traceback_and_frame_types similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_exceptions-ExceptionTests.test_unicode_error_str_does_not_crash rename to test/dynamo_expected_failures/CPython313-test_types-TypesTests.test_traceback_and_frame_types diff --git a/test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_error_message b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_bad_instancecheck similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_float-GeneralFloatCases.test_error_message rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_bad_instancecheck diff --git a/test/dynamo_expected_failures/CPython313-test_functools-TestPartialMethod.test_repr b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_bad_subclasscheck similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_functools-TestPartialMethod.test_repr rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_bad_subclasscheck diff --git a/test/dynamo_expected_failures/CPython313-test_functools-TestSingleDispatch.test_c3_abc b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_hash similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_functools-TestSingleDispatch.test_c3_abc rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_hash diff --git a/test/dynamo_expected_failures/CPython313-test_functools-TestTotalOrdering.test_total_ordering_no_overwrite b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_instancecheck_and_subclasscheck_order similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_functools-TestTotalOrdering.test_total_ordering_no_overwrite rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_instancecheck_and_subclasscheck_order diff --git a/test/dynamo_expected_failures/CPython313-test_heapq-TestErrorHandlingPython.test_iterable_args b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_reference_cycle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_heapq-TestErrorHandlingPython.test_iterable_args rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_reference_cycle diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_disabled_limit b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_Alias similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntStrDigitLimitsTests.test_disabled_limit rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_Alias diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_disabled_limit b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_Literal similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_disabled_limit rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_Literal diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits_edge_cases b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_NewType similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_max_str_digits_edge_cases rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_NewType diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_power_of_two_bases_unlimited b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_SpecialForm similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_power_of_two_bases_unlimited rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_SpecialForm diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_sign_not_counted b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_TypeVar similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntSubclassStrDigitLimitsTests.test_sign_not_counted rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_TypeVar diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_bad_module similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_int_subclass_with_index rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_bad_module diff --git a/test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_forward similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_int-IntTestCases.test_intconversion rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_forward diff --git a/test/dynamo_expected_failures/CPython313-test_math-MathTests.testRemainder b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_genericalias similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_math-MathTests.testRemainder rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_operator_with_genericalias diff --git a/test/dynamo_expected_failures/CPython313-test_operator-COperatorTestCase.test_dunder_is_original b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_repr similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_operator-COperatorTestCase.test_dunder_is_original rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_type_repr diff --git a/test/dynamo_expected_failures/CPython313-test_operator-PyOperatorTestCase.test_dunder_is_original b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_types_operator similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_operator-PyOperatorTestCase.test_dunder_is_original rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_or_types_operator diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_clear b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_args similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_clear rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_args diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setdefault b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_copy similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setdefault rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_copy diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setitem b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_of_unhashable similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_setitem rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_of_unhashable diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_update b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_chaining similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_dict_update rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_chaining diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_substitution similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_bool rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_substitution diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_substitution_errors similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonSubclassMappingTests.test_read rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_parameter_substitution_errors diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_add_after_full b/test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_pickle similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_add_after_full rename to test/dynamo_expected_failures/CPython313-test_types-UnionTests.test_union_pickle diff --git a/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_bool b/test/dynamo_expected_failures/CPython313-test_userdict-UserDictTest.test_bool deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_init b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_init deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_userlist_copy b/test/dynamo_expected_failures/CPython313-test_userlist-UserListTest.test_userlist_copy deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_popitem b/test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_ordered_dict-CSimpleLRUCacheTests.test_popitem rename to test/dynamo_expected_failures/TestCompileTransformsCPU.test_compile_vmap_hessian_cpu diff --git a/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu b/test/dynamo_expected_failures/TestComposabilityCPU.test_requires_grad_inside_transform_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda b/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestInheritance.test_late_registration_mapping b/test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestInheritance.test_late_registration_mapping rename to test/dynamo_expected_failures/TestNN.test_module_apply_inplace_op diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestInheritance.test_late_registration_sequence b/test/dynamo_expected_failures/TestNN.test_to similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestInheritance.test_late_registration_sequence rename to test/dynamo_expected_failures/TestNN.test_to diff --git a/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_modules_decorator_applies_module_and_param_specific_decorators_cpu b/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_modules_decorator_applies_module_and_param_specific_decorators_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_ops_composition_names_cpu b/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_ops_composition_names_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_ops_decorator_applies_op_and_param_specific_decorators_cpu b/test/dynamo_expected_failures/TestTestParametrizationDeviceTypeCPU.test_ops_decorator_applies_op_and_param_specific_decorators_cpu deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestTorch.test_print b/test/dynamo_expected_failures/TestTorch.test_print deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_204 b/test/dynamo_skips/TestLoadStateDict.test_load_state_dict_assign_shape_stride_swap_True similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_204 rename to test/dynamo_skips/TestLoadStateDict.test_load_state_dict_assign_shape_stride_swap_True diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_205 b/test/dynamo_skips/TestLoadStateDict.test_load_state_dict_custom_swap_True similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_205 rename to test/dynamo_skips/TestLoadStateDict.test_load_state_dict_custom_swap_True diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_206 b/test/dynamo_skips/TestNNParametrization.test_weight_norm_state_dict_compat_swap_True similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_206 rename to test/dynamo_skips/TestNNParametrization.test_weight_norm_state_dict_compat_swap_True diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_207 b/test/dynamo_skips/TestStateDictHooks.test_load_state_dict_post_hook_swap_True similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_207 rename to test/dynamo_skips/TestStateDictHooks.test_load_state_dict_post_hook_swap_True diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index 103cb407e078b..63d248c5a1420 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -351,6 +351,8 @@ aten::lt.Tensor aten::lt.Tensor_out aten::lt_.Scalar aten::lt_.Tensor +aten::max_pool2d_with_indices_backward +aten::max_pool2d_with_indices_backward.grad_input aten::maximum aten::maximum.out aten::mean diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 29997570218d2..61882630021d8 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -197,6 +197,8 @@ aten::_foreach_clamp_min.Scalar_out aten::_foreach_clamp_min_.List aten::_foreach_clamp_min_.Scalar aten::_foreach_clamp_min_.ScalarList +aten::_foreach_clone +aten::_foreach_clone.out aten::_foreach_copy aten::_foreach_copy.out aten::_foreach_copy_ @@ -391,6 +393,7 @@ aten::_fw_primal_copy aten::_fw_primal_copy.out aten::_grid_sampler_2d_cpu_fallback aten::_grid_sampler_2d_cpu_fallback.out +aten::_grid_sampler_2d_cpu_fallback_backward aten::_grouped_mm aten::_has_same_storage_numel aten::_histogramdd_bin_edges @@ -507,6 +510,14 @@ aten::_pdist_backward aten::_pdist_backward.out aten::_pdist_forward aten::_pdist_forward.out +aten::_philox_key_fold_in +aten::_philox_key_split +aten::_philox_normal +aten::_philox_normal.out +aten::_philox_normal_ +aten::_philox_uniform +aten::_philox_uniform.out +aten::_philox_uniform_ aten::_pin_memory aten::_pin_memory.out aten::_reshape_alias_copy @@ -649,6 +660,10 @@ aten::_upsample_bilinear2d_aa aten::_upsample_bilinear2d_aa.out aten::_upsample_bilinear2d_aa_backward aten::_upsample_bilinear2d_aa_backward.grad_input +aten::_upsample_lanczos2d_aa +aten::_upsample_lanczos2d_aa.out +aten::_upsample_lanczos2d_aa_backward +aten::_upsample_lanczos2d_aa_backward.grad_input aten::_upsample_nearest_exact1d_backward aten::_upsample_nearest_exact1d_backward.grad_input aten::_upsample_nearest_exact2d_backward @@ -857,10 +872,6 @@ aten::hamming_window.periodic_alpha_beta aten::hamming_window.periodic_alpha_beta_out aten::hamming_window.periodic_alpha_out aten::hamming_window.periodic_out -aten::hann_window -aten::hann_window.out -aten::hann_window.periodic -aten::hann_window.periodic_out aten::hardshrink_backward aten::hardshrink_backward.grad_input aten::hash_tensor @@ -954,8 +965,6 @@ aten::max_pool2d_backward aten::max_pool2d_backward.out aten::max_pool2d_with_indices aten::max_pool2d_with_indices.out -aten::max_pool2d_with_indices_backward -aten::max_pool2d_with_indices_backward.grad_input aten::max_pool3d_with_indices aten::max_pool3d_with_indices.out aten::max_pool3d_with_indices_backward diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 985b353f759d4..b0c6550eca753 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -1,4 +1,4 @@ -torch.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (,), autowrap_functions: Tuple[Callable, ...] = (,), param_shapes_constant: bool = False) -> None +torch.fx._symbolic_trace.Tracer.__init__(self, autowrap_modules: Tuple[Callable] = (,), autowrap_functions: Tuple[Callable[..., Any], ...] = (,), param_shapes_constant: bool = False) -> None torch.fx._symbolic_trace.Tracer.call_module(self, m: torch.nn.modules.module.Module, forward: Callable[..., Any], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any torch.fx._symbolic_trace.Tracer.create_arg(self, a: Any) -> 'Argument' torch.fx._symbolic_trace.Tracer.get_fresh_qualname(self, prefix: str) -> str @@ -6,45 +6,45 @@ torch.fx._symbolic_trace.Tracer.is_leaf_module(self, m: torch.nn.modules.module. torch.fx._symbolic_trace.Tracer.path_of_module(self, mod: torch.nn.modules.module.Module) -> str torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph.Graph torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule -torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable]) -torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None) -torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None, name: Optional[str] = None) -> torch.fx.node.Node -torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node -torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node -torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node -torch.fx.graph.Graph.eliminate_dead_code(self, is_impure_node: Optional[Callable[[torch.fx.node.Node], bool]] = None) -> bool -torch.fx.graph.Graph.erase_node(self, to_erase: torch.fx.node.Node) -> None -torch.fx.graph.Graph.get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> torch.fx.node.Node -torch.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: Dict[torch.fx.node.Node, torch.fx.node.Node], return_output_node = False) -> 'Optional[Argument]' -torch.fx.graph.Graph.inserting_after(self, n: Optional[torch.fx.node.Node] = None) -torch.fx.graph.Graph.inserting_before(self, n: Optional[torch.fx.node.Node] = None) -torch.fx.graph.Graph.lint(self) -torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Callable[[torch.fx.node.Node], Argument] = >) -> torch.fx.node.Node -torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) -torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node -torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False, additional_meta: Optional[List[str]] = None) -> torch.fx.graph.PythonCode -torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') -torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool -torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None -torch.fx.graph_module.GraphModule.delete_submodule(self, target: str) -> bool -torch.fx.graph_module.GraphModule.recompile(self) -> torch.fx.graph.PythonCode -torch.fx.graph_module.reduce_graph_module(body: Dict[Any, Any], import_block: str) -> torch.nn.modules.module.Module -torch.fx.graph_module.reduce_package_graph_module(importer: Callable, body: Dict[Any, Any], generated_module_name: str) -> torch.nn.modules.module.Module -torch.fx.interpreter.Interpreter.__init__(self, module: torch.nn.modules.module.Module, garbage_collect_values: bool = True, graph: Optional[torch.fx.graph.Graph] = None) -torch.fx.interpreter.Interpreter.boxed_run(self, args_list) +torch.fx._symbolic_trace.wrap(fn_or_name: torch.fx.node.Target) -> torch.fx.node.Target +torch.fx.graph.Graph.__init__(self, owning_module: 'Optional[GraphModule]' = None, tracer_cls: 'Optional[type[Tracer]]' = None, tracer_extras: 'Optional[dict[str, Any]]' = None) -> 'None' +torch.fx.graph.Graph.call_function(self, the_function: 'Callable[..., Any]', args: 'Optional[tuple[Argument, ...]]' = None, kwargs: 'Optional[dict[str, Argument]]' = None, type_expr: 'Optional[Any]' = None, name: 'Optional[str]' = None) -> 'Node' +torch.fx.graph.Graph.call_method(self, method_name: 'str', args: 'Optional[tuple[Argument, ...]]' = None, kwargs: 'Optional[dict[str, Argument]]' = None, type_expr: 'Optional[Any]' = None) -> 'Node' +torch.fx.graph.Graph.call_module(self, module_name: 'str', args: 'Optional[tuple[Argument, ...]]' = None, kwargs: 'Optional[dict[str, Argument]]' = None, type_expr: 'Optional[Any]' = None) -> 'Node' +torch.fx.graph.Graph.create_node(self, op: 'str', target: 'Target', args: 'Optional[tuple[Argument, ...]]' = None, kwargs: 'Optional[dict[str, Argument]]' = None, name: 'Optional[str]' = None, type_expr: 'Optional[Any]' = None) -> 'Node' +torch.fx.graph.Graph.eliminate_dead_code(self, is_impure_node: 'Optional[Callable[[Node], bool]]' = None) -> 'bool' +torch.fx.graph.Graph.erase_node(self, to_erase: 'Node') -> 'None' +torch.fx.graph.Graph.get_attr(self, qualified_name: 'str', type_expr: 'Optional[Any]' = None) -> 'Node' +torch.fx.graph.Graph.graph_copy(self, g: 'Graph', val_map: 'dict[Node, Node]', return_output_node: 'bool' = False) -> 'Optional[Argument]' +torch.fx.graph.Graph.inserting_after(self, n: 'Optional[Node]' = None) -> '_InsertPoint' +torch.fx.graph.Graph.inserting_before(self, n: 'Optional[Node]' = None) -> '_InsertPoint' +torch.fx.graph.Graph.lint(self) -> 'None' +torch.fx.graph.Graph.node_copy(self, node: 'Node', arg_transform: 'Callable[[Node], Argument]' = >) -> 'Node' +torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: 'Optional[Any]' = None) +torch.fx.graph.Graph.placeholder(self, name: 'str', type_expr: 'Optional[Any]' = None, default_value: 'Any') -> 'Node' +torch.fx.graph.Graph.print_tabular(self) -> 'None' +torch.fx.graph.Graph.python_code(self, root_module: 'str', verbose: 'bool' = False, include_stride: 'bool' = False, include_device: 'bool' = False, colored: 'bool' = False, expanded_def: 'bool' = False, record_func: 'bool' = False, additional_meta: 'Optional[list[str]]' = None) -> 'PythonCode' +torch.fx.graph_module.GraphModule.__init__(self, root: 'torch.nn.Module | dict[str, Any]', graph: 'Graph', class_name: 'str' = 'GraphModule') -> 'None' +torch.fx.graph_module.GraphModule.add_submodule(self, target: 'str', m: 'torch.nn.Module') -> 'bool' +torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> 'None' +torch.fx.graph_module.GraphModule.delete_submodule(self, target: 'str') -> 'bool' +torch.fx.graph_module.GraphModule.recompile(self) -> 'PythonCode' +torch.fx.graph_module.reduce_graph_module(body: 'dict[str, Any]', import_block: 'str') -> 'torch.nn.Module' +torch.fx.graph_module.reduce_package_graph_module(importer: 'PackageImporter', body: 'dict[str, Any]', generated_module_name: 'str') -> 'torch.nn.Module' +torch.fx.interpreter.Interpreter.__init__(self, module: torch.nn.modules.module.Module, garbage_collect_values: bool = True, graph: Optional[torch.fx.graph.Graph] = None) -> None +torch.fx.interpreter.Interpreter.boxed_run(self, args_list: List[Any]) -> Any torch.fx.interpreter.Interpreter.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Interpreter.call_method(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Interpreter.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -torch.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: torch.fx.node.Node) -> Tuple[Tuple, Dict] -torch.fx.interpreter.Interpreter.fetch_attr(self, target: str) +torch.fx.interpreter.Interpreter.fetch_args_kwargs_from_env(self, n: torch.fx.node.Node) -> Tuple[Tuple[Any, ...], Dict[str, Any]] +torch.fx.interpreter.Interpreter.fetch_attr(self, target: str) -> Any torch.fx.interpreter.Interpreter.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Interpreter.map_nodes_to_values(self, args: torch.fx.node.Argument, n: torch.fx.node.Node) -> torch.fx.node.Argument torch.fx.interpreter.Interpreter.output(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Interpreter.placeholder(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any -torch.fx.interpreter.Interpreter.run(self, *args, initial_env: Optional[Dict[torch.fx.node.Node, Any]] = None, enable_io_processing: bool = True) -> Any +torch.fx.interpreter.Interpreter.run(self, *args: Any, initial_env: Optional[Dict[torch.fx.node.Node, Any]] = None, enable_io_processing: bool = True) -> Any torch.fx.interpreter.Interpreter.run_node(self, n: torch.fx.node.Node) -> Any -torch.fx.interpreter.Transformer.__init__(self, module) +torch.fx.interpreter.Transformer.__init__(self, module: torch.fx.graph_module.GraphModule) -> None torch.fx.interpreter.Transformer.call_function(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Transformer.call_module(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> Any torch.fx.interpreter.Transformer.get_attr(self, target: 'Target', args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, Any]) -> torch.fx.proxy.Proxy @@ -63,15 +63,15 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.passes.reinplace.reinplace(gm, *sample_args) torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None -torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True, partition_affix: Optional[str] = None) -torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) -torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) -torch.fx.proxy.Proxy.keys(self) +torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True, partition_affix: Optional[str] = None, tuple_return: bool = False) -> torch.fx.graph_module.GraphModule +torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) -> None +torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) -> None +torch.fx.proxy.Proxy.keys(self) -> 'Proxy' torch.fx.proxy.TracerBase.create_arg(self, a: Any) -> torch.fx.node.Argument torch.fx.proxy.TracerBase.create_node(self, kind: str, target: torch.fx.node.Target, args: Tuple[torch.fx.node.Argument, ...], kwargs: Dict[str, torch.fx.node.Argument], name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node -torch.fx.proxy.TracerBase.create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Callable[[torch.fx.node.Node], Proxy] = None) +torch.fx.proxy.TracerBase.create_proxy(self, kind: str, target: torch.fx.node.Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], name: Optional[str] = None, type_expr: Optional[Any] = None, proxy_factory_fn: Optional[Callable[[torch.fx.node.Node], Proxy]] = None) -> 'Proxy' torch.fx.proxy.TracerBase.iter(self, obj: 'Proxy') -> Iterator -torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> Any +torch.fx.proxy.TracerBase.keys(self, obj: 'Proxy') -> 'Proxy' torch.fx.proxy.TracerBase.proxy(self, node: torch.fx.node.Node) -> 'Proxy' torch.fx.proxy.TracerBase.to_bool(self, obj: 'Proxy') -> bool -torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable, torch.fx.graph_module.GraphModule], replacement: Union[Callable, torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match] +torch.fx.subgraph_rewriter.replace_pattern(gm: torch.fx.graph_module.GraphModule, pattern: Union[Callable[..., Any], torch.fx.graph_module.GraphModule], replacement: Union[Callable[..., Any], torch.fx.graph_module.GraphModule]) -> List[torch.fx.subgraph_rewriter.Match] diff --git a/test/export/test_converter.py b/test/export/test_converter.py index 0cb48529635cf..47e470bc5f6fa 100644 --- a/test/export/test_converter.py +++ b/test/export/test_converter.py @@ -1366,12 +1366,12 @@ def func1(x, x_list: list[torch.Tensor]): x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2]) return x, x_list - def func2(x): # noqa: F841 + def func2(x): for i in range(x.size(0)): x = x * x * i return x - def func3(x): # noqa: F841 + def func3(x): while x.sum() < 10: x += x.sin() return x diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 32aaa56d69baf..fefd35ad99ead 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -427,7 +427,7 @@ def forward(self, x, y): for node in _ep.graph.nodes: if bindings := node.meta.get("unbacked_bindings"): unbacked_binding_symbols.update(bindings.keys()) - self.assertEqual(len(unbacked_binding_symbols), 2) + self.assertEqual(len(unbacked_binding_symbols), 1) def test_offsets(self): class M(torch.nn.Module): diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 646ec14f9a8c5..949c704e5903d 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: export"] # flake8: noqa +import contextlib import copy import types import unittest @@ -14,7 +15,12 @@ dynamo_graph_capture_for_export, ) from torch._dynamo.test_case import run_tests, TestCase -from torch._functorch.aot_autograd import aot_export_module +from torch._export.utils import _compiling_state_context +from torch._functorch.aot_autograd import ( + aot_export_joint_with_descriptors, + aot_export_module, +) +from torch._guards import tracing as torch_tracing, TracingContext from torch.export import export from torch.export.experimental import _export_forward_backward, _sticky_export from torch.export.graph_signature import OutputKind @@ -1434,8 +1440,14 @@ def forward(self, arg0_1): @unittest.skipIf(not TEST_CUDA, "CUDA not available") def test_aot_export_blockmask_closure_spec_mismatch(self): - """BlockMasks with same closure code but different captured values must - produce different TreeSpecs, so pytree won't confuse them.""" + """BlockMasks with same closure structure produce equal TreeSpecs. + + Closure values are extracted into pytree leaves, so two BlockMasks + whose mask_mod closures have the same code + structure but different + captured values have the same spec (values differ in the leaves, not + the context). BlockMasks with different closure *structure* (e.g. + different code) must still produce different specs. + """ from torch.nn.attention.flex_attention import create_block_mask _register_blockmask_pytree() @@ -1462,8 +1474,154 @@ def fn(b, h, q, k): # Same closure code + same captured value -> same spec self.assertEqual(spec_a, spec_a_same) - # Same closure code + different captured value -> different spec - self.assertNotEqual(spec_a, spec_b) + # Same closure code + different captured value -> same spec + # (values are in the leaves, not the context) + self.assertEqual(spec_a, spec_b) + + # Different closure *code* -> different spec + def different_mask(b, h, q, k): + return q > k + + mask_c = create_block_mask( + different_mask, B=1, H=1, Q_LEN=64, KV_LEN=64, device="cuda" + ) + _, spec_c = pytree.tree_flatten(mask_c) + self.assertNotEqual(spec_a, spec_c) + + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_blockmask_and_masks_closure_extraction(self): + """and_masks closure tensors are recursively extracted into pytree leaves. + + and_masks(fn1, fn2) returns a closure capturing a tuple of functions. + _extract_closure_pytree must recursively process these functions + (extracting their closure tensors) rather than emitting the functions + themselves as leaves, since functions are not supported export input + types. + """ + from torch.nn.attention.flex_attention import ( + and_masks, + BlockMask, + create_block_mask, + ) + + _register_blockmask_pytree() + + def causal(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + + offset = torch.tensor(3, device="cuda") + + def offset_mask(b, h, q_idx, kv_idx): + return q_idx >= kv_idx + offset + + mask_mod = and_masks(causal, offset_mask) + block_mask = create_block_mask( + mask_mod, B=1, H=None, Q_LEN=128, KV_LEN=128, device="cuda" + ) + + leaves, spec = pytree.tree_flatten(block_mask) + + # 8 regular BlockMask tensor attrs + offset extracted from + # offset_mask's closure + n_regular = len(BlockMask._TENSOR_ATTRS) + self.assertEqual(len(leaves), n_regular + 1) + self.assertTrue(all(isinstance(l, torch.Tensor) for l in leaves)) + self.assertIs(leaves[n_regular], offset) + + # Leaves must all pass check_user_input_output + from torch._dynamo.eval_frame import check_user_input_output + from torch._dynamo.exc import UserErrorType + + check_user_input_output(leaves, UserErrorType.INVALID_INPUT) + + # Round-trip: unflatten should reconstruct a working mask_mod + restored = pytree.tree_unflatten(leaves, spec) + self.assertTrue(callable(restored.mask_mod)) + + def test_aot_export_closure_buffer_mutation(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buf", torch.zeros(())) + + def forward(self, x): + self.buf.add_(x.sum()) + return x.sin() + + def make_closure(mod): + def fn(x): + mod._buffers["buf"].add_(x.sum()) + return x.sin() + + return fn + + class Wrapper(torch.nn.Module): + def __init__(self, fn, mod): + super().__init__() + self._parameters = mod._parameters + self._buffers = mod._buffers + self._modules = mod._modules + self._fn = fn + + def forward(self, x): + return self._fn(x) + + def run_export(capture_fn): + mod = Mod() + wrapped = Wrapper(make_closure(mod), mod) + x = torch.randn(4) + gm = capture_fn(wrapped)(x) + + with contextlib.ExitStack() as stack: + stack.enter_context( + torch_tracing( + gm.meta.get( + "tracing_context", TracingContext(gm.meta["fake_mode"]) + ) + ) + ) + stack.enter_context(_compiling_state_context()) + stack.enter_context(gm.meta["fake_mode"]) + + jd = aot_export_joint_with_descriptors( + stack, + gm, + args=(x,), + kwargs={}, + keep_inference_input_mutations=True, + disable_functionalization=True, + ) + return jd.graph_module, wrapped, x + + # Verify Dynamo-captured graph mutates the buffer via closure + mod = Mod() + wrapped = Wrapper(make_closure(mod), mod) + x = torch.randn(4) + gm = dynamo_graph_capture_for_export(wrapped)(x) + wrapped.buf.zero_() + gm(x) + self.assertEqual(wrapped.buf, x.sum()) + + # Verify joint graphs from both APIs match + joint_public, _, _ = run_export(dynamo_graph_capture_for_export) + joint_private, _, _ = run_export(_dynamo_graph_capture_for_export) + self.assertEqual( + str(joint_public.code).strip(), str(joint_private.code).strip() + ) + + # Verify numerical correctness of both joint graphs against eager + mod = Mod() + x = torch.randn(4) + eager_out = mod(x) + eager_buf = mod.buf.clone() + + for label, joint_gm in [("public", joint_public), ("private", joint_private)]: + buf_input = torch.zeros(()) + (exported_out,) = joint_gm(buf_input, x) + self.assertEqual(exported_out, eager_out, msg=f"{label}: output mismatch") + self.assertEqual( + buf_input, eager_buf, msg=f"{label}: buffer mutation mismatch" + ) if __name__ == "__main__": diff --git a/test/export/test_export.py b/test/export/test_export.py index ad4b071d1e8cb..680ddb6607548 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -79,8 +79,8 @@ run_tests, skipIfCrossRef, skipIfRocm, + skipIfTorchDynamo, skipIfXpu, - TEST_TRANSFORMERS, TEST_WITH_CROSSREF, TestCase as TorchTestCase, ) @@ -127,6 +127,8 @@ from torch.export import export +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + torch.library.define("testlib::returns_tensor_symint", "(Tensor x) -> (Tensor, SymInt)") torch.library.define( "testlib::foo", @@ -609,6 +611,136 @@ def forward(self, x, y): inp = ([torch.ones(1, 3)], torch.ones(1, 3)) self._test_export_same_as_eager(f, inp) + @skipIfCrossRef # CrossRefMode interferes with functorch ops + @skipIfTorchDynamo("export inside dynamo is not supported") + def test_gradient_tracking_tensors(self) -> None: + class JVP(torch.nn.Module): + def foo(self, x, r, t) -> torch.Tensor: + return x - 0.1 * r + 0.1 * t + + def forward(self, x, y, r, t, z, o) -> tuple[torch.Tensor, torch.Tensor]: + return torch.func.jvp( + self.foo, + (x, r, t), + (y, z, o), + ) + + inp = ( + torch.rand(2, 4), + torch.rand(2, 4), + torch.rand(2, 1), + torch.rand(2, 1), + torch.zeros(2, 1), + torch.ones(2, 1), + ) + + output_before = JVP()(*inp) + ep = torch.export.export(JVP(), inp) + unf = torch.export.unflatten(ep) + output_after = unf(*inp) + self.assertTrue(torch.allclose(output_after[0], output_before[0])) + self.assertTrue(torch.allclose(output_after[1], output_before[1])) + + @skipIfCrossRef + @skipIfTorchDynamo("export inside dynamo is not supported") + def test_jvp_export_complex_dtype(self) -> None: + class ComplexJVP(torch.nn.Module): + def forward( + self, x: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def fn(x: torch.Tensor) -> torch.Tensor: + return x * x + + return torch.func.jvp(fn, (x,), (v,)) + + inp = ( + torch.randn(3, 3, dtype=torch.complex64), + torch.randn(3, 3, dtype=torch.complex64), + ) + + output_before = ComplexJVP()(*inp) + ep = torch.export.export(ComplexJVP(), inp) + unf = torch.export.unflatten(ep) + output_after = unf(*inp) + self.assertTrue(torch.allclose(output_after[0], output_before[0])) + self.assertTrue(torch.allclose(output_after[1], output_before[1])) + + @skipIfCrossRef + @skipIfTorchDynamo("export inside dynamo is not supported") + def test_jvp_export_inplace_ops(self) -> None: + class InplaceJVP(torch.nn.Module): + def forward( + self, x: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def fn(x: torch.Tensor) -> torch.Tensor: + y = x.clone() + y.mul_(2.0) + y.add_(1.0) + return y + + return torch.func.jvp(fn, (x,), (v,)) + + inp = (torch.randn(4, 4), torch.randn(4, 4)) + + output_before = InplaceJVP()(*inp) + ep = torch.export.export(InplaceJVP(), inp) + unf = torch.export.unflatten(ep) + output_after = unf(*inp) + self.assertTrue(torch.allclose(output_after[0], output_before[0])) + self.assertTrue(torch.allclose(output_after[1], output_before[1])) + + @skipIfCrossRef + @skipIfTorchDynamo("export inside dynamo is not supported") + def test_jvp_export_nested(self) -> None: + class NestedJVP(torch.nn.Module): + def forward( + self, x: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + def outer_fn(x: torch.Tensor) -> torch.Tensor: + def inner_fn(x: torch.Tensor) -> torch.Tensor: + return torch.sin(x) + + primal, tangent = torch.func.jvp(inner_fn, (x,), (x,)) + return primal + tangent + + return torch.func.jvp(outer_fn, (x,), (v,)) + + inp = (torch.randn(3, 3), torch.randn(3, 3)) + + output_before = NestedJVP()(*inp) + ep = torch.export.export(NestedJVP(), inp) + unf = torch.export.unflatten(ep) + output_after = unf(*inp) + self.assertTrue(torch.allclose(output_after[0], output_before[0])) + self.assertTrue(torch.allclose(output_after[1], output_before[1])) + + @skipIfCrossRef + @skipIfTorchDynamo("export inside dynamo is not supported") + def test_jvp_export_multiple_outputs(self) -> None: + class MultiOutputJVP(torch.nn.Module): + def forward( + self, x: torch.Tensor, v: torch.Tensor + ) -> tuple[ + tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] + ]: + def fn(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return torch.sin(x), torch.cos(x) + + return torch.func.jvp(fn, (x,), (v,)) + + inp = (torch.randn(4, 4), torch.randn(4, 4)) + + output_before = MultiOutputJVP()(*inp) + ep = torch.export.export(MultiOutputJVP(), inp) + unf = torch.export.unflatten(ep) + output_after = unf(*inp) + # Check primals + self.assertTrue(torch.allclose(output_after[0][0], output_before[0][0])) + self.assertTrue(torch.allclose(output_after[0][1], output_before[0][1])) + # Check tangents + self.assertTrue(torch.allclose(output_after[1][0], output_before[1][0])) + self.assertTrue(torch.allclose(output_after[1][1], output_before[1][1])) + @testing.expectedFailureStrictV2 @skipIfCrossRef def test_custom_tag_metadata_re_export(self): @@ -646,7 +778,7 @@ def forward(self, x): # should not be copied to other nodes counter = 0 for node in new_ep.graph.nodes: - if "custom" in node.meta: + if "quantization_tag" in node.meta.get("custom", {}): counter += 1 self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertTrue(node.target == torch.ops.aten.linear.default) @@ -2461,6 +2593,49 @@ def auto_dynamic_shapes_from_args(args): # pyre-ignore dynamic_shapes=auto_dynamic_shapes_from_args(sample_input), ).run_decompositions({}) + def test_where_decomp_non_strict_inference_mode_dynamic_shapes(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.where(x > 0) + + test_module = TestModule() + sample_input = (torch.rand(2, 10),) + dynamic_shapes = ({0: Dim("batch_size", max=100)},) + + with torch.inference_mode(): + ep = torch.export.export( + test_module, + sample_input, + strict=False, + dynamic_shapes=dynamic_shapes, + ).run_decompositions({}) + + self.assertEqual(ep.module()(*sample_input), test_module(*sample_input)) + + larger_input = (torch.rand(4, 10),) + self.assertEqual(ep.module()(*larger_input), test_module(*larger_input)) + + def test_where_decomp_non_bool_input(self): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.where(x) + + test_module = TestModule() + sample_input = (torch.tensor([[0.0, 1.0], [2.0, 0.0]]),) + dynamic_shapes = ({0: Dim("batch_size", max=100)},) + + ep = torch.export.export( + test_module, + sample_input, + strict=False, + dynamic_shapes=dynamic_shapes, + ).run_decompositions({}) + + self.assertEqual(ep.module()(*sample_input), test_module(*sample_input)) + + larger_input = (torch.tensor([[0.0, 3.0], [4.0, 0.0], [5.0, 6.0]]),) + self.assertEqual(ep.module()(*larger_input), test_module(*larger_input)) + def test_basic_non_strict_fake_tensor(self): class Basic(torch.nn.Module): def __init__(self) -> None: @@ -10564,6 +10739,28 @@ def forward(self, x): ) self.assertTrue(torch.allclose(core_aten_ep.module()(*inp), m(*inp))) + def test_export_decomps_isin_dynamic(self): + class M(torch.nn.Module): + def forward(self, elements, test_elements): + return torch.isin(elements, test_elements) + + m = M() + elements = torch.tensor([1, 2, 3, 4, 5]) + test_elements = torch.tensor([2, 4]) + inp = (elements, test_elements) + + ep = export( + m, + inp, + dynamic_shapes={ + "elements": {0: Dim("n_elements")}, + "test_elements": {0: Dim("n_test")}, + }, + ) + decomposed = ep.run_decompositions() + + self.assertEqual(decomposed.module()(*inp), m(*inp)) + def test_where_broadcast_preserves_symint(self): import torch.fx.experimental._config as config from torch._dynamo.source import ConstantSource @@ -11033,7 +11230,6 @@ def forward(self, x): export_res = decomposed_ep.module()(x) self.assertTrue(export_res.size() == exp_res.size()) - @skipIfXpu def test_export_with_fake_tensor_inputs_on_cuda_devices(self): fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() @@ -11052,9 +11248,9 @@ def forward(self, x): model = Model() # Manually set the fake_device of fake tensors. - x.fake_device = torch.device("cuda:0") + x.fake_device = torch.device(f"{device_type}:0") for n, p in model.named_parameters(): - p.fake_device = torch.device("cuda:0") + p.fake_device = torch.device(f"{device_type}:0") # Need to set all the requires_grad of tensors to False, because fake_tensor with CUDA device # doesn't quite work well with aot_autograd right now due to some logic fails @@ -17526,6 +17722,25 @@ def forward(self, x, y): expected_mask = torch.ones(3, 5, dtype=torch.bool).triu(diagonal=2) self.assertEqual(eager_out, expected_mask) + def test_quantile_export(self): + class QuantilePair(torch.nn.Module): + def __init__(self, noise=0.1): + super().__init__() + self.noise = noise + + def forward(self, x): + q = torch.tensor( + [self.noise, 1.0 - self.noise], + device=x.device, + dtype=x.dtype, + ) + return torch.quantile(x, q, dim=-1, keepdim=True) + + model = QuantilePair(noise=0.1) + x = torch.randn(1, 3200) + ep = export(model, (x,)) + self.assertEqual(ep.module()(x), model(x)) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): @@ -17564,6 +17779,7 @@ def forward(self, q, k, v): ep = torch.export.export(ScaledDotProductAttention(), (q, k, v)) ep.run_decompositions() + @skipIfXpu(msg="scaled_dot_product_attention issue on xpu") @skipIfCrossRef @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -17588,9 +17804,9 @@ def forward(self, q, k, v): ) return attn_output - q = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") - k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") - v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device="cuda") + q = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device=device_type) + k = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device=device_type) + v = torch.randn(1, 16, 16, 64, dtype=torch.bfloat16, device=device_type) ep = torch.export.export( ScaledDotProductAttention(), (q, k, v) @@ -17791,30 +18007,35 @@ def forward(self, x): add_2: USER_OUTPUT""", ) - @unittest.skipIf(not TEST_TRANSFORMERS, "No transformers") def test_hf_logging_logger(self): - import transformers + # Replicate the HF transformers logging pattern (stdlib logging.Logger + # with a monkey-patched warning_once) without importing transformers, + # whose import can hang in CI on HF Hub I/O. + @functools.lru_cache(None) + def warning_once(self, *args, **kwargs): + self.warning(*args, **kwargs) - logger = transformers.utils.logging.get_logger(__name__) + with patch.object(logging.Logger, "warning_once", warning_once, create=True): + logger = logging.getLogger(__name__) - class M(torch.nn.Module): - def forward(self, x): - logger.warning_once("start") - x1 = x + x - x2 = x1 * x1 - x3 = x2 + x2 - return (x1, x3) + class M(torch.nn.Module): + def forward(self, x): + logger.warning_once("start") + x1 = x + x + x2 = x1 * x1 + x3 = x2 + x2 + return (x1, x3) - gm = export(M(), (torch.randn(3, 3),)).graph_module - self.assertExpectedInline( - gm.code.strip(), - """\ + gm = export(M(), (torch.randn(3, 3),)).graph_module + self.assertExpectedInline( + gm.code.strip(), + """\ def forward(self, x): add = torch.ops.aten.add.Tensor(x, x); x = None mul = torch.ops.aten.mul.Tensor(add, add) add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None return (add, add_1)""", - ) + ) def test_warning(self): class M(torch.nn.Module): diff --git a/test/export/test_export_opinfo.py b/test/export/test_export_opinfo.py index 361674a69c7ae..b33aeb45438a3 100644 --- a/test/export/test_export_opinfo.py +++ b/test/export/test_export_opinfo.py @@ -22,6 +22,7 @@ ) from torch.testing._internal.common_utils import ( IS_FBCODE, + IS_WINDOWS, run_tests, skipIfRocm, TestCase, @@ -152,6 +153,11 @@ class TestExportOnFakeCuda(TestCase): # We set CUDA_VISIBLE_DEVICES="" to simulate a CPU machine with cuda build # Running this on all ops in op_db is too slow, so we only run on a selected subset @onlyCUDA + @unittest.skipIf( + IS_WINDOWS, + 'Subprocess with CUDA_VISIBLE_DEVICES="" imports op_db which triggers ' + "get_device_capability(); 0 devices raises Invalid device id on Windows.", + ) @ops(selected_op_db, allowed_dtypes=(torch.float,)) def test_fake_export(self, device, dtype, op): test_script = f"""\ @@ -218,6 +224,10 @@ def forward(self, *args): self.assertEqual(r, "") @unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build") + @unittest.skipIf( + IS_WINDOWS, + "Failing on Windows, device_count() changes from 0 to 1 ", + ) def test_preserve_original_behavior(self): test_script = f"""\ import torch diff --git a/test/export/test_passes.py b/test/export/test_passes.py index 56f56438776b2..d05d14c3b1783 100644 --- a/test/export/test_passes.py +++ b/test/export/test_passes.py @@ -644,7 +644,7 @@ def forward(self, obj_attr, x): getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(foo = obj_attr, x = add); obj_attr = add = None - return (takes_foo_default,)""", # noqa: B950 + return (takes_foo_default,)""", ) def test_fakify_script_objects(self): @@ -905,7 +905,7 @@ def forward(self, x): sub = torch.ops.aten.sub.Tensor(add_1, 1) sub_1 = torch.ops.aten.sub.Tensor(add_2, 1) return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) - """, # noqa: B950 + """, ) mod_orig, mod, args = self.SET_GRAD_ENABLED_TESTS[ @@ -931,7 +931,7 @@ def forward(self, x): sub = wrap_with_set_grad_enabled_1[0] sub_1 = wrap_with_set_grad_enabled_1[1]; wrap_with_set_grad_enabled_1 = None return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) - """, # noqa: B950 + """, ) def test_sequential_split(self): @@ -1149,7 +1149,7 @@ def forward(self, x): sub = wrap_with_autocast_2[0] sub_1 = wrap_with_autocast_2[1]; wrap_with_autocast_2 = None return pytree.tree_unflatten((add_1, add_2, sub, sub_1), self._out_spec) - """, # noqa: B950 + """, ) self.assertExpectedInline( @@ -1342,7 +1342,7 @@ def forward(self, x): to = torch.ops.aten.to.device(x, 'cuda', torch.float32); x = None add = torch.ops.aten.add.Tensor(to, to); to = None return (add,) - """, # noqa: B950 + """, ) @unittest.skipIf(not TEST_CUDA, "requires cuda") @@ -1364,7 +1364,7 @@ def forward(self, arg0_1): to = torch.ops.aten.to.dtype_layout(arg0_1, dtype = torch.float32, layout = torch.strided, device = 'cuda'); arg0_1 = None add = torch.ops.aten.add.Tensor(to, to); to = None return (add,) - """, # noqa: B950 + """, ) @unittest.skipIf(not TEST_CUDA, "requires cuda") diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index c14775ee48cd4..97a8919e11616 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -24,6 +24,11 @@ from torch.library import wrap_triton from torch.utils._triton import has_triton +else: + + def has_triton(): + return False + import torch import torch._dynamo as torchdynamo @@ -1953,11 +1958,13 @@ def forward(self): roundtrip_ep = deserialize(serialize(ep)) self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) - def test_serialize_float8(self): + def test_serialize_dtypes(self): for dtype in [ torch.float8_e5m2, torch.float8_e4m3fn, torch.float8_e8m0fnu, + torch.uint32, + torch.uint64, ]: class MyModule(torch.nn.Module): diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index adf0986811648..5ed05137dfa16 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -198,7 +198,7 @@ def forward(self, token, obj_attr, x, n): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) def test_method_schema(self): @@ -246,7 +246,7 @@ def forward(self, token, obj_attr, x): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) @parametrize("pre_dispatch", [True, False]) @@ -281,7 +281,7 @@ def forward(self, token, obj_attr, x): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) @parametrize("pre_dispatch", [True, False]) @@ -316,7 +316,7 @@ def forward(self, token, x, cc): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) # aot_export_function runs the program twice # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function @@ -379,7 +379,7 @@ def forward(self, token, x, cc): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) @parametrize("pre_dispatch", [True, False]) @@ -464,7 +464,7 @@ def forward(self, x): takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default); attr = takes_foo_default = None add = torch.ops.aten.add.Tensor(x, takes_foo_default_1); x = takes_foo_default_1 = None - return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 + return pytree.tree_unflatten((add,), self._out_spec)""", ) self.assertExpectedInline( ep.graph_module.code.strip(), @@ -477,7 +477,7 @@ def forward(self, token, obj_attr, x): getitem_2 = with_effects_1[0] getitem_3 = with_effects_1[1]; with_effects_1 = None add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None - return (getitem_2, add)""", # noqa: B950 + return (getitem_2, add)""", ) @parametrize("pre_dispatch", [True, False]) @@ -530,7 +530,7 @@ def forward(self, token, obj_attr, x): getitem_5 = with_effects_1[0] getitem_6 = with_effects_1[1]; with_effects_1 = None add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None - return (getitem_5, add_2)""", # noqa: B950 + return (getitem_5, add_2)""", ) @parametrize("pre_dispatch", [True, False]) @@ -578,7 +578,7 @@ def forward(self, token, obj_attr, x): getitem_3 = with_effects_1[0] getitem_4 = with_effects_1[1]; with_effects_1 = None add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None - return (getitem_3, add_1)""", # noqa: B950 + return (getitem_3, add_1)""", ) @parametrize("pre_dispatch", [True, False]) @@ -777,7 +777,7 @@ def forward(self, token, p_linear_weight, p_linear_bias, tq, x): getitem_8 = with_effects_4[0]; with_effects_4 = None add_2 = torch.ops.aten.add.Tensor(getitem_7, 0); getitem_7 = None add_3 = torch.ops.aten.add.Tensor(add_2, x); add_2 = x = None - return (getitem_8, add_3, add_1, tq)""", # noqa: B950 + return (getitem_8, add_3, add_1, tq)""", ) self.assertEqual(tq.size(), 2) self.assertTrue(tq.pop() is a) @@ -818,7 +818,7 @@ def forward(self, token, safe_obj): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None sin = torch.ops.aten.sin.default(getitem_1); getitem_1 = None - return (getitem, sin)""", # noqa: B950 + return (getitem, sin)""", ) def test_identifying_torchbind_ops(self): @@ -1050,7 +1050,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): with_effects_5 = torch.ops.higher_order.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None getitem_10 = with_effects_5[0]; with_effects_5 = None add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None - return (getitem_10, sub, add, arg1_1)""", # noqa: B950 + return (getitem_10, sub, add, arg1_1)""", ) def test_export_inplace_custom_op(self): @@ -1081,7 +1081,7 @@ def forward(self, tq, x): def forward(self, token, tq, x): with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.queue_push.default, tq, x); token = x = None getitem = with_effects[0]; with_effects = None - return (getitem, tq)""", # noqa: B950 + return (getitem, tq)""", ) self.assertExpectedInline( str(ep.graph_module.graph).strip(), @@ -1092,7 +1092,7 @@ def forward(self, token, tq, x): %x : [num_users=1] = placeholder[target=x] %with_effects : [num_users=1] = call_function[target=torch.ops.higher_order.with_effects](args = (%token, _TorchScriptTesting.queue_push.default, %tq, %x), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) - return (getitem, tq)""", # noqa: B950 + return (getitem, tq)""", ) def test_deepcopy(self): @@ -1421,7 +1421,7 @@ def forward(self, token, obj, x): with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, obj, 'get'); getitem = obj = None getitem_2 = with_effects_1[0]; with_effects_1 = None add_1 = torch.ops.aten.add.Tensor(add, 3); add = None - return (getitem_2, add_1)""", # noqa: B950 + return (getitem_2, add_1)""", ) self.assertEqual(eager_out, compiled_out) self.assertEqual(eager_out, ep.module()(test_obj, x)) @@ -1634,7 +1634,7 @@ def __init__(self, x, y): self.x = x self.y = y - def __obj_unflatten__(cls, flattend_foo): # noqa: B902 + def __obj_unflatten__(cls, flattend_foo): return cls(**dict(flattend_foo)) def test_register_fake_class_valid(self): diff --git a/test/export/test_unflatten.py b/test/export/test_unflatten.py index 13b234c173e5b..b62b31ea3ea8d 100644 --- a/test/export/test_unflatten.py +++ b/test/export/test_unflatten.py @@ -178,6 +178,48 @@ def forward(self, x): id(getattr(unflattened_module.sub_net, "2")), ) + def test_unflatten_shared_submodule_reorder(self): + """Test that _reorder_submodules handles @N-suffixed FQNs correctly. + + When modules are shared (aliased), PyTorch assigns @N suffixes to + duplicate entries in _modules. The fqn_order dict (built from + fqn_list) filters out @N entries, so _reorder_submodules must fall + back to the base FQN for ordering. + """ + + class Block(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x): + return self.linear(x) + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + shared_block = Block() + self.blocks = torch.nn.Sequential( + shared_block, + torch.nn.ReLU(), + shared_block, + torch.nn.ReLU(), + ) + + def forward(self, x): + return self.blocks(x) + + eager_module = Model() + inps = (torch.rand(2, 10),) + export_module = export(eager_module, inps, {}, strict=True) + unflattened_module = unflatten(export_module) + self.compare_outputs(eager_module, unflattened_module, inps) + # Verify shared identity is preserved + self.assertEqual( + id(getattr(unflattened_module.blocks, "0")), + id(getattr(unflattened_module.blocks, "2")), + ) + def test_assert_tensor_metadata_stack(self): class N(torch.nn.Module): def __init__(self): diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 5a962dfa57c05..64495d1ae303e 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -63,6 +63,7 @@ ("prim::ModuleDictIndex", datetime.date(9999, 1, 1)), ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)), ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)), + ("onednn::qconv2d_pointwise", datetime.date(2026, 5, 1)), ("prim::is_ort", datetime.date(9999, 1, 1)), ("prim::Concat", datetime.date(9999, 1, 1)), ("aten::_NestedTensor_GeneralizedBMM", datetime.date(9999, 1, 1)), @@ -301,7 +302,7 @@ def check_bc(existing_schemas): log.warning( "Can NOT find backward compatible schemas after changes " "for schema %s from the following candidates:\n[\n%s\n]", - str(existing_schema), + existing_schema, "\n\t".join(str(s) for s in matching_new_schemas), ) # TODO Print out more details about why candidates don't match. @@ -346,7 +347,7 @@ def check_fc(existing_schemas): log.warning( "Can NOT find forward compatible schemas after changes " "for schema %s from the following candidates:\n[\n\t%s\n]", - str(existing_schema), + existing_schema, "\n\t".join(str(s) for s in matching_new_schemas), ) log.warning( diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index 2d2899e9ca297..7866f1454b2f9 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -128,7 +128,7 @@ def test_recomputable_node_only_graph(self): ) def test_recomputable_node_only_graph_with_larger_graph_context(self): - recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950 + recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context expected_nodes = self.all_recomputable_banned_nodes # node1 does not have an indirect path to node5 because of node2 # node2 has an indirect path to node5 diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 13f2318a0f59c..0079d6e2948ab 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -289,7 +289,7 @@ def forward( None, # None None, # None ], self._out_spec) -""", # noqa: B950 +""", ) # Compile the result @@ -475,7 +475,7 @@ def forward( as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear2.bias')) None, # None ], self._out_spec) -""", # noqa: B950 +""", ) # Compile the result @@ -1237,7 +1237,7 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[4, 3]"): cos: "f32[4, 3]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None mul_2: "f32[4, 3]" = torch.ops.aten.mul.Tensor(mul_1, cos); mul_1 = cos = None return (mul_2, mul) -""", # noqa: B950 +""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1254,7 +1254,7 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[4, 3]"): ('get_attr', 'repeated_subgraph1', {'mod_name': 'my_mod'}) [('placeholder', 'arg0_1', {'mod_name': 'my_mod'}), ('placeholder', 'arg1_1', {'mod_name': 'my_mod'}), ('call_function', 'sin', {'mod_name': 'bar'}), ('call_function', 'mul', {'mod_name': 'bar'}), ('call_function', 'mul_1', {'mod_name': 'bar'}), ('call_function', 'cos', {'mod_name': 'bar'}), ('call_function', 'mul_2', {'mod_name': 'bar'}), ('output', 'output', {'mod_name': 'my_mod'})] ('call_function', 'invoke_subgraph_1', {'mod_name': 'my_mod'}) -('call_function', 'getitem_1', {'mod_name': 'my_mod'})""", # noqa: B950 +('call_function', 'getitem_1', {'mod_name': 'my_mod'})""", ) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 834d756db3368..1f5f4dd20d909 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -11,6 +11,7 @@ import operator import unittest import warnings +import weakref from collections.abc import Callable from contextlib import ContextDecorator, ExitStack, nullcontext from functools import partial, wraps @@ -1017,7 +1018,7 @@ def f(x): """ During the backward, we encountered a tensor subclass where we guessed its metadata incorrectly. -""", # noqa: F541 +""", ): new_out.sum().backward() @@ -3246,7 +3247,7 @@ def forward(self, primals_1): as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None return (as_strided_scatter, add_1)""", - ) # noqa: B950 + ) def test_input_mutation_aliases_other_input2(self): def f(a, b): @@ -3279,7 +3280,7 @@ def forward(self, primals_1): as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None return (as_strided_scatter, add_1)""", - ) # noqa: B950 + ) def test_input_mutation_aliases_and_output_alias(self): def f(a, b): @@ -3310,7 +3311,7 @@ def forward(self, primals_1): as_strided_9 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_1 = torch.ops.aten.view.default(as_strided_9, [4]); as_strided_9 = None return (as_strided_scatter, view_1)""", - ) # noqa: B950 + ) def test_input_aliased_with_mutation_output_alias(self): def f(a, b, c): @@ -3347,7 +3348,7 @@ def forward(self, primals_1, primals_2): as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None return (as_strided_scatter, add, view_1)""", - ) # noqa: B950 + ) def test_input_metadata_mutation_aliases(self): def f(a, b): @@ -3419,7 +3420,7 @@ def forward(self, primals_1, primals_2): add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None return (as_strided_scatter, add, add_1)""", - ) # noqa: B950 + ) @skipIfDynamoInput("Fails with dynamo") def test_input_mutation_aliases_bases_out_of_order(self): @@ -3478,7 +3479,7 @@ def forward(self, primals_1, primals_2, primals_3): as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None return (as_strided_scatter, add_2, view_2, unsqueeze)""", - ) # noqa: B950 + ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") def test_synthetic_base_base_attribute_is_none(self): @@ -3552,7 +3553,7 @@ def forward(self, primals_1, primals_2): t_1 = torch.ops.aten.t.default(t) unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""", - ) # noqa: B950 + ) def test_dynamic_shape_output_not_in_bw_graph(self): def f(x): @@ -3792,7 +3793,7 @@ def forward(self, x, y): self.assertExpectedRaisesInline( AssertionError, lambda: fxx(x, y), - """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 + """At compilation time, graph 2 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", ) @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @@ -3848,7 +3849,7 @@ def forward(self, x, y): self.assertExpectedRaisesInline( AssertionError, lambda: fxx(x, y), - """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 + """At compilation time, graph 1 was compiled under the assumption that input 1 would be a duplicate of input 0, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", ) @patch("torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count) @@ -3897,7 +3898,7 @@ def forward(self, x, y): self.assertExpectedRaisesInline( AssertionError, lambda: fxz(x, y), - """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", # noqa: B950 + """At compilation time, graph 1 was compiled under the assumption that input 1 would not require grad, but at runtime this was not the case. This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch.""", ) def test_custom_autograd(self): @@ -4216,7 +4217,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = copy_ = None copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = copy__1 = None copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = copy__2 = None - return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", # noqa: B950 + return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", ) self.assertEqual(out_ref, out_test) @@ -4236,7 +4237,7 @@ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem getitem_5 = native_batch_norm_backward[0] getitem_6 = native_batch_norm_backward[1] getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None - return (getitem_6, getitem_7, None, None, None, getitem_5)""", # noqa: B950 + return (getitem_6, getitem_7, None, None, None, getitem_5)""", ) self.assertEqual(inp_ref.grad, inp_test.grad) @@ -4278,7 +4279,7 @@ def forward(self, primals_1, primals_2): clone = torch.ops.aten.clone.default(primals_1); primals_1 = None add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None return (add, add)""", - ) # noqa: B950 + ) self.assertEqual(out_ref, out_test) @@ -4290,7 +4291,7 @@ def forward(self, primals_1, primals_2): """\ def forward(self, tangents_1): return (None, tangents_1)""", - ) # noqa: B950 + ) def test_real_weights_in_symbolic_mode(self): from functorch.experimental import functionalize @@ -4733,6 +4734,135 @@ def f(x): out = f(inp) self.assertEqual(out.stride(), inp.stride()) + def _make_model_and_input(self, hidden=1024, vocab=4096, seq_len=128): + """Build a small model that produces non-scalar output (like an LM head).""" + model = nn.Sequential( + nn.Linear(hidden, hidden, bias=False), + nn.ReLU(), + nn.Linear(hidden, hidden, bias=False), + nn.ReLU(), + nn.Linear(hidden, vocab, bias=False), + ).cuda() + x = torch.randn(seq_len, hidden, device="cuda", requires_grad=True) + labels = torch.randint(0, vocab, (seq_len,), device="cuda") + return model, x, labels + + def _make_refcount_probe(self): + """Create a boxed-grads probe that checks the framework holds no extra refs. + + The probe uses boxed_grads_call=True so PyNode::apply moves grads + into a mutable list (same mechanism as CompiledFunction). After + removing the grad from the list, a weakref must become dead — proving + the framework released all its refs. + + Returns (ProbeClass, result_box). After backward, + result_box["ref_dead"] is True if no extra refs were held.""" + result_box = {} + + class RefcountProbe(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, pred): + return pred.clone() + + @staticmethod + def backward(ctx, grads): + grad = grads.pop(0) + grad_for_return = grad.clone() + ref = weakref.ref(grad) + del grad + result_box["ref_dead"] = ref() is None + return grad_for_return + + return RefcountProbe, result_box + + def _assert_no_extra_refs(self, result_box): + """Assert the framework holds no extra refs to the grad tensor.""" + self.assertIn("ref_dead", result_box) + self.assertTrue( + result_box["ref_dead"], + "Framework holds extra refs to tangent: weakref still alive after " + "removing all user-visible references", + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + def test_tangent_freed_compiled_model_and_loss(self): + """Scenario: compiled model + compiled loss (torchtitan simple_fsdp pattern). + + Mirrors torchtitan's simple_fsdp training loop: + pred = compiled_model(input) # torch.compile(model) + loss = compiled_loss(pred, labels) # torch.compile(loss_fn) + loss.backward() + + The tangent for model backward is created internally by loss backward. + No user variable holds it — only the C++ pyInputs tuple on the stack. + The boxed calling convention alone frees it.""" + model, x, labels = self._make_model_and_input() + compiled_model = torch.compile(model, backend="inductor") + + def loss_fn(pred, labels): + return torch.nn.functional.cross_entropy(pred.float(), labels) + + compiled_loss = torch.compile(loss_fn, backend="inductor") + + Probe, refcount_box = self._make_refcount_probe() + pred = compiled_model(x) + pred = Probe.apply(pred) + loss = compiled_loss(pred, labels) + del pred + loss.backward() + + self._assert_no_extra_refs(refcount_box) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + def test_tangent_freed_compiled_model_eager_loss(self): + """Scenario: compiled model + eager loss (common training pattern). + + Similar to torchtitan but loss is not compiled. The tangent is + created by eager cross_entropy backward. Boxed convention frees it.""" + model, x, labels = self._make_model_and_input() + compiled_model = torch.compile(model, backend="inductor") + + Probe, refcount_box = self._make_refcount_probe() + pred = compiled_model(x) + pred = Probe.apply(pred) + loss = torch.nn.functional.cross_entropy(pred.float(), labels) + del pred + loss.backward() + + self._assert_no_extra_refs(refcount_box) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") + def test_tangent_user_provided_via_pop(self): + """Scenario: user creates tangent but releases it before backward runs. + + The user wraps the tangent in a list and calls out.backward(l.pop()). + After pop(), no Python variable holds the tangent — same as torchtitan. + The boxed calling convention should free it.""" + model, x, _labels = self._make_model_and_input() + compiled_model = torch.compile(model, backend="inductor") + + def loss_fn(pred, labels): + return torch.nn.functional.cross_entropy(pred.float(), labels) + + compiled_loss = torch.compile(loss_fn, backend="inductor") + _, _, labels = self._make_model_and_input() + + Probe, refcount_box = self._make_refcount_probe() + pred = compiled_model(x) + pred = Probe.apply(pred) + loss = compiled_loss(pred, labels) + del pred + + # Wrap loss in a list and use pop() — no user variable holds loss + # after this. The tangent created by loss backward has no user ref. + loss_list = [loss] + del loss + loss_list.pop().backward() + + self._assert_no_extra_refs(refcount_box) + def extract_graph(fx_g, _, graph_cell): graph_cell[0] = fx_g @@ -5088,7 +5218,7 @@ def forward(self, arg0_1): getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None - return (add, add_1)""", # noqa: B950 + return (add, add_1)""", ) self.assertExpectedInline( @@ -5105,7 +5235,7 @@ def forward(self, arg0_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (cos_1,)); gt = true_graph_0 = false_graph_0 = cos_1 = None getitem = cond[0]; cond = None - return (getitem,)""", # noqa: B950 + return (getitem,)""", ) self.assertExpectedInline( @@ -5227,7 +5357,7 @@ def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"): add_1: "f32[2]" = torch.ops.aten.add.Tensor(add, arg1_1); add = arg1_1 = None return (add_1,) -""", # noqa: B950 +""", ) def test_aot_export_predispatch_map_2(self): @@ -5317,7 +5447,7 @@ def forward(self, arg0_1): getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None - return (add, add_1)""", # noqa: B950 + return (add, add_1)""", ) self.assertExpectedInline( str(gm.true_graph_0.code).strip(), @@ -5358,7 +5488,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 getitem = _native_batch_norm_legit_functional[0] getitem_3 = _native_batch_norm_legit_functional[3] getitem_4 = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None - return (getitem_3, getitem_4, add, getitem)""", # noqa: B950 + return (getitem_3, getitem_4, add, getitem)""", ) def test_aot_export_predispatch_reshape(self): @@ -5378,7 +5508,7 @@ def forward(self, arg0_1): view = torch.ops.aten.view.default(arg0_1, [4, 4]); arg0_1 = None sum_1 = torch.ops.aten.sum.default(view); view = None return (sum_1,)""", - ) # noqa: B950 + ) def test_aot_export_predispatch_contiguous(self): class Cont(torch.nn.Module): @@ -5396,7 +5526,7 @@ def forward(self, x): def forward(self, arg0_1): sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None return (sum_1,)""", - ) # noqa: B950 + ) def test_aot_export_module_joint(self): class ConvBatchnormRelu(torch.nn.Module): @@ -5466,7 +5596,7 @@ def forward( getitem_9: "f32[3, 1, 1, 1]" = convolution_backward[1] getitem_10: "f32[3]" = convolution_backward[2]; convolution_backward = None return (getitem_3, getitem_4, add, sum_1, detach_2, getitem_9, getitem_10, getitem_6, getitem_7) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -5481,19 +5611,19 @@ def forward( self.assertExpectedInline( str(signature.inputs_to_parameters), """{'arg0_1': 'conv.weight', 'arg1_1': 'conv.bias', 'arg2_1': 'bn.weight', 'arg3_1': 'bn.bias'}""", - ) # noqa: B950 + ) self.assertExpectedInline( str(signature.inputs_to_buffers), """{'arg4_1': 'bn.running_mean', 'arg5_1': 'bn.running_var', 'arg6_1': 'bn.num_batches_tracked'}""", - ) # noqa: B950 + ) self.assertExpectedInline( str(signature.buffers_to_mutate), """{'getitem_3': 'bn.running_mean', 'getitem_4': 'bn.running_var', 'add': 'bn.num_batches_tracked'}""", - ) # noqa: B950 + ) self.assertExpectedInline( str(signature.backward_signature.gradients_to_parameters), """{'getitem_9': 'conv.weight', 'getitem_10': 'conv.bias', 'getitem_6': 'bn.weight', 'getitem_7': 'bn.bias'}""", - ) # noqa: B950 + ) self.assertExpectedInline( str(signature.backward_signature.gradients_to_user_inputs), """{}""" ) @@ -5540,7 +5670,7 @@ def forward( sum_1, # PlainAOTOutput(idx=0) detach, # PlainAOTOutput(idx=1) ) -""", # noqa: B950 +""", ) # Some important characteristics of the exported graph below: # 8 arguments: 2 params from conv, 2 params from batchnorm, 2 buffers from 1 batchnorm, 1 user input @@ -5620,7 +5750,7 @@ def forward(self, arg0_1, arg1_1): sum_2 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add, add_1)""", - ) # noqa: B950 + ) self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg1_1"}) def test_aot_export_forward_mutation_multiple_mut(self): @@ -5653,7 +5783,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): add_2 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None sum_3 = torch.ops.aten.sum.default(add_1) return (add_1, add, add_2, sum_3)""", - ) # noqa: B950 + ) self.assertEqual(sig.user_inputs_to_mutate, {"add": "arg2_1"}) self.assertEqual(sig.buffers_to_mutate, {"add_1": "buffer1"}) @@ -5784,7 +5914,7 @@ def forward(self, arg0_1): getitem = cond[0]; cond = None add = torch.ops.aten.add.Tensor(getitem, 3) add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None - return (add, add_1)""", # noqa: B950 + return (add, add_1)""", ) self.assertExpectedInline( @@ -5851,7 +5981,7 @@ def forward(self): full = torch.ops.aten.full.default([], 11, device = device(type='cpu'), pin_memory = False) _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(full); full = None full_1 = torch.ops.aten.full.default([_local_scalar_dense], 0, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None - return (full_1,)""", # noqa: B950 + return (full_1,)""", ) def test_aot_export_input_mutation(self): @@ -6909,7 +7039,6 @@ def _(x: torch.Tensor) -> torch.Tensor: subclass_inp_meta=[], subclass_fw_graph_out_meta=[], subclass_tangent_meta=[], - is_train=False, traced_tangents_descs=[], ) meta.tokens = {EffectType.ORDERED: torch.tensor([])} @@ -6938,6 +7067,148 @@ def _(x: torch.Tensor) -> torch.Tensor: finally: handle.destroy() + def test_collect_metadata_subclass_fw_outs_follow_input_mutation_type(self): + from torch._functorch._aot_autograd.collect_metadata_analysis import ( + run_functionalized_fw_and_collect_metadata, + ) + from torch._functorch._aot_autograd.descriptors import PlainAOTInput + from torch._functorch._aot_autograd.schemas import SubclassCreationMeta + + def f(x): + x.add_(1) + return [torch.sin(x)] + + fake_mode = FakeTensorMode() + subclass_arg = TwoTensor( + fake_mode.from_tensor(torch.ones(2)), + fake_mode.from_tensor(torch.ones(2)), + ) + + keep_input_mutations_meta = run_functionalized_fw_and_collect_metadata( + f, + flat_args_descs=[PlainAOTInput(0)], + keep_input_mutations=True, + static_input_indices=[], + )(subclass_arg) + self.assertEqual(keep_input_mutations_meta.mutated_inp_runtime_indices, []) + self.assertEqual(len(keep_input_mutations_meta.subclass_fw_graph_out_meta), 1) + self.assertIsInstance( + keep_input_mutations_meta.subclass_fw_graph_out_meta[0], + SubclassCreationMeta, + ) + self.assertEqual( + keep_input_mutations_meta.subclass_fw_graph_out_meta[ + 0 + ].flat_tensor_start_idx, + 0, + ) + + out_of_graph_mutation_meta = run_functionalized_fw_and_collect_metadata( + f, + flat_args_descs=[PlainAOTInput(0)], + keep_input_mutations=False, + static_input_indices=[], + )(subclass_arg) + self.assertEqual(out_of_graph_mutation_meta.mutated_inp_runtime_indices, [0]) + self.assertEqual(len(out_of_graph_mutation_meta.subclass_fw_graph_out_meta), 2) + self.assertIsInstance( + out_of_graph_mutation_meta.subclass_fw_graph_out_meta[0], + SubclassCreationMeta, + ) + self.assertIsInstance( + out_of_graph_mutation_meta.subclass_fw_graph_out_meta[1], + SubclassCreationMeta, + ) + self.assertEqual( + out_of_graph_mutation_meta.subclass_fw_graph_out_meta[ + 0 + ].flat_tensor_start_idx, + 0, + ) + self.assertEqual( + out_of_graph_mutation_meta.subclass_fw_graph_out_meta[ + 1 + ].flat_tensor_start_idx, + 2, + ) + + def test_collect_metadata_subclass_fw_outs_include_metadata_only_mutation(self): + from torch._functorch._aot_autograd.collect_metadata_analysis import ( + run_functionalized_fw_and_collect_metadata, + ) + from torch._functorch._aot_autograd.descriptors import PlainAOTInput + from torch._functorch._aot_autograd.schemas import ( + PlainTensorMeta, + SubclassCreationMeta, + ) + + def f(x): + x.transpose_(0, 1) + return [TwoTensor(torch.sin(x), torch.cos(x))] + + fake_mode = FakeTensorMode() + arg = fake_mode.from_tensor(torch.ones(2, 3)) + + metadata = run_functionalized_fw_and_collect_metadata( + f, + flat_args_descs=[PlainAOTInput(0)], + keep_input_mutations=True, + static_input_indices=[], + )(arg) + + self.assertEqual(metadata.mutated_inp_runtime_indices, [0]) + self.assertEqual(len(metadata.subclass_fw_graph_out_meta), 2) + self.assertIsInstance(metadata.subclass_fw_graph_out_meta[0], PlainTensorMeta) + self.assertEqual(metadata.subclass_fw_graph_out_meta[0].unwrapped_idx, 0) + self.assertIsInstance( + metadata.subclass_fw_graph_out_meta[1], + SubclassCreationMeta, + ) + self.assertEqual( + metadata.subclass_fw_graph_out_meta[1].flat_tensor_start_idx, + 1, + ) + + def test_collect_metadata_subclass_fw_outs_include_intermediate_bases(self): + from torch._functorch._aot_autograd.collect_metadata_analysis import ( + run_functionalized_fw_and_collect_metadata, + ) + from torch._functorch._aot_autograd.descriptors import PlainAOTInput + from torch._functorch._aot_autograd.schemas import PlainTensorMeta + + def f(x): + y = x.a + x.b + return [y.view(-1), y.view(-1)] + + fake_mode = FakeTensorMode() + subclass_arg = TwoTensor( + fake_mode.from_tensor(torch.ones(2, 3, requires_grad=True)), + fake_mode.from_tensor(torch.ones(2, 3, requires_grad=True)), + ) + + metadata = run_functionalized_fw_and_collect_metadata( + f, + flat_args_descs=[PlainAOTInput(0)], + keep_input_mutations=True, + static_input_indices=[], + )(subclass_arg) + + self.assertEqual(metadata.num_intermediate_bases, 1) + self.assertEqual(len(metadata.subclass_fw_graph_out_meta), 3) + self.assertTrue( + all( + isinstance(out_meta, PlainTensorMeta) + for out_meta in metadata.subclass_fw_graph_out_meta + ) + ) + self.assertEqual( + [ + out_meta.unwrapped_idx + for out_meta in metadata.subclass_fw_graph_out_meta + ], + [0, 1, 2], + ) + class TestAOTDispatch(AOTTestCase): # Tests to add cases for (non-exhaustive list, mostly for my notes): @@ -7099,7 +7370,7 @@ def f(a, b): """ During the backward, we encountered a tensor subclass where we guessed its metadata incorrectly. -""", # noqa: F541 +""", ): (out_test[0] + out_test[1]).sum().backward() @@ -8745,8 +9016,6 @@ def fn(x): xfail("nn.functional.gaussian_nll_loss"), xfail("tensor_split"), xfail("corrcoef"), - xfail("quantile"), - xfail("nanquantile"), skip("narrow"), xfail("istft"), xfail("linalg.eig"), @@ -8781,6 +9050,14 @@ def fn(x): # in AOTDispatcher to make them contiguous decorator=toleranceOverride({torch.float32: tol(atol=1e-02, rtol=1e-02)}), ), + decorate( + "cholesky_inverse", + # Numerical differences due to tangent stride differences between + # eager (.sum().backward() uses contiguous tangent) and compiled + # (tangent strides match output strides). With ill-conditioned inputs, + # matmul accumulates rounding errors differently for different layouts. + decorator=toleranceOverride({torch.float32: tol(atol=3e02, rtol=2e-03)}), + ), decorate( "nn.functional.interpolate", "bicubic", @@ -8825,9 +9102,6 @@ def fn(x): skip( "nn.functional.batch_norm", "" ), # '0 is not tracked with proxy for TwoTensor, backward should wrap grad inputs back + into TwoTensor via codegen'd epilogue. + """ + with self._capture_codegen_source("backward_subclass_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x * 2 + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + tt = TwoTensor(a, b) + + # Run ref (eager) and test (compiled) to compare gradients + tt_ref = TwoTensor( + a.clone().detach().requires_grad_(True), + b.clone().detach().requires_grad_(True), + ) + out_ref = tt_ref * 2 + out_ref.sum().backward() + + out = f(tt) + out.sum().backward() + + self.assertIsInstance(out, TwoTensor) + self.assertIsInstance(tt.grad, TwoTensor) + self.assertEqual(tt.grad.a, tt_ref.grad.a) + self.assertEqual(tt.grad.b, tt_ref.grad.b) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected backward_subclass_wrapper codegen artifact to be emitted", + ) + + def test_multi_input_backward_wraps_grad_inputs(self): + """ + f(TwoTensor, TwoTensor) -> TwoTensor, backward should wrap grad + inputs for each subclass input. + """ + with self._capture_codegen_source("backward_subclass_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + return x * y + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + c = torch.randn(4, requires_grad=True) + d = torch.randn(4, requires_grad=True) + + tt1_ref = TwoTensor( + a.clone().detach().requires_grad_(True), + b.clone().detach().requires_grad_(True), + ) + tt2_ref = TwoTensor( + c.clone().detach().requires_grad_(True), + d.clone().detach().requires_grad_(True), + ) + out_ref = tt1_ref * tt2_ref + out_ref.sum().backward() + + tt1 = TwoTensor(a, b) + tt2 = TwoTensor(c, d) + out = f(tt1, tt2) + out.sum().backward() + + self.assertIsInstance(out, TwoTensor) + self.assertIsInstance(tt1.grad, TwoTensor) + self.assertIsInstance(tt2.grad, TwoTensor) + self.assertEqual(tt1.grad.a, tt1_ref.grad.a) + self.assertEqual(tt1.grad.b, tt1_ref.grad.b) + self.assertEqual(tt2.grad.a, tt2_ref.grad.a) + self.assertEqual(tt2.grad.b, tt2_ref.grad.b) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected backward_subclass_wrapper codegen artifact to be emitted", + ) + + def test_nested_subclass_backward_wraps_grad_inputs(self): + """ + Nested TwoTensor backward should recursively wrap grad inputs. + """ + with self._capture_codegen_source("backward_subclass_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.sin() + + a1 = torch.randn(4, requires_grad=True) + a2 = torch.randn(4, requires_grad=True) + a3 = torch.randn(4, requires_grad=True) + a4 = torch.randn(4, requires_grad=True) + + inner_a_ref = TwoTensor( + a1.clone().detach().requires_grad_(True), + a2.clone().detach().requires_grad_(True), + ) + inner_b_ref = TwoTensor( + a3.clone().detach().requires_grad_(True), + a4.clone().detach().requires_grad_(True), + ) + tt_ref = TwoTensor(inner_a_ref, inner_b_ref) + out_ref = tt_ref.sin() + out_ref.sum().backward() + + inner_a = TwoTensor(a1, a2) + inner_b = TwoTensor(a3, a4) + tt = TwoTensor(inner_a, inner_b) + out = f(tt) + out.sum().backward() + + self.assertIsInstance(out, TwoTensor) + self.assertIsInstance(out.a, TwoTensor) + self.assertIsInstance(out.b, TwoTensor) + self.assertIsInstance(tt.grad, TwoTensor) + self.assertEqual(tt.grad.a.a, tt_ref.grad.a.a) + self.assertEqual(tt.grad.a.b, tt_ref.grad.a.b) + self.assertEqual(tt.grad.b.a, tt_ref.grad.b.a) + self.assertEqual(tt.grad.b.b, tt_ref.grad.b.b) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected backward_subclass_wrapper codegen artifact to be emitted", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_codegen_backward_prologue.py b/test/functorch/test_codegen_backward_prologue.py new file mode 100644 index 0000000000000..e53551146579a --- /dev/null +++ b/test/functorch/test_codegen_backward_prologue.py @@ -0,0 +1,139 @@ +# Owner(s): ["module: functorch"] + +""" +Tests for codegen'ing _backward_prologue_functional non-tangent subclass +unwrapping. + +_backward_prologue_functional unwraps saved tensors (non-tangent subclass +inputs) before passing them to the compiled backward. The tangent processing +(process_runtime_tangent) is runtime-dependent and NOT a codegen candidate, +but the non-tangent unwrapping is pure compile-time-determined subclass +unwrapping, identical to the forward input unwrapping already codegen'd. + +Tests verify that a "backward_subclass_unwrap" artifact is emitted via +trace_structured. +""" + +import logging +from contextlib import contextmanager + +import torch +import torch._functorch.config +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.two_tensor import TwoTensor + + +trace_log = logging.getLogger("torch.__trace") + + +class TestCodegenBackwardPrologue(TestCase): + @contextmanager + def _capture_codegen_source(self, artifact_name): + """Capture codegen artifacts from the structured trace log.""" + captured: list[str] = [] + + class _ArtifactHandler(logging.Handler): + def emit(self, record): + metadata = getattr(record, "metadata", {}) + if ( + "artifact" in metadata + and metadata["artifact"].get("name") == artifact_name + ): + payload = getattr(record, "payload", None) + if payload is not None: + captured.append(payload) + + handler = _ArtifactHandler() + handler.setLevel(logging.DEBUG) + old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + trace_log.addHandler(handler) + try: + yield captured + finally: + trace_log.removeHandler(handler) + trace_log.setLevel(old_level) + + def test_saved_subclass_tensors_unwrapped(self): + """ + When subclass tensors are saved for backward, the prologue should + codegen their unwrapping (non-tangent path). + """ + with self._capture_codegen_source("backward_subclass_unwrap") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + return x * y + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + c = torch.randn(4, requires_grad=True) + d = torch.randn(4, requires_grad=True) + + tt1_ref = TwoTensor( + a.clone().detach().requires_grad_(True), + b.clone().detach().requires_grad_(True), + ) + tt2_ref = TwoTensor( + c.clone().detach().requires_grad_(True), + d.clone().detach().requires_grad_(True), + ) + out_ref = tt1_ref * tt2_ref + out_ref.sum().backward() + + tt1 = TwoTensor(a, b) + tt2 = TwoTensor(c, d) + out = f(tt1, tt2) + out.sum().backward() + + self.assertEqual(tt1.grad.a, tt1_ref.grad.a) + self.assertEqual(tt1.grad.b, tt1_ref.grad.b) + self.assertEqual(tt2.grad.a, tt2_ref.grad.a) + self.assertEqual(tt2.grad.b, tt2_ref.grad.b) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected backward_subclass_unwrap codegen artifact to be emitted", + ) + + def test_mixed_subclass_and_plain_saved_tensors(self): + """ + When both subclass and plain tensors are saved, the prologue should + codegen unwrapping only for the subclass ones. + """ + with self._capture_codegen_source("backward_subclass_unwrap") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + return x * y + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + + tt_ref = TwoTensor( + a.clone().detach().requires_grad_(True), + b.clone().detach().requires_grad_(True), + ) + plain_ref = torch.randn(4, requires_grad=True) + out_ref = tt_ref * plain_ref + out_ref.sum().backward() + + tt = TwoTensor(a, b) + plain = plain_ref.clone().detach().requires_grad_(True) + out = f(tt, plain) + out.sum().backward() + + self.assertEqual(tt.grad.a, tt_ref.grad.a) + self.assertEqual(tt.grad.b, tt_ref.grad.b) + self.assertEqual(plain.grad, plain_ref.grad) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected backward_subclass_unwrap codegen artifact to be emitted", + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_codegen_debug_assert.py b/test/functorch/test_codegen_debug_assert.py new file mode 100644 index 0000000000000..ca82d045ccc99 --- /dev/null +++ b/test/functorch/test_codegen_debug_assert.py @@ -0,0 +1,123 @@ +# Owner(s): ["module: functorch"] + +""" +Tests for codegen'ing DebugAssertWrapper. + +The codegen'd assertion function emits checks only for the specific arg +indices where requires_grad was False at compile time. Positions where +requires_grad=True are safe and generate no runtime check, replacing +the closure that iterated over all args. + +Enabled via torch._functorch.config.debug_assert = True. + +Tests verify that a "debug_assert_wrapper" artifact is emitted via +trace_structured. +""" + +import logging +from contextlib import contextmanager + +import torch +import torch._functorch.config +from torch.testing._internal.common_utils import run_tests, TestCase + + +trace_log = logging.getLogger("torch.__trace") + + +class TestCodegenDebugAssert(TestCase): + @contextmanager + def _capture_codegen_source(self, artifact_name): + """Capture codegen artifacts from the structured trace log.""" + captured: list[str] = [] + + class _ArtifactHandler(logging.Handler): + def emit(self, record): + metadata = getattr(record, "metadata", {}) + if ( + "artifact" in metadata + and metadata["artifact"].get("name") == artifact_name + ): + payload = getattr(record, "payload", None) + if payload is not None: + captured.append(payload) + + handler = _ArtifactHandler() + handler.setLevel(logging.DEBUG) + old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + trace_log.addHandler(handler) + try: + yield captured + finally: + trace_log.removeHandler(handler) + trace_log.setLevel(old_level) + + @torch._functorch.config.patch(debug_assert=True) + def test_mixed_requires_grad(self): + """ + With debug_assert=True, the wrapper should codegen assertions for + inputs compiled without requires_grad, and skip those with + requires_grad=True. + """ + with self._capture_codegen_source("debug_assert_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + return x + y + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4) + out = f(x, y) + + self.assertEqual(out, x + y) + + self.assertEqual(len(captured), 1) + + @torch._functorch.config.patch(debug_assert=True) + def test_all_requires_grad(self): + """ + All inputs with requires_grad=True. Codegen should emit no + assertions (all positions are safe), but the artifact should + still be emitted. + """ + with self._capture_codegen_source("debug_assert_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y, z): + return x + y + z + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4, requires_grad=True) + z = torch.randn(4, requires_grad=True) + out = f(x, y, z) + + self.assertEqual(out, x + y + z) + + self.assertEqual(len(captured), 1) + + @torch._functorch.config.patch(debug_assert=True) + def test_some_no_grad_inputs(self): + """ + Mix of requires_grad and non-requires_grad inputs going through + the training path. Codegen should emit assertions only for the + non-grad positions. + """ + with self._capture_codegen_source("debug_assert_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y, z): + return x * y + z + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4) + z = torch.randn(4) + out = f(x, y, z) + + self.assertEqual(out, x * y + z) + + self.assertEqual(len(captured), 1) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_codegen_dedup.py b/test/functorch/test_codegen_dedup.py new file mode 100644 index 0000000000000..335572b4b040b --- /dev/null +++ b/test/functorch/test_codegen_dedup.py @@ -0,0 +1,151 @@ +# Owner(s): ["module: functorch"] + +""" +Tests for codegen'ing AOTDedupeWrapper. + +The codegen'd remove_dupe_args emits straight-line index selections like +[args[0], args[2], args[5]] with all indices baked in as literals, replacing +the closure-based zip + filter over keep_arg_mask. + +Strategy 2 (the dedup post_compile path) is triggered when duplicate args +have mutations, so strategy 1 (leafification) can't handle them. Dynamo +already deduplicates inputs, so these tests use aot_function directly. + +Tests verify that a "dedup_wrapper" artifact is emitted via trace_structured. +""" + +import logging +from contextlib import contextmanager + +import torch +import torch._functorch.config +from torch._functorch.aot_autograd import aot_function +from torch.testing._internal.common_utils import run_tests, TestCase + + +trace_log = logging.getLogger("torch.__trace") + + +def _nop_compiler(gm, example_inputs): # type: ignore[no-untyped-def] + return gm.forward + + +class TestCodegenDedup(TestCase): + @contextmanager + def _capture_codegen_source(self, artifact_name): + """Capture codegen artifacts from the structured trace log.""" + captured: list[str] = [] + + class _ArtifactHandler(logging.Handler): + def emit(self, record): + metadata = getattr(record, "metadata", {}) + if ( + "artifact" in metadata + and metadata["artifact"].get("name") == artifact_name + ): + payload = getattr(record, "payload", None) + if payload is not None: + captured.append(payload) + + handler = _ArtifactHandler() + handler.setLevel(logging.DEBUG) + old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + trace_log.addHandler(handler) + try: + yield captured + finally: + trace_log.removeHandler(handler) + trace_log.setLevel(old_level) + + def test_duplicate_args_with_mutation(self): + """ + When the same tensor is passed as two args and one position mutates, + strategy 2 dedup kicks in. The codegen should emit straight-line + arg selection. + """ + with self._capture_codegen_source("dedup_wrapper") as captured: + + def f(a, b): + b.mul_(2) + return a + b + + compiled_f = aot_function(f, _nop_compiler) + x = torch.randn(4) + x_ref = x.clone() + out = compiled_f(x, x) + + self.assertEqual(x, x_ref * 2) + self.assertEqual(out, x_ref * 4) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected dedup_wrapper codegen artifact to be emitted", + ) + source = captured[0] + self.assertIn("args[0]", source) + self.assertNotIn("args[1]", source) + + def test_three_way_duplicate_with_mutation(self): + """ + Three-way duplication where the last position mutates. + """ + with self._capture_codegen_source("dedup_wrapper") as captured: + + def f(a, b, c): + c.add_(1) + return a + b + c + + compiled_f = aot_function(f, _nop_compiler) + x = torch.randn(4) + x_ref = x.clone() + out = compiled_f(x, x, x) + + self.assertEqual(x, x_ref + 1) + self.assertEqual(out, (x_ref + 1) * 3) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected dedup_wrapper codegen artifact to be emitted", + ) + source = captured[0] + self.assertIn("args[0]", source) + self.assertNotIn("args[1]", source) + self.assertNotIn("args[2]", source) + + def test_partial_duplicate_with_mutation(self): + """ + Two args are duplicates (with mutation), third is distinct. + Codegen should select [args[0], args[2]] dropping the duplicate. + """ + with self._capture_codegen_source("dedup_wrapper") as captured: + + def f(a, b, c): + b.mul_(3) + return a + b + c + + compiled_f = aot_function(f, _nop_compiler) + x = torch.randn(4) + y = torch.randn(4) + x_ref = x.clone() + y_ref = y.clone() + out = compiled_f(x, x, y) + + self.assertEqual(x, x_ref * 3) + self.assertEqual(out, x_ref * 3 + x_ref * 3 + y_ref) + + self.assertGreaterEqual( + len(captured), + 1, + "Expected dedup_wrapper codegen artifact to be emitted", + ) + source = captured[0] + self.assertIn("args[0]", source) + self.assertNotIn("args[1]", source) + self.assertIn("args[2]", source) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_codegen_mutation_epilogue.py b/test/functorch/test_codegen_mutation_epilogue.py new file mode 100644 index 0000000000000..15f798b13e3df --- /dev/null +++ b/test/functorch/test_codegen_mutation_epilogue.py @@ -0,0 +1,225 @@ +# Owner(s): ["module: functorch"] + +""" +Tests for codegen'ing the mutation epilogue in _create_runtime_wrapper. + +The codegen'd mutation epilogue emits one of as_strided_(), copy_(), +or detach().copy_() per mutated input, with the branch resolved at codegen +time from each input's mutation metadata (mutates_metadata, mutates_data, +is_leaf). + +Tests that exercise data-only mutations use torch.compile (dynamo handles +metadata mutations in-graph, so only data mutations reach the epilogue). + +Tests that exercise metadata mutations (metadata-only, data+metadata) +use aot_function directly so metadata mutations flow through the epilogue. + +Tests verify that a "mutation_epilogue" artifact is emitted via +trace_structured. +""" + +import logging +from contextlib import contextmanager + +import torch +import torch._functorch.config +from functorch.compile import nop +from torch._functorch.aot_autograd import aot_function +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +trace_log = logging.getLogger("torch.__trace") + + +class TestCodegenMutationEpilogue(TestCase): + @contextmanager + def _capture_codegen_source(self, artifact_name): + """Capture codegen artifacts from the structured trace log.""" + captured: list[str] = [] + + class _ArtifactHandler(logging.Handler): + def emit(self, record): + metadata = getattr(record, "metadata", {}) + if ( + "artifact" in metadata + and metadata["artifact"].get("name") == artifact_name + ): + payload = getattr(record, "payload", None) + if payload is not None: + captured.append(payload) + + handler = _ArtifactHandler() + handler.setLevel(logging.DEBUG) + old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + trace_log.addHandler(handler) + try: + yield captured + finally: + trace_log.removeHandler(handler) + trace_log.setLevel(old_level) + + def test_single_data_mutation(self): + """ + Single input data mutation via mul_. Codegen should emit a direct + copy_() for this input. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + x.mul_(2) + return x + y + + x = torch.randn(4, requires_grad=True).clone() + x.retain_grad() + y = torch.randn(4) + x_ref = x.detach().clone() + y_ref = y.clone() + out = f(x, y) + + self.assertEqual(x.detach(), x_ref * 2) + self.assertEqual(out, x_ref * 2 + y_ref) + + self.assertEqual( + len(captured), + 1, + "Expected mutation_epilogue codegen artifact to be emitted", + ) + self.assertIn("copy_", captured[0]) + + def test_multiple_data_mutations(self): + """ + Multiple inputs mutated. Codegen should emit a copy_() per mutated + input, with non-mutated inputs skipped entirely. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + @torch.compile(backend="aot_eager") + def f(a, b, c): + a.mul_(2) + c.add_(1) + return a + b + c + + a = torch.randn(4, requires_grad=True).clone() + a.retain_grad() + b = torch.randn(4) + c = torch.randn(4, requires_grad=True).clone() + c.retain_grad() + a_ref, c_ref = a.detach().clone(), c.detach().clone() + out = f(a, b, c) + + self.assertEqual(a.detach(), a_ref * 2) + self.assertEqual(c.detach(), c_ref + 1) + self.assertEqual(out, a_ref * 2 + b + c_ref + 1) + + self.assertEqual( + len(captured), + 1, + "Expected mutation_epilogue codegen artifact to be emitted", + ) + self.assertIn("copy_", captured[0]) + + def test_leaf_mutation_under_no_grad(self): + """ + Leaf tensor mutated under no_grad (e.g. via detach().mul_()). + Codegen should emit detach().copy_() for this case. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.detach().mul_(2) + return x + 1 + + x = torch.randn(4, requires_grad=True) + x_ref = x.detach().clone() + out = f(x) + + self.assertEqual(x.detach(), x_ref * 2) + self.assertEqual(out, x_ref * 2 + 1) + + self.assertEqual( + len(captured), + 1, + "Expected mutation_epilogue codegen artifact to be emitted", + ) + self.assertIn("detach().copy_", captured[0]) + + @skipIfTorchDynamo( + "aot_function uses FX tracing which conflicts with dynamo wrapping" + ) + def test_metadata_only_mutation(self): + """ + Metadata-only mutation via transpose_(). Codegen should emit + as_strided_() without copy_(). Uses aot_function directly because + dynamo handles metadata mutations in-graph. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + def f(a, b): + a.transpose_(1, 0) + return a + b + + a = torch.randn(3, 4, requires_grad=True).add(0) + b = torch.randn(4, 3) + compiled_f = aot_function(f, nop) + out = compiled_f(a, b) + + self.assertEqual(a.shape, (4, 3)) + self.assertEqual(out.shape, (4, 3)) + + self.assertEqual(len(captured), 1) + self.assertIn("as_strided_", captured[0]) + self.assertNotIn("copy_", captured[0]) + + @skipIfTorchDynamo( + "aot_function uses FX tracing which conflicts with dynamo wrapping" + ) + def test_data_and_metadata_mutation(self): + """ + Both data and metadata mutated (transpose_ then mul_). Codegen + should emit as_strided_() followed by copy_(). Uses aot_function + directly because dynamo handles metadata mutations in-graph. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + def f(a): + a.transpose_(1, 0) + a.mul_(2) + return a + 1 + + a = torch.randn(3, 4, requires_grad=True).add(0) + a_ref = a.detach().clone() + compiled_f = aot_function(f, nop) + out = compiled_f(a) + + self.assertEqual(a.shape, (4, 3)) + self.assertEqual(a.detach(), a_ref.transpose(1, 0) * 2) + self.assertEqual(out, a_ref.transpose(1, 0) * 2 + 1) + + self.assertEqual(len(captured), 1) + self.assertIn("as_strided_", captured[0]) + self.assertIn("copy_", captured[0]) + + def test_no_mutation_no_epilogue(self): + """ + No mutations at all. No mutation_epilogue artifact should be + emitted. + """ + with self._capture_codegen_source("mutation_epilogue") as captured: + + @torch.compile(backend="aot_eager") + def f(x, y): + return x + y + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4) + out = f(x, y) + + self.assertEqual(out, x + y) + self.assertEqual(len(captured), 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_codegen_output_alias.py b/test/functorch/test_codegen_output_alias.py new file mode 100644 index 0000000000000..173afff644c48 --- /dev/null +++ b/test/functorch/test_codegen_output_alias.py @@ -0,0 +1,574 @@ +# Owner(s): ["module: functorch"] + +""" +Tests for codegen'ing the output alias regeneration in +_create_runtime_wrapper. + +The codegen'd output alias handler inlines each handler type's logic per +output as straight-line code: NoopAliasHandler becomes a direct fw_outs[i] +reference, IsInputHandler becomes orig_inputs[base_idx], and +AliasOfInput/IntermediateHandler become inline gen_alias_from_base calls +with baked-in indices and metadata. + +Tests verify that an "output_alias_wrapper" artifact is emitted via +trace_structured. +""" + +import logging +from contextlib import contextmanager +from unittest.mock import patch + +import torch +import torch._functorch.config +from functorch.compile import nop +from torch._functorch.aot_autograd import aot_function +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +trace_log = logging.getLogger("torch.__trace") + + +class TestCodegenOutputAlias(TestCase): + def setUp(self): + super().setUp() + torch._dynamo.reset() + + @contextmanager + def _capture_codegen_source(self, artifact_name): + """Capture codegen artifacts from the structured trace log.""" + captured: list[str] = [] + + class _ArtifactHandler(logging.Handler): + def emit(self, record): + metadata = getattr(record, "metadata", {}) + if ( + "artifact" in metadata + and metadata["artifact"].get("name") == artifact_name + ): + payload = getattr(record, "payload", None) + if payload is not None: + captured.append(payload) + + handler = _ArtifactHandler() + handler.setLevel(logging.DEBUG) + old_level = trace_log.level + trace_log.setLevel(logging.DEBUG) + trace_log.addHandler(handler) + try: + yield captured + finally: + trace_log.removeHandler(handler) + trace_log.setLevel(old_level) + + def test_output_is_view_of_input(self): + """ + Output that is a view of an input (alias_of_input). Codegen should + emit gen_alias_from_base(orig_inputs[i], ...) inline. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.view(-1) + + x = torch.randn(2, 3) + out = f(x) + + self.assertEqual(out, x.view(-1)) + self.assertEqual(out.data_ptr(), x.data_ptr()) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_output_is_input(self): + """ + Output that IS the input (is_input). Codegen should emit a direct + reference to orig_inputs[i]. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.mul_(2) + return x + + x = torch.randn(4) + x_ref = x.clone() + out = f(x) + + self.assertEqual(x, x_ref * 2) + self.assertIs(out, x) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_mixed_alias_and_non_alias_outputs(self): + """ + Multiple outputs: one aliased, one not. Codegen should emit + gen_alias_from_base for the alias and a noop for the non-alias. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x * 2, x.view(-1) + + x = torch.randn(2, 3) + out1, out2 = f(x) + + self.assertEqual(out1, x * 2) + self.assertEqual(out2, x.view(-1)) + self.assertEqual(out2.data_ptr(), x.data_ptr()) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_output_alias_with_mutation(self): + """ + Input is mutated AND output is a view of the input. Codegen should + handle both mutation epilogue and alias regeneration. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.add_(1) + return x.view(-1) + + x = torch.randn(2, 3) + x_ref = x.clone() + out = f(x) + + self.assertEqual(x, x_ref + 1) + self.assertEqual(out, (x_ref + 1).view(-1)) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_output_aliases_intermediate(self): + """ + Output is a view of another output (intermediate), not of an input. + Triggers AliasOfIntermediateHandler in the codegen. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + y = x + 1 + return y, y.view(-1) + + x = torch.randn(2, 3, requires_grad=True) + out1, out2 = f(x) + + self.assertEqual(out1, x + 1) + self.assertEqual(out2, (x + 1).view(-1)) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_multiple_views_of_same_input(self): + """ + Two outputs both alias the same input. Codegen should emit two + separate gen_alias_from_base calls referencing the same base. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.view(-1), x.reshape(6) + + x = torch.randn(2, 3) + out1, out2 = f(x) + + self.assertEqual(out1, x.view(-1)) + self.assertEqual(out2, x.reshape(6)) + self.assertEqual(out1.data_ptr(), x.data_ptr()) + self.assertEqual(out2.data_ptr(), x.data_ptr()) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_training_path_view_of_input(self): + """ + Training path (trace_joint=True): output is a view of input with + requires_grad=True. Codegen should use _unwrap_tensoralias in the + alias function. Also verifies backward correctness. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.view(-1) + + x = torch.randn(2, 3, requires_grad=True) + out = f(x) + + self.assertEqual(out, x.view(-1)) + self.assertEqual(out.data_ptr(), x.data_ptr()) + + out.sum().backward() + self.assertEqual(x.grad, torch.ones(2, 3)) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_training_path_mixed_requires_grad(self): + """ + Training path with mixed differentiable and non-differentiable + outputs. Exercises non-differentiable output collection in + _transform_raw_returns codegen and backward correctness. + """ + with ( + self._capture_codegen_source("compiled_fn_wrapper") as xform_captured, + self._capture_codegen_source("output_alias_wrapper") as _alias_captured, + ): + + @torch.compile(backend="aot_eager") + def f(x, y): + return x * 2, y.view(-1) + + x = torch.randn(2, 3, requires_grad=True) + y = torch.randn(4) + out1, out2 = f(x, y) + + self.assertEqual(out1, x * 2) + self.assertEqual(out2, y.view(-1)) + self.assertTrue(out1.requires_grad) + self.assertFalse(out2.requires_grad) + + out1.sum().backward() + self.assertEqual(x.grad, torch.full((2, 3), 2.0)) + + self.assertEqual( + len(xform_captured), + 1, + "Expected compiled_fn_wrapper codegen artifact to be emitted", + ) + + def test_training_path_mutation_and_alias(self): + """ + Training path: input is mutated AND returned as a view. Exercises + both the mutation epilogue and alias codegen with trace_joint=True. + Uses a non-leaf tensor to allow in-place mutation with grad. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.add_(1) + return x.view(-1) + + base = torch.randn(2, 3, requires_grad=True) + x = base.clone() + x_ref = x.detach().clone() + out = f(x) + + self.assertEqual(x, x_ref + 1) + self.assertEqual(out, (x_ref + 1).view(-1)) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_training_path_is_input(self): + """ + Training path: output IS the input (mutation + return identity). + Exercises IsInputHandler with trace_joint=True. Uses a non-leaf + tensor to allow in-place mutation with grad. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.mul_(2) + return x + + base = torch.randn(4, requires_grad=True) + x = base.clone() + x_ref = x.detach().clone() + out = f(x) + + self.assertEqual(x, x_ref * 2) + self.assertIs(out, x) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_training_path_alias_of_intermediate_detach(self): + """ + Training path: one output is a detached view of an intermediate, + the other is a differentiable view. Exercises + AliasOfIntermediateHandler with trace_joint=True and the + base_is_user_output sub-path. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + y = x + 1 + return y.detach(), y.view(-1) + + x = torch.randn(3, 3, requires_grad=True) + out_detach, out_view = f(x) + + self.assertEqual(out_detach, x + 1) + self.assertEqual(out_view, (x + 1).view(-1)) + self.assertFalse(out_detach.requires_grad) + self.assertTrue(out_view.requires_grad) + self.assertEqual(out_detach.data_ptr(), out_view.data_ptr()) + + out_view.sum().backward() + self.assertEqual(x.grad, torch.ones(3, 3)) + + self.assertEqual( + len(captured), + 1, + "Expected output_alias_wrapper codegen artifact to be emitted", + ) + + def test_view_replay_config_false(self): + """ + Test that view_replay_for_aliased_outputs=False is correctly + baked into the codegen'd alias function. + """ + with patch( + "torch._functorch.config.view_replay_for_aliased_outputs", + False, + ): + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.view(-1) + + x = torch.randn(2, 3) + out = f(x) + + self.assertEqual(out, x.view(-1)) + self.assertEqual(out.data_ptr(), x.data_ptr()) + + self.assertEqual(len(captured), 1) + self.assertIn("replay_views=False", captured[0]) + + def test_view_replay_config_true(self): + """ + Test that view_replay_for_aliased_outputs=True (default) is + correctly baked into the codegen'd alias function. + """ + with patch( + "torch._functorch.config.view_replay_for_aliased_outputs", + True, + ): + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x.view(-1) + + x = torch.randn(2, 3) + out = f(x) + + self.assertEqual(out, x.view(-1)) + self.assertEqual(out.data_ptr(), x.data_ptr()) + + self.assertEqual(len(captured), 1) + self.assertIn("replay_views=True", captured[0]) + + def test_codegen_source_contains_gen_alias(self): + """ + Verify the codegen'd source contains gen_alias_from_base for + alias-of-input outputs and orig_inputs for is_input outputs. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + x.mul_(2) + return x, x.view(-1) + + x = torch.randn(2, 3) + f(x) + + self.assertEqual(len(captured), 1) + source = captured[0] + self.assertIn("gen_alias_from_base", source) + self.assertIn("orig_inputs[", source) + + def test_codegen_source_noop_handler(self): + """ + Verify the codegen'd source contains fw_outs[i] for non-aliased + (NoopAliasHandler) outputs in a mixed scenario. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x * 2, x.view(-1) + + x = torch.randn(2, 3) + f(x) + + self.assertEqual(len(captured), 1) + source = captured[0] + self.assertIn("fw_outs[", source) + self.assertIn("gen_alias_from_base", source) + + def test_alias_of_intermediate_save_as_output(self): + """ + Two outputs aliasing the same intermediate (not an input). When + multiple outputs share the same intermediate base, the first triggers + alias_of_intermediate_save_as_output and the second triggers + alias_of_intermediate. Both use AliasOfIntermediateHandler in codegen. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + y = x + 1 + return y.view(-1), y.view(3, 3) + + x = torch.randn(3, 3, requires_grad=True) + out1, out2 = f(x) + + expected = x + 1 + self.assertEqual(out1, expected.view(-1)) + self.assertEqual(out2, expected.view(3, 3)) + self.assertEqual(out1.data_ptr(), out2.data_ptr()) + + out1.sum().backward() + self.assertEqual(x.grad, torch.ones(3, 3)) + + self.assertEqual(len(captured), 1) + source = captured[0] + self.assertIn("gen_alias_from_base", source) + + def test_xform_unsafe_view_output(self): + """ + _transform_raw_returns codegen: when an output is a view of an + intermediate and is the only output aliasing that intermediate + (unsafe_view_alias), the codegen emits an _unsafe_view call. + """ + with self._capture_codegen_source("compiled_fn_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return (x + 1).view(-1) + + x = torch.randn(2, 3, requires_grad=True) + out = f(x) + + self.assertEqual(out, (x + 1).view(-1)) + self.assertTrue(out.requires_grad) + + out.sum().backward() + self.assertEqual(x.grad, torch.ones(2, 3)) + + self.assertEqual(len(captured), 1) + self.assertIn("_unsafe_view", captured[0]) + + @skipIfTorchDynamo("dynamo handles metadata mutations in-graph") + def test_xform_metadata_only_mutation(self): + """ + _transform_raw_returns codegen: when an input has a metadata-only + mutation (mutates_metadata=True, mutates_data=False), the codegen + wraps the corresponding mutated input return in TensorAlias. + Uses aot_function directly because dynamo handles metadata + mutations in-graph, so they never reach the _transform_raw_returns + codegen path. + """ + with self._capture_codegen_source("compiled_fn_wrapper") as captured: + + def f(a, b): + a.transpose_(1, 0) + return a + b + + a = torch.randn(3, 4, requires_grad=True).add(0) + b = torch.randn(4, 3) + compiled_f = aot_function(f, nop) + out = compiled_f(a, b) + + self.assertEqual(a.shape, (4, 3)) + self.assertEqual(out.shape, (4, 3)) + + self.assertEqual(len(captured), 1) + self.assertIn("TensorAlias", captured[0]) + + def test_cross_dtype_view_alias(self): + """ + Output is a cross-dtype view of the input (view_as_real on a + complex tensor). Exercises gen_alias_from_base's cross-dtype + handling through the codegen path. + """ + with self._capture_codegen_source("output_alias_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return torch.view_as_real(x) + + x = torch.randn(4, dtype=torch.complex64) + out = f(x) + + self.assertEqual(out, torch.view_as_real(x)) + self.assertEqual(out.shape, (4, 2)) + self.assertEqual(out.dtype, torch.float32) + + self.assertEqual(len(captured), 1) + self.assertIn("gen_alias_from_base", captured[0]) + + def test_xform_aliased_output_tensoralias_wrapping(self): + """ + _transform_raw_returns codegen: aliased outputs get wrapped in + TensorAlias so autograd.Function doesn't treat them as regular + tensors. Verifies the TensorAlias wrapping path for aliased + outputs (distinct from the metadata-only mutation wrapping). + Needs a non-view computation (x * 2) to force the autograd + factory path; a pure view like x.view(-1) alone bypasses it. + """ + with self._capture_codegen_source("compiled_fn_wrapper") as captured: + + @torch.compile(backend="aot_eager") + def f(x): + return x * 2, x.view(-1) + + x = torch.randn(2, 3, requires_grad=True) + out1, out2 = f(x) + + self.assertEqual(out1, x * 2) + self.assertEqual(out2, x.view(-1)) + self.assertEqual(out2.data_ptr(), x.data_ptr()) + + out1.sum().backward() + self.assertEqual(x.grad, torch.full((2, 3), 2.0)) + + self.assertEqual(len(captured), 1) + self.assertIn("TensorAlias", captured[0]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 72df12ad1822f..af4ac447e4f73 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -611,7 +611,7 @@ def forward(self, pred_1, x_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None getitem_1 = cond_1[0]; cond_1 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) def test_cond_autograd_complex(self): @@ -652,7 +652,7 @@ def forward(self, pred_1, x_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None getitem_1 = cond_1[0]; cond_1 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) @skipIfTorchDynamo("Skip due to graph break when run with dynamo") @@ -762,7 +762,7 @@ def forward(self, pred_1, x_1, y_1, z_1): cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (z_1, y_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = z_1 = y_1 = ones_like = None getitem_1 = cond_1[0] getitem_2 = cond_1[1]; cond_1 = getitem_2 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) @skipIfTorchDynamo("Skip due to graph break when run with dynamo") @@ -825,7 +825,7 @@ def forward(self, pred_1, x_1): getitem_2 = cond_1[1] getitem_3 = cond_1[2]; getitem_3 = None getitem_4 = cond_1[3]; cond_1 = getitem_4 = None - return (getitem_2,)""", # noqa: B950 + return (getitem_2,)""", ) def test_cond_in_forloop(self): @@ -885,7 +885,7 @@ def forward(self, x_1): mul_7 = torch.ops.aten.mul.Tensor(add_4, x_1) mul_8 = torch.ops.aten.mul.Tensor(add_4, x_1); add_4 = x_1 = None add_5 = torch.ops.aten.add.Tensor(mul_8, mul_7); mul_8 = mul_7 = None - return (add_5,)""", # noqa: B950 + return (add_5,)""", ) @skipIfTorchDynamo("Skip due to graph break when run with dynamo") @@ -941,7 +941,7 @@ def forward(self, pred_1, a_1, b_1, c_1): getitem_1 = cond_1[0] getitem_2 = cond_1[1] getitem_3 = cond_1[2]; cond_1 = getitem_3 = None - return (getitem_1, getitem_2)""", # noqa: B950 + return (getitem_1, getitem_2)""", ) # Forward self.assertExpectedInline( @@ -1016,7 +1016,7 @@ def forward(self, pred_1): getitem_1 = cond_1[0] getitem_2 = cond_1[1] getitem_3 = cond_1[2]; cond_1 = getitem_3 = None - return (getitem_1, getitem_2)""", # noqa: B950 + return (getitem_1, getitem_2)""", ) def test_cond_autograd_different_pytree_output(self): @@ -1081,7 +1081,7 @@ def forward(self, pred_1): cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (_tensor_constant0, _tensor_constant1, _tensor_constant2)); pred_1 = true_graph_0 = false_graph_0 = _tensor_constant0 = _tensor_constant1 = _tensor_constant2 = None getitem = cond[0] getitem_1 = cond[1]; cond = None - return {'res': [getitem, (getitem_1,)]}""", # noqa: B950 + return {'res': [getitem, (getitem_1,)]}""", ) @skipIfTorchDynamo("Skip due to graph break when run with dynamo") @@ -1145,7 +1145,7 @@ def forward(self, pred_1, x_1): getitem_5 = cond_1[4]; getitem_5 = None getitem_6 = cond_1[5]; getitem_6 = None getitem_7 = cond_1[6]; cond_1 = getitem_7 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) def test_cond_autograd_user_nn_module(self): @@ -1197,7 +1197,7 @@ def forward(self, pred_1, x_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None getitem_1 = cond_1[0]; cond_1 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) def test_cond_autograd_inner_fn(self): @@ -1252,7 +1252,7 @@ def forward(self, pred_1, x_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None getitem_1 = cond_1[0]; cond_1 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) def test_cond_autograd_inner_tensor(self): @@ -1294,7 +1294,7 @@ def forward(self, pred_1, x_1): false_graph_1 = self.false_graph_1 cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None getitem_1 = cond_1[0]; cond_1 = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.") @@ -2388,6 +2388,44 @@ def test_scan_compile_cnt(self, reverse, device): ) self.assertEqual(cnt.frame_count, 6) + def test_scan_operator_call_count(self): + from torch._higher_order_ops.scan import generic_scan, wrap_combine_fn_flat + + counter = [0] + + def counting_combine(carry, x): + counter[0] += 1 + return carry + x, x + + init = [torch.zeros(3)] + xs = [torch.ones(5, 3)] + combine_flat = functools.partial( + wrap_combine_fn_flat, + combine_fn=counting_combine, + spec_init=pytree.tree_flatten(init)[1], + spec_xs=pytree.tree_flatten(xs)[1], + num_init_leaves=len(init), + num_inp_leaves=len(xs), + ) + + counter[0] = 0 + result = generic_scan(combine_flat, init, xs) # noqa: F841 + self.assertEqual(counter[0], 5) + + # Single-element scan should call operator exactly once. + xs_one = [torch.ones(1, 3)] + combine_flat_one = functools.partial( + wrap_combine_fn_flat, + combine_fn=counting_combine, + spec_init=pytree.tree_flatten(init)[1], + spec_xs=pytree.tree_flatten(xs_one)[1], + num_init_leaves=len(init), + num_inp_leaves=len(xs_one), + ) + counter[0] = 0 + generic_scan(combine_flat_one, init, xs_one) + self.assertEqual(counter[0], 1) + @skipIfTorchDynamo("don't test compile on compile") def test_scan_init_scanned_0(self): # Only init and no input @@ -2860,7 +2898,7 @@ def forward(self, child: "f32[1, 10, 2]", child_1: "f32[1, 10, 2]", child_2: "f3 child_3: "f32[1, 10, 2]" = a - b return [a, b, child_3] -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") @@ -3542,7 +3580,7 @@ def forward(self, fct_1, init_1, xs_1): getitem = scan[0] getitem_1 = scan[1]; scan = None flip_1 = torch.ops.aten.flip.default(getitem_1, [0]); getitem_1 = None - return (getitem, flip_1)""", # noqa: B950 + return (getitem, flip_1)""", ) # Check graph @@ -3562,7 +3600,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor): carry = scan[0] out = scan[1]; scan = None out_1 = out.flip([0]); out = None - return (carry, out_1)""", # noqa: B950 + return (carry, out_1)""", ) @requires_cuda @@ -4272,7 +4310,7 @@ def forward(self, L_xs_0_0_: "f32[3, 10, 2]", L_xs_0_1_0_: "f32[3, 10, 2]", L_xs movedim_4: "f32[3, 10, 2]" = torch.movedim(flip_4, 0, 0); flip_4 = None movedim_5: "f32[3, 10, 2]" = torch.movedim(flip_5, 0, 0); flip_5 = None return (movedim_3, movedim_4, movedim_5) -""", # noqa: B950 +""", ) @unittest.skipIf(not SM70OrLater, "triton") @@ -5477,7 +5515,7 @@ def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor): grad_out = torch.ones_like(result) grad = torch.autograd.grad(result, (l_x_,), grad_out); result = l_x_ = grad_out = None getitem_1 = grad[0]; grad = None - return (getitem_1,)""", # noqa: B950 + return (getitem_1,)""", ) def test_while_loop_op_mismatch_in_meta(self): @@ -5670,7 +5708,7 @@ def forward(self, out_iter_1, it_1, y_1): getitem_1 = while_loop[1] getitem_2 = while_loop[2]; while_loop = None return (getitem, getitem_1, getitem_2) - """, # noqa: B950 + """, ) self.assertExpectedInline( graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), @@ -5693,7 +5731,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): getitem_2 = while_loop[2]; while_loop = None add = torch.ops.aten.add.Tensor(getitem, 1); getitem = None return (add, getitem_1, getitem_2) - """, # noqa: B950 + """, ) def test_while_loop_pytree_carry(self): @@ -5721,7 +5759,7 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py getitem_1 = while_loop[1] value = while_loop[2] value_1 = while_loop[3]; while_loop = None - return (getitem, getitem_1, value, value_1)""", # noqa: B950 + return (getitem, getitem_1, value, value_1)""", ) def _wrap_with_functionalize(self, fn, func_type): @@ -5757,7 +5795,7 @@ def forward(self, x_1): while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None getitem = while_loop[0]; while_loop = None return (getitem,) - """, # noqa: B950 + """, ) self.assertExpectedInline( graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), @@ -5792,7 +5830,7 @@ def forward(self, arg0_1): while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (arg0_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = arg0_1 = None getitem = while_loop[0]; while_loop = None return (getitem,) - """, # noqa: B950 + """, ) self.assertExpectedInline( graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), @@ -5827,7 +5865,7 @@ def forward(self, x_1): while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (x_1,), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = x_1 = None getitem = while_loop[0]; while_loop = None return (getitem,) - """, # noqa: B950 + """, ) self.assertExpectedInline( graphs["symbolic"].while_loop_cond_graph_0.code.strip("\n"), @@ -5910,10 +5948,9 @@ def test_while_loop_simple_with_linear_compile_check_graph(self): torch.compile(fn, backend=backend)(*inp) self.assertEqual(len(backend.graphs), 1) gm = backend.graphs[0] - if torch._dynamo.config.inline_inbuilt_nn_modules: - self.assertExpectedInline( - normalize_gm(gm.print_readable(print_output=False)), - """\ + self.assertExpectedInline( + normalize_gm(gm.print_readable(print_output=False)), + """\ class GraphModule(torch.nn.Module): def forward(self, L_iter_: "i64[]", L_x_: "f32[2, 2]", L_self_buffers_dec_: "i64[]", L_self_modules_linear_parameters_weight_: "f32[2, 2]", L_self_modules_linear_parameters_bias_: "f32[2]"): l_iter_ = L_iter_ @@ -5940,8 +5977,8 @@ def forward(self, child_2: "i64[]", child_3: "f32[2, 2]", l_self_buffers_dec__co child: "i64[]" = child_2 - 1; child_2 = None child_4: "f32[2, 2]" = torch._C._nn.linear(child_3, l_self_modules_linear_parameters_weight__body_fn, l_self_modules_linear_parameters_bias__body_fn); child_3 = l_self_modules_linear_parameters_weight__body_fn = l_self_modules_linear_parameters_bias__body_fn = None return (child, child_4) -""", # noqa: B950 - ) +""", + ) def test_while_loop_nested2_traced(self): fn, inp = WHILE_LOOP_TESTS["nested2"] @@ -5966,7 +6003,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1): getitem_2 = while_loop[2] getitem_3 = while_loop[3]; while_loop = None return (getitem, getitem_1, getitem_2, getitem_3) - """, # noqa: B950 + """, ) self.assertExpectedInline( outer_body.code.strip("\n"), @@ -5984,7 +6021,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None return (sub, clone, mul, div) - """, # noqa: B950 + """, ) self.assertExpectedInline( outer_body.code.strip("\n"), @@ -6002,7 +6039,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1 mul = torch.ops.aten.mul.Tensor(getitem_2, 2); getitem_2 = None div = torch.ops.aten.div.Tensor(getitem_3, 2); getitem_3 = None return (sub, clone, mul, div) - """, # noqa: B950 + """, ) self.assertExpectedInline( inner_body.code.strip("\n"), @@ -6153,7 +6190,7 @@ def forward(self, a_1, b_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None getitem = cond[0]; cond = None - return getitem""", # noqa: B950 + return getitem""", ) self.assertExpectedInline( gm.true_graph_0.code.strip(), @@ -6654,7 +6691,7 @@ def forward(self, x_1, pred_1, pred2_1): cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None getitem_1 = cond_1[0]; cond_1 = None add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None - return add""", # noqa: B950 + return add""", ) self.assertExpectedInline( graph.true_graph_0.code.strip(), @@ -6842,7 +6879,7 @@ def forward(self, x_1, pred_1, pred2_1): cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None getitem_1 = cond_1[0]; cond_1 = None add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None - return add""", # noqa: B950 + return add""", ) self.assertExpectedInline( graph.true_graph_0.code.strip(), @@ -7238,7 +7275,7 @@ def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]") clone: "f32[3, 4]" = torch.ops.aten.clone.default(arg2_1) clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(arg2_1); arg2_1 = None return [clone, clone_1] -""", # noqa: B950 +""", ) self.assertEqual(res, res_compiled) @@ -7290,7 +7327,7 @@ def forward(self, arg0_1: "f32[3, 4]", arg1_1: "f32[3, 4]", arg2_1: "f32[3, 4]") mul: "f32[3, 4]" = torch.ops.aten.mul.Tensor(arg2_1, cos); arg2_1 = cos = None clone: "f32[3, 4]" = torch.ops.aten.clone.default(mul) return [mul, clone] -""", # noqa: B950 +""", ) self.assertEqual(res, res_compiled) @@ -7417,7 +7454,7 @@ def forward(self, x_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None getitem = cond[0]; cond = None - return getitem""", # noqa: B950 + return getitem""", ) # We expect the traced graph module to work even if input size changes. @@ -7449,7 +7486,7 @@ def forward(self, x_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x_1, sym_size_int_1)); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None getitem = cond[0]; cond = None - return getitem""", # noqa: B950 + return getitem""", ) def _check_closure_correctly_lifted(self, f, *, args, exp_res, exp_arg_num): @@ -7531,7 +7568,7 @@ def forward(self, x_1): _tensor_constant1 = self._tensor_constant1 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int_1, sym_size_int, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int_1 = sym_size_int = _tensor_constant1 = None getitem = cond[0]; cond = None - return getitem""", # noqa: B950 + return getitem""", ) self.assertExpectedInline( gm.true_graph_0.code.strip(), @@ -7683,7 +7720,7 @@ def forward(self, arg0_1, arg1_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, (arg0_1,)); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None getitem = cond[0]; cond = None - return (getitem,)""", # noqa: B950 + return (getitem,)""", ) @skipIfCrossRef # Arg order changes with crossref @@ -7772,7 +7809,7 @@ def forward(self, x_1): false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int_1, sym_size_int)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = sym_size_int = None getitem = cond[0]; cond = None - return getitem""", # noqa: B950 + return getitem""", ) self.assertExpectedInline( @@ -8124,7 +8161,7 @@ def forward(self, s97 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor): cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, (l_a_, l_b_, s97)); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = s97 = None getitem = cond[0]; cond = None - return (getitem,)""", # noqa: B950 + return (getitem,)""", ) def test_two_hops_not_sharing_code_obj(self): @@ -8287,7 +8324,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None carry = scan[0] out = scan[1]; scan = None - return (carry, out)""", # noqa: B950 + return (carry, out)""", ) else: self.assertExpectedInline( @@ -8302,7 +8339,7 @@ def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_ scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None carry = scan[0] out = scan[1]; scan = None - return (carry, out)""", # noqa: B950 + return (carry, out)""", ) self.assertEqual(eager_out, exp_out) self.assertEqual(compiled_out, exp_out) @@ -8616,7 +8653,7 @@ def forward(self, it_1: "Sym(u0)", x_1: "f32[s77, 3]"): copy_: "f32[3]" = torch.ops.aten.copy_.default(select, add); select = add = copy_ = None add_1: "Sym(u0 + 1)" = it_1 + 1; it_1 = None return (add_1, clone) -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @@ -8697,7 +8734,7 @@ def forward(self, unbacked_symint_0: "Sym(u1)", child_1: "f32[s77, s27]", s27: " add_1: "Sym(u1 + 1)" = unbacked_symint_0 + 1; unbacked_symint_0 = None return (add_1, x_clone) -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Skip because we're testing export") @@ -8770,7 +8807,7 @@ def forward(self, a_1: "Sym(u1)", b_1: "Sym(u2)", c1_1: "Sym(u3)", c2_1: "Sym(u4 add: "Sym(u7 + 1)" = u0_1 + 1; u0_1 = None add_1: "f32[2, 3]" = torch.ops.aten.add.Tensor(x_1, 1); x_1 = None return (b_1, c1_1, c2_1, c3_1, a_1, 0, add, add_1) -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @@ -8845,7 +8882,7 @@ def forward(self, unbacked_symint_6: "Sym(u8)", unbacked_symint_7: "Sym(u9)", un add: "Sym(u14 + 1)" = unbacked_symint_12 + 1; unbacked_symint_12 = None child: "f32[2, 3]" = child_1 + 1; child_1 = None return (unbacked_symint_7, unbacked_symint_8, unbacked_symint_9, unbacked_symint_10, unbacked_symint_6, 0, add, child) -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Skip because we're testing export") @@ -8907,7 +8944,7 @@ def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", ar add_5: "f32[s6, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None return (add, add_1, add_2, add_3, add_4, add_5) -""", # noqa: B950 +""", ) @skipIfTorchDynamo("Graph is not captured correctly when test with dynamo") @@ -8977,7 +9014,7 @@ def forward(self, unbacked_symint_4: "Sym(u5)", unbacked_symint_5: "Sym(u6)", un child: "f32[s77, s27]" = child_2 + 1; child_2 = None return (add, add_1, add_2, add_3, add_4, child) -""", # noqa: B950 +""", ) @parametrize("dynamic", [True, False]) @@ -9102,7 +9139,7 @@ def forward(self, arg0_1: "f32[3, 3]", arg1_1: "f32[3]", arg2_1: "f32[3, 3]"): addmm: "f32[3, 3]" = torch.ops.aten.addmm.default(arg1_1, arg0_1, t); arg1_1 = arg0_1 = t = None add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, addmm); add = addmm = None return (add_1,) -""", # noqa: B950 +""", ) self.assertExpectedInline( @@ -9153,7 +9190,7 @@ def forward(self, arg0_1: "i64[]", arg1_1: "f32[3, 3]", arg2_1: "f32[3]", arg3_1 add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None return (add_9, add_8, add_10, add_11) -""", # noqa: B950 +""", ) def test_input_output_alias(self): @@ -9327,7 +9364,7 @@ def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"): mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None return (mul,) -""", # noqa: B950 +""", ) def test_cond_merge_graph_preserves_ph_meta(self): @@ -9411,7 +9448,7 @@ def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None return (add,) -""", # noqa: B950 +""", ) # unbacked symint inputs are created during non-strict export, @@ -9610,7 +9647,7 @@ def forward(self, l_x_: "f32[s17, s94]", s94: "Sym(s94)", s17_true_branch: "Sym( getitem: "f32[2, s94]" = add[slice(None, 2, None)]; add = None clone: "f32[2, s94]" = getitem.clone(); getitem = None return (clone,) -""", # noqa: B950 +""", ) @parametrize("dynamic", [True, False]) @@ -9735,7 +9772,7 @@ def forward(self, arg0_1: "f32[3, 4]"): sin: "f32[3, 4]" = torch.ops.aten.sin.default(add) copy_: "f32[3, 4]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None return (sin,) -""", # noqa: B950 +""", ) @requires_cuda @@ -9801,7 +9838,7 @@ def forward(self, arg0_1: "f32[8]", arg1_1: "f32[1]"): copy_: "f32[8]" = torch.ops.aten.copy_.default(arg0_1, add_1); arg0_1 = add_1 = copy_ = None copy__1: "f32[1]" = torch.ops.aten.copy_.default(arg1_1, add); arg1_1 = add = copy__1 = None return (add_2,) -""", # noqa: B950 +""", ) @requires_cuda @@ -9874,7 +9911,7 @@ def forward(self, arg0_1: "f32[4, 3]", arg1_1: "f32[3, 4]"): mm: "f32[3, 3]" = torch.ops.aten.mm.default(sin, add); sin = None copy_: "f32[4, 3]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None return (mm,) -""", # noqa: B950 +""", ) @@ -9982,15 +10019,15 @@ def test_function_schema_gen(self): ) self.assertExpectedInline( str(schema1), - """test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950 + """test_op1(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", ) self.assertExpectedInline( str(schema2), - """test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", # noqa: B950 + """test_op2(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> Tensor""", ) self.assertExpectedInline( str(schema3), - """test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", # noqa: B950, + """test_op3(bool bool_v, int int_v, float float_v, str str_v, Tensor Tensor_v, SymInt SymInt_v, SymBool SymBool_v, GraphModule GraphModule_v, __torch__.torch.classes._Foo ScriptObj_v) -> (Tensor, Tensor)""", ) self.assertEqual(schema1.parse(str(schema1)), schema1) self.assertEqual(schema2.parse(str(schema2)), schema2) @@ -10076,7 +10113,7 @@ def body_fn(x, y, z): ) self.assertExpectedInline( str(schema), - """while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950 + """while_loop(Any cond_fn, Any body_fn, Tensor carried_input0, Tensor carried_input1, Tensor additional_input0) -> (Tensor, Tensor)""", ) def test_scan_gen_schema_tensor_inputs(self): @@ -10106,7 +10143,7 @@ def combine_fn(carry, x, scale): ) self.assertExpectedInline( str(schema), - """scan(Any combine_fn, Tensor init0, Tensor xs0, Tensor additional_input0) -> (Tensor, Tensor)""", # noqa: B950 + """scan(Any combine_fn, Tensor init0, Tensor xs0, Tensor additional_input0) -> (Tensor, Tensor)""", ) def test_scan_gen_schema_multiple_inputs(self): @@ -10121,7 +10158,7 @@ def combine_fn(carry1, carry2, x1, x2): ) self.assertExpectedInline( str(schema), - """scan(Any combine_fn, Tensor init0, Tensor init1, Tensor xs0, Tensor xs1) -> (Tensor, Tensor, Tensor, Tensor)""", # noqa: B950 + """scan(Any combine_fn, Tensor init0, Tensor init1, Tensor xs0, Tensor xs1) -> (Tensor, Tensor, Tensor, Tensor)""", ) def test_associative_scan_gen_schema_tensor_inputs(self): @@ -10181,7 +10218,7 @@ def body_fn(x, y, z, c): ) self.assertExpectedInline( str(schema), - """while_loop(Any cond_fn, Any body_fn, int carried_input0, int carried_input1, Tensor carried_input2, Tensor additional_input0) -> (int, int, Tensor, Tensor)""", # noqa: B950 + """while_loop(Any cond_fn, Any body_fn, int carried_input0, int carried_input1, Tensor carried_input2, Tensor additional_input0) -> (int, int, Tensor, Tensor)""", ) def test_while_loop_gen_schema_with_input_mutation(self): @@ -10205,7 +10242,7 @@ def body_fn(x, y, z, c): ) self.assertExpectedInline( str(schema), - """while_loop(Any cond_fn, Any body_fn, Tensor(a2!) carried_input0, Tensor(a3!) carried_input1, Tensor(a4!) carried_input2, Tensor(a5!) additional_input0) -> (Tensor, Tensor, Tensor)""", # noqa: B950 + """while_loop(Any cond_fn, Any body_fn, Tensor(a2!) carried_input0, Tensor(a3!) carried_input1, Tensor(a4!) carried_input2, Tensor(a5!) additional_input0) -> (Tensor, Tensor, Tensor)""", ) @@ -10292,7 +10329,7 @@ def func(pred, x): compiled_func = torch.compile(func, backend="cudagraphs") with self.assertRaisesRegex( RuntimeError, - "RNG within data-dependent conditional nodes is not supported yet", + "RNG op during graph capture but generator is not registered", ): compiled_func(pred, x) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index f342e7ec9730b..6330640bf951a 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -76,7 +76,6 @@ TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_TORCHDYNAMO, TestCase, - xfailIfTorchDynamo, ) from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -1152,6 +1151,104 @@ def h(x): (z,) = torch.autograd.grad(y, x) self.assertEqual(z, 2) + @skipIfTorchDynamo("internal API test") + def test_pop_dynamic_layer_stack_to_depth_single(self, device): + ft = torch._C._functorch + ft._grad_increment_nesting() + self.assertEqual(ft.get_dynamic_layer_stack_depth(), 1) + ft.pop_dynamic_layer_stack_and_undo_to_depth(0) + self.assertEqual(ft.get_dynamic_layer_stack_depth(), 0) + + @skipIfTorchDynamo("internal API test") + def test_pop_dynamic_layer_stack_to_depth_mixed(self, device): + ft = torch._C._functorch + ft._vmap_increment_nesting(3, "error") + ft._grad_increment_nesting() + ft._jvp_increment_nesting() + self.assertEqual(ft.get_dynamic_layer_stack_depth(), 3) + # Pop only jvp — must remove exactly one layer + ft.pop_dynamic_layer_stack_and_undo_to_depth(2) + self.assertEqual(ft.get_dynamic_layer_stack_depth(), 2) + # Pop remaining + ft.pop_dynamic_layer_stack_and_undo_to_depth(0) + self.assertEqual(ft.get_dynamic_layer_stack_depth(), 0) + + def test_inference_mode_outside_grad(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + y = grad(lambda x: (x**2).sum())(x) + self.assertEqual(y, 2 * x) + + def test_inference_mode_nograd_outside_grad(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + with torch.no_grad(): + y = grad(lambda x: (x**2).sum())(x) + self.assertEqual(y, 2 * x) + + def test_inference_mode_outside_vjp(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + out, vjp_fn = vjp(lambda x: (x**2).sum(), x) + (y,) = vjp_fn(torch.tensor(1.0, device=device)) + self.assertEqual(y, 2 * x) + + def test_inference_mode_outside_jvp(self, device): + x = torch.randn(3, device=device) + t = torch.ones(3, device=device) + with torch.inference_mode(): + _, y = jvp(lambda x: (x**2).sum(), (x,), (t,)) + self.assertEqual(y, (2 * x * t).sum()) + + def test_inference_mode_outside_jacrev(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + y = jacrev(lambda x: x**2)(x) + self.assertEqual(y, torch.diag(2 * x)) + + def test_inference_mode_outside_vmap_grad(self, device): + xs = torch.randn(5, 3, device=device) + with torch.inference_mode(): + ys = vmap(grad(lambda x: (x**2).sum()))(xs) + self.assertEqual(ys, 2 * xs) + + def test_inference_mode_outside_grad_vmap(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + y = grad(lambda x: vmap(lambda x: (x**2).sum())(x).sum())(x) + self.assertEqual(y, 2 * x) + + def test_inference_mode_nested_grad(self, device): + x = torch.randn([], device=device) + with torch.inference_mode(): + y = grad(grad(lambda x: x**3))(x) + self.assertEqual(y, 6 * x) + + def test_inference_mode_jacrev_grad(self, device): + x = torch.randn(3, device=device) + with torch.inference_mode(): + H = jacrev(grad(lambda x: (x**3).sum()))(x) + self.assertEqual(H, torch.diag(6 * x)) + + def test_inference_mode_inside_grad(self, device): + def f(x): + with torch.inference_mode(): + c = x**2 + return x - c + + x = torch.randn(3, device=device) + with torch.inference_mode(): + y = grad(lambda x: f(x).sum())(x) + self.assertEqual(y, torch.ones_like(x)) + + def test_inference_mode_restored(self, device): + self.assertTrue(not torch.is_inference_mode_enabled()) + with torch.inference_mode(): + self.assertTrue(torch.is_inference_mode_enabled()) + grad(lambda x: (x**2).sum())(torch.randn(3, device=device)) + self.assertTrue(torch.is_inference_mode_enabled()) + self.assertTrue(not torch.is_inference_mode_enabled()) + @markDynamoStrictTest class TestAutogradFunction(TestCase): @@ -2417,7 +2514,6 @@ def f(x): self.assertEqual(actual, expected) # https://github.com/pytorch/pytorch/issues/127036 - @xfailIfTorchDynamo @parametrize("_preallocate_and_copy", (True, False)) def test_chunk_jacrev_chunksize_one(self, device, _preallocate_and_copy): # With chunk_size=1, we shouldn't `vmap` and hence not be limited diff --git a/test/functorch/test_leaf_function.py b/test/functorch/test_leaf_function.py index 7529242d5561d..7422d50d7563b 100644 --- a/test/functorch/test_leaf_function.py +++ b/test/functorch/test_leaf_function.py @@ -1,17 +1,29 @@ # Owner(s): ["oncall: pt2"] -"""Tests for @leaf_function with make_fx and aot_function.""" +"""Tests for @leaf_function with make_fx, aot_function, and torch.compile.""" +import copy +import re from functools import partial +from unittest.mock import patch import torch import torch._dynamo.config as config +import torch._dynamo.testing from functorch.compile import aot_function, nop from torch._dynamo.decorators import leaf_function from torch._dynamo.testing import normalize_gm from torch._higher_order_ops.invoke_leaf_function import invoke_leaf_function from torch.fx.experimental.proxy_tensor import make_fx -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + skipIfCrossRef, + skipIfTorchDynamo, + TestCase, +) +from torch.testing._internal.dynamo_pytree_test_utils import PytreeRegisteringTestCase def extract_graph(fx_g, _, graph_cell): @@ -54,10 +66,10 @@ def forward(self, x_1: "f32[3, 3]", y_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem,) -""", # noqa: B950 +""", ) x2 = torch.randn(3, 3) @@ -90,10 +102,10 @@ def forward(self, x_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = None getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem,) -""", # noqa: B950 +""", ) # Closure change reflected at runtime @@ -150,10 +162,10 @@ def forward(self, x_1: "f32[3, 3]", y_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem,) -""", # noqa: B950 +""", ) x2 = torch.randn(3, 3) @@ -235,10 +247,10 @@ def forward(self, x_1: "f32[3, 3]", y_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', x_1, y_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = x_1 = y_1 = None getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem,) -""", # noqa: B950 +""", ) x2 = torch.randn(3, 3) @@ -320,12 +332,12 @@ def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3 _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', primals_2, primals_3, requires_grad_indices = (0, 1)); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_2 = primals_3 = None + with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', primals_2, primals_3, requires_grad_indices = '0,1'); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_2 = primals_3 = None getitem: "f32[0]" = with_effects[0] getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None return (getitem, getitem_1) -""", # noqa: B950 +""", ) self.assertExpectedInline( normalize_gm(bw_graph_cell[0].print_readable(print_output=False)), @@ -335,12 +347,12 @@ def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): _opaque_obj2 = self._opaque_obj2 _opaque_obj3 = self._opaque_obj3 _tree_spec_constant1 = self._tree_spec_constant1 - with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ()); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None + with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ''); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None getitem_2: "f32[0]" = with_effects_1[0] getitem_3: "f32[3, 3]" = with_effects_1[1] getitem_4: "f32[3, 3]" = with_effects_1[2]; with_effects_1 = None return (getitem_3, getitem_4, getitem_2) -""", # noqa: B950 +""", ) def test_aot_function_gradients(self): @@ -702,11 +714,11 @@ def forward(self, arg0_1, arg1_1: "f32[3, 3]", arg2_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(None, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', arg1_1, arg2_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = arg1_1 = arg2_1 = None + with_effects = torch.ops.higher_order.with_effects(None, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', arg1_1, arg2_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = arg1_1 = arg2_1 = None getitem: "f32[0]" = with_effects[0] getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None return (getitem, getitem_1) -""", # noqa: B950 +""", ) x2 = torch.randn(3, 3) @@ -752,7 +764,7 @@ def forward(self, x_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', getitem, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = getitem = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', getitem, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = getitem = None getitem_1: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem_1,) @@ -760,7 +772,7 @@ class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[3, 3]"): add: "f32[3, 3]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None return (add,) -""", # noqa: B950 +""", ) x2 = torch.randn(3, 3) @@ -806,7 +818,7 @@ def forward(self, x_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', getitem_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = getitem_1 = None + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(_opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', getitem_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = getitem_1 = None getitem_2: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None return (getitem_2,) @@ -815,13 +827,1864 @@ def forward(self, arg0_1, arg1_1: "f32[3, 3]"): _opaque_obj0 = self._opaque_obj0 _opaque_obj1 = self._opaque_obj1 _tree_spec_constant0 = self._tree_spec_constant0 - with_effects = torch.ops.higher_order.with_effects(None, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', arg1_1, requires_grad_indices = ()); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = arg1_1 = None + with_effects = torch.ops.higher_order.with_effects(None, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', arg1_1, requires_grad_indices = ''); _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = arg1_1 = None getitem: "f32[0]" = with_effects[0] getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None return (getitem, getitem_1) -""", # noqa: B950 +""", ) +@skipIfTorchDynamo("leaf_function tests manage their own compilation") +class TestLeafFunctionDynamo(PytreeRegisteringTestCase): + def _assert_models_equal( + self, + model_expected, + model_test, + x_expected, + x_test, + ): + out_expected = model_expected(x_expected) + out_test = model_test(x_test) + self.assertEqual(out_expected, out_test) + + loss_expected = out_expected.sum() + loss_test = out_test.sum() + loss_expected.backward() + loss_test.backward() + self.assertEqual(x_expected.grad, x_test.grad) + + expected_grads = { + name: param.grad for name, param in model_expected.named_parameters() + } + test_grads = {name: param.grad for name, param in model_test.named_parameters()} + + self.assertEqual(set(expected_grads.keys()), set(test_grads.keys())) + for name in expected_grads: + if expected_grads[name] is not None: + self.assertEqual( + expected_grads[name], + test_grads[name], + msg=f"Gradient mismatch for parameter {name}", + ) + + def _test_leaf_function_helper(self, mod_class, args_fn, loss_fn): + import torch.utils._pytree as pytree + from torch._dynamo.testing import AotEagerAndRecordGraphs, EagerAndRecordGraphs + + mod_eager = mod_class() + mod_compile_eager = mod_class() + mod_compile_eager.load_state_dict(dict(mod_eager.state_dict())) + mod_compile_aot = mod_class() + mod_compile_aot.load_state_dict(dict(mod_eager.state_dict())) + + eager_backend = EagerAndRecordGraphs() + compiled_eager = torch.compile( + mod_compile_eager, backend=eager_backend, fullgraph=True + ) + + backend = AotEagerAndRecordGraphs() + compiled_aot = torch.compile(mod_compile_aot, backend=backend, fullgraph=True) + + for _ in range(2): + mod_eager.zero_grad() + mod_compile_eager.zero_grad() + mod_compile_aot.zero_grad() + + args = args_fn() + args_clone = pytree.tree_map( + lambda x: x.clone().detach().requires_grad_(x.requires_grad), args + ) + args_clone2 = pytree.tree_map( + lambda x: x.clone().detach().requires_grad_(x.requires_grad), args + ) + + out_eager = mod_eager(*args) + loss_fn(out_eager).backward() + + out_compile_eager = compiled_eager(*args_clone) + loss_fn(out_compile_eager).backward() + + out_compile_aot = compiled_aot(*args_clone2) + loss_fn(out_compile_aot).backward() + + self.assertEqual(out_eager, out_compile_eager) + self.assertEqual(out_eager, out_compile_aot) + + for (name_eager, param_eager), (_, param_compile_eager), ( + _, + param_compile_aot, + ) in zip( + mod_eager.named_parameters(), + mod_compile_eager.named_parameters(), + mod_compile_aot.named_parameters(), + ): + self.assertEqual( + param_eager.grad, + param_compile_eager.grad, + msg=f"Gradient mismatch for {name_eager} between eager and compile_eager", + ) + self.assertEqual( + param_eager.grad, + param_compile_aot.grad, + msg=f"Gradient mismatch for {name_eager} between eager and compile_aot", + ) + + pytree.tree_map( + lambda x, compile_x: self.assertEqual(x.grad, compile_x.grad) + if isinstance(x, torch.Tensor) and x.requires_grad + else None, + args, + args_clone, + ) + pytree.tree_map( + lambda x, compile_x: self.assertEqual(x.grad, compile_x.grad) + if isinstance(x, torch.Tensor) and x.requires_grad + else None, + args, + args_clone2, + ) + + def _normalize(gm): + s = normalize_gm(gm.print_readable(print_output=False)) + # Normalize nn_module_index which varies depending on whether an + # accelerator is available (stream reserves index 0). + return re.sub(r"'', \d+, ", "'', 0, ", s) + + return ( + _normalize(eager_backend.graphs[0]), + _normalize(backend.fw_graphs[0]), + _normalize(backend.bw_graphs[0]), + ) + + def test_leaf_function_simple(self): + @leaf_function + def non_tracable_forward(mod, x): + if x.sum() > 0: + return (mod.linear(x),) + else: + return (mod.linear(x) + x,) + + @non_tracable_forward.register_fake + def non_tracable_forward_fake(mod, x): + return (mod.linear(x),) + + class NonTracable(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return non_tracable_forward(self, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( + NonTracable, args_fn, loss_fn + ) + self.assertExpectedInline( + dynamo_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): + l_x_ = L_x_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + + real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn + fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn + input_spec : torch.utils._pytree.TreeSpec = self.input_spec + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None + getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None + return (getitem,) +""", + ) + self.assertExpectedInline( + fw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]"): + _opaque_obj0 = self._opaque_obj0 + _opaque_obj1 = self._opaque_obj1 + _tree_spec_constant0 = self._tree_spec_constant0 + with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_2, requires_grad_indices = '1,2,3'); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_2 = None + + getitem: "f32[0]" = with_effects[0] + getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None + return (getitem, getitem_1) +""", + ) + self.assertExpectedInline( + bw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): + _opaque_obj2 = self._opaque_obj2 + _opaque_obj3 = self._opaque_obj3 + _tree_spec_constant1 = self._tree_spec_constant1 + with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ''); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None + getitem_2: "f32[0]" = with_effects_1[0] + getitem_4: "f32[3, 3]" = with_effects_1[2] + getitem_5: "f32[3]" = with_effects_1[3] + getitem_6: "f32[3, 3]" = with_effects_1[4]; with_effects_1 = None + return (getitem_6, getitem_4, getitem_5, getitem_2) +""", + ) + + def test_leaf_function_with_logging(self): + @leaf_function + def logging_forward(mod, x): + print("Processing input") + return (mod.linear(x),) + + @logging_forward.register_fake + def logging_forward_fake(mod, x): + return (mod.linear(x),) + + class LoggingModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return logging_forward(self, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + with patch("builtins.print") as mock_print: + self._test_leaf_function_helper(LoggingModule, args_fn, loss_fn) + mock_print.assert_any_call("Processing input") + self.assertEqual(mock_print.call_count, 6) + + def test_leaf_function_dynamic_autograd_module_config(self): + from torch._dynamo.testing import CompileCounterWithBackend + + @leaf_function + def configurable_scale(mod, x): + if mod.use_double_scale: + return (mod.linear(x) * 2,) + else: + return (mod.linear(x) * 3,) + + @configurable_scale.register_fake + def configurable_scale_fake(mod, x): + return (mod.linear(x),) + + class ConfigurableModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.use_double_scale = True + + def forward(self, x): + return configurable_scale(self, x) + + mod_eager = ConfigurableModule() + mod_compiled = ConfigurableModule() + mod_compiled.load_state_dict(dict(mod_eager.state_dict())) + + counter = CompileCounterWithBackend("aot_eager") + compiled_fn = torch.compile(mod_compiled, backend=counter, fullgraph=True) + + x_value = torch.randn(3, 3) + + mod_eager.use_double_scale = True + mod_compiled.use_double_scale = True + + x1 = x_value.clone().requires_grad_(True) + x1_clone = x_value.clone().requires_grad_(True) + + out_eager_1 = mod_eager(x1) + out_eager_1[0].sum().backward() + + out_compiled_1 = compiled_fn(x1_clone) + out_compiled_1[0].sum().backward() + + self.assertEqual(out_eager_1, out_compiled_1) + self.assertEqual(x1.grad, x1_clone.grad) + + mod_eager.zero_grad() + mod_compiled.zero_grad() + + mod_eager.use_double_scale = False + mod_compiled.use_double_scale = False + + x2 = x_value.clone().requires_grad_(True) + x2_clone = x_value.clone().requires_grad_(True) + + out_eager_2 = mod_eager(x2) + out_eager_2[0].sum().backward() + + out_compiled_2 = compiled_fn(x2_clone) + out_compiled_2[0].sum().backward() + + self.assertEqual(out_eager_2, out_compiled_2) + self.assertEqual(x2.grad, x2_clone.grad) + + self.assertNotEqual(x1.grad, x2.grad) + + self.assertEqual(counter.frame_count, 1) + + def test_leaf_function_dynamic_autograd_closure(self): + from torch._dynamo.testing import CompileCounterWithBackend + + closure_config = {"use_double_scale": True} + + @leaf_function + def configurable_scale(x, y): + if closure_config["use_double_scale"]: + return (x @ y * 2,) + else: + return (x @ y * 3,) + + @configurable_scale.register_fake + def configurable_scale_fake(x, y): + return (x @ y,) + + def fn(x, y): + return configurable_scale(x, y) + + counter = CompileCounterWithBackend("aot_eager") + compiled_fn = torch.compile(fn, backend=counter, fullgraph=True) + + x_value = torch.randn(3, 3) + y_value = torch.randn(3, 3) + + closure_config["use_double_scale"] = True + + x1 = x_value.clone().requires_grad_(True) + y1 = y_value.clone().requires_grad_(True) + x1_clone = x_value.clone().requires_grad_(True) + y1_clone = y_value.clone().requires_grad_(True) + + out_eager_1 = fn(x1, y1) + out_eager_1[0].sum().backward() + + out_compiled_1 = compiled_fn(x1_clone, y1_clone) + out_compiled_1[0].sum().backward() + + self.assertEqual(out_eager_1, out_compiled_1) + self.assertEqual(x1.grad, x1_clone.grad) + self.assertEqual(y1.grad, y1_clone.grad) + + closure_config["use_double_scale"] = False + + x2 = x_value.clone().requires_grad_(True) + y2 = y_value.clone().requires_grad_(True) + x2_clone = x_value.clone().requires_grad_(True) + y2_clone = y_value.clone().requires_grad_(True) + + out_eager_2 = fn(x2, y2) + out_eager_2[0].sum().backward() + + out_compiled_2 = compiled_fn(x2_clone, y2_clone) + out_compiled_2[0].sum().backward() + + self.assertEqual(out_eager_2, out_compiled_2) + self.assertEqual(x2.grad, x2_clone.grad) + self.assertEqual(y2.grad, y2_clone.grad) + + self.assertNotEqual(x1.grad, x2.grad) + self.assertNotEqual(y1.grad, y2.grad) + + self.assertEqual(counter.frame_count, 1) + + def test_leaf_function_closure_constants_without_grad(self): + closure_scale = 2.0 + closure_tensor = torch.tensor([1.0, 2.0, 3.0]) + + @leaf_function + def closure_forward(mod, x): + out = mod.linear(x) * closure_scale * mod.scale + out = out + closure_tensor + mod.offset + return (out,) + + @closure_forward.register_fake + def closure_forward_fake(mod, x): + return (mod.linear(x) + mod.offset,) + + class ClosureModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.scale = 3.0 + self.offset = torch.nn.Parameter(torch.ones(3)) + + def forward(self, x): + return closure_forward(self, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( + ClosureModule, args_fn, loss_fn + ) + self.assertExpectedInline( + dynamo_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_self_parameters_offset_: "f32[3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): + l_x_ = L_x_ + l_self_parameters_offset_ = L_self_parameters_offset_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + + real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn + fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn + input_spec : torch.utils._pytree.TreeSpec = self.input_spec + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_parameters_offset_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_parameters_offset_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None + getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None + return (getitem,) +""", + ) + self.assertExpectedInline( + fw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3]", primals_4: "f32[3, 3]", primals_5: "f32[3]"): + _opaque_obj0 = self._opaque_obj0 + _opaque_obj1 = self._opaque_obj1 + _tree_spec_constant0 = self._tree_spec_constant0 + with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_5, primals_2, requires_grad_indices = '1,2,3,4'); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_5 = primals_2 = None + + getitem: "f32[0]" = with_effects[0] + getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None + return (getitem, getitem_1) +""", + ) + self.assertExpectedInline( + bw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): + _opaque_obj2 = self._opaque_obj2 + _opaque_obj3 = self._opaque_obj3 + _tree_spec_constant1 = self._tree_spec_constant1 + with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ''); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None + getitem_2: "f32[0]" = with_effects_1[0] + getitem_4: "f32[3]" = with_effects_1[2] + getitem_5: "f32[3, 3]" = with_effects_1[3] + getitem_6: "f32[3]" = with_effects_1[4] + getitem_7: "f32[3, 3]" = with_effects_1[5]; with_effects_1 = None + return (getitem_7, getitem_4, getitem_5, getitem_6, getitem_2) +""", + ) + + def test_leaf_function_pytree_inputs(self): + @leaf_function + def pytree_forward(mod, inputs): + if inputs["x"].sum() > 0: + return (mod.linear(inputs["x"]), inputs["y"] + 1) + return (mod.linear(inputs["x"]) + inputs["y"], inputs["y"] - 1) + + @pytree_forward.register_fake + def pytree_forward_fake(mod, inputs): + return (mod.linear(inputs["x"]), inputs["y"]) + + class PytreeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, inputs): + return pytree_forward(self, inputs) + + def args_fn(): + return ( + { + "x": torch.randn(3, 3, requires_grad=True), + "y": torch.randn(3, 3, requires_grad=True), + }, + ) + + def loss_fn(out): + return out[0].sum() + out[1].sum() + + self._test_leaf_function_helper(PytreeModule, args_fn, loss_fn) + + def test_leaf_function_nested_annotations(self): + @leaf_function + def inner_leaf_forward(mod, x): + y = mod.linear(x) + return (y + x,) + + @inner_leaf_forward.register_fake + def inner_leaf_forward_fake(mod, x): + return (mod.linear(x),) + + class InnerLeaf(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return inner_leaf_forward(self, x) + + @leaf_function + def outer_leaf_forward(mod, x): + z = mod.linear(x) + return mod.inner(z + x) + + @outer_leaf_forward.register_fake + def outer_leaf_forward_fake(mod, x): + return mod.inner(mod.linear(x)) + + class OuterLeaf(torch.nn.Module): + def __init__(self): + super().__init__() + self.inner = InnerLeaf() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return outer_leaf_forward(self, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + dynamo_graph_str, fw_graph_str, bw_graph_str = self._test_leaf_function_helper( + OuterLeaf, args_fn, loss_fn + ) + self.assertExpectedInline( + dynamo_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_self_modules_inner_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_inner_modules_linear_parameters_bias_: "f32[3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): + l_x_ = L_x_ + l_self_modules_inner_modules_linear_parameters_weight_ = L_self_modules_inner_modules_linear_parameters_weight_ + l_self_modules_inner_modules_linear_parameters_bias_ = L_self_modules_inner_modules_linear_parameters_bias_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + + real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn + fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn + input_spec : torch.utils._pytree.TreeSpec = self.input_spec + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', 0, l_self_modules_inner_modules_linear_parameters_weight_, l_self_modules_inner_modules_linear_parameters_bias_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_, l_x_); real_fn = fake_fn = input_spec = l_self_modules_inner_modules_linear_parameters_weight_ = l_self_modules_inner_modules_linear_parameters_bias_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = l_x_ = None + getitem: "f32[3, 3]" = invoke_leaf_function[0]; invoke_leaf_function = None + return (getitem,) +""", + ) + self.assertExpectedInline( + fw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]", primals_5: "f32[3, 3]", primals_6: "f32[3]"): + _opaque_obj0 = self._opaque_obj0 + _opaque_obj1 = self._opaque_obj1 + _tree_spec_constant0 = self._tree_spec_constant0 + with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', 0, primals_3, primals_4, primals_5, primals_6, primals_2, requires_grad_indices = '1,2,3,4,5'); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = primals_3 = primals_4 = primals_5 = primals_6 = primals_2 = None + + getitem: "f32[0]" = with_effects[0] + getitem_1: "f32[3, 3]" = with_effects[1]; with_effects = None + return (getitem, getitem_1) +""", + ) + self.assertExpectedInline( + bw_graph_str, + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[3, 3]", tangents_token: "f32[0]"): + _opaque_obj2 = self._opaque_obj2 + _opaque_obj3 = self._opaque_obj3 + _tree_spec_constant1 = self._tree_spec_constant1 + with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.higher_order.invoke_leaf_function, _opaque_obj2, _opaque_obj3, _tree_spec_constant1, '', tangents_1, requires_grad_indices = ''); tangents_token = _opaque_obj2 = _opaque_obj3 = _tree_spec_constant1 = tangents_1 = None + getitem_2: "f32[0]" = with_effects_1[0] + getitem_4: "f32[3, 3]" = with_effects_1[2] + getitem_5: "f32[3]" = with_effects_1[3] + getitem_6: "f32[3, 3]" = with_effects_1[4] + getitem_7: "f32[3]" = with_effects_1[5] + getitem_8: "f32[3, 3]" = with_effects_1[6]; with_effects_1 = None + return (getitem_8, getitem_4, getitem_5, getitem_6, getitem_7, getitem_2) +""", + ) + + def test_leaf_function_data_dependent_nonzero(self): + @leaf_function + def nonzero_forward(mod, x): + out = mod.linear(x) + nonzero_indices = (out > 0).nonzero() + return (out, nonzero_indices) + + @nonzero_forward.register_fake + def nonzero_forward_fake(mod, x): + out = mod.linear(x) + return out, (out > 0).nonzero() + + class NonzeroModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return nonzero_forward(self, x) + + class OuterModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pre_linear = torch.nn.Linear(3, 3) + self.nonzero_module = NonzeroModule() + self.scale = torch.nn.Parameter(torch.tensor(2.0)) + + def forward(self, x): + x = self.pre_linear(x) + x = torch.relu(x) + out, nonzero_indices = self.nonzero_module(x) + num_nonzero = nonzero_indices.shape[0] + scaled_out = out * self.scale + num_nonzero + return scaled_out, nonzero_indices + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + self._test_leaf_function_helper(OuterModule, args_fn, loss_fn) + + @skipIfCrossRef + def test_leaf_function_data_dependent_item(self): + @leaf_function + def item_forward(mod, x): + out = mod.linear(x) + scalar_value = out.sum().item() + return (out, scalar_value) + + @item_forward.register_fake + def item_forward_fake(mod, x): + out = mod.linear(x) + return (out, out.sum().item()) + + class ItemModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return item_forward(self, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + self._test_leaf_function_helper(ItemModule, args_fn, loss_fn) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_multiple_compiled_submodules(self, backend): + @leaf_function + def leaf_forward(mod, x): + if x.sum() > 0: + return (mod.linear(x),) + else: + return (mod.linear(x) + x,) + + @leaf_forward.register_fake + def leaf_forward_fake(mod, x): + return (mod.linear(x),) + + class LeafModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + return leaf_forward(self, x) + + class CompiledSubmodule1(torch.nn.Module): + def __init__(self): + super().__init__() + self.pre_linear = torch.nn.Linear(4, 4) + self.leaf = LeafModule(4, 4) + + def forward(self, x): + x = self.pre_linear(x) + x = torch.relu(x) + out = self.leaf(x)[0] + return out + + class CompiledSubmodule2(torch.nn.Module): + def __init__(self): + super().__init__() + self.leaf = LeafModule(4, 4) + self.post_linear = torch.nn.Linear(4, 4) + + def forward(self, x): + out = self.leaf(x)[0] + out = self.post_linear(out) + return torch.sigmoid(out) + + class CompiledSubmodule3(torch.nn.Module): + def __init__(self): + super().__init__() + self.leaf1 = LeafModule(4, 4) + self.leaf2 = LeafModule(4, 4) + + def forward(self, x): + out1 = self.leaf1(x)[0] + out2 = self.leaf2(x)[0] + return out1 + out2 + + class TopLevelModule(torch.nn.Module): + def __init__(self, compile_submodules=False): + super().__init__() + self.submodule1 = CompiledSubmodule1() + self.submodule2 = CompiledSubmodule2() + self.submodule3 = CompiledSubmodule3() + self.final_linear = torch.nn.Linear(4, 4) + self.compile_submodules = compile_submodules + + def forward(self, x): + if self.compile_submodules: + out1 = torch.compile(self.submodule1, backend=backend)(x) + out2 = torch.compile(self.submodule2, backend=backend)(out1) + out3 = torch.compile(self.submodule3, backend=backend)(out2) + else: + out1 = self.submodule1(x) + out2 = self.submodule2(out1) + out3 = self.submodule3(out2) + final = self.final_linear(out3) + return final + + model_eager = TopLevelModule(compile_submodules=False) + model_compiled = TopLevelModule(compile_submodules=True) + model_compiled.load_state_dict(model_eager.state_dict()) + + x = torch.randn(2, 4, requires_grad=True) + x_compiled = x.clone().detach().requires_grad_(True) + + self._assert_models_equal( + model_eager, + model_compiled, + x, + x_compiled, + ) + + @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("do_compile", [False, True]) + def test_leaf_function_with_graph_breaks(self, backend, do_compile): + @leaf_function + def leaf_forward(mod, x): + if x.sum() > 0: + return (mod.linear(x),) + else: + return (mod.linear(x) + 1,) + + @leaf_forward.register_fake + def leaf_forward_fake(mod, x): + return (mod.linear(x),) + + class LeafModule(torch.nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x): + return leaf_forward(self, x) + + class TopLevelModule(torch.nn.Module): + def __init__(self, do_compile=False, backend="eager"): + super().__init__() + self.leaf1 = LeafModule(4, 4) + self.leaf2 = LeafModule(4, 4) + self.leaf3 = LeafModule(4, 4) + self.final_linear = torch.nn.Linear(4, 4) + self.do_compile = do_compile + self.backend = backend + + def _forward(self, x): + out1 = self.leaf1(x)[0] + torch._dynamo.graph_break() + out2 = self.leaf2(out1)[0] + torch._dynamo.graph_break() + out3 = self.leaf3(out2)[0] + result = self.final_linear(out3) + return result + + def forward(self, x): + if self.do_compile: + return torch.compile( + self._forward, backend=self.backend, fullgraph=False + )(x) + else: + return self._forward(x) + + model_eager = TopLevelModule(do_compile=False) + model_test = TopLevelModule(do_compile=do_compile, backend=backend) + model_test.load_state_dict(model_eager.state_dict()) + + x = torch.randn(2, 4, requires_grad=True) + x_test = x.clone().detach().requires_grad_(True) + + self._assert_models_equal(model_eager, model_test, x, x_test) + + def test_leaf_function_with_module_in_pytree(self): + @leaf_function + def main_forward(modules_dict, x): + if x.sum() > 0: + return (modules_dict["first"](x) + modules_dict["second"](x),) + else: + return (modules_dict["first"](x) - modules_dict["second"](x),) + + @main_forward.register_fake + def main_forward_fake(modules_dict, x): + return (modules_dict["first"](x) + modules_dict["second"](x),) + + class HelperModule(torch.nn.Module): + def __init__(self, scale=1.0): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + self.scale = scale + + def forward(self, x): + return self.linear(x) * self.scale + + class WrapperModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.helper1 = HelperModule(scale=1.0) + self.helper2 = HelperModule(scale=0.5) + + def forward(self, x): + modules_dict = {"first": self.helper1, "second": self.helper2} + return main_forward(modules_dict, x) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + self._test_leaf_function_helper(WrapperModule, args_fn, loss_fn) + + def test_leaf_function_with_module_as_kwarg(self): + @leaf_function + def main_forward(x, helper_mod=None): + if x.sum() > 0: + return (helper_mod(x),) + else: + return (helper_mod(x) + x,) + + @main_forward.register_fake + def main_forward_fake(x, helper_mod=None): + return (helper_mod(x),) + + class HelperModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return self.linear(x) + + class WrapperModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.helper = HelperModule() + + def forward(self, x): + return main_forward(x, helper_mod=self.helper) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + self._test_leaf_function_helper(WrapperModule, args_fn, loss_fn) + + def test_leaf_function_missing_fake_impl_error(self): + @leaf_function + def no_fake_impl_forward(mod, x): + return (mod.linear(x),) + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return no_fake_impl_forward(self, x) + + mod = SimpleModule() + x = torch.randn(3, 3) + + with self.assertRaisesRegex(Exception, "requires a fake implementation"): + mod(x) + + compiled_mod = torch.compile(mod, backend="eager", fullgraph=True) + with self.assertRaisesRegex(Exception, "requires a fake implementation"): + compiled_mod(x) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_constant_tensor_closure_error(self, backend): + constant_weight = torch.randn(3, 3) + + @leaf_function + def constant_closure_forward(x): + return (x @ constant_weight,) + + @constant_closure_forward.register_fake + def constant_closure_forward_fake(x): + return (x @ constant_weight,) + + class ConstantClosureModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return constant_closure_forward(x) + + mod = ConstantClosureModule() + x = torch.randn(3, 3, requires_grad=True) + + result = mod(x) + expected = x @ constant_weight + self.assertEqual(result[0], expected) + + compiled_mod = torch.compile(mod, backend=backend, fullgraph=True) + with self.assertRaisesRegex( + Exception, "Please convert all Tensors to FakeTensors" + ): + compiled_mod(x) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_error(self, backend): + @leaf_function + def mutate_input(x): + x.add_(1) + return (x,) + + @mutate_input.register_fake + def mutate_input_fake(x): + x.add_(1) + return (x,) + + def fn(x): + return mutate_input(x) + + x = torch.randn(3, 3) + + x_eager = x.clone() + with self.assertRaisesRegex(RuntimeError, "Undeclared in-place mutation"): + fn(x_eager) + + x = torch.randn(3, 3, requires_grad=True) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + with self.assertRaisesRegex(RuntimeError, "leaf Variable that requires grad"): + compiled_fn(x.clone().requires_grad_(True)) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_validation_dtype_mismatch(self, backend): + @leaf_function + def dtype_mismatch_forward(mod, x): + return (mod.linear(x),) + + @dtype_mismatch_forward.register_fake + def dtype_mismatch_forward_fake(mod, x): + return (mod.linear(x).double(),) + + class DtypeMismatchModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return dtype_mismatch_forward(self, x) + + mod = DtypeMismatchModule() + x = torch.randn(3, 3) + + with config.patch(leaf_function_validate_outputs=True): + compiled_mod = torch.compile(mod, backend=backend) + with self.assertRaisesRegex(RuntimeError, "Dtype mismatch"): + compiled_mod(x) + + @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("validate_outputs", [True, False]) + def test_leaf_function_validation_shape_mismatch(self, backend, validate_outputs): + @leaf_function + def mismatched_forward(mod, x): + return (mod.linear(x),) + + @mismatched_forward.register_fake + def mismatched_forward_fake(mod, x): + return (torch.zeros(x.shape[0], 6),) + + class MismatchedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return mismatched_forward(self, x) + + mod = MismatchedModule() + x = torch.randn(3, 3) + + with config.patch(leaf_function_validate_outputs=validate_outputs): + compiled_mod = torch.compile(mod, backend=backend) + if validate_outputs: + with self.assertRaises((RuntimeError, AssertionError)): + compiled_mod(x) + else: + result = compiled_mod(x) + self.assertEqual(result[0].shape, (3, 3)) + + def test_leaf_function_no_module_inputs(self): + @leaf_function + def my_custom_fn(inputs: dict[str, torch.Tensor], scale: float, offset: int): + x = inputs["x"] + y = inputs["y"] + if x.sum() > 0: + return (x * scale + y + offset, x.sum() + y.sum()) + return (x * scale - y + offset, x.sum() - y.sum()) + + @my_custom_fn.register_fake + def my_custom_fn_fake( + inputs: dict[str, torch.Tensor], scale: float, offset: int + ): + x = inputs["x"] + y = inputs["y"] + return (x * scale + y + offset, x.sum() + y.sum()) + + class NoModuleInputsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = 2.0 + self.offset = 1 + + def forward(self, x, y): + inputs = {"x": x, "y": y} + return my_custom_fn(inputs, self.scale, self.offset) + + def args_fn(): + return ( + torch.randn(3, 3, requires_grad=True), + torch.randn(3, 3, requires_grad=True), + ) + + def loss_fn(out): + return out[0].sum() + out[1].sum() + + self._test_leaf_function_helper(NoModuleInputsModule, args_fn, loss_fn) + + @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("check_escaped_gradients", [True, False]) + def test_leaf_function_escaped_gradient_multiple_tensors( + self, backend, check_escaped_gradients + ): + weight1 = torch.randn(3, 3, requires_grad=True) + weight2 = torch.randn(3, 3, requires_grad=True) + + @leaf_function + def uses_multiple_closures(x): + return (x @ weight1 + x @ weight2,) + + @uses_multiple_closures.register_fake + def uses_multiple_closures_fake(x): + return (torch.empty(x.shape[0], 3),) + + def fn(x): + return uses_multiple_closures(x) + + x = torch.randn(2, 3, requires_grad=True) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + with config.patch( + leaf_function_check_escaped_gradients=check_escaped_gradients + ): + if check_escaped_gradients: + with self.assertRaisesRegex(RuntimeError, "2 tensor"): + compiled_fn(x) + else: + result = compiled_fn(x) + self.assertEqual(result[0].shape, (2, 3)) + + @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("check_escaped_gradients", [True, False]) + def test_leaf_function_escaped_gradient_input_no_grad( + self, backend, check_escaped_gradients + ): + closure_weight = torch.randn(3, 3, requires_grad=True) + + @leaf_function + def uses_closure(x): + return (x @ closure_weight,) + + @uses_closure.register_fake + def uses_closure_fake(x): + return (torch.empty(x.shape[0], 3),) + + def fn(x): + return uses_closure(x) + + x = torch.randn(2, 3, requires_grad=False) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + with config.patch( + leaf_function_check_escaped_gradients=check_escaped_gradients + ): + result = compiled_fn(x) + self.assertEqual(result[0].shape, (2, 3)) + + @parametrize("backend", ["eager", "aot_eager"]) + @parametrize("check_escaped_gradients", [True, False]) + def test_leaf_function_escaped_gradient_mixed_inputs( + self, backend, check_escaped_gradients + ): + base1 = torch.randn(3, 3, requires_grad=True) + base2 = torch.randn(3, 4, requires_grad=True) + closure_weight1 = base1 * 2 + closure_weight2 = base2 * 3 + + @leaf_function + def mixed_inputs(x, y): + out1 = x @ closure_weight1 + y + out2 = x @ closure_weight2 + return (out1, out2) + + @mixed_inputs.register_fake + def mixed_inputs_fake(x, y): + return (torch.empty(x.shape[0], 3), torch.empty(x.shape[0], 4)) + + def fn(x, y): + return mixed_inputs(x, y) + + x = torch.randn(2, 3, requires_grad=True) + y = torch.randn(2, 3, requires_grad=False) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + with config.patch( + leaf_function_check_escaped_gradients=check_escaped_gradients + ): + if check_escaped_gradients: + with self.assertRaisesRegex(RuntimeError, "2 tensor"): + compiled_fn(x, y) + else: + result = compiled_fn(x, y) + self.assertEqual(result[0].shape, (2, 3)) + self.assertEqual(result[1].shape, (2, 4)) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_escaped_gradient_error_message_contains_tensor_info( + self, backend + ): + closure_weight = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) + + @leaf_function + def uses_closure(x): + return (x @ closure_weight,) + + @uses_closure.register_fake + def uses_closure_fake(x): + return (torch.empty(x.shape[0], 5),) + + def fn(x): + return uses_closure(x) + + x = torch.randn(2, 4, requires_grad=True) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + with config.patch(leaf_function_check_escaped_gradients=True): + with self.assertRaisesRegex(RuntimeError, r"shape=\[4, 5\].*dtype="): + compiled_fn(x) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_escaped_gradient_actually_lost(self, backend): + closure_weight = torch.randn(3, 3, requires_grad=True) + + @leaf_function + def uses_closure(x): + return (x @ closure_weight,) + + @uses_closure.register_fake + def uses_closure_fake(x): + return (torch.empty(x.shape[0], 3),) + + def fn(x): + return uses_closure(x) + + x = torch.randn(2, 3, requires_grad=True) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + result = compiled_fn(x) + loss = result[0].sum() + loss.backward() + + self.assertIsNotNone(x.grad) + self.assertIsNone(closure_weight.grad) + + def test_leaf_function_and_nonstrict_trace_mutually_exclusive(self): + from torch._dynamo.decorators import leaf_function, nonstrict_trace + + with self.assertRaisesRegex( + ValueError, + "cannot be both marked as @leaf_function and @nonstrict_trace", + ): + + @leaf_function + @nonstrict_trace + def bad_fn1(x): + return (x,) + + with self.assertRaisesRegex( + ValueError, + "cannot be both marked as @leaf_function and @nonstrict_trace", + ): + + @nonstrict_trace + @leaf_function + def bad_fn2(x): + return (x,) + + @skipIfCrossRef + def test_leaf_function_no_return_value(self): + printed = [] + + @leaf_function + def fn_no_return(x): + print("processing") + + @fn_no_return.register_fake + def fn_no_return_fake(x): + pass + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + fn_no_return(x) + return (self.linear(x),) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + + with patch("builtins.print", lambda *args, **kwargs: printed.append(args)): + eager_graph, fw_graph, bw_graph = self._test_leaf_function_helper( + Mod, args_fn, loss_fn + ) + self.assertTrue(any("processing" in p for p in printed)) + self.assertExpectedInline( + eager_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[3, 3]", L_self_modules_linear_parameters_weight_: "f32[3, 3]", L_self_modules_linear_parameters_bias_: "f32[3]"): + l_x_ = L_x_ + l_self_modules_linear_parameters_weight_ = L_self_modules_linear_parameters_weight_ + l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_ + + real_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.real_fn + fake_fn : torch._higher_order_ops.invoke_leaf_function._LeafCallable = self.fake_fn + input_spec : torch.utils._pytree.TreeSpec = self.input_spec + invoke_leaf_function = torch.ops.higher_order.invoke_leaf_function(real_fn, fake_fn, input_spec, '', l_x_); real_fn = fake_fn = input_spec = invoke_leaf_function = None + + linear: "f32[3, 3]" = torch._C._nn.linear(l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); l_x_ = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None + return (linear,) +""", + ) + self.assertExpectedInline( + fw_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[0]", primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", primals_4: "f32[3]"): + _opaque_obj0 = self._opaque_obj0 + _opaque_obj1 = self._opaque_obj1 + _tree_spec_constant0 = self._tree_spec_constant0 + with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.higher_order.invoke_leaf_function, _opaque_obj0, _opaque_obj1, _tree_spec_constant0, '', primals_2, requires_grad_indices = '0'); primals_1 = _opaque_obj0 = _opaque_obj1 = _tree_spec_constant0 = None + + getitem: "f32[0]" = with_effects[0]; with_effects = None + + t: "f32[3, 3]" = torch.ops.aten.t.default(primals_3) + addmm: "f32[3, 3]" = torch.ops.aten.addmm.default(primals_4, primals_2, t); primals_4 = t = None + return (getitem, addmm, primals_2, primals_3) +""", + ) + self.assertExpectedInline( + bw_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_2: "f32[3, 3]", primals_3: "f32[3, 3]", tangents_1: "f32[3, 3]"): + t: "f32[3, 3]" = torch.ops.aten.t.default(primals_3); primals_3 = None + t_1: "f32[3, 3]" = torch.ops.aten.t.default(t); t = None + mm: "f32[3, 3]" = torch.ops.aten.mm.default(tangents_1, t_1); t_1 = None + t_2: "f32[3, 3]" = torch.ops.aten.t.default(tangents_1) + mm_1: "f32[3, 3]" = torch.ops.aten.mm.default(t_2, primals_2); t_2 = primals_2 = None + t_3: "f32[3, 3]" = torch.ops.aten.t.default(mm_1); mm_1 = None + sum_1: "f32[1, 3]" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None + view: "f32[3]" = torch.ops.aten.view.default(sum_1, [3]); sum_1 = None + t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None + return (mm, t_4, view) +""", + ) + + def test_leaf_function_output_structure_mismatch(self): + @leaf_function + def mismatched_fn(x): + return {"a": x, "b": x * 2} + + @mismatched_fn.register_fake + def mismatched_fn_fake(x): + return (x, x * 2) + + def fn(x): + return mismatched_fn(x) + + x = torch.randn(3, 3) + with self.assertRaisesRegex(AssertionError, "output structure mismatch"): + torch.compile(fn, backend="eager")(x) + + def test_leaf_function_nested_output(self): + @leaf_function + def nested_output_fn(linear1, linear2, linear3, x): + if x.sum() > 0: + return { + "out": (linear1(x), linear2(x)), + "extra": linear3(x), + "count": 42, + } + else: + return { + "out": (linear1(x) + 1, linear2(x) + 1), + "extra": linear3(x) + 1, + "count": 42, + } + + @nested_output_fn.register_fake + def nested_output_fn_fake(linear1, linear2, linear3, x): + return { + "out": (linear1(x), linear2(x)), + "extra": linear3(x), + "count": 42, + } + + class NestedOutputModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + self.linear3 = torch.nn.Linear(3, 3) + + def forward(self, x): + result = nested_output_fn(self.linear1, self.linear2, self.linear3, x) + return ( + result["out"][0] * result["count"] + + result["out"][1] + + result["extra"] + ) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out.sum() + + self._test_leaf_function_helper(NestedOutputModule, args_fn, loss_fn) + + def test_leaf_function_custom_pytree_output(self): + class Point: + x: torch.Tensor + y: torch.Tensor + + def __init__(self, x, y): + self.x = x + self.y = y + + self.register_pytree_node( + Point, + lambda p: ((p.x, p.y), ()), + lambda xy, _: Point(xy[0], xy[1]), + serialized_type_name=f"{Point.__module__}.{Point.__qualname__}", + ) + + @leaf_function + def point_fn(linear1, linear2, x): + return (Point(linear1(x), linear2(x)), 0.5) + + @point_fn.register_fake + def point_fn_fake(linear1, linear2, x): + return (Point(linear1(x), linear2(x)), 0.5) + + class PointModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(3, 3) + self.linear2 = torch.nn.Linear(3, 3) + + def forward(self, x): + p, scale = point_fn(self.linear1, self.linear2, x) + return (p.x * scale, p.y * scale) + + def args_fn(): + return (torch.randn(3, 3, requires_grad=True),) + + def loss_fn(out): + return out[0].sum() + out[1].sum() + + self._test_leaf_function_helper(PointModule, args_fn, loss_fn) + + def test_leaf_function_fake_requires_grad_ignored(self): + @leaf_function + def my_fn(x): + return (x * 2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x).requires_grad_(False),) + + from torch._dynamo.testing import EagerAndRecordGraphs + + backend = EagerAndRecordGraphs() + + @torch.compile(backend=backend, fullgraph=True) + def fn(x): + return my_fn(x) + + x = torch.randn(3, 3, requires_grad=True) + out = fn(x) + + self.assertTrue(out[0].requires_grad) + out[0].sum().backward() + self.assertIsNotNone(x.grad) + + graph = backend.graphs[0] + for node in graph.graph.nodes: + if node.op == "call_function" and "invoke_leaf_function" in str( + node.target + ): + example_value = node.meta.get("example_value") + self.assertIsNotNone(example_value) + self.assertTrue(example_value[0].requires_grad) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_non_grad(self, backend): + @leaf_function(mutates_args={"buf"}) + def mutate_buffer(x, buf): + buf.add_(1) + return (x + buf,) + + @mutate_buffer.register_fake + def mutate_buffer_fake(x, buf): + buf.add_(1) + return (x + buf,) + + def fn(x, buf): + return mutate_buffer(x, buf) + + x = torch.randn(3, 3) + buf = torch.randn(3, 3) + + buf_eager = buf.clone() + result_eager = fn(x, buf_eager) + expected = x + buf + 1 + self.assertEqual(result_eager[0], expected) + self.assertEqual(buf_eager, buf + 1) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + buf_compiled = buf.clone() + result_compiled = compiled_fn(x, buf_compiled) + self.assertEqual(result_compiled[0], expected) + self.assertEqual(buf_compiled, buf + 1) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_mixed(self, backend): + @leaf_function(mutates_args={"buf"}) + def mixed_fn(x, buf): + buf.mul_(2) + return (x * buf,) + + @mixed_fn.register_fake + def mixed_fn_fake(x, buf): + buf.mul_(2) + return (x * buf,) + + def fn(x, buf): + return mixed_fn(x, buf) + + x = torch.randn(3, 3, requires_grad=True) + buf = torch.randn(3, 3) + + buf_eager = buf.clone() + result_eager = fn(x, buf_eager) + expected = x * (buf * 2) + self.assertEqual(result_eager[0], expected) + self.assertEqual(buf_eager, buf * 2) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + buf_compiled = buf.clone() + result_compiled = compiled_fn(x, buf_compiled) + self.assertEqual(result_compiled[0], expected) + self.assertEqual(buf_compiled, buf * 2) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_module_buffer(self, backend): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("running_mean", torch.zeros(3)) + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x): + return update_stats(self, x) + + @leaf_function(mutates_args={"model.running_mean"}) + def update_stats(model, x): + model.running_mean.add_(x.mean(dim=0)) + return (model.linear(x),) + + @update_stats.register_fake + def update_stats_fake(model, x): + model.running_mean.add_(x.mean(dim=0)) + return (model.linear(x),) + + mod = MyModule() + x = torch.randn(4, 3) + + mod_eager = copy.deepcopy(mod) + result_eager = mod_eager(x) + expected_mean = torch.zeros(3) + x.mean(dim=0) + self.assertEqual(mod_eager.running_mean, expected_mean) + + mod_compiled = copy.deepcopy(mod) + compiled_mod = torch.compile(mod_compiled, backend=backend, fullgraph=True) + result_compiled = compiled_mod(x) + self.assertEqual(result_compiled, result_eager) + self.assertEqual(mod_compiled.running_mean, expected_mean) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_pytree(self, backend): + @leaf_function(mutates_args={"buffers"}) + def update_buffers(x, buffers): + for buf in buffers: + buf.add_(1) + return (x + sum(buffers),) + + @update_buffers.register_fake + def update_buffers_fake(x, buffers): + for buf in buffers: + buf.add_(1) + return (x + sum(buffers),) + + def fn(x, buffers): + return update_buffers(x, buffers) + + x = torch.randn(3, 3) + bufs = [torch.randn(3, 3), torch.randn(3, 3)] + + bufs_eager = [b.clone() for b in bufs] + result_eager = fn(x, bufs_eager) + expected = x + (bufs[0] + 1) + (bufs[1] + 1) + self.assertEqual(result_eager[0], expected) + self.assertEqual(bufs_eager[0], bufs[0] + 1) + self.assertEqual(bufs_eager[1], bufs[1] + 1) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + bufs_compiled = [b.clone() for b in bufs] + result_compiled = compiled_fn(x, bufs_compiled) + self.assertEqual(result_compiled[0], expected) + self.assertEqual(bufs_compiled[0], bufs[0] + 1) + self.assertEqual(bufs_compiled[1], bufs[1] + 1) + + @parametrize("backend", ["eager", "aot_eager"]) + def test_leaf_function_input_mutation_pytree_fine_grained(self, backend): + @leaf_function(mutates_args={"buffers[0]"}) + def update_first(x, buffers): + buffers[0].add_(1) + return (x + buffers[0] + buffers[1],) + + @update_first.register_fake + def update_first_fake(x, buffers): + buffers[0].add_(1) + return (x + buffers[0] + buffers[1],) + + def fn(x, buffers): + return update_first(x, buffers) + + x = torch.randn(3, 3) + bufs = [torch.randn(3, 3), torch.randn(3, 3)] + + bufs_eager = [b.clone() for b in bufs] + result_eager = fn(x, bufs_eager) + expected = x + (bufs[0] + 1) + bufs[1] + self.assertEqual(result_eager[0], expected) + self.assertEqual(bufs_eager[0], bufs[0] + 1) + self.assertEqual(bufs_eager[1], bufs[1]) + + compiled_fn = torch.compile(fn, backend=backend, fullgraph=True) + bufs_compiled = [b.clone() for b in bufs] + result_compiled = compiled_fn(x, bufs_compiled) + self.assertEqual(result_compiled[0], expected) + self.assertEqual(bufs_compiled[0], bufs[0] + 1) + self.assertEqual(bufs_compiled[1], bufs[1]) + + def test_leaf_function_mutates_args_invalid_parameter(self): + with self.assertRaisesRegex(ValueError, "refers to parameter 'buf'"): + + @leaf_function(mutates_args={"buf"}) + def bad_fn(x, buffers): + buffers.add_(1) + return (x + buffers,) + + with self.assertRaisesRegex(ValueError, "refers to parameter 'mdl'"): + + @leaf_function(mutates_args={"mdl.running_mean"}) + def bad_fn2(x, model): + model.running_mean.add_(1) + return (x,) + + def test_leaf_function_mutates_args_non_leaf_expression(self): + @leaf_function(mutates_args={"model"}) + def bad_fn(x, model): + model.running_mean.add_(1) + return (x,) + + @bad_fn.register_fake + def bad_fn_fake(x, model): + model.running_mean.add_(1) + return (x,) + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("running_mean", torch.zeros(3)) + + def forward(self, x): + return bad_fn(x, self) + + mod = MyModule() + x = torch.randn(3) + compiled_fn = torch.compile(mod, backend="eager", fullgraph=True) + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, "resolved to a non-leaf value" + ): + compiled_fn(x) + + +instantiate_parametrized_tests(TestLeafFunctionDynamo) + + +@skipIfTorchDynamo("leaf_function tests manage their own compilation") +class TestLeafFunctionRegisterHook(TestCase): + """Tests for @leaf_function's register_multi_grad_hook API.""" + + def test_hook_fires_on_backward(self): + hook_grads = [] + + @leaf_function + def my_fn(x): + return (x * 2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_grads.append(x_grad.clone()) + + x = torch.randn(3, requires_grad=True) + out = my_fn(x)[0] + out.sum().backward() + + self.assertEqual(len(hook_grads), 1) + self.assertEqual(hook_grads[0], torch.full((3,), 2.0)) + + def test_hook_with_non_tensor_args(self): + hook_grads = [] + + @leaf_function + def my_fn(x, tag, scale): + return (x * scale,) + + @my_fn.register_fake + def my_fn_fake(x, tag, scale): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_grads.append(x_grad.clone()) + + x = torch.randn(3, requires_grad=True) + out = my_fn(x, "hello", 5.0)[0] + out.sum().backward() + + self.assertEqual(len(hook_grads), 1) + self.assertEqual(hook_grads[0], torch.full((3,), 5.0)) + + def test_hook_multiple_tensor_inputs(self): + hook_calls = [] + + @leaf_function + def my_fn(x, y): + return (x * 2 + y * 3,) + + @my_fn.register_fake + def my_fn_fake(x, y): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad, y_grad): + hook_calls.append((x_grad.clone(), y_grad.clone())) + + x = torch.randn(3, requires_grad=True) + y = torch.randn(3, requires_grad=True) + out = my_fn(x, y)[0] + out.sum().backward() + + self.assertEqual(len(hook_calls), 1) + self.assertEqual(hook_calls[0][0], torch.full((3,), 2.0)) + self.assertEqual(hook_calls[0][1], torch.full((3,), 3.0)) + + def test_hook_only_fires_for_requires_grad_inputs(self): + hook_calls = [] + + @leaf_function + def my_fn(x, y): + return (x * 5 + y,) + + @my_fn.register_fake + def my_fn_fake(x, y): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_calls.append(x_grad.clone()) + + x = torch.randn(3, requires_grad=True) + y = torch.randn(3, requires_grad=False) + out = my_fn(x, y)[0] + out.sum().backward() + + self.assertEqual(len(hook_calls), 1) + self.assertEqual(hook_calls[0], torch.full((3,), 5.0)) + + def test_hook_no_requires_grad_no_fire(self): + hook_count = [0] + + @leaf_function + def my_fn(x): + return (x * 2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_count[0] += 1 + + x = torch.randn(3, requires_grad=False) + my_fn(x)[0] + self.assertEqual(hook_count[0], 0) + + def test_hook_side_effect_only_fn(self): + fwd_called = [False] + hook_grads = [] + + @leaf_function + def log_fn(x, tag): + fwd_called[0] = True + return None + + @log_fn.register_fake + def log_fn_fake(x, tag): + return None + + @log_fn.register_multi_grad_hook + def log_fn_hook(x_grad): + hook_grads.append(x_grad.clone()) + + x = torch.randn(4, requires_grad=True) + y = x * 2 + log_fn(y, "test") + y.sum().backward() + + self.assertTrue(fwd_called[0]) + self.assertEqual(len(hook_grads), 1) + + def test_hook_gradient_values_correct(self): + hook_grads = [] + + @leaf_function + def my_fn(x): + return (x**2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_grads.append(x_grad.clone()) + + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + out = my_fn(x)[0] + out.sum().backward() + + self.assertEqual(hook_grads[0], torch.tensor([2.0, 4.0, 6.0])) + self.assertEqual(x.grad, torch.tensor([2.0, 4.0, 6.0])) + + def test_hook_with_downstream_computation(self): + hook_grads = [] + + @leaf_function + def my_fn(x): + return (x * 2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_grads.append(x_grad.clone()) + + x = torch.tensor([1.0, 2.0], requires_grad=True) + y = my_fn(x)[0] + z = y * 3 + z.sum().backward() + + self.assertEqual(hook_grads[0], torch.tensor([6.0, 6.0])) + + def test_hook_with_retain_graph(self): + hook_count = [0] + + @leaf_function + def my_fn(x): + return (x * 2,) + + @my_fn.register_fake + def my_fn_fake(x): + return (torch.empty_like(x),) + + @my_fn.register_multi_grad_hook + def my_fn_hook(x_grad): + hook_count[0] += 1 + + x = torch.randn(3, requires_grad=True) + out = my_fn(x)[0] + out.sum().backward() + self.assertEqual(hook_count[0], 1) + + if __name__ == "__main__": run_tests() diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 955b1a8ada144..2c90e44391bd9 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1,5 +1,4 @@ # Owner(s): ["module: functorch"] -# ruff: noqa: F841 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. @@ -981,9 +980,6 @@ def fn(inp, *args, **kwargs): "masked.softmax", device_type="cpu", ), - xfail( - "nanquantile", device_type="cpu" - ), # vmap not implemented for at::equal. xfail("native_layer_norm"), # vmap: inplace into a regular tensor # got a batched tensor as input while the running_mean or running_var, # which will be updated in place, were not batched. @@ -1035,9 +1031,6 @@ def fn(inp, *args, **kwargs): xfail("normal"), # calls random op xfail("normal", "number_mean"), # calls random op xfail("pca_lowrank"), # calls random op - xfail( - "quantile", device_type="cpu" - ), # Batching rule not implemented for `at::equal` xfail( "scatter_reduce", "prod" ), # vmap (looks like you are calling item/data-dependent) @@ -1061,9 +1054,6 @@ def fn(inp, *args, **kwargs): xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. } ), ) @@ -1180,10 +1170,8 @@ def vjp_of_vjp(*args_and_cotangents): # TODO: implement batching rule skip("_batch_norm_with_update"), xfail("__getitem__", ""), # dynamic error - xfail("nanquantile", device_type="cpu"), # checks q via a .item() call xfail("nn.functional.gaussian_nll_loss"), # checks var for if any value < 0 xfail("narrow"), # .item() call - xfail("quantile", device_type="cpu"), # checks q via a .item() call xfail("view_as_complex"), # Tensor must have a last dimension with stride 1 # required rank 4 tensor to use channels_last format xfail("bfloat16"), @@ -1203,9 +1191,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail("sparse.mm", "reduce"), xfail("as_strided_scatter", ""), # calls as_strided xfail("index_reduce", "prod"), # .item() call - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. # --------------------------------------------------------------------- } ) @@ -1344,9 +1329,6 @@ def test_vmapvjp(self, device, dtype, op): xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. # ---------------------------------------------------------------------- } @@ -1654,9 +1636,6 @@ def test(): xfail("__getitem__", ""), xfail("index_put", ""), xfail("view_as_complex"), - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. xfail("nn.functional.gaussian_nll_loss"), xfail("masked_select"), xfail( @@ -1954,9 +1933,6 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail( "as_strided_scatter" ), # AssertionError: Tensor-likes are not close! - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. xfail("bernoulli"), # calls random op xfail("bfloat16"), # required rank 4 tensor to use channels_last format xfail("cdist"), # Forward AD not implemented and no decomposition diff --git a/test/functorch/test_subclass_codegen.py b/test/functorch/test_subclass_codegen.py index 11f3d3783c1f3..2d13984990933 100644 --- a/test/functorch/test_subclass_codegen.py +++ b/test/functorch/test_subclass_codegen.py @@ -70,11 +70,9 @@ def inner_fn(args): unwrapped_args.extend(args[1:]) args.clear() unwrapped_outs = compiled_fn(unwrapped_args) - wrapped_outs = [] _out_inner_2 = {'a': unwrapped_outs[0], 'b': unwrapped_outs[1]} _out_5 = _subclass_type_3.__tensor_unflatten__(_out_inner_2, _meta_4, (4,), (1,)) - wrapped_outs.append(_out_5) - return tuple(wrapped_outs)""", + return (_out_5,)""", ) def test_compile_nested_subclass(self): @@ -117,15 +115,13 @@ def inner_fn(args): unwrapped_args.extend(args[1:]) args.clear() unwrapped_outs = compiled_fn(unwrapped_args) - wrapped_outs = [] _out_inner_5 = {'a': unwrapped_outs[0], 'b': unwrapped_outs[1]} _out_8 = _subclass_type_6.__tensor_unflatten__(_out_inner_5, _meta_7, (4,), (1,)) _out_inner_9 = {'a': unwrapped_outs[2], 'b': unwrapped_outs[3]} _out_12 = _subclass_type_10.__tensor_unflatten__(_out_inner_9, _meta_11, (4,), (1,)) _out_inner_4 = {'a': _out_8, 'b': _out_12} _out_15 = _subclass_type_13.__tensor_unflatten__(_out_inner_4, _meta_14, (4,), (1,)) - wrapped_outs.append(_out_15) - return tuple(wrapped_outs)""", + return (_out_15,)""", ) def test_trailing_args_forwarded(self): @@ -194,7 +190,7 @@ def mock_compiled_fn(args): globals_dict["compiled_fn"] = mock_compiled_fn local_dict = {} - exec(compile(source, "", "exec"), globals_dict, local_dict) # noqa: S102 + exec(compile(source, "", "exec"), globals_dict, local_dict) wrapper = local_dict["inner_fn"] a = torch.randn(4) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 6799eab9f7c62..1e0a4fff0e59f 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -71,6 +71,7 @@ run_tests, skipIfTorchDynamo, subtest, + TEST_WITH_ROCM, TEST_WITH_TORCHDYNAMO, TestCase, unMarkDynamoStrictTest, @@ -683,8 +684,6 @@ def test_not_enough_in_dims_err_msg(self): vmap(torch.mul, (0, 0))(x, y) def test_integer_in_dim_but_not_tensor_input_err_msg(self): - # noqa: F841 - def foo(xy): return xy[0] * xy[1] @@ -3285,6 +3284,17 @@ def test_view_as(self): in_dims=(2, 0), ) + def test_view_dtype(self): + test = functools.partial(self._vmap_test, check_propagates_grad=False) + op = torch.ops.aten.view.dtype + + test(op, (torch.rand(2, 3, 4), torch.uint8), in_dims=(1, None), out_dims=1) + test(op, (torch.rand(5), torch.int32), in_dims=(0, None), out_dims=0) + with self.assertRaisesRegex( + RuntimeError, r"dim\(\) cannot be 0 to view Float as Byte" + ): + vmap(op, in_dims=(0, None))(torch.rand(6), torch.uint8) + def test_conv2d(self): conv_setups = [ (torch.nn.Conv1d, torch.conv1d, [2, 4, 15]), @@ -4423,9 +4433,6 @@ def sample_vmap_out_dim_numpy_split_copy_with_int( xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints # TypeError: expected Tensor as element 0 in argument 0, but got float xfail("item"), - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. # RuntimeError: required rank 4 tensor to use channels_last format xfailIf( "to", @@ -4508,9 +4515,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("item"), xfail("tril"), # Exception not raised on error input xfail("triu"), # Exception not raised on error input - xfail( - "unbind_copy" - ), # Batching rule not implemented for aten::unbind_copy.int. xfail("__getitem__", ""), xfail("count_nonzero"), xfail( @@ -4827,6 +4831,60 @@ def test(): check_vmap_fallback(self, test, Tensor.fill_) + @parametrize( + "op,msg,extra_positional_args,extra_kwargs", + [ + subtest( + ( + Tensor.scatter_add_, + "out-of-place operators instead of scatter_add_", + (), + {}, + ), + name="scatter_add", + ), + subtest( + ( + Tensor.scatter_reduce_, + "out-of-place operators instead of scatter_reduce_", + ("sum",), + {"include_self": True}, + ), + name="scatter_reduce", + ), + ], + ) + def test_scatter_inplace_self_not_batched( + self, device, op, msg, extra_positional_args, extra_kwargs + ): + x = torch.zeros(5, device=device) + + def call_op(self, dim, index, src): + return op( + self, + dim, + index, + src, + *extra_positional_args, + **extra_kwargs, + ) + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(call_op, in_dims=(None, None, None, 0))( + x, + 0, + torch.tensor([0, 1], device=device), + torch.randn(2, 2, device=device), + ) + + with self.assertRaisesRegex(RuntimeError, msg): + vmap(call_op, in_dims=(None, None, 0, None))( + x, + 0, + torch.tensor([[0, 1], [2, 3]], device=device), + torch.randn(2, device=device), + ) + @tf32_on_and_off(0.005) def test_conv_double_backward(self, device): images = torch.randn(2, 1, 5, 5, device=device) @@ -4943,6 +5001,23 @@ def test_group_norm(self, device): bias = torch.randn(B, C) test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0)) + def test_group_norm_layout_corruption(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/176432 + def op_function(cotangent): + input = torch.tensor([[-6.517, -6.264]], device=device, requires_grad=True) + result = F.group_norm(input, num_groups=1) + return torch.autograd.grad(result, input, grad_outputs=cotangent)[0] + + cotangent = torch.tensor([[-0.0236, -0.1431]], device=device) + result_in_dims_0 = vmap(op_function, in_dims=0)( + cotangent.unsqueeze(0).expand(2, -1, -1) + ) + result_in_dims_neg1 = vmap(op_function, in_dims=-1)( + cotangent.unsqueeze(-1).expand(-1, -1, 2) + ) + + self.assertEqual(result_in_dims_0, result_in_dims_neg1) + def test_index_put(self, device): def test(f, t, idx, values): base = f(t[0], idx[0], values[0]) @@ -5103,9 +5178,12 @@ def test_torch_return_types_returns(self, device): vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk ) ) - self.assertTrue( - isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig) - ) + if not (TEST_WITH_ROCM and not torch.cuda.has_magma): + self.assertTrue( + isinstance( + vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig + ) + ) def test_namedtuple_returns(self, device): Point = namedtuple("Point", ["x", "y"]) diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 22d79a997b84b..d9d12526ea833 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -61,6 +61,7 @@ "aten::cummaxmin_backward", "aten::data", "aten::diagflat", + "aten::dim", "aten::divide.out_mode", "aten::divide_.Scalar", "aten::dropout_", @@ -82,6 +83,7 @@ "aten::floor_divide_.Scalar", "aten::frobenius_norm", "aten::fused_moving_avg_obs_fake_quant", + "aten::get_device", "aten::get_gradients", "aten::greater_.Scalar", "aten::greater_.Tensor", @@ -145,6 +147,7 @@ "aten::norm_except_dim", "aten::not_equal_.Scalar", "aten::not_equal_.Tensor", + "aten::numel", "aten::one_hot", "aten::output_nr", "aten::pad_sequence", @@ -203,6 +206,7 @@ "aten::std_mean.names_dim", "aten::stft", "aten::stft.center", + "aten::storage_offset", "aten::stride.int", "aten::subtract.Scalar", "aten::subtract_.Scalar", diff --git a/test/fx/test_fx_split_node_finder.py b/test/fx/test_fx_split_node_finder.py index 8916140aa24a3..0007ed3885969 100644 --- a/test/fx/test_fx_split_node_finder.py +++ b/test/fx/test_fx_split_node_finder.py @@ -165,16 +165,22 @@ def _testTrackerMode(self, mode): self._assert_events_file(events_file) self._validate_file_content( nodes_file, - ["|-sup_f_2: init_acc|callable_and_operator_supported #"], + [ + "===== Tracking node sup_f_2 =====", + "|-sup_f_2: init_acc|callable_and_operator_supported #", + "===== End of tracking node sup_f_2 =====", + ], ) elif mode == 3: self._assert_events_file(events_file) self._validate_file_content( nodes_file, [ + "===== Tracking node sup_f_1 =====", "|-sup_f_1: init_acc|callable_and_operator_supported #", "|-sup_f_1: acc_del|non_tensor_output_with_cpu_user add", "| |-add: init_cpu|operator_support #", + "===== End of tracking node sup_f_1 =====", ], ) diff --git a/test/fx/test_graph_pickler.py b/test/fx/test_graph_pickler.py index 6784fbf502e7c..960bb23651a2e 100644 --- a/test/fx/test_graph_pickler.py +++ b/test/fx/test_graph_pickler.py @@ -129,7 +129,7 @@ def test_nested_unpicklable_in_list(self): When a lambda is nested in a list, debug_dumps should find the path to it (e.g., "root[1]"). """ - bad_obj = [1, lambda x: x, 3] # noqa: E731 + bad_obj = [1, lambda x: x, 3] result = self.GraphPickler.debug_dumps(bad_obj, verbose=False) self.assertIn("root[1]", result) @@ -138,7 +138,7 @@ def test_nested_unpicklable_in_dict(self): When a lambda is nested in a dict, debug_dumps should find the path to it (e.g., "root['bad_key']"). """ - bad_obj = {"good": 1, "bad": lambda x: x} # noqa: E731 + bad_obj = {"good": 1, "bad": lambda x: x} result = self.GraphPickler.debug_dumps(bad_obj, verbose=False) self.assertIn("root['bad']", result) @@ -146,7 +146,7 @@ def test_deeply_nested_unpicklable(self): """ debug_dumps should find unpicklables even when deeply nested. """ - bad_obj = {"level1": {"level2": {"level3": [1, 2, lambda x: x]}}} # noqa: E731 + bad_obj = {"level1": {"level2": {"level3": [1, 2, lambda x: x]}}} result = self.GraphPickler.debug_dumps(bad_obj, verbose=False) self.assertIn("level3", result) self.assertIn("[2]", result) @@ -155,7 +155,7 @@ def test_unpicklable_in_tuple(self): """ debug_dumps should handle tuples correctly. """ - bad_obj = (1, 2, lambda x: x) # noqa: E731 + bad_obj = (1, 2, lambda x: x) result = self.GraphPickler.debug_dumps(bad_obj, verbose=False) self.assertIn("root[2]", result) @@ -176,7 +176,7 @@ def test_max_depth_limit(self): def build_nested(depth): if depth == 0: - return lambda x: x # noqa: E731 + return lambda x: x return [build_nested(depth - 1)] deeply_nested = build_nested(100) @@ -193,7 +193,7 @@ def test_object_with_unpicklable_attribute(self): class Container: def __init__(self): self.good = 1 - self.bad = lambda x: x # noqa: E731 + self.bad = lambda x: x obj = Container() result = self.GraphPickler.debug_dumps(obj, verbose=False) @@ -210,7 +210,7 @@ class MyData: good: int bad: object - obj = MyData(good=1, bad=lambda x: x) # noqa: E731 + obj = MyData(good=1, bad=lambda x: x) result = self.GraphPickler.debug_dumps(obj, verbose=False) self.assertIn("bad", result) @@ -226,7 +226,7 @@ def test_verbose_output(self): old_stdout = sys.stdout sys.stdout = captured try: - self.GraphPickler.debug_dumps([1, lambda x: x], verbose=True) # noqa: E731 + self.GraphPickler.debug_dumps([1, lambda x: x], verbose=True) finally: sys.stdout = old_stdout @@ -554,7 +554,7 @@ def forward(self, x): gm = torch.fx.symbolic_trace(SimpleModule()) for node in gm.graph.nodes: - node.meta["lambda_fn"] = lambda x: x * 3 # noqa: E731 + node.meta["lambda_fn"] = lambda x: x * 3 options = self.Options(node_metadata_key_filter=None) serialized = self.GraphPickler.dumps(gm, options) @@ -858,6 +858,243 @@ def forward(self, x): self.assertIn("custom_key", pickle_data.meta) +class TestNodeStateSerialization(TestCase): + def test_type_entry_preserved_in_getstate(self): + class M(torch.nn.Module): + def forward(self, x): + y = torch.neg(x) + return y + 1 + + gm = torch.fx.symbolic_trace(M()) + node = next(n for n in gm.graph.nodes if n.op == "call_function") + node.type = torch.Tensor + state = node.__getstate__() + self.assertIs(state["type"], torch.Tensor) + + +@unittest.skipUnless(HAS_DILL, "dill not available") +class TestIgnoreRawNode(TestCase): + """Tests for the ignore_raw_node option in GraphPickler.Options.""" + + def setUp(self): + super().setUp() + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler, Options + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + self.GraphPickler = GraphPickler + self.Options = Options + self.fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + def _make_graph_with_raw_node_in_meta(self): + """Return a graph module whose first call_function node has a raw + torch.fx.Node stored in its metadata under the key 'raw_ref'.""" + + class M(torch.nn.Module): + def forward(self, x): + return x + 1 + + gm = torch.fx.symbolic_trace(M()) + call_node = next((n for n in gm.graph.nodes if n.op == "call_function"), None) + self.assertIsNotNone(call_node) + # Store a raw Node reference in meta – this is the problematic case. + call_node.meta["raw_ref"] = call_node + return gm + + def test_raw_node_in_meta_raises_by_default(self): + """Pickling should raise AssertionError when a raw Node is in metadata + and ignore_raw_node is False (the default).""" + gm = self._make_graph_with_raw_node_in_meta() + with self.assertRaises(AssertionError) as cm: + self.GraphPickler.dumps(gm) + self.assertIn("raw Node", str(cm.exception)) + + def test_raw_node_in_meta_with_ignore_raw_node(self): + """With ignore_raw_node=True, pickling should succeed and the raw Node + should be replaced with None after round-trip deserialization.""" + gm = self._make_graph_with_raw_node_in_meta() + options = self.Options(ignore_raw_node=True) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + self.assertIsInstance(restored, torch.fx.GraphModule) + call_node = next( + (n for n in restored.graph.nodes if n.op == "call_function"), None + ) + self.assertIsNotNone(call_node) + self.assertIsNone(call_node.meta.get("raw_ref")) + + +class _WeakrefTarget: + """A simple picklable class that supports weak references, for use in tests. + + Plain dicts do not support weak references in Python, so tests must use + instances of a regular class instead. + """ + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +@unittest.skipUnless(HAS_DILL, "dill not available") +class TestWeakrefPickle(TestCase): + """Tests that weakref objects are properly serialized and reconstructed.""" + + def setUp(self): + super().setUp() + import weakref + + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler, Options + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + self.weakref = weakref + self.GraphPickler = GraphPickler + self.Options = Options + self.fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + def _make_graph_with_weakref_in_meta(self, ref_obj): + """Return a graph module with a weakref stored in node metadata.""" + + class M(torch.nn.Module): + def forward(self, x): + return x + 1 + + gm = torch.fx.symbolic_trace(M()) + call_node = next((n for n in gm.graph.nodes if n.op == "call_function"), None) + self.assertIsNotNone(call_node) + call_node.meta["weak_ref"] = ref_obj + return gm + + def test_alive_weakref_in_meta_is_reconstructed(self): + """An alive weakref.ref in node metadata should be reconstructed as a weakref.""" + target = _WeakrefTarget(key="value") + weak = self.weakref.ref(target) + gm = self._make_graph_with_weakref_in_meta(weak) + # Also store a strong ref so the referent survives after unpickling + call_node = next((n for n in gm.graph.nodes if n.op == "call_function"), None) + call_node.meta["strong_ref"] = target + + options = self.Options(node_metadata_key_filter=None) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + + self.assertIsInstance(restored, torch.fx.GraphModule) + call_node = next( + (n for n in restored.graph.nodes if n.op == "call_function"), None + ) + self.assertIsNotNone(call_node) + restored_ref = call_node.meta.get("weak_ref") + self.assertIsInstance(restored_ref, self.weakref.ref) + self.assertEqual(restored_ref().key, "value") + + def test_dead_weakref_in_meta_unpickles_as_callable_none(self): + """A dead weakref should unpickle as a callable that returns None.""" + target = _WeakrefTarget() + weak = self.weakref.ref(target) + gm = self._make_graph_with_weakref_in_meta(weak) + # Kill the referent so the weakref is dead at pickle time + del target + + options = self.Options(node_metadata_key_filter=None) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + + self.assertIsInstance(restored, torch.fx.GraphModule) + call_node = next( + (n for n in restored.graph.nodes if n.op == "call_function"), None + ) + self.assertIsNotNone(call_node) + restored_ref = call_node.meta.get("weak_ref") + # Should be callable and return None, like a dead weakref + self.assertIsNotNone(restored_ref) + self.assertIsNone(restored_ref()) + + def test_keyed_ref_in_meta_is_reconstructed(self): + """A weakref.KeyedRef (from WeakValueDictionary) should be reconstructed.""" + wvd = self.weakref.WeakValueDictionary() + target = _WeakrefTarget(val=42) + wvd["k"] = target + keyed_ref = wvd.data["k"] + self.assertIsInstance(keyed_ref, self.weakref.KeyedRef) + + gm = self._make_graph_with_weakref_in_meta(keyed_ref) + # Also store a strong ref so the referent survives after unpickling + call_node = next((n for n in gm.graph.nodes if n.op == "call_function"), None) + call_node.meta["strong_ref"] = target + + options = self.Options(node_metadata_key_filter=None) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + + self.assertIsInstance(restored, torch.fx.GraphModule) + call_node = next( + (n for n in restored.graph.nodes if n.op == "call_function"), None + ) + self.assertIsNotNone(call_node) + restored_ref = call_node.meta.get("weak_ref") + self.assertIsInstance(restored_ref, self.weakref.ref) + self.assertEqual(restored_ref().val, 42) + + def test_weakref_in_module_dict_is_reconstructed(self): + """A weakref stored in the graph module's __dict__ should be reconstructed.""" + + class M(torch.nn.Module): + def forward(self, x): + return x + 1 + + gm = torch.fx.symbolic_trace(M()) + target = _WeakrefTarget(key="value") + gm._weak_cache = self.weakref.ref(target) + # Also store a strong ref so the referent survives after unpickling + gm._strong_cache = target + + options = self.Options(node_metadata_key_filter=None) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + + self.assertIsInstance(restored, torch.fx.GraphModule) + restored_ref = restored._weak_cache + self.assertIsInstance(restored_ref, self.weakref.ref) + self.assertEqual(restored_ref().key, "value") + + def test_weakref_and_strong_ref_share_same_object(self): + """When a weakref and a strong ref point to the same object, pickle's + memo should deduplicate them so they share identity after unpickling. + This also covers the case where the weakref is the first reference + pickle encounters — the memo must still work correctly.""" + + class M(torch.nn.Module): + def forward(self, x): + return x + 1 + + gm = torch.fx.symbolic_trace(M()) + target = _WeakrefTarget(key="value") + weak = self.weakref.ref(target) + + call_node = next((n for n in gm.graph.nodes if n.op == "call_function"), None) + self.assertIsNotNone(call_node) + # Put weakref first so it's the first reference pickle encounters + call_node.meta["weak_ref"] = weak + call_node.meta["strong_ref"] = target + + options = self.Options(node_metadata_key_filter=None) + data = self.GraphPickler.dumps(gm, options) + restored = self.GraphPickler.loads(data, self.fake_mode) + + self.assertIsInstance(restored, torch.fx.GraphModule) + restored_node = next( + (n for n in restored.graph.nodes if n.op == "call_function"), None + ) + self.assertIsNotNone(restored_node) + + restored_weak = restored_node.meta["weak_ref"] + restored_strong = restored_node.meta["strong_ref"] + + self.assertIsInstance(restored_weak, self.weakref.ref) + # The weakref's referent and the strong ref should be the same object + self.assertIs(restored_weak(), restored_strong) + + if __name__ == "__main__": from torch.testing._internal.common_utils import run_tests diff --git a/test/fx/test_subgraph_rewriter.py b/test/fx/test_subgraph_rewriter.py index d19ee70da5000..29803f4ee64bd 100644 --- a/test/fx/test_subgraph_rewriter.py +++ b/test/fx/test_subgraph_rewriter.py @@ -950,7 +950,7 @@ def replacement(x, arg0, arg1): def forward(self, x): _reshape_alias_copy_default_1 = torch.ops.aten._reshape_alias_copy.default(x, [3, 4], [1, 2]); x = None return _reshape_alias_copy_default_1""", - ) # noqa: B950 + ) def test_replacement_with_attrs(self): class M(torch.nn.Module): diff --git a/test/higher_order_ops/test_debug_log.py b/test/higher_order_ops/test_debug_log.py new file mode 100644 index 0000000000000..f6ece24227453 --- /dev/null +++ b/test/higher_order_ops/test_debug_log.py @@ -0,0 +1,75 @@ +# Owner(s): ["module: higher order operators"] +"""Tests for torch.utils.debug_log.debug_grad_log.""" + +import logging + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils.debug_log import debug_grad_log + + +class TestDebugGradLog(TestCase): + def setUp(self): + super().setUp() + self._log_records: list[str] = [] + + class _Handler(logging.Handler): + def __init__(self, dest): + super().__init__() + self.dest = dest + + def emit(self, record): + self.dest.append(self.format(record)) + + self._handler = _Handler(self._log_records) + logger = logging.getLogger("torch.utils.debug_log") + logger.addHandler(self._handler) + logger.setLevel(logging.INFO) + self.addCleanup(logger.removeHandler, self._handler) + + @property + def bwd_logs(self) -> list[str]: + return [r for r in self._log_records if "[bwd]" in r] + + def test_single_tensor(self): + x = torch.randn(4, requires_grad=True) + y = x * 2 + debug_grad_log(y) + y.sum().backward() + + self.assertEqual(len(self.bwd_logs), 1) + self.assertIn("t0_grad_norm=", self.bwd_logs[0]) + + def test_multi_tensor(self): + x = torch.randn(4, requires_grad=True) + y = torch.randn(4, requires_grad=True) + debug_grad_log(x, y) + (x * 2 + y * 3).sum().backward() + + self.assertEqual(len(self.bwd_logs), 1) + self.assertIn("t0_grad_norm=", self.bwd_logs[0]) + self.assertIn("t1_grad_norm=", self.bwd_logs[0]) + + def test_gradient_values(self): + x = torch.tensor([1.0], requires_grad=True) + y = torch.tensor([1.0], requires_grad=True) + debug_grad_log(x, y) + (x * 2 + y * 3).sum().backward() + + self.assertEqual(len(self.bwd_logs), 1) + self.assertIn("t0_grad_norm=2.0000", self.bwd_logs[0]) + self.assertIn("t1_grad_norm=3.0000", self.bwd_logs[0]) + + def test_no_requires_grad_no_log(self): + x = torch.randn(3, requires_grad=False) + debug_grad_log(x) + self.assertEqual(len(self.bwd_logs), 0) + + def test_forward_is_noop(self): + x = torch.randn(3, requires_grad=True) + debug_grad_log(x) + self.assertEqual(len(self._log_records), 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/higher_order_ops/test_inline_asm_elementwise.py b/test/higher_order_ops/test_inline_asm_elementwise.py new file mode 100644 index 0000000000000..07201cc666854 --- /dev/null +++ b/test/higher_order_ops/test_inline_asm_elementwise.py @@ -0,0 +1,733 @@ +# Owner(s): ["module: higher order operators"] +""" +Tests for inline_asm_elementwise higher-order operator. + +Tests verify: +1. Bitwise equivalence between eager (Jiterator) and compiled (Inductor) paths +2. Correctness via approximate comparison with reference PyTorch ops +""" + +import unittest +from collections.abc import Callable +from dataclasses import dataclass + +import torch +from torch._higher_order_ops.inline_asm_elementwise import inline_asm_elementwise +from torch.testing._internal.common_cuda import evaluate_gfx_arch_within, SM70OrLater +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + MI200_ARCH, + MI300_ARCH, + NAVI_ARCH, + parametrize, + run_tests, + TEST_CUDA, + TestCase, +) + + +@dataclass +class AsmTestCase: + name: str + input_gen_fn: Callable + asm_str: str + constraints: str + dtype: torch.dtype + approx_fn: Callable + pack: int = 1 + compile_only: bool = False + min_sm: int = 70 + + +TEST_CASES = [ + # Basic float32 operations + AsmTestCase( + "identity_f32", + lambda: (torch.randn(100, device="cuda", dtype=torch.float32),), + "v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + "=v, v" if torch.version.hip else "=f,f", + torch.float32, + lambda x: x, + ), + AsmTestCase( + "add_f32", + lambda: ( + torch.randn(100, device="cuda", dtype=torch.float32), + torch.randn(100, device="cuda", dtype=torch.float32), + ), + "v_add_f32 $0, $1, $2" if torch.version.hip else "add.f32 $0, $1, $2;", + "=v, v, v" if torch.version.hip else "=f,f,f", + torch.float32, + lambda x, y: x + y, + ), + AsmTestCase( + "mul_f32", + lambda: ( + torch.randn(100, device="cuda", dtype=torch.float32), + torch.randn(100, device="cuda", dtype=torch.float32), + ), + "v_mul_f32 $0, $1, $2" if torch.version.hip else "mul.f32 $0, $1, $2;", + "=v, v, v" if torch.version.hip else "=f,f,f", + torch.float32, + lambda x, y: x * y, + ), + AsmTestCase( + "fma_f32", + lambda: ( + torch.randn(100, device="cuda", dtype=torch.float32), + torch.randn(100, device="cuda", dtype=torch.float32), + torch.randn(100, device="cuda", dtype=torch.float32), + ), + "v_fma_f32 $0, $1, $2, $3" + if torch.version.hip + else "fma.rn.f32 $0, $1, $2, $3;", + "=v, v, v, v" if torch.version.hip else "=f,f,f,f", + torch.float32, + lambda a, b, c: a * b + c, + ), + # Multi-line inline asm. PTX uses curly braces; AMDGCN uses newlines. + AsmTestCase( + "double_multiline", + lambda: (torch.randn(100, device="cuda", dtype=torch.float32),), + ( + """ + v_mov_b32 $0, $1 + v_add_f32 $0, $0, $1 + """ + if torch.version.hip + else "{.reg .f32 tmp; mov.f32 tmp, $1; add.f32 $0, tmp, tmp;}" + ), + "=v, v" if torch.version.hip else "=f,f", + torch.float32, + lambda x: x * 2, + ), + # bf16/fp16 upcasting (compile-only: Jiterator can't handle dtype mismatch) + AsmTestCase( + "bf16_upcast", + lambda: (torch.randn(100, device="cuda", dtype=torch.bfloat16),), + "v_add_f32 $0, $1, $1" if torch.version.hip else "add.f32 $0, $1, $1;", + "=v, v" if torch.version.hip else "=f,f", + torch.float32, + lambda x: x.float() * 2, + compile_only=True, + min_sm=80, + ), + AsmTestCase( + "fp16_upcast", + lambda: (torch.randn(100, device="cuda", dtype=torch.float16),), + "v_add_f32 $0, $1, $1" if torch.version.hip else "add.f32 $0, $1, $1;", + "=v, v" if torch.version.hip else "=f,f", + torch.float32, + lambda x: x.float() * 2, + compile_only=True, + ), + # Integer operations + AsmTestCase( + "bitwise_and", + lambda: ( + torch.randint(0, 2**16, (100,), device="cuda", dtype=torch.int32), + torch.randint(0, 2**16, (100,), device="cuda", dtype=torch.int32), + ), + "v_and_b32 $0, $1, $2" if torch.version.hip else "and.b32 $0, $1, $2;", + "=v, v, v" if torch.version.hip else "=r,r,r", + torch.int32, + lambda x, y: x & y, + ), + AsmTestCase( + "bitwise_or", + lambda: ( + torch.randint(0, 2**16, (100,), device="cuda", dtype=torch.int32), + torch.randint(0, 2**16, (100,), device="cuda", dtype=torch.int32), + ), + "v_or_b32 $0, $1, $2" if torch.version.hip else "or.b32 $0, $1, $2;", + "=v, v, v" if torch.version.hip else "=r,r,r", + torch.int32, + lambda x, y: x | y, + ), + # Output dtype differs from input (compile-only: Jiterator returns input dtype) + # AMDGCN: v_bfe_u32 (bit-field extract) replaces PTX's multi-instruction + # shift-and-mask sequence in a single instruction. + AsmTestCase( + "exponent_extract", + lambda: ( + torch.tensor([1.0, 2.0, 0.5, 16.0], device="cuda", dtype=torch.float32), + ), + ( + "v_bfe_u32 $0, $1, 23, 8" + if torch.version.hip + else "{.reg .b32 t; mov.b32 t,$1; shr.u32 t,t,23; and.b32 $0,t,0xFF;}" + ), + "=v, v" if torch.version.hip else "=r,f", + torch.int32, + lambda x: ((x.view(torch.int32) >> 23) & 0xFF).to(torch.int32), + compile_only=True, + ), + # Truncate u32 -> u16 (compile-only). + # PTX: uses "h" (16-bit) output / "r" (32-bit) input constraints. + # AMDGCN: VGPRs are always 32-bit (no "h" equivalent), so we use "v" + # and extract the lower 16 bits via v_bfe_u32. + AsmTestCase( + "truncate_to_uint16", + lambda: (torch.randint(0, 256, (100,), device="cuda", dtype=torch.int32),), + "v_bfe_u32 $0, $1, 0, 16" if torch.version.hip else "cvt.u16.u32 $0, $1;", + "=v, v" if torch.version.hip else "=h,r", + torch.uint16, + lambda x: x.to(torch.uint16), + compile_only=True, + ), + # Broadcasting + AsmTestCase( + "broadcast_add", + lambda: ( + torch.randn(4, 1, device="cuda", dtype=torch.float32), + torch.randn(1, 8, device="cuda", dtype=torch.float32), + ), + "v_add_f32 $0, $1, $2" if torch.version.hip else "add.f32 $0, $1, $2;", + "=v, v, v" if torch.version.hip else "=f,f,f", + torch.float32, + lambda x, y: x + y, + ), + # Non-contiguous + AsmTestCase( + "noncontiguous", + lambda: (torch.randn(8, 16, device="cuda", dtype=torch.float32).t(),), + "v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + "=v, v" if torch.version.hip else "=f,f", + torch.float32, + lambda x: x, + ), + # fp16/bf16 native asm (compile-only: inductor computes in fp32, needs downcast) + # ROCm: Inductor feeds f32 values (upcasted for computation). AMDGCN has no + # "h" constraint for 16-bit regs, so we add in f32 and convert to the target + # format. PTX "h" constraints tell Triton to downcast before the asm. + AsmTestCase( + "add_fp16_native", + lambda: ( + torch.randn(100, device="cuda", dtype=torch.float16), + torch.randn(100, device="cuda", dtype=torch.float16), + ), + ( + "v_add_f32 $0, $1, $2\nv_cvt_f16_f32 $0, $0" + if torch.version.hip + else "add.f16 $0, $1, $2;" + ), + "=v,v,v" if torch.version.hip else "=h,h,h", + torch.float16, + lambda x, y: x + y, + compile_only=True, + ), + # AMDGCN: v_cvt_pk_bf16_f32 packs two f32 values into bf16 in a single + # 32-bit register. We pass $0 twice — only the lower 16 bits (first + # bf16 slot) are used by Triton. + AsmTestCase( + "add_bf16_native", + lambda: ( + torch.randn(100, device="cuda", dtype=torch.bfloat16), + torch.randn(100, device="cuda", dtype=torch.bfloat16), + ), + ( + "v_add_f32 $0, $1, $2\nv_cvt_pk_bf16_f32 $0, $0, $0" + if torch.version.hip + else "add.bf16 $0, $1, $2;" + ), + "=v,v,v" if torch.version.hip else "=h,h,h", + torch.bfloat16, + lambda x, y: x + y, + compile_only=True, + min_sm=90, + ), + # pack=2: each asm invocation processes 2 elements (compile-only) + AsmTestCase( + "identity_pack2", + lambda: (torch.randn(128, device="cuda", dtype=torch.float32),), + ( + """ + v_mov_b32 $0, $2 + v_mov_b32 $1, $3 + """ + if torch.version.hip + else "mov.b32 $0, $2; mov.b32 $1, $3;" + ), + "=v,=v,v,v" if torch.version.hip else "=r,=r,r,r", + torch.float32, + lambda x: x, + pack=2, + compile_only=True, + ), + AsmTestCase( + "add_pack2", + lambda: ( + torch.randn(128, device="cuda", dtype=torch.float32), + torch.randn(128, device="cuda", dtype=torch.float32), + ), + ( + """ + v_add_f32 $0, $2, $4 + v_add_f32 $1, $3, $5 + """ + if torch.version.hip + else "add.f32 $0, $2, $4; add.f32 $1, $3, $5;" + ), + "=v,=v,v,v,v,v" if torch.version.hip else "=f,=f,f,f,f,f", + torch.float32, + lambda x, y: x + y, + pack=2, + compile_only=True, + ), +] +TEST_CASE_NAMES = [tc.name for tc in TEST_CASES] + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM70OrLater, "Requires SM70+") +@instantiate_parametrized_tests +class TestInlineAsmElementwise(TestCase): + """Parametrized tests for inline_asm_elementwise.""" + + @parametrize( + "case_idx", list(range(len(TEST_CASES))), name_fn=lambda i: TEST_CASE_NAMES[i] + ) + def test_eager_vs_compiled_bitwise(self, case_idx): + """Verify eager and compiled produce bitwise identical results.""" + tc = TEST_CASES[case_idx] + if not torch.version.hip and torch.cuda.get_device_capability() < ( + tc.min_sm // 10, + tc.min_sm % 10, + ): + self.skipTest(f"Requires SM{tc.min_sm}+") + + # Native bf16 conversion instruction not available before gfx950. + if ( + torch.version.hip + and tc.name == "add_bf16_native" + and evaluate_gfx_arch_within( + [ + *MI200_ARCH, + *MI300_ARCH, + *NAVI_ARCH, + ] + ) + ): + self.skipTest("Requires gfx950+") + + inputs = tc.input_gen_fn() + + def fn(*args): + return inline_asm_elementwise( + *args, + asm_str=tc.asm_str, + constraints=tc.constraints, + dtype=tc.dtype, + pack=tc.pack, + ) + + torch._dynamo.reset() + compiled_result = torch.compile(fn, backend="inductor")(*inputs) + + if tc.compile_only: + expected = tc.approx_fn(*inputs) + self.assertEqual( + compiled_result.float(), expected.float(), atol=1e-5, rtol=1e-5 + ) + else: + eager_result = fn(*inputs) + self.assertEqual(eager_result, compiled_result) + + @parametrize( + "case_idx", list(range(len(TEST_CASES))), name_fn=lambda i: TEST_CASE_NAMES[i] + ) + def test_correctness(self, case_idx): + """Verify result matches reference function.""" + tc = TEST_CASES[case_idx] + if not torch.version.hip and torch.cuda.get_device_capability() < ( + tc.min_sm // 10, + tc.min_sm % 10, + ): + self.skipTest(f"Requires SM{tc.min_sm}+") + + # Native bf16 conversion instruction not available before gfx950. + if ( + torch.version.hip + and tc.name == "add_bf16_native" + and evaluate_gfx_arch_within( + [ + *MI200_ARCH, + *MI300_ARCH, + *NAVI_ARCH, + ] + ) + ): + self.skipTest("Requires gfx950+") + + inputs = tc.input_gen_fn() + + def fn(*args): + return inline_asm_elementwise( + *args, + asm_str=tc.asm_str, + constraints=tc.constraints, + dtype=tc.dtype, + pack=tc.pack, + ) + + if tc.compile_only: + torch._dynamo.reset() + result = torch.compile(fn, backend="inductor")(*inputs) + else: + result = fn(*inputs) + expected = tc.approx_fn(*inputs) + + self.assertEqual(result.float(), expected.float(), atol=1e-5, rtol=1e-5) + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +class TestInlineAsmElementwiseErrors(TestCase): + """Tests for error handling.""" + + def test_error_no_inputs(self): + with self.assertRaises(ValueError): + inline_asm_elementwise( + asm_str="v_mov_b32 $0, 1.0" + if torch.version.hip + else "mov.f32 $0, 1.0;", + constraints="=v" if torch.version.hip else "=f", + dtype=torch.float32, + ) + + def test_error_constraint_mismatch(self): + x = torch.randn(100, device="cuda", dtype=torch.float32) + y = torch.randn(100, device="cuda", dtype=torch.float32) + with self.assertRaises(ValueError): + inline_asm_elementwise( + x, + y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v,v" if torch.version.hip else "=f,f", + dtype=torch.float32, + ) + + def test_error_mixed_dtypes(self): + x = torch.randn(100, device="cuda", dtype=torch.float32) + y = torch.randint(0, 10, (100,), device="cuda", dtype=torch.int32) + with self.assertRaises(ValueError): + inline_asm_elementwise( + x, + y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v,v,v" if torch.version.hip else "=f,f,r", + dtype=torch.float32, + ) + + def test_error_cpu_tensor(self): + x = torch.randn(100, dtype=torch.float32) + with self.assertRaises(RuntimeError): + inline_asm_elementwise( + x, + asm_str="v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + constraints="=v,v" if torch.version.hip else "=f,f", + dtype=torch.float32, + ) + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM70OrLater, "Requires SM70+") +class TestInlineAsmElementwiseEdgeCases(TestCase): + """Tests for edge cases.""" + + def test_empty_tensor(self): + x = torch.empty(0, device="cuda", dtype=torch.float32) + result = inline_asm_elementwise( + x, + asm_str="v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + constraints="=v, v" if torch.version.hip else "=f,f", + dtype=torch.float32, + ) + self.assertEqual(result.shape, torch.Size([0])) + + def test_scalar_tensor(self): + x = torch.tensor(3.14, device="cuda", dtype=torch.float32) + result = inline_asm_elementwise( + x, + asm_str="v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + constraints="=v, v" if torch.version.hip else "=f,f", + dtype=torch.float32, + ) + self.assertEqual(result.shape, torch.Size([])) + self.assertEqual(result, x) + + def test_4d_tensor(self): + x = torch.randn(2, 3, 4, 5, device="cuda", dtype=torch.float32) + result = inline_asm_elementwise( + x, + asm_str="v_mov_b32 $0, $1" if torch.version.hip else "mov.f32 $0, $1;", + constraints="=v, v" if torch.version.hip else "=f,f", + dtype=torch.float32, + ) + self.assertEqual(result.shape, x.shape) + self.assertEqual(result, x) + + def test_composition_with_pytorch_ops(self): + def fn(x, y): + z = x * 2 + w = inline_asm_elementwise( + z, + y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v, v, v" if torch.version.hip else "=f,f,f", + dtype=torch.float32, + ) + return w + 1.0 + + x = torch.randn(100, device="cuda", dtype=torch.float32) + y = torch.randn(100, device="cuda", dtype=torch.float32) + + eager_result = fn(x, y) + compiled_fn = torch.compile(fn, backend="inductor") + compiled_result = compiled_fn(x, y) + + self.assertEqual(eager_result, compiled_result) + self.assertEqual(eager_result, x * 2 + y + 1.0) + + def test_output_strides_mixed_inputs(self): + """Verify fake mode output strides match eager (TensorIterator) strides.""" + from torch._subclasses.fake_tensor import FakeTensorMode + + # Two inputs with different strides: one contiguous, one transposed. + # This exercises TensorIterator's slow path for stride computation. + x = torch.randn(8, 16, device="cuda", dtype=torch.float32) + y = torch.randn(16, 8, device="cuda", dtype=torch.float32).t() + + eager_result = inline_asm_elementwise( + x, + y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v, v, v" if torch.version.hip else "=f,f,f", + dtype=torch.float32, + ) + + with FakeTensorMode() as mode: + fake_x = mode.from_tensor(x) + fake_y = mode.from_tensor(y) + fake_result = inline_asm_elementwise( + fake_x, + fake_y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v, v, v" if torch.version.hip else "=f,f,f", + dtype=torch.float32, + ) + + self.assertEqual(eager_result.shape, fake_result.shape) + self.assertEqual(eager_result.stride(), fake_result.stride()) + + def test_dynamic_shapes(self): + def fn(x, y): + return inline_asm_elementwise( + x, + y, + asm_str="v_add_f32 $0, $1, $2" + if torch.version.hip + else "add.f32 $0, $1, $2;", + constraints="=v, v, v" if torch.version.hip else "=f,f,f", + dtype=torch.float32, + ) + + compiled_fn = torch.compile(fn, backend="inductor", dynamic=True) + + for size in [50, 100, 200]: + x = torch.randn(size, device="cuda", dtype=torch.float32) + y = torch.randn(size, device="cuda", dtype=torch.float32) + eager_result = fn(x, y) + compiled_result = compiled_fn(x, y) + self.assertEqual(eager_result, compiled_result) + + +@unittest.skipIf(not TEST_CUDA, "CUDA not available") +@unittest.skipIf(not SM70OrLater, "Requires SM70+") +class TestInlineAsmPackPadding(TestCase): + """Test that pack padding works when block size < pack.""" + + def test_pack2_xblock1_padding(self): + """Force XBLOCK=1 with pack=2 so padding is needed.""" + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.triton import FixedTritonConfig + from torch._inductor.utils import run_and_get_code + from torch.testing import FileCheck + + class ForceXBlock1(InductorChoices): + def triton_kernel_kwargs(self, kernel_cls, features, groups, kernel_kwargs): + return { + **kernel_kwargs, + "fixed_config": FixedTritonConfig({"XBLOCK": 1}), + } + + def fn(x): + return inline_asm_elementwise( + x, + asm_str=( + """ + v_mov_b32 $0, $2 + v_mov_b32 $1, $3 + """ + if torch.version.hip + else "mov.b32 $0, $2; mov.b32 $1, $3;" + ), + constraints="=v,=v,v,v" if torch.version.hip else "=r,=r,r,r", + dtype=torch.float32, + pack=2, + ) + + x = torch.randn(128, device="cuda", dtype=torch.float32) + with torch._inductor.virtualized.V.set_choices_handler(ForceXBlock1()): + torch._dynamo.reset() + result, (code,) = run_and_get_code(torch.compile(fn, backend="inductor"), x) + + self.assertEqual(result, x) + # Verify padding helpers are emitted in the generated code + FileCheck().check("inline_asm_pack").check("inline_asm_unpack").run(code) + + def test_pack4_xblock1_padding(self): + """Force XBLOCK=1 with pack=4 so padding is needed.""" + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.triton import FixedTritonConfig + from torch._inductor.utils import run_and_get_code + from torch.testing import FileCheck + + class ForceXBlock1(InductorChoices): + def triton_kernel_kwargs(self, kernel_cls, features, groups, kernel_kwargs): + return { + **kernel_kwargs, + "fixed_config": FixedTritonConfig({"XBLOCK": 1}), + } + + def fn(x): + return inline_asm_elementwise( + x, + asm_str=( + """ + v_mov_b32 $0, $4 + v_mov_b32 $1, $5 + v_mov_b32 $2, $6 + v_mov_b32 $3, $7 + """ + if torch.version.hip + else "mov.b32 $0, $4; mov.b32 $1, $5; mov.b32 $2, $6; mov.b32 $3, $7;" + ), + constraints=( + "=v,=v,=v,=v,v,v,v,v" + if torch.version.hip + else "=r,=r,=r,=r,r,r,r,r" + ), + dtype=torch.float32, + pack=4, + ) + + x = torch.randn(128, device="cuda", dtype=torch.float32) + with torch._inductor.virtualized.V.set_choices_handler(ForceXBlock1()): + torch._dynamo.reset() + result, (code,) = run_and_get_code(torch.compile(fn, backend="inductor"), x) + + self.assertEqual(result, x) + FileCheck().check("inline_asm_pack").check("inline_asm_unpack").run(code) + + def test_pack4_xblock2_partial_padding(self): + """XBLOCK=2 < pack=4, so 1 round of padding is needed (not 2).""" + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.triton import FixedTritonConfig + from torch._inductor.utils import run_and_get_code + from torch.testing import FileCheck + + class ForceXBlock2(InductorChoices): + def triton_kernel_kwargs(self, kernel_cls, features, groups, kernel_kwargs): + return { + **kernel_kwargs, + "fixed_config": FixedTritonConfig({"XBLOCK": 2}), + } + + def fn(x): + return inline_asm_elementwise( + x, + asm_str=( + """ + v_mov_b32 $0, $4 + v_mov_b32 $1, $5 + v_mov_b32 $2, $6 + v_mov_b32 $3, $7 + """ + if torch.version.hip + else "mov.b32 $0, $4; mov.b32 $1, $5; mov.b32 $2, $6; mov.b32 $3, $7;" + ), + constraints=( + "=v,=v,=v,=v,v,v,v,v" + if torch.version.hip + else "=r,=r,=r,=r,r,r,r,r" + ), + dtype=torch.float32, + pack=4, + ) + + x = torch.randn(128, device="cuda", dtype=torch.float32) + with torch._inductor.virtualized.V.set_choices_handler(ForceXBlock2()): + torch._dynamo.reset() + result, (code,) = run_and_get_code(torch.compile(fn, backend="inductor"), x) + + self.assertEqual(result, x) + FileCheck().check("inline_asm_pack").check("inline_asm_unpack").run(code) + + def test_pack2_xblock1_yblock1_padding(self): + """Force XBLOCK=1, YBLOCK=1 with pack=2 on a 2D-tiled kernel.""" + from torch._inductor.choices import InductorChoices + from torch._inductor.codegen.triton import FixedTritonConfig + from torch._inductor.utils import run_and_get_code + from torch.testing import FileCheck + + class ForceXY1(InductorChoices): + def triton_kernel_kwargs(self, kernel_cls, features, groups, kernel_kwargs): + return { + **kernel_kwargs, + "fixed_config": FixedTritonConfig({"XBLOCK": 1, "YBLOCK": 1}), + } + + def fn(x, y): + return inline_asm_elementwise( + x, + y, + asm_str=( + """ + v_add_f32 $0, $2, $4 + v_add_f32 $1, $3, $5 + """ + if torch.version.hip + else "add.f32 $0, $2, $4; add.f32 $1, $3, $5;" + ), + constraints="=v,=v,v,v,v,v" if torch.version.hip else "=f,=f,f,f,f,f", + dtype=torch.float32, + pack=2, + ) + + x = torch.randn(8, 16, device="cuda", dtype=torch.float32) + # Transposed input triggers 2D tiling (different stride patterns) + y = torch.randn(16, 8, device="cuda", dtype=torch.float32).T + with torch._inductor.virtualized.V.set_choices_handler(ForceXY1()): + torch._dynamo.reset() + result, (code,) = run_and_get_code( + torch.compile(fn, backend="inductor"), x, y + ) + + self.assertEqual(result, x + y) + FileCheck().check("YBLOCK").check("inline_asm_pack").check( + "inline_asm_unpack" + ).run(code) + + +if __name__ == "__main__": + run_tests() diff --git a/test/higher_order_ops/test_invoke_quant.py b/test/higher_order_ops/test_invoke_quant.py index 7796a9e4a1685..52abe01277478 100644 --- a/test/higher_order_ops/test_invoke_quant.py +++ b/test/higher_order_ops/test_invoke_quant.py @@ -1,5 +1,4 @@ # Owner(s): ["module: higher order operators"] -# flake8: noqa: B950 import contextlib import logging diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 41c7013f3d8c0..34d2065e9d029 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -1,7 +1,7 @@ # Owner(s): ["module: higher order operators"] -# flake8: noqa: B950 # flake8: noqa: E731 +import contextlib import unittest import unittest.mock as mock @@ -339,11 +339,11 @@ def forward(self, L_x_: "f32[8]", L_y_: "f32[8]", L_mod_buffers_buf_: "f32[8]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_mod_buffers_buf_, l_x_, l_y_); subgraph_0 = None - getitem_8: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None subgraph_1 = self.subgraph_0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', l_mod_buffers_buf_, l_x_, l_y_); subgraph_1 = l_mod_buffers_buf_ = l_x_ = l_y_ = None - getitem_9: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - add: "f32[8]" = getitem_8 + getitem_9; getitem_8 = getitem_9 = None + getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + add: "f32[8]" = getitem + getitem_1; getitem = getitem_1 = None return (add,) class subgraph_0(torch.nn.Module): @@ -698,7 +698,7 @@ def fn(x): x = torch.randn(8, requires_grad=True) # Difficult to check the results here because we random does not match # between eager and Triton. - res = torch.compile(fn, backend="inductor", fullgraph=True)(x) # noqa: F841 + res = torch.compile(fn, backend="inductor", fullgraph=True)(x) torch.compiler.reset() backend = InductorAndRecordGraphs() @@ -1064,16 +1064,16 @@ def forward(self, L_x_: "f32[8]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_); subgraph_0 = l_x_ = None - getitem_2: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None - detach: "f32[8]" = getitem_2.detach(); getitem_2 = None + getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + detach: "f32[8]" = getitem.detach(); getitem = None return (detach,) class subgraph_0(torch.nn.Module): def forward(self, l_x_: "f32[8]"): wrap_body_0 = self.wrap_body_0 tag_activation_checkpoint = torch.ops.higher_order.tag_activation_checkpoint(wrap_body_0, l_x_, use_reentrant = False); wrap_body_0 = l_x_ = None - getitem_2: "f32[8]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None - return (getitem_2,) + getitem: "f32[8]" = tag_activation_checkpoint[0]; tag_activation_checkpoint = None + return (getitem,) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[8]"): @@ -1130,12 +1130,12 @@ def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None - getitem_4: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None subgraph_1 = self.subgraph_1 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', getitem_4, l_y_); subgraph_1 = getitem_4 = l_y_ = None - getitem_5: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - return (getitem_5,) + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', a, l_y_); subgraph_1 = a = l_y_ = None + getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + return (getitem_1,) class subgraph_0(torch.nn.Module): def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): @@ -1252,20 +1252,20 @@ def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None - getitem_25: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None subgraph_1 = self.subgraph_0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem_25, l_y_); subgraph_1 = getitem_25 = None - getitem_26: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', getitem, l_y_); subgraph_1 = getitem = None + x: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None subgraph_2 = self.subgraph_0 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', getitem_26, l_y_); subgraph_2 = getitem_26 = None - getitem_27: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x, l_y_); subgraph_2 = x = None + x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None subgraph_3 = self.subgraph_0 - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', getitem_27, l_y_); subgraph_3 = getitem_27 = None - getitem_28: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_3, 'subgraph_0', x_1, l_y_); subgraph_3 = x_1 = None + x_2: "f32[8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None subgraph_4 = self.subgraph_0 - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', getitem_28, l_y_); subgraph_4 = getitem_28 = l_y_ = None - getitem_29: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None - return (getitem_29,) + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_4, 'subgraph_0', x_2, l_y_); subgraph_4 = x_2 = l_y_ = None + x_3: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + return (x_3,) class subgraph_0(torch.nn.Module): def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): @@ -1514,12 +1514,11 @@ def fn(x): x = torch.randn(8, requires_grad=False) opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) - # TODO When a filtered aliased intermediate is captured by side effects, - # it will fail later with "does not belong to this Graph" error - # because the proxy from the inner graph is used in the outer graph. + # When a filtered aliased intermediate is captured by side effects, + # the tainted proxy raises a clear error telling the user to clone. with self.assertRaisesRegex( torch._dynamo.exc.InternalTorchDynamoError, - "does not belong to this Graph", + "aliases an input or output.*clone", ): opt_fn(x) @@ -2468,12 +2467,12 @@ def forward(self, L_x_: "f32[5]", L_y_: "f32[5]"): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x, y); subgraph_0 = x = None - getitem_4: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None + z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None - subgraph_1 = self.subgraph_1 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', getitem_4, y); subgraph_1 = getitem_4 = y = None - getitem_5: "f32[5]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - return (getitem_5,) + subgraph_1 = self.subgraph_0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', z, y); subgraph_1 = z = y = None + getitem_1: "f32[5]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + return (getitem_1,) class subgraph_0(torch.nn.Module): def forward(self, x: "f32[5]", y: "f32[5]"): @@ -2481,15 +2480,6 @@ def forward(self, x: "f32[5]", y: "f32[5]"): triton_kernel_wrapper_mutation = torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 0, constant_args_idx = 0, grid = [(5, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': x, 'in_ptr1': y, 'out_ptr': o}); x = y = triton_kernel_wrapper_mutation = None - sin: "f32[5]" = o.sin(); o = None - return (sin,) - - class subgraph_1(torch.nn.Module): - def forward(self, z: "f32[5]", y: "f32[5]"): - o: "f32[5]" = torch.zeros_like(z) - - triton_kernel_wrapper_mutation = torch.ops.higher_order.triton_kernel_wrapper_mutation(kernel_idx = 0, constant_args_idx = 1, grid = [(5, 1, 1)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': z, 'in_ptr1': y, 'out_ptr': o}); z = y = triton_kernel_wrapper_mutation = None - sin: "f32[5]" = o.sin(); o = None return (sin,) """, @@ -3101,6 +3091,821 @@ def forward(self, l_x_: "f32[64, 1]"): ) +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeSubgraphReuse(TestCase): + @contextlib.contextmanager + def _count_speculate_calls(self): + count = 0 + orig = torch._dynamo.variables.higher_order_ops.speculate_subgraph_with_auto_output_flattening + + def _counting(*args, **kwargs): + nonlocal count + count += 1 + return orig(*args, **kwargs) + + with mock.patch.object( + torch._dynamo.variables.higher_order_ops, + "speculate_subgraph_with_auto_output_flattening", + _counting, + ): + yield lambda: count + + def test_subgraph_reuse_skips_tracing(self): + @nested_compile_region + def gn(x, y): + return torch.mul(x, y) + + def fn(x, y): + a = gn(x, y) + b = gn(x, y) + c = gn(x, y) + return a + b + c + + x = torch.randn(8) + y = torch.randn(8) + + with self._count_speculate_calls() as count: + torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y) + + self.assertEqual(count(), 1) + + def test_subgraph_reuse_different_shapes(self): + @nested_compile_region + def gn(x): + return x.sin() + + def fn(x, y): + a = gn(x) + b = gn(y) + return a.sum() + b.sum() + + x = torch.randn(4) + y = torch.randn(8) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x, y) + + # Different shapes → two separate traces + self.assertEqual(count(), 2) + self.assertEqual(res, fn(x, y)) + + def test_subgraph_reuse_module(self): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @nested_compile_region + def forward(self, x, y): + return torch.mul(x, y).sin() + self.c + + mod = Mod() + + def fn(x, y): + return mod(x, y) + mod(x, y) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(x, y) + + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)( + x_clone, y_clone + ) + + # Second call reuses the first trace + self.assertEqual(count(), 1) + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_subgraph_reuse_module_different_instances(self): + class Mod(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + @nested_compile_region + def forward(self, x, y): + return torch.mul(x, y).sin() + self.c + + mod1 = Mod(5) + mod2 = Mod(5) + + def fn(x, y): + return mod1(x, y) + mod2(x, y) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(x, y) + + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)( + x_clone, y_clone + ) + + # mod1 and mod2 have the same structure and c value; source replacement + # means only one trace is needed. + self.assertEqual(count(), 1) + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_subgraph_reuse_tuple_output(self): + @nested_compile_region + def gn(x, y): + return torch.sin(x), torch.cos(y) + + def fn(x, y): + a1, a2 = gn(x, y) + b1, b2 = gn(x, y) + return a1 + b1, a2 + b2 + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(x, y) + + x_clone = x.detach().clone().requires_grad_(True) + y_clone = y.detach().clone().requires_grad_(True) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)( + x_clone, y_clone + ) + + # Second call reuses the first trace + self.assertEqual(count(), 1) + sum(r.sum() for r in ref).backward() + sum(r.sum() for r in res).backward() + + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_subgraph_reuse_mutated_attribute(self): + """Reuse must be skipped when a captured attribute is mutated between calls.""" + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @nested_compile_region + def forward(self, x): + return x * self.c + + mod = Mod() + + def fn(x): + a = mod(x) + mod.c = 10 + b = mod(x) + return a + b + + x = torch.randn(8) + # Eager: first call uses c=5, then c is set to 10, second call uses c=10. + # Result = x*5 + x*10 = x*15. + mod.c = 5 + ref = fn(x) + self.assertEqual(ref, x * 15) + + # Compiled should produce the same result. If reuse incorrectly + # fires, both calls would use c=5, giving x*10 instead of x*15. + mod.c = 5 + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(ref, res) + # c=5 and c=10 are distinct constants → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_unrelated_attr_mutation(self): + """Reuse should still fire when a different attribute is mutated.""" + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + self.d = 100 + + @nested_compile_region + def forward(self, x): + # Only reads self.c, never self.d + return x * self.c + + mod = Mod() + + def fn(x): + a = mod(x) + mod.d = 999 # unrelated attribute + b = mod(x) + return a + b + + x = torch.randn(8) + ref = fn(x) + # Both calls use c=5, so result = x*5 + x*5 = x*10. + self.assertEqual(ref, x * 10) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + # Mutating mod.d should not prevent reuse of the subgraph that only reads mod.c. + self.assertEqual(count(), 1) + + def test_subgraph_reuse_same_class_attr_mutated(self): + """Reuse must be skipped when a captured attr changes between calls. + + submod1 and submod2 are instances of the same class with the same + initial value for .c. The first call traces submod1; the second call + to submod2 could reuse the cache entry via source replacement. But + submod2.c is mutated between the two calls, so reuse must be skipped. + """ + + class Block(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + @nested_compile_region + def forward(self, x): + return x * self.c + + submod1 = Block(5) + submod2 = Block(5) # same initial .c as submod1 + + def fn(x): + a = submod1(x) # traces with c=5 + submod2.c = 10 # mutate submod2.c + b = submod2(x) # must NOT reuse the c=5 subgraph + return a + b + + x = torch.randn(8) + ref = fn(x) + # a = x*5, b = x*10 → x*15 + self.assertEqual(ref, x * 15) + + submod2.c = 5 # reset for compiled run + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(ref, res) + # submod1 traces with c=5; submod2 has c mutated to 10 → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_pre_existing_attr_guard(self): + """Guards installed before the subgraph trace must still block incorrect reuse. + + If ``block.c`` is read in a conditional before the nested compile region, + its guard is installed before ``guards_before`` is snapshotted and won't + appear in the delta. Reuse must still be rejected when a different module + with a different ``c`` is passed. + """ + + class Block(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + @nested_compile_region + def apply_block(mod, x): + return x * mod.c + + block1 = Block(5) + block2 = Block(10) + + def fn(x): + # The conditional installs EQUALS_MATCH on block1.c *before* + # the subgraph trace snapshots guards_before. + if block1.c == 5: + a = apply_block(block1, x) + else: + a = x + if block2.c == 10: + b = apply_block(block2, x) + else: + b = x + return a + b + + x = torch.randn(8) + ref = fn(x) + self.assertEqual(ref, x * 5 + x * 10) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(ref, res) + # block1.c=5 and block2.c=10 differ → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_mutated_captured_variable(self): + """Reuse must be skipped when a captured (non-input) variable is mutated.""" + + class Config: + def __init__(self, c): + self.c = c + + cfg = Config(5) + + @nested_compile_region + def apply(x): + # cfg is captured from closure, not an explicit input + return x * cfg.c + + def fn(x): + a = apply(x) + cfg.c = 10 + b = apply(x) + return a + b + + x = torch.randn(8) + cfg.c = 5 + ref = fn(x) + self.assertEqual(ref, x * 15) + + cfg.c = 5 + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + self.assertEqual(ref, res) + # cfg.c=5 and cfg.c=10 are distinct → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_synthetic_source(self): + """Reuse must handle TorchScriptObjectVariable with SyntheticLocalSource. + + Hoisted opaque value types get a SyntheticLocalSource that can't be + resolved via VariableBuilder. On cache hit, stamp_out_subgraph must + call synthetic_graph_input to create a fresh graph input. + """ + from test_opaque_obj_v2 import HoistedString, op_with_string + + @nested_compile_region + def gn(x): + return op_with_string(x, HoistedString("double")) + + def fn(x): + a = gn(x) + b = gn(x) + return a + b + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + self.assertEqual(count(), 1) + + def test_subgraph_reuse_synthetic_source_different_args(self): + """Reuse when hoisted opaque ctor args differ across submodules.""" + from test_opaque_obj_v2 import HoistedString, op_with_string + + class Layer(torch.nn.Module): + def __init__(self, name): + super().__init__() + self.name = name + + @nested_compile_region + def forward(self, x): + return op_with_string(x, HoistedString(self.name)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer0 = Layer("double") + self.layer1 = Layer("square") + self.layer2 = Layer("double") + + def forward(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + return x + + model = Model() + x = torch.randn(8) + ref = model(x) + + backend = EagerAndRecordGraphs() + with self._count_speculate_calls() as count: + res = torch.compile(model, backend=backend, fullgraph=True)(x) + + self.assertEqual(ref, res) + self.assertEqual(count(), 1) + + self.assertEqual(len(backend.graphs), 1) + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8]", SYNTHETIC_LOCAL_tmp_0_ : test_opaque_obj_v2_HoistedString, SYNTHETIC_LOCAL_tmp_2_ : test_opaque_obj_v2_HoistedString, SYNTHETIC_LOCAL_tmp_4_ : test_opaque_obj_v2_HoistedString): + l_x_ = L_x_ + synthetic_local_tmp_0_ = SYNTHETIC_LOCAL_tmp_0_ + synthetic_local_tmp_2_ = SYNTHETIC_LOCAL_tmp_2_ + synthetic_local_tmp_4_ = SYNTHETIC_LOCAL_tmp_4_ + + subgraph_0 = self.subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, synthetic_local_tmp_0_); subgraph_0 = l_x_ = synthetic_local_tmp_0_ = None + x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + + subgraph_1 = self.subgraph_0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', x, synthetic_local_tmp_2_); subgraph_1 = x = synthetic_local_tmp_2_ = None + x_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + + subgraph_2 = self.subgraph_0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_2, 'subgraph_0', x_1, synthetic_local_tmp_4_); subgraph_2 = x_1 = synthetic_local_tmp_4_ = None + x_2: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None + return (x_2,) + + class subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]", synthetic_local_tmp_0_ : test_opaque_obj_v2_HoistedString): + op_with_string_default: "f32[8]" = torch.ops.mylib.op_with_string.default(l_x_, synthetic_local_tmp_0_); l_x_ = synthetic_local_tmp_0_ = None + return (op_with_string_default,) +""", + ) + + def test_subgraph_reuse_different_list_lengths(self): + """Reuse must be skipped when list args have different lengths. + + The first call passes lists of length 2; the second passes lists of + length 3. The pytree treespec will differ, so the cache lookup must + fall through and trigger a second trace. + """ + + @nested_compile_region + def gn(xs, ys): + return [a + b for a, b in zip(xs, ys)] + + def fn(xs1, ys1, xs2, ys2): + a = gn(xs1, ys1) + b = gn(xs2, ys2) + return a, b + + xs1 = [torch.randn(4), torch.randn(4)] + ys1 = [torch.randn(4), torch.randn(4)] + xs2 = [torch.randn(4), torch.randn(4), torch.randn(4)] + ys2 = [torch.randn(4), torch.randn(4), torch.randn(4)] + + ref = fn(xs1, ys1, xs2, ys2) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)( + xs1, ys1, xs2, ys2 + ) + + # Different list lengths → treespec mismatch → two separate traces + self.assertEqual(count(), 2) + for r, e in zip(res, ref): + for ri, ei in zip(r, e): + self.assertEqual(ri, ei) + + def test_subgraph_reuse_different_constants_retrace(self): + """Constant args with different values each require a fresh trace. + + Three calls with three distinct scalar constants → call_count == 3. + """ + + @nested_compile_region + def gn(x, scale): + return x * scale + + def fn(x): + a = gn(x, 1) + b = gn(x, 2) + c = gn(x, 3) + return a, b, c + + x = torch.randn(4) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + for r, e in zip(res, ref): + self.assertEqual(r, e) + # Three distinct constants → three separate traces + self.assertEqual(count(), 3) + + def test_subgraph_reuse_tuple_destructure_with_intermediates(self): + class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(8, 8, bias=False) + + @nested_compile_region + def forward(self, x, residual): + h = self.linear(x) + h = h + residual + new_residual = h * 0.5 + return h, new_residual + + class Model(torch.nn.Module): + def __init__(self, num_layers): + super().__init__() + self.embed = torch.nn.Linear(8, 8, bias=False) + self.layers = torch.nn.ModuleList([Layer() for _ in range(num_layers)]) + + def forward(self, x, residual): + # embed gives x/residual requires_grad=True (same as layer outputs), + # so all layers see identical tensor metadata → single trace suffices. + x = self.embed(x) + residual = self.embed(residual) + for layer in self.layers: + # Must support extra outputs + hidden_states, residual = layer(x, residual) + x = hidden_states + return x, residual + + model = Model(3) + x = torch.randn(4, 8) + residual = torch.randn(4, 8) + ref = model(x, residual) + + torch._dynamo.reset() + + with self._count_speculate_calls() as count: + res = torch.compile(model, backend="aot_eager", fullgraph=True)( + x.clone(), residual.clone() + ) + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + # All layers see identical tensor metadata (requires_grad=True throughout) + # so layers[1] and layers[2] reuse layers[0]'s trace. + self.assertEqual(count(), 1) + + def test_subgraph_reuse_different_dynamic_symnodes(self): + @nested_compile_region + def gn(x, n): + return x * n + + def fn(x): + a = gn(x, x.shape[0]) + b = gn(x, x.shape[1]) + return a, b + + x = torch.ones(4, 6) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", dynamic=True)(x) + + self.assertEqual(ref[0], res[0]) + self.assertEqual(ref[1], res[1]) + # s0 (dim 0) and s1 (dim 1) are distinct symbols → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_cache_multiple_entries(self): + @nested_compile_region + def gn(x): + return x.sin() + + def fn(x4, x8, x4_again): + a = gn(x4) + b = gn(x8) + c = gn(x4_again) + return a, b, c + + x4 = torch.randn(4) + x8 = torch.randn(8) + x4_again = torch.randn(4) + ref = fn(x4, x8, x4_again) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)( + x4, x8, x4_again + ) + + for r, e in zip(res, ref): + self.assertEqual(r, e) + self.assertEqual(count(), 2) + + def test_subgraph_reuse_kwargs(self): + @nested_compile_region + def gn(x, *, scale=1.0): + return x * scale + + def fn(x): + a = gn(x, scale=2.0) + b = gn(x, scale=2.0) + return a + b + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + self.assertEqual(count(), 1) + + def test_subgraph_reuse_max_entries_raises(self): + """Exceeding max_reuse_entries raises RuntimeError.""" + + @nested_compile_region(max_reuse_entries=2) + def gn(x, c): + return x * c + + def fn(x): + # Three distinct constants exceed the limit of 2 + return gn(x, 1) + gn(x, 2) + gn(x, 3) + + x = torch.randn(4) + with self.assertRaisesRegex(RuntimeError, "exceeded maximum reuse entries"): + torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + def test_subgraph_reuse_module_different_instances_retrace(self): + """Different module instances with different weights require separate traces.""" + + class Mod(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + @nested_compile_region + def forward(self, x): + return x * self.c + + mod1 = Mod(5) + mod2 = Mod(10) + + def fn(x): + return mod1(x) + mod2(x) + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + # mod1.c=5 and mod2.c=10 differ → two separate traces + self.assertEqual(count(), 2) + + def test_subgraph_reuse_monkeypatch_forward(self): + """Monkeypatching forward with nested_compile_region should reuse cache. + + When a user wraps a module's forward with nested_compile_region via + monkeypatching, each module instance gets a distinct bound method and + thus a distinct function object passed to nested_compile_region. The + cache key should be based on fn.__code__ rather than id(fn) so that + reuse still works across instances whose forward shares the same code. + """ + + class Mod(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + def apply_nested_compile_region(mod): + # Grab the unbound function and create a fresh copy so that each + # module instance gets a function with a *different* id() but the + # *same* __code__ object, mimicking a monkeypatch/code-gen scenario. + import types + + orig = type(mod).forward + fresh_fn = types.FunctionType( + orig.__code__, + orig.__globals__, + orig.__name__, + orig.__defaults__, + orig.__closure__, + ) + mod.forward = nested_compile_region(fresh_fn).__get__(mod, type(mod)) + return mod + + mod1 = apply_nested_compile_region(Mod(5)) + mod2 = apply_nested_compile_region(Mod(5)) + + # The inner functions passed to invoke_subgraph_placeholder are + # different objects but share the same __code__. + fn1 = mod1.forward.__func__.__marked_compile_region_fn__ + fn2 = mod2.forward.__func__.__marked_compile_region_fn__ + self.assertIsNot(fn1, fn2) + self.assertIs(fn1.__code__, fn2.__code__) + + def fn(x): + return mod1(x) + mod2(x) + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + # Both modules have the same code and same c=5, so the second call + # should reuse the first trace despite having different function ids. + self.assertEqual(count(), 1) + + def test_subgraph_reuse_module_apply(self): + """Using module.apply to wrap transformer layers with nested_compile_region.""" + + class Layer(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + model = torch.nn.Sequential(Layer(5), Layer(5), Layer(5)) + ref_model = torch.nn.Sequential(Layer(5), Layer(5), Layer(5)) + + def wrap_layer(mod): + if isinstance(mod, Layer): + fwd = type(mod).forward + if not hasattr(fwd, "__marked_compile_region_fn__"): + type(mod).forward = nested_compile_region(fwd) + + model.apply(wrap_layer) + + x = torch.randn(8) + ref = ref_model(x) + + with self._count_speculate_calls() as count: + res = torch.compile(model, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + # All three layers share the same class and c=5, only one trace needed. + self.assertEqual(count(), 1) + + def test_subgraph_reuse_class_level_wrap(self): + """Wrapping cls.forward with nested_compile_region for named_modules pattern.""" + + class Layer(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + # Wrap at the class level — the pattern a user would use after + # iterating named_modules to wrap all layers of a given type. + Layer.forward = nested_compile_region(Layer.forward) + + mod1 = Layer(5) + mod2 = Layer(5) + + def fn(x): + return mod1(x) + mod2(x) + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + self.assertEqual(count(), 1) + + def test_subgraph_reuse_module_instance_as_callable(self): + """Passing nn.Module instances directly to nested_compile_region.""" + + class Layer(torch.nn.Module): + def __init__(self, c): + super().__init__() + self.c = c + + def forward(self, x): + return x * self.c + + mod1 = Layer(5) + mod2 = Layer(5) + + wrapped1 = nested_compile_region(mod1) + wrapped2 = nested_compile_region(mod2) + + def fn(x): + return wrapped1(x) + wrapped2(x) + + x = torch.randn(8) + ref = fn(x) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + # Both modules have the same forward code and same c=5, so the + # second call should reuse the first trace. + self.assertEqual(count(), 1) + + @skipIfTorchDynamo("Not a torch._dynamo test") @parameterized_class( [ @@ -3345,5 +4150,264 @@ def fn(x, y): self.assertEqual(y.grad, y2.grad) +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeSubgraphReuseHashFn(TestCase): + @contextlib.contextmanager + def _count_speculate_calls(self): + count = 0 + orig = torch._dynamo.variables.higher_order_ops.speculate_subgraph_with_auto_output_flattening + + def _counting(*args, **kwargs): + nonlocal count + count += 1 + return orig(*args, **kwargs) + + with mock.patch.object( + torch._dynamo.variables.higher_order_ops, + "speculate_subgraph_with_auto_output_flattening", + _counting, + ): + yield lambda: count + + def test_reuse_hash_fn_module_distinct_hashes(self): + """nn.Module arg with hash fn returning different values per layer.""" + + def hash_fn(mod, x): + return mod.layer_id + + @nested_compile_region(reuse_hash_fn=hash_fn) + def layer_fn(mod, x): + return x.sin() + mod.weight + + class Layer(torch.nn.Module): + def __init__(self, layer_id): + super().__init__() + self.layer_id = layer_id + self.weight = torch.nn.Parameter(torch.randn(8)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Layer(i) for i in range(4)]) + + def forward(self, x): + for layer in self.layers: + x = layer_fn(layer, x) + return x + + mod = Model() + x = torch.randn(8) + ref = mod(x) + + with self._count_speculate_calls() as count: + res = torch.compile(mod, backend="aot_eager", fullgraph=True)(x) + + # Each layer has a distinct layer_id → 4 separate traces + self.assertEqual(count(), 4) + self.assertEqual(ref, res) + + def test_reuse_hash_fn_module_same_hash(self): + """nn.Module arg with hash fn returning same value → single trace.""" + + def hash_fn(mod, x): + return 0 + + @nested_compile_region(reuse_hash_fn=hash_fn) + def layer_fn(mod, x): + return x.sin() + mod.weight + + class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(8)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Layer() for _ in range(4)]) + + def forward(self, x): + for layer in self.layers: + x = layer_fn(layer, x) + return x + + mod = Model() + x = torch.randn(8) + ref = mod(x) + + with self._count_speculate_calls() as count: + res = torch.compile(mod, backend="aot_eager", fullgraph=True)(x) + + # All layers hash to 0 → single trace + 3 stamp-outs + self.assertEqual(count(), 1) + self.assertEqual(ref, res) + + def test_reuse_hash_fn_tensor_shape(self): + """Hash fn that uses tensor shape to differentiate inputs.""" + + def hash_fn(x): + return x.shape[0] + + @nested_compile_region(reuse_hash_fn=hash_fn) + def gn(x): + return x.sin() + + def fn(x4, x8a, x8b): + # x8a and x8b have the same shape → same hash → reuse + a = gn(x4) + b = gn(x8a) + c = gn(x8b) + return a.sum() + b.sum() + c.sum() + + x4 = torch.randn(4) + x8a = torch.randn(8) + x8b = torch.randn(8) + ref = fn(x4, x8a, x8b) + + with self._count_speculate_calls() as count: + res = torch.compile(fn, backend="aot_eager", fullgraph=True)(x4, x8a, x8b) + + # shape[0]=4 and shape[0]=8 → 2 traces, third call reuses shape=8 + self.assertEqual(count(), 2) + self.assertEqual(ref, res) + + def test_reuse_hash_fn_graph_break_raises(self): + """reuse_hash_fn with a graph break raises a clear error.""" + + def bad_hash_fn(mod, x): + torch._dynamo.graph_break() + return 0 + + @nested_compile_region(reuse_hash_fn=bad_hash_fn) + def layer_fn(mod, x): + return x.sin() + mod.weight + + class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(8)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Layer(), Layer()]) + + def forward(self, x): + for layer in self.layers: + x = layer_fn(layer, x) + return x + + mod = Model() + x = torch.randn(8) + with self.assertRaisesRegex( + RuntimeError, "reuse_hash_fn must be fully traceable" + ): + torch.compile(mod, backend="aot_eager", fullgraph=True)(x) + + def test_reuse_hash_fn_if_cond_no_guard(self): + """if-condition in hash fn should not install guards on the module.""" + + def hash_fn(mod, x): + if mod.use_gelu: + return 1 + return 0 + + @nested_compile_region(reuse_hash_fn=hash_fn) + def layer_fn(mod, x): + return x.sin() + mod.weight + + class Layer(torch.nn.Module): + def __init__(self, use_gelu): + super().__init__() + self.use_gelu = use_gelu + self.weight = torch.nn.Parameter(torch.randn(8)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Layer(True), Layer(False)]) + + def forward(self, x): + for layer in self.layers: + x = layer_fn(layer, x) + return x + + mod = Model() + x = torch.randn(8) + ref = mod(x) + + cnt = torch._dynamo.testing.CompileCounter() + compiled = torch.compile(mod, backend=cnt) + res = compiled(x) + self.assertEqual(ref, res) + + # Flip use_gelu on both layers — if hash fn guards leaked, this + # would cause a recompilation. With proper guard stripping it + # should not recompile (the outer frame guard count stays the same). + frame_count_before = cnt.frame_count + mod.layers[0].use_gelu = False + mod.layers[1].use_gelu = True + ref2 = mod(x) + res2 = compiled(x) + self.assertEqual(ref2, res2) + self.assertEqual(cnt.frame_count, frame_count_before) + + def test_reuse_hash_fn_side_effect_allowed(self): + """Side effects (attribute mutation) should not block reuse with reuse_hash_fn.""" + + def hash_fn(mod, x): + return 0 + + @nested_compile_region(reuse_hash_fn=hash_fn) + def layer_fn(mod, x): + mod.call_count += 1 + return x.sin() + mod.weight + + class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.call_count = 0 + self.weight = torch.nn.Parameter(torch.randn(8)) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Layer() for _ in range(3)]) + + def forward(self, x): + for layer in self.layers: + x = layer_fn(layer, x) + return x + + mod = Model() + x = torch.randn(8) + + # Without reuse_hash_fn, the side effect (mod.call_count += 1) + # would prevent reuse. With reuse_hash_fn, it should still reuse. + with self._count_speculate_calls() as count: + torch.compile(mod, backend="aot_eager", fullgraph=True)(x) + + self.assertEqual(count(), 1) + + def test_reuse_hash_fn_unsupported_output_raises(self): + """Nested output (tuple of tuple of tensors) should raise with reuse_hash_fn.""" + + def hash_fn(x): + return 0 + + @nested_compile_region(reuse_hash_fn=hash_fn) + def gn(x): + return ((x.sin(), x.cos()),) + + def fn(x): + return gn(x) + + x = torch.randn(8) + with self.assertRaisesRegex( + RuntimeError, "reuse_hash_fn was provided but the subgraph is not eligible" + ): + torch.compile(fn, backend="aot_eager", fullgraph=True)(x) + + if __name__ == "__main__": run_tests() diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 2eb818f8831ae..e3e8be1dee9fb 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -1,5 +1,4 @@ # Owner(s): ["module: higher order operators"] -# flake8: noqa: B950 import functools diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py index a133044e081a6..c3c9342975fc8 100644 --- a/test/higher_order_ops/test_print.py +++ b/test/higher_order_ops/test_print.py @@ -628,7 +628,7 @@ def forward(self, primals_1, primals_2): with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, \ 'values {} {}', 3, add); getitem = None getitem_2 = with_effects_1[0]; with_effects_1 = None - return (getitem_2, add)""", # noqa: B950 + return (getitem_2, add)""", ) # Check backward graph - print HOP doesn't contribute to gradients @@ -680,7 +680,7 @@ def forward(self, arg1_1): with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.print, 'values {} {}', 3, add); getitem = None getitem_2 = with_effects_1[0]; with_effects_1 = None _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None - return (add,)""", # noqa: B950 + return (add,)""", ) diff --git a/test/higher_order_ops/test_with_effects.py b/test/higher_order_ops/test_with_effects.py index ab85ce5812bb6..d26e3c2bf1ad4 100644 --- a/test/higher_order_ops/test_with_effects.py +++ b/test/higher_order_ops/test_with_effects.py @@ -1,6 +1,5 @@ # Owner(s): ["module: functorch"] # ruff: noqa: F841 -# flake8: noqa: B950 import unittest from collections import deque from functools import partial @@ -138,7 +137,7 @@ def forward(self, arg1_1): with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None getitem_2 = with_effects_1[0]; with_effects_1 = None _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None - return (add,)""", # noqa: B950 + return (add,)""", ) def test_torchbind_custom_op(self): @@ -162,7 +161,7 @@ def forward(self, arg0_1, arg1_1): getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None - return (getitem, add)""", # noqa: B950 + return (getitem, add)""", ) self.assertEqual(len(gs.input_tokens), 1) self.assertEqual(len(gs.output_tokens), 1) @@ -1034,7 +1033,7 @@ def forward(self, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1, arg7_1, arg8_1 with_effects = torch.ops.higher_order.with_effects(getitem_4, torch.ops.mylib.record_memory.default, 'forward', 'N'); getitem_4 = None getitem_6 = with_effects[0]; with_effects = None _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_6]); getitem_6 = _sink_tokens_default = None - return (getitem_5,)""", # noqa: B950 + return (getitem_5,)""", ) self.assertExpectedInline( str(gm.repeated_subgraph0.code).strip(), @@ -1049,7 +1048,7 @@ def forward(self, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1): t_1 = torch.ops.aten.t.default(arg4_1); arg4_1 = None addmm_1 = torch.ops.aten.addmm.default(arg5_1, relu, t_1); arg5_1 = relu = t_1 = None _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem]); getitem = _sink_tokens_default = None - return (addmm_1,)""", # noqa: B950 + return (addmm_1,)""", ) recorded_list.clear() @@ -1103,7 +1102,7 @@ def fn(x): def forward(self, primals_1): mul = torch.ops.aten.mul.Tensor(primals_1, primals_1); primals_1 = None sum_1 = torch.ops.aten.sum.default(mul); mul = None - return (sum_1,)""", # noqa: B950 + return (sum_1,)""", ) self.assertExpectedInline( @@ -1114,7 +1113,7 @@ def forward(self, tangents_1, tangents_token): with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.mylib.log_grad.default, expand); tangents_token = expand = None getitem = with_effects[0] getitem_1 = with_effects[1]; with_effects = None - return (getitem_1, getitem)""", # noqa: B950 + return (getitem_1, getitem)""", ) def test_with_effects_through_functional_tensor_mode(self): diff --git a/test/inductor/extension_backends/cpp/extension_codegen_backend.py b/test/inductor/extension_backends/cpp/extension_codegen_backend.py index f6afd87db75b9..75694f95d454e 100644 --- a/test/inductor/extension_backends/cpp/extension_codegen_backend.py +++ b/test/inductor/extension_backends/cpp/extension_codegen_backend.py @@ -1,4 +1,6 @@ +from torch._inductor import ir from torch._inductor.codegen import cpp, cpp_wrapper_cpu, wrapper +from torch._inductor.codegen.wrapper import PythonWrapperCodegen from torch._inductor.scheduler import BaseScheduling from torch._inductor.virtualized import V @@ -7,11 +9,47 @@ class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: str | None, + parent_wrapper: PythonWrapperCodegen | None, + partition_signatures: ir.GraphPartitionSignature | None = None, + ): + return ExtensionWrapperCodegen() + + def _generate_kernel_call_helper( + self, kernel_name, call_args, *, device=None, **kwargs + ): + device = device or V.graph.get_current_device_or_throw() + if device.type == "extension_device": + import torch + + device = torch.device("cpu") + super()._generate_kernel_call_helper( + kernel_name, call_args, device=device, **kwargs + ) + class ExtensionCppWrapperCodegen(cpp_wrapper_cpu.CppWrapperCpu): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @staticmethod + def create( + is_subgraph: bool, + subgraph_name: str | None, + parent_wrapper: PythonWrapperCodegen | None, + partition_signatures: ir.GraphPartitionSignature | None = None, + ): + return ExtensionCppWrapperCodegen() + + @staticmethod + def get_device_include_path(device: str) -> str: + if device == "extension_device": + return cpp_wrapper_cpu.CppWrapperCpu.get_device_include_path("cpu") + return cpp_wrapper_cpu.CppWrapperCpu.get_device_include_path(device) + class ExtensionScheduling(BaseScheduling): def __init__(self, scheduler): diff --git a/test/inductor/extension_backends/cpp/extension_device.cpp b/test/inductor/extension_backends/cpp/extension_device.cpp index 249ab38656689..0642f2ced2037 100644 --- a/test/inductor/extension_backends/cpp/extension_device.cpp +++ b/test/inductor/extension_backends/cpp/extension_device.cpp @@ -185,4 +185,7 @@ bool custom_op_called() { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("custom_device", &get_custom_device, "get custom device object"); m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); + m.def("is_available", []() { return true; }); + m.def("device_count", []() { return 1; }); + m.def("current_device", []() { return 0; }); } diff --git a/test/inductor/extension_backends/triton/device_interface.py b/test/inductor/extension_backends/triton/device_interface.py index 63d87575e3dfb..7537e785dc672 100644 --- a/test/inductor/extension_backends/triton/device_interface.py +++ b/test/inductor/extension_backends/triton/device_interface.py @@ -3,7 +3,7 @@ import time import torch -from torch._dynamo import device_interface # noqa: PLC2701 import-private-name +from torch._dynamo import device_interface class DeviceProperties: @@ -48,7 +48,7 @@ def query(self) -> None: def synchronize(self) -> None: pass - class device: # noqa: N801 invalid-class-name # pyright: ignore [reportIncompatibleVariableOverride] + class device: def __init__(self, device) -> None: self.device = device diff --git a/test/inductor/extension_backends/triton/extension_codegen_backend.py b/test/inductor/extension_backends/triton/extension_codegen_backend.py index 3e77a29caacc7..0bebb0ab1af68 100644 --- a/test/inductor/extension_backends/triton/extension_codegen_backend.py +++ b/test/inductor/extension_backends/triton/extension_codegen_backend.py @@ -39,11 +39,11 @@ class CPUDeviceOpOverrides(DeviceOpOverrides): def import_get_raw_stream_as(self, name: str) -> str: return f"def {name}(name): None\n" - def set_device(self, device_idx: int) -> str: # noqa: ARG002 unused-argument + def set_device(self, device_idx: int) -> str: return "" def synchronize(self) -> None: pass - def device_guard(self, device_idx: int) -> str: # noqa: ARG002 unused-argument + def device_guard(self, device_idx: int) -> str: return "" diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_208 b/test/inductor/pallas_expected_failures/CpuTests.test_cumprod_backward_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_208 rename to test/inductor/pallas_expected_failures/CpuTests.test_cumprod_backward_cpu diff --git a/test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_209 b/test/inductor/pallas_expected_failures/CpuTests.test_cumprod_backward_with_zeros_cpu similarity index 100% rename from test/dynamo_expected_failures/CPython313-test_patma-TestPatma.test_patma_209 rename to test/inductor/pallas_expected_failures/CpuTests.test_cumprod_backward_with_zeros_cpu diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu deleted file mode 100644 index a2f73f3214dfe..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward2_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (56,), (40,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu deleted file mode 100644 index 4af5f1937326f..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward3_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (38,), (37,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu deleted file mode 100644 index 3308d592808d9..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward4_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (4,), (3,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu deleted file mode 100644 index 53aaef6eda9cd..0000000000000 --- a/test/inductor/pallas_expected_failures/CpuTests.test_max_pool2d_with_indices_backward_cpu +++ /dev/null @@ -1,5 +0,0 @@ -ERROR - File "/home/oulgen/py312/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 1238, in add - out = lax.add(x, y) - ^^^^^^^^^^^^^ -TypeError: add got incompatible shapes for broadcasting: (14,), (18,). diff --git a/test/inductor/pallas_expected_failures/CpuTests.test_slice_scatter_backward_with_overlapping_base_cpu b/test/inductor/pallas_expected_failures/CpuTests.test_slice_scatter_backward_with_overlapping_base_cpu new file mode 100644 index 0000000000000..2c3f456e070c8 --- /dev/null +++ b/test/inductor/pallas_expected_failures/CpuTests.test_slice_scatter_backward_with_overlapping_base_cpu @@ -0,0 +1,2 @@ +ERROR +ValueError: Incompatible shapes for broadcasting diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 6fa4daf1e9c4e..94fe11a57282a 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -61,6 +61,10 @@ e5m2_type, skipCUDAIf, ) +from torch.testing._internal.common_dtype import ( + highest_precision_complex, + highest_precision_float, +) from torch.testing._internal.common_quantization import ( _group_quantize_tensor, skip_if_no_torchvision, @@ -68,7 +72,9 @@ ) from torch.testing._internal.common_utils import ( DeterministicGuard, + IS_ARM64, IS_CI, + IS_CPU_CAPABILITY_SVE256, IS_FBCODE, IS_MACOS, IS_WINDOWS, @@ -84,6 +90,7 @@ skipIfXpu, TEST_MPS, TEST_WITH_ROCM, + xfailIf, ) from torch.testing._internal.custom_tensor import CustomTensorPlainOut from torch.testing._internal.inductor_utils import ( @@ -101,6 +108,22 @@ ) +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+, XPU and CPU devices" + + +@contextlib.contextmanager +def caching_allocator_disabled(): + if GPU_TYPE == "cuda": + from torch.cuda import ( + caching_allocator_disabled as _cuda_caching_allocator_disabled, + ) + + with _cuda_caching_allocator_disabled(): + yield + else: + yield + + @contextlib.contextmanager def use_fa3(): try: @@ -509,6 +532,94 @@ def runner_call(*args, **kwargs): new_output = runner_call(test_inputs) self.assertEqual(expected, new_output) + def test_update_inactive_constant_buffer_with_interleaved_folded_constants(self): + if self.device == "mps": + raise unittest.SkipTest("MPS baseline mismatch") + + class Model(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.fc1 = nn.Linear(2, 2, bias=True, device=device) + self.post = nn.Linear(2, 2, bias=False, device=device) + self.register_buffer( + "uie_user_memory_network", + torch.randn(2, 2, device=device), + persistent=True, + ) + self.register_buffer( + "uie_item_memory_network", + torch.randn(2, 2, device=device), + persistent=True, + ) + self.register_buffer( + "late_bias", + torch.randn(2, device=device), + persistent=True, + ) + + def forward(self, x): + x = self.fc1(x) + direct_user = torch.matmul(x, self.uie_user_memory_network) + direct_item = torch.matmul(x, self.uie_item_memory_network) + folded_user = torch.relu(self.uie_user_memory_network.permute(1, 0)) + folded_item = torch.relu(self.uie_item_memory_network.permute(1, 0)) + out = direct_user + direct_item + out = out + torch.matmul(x, folded_user) + torch.matmul(x, folded_item) + return self.post(out + self.late_bias) + + example_inputs = (torch.randn(2, 2, device=self.device),) + with ( + torch.no_grad(), + config.patch( + { + "always_keep_tensor_constants": True, + "aot_inductor.use_runtime_constant_folding": True, + } + ), + ): + model = Model(self.device) + so_path, _ = run_and_get_cpp_code( + AOTIRunnerUtil.legacy_compile, model, example_inputs + ) + + runner = AOTIRunnerUtil.legacy_load_runner(self.device, so_path) + name_to_fqn = runner.get_constant_names_to_original_fqns() + self.assertTrue( + [name for name in name_to_fqn if name.startswith("_FOLDED_CONST_")], + msg="Expected runtime-folded constants in generated model", + ) + + def runner_call(x): + return runner.run([x])[0] + + test_inputs = torch.tensor([[1.0, -2.0], [3.0, -4.0]], device=self.device) + atol = 1e-3 + rtol = 1e-3 + expected = model(test_inputs) + self.assertEqual(expected, runner_call(test_inputs), atol=atol, rtol=rtol) + + with torch.no_grad(): + for p in model.parameters(): + p.add_(1.0) + for b in model.buffers(): + b.add_(2.0) + + state = {**dict(model.named_parameters()), **dict(model.named_buffers())} + new_weights = { + const_name: state[fqn].detach().clone() + for const_name, fqn in name_to_fqn.items() + if fqn in state + } + self.assertTrue(new_weights, msg="Expected non-empty constant update map") + + new_expected = model(test_inputs) + + runner.update_constant_buffer(new_weights, True, True) + self.assertEqual(expected, runner_call(test_inputs), atol=atol, rtol=rtol) + + runner.swap_constant_buffer() + self.assertEqual(new_expected, runner_call(test_inputs), atol=atol, rtol=rtol) + @requires_gpu def test_duplicate_constant_folding(self): class Model(torch.nn.Module): @@ -904,6 +1015,8 @@ def forward(self, y): IS_FBCODE, "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", ) + @xfailIf(IS_ARM64 and IS_CPU_CAPABILITY_SVE256) + # see https://github.com/pytorch/pytorch/issues/177243 @tf32_on_and_off(0.005) def test_deconv_freezing(self): dtypes = [torch.float] @@ -1325,14 +1438,11 @@ def forward(self, x, y): example_inputs = (x, y) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) - @skipIfXpu + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @skipIfMPS def test_fp8(self): # cuda only - if self.device != "cuda": + if self.device not in ("cuda", "xpu"): return class Model(torch.nn.Module): @@ -1376,7 +1486,6 @@ def forward(self, x, weight, bias, scale_a, scale_b): not PLATFORM_SUPPORTS_FP8_GROUPED_GEMM, "scaled_grouped_mm is only supported on SM90 and MI300+ devices", ) - @skipIfXpu def test_scaled_grouped_mm(self): # Test torch._scaled_grouped_mm AOTI lowering # cuda only @@ -1438,11 +1547,8 @@ def forward(self, x, weight, scale_a, scale_b, offsets): (x_fp8, weight_fp8, scale_a, scale_b, offsets), ) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) - @skipIfXpu + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @skipIfMPS def test_fp8_view_of_param(self): # cuda only if self.device != GPU_TYPE: @@ -1612,6 +1718,47 @@ def forward(self, a, b): dynamic_shapes=dynamic_shapes, ) + @unittest.skipIf( + not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)" + ) + def test_bmm_large_batch_dynamic(self): + if self.device == "cpu": + raise unittest.SkipTest("using triton backend only is not supported on CPU") + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.bmm(a, b) + + M, K, N = 64, 64, 64 + dtype = torch.float16 + model = Model() + + # Compile with small batch, then run with batch > 65535 (CUDA grid.y limit) + compile_batch = 100 + a = torch.randn(compile_batch, M, K, device=self.device, dtype=dtype) + b = torch.randn(compile_batch, K, N, device=self.device, dtype=dtype) + dim0_a = Dim("dim0_a", min=1, max=2**17) + dynamic_shapes = {"a": {0: dim0_a}, "b": {0: dim0_a}} + list_example_inputs = [(a, b)] + + # Large batch exceeding CUDA grid.y limit of 65535 + large_batch = 70000 + list_example_inputs.append( + ( + torch.randn(large_batch, M, K, device=self.device, dtype=dtype), + torch.randn(large_batch, K, N, device=self.device, dtype=dtype), + ), + ) + self.check_model_with_multiple_inputs( + model, + list_example_inputs, + options={ + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + }, + dynamic_shapes=dynamic_shapes, + ) + @skipIfWindows(msg="TODO: (xuhancn) confirm, Crash: access violation") def test_foreach_multiple_dynamic(self): class Model(torch.nn.Module): @@ -1730,6 +1877,8 @@ def forward(self, x): with config.patch({"aot_inductor.use_runtime_constant_folding": True}): self.check_model(Model(self.device), example_inputs) + @xfailIf(IS_ARM64) + # see https://github.com/pytorch/pytorch/issues/177254 @skipIfNoFBGEMM def test_quanatized_int8_linear(self): class Model(torch.nn.Module): @@ -1839,23 +1988,24 @@ def forward(self, x, y, lst): return cat * cat # Disable cuda caching allocator to check for IMA - torch.cuda.caching_allocator_enable(False) - model = Repro() - example_inputs = ( - # s0, s1 - torch.randn((100, 200), device=self.device), - # s2, s3 - torch.randn((100, 3), device=self.device), - # u0, u1, u2, u3, u100 - torch.tensor([200, 100, 0, 1, 300], device=self.device, dtype=torch.int), - ) - spec = { - "x": (Dim.DYNAMIC, Dim.DYNAMIC), - "y": (Dim.DYNAMIC, Dim.DYNAMIC), - "lst": (Dim.STATIC,), - } - self.check_model(model, example_inputs, dynamic_shapes=spec) - torch.cuda.caching_allocator_enable(True) + with caching_allocator_disabled(): + model = Repro() + example_inputs = ( + # s0, s1 + torch.randn((100, 200), device=self.device), + # s2, s3 + torch.randn((100, 3), device=self.device), + # u0, u1, u2, u3, u100 + torch.tensor( + [200, 100, 0, 1, 300], device=self.device, dtype=torch.int + ), + ) + spec = { + "x": (Dim.DYNAMIC, Dim.DYNAMIC), + "y": (Dim.DYNAMIC, Dim.DYNAMIC), + "lst": (Dim.STATIC,), + } + self.check_model(model, example_inputs, dynamic_shapes=spec) @skipIfMPS @config.patch({"unbacked_symint_fallback": 12}) @@ -1897,20 +2047,87 @@ def forward(self, x, y, lengths): relevant_embeddings += ones return relevant_embeddings * relevant_embeddings - torch.cuda.caching_allocator_enable(False) - model = Repro() - example_inputs = ( - torch.randn((1000, INNER_DIM), device=self.device), - torch.randn((2000, INNER_DIM), device=self.device), - torch.ones(3000), - ) - spec = { - "x": (Dim.DYNAMIC, Dim.STATIC), - "y": (Dim.DYNAMIC, Dim.STATIC), - "lengths": (Dim.DYNAMIC,), - } - self.check_model(model, example_inputs, dynamic_shapes=spec) - torch.cuda.caching_allocator_enable(True) + with caching_allocator_disabled(): + model = Repro() + example_inputs = ( + torch.randn((1000, INNER_DIM), device=self.device), + torch.randn((2000, INNER_DIM), device=self.device), + torch.ones(3000), + ) + spec = { + "x": (Dim.DYNAMIC, Dim.STATIC), + "y": (Dim.DYNAMIC, Dim.STATIC), + "lengths": (Dim.DYNAMIC,), + } + self.check_model(model, example_inputs, dynamic_shapes=spec) + + @skipIfMPS + @config.patch({"triton.autotune_at_compile_time": None}) + @torch.fx.experimental._config.patch("backed_size_oblivious", True) + def test_slice_independent_backed_symints_no_unbacked(self): + # x[0:s1] where x.size(0) = s0-1 should produce Min(s1, s0-1), + # not an unbacked symint with a bad fallback value. + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires triton") + + INNER_DIM = 4224 + + class Repro(torch.nn.Module): + def forward(self, x, y): + x_trimmed = x[:-1] + sliced = x_trimmed[: y.size(0)] + reshaped = sliced.reshape(-1, 128, 33) + expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8) + shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64) + return (expanded >> shifts) & 255 + + with caching_allocator_disabled(): + model = Repro() + example_inputs = ( + torch.randint( + 0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64 + ), + torch.randn(50, 8, device=self.device), + ) + spec = { + "x": (Dim.DYNAMIC, Dim.STATIC), + "y": (Dim.DYNAMIC, Dim.STATIC), + } + self.check_model(model, example_inputs, dynamic_shapes=spec) + + @skipIfMPS + @config.patch({"triton.autotune_at_compile_time": None}) + @torch.fx.experimental._config.patch("backed_size_oblivious", True) + def test_slice_negative_index_backed_symints_no_unbacked(self): + # x[-s1:] where x.size(0) = s0-1 should produce Max(s0-1 - s1, 0), + # not an unbacked symint with a bad fallback value. + if self.device != GPU_TYPE: + raise unittest.SkipTest("requires triton") + + INNER_DIM = 4224 + + class Repro(torch.nn.Module): + def forward(self, x, y): + x_trimmed = x[:-1] + sliced = x_trimmed[-y.size(0) :] + reshaped = sliced.reshape(-1, 128, 33) + expanded = reshaped.unsqueeze(3).expand(-1, 128, 33, 8) + shifts = torch.arange(0, 64, 8, device=x.device, dtype=torch.int64) + return (expanded >> shifts) & 255 + + with caching_allocator_disabled(): + model = Repro() + example_inputs = ( + torch.randint( + 0, 256, (200, INNER_DIM), device=self.device, dtype=torch.int64 + ), + torch.randn(50, 8, device=self.device), + ) + spec = { + "x": (Dim.DYNAMIC, Dim.STATIC), + "y": (Dim.DYNAMIC, Dim.STATIC), + } + self.check_model(model, example_inputs, dynamic_shapes=spec) @config.patch({"triton.autotune_at_compile_time": None}) def test_stride_with_unbacked_expr(self): @@ -2076,7 +2293,9 @@ def forward(self, values, repeats, mask, embeddings, x, y, z, lst): } self.check_model(Repro(), example_inputs, dynamic_shapes=spec) - @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") + @skipIfXpu( + msg="FlashAttentionForward headdim limitation on xpu - torch-xpu-ops: 2698" + ) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA" ) @@ -2130,6 +2349,24 @@ def forward(self, q, k, v): aot_model = torch._export.aot_load(path, device=self.device) torch.testing.assert_close(m(*inputs), aot_model(*inputs)) + @unittest.skipIf(IS_MACOS, "fp8 is not supported on Mac") + def test_aoti_fp8(self): + if self.device != "cpu" and not PLATFORM_SUPPORTS_FP8: + raise unittest.SkipTest( + "FP8 is only supported on H100+, SM 8.9 and MI300+ devices" + ) + + class M(torch.nn.Module): + def forward(self, x1, x2): + return x1.to(torch.float32) + x2.to(torch.float32) + + m = M().eval().to(self.device) + x = torch.randn(16, 16, device=self.device) + x1 = x.to(torch.float8_e4m3fn) + x2 = x.to(torch.float8_e5m2) + + self.check_model(m, (x1, x2)) + def test_aoti_constant_tensor(self): class Foo(torch.nn.Module): def __init__(self, device): @@ -2527,8 +2764,8 @@ def test_cond_cpu_predicate_cuda_operands(self, max_autotune): determined device from [predicate] + operands, causing CPU predicates to force CUDA outputs onto CPU during autotuning. """ - if self.device != "cuda": - raise unittest.SkipTest("requires CUDA") + if self.device != "cuda" and self.device != "xpu": + raise unittest.SkipTest("requires CUDA or XPU") class Model(torch.nn.Module): def __init__(self, input_dim=4, hidden_dim=8): @@ -3626,7 +3863,7 @@ def forward(self, x): # Call eval() here so that batch_norm won't update the running stats # Use float64 to avoid numeric difference failure - dtype = torch.float32 if self.device == "mps" else torch.float64 + dtype = highest_precision_float(self.device) model = Model().to(device=self.device, dtype=dtype).eval() example_inputs = (torch.randn(4, 3, 64, 64, device=self.device, dtype=dtype),) self.check_model(model, example_inputs) @@ -4128,6 +4365,36 @@ def forward(self, x, y): dynamic_shapes=dynamic_shapes, ) + @common_utils.parametrize("threshold", [float("inf"), float("-inf"), float("nan")]) + def test_triton_kernel_inf_float_arg(self, threshold): + if self.device != GPU_TYPE or self.device == "mps": + raise unittest.SkipTest("requires GPU") + + @triton.jit + def clamp_kernel( + in_ptr, out_ptr, threshold, n_elements, BLOCK_SIZE: "tl.constexpr" + ): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + out = tl.where(x > threshold, threshold, x) + tl.store(out_ptr + offsets, out, mask=mask) + + class Model(torch.nn.Module): + def __init__(self, t): + super().__init__() + self.t = t + + def forward(self, x): + out = torch.empty_like(x) + n = x.numel() + clamp_kernel[(n,)](x, out, self.t, n, BLOCK_SIZE=16) + return out + + example_inputs = (torch.randn(16, device=self.device),) + self.check_model(Model(threshold), example_inputs) + def test_triton_kernel_weird_param_order(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -5053,10 +5320,7 @@ def forward(self, inputs): self.check_model(m, inputs) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @patch.dict(os.environ, {"AOTI_RUNTIME_CHECK_INPUTS": "1"}) def test_runtime_checks_fp8(self): # cuda only @@ -5129,7 +5393,7 @@ def forward(self, x0, x1, x2): ) x2 = torch.tensor( 128, - dtype=torch.complex128 if self.device != "mps" else torch.complex64, + dtype=highest_precision_complex(self.device), device=self.device, ) inputs.append(x0) @@ -5796,11 +6060,8 @@ def forward(self, x): 2, ).run(code) - @unittest.skipIf( - not PLATFORM_SUPPORTS_FP8, - "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", - ) - @skipIfXpu + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @skipIfMPS def test_aoti_debug_printer_fp8_dtype(self): if self.device != GPU_TYPE: raise unittest.SkipTest("requires GPU") @@ -6953,7 +7214,6 @@ def forward(self, x): example_inputs = (torch.randn(500, device=self.device),) self.check_model(model, example_inputs) - @skipIfXpu def test_conv3d(self): if self.device != GPU_TYPE or not is_big_gpu(): raise unittest.SkipTest("requires modern GPU to run max-autotune") @@ -7261,34 +7521,22 @@ def forward(self, x, y, m): m = torch.tensor([4096], dtype=torch.int32, device=self.device) with config.patch("triton.autotune_with_sample_inputs", True): - if torch.version.hip: - # ROCm: Use dynamic grid checking (portable across different configs) - # Compile and get the generated code - _, src_code = run_and_get_cpp_code( - torch._export.aot_compile, Model(), (x, y, m) - ) - actual_grid, expected_grids = get_triton_grid_info( - strange_config_matmul_kernel, 4096 * 2046, src_code - ) - self.assertTrue( - actual_grid is not None, "Could not find grid_0 in generated code" - ) - self.assertIn( - actual_grid, - expected_grids, - f"grid_0={actual_grid} not in expected {expected_grids} from kernel configs", - ) - - else: - # CUDA/XPU: Keep existing hardcoded values - # The tuned best config on XPU is different with CUDA. - if GPU_TYPE == "xpu": - grid_0 = 32736 - else: - grid_0 = 1023 - self.code_check_count( - Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1 - ) + # Use dynamic grid checking (portable across different configs + # and Triton versions) + _, src_code = run_and_get_cpp_code( + torch._export.aot_compile, Model(), (x, y, m) + ) + actual_grid, expected_grids = get_triton_grid_info( + strange_config_matmul_kernel, 4096 * 2046, src_code + ) + self.assertTrue( + actual_grid is not None, "Could not find grid_0 in generated code" + ) + self.assertIn( + actual_grid, + expected_grids, + f"grid_0={actual_grid} not in expected {expected_grids} from kernel configs", + ) def test_triton_mutated_autotuning(self): if self.device != GPU_TYPE: @@ -7331,34 +7579,22 @@ def forward(self, x, y, m): m = torch.tensor([4095], dtype=torch.int32, device=self.device) with config.patch("triton.autotune_with_sample_inputs", True): - if torch.version.hip: - # ROCm: Use dynamic grid checking (portable across different configs) - # Compile and get the generated code - _, src_code = run_and_get_cpp_code( - torch._export.aot_compile, Model(), (x, y, m) - ) - actual_grid, expected_grids = get_triton_grid_info( - strange_config_matmul_kernel, 4096 * 2046, src_code - ) - self.assertTrue( - actual_grid is not None, "Could not find grid_0 in generated code" - ) - self.assertIn( - actual_grid, - expected_grids, - f"grid_0={actual_grid} not in expected {expected_grids} from kernel configs", - ) - - else: - # CUDA/XPU: Keep existing hardcoded values - # The tuned best config on XPU is different with CUDA. - if GPU_TYPE == "xpu": - grid_0 = 32736 - else: - grid_0 = 1023 - self.code_check_count( - Model(), (x, y, m), f"uint32_t grid_0 = {grid_0}L;", 1 - ) + # Use dynamic grid checking (portable across different configs + # and Triton versions) + _, src_code = run_and_get_cpp_code( + torch._export.aot_compile, Model(), (x, y, m) + ) + actual_grid, expected_grids = get_triton_grid_info( + strange_config_matmul_kernel, 4096 * 2046, src_code + ) + self.assertTrue( + actual_grid is not None, "Could not find grid_0 in generated code" + ) + self.assertIn( + actual_grid, + expected_grids, + f"grid_0={actual_grid} not in expected {expected_grids} from kernel configs", + ) @patch.dict(os.environ, {"TRITON_DEBUG": "1"}) def test_triton_dynamic_launcher_grid(self): @@ -7510,10 +7746,11 @@ def forward( add_18, add_13, ): + device = add_13.device arange_1 = torch.ops.aten.arange.start( 180, 181, - device=torch.device(type=GPU_TYPE, index=0), + device=device, pin_memory=False, ) add_14 = torch.ops.aten.add.Tensor(arange_1, 198) @@ -7818,8 +8055,8 @@ def test_codegen_int_array_var_fix_memory_leak(self): """ Fix https://github.com/pytorch/pytorch/issues/167630 """ - if self.device != "cuda": - raise unittest.SkipTest("test is only for cuda") + if self.device not in ("cuda", "xpu"): + raise unittest.SkipTest("test is only for cuda or xpu") def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): layers = [] @@ -7840,7 +8077,7 @@ def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): allocated_memory = [] for _ in range(3): - torch.cuda.reset_peak_memory_stats() + torch.accelerator.reset_peak_memory_stats() model = make_mlp(in_dim, hidden, out_dim, depth).to(self.device) example_inputs = (torch.randn(batch, in_dim, device=self.device),) @@ -7851,10 +8088,10 @@ def make_mlp(in_dim=128, hidden=256, out_dim=64, depth=3): torch._inductor.aoti_compile_and_package(ep) del model, example_inputs, ep - torch.cuda.synchronize() - torch.cuda.empty_cache() + torch.accelerator.synchronize() + torch.accelerator.empty_cache() gc.collect() - allocated_memory.append(torch.cuda.memory_allocated()) + allocated_memory.append(torch.accelerator.memory_allocated()) self.assertTrue(allocated_memory[1] == allocated_memory[2]) @@ -8216,6 +8453,37 @@ def forward(self, batch_sizes, embeddings): if os.path.exists(temp_stderr_path): os.unlink(temp_stderr_path) + def test_combo_kernel_grid_mixed_types(self): + if self.device != GPU_TYPE: + raise unittest.SkipTest("combo kernels require GPU") + + class Model(torch.nn.Module): + def forward(self, a, b, c, d): + return a + b, c + d + + # Non-contiguous (transposed) tensors force 2D Triton iteration, + # creating both x and y range trees in the combo kernel. + # Dynamic shapes on (a, b) make ynumel_0 dynamic (string in codegen), + # while static shapes on (c, d) keep ynumel_1 as int. + # This triggers GridExpr.maximum() with mixed [str, int] in C++ mode, + # which previously generated std::max(long, int) causing a template + # deduction error in the AOTInductor C++ wrapper. + example_inputs = ( + torch.randn(30, 20, device=self.device), + torch.randn(20, 30, device=self.device).t(), + torch.randn(40, 30, device=self.device), + torch.randn(30, 40, device=self.device).t(), + ) + dim0 = Dim("dim0", min=1, max=100) + dynamic_shapes = { + "a": {0: dim0, 1: None}, + "b": {0: dim0, 1: None}, + "c": None, + "d": None, + } + with config.patch({"combo_kernels": True}): + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + class AOTInductorLoggingTest(LoggingTestCase): @make_logging_test(dynamic=logging.DEBUG) @@ -8348,8 +8616,6 @@ def fail_gpu(suffixes: tuple[str, ...], is_skip=False): "test_quantized_linear": fail_gpu(("cuda", "xpu")), "test_quanatized_int8_linear": fail_gpu(("cuda", "xpu")), "test_quantized_linear_bias_none": fail_gpu(("cuda", "xpu")), - # No scaled_dot_product_efficient_attention implementation for XPU yet. - "test_scaled_dot_product_efficient_attention": fail_gpu(("xpu",)), } MPS_TEST_FAILURES = { diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py index 2b1214c863409..908e05d8aba25 100644 --- a/test/inductor/test_aot_inductor_arrayref.py +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -2,7 +2,11 @@ import sys import unittest +import torch +from torch._inductor import config from torch._inductor.test_case import TestCase +from torch._inductor.utils import run_and_get_cpp_code +from torch.testing import FileCheck from torch.testing._internal.common_utils import IS_CI, IS_FBCODE, IS_WINDOWS @@ -18,6 +22,7 @@ try: from .test_aot_inductor import ( AOTInductorTestsTemplate, + AOTIRunnerUtil, check_model, check_model_with_multiple_inputs, code_check_count, @@ -26,6 +31,7 @@ except ImportError: from test_aot_inductor import ( # @manual AOTInductorTestsTemplate, + AOTIRunnerUtil, check_model, check_model_with_multiple_inputs, code_check_count, @@ -57,32 +63,51 @@ def fail_minimal_arrayref_interface(is_skip=False): ) +class AOTInductorArrayRefTestsTemplate(AOTInductorTestsTemplate): + def test_simple_v2_interface(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(10, 10) + + def forward(self, x, y): + return x + self.linear(y) + + example_inputs = ( + torch.randn(10, 10, device=self.device), + torch.randn(10, 10, device=self.device), + ) + model = Model() + with config.patch( + { + "aot_inductor.allow_stack_allocation": self.allow_stack_allocation, + "aot_inductor.use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, + } + ): + _, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, model, example_inputs + ) + + FileCheck().check("AOTInductorModelRunMinimalArrayrefInterfaceV2(").check( + "constexpr int32_t expected_num_inputs = 2;" + ).check("constexpr int32_t expected_num_outputs = 1;").check( + "if (num_inputs != expected_num_inputs)" + ).check("if (num_outputs != expected_num_outputs)").run(code) + self.code_check_count( + model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1 + ) + + # test_failures, xfail by default, set is_skip=True to skip CPU_TEST_FAILURES = { # TODO: error: ‘complex64’ was not declared in this scope "test_add_complex": fail_minimal_arrayref_interface(is_skip=True), "test_conv_freezing": fail_minimal_arrayref_interface(is_skip=True), "test_deconv_freezing": fail_minimal_arrayref_interface(is_skip=True), - "test_cond_nested": fail_minimal_arrayref_interface(), - "test_cond_simple": fail_minimal_arrayref_interface(), - "test_cond_symint_input": fail_minimal_arrayref_interface(), - "test_cond_use_buffers_from_outer_scope": fail_minimal_arrayref_interface(), - "test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(), - "test_cond_with_parameters": fail_minimal_arrayref_interface(), - "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), - "test_custom_op_in_subgraph": fail_minimal_arrayref_interface(), "test_cond_share_predicate": fail_stack_allocation(is_skip=True), "test_cond_predicate_on_cpu": fail_stack_allocation(is_skip=True), - "test_cond_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), - "test_while_loop_with_unbacked_symint_closure_dynamic_True": fail_minimal_arrayref_interface(), - "test_while_loop_with_unbacked_symint_closure_dynamic_False": fail_minimal_arrayref_interface(), "test_while_loop_with_mixed_device_dynamic_True": fail_stack_allocation(), "test_while_loop_with_mixed_device_dynamic_False": fail_stack_allocation(), - "test_while_loop_with_sym_expr_cond_dynamic_True": fail_minimal_arrayref_interface(), - "test_while_loop_with_sym_expr_cond_dynamic_False": fail_minimal_arrayref_interface(), - "test_while_loop_with_conv_dynamic_True": fail_minimal_arrayref_interface(), - "test_while_loop_with_conv_dynamic_False": fail_minimal_arrayref_interface(), - "test_while_loop_with_parameters": fail_minimal_arrayref_interface(), "test_while_loop_with_pytree_inputs": fail_stack_allocation(), # FIXME: failed with Segfault while exiting the Python runtime "test_duplicate_constant_folding": fail_stack_allocation(is_skip=True), @@ -239,6 +264,12 @@ class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterf "cpu_with_stack_allocation_and_minimal_arrayref_interface", CPU_TEST_FAILURES, ) + copy_tests( + AOTInductorArrayRefTestsTemplate, + AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface, + "cpu_with_stack_allocation_and_minimal_arrayref_interface", + CPU_TEST_FAILURES, + ) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_custom_ops.py b/test/inductor/test_aot_inductor_custom_ops.py index 1b608dc761702..e83d11fe0af38 100644 --- a/test/inductor/test_aot_inductor_custom_ops.py +++ b/test/inductor/test_aot_inductor_custom_ops.py @@ -505,7 +505,7 @@ def forward(self, x): args = (torch.randn(4, 4, device=self.device),) self.check_model(m, args) - @skipIfXpu + @skipIfXpu(msg="compile error - torch-xpu-ops: 2609") @unittest.skipIf(IS_FBCODE, "unable to find library -laoti_custom_ops") def test_custom_op_square(self) -> None: class Model(torch.nn.Module): diff --git a/test/inductor/test_aot_inductor_package.py b/test/inductor/test_aot_inductor_package.py index 887e8ab30ee38..6ae1dc0208480 100644 --- a/test/inductor/test_aot_inductor_package.py +++ b/test/inductor/test_aot_inductor_package.py @@ -308,8 +308,11 @@ def forward(self, x, y): if self.device == GPU_TYPE: kernel_bin = get_kernel_bin_format(self.device) self.assertTrue(not list(tmp_path.glob(f"*.{kernel_bin}"))) - # Check if .cubin.o files exist and use unique kernel names - self.assertTrue(list(tmp_path.glob(f"triton_*.{kernel_bin}.o"))) + # Check that cubin binaries are embedded as object files. + # Either individual per-kernel .o files or a single combined .o. + individual_objs = list(tmp_path.glob(f"triton_*.{kernel_bin}.o")) + combined_obj = list(tmp_path.glob("cubins_combined.o")) + self.assertTrue(individual_objs or combined_obj) # Check if the .so file was build successfully so_path = build_path / "libaoti_model.so" diff --git a/test/inductor/test_aoti_cross_compile_windows.py b/test/inductor/test_aoti_cross_compile_windows.py index d2e75041a0860..809cf74ed3d69 100644 --- a/test/inductor/test_aoti_cross_compile_windows.py +++ b/test/inductor/test_aoti_cross_compile_windows.py @@ -1,15 +1,18 @@ # Owner(s): ["module: inductor"] import os import platform +import shutil import tempfile import unittest from dataclasses import dataclass from pathlib import Path from typing import Any +from unittest.mock import MagicMock, patch import torch import torch._inductor.config from torch._environment import is_fbcode +from torch._inductor.cpp_builder import _ensure_mingw_cudart_import_lib from torch._inductor.test_case import TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu @@ -318,6 +321,137 @@ def forward(self, x): ) +class TestEnsureMingwCudartImportLib(TestCase): + """Unit tests for _ensure_mingw_cudart_import_lib.""" + + def setUp(self): + super().setUp() + self.tmp_dir = tempfile.mkdtemp(prefix="test_mingw_cudart_") + self.cuda_home = os.path.join(self.tmp_dir, "cuda") + self.lib_dir = os.path.join(self.tmp_dir, "lib") + os.makedirs(os.path.join(self.cuda_home, "bin", "x64"), exist_ok=True) + os.makedirs(self.lib_dir, exist_ok=True) + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + super().tearDown() + + def _create_fake_dll(self, version: str = "130") -> str: + dll_path = os.path.join(self.cuda_home, "bin", "x64", f"cudart64_{version}.dll") + Path(dll_path).touch() + return dll_path + + def _create_fake_cudart_lib(self) -> str: + lib_path = os.path.join(self.lib_dir, "cudart.lib") + Path(lib_path).touch() + return lib_path + + def test_noop_when_no_windows_cuda_home(self): + """Should skip gracefully when WINDOWS_CUDA_HOME is not set.""" + env = os.environ.copy() + env.pop("WINDOWS_CUDA_HOME", None) + with patch.dict(os.environ, env, clear=True): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + self.assertEqual(result, []) + self.assertFalse(os.path.exists(os.path.join(self.lib_dir, "libcudart.a"))) + + def test_noop_when_libcudart_a_already_exists(self): + """Should skip generation if libcudart.a already exists.""" + existing = os.path.join(self.lib_dir, "libcudart.a") + Path(existing).touch() + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + self.assertEqual(result, []) + self.assertTrue(os.path.exists(existing)) + + @patch("torch._inductor.cpp_builder._create_msvc_gs_stubs_lib", return_value=None) + def test_gs_stubs_fallback_when_no_dll_found(self, mock_stubs): + """Should attempt GS stubs fallback when no cudart64_*.dll is found.""" + self._create_fake_cudart_lib() + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + _ensure_mingw_cudart_import_lib([self.lib_dir]) + mock_stubs.assert_called_once_with(self.lib_dir) + + @patch( + "torch._inductor.cpp_builder._create_msvc_gs_stubs_lib", + return_value="msvc_gs_stubs", + ) + def test_gs_stubs_returned_when_no_dll(self, mock_stubs): + """Should return the stubs library name when DLL is unavailable.""" + self._create_fake_cudart_lib() + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + self.assertEqual(result, ["msvc_gs_stubs"]) + + def test_noop_when_no_writable_dir_with_cudart_lib(self): + """Should skip gracefully when no writable directory contains cudart.lib.""" + self._create_fake_dll() + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + self.assertEqual(result, []) + self.assertFalse(os.path.exists(os.path.join(self.lib_dir, "libcudart.a"))) + + @patch("subprocess.run") + def test_successful_generation(self, mock_run: MagicMock): + """Should call gendef and dlltool when all conditions are met.""" + dll_path = self._create_fake_dll() + self._create_fake_cudart_lib() + + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + + self.assertEqual(result, []) + self.assertEqual(mock_run.call_count, 2) + + # Verify gendef call + gendef_call = mock_run.call_args_list[0] + gendef_cmd = gendef_call[0][0] + self.assertEqual(gendef_cmd[0], "gendef") + self.assertEqual(gendef_cmd[1], "-") + self.assertEqual(gendef_cmd[2], dll_path) + + # Verify dlltool call + dlltool_call = mock_run.call_args_list[1] + dlltool_cmd = dlltool_call[0][0] + self.assertEqual(dlltool_cmd[0], "x86_64-w64-mingw32-dlltool") + self.assertIn("-d", dlltool_cmd) + self.assertIn("-l", dlltool_cmd) + self.assertIn("-D", dlltool_cmd) + self.assertIn("cudart64_130.dll", dlltool_cmd) + + @patch( + "torch._inductor.cpp_builder._create_msvc_gs_stubs_lib", + return_value="msvc_gs_stubs", + ) + @patch("subprocess.run", side_effect=FileNotFoundError("gendef not found")) + def test_gs_stubs_fallback_on_gendef_not_found(self, mock_run, mock_stubs): + """Should fall back to GS stubs when gendef is not installed.""" + self._create_fake_dll() + self._create_fake_cudart_lib() + + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + result = _ensure_mingw_cudart_import_lib([self.lib_dir]) + + self.assertEqual(result, ["msvc_gs_stubs"]) + mock_stubs.assert_called_once() + + @patch("subprocess.run") + def test_bin_x64_fallback_to_bin(self, mock_run: MagicMock): + """Should fall back to bin/ when bin/x64/ does not exist.""" + shutil.rmtree(os.path.join(self.cuda_home, "bin", "x64")) + bin_dir = os.path.join(self.cuda_home, "bin") + dll_path = os.path.join(bin_dir, "cudart64_130.dll") + Path(dll_path).touch() + self._create_fake_cudart_lib() + + with patch.dict(os.environ, {"WINDOWS_CUDA_HOME": self.cuda_home}): + _ensure_mingw_cudart_import_lib([self.lib_dir]) + + self.assertEqual(mock_run.call_count, 2) + gendef_cmd = mock_run.call_args_list[0][0][0] + self.assertIn(dll_path, gendef_cmd) + + if __name__ == "__main__": import sys diff --git a/test/inductor/test_augmented_graph_helper.py b/test/inductor/test_augmented_graph_helper.py index b9406b0cf8550..6e251aec40d6d 100644 --- a/test/inductor/test_augmented_graph_helper.py +++ b/test/inductor/test_augmented_graph_helper.py @@ -318,7 +318,8 @@ def test_cycle_through_merge(self): # Merging b4 and c4 would create cycle tracker4.merge_to_set(b4, c4) - self.assertTrue(tracker4.has_cycle()) + # has_cycle() is not merge-aware + self.assertFalse(tracker4.has_cycle()) def test_cycle_with_extra_deps(self): """Test cycle detection with extra dependencies.""" diff --git a/test/inductor/test_auto_chunker.py b/test/inductor/test_auto_chunker.py index 25957a01bcc6d..50726e5ff0bf1 100644 --- a/test/inductor/test_auto_chunker.py +++ b/test/inductor/test_auto_chunker.py @@ -271,6 +271,110 @@ def f(x, w): self.assertTrue(same(expect, actual, tol=1e-3)) self.assertEqual(metrics.num_auto_chunking, 1) + @config.patch("auto_chunker.output_size_threshold", 1024) + @config.patch("auto_chunker.num_chunk", 2) + def test_propagate_amax_unsqueeze(self): + M, K, N = 256, 4, 256 + x = torch.randn(M, K, device=GPU_TYPE, requires_grad=True) + w = torch.randn(K, N, device=GPU_TYPE, requires_grad=True) + + def f(x, w): + out = (x * 2) @ w + max_val = out.amax(dim=-1) + out = out - max_val.unsqueeze(-1) + out = torch.exp(out) + loss = out.sum() + loss.backward() + return loss + + expect = (f(x, w), x.grad, w.grad) + x.grad = None + w.grad = None + opt_f = torch.compile(f) + actual = (opt_f(x, w), x.grad, w.grad) + + self.assertTrue(same(expect, actual, tol=1e-3)) + self.assertEqual(metrics.num_auto_chunking, 1) + + @config.patch("auto_chunker.output_size_threshold", 1024) + @config.patch("auto_chunker.num_chunk", 2) + def test_propagate_gather(self): + M, K, N = 256, 4, 256 + x = torch.randn(M, K, device=GPU_TYPE, requires_grad=True) + w = torch.randn(K, N, device=GPU_TYPE, requires_grad=True) + targets = torch.randint(0, N, (M,), device=GPU_TYPE) + + def f(x, w, targets): + out = (x * 2) @ w + out = out.softmax(dim=-1) + selected = out.gather(1, targets.unsqueeze(1)) + loss = selected.squeeze(1).sum() + loss.backward() + return loss + + expect = (f(x, w, targets), x.grad, w.grad) + x.grad = None + w.grad = None + opt_f = torch.compile(f) + actual = (opt_f(x, w, targets), x.grad, w.grad) + + self.assertTrue(same(expect, actual, tol=1e-3)) + self.assertEqual(metrics.num_auto_chunking, 1) + + @config.patch("auto_chunker.output_size_threshold", 1024) + @config.patch("auto_chunker.num_chunk", 2) + def test_propagate_scatter(self): + M, K, N = 256, 4, 256 + x = torch.randn(M, K, device=GPU_TYPE, requires_grad=True) + w = torch.randn(K, N, device=GPU_TYPE, requires_grad=True) + targets = torch.randint(0, N, (M,), device=GPU_TYPE) + + def f(x, w, targets): + out = (x * 2) @ w + out = out.softmax(dim=-1) + out = out.scatter(1, targets.unsqueeze(1), 0.0) + loss = out.sum() + loss.backward() + return loss + + expect = (f(x, w, targets), x.grad, w.grad) + x.grad = None + w.grad = None + opt_f = torch.compile(f) + actual = (opt_f(x, w, targets), x.grad, w.grad) + + self.assertTrue(same(expect, actual, tol=1e-3)) + self.assertEqual(metrics.num_auto_chunking, 1) + + @config.patch("auto_chunker.output_size_threshold", 1024) + @config.patch("auto_chunker.num_chunk", 2) + def test_propagate_manual_cross_entropy(self): + M, K, N = 256, 4, 256 + x = torch.randn(M, K, device=GPU_TYPE, requires_grad=True) + w = torch.randn(K, N, device=GPU_TYPE, requires_grad=True) + targets = torch.randint(0, N, (M,), device=GPU_TYPE) + + def f(x, w, targets): + logits = (x * 2) @ w + max_logits = logits.amax(dim=-1) + shifted = logits - max_logits.unsqueeze(-1) + exp_shifted = shifted.exp() + sum_exp = exp_shifted.sum(dim=-1) + log_probs = shifted - sum_exp.log().unsqueeze(-1) + target_log_probs = log_probs.gather(1, targets.unsqueeze(1)) + loss = -target_log_probs.squeeze(1).sum() / M + loss.backward() + return loss + + expect = (f(x, w, targets), x.grad, w.grad) + x.grad = None + w.grad = None + opt_f = torch.compile(f) + actual = (opt_f(x, w, targets), x.grad, w.grad) + + self.assertTrue(same(expect, actual, tol=1e-3)) + self.assertEqual(metrics.num_auto_chunking, 1) + def test_set_num_chunk_with_compile_options(self): B = 32 T = 1024 diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 722d9d36ceaf9..da12a751137fe 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -186,7 +186,7 @@ def f(x, y, z, n): """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = foo_default = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ) @@ -249,7 +249,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg0_1 = arg3_1 = arg4_1 = arg1_1 = arg2_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - return (getitem_4, getitem_5)""", # noqa: B950 + return (getitem_4, getitem_5)""", ignore_comments=True, ) @@ -419,7 +419,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 foo_default = torch.ops.mylib.foo.default(arg1_1, [arg4_1, arg5_1], arg2_1, 2, arg3_1); arg4_1 = arg5_1 = arg3_1 = foo_default = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -431,7 +431,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 foo_default = torch.ops.mylib.foo.default(arg0_1, [arg3_1, arg4_1], arg1_1, 2, arg2_1); arg3_1 = arg4_1 = arg2_1 = foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -477,6 +477,35 @@ def run_inductor( return [compiled_args, result, graph] + def test_cumulative_out_preserves_out_dtype_under_compile(self): + test_cases = [ + (torch.cumsum, torch.tensor([0.9201, 0.1166], dtype=torch.float32)), + (torch.cumprod, torch.tensor([2.0, 0.6], dtype=torch.float32)), + ] + + for op, x in test_cases: + with self.subTest(op=op.__name__): + + def f(out, x): + return op(x, 0, out=out) + + eager_out = torch.tensor([0, 0], dtype=torch.int32) + aot_eager_out = eager_out.clone() + inductor_out = eager_out.clone() + + eager_ret = f(eager_out, x) + aot_eager_ret = torch.compile(f, backend="aot_eager", fullgraph=True)( + aot_eager_out, x + ) + inductor_ret = torch.compile(f, backend="inductor", fullgraph=True)( + inductor_out, x + ) + + self.assertEqual(aot_eager_out, eager_out) + self.assertEqual(inductor_out, eager_out) + self.assertEqual(aot_eager_ret, eager_ret) + self.assertEqual(inductor_ret, eager_ret) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) def test_auto_functionalize_with_returns_v2(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: @@ -527,7 +556,7 @@ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3 getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None - return (getitem_4, getitem_5)""", # noqa: B950 + return (getitem_4, getitem_5)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -587,7 +616,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu", arg2_1: "f32[s77 add: "f32[s77][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None copy__1: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_2); arg2_1 = getitem_2 = copy__1 = None - return (add,)""", # noqa: B950 + return (add,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -602,7 +631,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): add: "f32[2][1]cpu" = torch.ops.aten.add.Tensor(getitem_1, getitem_2) copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None - return (add,)""", # noqa: B950 + return (add,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -684,7 +713,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg1_1]) getitem_1: "f32[s77][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = getitem_1 = copy_ = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -696,7 +725,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_size = (), _x_stride = (), _x_storage_offset = 0, _y_base_index = 0, _y_size = (), _y_stride = (), _y_storage_offset = 1, _all_bases = [arg0_1]) getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -713,7 +742,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg1_1, [], [], 1) foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -726,7 +755,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): as_strided_default_1: "f32[][]cpu" = torch.ops.aten.as_strided.default(arg0_1, [], [], 1) foo_default = torch.ops.mylib.foo.default(as_strided_default, as_strided_default_1); as_strided_default = as_strided_default_1 = foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -777,7 +806,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None select_2: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 0) select_3: "f32[][]cpu" = torch.ops.aten.select.int(getitem_1, 0, 1); getitem_1 = None - return (select_2, select_3)""", # noqa: B950 + return (select_2, select_3)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -795,7 +824,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None select_2: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) select_3: "f32[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None - return (select_2, select_3)""", # noqa: B950 + return (select_2, select_3)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -847,7 +876,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): getitem_2: "f32[2][1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -861,7 +890,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu"): foo_default = torch.ops.mylib.foo.default(arg0_1, arg1_1); foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -913,7 +942,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = getitem_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_2); arg1_1 = getitem_2 = copy__1 = None copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, getitem_3); arg2_1 = getitem_3 = copy__2 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -931,7 +960,7 @@ def forward(self, arg0_1: "f32[2][1]cpu", arg1_1: "f32[2][1]cpu", arg2_1: "f32[2 copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None copy__1: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy__1 = None copy__2: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__2 = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -980,7 +1009,7 @@ def f(x, y, z, n): def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"): foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg0_1, 2, arg1_1); arg2_1 = arg3_1 = arg1_1 = foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); arg0_1 = copy_ = None - return ()""", # noqa: B950 + return ()""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1096,7 +1125,7 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None alias_2: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1) alias_3: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None - return (alias_2, alias_3)""", # noqa: B950 + return (alias_2, alias_3)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1110,7 +1139,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None alias_2: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1) alias_3: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None - return (alias_2, alias_3)""", # noqa: B950 + return (alias_2, alias_3)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1127,7 +1156,7 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); \ alias_default = alias_default_1 = foo_default = None copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None - return (arg1_1, arg1_1)""", # noqa: B950 + return (arg1_1, arg1_1)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1141,7 +1170,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); \ alias_default = alias_default_1 = foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None - return (arg0_1, arg0_1)""", # noqa: B950 + return (arg0_1, arg0_1)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1200,7 +1229,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1); getitem_3 = None getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None - return (getitem_4, getitem_7)""", # noqa: B950 + return (getitem_4, getitem_7)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1216,7 +1245,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1); getitem_3 = None getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None - return (getitem_4, getitem_7)""", # noqa: B950 + return (getitem_4, getitem_7)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1237,7 +1266,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1); arg0_1 = None getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None - return (getitem_4, getitem_7)""", # noqa: B950 + return (getitem_4, getitem_7)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1254,7 +1283,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1); arg0_1 = None getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None - return (getitem_4, getitem_7)""", # noqa: B950 + return (getitem_4, getitem_7)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1316,7 +1345,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"): copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None - return (slice_3, slice_4)""", # noqa: B950 + return (slice_3, slice_4)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1330,7 +1359,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None slice_3: "f32[2, 10][10, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) slice_4: "f32[10, 1][10, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None - return (slice_3, slice_4)""", # noqa: B950 + return (slice_3, slice_4)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1350,7 +1379,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, s77][s77, 1]cpu"): copy_: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None slice_3: "f32[2, s77][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) slice_4: "f32[s77, 1][s77, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None - return (slice_3, slice_4)""", # noqa: B950 + return (slice_3, slice_4)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1365,7 +1394,7 @@ def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None slice_3: "f32[2, 10][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 2) slice_4: "f32[10, 1][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 3, 4); arg0_1 = None - return (slice_3, slice_4)""", # noqa: B950 + return (slice_3, slice_4)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1502,7 +1531,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None alias_1: "f32[s77][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None - return (alias_1, slice_2)""", # noqa: B950 + return (alias_1, slice_2)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1525,7 +1554,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None - return (alias_1, slice_2)""", # noqa: B950 + return (alias_1, slice_2)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1547,7 +1576,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77][1]cpu"): foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None copy_: "f32[s77][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None - return (arg1_1, slice_2)""", # noqa: B950 + return (arg1_1, slice_2)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -1568,7 +1597,7 @@ def forward(self, arg0_1: "f32[2][1]cpu"): foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None slice_2: "f32[u0, 1][1, u0]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None - return (arg0_1, slice_2)""", # noqa: B950 + return (arg0_1, slice_2)""", ignore_comments=True, ignore_empty_lines=True, ) diff --git a/test/inductor/test_autoheuristic.py b/test/inductor/test_autoheuristic.py index 270608086799b..692251f6fe2b2 100644 --- a/test/inductor/test_autoheuristic.py +++ b/test/inductor/test_autoheuristic.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import os import unittest +from unittest.mock import patch import torch import torch._inductor.config as inductor_config @@ -14,7 +15,6 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100 -@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") class AutoHeuristicTest(TestCase): def count_lines_in_file(self, file_path): with open(file_path) as file: @@ -35,12 +35,14 @@ def get_path_to_autoheuristic_log(self, name): path = cache_dir() + "/autoheuristic/" + device_name + "/" + name + ".txt" return path + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") def test_autoheuristic_pad_mm_default(self): # this test ensures that data is not collected for pad_mm when autoheuristic config is set to its default value self.run_mm() self.assertFalse(os.path.exists(self.get_path_to_autoheuristic_log("pad_mm"))) - @inductor_config.patch(autoheuristic_collect="foo") + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @inductor_config.patch("autoheuristic_collect.pad_mm", False) def test_autoheuristic_pad_mm_off(self): # this test ensures that data is not collected for pad_mm when autoheuristic_collect does not contain "pad_mm" self.run_mm() @@ -56,17 +58,20 @@ def assert_autoheuristic_collected_data(self): # 1 line for metadata, 1 line for header, 1 line per choice (orig, padded) self.assertEqual(num_lines, 4) - @inductor_config.patch(autoheuristic_collect="pad_mm") + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @inductor_config.patch("autoheuristic_collect.pad_mm", True) def test_autoheuristic_pad_mm_collect_data(self): - # this test ensures that data is collected for pad_mm when autoheuristic_collect="pad_mm" + # this test ensures that data is collected for pad_mm when autoheuristic_collect.pad_mm=True self.assert_autoheuristic_collected_data() - @inductor_config.patch(autoheuristic_collect="foo,pad_mm") + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @inductor_config.patch("autoheuristic_collect.pad_mm", True) def test_autoheuristic_pad_mm_collect_data2(self): - # this test ensures that data is collected for "pad_mm" when autoheuristic_collect contains "pad_mm" + # this test ensures that data is collected for pad_mm when autoheuristic_collect.pad_mm=True self.assert_autoheuristic_collected_data() - @inductor_config.patch(autoheuristic_collect="test") + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @patch.dict(os.environ, {"TORCHINDUCTOR_AUTOHEURISTIC_COLLECT": "test"}) def test_autoheuristic(self): # test basic functionality of autoheuristic def fallback(): @@ -102,7 +107,9 @@ def feedback_fn(choice): self.assertEqual(num_lines, 5) shared_memory = get_gpu_shared_memory() - (fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability() + compute_cap = get_interface_for_device(GPU_TYPE).get_compute_capability() + # Convert single int compute capability (e.g., 90) to tuple (e.g., (9, 0)) + fst, snd = compute_cap // 10, compute_cap % 10 with open(path) as file: lines = file.readlines() @@ -116,15 +123,17 @@ def feedback_fn(choice): self.assertEqual("5,b,2", lines[3].rstrip()) self.assertEqual("5,c,3", lines[4].rstrip()) + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") @unittest.skipIf(not IS_A100, "heuristic only run on A100") - @inductor_config.patch(autoheuristic_use="pad_mm") + @inductor_config.patch("autoheuristic_use.pad_mm", True) def test_autoheuristic_a100(self): # Make sure heuristic does not break anything # TODO (AlnisM): Find a way to check whether heuristic is used self.run_mm() + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") @unittest.skipIf(not IS_H100, "heuristic only run on H100") - @inductor_config.patch(autoheuristic_use="pad_mm") + @inductor_config.patch("autoheuristic_use.pad_mm", True) def test_autoheuristic_h100(self): # Make sure heuristic does not break anything # TODO (AlnisM): Find a way to check whether heuristic is used @@ -145,12 +154,14 @@ def fn(a, b): # a choice made by the heuristic might be added to the list of choices # and if select_algorithm now creates a new precompile key, it will be # different from the precompile key created by autoheuristic - @inductor_config.patch( - autoheuristic_collect="mixed_mm", - autoheuristic_use="", - fx_graph_cache=False, - fx_graph_remote_cache=False, + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @unittest.skip( + "mixed_mm autoheuristic collection is broken after mixed_mm special casing deletion (PR #147151)" ) + @inductor_config.patch("autoheuristic_collect.mixed_mm", True) + @inductor_config.patch("autoheuristic_use.mixed_mm", False) + @inductor_config.patch(fx_graph_cache=False) + @inductor_config.patch(fx_graph_remote_cache=False) def test_global_feedback(self): self.run_mixed_mm() path = self.get_path_to_autoheuristic_log("mixed_mm") @@ -161,12 +172,43 @@ def test_global_feedback(self): # 1 line for fallback + at least 1 config self.assertTrue(num_lines > 4) - @inductor_config.patch(autoheuristic_use="mixed_mm") + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @inductor_config.patch("autoheuristic_use.mixed_mm", True) @unittest.skipIf(not IS_A100, "heuristic only run on A100") def test_mixed_mm_a100(self): self.run_mixed_mm() # TODO (AlnisM): Find a way to check whether heuristic is used + @skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") + @unittest.skipIf(not IS_H100 and not IS_A100, "heuristic only run on H100") + @inductor_config.patch(deterministic=True) + @inductor_config.patch("autoheuristic_use.pad_mm", True) + def test_pad_mm_autoheuristic_deterministic_mode(self): + """Test that pad_mm AutoHeuristics works in deterministic mode.""" + from torch._dynamo.utils import counters + + counters.clear() + + def f(a, b): + return torch.mm(a, b) + + cf = torch.compile(f) + # Use shapes that would normally trigger padding but aren't well-aligned + a = torch.randn(2047, 2048, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(2048, 2049, device=GPU_TYPE, dtype=torch.float16) + + # Run the compiled function + result = cf(a, b) + + # Verify correctness with tolerance for potential padding differences + expected = torch.mm(a, b) + torch.testing.assert_close(result, expected, atol=0.1, rtol=0.05) + + # In deterministic mode with AutoHeuristics enabled, + # we should not do any benchmarking (pad_mm_bench should stay 0) + # because AutoHeuristics makes the decision without benchmarking + self.assertEqual(counters["inductor"]["pad_mm_bench"], 0) + if __name__ == "__main__": if HAS_GPU: diff --git a/test/inductor/test_block_ptr_store_dtype.py b/test/inductor/test_block_ptr_store_dtype.py new file mode 100644 index 0000000000000..e37446379e747 --- /dev/null +++ b/test/inductor/test_block_ptr_store_dtype.py @@ -0,0 +1,108 @@ +# Owner(s): ["module: inductor"] +""" +Unit test for block_ptr store dtype resolution with inplace buffers. + +Validates that codegen_block_ptr_store_line casts store values to match the +block pointer element type (the actual tensor dtype), not the graph's +intermediate buffer dtype, when storing into an inplace-mutated buffer. + +See T260349710 for the original issue: Inductor codegen for _to_copy ops +generated .to(tl.float32) for block_ptr stores into bf16 gradient buffers, +causing a dtype mismatch assertion in the Triton compiler. + +Usage: + buck test fbcode//caffe2/test/inductor:test_block_ptr_store_dtype +""" + +import torch +from torch._inductor.codegen.common import InplacedBuffer, REMOVED +from torch._inductor.codegen.triton import triton_store_type +from torch._inductor.test_case import run_tests, TestCase + + +class TestBlockPtrStoreDtype(TestCase): + """Test dtype resolution for block_ptr stores with inplace buffers. + + The logic under test (from codegen_block_ptr_store_line): + + store_dtype = V.graph.get_dtype(name) + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if not isinstance(buf, RemovedArg): + store_dtype = V.graph.get_dtype(buf.other_names[0]) + value = f"{value}.to({triton_store_type(store_dtype)})" + """ + + @staticmethod + def _resolve_store_dtype(name, inplace_buffers, get_dtype): + """Reproduce the dtype resolution logic from codegen_block_ptr_store_line.""" + from torch._inductor.codegen.common import RemovedArg + + store_dtype = get_dtype(name) + if name in inplace_buffers: + buf = inplace_buffers[name] + if not isinstance(buf, RemovedArg): + store_dtype = get_dtype(buf.other_names[0]) + return store_dtype + + def test_non_inplace_uses_graph_dtype(self): + """Non-inplace buffer: store dtype matches graph dtype.""" + dtypes = {"buf0": torch.float32} + result = self._resolve_store_dtype("buf0", {}, lambda n: dtypes[n]) + self.assertEqual(result, torch.float32) + self.assertEqual(triton_store_type(result), "tl.float32") + + def test_inplace_buffer_uses_input_dtype(self): + """Inplace buffer with dtype mismatch: store dtype matches the input + buffer's dtype (the actual tensor), not the graph output dtype. + + This is the key scenario from T260349710: _to_copy produces fp32 + intermediate values stored into a bf16 gradient buffer via inplace + mutation. The block pointer element type comes from the actual tensor + (bf16), so the cast must be .to(tl.bfloat16), not .to(tl.float32). + """ + dtypes = {"buf0": torch.float32, "primals_1": torch.bfloat16} + inplace_buffers = { + "buf0": InplacedBuffer("in_out_ptr0", ["primals_1", "buf0"]), + "primals_1": InplacedBuffer("in_out_ptr0", ["primals_1", "buf0"]), + } + result = self._resolve_store_dtype("buf0", inplace_buffers, lambda n: dtypes[n]) + self.assertEqual(result, torch.bfloat16) + self.assertEqual(triton_store_type(result), "tl.bfloat16") + + def test_removed_inplace_falls_back_to_graph_dtype(self): + """Removed inplace buffer: falls back to graph dtype.""" + dtypes = {"buf0": torch.float32} + inplace_buffers = {"buf0": REMOVED} + result = self._resolve_store_dtype("buf0", inplace_buffers, lambda n: dtypes[n]) + self.assertEqual(result, torch.float32) + self.assertEqual(triton_store_type(result), "tl.float32") + + def test_same_dtype_inplace_is_unchanged(self): + """Same-dtype inplace buffer: dtype is unchanged (both bf16).""" + dtypes = {"buf0": torch.bfloat16, "primals_1": torch.bfloat16} + inplace_buffers = { + "buf0": InplacedBuffer("in_out_ptr0", ["primals_1", "buf0"]), + } + result = self._resolve_store_dtype("buf0", inplace_buffers, lambda n: dtypes[n]) + self.assertEqual(result, torch.bfloat16) + self.assertEqual(triton_store_type(result), "tl.bfloat16") + + def test_chained_inplace_uses_original_input_dtype(self): + """Chained inplace mutations: uses the original input buffer's dtype.""" + dtypes = { + "buf0": torch.float32, + "buf1": torch.float16, + "primals_1": torch.bfloat16, + } + # Chain: primals_1 -> buf1 -> buf0 + inplace_buffers = { + "buf0": InplacedBuffer("in_out_ptr0", ["primals_1", "buf1", "buf0"]), + } + result = self._resolve_store_dtype("buf0", inplace_buffers, lambda n: dtypes[n]) + self.assertEqual(result, torch.bfloat16) + self.assertEqual(triton_store_type(result), "tl.bfloat16") + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_cache.py b/test/inductor/test_cache.py index 3c99cfb88a78b..aac5e1c1e4835 100644 --- a/test/inductor/test_cache.py +++ b/test/inductor/test_cache.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] from __future__ import annotations +import json import pickle from concurrent.futures import ThreadPoolExecutor from inspect import isclass @@ -12,6 +13,7 @@ from typing_extensions import Self from unittest.mock import patch +import torch._inductor.config as inductor_config from torch._inductor import cache as icache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import ( @@ -815,5 +817,70 @@ def test_on_disk_cache_version_bump( self.assertEqual(cache.get(key), value) +class ConfigSerializationTest(TestCase): + def test_callable_config_not_json_serializable_1(self): + # Repro: setting callable configs to a non-None value + # save_config_portable() return a dict that is not JSON-serializable. + with inductor_config.patch( + bucket_all_gathers_fx_bucket_size_determinator=lambda: None + ): + portable = inductor_config.save_config_portable( + ignore_private_configs=False + ) + self.assertIn("bucket_all_gathers_fx_bucket_size_determinator", portable) + with self.assertRaises(TypeError, msg="not JSON serializable"): + json.dumps(portable) + + def test_callable_config_not_json_serializable_2(self): + # save_config_portable calls the factory and then .uuid(), + # producing a JSON-serializable value for the cache key. + from torch._inductor.choices import InductorChoices + + class ChoicesA(InductorChoices): + def uuid(self): + return "choices_a" + + class ChoicesB(InductorChoices): + def uuid(self): + return "choices_b" + + # None default stays as None in the config hash. + with inductor_config.patch(inductor_choices_class=None): + portable = inductor_config.save_config_portable( + ignore_private_configs=False + ) + self.assertIsNone(portable["inductor_choices_class"]) + json.dumps(portable) + + # Factory returning ChoicesA → uuid string + with inductor_config.patch(inductor_choices_class=ChoicesA): + portable = inductor_config.save_config_portable( + ignore_private_configs=False + ) + self.assertEqual(portable["inductor_choices_class"], "choices_a") + json_a = json.dumps(portable) + + # Factory returning ChoicesB → different uuid string + with inductor_config.patch(inductor_choices_class=ChoicesB): + portable = inductor_config.save_config_portable( + ignore_private_configs=False + ) + self.assertEqual(portable["inductor_choices_class"], "choices_b") + json_b = json.dumps(portable) + + self.assertNotEqual(json_a, json_b) + + def test_callable_config_without_uuid(self): + from torch._inductor.choices import InductorChoices + + # A subclass without uuid() raises RuntimeError. + class PlainChoices(InductorChoices): + pass + + with inductor_config.patch(inductor_choices_class=PlainChoices): + with self.assertRaises(RuntimeError): + inductor_config.save_config_portable(ignore_private_configs=False) + + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index 48602ef8a733a..e2de6d8e4d470 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -2177,9 +2177,10 @@ def _create_mock_match(self) -> Any: return mock_match + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) - def test_should_pad_memoizer_caches_result(self) -> None: + def test_should_pad_memoizer_caches_result(self, mock_is_contiguous) -> None: """Test that the should_pad_memoizer caches function results. Verifies that when a function decorated with should_pad_memoizer.memoize @@ -2221,10 +2222,11 @@ def mock_should_pad( self.assertTrue(result1) self.assertTrue(result2) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) def test_should_pad_memoizer_different_shapes_different_cache_entries( - self, + self, mock_is_contiguous ) -> None: """Test that different tensor shapes result in different cache entries. @@ -2272,10 +2274,11 @@ def mock_should_pad( self.assertFalse(result_small) # 8 <= 10 self.assertTrue(result_large) # 12 > 10 + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) def test_should_pad_memoizer_different_dtypes_different_cache_entries( - self, + self, mock_is_contiguous ) -> None: """Test that different tensor dtypes result in different cache entries. @@ -2321,10 +2324,11 @@ def mock_should_pad( self.assertTrue(result_fp32) self.assertFalse(result_fp16) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) def test_should_pad_memoizer_different_ops_different_cache_entries( - self, + self, mock_is_contiguous ) -> None: """Test that different operations result in different cache entries. @@ -2369,9 +2373,12 @@ def mock_should_pad( self.assertTrue(result_mm) self.assertFalse(result_addmm) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) - def test_should_pad_memoizer_replays_from_disk_cache(self) -> None: + def test_should_pad_memoizer_replays_from_disk_cache( + self, mock_is_contiguous + ) -> None: """Test that the memoizer replays results from disk cache after memory clear. Verifies that PersistentMemoizer correctly stores results to disk and @@ -2414,9 +2421,12 @@ def mock_should_pad( self.assertEqual(call_count, 1) # Function should NOT be called again self.assertTrue(result2) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(False) - def test_should_pad_memoizer_disabled_does_not_cache(self) -> None: + def test_should_pad_memoizer_disabled_does_not_cache( + self, mock_is_contiguous + ) -> None: """Test that the memoizer does not cache when caching is disabled. Verifies that when IS_CACHING_MODULE_ENABLED is False, the function @@ -2455,9 +2465,10 @@ def mock_should_pad( self.assertTrue(result1) self.assertTrue(result2) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) - def test_should_pad_memoizer_with_input_tensor(self) -> None: + def test_should_pad_memoizer_with_input_tensor(self, mock_is_contiguous) -> None: """Test that the memoizer correctly handles the optional input tensor. Verifies that different input tensors (for addmm) result in different @@ -2501,9 +2512,12 @@ def mock_should_pad( self.assertTrue(result_with_input) self.assertFalse(result_without_input) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) - def test_should_pad_params_encoder_produces_consistent_keys(self) -> None: + def test_should_pad_params_encoder_produces_consistent_keys( + self, mock_is_contiguous + ) -> None: """Test that the encoder produces consistent keys for the same inputs. Verifies that calling the encoder with the same tensor metadata produces @@ -2531,9 +2545,12 @@ def test_should_pad_params_encoder_produces_consistent_keys(self) -> None: self.assertEqual(encoded1["mat1"]["shape"], tuple(mat1.shape)) self.assertEqual(encoded1["mat2"]["shape"], tuple(mat2.shape)) + @patch("torch._prims_common.is_contiguous_or_false", return_value=True) @patch_on_disk_cache_base_dir @set_caching_module_enabled(True) - def test_should_pad_memoizer_same_shape_different_data_uses_cache(self) -> None: + def test_should_pad_memoizer_same_shape_different_data_uses_cache( + self, mock_is_contiguous + ) -> None: """Test that tensors with the same metadata but different data share cache. Verifies that the memoizer caches based on tensor metadata (shape, stride, @@ -3451,6 +3468,8 @@ def test_interim_result_with_future_pattern(self) -> None: call_count = 0 deferred_obj: interfaces.DeferredRecording | None = None + finalize_done = Event() + def future_encoder_factory(fn) -> object: def future_encoder(*args: object, **kwargs: object) -> object: def encode(future_result: Future[int]) -> interfaces.DeferredRecording: @@ -3462,6 +3481,7 @@ def encode(future_result: Future[int]) -> interfaces.DeferredRecording: def on_complete(completed_future: Future[int]) -> None: actual_result = completed_future.result() deferred.finalize(actual_result) + finalize_done.set() future_result.add_done_callback(on_complete) deferred_obj = deferred @@ -3503,6 +3523,11 @@ def work() -> int: result = future1.result(timeout=5) self.assertEqual(result, 10) + # Future.set_result() invokes done callbacks after releasing its + # condition lock, so future.result() can return before on_complete + # (and thus deferred.finalize()) has run. Wait for finalize. + self.assertTrue(finalize_done.wait(timeout=5)) + # Verify deferred recording completed self.assertIsNotNone(deferred_obj) self.assertIsNone(deferred_obj._callbacks) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 5e0bc60240b5c..9c79a8d694e9a 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1,5 +1,9 @@ # Owner(s): ["module: inductor"] +import base64 +import copy import functools +import hashlib +import json import logging import os import pickle @@ -8,6 +12,7 @@ import sys import tempfile import textwrap +import types import unittest from contextlib import contextmanager from typing_extensions import override @@ -21,8 +26,17 @@ from torch._functorch import config as functorch_config from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._inductor import config, metrics +from torch._inductor.cache_key import ( + AUTOTUNE_CACHE_KEY_STRATEGY, + CacheKeyStrategy, + CODE_CACHE_KEY_STRATEGY, + COMPACT_CACHE_KEY_STRATEGY, + FX_GRAPH_CACHE_KEY_STRATEGY, + SYSTEM_CACHE_KEY_STRATEGY, +) from torch._inductor.codecache import ( BypassFxGraphCache, + CacheBase, CUDACodeCache, FxGraphCachePickler, FxGraphHashDetails, @@ -96,6 +110,131 @@ STATIC_LAUNCHER_DEVICES = ("cuda", "xpu") +class TestCacheKeyStrategy(TestCase): + def _compact_sha256(self, data: bytes) -> str: + return ( + base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + ) + + def test_strategy_formats_existing_key_shapes(self): + self.assertEqual( + COMPACT_CACHE_KEY_STRATEGY.key(b"kernel"), + self._compact_sha256(b"kernel"), + ) + self.assertEqual( + CODE_CACHE_KEY_STRATEGY.key("kernel", "extra"), + "c" + self._compact_sha256(b"kernel||extra"), + ) + self.assertEqual( + FX_GRAPH_CACHE_KEY_STRATEGY.key(b"graph"), + "f" + self._compact_sha256(b"graph"), + ) + self.assertEqual( + AUTOTUNE_CACHE_KEY_STRATEGY.key("kernel", b"torch"), + hashlib.sha256(b"kerneltorch").hexdigest(), + ) + + def test_custom_strategy_composes_components(self): + strategy = CacheKeyStrategy( + name="test", + digest_format="hex", + prefix="x", + separator=b":", + ) + + self.assertEqual( + strategy.key("left", b"right"), + "x" + hashlib.sha256(b"left:right").hexdigest(), + ) + + def test_code_hash_uses_code_strategy(self): + from torch._inductor.codecache import code_hash, sha256_hash + + self.assertEqual( + sha256_hash(b"kernel"), COMPACT_CACHE_KEY_STRATEGY.key(b"kernel") + ) + self.assertEqual( + code_hash("kernel", extra="extra"), + CODE_CACHE_KEY_STRATEGY.key("kernel", "extra"), + ) + + def test_system_strategy_hashes_sorted_json(self): + system = {"version": {"triton": "abc"}, "device": {"name": "gpu"}} + + self.assertEqual( + SYSTEM_CACHE_KEY_STRATEGY.key_from_json(system), + hashlib.sha256( + json.dumps(system, sort_keys=True).encode("utf-8") + ).hexdigest(), + ) + + def test_cache_base_get_system_uses_system_strategy(self): + class FakeStrategy: + value = None + sort_keys = None + + def key_from_json(self, value, *, sort_keys=True): + self.value = copy.deepcopy(value) + self.sort_keys = sort_keys + return "sentinel" + + fake_strategy = FakeStrategy() + device_properties = types.SimpleNamespace( + name="test-gpu", gcnArchName="test-gcn" + ) + + CacheBase.get_system.cache_clear() + try: + with ( + mock.patch( + "torch._inductor.codecache.SYSTEM_CACHE_KEY_STRATEGY", + fake_strategy, + ), + mock.patch("torch._inductor.runtime.triton_compat.HAS_TRITON", False), + mock.patch.object(torch.cuda, "current_device", return_value=0), + mock.patch.object( + torch.cuda, + "get_device_properties", + return_value=device_properties, + ), + mock.patch.object(torch.version, "cuda", "test-cuda"), + ): + self.assertEqual(CacheBase.get_system()["hash"], "sentinel") + finally: + CacheBase.get_system.cache_clear() + + self.assertEqual( + fake_strategy.value, + { + "device": {"name": "test-gpu"}, + "version": {"triton": None, "cuda": "test-cuda"}, + }, + ) + self.assertTrue(fake_strategy.sort_keys) + + def test_autotune_prepare_key_uses_strategy(self): + from torch._inductor.runtime.autotune_cache import AutotuneCache + + class FakeStrategy: + components = None + + def key(self, *components): + self.components = components + return "sentinel" + + fake_strategy = FakeStrategy() + with ( + mock.patch( + "torch._inductor.runtime.autotune_cache.AUTOTUNE_CACHE_KEY_STRATEGY", + fake_strategy, + ), + torch.compiler.config.patch({"cache_key_tag": "tag"}), + ): + self.assertEqual(AutotuneCache._prepare_key("/tmp/cabcdef.py"), "sentinel") + + self.assertEqual(fake_strategy.components, ("cabcdef.py:tag",)) + + class LogCaptureHandler(logging.Handler): def __init__(self, level): super().__init__(level) @@ -1766,6 +1905,60 @@ def forward(self, x): self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 2) + @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) + def test_cache_guard_sqrt_no_recompilation(self): + """ + Verify that guards containing OpaqueUnaryFn_sqrt (math.sqrt) + do not cause spurious recompilations when loaded from cache. + + Previously, the guard printer emitted math.sqrt(...) for + OpaqueUnaryFn_sqrt, which when re-evaluated with SymInt inputs during + cache hit would force concretization/specialization of the symbol, + creating guards like `sym_float(size) == ` that didn't + exist in the original program. This caused a cache miss + recompilation + for every unique input size. + + The fix emits torch._sym_sqrt(...) which propagates symbolically + without forcing specialization. + + See https://github.com/pytorch/pytorch/issues/152435 + """ + import math + + def func(x): + y = math.ceil((x.numel() // 5) / (math.ceil(math.sqrt(x.numel())))) > 64 + if y: + return x * 5, y + else: + return x * 10, y + + compiled_fn = torch.compile(func, dynamic=True) + + # Warm up the cache with one size + compiled_fn(torch.rand(1000000)) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + + # Simulate a new process loading from cache + self.reset() + counters.clear() + + compiled_fn2 = torch.compile(func, dynamic=True) + # First call should hit the cache + compiled_fn2(torch.rand(2000000)) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + + # Subsequent calls with different sizes should NOT recompile. + # Before the fix, each new size would create a spurious guard + # `sym_float(size) == ` causing a recompilation. + for size in [3000000, 5000000, 6000000, 7000000]: + compiled_fn2(torch.rand(size)) + + # All subsequent calls should reuse the already-loaded compiled graph + # without any additional cache misses or dynamo recompilations. + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 0) + @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @config.patch({"freezing": True}) @@ -1998,6 +2191,45 @@ def backend(gm, args, **kwargs): compiled_out = torch.compile(f, fullgraph=True, backend=backend)(x) self.assertEqual(eager_out, compiled_out) + @config.patch({"fx_graph_cache": True}) + @config.patch({"fx_graph_remote_cache": False}) + @functorch_config.patch({"enable_autograd_cache": True}) + @parametrize("donate", (False, True)) + def test_donate_graph_module(self, donate: bool) -> None: + mod = torch.nn.Linear(1, 3) + x = torch.randn(4, 1) + + def f(x): + with torch.no_grad(): + return mod(x) + + eager_out = f(x) + + with fresh_cache(): + gm, args, kwargs = self.capture(f)(x) + if kwargs: + raise AssertionError + + # compile_fx mutates the graph module (e.g. adds + # mutation_region_id to node metadata). Use this as a + # fingerprint to detect mutation. + def has_mutation_region_ids(gm): + return any("mutation_region_id" in n.meta for n in gm.graph.nodes) + + before = has_mutation_region_ids(gm) + + compiled_artifact = torch._inductor.standalone_compile( + gm, args, donate_graph_module=donate + ) + compiled_out = compiled_artifact(*args) + self.assertEqual(eager_out, compiled_out[0]) + + after = has_mutation_region_ids(gm) + if donate: + self.assertNotEqual(before, after) + else: + self.assertEqual(before, after) + @config.patch({"fx_graph_cache": True}) @config.patch({"fx_graph_remote_cache": False}) @functorch_config.patch({"enable_autograd_cache": True}) @@ -2676,6 +2908,70 @@ def test_hash_private_config_changes(self): pickler.dumps(details3), ) + def test_hash_provenance_tracking_level(self): + """ + Test that provenance_tracking_level affects hashes. + """ + with config.patch({"trace.provenance_tracking_level": 0}): + details1 = FxGraphHashDetails(None, [], {}, []) + details2 = FxGraphHashDetails(None, [], {}, []) + + with config.patch({"trace.provenance_tracking_level": 1}): + details3 = FxGraphHashDetails(None, [], {}, []) + + gm = torch.fx.GraphModule({}, torch.fx.Graph()) + pickler = FxGraphCachePickler(gm) + + self.assertEqual( + pickler.dumps(details1), + pickler.dumps(details2), + ) + self.assertNotEqual( + pickler.dumps(details1), + pickler.dumps(details3), + ) + + def test_provenance_tracking_level_causes_cache_miss(self): + """ + Test that changing provenance_tracking_level causes a cache miss. + """ + + class Mod(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(4, 4)) + + def forward(self, x): + return x @ self.param + + mod = Mod() + mod_compiled = torch.compile(mod) + with torch.no_grad(): + x = torch.rand(4, 4) + with config.patch({"trace.provenance_tracking_level": 0}): + # miss + mod_compiled(x) + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + # hit + torch._dynamo.reset() + mod_compiled(x) + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) + torch._dynamo.reset() + counters.clear() + + with config.patch({"trace.provenance_tracking_level": 1}): + # miss (provenance_tracking_level changed) + mod_compiled(x) + self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0) + self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) + torch._dynamo.reset() + counters.clear() + def test_non_serializable_custom_passes_causes_cache_miss(self): class Mod(torch.nn.Module): def __init__(self) -> None: diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py index 1567569606a3f..3075eb303d56d 100644 --- a/test/inductor/test_codegen_triton.py +++ b/test/inductor/test_codegen_triton.py @@ -1,18 +1,29 @@ # Owner(s): ["module: inductor"] import contextlib +import unittest import sympy import torch import torch._inductor.config as inductor_config from torch._inductor.codegen import triton_utils -from torch._inductor.codegen.common import CSEVariable, SizeArg -from torch._inductor.codegen.triton import TritonKernelOverrides +from torch._inductor.codegen.common import CSEVariable, SizeArg, TensorArg +from torch._inductor.codegen.triton import ( + _materialize_trunc_to_float_expr, + TritonKernelOverrides, +) from torch._inductor.dtype_propagation import DtypePropagationOpsHandler, promote_types from torch._inductor.graph import GraphLowering from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_GPU, + HAS_GPU_AND_TRITON, +) +from torch.utils._sympy.functions import FloorDiv, TruncToFloat, TruncToInt from torch.utils._sympy.value_ranges import ValueRanges @@ -151,6 +162,154 @@ def constant(cls, value, dtype): "triton_helpers.pow_integer(custom_constant(3, torch.uint32), ks0)", ) + def test_materialize_trunc_to_float_expr_preserves_integer_subexpressions(self): + s0 = sympy.Symbol("s0") + + trunc_expr = TruncToInt(s0) + self.assertEqual( + _materialize_trunc_to_float_expr(trunc_expr, torch.float64), + TruncToFloat(s0), + ) + + integer_expr = FloorDiv(trunc_expr, sympy.Integer(5)) + self.assertEqual( + _materialize_trunc_to_float_expr(integer_expr, torch.float64), + integer_expr, + ) + + predicate_expr = sympy.Eq(trunc_expr, sympy.Integer(9007199254740993)) + self.assertEqual( + _materialize_trunc_to_float_expr(predicate_expr, torch.float64), + predicate_expr, + ) + + float_expr = sympy.Float(0.5) + trunc_expr + self.assertEqual( + _materialize_trunc_to_float_expr(float_expr, torch.float64), + sympy.Float(0.5) + TruncToFloat(s0), + ) + + @inductor_config.patch("triton.emit_pointer_range_32", True) + def test_config_of_emit_pointer_range_32_enabled(self): + from torch._inductor.utils import ( + get_triton_attrs_descriptor_version, + TritonAttrsDescriptorVersion, + ) + + sixteen = sympy.Integer(16) + s0 = sympy.Symbol("s0", positive=True, integer=True) + + config = triton_utils.config_of( + [SizeArg("A", sixteen), SizeArg("B", s0)], + pointer_range_override=(0,), + ) + + if get_triton_attrs_descriptor_version() in { + TritonAttrsDescriptorVersion.V0_NO_TRITON, + TritonAttrsDescriptorVersion.V1_COMPILER, + TritonAttrsDescriptorVersion.V2_BACKENDS, + TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE, + }: + self.assertEqual(config.pointer_range_32, (0,)) + else: + self.assertIsInstance(config, dict) + self.assertIn(["tt.pointer_range", 32], config[(0,)]) + + @inductor_config.patch("triton.emit_pointer_range_32", False) + def test_config_of_emit_pointer_range_32_disabled(self): + from torch._inductor.utils import ( + get_triton_attrs_descriptor_version, + TritonAttrsDescriptorVersion, + ) + + sixteen = sympy.Integer(16) + s0 = sympy.Symbol("s0", positive=True, integer=True) + + config = triton_utils.config_of( + [SizeArg("A", sixteen), SizeArg("B", s0)], + pointer_range_override=(), + ) + + if get_triton_attrs_descriptor_version() in { + TritonAttrsDescriptorVersion.V0_NO_TRITON, + TritonAttrsDescriptorVersion.V1_COMPILER, + TritonAttrsDescriptorVersion.V2_BACKENDS, + TritonAttrsDescriptorVersion.V3_BACKENDS_TUPLE, + }: + self.assertEqual(config.pointer_range_32, ()) + else: + self.assertIsInstance(config, dict) + if (0,) in config: + self.assertNotIn(["tt.pointer_range", 32], config[(0,)]) + + @unittest.skipUnless(torch.version.hip is not None, "pointer_range_32 is HIP-only") + @unittest.skipUnless(HAS_GPU_AND_TRITON, "requires GPU and Triton") + def test_pointer_range_in_generated_code(self): + """Verify tt.pointer_range=32 appears in generated Triton code on HIP.""" + + def fn(x): + return x + 1 + + x = torch.randn(64, 64, device=GPU_TYPE, dtype=torch.bfloat16) + _, code = run_and_get_code(torch.compile(fn), x) + code_str = " ".join(code) + self.assertIn("tt.pointer_range", code_str) + + def test_is_multiple_of_rules(self): + """Test structural divisibility rules in _is_multiple_of.""" + from torch.utils._sympy.functions import FloorDiv, Mod + + sv = V.graph.sizevars + shape_env = sv.shape_env + + s1 = sympy.Symbol("s1", positive=True, integer=True) + s2 = sympy.Symbol("s2", positive=True, integer=True) + s3 = sympy.Symbol("s3", positive=True, integer=True) + + # Product: any factor divisible → product divisible + self.assertTrue(sv.statically_known_multiple_of(16 * s1, 16)) + self.assertTrue(sv.statically_known_multiple_of(4 * 4 * s1, 16)) + shape_env.axioms[sympy.Eq(Mod(s1, 16), 0)] = sympy.true + self.assertTrue(sv.statically_known_multiple_of(s1 * s2, 16)) + self.assertFalse(sv.statically_known_multiple_of(s2 * s3, 16)) + + # Sum: all terms divisible → sum divisible + self.assertFalse(sv.statically_known_multiple_of(s1 + s2, 16)) + shape_env.axioms[sympy.Eq(Mod(s2, 16), 0)] = sympy.true + self.assertTrue(sv.statically_known_multiple_of(s1 + s2, 16)) + self.assertTrue(sv.statically_known_multiple_of(s1 + 32, 16)) + self.assertFalse(sv.statically_known_multiple_of(s1 + 3, 16)) + + # FloorDiv(a, b): a must be multiple of b*n + self.assertFalse(sv.statically_known_multiple_of(FloorDiv(s1, 3), 16)) + shape_env.axioms[sympy.Eq(Mod(s3, 48), 0)] = sympy.true + self.assertTrue(sv.statically_known_multiple_of(FloorDiv(s3, 3), 16)) + + # Mod(a, b): both a and b must be multiples of n + self.assertTrue(sv.statically_known_multiple_of(Mod(s1, 48), 16)) + s_nodiv = sympy.Symbol("s_nodiv", positive=True, integer=True) + self.assertFalse(sv.statically_known_multiple_of(Mod(s_nodiv, 32), 16)) + self.assertFalse(sv.statically_known_multiple_of(Mod(s1, 7), 16)) + + # Axiom fallback: bare symbol resolved via statically_known_true + s4 = sympy.Symbol("s4", positive=True, integer=True) + self.assertFalse(sv.statically_known_multiple_of(s4, 8)) + shape_env.axioms[sympy.Eq(Mod(s4, 8), 0)] = sympy.true + self.assertTrue(sv.statically_known_multiple_of(s4, 8)) + + def test_signature_of_fp8_dtypes(self): + """fp8 dtypes should produce correct Triton pointer signatures via _type_of.""" + expected = { + torch.float8_e4m3fn: "*fp8e4nv", + torch.float8_e5m2: "*fp8e5", + torch.float8_e4m3fnuz: "*fp8e4b8", + torch.float8_e5m2fnuz: "*fp8e5b16", + } + for dtype, expected_sig in expected.items(): + arg = TensorArg(name="x", buffer="buf0", dtype=dtype) + sig = triton_utils.signature_of(arg, size_dtype=None) + self.assertEqual(sig, expected_sig, f"wrong signature for {dtype}") + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_combo_kernels.py b/test/inductor/test_combo_kernels.py index f18ff7c8a8269..d595f78e6f24d 100644 --- a/test/inductor/test_combo_kernels.py +++ b/test/inductor/test_combo_kernels.py @@ -2,6 +2,8 @@ import contextlib import json +import logging +import re import sys import tempfile import unittest @@ -19,7 +21,11 @@ TestCase, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU_AND_TRITON -from torch.testing._internal.triton_utils import requires_gpu_and_triton +from torch.testing._internal.triton_utils import ( + requires_cuda_and_triton, + requires_gpu_and_triton, + requires_xpu_and_triton, +) aten = torch.ops.aten @@ -568,6 +574,68 @@ def fn(a, b, c, d): torch._inductor.metrics.generated_kernel_count, expected_kernel_count ) + @requires_gpu_and_triton + @parametrize( + "max_num_nodes,expected_kernel_count", + [(8, 1), (3, 2), (2, 3)], + ) + def test_combo_kernel_max_num_nodes(self, max_num_nodes, expected_kernel_count): + def fn(a, b, c, d, e, f): + return ( + a * 2.0, + b + 1.0, + c.sin(), + d.cos(), + e.exp(), + f.neg(), + ) + + inps = [ + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + + torch._inductor.metrics.reset() + with torch._inductor.config.patch("combo_kernel_max_num_nodes", max_num_nodes): + fn_c = torch.compile(fn) + out_compiled, _ = run_and_get_code(fn_c, *inps) + self.assertEqual(out_eager, out_compiled) + self.assertEqual( + torch._inductor.metrics.generated_kernel_count, expected_kernel_count + ) + + # waves_per_eu, matrix_instr_nonkdim, and kpack are HIP-only Triton + # compile options, so only ROCm exercises this combo-kernel rewrite path. + @unittest.skipIf(not torch.version.hip, "ROCm only") + @requires_gpu_and_triton + @parametrize("max_autotune", [False, True]) + def test_combo_kernel_amd_special_config_args(self, max_autotune): + if not torch._inductor.config.combo_kernel_per_subkernel_blocks: + self.skipTest("requires combo_kernel_per_subkernel_blocks") + + def fn(a, b): + return a * 2.0, b + 1.0 + + inps = [ + torch.rand(1024, device=GPU_TYPE), + torch.rand(1024, device=GPU_TYPE), + ] + out_eager = fn(*inps) + + torch._inductor.metrics.reset() + with torch._inductor.config.patch("max_autotune", max_autotune): + fn_c = torch.compile(fn) + out_compiled, _ = run_and_get_code(fn_c, *inps) + + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + @skipIfXpu(msg="Profiler JSON traceEvents is not supported on XPU") @requires_gpu_and_triton @unittest.skipIf(not SM90OrLater, "Avoid oom on CI") @@ -710,7 +778,7 @@ def test_mutated(a, b, c, d): out_compiled = torch.compile(test_mutated)(*inps) self.assertEqual(out_eager, out_compiled) - self.assertTrue(torch._inductor.metrics.generated_kernel_count in [6, 9]) + self.assertTrue(4 < torch._inductor.metrics.generated_kernel_count <= 10) @requires_gpu_and_triton def test_round_robin_dispatch(self): @@ -995,7 +1063,7 @@ def fn(a0, a1, a2, b0, b1, b2): self.assertEqual(out_eager, out_compiled) self.assertTrue(5 <= torch._inductor.metrics.generated_kernel_count <= 6) - @requires_gpu_and_triton + @requires_cuda_and_triton @torch._dynamo.config.patch("automatic_dynamic_shapes", True) @torch._dynamo.config.patch("assume_static_by_default", True) @torch._inductor.config.patch("triton.autotune_at_compile_time", True) @@ -1016,6 +1084,27 @@ def fn(x, y, z): self.assertEqual(out_eager, out_compiled) + @requires_xpu_and_triton + @torch._dynamo.config.patch("automatic_dynamic_shapes", True) + @torch._dynamo.config.patch("assume_static_by_default", True) + @torch._inductor.config.patch("triton.autotune_at_compile_time", True) + def test_dynamic_shapes_persistent_reduction_mixed_x_dim_xpu(self): + def fn(x, y, z): + return x.sum(1), y.mean(1), z.max(1) + + inps = ( + torch.rand(16, 128, device=GPU_TYPE), + torch.rand(32, 128, device=GPU_TYPE), + torch.rand(32, 256, device=GPU_TYPE), + ) + torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256) + torch._dynamo.mark_dynamic(inps[2], 0, min=1, max=256) + out_eager = fn(*inps) + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + @requires_gpu_and_triton def test_helper_fn_defined(self): def fn(x, y, z): @@ -1173,6 +1262,333 @@ def fn(x, y): self.assertEqual(out_eager, out_compiled) +class ComboKernelTestsMaxAutotune(TestCase): + def setUp(self): + super().setUp() + torch._inductor.metrics.reset() + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": False, + "combo_kernel_per_subkernel_blocks": True, + "max_autotune": True, + "autotune_local_cache": False, + } + ) + ) + + def tearDown(self): + self._test_stack.close() + torch._inductor.metrics.reset() + super().tearDown() + + @requires_gpu_and_triton + def test_combo_kernel_max_autotune(self): + def fn(a, b, c): + a1 = torch.nn.functional.relu(a) + b1 = torch.nn.functional.sigmoid(b) + c1 = torch.nn.functional.tanh(c) + return a1, b1, c1 + + inps = [ + torch.rand(32, 1024, device=GPU_TYPE), + torch.rand(64, 512, device=GPU_TYPE), + torch.rand(16, 2048, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + fn_c = torch.compile(fn) + + logger = logging.getLogger("torch._inductor.runtime.triton_heuristics") + with self.assertLogs(logger, level=logging.DEBUG) as cm: + out_compiled, code = run_and_get_code(fn_c, *inps) + chained_logs = [msg for msg in cm.output if "Combo sequential autotune" in msg] + self.assertGreater( + len(chained_logs), + 0, + "_combo_sequential_autotune was not invoked", + ) + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_gpu_and_triton + def test_combo_kernel_max_autotune_with_reduction(self): + def fn(x, y): + return x.sum(dim=-1), y.mean(dim=-1) + + inps = [ + torch.rand(128, 256, device=GPU_TYPE), + torch.rand(128, 256, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + fn_c = torch.compile(fn) + + logger = logging.getLogger("torch._inductor.runtime.triton_heuristics") + with self.assertLogs(logger, level=logging.DEBUG) as cm: + out_compiled, code = run_and_get_code(fn_c, *inps) + chained_logs = [msg for msg in cm.output if "Combo sequential autotune" in msg] + self.assertGreater( + len(chained_logs), + 0, + "_combo_sequential_autotune was not invoked", + ) + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_gpu_and_triton + def test_combo_autotune_many_subkernels(self): + def fn(a, b, c, d, e, f): + return ( + a * 2.0, + b + 1.0, + c.sin(), + d.cos(), + e.exp(), + f.neg(), + ) + + inps = [ + torch.rand(8, 8192, device=GPU_TYPE), + torch.rand(128, 64, device=GPU_TYPE), + torch.rand(16, 4096, device=GPU_TYPE), + torch.rand(512, 16, device=GPU_TYPE), + torch.rand(32, 2048, device=GPU_TYPE), + torch.rand(256, 32, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + fn_c = torch.compile(fn) + + logger = logging.getLogger("torch._inductor.runtime.triton_heuristics") + with self.assertLogs(logger, level=logging.DEBUG) as cm: + out_compiled, code = run_and_get_code(fn_c, *inps) + + chained_logs = [msg for msg in cm.output if "Combo sequential autotune" in msg] + self.assertGreater(len(chained_logs), 0) + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_gpu_and_triton + def test_combo_kernel_per_subkernel_reduction_hint(self): + def fn(x, y): + return x.sum(dim=-1), y.sum(dim=0) + + inps = [ + torch.rand(128, 256, device=GPU_TYPE), + torch.rand(128, 256, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + out, code = run_and_get_code(torch.compile(fn), *inps) + self.assertEqual(out_eager, out) + # Verify per-subkernel reduction hints in generated code + found_hints = {} + for c in code: + for key in ["reduction_hint_0", "reduction_hint_1"]: + m = re.search(rf"'{key}':\s*'(\w+)'", c) + if m: + found_hints[key] = m.group(1) + + self.assertIn( + "reduction_hint_0", found_hints, "Missing per-subkernel reduction_hint_0" + ) + self.assertIn( + "reduction_hint_1", found_hints, "Missing per-subkernel reduction_hint_1" + ) + self.assertEqual(found_hints["reduction_hint_0"], "INNER") + self.assertEqual(found_hints["reduction_hint_1"], "OUTER") + + @requires_gpu_and_triton + @torch._inductor.config.patch("combo_kernel_autotune_grouping", True) + def test_combo_autotune_grouping(self): + def fn(a, b, c, d): + return a.cos(), b.sin(), c.exp(), d.neg() + + # a,b: numel=262144 → bs=1024, c,d: numel=32 → bs=256 + # Different bs → different configs → separate groups + inps = [ + torch.rand(4, 65536, device=GPU_TYPE), + torch.rand(4, 65536, device=GPU_TYPE), + torch.rand(4, 8, device=GPU_TYPE), + torch.rand(4, 8, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + fn_c = torch.compile(fn) + + logger = logging.getLogger("torch._inductor.runtime.triton_heuristics") + with self.assertLogs(logger, level=logging.DEBUG) as cm: + out_compiled, code = run_and_get_code(fn_c, *inps) + + # Parse "Phase 1 group N SK[...]" lines to check grouping + group_lines = [ + msg for msg in cm.output if "Phase 1 group" in msg and "SK[" in msg + ] + group_indices = { + int(re.search(r"group (\d+)", line).group(1)) + for line in group_lines + if re.search(r"group (\d+)", line) + } + # Exact grouping count is hardware-dependent because pointwise candidate + # config sets can differ across environments. The stable regression for + # the new grouping key lives in the mocked test below. + self.assertGreater( + len(group_indices), + 0, + f"Expected at least one autotune group, got {group_lines}", + ) + self.assertEqual(out_eager, out_compiled) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + + @requires_gpu_and_triton + @torch._inductor.config.patch("combo_kernel_autotune_grouping", True) + def test_combo_autotune_grouping_uses_tiling_signature(self): + import triton + + inductor_meta = { + "combo_grid_meta": { + "num_kernels": 2, + "heuristic_0": "pointwise", + "heuristic_1": "pointwise", + "size_hints_0": {"x": 256, "y": 256}, + "size_hints_1": {"x": 256, "y": 256}, + "tile_hint_0": "TileHint.SQUARE", + "tile_hint_1": "TileHint.SQUARE", + "tiling_scores_0": {"x": 8, "y": 1}, + "tiling_scores_1": {"x": 1, "y": 8}, + } + } + + def pointwise_configs(*args, **kwargs): + return [ + triton.Config({"XBLOCK": 64, "YBLOCK": 32}, num_warps=4, num_stages=1), + triton.Config({"XBLOCK": 128, "YBLOCK": 32}, num_warps=4, num_stages=1), + ] + + with unittest.mock.patch( + "torch._inductor.runtime.triton_heuristics.pointwise", + side_effect=pointwise_configs, + ): + torch._inductor.runtime.triton_heuristics._handle_combo_kernel_per_subkernel_blocks( + {"x": 256, "y": 256}, + inductor_meta, + triton_meta={}, + ) + + groups = inductor_meta["combo_tuning_groups"] + self.assertEqual(len(groups), 2) + self.assertEqual([[0], [1]], [g["member_indices"] for g in groups]) + + @requires_gpu_and_triton + def test_combo_kernel_coordesc_tunes_largest_subkernel_first(self): + def fn(a, b, c): + return ( + torch.nn.functional.relu(a), + torch.nn.functional.sigmoid(b), + torch.nn.functional.tanh(c), + ) + + inps = [ + torch.rand(32, 1024, device=GPU_TYPE), + torch.rand(256, 256, device=GPU_TYPE), + torch.rand(16, 128, device=GPU_TYPE), + ] + + out_eager = fn(*inps) + + def parse_block_cfg(msg: str) -> dict[str, int]: + return { + m.group(1): int(m.group(2)) + for m in re.finditer(r"(\w+BLOCK_\d+): (\d+)", msg) + } + + logger = logging.getLogger("torch._inductor.runtime.coordinate_descent_tuner") + with torch._inductor.config.patch(coordinate_descent_tuning=True): + with self.assertLogs(logger, level=logging.DEBUG) as cm: + out_compiled = torch.compile(fn)(*inps) + + self.assertEqual(out_eager, out_compiled) + + baseline_log = next( + msg for msg in cm.output if "Baseline Config" in msg and "XBLOCK_" in msg + ) + baseline_cfg = parse_block_cfg(baseline_log) + try_logs = [ + msg for msg in cm.output if "Try config" in msg and "XBLOCK_" in msg + ] + self.assertGreater( + len(try_logs), 0, "Coordinate descent did not try combo fields" + ) + distinct_block_cfgs = { + tuple(sorted(parse_block_cfg(msg).items())) for msg in try_logs + } + self.assertGreater( + len(distinct_block_cfgs), + 1, + "Coordinate descent did not explore different suffixed block sizes.", + ) + + first_cfg = parse_block_cfg(try_logs[0]) + changed_fields = { + key for key, value in first_cfg.items() if baseline_cfg.get(key) != value + } + self.assertEqual( + changed_fields, + {"XBLOCK_1"}, + f"Expected the first combo coordesc step to tune the largest subkernel first, got {changed_fields}", + ) + + +@instantiate_parametrized_tests +class ComboKernelMetadataTests(TestCase): + + def setUp(self): + super().setUp() + torch._inductor.metrics.reset() + self._test_stack = contextlib.ExitStack() + self._test_stack.enter_context( + torch._inductor.config.patch( + { + "combo_kernels": True, + "benchmark_combo_kernel": False, + "combo_kernel_per_subkernel_blocks": True, + } + ) + ) + + def tearDown(self): + self._test_stack.close() + torch._inductor.metrics.reset() + super().tearDown() + + def _combo_code(self, fn, inps): + out_eager = fn(*inps) + out_compiled, code = run_and_get_code(torch.compile(fn), *inps) + self.assertEqual(out_eager, out_compiled) + return " ".join(code) + + @requires_gpu_and_triton + def test_combo_inductor_meta_has_optimize_mem(self): + def fn(a, b): + return torch.relu(a), torch.sigmoid(b) + + inps = [torch.rand(1024, device=GPU_TYPE) for _ in range(2)] + code = self._combo_code(fn, inps) + self.assertIn("'optimize_mem': True", code) + + @requires_gpu_and_triton + def test_combo_inductor_meta_optimize_mem_false_in_training_forward(self): + def fn(a, b): + return torch.relu(a), torch.sigmoid(b) + + inps = [torch.rand(1024, device=GPU_TYPE, requires_grad=True) for _ in range(2)] + code = self._combo_code(fn, inps) + self.assertIn("'optimize_mem': False", code) + + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/inductor/test_comm_analysis.py b/test/inductor/test_comm_analysis.py new file mode 100644 index 0000000000000..0ecf35dcec2de --- /dev/null +++ b/test/inductor/test_comm_analysis.py @@ -0,0 +1,134 @@ +# Owner(s): ["module: inductor"] + +import sys +import warnings + +import torch +import torch.distributed as dist + + +if not dist.is_available() or not dist.is_nccl_available(): + print("c10d NCCL not available, skipping tests", file=sys.stderr) + sys.exit(0) + +try: + from torch.testing._internal.common_distributed import ( + requires_nccl, + skip_if_lt_x_gpu, + ) +except ImportError: + print("common_distributed not importable, skipping tests", file=sys.stderr) + sys.exit(0) + +from torch.fx.experimental.proxy_tensor import make_fx +from torch.testing._internal.common_utils import run_tests, TestCase + + +def _get_all_gather_node(group_size, group_name): + """Trace a simple all_gather function and return the collective FX node.""" + + def func(inp, group_size, group_name): + out = torch.ops._c10d_functional.all_gather_into_tensor( + inp, group_size, group_name + ) + wait = torch.ops._c10d_functional.wait_tensor(out) + return wait + + gm = make_fx(func)(torch.ones(4, 4), group_size, group_name) + for n in gm.graph.nodes: + if n.op == "call_function" and "all_gather_into_tensor" in str(n.target): + return n + raise RuntimeError("No all_gather_into_tensor node found in traced graph") + + +class TestNcclEstimateDeviceResolution(TestCase): + """ + Tests for the device resolution fix in _nccl_estimate() inside + estimate_nccl_collective_runtime_from_fx_node. + """ + + def _init_pg(self, backend, world_size=2): + from torch.testing._internal.distributed.fake_pg import FakeStore + + store = FakeStore() + dist.init_process_group( + backend=backend, rank=0, world_size=world_size, store=store + ) + pg = dist.group.WORLD + group_name = "test_comm_analysis" + torch._C._distributed_c10d._register_process_group(group_name, pg) + return pg, group_name, pg.size() + + def _init_pg_real_store(self, backend, world_size=1): + store = dist.HashStore() + dist.init_process_group( + backend=backend, rank=0, world_size=world_size, store=store + ) + pg = dist.group.WORLD + group_name = "test_comm_analysis" + torch._C._distributed_c10d._register_process_group(group_name, pg) + return pg, group_name, pg.size() + + def _destroy_pg(self): + dist.destroy_process_group() + + def test_fake_backend_falls_back_to_analytical(self): + """FAKE backend: _nccl_estimate returns None, falls back to analytical formula.""" + pg, group_name, group_size = self._init_pg("fake") + try: + node = _get_all_gather_node(group_size, group_name) + from torch._inductor.comm_analysis import ( + estimate_nccl_collective_runtime_from_fx_node, + ) + + est_ms = estimate_nccl_collective_runtime_from_fx_node( + node, use_nccl_estimator=True + ) + self.assertGreater(est_ms, 0) + + est_ms_analytical = estimate_nccl_collective_runtime_from_fx_node( + node, use_nccl_estimator=False + ) + self.assertEqual(est_ms, est_ms_analytical) + finally: + self._destroy_pg() + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_multi_backend_pg_resolves_to_nccl(self): + """ + Multi-backend PG ("cpu:gloo,cuda:nccl"): We should resolve to the cuda device's backend. + """ + torch.cuda.set_device(0) + pg, group_name, group_size = self._init_pg_real_store("cpu:gloo,cuda:nccl") + try: + from torch.distributed.distributed_c10d import _get_pg_default_device + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + default_device = _get_pg_default_device(pg) + self.assertEqual(default_device, torch.device("cpu")) + + nccl_backend = pg._get_backend(torch.device("cuda")) + self.assertTrue(nccl_backend.supports_time_estimate) + + gloo_backend = pg._get_backend(torch.device("cpu")) + self.assertFalse(gloo_backend.supports_time_estimate) + finally: + self._destroy_pg() + + @requires_nccl() + @skip_if_lt_x_gpu(1) + def test_single_nccl_backend_resolves_correctly(self): + """Single NCCL backend PG: cuda device resolves to NCCL with time estimation.""" + torch.cuda.set_device(0) + pg, group_name, group_size = self._init_pg_real_store("nccl") + try: + backend = pg._get_backend(torch.device("cuda")) + self.assertTrue(backend.supports_time_estimate) + finally: + self._destroy_pg() + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/test_compile_subprocess.py b/test/inductor/test_compile_subprocess.py index 0248638480bdc..9344dc777be89 100644 --- a/test/inductor/test_compile_subprocess.py +++ b/test/inductor/test_compile_subprocess.py @@ -76,6 +76,11 @@ "test_weight_norm_conv2d": TestFailure(("cpu", "cuda"), is_skip=True), } +if TEST_WITH_ROCM and not torch.cuda.has_magma: + test_failures["test_linalg_eig_stride_consistency"] = TestFailure( + ("cuda",), is_skip=True + ) + class TestSubprocess(TestCase): def setUp(self): diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index a06ba7ed4faf8..35a6624bf682a 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -53,7 +53,7 @@ HAS_CPU, HAS_CUDA_AND_TRITON, HAS_GPU, - HAS_XPU_AND_TRITON, + HAS_GPU_AND_TRITON, ) from torch.testing._internal.logging_utils import logs_to_string from torch.testing._internal.triton_utils import ( @@ -1082,8 +1082,20 @@ def bytecode_hook(code, out_code): call_op = "CALL" insts = list(dis.get_instructions(out_code)) + # Find the CALL that invokes the compiled graph function + # (not an earlier CALL from e.g. store_user_object_weakrefs). + # The compiled fn is loaded via LOAD_GLOBAL __compiled_fn_*. + load_graph_idx = next( + i + for i, inst in enumerate(insts) + if inst.opname == "LOAD_GLOBAL" + and isinstance(inst.argval, str) + and inst.argval.startswith("__compiled_fn") + ) call_graph_idx = next( - i for i, inst in enumerate(insts) if inst.opname == call_op + i + for i, inst in enumerate(insts) + if i > load_graph_idx and inst.opname == call_op ) # pre-graph should alias: inputs_ref_0 = inputs[0] matches = [ @@ -1155,8 +1167,19 @@ def bytecode_hook(code, out_code): call_op = "CALL" insts = list(dis.get_instructions(out_code)) + # Find the CALL that invokes the compiled graph function + # (not an earlier CALL from e.g. store_user_object_weakrefs). + load_graph_idx = next( + i + for i, inst in enumerate(insts) + if inst.opname == "LOAD_GLOBAL" + and isinstance(inst.argval, str) + and inst.argval.startswith("__compiled_fn") + ) call_graph_idx = next( - i for i, inst in enumerate(insts) if inst.opname == call_op + i + for i, inst in enumerate(insts) + if i > load_graph_idx and inst.opname == call_op ) # pre-graph should alias: inputs_ref_0 = inputs[0] matches = [ @@ -1958,7 +1981,7 @@ def backward(ctx, gO): yield x.grad self.check_output_and_recompiles( - fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False) + fn, count=[1, 2], compiler_fn=make_compiler_fn(fullgraph=False) ) def test_custom_fn_compiled_fw_graph_break(self): @@ -2012,9 +2035,9 @@ def backward(ctx, gO): yield x.grad self.check_output_and_recompiles( - fn, count=[1, 3], compiler_fn=make_compiler_fn(fullgraph=False) + fn, count=[1, 2], compiler_fn=make_compiler_fn(fullgraph=False) ) - self.assertEqual(counters["stats"]["unique_graphs"], 6) # 3 fw, 3 bw + self.assertEqual(counters["stats"]["unique_graphs"], 5) def test_mismatch_fake_tensor_mode(self, dynamic_shape=False): """ @@ -3312,7 +3335,6 @@ def fn(): @mock.patch( "torch._functorch.aot_autograd.AOT_COUNTER", new_callable=itertools.count ) - @mock.patch("torch._dynamo.config.inline_inbuilt_nn_modules", True) def test_verbose_logs_aot_id(self, _): def fn(): model = torch.nn.Sequential( @@ -3654,7 +3676,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): getitem_27 = validate_outputs_1[0]; validate_outputs_1 = None getitem_28 = hooks[0]; getitem_28 = None - call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], [], getitem_27); getitem_1 = getitem_2 = getitem_27 = None + call_aot_bwd_prologue = torch__dynamo_compiled_autograd_call_aot_bwd_prologue((getitem_1, getitem_2), [], [], (getitem_27,)); getitem_1 = getitem_2 = getitem_27 = None aot0_primals_1 = call_aot_bwd_prologue[0] aot0_primals_2 = call_aot_bwd_prologue[1] aot0_tangents_1 = call_aot_bwd_prologue[2] @@ -3682,7 +3704,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] -""", # noqa: B950 +""", ) # https://github.com/pytorch/pytorch/issues/138920 @@ -3947,7 +3969,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): call_accumulate_grad = torch__dynamo_external_utils_call_accumulate_grad(getitem_1, getitem_32, False); getitem_1 = getitem_32 = call_accumulate_grad = None _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] -""", # noqa: B950 +""", ) self.check_output_and_recompiles( @@ -4027,7 +4049,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_15); getitem_1 = getitem_15 = accumulate_grad__default = None _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] -""", # noqa: B950 +""", ) self.check_output_and_recompiles( @@ -4107,7 +4129,7 @@ def forward(self, inputs, sizes, scalars, hooks, packed_data): accumulate_grad__default = torch.ops.inductor.accumulate_grad_.default(getitem_1, getitem_12); getitem_1 = getitem_12 = accumulate_grad__default = None _exec_final_callbacks_stub = torch__dynamo_external_utils__exec_final_callbacks_stub(); _exec_final_callbacks_stub = None return [] -""", # noqa: B950 +""", ) # 1 graph break on torch.load -> 2 dynamo graphs @@ -4616,7 +4638,7 @@ def check(grad): self.check_output_and_recompiles( fn, compiler_fn=make_compiler_fn(fullgraph=False), - count=[1, 2], + count=[1, 1], ) # Case 1.5.1: Dense variable gradient layout contract @@ -4885,7 +4907,7 @@ def backward(ctx, grad_output): self.check_output_and_recompiles( fn, compiler_fn=make_compiler_fn(fullgraph=False), - count=[1, 2], + count=[1, 1], ) # Case 3.1: Sparse variable_grad + Dense new_grad (reorder into Dense + Sparse) @@ -4982,7 +5004,7 @@ def backward(ctx, grad_output): self.check_output_and_recompiles( fn, compiler_fn=make_compiler_fn(fullgraph=False), - count=[1, 3], + count=[1, 2], ) def test_torch_function_mode(self): @@ -5031,7 +5053,7 @@ def fwd(x, y, z): _set_multithreading_enabled backward _set_multithreading_enabled""", - ) # noqa: B950 + ) def test_torch_dispatch_mode(self): called_funcs = [] @@ -5089,7 +5111,7 @@ def fwd(x, y, z): mul.Tensor new_empty_strided.default copy_.default""", - ) # noqa: B950 + ) def load_test_module(name): @@ -5270,6 +5292,7 @@ def tearDown(self): "test_nested_checkpoint_same_graph_early_stop_False", # dynamo disable "test_nested_checkpoint_same_graph_early_stop_True", # dynamo disable "test_nested_checkpoint_set_early_stop", # dynamo disable + "test_nested_checkpoint_set_early_stop_no_recompution_needed", # TorchDispatchMode causes frame skip "test_nested_checkpoint_two_children_early_stop_False", # dynamo disable "test_nested_checkpoint_two_children_early_stop_True", # dynamo disable "test_custom_autograd_ac_early_stop", # marked as skipped @@ -5417,16 +5440,26 @@ def tearDown(self): skipped_tests.add("test_checkpoint_automatic_dynamic_mark_dynamic_workaround") skipped_tests.add("test_checkpoint_automatic_dynamic_lru_disabled_workaround") +# boxed_grads_call relies on eager C++ PyNode::apply, incompatible with compiled autograd +skipped_tests.add("test_custom_function_boxed_grads") +skipped_tests.add("test_custom_function_boxed_grads_multi_output") +skipped_tests.add("test_custom_function_boxed_grads_no_extra_refs") +skipped_tests.add("test_custom_function_boxed_grads_cleanup_on_error") +skipped_tests.add("test_custom_function_boxed_grads_chain") +skipped_tests.add("test_custom_function_boxed_grads_none_grads") +skipped_tests.add("test_custom_function_boxed_grads_materialize_grads") +skipped_tests.add("test_custom_function_boxed_grads_direct_apply") +skipped_tests.add("test_custom_function_boxed_grads_single_list_arg") + test_autograd = load_test_module("test_autograd") test_custom_ops = load_test_module("test_custom_ops") test_higher_order_ops = load_test_module("dynamo/test_higher_order_ops") -if not HAS_XPU_AND_TRITON: - TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd) + +TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd) TestNestedCheckpointWithCompiledAutograd = wrap_test_class( test_autograd.TestNestedCheckpoint ) -if not HAS_XPU_AND_TRITON: - TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp) +TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp) HigherOrderOpTestsWithCompiledAutograd = wrap_test_class( test_higher_order_ops.HigherOrderOpTests ) @@ -5437,7 +5470,7 @@ def tearDown(self): test_higher_order_ops.ActivationCheckpointingTests ) -if torch.distributed.is_available() and HAS_CUDA_AND_TRITON: +if torch.distributed.is_available() and HAS_GPU_AND_TRITON: test_dtensor = load_test_module("distributed/tensor/test_dtensor_compile") TestDTensorCompileWithCompiledAutograd = wrap_test_class( test_dtensor.TestDTensorCompile diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index 225aae2ee23f5..590d7089c8d6f 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -58,18 +58,19 @@ optim_db, optims, ) -from torch.testing._internal.common_utils import parametrize, skipIfRocm, skipIfWindows +from torch.testing._internal.common_utils import ( + parametrize, + skipIfRocm, + skipIfWindows, + skipIfXpu, +) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, HAS_GPU, has_triton, ) -from torch.testing._internal.triton_utils import ( - requires_cuda_and_triton, - requires_gpu, - requires_gpu_and_triton, -) +from torch.testing._internal.triton_utils import requires_gpu, requires_gpu_and_triton def get_inputs(optim): @@ -255,8 +256,8 @@ class KernelCounts(NamedTuple): "test_adamax_tensor_lr_weight_decay_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), "test_asgd_tensor_lr_weight_decay_maximize_capturable_cuda": lambda x: assert_expected_inline(x, """5"""), "test_asgd_tensor_lr_weight_decay_maximize_capturable_xpu": lambda x: assert_expected_inline(x, """5"""), - "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), # noqa: B950 - "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), # noqa: B950 + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_cuda": lambda x: assert_expected_inline(x, """6"""), + "test_nadam_tensor_lr_weight_decay_momentum_decay_decoupled_weight_decay_capturable_xpu": lambda x: assert_expected_inline(x, """6"""), "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_cuda": lambda x: assert_expected_inline(x, """6"""), "test_radam_tensor_lr_capturable_weight_decay_decoupled_weight_decay_xpu": lambda x: assert_expected_inline(x, """6"""), "test_sgd_tensor_lr_cpu": lambda x: assert_expected_inline(x, """2"""), @@ -591,7 +592,7 @@ class CompiledOptimizerParityTests(TestCase): @optims(optim_db, dtypes=[torch.float32]) @parametrize("use_closure", [True, False]) def test_correctness(self, device, dtype, optim_info, use_closure): - torch.cuda.manual_seed_all(0) + torch.get_device_module(device).manual_seed_all(0) torch.manual_seed(0) random.seed(0) optim_cls = optim_info.optim_cls @@ -931,7 +932,7 @@ def fn(xs, ys): self.assertLess(end - start, 90) - @requires_cuda_and_triton + @requires_gpu_and_triton def test_S429861(self): # Just verify we can compile this function without error try: @@ -947,7 +948,7 @@ def test_S429861(self): from torch._inductor.utils import fresh_cache with fresh_cache(): - kwargs = aot_graph_input_parser(forward) + kwargs = aot_graph_input_parser(forward, device=GPU_TYPE) torch.compile(forward)(**kwargs) @requires_gpu_and_triton @@ -992,7 +993,7 @@ def loop(): @skipIfRocm(msg="ROCm may have different numerical behavior") -@requires_cuda_and_triton +@requires_gpu_and_triton class CompiledOptimizerBitwiseTests(TestCase): """ Tests that compiled optimizers produce bitwise identical results to eager @@ -1085,7 +1086,8 @@ def _test_optimizer_bitwise( def _make_bitwise_test(optim_cls, kernel_count=None, **optim_kwargs): @skipIfRocm(msg="ROCm may have different numerical behavior") - @requires_cuda_and_triton + @skipIfXpu(msg="AttributeError, torch-xpu-ops: #2999") + @requires_gpu_and_triton @config.patch( { "score_fusion_memory_threshold": 1, diff --git a/test/inductor/test_config.py b/test/inductor/test_config.py index f617a4d12f11f..e78723e2df10b 100644 --- a/test/inductor/test_config.py +++ b/test/inductor/test_config.py @@ -7,7 +7,8 @@ from torch._inductor import config from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_TRITON +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_TRITON +from torch.testing._internal.triton_utils import requires_gpu def dummy_fn(x): @@ -327,6 +328,68 @@ def fn2(x, y): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0) self.assertEqual(counters["inductor"]["fxgraph_lookup_write_file"], 0) + @requires_gpu + @torch._inductor.config.patch(fx_graph_cache=False) + def test_config_read_in_backwards(self): + @torch.compile + def f(x, y): + z = x @ y + return z.sin().sum() + + called = False + + def my_pass(graph): + nonlocal called + called = True + + x, y = ( + torch.randn(3, 3, device=GPU_TYPE, requires_grad=True), + torch.randn(3, 3, device=GPU_TYPE), + ) + z = f(x, y) + z.backward() + self.assertFalse(called) + torch._dynamo.reset() + z = f(x, y) + with torch._inductor.config.patch(post_grad_custom_pre_pass=my_pass): + z.backward() + + self.assertTrue(called) + + called = False + torch._dynamo.reset() + z = f(x, y) + with torch._inductor.config.patch(post_grad_custom_pre_pass=my_pass): + torch.autograd.grad(z, x) + self.assertTrue(called) + + @torch._inductor.config.patch(fx_graph_cache=False) + def test_config_read_in_grad_fn(self): + @torch.compile + def f(x, y): + z = x @ y + return z.sin().sum() + + called = False + + def my_pass(graph): + nonlocal called + called = True + + x, y = ( + torch.randn(3, 3, requires_grad=True), + torch.randn(3, 3), + ) + + with torch._inductor.config.patch(post_grad_custom_pre_pass=my_pass): + z = f(x, y) + self.assertTrue(called) + + # Make sure the context gets cleared after forward pass + called = False + z.grad_fn.apply(torch.tensor(0)) + self.assertFalse(called) + if __name__ == "__main__": run_tests() diff --git a/test/inductor/test_control_deps.py b/test/inductor/test_control_deps.py index 36b365e4f7d82..9adc3458fce75 100644 --- a/test/inductor/test_control_deps.py +++ b/test/inductor/test_control_deps.py @@ -8,7 +8,7 @@ from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import ( GPU_TYPE, - HAS_CUDA_AND_TRITON, + HAS_GPU_AND_TRITON, requires_gpu, ) @@ -259,5 +259,5 @@ def add_control_deps(graph): if __name__ == "__main__": - if IS_LINUX and HAS_CUDA_AND_TRITON: + if IS_LINUX and HAS_GPU_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_control_flow.py b/test/inductor/test_control_flow.py index 45761543b3186..680b503e3f815 100644 --- a/test/inductor/test_control_flow.py +++ b/test/inductor/test_control_flow.py @@ -14,7 +14,6 @@ decorateIf, instantiate_parametrized_tests, parametrize, - skipIfXpu, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.testing._internal.triton_utils import requires_gpu @@ -385,7 +384,6 @@ def test_cond_unbacked_symint_closure(self, device, dynamic): dynamic=dynamic, ) - @skipIfXpu(msg="Remove this skip after issue #154949 resolved.") @requires_gpu def test_cond_control_flow_with_precomputed_size(self): class TestModel(torch.nn.Module): diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index b211e273c2a79..fa395dac4b591 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -19,7 +19,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU -class TestingHeuristics(InductorChoices): +class _TestingHeuristics(InductorChoices): def __init__(self, *, cooperative: bool, persistent: bool, cfg: dict[str, int]): super().__init__() self.cooperative = cooperative @@ -255,7 +255,7 @@ class MultiKernelCooperativeReductionTests(CooperativeReductionTests): class TestFixedConfigs(TestCase): def _check(self, fn, args, *, persistent=False, cooperative=True, cfg): expected = fn(*args) - heuristic = TestingHeuristics( + heuristic = _TestingHeuristics( persistent=persistent, cooperative=cooperative, cfg=cfg ) with torch._inductor.virtualized.V.set_choices_handler(heuristic): diff --git a/test/inductor/test_coordinate_descent_tuner.py b/test/inductor/test_coordinate_descent_tuner.py index c5b39f4491d0c..1bb1c5133ac59 100644 --- a/test/inductor/test_coordinate_descent_tuner.py +++ b/test/inductor/test_coordinate_descent_tuner.py @@ -5,6 +5,7 @@ from unittest import mock import torch +from torch._inductor.runtime import triton_heuristics from torch._inductor.runtime.hints import TRITON_MAX_BLOCK from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX @@ -113,6 +114,97 @@ def test_value_too_large(self): self.assertFalse(tuner.value_too_large("R0_BLOCK", max_block["R0_"])) self.assertTrue(tuner.value_too_large("R0_BLOCK", max_block["R0_"] * 2)) + def test_value_too_large_combo_field_limits(self): + tuner = CoordescTuner( + size_hints={"x": 2**20, "r0_": 2**20}, + inductor_meta={ + "combo_coordesc_field_limits": { + "XBLOCK_0": 64, + "XBLOCK_1": 256, + "R0_BLOCK_1": 128, + } + }, + ) + + self.assertFalse(tuner.value_too_large("XBLOCK_0", 64)) + self.assertTrue(tuner.value_too_large("XBLOCK_0", 128)) + self.assertFalse(tuner.value_too_large("XBLOCK_1", 256)) + self.assertTrue(tuner.value_too_large("XBLOCK_1", 512)) + self.assertFalse(tuner.value_too_large("R0_BLOCK_1", 128)) + self.assertTrue(tuner.value_too_large("R0_BLOCK_1", 256)) + + def test_combo_metadata_orders_larger_subkernels_first_for_coordesc(self): + def make_configs(xblock, yblock): + return [ + triton.Config( + {"XBLOCK": xblock, "YBLOCK": yblock}, + num_warps=4, + num_stages=1, + ) + ] + + inductor_meta = { + "combo_grid_meta": { + "num_kernels": 3, + "heuristic_0": "pointwise", + "heuristic_1": "pointwise", + "heuristic_2": "pointwise", + "size_hints_0": {"x": 64, "y": 64}, + "size_hints_1": {"x": 256, "y": 256}, + "size_hints_2": {"x": 128, "y": 16}, + "tile_hint_0": "TileHint.SQUARE", + "tile_hint_1": "TileHint.SQUARE", + "tile_hint_2": "TileHint.SQUARE", + "no_x_dim_0": False, + "no_x_dim_1": False, + "no_x_dim_2": False, + } + } + + configs_by_size = { + (64, 64): make_configs(64, 32), + (256, 256): make_configs(256, 64), + (128, 16): make_configs(128, 16), + } + + def pointwise_side_effect(size_hints, *args, **kwargs): + return configs_by_size[(size_hints["x"], size_hints["y"])] + + with mock.patch.object( + triton_heuristics, + "pointwise", + side_effect=pointwise_side_effect, + ): + configs = triton_heuristics._handle_combo_kernel_per_subkernel_blocks( + {"x": 256, "y": 256}, + inductor_meta, + triton_meta={}, + ) + + self.assertIsNotNone(configs) + self.assertEqual( + inductor_meta["combo_coordesc_field_order"], + [ + "XBLOCK_1", + "YBLOCK_1", + "XBLOCK_0", + "YBLOCK_0", + "XBLOCK_2", + "YBLOCK_2", + ], + ) + self.assertEqual( + inductor_meta["combo_coordesc_field_limits"], + { + "XBLOCK_0": 64, + "YBLOCK_0": 64, + "XBLOCK_1": 256, + "YBLOCK_1": 256, + "XBLOCK_2": 128, + "YBLOCK_2": 16, + }, + ) + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 09ac55d097572..687a9db681c4a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -711,6 +711,10 @@ def _test_lstm_packed( ] ), ) + @unittest.skipIf( + IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED, + "flaky on AArch64 (no SVE)", + ) def test_lstm_packed( self, unbatched, @@ -758,6 +762,10 @@ def test_lstm_packed( "unbatched, input_size, hidden_size, num_layers, bidirectional, bias, empty_state, batch_first, batch_size, seq_len", _test_lstm_packed_change_input_sizes_cpu_params, ) + @unittest.skipIf( + IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED, + "flaky on AArch64 (no SVE)", + ) def test_lstm_packed_change_input_sizes_cpu( self, unbatched, @@ -1492,7 +1500,12 @@ def test_decomposed_dequant_relu_quant_uint8(self): def test_decomposed_dequant_relu_quant_int8(self): self._test_decomposed_dequant_relu_quant_helper(torch.int8) - def _test_dequant_quant_lowering_helper(self, dtype, dequant_out_dtype=None): + def _test_dequant_quant_lowering_helper( + self, + dtype, + input_dtype=torch.float32, + dequant_out_dtype=None, + ): def fn( x, scale, @@ -1553,7 +1566,7 @@ def fn( use_tensor_overload_list, ): x = torch.clamp( - torch.randn((1, 7, 7, 9), dtype=torch.float32) * 100, + torch.randn((1, 7, 7, 9), dtype=input_dtype) * 100, quant_min, quant_max, ) @@ -1590,6 +1603,9 @@ def fn( @requires_vectorization def test_dequant_quant_lowering_uint8(self): self._test_dequant_quant_lowering_helper(torch.uint8) + self._test_dequant_quant_lowering_helper( + torch.uint8, input_dtype=torch.bfloat16 + ) self._test_dequant_quant_lowering_helper( torch.uint8, dequant_out_dtype=torch.bfloat16 ) @@ -1597,6 +1613,7 @@ def test_dequant_quant_lowering_uint8(self): @requires_vectorization def test_dequant_quant_lowering_int8(self): self._test_dequant_quant_lowering_helper(torch.int8) + self._test_dequant_quant_lowering_helper(torch.int8, input_dtype=torch.bfloat16) self._test_dequant_quant_lowering_helper( torch.int8, dequant_out_dtype=torch.bfloat16 ) @@ -3654,6 +3671,36 @@ def fn2(x): f"Expected generated_cpp_vec_kernel_count == 1, got {metrics.generated_cpp_vec_kernel_count}" ) + @requires_vectorization + def test_argmax_argmin_cpptile2d_2d_input(self): + def fn(a, b): + return (a + b).max(dim=1) + + def fn_min(a, b): + return (a + b).min(dim=1) + + torch.manual_seed(0) + for sz in [8, 32, 35]: + for f in [fn, fn_min]: + torch._dynamo.reset() + a = torch.randn(sz, sz).transpose(0, 1) + b = torch.randn(sz, sz) + self.common(f, (a, b)) + + @requires_vectorization + def test_argmax_argmin_cpptile2d_3d_input(self): + def fn_3d(a, b): + return (a + b).max(dim=2) + + def fn_3d_min(a, b): + return (a + b).min(dim=2) + + for f in [fn_3d, fn_3d_min]: + torch._dynamo.reset() + a = torch.randn(4, 16, 16).permute(0, 2, 1) + b = torch.randn(4, 16, 16) + self.common(f, (a, b)) + # Currently, we enabled AVX2 and AVX512 for vectorization. If the platform is not # supported, the vectorization will not work and skip this test case. For ARM or # other platforms support, we just need to add the ISA info to the supported_vector_isa @@ -4421,6 +4468,92 @@ def forward(self, x): res2 = jit_func(x) self.assertEqual(res1, res2) + def test_sdpa_closure_mask_recompile(self): + # Regression test for closure-composed attention masks on CPU Inductor. + def causal_fn(b, h, q, kv): + return kv <= q + + def padding_fn(padding_mask): + def inner(b, h, q, kv): + return padding_mask[b, kv] + + return inner + + def and_masks(f1, f2): + def combined(b, h, q, kv): + return f1(b, h, q, kv) & f2(b, h, q, kv) + + return combined + + def make_mask_closure(batch_size, q_len, kv_len, q_offset, padding_mask): + fn = and_masks(causal_fn, padding_fn(padding_mask)) + b = torch.arange(batch_size)[:, None, None, None] + h = torch.arange(1)[None, :, None, None] + q = (torch.arange(q_len) + q_offset)[None, None, :, None] + kv = torch.arange(kv_len)[None, None, None, :] + return fn(b, h, q, kv).expand(batch_size, 1, q_len, kv_len) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.proj = torch.nn.Linear(32, 32, bias=False) + + def forward(self, x, past_k, past_v, padding_mask, q_offset): + B, S, _ = x.shape + q = self.proj(x).view(B, S, 4, 8).transpose(1, 2) + k = x.view(B, S, 4, 8).transpose(1, 2) + v = k.clone() + if past_k is not None: + k = torch.cat([past_k, k], dim=2) + v = torch.cat([past_v, v], dim=2) + + mask = make_mask_closure(B, S, k.shape[2], q_offset, padding_mask) + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) + return out.transpose(1, 2).reshape(B, S, 32), k, v + + torch._dynamo.reset() + torch.manual_seed(0) + + eager_model = Model().eval() + compiled_model = Model().eval() + compiled_model.load_state_dict(eager_model.state_dict()) + compiled = torch.compile(compiled_model, backend="inductor") + + pad_mask = torch.ones(1, 8, dtype=torch.bool) + pad_mask[0, :2] = False + + # Prefill: seq_len=8 + x = torch.randn(1, 8, 32) + with torch.no_grad(): + ref_out, ref_k, ref_v = eager_model(x, None, None, pad_mask, q_offset=0) + out, k, v = compiled(x, None, None, pad_mask, q_offset=0) + torch.testing.assert_close(out, ref_out) + torch.testing.assert_close(k, ref_k) + torch.testing.assert_close(v, ref_v) + + # Decode: seq_len=1 and kv grows, which exercises recompilation. + for _ in range(3): + pad_mask = torch.cat([pad_mask, torch.ones(1, 1, dtype=torch.bool)], dim=1) + x = torch.randn(1, 1, 32) + with torch.no_grad(): + ref_out, ref_k, ref_v = eager_model( + x, + ref_k, + ref_v, + pad_mask, + q_offset=ref_k.shape[2], + ) + out, k, v = compiled( + x, + k, + v, + pad_mask, + q_offset=k.shape[2], + ) + torch.testing.assert_close(out, ref_out) + torch.testing.assert_close(k, ref_k) + torch.testing.assert_close(v, ref_v) + def test_scalar_mul_bfloat16(self): def f(x): return torch.ops.aten.mul.Tensor(x, 1.7015043497085571) @@ -4638,6 +4771,35 @@ def fn(x, weight, bias): torch.testing.assert_close(weight_cmp.grad, weight_ref.grad) torch.testing.assert_close(bias_cmp.grad, bias_ref.grad) + @torch._dynamo.config.patch( + capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True + ) + def test_backward_dynamic_item_from_size1_tensor(self): + def fn(x, w): + num = x.nonzero().numel() + num = x.new_tensor([num]).item() + y = w * x + return y.sum() / num + + torch._dynamo.reset() + metrics.reset() + + x = torch.tensor([0.0, 1.0, -1.0, 0.5]) + w_ref = torch.tensor(0.0, requires_grad=True) + w_cmp = w_ref.detach().clone().requires_grad_(True) + + eager_out = fn(x, w_ref) + eager_out.backward() + + compiled = torch.compile(fn, backend="inductor", fullgraph=True, dynamic=True) + compiled_out = compiled(x, w_cmp) + compiled_out.backward() + + self.assertEqual(eager_out.shape, torch.Size([])) + self.assertEqual(compiled_out.shape, eager_out.shape) + torch.testing.assert_close(compiled_out, eager_out) + torch.testing.assert_close(w_cmp.grad, w_ref.grad) + @config.patch(emulate_precision_casts=True) def test_emulate_precision_casts_cpp_backend_no_error(self): """ @@ -5040,6 +5202,26 @@ def fn(x): "at::vec::VectorizedN::loadu", 2, exactly=True ).run(code) + @requires_vectorization + def test_indirect_assert_scalar_mask_tail_vec_no_crash(self): + # https://github.com/pytorch/pytorch/issues/178136 + def fn(positions, cache): + x = cache[positions] + y = x[0].clone() + y[..., 1::3] = x[1, ..., 1::3] + y[..., 2::3] = x[2, ..., 2::3] + return y + + positions = torch.tensor([[0, 0], [0, 0], [0, 0]], dtype=torch.int64) + cache = torch.arange(3, dtype=torch.float32).reshape(1, 3) + + with config.patch({"cpp.enable_loop_tail_vec": True}): + expected = fn(positions, cache) + compiled_fn = torch.compile(fn, fullgraph=True) + actual = compiled_fn(positions, cache) + + torch.testing.assert_close(actual, expected) + def test_uint64_pointwise_vec(self): def fn(x): return x * x @@ -5924,6 +6106,31 @@ def func1(arg0_float, arg1_bf16): "Expected convert in generated code for func1", ) + @requires_vectorization + def test_bool_to_float8_e4m3fn(self): + """ + Test that bool to float8_e4m3fn cast succeeds. + Issue: https://github.com/pytorch/pytorch/issues/178095 + """ + + def fn(x): + return x.to(dtype=torch.float8_e4m3fn) + + x = torch.ones(64, dtype=torch.bool) + self.common(fn, (x,)) + + @requires_vectorization + def test_bool_to_float8_e5m2(self): + """ + Test that bool to float8_e5m2 cast succeeds. + """ + + def fn(x): + return x.to(dtype=torch.float8_e5m2) + + x = torch.ones(64, dtype=torch.bool) + self.common(fn, (x,)) + @config.patch("cpp.simdlen", 256) @requires_vectorization def test_avx2_bool_constant_pad_nd(self): @@ -6218,6 +6425,37 @@ def forward(x): ) ) + def test_mutation_transpose_reshape_ordering(self): + def fn(x, y): + return x.add_(y).reshape(-1, 1, 3) + + torch.manual_seed(0) + x1 = torch.ones([1, 2, 3]).transpose(1, 2) + y1 = torch.randn(1, 3, 1) + + x2 = x1.clone() + y2 = y1.clone() + + cfunc = torch.compile(fn, backend="inductor") + + out1 = fn(x1, y1) + out2 = cfunc(x2, y2) + + self.assertTrue(torch.allclose(out1, out2, equal_nan=True)) + + def test_indirect_index_transposed_tensor(self): + # https://github.com/pytorch/pytorch/issues/178521 + # transpose_mxn was referencing a tmp variable defined inside the inner + # loop when the index used indirect (SymT.TMP) indexing. + def f(buf, idx): + return buf[torch.arange(buf.shape[0]), idx, :] + + buf = torch.randn(16, 2, 16).permute(2, 1, 0) + idx = torch.randint(0, 2, (16,)) + expected = f(buf, idx) + actual = torch.compile(f, backend="inductor")(buf, idx) + self.assertEqual(actual, expected) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 845ad60ae0525..befadb7a331a7 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -26,10 +26,13 @@ _calculate_dynamic_per_channel_qparams, ) from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_CPU_EXT_SVE_SUPPORTED, IS_MACOS, IS_WINDOWS, parametrize, TEST_MKL, + xfailIf, ) @@ -58,15 +61,17 @@ def skip_cache(self, choices, name, key, benchmark, hint_override=None): timings = benchmark(choices) for choice, timing in timings.items(): if isinstance(choice, select_algorithm.ExternKernelCaller): - # we intentionally make ATEN kernel slower to cover the cases + # We intentionally make ATEN kernel slower to cover the cases # where template kernels are always chosen with fusions applied - # and correctness checks at runtime. - timings[choice] = timing * 1000 + # and correctness checks at runtime. On k8s ARC runner pods, + # CPU contention from parallel tests can make cpp template + # benchmarks 2-3x slower than normal, so the multiplier needs + # to be large enough to still exceed those inflated times. + timings[choice] = timing * 1000000 return timings for patcher in [ dynamo_config.patch(verbose=True), - dynamo_config.patch(inline_inbuilt_nn_modules=True), inductor_config.patch( debug=True, max_autotune=True, @@ -801,7 +806,8 @@ def forward(self, arg7_1): self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches @@ -1524,7 +1530,8 @@ def forward(self, x0, x1, other): ) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches @@ -1630,6 +1637,9 @@ def forward(self, qx, x_scale, x_zp): ) @parametrize("in_features", (128, 144, 1024)) @parametrize("out_features", (64, 65, 1024)) + @unittest.skipIf( + IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED, "flaky on AArch64 (no SVE)" + ) def test_int8_woq_mm(self, dtype, batch_size, mid_dim, in_features, out_features): def _convert_weight_to_int8pack(w): scale, zp = _calculate_dynamic_per_channel_qparams( @@ -1692,6 +1702,9 @@ def forward(self, x, scale): ) @parametrize("in_features", (128,)) @parametrize("out_features", (64,)) + @unittest.skipIf( + IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED, "flaky on AArch64 (no SVE)" + ) def test_int8_woq_mm_concat( self, dtype, batch_size, mid_dim, in_features, out_features ): @@ -1757,7 +1770,8 @@ def forward(self, x, scale1, scale2, scale3): self._check_amx_counter(vec_amx) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches @@ -1841,9 +1855,8 @@ def forward(self, a): vec_amx = VecAMX() self._check_amx_counter(vec_amx) - if torch.cpu._is_amx_tile_supported(): - # Only AMX ISA based micro-kernel is currently supported for da8w8 - self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + # Only AMX ISA based micro-kernel is currently supported for da8w8 + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) @inductor_config.patch({"freezing": True}) @patches @@ -1886,7 +1899,8 @@ def forward(self, x): ) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches @@ -1957,7 +1971,8 @@ def forward(self, x): self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"cpp.use_small_dequant_buffer": True}) @@ -2000,9 +2015,10 @@ def forward(self, x): _target_code_check = f"constexpr int64_t Kc_blocks = {group_size // kr};" torch._C.FileCheck().check(_target_code_check).run(code) - @unittest.expectedFailure # Int4 kernel numerical errors (5.4x rel diff, 5.8% mismatch) + @xfailIf(IS_ARM64) # Int4 kernel numerical errors (5.4x rel diff, 5.8% mismatch) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches @@ -2071,9 +2087,10 @@ def forward(self, x): ) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) - @unittest.expectedFailure # Int4 kernel numerical errors (43.5x rel diff, 10.7% mismatch) + @xfailIf(IS_ARM64) # Int4 kernel numerical errors (43.5x rel diff, 10.7% mismatch) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @inductor_config.patch({"cpp.enable_concat_linear": True}) @@ -3066,7 +3083,8 @@ def forward(self, u, v): self.common(mod, (u, v)) @unittest.skipIf( - not torch.cpu._is_amx_tile_supported(), "AMX ISA support is required" + not isinstance(torch._inductor.cpu_vec_isa.pick_vec_isa(), VecAMX), + "AMX ISA support is required", ) @inductor_config.patch({"freezing": True}) @patches diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index f43960f064e79..f301da66d5c4e 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -31,6 +31,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, SM90OrLater, TEST_MULTIGPU, @@ -44,8 +45,11 @@ parametrize, skipIfRocm, skipIfRocmArch, + skipIfXpu, + TEST_CUDA, TEST_WITH_ASAN, TEST_WITH_ROCM, + TEST_XPU, xfailIfROCm, ) from torch.testing._internal.inductor_utils import IS_BIG_GPU @@ -89,14 +93,21 @@ aten = torch.ops.aten +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) + + @instantiate_parametrized_tests class CudaReproTests(TestCase): - device = "cuda" + device = device_type common = check_model_cuda def test_mm_out_dtype_compile(self): - a = torch.randn(1, 3, device="cuda", dtype=torch.float16) - b = torch.randn(3, 2, device="cuda", dtype=torch.float16) + a = torch.randn(1, 3, device=device_type, dtype=torch.float16) + b = torch.randn(3, 2, device=device_type, dtype=torch.float16) def fn(x, y): return torch.mm(x, y, out_dtype=torch.float32) @@ -135,7 +146,8 @@ def forward( (torch.Size([512, 768]), torch.float16), ] inps = [torch.zeros(())] + [ - torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps + torch.ones(shape, dtype=dtype, device=device_type) + for (shape, dtype) in inps ] mod = make_fx(forward)(*inps) compiled = compile_fx_inner(mod, inps) @@ -167,8 +179,8 @@ def forward(self, x: torch.Tensor): return {"ten0": out0, "ten1": out1} torch.manual_seed(0) - model = ReproModule().cuda() - inputs = torch.randn(36, 9, 7, 16, device="cuda", requires_grad=True) + model = ReproModule().to(device_type) + inputs = torch.randn(36, 9, 7, 16, device=device_type, requires_grad=True) eager_out = model(inputs) compiled_model = torch.compile( @@ -182,6 +194,7 @@ def forward(self, x: torch.Tensor): self.assertEqual(compiled_out["ten0"], eager_out["ten0"]) self.assertEqual(compiled_out["ten1"], eager_out["ten1"]) + @skipIfXpu(msg="RuntimeError, torch-xpu-ops: 2891") def test_effn_attn_bias_padding(self): batch_size, num_heads, seq_len, head_dim = 2, 32, 512, 128 @@ -207,15 +220,19 @@ def fn( scale=None, ) - query = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda") - key = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda") - value = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda") + query = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device_type + ) + key = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device_type) + value = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device_type + ) - input_tensor = torch.rand([2, 1, seq_len, 1], device="cuda") + input_tensor = torch.rand([2, 1, seq_len, 1], device=device_type) out, code = run_and_get_code(torch.compile(fn), query, key, value, input_tensor) - input_tensor2 = torch.rand([2, 32, seq_len, seq_len], device="cuda").copy_( + input_tensor2 = torch.rand([2, 32, seq_len, seq_len], device=device_type).copy_( input_tensor ) # even though the last dim is broadcasted, needs stride 1 for alignment @@ -230,6 +247,7 @@ def fn( # Greatest absolute difference: 0.07861328125 at index (14, 13, 1008, 36) (up to 1e-05 allowed) # Greatest relative difference: 2.90625 at index (14, 13, 1008, 36) (up to 0.016 allowed) @skipIfRocmArch(MI350_ARCH) + @skipIfXpu(msg="RuntimeError, torch-xpu-ops: 2697") def test_effn_attn_bias_padding_misaligned(self): seqlen_start = 1008 @@ -238,10 +256,18 @@ def test_effn_attn_bias_padding_misaligned(self): torch._dynamo.reset() bsz = 32 - q = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") - k = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") - v = torch.randn(bsz, 16, seqlen, 64, dtype=torch.bfloat16, device="cuda") - mask = torch.ones([bsz, 1, seqlen, seqlen], dtype=torch.bool, device="cuda") + q = torch.randn( + bsz, 16, seqlen, 64, dtype=torch.bfloat16, device=device_type + ) + k = torch.randn( + bsz, 16, seqlen, 64, dtype=torch.bfloat16, device=device_type + ) + v = torch.randn( + bsz, 16, seqlen, 64, dtype=torch.bfloat16, device=device_type + ) + mask = torch.ones( + [bsz, 1, seqlen, seqlen], dtype=torch.bool, device=device_type + ) inputs = [q, k, v, mask] def f(q, k, v, mask): @@ -262,12 +288,66 @@ def f(q, k, v, mask): self.assertEqual(out, f(*inputs)) + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + "Does not support mem_eff_attention", + ) + def test_mha_mem_eff_attention_backward_large_batch_compile(self): + torch.manual_seed(0) + + eager = nn.MultiheadAttention( + embed_dim=8, + num_heads=1, + batch_first=True, + device=device_type, + dtype=torch.float16, + ) + compiled = copy.deepcopy(eager) + compiled = torch.compile(compiled) + + q = torch.rand(2**16, 1, 8, device=device_type, dtype=torch.float16) + kv = torch.rand(2**16, 2, 8, device=device_type, dtype=torch.float16) + mask = torch.randint(0, 2, (2**16, 2), device=device_type, dtype=torch.bool) + + eager_q = q.detach().clone().requires_grad_(True) + eager_kv = kv.detach().clone().requires_grad_(True) + compiled_q = q.detach().clone().requires_grad_(True) + compiled_kv = kv.detach().clone().requires_grad_(True) + + eager_out = eager( + query=eager_q, + key=eager_kv, + value=eager_kv, + key_padding_mask=mask, + attn_mask=None, + need_weights=False, + )[0] + eager_out.sum().backward() + + compiled_out = compiled( + query=compiled_q, + key=compiled_kv, + value=compiled_kv, + key_padding_mask=mask, + attn_mask=None, + need_weights=False, + )[0] + compiled_out.sum().backward() + + self.assertEqual(compiled_out, eager_out) + self.assertEqual(compiled_q.grad, eager_q.grad) + self.assertEqual(compiled_kv.grad, eager_kv.grad) + def test_input_channels_last(self): m = torch.nn.Sequential( torch.nn.Conv2d(3, 3, 1, 1), ToTuple(), - ).cuda() - inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda() + ).to(device_type) + inp = ( + torch.randn([2, 3, 16, 16]) + .to(memory_format=torch.channels_last) + .to(device_type) + ) self.common( m, @@ -291,10 +371,12 @@ def forward(self, x, y): return [permute, add] inps = [ - rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"), + rand_strided( + (12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, device_type + ), rand_strided((), (), torch.int64, "cpu"), ] - mod = make_fx(Repro().to(device="cuda"))(*inps) + mod = make_fx(Repro().to(device=device_type))(*inps) compiled = compile_fx_inner(mod, inps) compiled(inps) @@ -305,7 +387,7 @@ def test_backward_context(self): def fn(x): return x * 3 - x = torch.randn(4, device="cuda", requires_grad=True) + x = torch.randn(4, device=device_type, requires_grad=True) gO = torch.rand_like(x) opt_fn = torch.compile(fn) out = opt_fn(x) @@ -317,7 +399,7 @@ def forward(): randn = torch.ops.aten.randn.default( [12, 64, 1, 64], dtype=torch.float32, - device=torch.device(type="cuda", index=0), + device=torch.device(type=device_type, index=0), pin_memory=False, ) unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1) @@ -325,9 +407,9 @@ def forward(): mod = make_fx(forward)() compiled = compile_fx_inner(mod, ()) - if compiled([])[0].device.type != "cuda": + if compiled([])[0].device.type != device_type: raise AssertionError( - f"Expected device type 'cuda', got {compiled([])[0].device.type!r}" + f"Expected device type {device_type}, got {compiled([])[0].device.type!r}" ) @config.patch({"triton.cudagraphs": True}) @@ -343,7 +425,7 @@ def forward(self): 1, dtype=torch.float32, layout=torch.strided, - device=torch.device(type="cuda", index=0), + device=torch.device(type=device_type, index=0), pin_memory=False, ) full_1 = torch.ops.aten.full.default( @@ -351,7 +433,7 @@ def forward(self): 0, dtype=torch.int64, layout=torch.strided, - device=torch.device(type="cuda", index=0), + device=torch.device(type=device_type, index=0), pin_memory=False, ) return (full_1, full) @@ -366,8 +448,8 @@ def fn(x, y): return x + y inputs = ( - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), + rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device=device_type), + rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device=device_type), ) self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) @@ -386,14 +468,14 @@ def fn(x, y): return r, r.size(0) inputs = ( - torch.randn((5, 5), device="cuda"), - torch.randn((5, 5), device="cuda"), + torch.randn((5, 5), device=device_type), + torch.randn((5, 5), device=device_type), ) self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5))) inputs = ( - torch.randn((6, 6), device="cuda"), - torch.randn((6, 6), device="cuda"), + torch.randn((6, 6), device=device_type), + torch.randn((6, 6), device=device_type), ) self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6))) @@ -412,17 +494,59 @@ def max(x): print(f"compile {ms_c=:.03f}, eager {ms_eager=:.03f}") def test_split_reduction_transposed(self): - x = torch.randn(4096, 8192, dtype=torch.bfloat16, device="cuda") + x = torch.randn(4096, 8192, dtype=torch.bfloat16, device=device_type) x = x.t().contiguous().t() self._test_split_reduction_impl(x) def test_split_reduction_channels_last(self): - x = torch.randn(4096, 8192, dtype=torch.bfloat16, device="cuda") + x = torch.randn(4096, 8192, dtype=torch.bfloat16, device=device_type) x = x.reshape([256, 256, 256, 2]).to(memory_format=torch.channels_last) self._test_split_reduction_impl(x) + def test_split_with_sizes_reshape_cat_cantsplit_regression(self): + class Repro(nn.Module): + def __init__(self): + super().__init__() + self.embedding = nn.Embedding(num_embeddings=128, embedding_dim=32) + self.linear = nn.Linear(32, 96) + self.conv = nn.Conv2d(3, 16, 3, padding=1) + + def forward(self, x, indices): + embedded = self.embedding(indices) + linear_out = self.linear(embedded) + conv_out = self.conv(x) + batch_size = conv_out.shape[0] + conv_flat = conv_out.view(batch_size, -1) + seq_out = linear_out[:, -1, :] + combined = torch.cat([seq_out, conv_flat], dim=1) + chunks = torch.ops.aten.split_with_sizes.default( + combined, split_sizes=[32, 64, combined.size(1) - 96], dim=1 + ) + chunk0_reshaped = torch.ops.aten.reshape.default( + chunks[0], (batch_size, 4, 8) + ) + chunk1_reshaped = torch.ops.aten.reshape.default( + chunks[1], (batch_size, 8, 8) + ) + chunk2_reshaped = torch.ops.aten.reshape.default( + chunks[2], (batch_size, -1, 8) + ) + return torch.ops.aten.cat.default( + [chunk0_reshaped, chunk1_reshaped, chunk2_reshaped], dim=1 + ) + + model = Repro().to(device_type) + x = torch.randn(2, 3, 32, 32, dtype=torch.float32, device=device_type) + indices = torch.randint(0, 128, (2, 10), dtype=torch.long, device=device_type) + + with torch.no_grad(): + eager_out = model(x, indices) + compiled_out = torch.compile(model)(x, indices) + + torch.testing.assert_close(compiled_out, eager_out) + @config.patch({"emulate_precision_casts": True}) def test_bool_emulate_low_precision(self): from torch import device @@ -439,7 +563,7 @@ def forward(): pin_memory=False, ) device_put_3 = torch.ops.prims.device_put.default( - full_1, device(type="cuda", index=0) + full_1, device(type=device_type, index=0) ) full_1 = None @@ -460,13 +584,13 @@ def forward(): view_15 = torch.ops.aten.reshape.default(clone, [1536, 1536]) clone = None scalar_tensor = torch.ops.aten.scalar_tensor.default( - -inf, dtype=torch.float16, device=device(type="cuda", index=0) + -inf, dtype=torch.float16, device=device(type=device_type, index=0) ) scalar_tensor_1 = torch.ops.aten.scalar_tensor.default( 0.0, dtype=torch.float16, layout=torch.strided, - device=device(type="cuda", index=0), + device=device(type=device_type, index=0), ) where = torch.ops.aten.where.self(view_15, scalar_tensor_1, scalar_tensor) view_15 = scalar_tensor_1 = scalar_tensor = None @@ -482,7 +606,9 @@ def test_emulate_low_precision(self): def foo(x): return torch.nn.functional.gelu(x) * 10.0 - inp = torch.rand([32], device="cuda", requires_grad=True, dtype=torch.bfloat16) + inp = torch.rand( + [32], device=device_type, requires_grad=True, dtype=torch.bfloat16 + ) out, codes = run_fw_bw_and_get_code(lambda: torch.compile(foo)(inp)) # fwd, backward @@ -525,8 +651,8 @@ def fn(x, y): return x + y inputs = ( - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), - rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"), + rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device=device_type), + rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device=device_type), ) self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1])) @@ -547,11 +673,11 @@ def forward(self, x): from copy import deepcopy - model = Repro().cuda() + model = Repro().to(device_type) model_ref = deepcopy(model) model_opt = torch.compile(model, backend="inductor") - input = torch.randn(10, 10, device="cuda", requires_grad=True) + input = torch.randn(10, 10, device=device_type, requires_grad=True) for _ in range(2): output_ref = model_ref(input) @@ -574,7 +700,7 @@ def foo(x): foo_opt = torch.compile(foo, backend="inductor") - inpt = torch.randn(10, 10, device="cuda", requires_grad=True) + inpt = torch.randn(10, 10, device=device_type, requires_grad=True) # TODO: this is broken, fix later # out = foo_opt(inpt) # out.add_(2) @@ -602,14 +728,14 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor): ) return cross_entropy - mod = Repro().cuda() + mod = Repro().to(device_type) opt_mod = torch.compile(mod, backend="inductor") mod.eval() opt_mod.eval() args = [ - ((1,), (1,), torch.int64, "cuda", False), - ((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True), + ((1,), (1,), torch.int64, device_type, False), + ((1, 128, 768), (98304, 768, 1), torch.float32, device_type, True), ] args = [ rand_strided(sh, st, dt, dev).requires_grad_(rg) @@ -628,7 +754,7 @@ def forward(add_1): getitem_1 = var_mean[1] return getitem_1 - x = torch.randn(1, 8, 768, device="cuda") + x = torch.randn(1, 8, 768, device=device_type) correct = forward(x) actual = torch.compile(forward, fullgraph=True)(x) self.assertEqual(actual, correct) @@ -640,7 +766,7 @@ def forward(x): 0, dtype=torch.float64, layout=torch.strided, - device="cuda", + device=device_type, pin_memory=False, ) return x + full_10.to("cpu") @@ -657,7 +783,11 @@ def test_autotune_inplace_kernel(self): https://github.com/triton-lang/triton/issues/781 https://github.com/pytorch/torchdynamo/issues/1670 """ - from torch._C import _cuda_getCurrentRawStream as get_cuda_stream + if TEST_XPU: + from torch._C import _xpu_getCurrentRawStream as get_gpu_stream + else: + from torch._C import _cuda_getCurrentRawStream as get_gpu_stream + from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType from torch._inductor.runtime.triton_heuristics import CachingAutotuner from torch._inductor.utils import triton_version_uses_attrs_dict @@ -695,7 +825,7 @@ def decorator(fn): "in_ptr0": "*fp32", "xnumel": "i32", }, - "device": DeviceProperties.create(torch.device("cuda")), + "device": DeviceProperties.create(torch.device(device_type)), "configs": [ AttrsDescriptorWrapper(divisible_by_16=(0, 1), equal_to_1=()) ], @@ -714,11 +844,11 @@ def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr): tl.store(in_out_ptr0 + offsets, output, mask=mask) xnumel = 384 - in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) - inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32) + in0 = rand_strided((xnumel,), (1,), device=device_type, dtype=torch.float32) + inout1 = rand_strided((xnumel,), (1,), device=device_type, dtype=torch.float32) inout2 = inout1.clone() - stream0 = get_cuda_stream(0) + stream0 = get_gpu_stream(0) kernel.run(inout1, in0, xnumel, stream=stream0) kernel.run(inout2, in0, xnumel, stream=stream0) @@ -734,7 +864,7 @@ def forward(pred_objectness_logits_3_: torch.Tensor): getitem_12 = sort_3[0] return getitem_12 - args = [((1, 100), (0, 1), torch.float16, "cuda", False)] + args = [((1, 100), (0, 1), torch.float16, device_type, False)] args = [ rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args @@ -751,7 +881,7 @@ def fn(a): zero = torch.zeros((16,), device=a.device, dtype=torch.int64) return (a[zero],) - a = torch.randn((8,), dtype=torch.float32, device="cuda") + a = torch.randn((8,), dtype=torch.float32, device=device_type) fn_optimized = torch.compile(fn, backend="inductor") if not same(fn(a), fn_optimized(a)): @@ -768,8 +898,8 @@ def fn(x, y): out = torch.ops.aten.multiply(y, squeeze) return (out,) - a = torch.zeros((1, 128), dtype=torch.int64, device="cuda") - b = torch.zeros((1, 128), dtype=torch.int64, device="cuda") + a = torch.zeros((1, 128), dtype=torch.int64, device=device_type) + b = torch.zeros((1, 128), dtype=torch.int64, device=device_type) fn_optimized = torch.compile(fn, backend="inductor") if not same(fn(a, b), fn_optimized(a, b)): @@ -779,7 +909,9 @@ def test_simplify_dims(self): def fn(a): return (a + 1,) - self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],)) + self.common( + fn, (torch.randn(2, 3, 10, 5, 6, device=device_type)[:, :, 2::2, :, :],) + ) @config.patch(permute_fusion=True) def test_permute_fusion(self): @@ -792,8 +924,8 @@ def forward(self, view, reshape_2): return (bmm,) args = [ - ((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True), - ((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True), + ((1024, 642, 160), (102720, 160, 1), torch.float32, device_type, True), + ((1024, 642, 20), (12840, 20, 1), torch.float32, device_type, True), ] args = [ rand_strided(sh, st, dt, dev).requires_grad_(rg) @@ -813,10 +945,10 @@ def fn(x, y): aten.add_.Tensor(x, y, alpha=0.55) return (x,) - x1 = torch.zeros(2, 3, 4, 10, device="cuda") - x2 = torch.zeros(2, 3, 4, 10, device="cuda") - x3 = torch.zeros(2, 3, 4, 10, device="cuda") - y = torch.randn(2, 3, 4, 10, device="cuda").to( + x1 = torch.zeros(2, 3, 4, 10, device=device_type) + x2 = torch.zeros(2, 3, 4, 10, device=device_type) + x3 = torch.zeros(2, 3, 4, 10, device=device_type) + y = torch.randn(2, 3, 4, 10, device=device_type).to( memory_format=torch.channels_last ) fn_fx = make_fx(fn)(x1, y) @@ -832,15 +964,18 @@ def foo(x, y, z): a = x @ y return a.unsqueeze(0).unsqueeze(0) + z - x = torch.zeros(5, 5, device="cuda") - y = torch.zeros(5, 5, device="cuda") - z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last) + x = torch.zeros(5, 5, device=device_type) + y = torch.zeros(5, 5, device=device_type) + z = torch.zeros(1, 1, 5, 5, device=device_type).to( + memory_format=torch.channels_last + ) self.common( foo, (x, y, z), check_lowp=False, ) + @skipIfXpu(msg="TypeError, torch-xpu-ops: 3004") def test_memory_history_inductor(self): def called_inside_compile(x, w, b): a = x @ w + b @@ -851,16 +986,26 @@ def fn(x, w, b): x = called_inside_compile(x, w, b) return called_inside_compile(x, w, b) - w = torch.rand(3, 3, device="cuda") - b = torch.rand(3, device="cuda") - x = torch.rand(3, device="cuda") + w = torch.rand(3, 3, device=device_type) + b = torch.rand(3, device=device_type) + x = torch.rand(3, device=device_type) + + def record_memory_history(value: bool): + if torch.xpu.is_available(): + torch.xpu.memory._record_memory_history(value) + else: + torch.cuda.memory._record_memory_history(value) + try: - torch.cuda.memory.empty_cache() - torch.cuda.memory._record_memory_history(True) + torch.accelerator.memory.empty_cache() + record_memory_history(True) r = fn(x, w, b) finally: - torch.cuda.memory._record_memory_history(False) - snapshot = str(torch.cuda.memory._snapshot()) + record_memory_history(False) + if torch.xpu.is_available(): + snapshot = str(torch.xpu.memory._snapshot()) + else: + snapshot = str(torch.cuda.memory._snapshot()) self.assertTrue("called_inside_compile" in snapshot) def test_negative_arange_dynamic_shapes(self): @@ -893,10 +1038,14 @@ def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor): padmask = dec_in == 0 dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2) dec_mask = dec_mask.to(dtype=torch.float32) - dec_mask = dec_mask.tril(diagonal=0).cuda() + dec_mask = dec_mask.tril(diagonal=0).to(device_type) - q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") - k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda") + q_pos = torch.arange( + dec_in.size(1), dtype=torch.long, device=device_type + ) + k_pos = torch.arange( + dec_in.size(1), dtype=torch.long, device=device_type + ) rel_pos = k_pos[None, :] - q_pos[:, None] values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0) dec_bias = values * self.scales @@ -907,14 +1056,15 @@ def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor): out = self.dec_layer(out, enc_out, tgt_mask=dec_mask) return self.head(out) - mod = Repro().cuda() + mod = Repro().to(device_type) opt_mod = torch.compile(mod, backend="inductor", dynamic=True) mod.eval() opt_mod.eval() - enc_out = torch.rand(1, 512, 256).cuda() + enc_out = torch.rand(1, 512, 256).to(device_type) dec_inputs = [ - torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8) + torch.randint(0, 512, (1, i + 1), dtype=torch.long).to(device_type) + for i in range(8) ] for dec_inp in dec_inputs: @@ -928,9 +1078,9 @@ def fn(arg3_1, relu, permute_1): return (cat_2,) args = [ - ((96,), (1,), torch.float32, "cuda"), - ((10, 256), (256, 1), torch.float32, "cuda"), - ((256, 96), (1, 256), torch.float32, "cuda"), + ((96,), (1,), torch.float32, device_type), + ((10, 256), (256, 1), torch.float32, device_type), + ((256, 96), (1, 256), torch.float32, device_type), ] args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] correct = fn(*args) @@ -958,7 +1108,7 @@ def forward(self, x): inp = x / y[..., None] return self.layer(inp) - x = torch.rand([4, 4], device="cuda") + x = torch.rand([4, 4], device=device_type) m = MyModule() opt_m = torch.compile(backend="inductor")(m) self.assertEqual(opt_m(x), m(x)) @@ -971,10 +1121,10 @@ def fn(arg3_1, arg3_2, relu, permute_1): return (cat_2,) args = [ - ((96,), (1,), torch.float32, "cuda"), - ((96,), (1,), torch.float32, "cuda"), - ((10, 256), (256, 1), torch.float32, "cuda"), - ((256, 96), (1, 256), torch.float32, "cuda"), + ((96,), (1,), torch.float32, device_type), + ((96,), (1,), torch.float32, device_type), + ((10, 256), (256, 1), torch.float32, device_type), + ((256, 96), (1, 256), torch.float32, device_type), ] args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args] correct = fn(*args) @@ -1003,7 +1153,9 @@ def test_normalize_norm_leq_one(self): def fn(x: torch.Tensor) -> torch.Tensor: return torch.nn.functional.normalize(x, dim=-1) - inp = torch.tensor([[3.799999, 0.0, 0.0]], device="cuda", dtype=torch.float32) + inp = torch.tensor( + [[3.799999, 0.0, 0.0]], device=device_type, dtype=torch.float32 + ) compiled = torch.compile(fn, backend="inductor", fullgraph=True) out = compiled(inp) norm = out.norm(dim=-1) @@ -1015,7 +1167,7 @@ def test_libdevice_routing(self): def foo(x): return x.exp() - inp = torch.ones(64, device="cuda").to(torch.float64) + inp = torch.ones(64, device=device_type).to(torch.float64) out, code = run_and_get_code(torch.compile(foo), inp) FileCheck().check("libdevice.exp").run(code[0]) @@ -1029,7 +1181,7 @@ def foo(x): def foo(x): return x.sigmoid() - inp = torch.ones(64, device="cuda").to(torch.float64) + inp = torch.ones(64, device=device_type).to(torch.float64) out, code = run_and_get_code(torch.compile(foo), inp) FileCheck().check("libdevice.exp").run(code[0]) self.assertEqual(foo(inp), out) @@ -1041,7 +1193,7 @@ def view_copy(target, source): assert source.dtype == torch.uint16 # noqa: S101 target.view(torch.uint16).copy_(source) - target = torch.ones(1024, dtype=torch.bfloat16, device="cuda") + target = torch.ones(1024, dtype=torch.bfloat16, device=device_type) source = torch.full_like(target, 4, dtype=torch.uint16) out = target.view(torch.uint16).copy_(source).clone() @@ -1055,7 +1207,7 @@ def forward(arg0_1): 1, dtype=torch.float32, layout=torch.strided, - device=torch.device(type="cuda", index=0), + device=torch.device(type=device_type, index=0), pin_memory=False, ) convert_element_type_1 = torch.ops.prims.convert_element_type.default( @@ -1073,23 +1225,23 @@ def forward(arg0_1): ) return [var_mean[0], var_mean[1], add_2] - emb = torch.randn([2050, 768], device="cuda") + emb = torch.randn([2050, 768], device=device_type) gm = make_fx(forward)(emb) opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb]) opt([emb]) - torch.cuda.synchronize() + torch.accelerator.synchronize() def test_deterministic_algorithms(self): N = 10000 @torch.compile def fn(idx, values): - x = torch.zeros(1, device="cuda") + x = torch.zeros(1, device=device_type) x[idx] += values return x - idx = torch.zeros(N, dtype=torch.int64, device="cuda") - values = torch.randn(N, device="cuda") + idx = torch.zeros(N, dtype=torch.int64, device=device_type) + values = torch.randn(N, device=device_type) r0 = fn(idx, values) with DeterministicGuard(True): @@ -1106,10 +1258,10 @@ def __init__(self) -> None: self.linear = nn.Linear(4, 4) def forward(self, data): - data = data.to("cuda") + data = data.to(device_type) return self.linear(data) - mod = Model().cuda().eval() + mod = Model().to(device_type).eval() with torch.no_grad(): self.common(mod, (torch.randn(4, 4),)) @@ -1125,7 +1277,7 @@ def forward(self, x): return self.dropout(y) mod = Repro() - x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda") + x = torch.randn((512, 1, 4096), requires_grad=True, device=device_type) y = torch.compile(mod)(x) # Inductor claims the output layout of gelu's saved variable for # backwards will be (4096, 4096, 1) but in actuality it is (4096, @@ -1152,9 +1304,9 @@ def forward(inductor_seeds, mul_4, view_15): sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3) return (sub_3,) - buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda") - buf1 = torch.zeros((2, 512, 768), device="cuda") - buf2 = torch.zeros((2, 512, 768), device="cuda") + buf0 = torch.zeros((37,), dtype=torch.int64, device=device_type) + buf1 = torch.zeros((2, 512, 768), device=device_type) + buf2 = torch.zeros((2, 512, 768), device=device_type) forward(buf0, buf1, buf2) def test_issue100806(self): @@ -1174,7 +1326,7 @@ def forward(self, x): x = self.relu(x) return x - device = "cuda" + device = device_type batch_size = 2 x = torch.randn(batch_size, 10).to(device) func = Model().to(device) @@ -1194,8 +1346,8 @@ def fn(x, y): add = mean + y return add - x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda") - y = torch.rand((), device="cuda") + x = torch.rand(4, 4, 4, 4, 4, 4, device=device_type) + y = torch.rand((), device=device_type) expect = fn(x, y) opt_fn = torch.compile(fn) @@ -1214,8 +1366,8 @@ def test_bucketize_dynamic_dense(self): def fn(values, offsets): return torch.bucketize(values, offsets) - values = torch.rand((64, 64), device="cuda") - offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda") + values = torch.rand((64, 64), device=device_type) + offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device=device_type) expect = fn(values, offsets) @@ -1253,9 +1405,9 @@ def fn(x: torch.Tensor, y: torch.Tensor, buckets: torch.Tensor) -> torch.Tensor: z = torch.mm(x, y) return torch.bucketize(z, buckets) - buckets = torch.arange(-100, 100, 10, device="cuda") - x = torch.randn(64, 64, device="cuda").clamp(-99, 99) - y = torch.randn(64, 64, device="cuda").clamp(-99, 99) + buckets = torch.arange(-100, 100, 10, device=device_type) + x = torch.randn(64, 64, device=device_type).clamp(-99, 99) + y = torch.randn(64, 64, device=device_type).clamp(-99, 99) opt_fn = torch.compile(fn, mode="max-autotune") @@ -1268,7 +1420,7 @@ def test_float64_constants(self): def fn(): # NOTE: tensors of all the same value are constant folded, so we # need a tensor with two distinct values - a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda") + a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device=device_type) return a * 2e50 cfn = torch.compile(fn) @@ -1308,10 +1460,15 @@ def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): return (bmm,) args = [] - args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda")) + args.append( + torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device=device_type) + ) args.append( rand_strided( - (1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda" + (1, 4, 1000, 4), + (16000, 4, 16, 1), + dtype=torch.float16, + device=device_type, ) ) args.append( @@ -1319,7 +1476,7 @@ def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): (3, 1, 4, 1000, 4), (16, 48000, 4, 48, 1), dtype=torch.float16, - device="cuda", + device=device_type, ) ) args.append( @@ -1327,7 +1484,7 @@ def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): (2, 1, 4, 1000, 4), (16, 48000, 4, 48, 1), dtype=torch.float16, - device="cuda", + device=device_type, ) ) args.append( @@ -1335,7 +1492,7 @@ def fn(arg7_1, add_1, permute_2, select_scatter, slice_8): (2, 1, 4, 1000, 4), (19200, 19200, 4800, 4, 1), dtype=torch.float16, - device="cuda", + device=device_type, ) ) @@ -1352,9 +1509,9 @@ def fn(x, y, z): x = torch.zeros_like(x) return x.index_put_([y], z, True) - x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) - y = torch.zeros((512,), device="cuda", dtype=torch.int64) - z = torch.ones((512, 512), device="cuda", dtype=torch.bool) + x = torch.zeros((512, 512), device=device_type, dtype=torch.bool) + y = torch.zeros((512,), device=device_type, dtype=torch.int64) + z = torch.ones((512, 512), device=device_type, dtype=torch.bool) opt_fn = torch.compile(fn, backend="inductor") @@ -1375,9 +1532,9 @@ def fn(x, y, z): x = torch.zeros_like(x) return x.index_put([y], z, True) - x = torch.zeros((512, 512), device="cuda", dtype=torch.bool) - y = torch.zeros((512,), device="cuda", dtype=torch.int64) - z = torch.ones((512, 512), device="cuda", dtype=torch.bool) + x = torch.zeros((512, 512), device=device_type, dtype=torch.bool) + y = torch.zeros((512,), device=device_type, dtype=torch.int64) + z = torch.ones((512, 512), device=device_type, dtype=torch.bool) opt_fn = torch.compile(fn, backend="inductor") @@ -1420,18 +1577,13 @@ def forward(self, x): cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor") - model = Model().cuda().half() + model = Model().to(device_type).half() model = torch.compile(model, backend=cnts, dynamic=True) - with torch.backends.cuda.sdp_kernel( - enable_flash=True, - enable_math=False, - enable_mem_efficient=False, - enable_cudnn=False, - ): - input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16) - input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16) - input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16) + with sdpa_kernel(SDPBackend.FLASH_ATTENTION): + input1 = torch.rand(5, 512, 1024, device=device_type, dtype=torch.float16) + input2 = torch.rand(5, 513, 1024, device=device_type, dtype=torch.float16) + input3 = torch.rand(5, 514, 1024, device=device_type, dtype=torch.float16) out1 = model(input1) out2 = model(input2) @@ -1445,9 +1597,9 @@ def fn(x, y, z): x = torch.zeros_like(x) return x.index_put([y], z, True) - x = torch.zeros((512, 512), device="cuda", dtype=torch.int32) - y = torch.zeros((512,), device="cuda", dtype=torch.int64) - z = torch.ones((512, 512), device="cuda", dtype=torch.int32) + x = torch.zeros((512, 512), device=device_type, dtype=torch.int32) + y = torch.zeros((512,), device=device_type, dtype=torch.int64) + z = torch.ones((512, 512), device=device_type, dtype=torch.int32) opt_fn = torch.compile(fn, backend="inductor") @@ -1459,13 +1611,38 @@ def fn(x, y, z): self.assertEqual(ref, res) + @parametrize("lowp_dtype", [torch.bfloat16, torch.float16]) + @torch._inductor.config.patch(emulate_precision_casts=True) + def test_emulate_precision_casts_preserves_explicit_precision_cast( + self, lowp_dtype + ): + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) if TEST_CUDA else torch.xpu.manual_seed_all(0) + lowp_name = str(lowp_dtype).removeprefix("torch.") + x = torch.randn(4, 32, 32, device=device_type, dtype=torch.float32) + w = torch.randn(32, 32, device=device_type, dtype=torch.float32) + + def fn(x, w): + x = torch.matmul(x, w) + x = x.to(lowp_dtype).float() + x = x * torch.sigmoid(x) + return x.sum(dim=1) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + + expected = fn(x, w) + actual, (code,) = run_and_get_code(opt_fn, x.clone(), w.clone()) + self.assertEqual(expected, actual) + self.assertIn("args = (%convert_element_type, torch.float32)", code) + self.assertIn(f".to(tl.{lowp_name})", code) + @torch._inductor.config.patch(emulate_precision_casts=True) @torch._inductor.config.patch(pattern_matcher=False) def test_emulate_precision_casts_convert_element_type(self): torch.manual_seed(0) - torch.cuda.manual_seed_all(0) + torch.cuda.manual_seed_all(0) if TEST_CUDA else torch.xpu.manual_seed_all(0) - x = torch.rand(1000, device="cuda", dtype=torch.float32) + x = torch.rand(1000, device=device_type, dtype=torch.float32) def fn(x): x_bf16 = x.to(torch.bfloat16) @@ -1483,8 +1660,8 @@ def test_emulate_precision_casts_norm_rounding(self): torch.manual_seed(0) torch.cuda.manual_seed_all(0) - x = torch.rand(1000, device="cuda", dtype=torch.bfloat16) - scalar = torch.rand([], device="cuda", dtype=torch.float32) + x = torch.rand(1000, device=device_type, dtype=torch.bfloat16) + scalar = torch.rand([], device=device_type, dtype=torch.float32) def fn(inp, scale): y = inp.norm() @@ -1497,10 +1674,11 @@ def fn(inp, scale): self.assertEqual(expected, actual) + @skipIfXpu(msg="AssertionError, torch-xpu-ops: #2554") @torch._inductor.config.patch(emulate_precision_casts=True) def test_emulate_precision_casts_min_pow_chain(self): torch.manual_seed(0) - torch.cuda.manual_seed_all(0) + torch.cuda.manual_seed_all(0) if TEST_CUDA else torch.xpu.manual_seed_all(0) with dynamo_config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True @@ -1508,17 +1686,17 @@ def test_emulate_precision_casts_min_pow_chain(self): arg0 = torch.rand( [383, 55, 2, 3], dtype=torch.float16, - device="cuda", + device=device_type, requires_grad=True, ) arg1 = torch.rand( - [383, 55], dtype=torch.bfloat16, device="cuda", requires_grad=True + [383, 55], dtype=torch.bfloat16, device=device_type, requires_grad=True ) arg2 = torch.rand( - [383, 55], dtype=torch.float32, device="cuda", requires_grad=True + [383, 55], dtype=torch.float32, device=device_type, requires_grad=True ) arg3 = torch.rand( - [383, 55], dtype=torch.float32, device="cuda", requires_grad=True + [383, 55], dtype=torch.float32, device=device_type, requires_grad=True ) def fn(a0, a1, a2, a3): @@ -1557,22 +1735,37 @@ def test_emulate_precision_casts_mean_ratio_chain(self): capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ): arg0 = torch.rand( - [125070], dtype=torch.bfloat16, device="cuda", requires_grad=True + [125070], dtype=torch.bfloat16, device=device_type, requires_grad=True ) arg1 = torch.rand( - [1895, 3, 11], dtype=torch.float16, device="cuda", requires_grad=True + [1895, 3, 11], + dtype=torch.float16, + device=device_type, + requires_grad=True, ) arg2 = torch.rand( - [1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True + [1895, 3, 11], + dtype=torch.float32, + device=device_type, + requires_grad=True, ) arg3 = torch.rand( - [1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True + [1895, 3, 11], + dtype=torch.float32, + device=device_type, + requires_grad=True, ) arg4 = torch.rand( - [1895, 3, 11], dtype=torch.float32, device="cuda", requires_grad=True + [1895, 3, 11], + dtype=torch.float32, + device=device_type, + requires_grad=True, ) arg5 = torch.rand( - [5, 379, 165], dtype=torch.float32, device="cuda", requires_grad=True + [5, 379, 165], + dtype=torch.float32, + device=device_type, + requires_grad=True, ) def fn(a0, a1, a2, a3, a4, a5): @@ -1603,15 +1796,15 @@ def fn(a0, a1, a2, a3, a4, a5): @torch._inductor.config.patch(emulate_precision_casts=True) def test_dont_inplace_disjoint_accesses(self): # TODO - would not need mms if we could annotate donated buffer.. - def forward( # noqa: F821, F722 - arg0_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 - arg1_1: "bf16[8, 4096, 2048][8388608, 2048, 1]cuda:0", # noqa: F821, F722 - arg2_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 - arg3_1: "bf16[2048, 2048][2048, 1]cuda:0", # noqa: F821, F722 - arg4_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 - arg5_1: "bf16[2048][1]cuda:0", # noqa: F821, F722 - arg6_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 - arg7_1: "f32[4096, 128][128, 1]cuda:0", # noqa: F821, F722 + def forward( + arg0_1: f"bf16[2048, 2048][2048, 1]{device_type}:0", + arg1_1: f"bf16[8, 4096, 2048][8388608, 2048, 1]{device_type}:0", + arg2_1: f"bf16[2048, 2048][2048, 1]{device_type}:0", + arg3_1: f"bf16[2048, 2048][2048, 1]{device_type}:0", + arg4_1: f"bf16[2048][1]{device_type}:0", + arg5_1: f"bf16[2048][1]{device_type}:0", + arg6_1: f"f32[4096, 128][128, 1]{device_type}:0", + arg7_1: f"f32[4096, 128][128, 1]{device_type}:0", ): permute = torch.ops.aten.permute.default(arg0_1, [1, 0]) arg0_1 = None @@ -1773,7 +1966,7 @@ def forward( # noqa: F821, F722 from torch._dynamo.debug_utils import aot_graph_input_parser - kwargs = aot_graph_input_parser(forward) + kwargs = aot_graph_input_parser(forward, device=device_type) out, code = run_and_get_code(torch.compile(forward), **kwargs) # ignore tiny values.. prior to this fix absolute error was ~28 self.assertEqual(forward(**kwargs), out, atol=0.01, rtol=2) @@ -1781,8 +1974,8 @@ def forward( # noqa: F821, F722 # https://github.com/pytorch/pytorch/issues/104937 def test_linear_with_zero_infeature_size(self): - m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda") - x = torch.rand(1, 1, 0, device="cuda") + m = nn.Linear(in_features=0, out_features=0, bias=True).to(device_type) + x = torch.rand(1, 1, 0, device=device_type) expect = m(x) opt_fn = torch.compile(m) actual = opt_fn(x) @@ -1791,7 +1984,7 @@ def test_linear_with_zero_infeature_size(self): @config.patch(fallback_random=True) def test_multi_output_layout_fallback(self): mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True) - inp = torch.rand([4, 4]).cuda() + inp = torch.rand([4, 4]).to(device_type) m = torch.compile(mod) with freeze_rng_state(): @@ -1806,8 +1999,8 @@ def test_sorted_masks(self): def foo(x, y): return (x + y).sum(dim=1) - x = torch.rand([255, 255], device="cuda") - y = torch.rand([255, 255], device="cuda") + x = torch.rand([255, 255], device=device_type) + y = torch.rand([255, 255], device=device_type) _, code = run_and_get_code(foo, x, y) FileCheck().check("tl.load").check_same("r0_mask").check_same("xmask").run( @@ -1821,7 +2014,8 @@ def cat(inps): for dtype in [torch.uint8, torch.int8]: inps = [ - torch.empty([256, 256], dtype=dtype, device="cuda") for _ in range(4) + torch.empty([256, 256], dtype=dtype, device=device_type) + for _ in range(4) ] out, code = run_and_get_code(cat, inps) @@ -1860,20 +2054,20 @@ def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3): return (sub_2,) args = [ - torch.randn((8, 1024, 4, 4), device="cuda") > 0, # torch.bool tensor - torch.randn((1, 1024, 1, 1), device="cuda"), - torch.randn((8, 1024, 4, 4), device="cuda"), - torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand( - (8, 1024, 4, 4) - ), - torch.randn((), device="cuda"), - torch.randn((1024,), device="cuda"), + torch.randn((8, 1024, 4, 4), device=device_type) > 0, # torch.bool tensor + torch.randn((1, 1024, 1, 1), device=device_type), + torch.randn((8, 1024, 4, 4), device=device_type), + torch.randn( + (8, 1024, 1, 1), dtype=torch.float16, device=device_type + ).expand((8, 1024, 4, 4)), + torch.randn((), device=device_type), + torch.randn((1024,), device=device_type), ] fn(*args) - torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address + torch.accelerator.synchronize() # shake out Triton Error [CUDA]: misaligned address def test_mutated_aligned_tensor(self): - t = torch.rand(4096, device="cuda", dtype=torch.float16) + t = torch.rand(4096, device=device_type, dtype=torch.float16) def foo(x): return x.add_(1) @@ -1893,8 +2087,8 @@ def foo(x): def test_non_commutative_scan_op(self): from torch._higher_order_ops.associative_scan import associative_scan - a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") - b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda") + a = torch.randn(1024, 8192, dtype=torch.float64, device=device_type) + b = torch.randn(1024, 8192, dtype=torch.float64, device=device_type) def baseline(v, u): A = [] @@ -1922,7 +2116,7 @@ def inner_reduce(x): assert x.shape[1] <= 1024 # noqa: S101 return x.sum(1) - a = torch.randn(50, 600, device="cuda") + a = torch.randn(50, 600, device=device_type) out, code = run_and_get_code(inner_reduce, a) self.assertEqual(inner_reduce(a), out) self.assertTrue("for roffset" not in code) @@ -1979,7 +2173,7 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: ) return attn_output - device = torch.device("cuda") + device = torch.device(device_type) num_attention_heads = 8 hidden_size = 512 attention_probs_dropout_prob = 0.0 @@ -2005,30 +2199,34 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor: def test_non_contiguous_unaligned_input_indices(self): from torch._inductor.compile_fx import remove_unaligned_input_idxs - inputs = [torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")[1:]] + inputs = [ + torch.ones(2, 2, device=device_type), + torch.ones(2, 2, device=device_type)[1:], + ] idxs = remove_unaligned_input_idxs(inputs, [1]) self.assertEqual(idxs, []) inputs = [ - torch.ones(2, 2, device="cuda"), - torch.ones(2, 2, device="cuda"), - torch.ones(2, 2, device="cuda")[1:], + torch.ones(2, 2, device=device_type), + torch.ones(2, 2, device=device_type), + torch.ones(2, 2, device=device_type)[1:], ] idxs = remove_unaligned_input_idxs(inputs, [0, 2]) self.assertEqual(idxs, [0]) + @skipIfXpu(msg="cudagraph is not supported on xpu") @config.patch("triton.cudagraphs", True) def test_unused_cpu_input_cudagraphs(self): def fn(x, y): return x.sin().sin().sin().sin().cos() + 1 fx_graph = torch.fx.symbolic_trace(fn) - inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")] + inp = [torch.randn(64, device=device_type), torch.randn(64, device="cpu")] compiled_fn, (graph,) = run_and_get_graph_lowering( torch._inductor.compile, fx_graph, inp ) self.assertEqual(graph.disable_cudagraphs_reason, None) - self.assertEqual(graph.device_types, {"cuda"}) + self.assertEqual(graph.device_types, {device_type}) self.assertEqual(compiled_fn(*inp), fn(*inp)) def test_epilogue_fusion_with_view(self): @@ -2044,8 +2242,8 @@ def forward(self, x): x = x.view(x.size(0), -1) return self.relu(self.linear(x)) - m = ToyModel().to(device="cuda:0") - input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0") + m = ToyModel().to(device=f"{device_type}:0") + input_tensor = torch.randn(32, 3, 64, 64).to(device=f"{device_type}:0") from torch._inductor.utils import fresh_cache with fresh_cache(): @@ -2054,6 +2252,7 @@ def forward(self, x): out2 = m(input_tensor) self.assertEqual(out, out2, atol=1e-3, rtol=1e-3) + @skipIfXpu(msg="cudagraph is not supported on xpu") @config.patch("triton.cudagraphs", True) def test_cpu_index(self): @torch.compile(fullgraph=True) @@ -2061,25 +2260,25 @@ def fn(x): return x[torch.arange(32)] result, (graph,) = run_and_get_graph_lowering( - fn, torch.randn(64, device="cuda") + fn, torch.randn(64, device=device_type) ) self.assertEqual(graph.disable_cudagraphs_reason, None) - self.assertEqual(graph.device_types, {"cuda"}) + self.assertEqual(graph.device_types, {device_type}) - inp = torch.randn(64, device="cuda", requires_grad=True) + inp = torch.randn(64, device=device_type, requires_grad=True) result, (graph,) = run_and_get_graph_lowering(fn, inp) self.assertEqual(graph.disable_cudagraphs_reason, None) - self.assertEqual(graph.device_types, {"cuda"}) + self.assertEqual(graph.device_types, {device_type}) result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward()) self.assertEqual(graph.disable_cudagraphs_reason, None) - self.assertEqual(graph.device_types, {"cuda"}) + self.assertEqual(graph.device_types, {device_type}) @unittest.skipIf(IS_FBCODE, "Not runnable in fbcode") def test_triton_interpret(self): import subprocess - script = """ + script = f""" import os os.environ["TRITON_INTERPRET"] = "1" import torch @@ -2089,8 +2288,9 @@ def foo(x): return x + 1 # somehow gives different results.. still, check that it doesn't error -foo(torch.rand([256], device="cuda")) +foo(torch.rand([256], device=\"{device_type}\")) """ + subprocess.run([sys.executable, "-c", script], check=True) def test_reflection_pad_loop_order(self): @@ -2100,8 +2300,8 @@ def fn(x, y): return a + b cfn = torch.compile(fn) - a = torch.rand((10, 10, 10), device="cuda") - b = torch.rand((10, 10, 10), device="cuda") + a = torch.rand((10, 10, 10), device=device_type) + b = torch.rand((10, 10, 10), device=device_type) expect = fn(a, b) actual, code = run_and_get_code(cfn, a, b) self.assertEqual(expect, actual) @@ -2126,7 +2326,7 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel, tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last') tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last') tmp2 = tmp0 + tmp1 - tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 + tl.store(out_ptr0 + (x3), tmp2, xmask)""", ) @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") @@ -2162,7 +2362,7 @@ def foo(inp): return cat_1 for mark_dynamic in [False, True]: - inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda") + inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device=device_type) if mark_dynamic: torch._dynamo.mark_dynamic(inp, 0) foo_c = torch.compile(foo) @@ -2172,7 +2372,7 @@ def foo(inp): not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90" ) def test_float8_e8m0fnu(self): - device = "cuda" + device = device_type dtype = torch.float8_e8m0fnu hp_dtype = torch.float32 # and torch.bfloat16 @@ -2219,9 +2419,9 @@ def f(x, y): return torch.index_select(x, 0, y) x = torch.randn( - 2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True + 2000, 384, dtype=torch.bfloat16, device=device_type, requires_grad=True ) - y = torch.ones(713268, dtype=torch.int64, device="cuda") + y = torch.ones(713268, dtype=torch.int64, device=device_type) x_ref = x.clone().detach().requires_grad_(True) y_ref = y.clone().detach() @@ -2254,15 +2454,15 @@ def f(arg0_1, arg1_1): start=0, step=1, dtype=torch.int64, - device="cuda", + device=device_type, requires_grad=False, ) view_3 = torch.ops.aten.view.default(iota, [1, 36]) max_1 = torch.ops.aten.max.default(view_3) return (max_1,) - x = torch.ones(1, 64, device="cuda", dtype=torch.int64) - y = torch.randn(64, 3072, device="cuda", dtype=torch.bfloat16) + x = torch.ones(1, 64, device=device_type, dtype=torch.int64) + y = torch.randn(64, 3072, device=device_type, dtype=torch.bfloat16) out = f(x, y) self.assertEqual(torch.compile(f)(x, y), out) @@ -2278,9 +2478,9 @@ def f(x, y): return torch.index_select(x, 0, y) x = torch.randn( - 2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True + 2000, 384, dtype=torch.bfloat16, device=device_type, requires_grad=True ) - y = torch.ones(713268, dtype=torch.int64, device="cuda") + y = torch.ones(713268, dtype=torch.int64, device=device_type) x_ref = x.clone().detach().requires_grad_(True) y_ref = y.clone().detach() @@ -2291,6 +2491,25 @@ def f(x, y): self.assertEqual(f(x_ref, y_ref), out) + @skipCUDAIf( + not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90" + ) + @unittest.skipIf( + config.is_fbcode(), + "bfloat16 atomic add is supported in fbcode, so we won't fallback", + ) + def test_index_add_fallback_direct(self): + def f(x, idx, src): + return torch.index_add(x, 0, idx, src) + + x = torch.randn(16, 256, dtype=torch.bfloat16, device=device_type) + idx = torch.randperm(8, device=device_type) + src = torch.randn(8, 256, dtype=torch.bfloat16, device=device_type) + + out = f(x, idx, src) + compiled_out = torch.compile(f)(x, idx, src) + self.assertEqual(out, compiled_out) + @requires_multigpu() def test_not_initializing_wrong_device(self): device_stats = torch.cuda.memory_stats("cuda:0") @@ -2322,7 +2541,7 @@ def test_3d_tiling(self): 1, 2, ) - GPU_TYPE = "cuda" + GPU_TYPE = device_type def get_input() -> torch.Tensor: device = torch.device(GPU_TYPE) @@ -2344,7 +2563,7 @@ def test_repeated_masked_load(self): mem_eff_temporal_upsampling_interp_chunks = 2 from functorch.einops import rearrange - x = torch.randn(1, 8, 12, 12, 4, dtype=torch.float16, device="cuda") + x = torch.randn(1, 8, 12, 12, 4, dtype=torch.float16, device=device_type) x = x.permute(0, 1, 4, 2, 3) # make non-contiguous x = rearrange(x, "b c t h w -> b c t (h w)") @@ -2406,8 +2625,8 @@ def forward(self, x): return x - model = ToyModel().to("cuda") - input_tensor = torch.randn((2, 4)).to("cuda") + model = ToyModel().to(device_type) + input_tensor = torch.randn((2, 4)).to(device_type) compile_default = torch.compile(model, mode="default") compile_max_autotune = torch.compile(model, mode="max-autotune") @@ -2434,7 +2653,7 @@ def forward(self, x): x = self.adaptive_pool(x) return x - model = Model().cuda() + model = Model().to(device_type) model.eval() test_cases = [ (1, 3, 8, 8), @@ -2446,7 +2665,7 @@ def forward(self, x): for batch, channels, h, w in test_cases: with self.subTest(input_shape=(batch, channels, h, w)): - input_tensor = torch.randn(batch, channels, h, w, device="cuda") + input_tensor = torch.randn(batch, channels, h, w, device=device_type) # Test eager mode with torch.no_grad(): @@ -2499,7 +2718,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: numel = 1 for dim in quantiles_shape: numel *= dim - data = torch.randn(numel, dtype=torch.float32, device="cuda") + data = torch.randn(numel, dtype=torch.float32, device=device_type) # Create tensor with specified shape and strides quantiles = torch.as_strided( @@ -2509,7 +2728,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: quantiles = torch.sort(quantiles, dim=0)[0] x_shape = (batch_size,) + quantiles_shape[1:] - x = torch.randn(*x_shape, dtype=torch.float32, device="cuda") + x = torch.randn(*x_shape, dtype=torch.float32, device=device_type) foo = Foo(quantiles) foo_compiled = torch.compile(Foo(quantiles), fullgraph=True) @@ -2522,7 +2741,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertEqual(eager, compiled) def test_identity_load(self): - device = "cuda" + device = device_type def f(x, y): y2 = torch.cat( @@ -2557,9 +2776,12 @@ def f(x, y): self.assertEqual(eager_out, compile_out) + @skipIfXpu( + msg="Explicit attn_mask should not be set when is_causal=True - torch-xpu-ops: 2802" + ) def test_qwen2_7b_sdpa_input_alignment_requires_recompile(self): # SDPA constraints ensures inputs have alignment (8). - device = "cuda" + device = device_type def forward(q_proj, k_proj, attn_mask): scale = 0.08838834764831845 # 1/sqrt(128) @@ -2677,8 +2899,8 @@ def compiled_divide(x, y): torch.float32, torch.float64, ]: - y_ten = torch.tensor([y], dtype=y_dtype, device="cuda") - x_ten = torch.tensor([x], dtype=x_dtype, device="cuda") + y_ten = torch.tensor([y], dtype=y_dtype, device=device_type) + x_ten = torch.tensor([x], dtype=x_dtype, device=device_type) torch._dynamo.reset() compiled_div = Decimal(compiled_divide(x_ten, y_ten).item()) @@ -2686,6 +2908,7 @@ def compiled_divide(x, y): self.assertEqual(eager_div, compiled_div) + @skipIfXpu(msg="triton dependency - torch-xpu-ops: 2554") @config.patch({"eager_numerics.division_rounding": False}) @xfailIfROCm def test_truediv_base_not_bitwise_equivalent(self): @@ -2693,8 +2916,8 @@ def test_truediv_base_not_bitwise_equivalent(self): y, x = 7.0, 11.0 - y_ten = torch.tensor([y], dtype=torch.float32, device="cuda") - x_ten = torch.tensor([x], dtype=torch.float32, device="cuda") + y_ten = torch.tensor([y], dtype=torch.float32, device=device_type) + x_ten = torch.tensor([x], dtype=torch.float32, device=device_type) compile_out, code = run_and_get_code( torch.compile(lambda x, y: x / y), @@ -2712,7 +2935,7 @@ def test_disabling_ftz_yields_subnormals(self): from decimal import Decimal x = -127.0 - x_ten = torch.tensor([x], dtype=torch.float32, device="cuda") + x_ten = torch.tensor([x], dtype=torch.float32, device=device_type) def fn(x): return 2.0**x @@ -2722,13 +2945,14 @@ def fn(x): self.assertTrue(compile_decimal > Decimal(0)) + @skipIfXpu(msg="Decimal object comparison failed - torch-xpu-ops: 2810") @skipIfRocm(msg="ROCm preserves subnormals by default") @config.patch({"eager_numerics.disable_ftz": False}) def test_not_disabling_ftz_yields_zero(self): from decimal import Decimal x = -128.0 - x_ten = torch.tensor([x], dtype=torch.float32, device="cuda") + x_ten = torch.tensor([x], dtype=torch.float32, device=device_type) def fn(x): return 2.0**x @@ -2738,6 +2962,7 @@ def fn(x): self.assertEqual(compile_decimal, Decimal(0)) + @skipIfXpu(msg="AssertionError: torch-xpu-ops: #3006") @config.patch( {"triton.use_block_ptr": True, "triton.codegen_upcast_to_fp32": False} ) @@ -2746,7 +2971,7 @@ def test_float16_reduction_with_int_output(self): def fn(input: torch.Tensor) -> torch.Tensor: return torch.argmax(input, dim=0) - input = torch.randn(20, 20, device="cuda", dtype=torch.float16) + input = torch.randn(20, 20, device=device_type, dtype=torch.float16) _, code = run_and_get_code(fn, input) # There should not be any conversions to float16 in this code, since the input # is already float16 and the output is int64. @@ -2760,13 +2985,13 @@ def test_reciprocal_precision_rounding(self): def fn(x): return torch.reciprocal(x) - x = torch.randn(1000, device="cuda", dtype=torch.float32) + 0.1 + x = torch.randn(1000, device=device_type, dtype=torch.float32) + 0.1 self.common(fn, [x]) if __name__ == "__main__": from torch._inductor.test_case import run_tests - from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON + from torch.testing._internal.inductor_utils import HAS_GPU_AND_TRITON - if HAS_CUDA_AND_TRITON and not TEST_WITH_ASAN: + if HAS_GPU_AND_TRITON and not TEST_WITH_ASAN: run_tests(needs="filelock") diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 8c356501852ca..8c040da118812 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -22,7 +22,7 @@ from torch._inductor.codecache import FxGraphCache from torch._inductor.compile_fx import compile_fx_inner from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl -from torch._inductor.cudagraph_utils import FunctionID, PlaceholderInfo +from torch._inductor.cudagraph_utils import PlaceholderInfo from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import run_and_get_code from torch._ops import OpOverload @@ -65,6 +65,16 @@ ) from io import StringIO +from torch._library.opaque_object import OpaqueBase, register_opaque_type + + +class _CudagraphTestScaleFactor(OpaqueBase): + def __init__(self, factor): + self.factor = factor + + +register_opaque_type(_CudagraphTestScaleFactor, typ="reference") + def get_compile_fn(backend): if backend == "cudagraphs": @@ -746,8 +756,7 @@ def foo2(x): self.assertEqual(foo_opt(ones), foo(ones)) # paths children = self.get_root_children() - # one root with two children - self.assertEqual(children, [2]) + self.assertEqual(children, [1]) def test_end_recording_early(self): def foo(x): @@ -1929,7 +1938,9 @@ def foo2(args): del x self.assertEqual(all_live_block_count(), 0) - @unittest.skipUnless(IS_X86 and IS_LINUX, "cpp contexts are linux only") + @unittest.skipUnless( + (IS_X86 or IS_ARM64) and IS_LINUX, "cpp contexts are linux x86/aarch64 only" + ) @torch._inductor.config.patch("triton.cudagraph_trees_history_recording", True) @blas_library_context("cublas") def test_workspace_allocation_error(self): @@ -2233,25 +2244,6 @@ def foo(mod, inp): node = self.get_manager().current_node self.assertEqual(len(list(node.path_live_weakrefs())), 1) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) - def test_unstable_ptr(self): - import torch - - @torch.compile(mode="reduce-overhead") - def foo(m, inp): - return m(inp) - - def f(): - l = [] - m = torch.nn.Linear(20, 20).cuda() - for _ in range(4): - inp = torch.rand([20, 20], device="cuda") - foo(m, inp) - m.weight.data = torch.rand([20, 20], device="cuda") - - self.assertRaises(RuntimeError, f) - @requires_multigpu() def test_manager_per_device(self): def test(): @@ -2312,6 +2304,63 @@ def foo(x): FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception)) + def test_output_node_has_stack_traces_inference(self): + """Test that output_stack_traces on the output node provides + stack traces even when a post-grad pass strips them from arg nodes + in inference mode.""" + + def strip_stack_traces(graph): + for node in graph.nodes: + if node.op not in ("placeholder", "output"): + node.meta.pop("stack_trace", None) + + with config.patch(post_grad_custom_post_pass=strip_stack_traces): + + @torch.compile(mode="reduce-overhead") + def foo(x): + return x * x * x + + inp = torch.rand([4], device="cuda") + out = foo(inp).detach() + out2 = foo(inp).detach() + + with self.assertRaises(Exception) as exc: + out + out + + self.assertIn("x * x * x", repr(exc.exception)) + + def test_output_node_has_stack_traces_training(self): + """Test that output_stack_traces on the output node provides + stack traces even when a post-grad pass strips them from arg nodes + in training mode (fwd/bwd graph split via partitioner).""" + + def strip_stack_traces(graph): + for node in graph.nodes: + if node.op not in ("placeholder", "output"): + node.meta.pop("stack_trace", None) + + with config.patch(post_grad_custom_post_pass=strip_stack_traces): + + @torch.compile(mode="reduce-overhead") + def foo(x): + return x * x * x + + inp = torch.rand([4], device="cuda", requires_grad=True) + # Complete fwd+bwd to compile both graphs + torch.compiler.cudagraph_mark_step_begin() + foo(inp).sum().backward() + + # Now trigger the dealloc error + torch.compiler.cudagraph_mark_step_begin() + out = foo(inp).detach() + torch.compiler.cudagraph_mark_step_begin() + out2 = foo(inp).detach() + + with self.assertRaises(Exception) as exc: + out + out + + self.assertIn("x * x * x", repr(exc.exception)) + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") def test_conv_benchmark(self): with torch.backends.cudnn.flags( @@ -2446,9 +2495,12 @@ def foo(x): with warnings.catch_warnings(record=True) as w: out = foo(torch.rand([4, 4], device="cuda", requires_grad=True)) - FileCheck().check( - "Unable to hit fast path of CUDAGraphs because of pending" - ).run(str(w[0])) + # Match substring only; scan all warnings in case another warning fires first. + msgs = [str(x.message) for x in w] + self.assertTrue( + any("require backward" in m for m in msgs), + f"expected CUDAGraph pending-backward warning; got: {msgs}", + ) self.assertTrue(self.get_manager().new_graph_id().id == 0) def test_mark_step(self): @@ -2597,9 +2649,9 @@ def foo(x): t = torch.rand([32], device="cuda") self.assertEqual(foo(t), foo_c(t)) - FileCheck().check("skipping cudagraphs due to cpp wrapper enabled").run( - log_stream.getvalue() - ) + FileCheck().check( + "skipping cudagraphs due to cpp-wrapper does not support graph partition yet" + ).run(log_stream.getvalue()) self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) def test_storage_access_error(self): @@ -2609,6 +2661,25 @@ def test_storage_access_error(self): with self.assertRaisesRegex(Exception, "custom error msg"): device = x.untyped_storage() + def test_clear_storage_data_ptr_access_error(self): + x = torch.rand([4], device="cuda") + storage = x.untyped_storage() + storage_ptr = storage.data_ptr() + storage_impl_ptr = storage._cdata + + storage.resize_(0) + + torch._C._set_storage_data_ptr_access_error_msg( + storage_impl_ptr, "storage is dead" + ) + with self.assertRaisesRegex(Exception, "storage is dead"): + storage.data_ptr() + + torch._C._clear_storage_data_ptr_access_error_msg(storage_impl_ptr) + storage.resize_(4 * x.element_size()) + # Should not raise + storage.data_ptr() + def test_side_stream_memory_allocation(self): device = f"cuda:{self.device_idx}" @@ -2652,49 +2723,6 @@ def multi_stream_allocation(args): self.assertEqual(self.get_manager().new_graph_id().id, 1) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False) - def test_static_inputs_address_mutation_log(self): - class Goo(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(2, 2, device="cuda") - - def forward(self, x) -> torch.Tensor: - return self.linear(x) - - class Foo(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.static_tensor = torch.zeros((2, 2), device="cuda") - self.goo = Goo() - - def forward(self, x) -> torch.Tensor: - self.static_tensor.add_(torch.ones((2, 2), device="cuda")) - return self.static_tensor + x + self.goo(x) - - foo = Foo() - foo = torch.compile(foo, mode="reduce-overhead") - inp = torch.rand((2, 2), device="cuda") - - for _ in range(3): - foo(inp) - - # mutates static input tensors' addresses - foo.static_tensor = torch.ones((2, 2), device="cuda") - foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) - - with self.assertRaisesRegex( - Exception, - r"(?s)static input data pointer changed.\n" - r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*" - r"input name: primals_.*. data pointer changed from .* to .*. input stack trace:.*," - r" in forward\n.* self.static_tensor.add\_\(torch.ones\(\(2, 2\), device=\"cuda\"\)\).*\n", - ): - self.curr_node().run( - [foo.goo.linear.weight, foo.goo.linear.bias, foo.static_tensor, inp] - ) - def _run_iter(self, param, fn): fwd_output = fn(torch.ones(2, 2), param) fwd_output.sum().backward() @@ -2761,7 +2789,6 @@ def run_test(): self.assertEqual(self.get_manager().new_graph_id().id, 4) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_single_compile_param_inputs(self): # Verify that we can record multiple cudagraphs for a single # compiled function with param inputs @@ -2772,7 +2799,6 @@ def fn(x, y): self.run_static_input_param_test(fn, 4) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_single_compile_builtin_module(self): # Verify that we don't recompile when changing the param of a builtin module # and that we record another cudagraph @@ -2780,7 +2806,6 @@ def test_multi_dispatch_single_compile_builtin_module(self): self._module_test(torch.nn.Linear(2, 3, device="cuda")) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_single_compile_builtin_module_buffers(self): # Verify that we don't recompile when changing the buffer of a builtin module # and that we record another cudagraph @@ -2792,7 +2817,6 @@ def test_multi_dispatch_single_compile_builtin_module_buffers(self): @torch._inductor.config.patch("triton.cudagraphs", True) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_custom_module(self): # Test that we can correctly dispatch multiple graphs # if params of a custom module change @@ -2809,7 +2833,6 @@ def forward(self, x): ) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_custom_module_buffer(self): # Test that we can correctly dispatch multiple graphs # if buffers of a custom module change @@ -2833,7 +2856,6 @@ def forward(self, x): @torch._inductor.config.patch("triton.cudagraphs", True) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_child_node(self): # Test that we can correctly dispatch multiple graphs if a child node # in the tree has stable input pointers change @@ -2852,7 +2874,6 @@ def fn(x, p): self.run_static_input_param_test(fn, 5) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_multi_dispatch_parent_node(self): def fn(x, p): # Graph 1 @@ -2872,145 +2893,6 @@ def fn(x, p): self.run_static_input_param_test(fn, 6) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) - @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) - def test_fallback_to_eager_if_recompiling_too_many_times(self): - class Foo(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda")) - - def forward(self, x): - return x * self.param - - log_stream, ctx = logs_to_string( - "torch._inductor.cudagraph_utils", "cudagraphs" - ) - with ctx(): - # We have 3 graphs here - # None - # / \ - # (fwd w/ p1, Graph 0) (bwd w/p2, Graph2) - # (bwd w/ p1, Graph 1) - # All other graphs are skipped because we hit the max recording limit - # (=0 for each node and function pair) - fn_compiled = torch.compile(Foo(), mode="reduce-overhead") - for _ in range(3): - fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None - - # Change static tensor address - fn_compiled.param.data = torch.rand([2, 2], device="cuda") - fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - self.assertEqual(self.get_manager().new_graph_id().id, 3) - - FileCheck().check( - "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " - "on cudagraph node None due to static input data pointer changed." - ).run(log_stream.getvalue()) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - - @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) - @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) - def test_fallback_to_eager_if_recompiling_too_many_times_warn_only_once(self): - class Foo(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.param = torch.nn.Parameter(torch.rand([2, 2], device="cuda")) - - def forward(self, x): - return x * self.param - - log_stream, ctx = logs_to_string( - "torch._inductor.cudagraph_utils", "cudagraphs" - ) - with ctx(): - with torch.device("cuda"): - # We have 3 graphs here - # None - # / \ - # (fwd w/ p1, Graph 0) (bwd w/p2, Graph2) - # (bwd w/ p1, Graph 1) - # All other graphs are skipped because we hit the max recording limit - # (=0 for each node and function pair) - fn_compiled = torch.compile(Foo(), mode="reduce-overhead") - for _ in range(3): - fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None - - for _ in range(5): - # Change static tensor address - fn_compiled.param.data = torch.rand([2, 2], device="cuda") - fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() - fn_compiled.param.grad = None - - FileCheck().check_count( - "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " - "on cudagraph node None due to static input data pointer changed.", - 1, - exactly=True, - ).check_count( - "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) " - "on cudagraph node None due to static input data pointer changed.", - 1, - exactly=True, - ).run(log_stream.getvalue()) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 2) - - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True) - @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 0) - def test_fallback_to_eager_if_recompiling_too_many_times_due_to_cudagraph_managed_tensor( - self, - ): - # By setting triton.cudagraph_support_input_mutation=True, we force re-record - # if cudagraph managed tensor addresses changed. - @torch.compile(mode="reduce-overhead") - def foo(x): - return x + 1 - - @torch.compile(mode="reduce-overhead") - def goo(x): - return x * 2 - - for _ in range(3): - torch.compiler.cudagraph_mark_step_begin() - inp = torch.rand((2, 3), device="cuda") - y = foo(inp) - z = goo(y) - - log_stream, ctx = logs_to_string( - "torch._inductor.cudagraph_utils", "cudagraphs" - ) - with ctx(): - torch.compiler.cudagraph_mark_step_begin() - x = torch.rand(2, 3, device="cuda") - y = foo(x) - y_clone = y.clone() - z = goo(y_clone) - - # eager function should run successfully - for _ in range(5): - torch.compiler.cudagraph_mark_step_begin() - x = torch.rand(2, 3, device="cuda") - y = foo(x) - y_clone = y.clone() - z = goo(y_clone) - - FileCheck().check_count( - "skipping cudagraph due to function 1 exceeding max re-recording limit (=0) " - "on cudagraph node 0 due to cudagraph managed tensor data pointer changed", - 1, - exactly=True, - ).run(log_stream.getvalue()) - self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) - - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", False) - @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) @torch._inductor.config.patch("triton.cudagraph_unexpected_rerecord_limit", 1) def test_not_fallback_to_eager_if_have_not_recompiling_too_many_times(self): def fn(x, y): @@ -3025,7 +2907,6 @@ def fn(x, y): self.assertEqual(counters["inductor"]["cudagraph_skips"], 0) @torch._dynamo.config.patch("error_on_recompile", True) - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_no_rerecord_with_mark_static_address(self): class Mod(torch.nn.Module): def __init__(self): @@ -3455,20 +3336,8 @@ def forward(self, x) -> torch.Tensor: foo.static_tensor = torch.ones((2, 2), device="cuda") foo.goo.linear.bias = torch.nn.Parameter(torch.ones((2,), device="cuda")) - if torch._dynamo.config.inline_inbuilt_nn_modules: - for _ in range(3): - foo(inp) - else: - # Run with specific function id to avoid dynamo recompiling - self.get_manager().run( - [ - foo.goo.linear.weight, - foo.goo.linear.bias, - foo.static_tensor, - inp, - ], - FunctionID(0), - ) + for _ in range(3): + foo(inp) self.assertEqual(self.get_manager().new_graph_id().id, 2) @@ -3866,14 +3735,7 @@ def forward(self, x): self.assertEqual(self.get_manager().new_graph_id().id, 3) @torch._inductor.config.patch("graph_partition", True) - def test_graph_partition_backward_cpu_scalar_saved_tensor(self): - """ - With graph_partition, a CPU graph input (e.g. a scalar parameter) - may be device-copied to CUDA for CG partition triton kernels. - The graph output must still reference the original CPU tensor so - the backward's CPU C++ kernel can dereference it safely. - """ - + def test_graph_partition_cpu_scalar_used_in_cpu_op(self): class Mod(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -3912,6 +3774,145 @@ def forward(self, x): compiled_out = compiled_model(x) self.assertEqual(compiled_out, eager_out) + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_cpu_scalar_as_output(self): + """ + When a CPU scalar placeholder is moved to GPU by ConstructorMoverPass, + the forward graph's output must NOT replace the CPU placeholder with + the GPU copy. + """ + + class Mod(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(4, 4, device="cuda") + self.cpu_scale = torch.nn.Parameter(torch.tensor(1.0)) + + def forward(self, x): + return self.linear(x) * self.cpu_scale + + model = Mod() + x = torch.randn(4, 4, device="cuda") + + compiled_model = torch.compile(model, mode="reduce-overhead") + compiled_model(x) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + for _ in range(5): + output = compiled_model(x) + loss = criterion(output, torch.randint(0, 4, (4,)).cuda()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + eager_out = model(x) + compiled_out = compiled_model(x) + self.assertEqual(compiled_out, eager_out) + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_saved_activation_not_static(self): + # When the forward is partitioned, saved activations produced + # by inline code between partitions (e.g., DeviceCopy) are + # NOT at fixed addresses. The backward must not mark them as + # static inputs, or it would re-record on every iteration. + # Primals (params/buffers) should still be marked static. + from unittest.mock import patch + + from torch._inductor.utils import count_tangents, get_static_bw_input_idxs + + bw_graph = None + orig_bw = torch._inductor.compile_fx.compile_fx_backward + + def intercept_bw(gm, example_inputs, compiler_config_extra, **kwargs): + nonlocal bw_graph + bw_graph = gm + return orig_bw(gm, example_inputs, compiler_config_extra, **kwargs) + + class Mod(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + a = x * 2 + # CPU round-trip creates a DeviceCopy partition boundary. + # The .cuda() result is an activation saved for backward. + b = a.cpu().cuda() + c = b * b + return self.linear(c) + + model = Mod().cuda() + input_data = torch.randn(16, 16, device="cuda") + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + with patch("torch._inductor.compile_fx.compile_fx_backward", intercept_bw): + compiled_model = torch.compile(model, mode="reduce-overhead") + output = compiled_model(input_data) + loss = criterion(output, torch.randint(0, 10, (16,)).cuda()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.assertIsNotNone(bw_graph) + # count_tangents marks ALL saved tensors as static (old behavior) + all_static = list(range(count_tangents(bw_graph))) + # get_static_bw_input_idxs only marks primals as static + primal_static = get_static_bw_input_idxs(bw_graph) + # With a partitioned forward, only primals should be static, + # so primal_static should be a strict subset of all_static. + self.assertTrue(len(primal_static) < len(all_static)) + for idx in primal_static: + self.assertIn(idx, all_static) + + # Run a few more iterations to confirm stability + for _ in range(4): + output = compiled_model(input_data) + loss = criterion(output, torch.randint(0, 10, (16,)).cuda()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + @torch._inductor.config.patch("graph_partition", True) + def test_graph_partition_no_partition_keeps_static(self): + # When graph_partition is enabled but the forward has no unsafe + # ops, forward_is_partitioned should be False and all saved + # tensors remain static in the backward. + from unittest.mock import patch + + forward_partitioned = None + orig_bw = torch._inductor.compile_fx.compile_fx_backward + + def intercept_bw(gm, example_inputs, compiler_config_extra, **kwargs): + nonlocal forward_partitioned + forward_partitioned = compiler_config_extra.forward_is_partitioned.value + return orig_bw(gm, example_inputs, compiler_config_extra, **kwargs) + + class Mod(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(16, 16) + + def forward(self, x): + return self.linear(x * x + 1) + + model = Mod().cuda() + input_data = torch.randn(16, 16, device="cuda") + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + with patch("torch._inductor.compile_fx.compile_fx_backward", intercept_bw): + compiled_model = torch.compile(model, mode="reduce-overhead") + output = compiled_model(input_data) + loss = criterion(output, torch.randint(0, 10, (16,)).cuda()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.assertFalse(forward_partitioned) + @torch._inductor.config.patch("graph_partition", True) def test_graph_partition_cpu_only(self): class Mod(torch.nn.Module): @@ -5108,6 +5109,221 @@ def run(n): run(10) run(25) + @torch._inductor.config.patch("triton.skip_cudagraph_warmup", True) + def test_opaque_value_input_cudagraph(self): + """Opaque reference-type objects (e.g. DeviceMesh, ProcessGroup) + passed alongside tensors must be marked "static" so they are + excluded from the tensor-copy path (non_static_input_idx). + "Static" here just means "don't copy as a tensor", not that + the object is semantically immutable.""" + + sf = _CudagraphTestScaleFactor(3.0) + + def foo(args): + obj = args[0] + x = args[1] + args.clear() + return (x * obj.factor,) + + inp = torch.rand([4], device="cuda") + foo_cg = self.cudagraphify_impl(foo, [sf, inp], ()) + result = foo_cg([sf, inp]) + self.assertEqual(result[0], inp * 3.0) + + result2 = foo_cg([sf, inp]) + self.assertEqual(result2[0], inp * 3.0) + + class TestCUDAGraphPolicy(TestCase): + def setUp(self): + super().setUp() + counters.clear() + self._stack = contextlib.ExitStack() + self._stack.enter_context( + config.patch( + { + "triton.cudagraphs": True, + "triton.cudagraph_trees": True, + } + ) + ) + torch._dynamo.reset() + + def tearDown(self): + super().tearDown() + torch._dynamo.reset() + self._stack.close() + + def test_policy_cudagraphify_called(self): + """Custom policy's cudagraphify is called instead of the default.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + calls = [] + + class RecordingPolicy(CUDAGraphPolicy): + def cudagraphify(self, model, inputs, static_input_idxs, **kwargs): + calls.append( + { + "static_input_idxs": static_input_idxs, + "device_index": kwargs.get("device_index"), + } + ) + return model + + def foo(x): + return x * x + 1 + + with config.patch("cudagraph_policy", RecordingPolicy()): + compiled = torch.compile(foo) + x = torch.randn(4, device="cuda") + compiled(x) + compiled(x) + + self.assertGreater(len(calls), 0) + + def test_should_wrap_false_skips_cudagraph(self): + """When should_wrap returns False, cudagraph wrapping is skipped.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + class NoWrapPolicy(CUDAGraphPolicy): + def should_wrap(self, compiled_graph): + return False + + def foo(x): + return x + 1 + + with config.patch("cudagraph_policy", NoWrapPolicy()): + compiled = torch.compile(foo) + x = torch.randn(4, device="cuda") + result = compiled(x) + result = compiled(x) + + self.assertEqual(result, x + 1) + self.assertGreater(counters["inductor"]["cudagraph_skips"], 0) + + def test_wrap_output_called(self): + """Policy's wrap_output is called in BundledOutputCodeLoadable.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + wrapped = [] + + class WrapPolicy(CUDAGraphPolicy): + def wrap_output(self, output_code): + wrapped.append(type(output_code).__name__) + return output_code + + def foo(x): + return x * 2 + + with config.patch("cudagraph_policy", WrapPolicy()): + compiled = torch.compile(foo) + x = torch.randn(4, device="cuda") + result = compiled(x) + result = compiled(x) + + self.assertEqual(result, x * 2) + self.assertGreater(len(wrapped), 0) + self.assertIn("CompiledFxGraph", wrapped) + + def test_default_policy_matches_builtin(self): + """Default CUDAGraphPolicy produces same results as no policy.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + def foo(x): + return x * x + x + + x = torch.randn(4, device="cuda") + + compiled_default = torch.compile(foo) + ref = compiled_default(x) + ref = compiled_default(x) + + torch._dynamo.reset() + + with config.patch("cudagraph_policy", CUDAGraphPolicy()): + compiled_policy = torch.compile(foo) + out = compiled_policy(x) + out = compiled_policy(x) + + self.assertEqual(ref, out) + + @torch._inductor.config.patch("graph_partition", True) + def test_default_policy_matches_builtin_partition(self): + """Default CUDAGraphPolicy matches builtin when graph_partition=True.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + def foo(x): + return x * x + x + + x = torch.randn(4, device="cuda") + + compiled_default = torch.compile(foo) + ref = compiled_default(x) + ref = compiled_default(x) + + torch._dynamo.reset() + + with config.patch("cudagraph_policy", CUDAGraphPolicy()): + compiled_policy = torch.compile(foo) + out = compiled_policy(x) + out = compiled_policy(x) + + self.assertEqual(ref, out) + + @torch._inductor.config.patch("graph_partition", True) + def test_policy_cudagraphify_partition(self): + """Custom policy's cudagraphify is called when graph_partition is enabled.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + calls = [] + + class RecordingPolicy(CUDAGraphPolicy): + def cudagraphify(self, model, inputs, static_input_idxs, **kwargs): + calls.append( + { + "static_input_idxs": static_input_idxs, + "device_index": kwargs.get("device_index"), + } + ) + return model + + def foo(x): + return x * x + 1 + + with config.patch("cudagraph_policy", RecordingPolicy()): + compiled = torch.compile(foo) + x = torch.randn(4, device="cuda") + compiled(x) + compiled(x) + + self.assertGreater(len(calls), 0) + + def test_should_wrap_false_with_wrap_output(self): + """should_wrap=False skips inner wrapping; wrap_output does outer wrapping.""" + from torch._inductor.cudagraph_utils import CUDAGraphPolicy + + wrap_calls = [] + + class OuterOnlyPolicy(CUDAGraphPolicy): + def should_wrap(self, compiled_graph): + return False + + def wrap_output(self, output_code): + wrap_calls.append(type(output_code).__name__) + return output_code + + def foo(x): + return x * x + 1 + + with config.patch("cudagraph_policy", OuterOnlyPolicy()): + compiled = torch.compile(foo) + x = torch.randn(4, device="cuda") + result = compiled(x) + result = compiled(x) + + self.assertEqual(result, x * x + 1) + self.assertGreater(counters["inductor"]["cudagraph_skips"], 0) + self.assertGreater(len(wrap_calls), 0) + class TestSAC(TestCase): def _make_observer_mode(self): class ObserverMode(TorchDispatchMode): @@ -5511,6 +5727,7 @@ def fn(x, y): ) instantiate_parametrized_tests(CudaGraphTreeTests) + instantiate_parametrized_tests(TestCUDAGraphPolicy) instantiate_parametrized_tests(TestSAC) # OpInfo-based test for index/scatter ops with cudagraphs diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index d92f572b68d07..d5b9572fa54f3 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -26,6 +26,7 @@ skipIfXpu, ) from torch.testing._internal.inductor_utils import ( + GPU_TYPE, HAS_CPU, HAS_GPU, HAS_TRITON, @@ -45,8 +46,12 @@ def setUp(self) -> None: """Set up test environment with appropriate device and dtype.""" super().setUp() torch._dynamo.reset() - self.device = "cuda" if HAS_GPU else "cpu" - self.dtype = torch.float16 if self.device == "cuda" else torch.float32 + self.device = GPU_TYPE if HAS_GPU else "cpu" + self.dtype = ( + torch.float16 + if self.device == "cuda" or self.device == "xpu" + else torch.float32 + ) # Clear any previous lowering registrations to ensure test isolation from torch._inductor.lowering import user_lowerings @@ -164,7 +169,6 @@ def _create_mlp_inputs( ) return input_tensor, gate_weight, up_weight, down_weight - @skipIfXpu def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self): """Test RMSNorm autotuning with multiple decomposition variants and dynamic shapes. @@ -247,22 +251,14 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): """ # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256 - sd = k**0.25 - a = ( - torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False) - / sd - ) - b = ( - torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False) - / sd - ) + a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False) + b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False) bias = ( torch.randn(n, device=self.device, dtype=self.dtype, requires_grad=False) * 0.1 ) return a, b, bias - @skipIfXpu def test_decompose_k_custom_op_autotune_dynamic_config_for_input_shape(self): """Test decompose_k autotuning with with epilogue fusion(matmul+bias+relu+scale) and dynamic config generation based on matmul input shapes. @@ -378,12 +374,11 @@ def reference_model(a, b, bias): torch.testing.assert_close( compiled_result, expected, - rtol=2e-3, - atol=5e-3, - # msg=f"Failed for shape ({m}, {k}, {n})", + rtol=2e-1, + atol=5e-1, + msg=f"Failed for shape ({m}, {k}, {n})", ) - @skipIfXpu def test_multi_parameter_tuning(self): """Test autotuning with multiple parameters for combinatorial parameter exploration. @@ -483,7 +478,6 @@ def _( multi_param_op, (test_x, test_factor), expected_result, "MultiParam" ) - @skipIfXpu def test_range_based_static_shape_no_cond_dispatch(self): """Test dispatch code generation for static vs dynamic shapes. @@ -1436,6 +1430,7 @@ def test_model(x, weight): torch.testing.assert_close(result, test_x @ test_weight, rtol=1e-1, atol=1e-1) + @skipIfXpu def test_cudagraph_memory_cleanup(self): """Test that CUDA graph destruction automatically cleans up cuBLAS workspaces.""" if self.device != "cuda": @@ -1480,6 +1475,7 @@ def test_cudagraph_memory_cleanup(self): f"Memory leak detected: baseline={baseline_memory}, after_cleanup={memory_after_cleanup}", ) + @skipIfXpu def test_cudagraph_memory_cleanup_benchmarker(self): """Test that CUDA graph benchmarking cleans up memory without leaking.""" if self.device != "cuda": diff --git a/test/inductor/test_custom_op_out_lowering.py b/test/inductor/test_custom_op_out_lowering.py index 96ffd12e6a4fe..e817c74bf2793 100644 --- a/test/inductor/test_custom_op_out_lowering.py +++ b/test/inductor/test_custom_op_out_lowering.py @@ -26,7 +26,7 @@ def _register_add_one_ops(self, lib): lib.define("add_one(Tensor x) -> Tensor") lib.define( "add_one.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", - tags=(torch.Tag.out_variant,), + tags=(torch.Tag.out,), ) def _add_one_impl(x: torch.Tensor) -> torch.Tensor: @@ -69,7 +69,7 @@ def _register_split_add_ops(self, lib): lib.define("split_add(Tensor x, float a, float b) -> (Tensor, Tensor)") lib.define( "split_add.out(Tensor x, float a, float b, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", - tags=(torch.Tag.out_variant,), + tags=(torch.Tag.out,), ) def _split_add_impl(x, a, b): diff --git a/test/inductor/test_cutlass_backend.py b/test/inductor/test_cutlass_backend.py index 1b15ef2c02ce1..c8c3026e5c309 100644 --- a/test/inductor/test_cutlass_backend.py +++ b/test/inductor/test_cutlass_backend.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import gc import itertools import logging import math @@ -18,6 +19,7 @@ ) from torch._inductor.utils import clear_caches from torch.export import Dim +from torch.testing._internal.common_utils import random_matrix_with_scaled_reduction_dim from torch.testing._internal.logging_utils import log_settings from torch.utils import _pytree as pytree @@ -33,6 +35,7 @@ from torch._dynamo import config as dynamo_config from torch._dynamo.utils import counters from torch._inductor import config +from torch._inductor.codecache import XPUCodeCache from torch._inductor.codegen.cutlass.kernel import CUTLASSTemplateCaller from torch._inductor.codegen.cutlass.utils import _gen_ops_cached, get_max_alignment from torch._inductor.exc import InductorError @@ -44,23 +47,32 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FP8, + SM100OrLater, SM80OrLater, SM90OrLater, ) +from torch.testing._internal.common_device_type import skipCUDAIf, skipXPUIf from torch.testing._internal.common_utils import ( IN_RE_WORKER, instantiate_parametrized_tests, IS_FBCODE, parametrize, ) +from torch.testing._internal.common_xpu import Xe2_Or_Later from torch.testing._internal.inductor_utils import ( _quantize_rowwise, _quantize_tensorwise, + GPU_TYPE, HAS_CPU, HAS_CUDA_AND_TRITON, ) +# We don't need triton in this test suite. +HAS_XPU = torch.xpu.is_available() +HAS_CUDA = torch.cuda.is_available() +HAS_GPU = HAS_CUDA or HAS_XPU + torch.set_float32_matmul_precision("high") if HAS_CUDA_AND_TRITON: torch.cuda.memory._set_allocator_settings("expandable_segments:False") @@ -122,7 +134,7 @@ def _check_if_instances_equal(op1, op2) -> bool: def gen_args(op, shape, dtype=torch.float16): if op in bin_ops_under_test: - return (torch.rand(*shape, device="cuda:0", dtype=dtype),) + return (torch.rand(*shape, device=f"{GPU_TYPE}:0", dtype=dtype),) else: return () @@ -133,7 +145,8 @@ def gen_args(op, shape, dtype=torch.float16): "max_autotune_gemm_backends": "CUTLASS", "cutlass.cutlass_max_profiling_configs": 1, "benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet - "cutlass.cutlass_tma_only": True, + "cutlass.cutlass_tma_only": GPU_TYPE + == "cuda", # Only CUDA requires TMA for EVT. "cutlass.cutlass_epilogue_fusion_enabled": True, } ) @@ -158,9 +171,12 @@ def select_no_algorithm(*args, **kwargs): @instantiate_parametrized_tests class TestCutlassBackend(TestCase): + # device_type of each test case is necessary for skipCUDAIf decorator. + device_type = GPU_TYPE + def setUp(self): - if not HAS_CUDA_AND_TRITON: - self.skipTest("CUDA and triton are not available") + if not HAS_GPU: + self.skipTest(f"{GPU_TYPE} and triton are not available") if torch.version.hip: self.skipTest("CUTLASS backend is not supported on HIP") @@ -184,13 +200,19 @@ def setUp(self): def tearDown(self): super().tearDown() clear_caches() + if GPU_TYPE == "xpu": + for dll in XPUCodeCache.dll_cache.values(): + dll.close() + XPUCodeCache.dll_cache.clear() + gc.collect() + torch.xpu.empty_cache() def run_evt_test(self, model, op, shape, num_fusions=1): M, N = shape - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N)) - model = model.cuda() + model = model.to(GPU_TYPE) result = torch.compile(model)(a, b, extra_args) ref_result = model(a, b, extra_args) @@ -214,7 +236,8 @@ def test_check_paths(self): self.assertTrue(os.path.exists(cutlass_mock_pydot_path)) self.assertTrue(os.path.exists(cutlass_mock_scipy_path)) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_threshold(self): """ @@ -224,8 +247,8 @@ def test_max_autotune_cutlass_threshold(self): def mm(a, b): return a @ b - a = torch.randn(100, 10).cuda().half() - b = torch.randn(100, 10).cuda().half().t() + a = torch.randn(100, 10).to(GPU_TYPE).half() + b = torch.randn(100, 10).to(GPU_TYPE).half().t() with config.patch( { @@ -248,6 +271,7 @@ def mm(a, b): self.assertEqual(choices, []) + @skipXPUIf(not Xe2_Or_Later, "") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_import_cutlass(self): from torch._inductor.codegen.cutlass.utils import try_import_cutlass @@ -257,6 +281,7 @@ def test_import_cutlass(self): import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401 import cutlass_library # noqa: F401 + @skipXPUIf(not Xe2_Or_Later, "") def test_cutlass_key(self): from torch._inductor.codegen.cutlass.utils import try_import_cutlass @@ -265,7 +290,8 @@ def test_cutlass_key(self): self.assertIsNotNone(cutlass_key()) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_subproc_mm(self): """ @@ -277,8 +303,8 @@ def test_cutlass_backend_subproc_mm(self): M, N, K = 4096, 2048, 25728 - a = torch.randn(M, K).cuda().half() - b = torch.randn(N, K).cuda().half().t() + a = torch.randn(M, K).to(GPU_TYPE).half() + b = torch.randn(N, K).to(GPU_TYPE).half().t() with config.patch( { @@ -291,9 +317,17 @@ def test_cutlass_backend_subproc_mm(self): ): Y_compiled = torch.compile(torch.mm)(a, b) Y = torch.mm(a, b) - torch.testing.assert_close(Y_compiled, Y) + if GPU_TYPE == "xpu": + atol = 1e-3 # default is 1e-5 + rtol = 1e-3 # default is 1e-3 + else: + atol = None + rtol = None - @unittest.skipIf(not SM90OrLater, "need sm_90") + torch.testing.assert_close(Y_compiled, Y, atol=atol, rtol=rtol) + + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @parametrize("dtype", (torch.float16, torch.bfloat16)) def test_cutlass_backend_subproc_addmm(self, dtype): @@ -302,10 +336,16 @@ def test_cutlass_backend_subproc_addmm(self, dtype): """ M, N, K = 4096, 2048, 25728 - dtype = torch.float16 - a = torch.randn(M, K, dtype=dtype).cuda() - b = torch.randn(N, K, dtype=dtype).cuda().t() + # Scale inputs by 1/sqrt(K) so that the matmul output has O(1) + # magnitude, avoiding large accumulation errors in half precision + # that would require loose tolerances. + a = random_matrix_with_scaled_reduction_dim( + M, K, dtype=dtype, device=GPU_TYPE, reduction_dim=-1 + ) + b = random_matrix_with_scaled_reduction_dim( + N, K, dtype=dtype, device=GPU_TYPE, reduction_dim=-1 + ).t() x_shapes = [ (M, N), @@ -330,12 +370,20 @@ def test_cutlass_backend_subproc_addmm(self, dtype): torch._dynamo.reset() clear_caches() - x = torch.randn(x_shape).cuda().to(dtype) + x = torch.randn(x_shape).to(GPU_TYPE).to(dtype) Y_compiled = torch.compile(torch.addmm)(x, a, b, alpha=alpha, beta=beta) Y = torch.addmm(x, a, b, alpha=alpha, beta=beta) - torch.testing.assert_close(Y_compiled, Y) + if GPU_TYPE == "xpu": + atol = 1e-3 + rtol = 1e-3 + else: + # use default + atol = None + rtol = None + torch.testing.assert_close(Y_compiled, Y, atol=atol, rtol=rtol) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_subproc_bmm(self): """ @@ -344,8 +392,8 @@ def test_cutlass_backend_subproc_bmm(self): B, M, N, K = 10, 4096, 2048, 25728 - a = torch.randn(B, M, K).cuda().half() - b = torch.randn(B, N, K).cuda().half().permute(0, 2, 1) + a = torch.randn(B, M, K).to(GPU_TYPE).half() + b = torch.randn(B, N, K).to(GPU_TYPE).half().permute(0, 2, 1) with config.patch( { @@ -360,7 +408,8 @@ def test_cutlass_backend_subproc_bmm(self): Y = torch.bmm(a, b) torch.testing.assert_close(Y_compiled, Y) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_diff_matmul_share_same_kernel(self, dynamic): @@ -376,9 +425,9 @@ def forward(self, a, b, c): return ab, ac model = MyModel() - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() - c = torch.randn(512, 16).cuda().half().t() + a = torch.randn(128, 16).to(GPU_TYPE).half() + b = torch.randn(128, 16).to(GPU_TYPE).half().t() + c = torch.randn(512, 16).to(GPU_TYPE).half().t() with config.patch( { @@ -403,7 +452,8 @@ def forward(self, a, b, c): 2, ).run(codes[0]) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_number_mm_precompiles(self): torch._dynamo.utils.counters.clear() @@ -418,9 +468,9 @@ def forward(self, a, b, c): return ab model = MyModel() - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() - c = torch.randn(512, 16).cuda().half().t() + a = torch.randn(128, 16).to(GPU_TYPE).half() + b = torch.randn(128, 16).to(GPU_TYPE).half().t() + c = torch.randn(512, 16).to(GPU_TYPE).half().t() with config.patch( { @@ -454,7 +504,8 @@ def forward(self, a, b, c): ) # NOTE: right now tuned_mm doesn't support cutlass 2x, which is used by A100 - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float16, torch.bfloat16)) @@ -483,10 +534,13 @@ class MyModel(torch.nn.Module): def forward(self, a, b): return a @ b - model = MyModel().cuda() + model = MyModel().to(GPU_TYPE) inputs = [ - (torch.randn(M, K).cuda().to(dtype), torch.randn(K, N).cuda().to(dtype)) + ( + torch.randn(M, K).to(GPU_TYPE).to(dtype), + torch.randn(K, N).to(GPU_TYPE).to(dtype), + ) for (M, N, K) in shapes ] @@ -520,7 +574,8 @@ def forward(self, a, b): torch.testing.assert_close(actual, expected) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "XPU SYCL-TLA has not supported fp8 yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float8_e4m3fn,)) @@ -549,7 +604,7 @@ def test_max_autotune_cutlass_backend_fp8_scaled_mm( for shape in shapes: M, N, K = shape output_dtype = torch.bfloat16 - device = "cuda" + device = GPU_TYPE x = torch.randn(M, K, dtype=output_dtype, device=device) w = torch.randn(N, K, dtype=output_dtype, device=device) @@ -587,7 +642,7 @@ def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale): if dynamic else None ) - model = MyModel().cuda() + model = MyModel().to(GPU_TYPE) with ( config.patch( @@ -631,10 +686,12 @@ def test_cutlass_backend_fp8_scaled_mm_mixed_dtypes(self): m, k, n = 256, 256, 256 # Create mixed FP8 dtypes: e4m3fn x e5m2 - a8 = torch.randn(m, k, device="cuda", dtype=torch.float16).to( + a8 = torch.randn(m, k, device=GPU_TYPE, dtype=torch.float16).to( torch.float8_e4m3fn ) - b8 = torch.randn(k, n, device="cuda", dtype=torch.float16).to(torch.float8_e5m2) + b8 = torch.randn(k, n, device=GPU_TYPE, dtype=torch.float16).to( + torch.float8_e5m2 + ) # _scaled_mm requires mat2 to be column-major b8 = b8.t().contiguous().t() @@ -651,7 +708,8 @@ def scaled_mm_fn(a, b): expected = scaled_mm_fn(a8, b8) torch.testing.assert_close(actual, expected, rtol=1e-2, atol=0.05) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float16, torch.bfloat16)) @@ -671,7 +729,7 @@ class MyModel(torch.nn.Module): def forward(self, x, a, b): return torch.addmm(x, a, b) - model = MyModel().cuda() + model = MyModel().to(GPU_TYPE) # M, N, K shapes = [ (128, 128, 16), @@ -691,9 +749,9 @@ def forward(self, x, a, b): inputs = [ ( - torch.randn(x_shape(M, N)).cuda().to(dtype), - torch.randn(M, K).cuda().to(dtype), - torch.randn(N, K).cuda().to(dtype).t(), + torch.randn(x_shape(M, N)).to(GPU_TYPE).to(dtype), + torch.randn(M, K).to(GPU_TYPE).to(dtype), + torch.randn(N, K).to(GPU_TYPE).to(dtype).t(), ) for (M, N, K) in shapes ] @@ -729,7 +787,18 @@ def forward(self, x, a, b): compiled_model = torch.compile(model, dynamic=dynamic) actual = [compiled_model(*input) for input in inputs] - torch.testing.assert_close(actual, expected) + assert_close_kwargs = {} + if dynamic and SM90OrLater: + # SM90+ CUTLASS addmm currently differs from eager by a small + # output-precision quantum on this test across multiple + # parametrizations. Keep the relaxation scoped to this test + # and stay tighter for float16 than bfloat16. + assert_close_kwargs = { + "rtol": 1.6e-2 if dtype == torch.bfloat16 else 1e-3, + "atol": 1e-2 if dtype == torch.bfloat16 else 2e-3, + } + + torch.testing.assert_close(actual, expected, **assert_close_kwargs) @unittest.skipIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @@ -766,9 +835,9 @@ def tracking_benchmark_choices(cls, choices, autotune_args, **kwargs): AlgorithmSelectorCache.benchmark_choices = tracking_benchmark_choices try: M, K, N = 256, 3520, 2048 - bias = torch.randn(N, device="cuda", dtype=torch.bfloat16) - x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - w = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + bias = torch.randn(N, device=GPU_TYPE, dtype=torch.bfloat16) + x = torch.randn(M, K, device=GPU_TYPE, dtype=torch.bfloat16) + w = torch.randn(K, N, device=GPU_TYPE, dtype=torch.bfloat16) with config.patch( { @@ -792,7 +861,8 @@ def tracking_benchmark_choices(cls, choices, autotune_args, **kwargs): finally: AlgorithmSelectorCache.benchmark_choices = original_benchmark_choices - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False, True)) @parametrize("use_aoti", (False, True)) @parametrize("dtype", (torch.float16, torch.bfloat16)) @@ -814,7 +884,7 @@ class MyModel(torch.nn.Module): def forward(self, a, b): return torch.bmm(a, b) - model = MyModel().cuda() + model = MyModel().to(GPU_TYPE) # B, M, N, K shapes = [ (10, 4096, 2048, 25728), @@ -826,12 +896,18 @@ def forward(self, a, b): for B, M, N, K in shapes: if use_expand: # Create A using unsqueeze and expand - A = torch.randn(M, K).cuda().to(dtype).unsqueeze(0).expand(B, -1, -1) + A = ( + torch.randn(M, K) + .to(GPU_TYPE) + .to(dtype) + .unsqueeze(0) + .expand(B, -1, -1) + ) else: # Original method - A = torch.randn(B, M, K).cuda().to(dtype) + A = torch.randn(B, M, K).to(GPU_TYPE).to(dtype) - B_tensor = torch.randn(B, N, K).cuda().to(dtype).permute(0, 2, 1) + B_tensor = torch.randn(B, N, K).to(GPU_TYPE).to(dtype).permute(0, 2, 1) inputs.append((A, B_tensor)) dynamic_shapes = ( { @@ -858,7 +934,8 @@ def forward(self, a, b): actual = [compiled_model(*input) for input in inputs] torch.testing.assert_close(actual, expected) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "streamk kernels not supported on xpu cutlass backend yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_backend_regular_mm_streamk( self, dynamic: bool = False, max_autotune_gemm_backends: str = "CUTLASS" @@ -892,15 +969,16 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk( 16384, ), ): - a = torch.randn(M, K).cuda().half() - b = torch.randn(N, K).cuda().half().t() + a = torch.randn(M, K).to(GPU_TYPE).half() + b = torch.randn(N, K).to(GPU_TYPE).half().t() Y_compiled = compiled_model(a, b) Y = torch.mm(a, b) # we need relaxed numerical limits due to the sheer size of the # matmuls involved. Many small addition differences add up. torch.testing.assert_close(Y_compiled, Y, atol=0.01, rtol=0.01) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "streamk kernels not supported on xpu cutlass backend yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_streamk_with_dynamic( self, ): @@ -911,8 +989,8 @@ def test_streamk_with_dynamic( shape. Without a correct workspace, the kernel will fail at runtime. """ - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() + a = torch.randn(128, 16).to(GPU_TYPE).half() + b = torch.randn(128, 16).to(GPU_TYPE).half().t() with config.patch( { @@ -924,7 +1002,8 @@ def test_streamk_with_dynamic( with self.assertRaisesRegex(InductorError, r".*NoValidChoicesError.*"): _ = torch.compile(torch.mm, dynamic=True)(a, b) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "streamk kernels not supported on xpu cutlass backend yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_streamk_with_static( self, ): @@ -941,8 +1020,8 @@ def test_streamk_with_static( for shape in shapes: M, N, K = shape - a = torch.randn(M, K).cuda().half() - b = torch.randn(N, K).cuda().half().t() + a = torch.randn(M, K).to(GPU_TYPE).half() + b = torch.randn(N, K).to(GPU_TYPE).half().t() with config.patch( { @@ -969,11 +1048,11 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( # it can happen that no Cutlass 3.x op is available # that allows fusions if batch_size is None: - a = torch.randn(256, 32).cuda() - b = torch.randn(256, 32).cuda().t() + a = torch.randn(256, 32).to(GPU_TYPE) + b = torch.randn(256, 32).to(GPU_TYPE).t() else: - a = torch.randn(batch_size, 256, 32).cuda() - b = torch.randn(batch_size, 256, 32).cuda().permute(0, 2, 1) + a = torch.randn(batch_size, 256, 32).to(GPU_TYPE) + b = torch.randn(batch_size, 256, 32).to(GPU_TYPE).permute(0, 2, 1) if fp16: a = a.half() b = b.half() @@ -984,7 +1063,6 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( "autotune_in_subproc": True, "max_autotune_gemm_backends": max_autotune_gemm_backends, "cutlass.cutlass_max_profiling_configs": 4, - "cuda.version": "12.2", # required to enable the Kernels we need } ): counters["inductor"]["cutlass_epilogue_fusion_counter"] = 0 @@ -999,7 +1077,8 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion( ) torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self): def mm(a, b): return (a @ b) * 3.0 @@ -1008,7 +1087,8 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_chained_fusion_fp16_fp32acc(self): def mm(a, b): return (a @ b) * 3.3 - 1.234 @@ -1017,7 +1097,8 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_relu_fusion_fp16_fp32acc(self): def mm(a, b): return torch.nn.functional.relu((a @ b) * 3.3 - 1.234) @@ -1027,7 +1108,8 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_relu6_fusion_fp16_fp32acc(self): def mm(a, b): return torch.clamp(torch.nn.functional.relu(a @ b), max=6.0) @@ -1037,7 +1119,8 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_no_fusion_dtype_mismatch(self): def mm(a, b): # this should not be fused, since the output dtype is different from the matmul dtype @@ -1047,7 +1130,8 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_max_autotune_cutlass_backend_shape_dependent_normalization_fusion(self): def mm(a, b): return (a @ b) / b.size(1) @@ -1056,8 +1140,9 @@ def mm(a, b): fp16=True, expected_fuse_count=0, mm=mm ) + @skipXPUIf(True, "int_mm not supported on xpu cutlass backend") # TODO: Enable dynamic test cases when dynamic support is added. - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipCUDAIf(not SM90OrLater, "need sm_90") @parametrize("dynamic", (False,)) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_max_autotune_cutlass_backend_int_mm( @@ -1076,8 +1161,8 @@ def mm(a, b): # this combination, so it's excluded from the test). Also, # for CUTLASS alignment requirements, number of columns in # both tensors has to be divisible by 16. - a = torch.randint(0, 5, (100, 16), dtype=torch.int8).cuda() - b = torch.randint(0, 5, (32, 16), dtype=torch.int8).cuda().T + a = torch.randint(0, 5, (100, 16), dtype=torch.int8).to(GPU_TYPE) + b = torch.randint(0, 5, (32, 16), dtype=torch.int8).to(GPU_TYPE).T with config.patch( { @@ -1092,7 +1177,8 @@ def mm(a, b): torch.testing.assert_close(Y_compiled, Y) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_force_cutlass_backend_aoti_dynamic(self): class MyModel(torch.nn.Module): def forward(self, x, w): @@ -1113,8 +1199,8 @@ def forward(self, x, w): "w": {0: K, 1: N}, } - x = torch.randn(M, K).cuda().half() - w = torch.randn(N, K).cuda().half().t() + x = torch.randn(M, K).to(GPU_TYPE).half() + w = torch.randn(N, K).to(GPU_TYPE).half().t() actual = AOTIRunnerUtil.run( model, @@ -1125,7 +1211,8 @@ def forward(self, x, w): torch.testing.assert_close(expected, actual) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_force_cutlass_backend_aoti_cexpr_codegen(self): class MyModel(torch.nn.Module): def forward(self, x, w): @@ -1151,8 +1238,8 @@ def forward(self, x, w): "w": None, } - x = torch.randn(M, K).cuda().half() - w = torch.randn(N, K).cuda().half().t() + x = torch.randn(M, K).to(GPU_TYPE).half() + w = torch.randn(N, K).to(GPU_TYPE).half().t() actual = AOTIRunnerUtil.run( model, @@ -1163,7 +1250,8 @@ def forward(self, x, w): torch.testing.assert_close(expected, actual) @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "streamk kernels not supported on xpu cutlass backend yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_aoti_workspace_ptr(self): class MyModel(torch.nn.Module): def forward(self, x, w): @@ -1174,15 +1262,15 @@ def forward(self, x, w): "max_autotune": True, "autotune_in_subproc": False, "max_autotune_gemm_backends": "CUTLASS", - "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k_warpspecialized_cooperative_epi_nosmem", + "cutlass.cutlass_op_allowlist_regex": "128x256x64.*stream_k", "cutlass.cutlass_max_profiling_configs": 1, } ): model = MyModel() M, N, K = 200, 5216, 10_432 - x = torch.randn(M, K).cuda().half() - w = torch.randn(N, K).cuda().half().t() + x = torch.randn(M, K).to(GPU_TYPE).half() + w = torch.randn(N, K).to(GPU_TYPE).half().t() actual = AOTIRunnerUtil.run( model, @@ -1208,10 +1296,10 @@ def mm(a, b): return torch.mm(a, b) m, n, k = 32, 8, 64 - mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).cuda().half() - a = torch.rand(m, k).cuda().half() * mask + mask = torch.tensor([0, 0, 1, 1]).tile(m, k // 4).to(GPU_TYPE).half() + a = torch.rand(m, k).to(GPU_TYPE).half() * mask a_sparse = to_sparse_semi_structured(a) - b = torch.rand(k, n).cuda().half() + b = torch.rand(k, n).to(GPU_TYPE).half() with config.patch( { @@ -1245,7 +1333,8 @@ def mm(a, b): f"Expected cutlass_kernels_count > 0, got {cutlass_kernels_count}" ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_op_denylist( self, @@ -1253,9 +1342,9 @@ def test_cutlass_backend_op_denylist( def my_addmm(x, a, b, alpha, beta): return torch.addmm(x, a, b, alpha=beta, beta=alpha) - x = torch.randn((128, 128)).cuda().half() - a = torch.randn(128, 128).cuda().half() - b = torch.randn(128, 128).cuda().half().t() + x = torch.randn((128, 128)).to(GPU_TYPE).half() + a = torch.randn(128, 128).to(GPU_TYPE).half() + b = torch.randn(128, 128).to(GPU_TYPE).half().t() with fresh_cache(): with config.patch( @@ -1298,7 +1387,8 @@ def my_addmm(x, a, b, alpha, beta): if cuda_template_count <= 0: raise AssertionError("No CUTLASSTemplateCaller choices") - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "Intel cutlass doesn't have pingpong kernels yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_op_allowlist( self, @@ -1306,9 +1396,9 @@ def test_cutlass_backend_op_allowlist( def addmm(x, a, b, alpha, beta): return torch.addmm(x, a, b, alpha=alpha, beta=beta) - x = torch.randn((128, 128)).cuda().half() - a = torch.randn(128, 128).cuda().half() - b = torch.randn(128, 128).cuda().half().t() + x = torch.randn((128, 128)).to(GPU_TYPE).half() + a = torch.randn(128, 128).to(GPU_TYPE).half() + b = torch.randn(128, 128).to(GPU_TYPE).half().t() with fresh_cache(): with config.patch( @@ -1351,7 +1441,8 @@ def addmm(x, a, b, alpha, beta): if cuda_template_count <= 0: raise AssertionError("No CUTLASSTemplateCaller choices") - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "fp8 not supported on xpu cutlass backend yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_fp8_scaled_mm_fast_accum_filtering( self, @@ -1359,7 +1450,7 @@ def test_cutlass_backend_fp8_scaled_mm_fast_accum_filtering( float8_dtype = torch.float8_e4m3fn # Only bf16 output type is supported for row-wise scaling, not fp32 output_dtype: torch.dtype = torch.bfloat16 - device = "cuda" + device = GPU_TYPE M, K, N = 128, 128, 128 # Matmul Y = X [M, K] x W [N, K] x = torch.randn(M, K, dtype=output_dtype, device=device) w = torch.randn(N, K, dtype=output_dtype, device=device) @@ -1438,10 +1529,12 @@ def run_test(use_fast_accum): if cuda_template_count <= 0: raise AssertionError("No CUTLASSTemplateCaller choices") - run_test(True) + if not SM100OrLater: + run_test(True) run_test(False) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_shape_coverage_mm( self, @@ -1455,19 +1548,25 @@ def test_cutlass_backend_shape_coverage_mm( """ inputs = [ - (torch.randn(128, 500).cuda().half(), torch.randn(500, 576).cuda().half()), ( - torch.randn(500, 128).cuda().half(), - torch.randn(128, 576).cuda().half(), + torch.randn(128, 500).to(GPU_TYPE).half(), + torch.randn(500, 576).to(GPU_TYPE).half(), ), - (torch.randn(128, 250).cuda().half(), torch.randn(250, 576).cuda().half()), ( - torch.randn(250, 128).cuda().half(), - torch.randn(128, 576).cuda().half(), + torch.randn(500, 128).to(GPU_TYPE).half(), + torch.randn(128, 576).to(GPU_TYPE).half(), ), ( - torch.randn(125, 128).cuda().half(), - torch.randn(128, 576).cuda().half(), + torch.randn(128, 250).to(GPU_TYPE).half(), + torch.randn(250, 576).to(GPU_TYPE).half(), + ), + ( + torch.randn(250, 128).to(GPU_TYPE).half(), + torch.randn(128, 576).to(GPU_TYPE).half(), + ), + ( + torch.randn(125, 128).to(GPU_TYPE).half(), + torch.randn(128, 576).to(GPU_TYPE).half(), ), ] @@ -1519,7 +1618,8 @@ def test_cutlass_backend_shape_coverage_mm( f"M={M}, N={N}, K={K}", ) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM80OrLater, "need sm_80") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_get_max_alignment(self): l4 = FixedLayout( @@ -1580,13 +1680,14 @@ def test_get_max_alignment(self): m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )." ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_standalone_runner(self): max_autotune_gemm_backends = "CUTLASS" - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() + a = torch.randn(128, 16).to(GPU_TYPE).half() + b = torch.randn(128, 16).to(GPU_TYPE).half().t() with config.patch( { @@ -1598,14 +1699,14 @@ def test_standalone_runner(self): ): from tempfile import NamedTemporaryFile - from torch._inductor.codegen.cuda.compile_utils import ( - cuda_standalone_runner_compile_command, - CUDACompileSourceCapturingContext, + from torch._inductor.codegen.cutlass.utils import ( + cutlass_standalone_runner_compile_command, + CUTLASSCompileSourceCapturingContext, ) # Run compilation, check results just in case, and save # CUTLASS-based generated code. - with CUDACompileSourceCapturingContext() as ctx: + with CUTLASSCompileSourceCapturingContext(GPU_TYPE) as ctx: compiled = torch.compile(torch.mm, dynamic=False) expected = torch.mm(a, b) @@ -1619,7 +1720,8 @@ def test_standalone_runner(self): raise AssertionError(f"Expected len(sources) >= 1, got {len(sources)}") # Get names for temporary source and executable files. - cu_file = NamedTemporaryFile("w", suffix=".cu", delete=False) # noqa: SIM115 + suffix = ".cpp" if GPU_TYPE == "xpu" else ".cu" + cu_file = NamedTemporaryFile("w", suffix=suffix, delete=False) # noqa: SIM115 cu_file.close() exe_file = NamedTemporaryFile("w", suffix="", delete=False) # noqa: SIM115 exe_file.close() @@ -1630,8 +1732,8 @@ def test_standalone_runner(self): # Get command to compile .cu file, and run the # compilation. - command = cuda_standalone_runner_compile_command( - Path(cu_file.name), Path(exe_file.name) + command = cutlass_standalone_runner_compile_command( + GPU_TYPE, Path(cu_file.name), Path(exe_file.name) ) if IS_FBCODE: @@ -1659,7 +1761,8 @@ def test_standalone_runner(self): os.remove(cu_file.name) os.remove(exe_file.name) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_integration(self): """ @@ -1669,8 +1772,8 @@ def test_cutlass_backend_integration(self): def mm(a, b): return a @ b - a = torch.randn(128, 16).cuda().half() - b = torch.randn(128, 16).cuda().half().t() + a = torch.randn(128, 16).to(GPU_TYPE).half() + b = torch.randn(128, 16).to(GPU_TYPE).half().t() with config.patch( { @@ -1702,7 +1805,8 @@ def mm(a, b): num_ops = int(match.group(1)) self.assertTrue(num_ops > 0, "The number of ops should be greater than 0") - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_maybe_append_choice_caching(self): """ Test if maybe_append_choice's caching leads to correct results and @@ -1717,9 +1821,9 @@ def forward(self, A, B): A = A @ B / 32 return A - model = TestModule().cuda() - A = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda") - B = torch.randn(1024, 1024, dtype=torch.bfloat16, device="cuda").t() + model = TestModule().to(GPU_TYPE) + A = torch.randn(1024, 1024, dtype=torch.bfloat16, device=GPU_TYPE) + B = torch.randn(1024, 1024, dtype=torch.bfloat16, device=GPU_TYPE).t() expected = model(A, B) @@ -1750,11 +1854,10 @@ def counting_render(self, *args, **kwargs): torch.testing.assert_close(actual, expected) - # Check render call count: render is called uniquely for each codegen - # and for each finalized codegen. self.assertEqual(render_call_count, NUM_ITERATIONS + 2) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_multiple_mm(self): """ @@ -1769,13 +1872,13 @@ def forward(self, a, b, c, d): mm2 = c @ d return mm1, mm2 - model = MultipleMMModel().cuda() + model = MultipleMMModel().to(GPU_TYPE) # Create tensors with different shapes - a = torch.randn(128, 64).cuda().half() - b = torch.randn(32, 64).cuda().half().t() - c = torch.randn(256, 128).cuda().half() - d = torch.randn(64, 128).cuda().half().t() + a = torch.randn(128, 64).to(GPU_TYPE).half() + b = torch.randn(32, 64).to(GPU_TYPE).half().t() + c = torch.randn(256, 128).to(GPU_TYPE).half() + d = torch.randn(64, 128).to(GPU_TYPE).half().t() # Track render calls from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate @@ -1810,9 +1913,11 @@ def counting_render(self, *args, **kwargs): torch.testing.assert_close(actual, expected) num_matmuls = 2 - self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + expected_count = num_matmuls + num_matmuls * 2 + self.assertEqual(render_call_count, expected_count) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_multiple_mm_with_dynamic_shape(self): """ @@ -1822,8 +1927,8 @@ def test_multiple_mm_with_dynamic_shape(self): class MultipleMMDynamicModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.c = torch.randn(64, 256).cuda().half() - self.d = torch.randn(128, 256).cuda().half().t() + self.c = torch.randn(64, 256).to(GPU_TYPE).half() + self.d = torch.randn(128, 256).to(GPU_TYPE).half().t() def forward(self, a, b): # dynamic shape matmul @@ -1832,11 +1937,11 @@ def forward(self, a, b): mm2 = self.c @ self.d return mm1, mm2 - model = MultipleMMDynamicModel().cuda() + model = MultipleMMDynamicModel().to(GPU_TYPE) # Create tensors with different shapes - a = torch.randn(128, 64).cuda().half() - b = torch.randn(32, 64).cuda().half().t() + a = torch.randn(128, 64).to(GPU_TYPE).half() + b = torch.randn(32, 64).to(GPU_TYPE).half().t() # Track render calls from torch._inductor.codegen.cutlass.gemm_template import CUTLASSGemmTemplate @@ -1871,15 +1976,17 @@ def counting_render(self, *args, **kwargs): torch.testing.assert_close(actual, expected) num_matmuls = 2 - self.assertEqual(render_call_count, num_matmuls + num_matmuls * 2) + expected_count = num_matmuls + num_matmuls * 2 + self.assertEqual(render_call_count, expected_count) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_matmul_same_tensor(self): max_autotune_gemm_backends = "CUTLASS" M = 128 - A = torch.randn(M, M).cuda().half() + A = torch.randn(M, M).to(GPU_TYPE).half() with config.patch( { @@ -1892,13 +1999,14 @@ def test_cutlass_backend_matmul_same_tensor(self): torch.testing.assert_close(A @ A.t(), compiled(A, A.t())) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_cutlass_backend_matmul_nonzero_offset(self): max_autotune_gemm_backends = "CUTLASS" M = 129 - A = torch.randn(M, M - 1).cuda().half() + A = torch.randn(M, M - 1).to(GPU_TYPE).half() with config.patch( { @@ -1912,7 +2020,8 @@ def test_cutlass_backend_matmul_nonzero_offset(self): A[1:, :] @ A[1:, :].t(), compiled(A[1:, :], A[1:, :].t()) ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_flexible_layout(self): class TestModel(torch.nn.Module): @@ -1921,8 +2030,8 @@ def forward(self, B): return A @ B.t() M = 1024 - B = torch.randn(M, M).cuda().half() - model = TestModel().cuda() + B = torch.randn(M, M).to(GPU_TYPE).half() + model = TestModel().to(GPU_TYPE) with config.patch( { @@ -1933,7 +2042,8 @@ def forward(self, B): ): _ = torch.compile(model)(B) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @use_evt_config def test_evt_flexible_layout(self): @@ -1943,8 +2053,8 @@ def forward(self, B): return (A @ B.t()).relu() M = 1024 - B = torch.randn(M, M).cuda().half() - model = TestModel().cuda().half() + B = torch.randn(M, M).to(GPU_TYPE).half() + model = TestModel().to(GPU_TYPE).half() with config.patch( { @@ -1960,7 +2070,8 @@ def forward(self, B): 1, ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) def test_filtered_ops_cache(self): class TestModel(torch.nn.Module): @@ -1971,8 +2082,8 @@ def forward(self, B): return A M = 1024 - B = torch.randn(M, M).cuda().half() - model = TestModel().cuda() + B = torch.randn(M, M).to(GPU_TYPE).half() + model = TestModel().to(GPU_TYPE) start_time = time.time() with config.patch( @@ -1983,21 +2094,27 @@ def forward(self, B): } ): _ = torch.compile(model)(B) - self.assertTrue(time.time() - start_time < 60) - @unittest.skipIf(not SM90OrLater, "need sm_90") + if GPU_TYPE == "xpu": + time_limit = 100 + else: + time_limit = 60 + self.assertTrue(time.time() - start_time < time_limit) + + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) @parametrize("use_aoti", (False, True)) def test_compilation_time(self, use_aoti): M = 1024 - A = torch.randn(M, M).cuda().half() - B = torch.randn(M, M).cuda().half().t() + A = torch.randn(M, M).to(GPU_TYPE).half() + B = torch.randn(M, M).to(GPU_TYPE).half().t() class MyModel(torch.nn.Module): def forward(self, a, b): return a @ b - model = MyModel().cuda() + model = MyModel().to(GPU_TYPE) expected = model(A, B) start_time = time.time() @@ -2016,10 +2133,19 @@ def forward(self, a, b): else: actual = torch.compile(model, fullgraph=True)(A, B) - torch.testing.assert_close(actual, expected) - self.assertTrue(time.time() - start_time < 50) - - @unittest.skipIf(not SM90OrLater, "need sm_90") + if GPU_TYPE == "xpu": + atol = 1e-4 # default is 1e-5 + rtol = 1e-3 # default is 1e-3 + expected_time = 100 + else: + atol = None + rtol = None + expected_time = 50 + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + self.assertTrue(time.time() - start_time < expected_time) + + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops @evt_all_shapes @@ -2031,7 +2157,8 @@ def forward(self, a, b, extra_args): self.run_evt_test(TestModel(), op, shape) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_bin_ops def test_evt_broadcasting(self, op): @@ -2042,10 +2169,10 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N)) - model = TestModel().cuda() + model = TestModel().to(GPU_TYPE) result = torch.compile(model)(a, b, extra_args) ref_result = model(a, b, extra_args) @@ -2056,7 +2183,8 @@ def forward(self, a, b, extra_args): ) torch.testing.assert_close(result, ref_result) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_un_ops def test_evt_activations(self, op): @@ -2067,10 +2195,10 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N)) - model = TestModel().cuda() + model = TestModel().to(GPU_TYPE) result = torch.compile(model)(a, b, extra_args) ref_result = model(a, b, extra_args) @@ -2081,14 +2209,15 @@ def forward(self, a, b, extra_args): ) torch.testing.assert_close(result, ref_result) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops def test_evt_mixed_dtypes(self, op): M = 1024 N = 256 - fp32_tensor = torch.ones(M, N).cuda().float() + fp32_tensor = torch.ones(M, N).to(GPU_TYPE).float() class TestModel(torch.nn.Module): def forward(self, a, b, extra_args): @@ -2097,9 +2226,9 @@ def forward(self, a, b, extra_args): out1 = torch.add(out0, fp32_tensor) return out1 - model = TestModel().cuda() - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + model = TestModel().to(GPU_TYPE) + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N), dtype=torch.float16) # baseline is cutlass kernel + triton @@ -2122,7 +2251,8 @@ def forward(self, a, b, extra_args): torch.testing.assert_close(result, ref_result) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops def test_evt_multi_op(self, op): @@ -2133,7 +2263,8 @@ def forward(self, a, b, extra_args): self.run_evt_test(TestModel(), op, (1024, 512)) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops def test_evt_reuse_matmul_input(self, op): @@ -2144,7 +2275,8 @@ def forward(self, a, b, extra_args): self.run_evt_test(TestModel(), op, (1024, 1024)) # shape needs to be square - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config @evt_all_ops @parametrize( @@ -2164,10 +2296,10 @@ def forward(self, a, b, extra_args): shapes = [(512, 512)] if not dynamic else [(1024, 64), (128, 256)] for i, shape in enumerate(shapes): M, N = shape - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N)) - model = TestModel().cuda() + model = TestModel().to(GPU_TYPE) result = torch.compile(model)(a, b, extra_args) ref_result = model(a, b, extra_args) @@ -2180,7 +2312,8 @@ def forward(self, a, b, extra_args): ) torch.testing.assert_close(result, ref_result) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") @use_evt_config def test_evt_return_accumulator(self): op = torch.add @@ -2192,10 +2325,10 @@ def forward(self, a, b, extra_args): M = 1024 N = 512 - a = torch.ones(M, N).cuda().half() - b = torch.ones(N, N).cuda().half().t() + a = torch.ones(M, N).to(GPU_TYPE).half() + b = torch.ones(N, N).to(GPU_TYPE).half().t() extra_args = gen_args(op, (M, N)) - model = TestModel().cuda() + model = TestModel().to(GPU_TYPE) result = torch.compile(model)(a, b, extra_args) ref_result = model(a, b, extra_args) @@ -2206,19 +2339,18 @@ def forward(self, a, b, extra_args): ) torch.testing.assert_close(result, ref_result) - @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) - @parametrize("arch", ("90", "100")) - @parametrize("cuda_version", ("12.4", "12.8")) - def test_gemm_operation_serialization(self, arch: str, cuda_version: str): + def _test_gemm_operation_serialization( + self, arch: str, cuda_version: str, min_ops=1000 + ): """ Testing serialization for GEMM operations generated by CUTLASS. This should cover GroupedGemmOperation as well. """ - full_ops = _gen_ops_cached(arch, cuda_version) + full_ops = _gen_ops_cached(arch, cuda_version, GPU_TYPE) ops = pytree.tree_flatten(full_ops)[0] # sanity check - self.assertGreater(len(ops), 1000, "Too few ops generated") + self.assertGreater(len(ops), min_ops, "Too few ops generated") # test if configuration name is unique op_config_names = [op.configuration_name() for op in ops] @@ -2234,11 +2366,30 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str): for op, deserialized_op in zip(ops, deserialized_ops, strict=False): self.assertTrue(_check_if_instances_equal(op, deserialized_op)) + @unittest.skipIf(not HAS_CUDA, "CUDA not available") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @parametrize("arch", ("90", "100")) + @parametrize("cuda_version", ("12.4", "12.8")) + def test_gemm_operation_serialization_cuda(self, arch: str, cuda_version: str): + self._test_gemm_operation_serialization(arch, cuda_version) + + @skipXPUIf(not Xe2_Or_Later, "") + @unittest.skipIf(not HAS_XPU, "XPU not available") + @mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()}) + @parametrize("arch", ("Xe12", "Xe20")) + @parametrize("xpu_version", ("20250201", "20250301")) + def test_gemm_operation_serialization_xpu(self, arch: str, xpu_version: str): + from torch._inductor.codegen.cutlass.utils import _normalize_xpu_arch + + arch = _normalize_xpu_arch(arch) + self._test_gemm_operation_serialization(arch, xpu_version, min_ops=40) + @unittest.skipIf( torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "XPU SYCL-TLA has not supported fp8 yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) @parametrize( @@ -2264,7 +2415,7 @@ def test_fp8_rowwise_scaling( ): # Only bf16 output type is supported for row-wise scaling, not fp32 output_dtype: torch.dtype = torch.bfloat16 - device = "cuda" + device = GPU_TYPE M, K, N = shape # Matmul Y = X [M, K] x W [N, K] x = torch.randn(M, K, dtype=input_dtype, device=device) w = torch.randn(N, K, dtype=input_dtype, device=device) @@ -2315,7 +2466,8 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "XPU SYCL-TLA has not supported fp8 yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) @parametrize( @@ -2345,7 +2497,7 @@ def test_fp8_rowwise_scaling_multiple_linear( self.skipTest("Accuracy issues when both AOTI and dynamic are enabled") # Only bf16 output type is supported for row-wise scaling, not fp32 output_dtype: torch.dtype = torch.bfloat16 - device = "cuda" + device = GPU_TYPE M, N = shape # Matmul Y = X [M, K] x W [N, K] x = torch.randn(M, N, dtype=output_dtype, device=device) w1 = torch.randn(N, N, dtype=output_dtype, device=device) @@ -2384,7 +2536,7 @@ def forward(self, x): ) return y2 - model = TestModule(w1, w2, float8_dtype).cuda() + model = TestModule(w1, w2, float8_dtype).to(GPU_TYPE) dynamic_shapes = ( { @@ -2412,7 +2564,8 @@ def forward(self, x): torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(True, "XPU SYCL-TLA has not supported fp8 yet") + @skipCUDAIf(not SM90OrLater, "need sm_90") @fp8_config @parametrize("float8_dtype", (torch.float8_e4m3fn,)) @parametrize( @@ -2436,7 +2589,7 @@ def test_fp8_tensorwise_scaling( use_fast_accum: bool, input_dtype: torch.dtype, ): - device = "cuda" + device = GPU_TYPE M, K, N = shape # Matmul Y = X [M, K] x W [N, K] output_dtype = input_dtype # input and output dtypes of _scaled_mm do not need to be the same, but @@ -2489,7 +2642,8 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): # setting a small absolute tolerance in these tests torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "") + @skipCUDAIf(not SM90OrLater, "need sm_90") def test_config_number_post_filtering(self) -> None: """ Test if cutlass backend produces the same number of configs after filtering @@ -2502,8 +2656,8 @@ def test_config_number_post_filtering(self) -> None: for layout in layouts: for dtype in dtypes: - a = torch.randn(128, 128, dtype=dtype).cuda() - b = torch.randn(128, 128, dtype=dtype).cuda() + a = torch.randn(128, 128, dtype=dtype).to(GPU_TYPE) + b = torch.randn(128, 128, dtype=dtype).to(GPU_TYPE) if layout[0] == "c": a = a.t() if layout[1] == "c": @@ -2533,8 +2687,10 @@ def test_config_number_post_filtering(self) -> None: # Check that all config counts are equal all_counts = list(config_counts.values()) + # XPU has more configs for bf16 than f16. + expected_count = 2 if GPU_TYPE == "xpu" else 1 self.assertTrue( - len(set(all_counts)) == 1, + len(set(all_counts)) == expected_count, f"Config counts should be equal across all layout/dtype combinations. " f"Got counts: {config_counts}", ) @@ -2544,5 +2700,5 @@ def test_config_number_post_filtering(self) -> None: from torch._inductor.utils import is_big_gpu # Set env to make it work in CI. - if HAS_CUDA_AND_TRITON and HAS_CPU and is_big_gpu(): + if HAS_GPU and HAS_CPU and is_big_gpu(): run_tests() diff --git a/test/inductor/test_cutlass_evt.py b/test/inductor/test_cutlass_evt.py index dd296b7f75ac7..f814880ac045a 100644 --- a/test/inductor/test_cutlass_evt.py +++ b/test/inductor/test_cutlass_evt.py @@ -12,10 +12,18 @@ from torch._inductor.ir import ComputedBuffer, FixedLayout, PermuteView, Pointwise from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import OrderedSet -from torch.testing._internal.common_cuda import SM90OrLater +from torch.testing._internal.common_cuda import ( + IS_SM100, + IS_SM90, + SM120OrLater, + SM90OrLater, +) +from torch.testing._internal.common_device_type import skipCUDAIf, skipXPUIf +from torch.testing._internal.common_xpu import Xe2_Or_Later from torch.testing._internal.inductor_utils import ( + GPU_TYPE, HAS_CPU, - HAS_CUDA_AND_TRITON, + HAS_GPU_AND_TRITON, MockGraphHandler, ) @@ -66,8 +74,16 @@ ), } + # For SM100 + class MockMathInstruction: + def __init__(self): + self.opcode_class = cutlass_lib.library.OpcodeClass.TensorOp + class MockTileDescription: threadblock_shape = (128, 128, 8) + # SM100 path + cluster_shape = (1, 1, 1) + math_instruction = MockMathInstruction() def _create_mock_buffer_name_map(example_tensors): name_to_buffer = {} @@ -104,7 +120,10 @@ def num_reads(self): class TestCutlassEVT(TestCase): - @unittest.skipIf(not SM90OrLater, "need sm_90") + device_type = GPU_TYPE + + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_accumulator_return(self): from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen @@ -161,7 +180,8 @@ def fn(accum, buf1, buf2): return tmp_0, tmp_2, D""", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_disjoint_read_indexing(self): from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen @@ -210,7 +230,8 @@ def inner_fn_buf4(index): index strides [200, 60000, 1], and layout stride [60000, 200, 1]""", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen_broadcasting(self): from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen @@ -270,7 +291,8 @@ def fn(accum, buf1, buf2): return tmp_0, tmp_2, D""", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_py_codegen(self): from torch._inductor.codegen.cutlass.python_evt import CutlassEVTCodegen @@ -326,7 +348,8 @@ def fn(accum, buf1, buf2): return tmp_1, D""", ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater, "need sm_90") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_example_tensor_creation(self): from torch._inductor.codegen.cutlass.lib_extensions.evt_extensions import ( @@ -358,21 +381,47 @@ def test_example_tensor_creation(self): result["buf1"].element, torch_dtype_to_cutlass_type(torch.float32) ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater or SM120OrLater, "need sm_90 or sm_100") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_argument_codegen(self): - from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch - - cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type] - epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, cuda_arch) + from torch._inductor.codegen.cutlass.utils import cutlass_arch - self.assertExpectedInline( - _render_argument_type( - epilogue_functor, - _create_mock_buffer_name_map(EXAMPLE_TENSORS), - lambda x: int(x), - )[0], - """\ + arch = int(cutlass_arch(GPU_TYPE)) + epilogue_functor = _trace(BIAS_CODE, EXAMPLE_TENSORS, arch) + code = _render_argument_type( + epilogue_functor, + _create_mock_buffer_name_map(EXAMPLE_TENSORS), + lambda x: int(x), + )[0] + if GPU_TYPE == "xpu": + self.assertExpectedInline( + code, + """\ +{ /* thread */ + { /* F */ + { /* compute_1 */ + { /* compute_0 */ + {}, /* accum */ + {}, /* C */ + {}, /* compute_0 */ + }, + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* aux */ + {}, /* compute_1 */ + }, + {/* ptr_aux */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* F */ + }, + {/* ptr_col */ (float*) (ptr_2 + ptr_2_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {}, /* compute_2 */ + {}, /* compute_3 */ + {}, /* compute_4 */ + } +""", + ) + else: + self.assertExpectedInline( + code, + """\ { /* thread */ { /* F */ { /* compute_1 */ @@ -392,12 +441,15 @@ def test_evt_argument_codegen(self): {}, /* compute_4 */ } """, - ) + ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater or SM120OrLater, "need sm_90 or sm_100") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_argument_codegen_return_accumulator(self): - from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch + from torch._inductor.codegen.cutlass.utils import cutlass_arch + + arch = int(cutlass_arch(GPU_TYPE)) code = """ def fn(accum, bias): @@ -420,16 +472,31 @@ def fn(accum, bias): ), } - cuda_arch = int(get_cuda_arch()) # type: ignore[arg-type] - epilogue_functor = _trace(code, example_tensors, cuda_arch) + epilogue_functor = _trace(code, example_tensors, arch) + code = _render_argument_type( + epilogue_functor, + _create_mock_buffer_name_map(example_tensors), + lambda x: int(x), + )[0] - self.assertExpectedInline( - _render_argument_type( - epilogue_functor, - _create_mock_buffer_name_map(example_tensors), - lambda x: int(x), - )[0], - """\ + if GPU_TYPE == "xpu": + self.assertExpectedInline( + code, + """\ +{ /* thread */ + { /* E */ + {}, /* accum */ + {/* ptr_aux */ (float*) (ptr_0 + ptr_0_offset), /* null_default */ float(0), /* dAux */ {2048, _1{}, _0{}}}, /* E */ + }, + {/* ptr_col */ (float*) (ptr_1 + ptr_1_offset), /* null_default */ float(0), /* dCol */ {}}, /* bias */ + {}, /* compute_0 */ + } +""", + ) + else: + self.assertExpectedInline( + code, + """\ { /* thread */ { /* E */ {}, /* accum */ @@ -439,9 +506,10 @@ def fn(accum, bias): {}, /* compute_0 */ } """, - ) + ) - @unittest.skipIf(not SM90OrLater, "need sm_90") + @skipXPUIf(not Xe2_Or_Later, "Unsupported platform") + @skipCUDAIf(not SM90OrLater or SM120OrLater, "need sm_90 or sm_100") @unittest.skipIf(not try_import_cutlass(), "requires cutlass") def test_evt_codegen(self): _, _, code, _ = trace( @@ -453,10 +521,12 @@ def test_evt_codegen(self): EpilogueScheduleType.ScheduleAuto, _create_mock_buffer_name_map(EXAMPLE_TENSORS), lambda x: x, # static shapes + device_type=GPU_TYPE, ) - self.assertExpectedInline( - code, - """\ + if IS_SM90: + self.assertExpectedInline( + code, + """\ using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< cute::Shape<_128, _128, _8>, cutlass::epilogue::collective::EpilogueTileAuto, @@ -554,11 +624,202 @@ def test_evt_codegen(self): using StrideD = cute::Stride, cute::Int<0>>; """, - ) + ) + + if IS_SM100: + self.assertExpectedInline( + code, + """\ + +using EpilogueDescriptor = cutlass::epilogue::collective::detail::Sm100EpilogueDescriptor< + cutlass::arch::OpClassTensorOp, + cute::Shape<_128, _128, _8>, cutlass::epilogue::collective::EpilogueTileAuto, + float, float, float, + cutlass::epilogue::collective::EpilogueScheduleAuto, cute::Stride, cute::Int<0>>, cute::Stride, cute::Int<0>>, + false /* IsPerColScaleSupported */, + false /* IsBlockScaleSupported */ +>; + +using ElementC = float; +using StrideC = cute::Stride, cute::Int<0>>; +using TensorC = cutlass::epilogue::fusion::Sm90SrcFetch; + +using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + +using AuxDescriptor = cutlass::epilogue::collective::detail::Sm100AuxLoadDescriptor, cute::Int<0>>, float>; + +using Aux = cutlass::epilogue::fusion::Sm90AuxLoad< + AuxDescriptor::Stages, typename AuxDescriptor::EpilogueTile, float, + cute::Stride, cute::Int<0>>, typename AuxDescriptor::SmemLayoutAtom, typename AuxDescriptor::CopyOpS2R +>; + +using Bias = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, float, float, + cute::Stride, cute::Int<0>, cute::Int<0>> +>; + +using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT< + Compute0, + Accum, + TensorC>; + +using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT< + Compute1, + EVTCompute0, + Aux>; + +using FDescriptor = cutlass::epilogue::collective::detail::Sm100AuxStoreDescriptor< + EpilogueDescriptor, cute::Stride, cute::Int<0>>, float +>; + +using F = cutlass::epilogue::fusion::Sm90AuxStore< + FDescriptor::Stages, typename FDescriptor::EpilogueTile, float, + cutlass::FloatRoundStyle::round_to_nearest, cute::Stride, cute::Int<0>>, typename FDescriptor::SmemLayoutAtom, + typename FDescriptor::CopyOpR2S +>; + +using EVTF = cutlass::epilogue::fusion::Sm90EVT< + F, + EVTCompute1>; + +using Compute2 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::ReLu, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using Compute3 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using Compute4 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using DagCompute4 = cutlass::epilogue::fusion::Sm90TopologicalVisitor< + float, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<0>, + cute::seq<2, 1>, + cute::seq<3, 0>, + >, + EVTF, + Bias, + Compute2, + Compute3, + Compute4 +>; + +using ElementD = float; +using StrideD = cute::Stride, cute::Int<0>>; + +""", + ) + if GPU_TYPE == "xpu": + self.assertExpectedInline( + code, + """\ + +using TileShape_MNK = cute::Shape<_128, _128, _8>; + +using ElementC = float; +using StrideC = cute::Stride, cute::Int<0>>; +using TensorC = cutlass::epilogue::fusion::XeSrcFetch; + +using Accum = cutlass::epilogue::fusion::XeAccFetch; + +using Aux = cutlass::epilogue::fusion::XeAuxLoad< + float, + cute::Stride, cute::Int<0>> +>; + +using Bias = cutlass::epilogue::fusion::XeColBroadcast< + 0 /*Stages*/, TileShape_MNK, float, float, + cute::Stride, cute::Int<0>, cute::Int<0>> +>; + +using Compute0 = cutlass::epilogue::fusion::XeCompute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::fusion::XeEVT< + Compute0, + Accum, + TensorC>; + +using Compute1 = cutlass::epilogue::fusion::XeCompute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::fusion::XeEVT< + Compute1, + EVTCompute0, + Aux>; + +using F = cutlass::epilogue::fusion::XeAuxStore< + float, + cute::Stride, cute::Int<0>> +>; + +using EVTF = cutlass::epilogue::fusion::XeEVT< + F, + EVTCompute1>; + +using Compute2 = cutlass::epilogue::fusion::XeCompute< + cutlass::epilogue::thread::ReLu, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using Compute3 = cutlass::epilogue::fusion::XeCompute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using Compute4 = cutlass::epilogue::fusion::XeCompute< + cutlass::plus, float, float, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using DagCompute4 = cutlass::epilogue::fusion::XeTopologicalVisitor< + float, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<0>, + cute::seq<2, 1>, + cute::seq<3, 0> + >, + EVTF, + Bias, + Compute2, + Compute3, + Compute4 +>; + +using ElementD = float; +using StrideD = cute::Stride, cute::Int<0>>; + +""", + ) if __name__ == "__main__": from torch._dynamo.test_case import run_tests - if HAS_CPU or HAS_CUDA_AND_TRITON: + if HAS_CPU or HAS_GPU_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_deterministic.py b/test/inductor/test_deterministic.py index c75e3f9961790..04ed4001a3ac1 100644 --- a/test/inductor/test_deterministic.py +++ b/test/inductor/test_deterministic.py @@ -16,7 +16,6 @@ instantiate_parametrized_tests, IS_FBCODE, parametrize, - skipIfXpu, ) from torch.testing._internal.inductor_utils import ( GPU_TYPE, @@ -48,7 +47,6 @@ def test_use_deterministic_algorithsm(self): finally: torch.use_deterministic_algorithms(old_val, warn_only=True) - @skipIfXpu(msg="pad_mm is not enabled for XPU.") @parametrize("deterministic", [False, True]) def test_mm_padding(self, deterministic): with inductor_config.patch(deterministic=deterministic): @@ -115,6 +113,56 @@ def foo(x): else: self.assertTrue(counters["inductor"]["coordesc_tuning_bench"] > 0) + @unittest.skipIf(not HAS_GPU_AND_TRITON, "requires GPU + Triton") + @inductor_config.patch(batch_invariant=True) + def test_persistent_reduction_batch_invariance(self): + H = 768 + FULL = 1024 + + def fn(x, w, b): + return torch.nn.functional.layer_norm(x, (H,), weight=w, bias=b) + + torch.manual_seed(0) + w = torch.randn(H, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(H, device=GPU_TYPE, dtype=torch.bfloat16) + x_full = torch.randn(FULL, H, device=GPU_TYPE, dtype=torch.bfloat16) + + compiled = torch.compile(fn) + torch._dynamo.reset() + out_full = compiled(x_full, w, b) + self.assertEqual(out_full, fn(x_full, w, b)) + + # Halving sweep, matching what the benchmark harness does. + size = FULL // 2 + while size >= 1: + torch._dynamo.reset() + out = compiled(x_full[:size].contiguous(), w, b) + ref = out_full[:size].contiguous() + self.assertTrue( + torch.equal(ref, out), + f"persistent reduction diverged at size={size} (FULL={FULL})", + ) + size //= 2 + + def test_reorder_for_locality_preserves_randint_order(self): + with inductor_config.patch(fallback_random=True): + + def fn(): + torch.manual_seed(0) + out = torch.randint(0, 100, (4, 1), dtype=torch.int64) + _ = torch.randint(0, 100, (2, 1), dtype=torch.int64) + return out + + compiled = torch.compile(fn, backend="inductor") + + torch.manual_seed(0) + eager = fn() + + torch.manual_seed(0) + compiled_out = compiled() + + torch.testing.assert_close(eager, compiled_out) + @unittest.skipIf(IS_FBCODE, "Skipping run2run determinism test in fbcode") @parametrize("model_name", ["GoogleFnet", "BertForMaskedLM", "DistillGPT2"]) @parametrize("training_or_inference", ["training", "inference"]) diff --git a/test/inductor/test_dropout_align_random_eager.py b/test/inductor/test_dropout_align_random_eager.py new file mode 100644 index 0000000000000..78746e0113377 --- /dev/null +++ b/test/inductor/test_dropout_align_random_eager.py @@ -0,0 +1,482 @@ +# Owner(s): ["module: inductor"] + +import struct +import time + +import pytest + +import torch +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import run_and_get_code +from torch.testing import FileCheck +from torch.testing._internal.common_utils import IS_LINUX +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA_AND_TRITON, + requires_gpu, +) + + +# ─────────────────────────────────────────────────────────────── +# Global config +# ─────────────────────────────────────────────────────────────── +BASE_SEED = 1234 +DROPOUT_P = 0.5 +FFN_DIM = 3072 +HIDDEN_DIM = 1024 +BATCH = 3 +SEQ_LEN = 512 + + +# ─────────────────────────────────────────────────────────────── +# Model under test +# ─────────────────────────────────────────────────────────────── +class LinearBlock(torch.nn.Module): + def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float = DROPOUT_P): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(hidden_dim, ffn_dim), + torch.nn.Dropout(dropout), + torch.nn.ReLU(inplace=False), + torch.nn.Linear(ffn_dim, hidden_dim), + ) + + def forward(self, x: torch.Tensor): + return self.net(x) + + +class MultiDropoutBlock(torch.nn.Module): + """Block with multiple Dropout ops to stress RNG alignment.""" + + def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float = DROPOUT_P): + super().__init__() + self.net = torch.nn.Sequential( + torch.nn.Linear(hidden_dim, ffn_dim), + torch.nn.Dropout(dropout), + torch.nn.ReLU(inplace=False), + torch.nn.Dropout(dropout), + torch.nn.Linear(ffn_dim, hidden_dim), + torch.nn.Dropout(dropout), + ) + + def forward(self, x: torch.Tensor): + return self.net(x) + + +def build_models(dropout: float, *, mode=None, dynamic: bool = False): + eager = LinearBlock(HIDDEN_DIM, FFN_DIM, dropout) + compiled = LinearBlock(HIDDEN_DIM, FFN_DIM, dropout) + compiled.load_state_dict(eager.state_dict()) + compiled = torch.compile(compiled, mode=mode, dynamic=dynamic) + return eager, compiled + + +# ─────────────────────────────────────────────────────────────── +# Helpers +# ─────────────────────────────────────────────────────────────── +def _set_seed(base: int = BASE_SEED): + torch.manual_seed(base) + if torch.cuda.is_available(): + torch.cuda.manual_seed(base) + + +def _sync(x: torch.Tensor): + if x.is_cuda: + torch.cuda.synchronize() + + +def _timed_run(model, x, backward: bool = False): + _sync(x) + t0 = time.time() + y = model(x) + if backward: + (y.square().mean()).backward() + _sync(x) + return (time.time() - t0) * 1e3, y + + +def _cuda_rng_u64_seed_off(): + """Return (seed, offset) extracted from torch.cuda.get_rng_state().""" + st = torch.cuda.get_rng_state() + seed = struct.unpack(" mark as XFAIL + # ─────────────────────────────────────────────────────────── + @requires_gpu() + @pytest.mark.xfail( + reason="primitive torch.rand parity is tracked as future work", + strict=False, + ) + def test_primitive_rand_parity(self): + device = torch.device(GPU_TYPE) + shape = (BATCH, SEQ_LEN, HIDDEN_DIM) + self._run_primitive_random_parity("rand", device, shape) + + @requires_gpu() + @pytest.mark.xfail( + reason="primitive torch.randn parity is tracked as future work", + strict=False, + ) + def test_primitive_randn_parity(self): + device = torch.device(GPU_TYPE) + shape = (BATCH, SEQ_LEN, HIDDEN_DIM) + self._run_primitive_random_parity("randn", device, shape) + + @requires_gpu() + @pytest.mark.xfail( + reason="primitive torch.randint parity is tracked as future work", + strict=False, + ) + def test_primitive_randint_parity(self): + device = torch.device(GPU_TYPE) + shape = (BATCH, SEQ_LEN, HIDDEN_DIM) + self._run_primitive_random_parity("randint", device, shape) + + # ─────────────────────────────────────────────────────────── + # nn.Dropout as primitive RNG consumer (should PASS) + # ─────────────────────────────────────────────────────────── + @requires_gpu() + def test_primitive_nn_dropout_parity(self): + device = torch.device(GPU_TYPE) + shape = (BATCH, SEQ_LEN, HIDDEN_DIM) + + x = torch.ones(shape, device=device) + + drop_eager = torch.nn.Dropout(DROPOUT_P).to(device).train() + drop_compiled = torch.nn.Dropout(DROPOUT_P).to(device).train() + drop_compiled.load_state_dict(drop_eager.state_dict()) + drop_compiled = torch.compile(drop_compiled) + + _set_seed(BASE_SEED) + out_eager = drop_eager(x) + + _set_seed(BASE_SEED) + out_comp = drop_compiled(x) + + torch.testing.assert_close(out_eager, out_comp, rtol=0.0, atol=0.0) + + # ─────────────────────────────────────────────────────────── + # Large seed (>32-bit) packing truncation + # Seed and base are packed into int64 as (seed << 32) | base. + # Seeds > 2^32 overflow. + # ─────────────────────────────────────────────────────────── + @requires_gpu() + def test_large_seed(self): + for seed in [2**33 + 1, 2**40 + 12345]: + with self.subTest(seed=seed): + masks_eq, _, _ = dropout_parity((1024,), seed=seed) + self.assertTrue(masks_eq, f"seed={seed}: mask mismatch") + + +if __name__ == "__main__": + if IS_LINUX and HAS_CUDA_AND_TRITON: + run_tests(needs="filelock") diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 08b458d761f46..6b2148b51adca 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import os import sys +import tempfile import unittest import torch @@ -11,7 +12,7 @@ try: - from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 + from extension_backends.cpp.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend ExtensionCppWrapperCodegen, ExtensionScheduling, ExtensionWrapperCodegen, @@ -23,8 +24,6 @@ ExtensionWrapperCodegen, ) -from filelock import FileLock, Timeout - import torch._inductor.config as config from torch._inductor import cpu_vec_isa, metrics from torch._inductor.codegen import cpp_utils @@ -32,7 +31,9 @@ get_scheduling_for_device, get_wrapper_codegen_for_device, register_backend_for_device, + register_device_op_overrides, ) +from torch._inductor.codegen.cpu_device_op_overrides import CpuDeviceOpOverrides from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS, xfailIfS390X @@ -55,22 +56,11 @@ class BaseExtensionBackendTests(TestCase): module = None - # Use a lock file so that only one test can build this extension at a time - lock_file = "extension_device.lock" - lock = FileLock(lock_file) - @classmethod def setUpClass(cls): super().setUpClass() - try: - cls.lock.acquire(timeout=600) - except Timeout: - # This shouldn't happen, still attempt to build the extension anyway - pass - - # Build Extension - torch.testing._internal.common_utils.remove_cpp_extensions_build_root() + cls._build_dir = tempfile.TemporaryDirectory() source_file_path = os.path.dirname(os.path.abspath(__file__)) source_file = os.path.join( source_file_path, "extension_backends/cpp/extension_device.cpp" @@ -82,6 +72,7 @@ def setUpClass(cls): ], extra_cflags=["-g"], verbose=True, + build_directory=cls._build_dir.name, ) @classmethod @@ -89,11 +80,7 @@ def tearDownClass(cls): cls._stack.close() super().tearDownClass() - torch.testing._internal.common_utils.remove_cpp_extensions_build_root() - - cls.lock.release() - if os.path.exists(cls.lock_file): - os.remove(cls.lock_file) + cls._build_dir.cleanup() def setUp(self): torch._dynamo.reset() @@ -110,7 +97,12 @@ def tearDown(self): super().tearDown() torch._dynamo.reset() - # return the working directory (see setUp) + backend_name = torch._C._get_privateuse1_backend_name() + if hasattr(torch, backend_name): + delattr(torch, backend_name) + if f"torch.{backend_name}" in sys.modules: + del sys.modules[f"torch.{backend_name}"] + os.chdir(self.old_working_dir) @@ -127,6 +119,7 @@ def test_open_device_registration(self): ExtensionWrapperCodegen, ExtensionCppWrapperCodegen, ) + register_device_op_overrides("extension_device", CpuDeviceOpOverrides()) self.assertTrue( get_scheduling_for_device("extension_device") == ExtensionScheduling ) @@ -181,4 +174,4 @@ def fn(a, b, c): # cpp_extension doesn't work in fbcode right now if HAS_CPU and not IS_MACOS and not IS_FBCODE: - run_tests(needs="filelock") + run_tests() diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index f38a6dc6896a1..73c440bdff4d5 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,5 +1,4 @@ # Owner(s): ["module: inductor"] -# flake8: noqa: B950 import functools import json @@ -70,6 +69,11 @@ skipXPUIf, ) from torch.testing._internal.common_quantized import _snr +from torch.testing._internal.common_utils import ( # noqa: F401 + MI200_ARCH, + skipIfRocm, + skipIfRocmArch, +) from torch.testing._internal.inductor_utils import HAS_GPU from torch.utils._triton import has_triton, has_triton_tma_device @@ -178,6 +182,21 @@ def create_attention(score_mod, block_mask, enable_gqa=False, kernel_options=Non ) +def flex_attention_fwd(q, k, v, score_mod, block_mask, scale): + # Uses the HOP directly because flex_attention_backward expects lse in log2 + # scale, but the public flex_attention API converts lse to natural log. + out, lse, _ = flex_attention_hop( + q, + k, + v, + score_mod, + block_mask.as_tuple(), + scale, + {}, + ) + return out, lse + + def create_block_mask_test(score_mod, query, key): block_mask = create_block_mask( score_mod, @@ -4846,7 +4865,7 @@ class mask_fn_0(torch.nn.Module): def forward(self, child: "i32[]", child_1: "i32[]", child_2: "i32[]", child_3: "i32[]"): ge: "b8[]" = child_2 >= child_3; child_2 = child_3 = None return ge -""", # noqa: B950 +""", ) # Save the AOT graphs aot_graphs = [] @@ -4894,10 +4913,197 @@ class mask_graph0(torch.nn.Module): def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]"): full_default: "b8[]" = torch.ops.aten.full.default([], True, dtype = torch.bool, layout = torch.strided, device = device(type='GPU_TYPE', index=0), pin_memory = False) return full_default -""".replace( # noqa: B950 - "GPU_TYPE", torch.device(device).type - ), +""".replace("GPU_TYPE", torch.device(device).type), + ) + + @supported_platform + @skip_on_cpu + def test_direct_backward_preserves_explicit_buffers(self, device): + mask_buffer = torch.full((), 128, device=device, dtype=torch.int32) + + def score_mod(score, b, h, m, n): + return score + + def mask_mod(b, h, m, n): + return m + mask_buffer >= n + + block_mask = create_block_mask( + mask_mod, + B=2, + H=2, + Q_LEN=128, + KV_LEN=128, + device=device, + ) + scale = 1.0 / 16**0.5 + + dtype = torch.float32 + q = torch.randn( + (2, 2, 128, 16), + dtype=dtype, + device=device, + requires_grad=True, ) + k = torch.randn( + (2, 2, 128, 16), + dtype=dtype, + device=device, + requires_grad=True, + ) + v = torch.randn( + (2, 2, 128, 16), + dtype=dtype, + device=device, + requires_grad=True, + ) + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) + q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + + sdpa_partial = create_attention(score_mod, block_mask) + golden_out = sdpa_partial(q_gold, k_gold, v_gold) + ref_out = sdpa_partial(q_ref, k_ref, v_ref) + + backward_grad = torch.randn((2, 2, 128, 16), dtype=dtype, device=device) + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + + out, logsumexp = flex_attention_fwd( + q, + k, + v, + score_mod, + block_mask, + scale, + ) + + @torch.compile(fullgraph=True) + def compiled_bw(query, key, value, fwd_out, lse, grad_out): + return torch.ops.higher_order.flex_attention_backward( + query, + key, + value, + fwd_out, + lse, + grad_out, + None, + score_mod, + None, + block_mask.as_tuple(), + scale, + {}, + (), + (), + ) + + with torch.no_grad(): + dq, dk, dv, _ = compiled_bw( + q, + k, + v, + out.detach(), + logsumexp.detach(), + backward_grad, + ) + + fudge_factor = 10.0 + self._check_equal(q_gold.grad, q_ref.grad, dq, fudge_factor, "Grad_Query") + self._check_equal(k_gold.grad, k_ref.grad, dk, fudge_factor, "Grad_Key") + self._check_equal(v_gold.grad, v_ref.grad, dv, fudge_factor, "Grad_Value") + + @supported_platform + @skip_on_cpu + def test_direct_backward_supports_symint_score_mod_buffers(self, device): + def score_mod(score, b, h, m, n, head_dim): + return score + + def mask_mod(b, h, m, n): + return m >= n + + block_mask = create_block_mask( + mask_mod, + B=2, + H=2, + Q_LEN=128, + KV_LEN=128, + device=device, + ) + scale = 1.0 / 16**0.5 + dtype = torch.float32 + q = torch.randn((2, 2, 128, 16), dtype=dtype, device=device) + k = torch.randn((2, 2, 128, 16), dtype=dtype, device=device) + v = torch.randn((2, 2, 128, 16), dtype=dtype, device=device) + backward_grad = torch.randn((2, 2, 128, 16), dtype=dtype, device=device) + + out, logsumexp = flex_attention_fwd( + q, + k, + v, + _identity, + block_mask, + scale, + ) + static_head_dim = q.shape[-1] + + @torch.compile(backend="aot_eager", fullgraph=True) + def compiled_literal_bw(query, key, value, fwd_out, lse, grad_out): + return torch.ops.higher_order.flex_attention_backward( + query, + key, + value, + fwd_out, + lse, + grad_out, + None, + score_mod, + None, + block_mask.as_tuple(), + scale, + {}, + (static_head_dim,), + (), + ) + + @torch.compile(backend="aot_eager", fullgraph=True, dynamic=True) + def compiled_bw(query, key, value, fwd_out, lse, grad_out): + return torch.ops.higher_order.flex_attention_backward( + query, + key, + value, + fwd_out, + lse, + grad_out, + None, + score_mod, + None, + block_mask.as_tuple(), + scale, + {}, + (query.shape[-1],), + (), + ) + + with torch.no_grad(): + ref_dq, ref_dk, ref_dv, ref_buffer_grads = compiled_literal_bw( + q, + k, + v, + out.detach(), + logsumexp.detach(), + backward_grad, + ) + compiled_dq, compiled_dk, compiled_dv, compiled_buffer_grads = compiled_bw( + q, + k, + v, + out.detach(), + logsumexp.detach(), + backward_grad, + ) + + torch.testing.assert_close(compiled_dq, ref_dq) + torch.testing.assert_close(compiled_dk, ref_dk) + torch.testing.assert_close(compiled_dv, ref_dv) + self.assertEqual(compiled_buffer_grads, ref_buffer_grads) @supported_platform def test_tensor_subclass_dispatch_order(self, device): diff --git a/test/inductor/test_flex_decoding.py b/test/inductor/test_flex_decoding.py index d172e4b565187..c546e2cbb8774 100644 --- a/test/inductor/test_flex_decoding.py +++ b/test/inductor/test_flex_decoding.py @@ -1,5 +1,4 @@ # Owner(s): ["module: inductor"] -# flake8: noqa: B950 import functools import sys @@ -2026,7 +2025,7 @@ def causal_offset_mask(b, h, q_idx, kv_idx): return causal_offset_mask - def noop(score, b, h, q_idx, kv_idx): # noqa: F841 + def noop(score, b, h, q_idx, kv_idx): return score mod = generate_causal_offset( diff --git a/test/inductor/test_flex_flash.py b/test/inductor/test_flex_flash.py index fbc7209ec1e12..2c6e1e268d0ed 100644 --- a/test/inductor/test_flex_flash.py +++ b/test/inductor/test_flex_flash.py @@ -25,6 +25,11 @@ from torch.testing._internal.common_cuda import ( IS_SM90, PLATFORM_SUPPORTS_FP8, + SM120OrLater, + SM80OrLater, + SM90OrLater, + xfailIfSM120OrLater, + xfailIfSM12X, xfailIfSM90, ) from torch.testing._internal.common_device_type import ( @@ -39,6 +44,9 @@ ) +IS_SM8X = SM80OrLater and not SM90OrLater + + def _times_two(score, _b, _h, _m, _n): return score * 2 @@ -180,6 +188,21 @@ def score_with_buffer(score, _b, h, _q_idx, _kv_idx): return score_with_buffer +def create_mask_mod_scalar_tensor(device="cuda"): + """mask_mod that captures a 0-dim (scalar) tensor. + + Regression test: loading a 0-dim tensor in CuTeDSL codegen produces a + constant index 0, which was incorrectly passed as a bare Python int to + ssa_to_indexable (expects TensorSSA). See #177813. + """ + offset = torch.tensor(5, dtype=torch.int32, device=device) + + def mask_with_scalar_tensor(_b, _h, q_idx, kv_idx): + return (q_idx + offset) >= kv_idx + + return mask_with_scalar_tensor + + def create_mask_mod_buffer(num_heads=4, dtype=torch.float16, device="cuda"): mask_bias = torch.randn(num_heads, device=device, dtype=dtype) * 0.1 @@ -265,7 +288,9 @@ def _create_block_mask_for_device( dev = torch.device(device) if dev.type == "cuda": major, _ = torch.cuda.get_device_capability(dev) - if major >= 10: + if major == 8: + kv_block = 64 + elif major == 10: q_block *= 2 return create_block_mask( mask_mod, @@ -598,6 +623,10 @@ def mask_case_name(case: MaskModCase): MASK_MOD_CASES = [ + MaskModCase( + "mask_mod_scalar_tensor", + lambda _dtype, device: create_mask_mod_scalar_tensor(device=device), + ), MaskModCase("block_mask_causal", lambda _dtype, _device: _causal_mask), MaskModCase( "block_mask_causal_score_times_two", @@ -782,7 +811,10 @@ def mask_case_name(case: MaskModCase): @unittest.skipIf( not ensure_flash_available(), "Flash attention (CUTE) library is not available" ) +@xfailIfSM12X class TestFlexFlash(InductorTestCase): + # `FlashAttentionForwardSm120` does not have `apply_score_mod`. + @xfailIfSM120OrLater @decorateIf( unittest.expectedFailure, lambda params: params["case"].requires_grad and IS_SM90, @@ -822,16 +854,42 @@ def test_flash_attention_backward_deterministic_score_mod_cases( device=device, requires_grad=case.requires_grad, ) - with DeterministicGuard(True): - flash_vs_triton( - q, - k, - v, - score_mod=case.score_mod_factory(dtype, device) - if case.score_mod_factory - else None, + if SM120OrLater: + cls, pattern = ( + AttributeError, + r"'FlashAttentionForwardSm120' object has no attribute 'apply_score_mod'", ) + if "gqa_basic" in case.name or "mqa_basic" in case.name: + cls, pattern = ValueError, "Operation creation failed" + with self.assertRaisesRegex(cls, pattern), DeterministicGuard(True): + flash_vs_triton( + q, + k, + v, + score_mod=case.score_mod_factory(dtype, device) + if case.score_mod_factory + else None, + ) + else: + with DeterministicGuard(True): + flash_vs_triton( + q, + k, + v, + score_mod=case.score_mod_factory(dtype, device) + if case.score_mod_factory + else None, + ) + @xfailIfSM120OrLater + @decorateIf( + unittest.expectedFailure, + lambda params: ( + IS_SM8X + and not params["case"].requires_grad + and params["case"].score_mod_factory is not None + ), + ) @dtypes(torch.float16, torch.bfloat16) @parametrize("case", MASK_MOD_CASES, name_fn=mask_case_name) def test_flash_attention_mask_mod_cases(self, device, dtype, case): @@ -898,8 +956,12 @@ def test_flash_attention_backward_deterministic_block_mask_raises( with DeterministicGuard(True): with self.assertRaisesRegex( - BackendCompilerFailed, - "Deterministic backward for flex_attention with block_mask using the FLASH backend", + (BackendCompilerFailed if not SM120OrLater else AssertionError), + ( + "Deterministic backward for flex_attention with block_mask using the FLASH backend" + if not SM120OrLater + else "Block sparsity not supported on SM 12.0" + ), ): out = compiled_fn( q, @@ -915,6 +977,20 @@ def test_flash_attention_backward_deterministic_block_mask_raises( ) out.sum().backward() + @decorateIf( + unittest.expectedFailure, + lambda params: ( + SM120OrLater + and params["case"].name + in { + "backward_block_mask_causal", + "backward_block_mask_causal_rel_bias", + "backward_block_mask_causal_score_squared", + "backward_block_mask_causal_score_times_two", + "mask_mod_view_buffer", + } + ), + ) @dtypes(torch.float16, torch.bfloat16) @parametrize("case", DETERMINISTIC_MASK_MOD_CASES, name_fn=mask_case_name) def test_flash_attention_backward_deterministic_warn_only_block_mask( @@ -963,6 +1039,35 @@ def test_flash_attention_backward_deterministic_warn_only_block_mask( ): out.sum().backward() + @decorateIf( + unittest.expectedFailure, + lambda params: ( + SM120OrLater + and params["case"].name + in { + "gqa_block_mask_causal", + "gqa_block_mask_causal_per_head", + "backward_gqa_block_mask_causal", + "backward_gqa_block_mask_causal_per_head", + "mqa_block_mask_causal", + "mqa_block_mask_causal_per_head", + "backward_mqa_block_mask_causal", + "backward_mqa_block_mask_causal_per_head", + } + ), + ) + @decorateIf( + unittest.expectedFailure, + lambda params: ( + IS_SM8X + and not params["case"].requires_grad + and params["case"].block_mask_num_heads == 1 + ), + ) + @decorateIf( + unittest.expectedFailure, + lambda params: SM120OrLater and params["case"].name.endswith("_dim128"), + ) @dtypes(torch.float16, torch.bfloat16) @parametrize("case", GQA_MQA_BLOCK_MASK_CASES, name_fn=mask_case_name) def test_flash_attention_gqa_mqa_block_mask_cases(self, device, dtype, case): @@ -1000,6 +1105,7 @@ def test_flash_attention_gqa_mqa_block_mask_cases(self, device, dtype, case): ), ) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_kernel_called(self, device, dtype): q, k, v = create_test_tensors(dtype=dtype, device=device) @@ -1088,6 +1194,7 @@ def score_mod_with_capture(score, b, h, q_idx, kv_idx): kernel_options={"BACKEND": "FLASH"}, ).sum().backward() + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_backward_kernel_called(self, device, dtype): q, k, v = create_test_tensors(dim=128, dtype=dtype, device=device) @@ -1113,6 +1220,7 @@ def run_for_profile(): f"Flash attention backward kernel not found. Kernels: {prof_result['kernel_names']}", ) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_backward_forwards_deterministic_flag(self, device, dtype): q, k, v = create_test_tensors(dim=128, dtype=dtype, device=device) @@ -1131,6 +1239,7 @@ def run_for_code(): "Expected deterministic flag to be wired through flash backward", ) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_generates_cute_hash(self, device, dtype): q, k, v = create_test_tensors(dtype=dtype, device=device) @@ -1151,6 +1260,7 @@ def test_flash_attention_generates_cute_hash(self, device, dtype): "Generated code should set __cute_hash__ on score_mod for fast hashing", ) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_attention_fused_qkv_reinterpret_view(self, device, dtype): B, M, H, D = 2, 256, 4, 64 @@ -1170,6 +1280,7 @@ def fn(x, weight): out = compiled_fn(x, weight) self.assertEqual(out.shape, (B, H, M, D)) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_gqa_expand_stride_zero_backward(self, device, dtype): """Test GQA backward with expand()-created K/V tensors (stride=0). @@ -1203,8 +1314,17 @@ def test_gqa_expand_stride_zero_backward(self, device, dtype): _causal_mask, batch_size, n_heads, seqlen, seqlen, device=device ) - flash_vs_triton(q, k, v, block_mask=block_mask) + if SM120OrLater: + # note: see [SM120 forward tile selection failure] + with self.assertRaisesRegex( + ValueError, + r"Block sparsity requires sparse_block_size[1]=64 to match tile_n", + ): + flash_vs_triton(q, k, v, block_mask=block_mask) + else: + flash_vs_triton(q, k, v, block_mask=block_mask) + @xfailIfSM120OrLater @dtypes(torch.float16, torch.bfloat16) def test_flash_backend_raises_on_grad_logsumexp(self, device, dtype): from torch._dynamo.exc import BackendCompilerFailed @@ -1238,6 +1358,8 @@ def test_flash_backend_raises_on_return_max_scores(self, device, dtype): kernel_options={"BACKEND": "FLASH"}, ) + # 'FlashAttentionForwardSm120' object has no attribute 'apply_score_mod' + @xfailIfSM120OrLater @decorateIf( unittest.expectedFailure, lambda params: IS_SM90, @@ -1299,10 +1421,13 @@ def _run_dynamic_test( def _flash_triton_dynamic(self, q, k, v, **kwargs): flash_vs_triton(q, k, v, dynamic=True, **kwargs) + # sm120: AttributeError: 'NoneType' object has no attribute '_trait' + @xfailIfSM120OrLater def test_dynamic_seq_len_no_score_mod(self): """Test dynamic sequence lengths without score_mod.""" self._run_dynamic_test(seq_lens=[128, 256, 512]) + @xfailIfSM120OrLater def test_dynamic_seq_len_inline_literal(self): """Test dynamic sequence lengths with inline literal score_mod.""" @@ -1311,6 +1436,7 @@ def score_mod(score, _b, _h, _q, _k): self._run_dynamic_test(seq_lens=[128, 256, 512], score_mod=score_mod) + @xfailIfSM120OrLater def test_dynamic_seq_len_captured_tensor_buffer(self): """Test dynamic sequence lengths with captured tensor buffer (ALiBi-style).""" num_heads = 4 @@ -1323,6 +1449,7 @@ def alibi_score_mod(score, b, h, q_idx, kv_idx): self._run_dynamic_test(seq_lens=[128, 256, 512], score_mod=alibi_score_mod) + @xfailIfSM120OrLater def test_dynamic_seq_len_with_block_mask(self): """Test dynamic sequence lengths with block mask.""" @@ -1335,6 +1462,7 @@ def block_mask_factory(seq_len): seq_lens=[128, 256, 512], block_mask_factory=block_mask_factory ) + @xfailIfSM120OrLater def test_dynamic_batch_size(self): """Test dynamic batch sizes.""" for batch_size in [1, 2, 4, 8]: @@ -1343,11 +1471,14 @@ def test_dynamic_batch_size(self): ) self._flash_triton_dynamic(q, k, v) + @xfailIfSM120OrLater @xfailIfSM90 def test_dynamic_backward(self): """Test backward with dynamic sequence lengths.""" self._run_dynamic_test(seq_lens=[128, 256, 512], requires_grad=True) + # 'FlashAttentionForwardSm120' object has no attribute 'apply_score_mod' + @xfailIfSM120OrLater @xfailIfSM90 def test_dynamic_backward_with_score_mod(self): """Test backward with score_mod and dynamic sequence lengths.""" @@ -1359,6 +1490,7 @@ def score_mod(score, _b, _h, _q, _k): seq_lens=[128, 256, 512], score_mod=score_mod, requires_grad=True ) + @xfailIfSM120OrLater def test_dynamic_backward_with_block_mask(self): """Test backward with block mask and dynamic sequence lengths.""" major, _ = torch.cuda.get_device_capability() @@ -1376,6 +1508,7 @@ def block_mask_factory(seq_len): requires_grad=True, ) + @xfailIfSM120OrLater def test_dynamic_gqa(self): """Test GQA with dynamic sequence lengths.""" q_heads, kv_heads = 8, 2 @@ -1389,6 +1522,7 @@ def test_dynamic_gqa(self): ) self._flash_triton_dynamic(q, k, v, score_mod=None, block_mask=None) + @xfailIfSM120OrLater def test_dynamic_mqa(self): """Test MQA with dynamic sequence lengths.""" q_heads, kv_heads = 8, 1 @@ -1402,6 +1536,7 @@ def test_dynamic_mqa(self): ) self._flash_triton_dynamic(q, k, v) + @xfailIfSM120OrLater def test_dynamic_non_divisible_seq_len(self): """Test non-block-divisible sequence lengths with dynamic shapes.""" for seq_len in [127, 255, 383, 511, 513]: @@ -1410,6 +1545,7 @@ def test_dynamic_non_divisible_seq_len(self): ) self._flash_triton_dynamic(q, k, v) + @xfailIfSM120OrLater def test_dynamic_asymmetric_qkv_lengths(self): """Test asymmetric Q and KV lengths with dynamic shapes.""" test_cases = [(256, 512), (512, 256), (128, 1024)] @@ -1453,6 +1589,7 @@ def score_mod(score, _b, _h, _q, _k): q, k, v, score_mod=score_mod, kernel_options={"BACKEND": "FLASH"} ) + @xfailIfSM120OrLater def test_captured_float_works_with_static(self): """Test that captured Python float works with dynamic=False.""" val = 2.0 # Captured float @@ -1468,6 +1605,7 @@ def score_mod(score, _b, _h, _q, _k): ) self.assertEqual(out.shape, q.shape) + @xfailIfSM120OrLater def test_dynamic_mask_from_input_lengths_single_graph(self): """Dynamic mask creation driven by input lengths should stay single-graph.""" counter = CompileCounterWithBackend("inductor") @@ -1514,6 +1652,7 @@ def forward(self, x, input_lengths): counter.frame_count, 1, f"Expected 1 graph, got {counter.frame_count}" ) + @xfailIfSM120OrLater def test_dynamic_free_symbol_mask_single_graph(self): """Free-symbol dense mask under dynamic=True should not recompile.""" counter = CompileCounterWithBackend("inductor") @@ -1550,6 +1689,7 @@ def run(q, k, v, block_mask): counter.frame_count, 1, f"Expected 1 graph, got {counter.frame_count}" ) + @xfailIfSM120OrLater def test_dynamic_max_autotune_with_block_mask(self): """Dynamic=True with max-autotune should succeed for FLASH backend.""" q, k, v = create_test_tensors( @@ -1576,6 +1716,7 @@ def test_dynamic_max_autotune_with_block_mask(self): ) self.assertEqual(out.shape, q.shape) + @xfailIfSM120OrLater @xfailIfSM90 def test_dynamic_captured_buffer_varying_heads(self): """Dynamic head_count with captured tensor buffer under FLASH/TRITON parity.""" @@ -1708,6 +1849,8 @@ def test_hierarchical_indexer_rank_mismatch(self): indexer([b]) self.assertIn("Rank mismatch", str(ctx.exception)) + # 'FlashAttentionForwardSm120' object has no attribute 'apply_score_mod' + @xfailIfSM120OrLater @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf( not ensure_flash_available(), "Flash attention (CUTE) library not available" @@ -1751,6 +1894,8 @@ def score_mod_2d(score, b, h, q_idx, kv_idx): f"Expected '{expected_pattern}' in generated code.\nExcerpt:\n{code_str[:2000]}", ) + # 'FlashAttentionForwardSm120' object has no attribute 'apply_score_mod' + @xfailIfSM120OrLater @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf( not ensure_flash_available(), "Flash attention (CUTE) library not available" @@ -1796,6 +1941,7 @@ def score_mod_3d(score, b, h, q_idx, kv_idx): f"Expected '{expected_pattern}' in generated code.\nExcerpt:\n{code_str[:2000]}", ) + @xfailIfSM120OrLater @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @unittest.skipIf( not ensure_flash_available(), "Flash attention (CUTE) library not available" diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index a1203c909338e..4e85153a6a0f5 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -157,6 +157,7 @@ def recipaddmul_op(x, y, z): torch._foreach_abs, torch._foreach_sqrt, torch._foreach_rsqrt, + torch._foreach_clone, *foreach_map_un_ops_under_test, ] diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index c91c8b186e7c4..ba694a0612820 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -18,7 +18,6 @@ PLATFORM_SUPPORTS_FP8, PLATFORM_SUPPORTS_MX_GEMM, SM100OrLater, - SM90OrLater, ) from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, @@ -27,12 +26,13 @@ skipCUDAIf, ) from torch.testing._internal.common_quantized import ceil_div, to_blocked -from torch.testing._internal.common_utils import parametrize, xfailIf +from torch.testing._internal.common_utils import parametrize, skipIfXpu, xfailIf from torch.testing._internal.inductor_utils import ( _quantize_blockwise, _quantize_rowwise, _quantize_tensorwise, _to_fp8_saturated, + GPU_TYPE, HAS_CPU, HAS_CUDA_AND_TRITON, is_big_gpu, @@ -43,7 +43,7 @@ torch.set_float32_matmul_precision("high") -f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ and XPU devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+, XPU and CPU devices" def _is_cuda_device(device) -> bool: @@ -138,7 +138,7 @@ def fp8_matmul_unwrapped(x): x_shape = (16, 16) x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) - y_fp8 = compiled_fp8_matmul(x) # noqa: F841 + y_fp8 = compiled_fp8_matmul(x) x_shape = (15, 16) x = torch.rand(*x_shape, device=device, dtype=dtype).to(e4m3_type) @@ -168,6 +168,9 @@ def fp8_cast(x): torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) torch.testing.assert_close(y1_fp8, x, rtol=5e-1, atol=5e-1) + @skipIfXpu( + msg="Conversions between float8_e5m2 and float8_e4m3fn is not supported, torch-xpu-ops: 2888" + ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_bad_cast(self, device): def fp8_cast(x, dtype): @@ -331,7 +334,7 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 ) - @onlyCUDA + @onlyOn(["cuda", "xpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) @parametrize("shape", ("4,2048,4096",)) @@ -342,7 +345,7 @@ def test_layernorm_fp8_quant_benchmark( shape: str, keepdim: bool, ): - float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device="cuda") + float8_dtype = _fix_fp8_dtype_for_rocm(float8_dtype, device=GPU_TYPE) shape = [int(dim) for dim in shape.split(",")] batch_size, sequence_length, hidden_size = shape @@ -373,11 +376,11 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") x_shape = (batch_size, sequence_length, hidden_size) - x = torch.rand(*x_shape, device="cuda", dtype=torch.half) - scale = torch.tensor(0.2, device="cuda", dtype=torch.float) + x = torch.rand(*x_shape, device=GPU_TYPE, dtype=torch.half) + scale = torch.tensor(0.2, device=GPU_TYPE, dtype=torch.float) - amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) - amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) + amax_buffer_compiled = torch.zeros((1), device=GPU_TYPE, dtype=torch.half) + amax_buffer = torch.zeros((1), device=GPU_TYPE, dtype=torch.half) _ = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) compiled_latency = utils.do_bench_using_profiling( functools.partial(compiled_ln_fp8_quant, x, scale, amax_buffer_compiled) @@ -396,10 +399,8 @@ def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): f"LN only Inductor: {ln_latency}ms." ) - @unittest.skipIf( - not SM90OrLater or torch.version.hip, "PDL requires NVIDIA SM 9.0+" - ) - @onlyOn(["cuda", "xpu"]) + @skipCUDAIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @onlyOn(["cuda", "xpu", "cpu"]) def test_scaled_mm_pdl_handles_none_bias(self, device): dtype_float8 = _fix_fp8_dtype_for_rocm(torch.float8_e4m3fn, device) M, K, N = 32, 64, 32 @@ -440,7 +441,7 @@ class TestFP8Lowering(TestCase): @parametrize( "persistent_matmul", [False, True] if has_triton_tma_device() else [False] ) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) def test_tensorwise_scaling( self, dtype: torch.dtype, @@ -465,8 +466,8 @@ def test_tensorwise_scaling( if has_bias: bias = torch.randn(N, device=device, dtype=torch.bfloat16) - # if "xpu" in device and use_fast_accum: - self.skipTest("XPU does not support use_fast_accum=True for now") + if "xpu" in device and use_fast_accum: + self.skipTest("XPU does not support use_fast_accum=True for now") # quantize weight (prior to inference) w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) @@ -517,7 +518,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) def test_scaled_mm_preserves_strides(self, device): """Test that scaled_mm preserves stride ordering through a custom pass.""" @@ -696,7 +697,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @parametrize("shape", ("16,16,32", "16,32,32", "1024,1024,512")) @parametrize("has_bias", (False, True)) @parametrize("use_fast_accum", (False, True)) @@ -989,7 +990,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @@ -1299,7 +1300,7 @@ def forward( ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) @parametrize("N", (16, 2048)) @@ -1371,7 +1372,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, "Not supported on non B200") def test_mx_fp8_max_autotune(self, device): M, K, N = 128, 32, 128 @@ -1455,7 +1456,7 @@ def linear(x, w_t_fp8, w_inverse_scale, bias): ) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) def test_unacceptable_scale_dims_rowwise_scaling(self, device): dtype: torch.dtype = torch.bfloat16 dtype_float8 = torch.float8_e4m3fn diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index 1e1b0410c2ea7..6ec6c45c71795 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -15,7 +15,14 @@ PLATFORM_SUPPORTS_FUSED_ATTENTION, SM80OrLater, ) -from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_CPU_CAPABILITY_SVE256, + IS_LINUX, + skipIfXpu, + TEST_WITH_ROCM, + xfailIf, +) from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_CPU, @@ -85,9 +92,9 @@ def _check_common( x.requires_grad = training if not self.use_static_shapes: - torch._dynamo.mark_dynamic(args2[0], 0) - torch._dynamo.mark_dynamic(args2[1], 0) - torch._dynamo.mark_dynamic(args2[2], 0) + for i in range(min(3, len(args2))): + if isinstance(args2[i], torch.Tensor): + torch._dynamo.mark_dynamic(args2[i], 0) dropout_arg = [training] if has_dropout else [] torch.manual_seed(1234) @@ -182,11 +189,6 @@ def dot_prod_attention( ) def _test_insignificant_strides(self): - if self.device == "xpu": - self.skipTest( - "The operator 'aten::_scaled_dot_product_efficient_attention'" - " is not currently implemented for the XPU device. " - ) f32 = torch.float32 # repro taken from https://github.com/pytorch/pytorch/issues/124289 @@ -363,10 +365,30 @@ def sfdp_pattern_5_v2(query, key, value): ) return attn_weight @ value + def sfdp_pattern_5_v3(query, key, value): + # https://github.com/pytorch/pytorch/issues/174049. + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool, device=query.device + ).tril(diagonal=0) + attn_mask = attn_mask.masked_fill( + torch.logical_not(attn_mask), -float("inf") + ) + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / (math.sqrt(query.size(-1)) + 0.1)) + + attn_mask, + dim=-1, + ) + return attn_weight @ value + self._check_common(sfdp_pattern_5_v1, contains=False) self._check_common(checkpoint_wrapper(sfdp_pattern_5_v1), contains=False) self._check_common(sfdp_pattern_5_v2, contains=False) self._check_common(checkpoint_wrapper(sfdp_pattern_5_v2), contains=False) + self._check_common(sfdp_pattern_5_v3, contains=False) + self._check_common( + checkpoint_wrapper(sfdp_pattern_5_v3), + contains=False, + ) def _test_sdpa_rewriter_6(self): def sfdp_pattern_6(query, key, value, training): @@ -383,10 +405,31 @@ def sfdp_pattern_6(query, key, value, training): attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training) return attn_weight @ value + def sfdp_pattern_6_v2(query, key, value, training): + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool, device=query.device + ).tril(diagonal=0) + attn_mask = attn_mask.masked_fill( + torch.logical_not(attn_mask), -float("inf") + ) + attn_weight = torch.softmax( + (query @ key.transpose(-2, -1) / (math.sqrt(query.size(-1)) + 0.1)) + + attn_mask, + dim=-1, + ) + attn_weight = torch.nn.functional.dropout(attn_weight, 0.5, training) + return attn_weight @ value + self._check_common(sfdp_pattern_6, contains=False, has_dropout=True) self._check_common( checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True ) + self._check_common(sfdp_pattern_6_v2, contains=False, has_dropout=True) + self._check_common( + checkpoint_wrapper(sfdp_pattern_6_v2), + contains=False, + has_dropout=True, + ) def _test_sdpa_rewriter_7(self): def sfdp_pattern_7(query, key, value, training): @@ -401,6 +444,18 @@ def sfdp_pattern_7(query, key, value, training): attn_weight = attn_weight.to(torch.float16) return attn_weight @ v + def sfdp_pattern_7_v2(query, key, value, training): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / (math.sqrt(q.size(-1)) + 0.1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + # Set to False + attn_weight = torch.dropout(attn_weight, 0.00000000001, training) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), @@ -414,6 +469,19 @@ def sfdp_pattern_7(query, key, value, training): override_check_equal=True, atol=2e-3, ) + args = ( + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + ) + self._check_common( + sfdp_pattern_7_v2, + args, + contains=False, + has_dropout=True, + override_check_equal=True, + atol=2e-3, + ) args = ( torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), @@ -428,6 +496,19 @@ def sfdp_pattern_7(query, key, value, training): override_check_equal=True, atol=2e-3, ) + args = ( + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + ) + self._check_common( + checkpoint_wrapper(sfdp_pattern_7_v2), + args, + contains=SM80OrLater, + has_dropout=True, + override_check_equal=True, + atol=2e-3, + ) def _test_sdpa_rewriter_8(self): def sfdp_pattern_8(query, key, value): @@ -440,12 +521,28 @@ def sfdp_pattern_8(query, key, value): attn_weight = attn_weight.to(torch.float16) return attn_weight @ v + def sfdp_pattern_8_v2(query, key, value): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + div = q @ k.transpose(-2, -1) / (math.sqrt(q.size(-1)) + 0.1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common(sfdp_pattern_8, args, atol=2e-3) + args = ( + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + ) + self._check_common(sfdp_pattern_8_v2, args, atol=2e-3, contains=False) args = ( torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), @@ -453,6 +550,17 @@ def sfdp_pattern_8(query, key, value): torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), ) self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3) + args = ( + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + ) + self._check_common( + checkpoint_wrapper(sfdp_pattern_8_v2), + args, + atol=2e-3, + contains=False, + ) def _test_sdpa_rewriter_9(self): def sfdp_pattern_9(query, key, value, training): @@ -468,6 +576,19 @@ def sfdp_pattern_9(query, key, value, training): attn_weight = attn_weight.to(torch.float16) return attn_weight @ v + def sfdp_pattern_9_v2(query, key, value, training): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / (math.sqrt(q.size(-1)) + 0.1) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + # very low dropout to make test pass + attn_weight = torch.dropout(attn_weight, 0.00000000001, training) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), @@ -481,6 +602,19 @@ def sfdp_pattern_9(query, key, value, training): override_check_equal=True, atol=2e-3, ) + args = ( + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + ) + self._check_common( + sfdp_pattern_9_v2, + args, + contains=SM80OrLater, + has_dropout=True, + override_check_equal=True, + atol=2e-3, + ) args = ( torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), @@ -494,6 +628,19 @@ def sfdp_pattern_9(query, key, value, training): override_check_equal=True, atol=2e-3, ) + args = ( + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + ) + self._check_common( + checkpoint_wrapper(sfdp_pattern_9_v2), + args, + contains=SM80OrLater, + has_dropout=True, + override_check_equal=True, + atol=2e-3, + ) def _test_sdpa_rewriter_10(self): def sfdp_pattern_10(query, key, value): @@ -507,12 +654,34 @@ def sfdp_pattern_10(query, key, value): attn_weight = attn_weight.to(torch.float16) return attn_weight @ v + def sfdp_pattern_10_v2(query, key, value): + q = query.permute(0, 2, 1, 3) + k = key.permute(0, 2, 1, 3) + v = value.permute(0, 2, 1, 3) + q = q / (math.sqrt(q.size(-1)) + 0.1) + div = q @ k.transpose(-2, -1) + div = div.to(torch.float32) + attn_weight = torch.softmax(div, dim=-1) + attn_weight = attn_weight.to(torch.float16) + return attn_weight @ v + args = ( torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), ) self._check_common(sfdp_pattern_10, args, atol=2e-3) + args = ( + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=self.device, dtype=torch.half), + ) + self._check_common( + sfdp_pattern_10_v2, + args, + atol=2e-3, + contains=False, + ) args = ( torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), @@ -520,6 +689,17 @@ def sfdp_pattern_10(query, key, value): torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), ) self._check_common(checkpoint_wrapper(sfdp_pattern_10), args, atol=2e-3) + args = ( + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + torch.randn((2, 8, 4, 16), device=GPU_TYPE, dtype=torch.half), + ) + self._check_common( + checkpoint_wrapper(sfdp_pattern_10_v2), + args, + atol=2e-3, + contains=False, + ) def _test_pattern_fails_with_tensor_factor(self): # https://github.com/pytorch/pytorch/issues/99124 @@ -888,6 +1068,37 @@ def dot_prod_attention( value.permute([0, 2, 1, 3]), ) + def dot_prod_attention_v2( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal_mask: torch.Tensor, + ) -> torch.Tensor: + query = query.permute([0, 2, 1, 3]) + key = key.permute([0, 2, 1, 3]) + value = value.permute([0, 2, 1, 3]) + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + (), + math.sqrt(value.size(-1)) + 0.1, + dtype=query.dtype, + device=query.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + return ( + ( + torch.nn.functional.dropout( + attn_weights.softmax(dim=-1), 0.0 + ).matmul(value) + ), + key.permute([0, 2, 1, 3]), + value.permute([0, 2, 1, 3]), + ) + tensor_shape = (4, 2, 16, 32) causal_mask = torch.ones(2, 2, dtype=torch.bool, device=self.device).tril( diagonal=0 @@ -906,6 +1117,15 @@ def dot_prod_attention( check_train=False, ) + self._check_common( + dot_prod_attention_v2, + args1=args, + atol=2e-3, + contains=False, + has_dropout=False, + check_train=False, + ) + # also check batch_size=1 because the graph is slightly different tensor_shape = (1, 2, 16, 32) args = [ @@ -921,6 +1141,14 @@ def dot_prod_attention( has_dropout=False, check_train=False, ) + self._check_common( + dot_prod_attention_v2, + args1=args, + contains=False, + atol=2e-3, + has_dropout=False, + check_train=False, + ) def _test_sdpa_rewriter_19(self): def dot_prod_attention( @@ -952,6 +1180,35 @@ def dot_prod_attention( inplace=False, ).matmul(value) + def dot_prod_attention_v2( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal_mask: torch.Tensor, + attn_mask: torch.Tensor, + training, + ) -> torch.Tensor: + attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) + inv_scale = torch.full( + (), + math.sqrt(value.size(-1)) + 0.1, + dtype=attn_weights.dtype, + device=attn_weights.device, + ) + attn_weights = attn_weights.div(inv_scale) + causal_mask_value = torch.full( + (), torch.finfo(query.dtype).min, dtype=query.dtype, device=query.device + ) + attn_weights = torch.where(causal_mask, attn_weights, causal_mask_value) + attn_weights = attn_weights + attn_mask + attn_weights = attn_weights.softmax(dim=-1).type(value.dtype) + return torch.nn.functional.dropout( + attn_weights, + p=0.4, + training=training, + inplace=False, + ).matmul(value) + tensor_shape = (4, 2, 16, 32) causal_mask = torch.ones(16, 16, dtype=torch.bool, device=self.device).tril( diagonal=0 @@ -971,6 +1228,13 @@ def dot_prod_attention( has_dropout=True, check_train=False, ) + self._check_common( + dot_prod_attention_v2, + args1=args, + contains=False, + has_dropout=True, + check_train=False, + ) def _test_sdpa_rewriter_20(self): def dot_prod_attention( @@ -1315,6 +1579,32 @@ def dot_prod_attention( check_train=True, ) + def _test_sdpa_rewriter_28(self): + def dot_prod_attention( + qkv: torch.Tensor, + training: bool, + ) -> torch.Tensor: + q, k, v = qkv.permute(1, 0, 2, 4, 3).unbind(0) + scores = torch.matmul(q, k.transpose(-2, -1)) + scores = scores.mul(0.2) + attn_weights = scores.softmax(dim=-1) + attn_weights = torch.nn.functional.dropout( + attn_weights, p=0.1, training=training + ) + return attn_weights.matmul(v) + + tensor_shape = (2, 3, 4, 16, 8) + args = [ + torch.randn(tensor_shape, dtype=torch.half, device=self.device), + ] + self._check_common( + dot_prod_attention, + args1=args, + contains=False, + has_dropout=True, + check_train=True, + ) + if HAS_XPU_AND_TRITON or (HAS_CUDA_AND_TRITON and PLATFORM_SUPPORTS_FUSED_ATTENTION): @@ -1400,6 +1690,9 @@ class SDPAPatternRewriterGpuTests(TestSDPAPatternRewriterTemplate): test_cache_sdpa_constraint_shared_kv_gpu = ( TestSDPAPatternRewriterTemplate._test_cache_sdpa_constraint_shared_kv ) + test_sdpa_rewriter_28_gpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_28 + ) if HAS_XPU_AND_TRITON: test_sdpa_rewriter_25_gpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_25 @@ -1461,7 +1754,10 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): TestSDPAPatternRewriterTemplate._test_pattern_fails_with_reuse ) test_sdpa_rewriter_2_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_2 - test_sdpa_rewriter_5_cpu = TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 + # see https://github.com/pytorch/pytorch/issues/177244 + test_sdpa_rewriter_5_cpu = xfailIf(IS_ARM64 and IS_CPU_CAPABILITY_SVE256)( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_5 + ) test_pattern_fails_with_tensor_factor_cpu = ( TestSDPAPatternRewriterTemplate._test_pattern_fails_with_tensor_factor ) @@ -1480,7 +1776,8 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_sdpa_rewriter_13_cpu = functools.partialmethod( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_13, dtype=torch.float32 ) - test_sdpa_rewriter_14_cpu = functools.partialmethod( + # see https://github.com/pytorch/pytorch/issues/177244 + test_sdpa_rewriter_14_cpu = xfailIf(IS_ARM64 and IS_CPU_CAPABILITY_SVE256)( TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_14 ) test_sdpa_rewriter_15_cpu = functools.partialmethod( @@ -1519,6 +1816,9 @@ class SDPAPatternRewriterCpuTests(TestSDPAPatternRewriterTemplate): test_cache_sdpa_constraint_shared_kv_cpu = ( TestSDPAPatternRewriterTemplate._test_cache_sdpa_constraint_shared_kv ) + test_sdpa_rewriter_28_cpu = functools.partialmethod( + TestSDPAPatternRewriterTemplate._test_sdpa_rewriter_28 + ) class SDPAPatternRewriterCpuDynamicTests(SDPAPatternRewriterCpuTests): use_static_shapes = False diff --git a/test/inductor/test_fuzzer.py b/test/inductor/test_fuzzer.py index d13662c6f666e..fb0838fed912e 100644 --- a/test/inductor/test_fuzzer.py +++ b/test/inductor/test_fuzzer.py @@ -181,7 +181,7 @@ def myfn(): self.assertEqual(len(new_results), 1) self.assertEqual( set(key_1.keys()), - {j for i in new_results for j in i} # noqa: SIM118 + {j for i in new_results for j in i} - set(MODULE_DEFAULTS["torch._dynamo.config"].keys()), ) diff --git a/test/inductor/test_fxir_backend.py b/test/inductor/test_fxir_backend.py index 3f95800007354..6e6c022bbce88 100644 --- a/test/inductor/test_fxir_backend.py +++ b/test/inductor/test_fxir_backend.py @@ -1043,7 +1043,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): cond = torch.ops.higher_order.cond(arg0_1, true_graph_0, false_graph_0, (arg1_1, arg2_1)); arg0_1 = true_graph_0 = false_graph_0 = arg1_1 = arg2_1 = None buf1 = cond[0] buf2 = cond[1]; cond = None - return [buf1, buf2]""", # noqa: B950 + return [buf1, buf2]""", ) def test_dims_dynamic_outer_static_padded_inner(self): diff --git a/test/inductor/test_gpu_cpp_wrapper.py b/test/inductor/test_gpu_cpp_wrapper.py index 714a0e65d2ce1..029bf1b8f2894 100644 --- a/test/inductor/test_gpu_cpp_wrapper.py +++ b/test/inductor/test_gpu_cpp_wrapper.py @@ -1,13 +1,24 @@ # Owner(s): ["module: inductor"] import itertools +import os +import subprocess import sys +import tempfile import unittest from typing import NamedTuple import torch from torch._inductor import config +from torch._inductor.codegen.common import TritonScratchWorkspace +from torch._inductor.codegen.cpp_wrapper_gpu import DeferredTritonCallWrapper +from torch._inductor.codegen.cuda.device_op_overrides import CUDADeviceOpOverrides from torch._inductor.test_case import TestCase as InductorTestCase -from torch.testing._internal.common_utils import slowTest +from torch._inductor.utils import IndentedBuffer +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + slowTest, +) from torch.testing._internal.inductor_utils import GPU_TYPE, RUN_GPU @@ -59,6 +70,7 @@ def test_fn(): comp = torch.compile( options={ "cpp_wrapper": True, + "cpp_wrapper_build_separate": True, "aot_inductor.debug_intermediate_value_printer": "2", } )(test_fn) @@ -77,6 +89,268 @@ def test_fn(x, s): _, code = test_torchinductor.run_and_get_cpp_code(compiled, x, 3) self.assertIn("torch.tensor(arg, device='cpu')", code) + def test_cpp_scratch_scales_with_grid_size_for_tma(self): + if GPU_TYPE != "cuda" or torch.version.hip: + self.skipTest("CUDA-only codegen test") + + scratch_def, scratch_var = CUDADeviceOpOverrides().cpp_scratch( + 0, + TritonScratchWorkspace( + size=256, generate_dtype_str=lambda: "at::ScalarType::Byte" + ), + prefix="global_scratch", + ) + self.assertEqual(scratch_var, "global_scratch_scratch_0") + self.assertIn( + "static_cast(256) * grid_0 * grid_1 * grid_2", scratch_def[0] + ) + + def test_triton_wrapper_scales_scratch_with_num_ctas(self): + if GPU_TYPE != "cuda" or torch.version.hip: + self.skipTest("CUDA-only codegen test") + + class FakeWrapper: + device = "cuda" + + def __init__(self): + self.scratch_spaces = None + + def generate_args_decl( + self, + prefix, + call_args, + arg_types, + arg_signatures, + is_triton_kernel=True, + scratch_spaces=None, + ): + self.scratch_spaces = scratch_spaces + + return "" + + wrapper = FakeWrapper() + prefix = IndentedBuffer() + params = { + "triton_meta": {"signature": {"x": "*fp32"}, "constants": {}}, + "def_args": ["x"], + "call_args": ["x"], + "config": {"num_ctas": 8}, + "num_warps": 4, + "shared_mem": 0, + "global_scratch": 256, + } + + DeferredTritonCallWrapper( + wrapper_name="wrapper", + kernel_name="kernel", + kernel_name_to_body={}, + arg_types=[torch.float32], + ).generate_launch_kernel(prefix, wrapper, "kernel_var", params) + + self.assertEqual(wrapper.scratch_spaces, {"global_scratch": 256 * 8}) + + @parametrize("per_subkernel_blocks", [False, True]) + def test_lazy_compile_combo_kernel_default_config(self, per_subkernel_blocks): + """Lazy compile should use default_config from combo_grid_meta for XBLOCK.""" + if not RUN_GPU: + self.skipTest("GPU not available") + + from unittest.mock import patch + + from torch._inductor.codegen.triton_combo_kernel import ( + DEFAULT_COMBO_BLOCK_SIZE_1D, + ) + from torch._inductor.runtime import triton_lazy_compile as tlc + + captured = {} + original = tlc.run_triton_kernel_with_autotune + + def capture(pending_kernels, kernel_name, stream, args): + result = original(pending_kernels, kernel_name, stream, args) + if "triton_for_fused" in kernel_name: + captured[kernel_name] = result.xblock + return result + + with patch.object(tlc, "run_triton_kernel_with_autotune", side_effect=capture): + params = [torch.randn(1024, device=self.device) for _ in range(4)] + grads = [torch.randn_like(p) for p in params] + + @torch.compile( + options={ + "cpp_wrapper": True, + "triton.autotune_at_compile_time": False, + "combo_kernels": True, + "combo_kernel_per_subkernel_blocks": per_subkernel_blocks, + } + ) + def fn(params, grads): + torch._foreach_add_(params, grads, alpha=-0.1) + + fn(params, grads) + + self.assertTrue(len(captured) > 0, "No combo kernels were lazy-compiled") + for name, xblock in captured.items(): + # When per_subkernel_blocks=False, default_config has a single XBLOCK + # that must be picked up correctly (not the hardcoded fallback of 128). + # When per_subkernel_blocks=True, default_config uses per-subkernel + # XBLOCK_N keys instead, so result.xblock is not used for grid + # computation; just verify compilation succeeded. + if not per_subkernel_blocks: + self.assertEqual( + xblock, + DEFAULT_COMBO_BLOCK_SIZE_1D, + f"{name} got XBLOCK={xblock}, expected {DEFAULT_COMBO_BLOCK_SIZE_1D}", + ) + + def test_cudagraph_no_partition(self): + if not RUN_GPU: + self.skipTest("GPU not available") + + def test_fn(x, s): + return (x + s).sum() + + x = torch.randn(4, device=self.device) + s = 3 + expected = test_fn(x, s) + + comp = torch.compile( + options={ + "cpp_wrapper": True, + "triton.cudagraphs": True, + "graph_partition": False, + } + )(test_fn) + for i in range(3): + res = comp(x, s) + self.assertEqual(res, expected) + + def test_many_args_fold_expression_nesting(self): + if not RUN_GPU: + self.skipTest("GPU not available") + if GPU_TYPE == "xpu": + self.skipTest("ocloc backend compiler crashes with too many kernel args") + + num_params = 130 + params = [torch.randn(64, device=self.device) for _ in range(num_params)] + grads = [torch.randn_like(p) for p in params] + expected = [p.clone() + (-0.1) * g for p, g in zip(params, grads)] + + @torch.compile( + options={ + "cpp_wrapper": True, + "combo_kernels": True, + "combo_kernel_max_num_args": 1000, + } + ) + def fn(params, grads): + torch._foreach_add_(params, grads, alpha=-0.1) + + fn(params, grads) + + for p, e in zip(params, expected): + self.assertEqual(p, e) + + def test_cpp_wrapper_backward_lazy_compile(self): + """Test that options={"cpp_wrapper": True} works with backward pass. + + Backward graphs may be compiled lazily (after compile_fx returns). + The cpp_wrapper triton config (store_cubin, autotune_at_compile_time) + must still be applied. See https://github.com/pytorch/pytorch/issues/178845 + """ + if not RUN_GPU: + self.skipTest("GPU not available") + + def fn(x, output_grad): + layer_norm = torch.nn.LayerNorm(normalized_shape=4).to(self.device) + output = layer_norm(x) + output.backward(output_grad) + return output + + x = torch.randn(2, 3, 4, device=self.device) + output_grad = torch.randn(2, 3, 4, device=self.device) + + opt_fn = torch.compile(options={"cpp_wrapper": True})(fn) + result = opt_fn(x, output_grad) + self.assertEqual(result.shape, x.shape) + + +instantiate_parametrized_tests(TestGpuWrapper) + +# Helper script for test_lazy_compile_kernel_name_collision_across_modules. +# Run as a subprocess so dlopen truly re-runs .so static initializers. +_LAZY_COMPILE_COLLISION_SCRIPT = """\ +import torch +from torch.testing._internal.inductor_utils import GPU_TYPE + +from torch._inductor import config + +config.cpp_wrapper = True +config.triton.autotune_at_compile_time = False + +def fn(x, y, z, w): + a = x.sin() + torch._dynamo.graph_break() + b = (a * y).cos() + torch._dynamo.graph_break() + c = (b * z).sin() + torch._dynamo.graph_break() + d = (c * w).cos() + return d.sum() + +args = [torch.randn(32, device=GPU_TYPE, requires_grad=True) for _ in range(4)] +ref_args = [a.detach().clone().requires_grad_(True) for a in args] +ref = fn(*ref_args) +ref.backward() + +compiled_fn = torch.compile(fn) +res = compiled_fn(*args) +res.backward() + +assert torch.allclose(res.detach(), ref.detach()), f"Forward mismatch: {res} vs {ref}" +for i, (a, r) in enumerate(zip(args, ref_args)): + assert torch.allclose(a.grad, r.grad), f"Grad mismatch for arg {i}" +""" + + +class TestLazyCompileKernelCollision(InductorTestCase): + device = GPU_TYPE + + def test_lazy_compile_kernel_name_collision_across_modules(self): + """The collision manifests when a fresh process loads .so modules from + warm on-disk caches: AOTAutograd cache hits cause both forward and + backward .so to be loaded (static initializers register kernels in + _pending_kernels) before either executes. If two modules share a + kernel name, the global dict collision corrupts the mapping. + + This requires two process invocations because dlopen within a single + process reuses loaded libraries without re-running static initializers. + """ + if not RUN_GPU: + self.skipTest("GPU not available") + + with tempfile.TemporaryDirectory() as cache_dir: + env = { + **os.environ, + "TORCHINDUCTOR_CACHE_DIR": cache_dir, + "INDUCTOR_TEST_DISABLE_FRESH_CACHE": "1", + } + # First run: cold compile, populates on-disk caches. + r1 = subprocess.run( + [sys.executable, "-c", _LAZY_COMPILE_COLLISION_SCRIPT], + capture_output=True, + text=True, + env=env, + ) + self.assertEqual(r1.returncode, 0, f"Cold run failed:\n{r1.stderr[-2000:]}") + # Second run: warm caches trigger the collision without the fix. + r2 = subprocess.run( + [sys.executable, "-c", _LAZY_COMPILE_COLLISION_SCRIPT], + capture_output=True, + text=True, + env=env, + ) + self.assertEqual(r2.returncode, 0, f"Warm run failed:\n{r2.stderr[-2000:]}") + class DynamicShapesGpuWrapperGpuTests(InductorTestCase): device = GPU_TYPE @@ -108,15 +382,22 @@ def test_fn(): "test_mm_plus_mm2_dynamic_shapes": test_torchinductor.TestFailure( ("gpu_wrapper",), is_skip=True ), - # ATen ops: scaled_dot_product_efficient_attention not implemented on XPU. - "test_scaled_dot_product_efficient_attention_xpu": test_torchinductor.TestFailure( - ("gpu_wrapper",), is_skip=False - ), - "test_scaled_dot_product_efficient_attention_xpu_dynamic_shapes": test_torchinductor.TestFailure( - ("gpu_wrapper",), is_skip=False - ), } +# XPU: complex add decomposition can return NotImplemented in cpp_wrapper path, +# which currently surfaces as InductorError in test_add_complex4_xpu_gpu_wrapper. +# Keep this targeted skip to XPU only. +if device_type == "xpu": + test_failures_gpu_wrapper["test_add_complex4_xpu"] = test_torchinductor.TestFailure( + ("gpu_wrapper",), is_skip=True + ) + test_failures_gpu_wrapper["test_add_complex_xpu"] = test_torchinductor.TestFailure( + ("gpu_wrapper",), is_skip=True + ) + test_failures_gpu_wrapper["test_adding_tensor_offsets_xpu"] = ( + test_torchinductor.TestFailure(("gpu_wrapper",), is_skip=True) + ) + # Skip only on CUDA as wrapper dynamic shapes passes on ROCm. # Per https://github.com/pytorch/pytorch/pull/172780 if not torch.version.hip: @@ -189,11 +470,6 @@ class BaseTest(NamedTuple): tests: InductorTestCase = test_torchinductor.GPUTests() check_code: bool = True - # XPU Not implemented yet - XPU_BASE_TEST_SKIP = [ - "test_dynamic_shapes_persistent_reduction_mixed_x_dim", - ] - # Maintain two separate test lists for cuda and cpp for now for item in [ BaseTest("test_add_complex"), @@ -306,8 +582,6 @@ class BaseTest(NamedTuple): tests=test_select_algorithm.TestSelectAlgorithm(), ), ]: - if item.device == "xpu" and item.name in XPU_BASE_TEST_SKIP: - continue make_test_case(item.name, item.device, item.tests, check_code=item.check_code) test_torchinductor.copy_tests( diff --git a/test/inductor/test_gpu_select_algorithm.py b/test/inductor/test_gpu_select_algorithm.py index af101d1113325..1c43906fa49b9 100644 --- a/test/inductor/test_gpu_select_algorithm.py +++ b/test/inductor/test_gpu_select_algorithm.py @@ -52,7 +52,6 @@ def skip_cache(self, choices, name, key, benchmark, hint_override=None): for patcher in [ dynamo_config.patch(verbose=True), - dynamo_config.patch(inline_inbuilt_nn_modules=True), inductor_config.patch( debug=True, max_autotune=True, diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index 884d15869f69d..aa24a63d891df 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -7,7 +7,7 @@ import unittest import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.testing import make_test_cls_with_patches from torch._inductor import config from torch._inductor.codecache import HalideCodeCache @@ -130,7 +130,7 @@ def generate(g): out_ptr0.set_estimates([hl.Range(1024, 1024)]) __name__ == '__main__' and hl.main() - """, # noqa: S101 + """, ), ) a = torch.randn(1024) @@ -204,7 +204,7 @@ def generate(g): tmp1.compute_inline() __name__ == '__main__' and hl.main() - """, # noqa: S101 + """, ), ) a = torch.randn(1024) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 3b9dd05c71b47..e998754908645 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -6,6 +6,7 @@ import sympy import torch +from torch._dynamo.source import ConstantSource from torch._inductor.codegen.cpp import cexpr from torch._inductor.codegen.triton import texpr from torch._inductor.codegen.wrapper import pexpr @@ -22,6 +23,7 @@ from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.utils._sympy.functions import ( FloorDiv, + Identity, Mod, ModularIndexing, PythonMod, @@ -115,11 +117,63 @@ def test_indexing_simplification(self): self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) var_ranges = {i2: 784} expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4) - expected = FloorDiv(ModularIndexing(i2, 1, 28), 7) + # FloorDiv(ModularIndexing(b, d1, m), d2) simplifies to + # ModularIndexing(b, d1*d2, m//d2) when d2 | m + expected = ModularIndexing(i2, 7, 4) self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected) expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4) self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) + def test_floordiv_modularindexing_simplification(self): + sizevars = SizeVarAllocator() + i0 = sympy.Symbol("i0", integer=True, nonneg=True) + + # FloorDiv(ModularIndexing(b, d1, m), d2) -> ModularIndexing(b, d1*d2, m//d2) + # when d2 divides m + self.assertEqual( + sizevars.simplify_with_ranges( + FloorDiv(ModularIndexing(i0, 1, 8192), 128), {} + ), + ModularIndexing(i0, 128, 64), + ) + self.assertEqual( + sizevars.simplify_with_ranges(FloorDiv(ModularIndexing(i0, 2, 120), 6), {}), + ModularIndexing(i0, 12, 20), + ) + # Does NOT simplify when d2 does not divide m + expr = FloorDiv(ModularIndexing(i0, 1, 28), 5) + self.assertEqual(sizevars.simplify_with_ranges(expr, {}), expr) + + # FloorDiv(base, divisor) -> 0 when 0 <= base < divisor + self.assertEqual( + sizevars.simplify_with_ranges(FloorDiv(ModularIndexing(i0, 1, 10), 10), {}), + sympy.S.Zero, + ) + + def test_remove_zero_terms_generalized(self): + sizevars = SizeVarAllocator() + i0 = sympy.Symbol("i0", integer=True, nonneg=True) + i1 = sympy.Symbol("i1", integer=True, nonneg=True) + + # FloorDiv(v + 128*i1, 8192): gcd(128*i1, 8192) = 128 + # Old rule fails (128 != 8192), new rule: v < 128 => drop v + self.assertEqual( + sizevars.simplify_with_ranges(FloorDiv(i0 + 128 * i1, 8192), {i0: 128}), + FloorDiv(128 * i1, 8192), + ) + # v range equals gcd exactly — still safe since v < gcd (strict) + # v=127 max, 127 < 128 + self.assertEqual( + sizevars.simplify_with_ranges(FloorDiv(i0 + 6 * i1, 18), {i0: 6}), + FloorDiv(6 * i1, 18), + ) + # v range exceeds gcd — cannot simplify + expr = FloorDiv(i0 + 128 * i1, 8192) + self.assertEqual( + sizevars.simplify_with_ranges(expr, {i0: 129}), + expr, + ) + def test_indexing_join(self): sizevars = SizeVarAllocator() i0 = sympy.Symbol("i0", integer=True) @@ -367,6 +421,29 @@ def test_print_round(self): texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" ) + def test_print_nan(self): + expr = sympy.nan + self.assertExpectedInline(pexpr(expr), """math.nan""") + self.assertExpectedInline( + cexpr(expr), """std::numeric_limits::quiet_NaN()""" + ) + + def test_print_infinity(self): + expr = sympy.oo + self.assertExpectedInline(pexpr(expr), """math.inf""") + self.assertExpectedInline( + cexpr(expr), + """std::numeric_limits::infinity()""", + ) + + def test_print_negative_infinity(self): + expr = -sympy.oo + self.assertExpectedInline(pexpr(expr), """-math.inf""") + self.assertExpectedInline( + cexpr(expr), + """-std::numeric_limits::infinity()""", + ) + def test_print_integer(self): expr = sympy.S((-1) << 63) self.assertExpectedInline(cexpr(expr), f"""(-1{LONG_SUFFIX} << 63)""") @@ -481,7 +558,7 @@ def test_print_Min_Max(self): expr = f(x, 2 * x, 3 * x) self.assertEqual( texpr(expr), - f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", # noqa: B950 line too long + f"((x) * ((x) {cmp}= (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x))))) + (((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) * ((((2*x) * ((2*x) {cmp}= (3*x)) + (3*x) * ((3*x) {cmp} (2*x)))) {cmp} (x)))", ) self.assertEqual( cexpr(expr), @@ -541,6 +618,138 @@ def test_guard_or_false_le_unbacked_symint_with_check(self): self.assertFalse(sizevars.guard_or_false(sympy.Lt(u0, 10 * u0))) +class TestPrecomputedSizeHinting(InductorTestCase): + """Tests for optimization_hint and guarding_hint_or_throw with PRECOMPUTED_SIZE symbols.""" + + def test_optimization_hint_with_precomputed_size(self): + """Test that optimization_hint correctly resolves PRECOMPUTED_SIZE symbols. + + When a complex expression is replaced with a precomputed size symbol (ps0, ps1, etc.), + optimization_hint must use inv_precomputed_replacements to resolve the symbol + back to its original expression before computing the hint. + """ + sizevars = SizeVarAllocator() + + # Create a backed symbol with a concrete hint value + s0 = sizevars.shape_env.create_symbol(168, source=ConstantSource("s0")) + sizevars.shape_env.var_to_val[s0] = sympy.Integer(168) + sizevars.backed_var_to_val[s0] = sympy.Integer(168) + + # Create a complex expression that would be precomputed + complex_expr = s0 * 8 # Should evaluate to 168 * 8 = 1344 + + # Register the expression as a precomputed size (simulating what Inductor does) + ps_symbol = sizevars.lookup_precomputed_size(complex_expr) + + # Verify the precomputed symbol was created + self.assertIn(complex_expr, sizevars.precomputed_replacements) + self.assertIn(ps_symbol, sizevars.inv_precomputed_replacements) + + # Test optimization_hint resolves the ps symbol correctly + hint = sizevars.optimization_hint(ps_symbol) + expected = 168 * 8 # The concrete value of s0 * 8 + self.assertEqual(hint, expected) + + def test_guarding_hint_or_throw_with_precomputed_size(self): + """Test that guarding_hint_or_throw correctly resolves PRECOMPUTED_SIZE symbols.""" + sizevars = SizeVarAllocator() + + # Create a backed symbol with a concrete hint value + s0 = sizevars.shape_env.create_symbol(42, source=ConstantSource("s0")) + sizevars.shape_env.var_to_val[s0] = sympy.Integer(42) + sizevars.backed_var_to_val[s0] = sympy.Integer(42) + + # Create a complex expression + complex_expr = s0 * 2 + + # Register as precomputed size + ps_symbol = sizevars.lookup_precomputed_size(complex_expr) + + # Test guarding_hint_or_throw resolves correctly + hint = sizevars.guarding_hint_or_throw(ps_symbol) + expected = 42 * 2 + self.assertEqual(hint, expected) + + def test_optimization_hint_with_expression_containing_precomputed_size(self): + """Test optimization_hint with an expression that contains a PRECOMPUTED_SIZE symbol.""" + sizevars = SizeVarAllocator() + + # Create a backed symbol + s0 = sizevars.shape_env.create_symbol(10, source=ConstantSource("s0")) + sizevars.shape_env.var_to_val[s0] = sympy.Integer(10) + sizevars.backed_var_to_val[s0] = sympy.Integer(10) + + # Register s0 * 5 as precomputed (ps0 = 50) + ps_symbol = sizevars.lookup_precomputed_size(s0 * 5) + + # Create an expression using the precomputed symbol: ps0 + 3 + expr = ps_symbol + 3 + + # optimization_hint should resolve ps0 -> s0*5 -> 50, then add 3 -> 53 + hint = sizevars.optimization_hint(expr) + self.assertEqual(hint, 53) + + +class TestOptimizationHintZeroDivision(InductorTestCase): + """Test that optimization_hint handles ZeroDivisionError from ModularIndexing with zero-valued unbacked symbols.""" + + def test_modular_indexing_with_zero_divisor(self): + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # u0 + 1 ensures base != 0 after substitution; u1 is the divisor. + # With fallback=0: u0->0, u1->0, so (0+1) // 0 -> ZeroDivisionError. + # optimization_hint catches ZeroDivisionError and returns fallback. + expr = ModularIndexing(u0 + 1, u1, 4) + hint = sizevars.optimization_hint(expr, fallback=0) + self.assertEqual(hint, 0) + + def test_floor_div_with_zero_divisor(self): + """optimization_hint should not crash when FloorDiv has an unbacked + symbol as divisor that gets substituted with 0.""" + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # With fallback=0: u0->0, u1->0, FloorDiv(0+1, 0) -> ZeroDivisionError. + # optimization_hint catches ZeroDivisionError and returns fallback. + expr = FloorDiv(u0 + 1, u1) + hint = sizevars.optimization_hint(expr, fallback=0) + self.assertEqual(hint, 0) + + def test_modular_indexing_zero_divisor_nonzero_fallback(self): + """When fallback is nonzero, the hint should still not crash.""" + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + u1 = sizevars.shape_env.create_unbacked_symint().node.expr + + # With fallback=8192: u0->8192, u1->8192 + # (8192+1) // 8192 = 1, 1 % 4 = 1 + expr = ModularIndexing(u0 + 1, u1, 4) + hint = sizevars.optimization_hint(expr, fallback=8192) + self.assertEqual(hint, 1) + + +class TestOptimizationHintIdentityExpansion(InductorTestCase): + """Test that optimization_hint expands Identity wrappers after _sub_unbacked_exprs.""" + + def test_identity_wrapped_expr_resolves_to_int(self): + """An expression containing Identity-wrapped constants and an unbacked + symbol should resolve to a concrete int after substitution.""" + sizevars = SizeVarAllocator() + u0 = sizevars.shape_env.create_unbacked_symint().node.expr + + # Mirrors the real bug: -u0 * (-Identity(1) + Identity(0)) + # simplifies to -u0 * (0 - 1) = u0. + # Without expand(identity=True) after _sub_unbacked_exprs, + # subs({u0: fallback}) leaves -Identity(1) + Identity(0) unexpanded, + # causing RuntimeError("Failed to realize expression to int"). + expr = -u0 * (-Identity(sympy.Integer(1)) + Identity(sympy.Integer(0))) + hint = sizevars.optimization_hint(expr, fallback=42) + self.assertEqual(hint, 42) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index d89dbb5b3d2ae..299532a9cee75 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -17,11 +17,7 @@ from torch._inductor.utils import override_lowering, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater, tf32_on_and_off -from torch.testing._internal.common_utils import ( - IS_FBCODE, - skipIfXpu, - TEST_WITH_SLOW_GRADCHECK, -) +from torch.testing._internal.common_utils import IS_FBCODE, TEST_WITH_SLOW_GRADCHECK # Make the helper files in test/ importable @@ -349,37 +345,15 @@ def foo(mod, inp): mod2.b1 = torch.nn.Parameter(torch.rand([15], device=self.device)) mod2.b2 = torch.nn.Parameter(torch.rand([20], device=self.device)) - # not fused - count = 3 if hasattr(mod2, "t3") else 2 - + # fused: weights share same dim 0 (in_features), different dim 1 is OK with torch.no_grad(): out_eager = mod2(inp) out, code = run_and_get_code(foo, mod2, inp) FileCheck().check_not(kernel_invoke).check_count( - mm_invoke, count=count, exactly=True + mm_invoke, count=1, exactly=True ).run(code[0]) self.assertEqual(out_eager, out) - # With inlining of inbuilt nn modules, Dynamo traces the innards of inbuilt - # module and does not modify the eager module. - @torch._dynamo.config.patch(inline_inbuilt_nn_modules=False) - def test_error_on_eager(self): - mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) - - x = torch.rand(3, 3, 32, 32).to(self.device) - - @torch.compile() - def foo(mod, x): - return mod(x) - - with torch.no_grad(): - foo(mod, x) - - with self.assertRaisesRegex( - RuntimeError, "Trying to run Pytorch Eager Module after Dynamo Freezing" - ): - mod(x) - def test_static_indices_cudagraph(self): if self.device != "cuda": return @@ -788,7 +762,6 @@ def foo(mod, inp): mod_eager = mod(x) self.assertEqual(foo(mod, x), mod_eager) - @skipIfXpu @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf( TEST_WITH_SLOW_GRADCHECK, diff --git a/test/inductor/test_inductor_utils.py b/test/inductor/test_inductor_utils.py index 12468a09103b9..2871a579fe577 100644 --- a/test/inductor/test_inductor_utils.py +++ b/test/inductor/test_inductor_utils.py @@ -11,13 +11,19 @@ log = logging.getLogger(__name__) +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) + class TestBench(TestCase): @classmethod def setUpClass(cls): super().setUpClass() - x = torch.rand(1024, 10).cuda().half() - w = torch.rand(512, 10).cuda().half() + x = torch.rand(1024, 10).to(device_type).half() + w = torch.rand(512, 10).to(device_type).half() cls._bench_fn = functools.partial(torch.nn.functional.linear, x, w) def test_benchmarker(self): @@ -32,4 +38,4 @@ def test_do_bench_using_profiling(self): if __name__ == "__main__": - run_tests("cuda") + run_tests(device_type) diff --git a/test/inductor/test_inplace_padding.py b/test/inductor/test_inplace_padding.py index d6aea40a8741f..c80671a1c4b9d 100644 --- a/test/inductor/test_inplace_padding.py +++ b/test/inductor/test_inplace_padding.py @@ -9,11 +9,11 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck -from torch.testing._internal.common_utils import serialTest +from torch.testing._internal.common_utils import serialTest, skipIfXpu from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU, - requires_cuda_with_enough_memory, + requires_gpu_with_enough_memory, ) @@ -112,6 +112,7 @@ def f(x): self.assertEqual(num_inplace_padding(), 1) @inductor_config.patch(cpp_wrapper=True) + @inductor_config.patch("triton.autotune_at_compile_time", True) def test_pad_non_zero_cpp_wrapper(self): def f(x): x = x + 1 @@ -210,7 +211,7 @@ def f(x, y): self.assertEqual(num_inplace_padding(), 0) - @requires_cuda_with_enough_memory(2e10) + @requires_gpu_with_enough_memory(2e10) @inductor_config.patch(force_shape_pad=True) @serialTest() def test_linear_and_cel(self): @@ -255,9 +256,10 @@ def f(x, y): # Enable Max-Autotune to repro this test failure: # https://github.com/pytorch/pytorch/pull/140249#issuecomment-2556079406 - @requires_cuda_with_enough_memory(2e10) + @requires_gpu_with_enough_memory(2e10) @inductor_config.patch(max_autotune=True) @serialTest() + @skipIfXpu(msg="AssertionError: torch-xpu-ops: #2997") def test_linear_and_cel_max_autotune(self): self.test_linear_and_cel() diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 1f6ec150bdcc7..f2435966d5dfb 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -8,7 +8,7 @@ from unittest.mock import patch import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.testing import rand_strided from torch._inductor import config from torch._inductor.codecache import PyCodeCache diff --git a/test/inductor/test_lookup_table.py b/test/inductor/test_lookup_table.py index 8081b15b85f08..316d0f3393a2a 100644 --- a/test/inductor/test_lookup_table.py +++ b/test/inductor/test_lookup_table.py @@ -906,6 +906,10 @@ def validate_choices(choices): @fresh_cache() def test_valid_lookup_table_entry(self, operation): """Test when there's a valid entry for the operation""" + if operation == "addmm" and torch.version.hip: + self.skipTest( + "skipping on ROCm since https://github.com/pytorch/pytorch/issues/179955 didn't skip as expected" + ) k = 256 if operation == "mm_plus_mm" else 64 tensors = self.create_tensors(operation, k=k) @@ -920,9 +924,17 @@ def test_valid_lookup_table_entry(self, operation): config = self.create_basic_config(template_id) self.setup_lookup_table(operation, tensors, [config]) - add_preprocessing_fn( - partial(verify_choice_names, pattern="triton_", expected_count=1) - ) + + # TODO (paulzhan): Update LookupTableChoices to return empty + # (not fallback) when key matches + if operation == "addmm": + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_|addmm", expected_count=2) + ) + else: + add_preprocessing_fn( + partial(verify_choice_names, pattern="triton_", expected_count=1) + ) self.run_model(operation, tensors) @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") @@ -936,13 +948,22 @@ def test_tma_lookup_table_entry(self, operation): ) self.setup_lookup_table(operation, tensors, [config]) - add_preprocessing_fn( - partial( - verify_choice_names, - pattern="triton_mm_persistent_tma_", - expected_count=1, + if operation == "addmm": + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_|addmm", + expected_count=2, + ) + ) + else: + add_preprocessing_fn( + partial( + verify_choice_names, + pattern="triton_mm_persistent_tma_", + expected_count=1, + ) ) - ) self.run_model( operation, tensors, {"triton.enable_persistent_tma_matmul": True} ) @@ -982,17 +1003,23 @@ def test_bias_addmm_lookup_table_entry(self): config = self.create_basic_config(torch._inductor.kernel.mm.aten_bias_addmm.uid) self.setup_lookup_table("addmm", tensors, [config]) - add_preprocessing_fn( - partial(verify_choice_names, pattern="bias_addmm", expected_count=1) - ) + # NOTE: This test passes bias_unexpanded (1D) to the model but sets up + # lookup key with expanded_bias (2D). The shapes differ so lookup will miss. + # We skip choice count verification here - just verify the model runs. - # Run with original unexpanded bias + # Run with expanded bias (stride[0] == 0) so the inductor sees + # bias_addmm-eligible inputs and the lookup key matches. + # Limit backends to ATEN so only the lookup table entry is selected. with inductor_config.patch( - {"max_autotune_gemm": True, "triton.autotune_cublasLt": True} + { + "max_autotune_gemm": True, + "triton.autotune_cublasLt": True, + "max_autotune_gemm_backends": "ATEN", + } ): model = UnifiedModel("addmm") compiled_model = torch.compile(model.to(self.device), mode="max-autotune") - compiled_model(bias_unexpanded, tensors[1], tensors[2]) + compiled_model(expanded_bias, tensors[1], tensors[2]) @unittest.skipIf(not has_triton_tma_device(), "Need TMA support") @fresh_cache() diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index 5e3ce5abb0020..5ed9b0640413b 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -23,7 +23,7 @@ from torch._inductor.utils import is_big_gpu, run_and_get_code, sympy_index_symbol from torch._inductor.virtualized import ops, V from torch.testing import FileCheck -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -416,7 +416,194 @@ def f(x): self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) + def test_reshape_reindexing_for_reduction(self): + """ + RMS norm pattern where reshape(-1, head_dim) changes the loop + decomposition from [M, N] to [M*num_heads, head_dim]. Without + reindexing, the pointwise and reduction have different MemoryDep + indexing and can't fuse. With reindexing, the pointwise's + iteration is re-factored to match the reduction's, enabling fusion + into a single kernel. + """ + + def f(x): + head_dim = 128 + M, N = x.shape + x_reshaped = x.reshape(-1, head_dim) + x_f32 = x_reshaped.float() + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_normed = x_f32 * torch.rsqrt(variance + 1e-5) + return x_normed.reshape(M, N).to(x.dtype) + + if DO_PERF_TEST: + M = 1024 + else: + M = 16 + # Non-contiguous input (simulating a slice from qkv projection) + qkv = torch.randn(M, 10240, dtype=torch.bfloat16) + x = qkv[:, :8192] + + ref = f(x) + actual = torch.compile(f)(x) + torch.testing.assert_close(actual, ref, atol=1e-2, rtol=1e-2) + self.assertEqual(1, metrics.generated_kernel_count) + + if DO_PERF_TEST: + from triton.testing import do_bench + + optf = torch.compile(f) + print(f"ms={do_bench(lambda: optf(x))}") + + def test_reshape_reindexing_transposed_input(self): + """ + Same RMS norm pattern but with a transposed input. The reshape + sees transposed strides, so the reduction's memory access + pattern differs from the contiguous case. Reindexing should + still enable fusion. + """ + + def f(x): + head_dim = 128 + M, N = x.shape + x_reshaped = x.reshape(-1, head_dim) + x_f32 = x_reshaped.float() + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_normed = x_f32 * torch.rsqrt(variance + 1e-5) + return x_normed.reshape(M, N).to(x.dtype) + + M = 16 + # Transposed input: shape [M, 8192] but stride (1, M) + x = torch.randn(8192, M, dtype=torch.bfloat16).T + + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + @inductor_config.patch("loop_ordering_after_fusion", False) + def test_reshape_reindexing_without_loop_ordering(self): + """ + Reindexing should enable fusion even when loop ordering is + disabled. Same RMS norm pattern as test_reshape_reindexing_for_reduction. + """ + + def f(x): + head_dim = 128 + M, N = x.shape + x_reshaped = x.reshape(-1, head_dim) + x_f32 = x_reshaped.float() + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_normed = x_f32 * torch.rsqrt(variance + 1e-5) + return x_normed.reshape(M, N).to(x.dtype) + + M = 16 + qkv = torch.randn(M, 10240, dtype=torch.bfloat16) + x = qkv[:, :8192] + + ref = f(x) + actual = torch.compile(f)(x) + torch.testing.assert_close(actual, ref, atol=1e-2, rtol=1e-2) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_reindex_unfusable_write_read_dep(self): + """ + Grouped quantization where a custom op consumes both the + quantized tensor and the scale, forcing the reduction (amax) + and scale epilogue to fuse into a FusedSchedulerNode while + the quantize pointwise remains separate. + + The FusedSchedulerNode's dep on the input has 3 vars (from the + reduction + scale bodies), while the pointwise has 2 vars. + This num_vars mismatch causes _try_reorder_loops_for_candidates + to return a score based on the shared read (input), short- + circuiting reindexing. The write-read dep prioritization + detects the unfusable dep and returns -1, letting reindexing + fire. + """ + HIDDEN, GROUP_SIZE = 7168, 128 + + @torch.library.custom_op("test::opaque_gemm", mutates_args=()) + def opaque_gemm(x_q: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return torch.zeros( + x_q.shape[0], 1024, device=x_q.device, dtype=torch.bfloat16 + ) + + @opaque_gemm.register_fake + def _(x_q, scale): + return torch.zeros( + x_q.shape[0], 1024, device=x_q.device, dtype=torch.bfloat16 + ) + + def f(x): + grouped = x.reshape(-1, HIDDEN // GROUP_SIZE, GROUP_SIZE).float() + absmax = grouped.abs().amax(dim=-1, keepdim=True) + scale = (absmax / FP8_MAX).clamp(min=1e-6) + x_q = ( + (grouped / scale) + .clamp(-FP8_MAX, FP8_MAX) + .to(torch.float16) + .reshape(x.shape) + ) + scale = scale.squeeze(-1) + return torch.ops.test.opaque_gemm(x_q, scale) + + FP8_MAX = 448.0 # arbitrary clamp range + x = torch.randn(8, HIDDEN, dtype=torch.bfloat16) + self.do_acc_test(f, x) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_reindex_rollback_on_no_improvement(self): + """ + When reindexing is attempted but doesn't improve the fusion + score, the node state should be rolled back. Here a reduction + and pointwise both read from x but at different slices (offsets). + They share the buffer and have the same iteration numel, so + reindexing is attempted, but the offset means deps still don't + match after reindexing. The rollback restores the original + node state so the pointwise isn't left with a wrong iteration + domain. + """ + M, N = 16, 128 + + def f(x): + r = x[:, :N].sum(dim=-1) + p = x[:, N:] * 2 + return r, p + + x = torch.randn(M, N * 2, device=GPU_TYPE) + self.do_acc_test(f, x) + + def test_reshape_reindexing_fused_pointwise(self): + """ + Redecomposition where the pointwise side is a FusedSchedulerNode. + realize() forces ops to materialize as separate nodes, so + add and clamp become two SchedulerNodes that fuse into a + FusedSchedulerNode before the reindexing fuses them with + the reduction. + """ + + def f(x, bias): + head_dim = 128 + M, N = x.shape + y = realize(x + bias) + z = realize(y.clamp(-1, 1)) + x_reshaped = z.reshape(-1, head_dim) + x_f32 = x_reshaped.float() + variance = x_f32.pow(2).mean(dim=-1, keepdim=True) + x_normed = x_f32 * torch.rsqrt(variance + 1e-5) + return x_normed.reshape(M, N).to(x.dtype) + + M = 16 + qkv = torch.randn(M, 10240, dtype=torch.bfloat16) + x = qkv[:, :8192] + bias = torch.randn(8192, dtype=torch.bfloat16) + + ref = f(x, bias) + actual = torch.compile(f)(x, bias) + torch.testing.assert_close(actual, ref, atol=1e-2, rtol=1e-2) + self.assertEqual(1, metrics.generated_kernel_count) + torch._dynamo.reset() + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") + @unittest.skipIf(not SM90OrLater, "sm89 errors out on this test") def test_fp8_cast_and_t(self): """ This test repros the not able to fuses issue in @@ -585,11 +772,19 @@ def f(x, y): out, code = run_and_get_code(f, x, y) - # well when benchmark_kernel flag is on, we have one more .run - # call in the benchmarking code. - FileCheck().check("def call(").check_count( - ".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True - ).run(code[0]) + FileCheck().check("def call(").run(code[0]) + # Prologue fused with mm: 1 kernel. Unfused: 2 kernels (expand+add + mm). + # With benchmark_kernel, add 1 for the benchmarking code path. + base_expected = 1 + int(inductor_config.benchmark_kernel) + run_count = code[0].count(".run(") + self.assertGreaterEqual( + run_count, base_expected, "Expected at least one kernel launch" + ) + self.assertLessEqual( + run_count, + base_expected + 1, + "Prologue fusion produces 1 kernel; unfused produces 2", + ) @inductor_config.patch( { @@ -664,6 +859,69 @@ def f(x): self.do_acc_test(f, x) self.assertEqual(0, metrics.num_loop_reordering) + def test_qknorm_rope_fusion(self): + """ + When qknorm (RMS norm) is followed by RoPE which uses cat, the cat + inputs read from the same buffers. Pointwise cat should be used so + everything fuses into a single kernel. + """ + B, H, S, D = 4, 8, 128, 64 + + def rms_norm(x, weight): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + 1e-6) + return x * weight + + def apply_rope(x, freqs_cos, freqs_sin): + half = x.shape[-1] // 2 + x1 = x[..., :half] + x2 = x[..., half:] + out1 = x1 * freqs_cos - x2 * freqs_sin + out2 = x2 * freqs_cos + x1 * freqs_sin + return torch.cat([out1, out2], dim=-1) + + def f(q, norm_weight, freqs_cos, freqs_sin): + q = rms_norm(q, norm_weight) + return apply_rope(q, freqs_cos, freqs_sin) + + q = torch.randn(B, H, S, D) + norm_weight = torch.randn(D) + freqs_cos = torch.randn(1, 1, S, D // 2) + freqs_sin = torch.randn(1, 1, S, D // 2) + + self.do_acc_test(f, q, norm_weight, freqs_cos, freqs_sin) + self.assertEqual(1, metrics.generated_kernel_count) + + def test_qknorm_interleaved_rope_fusion(self): + """ + Interleaved RoPE (stack + flatten) should also fuse with qknorm. + """ + B, H, S, D = 4, 8, 128, 64 + + def rms_norm(x, weight): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + 1e-6) + return x * weight + + def apply_interleaved_rope(x, cos, sin): + pairs = x.reshape(*x.shape[:-1], -1, 2) + a, b = pairs[..., 0], pairs[..., 1] + out_real = a * cos - b * sin + out_imag = a * sin + b * cos + return torch.stack([out_real, out_imag], dim=-1).flatten(-2) + + def f(q, norm_weight, cos, sin): + q = rms_norm(q, norm_weight) + return apply_interleaved_rope(q, cos, sin) + + q = torch.randn(B, H, S, D) + norm_weight = torch.randn(D) + cos = torch.randn(1, 1, S, D // 2) + sin = torch.randn(1, 1, S, D // 2) + + self.do_acc_test(f, q, norm_weight, cos, sin) + self.assertEqual(1, metrics.generated_kernel_count) + @inductor_config.patch( { @@ -1329,6 +1587,70 @@ def embedding_1d(idx, w): self.assertEqual(out, expected) +class TestSplitIterationRanges(MockSchedulerTest): + """Unit tests for SIMDKernel._split_iteration_ranges.""" + + def test_exact_match(self): + """Groups exactly match lengths — no splitting needed.""" + from torch._inductor.codegen.simd import SIMDKernel + + new_ranges, getters = SIMDKernel._split_iteration_ranges( + [sympy.Integer(4), sympy.Integer(8)], + [[sympy.Integer(4)], [sympy.Integer(8)]], + ) + self.assertEqual(len(new_ranges), 2) + self.assertEqual(len(new_ranges[0]), 1) + self.assertEqual(len(new_ranges[1]), 1) + + def test_two_way_split(self): + """A single large dimension splits across two groups.""" + from torch._inductor.codegen.simd import SIMDKernel + + new_ranges, getters = SIMDKernel._split_iteration_ranges( + [sympy.Integer(4), sympy.Integer(8)], + [[sympy.Integer(32)], []], + ) + # 32 should split into 4 * 8 across the two groups + self.assertEqual(len(new_ranges), 2) + + def test_groups_exhausted_raises_cant_split(self): + """When all groups are consumed but sizes remain, CantSplit is raised.""" + from torch._inductor.codegen.simd import CantSplit, SIMDKernel + + # groups=[1, 2, 2] can only absorb 2 sizes of 2 (consuming groups 1 and 2), + # leaving the third size=2 with no group to map to. + with self.assertRaises(CantSplit): + SIMDKernel._split_iteration_ranges( + [sympy.Integer(1), sympy.Integer(2), sympy.Integer(2)], + [[], [sympy.Integer(2), sympy.Integer(2), sympy.Integer(2)]], + ) + + def test_single_group_multiple_sizes(self): + """Multiple sizes fitting within a single group.""" + from torch._inductor.codegen.simd import SIMDKernel + + # groups=[8], lengths=[[2, 2, 2], []] — all 3 sizes fit in group 0 + new_ranges, getters = SIMDKernel._split_iteration_ranges( + [sympy.Integer(8)], + [[sympy.Integer(2), sympy.Integer(2), sympy.Integer(2)], []], + ) + self.assertEqual(len(new_ranges), 1) + self.assertEqual(len(new_ranges[0]), 3) + + def test_size_one_skipped(self): + """Dimensions of size 1 produce a zero-constant getter.""" + from torch._inductor.codegen.simd import SIMDKernel + + new_ranges, getters = SIMDKernel._split_iteration_ranges( + [sympy.Integer(4)], + [[sympy.Integer(1), sympy.Integer(4)], []], + ) + # Size-1 dim should not consume any range + self.assertEqual(len(getters[0]), 2) + # The first getter should return 0 for any input + self.assertEqual(getters[0][0]([sympy.Integer(99)]), sympy.Integer(0)) + + class TestIndexInversion(TestCase): @classmethod def setUpClass(cls): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 85836ccb47d86..e1bcae986a83d 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -15,6 +15,8 @@ from unittest import mock from unittest.mock import patch +import sympy + import torch import torch._inductor.async_compile from torch import multiprocessing as mp, nn @@ -38,7 +40,7 @@ from torch._inductor.graph import GraphLowering from torch._inductor.ir import Buffer, ChoiceCaller, FixedLayout, FlexibleLayout from torch._inductor.kernel.mm_plus_mm import aten_mm_plus_mm -from torch._inductor.runtime.triton_heuristics import CachingAutotuner +from torch._inductor.runtime.triton_heuristics import CachingAutotuner, pointwise from torch._inductor.scheduler import Scheduler from torch._inductor.select_algorithm import ( add_feedback_saver, @@ -64,7 +66,8 @@ XPUMMTemplateConfigHeuristic, XPUPersistentTMATemplateConfigHeuristic, ) -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater +from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, @@ -94,7 +97,7 @@ from torch._inductor.virtualized import V from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck -from torch.testing._internal.common_utils import MI300_ARCH, runOnRocmArch, skipIfXpu +from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import ( get_func_call, get_kernel_launch, @@ -176,6 +179,42 @@ def mm_plus_mm(a, b, c, d): ): torch.compile(mm_plus_mm, dynamic=dynamic)(a, b, c, d) + def test_max_autotune_includes_max_autotune_pointwise_configs(self): + """ + Verify that `max_autotune` includes all pointwise configs from + `max_autotune_pointwise` for 1D, 2D, and 3D pointwise kernels. + """ + triton_meta = {"device": object()} + inductor_meta_common = {"autotune_pointwise": False} + + for size_hints in ( + {"x": 2048}, + {"x": 128, "y": 256}, + {"x": 64, "y": 64, "z": 64}, + ): + with self.subTest(size_hints=size_hints): + max_autotune_configs = pointwise( + size_hints, + triton_meta=triton_meta, + inductor_meta={**inductor_meta_common, "max_autotune": True}, + return_configs=True, + ) + max_autotune_pointwise_configs = pointwise( + size_hints, + triton_meta=triton_meta, + inductor_meta={ + **inductor_meta_common, + "max_autotune_pointwise": True, + }, + return_configs=True, + ) + + self.assertEqual( + len(max_autotune_configs), + len(max_autotune_pointwise_configs), + "max_autotune should include all pointwise configs from max_autotune_pointwise", + ) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) @@ -253,6 +292,183 @@ def mm(a, b): torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + def test_use_triton_tma_template_rejects_descriptor_shapes_exceeding_int32(self): + from torch._inductor.utils import use_triton_tma_template + + int32_max = torch.iinfo(torch.int32).max + mat1 = mock.Mock() + mat1.get_size.return_value = [128, int32_max + 1] + mat2 = mock.Mock() + mat2.get_size.return_value = [int32_max + 1, 128] + output_layout = mock.Mock() + output_layout.size = [128, 128] + + with ( + config.patch({"triton.enable_persistent_tma_matmul": True}), + mock.patch("torch._inductor.utils.can_use_tma", return_value=True), + ): + self.assertFalse( + use_triton_tma_template(mat1, mat2, output_layout=output_layout) + ) + + mat1.get_size.return_value = [128, int32_max] + mat2.get_size.return_value = [int32_max, 128] + + with ( + config.patch({"triton.enable_persistent_tma_matmul": True}), + mock.patch("torch._inductor.utils.can_use_tma", return_value=True), + ): + self.assertTrue( + use_triton_tma_template(mat1, mat2, output_layout=output_layout) + ) + + def test_descriptor_shape_fits_in_int32_uses_expected_guarding(self): + from torch._inductor.utils import _descriptor_shape_fits_in_int32 + + gm = make_fx(lambda: torch.zeros(2, 3))() + graph = GraphLowering(gm) + size0 = sympy.Symbol("s0", integer=True, nonnegative=True) + size1 = sympy.Symbol("s1", integer=True, nonnegative=True) + condition = sympy.And( + sympy.Le(size0, torch.iinfo(torch.int32).max), + sympy.Le(size1, torch.iinfo(torch.int32).max), + ) + + with V.set_graph_handler(graph): + with ( + mock.patch.object( + V.graph.sizevars, "statically_known_true", return_value=True + ) as statically_known_true, + mock.patch.object( + V.graph.sizevars, "guard_or_false", return_value=True + ) as guard_or_false, + ): + self.assertTrue( + _descriptor_shape_fits_in_int32( + [128, size0, size1], add_guards=False + ) + ) + statically_known_true.assert_called_once_with(condition) + guard_or_false.assert_not_called() + + with ( + mock.patch.object( + V.graph.sizevars, "statically_known_true", return_value=True + ) as statically_known_true, + mock.patch.object( + V.graph.sizevars, "guard_or_false", return_value=True + ) as guard_or_false, + ): + self.assertTrue( + _descriptor_shape_fits_in_int32( + [128, size0, size1], add_guards=True + ) + ) + guard_or_false.assert_called_once_with(condition) + statically_known_true.assert_not_called() + + @unittest.skipIf(not torch.version.hip, "ROCM only") + @parametrize("a_transposed", (False, True)) + @parametrize("b_transposed", (False, True)) + @parametrize("dynamic", (False, True)) + def test_max_autotune_regular_mm_persistent( + self, + a_transposed: bool, + b_transposed: bool, + dynamic: bool, + ): + def mm(a, b): + a = a.repeat(8, 8) + b = b.repeat(8, 8) + + if a_transposed: + a = a.T + if b_transposed: + b = b.T + + return torch.mm(a, b) + + M, N, K = 21, 31, 11 + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) + + with config.patch( + { + "max_autotune": True, + "triton.enable_persistent_tma_matmul": "1", + "triton.native_matmul": False, + "test_configs.autotune_choice_name_regex": "mm_persistent", + } + ): + c_actual, code = run_and_get_code(torch.compile(mm, dynamic=dynamic), a, b) + c_expected = mm(a, b) + + # Verify that we are using the non-TMA persistent implementation + FileCheck().check("triton_tem_fused_mm").check("NUM_SMS").run(code[0]) + + torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + + @unittest.skipIf(not torch.version.hip, "ROCM only") + @parametrize("a_transposed", (False, True)) + @parametrize("b_transposed", (False, True)) + @parametrize("dynamic", (False, True)) + def test_max_autotune_regular_addmm_persistent( + self, + a_transposed: bool, + b_transposed: bool, + dynamic: bool, + ): + def addmm(x, a, b): + x = x.repeat(8) + a = a.repeat(8, 8) + b = b.repeat(8, 8) + + if a_transposed: + a = a.T + if b_transposed: + b = b.T + + return torch.addmm(x, a, b) + + M, N, K = 21, 31, 11 + a = ( + torch.randn(*((K, M) if a_transposed else (M, K))) + .to(torch.float16) + .to(GPU_TYPE) + ) + b = ( + torch.randn(*((N, K) if b_transposed else (K, N))) + .to(torch.float16) + .to(GPU_TYPE) + ) + x = torch.randn(N).to(torch.float16).to(GPU_TYPE) + + with config.patch( + { + "max_autotune": True, + "triton.enable_persistent_tma_matmul": "1", + "triton.native_matmul": False, + "test_configs.autotune_choice_name_regex": "mm_persistent", + } + ): + c_actual, code = run_and_get_code( + torch.compile(addmm, dynamic=dynamic), x, a, b + ) + c_expected = addmm(x, a, b) + + # Verify that we are using the non-TMA persistent implementation + FileCheck().check("triton_tem_fused_addmm").check("NUM_SMS").run(code[0]) + + torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) @@ -328,6 +544,75 @@ def mock_get_tma_workspace_arg(*args, **kwargs): mm_tma_heuristic.mm_configs = original_tma_configs mm_heuristic.mm_configs = original_mm_configs + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + def test_workspace_size_bytes_accounts_for_dtype(self): + """workspace_size passed to benchmark request must be in bytes, not elements.""" + import sympy + + from torch._inductor.codegen.common import WorkspaceZeroMode + from torch._inductor.utils import get_dtype_size + + count = 1024 + dtype = torch.float32 + expected_bytes = count * get_dtype_size(dtype) + + fake_ws = WorkspaceArg( + count=sympy.Integer(count), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + device=torch.device(GPU_TYPE), + outer_name="test_ws", + dtype=dtype, + ) + + captured_sizes = [] + orig_init = TritonBenchmarkRequest.__init__ + + def spy_init(self, *args, **kwargs): + captured_sizes.append(kwargs.get("workspace_size")) + orig_init(self, *args, **kwargs) + + M, K, N = 64, 64, 64 + a = torch.randn(M, K, device=GPU_TYPE, dtype=torch.bfloat16) + b = torch.randn(K, N, device=GPU_TYPE, dtype=torch.bfloat16) + + mm_tma_heuristic = CUDAPersistentTMATemplateConfigHeuristic() + mm_heuristic = CUDAMMTemplateConfigHeuristic() + original_tma_configs = mm_tma_heuristic.mm_configs + original_mm_configs = mm_heuristic.mm_configs + + try: + mm_heuristic.mm_configs = [] + mm_tma_heuristic.mm_configs = [GemmConfig(128, 128, 64, 4, 8, group_m=8)] + + with ( + config.patch( + { + "max_autotune_gemm": True, + "max_autotune_gemm_backends": "TRITON", + "triton.enable_persistent_tma_matmul": True, + } + ), + fresh_cache(), + patch( + "torch._inductor.template_heuristics.triton.get_tma_workspace_arg", + return_value=fake_ws, + ), + patch.object(TritonBenchmarkRequest, "__init__", spy_init), + ): + torch._dynamo.reset() + torch.compile(torch.mm, mode="max-autotune-no-cudagraphs")(a, b) + + finally: + mm_tma_heuristic.mm_configs = original_tma_configs + mm_heuristic.mm_configs = original_mm_configs + + ws_sizes = [s for s in captured_sizes if s is not None] + self.assertTrue(len(ws_sizes) > 0, "No workspace benchmark requests created") + for size in ws_sizes: + self.assertEqual(size, expected_bytes) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) @@ -476,7 +761,7 @@ def mm(a, b): a = a.repeat(8, 8) b = b.repeat(8, 8) - torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.maybe_mark_dynamic(a, 0) with config.patch( { @@ -491,6 +776,41 @@ def mm(a, b): torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2) + @unittest.skipIf( + not has_triton_tma_device(), "Need device-side TMA support in Triton" + ) + def test_persistent_tma_epilogue_fusion_store_cache(self): + # Regression test: when epilogue fusion runs with TMA store, the + # store_cache must be updated so that a subsequent epilogue load from + # the same buffer hits the cache. Otherwise remove_kernel_local_buffers + # strips the buffer pointer from the kernel signature, causing a + # NameError at Triton compile time. + def f(a, b): + a = a.repeat(8, 8) + b = b.repeat(8, 8) + mm = torch.mm(a, b) + return mm.relu() + + M, N, K = 21, 31, 11 + a = torch.randn(M, K, dtype=torch.float16, device=GPU_TYPE) + b = torch.randn(K, N, dtype=torch.float16, device=GPU_TYPE) + + with config.patch( + { + "max_autotune": True, + "epilogue_fusion": True, + "triton.enable_persistent_tma_matmul": "1", + "triton.native_matmul": False, + "triton.enable_template_tma_store": True, + "triton.disallow_failing_autotune_kernels_TESTING_ONLY": True, + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + actual = torch.compile(f)(a, b) + expected = f(a, b) + + torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) + @parametrize("dynamic", (False, True)) def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool): """ @@ -507,6 +827,34 @@ def mm(a, b): with config.patch({"max_autotune": True}): torch.compile(mm, dynamic=dynamic)(a, b) + @fresh_cache() + def test_addmm_1d_bias_no_reinterpret_tensor(self): + """ + Verify that aten addmm with 1D bias does not wrap bias in reinterpret_tensor. + This ensures cublasLt uses its optimized bias epilogue (requires dim==1). + """ + + def addmm(x, a, b): + return torch.addmm(x, a, b) + + x = torch.randn(100).to(GPU_TYPE) + a = torch.randn(100, 10).to(GPU_TYPE) + b = torch.randn(10, 100).to(GPU_TYPE) + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "ATEN", + } + ): + Y_compiled, code = run_and_get_code(torch.compile(addmm), x, a, b) + Y = addmm(x, a, b) + torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2) + + # Verify addmm is called without reinterpret_tensor on bias + FileCheck().check("addmm").run(code[0]) + self.assertNotIn("addmm(reinterpret_tensor", code[0]) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) @@ -639,7 +987,7 @@ def addmm(x, a, b): a = a.repeat(8, 8) b = b.repeat(8, 8) - torch._dynamo.mark_dynamic(a, 0) + torch._dynamo.maybe_mark_dynamic(a, 0) with config.patch( { @@ -1201,6 +1549,39 @@ def f(x, y): act = f(x, y) torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) + def test_broadcast_batch_bmm(self): + # Batch size > 1 with stride[0]=0 to exercise the broadcast batch + # guard that skips the Triton bmm template for stride-0 inputs. + x = rand_strided((4, 32, 64), (0, 64, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided( + (4, 64, 16), (1024, 16, 1), dtype=torch.bfloat16, device=GPU_TYPE + ) + + @torch.compile(mode="max-autotune") + def f(x, y): + return torch.bmm(x, y) + + ref = torch.bmm(x, y) + act = f(x, y) + torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) + + def test_broadcast_batch_baddbmm(self): + # Batch size > 1 with stride[0]=0 to exercise the broadcast batch + # guard that skips the Triton bmm template for stride-0 inputs. + x = rand_strided((4, 32, 64), (0, 64, 1), dtype=torch.bfloat16, device=GPU_TYPE) + y = rand_strided( + (4, 64, 16), (1024, 16, 1), dtype=torch.bfloat16, device=GPU_TYPE + ) + inp = torch.randn(4, 32, 16, dtype=torch.bfloat16, device=GPU_TYPE) + + @torch.compile(mode="max-autotune") + def f(inp, x, y): + return torch.baddbmm(inp, x, y) + + ref = torch.baddbmm(inp, x, y) + act = f(inp, x, y) + torch.testing.assert_close(act, ref, atol=2e-2, rtol=1e-2) + @unittest.skipIf( config.triton.native_matmul, "native matmul and Triton template both have accuracy fail (2.2%)", @@ -1370,7 +1751,7 @@ def check_divisors(code): ).run(code[0]) else: FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_0.run" + "triton_.*_fused_.*.run" ).check("decompose_k").run(code[0]) check_divisors(code) torch.testing.assert_close(out, a @ b, atol=atol, rtol=rtol) @@ -1384,7 +1765,7 @@ def check_divisors(code): ).run(code[0]) else: FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_mm_0.run" + "triton_.*_fused_.*.run" ).check("decompose_k").run(code[0]) check_divisors(code) torch.testing.assert_close( @@ -1521,7 +1902,7 @@ def f(a, b): out.backward() FileCheck().check("extern_kernels.bmm_dtype").check_regex( - "triton_.*_fused_0.run" + "triton_.*_fused_.*.run" ).check("decompose_k").check_regex(r"s[0-9]+ = s[0-9]+").check_regex( r"256\*s[0-9]+" ).check_regex("s[0-9]+ = 8").run( @@ -1857,20 +2238,32 @@ def f(a, b): a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16) b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16) + expected = torch.bmm(a.float(), b.float()) + with config.patch( + max_autotune=False, + max_autotune_gemm_backends="ATEN", + ): + compiled_f = torch.compile(f) + out, code = run_and_get_code(compiled_f, a, b) + FileCheck().check("extern_kernels.bmm_dtype").run(code[0]) + self.assertEqual(out, expected, atol=1e-3, rtol=1e-3) + + @unittest.skipIf(config.cpp_wrapper, "out_dtype override not supported for AOTI") + def test_triton_bmm_out_dtype(self): + def f(a, b, out_dtype=torch.float32): + return torch.bmm(a, b, out_dtype=out_dtype) + + a = torch.randn(2, 3, 4, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(2, 4, 5, device=GPU_TYPE, dtype=torch.float16) + expected = torch.bmm(a.float(), b.float()) with config.patch( max_autotune=True, max_autotune_gemm_backends="TRITON", ): compiled_f = torch.compile(f) - with self.assertRaisesRegex( - torch._inductor.exc.InductorError, - r"LoweringException: NoValidChoicesError: No choices to select", - ): - out, code = run_and_get_code(compiled_f, a, b) - - compiled_f = torch.compile(f) - out, code = run_and_get_code(compiled_f, a, b) - FileCheck().check("extern_kernels.bmm_dtype").run(code[0]) + out, code = run_and_get_code(compiled_f, a, b, out_dtype=torch.float32) + FileCheck().check("triton_tem_fused_bmm").run(code[0]) + self.assertEqual(out, expected, atol=1e-3, rtol=1e-3) def test_triton_template_generated_code_cache_key(self): generate_and_load_args = len( @@ -1978,7 +2371,7 @@ def func_test1(x, y, z, m): 'num_consumer_groups':0,'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False, 'transpose_discontiguous_tensor_descriptors_override':None, 'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32', - 'BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True}, + 'OUT_DTYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32,'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True}, 'hint_override':None,'triton_meta':None}""" expected = expected.replace("cuda", GPU_TYPE) @@ -2019,7 +2412,7 @@ def func_test1(x, y, z, m): 'layout':"[[s77,s94],[s94,1],torch.float32,device(type='cuda',index=0),0]",'num_consumer_groups':0, 'num_buffers_warp_spec':0,'epilogue_fn_hash':'identity','tma_store':False, 'transpose_discontiguous_tensor_descriptors_override':None, - 'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32, + 'kwargs':{'EVEN_K':False,'USE_FAST_ACCUM':False,'ACC_TYPE':'tl.float32','OUT_DTYPE':'tl.float32','BLOCK_M':16,'BLOCK_N':32, 'BLOCK_K':16,'GROUP_M':8,'ALLOW_TF32':True},'hint_override':None,'triton_meta':None}""" expected = expected.replace("cuda", GPU_TYPE) self.assertExpectedInline( @@ -2184,7 +2577,6 @@ def misses(): self.assertEqual(misses(), 4) @fresh_cache() - @skipIfXpu @unittest.skipIf( config.cpp_wrapper, "decompose_k not supported for cpp_wrapper yet" ) @@ -2655,6 +3047,218 @@ def fn(a, b, c): FileCheck().check("triton_tem_fused").run(code[0]) + @fresh_cache() + @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "Test requires CUDA") + @unittest.skipIf( + not SM90OrLater, "Requires SM90+ (H100/B200) for sufficient GPU memory" + ) + @largeTensorTest("10 GB", device=GPU_TYPE) + def test_max_autotune_mm_large_input_tensor_int64_indexing(self): + """ + Test mm with input tensor exceeding 2^31 elements. + Regression test for https://github.com/pytorch/pytorch/issues/171389 + When input tensor storage exceeds 2^31 elements, tl.arange() must be + cast to INDEX_DTYPE (int64) to avoid integer overflow in pointer arithmetic. + """ + + def mm(a, b): + return torch.mm(a, b) + + M, K, N = 1280, 65536, 65536 + a = torch.randn(M, K, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(K, N, device=GPU_TYPE, dtype=torch.float16) + + self.assertTrue( + b.numel() > 2**31 - 1, + f"Test requires tensor with >2^31 elements, got {b.numel()}", + ) + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "test_configs.autotune_choice_name_regex": r"^triton_mm_", + } + ): + result = torch.compile(mm)(a, b) + + torch.testing.assert_close(result, torch.mm(a, b), rtol=1e-2, atol=1e-2) + + @fresh_cache() + @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "Test requires CUDA") + @largeTensorTest("6 GB", device=GPU_TYPE) + def test_max_autotune_mm_large_storage_offset_i64_indexing(self): + """ + Test mm with input having dynamic storage offset exceeding i32 range. + When a dynamic-shaped input is a slice with offset proportional to + the dynamic dim (e.g., offset=70000*s6), the ks parameter in the + triton template signature must use i64 to avoid overflow in pointer + arithmetic like `A = arg_A + 70000*ks0`. + """ + + def mm(x, w): + batch = x.shape[0] // 8 + a = x[7 * batch :] + return torch.mm(a, w) + + K, N = 10000, 32 + batch = 32768 + x = torch.randn(8 * batch, K, device=GPU_TYPE, dtype=torch.bfloat16) + w = torch.randn(K, N, device=GPU_TYPE, dtype=torch.bfloat16) + + expected_offset = 7 * batch * K + self.assertTrue( + expected_offset > 2**31 - 1, + f"Test requires offset > i32_max, got {expected_offset}", + ) + + torch._dynamo.mark_dynamic(x, 0) + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "test_configs.autotune_choice_name_regex": r"^triton_mm_", + } + ): + result = torch.compile(mm)(x, w) + + a = x[7 * batch :] + torch.testing.assert_close(result, torch.mm(a, w), rtol=1e-2, atol=1e-2) + + @fresh_cache() + @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "Test requires CUDA") + @unittest.skipIf( + not SM90OrLater, "Requires SM90+ (H100/B200) for sufficient GPU memory" + ) + @largeTensorTest("10 GB", device=GPU_TYPE) + def test_max_autotune_mm_large_output_tensor_int32_overflow(self): + """ + Test mm with output tensor exceeding 2^32 elements. + Regression test for https://github.com/pytorch/pytorch/issues/171389 + When M * N >= 2^32, the early exit check `if M * N == 0` can overflow + to 0 in int32 arithmetic, causing the kernel to return immediately + with all-zero output. + """ + + def mm(a, b): + return torch.mm(a, b) + + M, K, N = 65536, 32, 65536 + a = torch.randn(M, K, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(K, N, device=GPU_TYPE, dtype=torch.float16) + + self.assertTrue( + M * N >= 2**32, + f"Test requires M*N >= 2^32 for overflow, got {M * N}", + ) + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "test_configs.autotune_choice_name_regex": r"^triton_mm_", + } + ): + result = torch.compile(mm)(a, b) + + torch.testing.assert_close(result, torch.mm(a, b), rtol=1e-2, atol=1e-2) + + @fresh_cache() + @skipIfXpu + @unittest.skipIf(TEST_WITH_ROCM, "Test requires CUDA") + @unittest.skipIf( + not SM90OrLater, "Requires SM90+ (H100/B200) for sufficient GPU memory" + ) + @largeTensorTest("10 GB", device=GPU_TYPE) + def test_max_autotune_mm_persistent_tma_large_input_tensor_int64_indexing(self): + """ + Test persistent TMA mm with input tensor exceeding 2^31 elements. + Regression test for https://github.com/pytorch/pytorch/issues/171389. + Triton TMA descriptors require 32-bit block offsets even when the + surrounding kernel uses int64 indexing. + """ + + def mm(a, b): + return torch.mm(a, b) + + M, K, N = 1280, 65536, 65536 + a = torch.randn(M, K, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(K, N, device=GPU_TYPE, dtype=torch.float16) + + self.assertTrue( + b.numel() > 2**31 - 1, + f"Test requires tensor with >2^31 elements, got {b.numel()}", + ) + + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "TRITON", + "triton.enable_persistent_tma_matmul": "1", + "triton.native_matmul": False, + "test_configs.autotune_choice_name_regex": "mm_persistent_tma", + } + ): + result = torch.compile(mm)(a, b) + + torch.testing.assert_close(result, torch.mm(a, b), rtol=1e-2, atol=1e-2) + + @fresh_cache() + @config.patch( + { + "max_autotune": True, + "test_configs.max_mm_configs": 1, + } + ) + def test_deferred_layout_constraint_reinterpret_3d(self): + batch, m, k, n = 32, 40, 1053, 40 + batch_stride = 42176 + + a = torch.randn(batch, m, k, dtype=torch.bfloat16, device=GPU_TYPE) + b = torch.empty_strided( + size=(batch, k * n), + stride=(batch_stride, 1), + dtype=torch.bfloat16, + device=GPU_TYPE, + ) + b.copy_(torch.randn_like(b)) + c = torch.randn(batch, n, m, dtype=torch.bfloat16, device=GPU_TYPE) + idx0 = torch.tensor([0], device=GPU_TYPE) + idx1 = torch.tensor([0], device=GPU_TYPE) + value = torch.zeros(1, dtype=torch.bfloat16, device=GPU_TYPE) + + def fn(a, b, c, idx0, idx1, value): + # Mirror the 2D regression first: a simple pointwise op gives us a + # rank-2 FlexibleLayout. Then force that 2D value to realize before + # reinterpreting it as the rank-3 view from the original repro. + b_flex_2d = b + 0 + b_flex_2d = aten.index_put.default(b_flex_2d, (idx0, idx1), value, False) + b_view_3d = b_flex_2d.view(batch, k, n) + return ( + torch.bmm(a, b_view_3d).to(torch.float32), + b_flex_2d + 1.0, + torch.bmm(b_view_3d, c).to(torch.float32), + ) + + with ( + mock.patch( + "torch._inductor.autotune_process.run_autotune_in_subprocess", + mock_benchmark_choice_wrapper(aten_time=1.0, triton_time=0.1), + ), + mock.patch.object( + AlgorithmSelectorCache, + "benchmark_choice", + mock_benchmark_choice_wrapper(aten_time=1.0, triton_time=0.1), + ), + ): + compiled_fn = torch.compile(fn) + _, code = run_and_get_code(compiled_fn, a, b, c, idx0, idx1, value) + FileCheck().check("triton_tem_fused").run(code[0]) + @instantiate_parametrized_tests class TestTemplateConfigPruning(TestCase): @@ -3005,7 +3609,6 @@ def fn(a, b, c): self.assertEqual(counters["inductor"]["select_algorithm_precompile"], 0) @config.patch(autotune_local_cache=False, autotune_remote_cache=False) - @runOnRocmArch(MI300_ARCH) @unittest.skipIf(config.triton.native_matmul, "native matmul has counter 0") def test_precompilations(self): def fn(a, b, c): @@ -3014,7 +3617,9 @@ def fn(a, b, c): return (a @ b) @ c fn_c = torch.compile(mode="max-autotune-no-cudagraphs")(fn) - inputs = [torch.rand([256, 256], device=GPU_TYPE) for _ in range(3)] + # Scale down so float16 doesn't overflow: rand [0,1) -> (a@b)@c has elements in [0, 256^2). + # float16 max ~65504, so we keep values in a safe range (e.g. scale by 1/256). + inputs = [torch.rand([256, 256], device=GPU_TYPE) / 256.0 for _ in range(3)] torch.testing.assert_close(fn_c(*inputs), fn(*inputs), atol=1e-2, rtol=1e-2) @@ -4193,6 +4798,9 @@ def _setup_mm_heuristic(self, use_async_compile: bool): @unittest.skipIf(not has_triton_tma_device(), "Need TMA support in Triton") @skipIfXpu(msg="Bad tma config can be covered by XPU TMA") + @unittest.skipIf( + config.cpp_wrapper, "Skip static analysis codegen checks on cpp_wrapper" + ) @parametrize("use_async_compile", (True, False)) def test_template_bad_epilogue_fusion(self, use_async_compile: bool): def f(a, b): @@ -4287,8 +4895,7 @@ def precompile(self): "triton_poi_fused__to_copy" ).run(code[0]) - if not config.cpp_wrapper: - torch.testing.assert_close(out, f(a, b), atol=1e-2, rtol=1e-2) + torch.testing.assert_close(out, f(a, b), atol=1e-2, rtol=1e-2) finally: # Restore original configs tma_heuristic.mm_configs = original_tma_mm_configs @@ -4297,6 +4904,9 @@ def precompile(self): @unittest.skipIf( not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" ) + @unittest.skipIf( + config.cpp_wrapper, "Skip static analysis codegen checks on cpp_wrapper" + ) @parametrize( "test_case", [ @@ -4356,9 +4966,7 @@ def test_template_epilogue_fusion_static_analysis( _, code = run_and_get_code(compiled_f, a, b) if expect_fusion: - FileCheck().check("triton_tem_fused__to_copy_add_mm_0.run").run( - code[0] - ) + FileCheck().check("triton_tem_fused__to_copy_add_mm_0").run(code[0]) elif triton_time < aten_time: FileCheck().check("triton_tem_fused_mm").check( "triton_poi_fused__to_copy" @@ -4371,6 +4979,9 @@ def test_template_epilogue_fusion_static_analysis( @unittest.skipIf( not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" ) + @unittest.skipIf( + config.cpp_wrapper, "Skip static analysis codegen checks on cpp_wrapper" + ) @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") @parametrize("fuse_epilogue", (True, False)) @parametrize("use_async_compile", (True, False)) @@ -4424,6 +5035,9 @@ def fn(x, w, bias, scale): @unittest.skipIf( not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" ) + @unittest.skipIf( + config.cpp_wrapper, "Skip static analysis codegen checks on cpp_wrapper" + ) @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") @parametrize( "test_case", @@ -4487,6 +5101,9 @@ def test_template_epilogue_fusion_occupancy_ratio( @unittest.skipIf( not HAS_CUDA_AND_TRITON, "Scheduler static analysis only tested on cuda" ) + @unittest.skipIf( + config.cpp_wrapper, "Skip static analysis codegen checks on cpp_wrapper" + ) @skipIfRocm(msg="Scheduler static analysis needs investigation on ROCm") @parametrize( "test_case", @@ -4606,22 +5223,11 @@ class TestMaxAutotuneAsyncPipelined(TestMaxAutotune, TestEpilogueFusionStaticAna """Tests for AsyncPipelinedAutotuning path.""" SKIP_TESTS = { - "test_max_autotune_decompose_k": "Subgraphs not supported with async pipelining", "test_inf_timing": "Logs not consistent with async pipelined autotuning", "test_non_contiguous_input_mm_plus_mm": "Flaky on trunk", "test_autotune_device_guard": "Flaky on trunk", "test_template_bad_epilogue_fusion": "Benchmarking path is different", - # Contiguous transform tests - SubgraphChoiceCaller not supported with async pipelining - "test_max_autotune_contiguous_transform_mm": "Subgraphs not supported with async pipelining", - "test_max_autotune_contiguous_transform_addmm": "Subgraphs not supported with async pipelining", - "test_max_autotune_contiguous_transform_non_contiguous_second_matrix": "Subgraphs not supported with async pipelining", - "test_max_autotune_contiguous_transform_with_epilogue": "Subgraphs not supported with async pipelining", - # XPU specific skips due to lack of multiprocess tensor reduction support (issue #170636) - "test_max_autotune_addmm_persistent_tma": "No XPU implementation for multiprocess tensor reduction", - "test_max_autotune_regular_mm_persistent_tma": "No XPU implementation for multiprocess tensor reduction", - "test_max_autotune_regular_mm_persistent_tma_strided": "No XPU implementation for multiprocess tensor reduction", - "test_max_autotune_addmm_tma_dynamic_outer_dim": "No XPU implementation for multiprocess tensor reduction", - "test_max_autotune_regular_mm_tma_dynamic_outer_dim": "No XPU implementation for multiprocess tensor reduction", + "test_persistent_tma_epilogue_fusion_store_cache": "Epilogue fusion disabled in async pipelining", } @classmethod @@ -4645,7 +5251,7 @@ def setUp(self): super().setUp() test_name = self._testMethodName for skip_test_name in self.SKIP_TESTS: - if skip_test_name in test_name or TEST_XPU: + if skip_test_name in test_name or TEST_XPU or config.cpp_wrapper: self.skipTest(self.SKIP_TESTS[skip_test_name]) def tearDown(self): diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 35cc93c340bc1..eb362b6535a79 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -455,6 +455,104 @@ def replace_foreach(gm): code = run_and_get_triton_code(foo, inp, inp2) FileCheck().check("allocated=['buf0']").run(code) + @unittest.skipUnless(TRITON_AVAILABLE, "Triton is not available") + def test_torch_cond_ordering_consistency(self): + small_sz, large_sz = 256, 1024 + + class MultiCondModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("large_buffer", torch.zeros(large_sz)) + self.register_buffer("small_buffer1", torch.zeros(small_sz)) + self.register_buffer("small_buffer2", torch.zeros(small_sz)) + self.register_buffer("counter", torch.tensor(0, dtype=torch.long)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + condition = self.counter % 2 == 0 + + def true_fn_large(buf): + return buf.clone() * 2.0 + + def false_fn_large(buf): + return buf.clone() + + def true_fn_small(buf): + return buf.clone() * 2.0 + + def false_fn_small(buf): + return buf.clone() + + result_large = torch.cond( + condition, + lambda: true_fn_large(self.large_buffer), + lambda: false_fn_large(self.large_buffer), + ) + result_small1 = torch.cond( + condition, + lambda: true_fn_small(self.small_buffer1), + lambda: false_fn_small(self.small_buffer1), + ) + result_small2 = torch.cond( + condition, + lambda: true_fn_small(self.small_buffer2), + lambda: false_fn_small(self.small_buffer2), + ) + return ( + x + result_large.sum() + result_small1.sum() + result_small2.sum() + ) + + def extract_cond_order(code: str) -> list[tuple[str, int]]: + """ + Extract the order of torch.cond operations from generated code. + Returns list of (cond_name, buffer_size) tuples in execution order. + """ + import re + + cond_order = [] + # Look for patterns like "cond" or "cond_1" in the generated code + # along with their buffer sizes + lines = code.split("\n") + for i, line in enumerate(lines): + # Match true_graph buffer allocations which indicate cond execution + match = re.search(r"true_graph_(\d+)_buf0\s*=.*\((\d+),", line) + if match: + cond_idx = int(match.group(1)) + buf_size = int(match.group(2)) + cond_order.append((f"cond_{cond_idx}", buf_size)) + return cond_order + + model = MultiCondModel().to(GPU_TYPE) + x = torch.randn(10, device=GPU_TYPE) + + # Compile with base settings (no reordering) + torch._dynamo.reset() + with config.patch({"reorder_for_peak_memory": False}): + compiled_base = torch.compile(model) + code_base = run_and_get_triton_code(compiled_base, x) + + base_order = extract_cond_order(code_base) + + # Compile with reorder_for_peak_memory=True + torch._dynamo.reset() + with config.patch({"reorder_for_peak_memory": True}): + compiled_peak_mem = torch.compile(model) + code_peak_mem = run_and_get_triton_code(compiled_peak_mem, x) + + peak_mem_order = extract_cond_order(code_peak_mem) + + if base_order and peak_mem_order: + self.assertEqual( + base_order, + peak_mem_order, + msg=( + f"torch.cond operations were reordered by reorder_for_peak_memory!\n" + f"Base order: {base_order}\n" + f"Peak memory order: {peak_mem_order}\n" + f"This can cause NCCL hangs when torch.cond contains collective operations " + f"because different ranks may execute collectives in different orders." + ), + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 17b863cc1bcb2..4130e2d65005d 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -13,7 +13,7 @@ ) if __name__ == "__main__": sys.exit(0) - raise unittest.SkipTest("requires sympy/functorch/filelock") # noqa: F821 + raise unittest.SkipTest("requires sympy/functorch/filelock") import torch from torch._C import FileCheck @@ -149,9 +149,9 @@ def forward(self, x, y): ).check_count("aoti_torch__alloc_from_pool(pool0", 1, exactly=True).run(code) FileCheck().check( - "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" # noqa: B950 + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_int32, 0, int_array_1, int_array_1, &tmp_tensor_handle_0));" ).check("RAIIAtenTensorHandle(tmp_tensor_handle_0);").check( - "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" # noqa: B950 + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool(pool0, 0, cached_torch_dtype_float32, 3, int_array_4, int_array_5, &tmp_tensor_handle_1));" ).check("RAIIAtenTensorHandle(tmp_tensor_handle_1);").run(code) diff --git a/test/inductor/test_minifier_isolate.py b/test/inductor/test_minifier_isolate.py index 61cf6e3961133..f1862b65f9bce 100644 --- a/test/inductor/test_minifier_isolate.py +++ b/test/inductor/test_minifier_isolate.py @@ -8,7 +8,6 @@ IS_MACOS, skipIfRocm, skipIfWindows, - skipIfXpu, TEST_WITH_ASAN, ) from torch.testing._internal.inductor_utils import GPU_TYPE @@ -40,11 +39,13 @@ def test_after_aot_cpu_runtime_error(self): self._test_after_aot_runtime_error("cpu", "") @skipIfRocm - @skipIfXpu @requires_gpu @inductor_config.patch("triton.inject_relu_bug_TESTING_ONLY", "runtime_error") def test_after_aot_gpu_runtime_error(self): - self._test_after_aot_runtime_error(GPU_TYPE, "device-side assert") + expected_error = ( + "injected assert fail" if GPU_TYPE == "xpu" else "device-side assert" + ) + self._test_after_aot_runtime_error(GPU_TYPE, expected_error) if __name__ == "__main__": diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 3e604949a8d41..2aed8dc9eff09 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -15,9 +15,8 @@ from torch.testing._internal.common_device_type import largeTensorTest from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, - isRocmArchAnyOf, - MI200_ARCH, parametrize, + skipIfXpu, ) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -59,6 +58,14 @@ def f(x): self.assertEqual(2, metrics.generated_kernel_count) +# Cooperative reductions disable split reductions, which are necessary for mix order +# reductions. +@inductor_config.patch( + { + "triton.cooperative_reductions": False, + "triton.force_cooperative_reductions": False, + } +) @instantiate_parametrized_tests class MixOrderReductionTest(TestBase): @parametrize( @@ -218,7 +225,13 @@ def f(x): self.check_numeric(f, (x,)) # We don't do mix order reduction for split redutions # with more than 2 layers - self.assertEqual(metrics.codegen_mix_order_reduction, 0) + self.assertEqual( + metrics.codegen_mix_order_reduction, + 1 + if inductor_config.triton.cooperative_reductions + or inductor_config.triton.force_cooperative_reductions + else 0, + ) def test_independent_split_size(self): """ @@ -294,6 +307,8 @@ def f(x, y): @parametrize("max_autotune", (False, True)) @parametrize("initial_xblock", (1, 2)) @parametrize("add_1dim", (False, True)) + # The test OOM in CI sometimes. Ask for more memory to make it stable. + @largeTensorTest("16GB", device=GPU_TYPE, inductor=True) def test_rms_norm_bwd( self, wdtype, @@ -534,9 +549,6 @@ def test_rms_norm_sharing_weights(self, split_reductions, dtype): if not inductor_config.triton.mix_order_reduction: self.skipTest("Mix order reduction not enabled") - if dtype is torch.bfloat16 and isRocmArchAnyOf(MI200_ARCH): - self.skipTest("Currently failing on rocm mi200") - def f(xs, w, eps): ys = [] for x in xs: @@ -554,13 +566,22 @@ def f(xs, w, eps): eps = 1e-5 ref = f(xs, w, eps) + + # use float64 to compute ref_grads for precision + # and cast back to original dtype + xs_f64 = [x.to(torch.float64) for x in xs] + w_f64 = w.to(torch.float64) + dys_f64 = [dy.to(torch.float64) for dy in dys] + ref_f64 = f(xs_f64, w_f64, eps) + ref_grads_f64 = torch.autograd.grad(ref_f64, [*xs_f64, w_f64], dys_f64) + ref_grads = [g.to(dtype) for g in ref_grads_f64] + act = torch.compile( f, options={ "split_reductions": split_reductions, }, )(xs, w, eps) - ref_grads = torch.autograd.grad(ref, [*xs, w], dys) act_grads, (wrapper,) = utils.run_and_get_code( lambda: torch.autograd.grad(act, [*xs, w], dys) ) @@ -663,6 +684,7 @@ def f(x): # the other is the piontwise kernel self.assertTrue(2, metrics.generated_kernel_count) + @patch("torch._inductor.scheduler.MixOrderReduction.is_split_reduction") @patch("torch._inductor.scheduler.MixOrderReduction.get_numel_rnumel") @patch("torch._inductor.scheduler.MixOrderReduction.get_common_read") @patch("torch._inductor.scheduler.MixOrderReduction.has_mix_reduction_orders") @@ -671,6 +693,7 @@ def test_mix_order_reduction_non_strict_mode( mock_has_mix_reduction_orders: mock.Mock, mock_get_common_read: mock.Mock, mock_get_numel_rnumel: mock.Mock, + mock_is_split_reduction: mock.Mock, ): """ This tests whether we can skip some non-critical checks @@ -703,6 +726,7 @@ def test_mix_order_reduction_non_strict_mode( from sympy import Integer mock_get_numel_rnumel.return_value = (Integer(1), Integer(1)) + mock_is_split_reduction.return_value = False mock_node_1.read_writes = mock.Mock() mock_node_1.read_writes.reads = [] @@ -724,7 +748,11 @@ def test_mix_order_reduction_non_strict_mode( self.assertFalse(MixOrderReduction.can_fuse(mock_node_1, mock_node_2)) with ( V.set_graph_handler(graph), - inductor_config.patch({"triton.mix_order_reduction_non_strict_mode": True}), + inductor_config.patch( + { + "triton.mix_order_reduction_non_strict_mode": True, + } + ), ): self.assertTrue(MixOrderReduction.can_fuse(mock_node_1, mock_node_2)) @@ -752,6 +780,7 @@ def f(x): compile_metrics = torch._dynamo.utils._compilation_metrics self.assertEqual(len(compile_metrics), 1, "Don't recompile") + @skipIfXpu(msg="https://github.com/intel/intel-xpu-backend-for-triton/issues/6398") def test_additive_rnumel(self): """ Fix https://github.com/pytorch/pytorch/issues/176375 @@ -931,6 +960,215 @@ def causal_mask(_b, _h, q, kv): loss.backward() self.assertTrue(metrics.codegen_mix_order_reduction > 1) + @inductor_config.patch("triton.mix_order_reduction", True) + @inductor_config.patch("triton.mix_order_reduction_non_strict_mode", True) + def test_dimension_refactoring_mismatch(self): + """ + This reproduces an issue where `simplify_and_reorder()` produces a different + dimension factorization than `_original_ranges` used during fusion decision. + For example, fusion might see (13, 8472) but codegen sees (26, 4236) after + the reduction split optimization adds a factor of 2 to the pointwise dimensions. + + We skip fusing split reductions for node1 in this case. + """ + + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + # Reproduce the RMSNorm backward pattern that triggered the bug. + # The key is: + # - Shape (M, N) = (13, 8472) where N=8472 is large enough to trigger split + # - RMSNorm backward creates reductions along both dimensions + # - The feature dimension reduction (8472) gets split with factor 2 + # - Mix order reduction tries to fuse these, but groups don't match after split + def f(x, w, eps): + orig_dtype = x.dtype + x = x.float() + # RMSNorm forward: y = x * rsqrt(mean(x^2) + eps) * w + rsqrt = torch.rsqrt((x * x).sum(dim=-1) / x.shape[-1] + eps) + y = (x * rsqrt[:, None] * w).to(dtype=orig_dtype) + return y + + def fwd_bwd(compiled_f): + x.grad = None + w.grad = None + out = compiled_f(x, w, eps) + out.backward(dy) + return x.grad, w.grad + + # Use the exact shape from the bug report: (13, 8472) + # 8472 = 2 * 4236, so split with factor 2 gives sub-reductions of 4236 + M, N = 13, 8472 + x = torch.randn(M, N, dtype=torch.float32, device=GPU_TYPE, requires_grad=True) + w = torch.randn(N, dtype=torch.float32, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + eps = 1e-5 + + opt_f = torch.compile(f) + + ref = fwd_bwd(f) + act = fwd_bwd(opt_f) + torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) + self.assertGreaterEqual(metrics.codegen_mix_order_reduction, 0) + + def test_keepdim_shape_mismatch(self): + """ + Test that MixOrderReduction correctly handles keepdim=True reductions. + + This test reproduces a bug where the final reduction in MixOrderReduction + generates `view(nsplit, rnumel).sum(dim=0)` which produces shape [rnumel], + but the expected output should be [1, rnumel] when keepdim=True. + + The error manifests as: + RuntimeError: Function CompiledFunctionBackward returned an invalid gradient + at index N - got [2048] but expected shape compatible with [1, 2048] + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + # Create a model that produces reductions with keepdim=True in the backward pass + # This pattern is common in normalization layers like RMSNorm/LayerNorm + class KeepDimReductionModel(nn.Module): + def __init__(self, hidden_size): + super().__init__() + # Using shape [1, hidden_size] to ensure keepdim=True in backward + self.weight = nn.Parameter(torch.ones(1, hidden_size)) + self.bias = nn.Parameter(torch.zeros(1, hidden_size)) + + def forward(self, x): + # x: [batch, hidden_size] + # Normalization-like operation that produces keepdim reductions in backward + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean) ** 2).mean(dim=-1, keepdim=True) + x_norm = (x - mean) / (var + 1e-5).sqrt() + return x_norm * self.weight + self.bias + + M, N = 32768, 2048 # Large batch to trigger mix order reduction + model = KeepDimReductionModel(N).to(GPU_TYPE) + + x = torch.randn(M, N, dtype=torch.float32, device=GPU_TYPE, requires_grad=True) + dy = torch.randn_like(x) + + def fwd_bwd(model, x, dy): + x.grad = None + model.zero_grad() + out = model(x) + out.backward(dy) + return x.grad, model.weight.grad, model.bias.grad + + # Reference (eager) + ref = fwd_bwd(model, x, dy) + + # Compiled with mix order reduction + compiled_model = torch.compile(model) + act = fwd_bwd(compiled_model, x, dy) + + # Verify numerical correctness + self.assertTrue(same(ref, act, tol=1e-3), f"ref:\n{ref}\nact:\n{act}") + + # Verify mix order reduction was used + self.assertGreater( + metrics.codegen_mix_order_reduction, + 0, + "Mix order reduction should be triggered", + ) + + +class OverFusionTest(TestBase): + """ + Regression test for mix-order reduction over-fusion in transformer backward + passes. When can_fuse_with absorbs too many pointwise nodes into a + mixed-order kernel, the resulting kernel has excessive buffer reads (loads) + per RSPLIT loop iteration, causing register spills and performance + regression. See #179423. + """ + + @inductor_config.patch( + { + "triton.mix_order_reduction": True, + "triton.mix_order_reduction_max_reads": 10, + } + ) + def test_max_reads_limits_fusion(self): + """ + Verify that max_reads limits over-fusion in a transformer backward + pass without disabling mix-order reduction entirely. + + Uses the exact model pattern from #179423: GQA attention with QK-norm + and squared leaky-relu MLP. The QK-norm creates extra intermediate + buffers in the backward pass that push read counts above the threshold. + """ + if not HAS_GPU: + self.skipTest("requires GPU") + + num_heads = 8 + num_kv_heads = 4 + dim = 512 + head_dim = dim // num_heads + + class Attention(nn.Module): + def __init__(self): + super().__init__() + self.c_q = nn.Linear(dim, dim, bias=False) + self.c_k = nn.Linear(dim, num_kv_heads * head_dim, bias=False) + self.c_v = nn.Linear(dim, num_kv_heads * head_dim, bias=False) + self.proj = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + B, T, D = x.shape + q = self.c_q(x).reshape(B, T, num_heads, head_dim) + k = self.c_k(x).reshape(B, T, num_kv_heads, head_dim) + v = self.c_v(x).reshape(B, T, num_kv_heads, head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + q = q.transpose(1, 2) + k = k.transpose(1, 2).repeat_interleave( + num_heads // num_kv_heads, dim=1 + ) + v = v.transpose(1, 2).repeat_interleave( + num_heads // num_kv_heads, dim=1 + ) + y = F.scaled_dot_product_attention(q, k, v, is_causal=True) + return self.proj(y.transpose(1, 2).reshape(B, T, D)) + + class Block(nn.Module): + def __init__(self): + super().__init__() + self.attn_norm = nn.RMSNorm(dim) + self.mlp_norm = nn.RMSNorm(dim) + self.attn = Attention() + self.fc1 = nn.Linear(dim, dim * 4, bias=False) + self.fc2 = nn.Linear(dim * 4, dim, bias=False) + + def forward(self, x): + x = x + self.attn(self.attn_norm(x)) + h = self.mlp_norm(x) + x = x + self.fc2(F.leaky_relu(self.fc1(h), negative_slope=0.5).square()) + return x + + model = nn.Sequential(*[Block() for _ in range(3)]).to(GPU_TYPE).bfloat16() + + x = torch.randn( + 8, 2048, dim, device=GPU_TYPE, dtype=torch.bfloat16, requires_grad=True + ) + dy = torch.randn_like(x) + + out_ref = model(x) + out_ref.backward(dy) + grad_ref = x.grad.clone() + x.grad = None + + compiled = torch.compile(model, dynamic=False, fullgraph=True) + out_act = compiled(x) + out_act.backward(dy) + grad_act = x.grad.clone() + + self.assertTrue(same(grad_ref, grad_act, tol=5e-2)) + + # max_reads should limit over-fusion, not disable mix_order entirely + self.assertGreater(metrics.codegen_mix_order_reduction, 0) + self.assertGreater(metrics.rejected_mix_order_reduction_fusion, 0) + @inductor_config.patch( "triton.mix_order_reduction", not inductor_config.triton.mix_order_reduction diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 8be5f9876235f..a342e67deecbd 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -1564,7 +1564,6 @@ def forward(self, x): @unittest.skipIf(not TEST_MKL, "Test requires MKL") @xfailIfACL - @torch._dynamo.config.patch("inline_inbuilt_nn_modules", True) def test_reproduce_121253_issue_addmm_fusion_check(self): class Mod(torch.nn.Module): def __init__(self, weight, bias, beta, alpha): diff --git a/test/inductor/test_mmdecomp.py b/test/inductor/test_mmdecomp.py index 6d5d012e733f1..140e43b150009 100644 --- a/test/inductor/test_mmdecomp.py +++ b/test/inductor/test_mmdecomp.py @@ -15,7 +15,12 @@ from torch.testing._internal.common_cuda import SM80OrLater from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import IS_WINDOWS, parametrize, run_tests +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + parametrize, + run_tests, + TEST_XPU, +) from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -132,7 +137,8 @@ def test_simple_mm(self, device, dtype): @unittest.skipIf(not HAS_GPU, "GPU tests require triton") @parametrize( - "dtype", [torch.float, torch.bfloat16] if SM80OrLater else [torch.float] + "dtype", + [torch.float, torch.bfloat16] if SM80OrLater or TEST_XPU else [torch.float], ) @parametrize("bs", [1, 2, 4, 10]) def test_batched_mm(self, device, dtype, bs): @@ -196,10 +202,14 @@ def test_bmm_outer_product_k_is_one_with_unbacked_k(self, device): rhs_unbacked_k = torch.empty((b, rhs_k_unbacked, n), device=device) self.assertIsNot( + decomp_bmm(lhs_static_k, rhs_static_k), + NotImplemented, + ) + self.assertIs( decomp_bmm(lhs_static_k, rhs_unbacked_k), NotImplemented, ) - self.assertIsNot( + self.assertIs( decomp_bmm(lhs_unbacked_k, rhs_static_k), NotImplemented, ) @@ -347,7 +357,9 @@ def test_dynamic_shape_mm(self, device, dtype): device_types = ("cpu", GPU_TYPE) -instantiate_device_type_tests(TestDecomp, globals(), only_for=device_types) +instantiate_device_type_tests( + TestDecomp, globals(), only_for=device_types, allow_xpu=True +) if __name__ == "__main__": # We don't support torch.compile() on Windows diff --git a/test/inductor/test_mps_basic.py b/test/inductor/test_mps_basic.py index a7e6ef8a68dfc..5d1d68d391a66 100644 --- a/test/inductor/test_mps_basic.py +++ b/test/inductor/test_mps_basic.py @@ -207,6 +207,44 @@ def fn(q, k, v): self.common(fn, (q, k, v), atol=1e-4, rtol=1e-4, check_lowp=False) + def test_nested_masked_cat(self): + # Regression test for YOLOv3 compilation failure on MPS. + # See https://github.com/pytorch/pytorch/actions/runs/23477894502 + # YOLOv3 detection heads do view/permute/clone, then in-place slice + # assignment (sigmoid+grid, exp*anchor, sigmoid) followed by cat across + # scales. The slice_scatter decomposition fused with cat produces nested + # ops.masked calls in Metal codegen. Without depth-aware variable + # prefixes, inner scoped variables shadow outer ones, causing: + # "variable 'tmp_scoped_1' declared with deduced type 'auto' + # cannot appear in its own initializer" + na, no = 3, 5 + + def head(p, grid, anchor_wh): + bs, _, ny, nx = p.shape + p = p.view(bs, na, no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + io = p.clone() + io[..., :2] = torch.sigmoid(io[..., :2]) + grid + io[..., 2:4] = torch.exp(io[..., 2:4]) * anchor_wh + torch.sigmoid_(io[..., 4:]) + return io.view(bs, -1, no) + + def fn(p1, p2, grid1, grid2, anchor_wh1, anchor_wh2): + return torch.cat( + [head(p1, grid1, anchor_wh1), head(p2, grid2, anchor_wh2)], dim=1 + ) + + self.common( + fn, + ( + torch.randn(1, na * no, 4, 4, device="mps"), + torch.randn(1, na * no, 8, 8, device="mps"), + torch.randn(1, 1, 4, 4, 2, device="mps"), + torch.randn(1, 1, 8, 8, 2, device="mps"), + torch.randn(1, na, 1, 1, 2, device="mps"), + torch.randn(1, na, 1, 1, 2, device="mps"), + ), + ) + class MPSBasicTestsAOTI(TestCase): def check_model(self, m, inp, dynamic_shapes=None): diff --git a/test/inductor/test_multi_kernel.py b/test/inductor/test_multi_kernel.py index 55f54756913db..f8fc1c3df607c 100644 --- a/test/inductor/test_multi_kernel.py +++ b/test/inductor/test_multi_kernel.py @@ -70,6 +70,7 @@ def make_cpp_wrapper_test(orig_test, **extra_args): """ @config.patch("cpp_wrapper", True) + @config.patch("triton.autotune_at_compile_time", True) def fn(self): # The same kernel may have been compiled by previous tests with # cpp_wrapper disabled. Clear the cache so we go ahead to re-compile @@ -111,7 +112,7 @@ def test_softmax(self, expect_multi_kernel=True): # TODO: bobrenjc93 to fix multi-kernel for ROCM @skipIfRocm @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") - @skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295") + @skipIfXpu(msg="driver issue, torch-xpu-ops: 2295") def test_triton_gemm(self): def fn(x, y): return x @ y @@ -136,7 +137,7 @@ def fn(x, y): self.assertEqual(ref, act) self.assertTrue(_contains_size_hint_multi_kernel_code(wrapper_code)) - @skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2295") + @skipIfXpu(msg="driver issue, torch-xpu-ops: 2295") @requires_triton() # TODO: bobrenjc93 to fix multi-kernel for ROCM @skipIfRocm diff --git a/test/inductor/test_needs_exact_strides.py b/test/inductor/test_needs_exact_strides.py index ee3d4779881f2..dc7c3d55f967e 100644 --- a/test/inductor/test_needs_exact_strides.py +++ b/test/inductor/test_needs_exact_strides.py @@ -13,13 +13,15 @@ IS_LINUX, parametrize, ) -from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +from torch.testing._internal.inductor_utils import HAS_GPU_AND_TRITON class TestNeedsExactStrides(InductorTestCase): @parametrize("dtype", [torch.float, torch.float8_e8m0fnu]) def test_custom_op(self, dtype): - device = "cuda" # float8_e8m0fnu errors on "cpu" + device = ( + torch.accelerator.current_accelerator() + ) # float8_e8m0fnu errors on "cpu" x = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu) other = torch.ones(4, 4, 2, 2, device=device, dtype=torch.float8_e8m0fnu) @@ -101,5 +103,5 @@ def f(x, other): instantiate_parametrized_tests(TestNeedsExactStrides) if __name__ == "__main__": - if IS_LINUX and HAS_CUDA_AND_TRITON: + if IS_LINUX and HAS_GPU_AND_TRITON: run_tests(needs="filelock") diff --git a/test/inductor/test_nv_universal_gemm.py b/test/inductor/test_nv_universal_gemm.py index 0019435130f05..9a6512923db7b 100644 --- a/test/inductor/test_nv_universal_gemm.py +++ b/test/inductor/test_nv_universal_gemm.py @@ -13,6 +13,7 @@ ) from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import ( + ceildiv, ensure_nv_universal_gemm_available, ensure_nvmatmul_heuristics_available, run_and_get_code, @@ -24,6 +25,68 @@ from torch.utils._ordered_set import OrderedSet +def _round_up(x, multiple): + return ((x + multiple - 1) // multiple) * multiple + + +def _prep_k(K, scale_size): + """Prepare K dimension for swizzle requirements (round up ceildiv to multiple of 4).""" + return _round_up(ceildiv(K, scale_size), 4) + + +def _create_tensor_with_layout(layout, rows, cols, dtype, device="cuda"): + """Create a tensor with the specified layout and dtype. + + Supports float16, bfloat16, float8_e4m3fn, and float4_e2m1fn_x2. + """ + is_fp4 = dtype == torch.float4_e2m1fn_x2 + is_fp8 = dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + + def _make_flat(n): + if is_fp4: + return torch.randint(0, 256, (n,), device=device, dtype=torch.uint8).view( + torch.float4_e2m1fn_x2 + ) + elif is_fp8: + return torch.randint(-1, 2, (n,), device=device).to(dtype) + else: + return torch.randn(n, device=device, dtype=dtype) + + if layout == "contiguous": + if is_fp4: + return torch.randint( + 0, 256, (rows, cols), device=device, dtype=torch.uint8 + ).view(torch.float4_e2m1fn_x2) + elif is_fp8: + return torch.randint(-1, 2, (rows, cols), device=device).to(dtype) + else: + return torch.randn(rows, cols, device=device, dtype=dtype) + elif layout == "aligned_offset": + storage = _make_flat(rows * cols + 512) + offset = 16 // storage.element_size() + return torch.as_strided(storage[offset:], (rows, cols), (cols, 1)) + elif layout == "view": + return _make_flat(rows * cols).view(rows, cols) + elif layout == "padded": + row_pitch = cols + 8 + storage = _make_flat(rows * row_pitch) + return torch.as_strided(storage, (rows, cols), (row_pitch, 1)) + else: + raise ValueError(f"Unknown layout: {layout}") + + +def _nvgemm_config(**overrides): + """Standard NVGEMM test config. Always disables ATen fallback.""" + cfg = { + "max_autotune": True, + "max_autotune_gemm_backends": "NVGEMM", + "nvgemm_max_profiling_configs": 3, + "autotune_fallback_to_aten": False, + } + cfg.update(overrides) + return cfg + + # TODO(nikhilap): Remove Blackwell restriction once cutlass_api includes H100 kernels @unittest.skipIf( not (ensure_nv_universal_gemm_available() and is_datacenter_blackwell_arch()), @@ -39,59 +102,31 @@ class TestNVUniversalGemm(TestCase): ( ("contiguous", "contiguous"), ("aligned_offset", "contiguous"), + ("contiguous", "aligned_offset"), ("contiguous", "view"), ("aligned_offset", "view"), ("padded", "contiguous"), + ("contiguous", "padded"), ), ) def test_matmul(self, dtype, layout_a, layout_b): """Test matmul with various dtypes and tensor layouts. - These layouts test various alignment scenarios: - - contiguous/view/aligned_offset: Standard aligned layouts - - padded: Non-16-byte-aligned stride, Inductor pads to aligned size - M=513 tests that non-divisible M dimension works (only N and K must be divisible by 16). + M=513 tests that non-divisible M dimension works + (only N and K must be divisible by 16). """ m, n, k = 513, 512, 512 - device = "cuda" def matmul(a, b): return a @ b - def create_tensor_with_layout(layout, rows, cols): - """Create a tensor with the specified layout.""" - if layout == "contiguous": - return torch.randn(rows, cols, device=device, dtype=dtype) - elif layout == "aligned_offset": - # Allocate bigger buffer than needed, use 16-byte aligned offset - # offset=128 elements * 2 bytes = 256 bytes (16-byte aligned) - storage = torch.randn(rows * cols + 512, device=device, dtype=dtype) - offset = 128 - return torch.as_strided(storage[offset:], (rows, cols), (cols, 1)) - elif layout == "view": - storage = torch.randn(rows * cols, device=device, dtype=dtype) - return storage.view(rows, cols) - elif layout == "padded": - # Simulate row pitch > cols with non-16-byte-aligned stride - # row_stride = cols + 8 = 520, 520 * 2 bytes = 1040 bytes (not 16-byte aligned) - row_pitch = cols + 8 - storage = torch.randn(rows * row_pitch, device=device, dtype=dtype) - return torch.as_strided(storage, (rows, cols), (row_pitch, 1)) - - a = create_tensor_with_layout(layout_a, m, k) - b = create_tensor_with_layout(layout_b, k, n) - + a = _create_tensor_with_layout(layout_a, m, k, dtype) + b = _create_tensor_with_layout(layout_b, k, n, dtype) expected = matmul(a, b) torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - } - ): + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(matmul) result = compiled_fn(a, b) @@ -111,7 +146,6 @@ def test_unaligned_base_pointer_rejected(self): def matmul(a, b): return a @ b - # Create tensor with unaligned base pointer # offset=117 elements * 2 bytes = 234 bytes (NOT 16-byte aligned) storage = torch.randn(m * k + 512, device=device, dtype=dtype) a = torch.as_strided(storage[117:], (m, k), (k, 1)) @@ -119,13 +153,7 @@ def matmul(a, b): torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - } - ): + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(matmul) with self.assertRaisesRegex( Exception, "NoValidChoicesError|no valid choice" @@ -137,34 +165,23 @@ def test_reinterpret_view_from_slice(self, dtype): """Test that sliced tensors (creating ReinterpretViews) work correctly. When tensors are slices of a shared buffer (e.g., from a fused projection), - they become ReinterpretViews with non-contiguous strides. NVIDIA Universal GEMM must - handle these correctly. + they become ReinterpretViews with non-contiguous strides. """ m, n, k = 512, 512, 512 device = "cuda" def fn(x, weight): - # Fused projection creates a single large output projected = x @ weight # (m, 2*n) - # Slicing creates ReinterpretViews a, b = projected.split(n, dim=1) # Each is (m, n) return a @ b.t() # (m, m) x = torch.randn(m, k, device=device, dtype=dtype) - # Weight projects to 2*n so we can split weight = torch.randn(k, 2 * n, device=device, dtype=dtype) - expected = fn(x, weight) torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - } - ): + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(fn) result = compiled_fn(x, weight) @@ -174,8 +191,7 @@ def test_workspace_allocation(self): """Test that workspace allocation works correctly. Since no current CUTLASS kernels require a workspace, we mock the - kernel.get_workspace_size method to return a non-zero value. This - exercises the workspace allocation/deallocation code paths. + kernel.get_workspace_size method to return a non-zero value. """ m, n, k = 512, 512, 512 dtype = torch.bfloat16 @@ -186,12 +202,10 @@ def matmul(a, b): a = torch.randn(m, k, device=device, dtype=dtype) b = torch.randn(k, n, device=device, dtype=dtype) - expected = matmul(a, b) torch._dynamo.reset() - # Patch cutlass_api.Kernel.get_workspace_size to return non-zero import cutlass_api def patched_get_workspace_size(self, args): @@ -202,13 +216,7 @@ def patched_get_workspace_size(self, args): "get_workspace_size", patched_get_workspace_size, ): - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - } - ): + with config.patch(_nvgemm_config()): result, (code,) = run_and_get_code( torch.compile(matmul), a, @@ -216,7 +224,6 @@ def patched_get_workspace_size(self, args): ) self.assertIn("workspace=workspace", code) - torch.testing.assert_close(result, expected) @parametrize("dtype", (torch.float16, torch.bfloat16)) @@ -228,17 +235,13 @@ def test_bmm_non_standard_batch_stride(self, dtype): def bmm(a, b): return torch.bmm(a, b) - # Create tensors with non-largest batch stride by transposing - # a_base shape: (m, batch, k), stride: (batch*k, k, 1) - # After transpose: shape (batch, m, k), stride: (k, batch*k, 1) - # batch_stride = k = 128, but m*k = 64*128 = 8192, so batch_stride < m*k + # Transpose creates non-largest batch stride a_base = torch.randn(m, batch, k, device=device, dtype=dtype) a = a_base.transpose(0, 1) # (batch, m, k) with stride (k, batch*k, 1) b_base = torch.randn(k, batch, n, device=device, dtype=dtype) b = b_base.transpose(0, 1) # (batch, k, n) with stride (n, batch*n, 1) - # Verify batch stride is not largest (i.e., batch_stride_largest_or_zero would be False) if a.stride()[0] == a.shape[1] * a.shape[2]: raise AssertionError( "Test setup error: a should have non-standard batch stride" @@ -252,13 +255,7 @@ def bmm(a, b): torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - } - ): + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(bmm) result = compiled_fn(a, b) @@ -268,68 +265,39 @@ def bmm(a, b): "layout_a", ("contiguous", "aligned_offset", "view"), ) - def test_scaled_gemm_mxfp8(self, layout_a): - """Test MXFP8 scaled GEMM with NVGEMM backend. - - Note: Invalid inputs (wrong shapes, dtypes, K not divisible by 16, etc.) - are caught early by Dynamo's _check_scaled_mm_sizes in torch._meta_registrations. - NVGEMM can assume inputs are valid by the time they reach kernel selection. - """ - from torch._inductor.utils import ceildiv - - m, n, k = 256, 512, 1024 + @parametrize( + "m,n,k", + ( + (256, 512, 1024), + (256, 1024, 512), + (128, 256, 512), + (512, 256, 1024), + ), + ) + def test_scaled_gemm_mxfp8(self, layout_a, m, n, k): + """Test MXFP8 scaled GEMM with NVGEMM backend.""" block_size = 32 - device = "cuda" - - def _round_up(x, multiple): - return ((x + multiple - 1) // multiple) * multiple - - def _prep_k(K, scale_size): - """Prepare K dimension for 32-4-4 swizzle requirements.""" - return _round_up(ceildiv(K, scale_size), 4) def scaled_mm(a, b, scale_a, scale_b): return torch._scaled_mm( a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float32 ) - # Create FP8 tensor A with requested layout - if layout_a == "contiguous": - a_fp8 = torch.randint(-1, 2, (m, k), device=device).to(torch.float8_e4m3fn) - elif layout_a == "aligned_offset": - # 16 elements * 1 byte = 16 bytes (16-byte aligned) - storage = torch.randint(-1, 2, (m * k + 512,), device=device).to( - torch.float8_e4m3fn - ) - a_fp8 = torch.as_strided(storage[16:], (m, k), (k, 1)) - elif layout_a == "view": - storage = torch.randint(-1, 2, (m * k,), device=device).to( - torch.float8_e4m3fn - ) - a_fp8 = storage.view(m, k) - # B is N x K, then transposed to K x N for scaled_mm - b_fp8 = torch.randint(-1, 2, (n, k), device=device).to(torch.float8_e4m3fn).T + a_fp8 = _create_tensor_with_layout(layout_a, m, k, torch.float8_e4m3fn) + b_fp8 = torch.randint(-1, 2, (n, k), device="cuda").to(torch.float8_e4m3fn).T - # Scale factors in float8_e8m0fnu (MXFP8 format) - # Shape: (M, prep_k(K, 32)) for A, (prep_k(K, 32), N) for B - scale_a = torch.rand(m, _prep_k(k, block_size), device=device).to( + scale_a = torch.rand(m, _prep_k(k, block_size), device="cuda").to( torch.float8_e8m0fnu ) - scale_b = torch.rand(_prep_k(k, block_size), n, device=device).to( + scale_b = torch.rand(_prep_k(k, block_size), n, device="cuda").to( torch.float8_e8m0fnu ) - # Get reference result from eager mode (ATen) expected = scaled_mm(a_fp8, b_fp8, scale_a, scale_b) - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - "autotune_fallback_to_aten": False, - } - ): + torch._dynamo.reset() + + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(scaled_mm) result = compiled_fn(a_fp8, b_fp8, scale_a, scale_b) @@ -350,44 +318,20 @@ def scaled_mm(a, b, scale_a, scale_b): ), ) def test_scaled_gemm_nvf4(self, out_dtype, layout_a, m, n, k): - """Test NVF4 (Float4 + Float8E4M3FN scales, block_size=16) with NVGEMM backend. - - NVF4 is the FP4 format supported end-to-end through torch._scaled_mm. - ATen requires float8_e4m3fn scale factors for FP4 blockwise scaling. - Uses autotune_choice_name_regex to force the vendored kernel. - """ - from torch._inductor.utils import ceildiv - + """Test NVF4 (Float4 + Float8E4M3FN scales, block_size=16) with NVGEMM backend.""" packed_k = k // 2 block_size = 16 - device = "cuda" - - def _round_up(x, multiple): - return ((x + multiple - 1) // multiple) * multiple def scaled_mm(a, b, scale_a, scale_b): return torch._scaled_mm( a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=out_dtype ) - # Create FP4 tensor A with requested layout - if layout_a == "contiguous": - a_fp4 = torch.randint( - 0, 256, (m, packed_k), device=device, dtype=torch.uint8 - ).view(torch.float4_e2m1fn_x2) - elif layout_a == "aligned_offset": - # 16 elements * 1 byte = 16 bytes (16-byte aligned) - storage = torch.randint( - 0, 256, (m * packed_k + 512,), device=device, dtype=torch.uint8 - ).view(torch.float4_e2m1fn_x2) - a_fp4 = torch.as_strided(storage[16:], (m, packed_k), (packed_k, 1)) - elif layout_a == "view": - storage = torch.randint( - 0, 256, (m * packed_k,), device=device, dtype=torch.uint8 - ).view(torch.float4_e2m1fn_x2) - a_fp4 = storage.view(m, packed_k) + a_fp4 = _create_tensor_with_layout( + layout_a, m, packed_k, torch.float4_e2m1fn_x2 + ) b_fp4 = torch.randint( - 0, 256, (n, packed_k), device=device, dtype=torch.uint8 + 0, 256, (n, packed_k), device="cuda", dtype=torch.uint8 ).view(torch.float4_e2m1fn_x2) b_fp4_t = b_fp4.T @@ -397,34 +341,66 @@ def scaled_mm(a, b, scale_a, scale_b): scale_a_numel = block_size_mn * ceildiv(m, block_size_mn) * padded_k_blocks scale_b_numel = block_size_mn * ceildiv(n, block_size_mn) * padded_k_blocks - scale_a = torch.rand(scale_a_numel, device=device).to(torch.float8_e4m3fn) - scale_b = torch.rand(scale_b_numel, device=device).to(torch.float8_e4m3fn) + scale_a = torch.rand(scale_a_numel, device="cuda").to(torch.float8_e4m3fn) + scale_b = torch.rand(scale_b_numel, device="cuda").to(torch.float8_e4m3fn) expected = scaled_mm(a_fp4, b_fp4_t, scale_a, scale_b) + torch._dynamo.reset() + with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 3, - "autotune_fallback_to_aten": False, - # Force vendored kernel via description regex - "test_configs.autotune_choice_desc_regex": "inductor_vendored", - } + _nvgemm_config( + **{"test_configs.autotune_choice_desc_regex": "inductor_vendored"} + ) ): compiled_fn = torch.compile(scaled_mm) result = compiled_fn(a_fp4, b_fp4_t, scale_a, scale_b) - torch.testing.assert_close(result, expected) + # a_fp4 and b_fp4_t could come with NaNs. + torch.testing.assert_close(result, expected, equal_nan=True) + + @parametrize( + "layout_a", + ("contiguous", "aligned_offset", "view", "padded"), + ) + def test_grouped_gemm(self, layout_a): + """Test grouped GEMM with NVGEMM backend and various A layouts. + + GroupedGemm currently only supports TN layout (column-major B). + """ + g, k, n = 4, 256, 256 + dtype = torch.bfloat16 + device = "cuda" + + def grouped_mm(a, b, offsets): + return torch._grouped_mm(a, b, offs=offsets) - def test_grouped_gemm(self): - """Test grouped GEMM with NVGEMM backend. + b = torch.randn(g, n, k, device=device, dtype=dtype).permute(0, 2, 1) + + m_per_group = [64, 64, 64, 64] + total_m = sum(m_per_group) + offsets = torch.tensor( + [sum(m_per_group[: i + 1]) for i in range(g)], + device=device, + dtype=torch.int32, + ) + a = _create_tensor_with_layout(layout_a, total_m, k, dtype, device) + + expected = grouped_mm(a, b, offsets) + + torch._dynamo.reset() + + with config.patch(_nvgemm_config()): + compiled_fn = torch.compile(grouped_mm) + result = compiled_fn(a, b, offsets) - This test runs the same shape twice with different offsets to verify that - different offset distributions produce correct results. + torch.testing.assert_close(result, expected) - Note: GroupedGemm currently only supports TN layout (column-major B). - B is created with shape (g, k, n) but column-major inner layout via permute. + def test_grouped_gemm_varying_offsets(self): + """Test that different offset distributions produce correct results. + + Runs the same compiled function with two different offset distributions + (same total_m) to verify offsets are handled dynamically at runtime. """ g, k, n = 4, 256, 256 dtype = torch.bfloat16 @@ -436,35 +412,30 @@ def grouped_mm(a, b, offsets): b = torch.randn(g, n, k, device=device, dtype=dtype).permute(0, 2, 1) m_per_group_1 = [64, 64, 64, 64] - total_m_1 = sum(m_per_group_1) + total_m = sum(m_per_group_1) offsets_1 = torch.tensor( [sum(m_per_group_1[: i + 1]) for i in range(g)], device=device, dtype=torch.int32, ) - a_1 = torch.randn(total_m_1, k, device=device, dtype=dtype) + a_1 = torch.randn(total_m, k, device=device, dtype=dtype) m_per_group_2 = [32, 96, 48, 80] - total_m_2 = sum(m_per_group_2) - if total_m_1 != total_m_2: + if sum(m_per_group_2) != total_m: raise AssertionError("Total M must match for cache key test") offsets_2 = torch.tensor( [sum(m_per_group_2[: i + 1]) for i in range(g)], device=device, dtype=torch.int32, ) - a_2 = torch.randn(total_m_2, k, device=device, dtype=dtype) + a_2 = torch.randn(total_m, k, device=device, dtype=dtype) expected_1 = grouped_mm(a_1, b, offsets_1) expected_2 = grouped_mm(a_2, b, offsets_2) - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "autotune_fallback_to_aten": False, - } - ): + torch._dynamo.reset() + + with config.patch(_nvgemm_config()): compiled_fn = torch.compile(grouped_mm) result_1 = compiled_fn(a_1, b, offsets_1) @@ -604,6 +575,33 @@ def test_fp4_heuristic_configs(self): self.assertGreater(cfg.tile_n, 0) self.assertGreater(cfg.estimated_runtime, 0) + def test_fp8_heuristic_configs(self): + """Test that nvMatmulHeuristics returns configs for FP8 GEMM.""" + heuristics = NVUniversalGemmHeuristics() + + m, n, k = 256, 512, 1024 + configs = heuristics._get_heuristic_configs( + m, + n, + k, + dtype_a=torch.float8_e4m3fn, + layout_a="row", + layout_b="col", + count=5, + valid_configs=OrderedSet(), + accumulator_type=torch.float32, + dtype_b=torch.float8_e4m3fn, + out_dtype=torch.float32, + ) + + self.assertGreater( + len(configs), 0, "nvMatmulHeuristics returned no FP8 configs" + ) + for cfg in configs: + self.assertGreater(cfg.tile_m, 0) + self.assertGreater(cfg.tile_n, 0) + self.assertGreater(cfg.estimated_runtime, 0) + @unittest.skipIf( not (ensure_nv_universal_gemm_available() and is_datacenter_blackwell_arch()), @@ -618,7 +616,6 @@ def test_unbacked_symint_rejected(self): def fn(x, w): nz = torch.nonzero(x) # Creates unbacked symint for nz.size(0) - # Use unbacked symint as M dimension in matmul a = torch.ones(nz.size(0), w.size(0), dtype=w.dtype, device=w.device) return a @ w @@ -627,13 +624,7 @@ def fn(x, w): torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 2, - } - ): + with config.patch(_nvgemm_config(nvgemm_max_profiling_configs=2)): compiled_fn = torch.compile(fn, dynamic=True) with self.assertRaisesRegex( Exception, "NoValidChoicesError|no valid choice" @@ -648,13 +639,7 @@ def matmul(a, b): torch._dynamo.reset() - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "NVGEMM", - "nvgemm_max_profiling_configs": 2, - } - ): + with config.patch(_nvgemm_config(nvgemm_max_profiling_configs=2)): compiled_fn = torch.compile(matmul, dynamic=True) shapes = [ diff --git a/test/inductor/test_pad_as_cat.py b/test/inductor/test_pad_as_cat.py index 4b4591dcf5b9d..81f0d02c016e3 100644 --- a/test/inductor/test_pad_as_cat.py +++ b/test/inductor/test_pad_as_cat.py @@ -1,9 +1,11 @@ # Owner(s): ["module: inductor"] -"""Tests for cat multi-consumer optimization (pytorch#125075).""" +"""Tests for cat multi-consumer and pad-as-cat optimizations.""" import torch +from torch._dynamo.utils import counters from torch._inductor import metrics from torch._inductor.test_case import TestCase +from torch._inductor.utils import run_and_get_code from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu @@ -74,6 +76,51 @@ def fn(x): ) +class TestPadAsCat(TestCase): + @requires_gpu() + def test_mul_pad_addmm(self): + """Multi-consumer F.pad uses ConcatKernel zero-copy.""" + counters.clear() + + def fn(x, scale, bias, weight): + mul_result = x * scale + padded = torch.nn.functional.pad(mul_result, [0, 192]) + mm_result = torch.addmm(bias, mul_result, weight) + return padded, mm_result + + x = torch.randn(128, 2880, device=GPU_TYPE, dtype=torch.bfloat16) + scale = torch.randn(128, 2880, device=GPU_TYPE, dtype=torch.bfloat16) + bias = torch.randn(1024, device=GPU_TYPE, dtype=torch.bfloat16) + weight = torch.randn(2880, 1024, device=GPU_TYPE, dtype=torch.bfloat16) + + compiled = torch.compile(fn) + result, (code,) = run_and_get_code(compiled, x, scale, bias, weight) + ref = fn(x, scale, bias, weight) + + self.assertEqual(result[0], ref[0]) + self.assertEqual(result[1], ref[1], atol=1e-2, rtol=1e-2) + self.assertIn("reinterpret_tensor", code) + self.assertGreater(counters["inductor"]["pad_rewritten_as_cat"], 0) + + @requires_gpu() + def test_single_consumer_pad(self): + """Single-consumer F.pad is decomposed into cat, which fuses via pointwise_cat.""" + counters.clear() + + def fn(x, scale): + return torch.nn.functional.pad(x * scale, [0, 192]) + + x = torch.randn(128, 2880, device=GPU_TYPE) + scale = torch.randn(128, 2880, device=GPU_TYPE) + + compiled = torch.compile(fn) + result = compiled(x, scale) + ref = fn(x, scale) + + self.assertEqual(result, ref) + self.assertGreater(counters["inductor"]["pad_rewritten_as_cat"], 0) + + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c67bde87a369b..ee12020a311e4 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -778,6 +778,52 @@ def get_input(size: tuple[int], alignment_bytes: int) -> torch.Tensor: output_line = f"buf12 = empty_strided_{GPU_TYPE}({output_shape}, {output_stride}, torch.float32)" self.assertTrue(output_line in code[0]) + @requires_gpu() + def test_concat_output_no_redundant_copy_with_padding(self): + """ + When comprehensive_padding is enabled, ConcatKernel pads its output + buffer strides. The graph output should accept the padded strides + directly instead of generating a redundant copy kernel. + """ + + def f(x): + a = x + 1 + # a has two consumers (mul and cat), which forces ConcatKernel + # over pointwise_cat. Both inputs are Pointwise with + # FlexibleLayout so they realize directly into concat slices. + b = a * 2 + return torch.cat([a, b], dim=1) + + # Use dim=131 so concat output dim (262) is not aligned to + # padding_alignment_bytes/4=32, triggering stride padding. + x = torch.randn(128, 131, device=GPU_TYPE) + + with config.patch( + { + "comprehensive_padding": True, + "pad_outputs": True, + "padding_stride_threshold": 0, + "inplace_buffers": False, + } + ): + result, code = run_and_get_code(torch.compile(f), x) + + ref = f(x) + self.assertTrue(torch.allclose(ref, result, atol=1e-3, rtol=1e-3)) + # Only one output buffer should be allocated for the concat result. + # Without the fix, a second empty_strided is allocated and a copy + # kernel is generated to copy from the padded concat buffer to it. + # Count actual buffer allocations (not import lines) by matching + # "= empty_strided_(" pattern. + import re + + num_allocs = len(re.findall(rf"= empty_strided_{GPU_TYPE}\(", code[0])) + self.assertEqual( + num_allocs, + 1, + "Expected exactly one buffer allocation for concat output (no redundant copy)", + ) + @parametrize( "shape,alignment_bytes,enable_pad", [ @@ -908,6 +954,25 @@ def test_dynamic_shape_padding(self, shape, alignment_bytes, enable_pad): ) self.assertEqual(result.stride(), expected_stride) + def test_reduction_comprehensive_padding_stride(self): + """Comprehensive padding should not cause stride mismatches for + user-visible reductions. + + Regression test for https://github.com/pytorch/pytorch/issues/179931 + """ + + def program(x): + y = torch.nn.functional.adaptive_avg_pool2d(x, 7) + return y.flatten(1).sum(dim=-1) + + x = torch.randn(4, 2049, 8, 8, dtype=torch.float32, device=GPU_TYPE) + eager = program(x.clone()) + + with config.patch({"comprehensive_padding": True}): + compiled = torch.compile(program, backend="inductor")(x.clone()) + + self.assertEqual(eager, compiled) + if __name__ == "__main__": if HAS_GPU: diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py index 4077030eada7d..178f83ca9740a 100644 --- a/test/inductor/test_pallas.py +++ b/test/inductor/test_pallas.py @@ -5,9 +5,11 @@ import sys import unittest +import numpy as np + import torch import torch._dynamo -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.testing import make_test_cls_with_patches from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase @@ -628,7 +630,7 @@ def test_stride_non_contiguous_unsqueeze(self): x = base_2d[::2, ::2].unsqueeze(0) self.assertEqual(compiled(x), x * 2.0 + 1.0) - @skip_if_tpu + @skip_if_tpu(reason="TPU doesn't support float 64") def test_stride_non_contiguous_dtypes(self): """Test non-contiguous patterns with various dtypes.""" compiled = self._compile(lambda x: x * 2.0 + 1.0) @@ -660,7 +662,6 @@ def test_stride_expanded_tensors(self): x = torch.randn(1, 1, 16, device=self.DEVICE).expand(4, 8, 16) self.assertEqual(compiled(x, x), x + x) - @skip_if_tpu def test_stride_multiple_inputs(self): """Test multiple strided inputs and broadcasting.""" compiled = self._compile(lambda a, b, c: a * b + c) @@ -686,6 +687,80 @@ def test_stride_multiple_inputs(self): compiled_bcast = self._compile(lambda x, y, s: x + y * s) self.assertEqual(compiled_bcast(x, y, s), x + y * s) + @skip_if_cuda + def test_scalar_scalar_ops(self): + """Test scalar-scalar operations.""" + + def test_add(a, b): + return a + b + + def test_mul(a, b): + return a * b + + def test_sub(a, b): + return a - b + + def test_div(a, b): + return a / b + + for fn in [test_add, test_mul, test_sub, test_div]: + with self.subTest(op=fn.__name__): + compiled = self._compile(fn) + + # Test with 0-D tensors (scalars) + a = torch.tensor(3.5, dtype=torch.float32, device=self.DEVICE) + b = torch.tensor(2.0, dtype=torch.float32, device=self.DEVICE) + + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + # Ensure result is also scalar + self.assertEqual(result.dim(), 0) + self.assertEqual(result.dtype, torch.float32) + + def test_scalar_tensor_ops(self): + """Test scalar-tensor operations.""" + + def test_scalar_add_tensor(s, t): + return s + t + + def test_tensor_add_scalar(t, s): + return t + s + + def test_scalar_mul_tensor(s, t): + return s * t + + def test_tensor_mul_scalar(t, s): + return t * s + + shapes = [(16,), (8, 8), (4, 4, 4)] + + for shape in shapes: + for fn in [ + test_scalar_add_tensor, + test_tensor_add_scalar, + test_scalar_mul_tensor, + test_tensor_mul_scalar, + ]: + with self.subTest(op=fn.__name__, shape=shape): + compiled = self._compile(fn) + + # Create 0-D scalar tensor + scalar = torch.tensor(2.5, dtype=torch.float32, device=self.DEVICE) + tensor = torch.randn(shape, dtype=torch.float32, device=self.DEVICE) + + if "scalar" in fn.__name__.split("_")[0]: + result = compiled(scalar, tensor) + expected = fn(scalar, tensor) + else: + result = compiled(tensor, scalar) + expected = fn(tensor, scalar) + + self.assertEqual(result, expected) + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, torch.float32) + def test_non_power_of_2_sizes(self): """Test that non-power-of-2 tensor sizes work correctly. @@ -720,7 +795,7 @@ def fn(x, y): expected = fn(x, y) self.assertEqual(result, expected) - @skip_if_tpu + @skip_if_tpu(reason="Cannot do int indexing on TPU") @skip_if_cuda(reason="gather not supported in Pallas GPU (Mosaic) backend") def test_complex_indexing_gather(self): """Test complex indexing with gather-like operations.""" @@ -740,7 +815,7 @@ def fn(x, indices): expected = fn(x, indices) self.assertEqual(result, expected) - @skip_if_tpu + @skip_if_tpu(reason="Cannot do int indexing on TPU") # Pallas Mosaic backend doesn't support gather operations with array indices # This limitation is in the Pallas/Mosaic lowering, not our implementation @skip_if_cuda( @@ -954,7 +1029,6 @@ def fn(a, b): expected = fn(a, b) self.assertEqual(result, expected) - @skip_if_tpu def test_sign(self): """Test sign operation.""" @@ -995,7 +1069,9 @@ def fn(x): expected = fn(x) self.assertEqual(result, expected) - @skip_if_tpu + @skip_if_tpu( + reason="Pallas loweing crash: https://github.com/jax-ml/jax/issues/36149" + ) def test_erf(self): """Test erf operation.""" @@ -1009,7 +1085,9 @@ def fn(x): expected = fn(x) self.assertEqual(result, expected) - @skip_if_tpu + @skip_if_tpu( + reason="Pallas loweing crash: https://github.com/jax-ml/jax/issues/36149" + ) def test_atan2(self): """Test atan2 operation.""" @@ -1075,7 +1153,7 @@ def fn(x): expected = fn(x) self.assertEqual(result, expected) - @skip_if_tpu + @skip_if_tpu(reason="reduce_prod primitive not implemented in Pallas TPU lowering") @skip_if_cuda(reason="reduce_prod primitive not implemented in Pallas Mosaic GPU") def test_prod_reduction(self): """Test prod reduction.""" @@ -1154,7 +1232,6 @@ def fn(x, weight): self.assertEqual(result, expected) @skip_if_cuda - @skip_if_tpu def test_welford(self): """Test Welford variance/mean computation (two-pass fallback).""" @@ -1168,9 +1245,11 @@ def fn(x): compiled = self._compile(fn) x = torch.randn(shape, device=self.DEVICE) var_result, mean_result = compiled(x) - var_expected, mean_expected = fn(x) - self.assertEqual(mean_result, mean_expected) - self.assertEqual(var_result, var_expected) + # Eager mode torch_tpu doesn't support lowering var_mean, so comparing with numpy + var_expected = np.var(x.cpu().numpy(), axis=-1, keepdims=True, ddof=1) + mean_expected = np.mean(x.cpu().numpy(), axis=-1, keepdims=True) + self.assertEqual(mean_result.cpu().numpy(), mean_expected) + self.assertEqual(var_result.cpu().numpy(), var_expected) @skip_if_cuda def test_layer_norm(self): @@ -1722,7 +1801,14 @@ def transformer_block(x, w_q, w_k, w_v, w_proj, w_fc, w_out, mask): self.assertEqual(result, expected) def _run_transformer_layer( - self, seq_len, hidden_dim, num_heads, head_dim, ffn_dim, atol=1e-5, rtol=1.3e-6 + self, + seq_len, + hidden_dim, + num_heads, + head_dim, + ffn_dim, + atol=1e-5, + rtol=1.3e-6, ): """Run a Llama-style transformer layer forward pass and verify correctness. @@ -1776,8 +1862,9 @@ def transformer_layer( compiled = self._compile(transformer_layer) - # Initialize weights with small values for numerical stability - s = 0.02 + # Scale weights by 1/sqrt(fan_in) to keep activations O(1) and + # avoid rsqrt amplification of reduction-order diffs. + s = hidden_dim**-0.5 w_q = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s w_k = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s w_v = torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s @@ -1794,7 +1881,7 @@ def transformer_layer( diagonal=1, ) - x = torch.randn(seq_len, hidden_dim, device=self.DEVICE) * 0.02 + x = torch.randn(seq_len, hidden_dim, device=self.DEVICE) result = compiled( x, @@ -1844,8 +1931,8 @@ def test_transformer_layer_medium(self): num_heads=32, head_dim=128, ffn_dim=11008, - atol=1e-4, - rtol=1e-4, + atol=2e-2, + rtol=1e-2, ) @skip_if_cuda @@ -1857,8 +1944,8 @@ def test_transformer_layer_large(self): num_heads=128, head_dim=128, ffn_dim=53248, - atol=2e-3, - rtol=1e-3, + atol=2e-2, + rtol=1e-2, ) @skip_if_cuda @@ -2067,6 +2154,461 @@ def fn(x, p=perm): expected = fn(x) self.assertEqual(result, expected) + def _run_transformer( + self, + num_layers, + seq_len, + hidden_dim, + num_heads, + head_dim, + ffn_dim, + atol=1e-5, + rtol=1.3e-6, + ): + """Run a multi-layer Llama-style transformer and verify correctness.""" + torch._dynamo.reset() + + def transformer(x, mask, *layer_params): + T, C = x.shape + params_per_layer = ( + 9 # rms_w1, rms_w2, w_q, w_k, w_v, w_o, w_gate, w_up, w_down + ) + + for i in range(num_layers): + offset = i * params_per_layer + rms_w1 = layer_params[offset] + rms_w2 = layer_params[offset + 1] + w_q = layer_params[offset + 2] + w_k = layer_params[offset + 3] + w_v = layer_params[offset + 4] + w_o = layer_params[offset + 5] + w_gate = layer_params[offset + 6] + w_up = layer_params[offset + 7] + w_down = layer_params[offset + 8] + + # Pre-attention RMSNorm + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w1 + + # Multi-head self-attention + q = (h @ w_q).view(T, num_heads, head_dim).permute(1, 0, 2) + k = (h @ w_k).view(T, num_heads, head_dim).permute(1, 0, 2) + v = (h @ w_v).view(T, num_heads, head_dim).permute(1, 0, 2) + + scale = 1.0 / (head_dim**0.5) + att = (q @ k.transpose(-2, -1)) * scale + att = att + mask + att = torch.softmax(att, dim=-1) + attn_out = (att @ v).permute(1, 0, 2).contiguous().view(T, C) + + x = x + (attn_out @ w_o) + + # Pre-FFN RMSNorm + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w2 + + # SwiGLU FFN + gate = torch.nn.functional.silu(h @ w_gate) + up = h @ w_up + x = x + ((gate * up) @ w_down) + + return x + + compiled = self._compile(transformer) + + s = hidden_dim**-0.5 + all_params = [] + for _ in range(num_layers): + all_params.extend( + [ + torch.ones(hidden_dim, device=self.DEVICE), # rms_w1 + torch.ones(hidden_dim, device=self.DEVICE), # rms_w2 + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, # w_q + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, # w_k + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, # w_v + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, # w_o + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, # w_gate + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, # w_up + torch.randn(ffn_dim, hidden_dim, device=self.DEVICE) * s, # w_down + ] + ) + + mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=self.DEVICE), + diagonal=1, + ) + x = torch.randn(seq_len, hidden_dim, device=self.DEVICE) + + result = compiled(x, mask, *all_params) + expected = transformer(x, mask, *all_params) + self.assertEqual(result, expected, atol=atol, rtol=rtol) + + @skip_if_cuda + def test_transformer_tiny(self): + """Test a 4-layer Llama-style transformer at tiny dimensions.""" + self._run_transformer( + num_layers=4, + seq_len=32, + hidden_dim=64, + num_heads=2, + head_dim=32, + ffn_dim=256, + atol=1e-2, + rtol=1e-2, + ) + + @skip_if_cuda + def test_transformer_medium(self): + """Test a 4-layer transformer at Llama-7B-like dimensions.""" + self._run_transformer( + num_layers=4, + seq_len=128, + hidden_dim=4096, + num_heads=32, + head_dim=128, + ffn_dim=11008, + atol=0.1, + rtol=1e-2, + ) + + def _run_transformer_lm( + self, + num_layers, + seq_len, + vocab_size, + hidden_dim, + num_heads, + head_dim, + ffn_dim, + atol=1e-5, + rtol=1.3e-6, + ): + """Run a full Llama-style LM (embedding + layers + norm + lm_head).""" + torch._dynamo.reset() + + def transformer_lm( + token_ids, embed_table, final_rms_w, lm_head_w, mask, *layer_params + ): + # Token embedding via gather + x = embed_table[token_ids] # (T, C) + T, C = x.shape + params_per_layer = 9 + + for i in range(num_layers): + offset = i * params_per_layer + rms_w1 = layer_params[offset] + rms_w2 = layer_params[offset + 1] + w_q = layer_params[offset + 2] + w_k = layer_params[offset + 3] + w_v = layer_params[offset + 4] + w_o = layer_params[offset + 5] + w_gate = layer_params[offset + 6] + w_up = layer_params[offset + 7] + w_down = layer_params[offset + 8] + + # Pre-attention RMSNorm + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w1 + + # Multi-head self-attention + q = (h @ w_q).view(T, num_heads, head_dim).permute(1, 0, 2) + k = (h @ w_k).view(T, num_heads, head_dim).permute(1, 0, 2) + v = (h @ w_v).view(T, num_heads, head_dim).permute(1, 0, 2) + + scale = 1.0 / (head_dim**0.5) + att = (q @ k.transpose(-2, -1)) * scale + att = att + mask + att = torch.softmax(att, dim=-1) + attn_out = (att @ v).permute(1, 0, 2).contiguous().view(T, C) + + x = x + (attn_out @ w_o) + + # Pre-FFN RMSNorm + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w2 + + # SwiGLU FFN + gate = torch.nn.functional.silu(h @ w_gate) + up = h @ w_up + x = x + ((gate * up) @ w_down) + + # Final RMSNorm + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + 1e-6) * final_rms_w + + # LM head + logits = x @ lm_head_w # (T, V) + return logits + + compiled = self._compile(transformer_lm) + + s = hidden_dim**-0.5 + embed_table = torch.randn(vocab_size, hidden_dim, device=self.DEVICE) * s + final_rms_w = torch.ones(hidden_dim, device=self.DEVICE) + lm_head_w = torch.randn(hidden_dim, vocab_size, device=self.DEVICE) * s + + all_params = [] + for _ in range(num_layers): + all_params.extend( + [ + torch.ones(hidden_dim, device=self.DEVICE), + torch.ones(hidden_dim, device=self.DEVICE), + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, + torch.randn(ffn_dim, hidden_dim, device=self.DEVICE) * s, + ] + ) + + mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=self.DEVICE), + diagonal=1, + ) + token_ids = torch.randint(0, vocab_size, (seq_len,), device=self.DEVICE) + + result = compiled( + token_ids, embed_table, final_rms_w, lm_head_w, mask, *all_params + ) + expected = transformer_lm( + token_ids, embed_table, final_rms_w, lm_head_w, mask, *all_params + ) + self.assertEqual(result, expected, atol=atol, rtol=rtol) + + @unittest.skip("numerical mismatch in embedding + RMSNorm fusion") + def test_transformer_lm_tiny(self): + """Test a full LM (embed + 4 layers + norm + lm_head) at tiny dims.""" + self._run_transformer_lm( + num_layers=4, + seq_len=32, + vocab_size=256, + hidden_dim=64, + num_heads=2, + head_dim=32, + ffn_dim=256, + ) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_embedding_lookup(self): + """Test simple embedding table lookup (integer indexing).""" + + def fn(token_ids, embed_table): + return embed_table[token_ids] + + compiled = self._compile(fn) + embed_table = torch.randn(256, 64, device=self.DEVICE) + token_ids = torch.randint(0, 256, (32,), device=self.DEVICE) + result = compiled(token_ids, embed_table) + expected = fn(token_ids, embed_table) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_basic(self): + """Test bare embedding lookup via indirect access detection.""" + + def fn(indices, table): + return table[indices] + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + result = compiled(indices, table) + expected = fn(indices, table) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_non_128_dim(self): + """Test indirect access with D not divisible by 128.""" + + def fn(indices, table): + return table[indices] + + compiled = self._compile(fn) + table = torch.randn(256, 100, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + result = compiled(indices, table) + expected = fn(indices, table) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_large_vocab(self): + """Test indirect access with larger vocabulary.""" + + def fn(indices, table): + return table[indices] + + compiled = self._compile(fn) + table = torch.randn(2048, 128, device=self.DEVICE) + indices = torch.randint(0, 2048, (64,), device=self.DEVICE) + result = compiled(indices, table) + expected = fn(indices, table) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_fused_add(self): + """Test embedding + pointwise add fused.""" + + def fn(indices, table, bias): + return table[indices] + bias + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + bias = torch.randn(32, 64, device=self.DEVICE) + result = compiled(indices, table, bias) + expected = fn(indices, table, bias) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_fused_mul(self): + """Test embedding + pointwise multiply with broadcast.""" + + def fn(indices, table, scale): + return table[indices] * scale + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + scale = torch.randn(64, device=self.DEVICE) + result = compiled(indices, table, scale) + expected = fn(indices, table, scale) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_fused_chain(self): + """Test embedding + chained add and multiply.""" + + def fn(indices, table, bias, scale): + x = table[indices] + bias + return x * scale + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + bias = torch.randn(32, 64, device=self.DEVICE) + scale = torch.randn(64, device=self.DEVICE) + result = compiled(indices, table, bias, scale) + expected = fn(indices, table, bias, scale) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_nn_embedding(self): + """Test nn.functional.embedding path through indirect access.""" + + def fn(indices, weight): + return torch.nn.functional.embedding(indices, weight) + + compiled = self._compile(fn) + weight = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (32,), device=self.DEVICE) + result = compiled(indices, weight) + expected = fn(indices, weight) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_single_token(self): + """Test indirect access with seq=1.""" + + def fn(indices, table): + return table[indices] + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.randint(0, 256, (1,), device=self.DEVICE) + result = compiled(indices, table) + expected = fn(indices, table) + self.assertEqual(result, expected) + + @skip_if_cuda(reason="scalar prefetch not supported in Pallas GPU (Mosaic) backend") + def test_indirect_access_duplicate_indices(self): + """Test indirect access with repeated indices.""" + + def fn(indices, table): + return table[indices] + + compiled = self._compile(fn) + table = torch.randn(256, 64, device=self.DEVICE) + indices = torch.tensor([0, 1, 0, 1, 2, 2, 3, 3], device=self.DEVICE) + result = compiled(indices, table) + expected = fn(indices, table) + self.assertEqual(result, expected) + + @skip_if_cuda + def test_transformer_with_final_norm_and_lm_head(self): + """Test multi-layer transformer + final RMSNorm + LM head (no embedding).""" + torch._dynamo.reset() + num_layers = 4 + seq_len = 32 + hidden_dim = 64 + num_heads = 2 + head_dim = 32 + ffn_dim = 256 + vocab_size = 256 + + def transformer_with_head(x, final_rms_w, lm_head_w, mask, *layer_params): + T, C = x.shape + params_per_layer = 9 + for i in range(num_layers): + offset = i * params_per_layer + rms_w1 = layer_params[offset] + rms_w2 = layer_params[offset + 1] + w_q = layer_params[offset + 2] + w_k = layer_params[offset + 3] + w_v = layer_params[offset + 4] + w_o = layer_params[offset + 5] + w_gate = layer_params[offset + 6] + w_up = layer_params[offset + 7] + w_down = layer_params[offset + 8] + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w1 + q = (h @ w_q).view(T, num_heads, head_dim).permute(1, 0, 2) + k = (h @ w_k).view(T, num_heads, head_dim).permute(1, 0, 2) + v = (h @ w_v).view(T, num_heads, head_dim).permute(1, 0, 2) + scale = 1.0 / (head_dim**0.5) + att = (q @ k.transpose(-2, -1)) * scale + att = att + mask + att = torch.softmax(att, dim=-1) + attn_out = (att @ v).permute(1, 0, 2).contiguous().view(T, C) + x = x + (attn_out @ w_o) + variance = x.pow(2).mean(-1, keepdim=True) + h = x * torch.rsqrt(variance + 1e-6) * rms_w2 + gate = torch.nn.functional.silu(h @ w_gate) + up = h @ w_up + x = x + ((gate * up) @ w_down) + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + 1e-6) * final_rms_w + return x @ lm_head_w + + compiled = self._compile(transformer_with_head) + s = hidden_dim**-0.5 + final_rms_w = torch.ones(hidden_dim, device=self.DEVICE) + lm_head_w = torch.randn(hidden_dim, vocab_size, device=self.DEVICE) * s + all_params = [] + for _ in range(num_layers): + all_params.extend( + [ + torch.ones(hidden_dim, device=self.DEVICE), + torch.ones(hidden_dim, device=self.DEVICE), + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, hidden_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, + torch.randn(hidden_dim, ffn_dim, device=self.DEVICE) * s, + torch.randn(ffn_dim, hidden_dim, device=self.DEVICE) * s, + ] + ) + mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=self.DEVICE), + diagonal=1, + ) + x = torch.randn(seq_len, hidden_dim, device=self.DEVICE) + result = compiled(x, final_rms_w, lm_head_w, mask, *all_params) + expected = transformer_with_head(x, final_rms_w, lm_head_w, mask, *all_params) + self.assertEqual(result, expected) + def test_warpgroup_size_2d_aligned_32x8(self): """Test 2D tensor with 32x8 = 256 elements (2 warpgroups).""" diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 6598d02443d39..63a2236bbb6b7 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -30,6 +30,11 @@ from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code from torch._inductor.virtualized import V +from torch._library.opaque_object import ( + get_opaque_type_name, + OpaqueBase, + register_opaque_type, +) from torch.fx.experimental.proxy_tensor import make_fx from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater, xfailIfSM89 @@ -47,6 +52,26 @@ aten = torch.ops.aten +class OpaqueScaleFactor(OpaqueBase): + def __init__(self, val): + self.val = val + + def __eq__(self, other): + return isinstance(other, OpaqueScaleFactor) and self.val == other.val + + def __hash__(self): + return hash(self.val) + + def __fx_repr__(self): + return ( + f"OpaqueScaleFactor({self.val!r})", + {"OpaqueScaleFactor": OpaqueScaleFactor}, + ) + + +register_opaque_type(OpaqueScaleFactor, typ="value", hoist=True) + + @instantiate_parametrized_tests class TestPatternMatcher(TestCase): device_type = GPU_TYPE @@ -136,7 +161,6 @@ def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True): ref[indices], test[indices] ) # also checks that dtype is correct - # @skipIfXpu @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch( { @@ -839,23 +863,36 @@ def f(x): joint_graph.joint_graph_passes(gm) self.assertEqual(count_calls(gm.graph), 2) - def test_pointless_convert(self): - def fn1(x): - x = torch.ops.prims.convert_element_type.default(x, torch.float16) - x = torch.ops.prims.convert_element_type.default(x, torch.float32) + @parametrize( + "input_dtype, intermediate_dtype, emulate_precision_casts, expected_calls", + [ + (torch.float32, torch.float16, False, 1), + (torch.float32, torch.float16, True, 2), + (torch.float16, torch.float32, True, 1), + ], + ) + def test_pointless_convert( + self, input_dtype, intermediate_dtype, emulate_precision_casts, expected_calls + ): + def fn(x): + x = torch.ops.prims.convert_element_type.default(x, intermediate_dtype) + x = torch.ops.prims.convert_element_type.default(x, input_dtype) return x - gm = torch.fx.symbolic_trace(fn1) + x = torch.randn(8, device=GPU_TYPE, dtype=input_dtype) + gm = make_fx(fn)(x) self.assertEqual(count_calls(gm.graph), 2) - joint_graph.joint_graph_passes(gm) - self.assertEqual(count_calls(gm.graph), 1) + with inductor_config.patch(emulate_precision_casts=emulate_precision_casts): + joint_graph.joint_graph_passes(gm) + self.assertEqual(count_calls(gm.graph), expected_calls) - def fn2(x): + def fn_int(x): x = torch.ops.prims.convert_element_type.default(x, torch.int32) x = torch.ops.prims.convert_element_type.default(x, torch.float32) return x - gm = torch.fx.symbolic_trace(fn2) + x = torch.randn(8, device=GPU_TYPE, dtype=torch.float32) + gm = make_fx(fn_int)(x) self.assertEqual(count_calls(gm.graph), 2) joint_graph.joint_graph_passes(gm) self.assertEqual(count_calls(gm.graph), 2) @@ -1216,6 +1253,12 @@ def fn2(inp, a, b): FileCheck().check_not("extern_kernels.addmm(").run(code[0]) @parametrize("dtype", [torch.bfloat16, torch.float16]) + @inductor_config.patch( + { + "fx_graph_remote_cache": False, + "keep_addmm_fused_for_half_dtypes": True, + } + ) def test_unfuse_bias_addmm_half_dtypes(self, dtype): args = [ torch.randn(20, device=GPU_TYPE, dtype=dtype), @@ -1232,6 +1275,27 @@ def fn(inp, a, b): _, (code) = run_and_get_code(fn, args[0], args[1], args[2]) FileCheck().check("extern_kernels.addmm(").run(code[0]) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + @inductor_config.patch( + { + "fx_graph_remote_cache": False, + "keep_addmm_fused_for_half_dtypes": False, + } + ) + def test_unfuse_bias_addmm_half_dtypes_when_flag_disabled(self, dtype): + args = [ + torch.randn(20, device=GPU_TYPE, dtype=dtype), + torch.randn(10, 15, device=GPU_TYPE, dtype=dtype), + torch.randn(15, 20, device=GPU_TYPE, dtype=dtype), + ] + + @torch.compile() + def fn(inp, a, b): + return torch.nn.functional.gelu(torch.ops.aten.addmm(inp, a, b)) + + _, (code) = run_and_get_code(fn, args[0], args[1], args[2]) + FileCheck().check_not("extern_kernels.addmm(").run(code[0]) + def test_addmm_alpha_beta_with_pointwise(self): # Test that addmm with alpha/beta != 1 is unfused correctly with pointwise ops # See https://github.com/pytorch/pytorch/issues/167313 @@ -1321,6 +1385,7 @@ def remap_fake_tensor(x): "max_autotune_gemm_backends": "TRITON", } ) + @unittest.skipIf(not IS_BIG_GPU, "templates require big gpu") def test_original_aten_preserved_split_addmm(self): # addmm -> elementwise should be decomposed into mm -> add -> elementwise def fn(x, y, z): @@ -1895,6 +1960,60 @@ def f(x): self.assertEqual(len(sigmoid_nodes), 1) self.assertTrue("original_aten" in sigmoid_nodes[0].meta) + def test_fwd_only_uses_get_decomp_fn(self): + """fwd_only traces the pattern graph using the table from get_decomp_fn.""" + + def fn(x): + return F.gelu(x) + + x = torch.randn(4, 4, device=GPU_TYPE) + + # Default: gelu decomposes into erf/mul/add primitives. + gm = fwd_only(fn, args=[x]) + targets = {n.target for n in gm.graph.nodes if n.op == "call_function"} + self.assertNotIn(aten.gelu.default, targets) + self.assertIn(aten.erf.default, targets) + + # Empty decomp table: gelu stays intact as aten.gelu.default. + gm_nodec = fwd_only(fn, args=[x], get_decomp_fn=dict) + targets_nodec = { + n.target for n in gm_nodec.graph.nodes if n.op == "call_function" + } + self.assertIn(aten.gelu.default, targets_nodec) + self.assertNotIn(aten.erf.default, targets_nodec) + + def test_register_replacement_get_decomp_fn(self): + """A pattern registered with get_decomp_fn matches graphs traced with + the same decomposition table, and does not match graphs traced with a + different one.""" + + def gelu_pattern(x): + return F.gelu(x) + + def gelu_double(x): + return F.gelu(x) * 2.0 + + x = torch.randn(4, 4, device=GPU_TYPE) + + # Pattern traced without decompositions: stored as aten.gelu.default. + my_patterns = PatternMatcherPass() + register_replacement( + gelu_pattern, + gelu_double, + [x], + fwd_only, + my_patterns, + get_decomp_fn=dict, + ) + + # Graph where gelu is also not decomposed: pattern should match once. + gm_nodec = make_fx(gelu_pattern, {})(x) + self.assertEqual(my_patterns.apply(gm_nodec.graph), 1) + + # Graph where gelu is decomposed (default decomps): no match. + gm_decomposed = fwd_only(gelu_pattern, args=[x]) + self.assertEqual(my_patterns.apply(gm_decomposed.graph), 0) + @inductor_config.patch(is_predispatch=True) def test_remove_noop_pass_with_remove_passes(self): def fn_with_noop(x): @@ -2069,6 +2188,214 @@ def fn(x): self.assertEqual(result, x * 3) self.assertEqual(count, 1) + def test_register_replacement_single_tensor_input(self): + def pattern(x): + return x + 1 + + def replacement(x): + return x - 1 + + my_patterns = PatternMatcherPass() + + # Single tensor should fail fast instead of reaching tracing logic. + example_input = torch.randn(4, 4, device=GPU_TYPE, requires_grad=True) + with self.assertRaisesRegex( + TypeError, + f"example_inputs must be a list or tuple, got {type(example_input)}", + ): + register_replacement( + pattern, replacement, example_input, fwd_only, my_patterns + ) + + def _inject_test_metadata(self, graph): + """Inject identifiable metadata on all call_function nodes for testing.""" + for node in graph.nodes: + if node.op == "call_function": + node.meta["stack_trace"] = f"trace_for_{node.name}" + node.meta["nn_module_stack"] = {"test": ("m", "M")} + node.meta["_numeric_debug_handle"] = 42 + node.meta["custom"] = {"test_key": "test_value"} + + def test_metadata_propagation_register_replacement(self): + """Verify metadata from matched nodes transfers to replacement nodes.""" + + def pattern(x, y): + return x + y + + def replacement(x, y): + return x * y + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement(pattern, replacement, inputs, fwd_only, my_patterns) + + def custom_pass(graph: torch.fx.Graph): + self._inject_test_metadata(graph) + # _transfer_meta runs inside replace_with_graph for + # each old->new pair to propagate metadata fields + my_patterns.apply(graph) + + def fn(x, y): + return x + y + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + compiled_fn(x, y) + + def test_metadata_propagation_register_replacement_multinode(self): + """Verify metadata propagation for multi-node patterns.""" + + def pattern(x, y): + tmp = x + y + return tmp * 2 + + def replacement(x, y): + return (x + y) * 3 + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement(pattern, replacement, inputs, fwd_only, my_patterns) + + def custom_pass(graph: torch.fx.Graph): + self._inject_test_metadata(graph) + my_patterns.apply(graph) + + def fn(x, y): + tmp = x + y + return tmp * 2 + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + compiled_fn(x, y) + + def test_metadata_propagation_graph_pattern_replace_by_example(self): + """Verify metadata propagation for replace_by_example (single-node match).""" + + test_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction(aten.add.Tensor, KeywordArg("x"), KeywordArg("y")), + pass_dict=test_pass, + ) + def add_to_mul(match: Match, x, y): + def repl(a, b): + return a * b + + with V.fake_mode: + match.replace_by_example(repl, [x, y]) + + def custom_pass(graph: torch.fx.Graph): + self._inject_test_metadata(graph) + test_pass.apply(graph) + + def fn(x, y): + return x + y + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + compiled_fn(x, y) + + def test_metadata_propagation_replace_by_example_multinode(self): + """Verify metadata propagation for replace_by_example with a multi-node + match. The output node should inherit metadata from the matched output + node via replace_with_graph's replace() inner function.""" + + test_pass = PatternMatcherPass() + + @register_graph_pattern( + CallFunction( + aten.mul.Tensor, + CallFunction(aten.add.Tensor, KeywordArg("x"), KeywordArg("y")), + KeywordArg("z"), + ), + pass_dict=test_pass, + ) + def add_mul_to_sub(match: Match, x, y, z): + def repl(a, b, c): + return (a - b) * c + + with V.fake_mode: + match.replace_by_example(repl, [x, y, z]) + + def custom_pass(graph: torch.fx.Graph): + self._inject_test_metadata(graph) + test_pass.apply(graph) + + def fn(x, y, z): + return (x + y) * z + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + z = torch.randn(4, 4, device=GPU_TYPE) + + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + compiled_fn(x, y, z) + + def test_metadata_propagation_lowering_pattern(self): + """Verify metadata propagation for LoweringPatternEntry.apply. + + LoweringPatternEntry uses _transfer_meta to copy _COPY_META_FIELDS + and stack_trace from the matched node to the replacement. + """ + from torch._inductor.pattern_matcher import _transfer_meta + + test_pass = PatternMatcherPass() + + counter = 0 + + @register_graph_pattern( + CallFunction(aten.add.Tensor, KeywordArg("x"), KeywordArg("y")), + pass_dict=test_pass, + ) + def manual_lowering(match: Match, x, y): + nonlocal counter + # Manually exercise the LoweringPatternEntry code path: + # create a replacement node, propagate meta, and replace + node = match.output_node() + graph = match.graph + with graph.inserting_before(node): + replacement = graph.call_function(aten.mul.Tensor, (x, y)) + _transfer_meta(replacement.meta, node) + node.replace_all_uses_with(replacement) + match.erase_nodes() + counter += 1 + + def custom_pass(graph: torch.fx.Graph): + self._inject_test_metadata(graph) + test_pass.apply(graph) + + def fn(x, y): + return x + y + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + compiled_fn(x, y) + self.assertEqual(counter, 1) + class TestPatternMatcherLogging(LoggingTestCase): device_type = GPU_TYPE @@ -2149,6 +2476,98 @@ def fn(x): specific_record.getMessage(), ) + @make_logging_test() + def test_pattern_match_debug_multiple_nodes(self, records): + def pattern_add(x, y): + return x + y + + def replacement_add(x, y): + return x * y + + def pattern_sub(x, y): + return x - y + + def replacement_sub(x, y): + return x * y + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement( + pattern_add, replacement_add, inputs, fwd_only, my_patterns + ) + register_replacement( + pattern_sub, replacement_sub, inputs, fwd_only, my_patterns + ) + + def custom_pass(graph: torch.fx.Graph): + return my_patterns.apply(graph) + + def fn(x, y): + return (x + y) + (x - y) + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + # Debug both "add" and "sub" nodes + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "add,sub"} + ): + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + _ = compiled_fn(x, y) + + self.assertTrue(self.hasRecord(records, "Specific pattern match: add")) + self.assertTrue(self.hasRecord(records, "Specific pattern match: sub")) + + @make_logging_test() + def test_pattern_match_debug_all_nodes(self, records): + def pattern_add(x, y): + return x + y + + def replacement_add(x, y): + return x * y + + def pattern_sub(x, y): + return x - y + + def replacement_sub(x, y): + return x * y + + my_patterns = PatternMatcherPass() + inputs = [ + torch.randn(4, 4, device=GPU_TYPE), + torch.randn(4, 4, device=GPU_TYPE), + ] + register_replacement( + pattern_add, replacement_add, inputs, fwd_only, my_patterns + ) + register_replacement( + pattern_sub, replacement_sub, inputs, fwd_only, my_patterns + ) + + def custom_pass(graph: torch.fx.Graph): + return my_patterns.apply(graph) + + def fn(x, y): + return (x + y) + (x - y) + + x = torch.randn(4, 4, device=GPU_TYPE) + y = torch.randn(4, 4, device=GPU_TYPE) + + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "all"} + ): + compiled_fn = torch.compile( + fn, options={"post_grad_custom_post_pass": custom_pass} + ) + _ = compiled_fn(x, y) + self.assertTrue(self.hasRecord(records, "Specific pattern match: add")) + self.assertTrue(self.hasRecord(records, "Specific pattern match: sub")) + def test_gumbel_max_trick(self): counters.clear() @@ -2186,6 +2605,135 @@ def sample(logits, temperature): self.assertTrue(counters["inductor"]["apply_gumbel_max_trick"] == 1) + def test_per_pattern_counter(self): + """Test that per-pattern counters track individual pattern matches""" + with inductor_config.patch(fx_graph_cache=False): + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "1"} + ): + counters.clear() + + def fn(x, y): + return torch.bmm(x, y) + + x = torch.randn(4, 10, 10, device=GPU_TYPE) + y = torch.randn(4, 10, 10, device=GPU_TYPE) + + compiled = torch.compile(fn) + compiled(x, y) + + counter_key = "inductor_pattern_matcher_per_pattern" + per_pattern = counters.get(counter_key, None) + + self.assertIsInstance(per_pattern, dict) + self.assertGreater(len(per_pattern), 0) + self.assertIn("CallFunction_aten.bmm.default", per_pattern) + self.assertEqual(per_pattern["CallFunction_aten.bmm.default"], 1) + + def test_per_pattern_counter_accumulation(self): + """Test that per-pattern counters accumulate across compilations""" + with inductor_config.patch(fx_graph_cache=False): + with unittest.mock.patch.dict( + os.environ, {"TORCHINDUCTOR_PATTERN_MATCH_DEBUG": "1"} + ): + counter_key = "inductor_pattern_matcher_per_pattern" + + counters.clear() + + x = torch.randn(2, 10, 10, device=GPU_TYPE) + y = torch.randn(2, 10, 10, device=GPU_TYPE) + + def fn1(a, b): + return torch.bmm(a, b) + + compiled1 = torch.compile(fn1) + compiled1(x, y) + count1 = sum(counters.get(counter_key, {}).values()) + + # Compile second function without clearing counters + def fn2(a, b): + return torch.bmm(a, b) * 2 + + compiled2 = torch.compile(fn2) + compiled2(x, y) + accumulated_count = sum(counters.get(counter_key, {}).values()) + + # Verify accumulation + counters.clear() + torch._dynamo.reset() + compiled2 = torch.compile(fn2) + compiled2(x, y) + count2 = sum(counters.get(counter_key, {}).values()) + + self.assertEqual(accumulated_count, count1 + count2) + + def test_opaque_obj_custom_op(self): + with torch.library._scoped_library("_test_pm", "FRAGMENT") as lib: + lib.define( + f"original_op(Tensor x, {get_opaque_type_name(OpaqueScaleFactor)} s) -> Tensor" + ) + lib.impl("original_op", lambda x, s: x * s.val, "CompositeExplicitAutograd") + + @torch.library.register_fake("_test_pm::original_op", lib=lib) + def _orig_fake(x, s): + return torch.empty_like(x) + + lib.define( + f"replacement_op(Tensor x, {get_opaque_type_name(OpaqueScaleFactor)} s) -> Tensor" + ) + lib.impl( + "replacement_op", lambda x, s: x + s.val, "CompositeExplicitAutograd" + ) + + @torch.library.register_fake("_test_pm::replacement_op", lib=lib) + def _repl_fake(x, s): + return torch.empty_like(x) + + def pattern(x, factor): + return torch.ops._test_pm.original_op(x, factor) + + def replacement(x, factor): + return torch.ops._test_pm.replacement_op(x, factor) + + patterns = PatternMatcherPass() + register_replacement( + pattern, + replacement, + [torch.randn(4, 4), OpaqueScaleFactor(2.0)], + fwd_only, + patterns, + ) + + count = 0 + post_pass_graph = None + + def custom_pass(graph): + nonlocal count, post_pass_graph + count = patterns.apply(graph) + post_pass_graph = graph + return graph + + def custom_backend(graph, example_inputs): + from torch._inductor.compile_fx import compile_fx + + current_config = inductor_config.get_config_copy() + current_config["post_grad_custom_post_pass"] = custom_pass + return compile_fx(graph, example_inputs, config_patches=current_config) + + @torch.compile(backend=custom_backend) + def f(x, s): + return torch.ops._test_pm.original_op(x, s) + + inp = torch.randn(4, 4) + result = f(inp, OpaqueScaleFactor(4.0)) + self.assertEqual(result, inp + 4.0) + self.assertEqual(count, 1) + op_targets = [ + n.target for n in post_pass_graph.nodes if n.op == "call_function" + ] + self.assertNotIn(torch.ops._test_pm.original_op.default, op_targets) + self.assertIn(torch.ops._test_pm.replacement_op.default, op_targets) + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index cbb2a558b8865..4e168fd0b4c21 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import contextlib import re +import unittest from unittest.mock import patch import functorch @@ -28,6 +29,7 @@ # performance for that setting. # # Defines all the kernels for tests +from torch.testing._internal.common_utils import skipIfXpu, TEST_WITH_ROCM from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU_AND_TRITON from torch.testing._internal.triton_utils import requires_gpu_and_triton @@ -337,6 +339,35 @@ def f(a, b): inp = (T(10), TI(10, mx=10)) self.assertExpectedInline(count_numel(f, *inp), """30""") + @requires_gpu_and_triton + def test_delay_realize_cheap_outputs_shared_mask(self): + # Shared tril mask across multiple users gets eagerly materialized + # as an output buffer, inflating downstream read counts. With + # delay_realize_cheap_outputs, the mask stays inlined as index + # arithmetic, reducing memory traffic. + def f(x1, x2, x3): + mask = torch.tril(torch.ones(32, 32, device=x1.device)) + return x1 + mask, x2 + mask, x3 + mask + + inp = ( + T(32, 32, grad=True), + T(32, 32, grad=True), + T(32, 32, grad=True), + ) + + # Without deferred realization: mask gets materialized as output buffer + with patch.object(config, "delay_realize_cheap_outputs", False): + metrics.reset() + torch.compile(f, backend=compile_but_use_eager)(*inp) + eager_bytes = metrics.num_bytes_accessed + + # Default (deferred realization on): mask stays inlined + metrics.reset() + torch.compile(f, backend=compile_but_use_eager)(*inp) + deferred_bytes = metrics.num_bytes_accessed + + self.assertLessEqual(deferred_bytes, eager_bytes) + class FusionTests(TestCase): """ @@ -484,6 +515,26 @@ def f(a, b, c): inp = (T(10, 10), T(10, 10), T(10, 10)) self.assertExpectedInline(count_numel(f, *inp), """500""") + @skipIfXpu(msg="copy_(cat()) fusion not supported on XPU") + @unittest.skipIf(TEST_WITH_ROCM, "copy_(cat()) fusion not supported on ROCm") + # TODO(ivankobzarev): enable copy_(cat()) fusion for CUDA 13+ + @unittest.skipIf( + torch.version.cuda + and tuple(int(x) for x in torch.version.cuda.split(".")) >= (13, 0), + "copy_(cat()) fusion not supported on CUDA 13+", + ) + def test_copy_cat_fusion(self): + """copy_(cat(...)) should fuse: no intermediate allocation for cat.""" + + def f(dst, a, b): + dst.copy_(torch.cat([a, b])) + + dst = T(20) + inp = (dst, T(10), T(10)) + # 10 (read a) + 10 (read b) + 20 (write dst) = 40 + # Without fusion cat would allocate intermediate: 80 + self.assertExpectedInline(count_numel(f, *inp), """40""") + def test_reduction_pointwise_multi_level_reduction(self): hidden_size = 4096 layer_norm = torch.nn.LayerNorm(hidden_size).to(GPU_TYPE).float() @@ -505,7 +556,10 @@ def f(x, scale, amax_keep_dim): expected_numel = ( 1 + hidden_size * 2 + 4 * 2048 * hidden_size * 2 + 4 * 2048 * 2 + 1 ) - if config.triton.cooperative_reductions: + if ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ): expected_numel = 134225922 self.assertExpectedInline(count_numel(f, *inp, True), str(expected_numel)) diff --git a/test/inductor/test_profiler.py b/test/inductor/test_profiler.py index 99f06d5647dcf..fefa7fa20f585 100644 --- a/test/inductor/test_profiler.py +++ b/test/inductor/test_profiler.py @@ -11,7 +11,7 @@ from torch import _dynamo as torchdynamo from torch._inductor import config from torch.profiler import ProfilerActivity, record_function -from torch.testing._internal.common_utils import skipIfXpu, TemporaryFileName +from torch.testing._internal.common_utils import TemporaryFileName from torch.testing._internal.inductor_utils import ( GPU_TYPE, HAS_GPU_AND_TRITON, @@ -26,10 +26,6 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase): - @skipIfXpu( - msg="AssertionError: False is not true, " - "https://github.com/intel/torch-xpu-ops/issues/2335" - ) @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") def test_inductor_profiling_triton_launch(self): # Verify that we get some sort of CPU-side indication of triton kernel launches @@ -224,9 +220,6 @@ def fn(x, y): self.assertTrue(hooks_called["enter"]) self.assertTrue(hooks_called["exit"]) - @skipIfXpu( - msg="TypeError: list indices must be integers or slices, not str, https://github.com/intel/torch-xpu-ops/issues/2335" - ) @unittest.skipIf(not HAS_TRITON, "requires cuda & triton") def test_pt2_triton_attributes(self): from torch._inductor.codecache import code_hash @@ -349,7 +342,7 @@ def fn(x): sin: "f32[10][1]cpu" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None cos: "f32[10][1]cpu" = torch.ops.aten.cos.default(sin); sin = None add: "f32[10][1]cpu" = torch.ops.aten.add.Tensor(cos, 2); cos = None - return (add,)""", # noqa: B950 + return (add,)""", ignore_comments=True, ignore_empty_lines=True, ) diff --git a/test/inductor/test_provenance_tracing.py b/test/inductor/test_provenance_tracing.py index c58951ef8c7a1..a5e3e8dfc4ad0 100644 --- a/test/inductor/test_provenance_tracing.py +++ b/test/inductor/test_provenance_tracing.py @@ -29,10 +29,7 @@ from torch._inductor.virtualized import V from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.inductor_utils import GPU_TYPE -from torch.testing._internal.triton_utils import ( - requires_cuda_and_triton, - requires_gpu_and_triton, -) +from torch.testing._internal.triton_utils import requires_gpu_and_triton try: @@ -613,7 +610,7 @@ def test_tlparse_kernel_stack_traces(self): @torch._inductor.config.patch( {"trace.provenance_tracking_level": 2, "max_autotune_gemm_backends": "ATEN"} ) - @requires_cuda_and_triton + @requires_gpu_and_triton def test_deferred_triton_kernels(self): def foo(m, inp): a = m(inp) @@ -621,8 +618,8 @@ def foo(m, inp): foo_c = torch.compile(mode="max-autotune-no-cudagraphs")(foo) - m = torch.nn.Linear(512, 512, bias=True).half().cuda() - inp = torch.rand([1, 512]).half().cuda() + m = torch.nn.Linear(512, 512, bias=True).half().to(GPU_TYPE) + inp = torch.rand([1, 512]).half().to(GPU_TYPE) with self._setup_provenance_capture() as payload_buffer: with torch.no_grad(): @@ -902,16 +899,21 @@ def forward(self, x, a, b, c): code ) - if self.device == "cuda": + if self.device == "cuda" or self.device == "xpu": + device_type = torch.accelerator.current_accelerator().type FileCheck().check( - """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" - ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + f"""KernelContextGuard _ctx("aoti_torch_{device_type}_mm_out", R"(""" + ).check( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_{device_type}_mm_out(" + ).check( """KernelContextGuard _ctx("triton_poi_fused_addmm_relu_sigmoid_0", R"(""" ).check("call_triton_poi_fused_addmm_relu_sigmoid_0(").check( """KernelContextGuard _ctx("triton_poi_fused_mul_1", R"(""" ).check("call_triton_poi_fused_mul_1(").check( - """KernelContextGuard _ctx("aoti_torch_cuda_mm_out", R"(""" - ).check("AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_cuda_mm_out(").check( + f"""KernelContextGuard _ctx("aoti_torch_{device_type}_mm_out", R""" + ).check( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_{device_type}_mm_out(" + ).check( """ KernelContextGuard _ctx("triton_poi_fused_addmm_gelu_2", R"(""" ).check("call_triton_poi_fused_addmm_gelu_2(").run(code) else: @@ -942,15 +944,17 @@ class TestProvenanceTracingKernelContextCpu(TestCase): @unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") -@unittest.skipIf(not torch.cuda.is_available(), "No CUDA") +@unittest.skipIf( + not torch.cuda.is_available() and not torch.xpu.is_available(), "No CUDA and no XPU" +) class TestProvenanceTracingKernelContextGpu(TestCase): - device = "cuda" + device = GPU_TYPE copy_tests( ProvenanceTracingKernelContextTemplate, TestProvenanceTracingKernelContextGpu, - "cuda", + GPU_TYPE, ) diff --git a/test/inductor/test_quantization.py b/test/inductor/test_quantization.py index acc07c454c94d..51f7f4f205092 100644 --- a/test/inductor/test_quantization.py +++ b/test/inductor/test_quantization.py @@ -47,6 +47,29 @@ def forward(self, x): return x +class SharedOutputAndSavedModule(torch.nn.Module): + """A module where a 3D intermediate is both a user output and saved for backward. + + This triggers the bug in T264303372: the activation quantization pass would + quantize both the user output and saved-for-backward positions of the same + tensor, creating duplicate backward placeholders that shift the stride mapping. + """ + + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.ones(8, 8)) + self.W = torch.nn.Parameter(torch.randn(64, 32)) + + def forward(self, x): + h = x.view(x.shape[0], 8, 8) + attn = h * torch.sigmoid(h) + # attn is both returned (user output) and needed by backward (mul saves it) + scaled = attn * self.scale + flat = scaled.flatten(1) + result = flat @ self.W + return result, attn + + class LayernormNN(torch.nn.Module): def __init__(self): super().__init__() @@ -229,6 +252,37 @@ def test_activation_quantization_aten_without_scaling(self): self.assertTrue(torch.allclose(ref, res)) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={}, + post_grad_fusion_options={ + "activation_quantization_aten_pass": { + "quant_type": "torch.float8_e5m2", + "use_scaling": True, + "size_in_mb": 0.0, + "exclude_primals": True, + "allowed_dtypes": "torch.bfloat16;torch.float32", + }, + }, + ) + def test_activation_quantization_shared_output_and_saved(self): + """Test that activation quantization works when a tensor is both a user + output and a saved-for-backward activation (T264303372).""" + counters.clear() + module = SharedOutputAndSavedModule().to(GPU_TYPE) + x = torch.randn(8, 64, device=GPU_TYPE, requires_grad=True) + compiled = torch.compile(module) + result, attn = compiled(x) + loss = result.sum() + attn.sum() + loss.backward() + self.assertEqual( + counters["inductor"]["activation_quantization_fwd_aten_pass"], 1 + ) + self.assertEqual( + counters["inductor"]["activation_quantization_bwd_aten_pass"], 1 + ) + counters.clear() + if __name__ == "__main__": if IS_LINUX and HAS_GPU: diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index c014c50747a27..f17084bacf49e 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -25,7 +25,9 @@ from torch._inductor.kernel_inputs import KernelInputs from torch._inductor.select_algorithm import ( autotune_select_algorithm, + ExternalTritonTemplateKernel, ExternKernelChoice, + PartialRender, TritonTemplate, TritonTemplateKernel, ) @@ -228,6 +230,23 @@ def foo(a, b): if not torch.version.hip: # autotuning is not guaranteed to run on ROCm self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @patches + def test_bmm_small_m(self): + # Verify BMM works when M < BLOCK_M. The triton_bmm template's + # tl.max_contiguous/tl.multiple_of hints must be guarded by + # M >= BLOCK_M to avoid out-of-bounds vectorized loads (see #179267). + @torch.compile + def foo(a, b): + return torch.bmm(a, b) + + # M=2 is smaller than any BLOCK_M autotuning config (typically >= 16) + foo( + torch.randn(4, 2, 64, device=GPU_TYPE), + torch.randn(4, 64, 32, device=GPU_TYPE), + ) + if not torch.version.hip: + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @patches def test_mm_not_even_k(self): @torch.compile @@ -801,6 +820,73 @@ def patch_lowering(lowering_overrides) -> Callable[[], None]: yield +class TestDtypeViewAutotuning(TestCase): + @requires_gpu() + @unittest.skipIf( + not hasattr(torch, "float4_e2m1fn_x2"), + "float4_e2m1fn_x2 dtype not available", + ) + @patches + def test_benchmark_example_value_preserves_dtype_view(self): + """ + Verify that benchmark_example_value preserves the dtype when the IR + node is a dtype view (e.g. uint8 storage viewed as float4_e2m1fn_x2). + """ + from torch._inductor import ir + + m, k = 256, 2048 + device = torch.device(GPU_TYPE) + + # We need a V.graph context for benchmark_example_value. + # Compile a trivial function to set up V.graph, then call the + # function under test within the compilation callback. + captured_results = {} + + orig_lowering = torch._inductor.lowering.lowerings[aten.add.Tensor] + + def patched_lowering(*args, **kwargs): + # We're inside compilation — V.graph is valid. + # Construct a dtype-viewed IR node and test benchmark_example_value. + base_layout = ir.FixedLayout( + device=device, dtype=torch.uint8, size=[m, k], stride=[k, 1] + ) + base_buf = ir.Buffer(name="test_buf", layout=base_layout) + + fp4_layout = ir.FixedLayout( + device=device, + dtype=torch.float4_e2m1fn_x2, + size=[m, k], + stride=[k, 1], + ) + fp4_view = ir.ReinterpretView(data=base_buf, layout=fp4_layout) + + example = select_algorithm.AlgorithmSelectorCache.benchmark_example_value( + fp4_view + ) + captured_results["ir_dtype"] = fp4_view.get_dtype() + captured_results["example_dtype"] = example.dtype + captured_results["example_shape"] = tuple(example.shape) + + return orig_lowering(*args, **kwargs) + + torch._dynamo.reset() + with patch.dict( + torch._inductor.lowering.lowerings, + {aten.add.Tensor: patched_lowering}, + ): + compiled_fn = torch.compile(lambda x: x + 1) + compiled_fn(torch.randn(4, device=device)) + + self.assertIn("ir_dtype", captured_results, "Patched lowering was not called") + self.assertEqual( + captured_results["example_dtype"], + torch.float4_e2m1fn_x2, + f"benchmark_example_value should preserve float4_e2m1fn_x2 dtype " + f"after unwrapping the view, but got {captured_results['example_dtype']}", + ) + self.assertEqual(captured_results["example_shape"], (m, k)) + + class TestTemplateRender(TestCase): @requires_gpu() @requires_triton() @@ -899,6 +985,171 @@ def add(a, b): kernels[0] ) + @requires_gpu() + @requires_triton() + @config.patch(cuda_backend="triton") + def test_external_template_prologue_epilogue_fusion(self): + """ + Tests prologue fusion, epilogue fusion, and extra inputs through the + ExternalTritonTemplateKernel render()-based path. + + Compiled function: relu(template_add(a, sigmoid(b))) * bias + - Prologue: sigmoid(b) fused into template as + - Epilogue: relu(...) * bias fused into template as + - Extra inputs: bias is read by the epilogue but is not among the + template's original inputs, exercising kernel._extra_inputs + """ + import torch._inductor.ir as ir + from torch._inductor.ir import OrderedSet + from torch._inductor.utils import Placeholder, run_and_get_code + + XBLOCK = 128 + render_called = [False] + + # Template source with placeholders filled in by _render() + _MOCK_ADD_KERNEL_TEMPLATE = ( + "import triton\n" + "import triton.language as tl\n" + "import torch\n" + "from torch._inductor.runtime import triton_helpers\n" + "\n" + "@triton.jit\n" + "def _mock_inner_add(A, B, {out_param}{extra_sig}," + " numel, XBLOCK: tl.constexpr):\n" + " xoffset = tl.program_id(0) * XBLOCK\n" + " xindex = xoffset + tl.arange(0, XBLOCK)\n" + " xmask = xindex < numel\n" + " a = tl.load(A + xindex, mask=xmask)\n" + "{prologue_load_b}" + " _kernel_val_0 = a + b\n" + " x_epilogue0_0 = xindex\n" + " _tile_mask_0 = xmask\n" + " \n" + "\n" + "def {kernel_name}(A, B, {out_param}{extra_sig}, numel):\n" + " grid = ((numel + {xblock} - 1) // {xblock},)\n" + " _mock_inner_add[grid](" + "A, B, {out_param}{extra_sig}, numel, XBLOCK={xblock})\n" + " return {out_param}\n" + ) + + class _MockExternalTemplateBuffer(ir.TemplateBuffer): + def __init__(self, layout, inputs): + tb_self = self + + def _make_kernel_render(out_node, hint_override=None): + kernel = ExternalTritonTemplateKernel(tb_self) + + def render(): + return tb_self._render(kernel) + + return kernel, render + + super().__init__( + layout, + inputs, + _make_kernel_render, + named_inputs={"A": inputs[0], "B": inputs[1]}, + ) + # Allow prologue fusion on input B (sigmoid(b) can be + # absorbed so the template reads b directly) + self.allowed_prologue_inps = OrderedSet([inputs[1].get_name()]) + self.epilogue_fusable_outputs = {self.name: "result"} + + def _render(self, kernel): + render_called[0] = True + + # Set up all fusion hooks in one call + kernel._setup_fusion_hooks() + + # --- Prologue handling for B --- + b_arg = self.inputs[1].get_name() + prologue_load_b = " b = tl.load(B + xindex, mask=xmask)\n" + if kernel._prologue_source_buffers.get("B") is not None: + b_arg = kernel._prologue_source_buffers["B"] + prologue_load_b = ( + " _prologue_B_xindex = xindex\n" + " _prologue_B_xmask = xmask\n" + " \n" + " b = _prologue_B_result\n" + ) + + call_args = [self.inputs[0].get_name(), b_arg] + + # --- Epilogue handling --- + out_param = "result" + out_arg = self.name + for buf, param in kernel._extra_store_targets.items(): + out_param = param + out_arg = buf + break + + call_args.append(out_arg) + + # --- Extra inputs --- + extra_params = [] + for buf_name, param_name in kernel._extra_inputs.items(): + extra_params.append(param_name) + call_args.append(buf_name) + + numel = self.get_size()[0] + call_args.append(str(numel)) + + extra_sig = ", " + ", ".join(extra_params) if extra_params else "" + + source = _MOCK_ADD_KERNEL_TEMPLATE.format( + out_param=out_param, + extra_sig=extra_sig, + prologue_load_b=prologue_load_b, + kernel_name=str(Placeholder.KERNEL_NAME), + xblock=XBLOCK, + ) + + kernel._call_preamble = [] + kernel._call_args = call_args + + return PartialRender(source, kernel.render_hooks) + + def add_override(a, b, alpha=None): + layout = FixedLayout(a.get_device(), a.get_dtype(), a.get_size()) + a = ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(a)) + b = ir.ExternKernel.require_stride1(ir.ExternKernel.realize_input(b)) + return ir.TensorBox.create(_MockExternalTemplateBuffer(layout, [a, b])) + + # (override_fn, decompose, type_promotion, convert_input_to_bool) + with patch_lowering( + { + torch.ops.aten.add.Tensor: ( + add_override, + True, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + False, + ) + } + ): + + @torch.compile + def f(a, b, bias): + # Use * for bias so it doesn't trigger add_override again + return torch.relu(a + torch.sigmoid(b)) * bias + + a = torch.randn(32, device=GPU_TYPE) + b = torch.randn(32, device=GPU_TYPE) + bias = torch.randn(32, device=GPU_TYPE) + + result, (code,) = run_and_get_code(f, a, b, bias) + expected = torch.relu(a + torch.sigmoid(b)) * bias + torch.testing.assert_close(result, expected) + + # Verify render() was called (new protocol) + self.assertTrue(render_called[0]) + # Verify template kernel was used + self.assertIn("_mock_inner_add", code) + # Verify epilogue fusion: relu fused via hook + self.assertIn("triton_helpers.maximum", code) + # Verify prologue fusion: sigmoid fused via hook + self.assertIn("tl.sigmoid", code) + if __name__ == "__main__": if IS_LINUX and HAS_GPU and is_big_gpu(): diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index aae07ba53d61e..96a8ead033c98 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -806,7 +806,7 @@ def test_unbind_stack(self): def unbind_stack(x): return torch.stack(torch.unbind(x, 1), 1) - def unbind_cat(x): # noqa: F841 + def unbind_cat(x): return torch.cat(torch.unbind(x, dim=-3), 1) def unbind_stack_argspec1(x): diff --git a/test/inductor/test_static_triton_launcher.py b/test/inductor/test_static_triton_launcher.py index c630746fa6e74..fb22dbb613692 100644 --- a/test/inductor/test_static_triton_launcher.py +++ b/test/inductor/test_static_triton_launcher.py @@ -231,7 +231,7 @@ def test_implied_constant(self): def triton_red_fused_any_isinf_0( in_ptr0, out_ptr0, - xnumel, # noqa: F841 + xnumel, r0_numel, XBLOCK: tl.constexpr, R0_BLOCK: tl.constexpr, diff --git a/test/inductor/test_subgraph_choice.py b/test/inductor/test_subgraph_choice.py index 098b70b591787..27100f7b66b8f 100644 --- a/test/inductor/test_subgraph_choice.py +++ b/test/inductor/test_subgraph_choice.py @@ -7,7 +7,6 @@ from torch._inductor.lowering import register_lowering from torch._inductor.select_algorithm import autotune_select_algorithm from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import skipIfXpu from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU @@ -35,7 +34,6 @@ def _create_buffer(self, name, shape, dtype): layout=FixedLayout(torch.device(f"{GPU_TYPE}:0"), dtype=dtype, size=shape), ) - @skipIfXpu def test_subgraph_decompose_k(self): from torch._inductor.kernel.mm import aten_mm from torch._inductor.kernel.mm_common import mm_args @@ -96,7 +94,6 @@ def func(mat1, mat2): # Check same results of compiled result and regular torch.mm torch.testing.assert_close(res, a_in @ b_in, atol=1e-1, rtol=1e-1) - @skipIfXpu def test_subgraph_freeze_layout(self): from torch._inductor.kernel.mm_common import mm_args diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 056e67761da6c..a696d9582be5d 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] # ruff: noqa: F841 import contextlib +import contextvars import copy import dataclasses import functools @@ -83,7 +84,11 @@ expectedFailureXPU, largeTensorTest, ) -from torch.testing._internal.common_dtype import all_types, get_all_dtypes +from torch.testing._internal.common_dtype import ( + all_types, + get_all_dtypes, + highest_precision_float, +) from torch.testing._internal.common_quantization import ( _dynamically_quantize_per_channel, _group_quantize_tensor_symmetric, @@ -91,6 +96,8 @@ from torch.testing._internal.common_utils import ( DeterministicGuard, instantiate_parametrized_tests, + IS_ARM64, + IS_CPU_EXT_SVE_SUPPORTED, IS_FBCODE, IS_MACOS, IS_X86, @@ -99,7 +106,6 @@ NAVI_ARCH, parametrize, serialTest, - skipIfMPS, skipIfRocm, skipIfRocmArch, skipIfWindows, @@ -295,7 +301,7 @@ def register_ops_with_aoti_compile(ns, op_set, dispatch_key, torch_compile_op_li schema = torch._C._get_schema(qualified_op_name, overload_name) if schema.overload_name: reg_op_name = f"{qualified_op_name}.{schema.overload_name}" - torch_compile_op_lib_impl._impl_with_aoti_compile( # noqa: F821 + torch_compile_op_lib_impl._impl_with_aoti_compile( reg_op_name, dispatch_key ) except Exception as e: @@ -1052,6 +1058,11 @@ def xfail_if_triton_cpu(fn): return fn +def xfail_if_pallas(fn): + fn._expected_failure_pallas = True + return fn + + def skip_if_gpu_halide(fn): @functools.wraps(fn) def wrapper(self, *args, **kwargs): @@ -2320,6 +2331,15 @@ def fn(a): self.common(fn, (inp.view(-1),), rtol=1e-4, atol=1e-5, check_lowp=False) self.common(fn, (inp.view(10, -1),), rtol=1e-4, atol=1e-5, check_lowp=False) + def test_split_cumsum_broadcast(self): + # https://github.com/pytorch/pytorch/issues/180221 + def fn(x, b): + return torch.cumsum(x + b, dim=1) + + x = make_tensor(1, 129, 64, low=0, dtype=torch.float32, device=self.device) + b = make_tensor(64, low=0, dtype=torch.float32, device=self.device) + self.common(fn, (x, b)) + @skipCUDAIf(not SM80OrLater, "Requires sm80") @skip_if_gpu_halide # accuracy issue def test_split_cumsum_low_prec(self): @@ -2598,6 +2618,8 @@ def fn(a, b_int8pack, b_scales, c): @xfail_if_triton_cpu @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") + # Pallas codegen doesn't handle reduction axis after FloorDiv(ModularIndexing) simplification + @xfail_if_pallas def test__dyn_quant_pack_4bit_weight_fp32(self): q_group = 32 k = 128 @@ -2634,6 +2656,8 @@ def fn(b, in_features, out_features): @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") @skip_if_halide # bf16 + # Pallas codegen doesn't handle reduction axis after FloorDiv(ModularIndexing) simplification + @xfail_if_pallas def test__dyn_quant_pack_4bit_weight_bf16(self): k = 128 n = 128 @@ -2672,6 +2696,8 @@ def fn(b, in_features, out_features): @xfail_if_mps_unimplemented @xfail_if_triton_cpu + # Pallas codegen doesn't handle reduction axis after FloorDiv(ModularIndexing) simplification + @xfail_if_pallas @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") def test__dyn_quant_matmul_4bit_fp32_input(self): @@ -2715,8 +2741,10 @@ def fn(a, q_group, in_features, out_features): self.common(fn, (a, q_group, in_features, out_features)) @skipCPUIf(IS_MACOS, "fails on M1, mismatch in bf16 support reporting") - @xfail_if_mps_unimplemented @xfail_if_triton_cpu + @xfailIf( + IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED + ) # see https://github.com/pytorch/pytorch/issues/170787 @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") @skip_if_halide # bf16 @@ -2916,6 +2944,92 @@ def fn(x): a = torch.rand(()) self.common(fn, (a,)) + def test_cumprod_backward(self): + if self.device == "mps": + raise unittest.SkipTest( + "MPS inductor codegen bug with argmax: threadgroup_argmax" + ) + + # Regression test for https://github.com/pytorch/pytorch/issues/136263 + # torch.compile used O(n^2) algorithm for cumprod backward with tensor + # subclasses (like FakeTensor), making it extremely slow. + def fn(x): + y = x.cumprod(dim=1) + return y.sum() + + for dim_size in [8, 32]: + x = torch.rand(4, dim_size, 64, device=self.device, requires_grad=True) + x_ref = x.clone().detach().requires_grad_(True) + + # Eager forward + backward + y_ref = fn(x_ref) + y_ref.backward() + + # Compiled forward + backward + compiled_fn = torch.compile(fn, fullgraph=True) + y = compiled_fn(x) + y.backward() + + # Check correctness + self.assertEqual(y, y_ref, atol=1e-4, rtol=1e-4) + self.assertEqual(x.grad, x_ref.grad, atol=1e-4, rtol=1e-4) + + def test_cumprod_backward_with_zeros(self): + if self.device == "mps": + raise unittest.SkipTest( + "MPS inductor codegen bug with argmax: threadgroup_argmax" + ) + + # Test cumprod backward with zeros in the input + # This exercises the more complex O(n) algorithm path + def fn(x): + y = x.cumprod(dim=1) + return y.sum() + + for dim_size in [8, 16]: + x = torch.rand(4, dim_size, 32, device=self.device, requires_grad=True) + # Insert some zeros to exercise the zero-handling path + x.data[:, dim_size // 2, :] = 0 + x_ref = x.clone().detach().requires_grad_(True) + + # Eager forward + backward + y_ref = fn(x_ref) + y_ref.backward() + + # Compiled forward + backward + compiled_fn = torch.compile(fn, fullgraph=True) + y = compiled_fn(x) + y.backward() + + # Check correctness + self.assertEqual(y, y_ref, atol=1e-4, rtol=1e-4) + self.assertEqual(x.grad, x_ref.grad, atol=1e-4, rtol=1e-4) + + def test_view_dtype_bool(self): + # Regression test for boolean dtype handling in view.dtype lowering + # torch.iinfo doesn't support bool, so we need special handling + def fn(x, mask): + # Create a computation that involves boolean tensors and where + result = torch.where(mask, x, torch.zeros_like(x)) + return result.sum() + + x = torch.rand(4, 8, device=self.device, requires_grad=True) + mask = torch.rand(4, 8, device=self.device) > 0.5 + + x_ref = x.clone().detach().requires_grad_(True) + + # Eager + y_ref = fn(x_ref, mask) + y_ref.backward() + + # Compiled + compiled_fn = torch.compile(fn, fullgraph=True) + y = compiled_fn(x, mask) + y.backward() + + self.assertEqual(y, y_ref) + self.assertEqual(x.grad, x_ref.grad) + def test_cumsum_inf(self): def fn(x): return x.cumsum(-1) @@ -3019,7 +3133,6 @@ def fn(a, b): self.common(fn, (torch.randn(4, 4), torch.randn(4, 4))) - @xfail_if_mps @skip_if_halide # different pow accuracies @xfail_if_triton_cpu def test_norm_constant_overflow(self): @@ -3097,6 +3210,15 @@ def fn(x): self.common(fn, (make_arg(1, dtype=torch.float32),)) self.common(fn, (make_arg(1, dtype=torch.int64),)) + def test_arange7(self): + def fn(x): + # Test aten.arange.start_step lowering with integer dtypes + return x + torch.ops.aten.arange.start_step( + 0, 10, 2, dtype=torch.int64, device=x.device + ) + + self.common(fn, (torch.zeros(5, dtype=torch.int64),), check_lowp=False) + def test_linspace1(self): def fn(x): return torch.linspace(0.125, 0.875, 7, device=x.device) + x @@ -3333,7 +3455,7 @@ def test_round_correctness(self): def fn(a): return torch.round(a) - dtype = torch.float64 if self.device != "mps" else torch.float32 + dtype = highest_precision_float(self.device) self.common( fn, [torch.arange(-10, 10, 0.1, dtype=dtype)], @@ -3471,6 +3593,8 @@ def fn(a): check_lowp=False, ) + test_one_hot._expected_failure_halide = True + def test_div1(self): def fn(a, b): return ( @@ -3553,6 +3677,7 @@ def fn(a, b): ) @skip_if_triton_cpu # divide by zero; cannot xfail because it crashes process + @skipIfXpu(msg="https://github.com/intel/intel-xpu-backend-for-triton/issues/6401") def test_div7(self): def fn(a, b): return ( @@ -3666,6 +3791,50 @@ def fn(a, b): b_neg = torch.full_like(a, -divisor) self.common(fn, (a, b_neg)) + @skip_if_cpu + def test_floordiv_div_by_zero_int(self): + # Integer floor division by zero is undefined behavior on CUDA/Triton. + # On CPU, integer division by zero correctly raises ZeroDivisionError. + # Eager (c10::div_floor_integer) and compiled (Triton floordiv) must + # both return 0 for elements where the divisor is zero on GPU. + def fn(a, b): + return torch.floor_divide(a, b) + + for dtype in [torch.int32, torch.int64]: + # All-zero divisor: every element should be 0 + for dividend in [0, 1, -1, 5, -5]: + a = torch.full((8,), dividend, device=self.device, dtype=dtype) + b = torch.full((8,), 0, device=self.device, dtype=dtype) + self.common(fn, (a, b)) + + def test_floordiv_float_accuracy(self): + # Triton uses an approximate reciprocal for fp32 division, so a naive + # floor(a / b) can be off by one when the true quotient is very close + # to an integer. Verify the corrected lowering matches eager. + def fn(a, b): + return torch.floor_divide(a, b) + + # The original repro: fp16 values whose exact quotient is 4.0, but + # Triton's approximate fp32 division yields 3.9999997… + a = torch.full((8,), 84.3125, dtype=torch.float16, device=self.device) + b = torch.full((8,), 21.078125, dtype=torch.float16, device=self.device) + self.common(fn, (a, b)) + + # Negative operands / mixed signs + for dtype in [torch.float16, torch.float32]: + for a_val, b_val in [ + (7.0, 2.0), + (7.0, -2.0), + (-7.0, 2.0), + (-7.0, -2.0), + (6.0, 3.0), + (0.0, 5.0), + (0.0, -5.0), + ]: + a = torch.full((8,), a_val, dtype=dtype, device=self.device) + b = torch.full((8,), b_val, dtype=dtype, device=self.device) + self.common(fn, (a, b)) + def test_div_precision(self): # Reproducer for https://github.com/pytorch/pytorch/issues/101039 @@ -4616,7 +4785,7 @@ def fn(x, y): x1 = torch.randn(30, device=self.device) x2 = torch.randn(36, device=self.device) - dtype = torch.float64 if self.device != "mps" else torch.float32 + dtype = highest_precision_float(self.device) y = torch.ones(1, dtype=dtype, device=self.device) self.assertEqual(torch.compile(fn)(x1, y), fn(x1, y)) @@ -4790,7 +4959,7 @@ def fn(x): with config.patch({"triton.use_block_ptr": use_block_ptr}): self.common(fn, (torch.randn(1, 3, *[10] * dim),)) - @skipIfMPS + @xfail_if_mps # aten::full with zero-sized dim triggers AcceleratorError on MPS def test_max_unpool_empty_output(self): class Unpool1d(nn.Module): def __init__(self): @@ -5249,6 +5418,22 @@ def fn(x): ) assertGeneratedKernelCountEqual(self, 0) + @requires_gpu() + @skip_if_gpu_halide # slow + @xfail_if_mps # Non-divisible input sizes are not implemented on MPS device + @parametrize("comprehensive_padding", (False, True)) + def test_adaptive_avg_pool2d_flatten_sum(self, comprehensive_padding): + def fn(x): + y = F.adaptive_avg_pool2d(x, 7) + return y.flatten(1).sum(dim=-1) + + with config.patch(comprehensive_padding=comprehensive_padding): + self.common( + fn, + (torch.randn(2, 33, 8, 8, device=self.device, dtype=torch.float64),), + check_lowp=False, + ) + @xfail_if_mps @skip_if_gpu_halide # slow def test_adaptive_max_pool2d1(self): @@ -5375,15 +5560,15 @@ def run_weights_sharing_model(m, inp): threads = [] compiled_m = torch.compile(model) for _ in range(1, numb_instance + 1): + ctx = contextvars.copy_context() thread = threading.Thread( - target=run_weights_sharing_model, args=(compiled_m, inp) + target=ctx.run, args=(run_weights_sharing_model, compiled_m, inp) ) threads.append(thread) thread.start() for thread in threads: thread.join() - @unittest.skipIf(config.is_fbcode(), "fbcode triton error, needs debugging") @skip_if_triton_cpu("Flaky on Triton CPU") @skip_if_gpu_halide # https://github.com/halide/Halide/issues/8311 def test_adaptive_avg_pool2d_low_prec(self): @@ -6122,6 +6307,24 @@ def test_layer_norm(self): if self.device != "cpu": assertGeneratedKernelCountEqual(self, 1) + def test_layer_norm_rejects_complex_inputs(self): + if self.device not in ("cpu", "cuda"): + raise unittest.SkipTest("Only validated on CPU/CUDA") + + m = torch.nn.LayerNorm(10).to(self.device) + x = torch.randn(1, 1, 10, device=self.device, dtype=torch.complex64) + + with self.assertRaises(RuntimeError): + m(x) + + with self.assertRaises(RuntimeError) as compiled_error: + torch.compile(m)(x) + + self.assertIn( + "native_layer_norm does not support complex inputs", + str(compiled_error.exception), + ) + @torch._functorch.config.patch("donated_buffer", True) def test_matmul_layer_norm(self): batch_size = 32 @@ -6280,6 +6483,94 @@ def fn(): actual = compiled() self.assertEqual(actual, expected) + def test_complex_uniform_constant_folding(self): + # Fix https://github.com/pytorch/pytorch/issues/174891 + # view.dtype with mismatched element sizes changes element count, + # so constant folding must not treat the result as uniform. + def fn(x): + mask = torch.ones(2, 1, dtype=torch.complex64, device=self.device) + return x + mask + + x = torch.full((2, 2), 1.0, dtype=torch.complex64, device=self.device) + expected = fn(x) + compiled = torch.compile(fn, backend="inductor") + actual = compiled(x) + self.assertEqual(actual, expected) + + def test_view_dtype_non_0d_larger_to_smaller_element_size(self): + # Non-0-d counterpart of test_view_dtype_0d_smaller_to_larger_element_size. + # element_size (8) > itemsize (4): complex64 -> float32. + import torch.fx as fx + from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder + + graph = fx.Graph() + + full_node = graph.call_function( + torch.ops.aten.full.default, + args=([2], 1 + 0j), + kwargs={ + "dtype": torch.complex64, + "layout": torch.strided, + "device": self.device, + "pin_memory": False, + }, + ) + full_node.meta["val"] = torch.full( + [2], 1 + 0j, dtype=torch.complex64, device=self.device + ) + + view_node = graph.call_function( + torch.ops.aten.view.dtype, args=(full_node, torch.float32) + ) + view_node.meta["val"] = torch.full( + [2], 1 + 0j, dtype=torch.complex64, device=self.device + ).view(torch.float32) + + graph.output(view_node) + gm = fx.GraphModule(torch.nn.Module(), graph) + + folder = UniformValueConstantFolder(gm) + folder.run() + + self.assertNotIn(view_node, folder.node_replacements) + + def test_view_dtype_non_0d_smaller_to_larger_element_size(self): + # Non-0-d counterpart of test_view_dtype_0d_smaller_to_larger_element_size. + # element_size (4) < itemsize (8): float32 -> complex64. + import torch.fx as fx + from torch._inductor.fx_passes.joint_graph import UniformValueConstantFolder + + graph = fx.Graph() + + full_node = graph.call_function( + torch.ops.aten.full.default, + args=([2], 1.0), + kwargs={ + "dtype": torch.float32, + "layout": torch.strided, + "device": self.device, + "pin_memory": False, + }, + ) + full_node.meta["val"] = torch.full( + [2], 1.0, dtype=torch.float32, device=self.device + ) + + view_node = graph.call_function( + torch.ops.aten.view.dtype, args=(full_node, torch.complex64) + ) + view_node.meta["val"] = torch.full( + [2], 1.0, dtype=torch.float32, device=self.device + ).view(torch.complex64) + + graph.output(view_node) + gm = fx.GraphModule(torch.nn.Module(), graph) + + folder = UniformValueConstantFolder(gm) + folder.run() + + self.assertNotIn(view_node, folder.node_replacements) + def test_uniform(self): def fn(x): return aten.uniform.default(x, 0, 1) @@ -6306,7 +6597,11 @@ def fn(x, y): reference_in_float=False, ) - @skipIfMPS + @unittest.skipIf( + TEST_WITH_ROCM and not torch.cuda.has_magma, + "ROCm hipsolver backend does not currently support eig", + ) + @xfail_if_mps_unimplemented # aten::linalg_eig not implemented for MPS def test_linalg_eig_stride_consistency(self): def fn(x): eigenvals, eigenvecs = torch.linalg.eig(x) @@ -6392,7 +6687,7 @@ def test_polar(self): def fn(dist, angle): return torch.polar(dist, angle) - dtype = torch.float64 if self.device != "mps" else torch.float32 + dtype = highest_precision_float(self.device) inp = ( torch.tensor([1, 2], dtype=dtype), torch.tensor([np.pi / 2, 5 * np.pi / 4], dtype=dtype), @@ -6561,7 +6856,6 @@ def fn(mask, value): ) self.assertEqual(fn(*inputs), opt_fn(*inputs)) - @xfail_if_mps # 'NullHandler' object has no attribute 'wrapper_code' def test_masked_scatter(self): def fn(value, mask, source): return torch.masked_scatter(value, mask, source) @@ -6968,10 +7262,10 @@ def matmul_with_op(x, y, fn): # test no-op fns = ( - lambda x: x + torch.zeros([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 - lambda x: x - torch.zeros([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 - lambda x: x * torch.ones([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 - lambda x: x / torch.ones([256, 256], dtype=torch.float32, device=x.device), # noqa: E731 + lambda x: x + torch.zeros([256, 256], dtype=torch.float32, device=x.device), + lambda x: x - torch.zeros([256, 256], dtype=torch.float32, device=x.device), + lambda x: x * torch.ones([256, 256], dtype=torch.float32, device=x.device), + lambda x: x / torch.ones([256, 256], dtype=torch.float32, device=x.device), ) inps = [torch.rand([256, 256], device=self.device) for _ in range(2)] @@ -7057,7 +7351,7 @@ def f(x): def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", arg3_1: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}"): add: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(arg3_1, 1); arg3_1 = None add_9: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(add, 1); add = None - return (add_9,)""" # noqa: B950 + return (add_9,)""" self.assertExpectedInline( post_grad_graph, expected_graph, @@ -7081,7 +7375,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "f32[s77, s27, add: "f32[s77, s27, 2][2*s27, 2, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None slice_1: "f32[s77, s27, 1][2*s27, 2, 1]{str(x.device)}" = torch.ops.aten.slice.Tensor(add, -1, 0, -1); add = None add_9: "f32[s77, s27, 1][s27, 1, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(slice_1, 1); slice_1 = None - return (add_9,)""" # noqa: B950 + return (add_9,)""" self.assertExpectedInline( post_grad_graph, expected_graph, @@ -7110,7 +7404,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar empty: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}" = torch.ops.aten.empty.memory_format([arg0_1, arg1_1, arg2_1], dtype = torch.float32, layout = torch.strided, device = {repr(x.device)}, pin_memory = False); arg0_1 = arg1_1 = arg2_1 = empty = None add: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(arg3_1, 1); arg3_1 = None add_13: "f32[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}" = torch.ops.aten.add.Tensor(add, 1); add = None - return (add_13,)""" # noqa: B950 + return (add_13,)""" self.assertExpectedInline( post_grad_graph, expected_graph, @@ -7318,8 +7612,6 @@ def fn(x): @config.patch(force_disable_caches=True) def test_deterministic_codegen(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") if "cpu" in str(self.device) and config.cpp_wrapper: raise unittest.SkipTest( "run_and_get_kernels can't extract kernels from CPU cpp_wrapper code" @@ -7371,8 +7663,6 @@ def c(x): @config.patch(force_disable_caches=True) def test_deterministic_codegen_on_graph_break(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") if "cpu" in str(self.device) and config.cpp_wrapper: raise unittest.SkipTest( "run_and_get_kernels can't extract kernels from CPU cpp_wrapper code" @@ -7389,7 +7679,7 @@ def b(x): return x x = torch.randn(16, 256, device=self.device) - if config.cpp_wrapper and config.triton.autotune_at_compile_time is False: + if config.cpp_wrapper and config.triton.autotune_at_compile_time is not True: # With lazy compile, both graph segments produce identical code # (no unique .cubin paths), so run_and_get_code deduplicates them # and only 1 kernel is returned. @@ -7409,8 +7699,6 @@ def b(x): } ) def test_deterministic_codegen_with_suffix(self): - if "cpu" in str(self.device) and config.is_fbcode(): - raise unittest.SkipTest("cpp packaging is wacky in fbcode") if "cpu" in str(self.device) and config.cpp_wrapper: raise unittest.SkipTest( "run_and_get_kernels can't extract kernels from CPU cpp_wrapper code" @@ -7589,6 +7877,46 @@ def fn(a, b): self.common(fn, (torch.randn(64), torch.randn(64))) + @skip_if_halide # cpp-only RuntimeError contract + @skip_if_pallas # cpp-only RuntimeError contract + @skip_if_triton_cpu # cpp-only RuntimeError contract + def test_fmod_uint8_zero_divisor_cpu_inductor_raises_error(self): + if self.device != "cpu": + raise unittest.SkipTest( + "CPU-only: uint8 Inductor integer divide-by-zero RuntimeError check" + ) + torch.manual_seed(0) + device = self.device + divisor = torch.randint(0, 255, (8,), dtype=torch.uint8, device=device) + inp = torch.randint(1, 255, (16, 8), dtype=torch.uint8, device=device) + divisor[2] = 0 # ensure a divisor is 0 + + torch._dynamo.reset() + opt = torch.compile(torch.fmod, fullgraph=True, backend="inductor", mode=None) + with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): + opt(inp, divisor) + + @skip_if_halide # cpp-only RuntimeError contract + @skip_if_pallas # cpp-only RuntimeError contract + @skip_if_triton_cpu # cpp-only RuntimeError contract + def test_remainder_uint8_zero_divisor_cpu_inductor_raises_error(self): + if self.device != "cpu": + raise unittest.SkipTest( + "CPU-only: uint8 Inductor integer divide-by-zero RuntimeError check" + ) + torch.manual_seed(0) + device = self.device + divisor = torch.randint(0, 255, (8,), dtype=torch.uint8, device=device) + inp = torch.randint(1, 255, (16, 8), dtype=torch.uint8, device=device) + divisor[2] = 0 # ensure a divisor is 0 + + torch._dynamo.reset() + opt = torch.compile( + torch.remainder, fullgraph=True, backend="inductor", mode=None + ) + with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError"): + opt(inp, divisor) + def test_zeros(self): def fn(a): return ( @@ -9123,7 +9451,6 @@ def fn(x): actual_out = compiled_fn(view) self.assertEqual(reference_out.stride(), actual_out.stride()) - @skipIfMPS def test_nonzero_static_stride(self): from torch._subclasses import FakeTensorMode @@ -9181,7 +9508,6 @@ def fn(x, a, b): ], ) - @skipIfMPS # MPS does not support float64 def test_select_scatter_dtype_consistency(self): def fn(x, a): return (torch.select_scatter(x, a, 1, 0),) @@ -9190,6 +9516,8 @@ def fn(x, a): torch.int64, torch.float64, ]: + if not self.is_dtype_supported(dtype): + continue self.common( fn, [ @@ -9214,6 +9542,30 @@ def fn(x, a): ], ) + @skip_if_gpu_halide # accuracy issue + def test_slice_scatter_backward_with_overlapping_base(self): + def fn(x, y): + return torch.slice_scatter(x, y, dim=1, start=0, end=6).sum(dim=1) + + torch.manual_seed(0) + x = torch.randn(4, 13, 33, device=self.device, requires_grad=True) + y = torch.randn(4, 6, 33, device=self.device, requires_grad=True) + grad = torch.randn(4, 33, device=self.device) + + x_ref = x.detach().clone().requires_grad_(True) + y_ref = y.detach().clone().requires_grad_(True) + + out_ref = fn(x_ref, y_ref) + out_ref.backward(grad) + + compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True) + out = compiled_fn(x, y) + out.backward(grad) + + self.assertEqual(out, out_ref) + self.assertEqual(x.grad, x_ref.grad) + self.assertEqual(y.grad, y_ref.grad) + def test_slice_scatter2(self): def fn(a, b): return aten.slice_scatter(a, b, 0, 0, 9223372036854775807) @@ -9302,7 +9654,6 @@ def forward(self, x, start_pos): else: assertGeneratedKernelCountEqual(self, 1) - @skipIfMPS def test_slice_scatter_dtype_consistency(self): # Test dtype consistency of slice_scatter def fn(x, y): @@ -9312,6 +9663,8 @@ def fn(x, y): torch.int64, torch.float64, ]: + if not self.is_dtype_supported(dtype): + continue self.common( fn, [ @@ -9575,6 +9928,33 @@ def fn(a, dim, index, b, reduce): check_lowp=check_lowp, ) + def test_scatter_reduce_fused_broadcast_non_power_of_2(self): + # https://github.com/pytorch/pytorch/issues/178871 + # When inductor fuses a scalar broadcast with scatter_reduce("sum"), + # the store index may simplify to a constant (e.g., output size 1). + # For non-power-of-2 input sizes, XBLOCK > numel causes OOB threads + # to execute tl.atomic_add without a mask, adding stale values. + def fn(x, bias, idx): + src = x + bias + out = torch.zeros(1, 1, device=x.device, dtype=x.dtype) + return out.scatter_reduce( + 0, idx.unsqueeze(-1).expand_as(src), src, "sum", include_self=True + ) + + check_lowp = self.device != "xpu" + for N in [3, 5, 7, 9, 17, 33, 48]: + self.common( + fn, + [ + torch.randn(N, 1), + torch.tensor([0.5]), + torch.zeros(N, dtype=torch.long), + ], + check_lowp=check_lowp, + atol=1e-3, + rtol=0.01, + ) + @skip_if_gpu_halide def test_dense_mask_index(self): r""" @@ -9613,18 +9993,35 @@ def fn(): @requires_gpu() @unittest.skipIf(IS_MACOS, "fails on macos") @parametrize( - "constructor", - [torch.empty, torch.ones, torch.zeros, torch.rand, torch.randn], - name_fn=lambda constructor: constructor.__name__, + "constructor_and_args", + [ + (torch.empty, ([1, 128, 128],)), + (torch.ones, ([1, 128, 128],)), + (torch.zeros, ([1, 128, 128],)), + ( + torch.full, + ( + [1, 128, 128], + 2, + ), + ), + (torch.rand, ([1, 128, 128],)), + (torch.randn, ([1, 128, 128],)), + ], + name_fn=lambda constructor_and_args: constructor_and_args[0].__name__, ) - def test_constructors_pin_memory(self, constructor): + def test_constructors_pin_memory(self, constructor_and_args): if self.device != "cpu": raise unittest.SkipTest("pin_memory is not supported on non-CPU devices") - failing_constructors = [torch.rand, torch.ones, torch.zeros] + constructor, args = constructor_and_args + + failing_constructors = [ + torch.rand, + ] def fn(): - return constructor([1, 128, 128], pin_memory=True, device=self.device) + return constructor(*args, pin_memory=True, device=self.device) result = torch.compile(fn, backend="inductor")() if constructor in failing_constructors: @@ -10181,7 +10578,6 @@ def fn(x): self.common(fn, [torch.zeros([20, 20])], exact_stride=True) @config.patch(fallback_random=True) - @xfail_if_mps # 100% are not close def test_like_rands_sliced(self): def fn(x): return ( @@ -10192,6 +10588,31 @@ def fn(x): self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) + @config.patch(fallback_random=True) + def test_like_randn_non_contiguous_rng_consistency(self): + x = torch.zeros((3, 6), device=self.device).t() + self.assertFalse(x.is_contiguous()) + + randn_like_compiled = torch.compile(torch.randn_like, backend="inductor") + + torch.manual_seed(0) + eager_out = torch.randn_like(x) + torch.manual_seed(0) + compiled_out = randn_like_compiled(x) + + self.assertEqual(eager_out, compiled_out) + + @config.patch(fallback_random=False) + def test_fast_like_rands_decomps_use_non_eager_path(self): + x = torch.zeros(3, 6, device=self.device) + + randn_like_compiled = torch.compile(torch.randn_like, backend="inductor") + torch.manual_seed(0) + eager_randn_out = torch.randn_like(x) + torch.manual_seed(0) + compiled_randn_out = randn_like_compiled(x) + self.assertNotEqual(eager_randn_out, compiled_randn_out) + @config.patch(check_stack_no_cycles_TESTING_ONLY=True) def test_check_stack_no_cycles(self): if config.cpp_wrapper and self.device != "cpu": @@ -10350,11 +10771,12 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 1) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) - @expectedFailureXPU def test_max_pool2d_with_indices_backward5(self): - # Window size is too big. Should fallback + # Large window size - decomposition handles via scatter_add def fn(a, b, c): return aten.max_pool2d_with_indices_backward( a, b, [13, 13], [1, 1], [2, 2], [1, 1], False, c @@ -10378,11 +10800,15 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 0) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + # MPS: decomposition falls back to native kernel, so no inductor kernels generated + if self.device != "mps": + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) # From https://github.com/pytorch/pytorch/issues/93384 def test_max_pool2d_with_indices_backward6(self): - # dilation is not 1. Should fallback + # dilation != 1 - decomposition handles all dilation cases def fn(a, b, c): return aten.max_pool2d_with_indices_backward( a, b, [3, 2], [2, 1], [1, 1], [1, 2], False, c @@ -10406,7 +10832,11 @@ def fn(a, b, c): indices, ], ) - assertGeneratedKernelCountEqual(self, 0) + # Note: Kernel count varies by backend (CUDA ~3, ROCm ~2) due to fusion. + # Correctness is validated by self.common() above. + # MPS: decomposition falls back to native kernel, so no inductor kernels generated + if self.device != "mps": + self.assertGreater(torch._inductor.metrics.generated_kernel_count, 0) def test_issue102546(self): def fn(x): @@ -10630,6 +11060,10 @@ def fn(a): @xfail_if_mps def test_dropout2(self): + if is_mps_backend(self.device) and torch._inductor.config.align_random_eager: + raise AssertionError( + "MPS + align_random_eager: will pass but force failure for xfail_if_mps" + ) n = 100000 weight = torch.ones( n, device=self.device, dtype=torch.float32, requires_grad=True @@ -10667,10 +11101,13 @@ def check(r, g): torch.manual_seed(1234) weight.grad.zero_() r2, (fw_code, bw_code) = run_fw_bw_and_get_code(lambda: run(ones)) - if is_halide_backend(self.device): + if ( + is_halide_backend(self.device) + and not torch._inductor.config.align_random_eager + ): self.assertEqual(fw_code.count("halide_helpers.rand"), 1) self.assertEqual(bw_code.count("halide_helpers.rand"), 0) - elif self.device == GPU_TYPE: + elif self.device == GPU_TYPE and not torch._inductor.config.align_random_eager: self.assertEqual(fw_code.count("tl.rand"), 1) self.assertEqual(bw_code.count("tl.rand"), 0) g2 = weight.grad.clone() @@ -10690,6 +11127,10 @@ def check(r, g): @xfail_if_mps @config.patch(search_autotune_cache=False) def test_dropout3(self): + if is_mps_backend(self.device) and torch._inductor.config.align_random_eager: + raise AssertionError( + "MPS + align_random_eager + dynamic shapes: will pass but force failure for xfail_if_mps" + ) m = torch.nn.Sequential( torch.nn.Linear(32, 32, bias=False), torch.nn.Dropout(), @@ -10707,10 +11148,13 @@ def run(x): lambda: run(torch.randn([8, 32], device=self.device)) ) - if is_halide_backend(self.device): + if ( + is_halide_backend(self.device) + and not torch._inductor.config.align_random_eager + ): self.assertEqual(fw_code.count("halide_helpers.rand"), 2) self.assertEqual(bw_code.count("halide_helpers.rand"), 0) - elif self.device == GPU_TYPE: + elif self.device == GPU_TYPE and not torch._inductor.config.align_random_eager: # the load_seed_offset arg can be 1 or non-1; depending on whether # the triton signature specializes on 1 vs non-1, you might get 1 # or 2 kernels. In newer versions of triton, there's no specialization @@ -10784,7 +11228,6 @@ def fn(x): ], ) - @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin2(self): def fn(x): return ( @@ -10796,7 +11239,6 @@ def fn(x): self.common(fn, (torch.randn([144, 144]),)) - @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin_with_duplicates(self): def fn(x): return ( @@ -10818,7 +11260,6 @@ def fn(x): t1 = torch.randint(8, size=(1028, 1028)) self.common(fn, (t1,)) - @skipIfXpu(msg="# Incorrect XPU reference ") @xfail_if_mps # eager nan is wrong, see https://github.com/pytorch/pytorch/issues/130295 @skip_if_halide # nan behavior def test_argmax_argmin_with_nan(self): @@ -10943,7 +11384,6 @@ def shrink_rank(x, rank): [rank4_inps, rank3_inps, rank5_inps], ) - @skipIfXpu(msg="Incorrect XPU reference") def test_argmax_argmin3(self): def fn(x): return ( @@ -12564,6 +13004,14 @@ def _cases_resize_common(): x_strided = x[::2].reshape(25, 2).transpose(0, 1) yield x_strided, y_size, memory_format + def test_resize_overlapping_strides(self): + # Resize on a stride-0 view should read logical elements, not raw storage. + def fn(x): + view = torch.as_strided(x, (100,), (0,)) + return torch.ops.aten.resize(view, (50,)) + + self.common(fn, (torch.ones(10),)) + def test_resize(self): def fn(x, size, memory_format): # NOTE: Tensor.resize() =/= aten::resize() @@ -12681,8 +13129,30 @@ def fn(q, k, v): rtol=1e-2, # to pass lowp check on GPU ) + @skip_if_halide + @skip_if_pallas # cpp-only fusion path + @skip_if_triton_cpu + def test_group_norm_sdpa_bmm_cpu_cpp_fusion(self): + if self.device != "cpu": + raise unittest.SkipTest("cpu only") + + group_norm = nn.GroupNorm(1, 7).eval() + rrelu = nn.RReLU().eval() + + def fn(x): + y = torch.sigmoid(x) + z = group_norm(x) + z = F.scaled_dot_product_attention(z, z, z) + z = rrelu(z) + return torch.bmm(y, z) + + torch.manual_seed(0) + x = torch.randn((8, 7, 7)) + expected = fn(x) + actual = torch.compile(fn, backend="inductor", fullgraph=True)(x) + self.assertEqual(actual, expected) + @xfail_if_mps_unimplemented - @expectedFailureXPU @unittest.skipIf( not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" ) @@ -13187,6 +13657,108 @@ def test_assert_alignment_op_name_fail(self): with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): assert_alignment(tensor, 0, "torch.ops.dummy.op_name") + @requires_gpu() + @skip_if_not_triton + @unittest.skipIf( + config.cpp_wrapper, + "Inductor does not generate size/stride asserts for cpp_wrapper", + ) + def test_input_asserts_deferred_to_first_use(self): + def fn(x, y, z): + a = torch.mm(x, y) + b = torch.mm(a, z) + return b + + x = torch.randn(16, 32, device=self.device) + y = torch.randn(32, 64, device=self.device) + z = torch.randn(64, 8, device=self.device) + + _, code = run_and_get_code(torch.compile(fn), x, y, z) + # z's assert should appear after the first mm, not at the top + # with all the other asserts + FileCheck().check("def call").check_count( + "assert_size_stride", 2, exactly=True + ).check("extern_kernels.mm(").run(code[0]) + + @requires_gpu() + @skip_if_not_triton + @unittest.skipIf( + config.cpp_wrapper, + "Deferred alignment copies are not generated for cpp_wrapper", + ) + def test_alignment_copy_deferred_to_first_use(self): + def fn(x, y, z): + a = torch.mm(x, y) + b = torch.mm(a, z) + return b + + x = torch.randn(16, 32, device=self.device) + y = torch.randn(32, 64, device=self.device) + + z = torch.randn(64, 8, device=self.device) + + _, code = run_and_get_code(torch.compile(fn), x, y, z) + # z's alignment check should appear between the two mm calls: + # first mm (uses x, y) -> alignment clone (for z) -> second mm (uses z) + FileCheck().check("extern_kernels.mm(").check("copy_if_misaligned").check( + "extern_kernels.mm(" + ).run(code[0]) + + @requires_gpu() + @skip_if_not_triton + @torch._inductor.config.patch(cpp_wrapper=True) + def test_alignment_copy_not_emitted_for_cpp_wrapper(self): + def fn(x, y): + return torch.mm(x, y) + + x = torch.randn(16, 32, device=self.device) + y = torch.randn(32, 64, device=self.device) + + _, code = run_and_get_code(torch.compile(fn), x, y) + # cpp_wrapper should NOT contain Python-syntax alignment copies + self.assertNotIn("copy_if_misaligned", code[0]) + + def test_copy_if_misaligned_returns_same_tensor_when_aligned(self): + import weakref + + from torch._C._dynamo.guards import copy_if_misaligned + + x = torch.randn(32, 32, device=self.device) + ref = weakref.ref(x) + result = copy_if_misaligned(x) + self.assertIs(result, x) + del x, result + self.assertIsNone(ref(), "aligned tensor should be freed") + + def test_copy_if_misaligned_clones_when_misaligned(self): + import weakref + + from torch._C._dynamo.guards import copy_if_misaligned + + big = torch.randn(32 * 32 + 1, device=self.device) + x = big[1:].reshape(32, 32) + self.assertNotEqual(x.data_ptr() % 16, 0) + ref_orig = weakref.ref(x) + result = copy_if_misaligned(x) + self.assertIsNot(result, x) + self.assertEqual(result.data_ptr() % 16, 0) + self.assertEqual(result, x) + self.assertEqual(result.size(), x.size()) + self.assertEqual(result.stride(), x.stride()) + ref_clone = weakref.ref(result) + del x, result + self.assertIsNone(ref_clone(), "cloned tensor should be freed") + # orig kept alive by big's storage, that's fine + del big + self.assertIsNone(ref_orig(), "original tensor should be freed") + + def test_copy_if_misaligned_empty_tensor(self): + from torch._C._dynamo.guards import copy_if_misaligned + + x = torch.randn(0, device=self.device) + result = copy_if_misaligned(x) + self.assertEqual(result.size(), x.size()) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) @torch._inductor.config.patch(implicit_fallbacks=True) def test_custom_op_unbacked_symints(self): @@ -13543,7 +14115,7 @@ def fn(tensor, index, source): return out device = "cpu" - dtype = torch.double if self.device != "mps" else torch.float32 + dtype = highest_precision_float(self.device) tensor = torch.rand((1,), dtype=dtype, device=device) index = torch.tensor([0], dtype=torch.long, device=device) source = torch.rand((1,), dtype=dtype, device=device) @@ -14317,7 +14889,7 @@ def f(image_latent): torch.testing.assert_close(ref, act, atol=1e-3, rtol=1e-3) if is_dynamic_shape_enabled(): - size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." # noqa: B950 + size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, s12, s80, s80., .3\*s12\*s80\*s80, s12\*s80\*s80, 1, s12\*s80, s1.." else: size_assert_pattern = r"assert_size_stride.[a-z]+[0-9]+, .2, 3, 16, 32, 32., .49152, 16384, 1, 512, 16.." FileCheck().check_regex(size_assert_pattern).run(code) @@ -14932,7 +15504,7 @@ def f(x): expected_graph1 = f"""\ def forward(self, arg0_1: "f32[2, 3, 2][6, 2, 1]{str(x.device)}"): permute: "f32[2, 2, 3][6, 1, 2]{str(x.device)}" = torch.ops.aten.permute.default(arg0_1, [0, 2, 1]); arg0_1 = None - return (permute,)""" # noqa: B950 + return (permute,)""" post_grad_graph = get_post_grad_graph(f, (x,)) @@ -14948,7 +15520,7 @@ def forward(self, arg0_1: "f32[2, 3, 2][6, 2, 1]{str(x.device)}"): expected_graph2 = f"""\ def forward(self, arg0_1: "Sym(s77)", arg1_1: "f32[s77, 3, 2][6, 2, 1]{str(x.device)}"): permute: "f32[s77, 2, 3][6, 1, 2]{str(x.device)}" = torch.ops.aten.permute.default(arg1_1, [0, 2, 1]); arg1_1 = None - return (permute,)""" # noqa: B950 + return (permute,)""" post_grad_graph = get_post_grad_graph(f, (x,)) self.assertExpectedInline( post_grad_graph, @@ -14974,7 +15546,7 @@ def f(x): expected_graph = f"""\ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", arg3_1: "u8[s77, s27, s53][s27*s53, s53, 1]{str(x.device)}"): permute: "u8[s77, s53, s27][s27*s53, 1, s53]{str(x.device)}" = torch.ops.aten.permute.default(arg3_1, [0, 2, 1]); arg3_1 = None - return (permute,)""" # noqa: B950 + return (permute,)""" self.assertExpectedInline( post_grad_graph, expected_graph, @@ -15171,7 +15743,7 @@ def fn(inp, repeats, output_size): self.assertEqual(fn(*args), torch.compile(fn)(*args)) @unittest.skipIf( - config.triton.autotune_at_compile_time is not False, + config.triton.autotune_at_compile_time is True, "autotune_at_compile_time doesn't work for test with indexing", ) @parametrize("dtype", [torch.int32, torch.int64]) @@ -15382,7 +15954,7 @@ def fn(x): torch.testing.assert_close(out, fn(x)) @unittest.skipIf( - config.triton.autotune_at_compile_time is not False, + config.triton.autotune_at_compile_time is True, "autotune_at_compile_time doesn't work for test with indexing", ) @requires_gpu_and_triton @@ -15475,7 +16047,6 @@ def foo(x): torch.compile(foo, fullgraph=True)(torch.ones(3, 3)) self.assertTrue(len(log2), 1) - @skipIfMPS # Accuracy issue on MPS def test_weight_norm_conv2d(self): """ Verify fix for https://github.com/pytorch/pytorch/issues/165749 @@ -15495,7 +16066,7 @@ def test_weight_norm_conv2d(self): self.assertTrue(same((ref, ref_grad), (act, act_grad), tol=1e-3)) - @skipIfMPS + @xfail_if_mps # MPS codegen does not emit ReductionHint.OUTER def test_inner_reduction_detection(self): if self.device == "cpu": self.skipTest("Skip for CPU device") @@ -15511,7 +16082,7 @@ def f(x): self.assertFalse("ReductionHint.INNER" in code) @skip_if_halide - @requires_cuda_and_triton + @requires_gpu_and_triton def test_triton_argmin_argmax_transpose_logical_index(self): def fn(x): x.tan_() @@ -15548,7 +16119,7 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) @skip_if_halide - @requires_cuda_and_triton + @requires_gpu_and_triton def test_unbacked_float_item(self): def fn(x, max_val): return torch.clamp(x, 0, max_val.item()) @@ -15588,7 +16159,7 @@ def fn(a, b, c, d): if self.device.lower() == "cuda": self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) - @requires_cuda_and_triton + @requires_gpu_and_triton @config.patch(combo_kernels=True) @torch._dynamo.config.patch(assume_static_by_default=False) def test_combo_kernel_store_mask(self): @@ -15921,7 +16492,7 @@ class TestFailure: __test__: bool = False -def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): # noqa: B902 +def copy_tests(my_cls, other_cls, suffix, test_failures=None, xfail_prop=None): for name, value in my_cls.__dict__.items(): if name.startswith("test_"): # You cannot copy functions in Python, so we use closures here to @@ -16095,7 +16666,10 @@ def fn(a: torch.Tensor) -> torch.Tensor: if config.triton.multi_kernel: self.assertEqual(len(kernels), 4) expected_divisible[2] = expected_divisible.pop(1) - elif config.triton.cooperative_reductions: + elif ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ): self.assertEqual(len(kernels), 1) expected_divisible = { # one kernel, with extra workspace/semaphore args @@ -16752,14 +17326,14 @@ def f(a, b): tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0) tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[262144, 512], strides=[1, 262144], block_shape=[XBLOCK, R0_BLOCK], order=[0, 1], offsets=[xoffset, roffset]), boundary_check=[1], padding_option='zero') tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r0_2)), rmask, eviction_policy='evict_last', other=0.0) - tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long + tmp1 = tl.load(block_ptr0, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", ) else: self.assertExpectedInline( "\n".join(lines), """\ tmp0 = tl.reshape(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last'), [XBLOCK, R0_BLOCK]) - tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long + tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", ) # Disable index propagation, so the indirect indexing isn't optimized away @@ -17184,8 +17758,8 @@ def fn(inp, weight): @torch._functorch.config.patch("donated_buffer", True) # The inplace updating does not happen after we fused the - # layernorm backward - @torch._inductor.config.patch("triton.mix_order_reduction", False) + # layernorm backward. + @torch._inductor.config.patch({"triton.mix_order_reduction": False}) def test_donated_buffer_inplace(self): batch_size = 32 seq_length = 50 @@ -17551,6 +18125,123 @@ def fn(x): self.assertIn("0x80000000", code[0]) torch.testing.assert_close(result, fn(inp)) + def test_3d_reductions_with_max_tiles_3(self): + # Inductor only supports at most two reduction iteration ranges, R0 and R1, which the + # reduction component of the kernel can be tiled across. + # When max_tiles>=3, SIMDScheduling.create_tiling would previously incorrectly allow the + # tiling of the kernel in three dimensions, despite there being no pointwise component + # of the kernel. + + @torch._inductor.config.patch( + { + "triton.prefer_nd_tiling": True, + "triton.max_tiles": 3, + "triton.tile_reductions": True, + "triton.use_tensor_descriptor": True, + } + ) + @torch.compile + def sum(x): + return torch.sum(x, [0, 1, 2]) + + shape = (2, 2, 4) + strides = (32, 8, 1) + + arg0_1_orig = torch.arange( + math.prod(shape), device=GPU_TYPE, dtype=torch.int32 + ).view(shape) + arg0_1 = torch.empty_strided( + shape, strides, device=GPU_TYPE, dtype=torch.int32 + ) + arg0_1.copy_(arg0_1_orig) + + actual, code = run_and_get_code(sum, arg0_1) + expected = torch.sum(arg0_1, [0, 1, 2]) + + torch.testing.assert_close(actual=actual, expected=expected) + + fc = FileCheck() + # There's no pointwise work to do, so xnumel should be 1... + fc.check("xnumel = 1") + fc.run(code[0]) + + @config.patch({"triton.decompose_sort_ops": True}) + def test_median_decompose_sort_ops(self): + def fn_default(a): + return torch.median(a) + + def fn_dim(a): + return torch.median(a, dim=1) + + inp = torch.randn(8, 16, device=GPU_TYPE) + for fn in (fn_default, fn_dim): + torch._dynamo.reset() + result, code = run_and_get_code(torch.compile(fn), inp) + self.assertIn( + "sort_with_index", + " ".join(code), + "Expected Triton sort codegen for median", + ) + torch.testing.assert_close(result, fn(inp)) + + @config.patch({"triton.decompose_sort_ops": True}) + def test_mode_decompose_sort_ops(self): + def fn(a): + return torch.mode(a, dim=1) + + # Use integers so ties are common, exercising run-length logic + inp = torch.randint( + 0, 5, size=[8, 16], dtype=torch.float32, device=GPU_TYPE + ) + torch._dynamo.reset() + result, code = run_and_get_code(torch.compile(fn), inp) + self.assertIn( + "sort_with_index", + " ".join(code), + "Expected Triton sort codegen for mode", + ) + expected = fn(inp) + torch.testing.assert_close(result[0], expected[0]) + torch.testing.assert_close(result[1], expected[1]) + + @config.patch({"triton.decompose_sort_ops": True}) + def test_topk_decompose_sort_ops(self): + def fn(a): + return torch.topk(a, 3, dim=-1) + + def fn_largest_false(a): + return torch.topk(a, 3, dim=-1, largest=False) + + inp = torch.randn(8, 16, device=GPU_TYPE) + for test_fn in (fn, fn_largest_false): + torch._dynamo.reset() + result, code = run_and_get_code(torch.compile(test_fn), inp) + self.assertIn( + "sort_with_index", + " ".join(code), + "Expected Triton sort codegen for topk", + ) + expected = test_fn(inp) + torch.testing.assert_close(result[0], expected[0]) + torch.testing.assert_close(result[1], expected[1]) + + @config.patch({"triton.decompose_sort_ops": True}) + def test_kthvalue_decompose_sort_ops(self): + def fn(a): + return torch.kthvalue(a, 3, dim=-1) + + inp = torch.randn(8, 16, device=GPU_TYPE) + torch._dynamo.reset() + result, code = run_and_get_code(torch.compile(fn), inp) + self.assertIn( + "sort_with_index", + " ".join(code), + "Expected Triton sort codegen for kthvalue", + ) + expected = fn(inp) + torch.testing.assert_close(result[0], expected[0]) + torch.testing.assert_close(result[1], expected[1]) + class RNNTest(TestCase): device_type = GPU_TYPE diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 38ae6dace6b92..75c432f72d5b4 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -113,6 +113,9 @@ def run(*ex, **kwargs): "test_as_strided_on_views_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), + "test_resize_overlapping_strides_dynamic_shapes": TestFailure( + ("cpu",), is_skip=True + ), # # Failed to find dynamic for loop variable: # @@ -128,6 +131,7 @@ def run(*ex, **kwargs): "test_arange3_dynamic_shapes": TestFailure(("cpu",)), "test_arange4_dynamic_shapes": TestFailure(("cpu",)), "test_arange6_dynamic_shapes": TestFailure(("cpu",)), + "test_arange7_dynamic_shapes": TestFailure(("cpu",)), "test_clamp_type_promotion_dynamic_shapes": TestFailure(("cpu",)), "test_conv2d_channels_last_dynamic_shapes": TestFailure(("cpu",)), "test_conv3d_dynamic_shapes": TestFailure(("cpu",)), @@ -216,12 +220,6 @@ def run(*ex, **kwargs): "test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)), "test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)), - "test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure( - ("cpu", "cuda") - ), - "test_max_pool2d_with_indices_backward6_dynamic_shapes": TestFailure( - ("cpu", "cuda", "xpu") - ), "test_misaligned_address_issue1_dynamic_shapes": TestFailure(("cpu",)), "test_mm_views_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_new_empty_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 25575114f7b29..1144c1f9c4553 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -27,9 +27,10 @@ from torch.testing._internal.common_utils import ( IS_ARM64, IS_FBCODE, + MI350_ARCH, parametrize, serialTest, - skipIfRocm, + skipIfRocmArch, TEST_CUDA_MEM_LEAK_CHECK, TEST_WITH_ASAN, ) @@ -80,7 +81,6 @@ "test_reduction3_dynamic_shapes": TestFailure(("mps",)), "test_reduction5_dynamic_shapes": TestFailure(("mps",)), "test_roll_dynamic_shapes": TestFailure(("mps",)), - "test_select_scatter_dtype_consistency_dynamic_shapes": TestFailure(("mps",)), "test_std_dynamic_shapes": TestFailure(("mps",)), "test_var_correction_dynamic_shapes": TestFailure(("mps",)), "test_var_mean_div_by_dynamic_shapes": TestFailure(("mps",)), @@ -136,6 +136,16 @@ class DynamicShapesGPUTests(TestCase): DynamicShapesCommonTemplate, DynamicShapesGPUTests, GPU_TYPE, test_failures ) + if HAS_GPU and hasattr( + DynamicShapesGPUTests, "test_conv_with_as_strided_dynamic_shapes_cuda" + ): + # gfx950 shows a deterministic numerical mismatch for this generated test. + DynamicShapesGPUTests.test_conv_with_as_strided_dynamic_shapes_cuda = ( + skipIfRocmArch(MI350_ARCH)( + DynamicShapesGPUTests.test_conv_with_as_strided_dynamic_shapes_cuda + ) + ) + class TestInductorDynamic(TestCase): compile_fn = partial(torch.compile, dynamic=True) @@ -633,7 +643,6 @@ def f(x, w): torch.compile(fullgraph=True)(f)(x, w).sum().backward() self.assertEqual(orig_w, w.grad) - @skipIfRocm # regression in ROCm 7.2, XBLOCK should remain 64 (got 256) @torch._dynamo.config.patch( capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True ) diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 5df1aec2e1ab1..698a413222428 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -32,6 +32,7 @@ from torch.testing._internal.common_methods_invocations import op_db, skipOps from torch.testing._internal.common_utils import ( IS_CI, + IS_LINUX, IS_MACOS, IS_WINDOWS, IS_X86, @@ -232,6 +233,11 @@ def format_op(op): inductor_skips["xpu"] = {} +# torch-xpu-ops: #2956 +inductor_skips["xpu"]["lu"] = {f32} +inductor_skips["xpu"]["nn.functional.linear"] = {f16} +inductor_skips["xpu"]["masked.cumprod"] = {f16} + inductor_expected_failures_single_sample = defaultdict(dict) inductor_expected_failures_single_sample["cpu"] = { @@ -295,27 +301,6 @@ def format_op(op): i32, i64, }, # align with cuda. - ("linalg.pinv", "singular"): {f64}, - # could not create a primitive - "addmv": {f64}, - "fft.fft": {f16}, - "fft.fft2": {f16}, - "fft.fftn": {f16}, - "fft.hfft": {f16}, - "fft.hfft2": {f16}, - "fft.hfftn": {f16}, - "fft.rfft": {f16}, - "fft.rfft2": {f16}, - "fft.rfftn": {f16}, - "fft.ifft": {f16}, - "fft.ifft2": {f16}, - "fft.ifftn": {f16}, - "fft.ihfft": {f16}, - "fft.ihfft2": {f16}, - "fft.ihfftn": {f16}, - "fft.irfft": {f16}, - "fft.irfft2": {f16}, - "fft.irfftn": {f16}, } @@ -991,8 +976,33 @@ def wrapper_noop_set_seed(op, *args, **kwargs): # TODO: Fix these so strides match. inductor_skip_exact_stride = { + "complex", + "empty_permuted", + "fft.irfftn", + "fft.irfft2", + "linalg.diagonal", + "linalg.eigvals", # Fails for ROCM + "linalg.lu", + "linalg.lu_factor", + "linalg.lu_factor_ex", "linalg.matrix_norm", + "linalg.norm", + "linalg.norm.subgradients_at_zero", + "linalg.pinv.singular", + "linalg.svdvals", + "linalg.solve", + "linalg.solve_ex", + "linalg.qr", + "lu", + "matmul", + "__rmatmul__", + "nn.functional.adaptive_avg_pool1d", + "nn.functional.group_norm", + "nn.functional.linear", + "nn.functional.max_pool2d", + "nn.functional.unfold", "ormqr", + "pca_lowrank", "rot90", "sum", "tensordot", @@ -1222,6 +1232,7 @@ def tearDown(self): ) @torch._inductor.config.patch("test_configs.runtime_triton_dtype_assert", True) @torch._inductor.config.patch("test_configs.static_cpp_dtype_assert", True) + @torch._inductor.config.patch("shape_padding", False) @collection_decorator def test_comprehensive(self, device, dtype, op): device_type = torch.device(device).type @@ -1259,16 +1270,19 @@ def test_comprehensive(self, device, dtype, op): # print(f"CONSIDERING OP {op_name} on {device_type} with {dtype} | # {inductor_skips[device_type].get(op_name, set())}", flush=True) if dtype in inductor_skips[device_type].get(op_name, set()): - test_expect = ExpectedTestResult.SKIP # noqa: F841 + test_expect = ExpectedTestResult.SKIP # with open("test_output.txt", "a") as f: # print(f"SKIPPING OP {op_name} on {device_type}", flush=True, file=f) # print(f"SKIPPING OP {op_name} on {device_type}", flush=True) - elif dtype in inductor_expected_failures_single_sample[device_type].get( - op_name, set() + elif ( + device_type == "cpu" + and IS_LINUX + and dtype + in inductor_expected_failures_single_sample[device_type].get(op_name, set()) ) or dtype in inductor_gradient_expected_failures_single_sample[ device_type ].get(op_name, set()): - test_expect = ExpectedTestResult.XFAILURE # noqa: F841 + test_expect = ExpectedTestResult.XFAILURE else: test_expect = ExpectedTestResult.SUCCESS # noqa: F841 diff --git a/test/inductor/test_torchinductor_opinfo_properties.py b/test/inductor/test_torchinductor_opinfo_properties.py index f3e886382da79..362b11db6611b 100644 --- a/test/inductor/test_torchinductor_opinfo_properties.py +++ b/test/inductor/test_torchinductor_opinfo_properties.py @@ -57,7 +57,6 @@ parametrize, skipIfTorchDynamo, TEST_WITH_ASAN, - TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_GPU @@ -242,6 +241,7 @@ def generate_sampled_pairs(num_samples, dtype, device, seed=42): "emulate_precision_casts": True, "eager_numerics.division_rounding": True, "eager_numerics.disable_ftz": True, + "eager_numerics.use_pytorch_libdevice": True, } @@ -437,29 +437,77 @@ def sample_operates_on_batch_dim(op_name, sample_input): "binary_numerical": BINARY_NUMERICAL_XFAILS, } -# Additional expected failures that only apply on ROCm. -# Same structure as the main xfail dicts: test_type -> backend -> op_name -> dtypes. -ROCM_XFAILS = { - "batch_invariance": { - "inductor_default": { - "log1p": {fp32}, - }, +ROCM_EAGER_EQUIV_XFAILS = { + "aot_eager_decomp_partition": { + "nn.functional.gelu": {fp32}, + "nn.functional.layer_norm": {fp32}, + "nn.functional.rms_norm": {fp32}, + "softmax": {fp32}, + "log_softmax": {fp32}, + }, + "inductor_default": { + "sigmoid": {fp32}, + "nn.functional.gelu": {fp32}, + "nn.functional.layer_norm": {fp32}, + "nn.functional.silu": {fp16, fp32}, + "softmax": {fp32}, + "log_softmax": {fp32}, + }, + "inductor_numerics": { + "sigmoid": {fp32}, + "sub": {ALL}, + "nn.functional.gelu": {fp32}, + "nn.functional.layer_norm": {fp32}, + "softmax": {fp32}, + "log_softmax": {fp32}, }, } +ROCM_DETERMINISM_XFAILS = {} -def is_expected_failure(device_type, op_name, backend, test_type, dtype=None): - """Check if a test is expected to fail.""" - xfails = XFAIL_DICTS.get(test_type, {}).get(backend, {}).get(op_name, set()) - is_xfail = dtype in xfails or ALL in xfails +ROCM_BATCH_INVARIANCE_XFAILS = { + "aot_eager_decomp_partition": { + "nn.functional.linear": {ALL}, + }, + "inductor_default": { + "nn.functional.linear": {ALL}, + "log1p": {fp32}, + }, + "inductor_numerics": { + "nn.functional.linear": {ALL}, + }, +} - if not is_xfail and torch.version.hip is not None: - rocm_xfails = ( - ROCM_XFAILS.get(test_type, {}).get(backend, {}).get(op_name, set()) - ) - is_xfail = dtype in rocm_xfails or ALL in rocm_xfails +ROCM_UNARY_NUMERICAL_XFAILS = { + "inductor_default": { + "log1p": {fp32}, + "rsqrt": {bf16, fp32}, + "sigmoid": {fp32}, + "sin": {fp32}, + "tan": {fp32}, + "tanh": {fp32}, + }, + "inductor_numerics": { + "sigmoid": {fp32}, + }, +} - return is_xfail +ROCM_BINARY_NUMERICAL_XFAILS = {} + +ROCM_XFAIL_DICTS = { + "eager_equivalence": ROCM_EAGER_EQUIV_XFAILS, + "determinism": ROCM_DETERMINISM_XFAILS, + "batch_invariance": ROCM_BATCH_INVARIANCE_XFAILS, + "unary_numerical": ROCM_UNARY_NUMERICAL_XFAILS, + "binary_numerical": ROCM_BINARY_NUMERICAL_XFAILS, +} + + +def is_expected_failure(device_type, op_name, backend, test_type, dtype=None): + """Check if a test is expected to fail.""" + xfail_dicts = ROCM_XFAIL_DICTS if torch.version.hip is not None else XFAIL_DICTS + xfails = xfail_dicts.get(test_type, {}).get(backend, {}).get(op_name, set()) + return dtype in xfails or ALL in xfails def compile_fn(fn, backend): @@ -476,7 +524,6 @@ def compile_fn(fn, backend): @unittest.skipIf(IS_WINDOWS, "Skipped on Windows") @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") -@unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm") @unittest.skipIf(not HAS_GPU, "Requires GPU") class TestOpInfoProperties(TestCase): """Test op properties under various inductor modes using OpInfo on CUDA.""" diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 07387bff3d2b8..961f785b28603 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -479,22 +479,22 @@ def test_pointwise_broadcast_nonzero_strides(self, prefer_nd_tiling: bool): load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), boundary_check=[0, 1]) - tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", # noqa: B950 + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[YBLOCK], order=[0], offsets=[yoffset]), boundary_check=[0], eviction_policy='evict_last')[:, None]""", ) self.assertExpectedInline( store_lines, - """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", # noqa: B950 + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[8, 8], strides=[8, 1], block_shape=[YBLOCK, XBLOCK], order=[1, 0], offsets=[yoffset, xoffset]), tl.broadcast_to(tmp2, [YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1])""", ) else: self.assertExpectedInline( load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0]) - tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 + tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", ) self.assertExpectedInline( store_lines, - """ tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])""", # noqa: B950 + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), tl.broadcast_to(tmp2, [XBLOCK]).to(tl.float32), boundary_check=[0])""", ) @parametrize("prefer_nd_tiling", [False, True]) @@ -553,11 +553,6 @@ def test_reduction( view = self._discontiguous_tensor(view_size, self.device) - if num_triton_kernels == 2 and config.triton.cooperative_reductions: - # fewer kernels with cooperative reductions - num_triton_kernels = 1 - num_block_pointers -= 2 - # Expect at least 1 block pointer for the input. # Add 2 more if we generate 2 kernels. result, (code,) = self._run_and_compare( @@ -914,11 +909,15 @@ def test_welford_non_block_pointer( view = self._discontiguous_tensor((259, 311), self.device) # We expect many block pointers for this one. + cooperative_reductions = ( + config.triton.cooperative_reductions + or config.triton.force_cooperative_reductions + ) result, (code,) = self._run_and_compare( torch.var_mean, view, - expected_num_block_pointers=6, - expected_num_triton_kernels=2, + expected_num_block_pointers=0 if cooperative_reductions else 6, + expected_num_triton_kernels=1 if cooperative_reductions else 2, config_patches={"triton.prefer_nd_tiling": True}, ) @@ -1292,12 +1291,12 @@ def test_pointwise_index_order(self): load_lines, """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2]) - tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", # noqa: B950 + tmp1 = tl.load(tl.make_block_ptr(in_ptr1, shape=[5, 5, 5], strides=[100, 10, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), boundary_check=[0, 1, 2])""", ) self.assertExpectedInline( store_lines, - """ tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", # noqa: B950 + """ tl.store(tl.make_block_ptr(out_ptr0, shape=[5, 5, 5], strides=[25, 5, 1], block_shape=[ZBLOCK, YBLOCK, XBLOCK], order=[2, 1, 0], offsets=[zoffset, yoffset, xoffset]), tl.broadcast_to(tmp2, [ZBLOCK, YBLOCK, XBLOCK]).to(tl.float32), boundary_check=[0, 1, 2])""", ) # Check the indices. These are used for non-block pointers. @@ -1306,7 +1305,7 @@ def test_pointwise_index_order(self): """\ zindex = zoffset + tl.arange(0, ZBLOCK)[:, None, None] yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None] - xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", # noqa: B950 + xindex = xoffset + tl.arange(0, XBLOCK)[None, None, :]""", ) def test_expand_clone_broadcast(self): diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index c3bd9bf14612c..734df5dad8685 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -14,15 +14,15 @@ try: - from extension_backends.triton.device_interface import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:device_interface # noqa: B950 + from extension_backends.triton.device_interface import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:device_interface DeviceInterface, ) - from extension_backends.triton.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend # noqa: B950 + from extension_backends.triton.extension_codegen_backend import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_codegen_backend CPUDeviceOpOverrides, ExtensionScheduling, ExtensionWrapperCodegen, ) - from extension_backends.triton.extension_triton_heuristics import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_triton_heuristics # noqa: B950 + from extension_backends.triton.extension_triton_heuristics import ( # @manual=fbcode//caffe2/test/inductor/extension_backends:extension_triton_heuristics EXTENSION_TRITON_META_FIELD, ) except ImportError: diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 0d777ab25afe6..93c5366840764 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -23,12 +23,12 @@ GPU_TYPE, HAS_GPU, HAS_GPU_AND_TRITON, - requires_cuda_with_enough_memory, + requires_gpu_with_enough_memory, ) try: - import triton # noqa: F401 # @manual + import triton # @manual import triton.language as tl # @manual except ImportError: if __name__ == "__main__": @@ -129,11 +129,9 @@ def forward(primals_1, primals_2, primals_5): ] self.assertEqual(forward(*args), foo_c(*args)) - # @skipIfXpu def test_artificial_zgrid(self): self._test_artificial_zgrid() - # @skipIfXpu @config.patch("cpp_wrapper", True) def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() @@ -179,7 +177,6 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): "inductor_meta": inductor_meta, } - # @skipIfXpu def test_pre_hook_assert(self): # assert if any of the configs passed to the CachingAutotuner have pre-hooks args = self._get_cos_kernel_caching_autotuner_args() @@ -273,7 +270,9 @@ def fn(x): res = torch.compile(fn)(x) self.assertEqual(ref, res) - @skipIfXpu(msg="https://github.com/intel/torch-xpu-ops/issues/2331") + @skipIfXpu( + msg="lack _get_exceeding_shared_memory_checker support - torch-xpu-ops: 2331" + ) @skipUnless(HAS_GPU_AND_TRITON, "requires gpu and triton") @parametrize("do_pruning", [False, True]) def test_prune_configs_over_shared_memory_limit(self, do_pruning): @@ -336,7 +335,7 @@ def _do_test(self, gpu_tensor): old_storage_offset = gpu_tensor.storage_offset() gpu_tensor_clone = clone_preserve_strides(gpu_tensor) - peak_mem_before = torch.cuda.max_memory_allocated() + peak_mem_before = torch.get_device_module(GPU_TYPE).max_memory_allocated() cpu_copies = autotuner.copy_args_to_cpu_if_needed(gpu_tensor) self.assertTrue(len(cpu_copies) == 1) @@ -363,21 +362,21 @@ def _do_test(self, gpu_tensor): # Avoid OOM in CI self.assertTrue(peak_mem_after < 1e10) - @requires_cuda_with_enough_memory(1e10) + @requires_gpu_with_enough_memory(1e10) def test_clone_contiguous_args(self): arg = self._create_tensor(pad=0) self.assertTrue(arg.is_contiguous()) self.assertTrue(arg.storage_offset() == 0) self._do_test(arg) - @requires_cuda_with_enough_memory(1e10) + @requires_gpu_with_enough_memory(1e10) def test_clone_non_contiguous_args(self): arg = self._create_tensor(pad=1) self.assertFalse(arg.is_contiguous()) self.assertTrue(arg.storage_offset() == 0) self._do_test(arg) - @requires_cuda_with_enough_memory(1e10) + @requires_gpu_with_enough_memory(1e10) def test_clone_args_with_non_zero_offset(self): arg = self._create_tensor(pad=1, with_offset=True) self.assertFalse(arg.is_contiguous()) @@ -694,6 +693,72 @@ def grid(meta): torch.testing.assert_close(y, expected) +class TestGridExprMaximum(TestCase): + def test_maximum_cpp_mode_casts_int_constants_to_long(self): + from torch._inductor.runtime.triton_heuristics import Grid1D + + grid = Grid1D(inductor_meta={}, mode="cpp") + # Mixed str/int: int constants must be cast to (long) for std::max + result = grid.maximum(["ynumel_0", "ynumel_1", 4480]) + self.assertIn("(long)4480", result) + self.assertIn("std::max", result) + # All strings: no cast needed + result = grid.maximum(["xnumel", "ynumel"]) + self.assertNotIn("(long)", result) + # All ints: constant-folds + self.assertEqual(grid.maximum([10, 20, 5]), 20) + + +class TestGrid2DWithYZOverflowZeroYnumel(TestCase): + """Regression test for https://github.com/pytorch/pytorch/issues/178530""" + + def test_grid2d_yz_overflow_zero_ynumel_python(self): + from torch._inductor.runtime.triton_heuristics import Grid2DWithYZOverflow + + grid = Grid2DWithYZOverflow(inductor_meta={}, mode="python") + grid.generate({"XBLOCK": 128, "YBLOCK": 128}) + # ynumel=0 must not raise ZeroDivisionError + x, y, z = grid.eval_slow( + {"xnumel": 256, "ynumel": 0, "XBLOCK": 128, "YBLOCK": 128} + ) + self.assertEqual(y, 0) + self.assertEqual(z, 0) + + def test_grid2d_yz_overflow_zero_ynumel_cpp(self): + from torch._inductor.runtime.triton_heuristics import Grid2DWithYZOverflow + + grid = Grid2DWithYZOverflow(inductor_meta={}, mode="cpp") + grid.generate({"XBLOCK": 128, "YBLOCK": 128}) + # cpp mode: the generated expression should contain a zero-guard + self.assertIn("== 0", str(grid.y_grid)) + + def test_grid2d_yz_overflow_nonzero_ynumel_unchanged(self): + from torch._inductor.runtime.triton_heuristics import Grid2DWithYZOverflow + + grid = Grid2DWithYZOverflow(inductor_meta={}, mode="python") + grid.generate({"XBLOCK": 128, "YBLOCK": 128}) + # Normal case: ynumel > 0 still works correctly + x, y, z = grid.eval_slow( + {"xnumel": 256, "ynumel": 256, "XBLOCK": 128, "YBLOCK": 128} + ) + self.assertEqual(x, 2) + self.assertEqual(y, 2) + self.assertEqual(z, 1) + + def test_grid2d_yz_overflow_large_ynumel(self): + from torch._inductor.runtime.triton_heuristics import Grid2DWithYZOverflow + + grid = Grid2DWithYZOverflow(inductor_meta={}, mode="python") + grid.generate({"XBLOCK": 128, "YBLOCK": 128}) + # Large ynumel that requires overflow splitting across y and z + x, y, z = grid.eval_slow( + {"xnumel": 128, "ynumel": 128 * 131070, "XBLOCK": 128, "YBLOCK": 128} + ) + self.assertEqual(x, 1) + # y * z must cover all y blocks + self.assertGreaterEqual(y * z, 131070) + + if __name__ == "__main__": if IS_LINUX and HAS_GPU: run_tests() diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index e72cd8b4c87fb..db7510023f663 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -100,11 +100,20 @@ def _triton_get_ast_equal_to_str(params): USE_TF32 = torch.backends.cuda.matmul.fp32_precision == "tf32" if hasattr(triton, "constexpr_function"): - + # Helper functions for triton kernels must be in globals. @triton.constexpr_function def log2(n): return len(bin(n)) - 3 + _get_int_dtype_test = triton.constexpr_function(tl.core.get_int_dtype) + + @triton.jit + def _dtype_helper_test(x): + idtype = _get_int_dtype_test( + bitwidth=x.dtype.primitive_bitwidth, signed=True + ) + return x.to(idtype, bitcast=True) + class KernelTests(torch._inductor.test_case.TestCase): def _kernel_launched_in_code(self, kernel_name: str, code: str) -> bool: @@ -1175,7 +1184,7 @@ def f(x): num_bufs_allocated = code.count(code_string) if ( inductor_config.cpp_wrapper - and inductor_config.triton.autotune_at_compile_time is False + and inductor_config.triton.autotune_at_compile_time is not True ): # Lazy compile emits aoti_torch_empty_strided for scratch space # allocation (global_scratch + profile_scratch) per unique kernel wrapper @@ -1587,6 +1596,50 @@ def f(x): self.assertIn("@triton.constexpr_function", triton_code) self.assertEqual(compiled_out, eager_out) + @unittest.skipIf( + not HAS_GPU or not hasattr(triton, "constexpr_function"), + "newer triton version required", + ) + def test_triton_kernel_with_constexpr_dtype_annotations(self): + """ + Test that constexpr functions with dtype type annotations work correctly. + This tests the fix for proper handling of: + 1. Type annotations using triton.language.core.dtype + 2. Dtype instances (int8, uint8, etc.) that lack __name__ attribute + 3. Function name aliasing + """ + + @triton.jit + def kernel_with_dtype_annotation(out_ptr, n: tl.constexpr): + offs = tl.arange(0, n) + x = tl.full([n], 1.0, dtype=tl.float32) + y = _dtype_helper_test(x) + tl.store(out_ptr + offs, y.to(tl.float32, bitcast=True)) + + def f(n): + out = torch.empty(n, device=GPU_TYPE) + kernel_with_dtype_annotation[(1,)](out, n) + return out + + n = 8 + eager_out = f(n) + compiled_out, (triton_code,) = run_and_get_code( + torch.compile(f, fullgraph=True), n + ) + + # Verify the generated code has proper imports + self.assertIn("from triton.language.core import dtype as dtype", triton_code) + self.assertIn("@triton.constexpr_function", triton_code) + self.assertIn("_get_int_dtype_test = get_int_dtype", triton_code) + # Verify dtype instances are emitted correctly + self.assertIn("int8 = tl.int8", triton_code) + self.assertIn("@triton.jit", triton_code) + self.assertIn("def _dtype_helper_test", triton_code) + + # Verify correctness + self.assertEqual(compiled_out, eager_out) + self.assertTrue(torch.all(compiled_out == 1.0).item()) + @requires_gpu def test_triton_kernel_with_imported_symbol_with_custom_name(self): @triton.jit @@ -2766,7 +2819,10 @@ def fn(sz): self.assertEqual(actual, expected) @requires_gpu - @skipIfXpu(msg="`tl.inline_asm_elementwise` is not yet supported on Intel GPUs") + @skipIfXpu( + msg="`tl.inline_asm_elementwise` is not yet supported on Intel GPUs, " + "https://github.com/pytorch/pytorch/pull/167786" + ) @inductor_config.patch({"triton.autotune_at_compile_time": True}) @parametrize("quotes", ["single", "double"]) def test_kernel_inline_asm(self, quotes): @@ -2843,6 +2899,49 @@ def fn(a, b): ) from e raise + @requires_gpu + def test_constexpr_handling(self): + @triton.jit + def copy_kernel( + src_ptr, + dst_ptr, + n_elements, + stride, + maybe_param, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + x = tl.load(src_ptr + offs * stride, mask=mask) + scale = tl.where(maybe_param != 0, 0.5, 1.0) + x = x * scale + + tl.store(dst_ptr + offs * stride, x, mask=mask) + + t = torch.randn(1024, device=GPU_TYPE) + out = torch.empty(1024, device=GPU_TYPE) + + kwargs = { + "src_ptr": t, + "dst_ptr": out, + "n_elements": 1024, + "stride": 1, + "maybe_param": None, # semantically wrong, but testing frontend specialization + "BLOCK_SIZE": 256, + } + + ttir_module, _ = generate_ttir(copy_kernel, kwargs, tma_descriptor_metadata={}) + ttir_str = str(ttir_module) + + # `constexpr` and None values get inlined, and do not appear as function parameters. + self.assertIn("src_ptr", ttir_str) + self.assertIn("dst_ptr", ttir_str) + self.assertIn("n_elements", ttir_str) + self.assertIn("stride", ttir_str) + self.assertNotIn("BLOCK_SIZE", ttir_str) + self.assertNotIn("maybe_param", ttir_str) + def make_mutation_test(fn): @requires_gpu @@ -2878,6 +2977,26 @@ def helper_add_and_out(x, y, out_ptr): class MutationTests(torch._inductor.test_case.TestCase): # Tests injected below + # Test that a scalar args are not flagged as mutated when passed + # to a tt.call op. + @make_mutation_test + def test_scalar_via_nested_write(): + @triton.jit + def inner(ptr, scalar_offset): + tl.store(ptr + scalar_offset, 1.0) + + @triton.jit + def outer(out_ptr, n_elements): + inner(out_ptr, n_elements) + + t = torch.randn(4) + return ( + outer, + {"out_ptr": t, "n_elements": 4}, + {}, + ["out_ptr"], + ) + # Regression test for #169782 @make_mutation_test def test_with_none_arg(): @@ -3230,7 +3349,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -3239,7 +3358,7 @@ def add_4_times_kernel( mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) - output = tl.zeros((n_elements,), dtype=tl.float32) + output = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for _ in range(4): output += x + y tl.store(out_ptr + offsets, output, mask=mask) @@ -3300,7 +3419,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -3309,7 +3428,7 @@ def add_4_times_kernel( mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) - output = tl.zeros((n_elements,), dtype=tl.float32) + output = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for _ in range(2): for _ in range(2): output += x + y @@ -3336,7 +3455,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -3345,8 +3464,8 @@ def add_4_times_kernel( mask = offsets < n_elements x = tl.load(in_ptr0 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask) - output1 = tl.zeros((n_elements,), dtype=tl.float32) - output2 = tl.zeros((n_elements,), dtype=tl.float32) + output1 = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + output2 = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) for _ in range(2): for _ in range(2): output1 += y @@ -3965,6 +4084,9 @@ def add_kernel_descriptor_method( while name in MutationTests.__dict__: name += "1" + if kernel.fn.__name__ == "add_kernel_2d_autotuned": + fn = unittest.skip("Fails with Triton update")(fn) + setattr(MutationTests, name, fn) @@ -4829,6 +4951,118 @@ def f(dst, src, add_float, N): self.assertEqual(counter.op_count, 2) + @requires_gpu + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_prune_configs_by_called_twice(self, backend): + def early_config_prune(configs, named_args, **kwargs): + return [configs[0]] + + @triton.autotune( + configs=[ + triton.Config(kwargs={"BLOCK_SIZE": 128}), + triton.Config(kwargs={"BLOCK_SIZE": 256}), + ], + key=["N"], + prune_configs_by={"early_config_prune": early_config_prune}, + ) + @triton.jit + def add_kernel( + x_ptr, + y_ptr, + out_ptr, + N, + BLOCK_SIZE: tl.constexpr, + ): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x + y, mask=mask) + + def call_kernel(x, y): + out = torch.empty_like(x) + n_elements = x.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel[grid](x, y, out, N=n_elements) + return out + + @torch.compile(fullgraph=True, backend=backend) + def f(x, y): + out = call_kernel(x, y) + return call_kernel(out, y) + + x = torch.randn(1024, device=GPU_TYPE) + y = torch.randn(1024, device=GPU_TYPE) + + self.assertEqual(f(x, y), x + y + y) + + @requires_gpu + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_heuristics_and_prune_configs_by(self, backend): + def noop_prune(configs, named_args, **kwargs): + return configs + + @triton.autotune( + configs=[ + triton.Config({"BLOCK": 128}), + triton.Config({"BLOCK": 256}), + ], + key=["N"], + prune_configs_by={"early_config_prune": noop_prune}, + ) + @triton.heuristics({"EVEN": lambda args: args["N"] % 128 == 0}) + @triton.jit + def kernel(x_ptr, out_ptr, N, BLOCK: tl.constexpr, EVEN: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < N + tl.store(out_ptr + offs, tl.load(x_ptr + offs, mask=mask) * 2, mask=mask) + + def f(x): + out = torch.empty_like(x) + kernel[(triton.cdiv(x.numel(), 128),)](x, out, x.numel()) + return out + + compiled_f = torch.compile(f, backend=backend, fullgraph=True) + x = torch.randn(1024, device=GPU_TYPE) + self.assertEqual(compiled_f(x), x * 2) + + @requires_gpu + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_triton_kernel_run_with_prune_configs_by(self, backend): + def noop_prune(configs, named_args, **kwargs): + return configs + + @triton.autotune( + configs=[ + triton.Config({"BLOCK": 64}), + triton.Config({"BLOCK": 128}), + ], + key=["N"], + prune_configs_by={"early_config_prune": noop_prune}, + ) + @triton.jit + def kernel(x_ptr, out_ptr, N, BLOCK: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < N + tl.store(out_ptr + offs, tl.load(x_ptr + offs, mask=mask) + 1, mask=mask) + + def f(x): + out = torch.empty_like(x) + kernel.run( + x, + out, + x.numel(), + grid=(triton.cdiv(x.numel(), 64),), + warmup=False, + ) + return out + + compiled_f = torch.compile(f, backend=backend, fullgraph=True) + x = torch.randn(1024, device=GPU_TYPE) + self.assertEqual(compiled_f(x), x + 1) + # see: https://github.com/triton-lang/triton/blob/67ea999935f4511a535a25bdecb27e79e3c3af41/python/test/unit/language/test_decorator.py#L31 @requires_gpu @common_utils.parametrize("non_strict", [True, False]) @@ -5037,6 +5271,62 @@ def fn(a): self.assertEqual(out, fn(a), atol=0.05, rtol=0.05) self.check_code(code[0], num_kernels=1, num_allocs=1, num_deallocs=1) + @requires_cuda_and_triton + def test_fusion_with_ordering_constraints(self): + @triton.jit + def add_kernel(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + x = tl.load(in_ptr0 + offs, mask=mask) + y = tl.load(in_ptr1 + offs, mask=mask) + tl.store(out_ptr + offs, x + y, mask=mask) + + def add(a, b): + out = torch.empty_like(a) + add_kernel[(a.numel(),)](a, b, out, a.numel(), BLOCK_SIZE=1) + return out + + def fn(a, b): + c = add(a, b) + c = add(c, b) + return c.relu() + + a = torch.randn(10, device="cuda") + b = torch.randn(10, device="cuda") + + out, code = run_and_get_code(torch.compile(fn), a, b) + self.assertEqual(out, fn(a, b)) + self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3) + + @requires_cuda_and_triton + def test_fusion_cache(self): + @triton.jit + def add_kernel(in_ptr0, in_ptr1, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + x = tl.load(in_ptr0 + offs, mask=mask) + y = tl.load(in_ptr1 + offs, mask=mask) + tl.store(out_ptr + offs, x + y, mask=mask) + + def add(a, b): + out = torch.empty_like(a) + add_kernel[(a.numel(),)](a, b, out, a.numel(), BLOCK_SIZE=1) + return out + + def fn(a, b): + c = add(a, b).sigmoid() + c = add(c, b).relu() + return c + + a = torch.randn(10, device="cuda") + b = torch.randn(10, device="cuda") + + out, code = run_and_get_code(torch.compile(fn), a, b) + self.assertEqual(out, fn(a, b)) + self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3) + @requires_cuda_and_triton def test_fusion_custom_kernel_with_linebreaks(self): # we do AST manipulation / string manipulation of the kernel source code @@ -5276,6 +5566,32 @@ def fn(a, b): self.assertEqual(out, fn(a, b), atol=0.05, rtol=0.05) self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=3) + @requires_cuda_and_triton + def test_no_fusion_non_unary_epilogue(self): + @triton.jit + def add_kernel(a_ptr, b_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < numel + a = tl.load(a_ptr + offs, mask=mask) + b = tl.load(b_ptr + offs, mask=mask) + out = a + b + tl.store(out_ptr + offs, out, mask=mask) + + def fn(a, b, c): + out = torch.empty_like(a) + GRID = (a.numel(),) + add_kernel[GRID](a, b, out, a.numel(), 1) + return out + c + + a = torch.randn(10, dtype=torch.float32, device="cuda") + b = torch.randn(10, dtype=torch.float32, device="cuda") + c = torch.randn(10, dtype=torch.float32, device="cuda") + + out, code = run_and_get_code(torch.compile(fn), a, b, c) + self.assertEqual(out, fn(a, b, c)) + self.check_code(code[0], num_kernels=2, num_allocs=2, num_deallocs=4) + if HAS_CUDA_AND_TRITON: diff --git a/test/inductor/test_triton_wrapper.py b/test/inductor/test_triton_wrapper.py index b5e822fe4b3a7..dd44ff29a201a 100644 --- a/test/inductor/test_triton_wrapper.py +++ b/test/inductor/test_triton_wrapper.py @@ -6,7 +6,7 @@ import sys import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._inductor.codecache import PyCodeCache from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU @@ -63,7 +63,7 @@ def f(x, y): N = 10 x = torch.rand(N).to(device=GPU_TYPE) y = torch.rand(N).to(device=GPU_TYPE) - f(x, y) # noqa: F841 + f(x, y) compiled_module = self.get_compiled_module() diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index a6cfcfc09514b..6c7a34732c5f4 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -4,6 +4,7 @@ import torch from torch._dynamo import config as dynamo_config +from torch._dynamo.exc import InternalTorchDynamoError from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase as InductorTestCase from torch.testing import make_tensor @@ -311,6 +312,23 @@ def fn(x, a, b): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch( + {"capture_scalar_outputs": True, "capture_dynamic_output_shape_ops": True} + ) + def test_repeat_interleave_with_unbacked_scalar(self, device): + def fn(x, repeats): + return x.repeat_interleave(repeats.item()) + + example_inputs = ( + torch.arange(4, device=device), + torch.scalar_tensor(3, dtype=torch.int64, device=device), + ) + + actual = torch.compile(fn, fullgraph=True, dynamic=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) @parametrize("dynamic", [False, True, None]) @@ -499,9 +517,9 @@ def fn(x): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipIfXpu(msg="FlashAttentionForward headdim limitation on xpu") @skipGPUIf(not HAS_GPU, "requires gpu and triton") @skipCUDAIf(not SM80OrLater, "Requires sm80 or later.") - @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_sdpfa(self, device): if device == "cpu": @@ -526,9 +544,9 @@ def fn(x): x = torch.tensor([1.0, 0.0, 1.0, 0.0], device=device) torch.compile(fn, fullgraph=True)(x) + @skipIfXpu(msg="FlashAttentionForward headdim limitation on xpu") @skipGPUIf(not HAS_GPU, "requires gpu and triton") @skipCUDAIf(not SM80OrLater, "Requires sm80 or later.") - @skipIfXpu(msg="scaled_dot_product_attention is not supported on XPU yet") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_sdfpa_unbacked_strides(self, device): if device == "cpu": @@ -720,6 +738,41 @@ def fn(x): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipCPUIf(True, "Triton codegen bug only affects GPU") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_triton_trunc_large_float_scalar_tensor(self, device): + import math + + def fn(x): + r = math.sqrt(x.size(0)) + r = r**70 + return torch.tensor(math.trunc(r), dtype=torch.float64, device=device) + + example_inputs = (torch.randn(4, device=device),) + torch._dynamo.mark_dynamic(example_inputs[0], 0) + actual = torch.compile(fn, fullgraph=True, dynamic=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + + @skipCPUIf(True, "Triton codegen bug only affects GPU") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_triton_trunc_float_scalar_tensor_preserves_positive_zero(self, device): + import math + + def fn(x): + r = math.sqrt(x.size(0)) - 2.5 + return torch.signbit( + torch.tensor(math.trunc(r), dtype=torch.float64, device=device) + ) + + example_inputs = (torch.randn(4, device=device),) + torch._dynamo.mark_dynamic(example_inputs[0], 0) + actual = torch.compile(fn, fullgraph=True, dynamic=True)(*example_inputs) + expected = fn(*example_inputs) + self.assertEqual(actual, expected) + @skipCPUIf(True, "Triton codegen bug only affects GPU") @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) @@ -773,6 +826,188 @@ def fn(x, exponent_src): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) + @skipGPUIf(not HAS_GPU, "requires gpu and triton") + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_slice_unbacked_bindings_with_later_constraint(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/166460 + # When slicing with an unbacked symint end index (e.g. x[:total] where + # total = sum of data-dependent values), dynamo may allocate a fresh + # unbacked symbol for the output size because it can't prove bounds at + # trace time. Later operations (e.g. new_zeros with padding) may + # establish constraints that make the bounds provable at inductor time. + # Inductor must still define the unbacked symbol even when taking the + # efficient SliceView path. + def fn(x, sizes): + sizes_list = sizes.tolist() + total = sum(sizes_list) + num_padding = x.shape[0] - total + sliced = x[:total] + splits = torch.split(sliced, sizes_list, dim=0) + out = torch.cat([s * 2 for s in splits], dim=0) + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + return out + + example_inputs = ( + torch.randn(64, 16, device=device, dtype=torch.bfloat16), + torch.tensor([8, 8], device=device, dtype=torch.int64), + ) + + actual = torch.compile(fn, fullgraph=True)(*example_inputs) + expected = fn(*example_inputs) + torch.testing.assert_close(actual, expected) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_override_optimization_hint_eager(self, device): + """Test that override_optimization_hint updates var_to_hint_override eagerly.""" + t = torch.tensor([5], device=device) + torch._dynamo.decorators.mark_unbacked(t, 0) + + def fn(x): + u = x.item() + torch._dynamo.override_optimization_hint(u, 42) + return u + 1 + + result = fn(t) + self.assertEqual(result, 6) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_override_optimization_hint_compiled(self, device): + """Test override_optimization_hint inside a compiled function with fullgraph=True.""" + + def fn(x): + u = x.item() + torch._dynamo.override_optimization_hint(u, 42) + return u + 1 + + t = torch.tensor([5], device=device) + compiled_fn = torch.compile(fn, fullgraph=True) + result = compiled_fn(t) + self.assertEqual(result, 6) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_override_optimization_hint_compiled_tolist(self, device): + """Test that override_optimization_hint is a no-op on concrete ints from tolist().""" + + def fn(x): + vals = x.tolist() + for v in vals: + torch._dynamo.override_optimization_hint(v, 99) + return x.sum() + + t = torch.tensor([3, 4, 5], device=device) + compiled_fn = torch.compile(fn, fullgraph=True) + result = compiled_fn(t) + self.assertEqual(result, t.sum()) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_override_optimization_hint_multiple_items(self, device): + """Test override_optimization_hint on multiple unbacked symbols from separate .item() calls.""" + + def fn(x, y): + u = x.item() + v = y.item() + torch._dynamo.override_optimization_hint(u, 128) + torch._dynamo.override_optimization_hint(v, 256) + return u + v + + tx = torch.tensor([10], device=device) + ty = torch.tensor([20], device=device) + compiled_fn = torch.compile(fn, fullgraph=True) + result = compiled_fn(tx, ty) + self.assertEqual(result, 30) + + def test_override_optimization_hint_concrete_int_noop(self, device): + """Test that override_optimization_hint on a plain int is a no-op.""" + torch._dynamo.override_optimization_hint(42, 100) + + def test_override_optimization_hint_rejects_wrong_type(self, device): + """Test that override_optimization_hint raises TypeError on non-int/non-SymInt.""" + with self.assertRaisesRegex(TypeError, "expects a torch.SymInt or int"): + torch._dynamo.override_optimization_hint(3.14, 100) + with self.assertRaisesRegex(TypeError, "expects a torch.SymInt or int"): + torch._dynamo.override_optimization_hint("hello", 100) + + def test_override_optimization_hint_rejects_non_int_val(self, device): + """Test that override_optimization_hint rejects non-int val.""" + with self.assertRaisesRegex(TypeError, "val to be an int"): + torch._dynamo.override_optimization_hint(42, 3.14) + with self.assertRaisesRegex(TypeError, "val to be an int"): + torch._dynamo.override_optimization_hint(42, "hello") + + def test_override_optimization_hint_rejects_derived_expression(self, device): + """Test that override_optimization_hint rejects derived expressions like u0 + 1.""" + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + shape_env = ShapeEnv() + u = shape_env.create_unbacked_symint() + v = u + 1 # derived expression: u0 + 1 + with self.assertRaisesRegex(ValueError, "single unbacked symbol"): + torch._dynamo.override_optimization_hint(v, 42) + + def test_override_optimization_hint_rejects_backed_symbol(self, device): + """Test that override_optimization_hint rejects backed (non-unbacked) symbols.""" + + def fn(t): + s = t.size(0) # backed symbol inside compile + torch._dynamo.override_optimization_hint(s, 42) + return t.sum() + + t = torch.randn(5, device=device) + torch._dynamo.mark_dynamic(t, 0) + with self.assertRaisesRegex( + InternalTorchDynamoError, + "expects an unbacked symbol", + ): + torch.compile(fn, fullgraph=True)(t) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_override_optimization_hint_in_fx_pass(self, device): + """Test using override_optimization_hint in a custom FX pass (backend). + + This shows the intended use case: a custom backend or FX pass walks + the graph, finds unbacked symbols from .item() calls, and sets + optimization hints on them before handing the graph to inductor. + """ + + def fx_pass_backend(gm, example_inputs): + # Walk the graph and set hints on unbacked SymInts + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.aten.item.default + ): + # The node's example value is a SymInt from .item() + sym_val = node.meta.get("example_value", None) + if sym_val is not None and isinstance(sym_val, torch.SymInt): + torch._dynamo.override_optimization_hint(sym_val, 512) + + # Verify the hint was set on shape_env + shape_env = None + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and node.target is torch.ops.aten.item.default + ): + sym_val = node.meta.get("example_value", None) + if sym_val is not None and isinstance(sym_val, torch.SymInt): + shape_env = sym_val.node.shape_env + expr = sym_val.node.expr + if expr not in shape_env.var_to_hint_override: + raise AssertionError("hint not set on shape_env") + if shape_env.var_to_hint_override[expr] != 512: + raise AssertionError("hint value mismatch") + + return gm + + def fn(x): + u = x.item() + return u + 1 + + t = torch.tensor([7], device=device) + compiled_fn = torch.compile(fn, backend=fx_pass_backend, fullgraph=True) + result = compiled_fn(t) + self.assertEqual(result, 8) + instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) diff --git a/test/inductor/test_user_streams.py b/test/inductor/test_user_streams.py index d86915c6c875d..bccd1acfe0124 100644 --- a/test/inductor/test_user_streams.py +++ b/test/inductor/test_user_streams.py @@ -11,6 +11,7 @@ import unittest import torch +import torch._inductor.config as inductor_config import torch._inductor.metrics from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._inductor.codegen.wrapper import ( @@ -27,8 +28,12 @@ from torch._inductor.stream_utils import get_stream_name from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import IndentedBuffer -from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import instantiate_parametrized_tests +from torch.testing import FileCheck +from torch.testing._internal.common_cuda import SM90OrLater, TEST_CUDA +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + TEST_WITH_ROCM, +) def _extract_wrapper_body(code): @@ -217,9 +222,9 @@ def fn(x, y): # Verify correctness self.assertEqual(result, expected) - # Verify generated code contains stream handling - # Streams are acquired from a pool, so check for pool usage or context manager + # Verify generated code contains stream handling and synchronize survives self.assertIn("torch.cuda.stream", code) + self.assertIn("synchronize_stream", code) def test_compile_preserves_stream_semantics(self): """Test that compiled code preserves stream execution semantics.""" @@ -243,8 +248,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify stream context is present in generated code + # Verify stream context and synchronize survive self.assertIn("torch.cuda.stream", code) + self.assertIn("synchronize_stream", code) def test_multiple_stream_contexts(self): """Test compilation with multiple stream context switches.""" @@ -274,12 +280,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify multiple stream contexts in generated code - # The scheduler may optimize stream usage; check for at least 1 stream context - self.assertTrue( - code.count("torch.cuda.stream") >= 1 or "stream" in code.lower(), - "Expected stream context in generated code", - ) + # Verify stream contexts and synchronization survive + self.assertGreaterEqual(code.count("torch.cuda.stream"), 1) + self.assertIn("synchronize_stream", code) def test_nested_stream_contexts(self): """Test compilation with nested stream contexts.""" @@ -307,12 +310,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify nested stream contexts - # The scheduler may optimize stream usage; check for at least 1 stream context - self.assertTrue( - code.count("torch.cuda.stream") >= 1 or "stream" in code.lower(), - "Expected stream context in generated code", - ) + # Verify nested stream contexts and synchronization survive + self.assertGreaterEqual(code.count("torch.cuda.stream"), 1) + self.assertIn("synchronize_stream", code) def test_stream_context_with_data_dependency(self): """Test stream contexts with data flowing between streams.""" @@ -339,8 +339,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify stream context is present + # Verify stream context and synchronize survive self.assertIn("torch.cuda.stream", code) + self.assertIn("synchronize_stream", code) def test_event_record_and_wait(self): """Test compilation with explicit event record and wait.""" @@ -371,17 +372,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify event operations in generated code - # Events may be generated as custom ops (torch.ops.streams.record_event/wait_event) - # or as internal event methods (.record_event()/.wait()) - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) - self.assertTrue( - "wait_event" in code or ".wait(" in code, - "Expected wait_event or .wait( in generated code", - ) + # Verify event operations survive compilation as custom ops + self.assertIn("record_event", code) + self.assertIn("wait_event", code) def test_event_record_on_stream(self): """Test event recording on a specific stream.""" @@ -414,16 +407,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify event record/wait with explicit stream args - # Events may be generated as custom ops or internal event methods - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) - self.assertTrue( - "wait_event" in code or ".wait(" in code, - "Expected wait_event or .wait( in generated code", - ) + # Verify event operations survive compilation as custom ops + self.assertIn("record_event", code) + self.assertIn("wait_event", code) def test_multiple_events_multiple_streams(self): """Test multiple events synchronizing multiple streams.""" @@ -462,12 +448,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify multiple events and streams - # Events may be internally managed, not explicitly constructed - record_count = code.count("record_event") + code.count(".record(") - wait_count = code.count("wait_event") + code.count(".wait(") - self.assertGreaterEqual(record_count, 2) - self.assertGreaterEqual(wait_count, 2) + # Verify multiple events and streams survive as custom ops + self.assertGreaterEqual(code.count("record_event"), 2) + self.assertGreaterEqual(code.count("wait_event"), 2) def test_event_wait_without_record(self): """Test that waiting on unrecorded event works (no-op).""" @@ -496,15 +479,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify event operations (may appear as custom ops or methods) - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) - self.assertTrue( - "wait_event" in code or ".wait(" in code, - "Expected wait_event or .wait( in generated code", - ) + # Verify event operations survive compilation as custom ops + self.assertIn("record_event", code) + self.assertIn("wait_event", code) def test_stream_wait_event(self): """Test stream.wait_event() method.""" @@ -533,11 +510,8 @@ def fn(x): self.assertEqual(result, expected) - # Verify stream.wait_event is present (may appear as custom ops or methods) - self.assertTrue( - "wait_event" in code or ".wait(" in code, - "Expected wait_event or .wait( in generated code", - ) + # Verify stream.wait_event survives compilation as custom op + self.assertIn("wait_event", code) def test_bidirectional_stream_sync(self): """Test bidirectional synchronization between streams.""" @@ -577,13 +551,9 @@ def fn(x): self.assertEqual(result, expected) - # Verify bidirectional sync - multiple records and waits - # These may appear as custom ops (torch.ops.streams.record_event/wait_event) - # or as internal event methods (.record_event()/.wait()) - record_count = code.count("record_event") + code.count(".record(") - wait_count = code.count("wait_event") + code.count(".wait(") - self.assertGreaterEqual(record_count, 2) - self.assertGreaterEqual(wait_count, 2) + # Verify bidirectional sync - multiple records and waits as custom ops + self.assertGreaterEqual(code.count("record_event"), 2) + self.assertGreaterEqual(code.count("wait_event"), 2) def test_three_streams_pipeline(self): """Test pipeline pattern with three streams.""" @@ -626,13 +596,9 @@ def fn(x): self.assertEqual(result, expected) # Verify three-stage pipeline with 3 streams - # Streams may be managed via pool, check for stream usage pattern - self.assertTrue( - code.count("torch.cuda.stream") >= 3 or "stream" in code.lower(), - "Expected stream context in generated code", - ) - record_count = code.count("record_event") + code.count(".record(") - self.assertGreaterEqual(record_count, 2) + self.assertGreaterEqual(code.count("torch.cuda.stream"), 3) + self.assertGreaterEqual(code.count("record_event"), 2) + self.assertGreaterEqual(code.count("wait_event"), 2) def test_parallel_streams_join(self): """Test parallel work on multiple streams joining at the end.""" @@ -679,15 +645,9 @@ def fn(x): self.assertEqual(result, expected) # Verify parallel streams joining - # Streams may be managed via pool, check for stream usage pattern - self.assertTrue( - code.count("torch.cuda.stream") >= 3 or "stream" in code.lower(), - "Expected stream context in generated code", - ) - record_count = code.count("record_event") + code.count(".record(") - wait_count = code.count("wait_event") + code.count(".wait(") - self.assertGreaterEqual(record_count, 3) - self.assertGreaterEqual(wait_count, 3) + self.assertGreaterEqual(code.count("torch.cuda.stream"), 3) + self.assertGreaterEqual(code.count("record_event"), 3) + self.assertGreaterEqual(code.count("wait_event"), 3) def test_fan_out_fan_in(self): """Test fan-out from one stream to multiple, then fan-in.""" @@ -733,10 +693,8 @@ def fn(x): self.assertEqual(result, expected) # Verify fan-out/fan-in pattern - record_count = code.count("record_event") + code.count(".record(") - wait_count = code.count("wait_event") + code.count(".wait(") - self.assertGreaterEqual(record_count, 3) - self.assertGreaterEqual(wait_count, 4) + self.assertGreaterEqual(code.count("record_event"), 3) + self.assertGreaterEqual(code.count("wait_event"), 4) def test_four_streams_diamond(self): """Test diamond pattern: one start, two parallel, one end.""" @@ -785,15 +743,9 @@ def fn(x): self.assertEqual(result, expected) # Verify diamond pattern - # Streams may be managed via pool, check for stream usage pattern - self.assertTrue( - code.count("torch.cuda.stream") >= 3 or "stream" in code.lower(), - "Expected stream context in generated code", - ) - record_count = code.count("record_event") + code.count(".record(") - wait_count = code.count("wait_event") + code.count(".wait(") - self.assertGreaterEqual(record_count, 3) - self.assertGreaterEqual(wait_count, 4) + self.assertGreaterEqual(code.count("torch.cuda.stream"), 3) + self.assertGreaterEqual(code.count("record_event"), 3) + self.assertGreaterEqual(code.count("wait_event"), 4) def test_stream_reuse_across_iterations(self): """Test that streams can be reused across loop iterations.""" @@ -821,16 +773,10 @@ def fn(x): self.assertEqual(result, expected) - # Verify stream reuse in loop + # Verify stream reuse in loop — events survive compilation self.assertIn("torch.cuda.stream", code) - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) - self.assertTrue( - "wait_event" in code or ".wait(" in code, - "Expected wait_event or .wait( in generated code", - ) + self.assertIn("record_event", code) + self.assertIn("wait_event", code) def test_no_fusion_across_streams(self): """Test that operations on different streams are not fused together.""" @@ -871,7 +817,9 @@ def fn(x): self.assertEqual(result, expected) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + # 3 kernels: s1 pointwise, s2 pointwise, and the final add on + # the default stream (which is a third stream context). + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 3) def test_no_fusion_across_streams_with_dependency(self): """Test no fusion when there's a data dependency across streams.""" @@ -935,6 +883,256 @@ def fn(x): # All pointwise ops on same stream should fuse into 1 kernel self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) + def test_no_fusion_simple_dependency_across_streams(self): + """Regression: a single pointwise consumed across a stream boundary must not fuse.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + a = x + 1 + e = s1.record_event() + s2.wait_event(e) + with torch.cuda.stream(s2): + b = a * 2 + s1.synchronize() + s2.synchronize() + return b + + x = torch.randn(1024, device="cuda") + + expected = fn(x) + compiled_fn = torch.compile(fn) + torch._inductor.metrics.reset() + result, (code,) = run_and_get_code(compiled_fn, x) + + self.assertEqual(result, expected) + + # Must be 2 separate kernels on 2 streams, not fused into 1 + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + + @torch._inductor.config.patch(combo_kernels=True) + def test_no_combo_kernel_fusion_across_streams(self): + """Combo kernels must not group nodes on different streams.""" + from torch._inductor.utils import run_and_get_code + + def fn(x, y, z, w): + s = torch.cuda.Stream() + event = torch.cuda.Event() + + # Independent pointwise ops on different streams at the same + # topological level — combo kernels must not merge them. + a = x + y + event.record() + with torch.cuda.stream(s): + event.wait() + b = z + w + s.synchronize() + return a, b + + x = torch.randn(1024, device="cuda") + y = torch.randn(1024, device="cuda") + z = torch.randn(1024, device="cuda") + w = torch.randn(1024, device="cuda") + + expected = fn(x, y, z, w) + compiled_fn = torch.compile(fn) + torch._inductor.metrics.reset() + result, (code,) = run_and_get_code(compiled_fn, x, y, z, w) + + self.assertEqual(result, expected) + # 2 kernels: one per stream. Without the stream-aware fix, combo + # kernels would merge them into 1. + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + + @torch._inductor.config.patch(combo_kernels=True) + def test_combo_kernel_fusion_within_same_stream(self): + """Combo kernels should still group independent nodes on the same stream.""" + from torch._inductor.utils import run_and_get_code + + def fn(x, y): + s = torch.cuda.Stream() + + with torch.cuda.stream(s): + # Two independent pointwise ops on the same stream — eligible + # for combo kernel fusion. + a = x * 2 + b = y * 3 + + s.synchronize() + return a + b + + x = torch.randn(1024, device="cuda") + y = torch.randn(1024, device="cuda") + + expected = fn(x, y) + compiled_fn = torch.compile(fn) + torch._inductor.metrics.reset() + result, (code,) = run_and_get_code(compiled_fn, x, y) + + self.assertEqual(result, expected) + # With combo kernels, the two independent ops on the same stream + # should be combined, yielding fewer kernels than without. + self.assertLessEqual(torch._inductor.metrics.generated_kernel_count, 2) + + def test_cross_stream_stride_copy(self): + """A contiguous copy forced by a non-contiguous slice across streams + must run on the consumer's stream, not the producer's.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + + with torch.cuda.stream(s1): + a = x + 1 + b = a[:, ::2] # non-contiguous slice + e = s1.record_event() + s2.wait_event(e) + with torch.cuda.stream(s2): + c = b.contiguous() + d = c + 1 + s2.synchronize() + return d + + x = torch.randn(64, 64, device="cuda") + + expected = fn(x) + compiled_fn = torch.compile(fn) + torch._inductor.metrics.reset() + result, (code,) = run_and_get_code(compiled_fn, x) + + self.assertEqual(result, expected) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + + # Verify: s1 gets the pointwise (x+1), s2 gets the fused copy + add. + # The contiguous copy is fused into the s2 triton kernel (which reads + # from the s1 output buffer with strided indexing). If the copy were + # incorrectly placed on s1, we'd see 2 kernels on s1 instead of 1. + wrapper = _extract_wrapper_body(code) + lines = wrapper.split("\n") + current_stream = None + stream_kernels: dict[str | None, list[str]] = {} + for line in lines: + stripped = line.strip() + if "with torch.cuda.stream(" in stripped: + if "stream1" in stripped: + current_stream = "s1" + elif "stream2" in stripped: + current_stream = "s2" + elif "default_stream" in stripped: + current_stream = "default" + elif ".run(" in stripped: + stream_kernels.setdefault(current_stream, []).append(stripped) + + self.assertEqual( + len(stream_kernels.get("s1", [])), + 1, + f"Expected 1 kernel on s1, got: {stream_kernels}", + ) + self.assertEqual( + len(stream_kernels.get("s2", [])), + 1, + f"Expected 1 kernel on s2, got: {stream_kernels}", + ) + + def test_no_buffer_reuse_across_streams(self): + """Buffer produced on one stream must not be reused in-place on another.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + with torch.cuda.stream(s1): + a = x + 1 + e = s1.record_event() + s2.wait_event(e) + with torch.cuda.stream(s2): + b = a + 2 + s2.synchronize() + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + # The second kernel should allocate a fresh buffer, not reuse + # the one produced on the other stream + wrapper = _extract_wrapper_body(code) + self.assertIn("record_event", wrapper) + self.assertIn("wait_event", wrapper) + self.assertNotIn("buf0; del buf0", wrapper) + + def test_stream_record_wait_event_not_dropped(self): + """stream.record_event() and stream.wait_event() must survive compilation.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + with torch.cuda.stream(s1): + a = x + 1 + e = s1.record_event() + s2.wait_event(e) + with torch.cuda.stream(s2): + b = a * 2 + s2.synchronize() + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + self.assertIn("record_event", code) + self.assertIn("wait_event", code) + self.assertIn("synchronize_stream", code) + + def test_stream_synchronize_not_dropped(self): + """stream.synchronize() must survive compilation and appear in wrapper code.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + a = x + 1 + s.synchronize() + b = a * 2 + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + self.assertIn("synchronize_stream", code) + + def test_stream_wait_stream_not_dropped(self): + """stream.wait_stream() must survive compilation and appear in wrapper code.""" + from torch._inductor.utils import run_and_get_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + with torch.cuda.stream(s1): + a = x + 1 + s2.wait_stream(s1) + with torch.cuda.stream(s2): + b = a * 2 + s2.synchronize() + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + self.assertIn("wait_stream", code) + self.assertIn("synchronize_stream", code) + def test_codegen_structure_single_stream(self): """Verify wrapper structure for pointwise ops with one side stream.""" from torch._inductor.utils import run_and_get_code @@ -961,13 +1159,13 @@ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[1024]"): l_x_ = L_x_ - get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(0) + get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1); get_external_object_by_index = None a: "f32[1024]" = l_x_ * 2 b: "f32[1024]" = l_x_ * 3; l_x_ = None - synchronize = get_external_object_by_index.synchronize(); get_external_object_by_index = synchronize = None + synchronize_stream = torch.ops.streams.synchronize_stream(1); synchronize_stream = None add: "f32[1024]" = a + b; a = b = None return (add,) @@ -982,13 +1180,18 @@ def forward(self, L_x_: "f32[1024]"): with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) default_stream = torch.cuda.current_stream() - stream1 = torch.cuda.Stream(device=0) + stream1 = get_external_object_by_index(1) with torch.cuda.stream(stream1): + arg0_1 = copy_if_misaligned(arg0_1) buf0 = empty_strided_cuda((1024, ), (1, ), torch.float32) - buf1 = buf0; del buf0 + raw_stream = get_raw_stream(0) + triton_kernel.run(arg0_1, buf0, 1024, stream=raw_stream) + with torch.cuda.stream(default_stream): + buf3 = empty_strided_cuda((1024, ), (1, ), torch.float32) stream0 = get_raw_stream(0) - triton_kernel.run(buf1, arg0_1, 1024, stream=stream0) - return (buf1, )""", + triton_kernel.run(arg0_1, buf0, buf3, 1024, stream=stream0) + torch.ops.streams.synchronize_stream.default(1) + return (buf3, )""", ) def test_codegen_structure_pipeline(self): @@ -1024,42 +1227,38 @@ def forward(self, L_x_: "f32[32, 32]", L_w1_: "f32[32, 32]", L_w2_: "f32[32, 32] l_w1_ = L_w1_ l_w2_ = L_w2_ - get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(0) + get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1); get_external_object_by_index = None - get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1); get_external_object_by_index_1 = None + get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2); get_external_object_by_index_1 = None a: "f32[32, 32]" = l_x_ @ l_w1_; l_x_ = l_w1_ = None - get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2); get_external_object_by_index_2 = None - record_event = torch.ops.streams.record_event(1, 2); record_event = None + get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(0); get_external_object_by_index_2 = None + record_event = torch.ops.streams.record_event(2, 0); record_event = None - wait_event = torch.ops.streams.wait_event(1, 0); wait_event = None + wait_event = torch.ops.streams.wait_event(2, 1); wait_event = None b: "f32[32, 32]" = a @ l_w2_; a = l_w2_ = None - synchronize = get_external_object_by_index.synchronize(); get_external_object_by_index = synchronize = None + synchronize_stream = torch.ops.streams.synchronize_stream(1); synchronize_stream = None return (b,) -""", # noqa: B950 +""", ) wrapper_body = _extract_wrapper_body(code) - self.assertExpectedInline( - wrapper_body, + FileCheck().run( """\ -arg0_1, arg1_1, arg2_1 = args -torch.ops.streams.record_event.default(1, 2) -torch.ops.streams.wait_event.default(1, 0) -with torch.cuda._DeviceGuard(0): - torch.cuda.set_device(0) - default_stream = torch.cuda.current_stream() - stream1 = torch.cuda.Stream(device=0) - with torch.cuda.stream(default_stream): - buf2 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.mm(arg0_1, arg1_1, out=buf2) - with torch.cuda.stream(stream1): - buf3 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.mm(buf2, arg2_1, out=buf3) - return (buf3, )""", +# CHECK: with torch.cuda.stream(default_stream): +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: record_event +# CHECK: with torch.cuda.stream(stream1): +# CHECK: wait_event +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: with torch.cuda.stream(default_stream): +# CHECK: synchronize_stream""", + wrapper_body, ) def test_codegen_structure_three_stream_pipeline(self): @@ -1107,64 +1306,60 @@ def forward(self, L_x_: "f32[32, 32]", L_w1_: "f32[32, 32]", L_w2_: "f32[32, 32] l_w2_ = L_w2_ l_w3_ = L_w3_ - get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(0) + get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1); get_external_object_by_index = None - get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1) + get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2); get_external_object_by_index_1 = None - get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2) + get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(3); get_external_object_by_index_2 = None - get_external_object_by_index_3 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(3); get_external_object_by_index_3 = None + get_external_object_by_index_3 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(4); get_external_object_by_index_3 = None - get_external_object_by_index_4 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(4); get_external_object_by_index_4 = None + get_external_object_by_index_4 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(5); get_external_object_by_index_4 = None a: "f32[32, 32]" = l_x_ @ l_w1_; l_x_ = l_w1_ = None - record_event = torch.ops.streams.record_event(3, 0); record_event = None + record_event = torch.ops.streams.record_event(4, 1); record_event = None - wait_event = torch.ops.streams.wait_event(3, 1); wait_event = None + wait_event = torch.ops.streams.wait_event(4, 2); wait_event = None b: "f32[32, 32]" = a @ l_w2_; a = l_w2_ = None - record_event_1 = torch.ops.streams.record_event(4, 1); record_event_1 = None + record_event_1 = torch.ops.streams.record_event(5, 2); record_event_1 = None - wait_event_1 = torch.ops.streams.wait_event(4, 2); wait_event_1 = None + wait_event_1 = torch.ops.streams.wait_event(5, 3); wait_event_1 = None c: "f32[32, 32]" = b @ l_w3_; b = l_w3_ = None - synchronize = get_external_object_by_index.synchronize(); get_external_object_by_index = synchronize = None + synchronize_stream = torch.ops.streams.synchronize_stream(1); synchronize_stream = None - synchronize_1 = get_external_object_by_index_1.synchronize(); get_external_object_by_index_1 = synchronize_1 = None + synchronize_stream_1 = torch.ops.streams.synchronize_stream(2); synchronize_stream_1 = None - synchronize_2 = get_external_object_by_index_2.synchronize(); get_external_object_by_index_2 = synchronize_2 = None + synchronize_stream_2 = torch.ops.streams.synchronize_stream(3); synchronize_stream_2 = None return (c,) -""", # noqa: B950 +""", ) wrapper_body = _extract_wrapper_body(code) - self.assertExpectedInline( - wrapper_body, + FileCheck().run( """\ -arg0_1, arg1_1, arg2_1, arg3_1 = args -torch.ops.streams.record_event.default(3, 0) -torch.ops.streams.wait_event.default(3, 1) -torch.ops.streams.record_event.default(4, 1) -torch.ops.streams.wait_event.default(4, 2) -with torch.cuda._DeviceGuard(0): - torch.cuda.set_device(0) - default_stream = torch.cuda.current_stream() - stream1 = torch.cuda.Stream(device=0) - stream2 = torch.cuda.Stream(device=0) - stream3 = torch.cuda.Stream(device=0) - with torch.cuda.stream(stream1): - buf4 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.mm(arg0_1, arg1_1, out=buf4) - with torch.cuda.stream(stream2): - buf5 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.mm(buf4, arg2_1, out=buf5) - with torch.cuda.stream(stream3): - buf6 = buf4; del buf4 - extern_kernels.mm(buf5, arg3_1, out=buf6) - return (buf6, )""", +# CHECK: with torch.cuda.stream(stream1): +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: record_event +# CHECK: with torch.cuda.stream(stream2): +# CHECK: wait_event +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: record_event +# CHECK: with torch.cuda.stream(stream3): +# CHECK: wait_event +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: with torch.cuda.stream(default_stream): +# CHECK: synchronize_stream +# CHECK: synchronize_stream +# CHECK: synchronize_stream""", + wrapper_body, ) def test_codegen_structure_parallel_matmuls(self): @@ -1207,60 +1402,337 @@ def forward(self, L_x_: "f32[32, 32]", L_w1_: "f32[32, 32]", L_w2_: "f32[32, 32] l_w1_ = L_w1_ l_w2_ = L_w2_ - get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(0) + get_external_object_by_index = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1); get_external_object_by_index = None - get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(1) + get_external_object_by_index_1 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2); get_external_object_by_index_1 = None - get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(2); get_external_object_by_index_2 = None + get_external_object_by_index_2 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(3); get_external_object_by_index_2 = None - get_external_object_by_index_3 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(3); get_external_object_by_index_3 = None + get_external_object_by_index_3 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(4); get_external_object_by_index_3 = None a: "f32[32, 32]" = l_x_ @ l_w1_; l_w1_ = None - record_event = torch.ops.streams.record_event(2, 0); record_event = None + record_event = torch.ops.streams.record_event(3, 1); record_event = None b: "f32[32, 32]" = l_x_ @ l_w2_; l_x_ = l_w2_ = None - record_event_1 = torch.ops.streams.record_event(3, 1); record_event_1 = None + record_event_1 = torch.ops.streams.record_event(4, 2); record_event_1 = None - get_external_object_by_index_4 = torch__dynamo_graph_bytecode_inputs_get_external_object_by_index(4); get_external_object_by_index_4 = None - wait_event = torch.ops.streams.wait_event(2, 4); wait_event = None + wait_event = torch.ops.streams.wait_event(3, 0); wait_event = None - wait_event_1 = torch.ops.streams.wait_event(3, 4); wait_event_1 = None + wait_event_1 = torch.ops.streams.wait_event(4, 0); wait_event_1 = None c: "f32[32, 32]" = a + b; a = b = None - synchronize = get_external_object_by_index.synchronize(); get_external_object_by_index = synchronize = None + synchronize_stream = torch.ops.streams.synchronize_stream(1); synchronize_stream = None - synchronize_1 = get_external_object_by_index_1.synchronize(); get_external_object_by_index_1 = synchronize_1 = None + synchronize_stream_1 = torch.ops.streams.synchronize_stream(2); synchronize_stream_1 = None return (c,) -""", # noqa: B950 +""", ) wrapper_body = _extract_wrapper_body(code) - self.assertExpectedInline( - wrapper_body, + FileCheck().run( """\ -arg0_1, arg1_1, arg2_1 = args -torch.ops.streams.record_event.default(2, 0) -torch.ops.streams.record_event.default(3, 1) -torch.ops.streams.wait_event.default(2, 4) -torch.ops.streams.wait_event.default(3, 4) -with torch.cuda._DeviceGuard(0): - torch.cuda.set_device(0) - default_stream = torch.cuda.current_stream() - stream1 = torch.cuda.Stream(device=0) - stream2 = torch.cuda.Stream(device=0) - with torch.cuda.stream(stream1): - buf4 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.mm(arg0_1, arg1_1, out=buf4) - with torch.cuda.stream(default_stream): - buf5 = empty_strided_cuda((32, 32), (32, 1), torch.float32) - extern_kernels.addmm(buf4, arg0_1, arg2_1, alpha=1, beta=1, out=buf5) - return (buf5, )""", # noqa: B950 +# CHECK: with torch.cuda.stream(stream1): +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: record_event +# CHECK: with torch.cuda.stream(stream2): +# CHECK: copy_if_misaligned +# CHECK: extern_kernels.mm( +# CHECK: record_event +# CHECK: with torch.cuda.stream(default_stream): +# CHECK: wait_event +# CHECK: triton_kernel.run( +# CHECK: synchronize_stream +# CHECK: synchronize_stream""", + wrapper_body, ) +@unittest.skipUnless(TEST_CUDA, "requires CUDA") +class TestStreamOrderingStress(InductorTestCase): + """Stress tests verifying that interleaved event record/wait ops + produce correct ordering under compilation. Each test uses large + matmuls so there is real GPU work, and repeats many iterations so + that race conditions (if ordering is wrong) surface reliably.""" + + N = 4096 # matrix size — big enough for real GPU work + ITERS = 20 # repetitions per test + + def _check_compiled_matches_eager(self, fn, *args): + """Run fn eagerly and compiled, assert results match over ITERS runs.""" + compiled_fn = torch.compile(fn) + for _ in range(self.ITERS): + expected = fn(*args) + actual = compiled_fn(*args) + # Full device sync as a safety net to ensure all stream work + # is visible before comparing results. + torch.cuda.synchronize() + if not isinstance(expected, (tuple, list)): + expected, actual = [expected], [actual] + for e, a in zip(expected, actual): + self.assertEqual(a, e) + + @staticmethod + def _heavy_matmul_chain(x, w, depth=8): + """Chain of matmuls to create substantial GPU work (~ms). + Used to widen the race window between streams so that missing + synchronization is observable.""" + h = x + for _ in range(depth): + h = h @ w + return h + + # ------------------------------------------------------------------ + # 1. Race: producer does heavy work, consumer reads the result. + # Without the event.wait() the consumer would launch immediately + # and read stale memory because the producer chain hasn't finished. + # ------------------------------------------------------------------ + def test_race_producer_consumer(self): + N = self.N + + def fn(x, w): + s = torch.cuda.Stream() + e = torch.cuda.Event() + + # Heavy producer on default stream — takes real GPU time + a = TestStreamOrderingStress._heavy_matmul_chain(x, w) + e.record() + + with torch.cuda.stream(s): + e.wait() # removing this would cause a race + b = a + 1 + + s.synchronize() + return b + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 # use scaled identity for stability + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 2. Race: ping-pong where each direction has heavy work. + # Both event.wait() calls are load-bearing. + # ------------------------------------------------------------------ + def test_race_ping_pong(self): + N = self.N + + def fn(x, w): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + + with torch.cuda.stream(s1): + a = TestStreamOrderingStress._heavy_matmul_chain(x, w) + e1.record(s1) + + with torch.cuda.stream(s2): + e1.wait(s2) + b = TestStreamOrderingStress._heavy_matmul_chain(a, w) + e2.record(s2) + + with torch.cuda.stream(s1): + e2.wait(s1) + c = b + a + + s1.synchronize() + s2.synchronize() + return c + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 3. Race: fan-out where the producer is slow. + # All three consumers depend on the producer finishing. + # ------------------------------------------------------------------ + def test_race_fan_out(self): + N = self.N + + def fn(x, w): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + s3 = torch.cuda.Stream() + e = torch.cuda.Event() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + e3 = torch.cuda.Event() + + # Slow producer + a = TestStreamOrderingStress._heavy_matmul_chain(x, w) + e.record() + + with torch.cuda.stream(s1): + e.wait() + r1 = a * 2 + e1.record(s1) + + with torch.cuda.stream(s2): + e.wait() + r2 = a * 3 + e2.record(s2) + + with torch.cuda.stream(s3): + e.wait() + r3 = a * 4 + e3.record(s3) + + e1.wait() + e2.wait() + e3.wait() + result = r1 + r2 + r3 + + s1.synchronize() + s2.synchronize() + s3.synchronize() + return result + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 4. Race: diamond pattern with heavy work on both branches. + # The join must wait for both branches. + # ------------------------------------------------------------------ + def test_race_diamond(self): + N = self.N + + def fn(x, w): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + e_fork = torch.cuda.Event() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + + base = x @ w + e_fork.record() + + with torch.cuda.stream(s1): + e_fork.wait() + branch1 = TestStreamOrderingStress._heavy_matmul_chain( + torch.relu(base), w + ) + e1.record(s1) + + with torch.cuda.stream(s2): + e_fork.wait() + branch2 = TestStreamOrderingStress._heavy_matmul_chain( + torch.sigmoid(base), w + ) + e2.record(s2) + + e1.wait() + e2.wait() + result = branch1 + branch2 + + s1.synchronize() + s2.synchronize() + return result + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.5 + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 5. Race: 4-stage pipeline where each stage is heavy. + # Every event.wait() is load-bearing. + # ------------------------------------------------------------------ + def test_race_pipeline(self): + N = self.N + + def fn(x, w): + streams = [torch.cuda.Stream() for _ in range(4)] + events = [torch.cuda.Event() for _ in range(3)] + + with torch.cuda.stream(streams[0]): + h = TestStreamOrderingStress._heavy_matmul_chain(x, w, depth=4) + events[0].record(streams[0]) + + with torch.cuda.stream(streams[1]): + events[0].wait(streams[1]) + h = TestStreamOrderingStress._heavy_matmul_chain(h, w, depth=4) + events[1].record(streams[1]) + + with torch.cuda.stream(streams[2]): + events[1].wait(streams[2]) + h = TestStreamOrderingStress._heavy_matmul_chain(h, w, depth=4) + events[2].record(streams[2]) + + with torch.cuda.stream(streams[3]): + events[2].wait(streams[3]) + h = h + x # quick consumer — races if wait is missing + + for s in streams: + s.synchronize() + return h + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 6. Race: back-to-back sync, both directions carry heavy work + # ------------------------------------------------------------------ + def test_race_back_to_back(self): + N = self.N + + def fn(x, w): + s = torch.cuda.Stream() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + + a = TestStreamOrderingStress._heavy_matmul_chain(x, w) + e1.record() + + with torch.cuda.stream(s): + e1.wait() + b = TestStreamOrderingStress._heavy_matmul_chain(a, w) + e2.record(s) + + e2.wait() + c = b + 1 + + s.synchronize() + return c + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + self._check_compiled_matches_eager(fn, x, w) + + # ------------------------------------------------------------------ + # 7. Race: triton kernel on user stream. + # Without the triton stream fix the kernel launches on the default + # stream and reads stale/in-progress data from the user stream. + # ------------------------------------------------------------------ + def test_race_triton_on_user_stream(self): + N = self.N + + def fn(x, w): + s = torch.cuda.Stream() + e = torch.cuda.Event() + + with torch.cuda.stream(s): + # Heavy matmul chain produces data on user stream + a = TestStreamOrderingStress._heavy_matmul_chain(x, w) + # Triton pointwise on the same user stream — without fix + # this launches on the default stream + b = torch.relu(a) + e.record(s) + + e.wait() + s.synchronize() + return b + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + self._check_compiled_matches_eager(fn, x, w) + + @unittest.skipUnless(TEST_CUDA, "requires CUDA") class TestGenericStreamCompile(InductorTestCase): """Tests for torch.compile with device-agnostic torch.Stream API.""" @@ -1330,11 +1802,8 @@ def fn(x): self.assertEqual(result, expected) - # Verify event operations - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) + # Verify event operations survive compilation as custom ops + self.assertIn("record_event", code) def test_generic_stream_multiple(self): """Test compilation with multiple torch.Stream instances.""" @@ -1376,11 +1845,10 @@ def fn(x): self.assertEqual(result, expected) - # Verify stream handling - self.assertTrue( - "torch.cuda.stream" in code or "stream" in code.lower(), - "Expected stream context in generated code", - ) + # Verify stream handling and event ops survive + self.assertIn("torch.cuda.stream", code) + self.assertIn("record_event", code) + self.assertIn("wait_event", code) def test_generic_event_record_on_stream(self): """Test torch.Event.record() with explicit stream argument.""" @@ -1415,18 +1883,464 @@ def fn(x): self.assertEqual(result, expected) - # Verify event operations - self.assertTrue( - "record_event" in code or ".record(" in code, - "Expected record_event or .record( in generated code", - ) + # Verify event operations survive compilation as custom ops + self.assertIn("record_event", code) + + +@unittest.skipUnless(TEST_CUDA, "requires CUDA") +class TestStreamIdentity(InductorTestCase): + """Verify that compiled code uses the user's original stream objects.""" + + def test_single_stream_identity(self): + """Codegen should retrieve the user's stream via get_external_object_by_index.""" + from torch._inductor.utils import run_and_get_code + + user_stream = torch.cuda.Stream() + + def fn(x): + with torch.cuda.stream(user_stream): + return x * 2 + + x = torch.randn(1024, device="cuda") + result, (code,) = run_and_get_code(torch.compile(fn), x) + + self.assertEqual(result, fn(x)) + self.assertIn("get_external_object_by_index", code) + self.assertNotIn("torch.cuda.Stream(device=", code) + + def test_multiple_stream_identity(self): + """Each stream context should retrieve a different user stream object.""" + from torch._inductor.utils import run_and_get_code + + stream_a = torch.cuda.Stream() + stream_b = torch.cuda.Stream() + + def fn(x): + event = torch.cuda.Event() + with torch.cuda.stream(stream_a): + a = x * 2 + event.record() + with torch.cuda.stream(stream_b): + event.wait() + b = a + 1 + stream_b.synchronize() + return b + + x = torch.randn(1024, device="cuda") + result, (code,) = run_and_get_code(torch.compile(fn), x) + + self.assertEqual(result, fn(x)) + # Should have two distinct get_external_object_by_index calls + matches = re.findall(r"get_external_object_by_index\((\d+)\)", code) + self.assertEqual(len(matches), 2) + self.assertNotEqual(matches[0], matches[1]) + self.assertNotIn("torch.cuda.Stream(device=", code) + + +@unittest.skipUnless(TEST_CUDA, "requires CUDA") +class TestPDLWithMultiStream(InductorTestCase): + """Tests that PDL (Programmatic Dependent Launch) composes safely with + user-annotated multi-stream code under torch.compile. + + PDL's GDC intrinsics are stream-local: gdc_wait/gdc_launch_dependents + only govern the overlap between consecutive kernels on the *same* CUDA + stream. Cross-stream ordering is handled entirely by CUDA events at the + wrapper level. These tests verify that enabling PDL in the presence of + multi-stream code doesn't break correctness, doesn't interfere with + stream-level invariants (no cross-stream fusion, event ops preserved), + and still applies within each stream's own kernel sequence. + """ + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_single_side_stream(self): + """PDL metadata is emitted for a kernel on a side stream.""" + from torch._inductor.utils import run_and_get_code, run_and_get_triton_code + + def fn(x): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + a = x * 2 + b = a + 1 + s.synchronize() + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + + compiled_fn = torch.compile(fn) + result, (wrapper_code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + + self.assertIn("torch.cuda.stream", wrapper_code) + self.assertIn("synchronize_stream", wrapper_code) + + triton_code = run_and_get_triton_code(torch.compile(fn), x) + ( + FileCheck() + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_correctness_with_multiple_streams(self): + """Enabling PDL with independent side streams produces correct results.""" + from torch._inductor.utils import run_and_get_triton_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + + with torch.cuda.stream(s1): + a = x * 2 + 1 + e1.record(s1) + + with torch.cuda.stream(s2): + b = x * 3 + 2 + e2.record(s2) + + e1.wait() + e2.wait() + return a + b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + compiled_fn = torch.compile(fn) + self.assertEqual(compiled_fn(x), expected) + + triton_code = run_and_get_triton_code(torch.compile(fn), x) + # s1 kernel, s2 kernel, and default-stream add kernel + ( + FileCheck() + # s1 kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # s2 kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # default stream add + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_cross_stream_events_preserved(self): + """Event record/wait for cross-stream sync must survive with PDL on. + + PDL is stream-local so it cannot replace event-based cross-stream + ordering. Verify the events are still in the generated code.""" + from torch._inductor.utils import run_and_get_code, run_and_get_triton_code + + def fn(x): + s = torch.cuda.Stream() + event = torch.cuda.Event() + + a = x * 2 + 1 + event.record() + + with torch.cuda.stream(s): + event.wait() + b = a + 3 + + s.synchronize() + return b + + x = torch.randn(1024, device="cuda") + expected = fn(x) + + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + + self.assertIn("record_event", code) + self.assertIn("wait_event", code) + + # Both kernels (default + side stream) get PDL intrinsics + triton_code = run_and_get_triton_code(torch.compile(fn), x) + ( + FileCheck() + # default stream kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # side stream kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_same_stream_consecutive_kernels(self): + """Two consecutive kernels on the same side stream should both get PDL. + + This is the case where PDL is actually useful: the second kernel can + overlap with the first via GDC intrinsics because they share a stream.""" + from torch._inductor.utils import run_and_get_triton_code + + def fn(x, y): + s = torch.cuda.Stream() + with torch.cuda.stream(s): + # Two separate fused groups on the same stream + a = x**2 + x + y.copy_(a) + s.synchronize() + return y + + x = torch.randn(1024, device="cuda") + y = torch.empty(1024, device="cuda") + expected = fn(x, y.clone()) + compiled_fn = torch.compile(fn) + self.assertEqual(compiled_fn(x, y.clone()), expected) + + triton_code = run_and_get_triton_code(torch.compile(fn), x, y.clone()) + ( + FileCheck() + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_no_fusion_across_streams(self): + """PDL must not cause cross-stream ops to be fused.""" + from torch._inductor.utils import run_and_get_triton_code + + def fn(x): + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + e1 = torch.cuda.Event() + e2 = torch.cuda.Event() + + with torch.cuda.stream(s1): + a = x * 2 + b = a + 1 + e1.record(s1) + + with torch.cuda.stream(s2): + c = x * 3 + d = c + 2 + e2.record(s2) + + e1.wait() + e2.wait() + return b + d + + x = torch.randn(1024, device="cuda") + expected = fn(x) + + compiled_fn = torch.compile(fn) + torch._inductor.metrics.reset() + result = compiled_fn(x) + self.assertEqual(result, expected) + + # 3 kernels: s1 pointwise, s2 pointwise, default stream add + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 3) + + # All 3 kernels should have PDL with GDC intrinsics + triton_code = run_and_get_triton_code(torch.compile(fn), x) + ( + FileCheck() + # s1 kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # s2 kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # default stream add + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_stress_multistream_correctness(self): + """Stress test: heavy work across streams with PDL must produce + correct results over many iterations to surface any races. + + Uses 4096x4096 matmuls (matching TestStreamOrderingStress) so the + GPU work is long enough that a missing event.wait() would cause + the consumer to read stale data.""" + from torch._inductor.utils import run_and_get_code + + N = 4096 + ITERS = 20 + + def fn(x, w): + s = torch.cuda.Stream() + e = torch.cuda.Event() + + h = x + for _ in range(4): + h = h @ w + e.record() + + with torch.cuda.stream(s): + e.wait() + out = torch.relu(h) + 1.0 + + s.synchronize() + return out + + x = torch.randn(N, N, device="cuda") + w = torch.eye(N, device="cuda") * 0.9 + + # Verify codegen once before the stress loop + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x, w) + # Wrapper must have stream context and event sync + self.assertIn("torch.cuda.stream", code) + self.assertIn("wait_event", code) + # The relu+add pointwise kernel should have PDL + ( + FileCheck() + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(code) + + for _ in range(ITERS): + expected = fn(x, w) + actual = compiled_fn(x, w) + torch.cuda.synchronize() + self.assertEqual(actual, expected) + + @unittest.skipIf(not SM90OrLater or TEST_WITH_ROCM, "PDL requires NVIDIA sm90+") + @inductor_config.patch({"triton.enable_pdl": True}) + def test_pdl_mutation_across_streams(self): + """Buffer mutation on one stream, read on another, with PDL enabled. + + The mutation is on a locally-created buffer (not an input) to avoid + the dynamo guard that forbids event.record() after input mutation.""" + from torch._inductor.utils import run_and_get_code, run_and_get_triton_code + + def fn(x): + s = torch.cuda.Stream() + event = torch.cuda.Event() + + # Produce a new buffer (not input mutation) then record + a = x * 2 + event.record() + + with torch.cuda.stream(s): + event.wait() + # In-place add on side stream + a.add_(1) + + s.synchronize() + return a + + x = torch.randn(1024, device="cuda") + + expected = fn(x) + compiled_fn = torch.compile(fn) + result, (code,) = run_and_get_code(compiled_fn, x) + self.assertEqual(result, expected) + + self.assertIn("record_event", code) + self.assertIn("wait_event", code) + + # Both kernels (default + side stream) get PDL intrinsics + triton_code = run_and_get_triton_code(torch.compile(fn), x) + ( + FileCheck() + # default stream kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + # side stream kernel + .check("'launch_pdl': True") + .check("gdc_wait") + .check("gdc_launch") + ).run(triton_code) + + +@unittest.skipIf(not TEST_CUDA, "requires CUDA") +@torch._inductor.config.patch({"triton.cudagraphs": True}) +class TestStreamCudagraphInteraction(InductorTestCase): + """Tests for user streams under cudagraph capture (reduce-overhead mode).""" + + def test_implicit_current_stream_with_cudagraphs(self): + """Event record/wait with implicit current stream must work under cudagraph capture. + + The implicit current stream resolves at runtime via torch.cuda.current_stream(), + which correctly returns the cudagraph capture stream during recording. + """ + s1 = torch.cuda.Stream() + ev = torch.cuda.Event() + ev2 = torch.cuda.Event() + + def fn(x, y): + ev.record() + with torch.cuda.stream(s1): + ev.wait() + z = x * 2 + ev2.record() + ev2.wait() + return z + y + + x = torch.randn(100, 100, device="cuda") + y = torch.randn(100, 100, device="cuda") + + expected = fn(x, y) + compiled_fn = torch.compile(fn) + # Warmup + capture + replay + for _ in range(3): + result = compiled_fn(x, y) + self.assertEqual(result, expected) + + def test_explicit_current_stream_with_cudagraphs(self): + """Passing torch.cuda.current_stream() explicitly must also work under capture. + + The user writes ev.record(torch.cuda.current_stream()) which is + semantically identical to ev.record() — both should resolve to the + capture stream during cudagraph recording, not the stale default stream. + """ + s1 = torch.cuda.Stream() + ev = torch.cuda.Event() + ev2 = torch.cuda.Event() + + def fn(x, y): + cur = torch.cuda.current_stream() + ev.record(cur) + with torch.cuda.stream(s1): + ev.wait() + z = x * 2 + ev2.record() + ev2.wait(cur) + return z + y + + x = torch.randn(100, 100, device="cuda") + y = torch.randn(100, 100, device="cuda") + + expected = fn(x, y) + compiled_fn = torch.compile(fn) + for _ in range(3): + result = compiled_fn(x, y) + self.assertEqual(result, expected) instantiate_parametrized_tests(TestStreamUtils) instantiate_parametrized_tests(TestWrapperCodegenStreams) instantiate_parametrized_tests(TestStreamCodegen) instantiate_parametrized_tests(TestUserStreamCompile) +instantiate_parametrized_tests(TestStreamOrderingStress) instantiate_parametrized_tests(TestGenericStreamCompile) +instantiate_parametrized_tests(TestStreamIdentity) +instantiate_parametrized_tests(TestPDLWithMultiStream) +instantiate_parametrized_tests(TestStreamCudagraphInteraction) if __name__ == "__main__": diff --git a/test/jit/_imported_class_test/bar.py b/test/jit/_imported_class_test/bar.py index 67f0996065ca6..052e517345af7 100644 --- a/test/jit/_imported_class_test/bar.py +++ b/test/jit/_imported_class_test/bar.py @@ -5,7 +5,7 @@ # They are used by test_jit.py to test ScriptClass imports -@torch.jit.script # noqa: B903 -class FooSameName: # noqa: B903 +@torch.jit.script +class FooSameName: def __init__(self, y): self.y = y diff --git a/test/jit/_imported_class_test/foo.py b/test/jit/_imported_class_test/foo.py index 8d5281567c815..218f5839865c6 100644 --- a/test/jit/_imported_class_test/foo.py +++ b/test/jit/_imported_class_test/foo.py @@ -7,7 +7,7 @@ # They are used by test_jit.py to test ScriptClass imports -@torch.jit.script # noqa: B903 +@torch.jit.script class FooSameName: def __init__(self, x): self.x = x diff --git a/test/jit/_imported_class_test/very/very/nested.py b/test/jit/_imported_class_test/very/very/nested.py index a931d37971d1b..599ae29a035cf 100644 --- a/test/jit/_imported_class_test/very/very/nested.py +++ b/test/jit/_imported_class_test/very/very/nested.py @@ -5,7 +5,7 @@ # They are used by test_jit.py to test ScriptClass imports -@torch.jit.script # noqa: B903 -class FooUniqueName: # noqa: B903 +@torch.jit.script +class FooUniqueName: def __init__(self, y): self.y = y diff --git a/test/jit/test_attr.py b/test/jit/test_attr.py index d9d5fab1615ae..e7980a1ed1739 100644 --- a/test/jit/test_attr.py +++ b/test/jit/test_attr.py @@ -18,7 +18,7 @@ def __init__(self) -> None: def forward(self, x): y = getattr(self, "init_attr_val") # noqa: B009 w: list[float] = [1.0] - z = getattr(self, "missing", w) # noqa: B009 + z = getattr(self, "missing", w) z.append(y) return z diff --git a/test/jit/test_await.py b/test/jit/test_await.py index 0f538fd9b909a..37f1d077b2751 100644 --- a/test/jit/test_await.py +++ b/test/jit/test_await.py @@ -194,7 +194,7 @@ def b(self): def C_wait_impl(self: C) -> C: return C(self._a * 2, self._b * 3) - def fn_arg_C(x: C) -> Tensor: # noqa: F841 + def fn_arg_C(x: C) -> Tensor: return x._a + x._b def fn(x: Tensor): diff --git a/test/jit/test_class_type.py b/test/jit/test_class_type.py index 4b5f2ad9a0d77..282361a69d9b8 100644 --- a/test/jit/test_class_type.py +++ b/test/jit/test_class_type.py @@ -78,7 +78,7 @@ def fn(x): self.assertEqual(fn(input), input) def test_get_attr(self): - class FooTest: # noqa: B903 + class FooTest: def __init__(self, x): self.foo = x @@ -91,7 +91,7 @@ def fn(x): self.assertEqual(fn(input), input) def test_in(self): - class FooTest: # noqa: B903 + class FooTest: def __init__(self) -> None: pass @@ -176,8 +176,8 @@ def test_type_annotations(self): RuntimeError, "Expected a value of type 'bool", "" ): - @torch.jit.script # noqa: B903 - class FooTest: # noqa: B903 + @torch.jit.script + class FooTest: def __init__(self, x: bool) -> None: self.foo = x @@ -199,7 +199,7 @@ def __init__(self, x): self.attr = x def test_class_type_as_param(self): - class FooTest: # noqa: B903 + class FooTest: def __init__(self, x): self.attr = x @@ -296,7 +296,7 @@ def forward(self, a): self.assertEqual(input, output) def test_save_load_with_classes_nested(self): - class FooNestedTest: # noqa: B903 + class FooNestedTest: def __init__(self, y): self.y = y @@ -334,7 +334,7 @@ def forward(self, a): self.assertEqual(2 * input, output) def test_python_interop(self): - class Foo: # noqa: B903 + class Foo: def __init__(self, x, y): self.x = x self.y = y @@ -360,7 +360,7 @@ def use_foo(foo: Foo) -> Foo: self.assertEqual(y, f2.y) def test_class_specialization(self): - class Foo: # noqa: B903 + class Foo: def __init__(self, x, y): self.x = x self.y = y @@ -384,7 +384,7 @@ def use_foo(foo: Foo, foo2: Foo, tup: Tuple[Foo, Foo]) -> torch.Tensor: FileCheck().check_count("prim::GetAttr", 4).run(graphstr) def test_class_sorting(self): - class Foo: # noqa: B903 + class Foo: def __init__(self, x: int) -> None: self.x = x @@ -961,8 +961,8 @@ def test(): print(1) def test_init_compiled_first(self): - @torch.jit.script # noqa: B903 - class Foo: # noqa: B903 + @torch.jit.script + class Foo: def __before_init__(self): # accessing this field should not throw, since __init__ should be compiled return self.x @@ -972,8 +972,8 @@ def __init__(self, x, y): self.y = y def test_class_constructs_itself(self): - @torch.jit.script # noqa: B903 - class LSTMStateStack: # noqa: B903 + @torch.jit.script + class LSTMStateStack: def __init__(self, num_layers: int, hidden_size: int) -> None: self.num_layers = num_layers self.hidden_size = hidden_size @@ -996,8 +996,8 @@ def __init__(self) -> None: self.x = 1 # should not throw - @torch.jit.script # noqa: B903 - class Tree: # noqa: B903 + @torch.jit.script + class Tree: def __init__(self) -> None: self.child = torch.jit.annotate(Optional[Leaf], None) @@ -1010,8 +1010,8 @@ def test_recursive_class(self): """ with self.assertRaises(RuntimeError): - @torch.jit.script # noqa: B903 - class Tree: # noqa: B903 + @torch.jit.script + class Tree: def __init__(self) -> None: self.parent = torch.jit.annotate(Optional[Tree], None) @@ -1229,7 +1229,7 @@ def method_defaults() -> float: self.checkScript(method_defaults, ()) # The constructor of this class below has some arguments without default values. - class ClassWithSomeDefaultArgs: # noqa: B903 + class ClassWithSomeDefaultArgs: def __init__( self, a: int, @@ -1251,7 +1251,7 @@ def set_b() -> int: # The constructor of this class below has mutable arguments. This should throw # an error. - class ClassWithMutableArgs: # noqa: B903 + class ClassWithMutableArgs: def __init__( self, a: List[int] = [1, 2, 3], # noqa: B006 @@ -1538,8 +1538,8 @@ def test_class_attribute_wrong_type(self): to an IValue that has an attribute of the wrong type. """ - @torch.jit.script # noqa: B903 - class ValHolder: # noqa: B903 + @torch.jit.script + class ValHolder: def __init__(self, val): self.val = val diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py index 8cfe63faa0e6a..ade90b181e1c6 100644 --- a/test/jit/test_cuda.py +++ b/test/jit/test_cuda.py @@ -26,7 +26,7 @@ # If GPU is not available, then do not run the tests if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - JitTestCase = NoTest # noqa: F811 + JitTestCase = NoTest TEST_LARGE_TENSOR = TEST_CUDA diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py index a695e0e9b3f7d..cbd21616f40e6 100644 --- a/test/jit/test_freezing.py +++ b/test/jit/test_freezing.py @@ -15,11 +15,13 @@ from torch.testing._internal.common_quantization import skipIfNoFBGEMM from torch.testing._internal.common_quantized import override_quantized_engine from torch.testing._internal.common_utils import ( + IS_ARM64, raise_on_run_directly, set_default_dtype, skipCUDAMemoryLeakCheckIf, skipIfTorchDynamo, TEST_WITH_ROCM, + xfailIf, ) from torch.testing._internal.jit_utils import JitTestCase from torch.utils import mkldnn as mkldnn_utils @@ -1050,8 +1052,8 @@ def forward(self, x): m_f = torch._C._freeze_module(m_s._c) def test_freeze_module_inlining(self): - @torch.jit.script # noqa: B903 - class Obj: # noqa: B903 + @torch.jit.script + class Obj: def __init__(self, x: int, y: int): self.x = x self.y = y @@ -1433,6 +1435,8 @@ def forward(self, x): self.assertTrue(fm.sub._has_method("method_a")) self.assertFalse(fm.sub._has_method("method_b")) + @xfailIf(IS_ARM64) + # see https://github.com/pytorch/pytorch/issues/177258 @skipIfNoFBGEMM def test_module_with_shared_type_instances(self): class Child(nn.Module): diff --git a/test/jit/test_generator.py b/test/jit/test_generator.py index 6fe3558206397..42dcba100104b 100644 --- a/test/jit/test_generator.py +++ b/test/jit/test_generator.py @@ -176,16 +176,16 @@ def forward(self, x): try: self.assertEqual(out1, out2) - except: # noqa: B001, E722 + except: print(f"Iteration {i}:\n{out1=}\n{out2=}") raise try: self.assertEqual(r1, r2) - except: # noqa: B001, E722 + except: print(f"Iteration {i}:\n{r1=}\n{r2=}") raise - except: # noqa: B001, E722 + except: print(loaded_module.forward.code) raise diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 1949ec46557dd..505e94d12bdc6 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -252,7 +252,7 @@ def foo(): self.checkScript(foo, ()) def foo2(): - x: List[int] = list() # noqa: C408 + x: List[int] = list() x.append(1) return (x,) @@ -328,7 +328,7 @@ def fn(): def test_dict_keyword_is_correctly_typed(self): def fn(): - x: Dict[str, int] = dict() # noqa: C408 + x: Dict[str, int] = dict() x["foo"] = 1 return x @@ -2027,7 +2027,7 @@ def no_args(): test_func(no_args, ()) def test_dict_constructor(): - a = dict() # noqa: C408 + a = dict() a["one"] = torch.tensor(1) return a, dict([(1, 2), (2, 3), (1, 4)]) # noqa: C406 @@ -2043,7 +2043,7 @@ def test_dict_initializer_list(): test_func(test_dict_initializer_list, ()) def test_dict_error(): - a = dict() # noqa: C408 + a = dict() a[1] = 2 return a diff --git a/test/jit/test_logging.py b/test/jit/test_logging.py index 37c379bde6c1b..26dcbfc6c6f26 100644 --- a/test/jit/test_logging.py +++ b/test/jit/test_logging.py @@ -1,5 +1,4 @@ # Owner(s): ["oncall: jit"] -# ruff: noqa: F841 import os import sys diff --git a/test/jit/test_pdt.py b/test/jit/test_pdt.py index ae48a0daa1df4..a16275098919c 100644 --- a/test/jit/test_pdt.py +++ b/test/jit/test_pdt.py @@ -19,7 +19,7 @@ "monkeytype is not installed. Skipping tests for Profile-Directed Typing", file=sys.stderr, ) - JitTestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 + JitTestCase = NoTest # type: ignore[misc, assignment] class TestPDT(JitTestCase): diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index f697e74ae9ac1..54c1148278bc0 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -181,7 +181,7 @@ class MyInterface: def not_bar(self, x: Tensor) -> Tensor: pass - @torch.jit.script # noqa: F811 + @torch.jit.script class ImplementInterface: # noqa: F811 def __init__(self) -> None: pass @@ -277,7 +277,7 @@ class MyInterface: def not_bar(self, x: Tensor) -> Tensor: pass - @torch.jit.script # noqa: F811 + @torch.jit.script class ImplementInterface: # noqa: F811 def __init__(self) -> None: pass @@ -900,7 +900,7 @@ class MyInterface: def not_bar(self, x: Tensor) -> Tensor: pass - @torch.jit.script # noqa: F811 + @torch.jit.script class ImplementInterface: # noqa: F811 def __init__(self) -> None: pass @@ -990,7 +990,7 @@ class MyInterface: def not_bar(self, x: Tensor) -> Tensor: pass - @torch.jit.script # noqa: F811 + @torch.jit.script class ImplementInterface: # noqa: F811 def __init__(self) -> None: pass diff --git a/test/jit/test_string_formatting.py b/test/jit/test_string_formatting.py index 295ae85e3fb98..06dc58053a2c4 100644 --- a/test/jit/test_string_formatting.py +++ b/test/jit/test_string_formatting.py @@ -184,7 +184,7 @@ def fn(arg1: str, arg2: str) -> str: def test_string_interpolation_with_unknown_format_specifier(self): @torch.jit.script def fn(arg1: str) -> str: - return "%a in template" % arg1 # noqa: F501 + return "%a in template" % arg1 with self.assertRaisesRegexWithHighlight( RuntimeError, diff --git a/test/jit/test_union.py b/test/jit/test_union.py index c5afa13463221..80eb73f09c552 100644 --- a/test/jit/test_union.py +++ b/test/jit/test_union.py @@ -125,8 +125,8 @@ def fn(x: Union[str, Color]) -> str: scripted(1) def test_union_in_class_constructor(self): - @torch.jit.script # noqa: B903 - class A: # noqa: B903 + @torch.jit.script + class A: def __init__(self, x: Union[int, str]) -> None: self.x = x diff --git a/test/jit/test_union_pep604.py b/test/jit/test_union_pep604.py index 953ce52c49786..619afe158d7bf 100644 --- a/test/jit/test_union_pep604.py +++ b/test/jit/test_union_pep604.py @@ -126,8 +126,8 @@ def fn(x: str | Color) -> str: scripted(1) def test_union_in_class_constructor(self): - @torch.jit.script # noqa: B903 - class A: # noqa: B903 + @torch.jit.script + class A: def __init__(self, x: int | str) -> None: self.x = x diff --git a/test/jit_hooks/CMakeLists.txt b/test/jit_hooks/CMakeLists.txt index ba32390b0b366..606625148f2e5 100644 --- a/test/jit_hooks/CMakeLists.txt +++ b/test/jit_hooks/CMakeLists.txt @@ -9,5 +9,5 @@ endif() find_package(Torch REQUIRED) add_executable(test_jit_hooks test_jit_hooks.cpp) -set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 17) +set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 20) target_link_libraries(test_jit_hooks "${TORCH_LIBRARIES}") diff --git a/test/lazy/test_functionalization.py b/test/lazy/test_functionalization.py index c563d1f99cb30..ad791ddd9473a 100644 --- a/test/lazy/test_functionalization.py +++ b/test/lazy/test_functionalization.py @@ -97,7 +97,7 @@ def text(lazyt): %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0 %1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0 } -""", # noqa: B950 +""", ) diff --git a/test/lazy/test_generator.py b/test/lazy/test_generator.py index 36cf8c52df5d8..2a27ec328d2ac 100644 --- a/test/lazy/test_generator.py +++ b/test/lazy/test_generator.py @@ -3,7 +3,7 @@ import torch import torch._lazy.metrics as metrics import torch._lazy.ts_backend -from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase torch._lazy.ts_backend.init() @@ -47,7 +47,6 @@ def generate_tensor(): if not torch.allclose(cpu_t2, lazy_t2.to("cpu")): raise AssertionError(f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}") - @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type") def test_generator_causes_multiple_compiles(self): """ Test that inserting generators with different seed caused recompile diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index bc88867bd50b3..6d07e1c4f7907 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -201,7 +201,7 @@ class TestLazyOpInfo(TestCase): allowed_dtypes=(torch.float,), ) def test_dispatched_to_lazy(self, device, dtype, op): - def get_name(op): # noqa: F841 + def get_name(op): l = [op.name] if op.variant_test_name != "": l.append(op.variant_test_name) @@ -239,7 +239,7 @@ def get_name(op): # noqa: F841 and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST ], allowed_dtypes=(torch.float,), - ) # noqa: B950 + ) def test_correctness(self, device, dtype, op): test_device = get_test_device() @@ -284,7 +284,7 @@ def assert_allclose_rec(t): and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST ], allowed_dtypes=(torch.float,), - ) # noqa: B950 + ) def test_correctness_with_reusing_ir(self, device, dtype, op): torch._lazy.config.set_reuse_ir(True) test_device = get_test_device() diff --git a/test/metal/test_kernels.metal b/test/metal/test_kernels.metal new file mode 100644 index 0000000000000..eb337af74b838 --- /dev/null +++ b/test/metal/test_kernels.metal @@ -0,0 +1,12 @@ +#include +using namespace metal; + +kernel void square(device float *data [[buffer(0)]], + uint idx [[thread_position_in_grid]]) { + data[idx] = data[idx] * data[idx]; +} + +kernel void inc_inplace(device float *data [[buffer(0)]], + uint idx [[thread_position_in_grid]]) { + data[idx] = data[idx] + 1.0; +} diff --git a/test/metal/test_kernels.metallib b/test/metal/test_kernels.metallib new file mode 100644 index 0000000000000..3a74863cee321 Binary files /dev/null and b/test/metal/test_kernels.metallib differ diff --git a/test/mobile/custom_build/CMakeLists.txt b/test/mobile/custom_build/CMakeLists.txt index 52e713895ff8a..2629a8848d6b0 100644 --- a/test/mobile/custom_build/CMakeLists.txt +++ b/test/mobile/custom_build/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15) project(custom_build_project) -set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") # Find torch library find_package(Torch REQUIRED) diff --git a/test/mobile/test_upgraders.py b/test/mobile/test_upgraders.py index 3567e0d030b4c..340fc916e8b0f 100644 --- a/test/mobile/test_upgraders.py +++ b/test/mobile/test_upgraders.py @@ -31,8 +31,7 @@ def _try_fn(self, fn, *args, **kwargs): return e def test_versioned_div_tensor(self): - # noqa: F841 - def div_tensor_0_3(self, other): # noqa: F841 + def div_tensor_0_3(self, other): if self.is_floating_point() or other.is_floating_point(): return self.true_divide(other) return self.divide(other, rounding_mode="trunc") diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index 35f715a88c376..a7011253c0a84 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -58,6 +58,8 @@ def _get_cudnn_version(): GRADCHECK_NONDET_TOL, gradgradcheck, instantiate_parametrized_tests, + IS_ARM64, + IS_LINUX, MACOS_VERSION, MI300_ARCH, parametrize as parametrize_test, @@ -3197,6 +3199,8 @@ def _make_noncontiguous(inp): gradgradcheck(convolution, inputs, nondet_tol=gradcheck_nondet_tol) ) + @xfailIf(IS_LINUX and IS_ARM64) + # see https://github.com/pytorch/pytorch/issues/177245 @onlyCPU def test_conv_contiguous_for_oneDNN(self): # See https://github.com/pytorch/pytorch/issues/80837. @@ -3223,6 +3227,8 @@ def test_conv_contiguous_for_oneDNN(self): y_ = conv(x2) self.assertEqual(y, y_) + @xfailIf(IS_LINUX and IS_ARM64) + # see https://github.com/pytorch/pytorch/issues/177245 @onlyCPU def test_conv_ic1_channels_last_for_oneDNN(self): # See https://github.com/pytorch/pytorch/issues/82060, N > 1 will call in OneDNN path. @@ -3575,7 +3581,7 @@ def test_ConvTranspose3d_size_1_kernel(self, device): @dtypes(torch.float) @torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False) @torch.backends.miopen.flags(immediate=True) - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.005) def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions m = nn.Conv2d(4, 4, kernel_size=3, groups=2).to(device, dtype) @@ -3634,6 +3640,33 @@ def conv2d_depthwise(x, weight): with torch.backends.cudnn.flags(enabled=cudnn_enabled): torch.autograd.gradcheck(conv2d_depthwise, (x, weight)) + @onlyCUDA + @skipCUDAIfNoCudnn + @skipCUDAIfRocm + @dtypes(torch.half) + def test_Conv2d_depthwise_kernel_flag(self, device, dtype): + # Use shapes that qualify for the cuDNN depthwise path: + # FP16, depthwise (groups==channels), 4D, no dilation, >= 32 channels + channels = 32 + x = torch.randn(2, channels, 16, 16, device=device, dtype=dtype) + conv = nn.Conv2d( + channels, channels, kernel_size=3, padding=1, groups=channels + ).to(device, dtype) + + # All three modes should produce the same numerics + results = {} + for mode in ("auto", "cudnn", "native"): + with torch.backends.cudnn.flags( + enabled=True, + benchmark=False, + deterministic=True, + depthwise_kernel=mode, + ): + results[mode] = conv(x).detach().clone() + + self.assertEqual(results["cudnn"], results["native"], atol=1e-3, rtol=1e-3) + self.assertEqual(results["auto"], results["native"], atol=1e-3, rtol=1e-3) + @onlyCPU @dtypes(torch.float, torch.double) def test_conv_thnn_nhwc(self, device, dtype): diff --git a/test/nn/test_dropout.py b/test/nn/test_dropout.py index 9d26f3c9c16db..5110d87525634 100644 --- a/test/nn/test_dropout.py +++ b/test/nn/test_dropout.py @@ -11,7 +11,6 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfMPS, - expectedFailureMPS, expectedFailureXLA, instantiate_device_type_tests, ) @@ -216,7 +215,6 @@ def _test_dropoutNd_channel_zero(self, dropout, input): @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA @dtypes(torch.double) @dtypesIfMPS(torch.float32) - @expectedFailureMPS def test_Dropout1d(self, device, dtype): with set_default_dtype(dtype): N, C, L = ( @@ -289,7 +287,6 @@ def test_Dropout2d(self, device): self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input) @expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA - @expectedFailureMPS # Failing on current pytorch MPS def test_Dropout3d(self, device): b = random.randint(1, 5) w = random.randint(1, 5) diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index 8f6847f18f5b2..297e0c5e6d5a5 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -323,6 +323,38 @@ def test_embedding_scalar_weight_error(self, device): with self.assertRaisesRegex(RuntimeError, "'weight' must be 2-D"): torch.nn.functional.embedding(indices, weight) + def test_embedding_float_indices_error(self, device): + # Regression test for https://github.com/pytorch/pytorch/issues/178042 + # torch.compile should raise the same dtype error as eager when + # nn.Embedding receives float indices, even if the result is unused. + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, kernel_size=1) + self.embedding = torch.nn.Embedding(100, 32) + self.fc = torch.nn.Linear(16 * 32 * 32, 10) + + def forward(self, x, token_ids): + conv_out = self.conv(x) + gelu_out = torch.nn.functional.gelu(conv_out) + self.embedding(token_ids) # result unused (dead code) + return self.fc(gelu_out.view(x.size(0), -1)) + + model = Model().to(device) + x = torch.randn(2, 3, 32, 32, device=device) + float_indices = torch.randn(2, 8, device=device) + + error_msg = "Expected tensor for argument #1 'indices'" + + with self.assertRaisesRegex(RuntimeError, error_msg): + model(x, float_indices) + + for backend in ["aot_eager", "inductor"]: + torch._dynamo.reset() + compiled = torch.compile(model, backend=backend, fullgraph=True) + with self.assertRaisesRegex(RuntimeError, error_msg): + compiled(x, float_indices) + @dtypesIfCUDA(torch.float16, torch.float64) @dtypesIfXPU(torch.float16, torch.float64) @dtypes(torch.float64) diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 7d2564313a60f..00e14424f9ada 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -1542,7 +1542,7 @@ def hook_pre(mod, grad_output): # Input inplace error should throw an error with self.assertRaisesRegex( RuntimeError, - "Output 0 of BackwardHookFunctionBackward is " + "Output 0 of BackwardHookFunction is " "a view and is being modified inplace.", ): mod(inp.clone(), True) @@ -1554,7 +1554,7 @@ def hook_pre(mod, grad_output): local_inp[0] *= 1 with self.assertRaisesRegex( RuntimeError, - "Output 0 of BackwardHookFunctionBackward is " + "Output 0 of BackwardHookFunction is " "a view and its base or another view", ): # Any operation involving the view will fail here @@ -1564,8 +1564,7 @@ def hook_pre(mod, grad_output): out = mod(inp, False) with self.assertRaisesRegex( RuntimeError, - "BackwardHookFunctionBackward is a view " - "and is being modified inplace.", + "BackwardHookFunction is a view and is being modified inplace.", ): out += 1 diff --git a/test/onnx/torchlib/ops_test_data.py b/test/onnx/torchlib/ops_test_data.py index 916e180c0fdfd..510af85301fe5 100644 --- a/test/onnx/torchlib/ops_test_data.py +++ b/test/onnx/torchlib/ops_test_data.py @@ -1,5 +1,4 @@ # Owner(s): ["module: onnx"] -# flake8: noqa: B950 """Test op correctness by comparing with PyTorch results. ## Usage diff --git a/test/package/package_a/long_name.py b/test/package/package_a/long_name.py index dd315223e8562..681b708180d00 100644 --- a/test/package/package_a/long_name.py +++ b/test/package/package_a/long_name.py @@ -1,9 +1,8 @@ def add_function(d): - # noqa: B950 d.append( function_with_a_long_name_256charsplus_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx ) -def function_with_a_long_name_256charsplus_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx(): # noqa: B950 +def function_with_a_long_name_256charsplus_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx(): return 1337 diff --git a/test/package/package_a/std_sys_module_hacks.py b/test/package/package_a/std_sys_module_hacks.py index bb7435cb1243b..2158baa3dda17 100644 --- a/test/package/package_a/std_sys_module_hacks.py +++ b/test/package/package_a/std_sys_module_hacks.py @@ -1,7 +1,7 @@ -import os # noqa: F401 -import os.path # noqa: F401 -import typing # noqa: F401 -import typing.io # noqa: F401 +import os +import os.path +import typing +import typing.io import typing.re # noqa: F401 import torch diff --git a/test/package/package_a/std_sys_module_hacks_3_13.py b/test/package/package_a/std_sys_module_hacks_3_13.py index 1245dc79a0b8a..a16dc120bb64f 100644 --- a/test/package/package_a/std_sys_module_hacks_3_13.py +++ b/test/package/package_a/std_sys_module_hacks_3_13.py @@ -1,5 +1,5 @@ -import os # noqa: F401 -import os.path # noqa: F401 +import os +import os.path import typing # noqa: F401 import torch diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index a2e80c2d26e4b..75c0fb4504db6 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -6,6 +6,7 @@ import tempfile import unittest from typing import Any +from unittest.mock import patch import numpy as np @@ -222,11 +223,14 @@ def trace_handler(p): @unittest.skipIf(not kineto_available(), "Kineto is required") @skipIfHpu @skipIfTorchDynamo("profiler gets ignored if dynamo activated") + @patch.dict( + os.environ, + { + "ENABLE_PYTORCH_EXECUTION_TRACE": "1", + "ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS": "1", + }, + ) def test_execution_trace_env_enabled_with_kineto(self, device): - import os - - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1" - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1" trace_called_num = 0 def trace_handler(p): @@ -364,11 +368,14 @@ def test_execution_trace_alone(self, device): f"Expected {expected_loop_events} loop events, got {loop_count}" ) + @patch.dict( + os.environ, + { + "ENABLE_PYTORCH_EXECUTION_TRACE": "0", + "ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS": "0", + }, + ) def test_execution_trace_env_disabled(self, device): - import os - - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "0" - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "0" use_device = ( torch.profiler.ProfilerActivity.CUDA or torch.profiler.ProfilerActivity.HPU in supported_activities() @@ -466,17 +473,19 @@ def fn(a, b, c): "need triton and device(CUDA or XPU) availability to run", ) @skipCPUIf(True, "skip CPU device for testing profiling triton") + @patch.dict( + os.environ, + { + "ENABLE_PYTORCH_EXECUTION_TRACE": "1", + "ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS": "1", + }, + ) def test_execution_trace_env_enabled_with_pt2(self, device): # clean up the local cache for triton kernel from torch._inductor.codecache import PyCodeCache PyCodeCache.cache_clear(purge=True) - import os - - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE"] = "1" - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_EXTRAS"] = "1" - @torchdynamo.optimize("inductor") def fn(a, b, c): x = torch.nn.functional.linear(a, b) @@ -606,11 +615,11 @@ def fn(a, b, c): expected_graph = [ f'# %mm : Tensor "f32[4, 4][4, 1]{device}" = PlaceHolder[target=mm]', f'# %arg2_1 : Tensor "f32[4, 4][4, 1]{device}" = PlaceHolder[target=arg2_1]', - f'# %sin : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {{}})', # noqa: B950 - f'# %permute_1 : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {{}})', # noqa: B950 - f'# %mul : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {{}})', # noqa: B950 - f'# %add : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {{}})', # noqa: B950 - f'# %cos : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {{}})', # noqa: B950 + f'# %sin : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.sin.default](args = (%mm,), kwargs = {{}})', + f'# %permute_1 : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%sin, [1, 0]), kwargs = {{}})', + f'# %mul : Tensor "f32[4, 4][4, 1]{device}"[num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg2_1, 1111), kwargs = {{}})', + f'# %add : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute_1, %mul), kwargs = {{}})', + f'# %cos : Tensor "f32[4, 4][1, 4]{device}"[num_users=1] = call_function[target=torch.ops.aten.cos.default](args = (%add,), kwargs = {{}})', "# return %cos", ] if len(fx_graph) < len(expected_graph): @@ -766,8 +775,10 @@ def fn(nt): not TEST_CUDA, "need CUDA device availability to run", ) + @patch.dict( + os.environ, {"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_RANGE": "1"} + ) def test_execution_trace_record_integral_tensor_range(self): - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_RANGE"] = "1" t1 = torch.tensor([[1, 2], [3, 4]]).cuda() t2 = torch.tensor([[0, 0], [1, 0]]).cuda() with ( @@ -805,13 +816,13 @@ def test_execution_trace_record_integral_tensor_range(self): not TEST_CUDA, "need CUDA device availability to run", ) + @patch.dict( + os.environ, + {"ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA": "aten::gather"}, + ) def test_execution_trace_record_integral_tensor_data(self): with tempfile.TemporaryDirectory() as temp_dir: fp_name = os.path.join(temp_dir, "test.et.json") - - os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_SAVE_INTEGRAL_TENSOR_DATA"] = ( - "aten::gather" - ) et = ExecutionTraceObserver() et.register_callback(fp_name) et.set_extra_resource_collection(True) diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index c9b01054929eb..936e31225e82c 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -494,7 +494,7 @@ def f0(x, y): z = x.mul(y) return {"z": z.view_as(z)} - def f1(x, y): # noqa: F841 + def f1(x, y): with torch.no_grad(): return f0(x, y) @@ -1123,7 +1123,7 @@ def test_categories_e2e_simple_fwd(self) -> None: w1 = torch.ones((1,), requires_grad=True) def step_fn(_): - x = torch.ones((2, 2)) # noqa: F841 + x = torch.ones((2, 2)) y = torch.cat([x * w0, x * w1], dim=1) # noqa: F841 # NOTE: We expect that all unknown categories. This is simply a sanity diff --git a/test/profiler/test_memory_viz.js b/test/profiler/test_memory_viz.js new file mode 100644 index 0000000000000..45415e5acd238 --- /dev/null +++ b/test/profiler/test_memory_viz.js @@ -0,0 +1,1510 @@ +// Test cases for process_alloc_data in torch/utils/viz/MemoryViz.js +// Run: node test/profiler/test_memory_viz.js + +'use strict'; + +// Polyfill for Node < 18 +if (!Array.prototype.findLastIndex) { + Array.prototype.findLastIndex = function(pred) { + for (let i = this.length - 1; i >= 0; i--) { + if (pred(this[i], i, this)) return i; + } + return -1; + }; +} +if (!Array.prototype.at) { + Array.prototype.at = function(n) { + return n < 0 ? this[this.length + n] : this[n]; + }; +} + +// ============================================================ +// Load process_alloc_data from the actual MemoryViz.js source +// ============================================================ + +const fs = require('fs'); +const path = require('path'); +const vm = require('vm'); + +// Load process_alloc_data.js (ESM file) in Node.js CommonJS context. +// Strip the `export` line since Node CommonJS doesn't support ESM syntax. +const modPath = path.resolve(__dirname, '../../torch/utils/viz/process_alloc_data.js'); +let src = fs.readFileSync(modPath, 'utf-8'); +src = src.replace(/^export\s*\{[^}]*\};?\s*$/gm, ''); +const wrapper = `(function() { ${src}\nreturn { process_alloc_data, isPrivatePoolId, formatSize, formatAddr, elideRepeats }; })()`; +const { process_alloc_data, isPrivatePoolId, formatSize, formatAddr, elideRepeats } = vm.runInThisContext(wrapper, { filename: modPath }); + +// ============================================================ +// Test helpers +// ============================================================ + +function makeSnapshot({ traces = [], segments = [], categories = [] }) { + return { + device_traces: [traces], + segments, + categories, + }; +} + +let passed = 0; +let failed = 0; + +function assert(condition, msg) { + if (!condition) { + failed++; + console.error(` FAIL: ${msg}`); + console.trace(); + } else { + passed++; + } +} + +function assertEqual(actual, expected, msg) { + if (actual !== expected) { + failed++; + console.error(` FAIL: ${msg} — expected ${expected}, got ${actual}`); + } else { + passed++; + } +} + +function assertContains(str, substr, msg) { + if (!str.includes(substr)) { + failed++; + console.error(` FAIL: ${msg} — "${substr}" not found in "${str.slice(0, 200)}..."`); + } else { + passed++; + } +} + +// ============================================================ +// Tests +// ============================================================ + +function test_basic_alloc_free() { + console.log('test_basic_alloc_free'); + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 1000, size: 100, frames: [], stream: 0 }, + { action: 'free_completed', addr: 1000, size: 100, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + const result = process_alloc_data(snapshot, 0, false, 15000, false); + assertEqual(result.max_size, 100, 'peak should be 100'); +} + +function test_free_completed_is_matched() { + console.log('test_free_completed_is_matched'); + // Only free_completed is matched (not free_requested). Verify alloc+free_completed works. + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 1000, size: 200, frames: [], stream: 0 }, + { action: 'free_completed', addr: 1000, size: 200, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + const result = process_alloc_data(snapshot, 0, false, 15000, false); + assertEqual(result.max_size, 200, 'alloc+free_completed peak should be 200'); +} + +function test_pool_free_without_alloc_no_inflation() { + console.log('test_pool_free_without_alloc_no_inflation'); + // Simulate: a private pool block was allocated BEFORE trace recording, + // then freed during recording. The trace only has free_completed. + // With include_private_inactive=true, the block should be pre-loaded into + // pool state and then freed — NOT treated as a new allocation. + const poolId = [1, 42]; + const snapshot = makeSnapshot({ + traces: [ + // No alloc — it happened before recording started or got replaced in ring buffer + // Only the free_completed event is matched (free_requested is ignored). + { action: 'free_completed', addr: 0x788c2000000, size: 94371840, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x788c2000000, size: 256, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x788c0000000, total_size: 83886080, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // The free_completed creates an element in initially_allocated + actions. + // Pre-loaded into pool: pool.active=90M, envelope=max(90M, 80M reserved)=90M. + // Then freed (pool.active=0), then small alloc (pool.active=256). + // Peak should be 94371840 (max of active and reserved), NOT double-counted. + assertEqual(result.max_size, 94371840, + 'pool free-without-alloc peak should be max(active, reserved)'); + + assertEqual(result.elements_length, 2, 'should have 2 elements'); + + assert(result.max_at_time.length > 0, 'max_at_time should not be empty'); + assertEqual(Math.max(...result.max_at_time), 94371840, + 'max_at_time peak should be max(active, reserved)'); + + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 pool envelope'); + assertEqual(envelopes[0].elem, 'pool:1,42,s0', 'envelope key matches pool id and stream'); + + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 2, 'should have 2 pool stripes'); + assertEqual(stripes[0].size, 94371840, 'first stripe: pre-loaded block (freed later)'); + assertEqual(stripes[1].size, 256, 'second stripe: small alloc after free'); +} + +function test_pool_alloc_then_free_normal() { + console.log('test_pool_alloc_then_free_normal'); + // Normal case: alloc + free both in trace, private pool, include_private_inactive=true + const poolId = [1, 7]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 2000, size: 500, frames: [], stream: 0 }, + { action: 'free_completed', addr: 2000, size: 500, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 8192, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // Envelope = segment reserved (8192), which is the pool's GPU footprint + assertEqual(result.max_size, 8192, 'pool envelope should be segment reserved (8192)'); +} + +function test_multiple_pool_frees_without_alloc() { + console.log('test_multiple_pool_frees_without_alloc'); + // Multiple blocks freed from the same pool without matching allocs. + // This is the FSDP scenario with many free_storage calls. + const poolId = [1, 99]; + const snapshot = makeSnapshot({ + traces: [ + // Each block has one free_completed (free_requested is ignored by the JS code) + { action: 'free_completed', addr: 1000, size: 500, frames: [], stream: 0 }, + { action: 'free_completed', addr: 2000, size: 500, frames: [], stream: 0 }, + { action: 'free_completed', addr: 3000, size: 500, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 16384, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // 3 blocks of 500 each were initially allocated. Pool envelope = 16384 (segment reserved). + // Then all 3 are freed. Peak should be 16384 (segment reserved, the initial state). + // BUG would give 32768+ (double-counted). + assert(result.max_size <= 16384, + `multiple pool frees should not inflate: got ${result.max_size}, expected <= 16384`); +} + +function test_non_pool_free_without_alloc() { + console.log('test_non_pool_free_without_alloc'); + // Non-pool block freed without matching alloc (ring buffer wrap). + // Should appear then disappear. + const snapshot = makeSnapshot({ + traces: [ + { action: 'free_completed', addr: 8000, size: 300, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 16384, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + // max_size is only updated inside the actions loop AFTER the free decrements total_mem, + // so for a free-without-alloc element, max_size ends up 0. + // The actual peak (300) is captured in max_at_time instead. + assertEqual(result.max_size, 0, 'non-pool free-without-alloc: max_size is 0 (peak is in max_at_time)'); + assert(Math.max(...result.max_at_time) === 300, + 'non-pool free-without-alloc: max_at_time peak should be 300'); +} + +function test_mixed_pool_and_nonpool() { + console.log('test_mixed_pool_and_nonpool'); + // Mix of pool and non-pool allocations + const poolId = [1, 5]; + const snapshot = makeSnapshot({ + traces: [ + // Non-pool alloc+free (addr within default pool segment) + { action: 'alloc', addr: 100, size: 200, frames: [], stream: 0 }, + // Pool alloc+free (addr within private pool segment) + { action: 'alloc', addr: 5000, size: 400, frames: [], stream: 0 }, + { action: 'free_completed', addr: 100, size: 200, frames: [], stream: 0 }, + { action: 'free_completed', addr: 5000, size: 400, frames: [], stream: 0 }, + ], + segments: [ + { device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], stream: 0, blocks: [] }, + { device: 0, address: 4096, total_size: 8192, segment_pool_id: poolId, stream: 0, blocks: [] }, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // Peak: 200 (non-pool) + 8192 (pool envelope = segment reserved) = 8392 + assertEqual(result.max_size, 8392, 'mixed pool+nonpool peak: 200 + 8192 segment reserved'); +} + +function test_include_private_inactive_false_ignores_pools() { + console.log('test_include_private_inactive_false_ignores_pools'); + // When include_private_inactive=false, pool blocks should be treated as regular + const poolId = [1, 10]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 5000, size: 800, frames: [], stream: 0 }, + { action: 'free_completed', addr: 5000, size: 800, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 4096, total_size: 8192, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + // Without pool logic, it's just a regular alloc+free. Peak = 800. + assertEqual(result.max_size, 800, 'with include_private_inactive=false, peak = 800'); +} + +// ============================================================ +// formatSize tests +// ============================================================ + +function test_formatSize_bytes() { + console.log('test_formatSize_bytes'); + assertEqual(formatSize(0), '0.0B (0 bytes)', 'zero bytes'); + assertEqual(formatSize(512), '512.0B (512 bytes)', '512 bytes'); + assertEqual(formatSize(1), '1.0B (1 bytes)', '1 byte'); +} + +function test_formatSize_kib() { + console.log('test_formatSize_kib'); + // 1024 bytes = 1.0 KiB + assertEqual(formatSize(1024), '1.0KiB (1024 bytes)', '1 KiB'); + // 1536 = 1.5 * 1024 + assertEqual(formatSize(1536), '1.5KiB (1536 bytes)', '1.5 KiB'); +} + +function test_formatSize_mib_gib() { + console.log('test_formatSize_mib_gib'); + const mib = 1024 * 1024; + assertEqual(formatSize(mib), '1.0MiB (1048576 bytes)', '1 MiB'); + const gib = 1024 * 1024 * 1024; + assertEqual(formatSize(gib), '1.0GiB (1073741824 bytes)', '1 GiB'); +} + +function test_formatSize_no_bytes() { + console.log('test_formatSize_no_bytes'); + assertEqual(formatSize(1024, false), '1.0KiB', 'showBytes=false omits raw count'); + assertEqual(formatSize(512, false), '512.0B', 'showBytes=false for small values'); +} + +// ============================================================ +// formatAddr tests +// ============================================================ + +function test_formatAddr_block_event() { + console.log('test_formatAddr_block_event'); + const event = { action: 'alloc', addr: 0x7f4c00000, version: 3 }; + assertEqual(formatAddr(event), "b'7f4c00000_3", 'block alloc address'); +} + +function test_formatAddr_segment_event() { + console.log('test_formatAddr_segment_event'); + const event = { action: 'segment_alloc', addr: 0xabc, version: 0 }; + assertEqual(formatAddr(event), "s'abc_0", 'segment alloc address'); +} + +function test_formatAddr_free_event() { + console.log('test_formatAddr_free_event'); + const event = { action: 'free_completed', addr: 0xff, version: 5 }; + assertEqual(formatAddr(event), "b'ff_5", 'free_completed is a block event'); +} + +// ============================================================ +// elideRepeats tests +// ============================================================ + +function test_elideRepeats_no_repeats() { + console.log('test_elideRepeats_no_repeats'); + const result = elideRepeats(['a', 'b', 'c']); + assertEqual(result.join(','), 'a,b,c', 'no repeats passes through'); +} + +function test_elideRepeats_two_consecutive() { + console.log('test_elideRepeats_two_consecutive'); + // Two consecutive duplicates are kept as-is (not collapsed) + const result = elideRepeats(['a', 'a', 'b']); + assertEqual(result.join(','), 'a,a,b', 'two consecutive kept verbatim'); +} + +function test_elideRepeats_three_or_more() { + console.log('test_elideRepeats_three_or_more'); + // Three+ consecutive duplicates collapse to [frame, ""] + const result = elideRepeats(['x', 'x', 'x', 'x']); + assertEqual(result.length, 2, 'collapsed to 2 entries'); + assertEqual(result[0], 'x', 'first entry is the frame'); + assertEqual(result[1], '', 'second entry is repeat count'); +} + +function test_elideRepeats_mixed() { + console.log('test_elideRepeats_mixed'); + // Realistic: a stack trace with a recursive section in the middle + const result = elideRepeats(['top', 'recurse', 'recurse', 'recurse', 'recurse', 'recurse', 'bottom']); + assertEqual(result.join(','), 'top,recurse,,bottom', 'mixed with recursion'); +} + +function test_elideRepeats_empty() { + console.log('test_elideRepeats_empty'); + const result = elideRepeats([]); + assertEqual(result.length, 0, 'empty input gives empty output'); +} + +// ============================================================ +// context_for_id tests +// ============================================================ + +function test_context_for_id_with_pool() { + console.log('test_context_for_id_with_pool'); + // Alloc with a known private pool, stack frames, and stream + const poolId = [2, 7]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0xabc000, size: 2048, version: 5, + frames: [{ filename: 'model.py', line: 42, name: 'forward' }], + stream: 3, timestamp: true, time_us: 1700000000000000 }, + ], + segments: [{ + device: 0, address: 0xa00000, total_size: 0x200000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + const ctx = result.context_for_id(0); + + // Address should be formatted as hex with version + assertContains(ctx, "abc000", 'context includes hex address'); + assertContains(ctx, '_5', 'context includes version'); + // Size should be formatted + assertContains(ctx, '2.0KiB', 'context includes formatted size'); + assertContains(ctx, '2048 bytes', 'context includes raw byte count'); + // Pool ID resolved from segment + assertContains(ctx, 'pool_id (2, 7)', 'context includes resolved pool_id'); + // Stream + assertContains(ctx, 'stream 3', 'context includes stream'); + // Stack frame + assertContains(ctx, 'model.py:42:forward', 'context includes stack frame'); +} + +function test_context_for_id_unknown_pool() { + console.log('test_context_for_id_unknown_pool'); + // Alloc at address outside any segment → pool_id unknown + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0xff0000, size: 512, version: 0, + frames: [{ filename: 'train.py', line: 10, name: 'step' }], + stream: 0, timestamp: null }, + ], + segments: [{ + device: 0, address: 0x100000, total_size: 0x100000, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + const ctx = result.context_for_id(0); + + assertContains(ctx, 'pool_id unknown', 'addr outside segments gives unknown pool'); + assertContains(ctx, 'ff0000', 'context includes hex address'); + assertContains(ctx, '512.0B', 'context includes size'); +} + +function test_context_for_id_free_without_alloc() { + console.log('test_context_for_id_free_without_alloc'); + // free_completed without matching alloc → "alloc not recorded" message + const snapshot = makeSnapshot({ + traces: [ + { action: 'free_completed', addr: 0xdead00, size: 4096, version: 2, + frames: [{ filename: 'fsdp.py', line: 750, name: 'free_storage' }], + stream: 0, timestamp: null }, + ], + segments: [{ + device: 0, address: 0xd00000, total_size: 0x100000, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + const ctx = result.context_for_id(0); + + assertContains(ctx, 'alloc not recorded', 'free-without-alloc shows warning'); + assertContains(ctx, 'dead00', 'context includes hex address'); + assertContains(ctx, '4.0KiB', 'context includes size'); + assertContains(ctx, 'fsdp.py:750:free_storage', 'context shows free stack trace'); +} + +// ============================================================ +// Post-PR#177717 tests: trace events carry segment_pool_id directly +// ============================================================ + +function test_post177717_pool_id_from_trace_event() { + console.log('test_post177717_pool_id_from_trace_event'); + // After PR#177717, trace events include segment_pool_id. The code should + // use it directly instead of falling back to find_pool_id from segments. + // Here the addr is OUTSIDE any segment, but pool_id is on the event itself. + const poolId = [3, 15]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0xff0000, size: 1024, frames: [], stream: 0, + segment_pool_id: poolId }, + { action: 'free_completed', addr: 0xff0000, size: 1024, frames: [], stream: 0, + segment_pool_id: poolId }, + ], + // No segment covers 0xff0000 — pool_id comes from the trace event + segments: [], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + assertEqual(result.max_size, 1024, 'pool alloc/free with event-level pool_id'); + + const ctx = result.context_for_id(0); + assertContains(ctx, 'pool_id (3, 15)', 'pool_id resolved from trace event, not segment'); +} + +function test_post177717_pool_free_without_alloc_no_segment() { + console.log('test_post177717_pool_free_without_alloc_no_segment'); + // Post-177717: free_completed has segment_pool_id on the event. + // The segment was unmapped (not in segments list), but pool_id is still known. + const poolId = [1, 42]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'free_completed', addr: 0xdead00, size: 2048, frames: [], stream: 0, + segment_pool_id: poolId }, + ], + // Segment was unmapped — not present. Pool resolved from event. + segments: [], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // Should be pre-loaded into pool and then freed, not inflated + assert(result.max_size <= 2048, + `post-177717 pool free-without-alloc should not inflate: got ${result.max_size}`); + + const ctx = result.context_for_id(0); + assertContains(ctx, 'pool_id (1, 42)', 'pool_id from event even without segment'); + assertContains(ctx, 'alloc not recorded', 'still shows alloc not recorded'); +} + +function test_post177717_mixed_events_with_and_without_pool_id() { + console.log('test_post177717_mixed_events_with_and_without_pool_id'); + // Some events have segment_pool_id (post-177717), others don't (pre-177717 + // or default pool). Verify both paths work together. + const poolId = [2, 8]; + const snapshot = makeSnapshot({ + traces: [ + // Default pool alloc — no segment_pool_id on event, resolved via segment + { action: 'alloc', addr: 100, size: 300, frames: [], stream: 0 }, + // Private pool alloc — segment_pool_id on the event + { action: 'alloc', addr: 0xf000, size: 500, frames: [], stream: 0, + segment_pool_id: poolId }, + { action: 'free_completed', addr: 100, size: 300, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0xf000, size: 500, frames: [], stream: 0, + segment_pool_id: poolId }, + ], + segments: [ + // Only covers addr=100 (default pool). addr=0xf000 has no segment. + { device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [] }, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // Peak: 300 (non-pool) + 500 (pool envelope) = 800 + assertEqual(result.max_size, 800, 'mixed pre/post-177717 events peak'); + + const ctx0 = result.context_for_id(0); + assertContains(ctx0, 'pool_id (0, 0)', 'default pool resolved from segment'); + + const ctx1 = result.context_for_id(1); + assertContains(ctx1, 'pool_id (2, 8)', 'private pool from event-level segment_pool_id'); +} + + +// ============================================================ +// Pool grouping by (pool_id, stream) tests +// ============================================================ + +function test_pool_grouped_by_stream() { + console.log('test_pool_grouped_by_stream'); + // Same pool_id but different streams should produce separate envelopes. + const poolId = [1, 5]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 1000, size: 500, frames: [], stream: 1 }, + { action: 'alloc', addr: 2000, size: 300, frames: [], stream: 2 }, + { action: 'free_completed', addr: 1000, size: 500, frames: [], stream: 1 }, + { action: 'free_completed', addr: 2000, size: 300, frames: [], stream: 2 }, + ], + segments: [ + { device: 0, address: 0, total_size: 4096, segment_pool_id: poolId, + stream: 0, blocks: [] }, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 2, 'should have 2 pool envelopes (one per stream)'); + + const keys = envelopes.map(e => e.elem).sort(); + assertEqual(keys[0], 'pool:1,5,s1', 'first envelope key includes stream 1'); + assertEqual(keys[1], 'pool:1,5,s2', 'second envelope key includes stream 2'); +} + +// ============================================================ +// Initially Added Blocks Tests +// ============================================================ + +function test_default_pool_ghost_block() { + console.log('test_segment_snapshot_no_trace'); + const poolId = [0, 0]; + const snapshot = makeSnapshot({ + traces: [], + segments: [{ + device: 0, address: 4096, total_size: 8192, segment_pool_id: poolId, + stream: 0, blocks: [ + { address: 5000, size: 1000, requested_size: 1000, state: 'active_allocated', + frames: [], version: 0 }, + ], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + assertEqual(result.elements_length, 1, + 'snapshot-only block should not be added (include_private_inactive=false)'); + +} + +function test_segment_snapshot_with_trace_history() { + console.log('test_segment_snapshot_with_trace_history'); + const poolId = [1, 42]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 5000, size: 1000, frames: [], stream: 0 }, + { action: 'free_completed', addr: 5000, size: 1000, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 4096, total_size: 8192, segment_pool_id: poolId, + stream: 0, blocks: [ + { address: 5000, size: 1000, requested_size: 1000, state: 'inactive', + frames: [], version: 0 }, + ], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + assertEqual(result.elements_length, 1, + 'only trace element present, snapshot block not duplicated'); +} + +function test_segment_snapshot_no_trace() { + console.log('test_segment_snapshot_no_trace'); + const poolId = [1, 42]; + const snapshot = makeSnapshot({ + traces: [], + segments: [{ + device: 0, address: 4096, total_size: 8192, segment_pool_id: poolId, + stream: 0, blocks: [ + { address: 5000, size: 1000, requested_size: 1000, state: 'inactive', + frames: [], version: 0 }, + ], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + assertEqual(result.elements_length, 0, + 'snapshot-only block should not be added (include_private_inactive=true)'); + + const result2 = process_alloc_data(snapshot, 0, false, 15000, false); + assertEqual(result2.elements_length, 0, + 'snapshot-only block should not be added (include_private_inactive=false)'); +} + +function test_ghost_blocks() { + console.log('test_ghost_blocks'); + // Snapshot produced by (agent_space/test_ring_buffer_overflow.py): + // pre_record = torch.empty(1024 * 1024, device="cuda", dtype=torch.uint8) # 1 MiB + // torch.cuda.memory._record_memory_history(max_entries=10) + // early = torch.empty(2 * 1024 * 1024, device="cuda", dtype=torch.uint8) # 2 MiB + // for _ in range(15): # overflow the 10-entry ring buffer + // t = torch.empty(4 * 1024 * 1024, device="cuda", dtype=torch.uint8) # 4 MiB + // del t + // snap = torch.cuda.memory._snapshot() + // + // pre_record: allocated before recording → no trace event at all + // early: alloc event evicted from ring buffer by churn → no trace event + // Both are active_allocated in segment snapshot but invisible in trace. + const snapshot = makeSnapshot({ + traces: [ + { action: 'free_completed', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + ], + segments: [ + { device: 0, address: 0x1e00000, total_size: 2097152, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 0x1e00000, size: 1048576, requested_size: 1048576, + state: 'active_allocated', frames: [] }, + { address: 0x1f00000, size: 1048576, requested_size: 1048576, + state: 'inactive', frames: [] }, + ]}, + { device: 0, address: 0x6800000, total_size: 20971520, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 0x6800000, size: 2097152, requested_size: 2097152, + state: 'active_allocated', frames: [ + { filename: 'test.py', line: 10, name: 'early_alloc' }, + ]}, + { address: 0x6a00000, size: 18874368, requested_size: 18874368, + state: 'inactive', frames: [] }, + ]}, + ], + }); + + // Ghost blocks (active_allocated not in trace) show on both tabs. + for (const include_private of [false, true]) { + const label = `include_private_inactive=${include_private}`; + const result = process_alloc_data(snapshot, 0, false, 15000, include_private); + + // Trace creates 2 elements from alloc events + 1 from unmatched free. + // 2 active_allocated blocks from snapshot not in trace. + assertEqual(result.elements_length, 5, + `${label}: 3 trace elements + 2 snapshot blocks`); + + const aot = result.allocations_over_time; + const ghosts = aot.filter(d => d.ghost === true); + assertEqual(ghosts.length, 2, `${label}: should have 2 ghost block entries`); + + // Ghost block sizes match segment snapshot blocks + const ghost_sizes = ghosts.map(g => g.size).sort(); + assertEqual(ghost_sizes[0], 1048576, `${label}: ghost block 1 MiB (pre_record)`); + assertEqual(ghost_sizes[1], 2097152, `${label}: ghost block 2 MiB (ring buffer overflow)`); + + // context_for_id shows ghost explanation + const ghost_elem_ids = ghosts.map(g => g.elem); + for (const id of ghost_elem_ids) { + const ctx = result.context_for_id(id); + assertContains(ctx, '[Ghost block]', `${label}: context contains ghost label`); + assertContains(ctx, 'segment snapshot', `${label}: context explains source`); + } + } +} + +function test_ghost_blocks_not_created_for_traced_addrs() { + console.log('test_ghost_blocks_not_created_for_traced_addrs'); + // A block in the segment snapshot whose address DID appear in trace events + // should NOT be a ghost block. + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 1000, size: 512, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 1000, size: 512, requested_size: 512, + state: 'active_allocated', frames: [] }, + ], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const aot = result.allocations_over_time; + const ghosts = aot.filter(d => d.ghost === true); + assertEqual(ghosts.length, 0, 'no ghost blocks when addr is in trace'); + assertEqual(result.elements_length, 1, 'only the trace element'); +} + +function test_ghost_blocks_default_pool_collected() { + console.log('test_ghost_blocks_default_pool_collected'); + // Ghost blocks from default pool [0,0] should be collected + // from the segment snapshot when they have no trace events. + const snapshot = makeSnapshot({ + traces: [], + segments: [{ + device: 0, address: 0x1000, total_size: 8192, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 0x1000, size: 2048, requested_size: 2048, + state: 'active_allocated', frames: [] }, + { address: 0x1800, size: 1024, requested_size: 1024, + state: 'active_allocated', frames: [] }, + ], + }], + }); + + // Shows on both tabs + for (const include_private of [false, true]) { + const label = `include_private_inactive=${include_private}`; + const result = process_alloc_data(snapshot, 0, false, 15000, include_private); + const ghosts = result.allocations_over_time.filter(d => d.ghost === true); + assertEqual(ghosts.length, 2, `${label}: 2 ghost blocks from default pool`); + const sizes = ghosts.map(g => g.size).sort(); + assertEqual(sizes[0], 1024, `${label}: ghost 1024 bytes`); + assertEqual(sizes[1], 2048, `${label}: ghost 2048 bytes`); + } +} + +function test_ghost_blocks_not_in_segment_mode() { + console.log('test_ghost_blocks_not_in_segment_mode'); + // Ghost blocks should not be created in segment-level views + const snapshot = makeSnapshot({ + traces: [], + segments: [{ + device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 100, size: 512, requested_size: 512, + state: 'active_allocated', frames: [] }, + ], + }], + }); + + const result_seg = process_alloc_data(snapshot, 0, true, 15000, false); + const ghosts_seg = result_seg.allocations_over_time.filter(d => d.ghost === true); + assertEqual(ghosts_seg.length, 0, 'no ghost blocks in segment_alloc mode'); + + // Even with include_private_inactive=true, segment modes should not create ghosts + const result_seg2 = process_alloc_data(snapshot, 0, true, 15000, true); + const ghosts_seg2 = result_seg2.allocations_over_time.filter(d => d.ghost === true); + assertEqual(ghosts_seg2.length, 0, 'no ghost blocks in segment_alloc mode (private pool tab)'); +} + +function test_ghost_blocks_private_pool() { + console.log('test_ghost_blocks_private_pool'); + // Snapshot produced by (agent_space/test_ghost_blocks_private_pool.py): + // pool = torch.cuda.MemPool() + // with torch.cuda.use_mem_pool(pool): + // pre_record_pool = torch.empty(1 MiB) # ghost in private pool (0,1) + // pre_record_default = torch.empty(2 MiB) # ghost in default pool (0,0) + // torch.cuda.memory._record_memory_history(max_entries=20) + // with torch.cuda.use_mem_pool(pool): + // traced_pool = torch.empty(3 MiB) # traced in private pool (0,1) + // traced_default = torch.empty(4 MiB) # traced in default pool (0,0) + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0xe600000, size: 3145728, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x6a00000, size: 4194304, frames: [], stream: 0 }, + ], + segments: [ + // Private pool segment with ghost block + { device: 0, address: 0x1e00000, total_size: 2097152, segment_pool_id: [0, 1], + stream: 0, blocks: [ + { address: 0x1e00000, size: 1048576, requested_size: 1048576, + state: 'active_allocated', frames: [] }, + ]}, + // Default pool segment with ghost block + traced block + { device: 0, address: 0x6800000, total_size: 20971520, segment_pool_id: [0, 0], + stream: 0, blocks: [ + { address: 0x6800000, size: 2097152, requested_size: 2097152, + state: 'active_allocated', frames: [] }, + { address: 0x6a00000, size: 4194304, requested_size: 4194304, + state: 'active_allocated', frames: [] }, + ]}, + // Private pool segment with traced block + { device: 0, address: 0xe600000, total_size: 20971520, segment_pool_id: [0, 1], + stream: 0, blocks: [ + { address: 0xe600000, size: 3145728, requested_size: 3145728, + state: 'active_allocated', frames: [] }, + ]}, + ], + }); + + // With include_private_inactive=true, ghost blocks from BOTH pools are + // collected. Private pool ghosts go inside their pool envelope; default + // pool ghosts are rendered at the global bottom of the stacked area. + const result = process_alloc_data(snapshot, 0, false, 15000, true); + + const aot = result.allocations_over_time; + const ghosts = aot.filter(d => d.ghost === true); + assertEqual(ghosts.length, 2, '2 ghost blocks (one from each pool)'); + + // Default pool ghost (2 MiB): at global bottom, spans full timeline + const default_ghost = ghosts.find(g => g.size === 2097152); + assert(default_ghost !== undefined, 'default pool ghost (2 MiB) exists'); + assertEqual(default_ghost.offsets[0], 0, 'default ghost at offset 0'); + assertEqual(default_ghost.timesteps[0], 0, 'default ghost starts at timestep 0'); + assertEqual(default_ghost.timesteps.length, 2, 'default ghost has 2 timesteps'); + assert(default_ghost.timesteps[1] > 0, 'default ghost ends after timestep 0'); + + // Private pool ghost (1 MiB): inside the pool (0,1) envelope + const pool_ghost = ghosts.find(g => g.size === 1048576); + assert(pool_ghost !== undefined, 'private pool ghost (1 MiB) exists'); + assertEqual(pool_ghost.timesteps[0], 0, 'pool ghost starts at timestep 0'); + assert(pool_ghost.timesteps.at(-1) > 0, 'pool ghost ends after timestep 0'); + + // Both ghosts should end at the same final timestep + assertEqual(default_ghost.timesteps[1], pool_ghost.timesteps.at(-1), + 'both ghosts end at the same final timestep'); + + // Pool envelope only for (0,1) — default pool ghosts are at global bottom. + // Both pool segments use stream 0, so there should be exactly 1 envelope + // (not split by stream, and not "snull" from annotate_snapshot eliding streams). + const envelopes = aot.filter(d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, '1 pool envelope'); + assertContains(envelopes[0].elem, '0,1,s0', 'envelope for pool (0,1)'); + + // Envelope initial size at timestep 0 = total pool reserved from snapshot + // (ghost segment 2 MiB + traced segment ~20 MiB = ~22 MiB) + const pool_reserved = 2097152 + 20971520; // sum of pool (0,1) segment total_sizes + const env = envelopes[0]; + assertEqual(env.timesteps[0], 0, 'envelope starts at timestep 0'); + assertEqual(env.size[0], pool_reserved, 'envelope initial size = total pool reserved'); + + // Ghost stripe should fit within the envelope + const env_offset = env.offsets[0]; + const ghost_offset = pool_ghost.offsets[0]; + assert(ghost_offset >= env_offset, 'ghost stripe offset >= envelope offset'); + assert(ghost_offset + 1048576 <= env_offset + env.size[0], + 'ghost stripe fits within envelope at timestep 0'); + + // Pool envelope max = total pool reserved (same as initial since no segment events) + const env_max_size = Array.isArray(env.size) + ? Math.max(...env.size) + : env.size; + assertEqual(env_max_size, pool_reserved, 'pool envelope max = total pool reserved'); + + // With include_private_inactive=false, ghosts still exist but no pool envelope + const result_false = process_alloc_data(snapshot, 0, false, 15000, false); + const ghosts_false = result_false.allocations_over_time.filter(d => d.ghost === true); + assertEqual(ghosts_false.length, 2, '2 ghost blocks on regular tab too'); + const envs_false = result_false.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envs_false.length, 0, 'no pool envelopes when include_private_inactive=false'); +} + +function test_ghost_stripe_offset_with_multiple_pools() { + console.log('test_ghost_stripe_offset_with_multiple_pools'); + // When multiple private pools have initially_allocated blocks, pool envelopes + // are stacked. A ghost stripe must have its offset within its own envelope, + // not at the offset from when the stripe was first created (before other + // pools shifted it upward). + const snapshot = makeSnapshot({ + traces: [ + // Traced alloc in default pool so actions is non-empty + { action: 'alloc', addr: 0x100, size: 100, frames: [], stream: 0 }, + ], + segments: [ + // Default pool segment for the traced alloc + { device: 0, address: 0, total_size: 4096, segment_pool_id: [0, 0], + stream: 0, blocks: [] }, + // Pool (0,3) segment 1: one ghost block. + // Added to initially_allocated first among private pools. + { device: 0, address: 0x4000, total_size: 4096, segment_pool_id: [0, 3], + stream: 0, blocks: [ + { address: 0x4000, size: 3000, requested_size: 3000, + state: 'active_allocated', frames: [] }, + ]}, + // Pool (0,2): ghost block. Added to initially_allocated second. + { device: 0, address: 0x10000, total_size: 8192, segment_pool_id: [0, 2], + stream: 0, blocks: [ + { address: 0x10000, size: 2000, requested_size: 2000, + state: 'active_allocated', frames: [] }, + ]}, + // Pool (0,3) segment 2: another ghost block. + // Added to initially_allocated LAST. After reverse(), processed FIRST. + // This creates pool (0,3) at offset 0. Then pool (0,2) is processed + // (stripe at offset 1000). Then pool (0,3) segment 1's ghost block + // grows pool (0,3) from 1000 to 4000, shifting pool (0,2)'s envelope + // up by 3000 but (without fix) not its stripe. + { device: 0, address: 0x5000, total_size: 4096, segment_pool_id: [0, 3], + stream: 0, blocks: [ + { address: 0x5000, size: 1000, requested_size: 1000, + state: 'active_allocated', frames: [] }, + ]}, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const aot = result.allocations_over_time; + + const envelopes = aot.filter(d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 2, '2 pool envelopes'); + + // Find the pool (0,2) envelope and its ghost stripe + const env02 = envelopes.find(e => e.elem.includes('0,2')); + assert(env02 !== undefined, 'pool (0,2) envelope exists'); + + const ghosts = aot.filter(d => d.ghost === true); + assertEqual(ghosts.length, 3, '3 ghost blocks (2 in pool 0,3 + 1 in pool 0,2)'); + // The ghost for pool (0,2) is the 2000-byte one + const ghost02 = ghosts.find(g => g.size === 2000); + assert(ghost02 !== undefined, 'pool (0,2) ghost exists'); + + // Ghost stripe offset must be within the envelope range + const env_offset = env02.offsets[0]; + const env_size = env02.size[0]; + const ghost_offset = ghost02.offsets[0]; + assertEqual(ghost_offset, env_offset, + `ghost offset should equal envelope offset (single block in pool)`); + assert(ghost_offset + 2000 <= env_offset + env_size, + `ghost fits within envelope: ${ghost_offset}+2000 <= ${env_offset}+${env_size}`); +} + +// ============================================================ +// Full snapshot integration test +// ============================================================ + +function test_full_snapshot_private_pools() { + console.log('test_full_snapshot_private_pools'); + // 512 KiB allocs so they're visible relative to the 2 MiB envelope + const S = 524288; + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_alloc', addr: 0x7f08cde00000, size: 2097152, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde00000, size: S, frames: [], stream: 0 }, + { action: 'segment_alloc', addr: 0x7f08d2800000, size: 2097152, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2800000, size: S, frames: [], stream: 0 }, + { action: 'free_requested', addr: 0x7f08cde00000, size: S, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x7f08cde00000, size: S, frames: [], stream: 0 }, + { action: 'free_requested', addr: 0x7f08d2800000, size: S, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x7f08d2800000, size: S, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde00000, size: S, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2800000, size: S, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2880000, size: S, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde80000, size: S, frames: [], stream: 0 }, + ], + segments: [ + { device: 0, address: 0x7f08cde00000, total_size: 2097152, + segment_pool_id: [0, 0], stream: 0, blocks: [] }, + { device: 0, address: 0x7f08d2800000, total_size: 2097152, + segment_pool_id: [0, 1], stream: 0, blocks: [] }, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + + assertEqual(result.elements_length, 6, 'should have 6 elements'); + assertEqual(result.max_size, 2 * S + 2097152, 'peak = 2 default pool blocks + 2M envelope'); + + // Timeline (envelope = 2M from segment reserved, never shrinks): + // alloc 512K in default pool → 512K + // alloc 512K in private pool → 512K + 2M envelope = 2.5M + // free 512K from default pool → 2M (envelope stays) + // free 512K stripe in pool → 2M (stripe gone, envelope stays) + // alloc 512K in default pool → 2M + 512K = 2.5M + // alloc 512K stripe in pool → 2.5M (within envelope) + // alloc 512K stripe in pool → 2.5M (within envelope) + // alloc 512K in default pool → 2M + 2*512K = 3M + const expected_max_at_time = [ + S, S+2097152, S+2097152, S+2097152, S+2097152, S+2097152, + S+2097152, S+2097152, S+2097152, + 2097152, S+2097152, S+2097152, S+2097152, 2*S+2097152, + ]; + assertEqual(result.max_at_time.length, expected_max_at_time.length, + 'max_at_time length'); + for (let i = 0; i < expected_max_at_time.length; i++) { + assertEqual(result.max_at_time[i], expected_max_at_time[i], + `max_at_time[${i}]`); + } + + const aot = result.allocations_over_time; + + const envelopes = aot.filter(d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 pool envelope'); + assertEqual(envelopes[0].elem, 'pool:0,1,s0', 'envelope key matches pool (0,1)'); + + // Envelope size is driven by segment reserved (2 MiB), not just active allocs + const env_max = Array.isArray(envelopes[0].size) + ? Math.max(...envelopes[0].size) : envelopes[0].size; + assertEqual(env_max, 2097152, 'envelope = segment reserved (2 MiB)'); + + const stripes = aot.filter(d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 3, 'should have 3 pool stripes (elements 1, 3, 4)'); + + const non_pool = aot.filter(d => typeof d.elem === 'number' && d.opacity === undefined); + assertEqual(non_pool.length, 3, 'should have 3 non-pool elements (0, 2, 5)'); + + const ctx0 = result.context_for_id(0); + assertContains(ctx0, '7f08cde00000', 'element 0 addr'); + assertContains(ctx0, 'pool_id (0, 0)', 'element 0 pool'); + + const ctx1 = result.context_for_id(1); + assertContains(ctx1, '7f08d2800000', 'element 1 addr'); + assertContains(ctx1, 'pool_id (0, 1)', 'element 1 pool'); +} + +function test_full_snapshot_no_private_pools() { + console.log('test_full_snapshot_no_private_pools'); + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_alloc', addr: 0x7f08cde00000, size: 2097152, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde00000, size: 1024, frames: [], stream: 0 }, + { action: 'segment_alloc', addr: 0x7f08d2800000, size: 2097152, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2800000, size: 1024, frames: [], stream: 0 }, + { action: 'free_requested', addr: 0x7f08cde00000, size: 1024, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x7f08cde00000, size: 1024, frames: [], stream: 0 }, + { action: 'free_requested', addr: 0x7f08d2800000, size: 1024, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x7f08d2800000, size: 1024, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde00000, size: 1024, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2800000, size: 1024, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08d2800400, size: 1024, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x7f08cde00400, size: 1024, frames: [], stream: 0 }, + ], + segments: [ + { device: 0, address: 0x7f08cde00000, total_size: 2097152, + segment_pool_id: [0, 0], stream: 0, blocks: [] }, + { device: 0, address: 0x7f08d2800000, total_size: 2097152, + segment_pool_id: [0, 1], stream: 0, blocks: [] }, + ], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, false); + + assertEqual(result.elements_length, 6, 'should have 6 elements'); + assertEqual(result.max_size, 4096, 'peak memory should be 4096'); + + const expected_max_at_time = [ + 1024, 2048, 2048, 2048, 2048, 2048, + 1024, 1024, 2048, 3072, 4096, + ]; + assertEqual(result.max_at_time.length, expected_max_at_time.length, + 'max_at_time length'); + for (let i = 0; i < expected_max_at_time.length; i++) { + assertEqual(result.max_at_time[i], expected_max_at_time[i], + `max_at_time[${i}]`); + } + + const aot = result.allocations_over_time; + const envelopes = aot.filter(d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 0, 'no pool envelopes'); + + const stripes = aot.filter(d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 0, 'no pool stripes'); + + const regular = aot.filter(d => typeof d.elem === 'number'); + assertEqual(regular.length, 6, 'all 6 elements are regular'); + + const ctx0 = result.context_for_id(0); + assertContains(ctx0, '7f08cde00000', 'element 0 addr'); + assertContains(ctx0, 'Total memory used after allocation: 1.0KiB', 'element 0 total'); + assertContains(ctx0, 'pool_id (0, 0)', 'element 0 pool'); +} + +// ============================================================ +// Pool envelope reserved-memory tests +// ============================================================ + +function test_envelope_grows_on_segment_map() { + console.log('test_envelope_grows_on_segment_map'); + // When a private pool alloc triggers segment_map (fragmentation), the envelope + // should grow to the reserved size, not just the active allocation size. + // Scenario: alloc 500, free 500, alloc 400 triggers segment_map of 400 + // (can't reuse fragmented free blocks). Reserved = 500 + 400 = 900. + const poolId = [1, 10]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_map', addr: 0x1000, size: 500, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x1000, size: 500, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0x1000, size: 500, frames: [], stream: 0 }, + { action: 'segment_map', addr: 0x1200, size: 400, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x1200, size: 400, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x1000, total_size: 900, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + // Envelope should be 900 (reserved), not 500 (peak active) + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 pool envelope'); + const env_max = Array.isArray(envelopes[0].size) + ? Math.max(...envelopes[0].size) : envelopes[0].size; + assertEqual(env_max, 900, 'envelope should be 900 (reserved), not 500 (peak active)'); + + // Stripes inside the envelope should reflect individual block sizes (500, 400), + // not grow to the envelope size. + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 2, 'should have 2 pool stripes'); + const stripe_sizes = stripes.map(s => s.size).sort(); + assertEqual(stripe_sizes[0], 400, 'stripe for second alloc = 400'); + assertEqual(stripe_sizes[1], 500, 'stripe for first alloc = 500'); +} + +function test_envelope_from_initial_reserved() { + console.log('test_envelope_from_initial_reserved'); + // Segment reserved is 2000 but only 300 is actively allocated. + // No segment events in trace — initial reserved = snapshot reserved. + // Envelope should be 2000 (reserved), not 300 (active). + const poolId = [1, 20]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0x2000, size: 300, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x2000, total_size: 2000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 envelope'); + const env_max = Array.isArray(envelopes[0].size) + ? Math.max(...envelopes[0].size) : envelopes[0].size; + assertEqual(env_max, 2000, 'envelope should be 2000 (segment reserved)'); +} + +function test_envelope_segment_map_no_double_count() { + console.log('test_envelope_segment_map_no_double_count'); + // Segment_map events in trace + snapshot reserved should not double-count. + // Trace has segment_map of 600. Snapshot total = 600. So initial reserved = 0. + // Envelope grows to 600 from the segment_map event. + const poolId = [1, 30]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_map', addr: 0x3000, size: 600, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x3000, size: 600, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x3000, total_size: 600, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 envelope'); + const env_max = Array.isArray(envelopes[0].size) + ? Math.max(...envelopes[0].size) : envelopes[0].size; + // Should be 600, not 1200 (double-counted) + assertEqual(env_max, 600, 'envelope should be 600 (no double count)'); +} + +function test_envelope_active_exceeds_reserved() { + console.log('test_envelope_active_exceeds_reserved'); + // Edge case: active > reserved (e.g., segment events lost from ring buffer). + // Envelope should use max(active, reserved). + // Scenario: segment_map of 500 in trace, snapshot total = 800. + // initial reserved = 800 - 500 = 300. After segment_map, reserved = 800. + // Then alloc of 800 → active=800 = reserved=800 → envelope=800. + // Now a second alloc of 200 with no segment event → active=1000 > reserved=800. + // Envelope should grow to 1000 (active exceeds reserved). + const poolId = [1, 40]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_map', addr: 0x4000, size: 500, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x4000, size: 800, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x4400, size: 200, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x4000, total_size: 1000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, 'should have 1 envelope'); + const env_max = Array.isArray(envelopes[0].size) + ? Math.max(...envelopes[0].size) : envelopes[0].size; + // active=1000 > reserved at time of second alloc (800), envelope = 1000 + assertEqual(env_max, 1000, 'envelope should be 1000 (active exceeds earlier reserved)'); +} + +function test_envelope_default_pool_unaffected() { + console.log('test_envelope_default_pool_unaffected'); + // Default pool (0,0) should NOT get envelope treatment regardless of segment events. + const snapshot = makeSnapshot({ + traces: [ + { action: 'segment_map', addr: 100, size: 5000, frames: [], stream: 0 }, + { action: 'alloc', addr: 100, size: 200, frames: [], stream: 0 }, + { action: 'free_completed', addr: 100, size: 200, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0, total_size: 5000, segment_pool_id: [0, 0], + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 15000, true); + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 0, 'default pool should not have envelope'); + assertEqual(result.max_size, 200, 'default pool peak = active only'); +} + +// ============================================================ +// Per-pool summarization tests +// ============================================================ + +function test_per_pool_summarization() { + console.log('test_per_pool_summarization'); + // 10 allocs in default pool (sizes 100..1000), 10 allocs in private pool + // (sizes 200..2000). Limit to 5 entries per pool. + // Expect: 5 drawn + 5 summarized per pool. + const poolId = [1, 50]; + const traces = []; + // Default pool: 10 allocs at 0x1000..0x1900, sizes 100,200,...,1000 + for (let i = 0; i < 10; i++) { + traces.push({ action: 'alloc', addr: 0x1000 + i * 0x100, size: (i + 1) * 100, + frames: [], stream: 0 }); + } + // Private pool: 10 allocs at 0x5000..0x5900, sizes 200,400,...,2000 + for (let i = 0; i < 10; i++) { + traces.push({ action: 'alloc', addr: 0x5000 + i * 0x100, size: (i + 1) * 200, + frames: [], stream: 0 }); + } + + const snapshot = makeSnapshot({ + traces, + segments: [ + { device: 0, address: 0x1000, total_size: 0x1000, segment_pool_id: [0, 0], + stream: 0, blocks: [] }, + { device: 0, address: 0x5000, total_size: 0x1000, segment_pool_id: poolId, + stream: 0, blocks: [] }, + ], + }); + + // Global top 5: the 5 largest across all pools. + // Private pool sizes: 200,400,...,2000. Default pool sizes: 100,200,...,1000. + // Top 5 globally = 2000,1800,1600,1400,1200 (all from private pool). + // Default pool: all 10 go to global summarized band. + // Private pool: top 5 drawn, bottom 5 (200+400+600+800+1000=3000) in per-pool summary. + const result = process_alloc_data(snapshot, 0, false, 5, true); + + const non_pool = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === undefined); + assertEqual(non_pool.length, 0, 'default pool: 0 drawn (all smaller than top 5)'); + + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 5, 'private pool: 5 drawn stripes'); + + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + assertEqual(envelopes.length, 1, '1 pool envelope'); + + const pool_summaries = result.allocations_over_time.filter( + d => d.elem === 'summarized' && d.opacity === 0.3); + assertEqual(pool_summaries.length, 1, '1 per-pool summarized stripe'); + const ps_max = Math.max(...pool_summaries[0].size); + assertEqual(ps_max, 3000, 'per-pool summarized max = 3000'); + + // Global summarized band has ALL default pool allocs (100+200+...+1000=5500) + const global_summary = result.allocations_over_time.find( + d => d.elem === 'summarized' && d.opacity === undefined); + assert(global_summary !== undefined, 'global summarized band exists'); + const gs_max = Math.max(...global_summary.size); + assertEqual(gs_max, 5500, 'global summarized max = 5500 (all default pool allocs)'); + + assertEqual(result.elements_length, 20, 'elements_length = 20 (total elements)'); +} + +function test_per_pool_summarization_with_frees() { + console.log('test_per_pool_summarization_with_frees'); + // 6 allocs in private pool, limit to 3. Then free 2 drawn and 2 non-drawn. + // Verify summarized stripe shrinks on non-drawn frees. + const poolId = [1, 60]; + const snapshot = makeSnapshot({ + traces: [ + // 6 allocs: sizes 100,200,300,400,500,600 + // Top 3 drawn: 400,500,600. Summarized: 100,200,300. + { action: 'alloc', addr: 0x8100, size: 100, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x8200, size: 200, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x8300, size: 300, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x8400, size: 400, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x8500, size: 500, frames: [], stream: 0 }, + { action: 'alloc', addr: 0x8600, size: 600, frames: [], stream: 0 }, + // Free a drawn element (600) + { action: 'free_completed', addr: 0x8600, size: 600, frames: [], stream: 0 }, + // Free a non-drawn element (100) — summarized should shrink + { action: 'free_completed', addr: 0x8100, size: 100, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0x8000, total_size: 0x1000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 3, true); + + // After all events: drawn active = 400+500 = 900, summarized active = 200+300 = 500 + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + // 3 drawn stripes created (400, 500, 600), 600 was freed (closed out) + assertEqual(stripes.length, 3, '3 pool stripes created (one freed)'); + + const pool_summaries = result.allocations_over_time.filter( + d => d.elem === 'summarized' && d.opacity === 0.3); + assertEqual(pool_summaries.length, 1, '1 per-pool summarized stripe'); + + // Summarized stripe should show the shrink: peak was 600 (100+200+300), + // then 100 was freed → final = 500 (200+300) + const ps = pool_summaries[0]; + const ps_max = Math.max(...ps.size); + assertEqual(ps_max, 600, 'per-pool summarized peak = 600 (before non-drawn free)'); + const ps_final = ps.size.at(-1); + assertEqual(ps_final, 500, 'per-pool summarized final = 500 (after non-drawn free)'); +} + +function test_per_pool_summarization_initially_allocated() { + console.log('test_per_pool_summarization_initially_allocated'); + // 4 free_completed events in a private pool (no matching allocs). + // Limit to 2. The 2 largest should be drawn stripes, the 2 smallest + // should be in the per-pool summarized stripe. + const poolId = [1, 70]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'free_completed', addr: 0xa100, size: 100, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0xa200, size: 200, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0xa300, size: 300, frames: [], stream: 0 }, + { action: 'free_completed', addr: 0xa400, size: 400, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0xa000, total_size: 0x1000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + const result = process_alloc_data(snapshot, 0, false, 2, true); + + // Top 2 by size: 300, 400 → drawn stripes (pre-loaded then freed) + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 2, '2 drawn stripes (largest initially_allocated)'); + const stripe_sizes = stripes.map(s => s.size).sort((a, b) => a - b); + assertEqual(stripe_sizes[0], 300, 'drawn stripe 300'); + assertEqual(stripe_sizes[1], 400, 'drawn stripe 400'); + + // Summarized: 100 + 200 = 300 initially, then freed to 0 + const pool_summaries = result.allocations_over_time.filter( + d => d.elem === 'summarized' && d.opacity === 0.3); + assertEqual(pool_summaries.length, 1, '1 per-pool summarized stripe'); + const ps_max = Math.max(...pool_summaries[0].size); + assertEqual(ps_max, 300, 'per-pool summarized peak = 300 (100+200)'); + assertEqual(pool_summaries[0].size.at(-1), 0, + 'per-pool summarized final = 0 (all freed)'); +} + +function test_per_pool_summarization_interleaved() { + console.log('test_per_pool_summarization_interleaved'); + // Drawn and non-drawn allocs interleaved: drawn stripes must not overlap + // with the summarized region. Alloc order: 1800 (drawn), 200 (non-drawn), + // 1600 (drawn), 100 (non-drawn). + const poolId = [1, 80]; + const snapshot = makeSnapshot({ + traces: [ + { action: 'alloc', addr: 0xc000, size: 1800, frames: [], stream: 0 }, + { action: 'alloc', addr: 0xc800, size: 200, frames: [], stream: 0 }, + { action: 'alloc', addr: 0xd000, size: 1600, frames: [], stream: 0 }, + { action: 'alloc', addr: 0xd800, size: 100, frames: [], stream: 0 }, + ], + segments: [{ + device: 0, address: 0xc000, total_size: 0x2000, segment_pool_id: poolId, + stream: 0, blocks: [], + }], + }); + + // max_entries=2: drawn = 1800, 1600. Non-drawn = 200, 100. + const result = process_alloc_data(snapshot, 0, false, 2, true); + + const stripes = result.allocations_over_time.filter( + d => typeof d.elem === 'number' && d.opacity === 0.5); + assertEqual(stripes.length, 2, '2 drawn stripes'); + + const pool_summaries = result.allocations_over_time.filter( + d => d.elem === 'summarized' && d.opacity === 0.3); + assertEqual(pool_summaries.length, 1, '1 per-pool summarized stripe'); + + // Summarized sits on top of drawn stripes (like global summarized band). + // Drawn stripes start at envelope base, summarized is above them. + const envelopes = result.allocations_over_time.filter( + d => typeof d.elem === 'string' && d.elem.startsWith('pool:')); + const env_base = envelopes[0].offsets.at(-1); + const sum_final = pool_summaries[0].size.at(-1); + assertEqual(sum_final, 300, 'summarized final = 300 (200+100)'); + + // Drawn stripes should start at envelope base + for (const stripe of stripes) { + const final_offset = stripe.offsets.at(-1); + assert(final_offset >= env_base, + `stripe offset ${final_offset} must be >= env_base(${env_base})`); + } + + // Summarized stripe offset should be at envelope_base + drawn_active + const sum_offset = pool_summaries[0].offsets.at(-1); + const drawn_tops = stripes.map(s => s.offsets.at(-1) + s.size); + const max_drawn_top = Math.max(...drawn_tops); + assert(sum_offset >= max_drawn_top - 1, + `summarized offset ${sum_offset} should be at or above top of drawn stripes ${max_drawn_top}`); +} + +// ============================================================ +// Run all tests +// ============================================================ + +test_basic_alloc_free(); +test_free_completed_is_matched(); +test_pool_free_without_alloc_no_inflation(); +test_pool_alloc_then_free_normal(); +test_multiple_pool_frees_without_alloc(); +test_non_pool_free_without_alloc(); +test_mixed_pool_and_nonpool(); +test_include_private_inactive_false_ignores_pools(); +test_formatSize_bytes(); +test_formatSize_kib(); +test_formatSize_mib_gib(); +test_formatSize_no_bytes(); +test_formatAddr_block_event(); +test_formatAddr_segment_event(); +test_formatAddr_free_event(); +test_elideRepeats_no_repeats(); +test_elideRepeats_two_consecutive(); +test_elideRepeats_three_or_more(); +test_elideRepeats_mixed(); +test_elideRepeats_empty(); +test_context_for_id_with_pool(); +test_context_for_id_unknown_pool(); +test_context_for_id_free_without_alloc(); +test_post177717_pool_id_from_trace_event(); +test_post177717_pool_free_without_alloc_no_segment(); +test_post177717_mixed_events_with_and_without_pool_id(); +test_pool_grouped_by_stream(); +test_segment_snapshot_with_trace_history(); +test_segment_snapshot_no_trace(); +test_default_pool_ghost_block(); +test_ghost_blocks(); +test_ghost_blocks_not_created_for_traced_addrs(); +test_ghost_blocks_default_pool_collected(); +test_ghost_blocks_not_in_segment_mode(); +test_ghost_blocks_private_pool(); +test_ghost_stripe_offset_with_multiple_pools(); +test_full_snapshot_private_pools(); +test_full_snapshot_no_private_pools(); +test_envelope_grows_on_segment_map(); +test_envelope_from_initial_reserved(); +test_envelope_segment_map_no_double_count(); +test_envelope_active_exceeds_reserved(); +test_envelope_default_pool_unaffected(); +test_per_pool_summarization(); +test_per_pool_summarization_with_frees(); +test_per_pool_summarization_initially_allocated(); +test_per_pool_summarization_interleaved(); + +console.log(`\n${passed} passed, ${failed} failed`); +process.exit(failed > 0 ? 1 : 0); diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 17af6338d0fc1..d4a0db0a1fe82 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -62,6 +62,7 @@ IS_JETSON, IS_LINUX, IS_WINDOWS, + IS_X86, parametrize, run_tests, serialTest, @@ -485,11 +486,11 @@ def join_threads(context: bool): self.assertEqual(len(observed_during_run), worker_threads) self.assertEqual(len(observed_during_run), len(set(observed_during_run))) - def payload(self, use_cuda=False): - x = torch.randn(10, 10) + def payload(self, use_cuda=False, tensor_size=10): + x = torch.randn(tensor_size, tensor_size) if use_cuda: x = x.cuda() - y = torch.randn(10, 10) + y = torch.randn(tensor_size, tensor_size) if use_cuda: y = y.cuda() z = torch.mm(x, y) @@ -1439,6 +1440,32 @@ def test_profiler_strides(self): if e["name"] == "aten::add": self.assertEqual(args["Input Strides"], [[17, 1], [25, 2], []]) + def test_profiler_strides_without_concrete_inputs(self): + torch._C._profiler._set_record_concrete_inputs_enabled_val(False) + try: + base_tensor = torch.randn(1024, dtype=torch.float32) + a = base_tensor.as_strided((16, 16), (17, 1), 0) + b = base_tensor.as_strided((16, 16), (25, 2), 272) + with _profile(record_shapes=True) as prof: + c = torch.add(a, b) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + op_events = [ + e for e in j["traceEvents"] if e.get("cat", "") == "cpu_op" + ] + for e in op_events: + args = e["args"] + if e["name"] == "aten::add": + self.assertIn("Input Strides", args) + self.assertEqual( + args["Input Strides"], [[17, 1], [25, 2], []] + ) + finally: + torch._C._profiler._set_record_concrete_inputs_enabled_val(True) + def test_profiler_fwd_bwd_link(self): with _profile(use_kineto=True) as prof: t1, t2 = ( @@ -2426,7 +2453,7 @@ def validate_json(prof, disable_external_correlation): disable_external_correlation=disable_external_correlation ), ) as prof: - self.payload(use_cuda=True) + self.payload(use_cuda=True, tensor_size=256) validate_json(prof, disable_external_correlation) @skipIfTorchDynamo("profiler gets ignored if dynamo activated") @@ -2699,21 +2726,28 @@ def test_public_api_post_processing_timeout_fails(self): y = torch.randn(10, 10) z = torch.mm(x, y) - @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") @unittest.skipIf(not kineto_available(), "Kineto is required") - def test_activity_filter_backward_compat(self): - """Plain activities=[CPU] still works unchanged.""" - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as p: - x = torch.randn(10, 10).to("cuda") - y = torch.mm(x, x) + def test_profiler(self): + """Basic test for torch.profiler.profile API.""" + use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() + activities = [ProfilerActivity.CPU] + if use_cuda: + activities.append(ProfilerActivity.CUDA) + with profile(activities=activities) as p: + self.payload(use_cuda=use_cuda) events = p.events() self.assertGreater(len(events), 0) - has_overhead = any( - "Lazy Function Loading" in e.name for e in events - ) # Lazy Function Loading is an OVERHEAD event - self.assertTrue(has_overhead) + found_mm = False + for e in events: + if "aten::mm" in e.name: + found_mm = True + self.assertTrue(found_mm) + if use_cuda: + gpu_events = [e for e in events if e.device_type == DeviceType.CUDA] + self.assertGreater(len(gpu_events), 0, "No GPU events captured by profiler") @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + @unittest.skipIf(TEST_WITH_ROCM, "not supported on ROCm") @unittest.skipIf(not kineto_available(), "Kineto is required") def test_activity_filter_dict_syntax(self): """Dict syntax collects only the requested activity types.""" @@ -2790,6 +2824,42 @@ def test_activity_filter_empty_list(self): y = torch.mm(x, x) self.assertEqual(len(p.events()), 0) + @unittest.skipIf(not kineto_available(), "Kineto is required") + @unittest.skipIf(not TEST_CUDA, "CUDA is required") + def test_kineto_kernel_metadata_in_trace(self): + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + self.payload(use_cuda=True) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + trace = json.load(f) + events = trace["traceEvents"] + kernel_events = [e for e in events if e.get("cat", "") == "kernel"] + self.assertGreater( + len(kernel_events), 0, "Error: No kernel events in trace" + ) + has_kernel_launch_metadata = False + for ke in kernel_events: + args = ke.get("args", {}) + name = ke.get("name", "") + for key in ["device", "stream", "correlation"]: + self.assertIn(key, args, f"kernel '{name}' missing '{key}'") + # Some kernel events on ROCm (__amd_rocclr...) do not have grid/block metadata + # so we just validate that it shows up for at least one event + has_grid = "grid" in args + has_block = "block" in args + self.assertEqual( + has_grid, + has_block, + f"kernel '{name}' should provide grid and block together", + ) + has_kernel_launch_metadata |= has_grid + self.assertTrue( + has_kernel_launch_metadata, + "Error: No kernel events in trace contained grid/block metadata", + ) + class SimpleNet(nn.Module): def __init__(self) -> None: @@ -3101,7 +3171,7 @@ def format_queue_depth(queue_depth_list, events): def test_utils_compute_queue_depth_when_no_cuda_events(self): # For traces with only cpu events, we expect empty queue depth list - x = torch.ones((1024, 1024)) + x = torch.ones((100, 100)) with profile() as prof: for _ in range(5): x = x @ x @@ -3151,7 +3221,7 @@ def test_utils_get_optimizable_events(self): ) def test_profiler_name_pattern(self): - x = torch.ones((4096, 4096)) + x = torch.ones((100, 100)) with profile() as prof: for _ in range(5): x = x @ x @@ -3415,7 +3485,9 @@ def test_profiler_pattern_matcher_json_report(self): actual_fields = sorted(event.keys()) self.assertEqual(expected_fields, actual_fields) - @unittest.skipIf(IS_ARM64 or not IS_LINUX, "x86 linux only cpp unwinding") + @unittest.skipIf( + not IS_LINUX or not (IS_X86 or IS_ARM64), "linux x86/aarch64 only cpp unwinding" + ) def test_fuzz_symbolize(self): # generate some random addresses in the text section and make sure the # symbolizers do not throw exceptions/crash @@ -3708,5 +3780,457 @@ def test_privateuse1_fallback_requires_use_cpu(self): ) +@unittest.skipIf(not kineto_available(), "Kineto is required") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required") +class TestProfilerEventsParity(TestCase): + """Tests validating parity between events() and export_chrome_trace() JSON.""" + + def test_python_function_events_in_events(self): + class DummyModule(nn.Module): + def forward(self, x): + return x + 1 + + mod = DummyModule() + with profile( + activities=[ProfilerActivity.CPU], + with_stack=True, + experimental_config=_ExperimentalConfig(verbose=True), + ) as prof: + mod(torch.randn(4, 4)) + + events = prof.events() + python_events = [e for e in events if e.is_python_function] + self.assertGreater(len(python_events), 0) + for e in python_events: + self.assertIsInstance(e.name, str) + self.assertGreater(e.time_range.end - e.time_range.start, 0) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + trace = json.load(f) + + json_py = [ + e + for e in trace["traceEvents"] + if e.get("cat") == "python_function" and e.get("ph") == "X" + ] + self.assertEqual(len(python_events), len(json_py)) + + # Verify python_id/parent_id/module_id parity with JSON args + fe_mod = next((e for e in events if "DummyModule" in e.name), None) + self.assertIsNotNone(fe_mod) + self.assertGreater(fe_mod.python_id, 0) + self.assertGreaterEqual(fe_mod.python_module_id, 0) + + json_mod = next( + (e for e in json_py if "DummyModule" in e.get("name", "")), + None, + ) + self.assertIsNotNone(json_mod) + args = json_mod["args"] + self.assertEqual(fe_mod.python_id, args["Python id"]) + self.assertEqual(fe_mod.python_parent_id, args["Python parent id"]) + self.assertEqual(fe_mod.python_module_id, args["Python module id"]) + + def test_profiler_flow_events_parity(self): + """Verify that async CPU->GPU flow fields on events() match Chrome trace JSON.""" + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + x = torch.randn(32, 32, device="cuda") + torch.mm(x, x) + + # Collect async CPU->GPU flow info from events() + events_with_flow = [ + e for e in prof.events() if e.flow_id is not None and e.flow_id != 0 + ] + self.assertGreater( + len(events_with_flow), 0, "No flow events found via events()" + ) + + for e in events_with_flow: + self.assertIsInstance(e.flow_id, int) + self.assertIsInstance(e.flow_type, int) + self.assertIsInstance(e.flow_start, bool) + + # Verify parity with Chrome trace JSON for async CPU->GPU flow + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_flow_events = [ + e + for e in j["traceEvents"] + if e.get("ph") in ("s", "f") and e.get("cat") == "ac2g" + ] + json_flow_starts = {e["id"] for e in json_flow_events if e["ph"] == "s"} + json_flow_ends = {e["id"] for e in json_flow_events if e["ph"] == "f"} + + # kLinkAsyncCpuGpu = 2 + ac2g_events = [e for e in events_with_flow if e.flow_type == 2] + events_flow_starts = {e.flow_id for e in ac2g_events if e.flow_start} + events_flow_ends = {e.flow_id for e in ac2g_events if not e.flow_start} + + self.assertEqual( + json_flow_starts, + events_flow_starts, + "Async CPU->GPU flow start IDs differ between events() and Chrome trace", + ) + self.assertEqual( + json_flow_ends, + events_flow_ends, + "Async CPU->GPU flow end IDs differ between events() and Chrome trace", + ) + + def test_profiler_fwdbwd_flow_events_parity(self): + """Verify that fwd->bwd flow fields on events() match Chrome trace JSON.""" + with profile(activities=[ProfilerActivity.CPU]) as prof: + t1 = torch.ones(1, requires_grad=True) + t2 = torch.ones(1, requires_grad=True) + z = torch.add(t1, t2) + y = torch.ones(1) + loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y) + loss.backward() + + fwdbwd_events = [ + e for e in prof.events() if e.flow_type == 1 and e.flow_id != 0 + ] + self.assertGreater( + len(fwdbwd_events), 0, "No fwdbwd flow events found via events()" + ) + + events_flow_starts = {e.flow_id for e in fwdbwd_events if e.flow_start} + events_flow_ends = {e.flow_id for e in fwdbwd_events if not e.flow_start} + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_flow_events = [ + e + for e in j["traceEvents"] + if e.get("ph") in ("s", "f") and e.get("cat") == "fwdbwd" + ] + json_flow_starts = {e["id"] for e in json_flow_events if e["ph"] == "s"} + json_flow_ends = {e["id"] for e in json_flow_events if e["ph"] == "f"} + + self.assertEqual( + json_flow_starts, + events_flow_starts, + "fwdbwd flow start IDs differ between events() and Chrome trace", + ) + self.assertEqual( + json_flow_ends, + events_flow_ends, + "fwdbwd flow end IDs differ between events() and Chrome trace", + ) + + def test_profiler_timestamp_consistency(self): + """Verify that FunctionEvent timestamps can reconstruct Chrome trace ts values.""" + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + x = torch.randn(32, 32, device="cuda") + torch.mm(x, x) + + trace_start_ns = prof.profiler.kineto_results.trace_start_ns() + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + # Chrome trace is relative to a different base time which is not exposed in Python. + # It's probably not important to do so as we still have the relative differences + # in duration. + base_time_ns = j.get("baseTimeNanoseconds", 0) + + # Grab mm timestamp from events() and json + fe_mm = next((e for e in prof.events() if e.name == "aten::mm"), None) + json_mm = next( + ( + e + for e in j["traceEvents"] + if e.get("name") == "aten::mm" and e.get("ph") == "X" + ), + None, + ) + + # Reconstruct Chrome trace ts from events(): + # absolute_ns = mm_op_start_us * 1000 + trace_start_ns + # chrome_ts = (absolute_ns - base_time_ns) / 1000 -> realign with json timeframe + absolute_ns = int(fe_mm.time_range.start * 1000) + trace_start_ns + recovered_ts = (absolute_ns - base_time_ns) / 1000 + self.assertEqual( + recovered_ts, + json_mm["ts"], + msg="Recovered Chrome trace ts doesn't match JSON for aten::mm", + ) + + def test_profiler_op_args_events_parity(self): + """Verify that cpu_op args on events() match Chrome trace JSON args.""" + base_tensor = torch.randn(1024, dtype=torch.float32) + a = base_tensor.as_strided((16, 16), (17, 1), 0) + b = base_tensor.as_strided((16, 16), (25, 2), 272) + t1 = torch.ones((64, 32)) + t2 = torch.ones((64, 32)) + with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: + torch.add(a, b) + torch.cat([t1, t2]) + + fe_add = next((e for e in prof.events() if e.name == "aten::add"), None) + self.assertIsNotNone(fe_add) + fe_cat = next((e for e in prof.events() if e.name == "aten::cat"), None) + self.assertIsNotNone(fe_cat) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_add = next( + ( + e + for e in j["traceEvents"] + if e.get("name") == "aten::add" and e.get("cat") == "cpu_op" + ), + None, + ) + self.assertIsNotNone(json_add) + args = json_add["args"] + self.assertEqual(fe_add.structured_input_shapes, args["Input Dims"]) + self.assertEqual(fe_add.structured_input_strides, args["Input Strides"]) + self.assertEqual(fe_add.input_dtypes, args["Input type"]) + + # Test a case with TensorList inputs -- structured_input_shapes + # should handle TensorList nesting correctly. + json_cat = next( + ( + e + for e in j["traceEvents"] + if e.get("name") == "aten::cat" and e.get("cat") == "cpu_op" + ), + None, + ) + self.assertIsNotNone(json_cat) + args_cat = json_cat["args"] + self.assertEqual(fe_cat.structured_input_shapes, args_cat["Input Dims"]) + self.assertEqual(fe_cat.structured_input_strides, args_cat["Input Strides"]) + self.assertEqual(fe_cat.input_dtypes, args_cat["Input type"]) + + def test_profiler_external_id_parity(self): + """Verify that FunctionEvent.external_id matches External id in Chrome trace JSON.""" + from collections import Counter + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with torch.profiler.record_function("test_region"): + x = torch.randn(32, 32, device="cuda") + y = torch.mm(x, x) + z = y + x + z.cpu() + torch.cuda.synchronize() + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_name_ext = Counter( + (e["name"], e["args"]["External id"]) + for e in j["traceEvents"] + if e.get("args", {}).get("External id") is not None + ) + events_name_ext = Counter( + (ev.name, ev.external_id) for ev in prof.events() if ev.external_id != 0 + ) + + self.assertEqual( + events_name_ext, + json_name_ext, + "(name, external_id) pairs differ between events() and Chrome trace JSON", + ) + + def test_profiler_activity_type_parity(self): + """Verify activity_type on events() matches Chrome trace cat field.""" + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + x = torch.randn(32, 32, device="cuda") + torch.mm(x, x) + + events = prof.events() + for e in events: + self.assertIsInstance(e.activity_type, str) + self.assertGreater(len(e.activity_type), 0) + + mm_event = next((e for e in events if e.name == "aten::mm"), None) + self.assertIsNotNone(mm_event) + self.assertEqual(mm_event.activity_type, "cpu_op") + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + j = json.load(f) + + json_name_cats = { + (e["name"], e["cat"]) + for e in j["traceEvents"] + if e.get("ph") == "X" and "cat" in e + } + for e in events: + self.assertIn( + (e.name, e.activity_type), + json_name_cats, + f"activity_type mismatch for {e.name}", + ) + + def test_structured_metadata_matches_chrome_trace(self): + # Compare metadata fields between events() and Chrome trace JSON to make sure they stay in parity + # 1. Run a dummy workload with profiling enabled and collect the json/events() outputs + # 2. Parse each event instance in the json and events() to create a key->value mapping + # - The key is a tuple of metadata fields that should be unique for each event + # - The value is a dict of metadata fields for that event + # 3. Ensure that the keys and values match between the json and events() outputs + + from torch.autograd.profiler_util import _EVENT_METADATA_KEYS + + target_cats = ("cuda_runtime", "gpu_memcpy", "kernel") + allowed_non_structured_trace_keys = { + "External id", + "correlation", + "cbid", + "cid", + "device", + "kind", + "kernel", + "ptr", + "src", + "dst", + } + supported_trace_keys = set(_EVENT_METADATA_KEYS).union( + allowed_non_structured_trace_keys + ) + + def metadata_dict_from_trace_args(args): + out = {} + for kineto_key, (field_name, convert) in _EVENT_METADATA_KEYS.items(): + if kineto_key in args: + raw_value = args[kineto_key] + out[field_name] = ( + convert(raw_value) if isinstance(raw_value, str) else raw_value + ) + return out + + def metadata_dict_from_function_event(fe): + if fe.event_metadata is None: + return {} + + out = {} + for field_name, _ in _EVENT_METADATA_KEYS.values(): + val = getattr(fe.event_metadata, field_name) + if val is not None: + out[field_name] = val + return out + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + experimental_config=torch._C._profiler._ExperimentalConfig( + expose_kineto_event_metadata=True + ), + ) as prof: + x = torch.randn(10, 10, device="cuda") + y = torch.mm(x, x) + z = x + y + z.cpu() + + # Build a mapping from key to events() FunctionEvent metadata + event_records = {} + for fe in prof.events(): + if fe.external_id == 0 or fe.id == 0 or fe.activity_type not in target_cats: + continue + # Using just one of these keys could result in collisions, so try to uniquely identify the event with all of them + key = (fe.name, fe.activity_type, fe.external_id, fe.id) + self.assertNotIn( + key, + event_records, + f"Duplicate FunctionEvent record key encountered: {key}", + ) + event_records[key] = metadata_dict_from_function_event(fe) + + with TemporaryFileName(mode="w+") as fname: + prof.export_chrome_trace(fname) + with open(fname) as f: + trace = json.load(f) + + json_records = {} + # Track unexpected (event_name, cat, key) combos, deduplicated + unexpected_combos: set[tuple[str, str, str]] = set() + + # Loop through the trace events to perform a comparison + for te in trace["traceEvents"]: + cat = te.get("cat", "") + args = te.get("args", {}) + ext_id = args.get("External id") + correlation = args.get("correlation") + + if ext_id is None or correlation is None: + continue + if cat not in target_cats: + continue + + # Any metadata keys that show up in JSON should show up in events() + for k in set(args) - supported_trace_keys: + unexpected_combos.add((te["name"][:100], cat, k)) + + # Build the same key from JSON to try to match with a FunctionEvent + key = (te["name"], te["cat"], ext_id, correlation) + self.assertNotIn( + key, + json_records, + f"Duplicate Chrome trace record key encountered: {key}", + ) + json_records[key] = metadata_dict_from_trace_args(args) + + failure_msg = """\ +==================================================================================== +IMPORTANT: Are you making a Kineto change or bumping the third_party/kineto +submodule hash and seeing this message? + +New metadata keys (see below) were found in the Chrome trace JSON that are not +yet exposed through the profiler's events() API (i.e. EventMetadata in +torch/autograd/profiler_util.py). + +To fix this properly, you need to make sure the new Kineto data makes its way +to the events() property. The steps are: + +1. Add the new key(s) to _EVENT_METADATA_KEYS in torch/autograd/profiler_util.py + with the appropriate field name and type converter. +2. Add corresponding field(s) to the EventMetadata dataclass in the same file. +3. If the key should NOT be mapped (e.g. it duplicates an existing FunctionEvent + attribute), add it to allowed_non_structured_trace_keys in this test instead. + +For a model PR to follow, see: https://github.com/pytorch/pytorch/pull/180100 +====================================================================================""" + if unexpected_combos: + summary = "\n".join( + f" {name} ({cat}): {key!r}" + for name, cat, key in sorted(unexpected_combos) + ) + raise AssertionError(f"\n{failure_msg}\n\nUnmapped keys:\n{summary}") + + self.assertGreater(len(json_records), 0, "No device-side records were compared") + self.assertEqual( + set(event_records), + set(json_records), + "Device event identities differ between events() and Chrome trace JSON", + ) + + for key in json_records: + expected_meta = json_records[key] + actual_meta = event_records[key] + self.assertEqual( + actual_meta, + expected_meta, + f"{key}: structured metadata differs between events() and Chrome trace JSON", + ) + + if __name__ == "__main__": run_tests() diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index 29e3a61729c4f..04e1aad2e23d8 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -881,7 +881,7 @@ def test_profiler_experimental_tree_cuda(self): aten::add_ cudaLaunchKernel void at::native::vectorized_elementwise_kernel<...>(...) - [memory]""", # noqa: B950 + [memory]""", allow_failure=ALLOW_CUDA_FAILURE, ) @@ -1150,7 +1150,7 @@ def step(): enum.py(...): __hash__ - ...""", # noqa: B950 + ...""", allow_failure=ALLOW_CUDA_FAILURE, ) diff --git a/test/profiler/test_trace_validator.py b/test/profiler/test_trace_validator.py new file mode 100644 index 0000000000000..41646d1e5b4f1 --- /dev/null +++ b/test/profiler/test_trace_validator.py @@ -0,0 +1,273 @@ +# Owner(s): ["oncall: profiler"] + +import json +import os +import shutil +import tempfile +import unittest + +import torch +import torch.nn as nn +from torch._C._profiler import _ExperimentalConfig +from torch.profiler import profile, ProfilerActivity, record_function +from torch.profiler._trace_validator import ( + _check_backward_seq_id_uniqueness, + _check_gpu_kernel_causality, + _check_nccl_metadata, + _check_stream_sync_overlap, + _check_stream_wait_corr_id_in_past, + _check_stream_wait_corr_id_populated, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + skipIfTorchDynamo, + TestCase, +) + + +class _Bottleneck(nn.Module): + def __init__(self, in_ch, mid_ch, out_ch, stride=1): + super().__init__() + self.conv1 = nn.Conv2d(in_ch, mid_ch, 1, bias=False) + self.bn1 = nn.BatchNorm2d(mid_ch) + self.conv2 = nn.Conv2d(mid_ch, mid_ch, 3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(mid_ch) + self.conv3 = nn.Conv2d(mid_ch, out_ch, 1, bias=False) + self.bn3 = nn.BatchNorm2d(out_ch) + self.relu = nn.ReLU(inplace=True) + self.downsample = ( + nn.Sequential( + nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), + nn.BatchNorm2d(out_ch), + ) + if in_ch != out_ch or stride != 1 + else None + ) + + def forward(self, x): + identity = x + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + if self.downsample is not None: + identity = self.downsample(x) + return self.relu(out + identity) + + +def _make_layer(in_ch, mid_ch, out_ch, n, stride=1): + layers = [_Bottleneck(in_ch, mid_ch, out_ch, stride=stride)] + for _ in range(n - 1): + layers.append(_Bottleneck(out_ch, mid_ch, out_ch)) + return nn.Sequential(*layers) + + +def _resnet50(): + return nn.Sequential( + nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(3, stride=2, padding=1), + _make_layer(64, 64, 256, n=3), + _make_layer(256, 128, 512, n=4, stride=2), + _make_layer(512, 256, 1024, n=6, stride=2), + _make_layer(1024, 512, 2048, n=3, stride=2), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(2048, 1000), + ) + + +def _profile_resnet_payload(trace_path): + """Profile a ResNet50 training loop and return events.""" + device = torch.device("cuda:0") + model = _resnet50().to(device) + inputs = torch.randn(4, 3, 224, 224, device=device) + outputs = torch.rand_like(model(inputs)) + optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) + loss_fn = nn.CrossEntropyLoss() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True), + ) as prof: + for _ in range(3): + optimizer.zero_grad(set_to_none=True) + with record_function("## forward ##"): + pred = model(inputs) + with record_function("## backward ##"): + loss_fn(pred, outputs).backward() + with record_function("## optimizer ##"): + optimizer.step() + + torch.cuda.synchronize() + prof.export_chrome_trace(trace_path) + return _load_events(trace_path) + + +def _profile_complex_payload(trace_path): + """Profile multi-stream + CUDA events + forward/backward and return events.""" + device = torch.device("cuda:0") + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + experimental_config=_ExperimentalConfig(enable_cuda_sync_events=True), + ) as prof: + x = torch.randn(32, 32, device=device, requires_grad=True) + y = torch.mm(x, x) + loss = y.sum() + loss.backward() + + s1 = torch.cuda.Stream() + s2 = torch.cuda.Stream() + event = torch.cuda.Event() + with torch.cuda.stream(s1): + a = torch.randn(64, 64, device=device) + _b = torch.mm(a, a) + event.record(s1) + s2.wait_event(event) + with torch.cuda.stream(s2): + _c = torch.mm(a, a) + s2.synchronize() + + torch.cuda.synchronize() + prof.export_chrome_trace(trace_path) + return _load_events(trace_path) + + +def _load_events(trace_path): + with open(trace_path) as f: + data = json.load(f) + return data.get("traceEvents", data) + + +class TestTraceValidatorRules(TestCase): + """Synthetic tests for rules not exercised by real E2E payloads.""" + + def test_nccl_metadata_pass(self): + events = [ + { + "ph": "X", + "name": "record_param_comms", + "ts": 100, + "dur": 10, + "args": { + "Collective name": "all_reduce", + "dtype": "float32", + "In msg nelems": 1024, + "Out msg nelems": 1024, + "Group size": 8, + }, + }, + ] + self.assertEqual(_check_nccl_metadata(events), []) + + def test_nccl_metadata_fail(self): + events = [ + { + "ph": "X", + "name": "record_param_comms", + "ts": 100, + "dur": 10, + "args": {"Collective name": "all_reduce"}, + }, + ] + v = _check_nccl_metadata(events) + self.assertEqual(len(v), 1) + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@skipIfTorchDynamo("profiler tests do not work with dynamo") +@instantiate_parametrized_tests +class TestTraceValidatorE2E(TestCase): + _trace_dir: str = "" + _payloads: dict = {} + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls._trace_dir = tempfile.mkdtemp(prefix="profiler_e2e_trace_") + cls._payloads = { + "resnet": _profile_resnet_payload( + os.path.join(cls._trace_dir, "resnet.json") + ), + "complex": _profile_complex_payload( + os.path.join(cls._trace_dir, "complex.json") + ), + } + + @classmethod + def tearDownClass(cls): + if cls._trace_dir and os.path.isdir(cls._trace_dir): + shutil.rmtree(cls._trace_dir, ignore_errors=True) + super().tearDownClass() + + def _events(self, payload): + return self._payloads[payload] + + @staticmethod + def _fmt(violations, limit=5): + lines = [f" {v}" for v in violations[:limit]] + if len(violations) > limit: + lines.append(f" ... and {len(violations) - limit} more") + return "\n".join(lines) + + # TODO: unskip once kineto fixes CPU/GPU timestamp synchronization + @unittest.skip( + "kineto reports GPU kernel timestamps before cudaLaunchKernel due to clock skew" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_gpu_kernel_causality(self, payload): + v = _check_gpu_kernel_causality(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + # TODO: unskip once kineto populates wait_on_cuda_event_record_corr_id for cuStreamWaitEvent + @unittest.skip( + "kineto does not populate wait_on_cuda_event_record_corr_id for stream wait events (returns -1)" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_stream_wait_corr_id_populated(self, payload): + v = _check_stream_wait_corr_id_populated(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + # TODO: unskip once kineto stream sync event emission is verified in integration testing + @unittest.skip( + "kineto stream sync overlap detection not yet verified in kineto integration testing" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_stream_sync_overlap(self, payload): + v = _check_stream_sync_overlap(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + # TODO: unskip once kineto populates wait_on_cuda_event_record_corr_id for cuStreamWaitEvent + @unittest.skip( + "kineto wait_on_cuda_event_record_corr_id temporal ordering not yet verified in kineto integration testing" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_stream_wait_corr_id_in_past(self, payload): + v = _check_stream_wait_corr_id_in_past(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + # TODO: unskip once kineto NCCL collective metadata is verified in integration testing + @unittest.skip( + "kineto NCCL collective metadata not yet verified in kineto integration testing" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_nccl_metadata(self, payload): + v = _check_nccl_metadata(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + # TODO: unskip once kineto backward sequence ID emission is verified in integration testing + @unittest.skip( + "kineto backward sequence ID uniqueness not yet verified in kineto integration testing" + ) + @parametrize("payload", ["resnet", "complex"]) + def test_backward_seq_id_uniqueness(self, payload): + v = _check_backward_seq_id_uniqueness(self._events(payload)) + self.assertEqual(len(v), 0, self._fmt(v)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/python_native/test_dsl_registry.py b/test/python_native/test_dsl_registry.py new file mode 100644 index 0000000000000..8b6012aa49788 --- /dev/null +++ b/test/python_native/test_dsl_registry.py @@ -0,0 +1,362 @@ +# Owner(s): ["module: dsl-native-ops"] + +from unittest.mock import Mock + +from torch._vendor.packaging.version import Version +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestDSLRegistry(TestCase): + """Comprehensive tests for DSL registry functionality.""" + + def setUp(self): + """Set up clean registry state for each test""" + # Import registry here to avoid import-time side effects + from torch._native.dsl_registry import DSLRegistry + + # Save original registry state + self.original_registry = None + try: + from torch._native.dsl_registry import dsl_registry as original + + self.original_modules = original._dsl_modules.copy() + self.original_registry = original + except ImportError: + pass + + # Create isolated registry for testing + self.test_registry = DSLRegistry() + + def tearDown(self): + """Restore original registry state""" + if self.original_registry is not None: + # Restore original DSL modules + self.original_registry._dsl_modules.clear() + self.original_registry._dsl_modules.update(self.original_modules) + + # Clear any caches + if hasattr(self.original_registry.is_dsl_available, "cache_clear"): + self.original_registry.is_dsl_available.cache_clear() + if hasattr(self.original_registry.get_dsl_version, "cache_clear"): + self.original_registry.get_dsl_version.cache_clear() + if hasattr(self.original_registry.list_available_dsls, "cache_clear"): + self.original_registry.list_available_dsls.cache_clear() + + def create_valid_mock_dsl(self, name="test_dsl", available=True, version="1.0.0"): + """Helper to create valid mock DSL for testing""" + mock = Mock() + mock.runtime_available.return_value = available + mock.runtime_version.return_value = Version(version) if version else None + mock.deregister_op_overrides = Mock() + mock.register_op_override = Mock() + return mock + + def create_broken_mock_dsl( + self, break_method="runtime_available", error=ImportError + ): + """Helper to create mock DSL with broken methods for error testing""" + mock = self.create_valid_mock_dsl() + getattr(mock, break_method).side_effect = error("Simulated error") + return mock + + # Phase 1: Basic Registry Operations (4 methods) + + def test_register_dsl_basic(self): + """Test basic DSL registration functionality""" + mock_dsl = self.create_valid_mock_dsl() + self.test_registry.register_dsl("test_dsl", mock_dsl) + + # Verify registration + self.assertIn("test_dsl", self.test_registry.list_all_dsls()) + self.assertTrue(self.test_registry.is_dsl_available("test_dsl")) + + # Verify the DSL module is stored correctly + self.assertEqual(len(self.test_registry.list_all_dsls()), 1) + + def test_register_dsl_duplicates(self): + """Test duplicate DSL registration scenarios""" + scenarios = [ + ("same_module", "re-registered with same module"), + ("different_module", "re-registered with different module"), + ] + + for scenario, expected_log_contains in scenarios: + with self.subTest(scenario=scenario): + # Register initial DSL + mock_dsl1 = self.create_valid_mock_dsl() + dsl_name = f"duplicate_test_{scenario}" + self.test_registry.register_dsl(dsl_name, mock_dsl1) + + # Register duplicate based on scenario + if scenario == "same_module": + mock_dsl2 = mock_dsl1 # Same module object + else: + mock_dsl2 = self.create_valid_mock_dsl() # Different module object + + # Capture logging + with self.assertLogs( + "torch._native.dsl_registry", level="DEBUG" + ) as log_capture: + self.test_registry.register_dsl(dsl_name, mock_dsl2) + + # Verify appropriate logging + log_messages = " ".join(log_capture.output) + self.assertIn(expected_log_contains, log_messages) + + # Verify DSL is still registered + self.assertIn(dsl_name, self.test_registry.list_all_dsls()) + + def test_get_dsl_version(self): + """Test DSL version querying with various conditions""" + test_cases = [("1.2.3", Version), (None, type(None)), ("error", type(None))] + + for version_setup, expected_result_type in test_cases: + with self.subTest(version_setup=version_setup): + if version_setup == "error": + mock_dsl = self.create_broken_mock_dsl( + "runtime_version", RuntimeError + ) + else: + mock_dsl = self.create_valid_mock_dsl(version=version_setup) + + dsl_name = f"version_test_{version_setup}" + self.test_registry.register_dsl(dsl_name, mock_dsl) + + result = self.test_registry.get_dsl_version(dsl_name) + self.assertIsInstance(result, expected_result_type) + + if version_setup == "1.2.3": + self.assertEqual(result, Version("1.2.3")) + + def test_list_dsls_operations(self): + """Test list_all_dsls and list_available_dsls""" + # Register mix of available/unavailable DSLs + available_dsl = self.create_valid_mock_dsl("available", available=True) + unavailable_dsl = self.create_valid_mock_dsl("unavailable", available=False) + + self.test_registry.register_dsl("available_dsl", available_dsl) + self.test_registry.register_dsl("unavailable_dsl", unavailable_dsl) + + # Test list operations + all_dsls = self.test_registry.list_all_dsls() + available_dsls = self.test_registry.list_available_dsls() + + # Verify results + self.assertEqual(set(all_dsls), {"available_dsl", "unavailable_dsl"}) + self.assertEqual(set(available_dsls), {"available_dsl"}) + + # Verify available is subset of all + self.assertTrue(set(available_dsls).issubset(set(all_dsls))) + + # Phase 2: Input Validation (2 methods) + + def test_register_dsl_invalid_names(self): + """Test registration with invalid name inputs""" + test_cases = [ + (None, TypeError), + (123, TypeError), + ([], TypeError), + ({}, TypeError), + ("", ValueError), + (" ", ValueError), + ("\t\n", ValueError), + ] + + mock_dsl = self.create_valid_mock_dsl() + + for invalid_name, expected_exception in test_cases: + with self.subTest(invalid_name=repr(invalid_name)): + with self.assertRaises(expected_exception): + self.test_registry.register_dsl(invalid_name, mock_dsl) + + def test_register_dsl_valid_names(self): + """Test registration with valid name formats""" + valid_names = ["triton", "cutedsl", "my_dsl", "dsl_v2", "a", "test_dsl_123"] + + for valid_name in valid_names: + with self.subTest(valid_name=valid_name): + mock_dsl = self.create_valid_mock_dsl() + self.test_registry.register_dsl(valid_name, mock_dsl) + self.assertIn(valid_name, self.test_registry.list_all_dsls()) + + # Phase 3: Error Handling (3 methods) + + def test_is_dsl_available_errors(self): + """Test is_dsl_available when runtime_available() raises errors""" + error_types = [ImportError, ModuleNotFoundError, RuntimeError, AttributeError] + + for error_type in error_types: + with self.subTest(error_type=error_type.__name__): + mock_dsl = self.create_broken_mock_dsl("runtime_available", error_type) + dsl_name = f"broken_dsl_{error_type.__name__}" + self.test_registry.register_dsl(dsl_name, mock_dsl) + + result = self.test_registry.is_dsl_available(dsl_name) + self.assertEqual(result, False) # All errors should result in False + + def test_get_dsl_version_errors(self): + """Test get_dsl_version when runtime_version() raises errors""" + error_types = [ImportError, RuntimeError, AttributeError, TypeError] + + for error_type in error_types: + with self.subTest(error_type=error_type.__name__): + mock_dsl = self.create_broken_mock_dsl("runtime_version", error_type) + dsl_name = f"broken_version_dsl_{error_type.__name__}" + self.test_registry.register_dsl(dsl_name, mock_dsl) + + result = self.test_registry.get_dsl_version(dsl_name) + self.assertIsNone(result) + + def test_nonexistent_dsl_queries(self): + """Test querying non-existent DSLs returns appropriate defaults""" + # Test with empty registry + self.assertFalse(self.test_registry.is_dsl_available("nonexistent")) + self.assertIsNone(self.test_registry.get_dsl_version("nonexistent")) + self.assertEqual(self.test_registry.list_all_dsls(), []) + self.assertEqual(self.test_registry.list_available_dsls(), []) + + # Test with some DSLs registered but querying non-existent + mock_dsl = self.create_valid_mock_dsl() + self.test_registry.register_dsl("existing", mock_dsl) + + self.assertFalse(self.test_registry.is_dsl_available("still_nonexistent")) + self.assertIsNone(self.test_registry.get_dsl_version("still_nonexistent")) + + # Phase 4: Protocol & Integration (4 methods) + + def test_dsl_protocol_interface(self): + """Test DSL modules implement complete protocol""" + # Test with actual registered DSLs from the global registry + if self.original_registry: + all_dsls = self.original_registry.list_all_dsls() + + for dsl_name in all_dsls: + # Get the actual DSL module from registry + dsl_module = self.original_registry._dsl_modules.get(dsl_name) + if dsl_module: + # Verify all required methods exist and are callable + required_methods = [ + "runtime_available", + "runtime_version", + "deregister_op_overrides", + "register_op_override", + ] + + for method_name in required_methods: + self.assertTrue( + hasattr(dsl_module, method_name), + f"DSL '{dsl_name}' missing method '{method_name}'", + ) + method = getattr(dsl_module, method_name) + self.assertTrue( + callable(method), + f"DSL '{dsl_name}' method '{method_name}' is not callable", + ) + + def test_real_dsl_integration(self): + """Test integration with actual DSL modules""" + if not self.original_registry: + self.skipTest("Original registry not available") + + dsl_names = ["triton", "cutedsl"] + + for dsl_name in dsl_names: + with self.subTest(dsl_name=dsl_name): + # Verify DSL is registered in original registry + all_dsls = self.original_registry.list_all_dsls() + if dsl_name not in all_dsls: + self.skipTest( + f"DSL '{dsl_name}' not registered in original registry" + ) + + # Test all registry operations work + availability = self.original_registry.is_dsl_available(dsl_name) + self.assertIsInstance(availability, bool) + + version = self.original_registry.get_dsl_version(dsl_name) + self.assertTrue(version is None or isinstance(version, Version)) + + # Verify DSL appears in appropriate lists + if availability: + self.assertIn( + dsl_name, self.original_registry.list_available_dsls() + ) + self.assertIn(dsl_name, self.original_registry.list_all_dsls()) + + def test_common_utils_wrappers(self): + """Test common_utils wrapper functions work correctly""" + from torch.testing._internal.common_utils import ( + get_all_dsls, + get_available_dsls, + ) + + # Compare wrapper results with direct registry calls + if self.original_registry: + self.assertEqual(get_all_dsls(), self.original_registry.list_all_dsls()) + self.assertEqual( + get_available_dsls(), self.original_registry.list_available_dsls() + ) + + def test_skip_decorators(self): + """Test DSL skip decorators work with registry""" + from torch.testing._internal.common_utils import ( + skipIfDSLUnavailable, + skipUnlessDSLAvailable, + ) + + # Test decorators are callable + self.assertTrue(callable(skipIfDSLUnavailable)) + self.assertTrue(callable(skipUnlessDSLAvailable)) + + # Test decorator creation works + decorator1 = skipIfDSLUnavailable("nonexistent_dsl") + decorator2 = skipUnlessDSLAvailable("triton") + + self.assertTrue(callable(decorator1)) + self.assertTrue(callable(decorator2)) + + # Phase 5: Test Infrastructure (2 methods) + + def test_registry_isolation(self): + """Test registry state can be saved and restored""" + # Verify we start with clean test registry + self.assertEqual(self.test_registry.list_all_dsls(), []) + + # Register test DSL + mock_dsl = self.create_valid_mock_dsl() + self.test_registry.register_dsl("isolation_test", mock_dsl) + + # Verify registration + self.assertIn("isolation_test", self.test_registry.list_all_dsls()) + + # Verify original registry is unaffected + if self.original_registry: + original_dsls = self.original_registry.list_all_dsls() + self.assertNotIn("isolation_test", original_dsls) + + def test_mock_dsl_helpers(self): + """Test mock DSL creation utilities work correctly""" + # Test create_valid_mock_dsl + valid_mock = self.create_valid_mock_dsl("test", available=True, version="2.1.0") + + self.assertTrue(callable(valid_mock.runtime_available)) + self.assertTrue(callable(valid_mock.runtime_version)) + self.assertTrue(callable(valid_mock.deregister_op_overrides)) + self.assertTrue(callable(valid_mock.register_op_override)) + + # Test behavior + self.assertTrue(valid_mock.runtime_available()) + self.assertEqual(valid_mock.runtime_version(), Version("2.1.0")) + + # Test create_broken_mock_dsl + broken_mock = self.create_broken_mock_dsl("runtime_available", ImportError) + + with self.assertRaises(ImportError): + broken_mock.runtime_available() + + # Other methods should still work + self.assertIsInstance(broken_mock.runtime_version(), Version) + + +if __name__ == "__main__": + run_tests() diff --git a/test/python_native/test_native_dsl_ops.py b/test/python_native/test_native_dsl_ops.py index 9fe59bcb4726d..e01a2b608cb33 100644 --- a/test/python_native/test_native_dsl_ops.py +++ b/test/python_native/test_native_dsl_ops.py @@ -1,11 +1,19 @@ # Owner(s): ["module: dsl-native-ops"] +import importlib.util import os import subprocess import sys import textwrap +import uuid +from unittest.mock import patch -from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TestCase, +) def _subprocess_lastline(script, env=None): @@ -13,47 +21,129 @@ def _subprocess_lastline(script, env=None): result = subprocess.check_output( [sys.executable, "-c", script], cwd=os.path.dirname(os.path.realpath(__file__)), - env=env, - stderr=subprocess.DEVNULL, text=True, ).strip() return result.rsplit("\n", 1)[-1] +def _import_module_directly(module_name, file_name): + """Import a module directly without triggering package imports.""" + test_dir = os.path.dirname(os.path.abspath(__file__)) + pytorch_root = os.path.dirname(os.path.dirname(test_dir)) + module_path = os.path.join(pytorch_root, "torch", "_native", file_name) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + class TestNativeDSLOps(TestCase): """Tests for the torch._native DSL ops framework.""" + def setUp(self): + """Clear all caches before each test to ensure test isolation.""" + self._cache_functions_to_clear = [ + ( + "torch._native.common_utils", + ["check_native_jit_disabled", "check_native_version_skip"], + ), + ( + "torch._native.triton_utils", + [ + "_version_is_sufficient", + "check_native_jit_disabled", + "check_native_version_skip", + ], + ), + ( + "torch._native.cutedsl_utils", + [ + "_version_is_ok", + "check_native_jit_disabled", + "check_native_version_skip", + ], + ), + ] + self._clear_function_caches() + + def _clear_function_caches(self): + """Helper method to clear function caches with error handling.""" + for module_name, function_names in self._cache_functions_to_clear: + try: + module = __import__(module_name, fromlist=function_names) + for func_name in function_names: + if hasattr(module, func_name): + getattr(module, func_name).cache_clear() + except (AttributeError, ImportError): + # Some functions might not exist or be cached, ignore errors + pass + def test_consistent_helper_interface(self): - """triton_utils and cutedsl_utils expose the same public API.""" - from torch._native import cutedsl_utils, triton_utils + """Test all registered DSL utils expose consistent public APIs.""" + from torch.testing._internal.common_utils import get_all_dsls + + # Automatically discover all registered DSLs + dsl_names = get_all_dsls() + if not dsl_names: + # Fallback to hardcoded list if registry not available + dsl_names = ["triton", "cutedsl"] + + modules_info = [ + (f"{dsl}_utils.py", f"torch._native.{dsl}_utils") for dsl in dsl_names + ] - REQUIRED_METHODS = { + # Import modules directly to avoid dependency issues + modules = {} + for file_name, module_name in modules_info: + modules[module_name] = _import_module_directly(module_name, file_name) + + required_methods = { "runtime_available", "runtime_version", "register_op_override", + "deregister_op_overrides", } - for mod in (cutedsl_utils, triton_utils): - public = {name for name in dir(mod) if not name.startswith("_")} - self.assertTrue( - REQUIRED_METHODS <= public, - f"{mod.__name__} missing: {REQUIRED_METHODS - public}", - ) - for name in REQUIRED_METHODS: - self.assertTrue(callable(getattr(mod, name))) + # Test each module has required methods and they're callable + public_apis = {} + for module_name, mod in modules.items(): + with self.subTest(module=module_name, test="required_methods"): + public = {name for name in dir(mod) if not name.startswith("_")} + public_apis[module_name] = public - triton_public = {n for n in dir(triton_utils) if not n.startswith("_")} - cute_public = {n for n in dir(cutedsl_utils) if not n.startswith("_")} + self.assertTrue( + required_methods <= public, + f"{module_name} missing: {required_methods - public}", + ) - self.assertEqual(triton_public, cute_public) + for method_name in required_methods: + with self.subTest(module=module_name, method=method_name): + self.assertTrue(callable(getattr(mod, method_name))) - for mod in (cutedsl_utils, triton_utils): - self.assertIsInstance(mod.runtime_available(), bool) - ver = mod.runtime_version() - if ver is not None: - from packaging.version import Version + # Test modules expose identical public APIs + api_sets = list(public_apis.values()) + if len(api_sets) > 1: + for i, api_set in enumerate(api_sets[1:], 1): + self.assertEqual( + api_sets[0], + api_set, + f"Module {i} should have identical public API to module 0", + ) - self.assertIsInstance(ver, Version) + # Test runtime functions return expected types + for module_name, mod in modules.items(): + with self.subTest(module=module_name, test="runtime_functions"): + # runtime_available should return bool + self.assertIsInstance(mod.runtime_available(), bool) + + # runtime_version should return Version or None + ver = mod.runtime_version() + if ver is not None: + from torch._vendor.packaging.version import Version + + self.assertIsInstance(ver, Version) def test_no_dsl_imports_after_import_torch(self): """import torch must not transitively import DSL runtimes. @@ -72,153 +162,404 @@ def test_no_dsl_imports_after_import_torch(self): result = _subprocess_lastline(script) self.assertEqual(result, "[]", f"DSL modules leaked on import torch: {result}") - def test_check_native_jit_disabled_default(self): - """TORCH_DISABLE_NATIVE_JIT unset -> check returns False.""" + def test_no_external_packaging_dependency(self): + """torch._native must not import the external `packaging` package. + + It should use the vendored copy at torch._vendor.packaging instead. + This guards against ModuleNotFoundError in environments where the + external `packaging` is not installed (e.g. torchvision Windows CI). + """ script = textwrap.dedent("""\ - import os - os.environ.pop("TORCH_DISABLE_NATIVE_JIT", None) - from torch._native.common_utils import check_native_jit_disabled - print(check_native_jit_disabled()) + import sys + # Remove external packaging from sys.modules if already loaded + for mod_name in list(sys.modules): + if mod_name == "packaging" or mod_name.startswith("packaging."): + del sys.modules[mod_name] + # Block external packaging from being imported + import importlib.abc + import importlib.machinery + class BlockPackaging(importlib.abc.MetaPathFinder): + def find_module(self, fullname, path=None): + if fullname == "packaging" or fullname.startswith("packaging."): + return self + def load_module(self, fullname): + raise ImportError(f"External {fullname} is blocked") + sys.meta_path.insert(0, BlockPackaging()) + import torch + print("OK") """) result = _subprocess_lastline(script) - self.assertEqual(result, "False") + self.assertEqual(result, "OK") - def test_check_native_jit_disabled_set(self): - """TORCH_DISABLE_NATIVE_JIT=1 -> check returns True.""" - script = textwrap.dedent("""\ - from torch._native.common_utils import check_native_jit_disabled - print(check_native_jit_disabled()) - """) - env = os.environ.copy() - env["TORCH_DISABLE_NATIVE_JIT"] = "1" - result = _subprocess_lastline(script, env=env) - self.assertEqual(result, "True") + @parametrize("env_value, expected", [(None, False), ("1", True)]) + def test_check_native_jit_disabled_environment_variable(self, env_value, expected): + """Test TORCH_DISABLE_NATIVE_JIT environment variable behavior.""" + from torch._native.common_utils import check_native_jit_disabled + + if env_value is None: + os.environ.pop("TORCH_DISABLE_NATIVE_JIT", None) + else: + os.environ["TORCH_DISABLE_NATIVE_JIT"] = env_value + + try: + # Clear cache so function re-reads environment variable + check_native_jit_disabled.cache_clear() + self.assertEqual(check_native_jit_disabled(), expected) + finally: + # Clean up environment variable + os.environ.pop("TORCH_DISABLE_NATIVE_JIT", None) def test_unavailable_reason_missing(self): """Nonexistent package -> _unavailable_reason returns a string.""" - from torch._native.common_utils import _unavailable_reason - - reason = _unavailable_reason([("nonexistent_pkg_xyz", "nonexistent_pkg_xyz")]) + common_utils = _import_module_directly( + "torch._native.common_utils", "common_utils.py" + ) + reason = common_utils._unavailable_reason( + [("nonexistent_pkg_xyz", "nonexistent_pkg_xyz")] + ) self.assertIsNotNone(reason) self.assertIn("nonexistent_pkg_xyz", reason) - def test_available_version(self): - """_available_version returns a packaging.version.Version""" - from packaging.version import Version + def test_available_version_parsing(self): + """Test _available_version parses various version formats and handles invalid ones.""" + from torch._vendor.packaging.version import Version - from torch._native.common_utils import _available_version + common_utils = _import_module_directly( + "torch._native.common_utils", "common_utils.py" + ) - # Use typing_extensions which always has a clean major.minor.patch version, - # unlike torch which may have pre-release suffixes in dev builds. - ver = _available_version("typing_extensions") + # Test with real package that has clean version + ver = common_utils._available_version("typing_extensions") self.assertIsInstance(ver, Version) + # Test various version format scenarios + version_scenarios = [ + ("0.7.0rc1", Version("0.7.0rc1"), "pre-release version"), + ("3.1.0.post1", Version("3.1.0.post1"), "post-release version"), + ("2.4.0a1", Version("2.4.0a1"), "alpha version"), + ("1.2.3", Version("1.2.3"), "standard version"), + ("abc", None, "invalid version string"), + ] + + for version_str, expected_result, description in version_scenarios: + with self.subTest(version=version_str, scenario=description): + with patch("importlib.metadata.version", return_value=version_str): + result = common_utils._available_version("fake_package") + self.assertEqual( + result, + expected_result, + f"_available_version({version_str!r}) = {result}", + ) + def test_registry_mechanics(self): - """_get_library caches Library instances per (lib, dispatch_key).""" + """_get_or_create_library caches Library instances per (lib, dispatch_key).""" + import torch._native.registry as registry import torch.library - from torch._native.registry import _get_library, libs - key = ("_test_native_dsl_registry", "CPU") - libs.pop(key, None) + # Save original state for restoration + original_libs = dict(registry._libs) + original_filter_state = ( + set(registry._filter_state._dsl_names), + set(registry._filter_state._op_symbols), + set(registry._filter_state._dispatch_keys), + ) + + try: + key = ("_test_native_dsl_registry", "CPU") + registry._libs.pop(key, None) + + lib1 = registry._get_or_create_library(*key) + self.assertIsInstance(lib1, torch.library.Library) + lib2 = registry._get_or_create_library(*key) + self.assertIs(lib1, lib2, "should return cached instance") + + # Different dispatch key -> different Library + key2 = ("_test_native_dsl_registry", "CUDA") + registry._libs.pop(key2, None) + lib3 = registry._get_or_create_library(*key2) + self.assertIsNot(lib1, lib3) + + # cleanup + registry._libs.pop(key, None) + registry._libs.pop(key2, None) + finally: + # Restore original registry state + registry._libs.clear() + registry._libs.update(original_libs) + + # Restore filter state + filter_state = registry._filter_state + filter_state._dsl_names.clear() + filter_state._op_symbols.clear() + filter_state._dispatch_keys.clear() + filter_state._dsl_names.update(original_filter_state[0]) + filter_state._op_symbols.update(original_filter_state[1]) + filter_state._dispatch_keys.update(original_filter_state[2]) + + def test_deregister_op_overrides_functionality(self): + """Test deregister_op_overrides methods exist, are callable, and work correctly.""" + modules_to_test = [ + ("triton_utils.py", "torch._native.triton_utils"), + ("cutedsl_utils.py", "torch._native.cutedsl_utils"), + ] + + # Use the preserve_filter_state context manager pattern + from torch._native.registry import _filter_state + + original_filter_state = ( + set(_filter_state._dsl_names), + set(_filter_state._op_symbols), + set(_filter_state._dispatch_keys), + ) + + try: + for file_name, module_name in modules_to_test: + with self.subTest(module=module_name): + mod = _import_module_directly(module_name, file_name) + + # Test method exists and is callable + self.assertTrue(hasattr(mod, "deregister_op_overrides")) + self.assertTrue(callable(mod.deregister_op_overrides)) + + # Test method can be called without error (should be no-op when no overrides registered) + try: + mod.deregister_op_overrides() + except Exception as e: + self.fail( + f"deregister_op_overrides on {module_name} raised exception: {e}" + ) + finally: + # Restore original filter state + _filter_state._dsl_names.clear() + _filter_state._op_symbols.clear() + _filter_state._dispatch_keys.clear() + _filter_state._dsl_names.update(original_filter_state[0]) + _filter_state._op_symbols.update(original_filter_state[1]) + _filter_state._dispatch_keys.update(original_filter_state[2]) - lib1 = _get_library(*key) - self.assertIsInstance(lib1, torch.library.Library) - lib2 = _get_library(*key) - self.assertIs(lib1, lib2, "should return cached instance") + def test_register_op_skips_when_jit_disabled(self): + """register_op_override does not call through when TORCH_DISABLE_NATIVE_JIT=1.""" + from torch._native import cutedsl_utils, triton_utils - # Different dispatch key -> different Library - key2 = ("_test_native_dsl_registry", "CUDA") - libs.pop(key2, None) - lib3 = _get_library(*key2) - self.assertIsNot(lib1, lib3) + # Test the actual environment variable behavior to ensure it works + # Set TORCH_DISABLE_NATIVE_JIT=1 and clear caches + with patch.dict(os.environ, {"TORCH_DISABLE_NATIVE_JIT": "1"}): + # Import and clear caches for both modules + from torch._native.common_utils import check_native_jit_disabled - # cleanup - libs.pop(key, None) - libs.pop(key2, None) + check_native_jit_disabled.cache_clear() - def test_register_op_skips_when_jit_disabled(self): - """register_op_override does not call through when TORCH_DISABLE_NATIVE_JIT=1.""" - script = textwrap.dedent("""\ - from unittest.mock import patch - from torch._native import triton_utils, cutedsl_utils + # Import functions from each module and clear their caches too + triton_utils.check_native_jit_disabled.cache_clear() + cutedsl_utils.check_native_jit_disabled.cache_clear() - with patch('torch._native.registry._register_op_override') as mock_reg: - triton_utils.register_op_override("aten", "add.Tensor", "CUDA", lambda: None) - cutedsl_utils.register_op_override("aten", "add.Tensor", "CUDA", lambda: None) - print(mock_reg.call_count == 0) - """) - env = os.environ.copy() - env["TORCH_DISABLE_NATIVE_JIT"] = "1" - result = _subprocess_lastline(script, env=env) - self.assertEqual(result, "True") + # Verify the function returns True + self.assertTrue(check_native_jit_disabled()) + + # Mock the registry calls to count how many times they would be called + with patch("torch._native.registry.register_op_override") as registry_mock: + # Use a unique operation name + unique_op = f"test_jit_disabled_{uuid.uuid4().hex[:8]}.Tensor" + triton_utils.register_op_override( + "aten", unique_op, "CPU", lambda: None + ) + cutedsl_utils.register_op_override( + "aten", unique_op, "CPU", lambda: None + ) + # Should not call the registry function at all since JIT is disabled + self.assertEqual(registry_mock.call_count, 0) def test_version_skip_env_var_overrides(self): """TORCH_NATIVE_SKIP_VERSION_CHECK=1 allows non-blessed versions.""" - script = textwrap.dedent("""\ - from unittest.mock import patch, MagicMock - from packaging.version import Version - from torch._native import triton_utils, cutedsl_utils - - fake_version = Version("99.99.99") - - with patch.object(triton_utils, '_check_runtime_available', return_value=(True, fake_version)), \\ - patch.object(cutedsl_utils, '_check_runtime_available', return_value=(True, fake_version)), \\ - patch.object(triton_utils, '_register_op_override') as triton_mock, \\ - patch.object(cutedsl_utils, '_register_op_override') as cute_mock: - triton_utils.register_op_override("aten", "add.Tensor", "CUDA", lambda: None) - cutedsl_utils.register_op_override("aten", "add.Tensor", "CUDA", lambda: None) - print(triton_mock.call_count + cute_mock.call_count) - """) - env = os.environ.copy() - env["TORCH_NATIVE_SKIP_VERSION_CHECK"] = "1" - result = _subprocess_lastline(script, env=env) - self.assertEqual(result, "2") + from torch._vendor.packaging.version import Version - def test_check_native_version_skip_default(self): - """TORCH_NATIVE_SKIP_VERSION_CHECK unset -> returns False.""" - script = textwrap.dedent("""\ - import os - os.environ.pop("TORCH_NATIVE_SKIP_VERSION_CHECK", None) - from torch._native.common_utils import check_native_version_skip - print(check_native_version_skip()) - """) - result = _subprocess_lastline(script) - self.assertEqual(result, "False") + fake_version = Version("99.99.99") - def test_check_native_version_skip_set(self): - """TORCH_NATIVE_SKIP_VERSION_CHECK=1 -> returns True.""" - script = textwrap.dedent("""\ + # Set the environment variable and clear caches + with patch.dict(os.environ, {"TORCH_NATIVE_SKIP_VERSION_CHECK": "1"}): + # Import fresh modules to avoid cached state + from torch._native import cutedsl_utils, triton_utils from torch._native.common_utils import check_native_version_skip - print(check_native_version_skip()) - """) - env = os.environ.copy() - env["TORCH_NATIVE_SKIP_VERSION_CHECK"] = "1" - result = _subprocess_lastline(script, env=env) - self.assertEqual(result, "True") - def test_available_version_prerelease(self): - """_available_version parses valid versions and rejects unparsable ones.""" - from unittest.mock import patch - - from packaging.version import Version - - from torch._native.common_utils import _available_version - - valid_versions = ["0.7.0rc1", "3.1.0.post1", "2.4.0a1", "1.2.3"] - for version_str in valid_versions: - with patch("importlib.metadata.version", return_value=version_str): - result = _available_version("fake_package") + # Clear all relevant caches to ensure clean state + check_native_version_skip.cache_clear() + + # Clear module-specific caches for the imported modules + for module in [triton_utils, cutedsl_utils]: + for attr_name in dir(module): + attr = getattr(module, attr_name) + if hasattr(attr, "cache_clear"): + attr.cache_clear() + + with ( + patch.object( + triton_utils, + "_check_runtime_available", + return_value=(True, fake_version), + ), + patch.object( + cutedsl_utils, + "_check_runtime_available", + return_value=(True, fake_version), + ), + patch.object(triton_utils, "_register_op_override_impl") as triton_mock, + patch.object(cutedsl_utils, "_register_op_override_impl") as cute_mock, + ): + # Use unique operation names to avoid conflicts + op_name = f"test_version_skip_{uuid.uuid4().hex[:8]}.Tensor" + + # Call the register functions + triton_utils.register_op_override("aten", op_name, "CPU", lambda: None) + cutedsl_utils.register_op_override("aten", op_name, "CPU", lambda: None) + + # Verify both implementation functions were called self.assertEqual( - result, - Version(version_str), - f"_available_version({version_str!r}) = {result}", + triton_mock.call_count + cute_mock.call_count, + 2, + f"Expected 2 calls but got triton: {triton_mock.call_count}, cutedsl: {cute_mock.call_count}", ) - # Completely unparsable -> None - with patch("importlib.metadata.version", return_value="abc"): - result = _available_version("fake_package") - self.assertIsNone(result) + @parametrize("env_value, expected", [(None, False), ("1", True)]) + def test_check_native_version_skip_environment_variable(self, env_value, expected): + """Test TORCH_NATIVE_SKIP_VERSION_CHECK environment variable behavior.""" + from torch._native.common_utils import check_native_version_skip + + if env_value is None: + os.environ.pop("TORCH_NATIVE_SKIP_VERSION_CHECK", None) + else: + os.environ["TORCH_NATIVE_SKIP_VERSION_CHECK"] = env_value + + try: + # Clear cache so function re-reads environment variable + check_native_version_skip.cache_clear() + self.assertEqual(check_native_version_skip(), expected) + finally: + # Clean up environment variable + os.environ.pop("TORCH_NATIVE_SKIP_VERSION_CHECK", None) + def test_dsl_registry_functionality(self): + """Test that DSL registry works correctly""" + from torch.testing._internal.common_utils import ( + get_all_dsls, + get_available_dsls, + is_dsl_available, + ) + + # Test registry returns expected DSLs + all_dsls = get_all_dsls() + self.assertIsInstance(all_dsls, list) + self.assertIn("triton", all_dsls) + self.assertIn("cutedsl", all_dsls) + + # Test available DSLs are subset of all DSLs + available_dsls = get_available_dsls() + self.assertIsInstance(available_dsls, list) + for dsl in available_dsls: + self.assertIn(dsl, all_dsls) + + # Test availability check function + for dsl in all_dsls: + availability = is_dsl_available(dsl) + self.assertIsInstance(availability, bool) + # If DSL is in available list, it should return True + if dsl in available_dsls: + self.assertTrue(availability) + + def test_dsl_test_helpers(self): + """Test that DSL test helper decorators work""" + from torch.testing._internal.common_utils import ( + skipIfDSLUnavailable, + skipIfNoCuteDSL, + skipIfNoTritonDSL, + skipUnlessDSLAvailable, + ) + + # Test that decorators are callable + self.assertTrue(callable(skipIfNoTritonDSL)) + self.assertTrue(callable(skipIfNoCuteDSL)) + self.assertTrue(callable(skipIfDSLUnavailable)) + self.assertTrue(callable(skipUnlessDSLAvailable)) + + # Test dynamic decorators can be called + try: + decorator1 = skipIfDSLUnavailable("nonexistent_dsl") + decorator2 = skipUnlessDSLAvailable("triton") + self.assertTrue(callable(decorator1)) + self.assertTrue(callable(decorator2)) + except Exception as e: + self.fail(f"Dynamic DSL decorators failed: {e}") + + def test_cache_invalidation_after_re_registration(self): + """Test that caches are properly invalidated when DSLs are re-registered""" + from unittest.mock import Mock + + from torch._native.dsl_registry import DSLRegistry + + # Create a fresh registry for this test + registry = DSLRegistry() + + # Create mock DSL modules + mock_dsl_1 = Mock() + mock_dsl_1.runtime_available.return_value = False # Initially unavailable + mock_dsl_1.runtime_version.return_value = None + + mock_dsl_2 = Mock() + mock_dsl_2.runtime_available.return_value = True # Available + mock_dsl_2.runtime_version.return_value = None + + # Register first DSL and cache results + registry.register_dsl("test_cache_dsl", mock_dsl_1) + initial_available = registry.is_dsl_available("test_cache_dsl") + initial_list = registry.list_available_dsls() + + self.assertFalse(initial_available) + self.assertNotIn("test_cache_dsl", initial_list) + + # Re-register with different module that is available + registry.register_dsl("test_cache_dsl", mock_dsl_2) + + # Verify cache was invalidated and new results are returned + new_available = registry.is_dsl_available("test_cache_dsl") + new_list = registry.list_available_dsls() + + self.assertTrue( + new_available, "Cache should be invalidated and return new result" + ) + self.assertIn( + "test_cache_dsl", + new_list, + "Available DSLs list should reflect new registration", + ) + + def test_incomplete_protocol_implementation(self): + """Test that registration fails when module doesn't implement required protocol methods""" + from torch._native.dsl_registry import DSLRegistry + + # Create a fresh registry for this test + registry = DSLRegistry() + + # Create an object missing required protocol methods (not using Mock) + class IncompleteModule: + def runtime_available(self): + return True + + # Missing: runtime_version, register_op_override, deregister_op_overrides + + incomplete_module = IncompleteModule() + + # Attempt to register should raise TypeError due to missing methods + with self.assertRaises(TypeError) as cm: + registry.register_dsl("incomplete_dsl", incomplete_module) + + self.assertIn("missing required methods", str(cm.exception)) + self.assertIn("runtime_version", str(cm.exception)) + self.assertIn("register_op_override", str(cm.exception)) + + # Verify DSL was not registered + self.assertNotIn("incomplete_dsl", registry.list_all_dsls()) + + +instantiate_parametrized_tests(TestNativeDSLOps) if __name__ == "__main__": run_tests() diff --git a/test/python_native/test_registry.py b/test/python_native/test_registry.py new file mode 100644 index 0000000000000..f58ea56ad3c9c --- /dev/null +++ b/test/python_native/test_registry.py @@ -0,0 +1,258 @@ +# Owner(s): ["module: dsl-native-ops"] + +from unittest.mock import MagicMock, patch + +import torch._native.registry as registry_module +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +@skipIfTorchDynamo("Registry tests don't need dynamo compilation") +class TestRegistry(TestCase): + """Tests for the torch._native.registry module.""" + + def setUp(self): + """Clean up registry state before each test.""" + self.registry = registry_module + + # Store original state for restoration + self._original_libs = dict(self.registry._libs) + self._original_graphs = dict(self.registry._graphs) + self._original_dsl_name_to_lib_graph = { + k: list(v) for k, v in self.registry._dsl_name_to_lib_graph.items() + } + self._original_dispatch_key_to_lib_graph = { + k: list(v) for k, v in self.registry._dispatch_key_to_lib_graph.items() + } + self._original_op_symbol_to_lib_graph = { + k: list(v) for k, v in self.registry._op_symbol_to_lib_graph.items() + } + + # Store original filter state + self._original_filter_state = ( + set(self.registry._filter_state._dsl_names), + set(self.registry._filter_state._op_symbols), + set(self.registry._filter_state._dispatch_keys), + ) + + # Clear global state + self.registry._libs.clear() + self.registry._graphs.clear() + self.registry._dsl_name_to_lib_graph.clear() + self.registry._dispatch_key_to_lib_graph.clear() + self.registry._op_symbol_to_lib_graph.clear() + + # Clear filter state to ensure clean start + self.registry._filter_state._dsl_names.clear() + self.registry._filter_state._op_symbols.clear() + self.registry._filter_state._dispatch_keys.clear() + + def tearDown(self): + """Restore original registry state after each test.""" + if hasattr(self, "registry"): + # Restore original state + self.registry._libs.clear() + self.registry._libs.update(self._original_libs) + + self.registry._graphs.clear() + self.registry._graphs.update(self._original_graphs) + + # Properly restore mapping dictionaries with new list instances + self.registry._dsl_name_to_lib_graph.clear() + for k, v in self._original_dsl_name_to_lib_graph.items(): + self.registry._dsl_name_to_lib_graph[k] = list(v) + + self.registry._dispatch_key_to_lib_graph.clear() + for k, v in self._original_dispatch_key_to_lib_graph.items(): + self.registry._dispatch_key_to_lib_graph[k] = list(v) + + self.registry._op_symbol_to_lib_graph.clear() + for k, v in self._original_op_symbol_to_lib_graph.items(): + self.registry._op_symbol_to_lib_graph[k] = list(v) + + # Restore filter state + self.registry._filter_state._dsl_names.clear() + self.registry._filter_state._op_symbols.clear() + self.registry._filter_state._dispatch_keys.clear() + self.registry._filter_state._dsl_names.update( + self._original_filter_state[0] + ) + self.registry._filter_state._op_symbols.update( + self._original_filter_state[1] + ) + self.registry._filter_state._dispatch_keys.update( + self._original_filter_state[2] + ) + + # Keep essential existing tests + def test_override_node_dataclass(self): + """Test _OverrideNode dataclass creation and defaults.""" + + def test_fn(x): + return x + + node = self.registry._OverrideNode("test_dsl", "add.Tensor", "CPU", test_fn) + self.assertEqual(node.dsl_name, "test_dsl") + self.assertEqual(node.op_symbol, "add.Tensor") + self.assertEqual(node.dispatch_key, "CPU") + self.assertEqual(node.override_fn, test_fn) + self.assertFalse(node.unconditional_override) + self.assertTrue(node.active) + + @patch("torch.library.Library") + def test_register_op_override_basic(self, mock_library_cls): + """Test basic register_op_override functionality.""" + + def impl_fn(x): + return x + + mock_lib = MagicMock() + mock_library_cls.return_value = mock_lib + + self.registry.register_op_override( + "test_backend", "aten", "add.Tensor", "CPU", impl_fn + ) + + key = ("add.Tensor", "CPU") + self.assertEqual(len(self.registry._graphs[key]), 1) + node = self.registry._graphs[key][0] + self.assertEqual(node.dsl_name, "test_backend") + self.assertEqual(node.override_fn, impl_fn) + + @patch("torch.library.Library") + def test_deregister_op_overrides_basic(self, mock_library_cls): + """Test basic deregister_op_overrides functionality.""" + + def impl_fn(x): + return x + + mock_lib = MagicMock() + mock_library_cls.return_value = mock_lib + + # Register first + self.registry.register_op_override( + "test_backend", "aten", "mul.Tensor", "CPU", impl_fn + ) + + key = ("mul.Tensor", "CPU") + self.assertTrue(self.registry._graphs[key][0].active) + + # Then deregister + self.registry.deregister_op_overrides(disable_dsl_names="test_backend") + self.assertFalse(self.registry._graphs[key][0].active) + + # NEW FUNCTIONALITY TESTS - ONLY THE ESSENTIAL ONES + + def test_reorder_graphs_from_user_function_basic(self): + """Test basic graph reordering functionality.""" + # Set up test data + key = ("test_reorder.Tensor", "CPU") + + def impl_fn(x): + return x + + # Create nodes in specific order + nodes = [ + self.registry._OverrideNode("dsl_c", "test_reorder.Tensor", "CPU", impl_fn), + self.registry._OverrideNode("dsl_a", "test_reorder.Tensor", "CPU", impl_fn), + self.registry._OverrideNode("dsl_b", "test_reorder.Tensor", "CPU", impl_fn), + ] + self.registry._graphs[key] = nodes + + # Define alphabetical ordering function + def alphabetical_order(op_symbol, dispatch_key, graph): + return sorted(graph, key=lambda n: n.dsl_name) + + # Apply reordering + self.registry.reorder_graphs_from_user_function(alphabetical_order) + + # Verify alphabetical order + reordered_graph = self.registry._graphs[key] + actual_names = [node.dsl_name for node in reordered_graph] + self.assertEqual(actual_names, ["dsl_a", "dsl_b", "dsl_c"]) + + def test_reorder_graphs_from_user_function_error_handling(self): + """Test error handling in graph reordering.""" + # Set up test data + key = ("test_error.Tensor", "CPU") + + def impl_fn(x): + return x + + node = self.registry._OverrideNode( + "test_dsl", "test_error.Tensor", "CPU", impl_fn + ) + original_graph = [node] + self.registry._graphs[key] = original_graph.copy() + + # Define failing ordering function + def failing_order_fn(op_symbol, dispatch_key, graph): + raise ValueError("Test exception") + + # Should handle the exception gracefully + with self.assertLogs("torch._native.registry", level="WARNING") as log: + self.registry.reorder_graphs_from_user_function(failing_order_fn) + + # Verify warning was logged and original graph preserved + self.assertEqual(len(log.records), 1) + self.assertIn("Graph transformation failed", log.records[0].getMessage()) + self.assertEqual(self.registry._graphs[key], original_graph) + + def test_get_user_ordering_fn_env_var_not_set(self): + """Test behavior when environment variable is not set.""" + with patch.dict("os.environ", {}, clear=True): + from torch._native import get_user_ordering_fn + + get_user_ordering_fn.cache_clear() + result = get_user_ordering_fn() + self.assertIsNone(result) + + def test_get_user_ordering_fn_invalid_path(self): + """Test handling of invalid environment variable paths.""" + with patch.dict( + "os.environ", + {"TORCH_PYTHON_NATIVE_USER_GRAPH_ORDER_FN": "nonexistent.module.function"}, + ): + from torch._native import get_user_ordering_fn + + get_user_ordering_fn.cache_clear() + + with self.assertRaises(ValueError) as cm: + get_user_ordering_fn() + self.assertIn("Could not resolve", str(cm.exception)) + + def test_integration_reorder_and_register(self): + """Integration test: reorder then register functionality.""" + + def impl_fn1(x): + return x + 1 + + def impl_fn2(x): + return x + 2 + + # Register multiple overrides + self.registry.register_op_override( + "backend_z", "aten", "test.Tensor", "CPU", impl_fn1 + ) + self.registry.register_op_override( + "backend_a", "aten", "test.Tensor", "CPU", impl_fn2 + ) + + key = ("test.Tensor", "CPU") + + # Verify initial order + initial_names = [node.dsl_name for node in self.registry._graphs[key]] + self.assertEqual(initial_names, ["backend_z", "backend_a"]) + + # Reorder alphabetically + def alphabetical_order(op_symbol, dispatch_key, graph): + return sorted(graph, key=lambda n: n.dsl_name) + + self.registry.reorder_graphs_from_user_function(alphabetical_order) + + # Verify reordered + final_names = [node.dsl_name for node in self.registry._graphs[key]] + self.assertEqual(final_names, ["backend_a", "backend_z"]) + + +if __name__ == "__main__": + run_tests() diff --git a/test/python_native/test_torch_backends.py b/test/python_native/test_torch_backends.py new file mode 100644 index 0000000000000..e3e8f1e46a246 --- /dev/null +++ b/test/python_native/test_torch_backends.py @@ -0,0 +1,544 @@ +# Owner(s): ["module: dsl-native-ops"] + +import torch.backends.python_native as pn +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase + + +class RegistryTestMixin: + """Mixin for tests that need to preserve registry state.""" + + def setUp(self): + """Set up test state with registry preservation.""" + super().setUp() if hasattr(super(), "setUp") else None + + # Use the new _preserve_filter_state but setup manually for tearDown + filter_state = pn._get_filter_state() + self._original_filter_state = ( + set(filter_state._dsl_names), + set(filter_state._op_symbols), + set(filter_state._dispatch_keys), + ) + + # Clear filter state for clean test start + filter_state._dsl_names.clear() + filter_state._op_symbols.clear() + filter_state._dispatch_keys.clear() + + # Ensure all DSLs start enabled + try: + for dsl_name in pn.all_dsls: + dsl = getattr(pn, dsl_name) + dsl.enable() + except Exception: + pass + + def tearDown(self): + """Restore original filter state.""" + try: + if hasattr(self, "_original_filter_state"): + filter_state = pn._get_filter_state() + + # Restore original filter state + filter_state._dsl_names.clear() + filter_state._op_symbols.clear() + filter_state._dispatch_keys.clear() + + filter_state._dsl_names.update(self._original_filter_state[0]) + filter_state._op_symbols.update(self._original_filter_state[1]) + filter_state._dispatch_keys.update(self._original_filter_state[2]) + except Exception: + pass + super().tearDown() if hasattr(super(), "tearDown") else None + + +@skipIfTorchDynamo("Backend tests don't need dynamo compilation") +class TestTorchBackendsPythonNative(RegistryTestMixin, TestCase): + """Tests for torch.backends.python_native user-facing API.""" + + def test_module_import(self): + """Test that torch.backends.python_native imports successfully.""" + # Should not raise any exceptions + import torch.backends.python_native as pn_import + + self.assertIsNotNone(pn_import) + + def test_dsl_discovery(self): + """Test DSL discovery functionality.""" + all_dsls = pn.all_dsls + available_dsls = pn.available_dsls + + # Should return lists + self.assertIsInstance(all_dsls, list) + self.assertIsInstance(available_dsls, list) + + # Available DSLs should be subset of all DSLs + self.assertTrue(set(available_dsls).issubset(set(all_dsls))) + + def test_dsl_access(self): + """Test accessing DSL controllers.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + # Should be able to access DSL controller + dsl = getattr(pn, dsl_name) + self.assertIsNotNone(dsl) + + # Should have required attributes + self.assertEqual(dsl.name, dsl_name) + self.assertIsInstance(dsl.available, bool) + self.assertIsInstance(dsl.enabled, bool) + + def test_dsl_properties(self): + """Test DSL controller properties.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + + # Test name property + self.assertEqual(dsl.name, dsl_name) + + # Test available property (should not raise) + available = dsl.available + self.assertIsInstance(available, bool) + + # Test version property (should not raise) + version = dsl.version + # Version can be None if DSL not available + self.assertTrue(version is None or hasattr(version, "major")) + + # Test enabled property + enabled = dsl.enabled + self.assertIsInstance(enabled, bool) + + def test_dsl_enable_disable(self): + """Test DSL enable/disable functionality.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + original_state = dsl.enabled + + try: + # Test method-based disable/enable + dsl.disable() + self.assertEqual(dsl.enabled, False) + + dsl.enable() + self.assertEqual(dsl.enabled, True) + + # Test property-based disable/enable + dsl.disable() # Use method to avoid flags_frozen issue + self.assertEqual(dsl.enabled, False) + + dsl.enable() # Use method to avoid flags_frozen issue + self.assertEqual(dsl.enabled, True) + + finally: + # Restore original state + if original_state: + dsl.enable() + else: + dsl.disable() + + def test_dsl_context_managers(self): + """Test DSL context manager functionality.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + original_state = dsl.enabled + + try: + # Ensure DSL starts enabled + dsl.enable() + + # Test disabled context manager + with dsl.disabled(): + self.assertEqual(dsl.enabled, False) + + # Should be restored after context + self.assertEqual(dsl.enabled, True) + + finally: + # Restore original state + if original_state: + dsl.enable() + else: + dsl.disable() + + def test_nested_context_managers(self): + """Test nested DSL context managers.""" + all_dsls = pn.all_dsls + + if len(all_dsls) >= 2: + dsl1_name, dsl2_name = all_dsls[0], all_dsls[1] + dsl1 = getattr(pn, dsl1_name) + dsl2 = getattr(pn, dsl2_name) + + original_state1 = dsl1.enabled + original_state2 = dsl2.enabled + + try: + # Ensure both start enabled + dsl1.enable() + dsl2.enable() + + with dsl1.disabled(): + self.assertEqual(dsl1.enabled, False) + self.assertEqual(dsl2.enabled, True) + + with dsl2.disabled(): + self.assertEqual(dsl1.enabled, False) + self.assertEqual(dsl2.enabled, False) + + # dsl2 should be restored + self.assertEqual(dsl1.enabled, False) + self.assertEqual(dsl2.enabled, True) + + # Both should be restored + self.assertEqual(dsl1.enabled, True) + self.assertEqual(dsl2.enabled, True) + + finally: + # Restore original states + if original_state1: + dsl1.enable() + else: + dsl1.disable() + if original_state2: + dsl2.enable() + else: + dsl2.disable() + + def test_operation_discovery(self): + """Test operation discovery functionality.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + operations = pn.get_dsl_operations(dsl_name) + + # Should return a list + self.assertIsInstance(operations, list) + + # Each operation should be a string + for op in operations: + self.assertIsInstance(op, str) + self.assertTrue(len(op) > 0) + + def test_operation_control(self): + """Test operation-level control functionality.""" + # Get an operation to test with + all_dsls = pn.all_dsls + test_operation = None + + for dsl_name in all_dsls: + operations = pn.get_dsl_operations(dsl_name) + if operations: + test_operation = operations[0] + break + + if test_operation: + # Test operation disable/enable (should not raise) + pn.disable_operations(test_operation) + pn.enable_operations(test_operation) + + # Test multiple operations + pn.disable_operations(test_operation, "nonexistent_op") + pn.enable_operations(test_operation, "nonexistent_op") + + def test_operation_context_manager(self): + """Test operation-level context manager.""" + # Get an operation to test with + all_dsls = pn.all_dsls + test_operation = None + + for dsl_name in all_dsls: + operations = pn.get_dsl_operations(dsl_name) + if operations: + test_operation = operations[0] + break + + if test_operation: + # Should not raise exceptions + with pn.operations_disabled(test_operation): + pass # Operation should be disabled in this context + # Operation should be re-enabled after context + + def test_dispatch_key_control(self): + """Test dispatch key control functionality.""" + # Test basic dispatch key control (should not raise) + pn.disable_dispatch_keys("CUDA", "CPU") + pn.enable_dispatch_keys("CUDA", "CPU") + + def test_invalid_dsl_access(self): + """Test accessing invalid DSL names.""" + with self.assertRaises(AttributeError): + _ = pn.nonexistent_dsl + + def test_dsl_repr(self): + """Test DSL controller string representation.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + repr_str = repr(dsl) + + # Should contain DSL name + self.assertIn(dsl_name, repr_str) + # Should contain status info + self.assertTrue( + any(status in repr_str for status in ["available", "unavailable"]) + ) + self.assertTrue( + any(status in repr_str for status in ["enabled", "disabled"]) + ) + + def test_module_dir(self): + """Test module __dir__ functionality.""" + attrs = dir(pn) + + # Should contain core attributes + expected_attrs = [ + "available_dsls", + "all_dsls", + "get_dsl_operations", + "disable_operations", + "enable_operations", + "disable_dispatch_keys", + "enable_dispatch_keys", + "operations_disabled", + ] + + for attr in expected_attrs: + self.assertIn(attr, attrs) + + # Should contain DSL names + for dsl_name in pn.all_dsls: + self.assertIn(dsl_name, attrs) + + def test_caching_functionality(self): + """Test caching integration and API contracts (not functools.lru_cache internals).""" + all_dsls = pn.all_dsls + + if all_dsls: + dsl_name = all_dsls[0] + + # Test operation caching - results should be identical + ops1 = pn.get_dsl_operations(dsl_name) + ops2 = pn.get_dsl_operations(dsl_name) + self.assertEqual(ops1, ops2) + + # Test DSL controller caching - same object returned (API contract) + controller1 = getattr(pn, dsl_name) + controller2 = getattr(pn, dsl_name) + self.assertIs(controller1, controller2) + + # Verify controller still works correctly + self.assertEqual(controller1.name, dsl_name) + self.assertIsInstance(controller1.enabled, bool) + + def test_error_handling(self): + """Test error handling with invalid inputs.""" + # Test with empty operation names - should not raise + pn.disable_operations("") + pn.enable_operations("") + + # Test with invalid dispatch keys - should not raise + pn.disable_dispatch_keys("") + pn.enable_dispatch_keys("") + + # Test invalid DSL access raises AttributeError + with self.assertRaises(AttributeError): + _ = pn.invalid_dsl_name + + # Test get_dsl_operations with invalid DSL - should return empty list or raise ValueError + invalid_ops = pn.get_dsl_operations("invalid_dsl") + self.assertIsInstance(invalid_ops, list) + self.assertEqual(len(invalid_ops), 0) + + def test_dispatch_behavior_verification(self): + """Test that DSL disable/enable actually affects operation dispatch.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + + if not dsl.available: + continue # Skip unavailable DSLs + + operations = pn.get_dsl_operations(dsl_name) + if not operations: + continue # Skip DSLs with no operations + + original_state = dsl.enabled + + try: + # Enable DSL first + dsl.enable() + self.assertTrue( + dsl.enabled, f"{dsl_name} should be enabled after enable()" + ) + + # Disable DSL + dsl.disable() + self.assertFalse( + dsl.enabled, f"{dsl_name} should be disabled after disable()" + ) + + # Verify operations are actually deregistered by checking registry state + self.assertTrue( + pn.is_dsl_disabled(dsl_name), + f"{dsl_name} should be disabled in registry", + ) + + # Re-enable and verify registry state + dsl.enable() + self.assertTrue( + dsl.enabled, f"{dsl_name} should be enabled after re-enable()" + ) + self.assertFalse( + pn.is_dsl_disabled(dsl_name), + f"{dsl_name} should not be disabled in registry after re-enable", + ) + + finally: + # Restore original state + if original_state: + dsl.enable() + else: + dsl.disable() + + def test_context_manager_dispatch_behavior(self): + """Test that context managers actually affect dispatch, not just state tracking.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + + if not dsl.available: + continue # Skip unavailable DSLs + + original_state = dsl.enabled + + try: + # Start with DSL enabled + dsl.enable() + self.assertFalse(pn.is_dsl_disabled(dsl_name)) + + # Use disabled context manager + with dsl.disabled(): + # Verify DSL is actually disabled in registry during context + self.assertTrue( + pn.is_dsl_disabled(dsl_name), + f"{dsl_name} should be disabled in registry during context", + ) + + # Verify DSL is re-enabled after context + self.assertFalse( + pn.is_dsl_disabled(dsl_name), + f"{dsl_name} should be re-enabled in registry after context", + ) + + finally: + # Restore original state + if original_state: + dsl.enable() + else: + dsl.disable() + + def test_operation_dispatch_behavior(self): + """Test that operation-level disable/enable affects registry state.""" + all_dsls = pn.all_dsls + test_operation = None + + # Find a real operation to test with + for dsl_name in all_dsls: + operations = pn.get_dsl_operations(dsl_name) + if operations: + test_operation = operations[0] + break + + if test_operation: + # Verify operation starts enabled + initial_disabled = pn.is_operation_disabled(test_operation) + + try: + # Disable operation and verify registry state + pn.disable_operations(test_operation) + self.assertTrue( + pn.is_operation_disabled(test_operation), + f"Operation {test_operation} should be disabled in registry", + ) + + # Re-enable operation and verify registry state + pn.enable_operations(test_operation) + # Note: enable_operations removes from disabled list + self.assertFalse( + pn.is_operation_disabled(test_operation), + f"Operation {test_operation} should not be disabled in registry after re-enable", + ) + + finally: + # Restore initial state + if initial_disabled: + pn.disable_operations(test_operation) + else: + pn.enable_operations(test_operation) + + +class TestTorchBackendsPythonNativeIntegration(TestCase): + """Integration tests for torch.backends.python_native with actual DSLs.""" + + def test_real_dsl_integration(self): + """Test integration with real DSL modules if available.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + dsl = getattr(pn, dsl_name) + + if dsl.available: + # Test that we can actually disable/enable real DSLs + original_state = dsl.enabled + + try: + # This should call actual DSL deregister functions + dsl.disable() + + # This should call actual registry re-enable functions + dsl.enable() + + finally: + # Restore state + if original_state: + dsl.enable() + else: + dsl.disable() + + def test_operations_with_real_registry(self): + """Test operation discovery with real registry.""" + all_dsls = pn.all_dsls + + for dsl_name in all_dsls: + with self.subTest(dsl_name=dsl_name): + operations = pn.get_dsl_operations(dsl_name) + + # If DSL has operations, they should be valid + if operations: + # Operations should be strings + for op in operations: + self.assertIsInstance(op, str) + self.assertGreater(len(op), 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/core/experimental/test_floatx.py b/test/quantization/core/experimental/test_floatx.py index 75b542a78d095..67228279aa694 100644 --- a/test/quantization/core/experimental/test_floatx.py +++ b/test/quantization/core/experimental/test_floatx.py @@ -127,6 +127,8 @@ FLOAT8_DTYPES_WITH_INF = [torch.float8_e5m2] +FLOAT8_DTYPES_SATURATE_ON_OVERFLOW = [torch.float8_e4m3fn] + def _int_bits_to_float(x): y = struct.unpack("!f", struct.pack("!I", x))[0] @@ -178,9 +180,15 @@ def simulate_fp8_precision(input, variant): # Re-compose mantissa and exponent vals = (mantissa_val_rounded * 2.0 ** (-23 + exponent)).to(dtype) - # Replace overflows with inf/NaN as appropriate (no saturation) - have_inf = variant in FLOAT8_DTYPES_WITH_INF - vals[vals > torch.finfo(variant).max] = torch.inf if have_inf else torch.nan + # Replace overflows: inf for types that have it, saturate to max for types + # that use satfinite semantics, NaN otherwise + overflow = vals > torch.finfo(variant).max + if variant in FLOAT8_DTYPES_WITH_INF: + vals[overflow] = torch.inf + elif variant in FLOAT8_DTYPES_SATURATE_ON_OVERFLOW: + vals[overflow] = torch.finfo(variant).max + else: + vals[overflow] = torch.nan return vals * signs diff --git a/test/quantization/core/experimental/test_nonuniform_observer.py b/test/quantization/core/experimental/test_nonuniform_observer.py index 015645568343d..5e6205f295887 100644 --- a/test/quantization/core/experimental/test_nonuniform_observer.py +++ b/test/quantization/core/experimental/test_nonuniform_observer.py @@ -1,5 +1,4 @@ # Owner(s): ["oncall: quantization"] -# ruff: noqa: F841 from torch.ao.quantization.experimental.observer import APoTObserver import unittest diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 910fc677fdafb..1c64e5738c90a 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -7848,6 +7848,137 @@ def test_qconv2d_sum_relu_float_output_pt2e(self): qconv_x2_dtype=qconv_x2_dtype, ) + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv2d_binary_no_output_aliasing(self): + # Verify that qconv2d_pointwise.binary mutates qaccum in-place + # but returns a tensor that does not alias it. + groups = 1 + in_ch, out_ch = 4, 4 + x = torch.randint(0, 8, (1, in_ch, 5, 5), dtype=torch.uint8) + x = x.to(memory_format=torch.channels_last) + w_raw = torch.randint(-4, 4, (out_ch, in_ch, 3, 3), dtype=torch.int8) + w_scale = torch.tensor([1.0], dtype=torch.float) + w_zp = torch.tensor([0], dtype=torch.int64) + packed_w = torch.ops.onednn.qconv_prepack( + w_raw, w_scale, 1.0, 0, [1, 1], [1, 1], [1, 1], groups, x.size(), + ) + dummy_out = torch.ops.onednn.qconv2d_pointwise( + x, 1.0, 0, packed_w, w_scale, w_zp, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, "none", [], None, + ) + for binary_attr, unary_attr in [("sum", "none"), ("sum", "relu")]: + qaccum = torch.randint(0, 8, dummy_out.size(), dtype=torch.uint8) + qaccum = qaccum.to(memory_format=torch.channels_last) + qaccum_data_before = qaccum.clone() + result = torch.ops.onednn.qconv2d_pointwise.binary( + x, 1.0, 0, packed_w, w_scale, w_zp, qaccum, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, + 1.0, 0, binary_attr, None, unary_attr, [], None, + ) + # qaccum was mutated + self.assertFalse(torch.equal(qaccum, qaccum_data_before)) + # result does not alias qaccum + self.assertNotEqual(result.data_ptr(), qaccum.data_ptr()) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv2d_binary_tensor_no_output_aliasing(self): + # Same as above but for the .binary_tensor overload (tensor scale/zp). + groups = 1 + in_ch, out_ch = 4, 4 + x = torch.randint(0, 8, (1, in_ch, 5, 5), dtype=torch.uint8) + x = x.to(memory_format=torch.channels_last) + w_raw = torch.randint(-4, 4, (out_ch, in_ch, 3, 3), dtype=torch.int8) + w_scale = torch.tensor([1.0], dtype=torch.float) + w_zp = torch.tensor([0], dtype=torch.int64) + packed_w = torch.ops.onednn.qconv_prepack( + w_raw, w_scale, 1.0, 0, [1, 1], [1, 1], [1, 1], groups, x.size(), + ) + dummy_out = torch.ops.onednn.qconv2d_pointwise( + x, 1.0, 0, packed_w, w_scale, w_zp, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, "none", [], None, + ) + qaccum = torch.randint(0, 8, dummy_out.size(), dtype=torch.uint8) + qaccum = qaccum.to(memory_format=torch.channels_last) + qaccum_data_before = qaccum.clone() + x_scale = torch.tensor([1.0]) + x_zp = torch.tensor([0]) + result = torch.ops.onednn.qconv2d_pointwise.binary_tensor( + x, x_scale, x_zp, packed_w, w_scale, w_zp, qaccum, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, + 1.0, 0, "sum", None, "none", [], None, + ) + self.assertFalse(torch.equal(qaccum, qaccum_data_before)) + self.assertNotEqual(result.data_ptr(), qaccum.data_ptr()) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv2d_binary_empty_output_no_aliasing(self): + # Trigger the early return path when output has 0 elements. + # Use batch_size=0 so the convolution output has numel == 0. + groups = 1 + in_ch, out_ch = 4, 4 + x = torch.randint(0, 8, (0, in_ch, 5, 5), dtype=torch.uint8) + x = x.to(memory_format=torch.channels_last) + w_raw = torch.randint(-4, 4, (out_ch, in_ch, 3, 3), dtype=torch.int8) + w_scale = torch.tensor([1.0], dtype=torch.float) + w_zp = torch.tensor([0], dtype=torch.int64) + packed_w = torch.ops.onednn.qconv_prepack( + w_raw, w_scale, 1.0, 0, [1, 1], [1, 1], [1, 1], groups, x.size(), + ) + qaccum = torch.empty(0, out_ch, 3, 3, dtype=torch.uint8) + qaccum = qaccum.to(memory_format=torch.channels_last) + result = torch.ops.onednn.qconv2d_pointwise.binary( + x, 1.0, 0, packed_w, w_scale, w_zp, qaccum, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, + 1.0, 0, "sum", None, "none", [], None, + ) + self.assertEqual(result.numel(), 0) + # Empty tensors both have data_ptr()==0; the key invariant is that + # the early-return path doesn't crash and returns a valid tensor. + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv2d_empty_output_no_binary(self): + # Exercise the early return path with has_accum_postop_sum=false. + groups = 1 + in_ch, out_ch = 4, 4 + x = torch.randint(0, 8, (0, in_ch, 5, 5), dtype=torch.uint8) + x = x.to(memory_format=torch.channels_last) + w_raw = torch.randint(-4, 4, (out_ch, in_ch, 3, 3), dtype=torch.int8) + w_scale = torch.tensor([1.0], dtype=torch.float) + w_zp = torch.tensor([0], dtype=torch.int64) + packed_w = torch.ops.onednn.qconv_prepack( + w_raw, w_scale, 1.0, 0, [1, 1], [1, 1], [1, 1], groups, x.size(), + ) + result = torch.ops.onednn.qconv2d_pointwise( + x, 1.0, 0, packed_w, w_scale, w_zp, None, + [1, 1], [1, 1], [1, 1], groups, 1.0, 0, None, "none", [], None, + ) + self.assertEqual(result.numel(), 0) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qconv1d_sum_exercises_1d_path(self): + # Exercise the is_1d squeeze path via a 1d convolution without + # binary postop (has_accum_postop_sum is unreachable for 1d since + # run_pointwise_binary enforces act.dim()==4). + groups = 1 + in_ch, out_ch = 4, 4 + x = torch.randint(0, 8, (1, in_ch, 8), dtype=torch.uint8) + w_raw = torch.randint(-4, 4, (out_ch, in_ch, 3), dtype=torch.int8) + w_scale = torch.tensor([1.0], dtype=torch.float) + w_zp = torch.tensor([0], dtype=torch.int64) + packed_w = torch.ops.onednn.qconv_prepack( + w_raw, w_scale, 1.0, 0, [1], [1], [1], groups, x.size(), + ) + result = torch.ops.onednn.qconv_pointwise( + x, 1.0, 0, packed_w, w_scale, w_zp, None, + [1], [1], [1], groups, 1.0, 0, None, "none", [], None, + ) + self.assertEqual(result.dim(), 3) + # Test qconv1d with post op relu @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") @skipIfNoONEDNN @@ -8232,6 +8363,21 @@ def test_qconv3d_fp8(self): torch.manual_seed(0) # For reproducibility in 3D conv tests self._test_qconv_fp8_helper(3, pointwise_post_op) + @unittest.skipIf( + torch.backends.quantized.engine == "none", + "No default quantized engine available", + ) + def test_qconv1d_default_engine(self): + # Regression test for https://github.com/pytorch/pytorch/issues/177254 + # On aarch64, fbgemmSupportedCPU() incorrectly returned True, causing + # the default quantized engine to be X86 which crashed in FBGEMM with + # "RuntimeError: unknown architecure". # codespell:ignore architecure + qconv1d = torch.ao.nn.quantized.Conv1d(4, 8, 3) + x = torch.quantize_per_tensor( + torch.randn(1, 4, 16), scale=1.0, zero_point=0, dtype=torch.quint8 + ) + qconv1d(x) + class TestPadding(TestCase): diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index fd7e8516bf099..be6a2ae8b8cb0 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -972,13 +972,13 @@ def _test_learnable_backward_per_channel(self, X_base, device, scale_base, zero_ self.assertTrue( torch.allclose(dX_expected, dX_actual, rtol=tolerance, atol=tolerance), - f"Expected dX={dX_expected} to match X.grad={dX_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 + f"Expected dX={dX_expected} to match X.grad={dX_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") self.assertTrue( torch.allclose(dScale_expected * grad_factor, dScale_actual, rtol=tolerance, atol=tolerance), - f"Expected dScale={dScale_expected * grad_factor} to match scale.grad={dScale_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 + f"Expected dScale={dScale_expected * grad_factor} to match scale.grad={dScale_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") self.assertTrue( torch.allclose(dZeroPoint_expected * grad_factor, dZeroPoint_actual, rtol=tolerance, atol=tolerance), - f"Expected dZeroPoint={dZeroPoint_expected * grad_factor} to match zero_point.grad={dZeroPoint_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") # noqa: B950 + f"Expected dZeroPoint={dZeroPoint_expected * grad_factor} to match zero_point.grad={dZeroPoint_actual}, X={X_curr}, s={scale_curr}, z={zero_point_curr}, dout={dout}, n_bits={n_bits}") X_curr.grad.data.zero_() scale_curr.grad.data.zero_() zero_point_curr.grad.data.zero_() diff --git a/test/quantization/eager/test_quantize_eager_qat.py b/test/quantization/eager/test_quantize_eager_qat.py index 1507246c3b798..007060c053e17 100644 --- a/test/quantization/eager/test_quantize_eager_qat.py +++ b/test/quantization/eager/test_quantize_eager_qat.py @@ -977,7 +977,7 @@ def compose(functions): if not use_relu: - def relu_op(x): # noqa: F811 + def relu_op(x): return x if freeze_bn: diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index 8584b9f405d76..a5175356342cf 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -5851,7 +5851,7 @@ def forward(self, x): ) backend_config = BackendConfig() \ .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear)) @@ -5903,7 +5903,7 @@ def forward(self, x): backend_config = BackendConfig() \ .set_backend_pattern_config(BackendPatternConfig(torch.nn.Linear) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E128 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .add_dtype_config(dtype_config) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear)) @@ -6276,7 +6276,7 @@ def root_node_getter(node_pattern): backend_pattern_configs.append( BackendPatternConfig() - ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) # noqa: E131 + ._set_pattern_complex_format((torch.reshape, torch.transpose, MatchAllNode)) .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_root_node_getter(root_node_getter) diff --git a/test/run_test.py b/test/run_test.py index ff2a617a6eb78..dc1203728d3b8 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -30,6 +30,7 @@ get_report_path, IS_CI, IS_MACOS, + IS_WINDOWS, isRocmArchAnyOf, retry_shell, set_cwd, @@ -104,6 +105,7 @@ def upload_adhoc_failure_json(*args, **kwargs): TEST_CONFIG = os.getenv("TEST_CONFIG", "") BUILD_ENVIRONMENT = os.getenv("BUILD_ENVIRONMENT", "") RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1" +NUM_PYTEST_RERUNS = int(os.getenv("PYTORCH_NUM_PYTEST_RERUNS", "2")) DISTRIBUTED_TEST_PREFIX = "distributed" INDUCTOR_TEST_PREFIX = "inductor" IS_SLOW = "slow" in TEST_CONFIG or "slow" in BUILD_ENVIRONMENT @@ -196,6 +198,7 @@ def __contains__(self, item): "test_jit_legacy", "test_cuda_nvml_based_avail", "test_jit_cuda_fuser", + "distributed/pipelining/test_dtensor_pp_integration", ] # Add architecture-specific blocklist entries @@ -1266,9 +1269,9 @@ def get_pytest_args(options, is_cpp_test=False, is_distributed_test=False): # flakiness status. Default to 50 re-runs rerun_options = ["--flake-finder", f"--flake-runs={count}"] else: - # When under the normal mode, retry a failed test 2 more times. -x means stop at the first - # failure - rerun_options = ["-x", "--reruns=2"] + # When under the normal mode, retry a failed test NUM_PYTEST_RERUNS more times. + # -x means stop at the first failure. Set PYTORCH_NUM_PYTEST_RERUNS=0 to disable. + rerun_options = ["-x", f"--reruns={NUM_PYTEST_RERUNS}"] pytest_args = [ "-vv", @@ -1801,7 +1804,50 @@ def get_selected_tests(options) -> list[str]: selected_tests = exclude_tests(options.exclude, selected_tests) - if sys.platform == "win32" and not options.ignore_win_blocklist: + if IS_WINDOWS and not options.ignore_win_blocklist: + from torch.testing._internal.common_cuda import SM120OrLater, SM89OrLater + + # Disable tests on Windows for SM89 and later - tests failing in ci + # Enable tests after fixing the failures + if SM89OrLater: + WINDOWS_BLOCKLIST.extend( + [ + # Windows fatal exception / access violation + "functorch/test_aotdispatch", + "functorch/test_control_flow", + "nn/test_convolution", + "profiler/test_profiler", + "test_modules", + "test_expanded_weights", + "test_jit", + "test_nested_tensor", + "test_nestedtensor", + "test_nn", + # DLL load failed errors, missing dependencies + "test_custom_ops", + "test_testing", + # Features not supported on Windows ( e.g. rowwise scaling) + "test_decomp", + "test_transformers", + "test_ops", + # Output mismatch errors and long running tests + "test_linalg", + "test_matmul_cuda", + "functorch/test_ops", + "test_scaled_matmul_cuda", + ] + ) + + # Disable tests on Windows for SM120 and later - tests failing in ci + # Enable tests after fixing the failures + if SM120OrLater: + WINDOWS_BLOCKLIST.extend( + [ + # test_api fails on Windows SM120+. Triage pending. + "cpp/test_api", + ] + ) + target_arch = os.environ.get("VSCMD_ARG_TGT_ARCH") if target_arch != "x64": WINDOWS_BLOCKLIST.append("cpp_extensions_aot_no_ninja") @@ -1862,13 +1908,18 @@ def load_test_times_from_file(file: str) -> dict[str, Any]: with open(path) as f: test_times_file = cast(dict[str, Any], json.load(f)) - job_name = os.environ.get("JOB_NAME") + raw_job_name = os.environ.get("JOB_NAME") + build_env = os.environ.get("BUILD_ENVIRONMENT") + job_name = raw_job_name if job_name is None or job_name == "": # If job name isn't available, use build environment as a backup - job_name = os.environ.get("BUILD_ENVIRONMENT") + job_name = build_env else: job_name = job_name.split(" / test (")[0] test_config = os.environ.get("TEST_CONFIG") + print_to_stderr(f"JOB_NAME={raw_job_name}") + print_to_stderr(f"BUILD_ENVIRONMENT={build_env}") + print_to_stderr(f"test-times lookup key={job_name}, test_config={test_config}") if test_config in test_times_file.get(job_name, {}): print_to_stderr("Found test times from artifacts") return test_times_file[job_name][test_config] diff --git a/test/slow_tests.json b/test/slow_tests.json index abd30b98d47e1..975ad037c53f8 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,252 +1,277 @@ { - "EndToEndLSTM (__main__.RNNTest)": 184.69266764322916, - "MultiheadAttention (__main__.ModulesTest)": 137.28699747721353, - "test_3mm_add (__main__.TestTritonDotReduction)": 100.44633229573567, - "test_AllenaiLongformerBase_repro_cpu_halide (__main__.HalideCpuTests)": 213.74633280436197, - "test_StridedShard_to_shard_order (__main__.Test_StridedShard_with_shard_order)": 283.25299580891925, - "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 92.95833418104384, - "test_adaptive_max_pool2d1_cpu_halide (__main__.HalideCpuTests)": 116.62366739908855, - "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 62.366333855523, - "test_alexnet_prefix_cpu_halide (__main__.HalideCpuTests)": 177.8383331298828, - "test_aot_autograd_disable_functionalization_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 63.67699940999349, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 127.79400380452473, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 208.3969980875651, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 124.48566436767578, - "test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 73.43719082786923, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 71.86966705322266, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 145.66200256347656, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 223.1346689860026, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 136.3513387044271, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 66.18100102742513, - "test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 61.71500015258789, - "test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 68.69900004069011, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.28733571370442, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 149.20133209228516, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 91.24416859944661, - "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 108.57016626993816, - "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 616.325, - "test_avg_pool3d_backward2_cpu (__main__.CpuTritonTests)": 272.4786682128906, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 156.27499934605189, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 529.6212259928385, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 542.7406548394097, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 64.20183245340984, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 124.58166758219402, - "test_avg_pool3d_backward_cpu_halide (__main__.HalideCpuTests)": 63.05166753133138, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 94.64399973551433, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 66.33616765340169, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 218.5576663547092, - "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 168.72183481852213, - "test_cat_2k_args (__main__.TestTEFuserStatic)": 64.58066653770705, - "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 331.3285556369358, - "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 418.45111083984375, - "test_collect_callgrind (__main__.TestBenchmarkUtils)": 317.6703355577257, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 105.86016591389973, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 101.96483357747395, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 67.77633285522461, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 69.17333221435547, - "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 102.28200022379558, - "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 93.95999908447266, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 430.5759989420573, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 449.9323323567708, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 253.6421686808268, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 271.6805013020833, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1271.2601725260417, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 77.94216664632161, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1118.7591756184895, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 82.65366744995117, - "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.36716715494792, - "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 71.19016647338867, - "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.25883356730144, - "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 75.40216700236003, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 66.99350102742513, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 69.02083269755046, - "test_comprehensive_logspace_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 60.92988925509982, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 113.01116689046223, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 107.1844991048177, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 108.23550160725911, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 65.7233346303304, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 64.3291670481364, - "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 62.670667012532554, - "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float64 (__main__.TestDecompCPU)": 62.265000661214195, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 111.73116556803386, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 113.69033304850261, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 117.5066655476888, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 118.56566619873047, - "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 61.97433344523112, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 271.9861602783203, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 274.42466735839844, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 82.77866490681966, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 84.92516581217448, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 82.4509989420573, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 86.13049952189128, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 125.71083450317383, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 121.97066752115886, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 1262.5691528320312, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 1252.483642578125, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 1260.509501139323, - "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 604.4365030924479, - "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 614.1001586914062, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 67.40933354695638, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 66.7336654663086, - "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 66.91466649373372, - "test_comprehensive_ormqr_cpu_complex64 (__main__.TestDecompCPU)": 65.80933380126953, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 117.8523343404134, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 145.19499842325845, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 73.76366806030273, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 78.31683349609375, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 76.74333190917969, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 70.06755489773221, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 80.16916529337566, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 76.32833353678386, - "test_comprehensive_svd_lowrank_cuda_complex128 (__main__.TestDecompCUDA)": 65.76883252461751, - "test_comprehensive_svd_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 67.16433334350586, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 88.39966583251953, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 83.59666442871094, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 83.94522179497613, - "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 187.1534457736545, - "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 91.1510009765625, - "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 153.24633280436197, - "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 62.278334299723305, - "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 82.22033437093098, - "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 86.71616617838542, - "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 83.59499867757161, - "test_count_nonzero_all (__main__.TestBool)": 676.7384372287327, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 448.8698336283366, - "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 62.429334004720054, - "test_diff_hyperparams_sharding_strategy_str_no_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 61.66200065612793, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 89.41600036621094, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestLocalDTensorOpsCPU)": 1525.8353271484375, - "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 90.60733540852864, - "test_error_detection_and_propagation (__main__.NcclErrorHandlingTest)": 66.99349848429362, - "test_fail_arithmetic_ops.py (__main__.TestTyping)": 71.01377783881293, - "test_fail_creation_ops.py (__main__.TestTyping)": 205.9894081398293, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 79.61499913533528, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 92.82350031534831, - "test_fuse_large_params_cpu (__main__.CpuTests)": 139.3203353881836, - "test_fuse_large_params_cuda (__main__.GPUTests)": 63.387572152274, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 171.7693345811632, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 168.58877902560764, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 126.14983367919922, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 120.24416605631511, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 81.49983342488606, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 192.15400187174478, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 113.96200052897136, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 125.24116770426433, - "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 510.77049255371094, - "test_grid_sampler_2d_cpu_halide (__main__.HalideCpuTests)": 194.58466593424478, - "test_indirect_device_assert (__main__.TritonCodeGenTests)": 307.21116638183594, - "test_inductor_no_recursionerror_on_for_loops (__main__.ReproTests)": 65.38683425055609, - "test_inductor_no_recursionerror_on_for_loops_dynamic_shapes (__main__.DynamicShapesReproTests)": 70.7237786187066, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 91.1038335164388, - "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 132.38899824354382, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 78.77233250935872, - "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 120.73766835530598, - "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 124.48300170898438, - "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 104.92066701253255, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 86.04666392008464, - "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 136.2572224934896, - "test_max_pool2d2_cpu_halide (__main__.HalideCpuTests)": 417.986328125, - "test_max_pool2d3_cpu_halide (__main__.HalideCpuTests)": 134.87866719563803, - "test_max_pool2d5_cpu_halide (__main__.HalideCpuTests)": 357.8183288574219, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 64.54022216796875, - "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 66.25211164686415, - "test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTest)": 158.06944274902344, - "test_ordered_distribute_all_combination (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 78.09822167290582, - "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTest)": 177.78066677517361, - "test_ordered_redistribute_with_partial (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 621.3526577419705, - "test_proper_exit (__main__.TestDataLoader)": 218.22477637396918, - "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 219.7027791341146, - "test_python_ref__refs_special_zeta_cuda_float64 (__main__.TestCommonCUDA)": 67.61199951171875, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 82.16616694132487, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 95.37699890136719, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 106.87333424886067, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 110.58233388264973, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 93.18999989827473, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 117.0056660970052, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 112.08600107828777, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 90.89966583251953, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 99.80333201090495, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 101.7693354288737, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 106.9923324584961, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 104.36800130208333, - "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.9153340657552, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False (__main__.TestPatternMatcher)": 100.8326644897461, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 107.92199961344402, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 104.72100067138672, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True (__main__.TestPatternMatcher)": 93.16766611735027, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.56466674804688, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.9739990234375, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False (__main__.TestPatternMatcher)": 88.35433451334636, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 100.52966817220052, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 102.96500142415364, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True (__main__.TestPatternMatcher)": 98.01133473714192, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 112.06033325195312, - "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 106.11466725667317, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 616.5533447265625, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1048.9943033854167, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 770.6356608072916, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1446.0608520507812, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 85.64499918619792, - "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 278.70049540201825, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 127.52366892496745, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 237.04200236002603, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 70.95300038655598, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 146.1635004679362, - "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 72.04449971516927, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 99.4903335571289, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 177.31633504231772, - "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 72.39499918619792, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 142.12999725341797, - "test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 114.4384994506836, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_amp (__main__.DeterministicTest)": 70.59966659545898, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_bfloat16 (__main__.DeterministicTest)": 61.23133405049642, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 152.28532918294272, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 151.11583201090494, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 170.14933268229166, - "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 141.25400034586588, - "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 78.44533411661784, - "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 91.19933319091797, - "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 80.00733184814453, - "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 76.68100102742513, - "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 92.68900044759114, - "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 92.34583409627278, - "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 92.69583257039388, - "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 93.89749908447266, - "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 65.03716723124187, - "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 69.72733306884766, - "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 171.71016947428384, - "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 107.34944322374132, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 129.4277776082357, - "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 118.2247797648112, - "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 108.49850209554036, - "test_sort_stable_cpu (__main__.CpuTritonTests)": 1305.087646484375, - "test_sort_stable_cuda (__main__.GPUTests)": 102.32514299665179, - "test_split_cumsum_cpu (__main__.CpuTritonTests)": 90.37133534749348, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 285.8221740722656, - "test_tensor_split (__main__.TestVmapOperators)": 91.96030240144694, - "test_terminate_handler_on_crash (__main__.TestTorch)": 207.283443874783, - "test_terminate_signal (__main__.ForkTest)": 234.09955173068576, - "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 234.16177961561414, - "test_terminate_signal (__main__.SpawnTest)": 237.5818862915039, - "test_train_parity_with_activation_checkpointing (__main__.TestFullyShard1DTrainingCompose)": 64.86788728502061, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 88.8943322499593, - "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 215.28750356038412, - "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 214.9933344523112, - "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 150.96166483561197, - "test_upsample_bicubic2d_cpu_halide (__main__.HalideCpuTests)": 97.6326675415039, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 86.5586675008138, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 76.31433232625325, - "test_vec_compare_op_cpu_only (__main__.CPUReproTests)": 64.15555487738715, - "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 86.58700052897136, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 70.91766738891602, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 64.271666208903, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 81.8783327738444, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 65.68414279392788, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 70.83733367919922, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 77.18000030517578, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 67.37900034586589, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 80.73283386230469, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 90.59783426920573, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 83.13833363850911, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 112.93466695149739 + "EndToEndLSTM (__main__.RNNTest)": 193.73400370279947, + "MultiheadAttention (__main__.ModulesTest)": 144.85599772135416, + "test_3mm_add (__main__.TestTritonDotReduction)": 129.92483266194662, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 90.46566390991211, + "test_aot_autograd_disable_functionalization_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 77.86966705322266, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 127.5673319498698, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 204.36866760253906, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 123.83066304524739, + "test_aot_autograd_disable_functionalization_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 66.83466720581055, + "test_aot_autograd_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 72.18625852796767, + "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 77.01166534423828, + "test_aot_autograd_symbolic_exhaustive_masked_norm_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 62.44844436645508, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 156.3863321940104, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 203.4173329671224, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 105.20900217692058, + "test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 65.38600158691406, + "test_aot_autograd_symbolic_exhaustive_ormqr_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 83.66033426920573, + "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 101.71733093261719, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 127.03733317057292, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_False_cpu (__main__.AssociativeScanTests)": 88.39666748046875, + "test_associative_scan_partial_grad_combine_mode_generic_compile_mode_compile_dynamic_shape_reverse_True_cpu (__main__.AssociativeScanTests)": 102.40783437093098, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 763.9162445068359, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 152.19875049591064, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 514.7272237141927, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 520.711890326606, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 64.15366808573405, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 128.34783426920572, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 71.84533437093098, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 64.79450035095215, + "test_baddmm_search_space_EXHAUSTIVE (__main__.TestMaxAutotune)": 112.79600016276042, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 226.9800008138021, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 173.3750025431315, + "test_cat_2k_args (__main__.TestTEFuserDynamic)": 131.654914483428, + "test_checkpointing_without_reentrant_input_requires_grad_False (__main__.TestAutogradWithCompiledAutograd)": 347.29089016384546, + "test_checkpointing_without_reentrant_input_requires_grad_True (__main__.TestAutogradWithCompiledAutograd)": 448.10833062065973, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 285.1102227105035, + "test_compiled_dtensor_op_db_nn_functional_max_pool2d_cpu_float32 (__main__.TestCompiledDTensorOpsCPU)": 92.9769999186198, + "test_compiled_dtensor_op_db_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestCompiledDTensorOpsCPU)": 64.3003323872884, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 100.02783330281575, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 97.7408332824707, + "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 67.63183339436848, + "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 66.48316637674968, + "test_comprehensive_grid_sampler_2d_cpu_bfloat16 (__main__.TestDecompCPU)": 89.89199829101562, + "test_comprehensive_grid_sampler_2d_cpu_float16 (__main__.TestDecompCPU)": 90.3046646118164, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 432.8856608072917, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 431.4099934895833, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 244.3134994506836, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 277.8193359375, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1243.374491373698, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 71.73899968465169, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1143.0853271484375, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 75.46933364868164, + "test_comprehensive_linalg_lu_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 66.4915000597636, + "test_comprehensive_linalg_lu_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 67.09683386484782, + "test_comprehensive_linalg_lu_factor_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 62.22633425394694, + "test_comprehensive_linalg_lu_factor_ex_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 62.32066790262858, + "test_comprehensive_linalg_lu_factor_ex_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 62.156500498453774, + "test_comprehensive_linalg_lu_solve_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 80.1408322652181, + "test_comprehensive_linalg_lu_solve_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 79.4061648050944, + "test_comprehensive_linalg_matrix_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 95.98250071207683, + "test_comprehensive_linalg_matrix_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 94.55149968465169, + "test_comprehensive_linalg_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 98.94500128428142, + "test_comprehensive_linalg_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 115.80650011698405, + "test_comprehensive_linalg_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 115.1416654586792, + "test_comprehensive_linalg_norm_subgradients_at_zero_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 87.03200038274129, + "test_comprehensive_linalg_norm_subgradients_at_zero_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 110.46916611989339, + "test_comprehensive_linalg_norm_subgradients_at_zero_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 111.24783261617024, + "test_comprehensive_linalg_solve_triangular_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 83.31766764322917, + "test_comprehensive_linalg_solve_triangular_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 80.25166447957356, + "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 69.97166570027669, + "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.84450022379558, + "test_comprehensive_linalg_svd_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 96.60883108774821, + "test_comprehensive_linalg_svd_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 93.79333368937175, + "test_comprehensive_linalg_vector_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 256.18433062235516, + "test_comprehensive_linalg_vector_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 248.46516863505045, + "test_comprehensive_logspace_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 74.78166580200195, + "test_comprehensive_logspace_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 69.59733327229817, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 117.71599960327148, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 113.31033579508464, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 114.37183252970378, + "test_comprehensive_max_pool2d_with_indices_backward_cpu_float64 (__main__.TestDecompCPU)": 64.0938326517741, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 63.42050043741862, + "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 63.77633285522461, + "test_comprehensive_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDecompCPU)": 61.52166620890299, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 118.71116765340169, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 125.97533289591472, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 110.10399881998698, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 119.12533315022786, + "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 63.85616683959961, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 263.6338348388672, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 260.0433349609375, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 80.7168337504069, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 83.66033426920573, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 78.34583409627278, + "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 83.96233240763347, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 120.8836669921875, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 124.52166748046875, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 819.3448282877604, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 800.9398295084635, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 798.9723409016927, + "test_comprehensive_nn_functional_max_pool3d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 650.4093424479166, + "test_comprehensive_nn_functional_max_pool3d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 672.7261657714844, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 70.85333506266277, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.07083384195964, + "test_comprehensive_nn_functional_max_unpool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 73.60183334350586, + "test_comprehensive_ormqr_cpu_complex128 (__main__.TestDecompCPU)": 63.091583251953125, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 119.43050130208333, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 129.18349838256836, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 73.39383252461751, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 77.46083196004231, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 100.95350011189778, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 76.22900009155273, + "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 76.60616556803386, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 72.88816579182942, + "test_comprehensive_svd_lowrank_cuda_complex128 (__main__.TestDecompCUDA)": 66.14150047302246, + "test_comprehensive_svd_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 71.52549870808919, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 103.31566747029622, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 90.61466598510742, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 85.5044453938802, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 186.80122545030383, + "test_conv2d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 73.75333404541016, + "test_conv3d_binary_broadcast_shapes_cpu (__main__.TestPatternMatcherGenericCPU)": 165.67066955566406, + "test_conv3d_binary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 69.34033330281575, + "test_conv3d_unary_dynamic_shapes_cpu (__main__.TestDynamicPatternMatcherGenericCPU)": 73.0816667344835, + "test_conv_bn_fuse_cpu (__main__.CpuTests)": 62.08125019073486, + "test_conv_unary_fusion_nnc (__main__.TestMkldnnFusion)": 82.00311109754774, + "test_correctness_AdamW_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 89.69633356730144, + "test_correctness_Adam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 88.30749893188477, + "test_count_nonzero_all (__main__.TestBool)": 654.9627821180555, + "test_create_rand_mask_from_inputs_dynamic_shapes (__main__.DynamicShapesReproTests)": 92.5728333791097, + "test_dijkstra_expand_single_dim_strategy_to_mesh_hard_4d (__main__.TestDijkstraExpandSingleDimStrategy)": 303.05066935221356, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 95.07450103759766, + "test_dtensor_op_db__native_batch_norm_legit_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 172.50000508626303, + "test_dtensor_op_db_native_batch_norm_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 200.29400126139322, + "test_dtensor_op_db_nn_functional_batch_norm_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 199.36066691080728, + "test_dtensor_op_db_nn_functional_binary_cross_entropy_with_logits_cpu_float32 (__main__.TestMultiThreadedDTensorOpsCPU)": 183.6556650797526, + "test_eager_sequence_nr_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 195.9025821685791, + "test_fail_arithmetic_ops.py (__main__.TestTyping)": 85.94689008924696, + "test_fail_random.py (__main__.TestTyping)": 363.56556599934896, + "test_fail_torch_size.py (__main__.TestTyping)": 114.58616725787675, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 83.33750025431316, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 103.71716690063477, + "test_fuse_large_params_cpu (__main__.CpuTests)": 92.53525161743164, + "test_fuse_large_params_cuda (__main__.GPUTests)": 62.06562566757202, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 156.92011176215277, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 158.4638909233941, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 135.43400065104166, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 132.73733266194662, + "test_grad_nn_Transformer_cpu_float64 (__main__.TestModuleCPU)": 84.30703392028809, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 80.69633356730144, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 190.28999837239584, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 113.35149892171223, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 123.2726656595866, + "test_gradgrad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 494.573003133138, + "test_graph_based_redistribute_cost (__main__.DistributeWithDeviceOrderTest)": 67.4446652730306, + "test_graph_based_redistribute_cost (__main__.DistributeWithDeviceOrderTestWithLocalTensor)": 62.11899948120117, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 325.7254994710286, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 83.06583404541016, + "test_inputs_overlapping_with_mutation_stress_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 122.00066714816623, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 86.76616668701172, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 1066.7661743164062, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 72.56916681925456, + "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 151.44950103759766, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 140.42566935221353, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 150.54800160725912, + "test_longformer_chunk_dynamic_shapes (__main__.DynamicShapesReproTests)": 104.55644395616319, + "test_low_memory_max_pool_dilation_1_dim_3_use_block_ptr_False_cpu_halide (__main__.HalideCpuTests)": 583.5033365885416, + "test_low_memory_max_pool_dilation_2_dim_3_use_block_ptr_False_cpu_halide (__main__.HalideCpuTests)": 516.3889973958334, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 68.73433176676433, + "test_many_overlapping_inputs_does_not_explode_guards_dynamic_shapes (__main__.DynamicShapesReproTests)": 108.97577582465277, + "test_max_autotune_remote_caching_dynamic_True (__main__.TestMaxAutotuneRemoteCache)": 61.8836669921875, + "test_memory_format_operators_cuda (__main__.TestTorchDeviceTypeCUDA)": 125.83300018310547, + "test_ops_composition_names_cpu (__main__.TestTestParametrizationDeviceTypeCPU)": 157.14929997324944, + "test_ops_decorator_applies_op_and_param_specific_decorators_cpu (__main__.TestTestParametrizationDeviceTypeCPU)": 189.11739701827366, + "test_pattern_matcher_multi_user_cpu (__main__.CpuTritonTests)": 143.3016611735026, + "test_pq_vs_full_expansion_data_driven (__main__.TestDijkstraExpandSingleDimStrategy)": 881.4066772460938, + "test_proper_exit (__main__.TestDataLoader)": 219.32978057861328, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 229.30333031548395, + "test_python_ref__refs_special_zeta_cuda_float64 (__main__.TestCommonCUDA)": 74.98866653442383, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 89.93916702270508, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 556.7093302408854, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 1110.48583984375, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 813.156005859375, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1398.7890014648438, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 84.88433329264323, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 268.56183369954425, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 127.7530008951823, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 241.10733795166016, + "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 77.3393325805664, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 143.2455037434896, + "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 69.18416849772136, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 105.94433339436848, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 174.2881647745768, + "test_quick_core_backward_std_cpu_float64 (__main__.TestDecompCPU)": 77.57733408610027, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 158.21883392333984, + "test_register_spills_cuda (__main__.BenchmarkFusionGpuTest)": 108.23566563924153, + "test_resize_as_cuda (__main__.GPUTests)": 89.16333452860515, + "test_resize_cuda (__main__.GPUTests)": 67.50366735458374, + "test_reveal_opt_size.py (__main__.TestTyping)": 60.49586756983772, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_amp (__main__.DeterministicTest)": 104.33299763997395, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_inference_precision_bfloat16 (__main__.DeterministicTest)": 64.46683311462402, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 198.3463338216146, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 153.71866353352866, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 167.72516632080078, + "test_run2run_determinism_model_name_BertForMaskedLM_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 188.2213338216146, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 82.31683349609375, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 115.22916666666667, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 84.5586649576823, + "test_run2run_determinism_model_name_DistillGPT2_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 122.09583409627278, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_amp (__main__.DeterministicTest)": 95.04883448282878, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_bfloat16 (__main__.DeterministicTest)": 94.20300165812175, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float16 (__main__.DeterministicTest)": 118.3158327738444, + "test_run2run_determinism_model_name_GoogleFnet_training_or_inference_training_precision_float32 (__main__.DeterministicTest)": 119.6198336283366, + "test_runtime_checks_large_cpu (__main__.AOTInductorTestABICompatibleCpu)": 70.40016682942708, + "test_runtime_checks_large_cpu_with_stack_allocation (__main__.AOTInductorTestABICompatibleCpuWithStackAllocation)": 70.30955590142145, + "test_runtime_checks_large_cuda (__main__.AOTInductorTestABICompatibleGpu)": 171.0409952799479, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 106.9195556640625, + "test_searchsorted_cuda (__main__.GPUTests)": 82.54333392779033, + "test_shared_memory_pruning_addmm_bfloat16_mat1_transposed_False_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 143.74433390299478, + "test_shared_memory_pruning_addmm_bfloat16_mat1_transposed_False_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 143.3086675008138, + "test_shared_memory_pruning_addmm_bfloat16_mat1_transposed_True_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 143.81799825032553, + "test_shared_memory_pruning_addmm_bfloat16_mat1_transposed_True_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 144.1403350830078, + "test_shared_memory_pruning_addmm_float32_mat1_transposed_False_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 153.8959986368815, + "test_shared_memory_pruning_addmm_float32_mat1_transposed_False_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 159.3713353474935, + "test_shared_memory_pruning_addmm_float32_mat1_transposed_True_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 151.32366689046225, + "test_shared_memory_pruning_addmm_float32_mat1_transposed_True_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 151.53400166829428, + "test_shared_memory_pruning_mm_bfloat16_mat1_transposed_False_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 137.7143351236979, + "test_shared_memory_pruning_mm_bfloat16_mat1_transposed_False_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 137.6969985961914, + "test_shared_memory_pruning_mm_bfloat16_mat1_transposed_True_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 139.831662495931, + "test_shared_memory_pruning_mm_bfloat16_mat1_transposed_True_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 138.6189982096354, + "test_shared_memory_pruning_mm_float32_mat1_transposed_False_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 149.81566874186197, + "test_shared_memory_pruning_mm_float32_mat1_transposed_False_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 154.44766743977866, + "test_shared_memory_pruning_mm_float32_mat1_transposed_True_mat2_transposed_False_use_tma_False (__main__.TestTemplateConfigPruning)": 146.81933085123697, + "test_shared_memory_pruning_mm_float32_mat1_transposed_True_mat2_transposed_True_use_tma_False (__main__.TestTemplateConfigPruning)": 147.39400227864584, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 122.32033454047308, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 124.18311055501302, + "test_softmax_cpu_float64 (__main__.TestSparseCPU)": 60.396761218706764, + "test_sort_bool_cpu (__main__.CpuTritonTests)": 348.36634318033856, + "test_sort_dynamic_shape_with_check_cuda (__main__.TestInductorDynamicCUDA)": 112.24066670735677, + "test_sort_stable_cuda (__main__.GPUTests)": 126.56412506103516, + "test_sort_transpose_cpu (__main__.CpuTritonTests)": 376.54832967122394, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 327.46799723307294, + "test_tensorwise_scaling_acceptable_input_dims_M_1024_K_1024_N_2048_persistent_matmul_False_cpu (__main__.TestFP8LoweringCPU)": 75.4329337477684, + "test_terminate_handler_on_crash (__main__.TestTorch)": 207.00677956475153, + "test_terminate_signal (__main__.ForkTest)": 232.13122049967447, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 232.00266604953342, + "test_terminate_signal (__main__.SpawnTest)": 222.40377510918512, + "test_torch_size_tensor_index_scalar_constant_dynamic_shapes (__main__.DynamicShapesMiscTests)": 601.1639200846354, + "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 174.3316650390625, + "test_train_parity_multi_group_cpu_offload_eager (__main__.TestFullyShard1DTrainingCore)": 61.27133305867513, + "test_train_parity_multi_group_unshard_async_op (__main__.TestFullyShard1DTrainingCore)": 64.80400085449219, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 87.95033391316731, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 254.8540013631185, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 256.1475016276042, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 194.4748331705729, + "test_tuning_pool_timeout (__main__.TestTuningProcessPool)": 126.16133371988933, + "test_unbacked_dtensor_op_db_clamp_cpu_float32 (__main__.TestUnbackedDTensorOpsCPU)": 96.5290018717448, + "test_unbind (__main__.TestVmapOperators)": 82.44380033289393, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 88.58866628011067, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 78.71033223470052, + "test_vec_compare_op_cpu_only (__main__.CPUReproTests)": 62.92533323499892, + "test_views1_cuda (__main__.GPUTests)": 106.26166693369548, + "test_vmapjvpvjp_linalg_lstsq_grad_oriented_cpu_float32 (__main__.TestOperatorsCPU)": 113.2413330078125, + "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 68.83388929013853, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 67.28550148010254, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 86.25900014241536, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 78.42966588338216, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 69.55499903361003, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 89.97300211588542, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 79.44366709391277, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.21299997965495, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 81.87016677856445, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 82.90099970499675, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 81.67116800944011, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 110.1931660970052 } \ No newline at end of file diff --git a/test/test_accelerator.py b/test/test_accelerator.py index 67b92969aefd0..393e2ff8d7e05 100644 --- a/test/test_accelerator.py +++ b/test/test_accelerator.py @@ -3,6 +3,7 @@ import gc import sys import unittest +from contextlib import nullcontext import torch from torch.testing._internal.common_utils import ( @@ -17,7 +18,7 @@ if not TEST_ACCELERATOR: print("No available accelerator detected, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest # Skip because failing when run on cuda build with no GPU, see #150059 for example sys.exit() @@ -260,6 +261,36 @@ def test_get_memory_info(self): self.assertGreaterEqual(free_bytes, 0) self.assertGreaterEqual(total_bytes, 0) + def test_device_capability_supported_dtypes(self): + try: + caps = torch.accelerator.get_device_capability() + except RuntimeError: + self.skipTest("Backend doesn't support get_device_capability") + + supported_dtypes = caps["supported_dtypes"] + self.assertIsInstance(supported_dtypes, set) + self.assertGreater(len(supported_dtypes), 0) + + acc = torch.accelerator.current_accelerator() + reference_dtype = next(iter(supported_dtypes)) + + all_dtypes = [ + getattr(torch, name) + for name in dir(torch) + if isinstance(getattr(torch, name), torch.dtype) + ] + for dtype in all_dtypes: + with self.subTest(dtype=dtype): + ctx = ( + nullcontext() + if dtype in supported_dtypes + else self.assertRaises((RuntimeError, TypeError)) + ) + with ctx: + t = torch.empty(16, dtype=dtype, device=acc) + t = t.to(reference_dtype) + t = t.to(dtype) + if __name__ == "__main__": run_tests() diff --git a/test/test_autograd.py b/test/test_autograd.py index 2eb7c63e1b6f2..b78d46abd1c34 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -20,7 +20,7 @@ import uuid import warnings import weakref -from collections import OrderedDict +from collections import defaultdict, OrderedDict from copy import deepcopy from functools import partial, reduce from itertools import product @@ -82,6 +82,7 @@ skipIfXpu, slowTest, TEST_WITH_TORCHDYNAMO, + TEST_XPU, TestCase, ) from torch.utils._mode_utils import no_dispatch @@ -93,6 +94,7 @@ create_selective_checkpoint_contexts, ) from torch.utils.flop_counter import FlopCounterMode +from torch.utils.weak import WeakTensorKeyDictionary if TYPE_CHECKING: @@ -572,9 +574,10 @@ def my_function(x, y): gradgradcheck(my_function, (x, y)) def test_not_implemented_grad(self): + # Test that built-in functions with unimplemented gradients raise appropriate errors a = torch.rand(2, requires_grad=True) - # if grad for nextafter ends up being implemented, this should be changed - y = torch.nextafter(a, a).sum() + y = torch.acosh_(a.clone()).sum() + with self.assertRaisesRegex( NotImplementedError, "the derivative for .* is not implemented" ): @@ -4368,6 +4371,50 @@ def test_no_grad_modifies_version(self): RuntimeError, "modified by an inplace operation", lambda: z.backward() ) + def test_inplace_version_error_shows_forward_op_name(self): + # The error message should refer to the forward op (e.g. "Add"), + # not the backward node (e.g. "AddBackward0"). + # Use b * b (not b * 2) so MulBackward saves b and detects the + # version mismatch when unpacking. + + a = torch.randn(5, requires_grad=True) + b = a + 1 + c = b * b + with torch.no_grad(): + b += 1 + with self.assertRaisesRegex(RuntimeError, r"output 0 of Add,"): + c.backward(torch.ones(5)) + + # sum(dim=...) uses SumBackward1; the non-zero variant number is + # preserved so the message says "Sum1" not "Sum". + a = torch.randn(3, 4, requires_grad=True) + b = a.sum(dim=1) + c = b * b + with torch.no_grad(): + b += 1 + with self.assertRaisesRegex(RuntimeError, r"output 0 of Sum1,"): + c.backward(torch.ones(3)) + + # Custom autograd Function: "MyFunc" not "MyFuncBackward". + class MyFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x.clone() + + @staticmethod + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return grad * x + + a = torch.randn(5, requires_grad=True) + b = MyFunc.apply(a) + c = b * b + with torch.no_grad(): + b += 1 + with self.assertRaisesRegex(RuntimeError, r"output 0 of MyFunc,"): + c.backward(torch.ones(5)) + def test_increment_version(self): a = torch.rand(5, requires_grad=True) v = a._version @@ -5075,6 +5122,19 @@ def f(x): self.assertTrue(torch.autograd.is_view_replay_enabled()) self.assertFalse(torch.autograd.is_view_replay_enabled()) + prev = torch.autograd.is_view_replay_enabled() + ctx = torch.autograd._force_original_view_tracking(not prev) + # Construction eagerly sets state (function-form behavior). + self.assertEqual(torch.autograd.is_view_replay_enabled(), not prev) + with ctx: + self.assertEqual(torch.autograd.is_view_replay_enabled(), not prev) + out = f(x) + self.assertTrue( + ("ViewBackward" if not prev else "AsStridedBackward") + in str(out.grad_fn) + ) + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + # Test as a function torch.autograd._force_original_view_tracking(False) out = f(x) @@ -5086,6 +5146,20 @@ def f(x): self.assertTrue("ViewBackward" in str(out.grad_fn)) self.assertTrue(torch.autograd.is_view_replay_enabled()) + prev = torch.autograd.is_view_replay_enabled() + + @torch.autograd._force_original_view_tracking(not prev) + def g(x): + return f(x) + + # __call__ undoes the __init__ mutation, so ambient state is restored. + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + out = g(x) + self.assertTrue( + ("ViewBackward" if not prev else "AsStridedBackward") in str(out.grad_fn) + ) + self.assertEqual(torch.autograd.is_view_replay_enabled(), prev) + def test_unsafe_set_version_counter(self): x = torch.ones(2, requires_grad=True).clone() x.add_(1) @@ -7356,10 +7430,10 @@ def foo(x, y, z): x = torch.randn(3, 3, requires_grad=True) y = torch.randn(3, 3, requires_grad=True) z = torch.randn(3, 3, requires_grad=True) - if device_type == "cuda": - x = x.cuda() - y = y.cuda() - z = z.cuda() + if device_type in ("cuda", "xpu"): + x = x.to(device_type) + y = y.to(device_type) + z = z.to(device_type) with torch.autocast( enabled=enabled, device_type=device_type, dtype=torch.bfloat16 @@ -7379,15 +7453,17 @@ def test_checkpointing_non_reentrant_autocast_cpu(self): self._test_checkpointing_non_reentrant_autocast(device_type="cpu") @unittest.skipIf( - not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(), - "Test requires CUDA bf16 support", + (not torch.cuda.is_available() or not torch.cuda.is_bf16_supported()) + and (not torch.xpu.is_available() or not torch.xpu.is_bf16_supported()), + "Test requires CUDA or XPU bf16 support", ) def test_checkpointing_non_reentrant_autocast_gpu(self): """ Test that autocast args/kwargs such as the dtype are preserved during non-reentrant checkpoint recomputation on GPU. """ - self._test_checkpointing_non_reentrant_autocast(device_type="cuda") + device_type = "cuda" if torch.cuda.is_available() else "xpu" + self._test_checkpointing_non_reentrant_autocast(device_type=device_type) @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") @slowTest @@ -8070,6 +8146,28 @@ def test_checkpointing_without_reentrant_correct_grad(self): self.assertEqual(b_grad, c_grad) self.assertEqual(b_grad, d_grad) + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + def test_checkpointing_without_reentrant_with_block_mask(self): + from torch.nn.attention.flex_attention import BlockMask, create_block_mask + from torch.utils._pytree import register_pytree_node, SUPPORTED_NODES + + if BlockMask not in SUPPORTED_NODES: + register_pytree_node( + BlockMask, + BlockMask._flatten, + BlockMask._unflatten, + flatten_with_keys_fn=BlockMask._flatten_with_keys, + serialized_type_name="torch.nn.attention.flex_attention.BlockMask", + ) + + block_mask = create_block_mask( + lambda b, h, q, kv: q >= kv, B=1, H=1, Q_LEN=128, KV_LEN=128 + ) + x = torch.randn(4, 128, device="cuda") + + result = checkpoint(lambda x, mask: x * 2, x, block_mask, use_reentrant=False) + self.assertEqual(result, x * 2) + @skipIfXpu(msg="torch._C._scatter Not implemented on XPU, issue #143239") def test_checkpointing_without_reentrant_dataparallel(self): """ @@ -8635,6 +8733,309 @@ def backward(ctx, g): self.assertEqual(y.grad_fn.saved_tensors, ()) self.assertEqual(y.grad_fn._raw_saved_tensors, ()) + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads(self): + """Test boxed_grads_call mechanism without torch.compile. + + With boxed_grads_call, backward receives a single mutable list + of grads instead of individual grad arguments.""" + + class BoxedFunc(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2 + + @staticmethod + def backward(ctx, grads): + self.assertIsInstance(grads, list) + self.assertEqual(len(grads), 1) + grad = grads[0] + grads.clear() + (x,) = ctx.saved_tensors + return grad * 2 + + x = torch.randn(4, requires_grad=True) + out = BoxedFunc.apply(x) + out.sum().backward() + # d/dx (2x).sum() = 2 + self.assertEqual(x.grad, torch.full_like(x, 2.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_multi_output(self): + """Boxed grads with multiple outputs — grads list has one + entry per output.""" + + class MultiOut(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2, x * 3 + + @staticmethod + def backward(ctx, grads): + self.assertIsInstance(grads, list) + self.assertEqual(len(grads), 2) + (x,) = ctx.saved_tensors + g1, g2 = grads + grads.clear() + return g1 * 2 + g2 * 3 + + x = torch.randn(4, requires_grad=True) + a, b = MultiOut.apply(x) + (a.sum() + b.sum()).backward() + # forward: a=2x, b=3x; backward: grad_a=1, grad_b=1 + # return grad_a*2 + grad_b*3 = 2+3 = 5 + self.assertEqual(x.grad, torch.full_like(x, 5.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_no_extra_refs(self): + """Framework holds no extra refs to grads with boxed calling convention. + + After removing the grad from the boxed list, a weakref to it should + become dead — proving the framework released all its refs.""" + + class CheckRefs(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2 + + @staticmethod + def backward(ctx, grads): + grad = grads[0] + grad_for_compute = grad.clone() + ref = weakref.ref(grad) + grads.clear() + del grad + self.assertTrue(ref() is None) + (x,) = ctx.saved_tensors + return grad_for_compute * 2 + + x = torch.randn(4, requires_grad=True) + out = CheckRefs.apply(x) + out.sum().backward() + self.assertEqual(x.grad, torch.full_like(x, 2.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_cleanup_on_error(self): + """Grads list is not leaked when backward raises.""" + + class FailingBwd(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grads): + raise RuntimeError("intentional failure") + + x = torch.randn(4, requires_grad=True) + out = FailingBwd.apply(x) + with self.assertRaisesRegex(RuntimeError, "intentional failure"): + out.sum().backward() + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_chain(self): + """Two boxed-grads functions chained — each gets its own grads list.""" + + call_order = [] + + class Mul2(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2 + + @staticmethod + def backward(ctx, grads): + call_order.append("Mul2") + grad = grads[0] + grads.clear() + (x,) = ctx.saved_tensors + return grad * 2 + + class Mul3(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 3 + + @staticmethod + def backward(ctx, grads): + call_order.append("Mul3") + grad = grads[0] + grads.clear() + (x,) = ctx.saved_tensors + return grad * 3 + + x = torch.randn(4, requires_grad=True) + out = Mul3.apply(Mul2.apply(x)) + out.sum().backward() + # backward order: Mul3 then Mul2 + self.assertEqual(call_order, ["Mul3", "Mul2"]) + # d/dx (3 * 2 * x).sum() = grad 1 → Mul3 bwd: 1*3=3 → Mul2 bwd: 3*2=6 + self.assertEqual(x.grad, torch.full_like(x, 6.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_none_grads(self): + """Boxed grads with materialize_grads=False and partial None grads. + + When only some outputs are used in the loss and materialize_grads + is False, the grads list contains None for unused outputs. Verify + that boxed_grads_call handles this correctly.""" + + class TwoOutput(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + ctx.set_materialize_grads(False) + return x * 2, x * 3 + + @staticmethod + def backward(ctx, grads): + self.assertIsInstance(grads, list) + self.assertEqual(len(grads), 2) + # Only second output is used in loss, first grad is None + self.assertIsNone(grads[0]) + self.assertIsNotNone(grads[1]) + (x,) = ctx.saved_tensors + g2 = grads[1] + grads[1] = None + return g2 * 3 + + x = torch.randn(4, requires_grad=True) + a, b = TwoOutput.apply(x) + # Only use b in the loss — a gets no gradient + b.sum().backward() + # forward: b=3x; backward: grad_b=1, return 1*3=3 + self.assertEqual(x.grad, torch.full_like(x, 3.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_materialize_grads(self): + """boxed_grads_call works with materialize_grads(True). + + When materialize_grads is True (default) and some outputs have no + gradient, the engine materializes zero tensors. Verify boxed_grads_call + delivers these materialized zeros in the grads list.""" + + class TwoOutputMaterialized(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2, x * 3 + + @staticmethod + def backward(ctx, grads): + self.assertIsInstance(grads, list) + self.assertEqual(len(grads), 2) + # materialize_grads is True (default), so grads[0] is a zero + # tensor even though first output wasn't used in the loss + self.assertIsNotNone(grads[0]) + self.assertEqual(grads[0], torch.zeros_like(grads[1])) + self.assertIsNotNone(grads[1]) + (x,) = ctx.saved_tensors + g1, g2 = grads[0], grads[1] + grads.clear() + return g1 * 2 + g2 * 3 + + x = torch.randn(4, requires_grad=True) + a, b = TwoOutputMaterialized.apply(x) + # Only use b in the loss — a gets a materialized zero grad + b.sum().backward() + # forward: b=3x; backward: g1=0, g2=1, return 0*2 + 1*3 = 3 + self.assertEqual(x.grad, torch.full_like(x, 3.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_direct_apply(self): + """Test both paths into backward with boxed_grads_call. + + Path 1 (C++ engine): .backward() goes through C++ PyNode::apply, + which calls apply_boxed with grads in a mutable list. + + Path 2 (direct grad_fn.apply()): bypasses C++, apply() boxes + grads into a list before calling backward.""" + + received_grads = [] + + class BoxedMultiOut(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return x * 2, x * 3 + + @staticmethod + def backward(ctx, grads): + received_grads.append(grads) + (x,) = ctx.saved_tensors + return grads[0] * 2 + grads[1] * 3 + + # Path 1: .backward() — C++ engine calls apply_boxed + x1 = torch.randn(4, requires_grad=True) + a1, b1 = BoxedMultiOut.apply(x1) + (a1.sum() + b1.sum()).backward() + self.assertIsInstance(received_grads[0], list) + self.assertEqual(len(received_grads[0]), 2) + self.assertEqual(x1.grad, torch.full_like(x1, 5.0)) + + # Path 2: grad_fn.apply() — apply() boxes grads + received_grads.clear() + x2 = torch.randn(4, requires_grad=True) + a2, b2 = BoxedMultiOut.apply(x2) + result = a2.grad_fn.apply(torch.ones(4), torch.ones(4) * 2) + self.assertIsInstance(received_grads[0], list) + self.assertEqual(len(received_grads[0]), 2) + # return 1*2 + 2*3 = 8 + self.assertEqual(result, torch.full_like(x2, 8.0)) + + @skipIfTorchDynamo("boxed_grads_call is incompatible with compiled autograd") + def test_custom_function_boxed_grads_single_list_arg(self): + """A plain list passed via grad_fn.apply() gets boxed into + a list wrapping it — the user's list becomes an element.""" + + received_grads = [] + + class SingleOut(torch.autograd.Function): + boxed_grads_call = True + + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, grads): + received_grads.append(grads) + return grads[0] + + x = torch.randn(4, requires_grad=True) + out = SingleOut.apply(x) + + # grad_fn.apply() with a list — apply() boxes it, so backward + # receives [user_list] (the user's list is an element, not the grads) + user_list = [torch.ones(4)] + out.grad_fn.apply(user_list) + self.assertIsInstance(received_grads[0], list) + self.assertEqual(len(received_grads[0]), 1) + self.assertIs(received_grads[0][0], user_list) + @skipIfTorchDynamo("dynamo accesses saved_tensors multiple times") def test_clear_saved_tensors_on_access(self): class MyFn(Function): @@ -8769,9 +9170,7 @@ def maybe_check_raise(fn, should_raise): is_view=True, should_raise_tuple=(None, None, None), ) - inp_change_err = ( - "Output {} of UnbindBackward0 is a view and is being modified inplace." - ) + inp_change_err = "Output {} of Unbind is a view and is being modified inplace." run_test( grad_mode=True, requires_grad=True, @@ -8918,17 +9317,17 @@ def backward(ctx, grad): fn_id_to_inplace_on_view_err_msg = { "one_output": ( - "Output 0 of IdOneOutputBackward is a view and is being " + "Output 0 of IdOneOutput is a view and is being " "modified inplace. This view was created inside a custom Function" ), "two_output": ( - "Output 0 of IdTwoOutputBackward is a view and is being modified inplace." + "Output 0 of IdTwoOutput is a view and is being modified inplace." " This view is the output of a function that returns multiple views.", "Pure view custom Function can only have one input Tensor and one output Tensor." " Open an issue if you need to support more.", ), "view_of_temp": ( - "Output 0 of ViewOfTempBackward is a view and is being " + "Output 0 of ViewOfTemp is a view and is being " "modified inplace. This view was created inside a custom Function", "a view of a leaf Variable that requires grad is being used in an in-place operation", ), @@ -9228,7 +9627,7 @@ def backward(ctx, grad): out = ComplexView.apply(a.clone(), idx) with self.assertRaisesRegex( RuntimeError, - "Output 0 of ComplexViewBackward is a view and is being modified inplace", + "Output 0 of ComplexView is a view and is being modified inplace", ): out += 1 @@ -10845,25 +11244,31 @@ def test(get_input, cuda, pin_memory): ) test(lambda: x, cuda, pin_memory) - @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "test requires CUDA or XPU") def test_graph_save_on_cpu_cuda(self): + device_type = torch.accelerator.current_accelerator().type + def f(x): a = x + 1 return a * a # with grad - a = torch.ones(1, requires_grad=True, device="cuda") + a = torch.ones(1, requires_grad=True, device=device_type) y = f(a) - memory_with_grad = torch.cuda.memory_allocated() + memory_with_grad = ( + torch.cuda.memory_allocated() if TEST_CUDA else torch.xpu.memory_allocated() + ) del a del y # without grad - a = torch.ones(1, requires_grad=True, device="cuda") + a = torch.ones(1, requires_grad=True, device=device_type) with torch.no_grad(): y = f(a) - memory_without_grad = torch.cuda.memory_allocated() + memory_without_grad = ( + torch.cuda.memory_allocated() if TEST_CUDA else torch.xpu.memory_allocated() + ) self.assertGreater(memory_with_grad, memory_without_grad) @@ -10872,15 +11277,20 @@ def f(x): # with hooks with torch.autograd.graph.save_on_cpu(): - a = torch.ones(1, requires_grad=True, device="cuda") + a = torch.ones(1, requires_grad=True, device=device_type) y = f(a) - memory_with_hooks = torch.cuda.memory_allocated() + memory_with_hooks = ( + torch.cuda.memory_allocated() + if TEST_CUDA + else torch.xpu.memory_allocated() + ) self.assertEqual(memory_with_hooks, memory_without_grad) - @unittest.skipIf(not TEST_CUDA, "test requires CUDA") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "test requires CUDA and XPU") def test_scalar_grad_mixed_device(self): + device_type = torch.accelerator.current_accelerator().type x = torch.tensor(1.0, requires_grad=True) - y = torch.randn(2, 2, device="cuda") + y = torch.randn(2, 2, device=device_type) out = x * y out.sum().backward() @@ -12214,24 +12624,58 @@ def fn(): # Generic device type autograd tests. class TestAutogradDeviceType(TestCase): - def test_min_max_median_backprops_to_all_values(self, device): + def test_min_max_aminmax_median_backprops_to_all_values(self, device): + # 1) Test min/max/median/nanmedian on both a non NaN and all NaN tensor for f in [torch.min, torch.max, torch.median, torch.nanmedian]: - x1 = torch.tensor( - [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True - ) - x2 = torch.tensor( - [float("nan"), float("nan"), float("nan")], requires_grad=True - ) - for x in [x1, x2]: - y = f(x) + with self.subTest(f=f): + x1 = torch.tensor( + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True + ) + x2 = torch.tensor( + [float("nan"), float("nan"), float("nan")], + device=device, + requires_grad=True, + ) + for x in [x1, x2]: + y = f(x) + y.backward() + self.assertEqual(x.grad.sum(), 1.0) + self.assertEqual((x.grad == 1 / 3).sum(), 3) + + # 2) Explicit amin/amax plus the two components of aminmax + def amin2(x): + return torch.aminmax(x)[0] # min part + + def amax2(x): + return torch.aminmax(x)[1] + + for f in [torch.amin, torch.amax, amax2, amin2]: + with self.subTest(f=f): + x1 = torch.tensor( + [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], + device=device, + requires_grad=True, + ) + y = f(x1) y.backward() - self.assertEqual(x.grad.sum(), 1.0) - self.assertEqual((x.grad == 1 / 3).sum(), 3) - - def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device): + self.assertEqual(x1.grad.sum(), 1.0) + self.assertEqual((x1.grad == 1.0 / 3.0).sum(), 3) + + # 3) Both min and max grads active simultaneously — exercises the add_ path + # in aminmax_backward when both grad_min and grad_max are defined. + # min ties at indices 1,3 → each gets 0.5; max ties at indices 0,2 → each gets 0.5 + with self.subTest("aminmax_both_grads"): + x = torch.tensor([3.0, 1.0, 3.0, 1.0], device=device, requires_grad=True) + min_val, max_val = torch.aminmax(x) + (min_val + max_val).backward() + self.assertEqual(x.grad, torch.tensor([0.5, 0.5, 0.5, 0.5], device=device)) + + def test_scatter_index_reduce_amin_amax_aminmax_backprops_to_all_values( + self, device + ): # tests that gradients are evenly distributed when there are multiple max/min values # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad - # as is the case for test_min_max_median_backprops_to_all_values above + # as is the case for test_min_max_aminmax_median_backprops_to_all_values above fns = (torch.scatter_reduce, torch.index_reduce) reduces = ("amin", "amax") for fn, reduction in product(fns, reduces): @@ -15117,6 +15561,74 @@ def fn2(x): self.assertEqual(counter[0], 1) +@contextlib.contextmanager +def _counter_op(name): + """Yields (op, counts, idx_log) where op is a custom op that counts + invocations. idx_log maps call_idx (passed by caller) to replay count.""" + counts = [0] + idx_log: dict = {} + with torch.library._scoped_library("test_ckpt", "FRAGMENT"): + + @torch.library.custom_op(f"test_ckpt::{name}", mutates_args=()) + def op(x: torch.Tensor, call_idx: int) -> torch.Tensor: + counts[0] += 1 + idx_log[call_idx] = idx_log.get(call_idx, 0) + 1 + return x.sin() + + def setup_context(ctx, inputs, output): + ctx.save_for_backward(inputs[0]) + + def backward(ctx, grad): + (x,) = ctx.saved_tensors + return grad * x.cos(), None + + op.register_autograd(backward, setup_context=setup_context) + + yield op, counts, idx_log + + +class _AutoNamingMode(TorchDispatchMode): + """Test helper: names output tensors as ``fqn_op_count[_outputidx]``.""" + + def __init__(self): + from torch.utils.module_tracker import ModuleTracker + + self._tracker = ModuleTracker() + self._func_counter: dict = defaultdict(int) + self.names = WeakTensorKeyDictionary() + + def __enter__(self): + self._tracker.__enter__() + return super().__enter__() + + def __exit__(self, *args): + self._tracker.__exit__(*args) + return super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + out = func(*args, **(kwargs or {})) + parents = self._tracker.parents - {"Global"} + fqn = max(parents, key=len) if parents else "Global" + op_name = func.__name__ if hasattr(func, "__name__") else str(func) + key = (fqn, func) + count = self._func_counter[key] + self._func_counter[key] += 1 + multi_output = ( + isinstance(out, (tuple, list)) + and sum(isinstance(o, torch.Tensor) for o in out) > 1 + ) + if isinstance(out, torch.Tensor): + self.names[out] = f"{fqn}_{op_name}_{count}" + elif isinstance(out, (tuple, list)): + for i, o in enumerate(out): + if isinstance(o, torch.Tensor): + name = f"{fqn}_{op_name}_{count}" + if multi_output: + name += f"_{i}" + self.names[o] = name + return out + + class TestSelectiveActivationCheckpoint(TestCase): @unittest.skipIf(not TEST_CUDA, "requires CUDA") def test_flops_and_mem(self): @@ -15536,6 +16048,83 @@ def fn(x): with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): out.sum().backward(retain_graph=True) + @skipIfTorchDynamo("torch dispatch modes don't support compile") + def test_auto_naming_mode_names(self): + with _counter_op("my_op") as (my_op, my_count, idx_log): + + class Block(torch.nn.Module): + def forward(self, x, counter): + x = my_op(x, counter[0]) + counter[0] += 1 + x = my_op(x, counter[0]) + counter[0] += 1 + x = my_op(x, counter[0]) + counter[0] += 1 + return x + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.layers = torch.nn.ModuleList([Block(), Block()]) + + def forward(self, x): + counter = [0] + for layer in self.layers: + x = layer(x, counter) + return x + + mod = Model() + naming = _AutoNamingMode() + + save_names = { + "Model.layers.0_my_op.default_1", + "Model.layers.1_my_op.default_0", + "Model.layers.1_my_op.default_2", + } + + fwd_decisions: list = [] + fwd_idx = [0] + + def policy_fn(ctx, op, *args, **kwargs): + if ctx.is_recompute: + decision = fwd_decisions[fwd_idx[0]] + fwd_idx[0] += 1 + return decision + out = ctx.op_output + decision = CheckpointPolicy.PREFER_RECOMPUTE + if isinstance(out, torch.Tensor): + name = naming.names.get(out) + if name in save_names: + decision = CheckpointPolicy.MUST_SAVE + fwd_decisions.append(decision) + return decision + + x = torch.randn(4, requires_grad=True) + context_fn = functools.partial( + create_selective_checkpoint_contexts, policy_fn + ) + with naming: + out = checkpoint( + lambda x: mod(x), + x, + use_reentrant=False, + context_fn=context_fn, + ) + out.sum().backward() + + self.assertEqual( + idx_log, + { + 0: 2, # Model.layers.0_my_op.default_0 -> recomputed + 1: 1, # Model.layers.0_my_op.default_1 -> saved + 2: 2, # Model.layers.0_my_op.default_2 -> recomputed + 3: 1, # Model.layers.1_my_op.default_0 -> saved + 4: 2, # Model.layers.1_my_op.default_1 -> recomputed + 5: 1, # Model.layers.1_my_op.default_2 -> saved + }, + ) + self.assertEqual(my_count[0], 9) + class TestAutogradMultipleDispatch(TestCase): def test_autograd_multiple_dispatch_registrations(self, device): diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 29dff031d6e5a..28056c5b396bc 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -4049,7 +4049,7 @@ def test_dx(sizes, dim, dx, device): t = torch.randn(sizes, device=device) actual = torch.trapezoid(t, dx=dx, dim=dim) if int(np.__version__.split(".")[0]) >= 2: - expected = np.trapezoid(t.cpu().numpy(), dx=dx, axis=dim) # noqa: NPY201 + expected = np.trapezoid(t.cpu().numpy(), dx=dx, axis=dim) else: expected = np.trapz(t.cpu().numpy(), dx=dx, axis=dim) # noqa: NPY201 self.assertEqual(expected.shape, actual.shape) @@ -4059,7 +4059,7 @@ def test_x(sizes, dim, x, device): t = torch.randn(sizes, device=device) actual = torch.trapezoid(t, x=torch.tensor(x, device=device), dim=dim) if int(np.__version__.split(".")[0]) >= 2: - expected = np.trapezoid(t.cpu().numpy(), x=x, axis=dim) # noqa: NPY201 + expected = np.trapezoid(t.cpu().numpy(), x=x, axis=dim) else: expected = np.trapz(t.cpu().numpy(), x=x, axis=dim) # noqa: NPY201 self.assertEqual(expected.shape, actual.shape) diff --git a/test/test_bmm_outer_product.py b/test/test_bmm_outer_product.py new file mode 100644 index 0000000000000..88742ba0708df --- /dev/null +++ b/test/test_bmm_outer_product.py @@ -0,0 +1,103 @@ +# Owner(s): ["module: nn"] + +import unittest + +import torch +from torch._native.ops.bmm_outer_product.triton_impl import _is_outer_product +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU + + +@unittest.skipIf(not HAS_GPU, "requires GPU") +class TestBmmOuterProduct(TestCase): + def _check_bmm(self, a, b, **kwargs): + self.assertEqual(torch.bmm(a, b), a @ b, **kwargs) + + def test_shapes(self): + shapes = [ + (4, 8, 16), + (32, 8, 256), + (16, 128, 512), + (1, 64, 128), + (8, 1, 1), + (64, 256, 512), + (256, 8, 2048), + ] + for B, M, N in shapes: + with self.subTest(B=B, M=M, N=N): + a = torch.randn(B, M, 1, device=GPU_TYPE) + b = torch.randn(B, 1, N, device=GPU_TYPE) + self._check_bmm(a, b) + + def test_basic_dtypes(self): + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + with self.subTest(dtype=dtype): + a = torch.randn(4, 8, 1, device=GPU_TYPE, dtype=dtype) + b = torch.randn(4, 1, 16, device=GPU_TYPE, dtype=dtype) + self.assertEqual(torch.bmm(a, b), a @ b) + + def test_permuted_inputs(self): + B, M, N = 4, 8, 16 + cases = [ + ( + torch.randn(M, B, 1, device=GPU_TYPE).permute(1, 0, 2), + torch.randn(B, 1, N, device=GPU_TYPE), + ), + ( + torch.randn(B, M, 1, device=GPU_TYPE), + torch.randn(N, B, 1, device=GPU_TYPE).permute(1, 2, 0), + ), + ( + torch.randn(M, B, 1, device=GPU_TYPE).permute(1, 0, 2), + torch.randn(N, B, 1, device=GPU_TYPE).permute(1, 2, 0), + ), + ] + for a, b in cases: + self.assertEqual(torch.bmm(a, b), a @ b) + + def test_fallback_non_outer_product(self): + a = torch.randn(4, 8, 16, device=GPU_TYPE) + b = torch.randn(4, 16, 32, device=GPU_TYPE) + self.assertEqual(torch.bmm(a, b), a @ b, atol=1e-5, rtol=1.3e-6) + + def test_batch_one(self): + a = torch.randn(1, 64, 1, device=GPU_TYPE) + b = torch.randn(1, 1, 128, device=GPU_TYPE) + self.assertEqual(torch.bmm(a, b), a @ b) + + def test_m_one_n_one(self): + a = torch.randn(8, 1, 1, device=GPU_TYPE) + b = torch.randn(8, 1, 1, device=GPU_TYPE) + self.assertEqual(torch.bmm(a, b), a @ b) + + def test_gradient_flow(self): + a = torch.randn(4, 8, 1, device=GPU_TYPE, requires_grad=True) + b = torch.randn(4, 1, 16, device=GPU_TYPE, requires_grad=True) + result = torch.bmm(a, b) + result.sum().backward() + self.assertIsNotNone(a.grad) + self.assertIsNotNone(b.grad) + self.assertEqual(a.grad.shape, a.shape) + self.assertEqual(b.grad.shape, b.shape) + + +class TestOuterProductDetection(TestCase): + def test_is_outer_product(self): + self.assertTrue(_is_outer_product(torch.empty(4, 8, 1), torch.empty(4, 1, 16))) + self.assertTrue(_is_outer_product(torch.empty(4, 8, 1), torch.empty(4, 1, 1))) + self.assertFalse( + _is_outer_product(torch.empty(4, 8, 16), torch.empty(4, 16, 32)) + ) + self.assertFalse(_is_outer_product(torch.empty(8, 1), torch.empty(1, 16))) + self.assertFalse(_is_outer_product(torch.empty(4, 8, 1), torch.empty(4, 2, 16))) + self.assertFalse(_is_outer_product(torch.empty(4, 8, 3), torch.empty(4, 1, 16))) + self.assertFalse( + _is_outer_product( + torch.empty(4, 8, 1, dtype=torch.complex64), + torch.empty(4, 1, 16, dtype=torch.complex64), + ) + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_cpp_extensions_jit.py b/test/test_cpp_extensions_jit.py index 2c57587f757dc..db640fc9eb9ce 100644 --- a/test/test_cpp_extensions_jit.py +++ b/test/test_cpp_extensions_jit.py @@ -1552,7 +1552,7 @@ def test_torch_check_eq_stacktrace(self): self.assertIn( "C++ CapturedTraceback:", error_message, - f"Expected C++ stack trace info in error message when TORCH_SHOW_CPP_STACKTRACES=1, got: {error_message}", # noqa: B950 + f"Expected C++ stack trace info in error message when TORCH_SHOW_CPP_STACKTRACES=1, got: {error_message}", ) self.assertRegex( error_message, @@ -1562,7 +1562,7 @@ def test_torch_check_eq_stacktrace(self): self.assertNotIn( "C++ CapturedTraceback:", error_message, - f"Did not expect 'C++ CapturedTraceback:' in error message when TORCH_SHOW_CPP_STACKTRACES=0, got: {error_message}", # noqa: B950 + f"Did not expect 'C++ CapturedTraceback:' in error message when TORCH_SHOW_CPP_STACKTRACES=0, got: {error_message}", ) diff --git a/test/test_cuda.py b/test/test_cuda.py index 55acd8de897f7..acfce2b8821e2 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3,6 +3,7 @@ import contextlib import ctypes +import functools import gc import json import os @@ -38,7 +39,9 @@ _get_torch_cuda_version, blas_library_context, PLATFORM_SUPPORTS_GREEN_CONTEXT, + PLATFORM_SUPPORTS_WORKQUEUE_CONFIG, SM70OrLater, + SM89OrLater, TEST_CUDNN, TEST_MULTIGPU, tf32_on_and_off, @@ -63,6 +66,7 @@ gcIfJetson, get_cycles_per_ms, instantiate_parametrized_tests, + IS_ARM64, IS_FBCODE, IS_JETSON, IS_LINUX, @@ -72,7 +76,6 @@ load_tests, MI200_ARCH, MI300_ARCH, - MI350_ARCH, parametrize, recover_orig_fp32_precision, run_tests, @@ -98,7 +101,7 @@ requiresCppContext = unittest.skipUnless( - IS_X86 and IS_LINUX, "cpp contexts are x86 linux only" + (IS_X86 or IS_ARM64) and IS_LINUX, "cpp contexts are linux x86/aarch64 only" ) # load_tests from common_utils is used to automatically filter tests for @@ -106,8 +109,9 @@ load_tests = load_tests # noqa: PLW0127 try: - # import torchvision.models # noqa: F401 - # from torchvision.models import resnet18 # noqa: F401 + import torchvision.models # noqa: F401 + + # from torchvision.models import resnet18 HAS_TORCHVISION = True except ImportError: @@ -133,6 +137,16 @@ _wait_for_cpu_kernel = None +def skip_background_threads_on_windows(f): + @functools.wraps(f) + def wrapped(self, **kwargs): + if IS_WINDOWS and SM89OrLater and kwargs.get("use_background_threads"): + raise unittest.SkipTest("using background threads fails on Windows") + return f(self, **kwargs) + + return wrapped + + def get_wait_for_cpu_kernel(): """Returns a compiled CUDA spin-wait kernel that blocks the GPU stream until the host sets a pinned int32 flag to non-zero. Requires SM70+. @@ -324,6 +338,9 @@ def test_pinned_memory_empty_cache(self): "pinned_use_cuda_host_register:False" ) + # Pinned allocator background thread does not shut down cleanly on Windows + # Python process hangs + @unittest.skipIf(IS_WINDOWS and SM89OrLater, "Fails on windows with SM89+") def test_pinned_memory_use_background_threads(self): script = """ import torch @@ -363,6 +380,10 @@ def test_memory_allocation(self): torch.cuda.caching_allocator_delete(mem) self.assertEqual(torch.cuda.memory_allocated(), prev) + def test_caching_allocator_alloc_negative_size(self): + with self.assertRaisesRegex(ValueError, "Invalid memory size"): + torch.cuda.memory.caching_allocator_alloc(-1024) + def test_memory_stats(self): gc.collect() torch.cuda.empty_cache() @@ -466,6 +487,9 @@ def test_out_of_memory(self): tensor.fill_(1) self.assertTrue((tensor == 1).all()) + # CUDA memory allocations on windows do not OOM on rtx even when they cross allowed memory + # Skip test until this is investigated + @unittest.skipIf(IS_WINDOWS and SM89OrLater, "Fails on windows with SM89+") @unittest.skipIf( TEST_CUDAMALLOCASYNC or IS_JETSON, "Segmentation fault (core dumped)" ) @@ -647,13 +671,17 @@ def test_serialization_array_with_storage(self): q_copy[1].fill_(10) self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10)) + @unittest.skipIf( + IS_WINDOWS and SM89OrLater, "preferred_blas_library not supported on Windows" + ) @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Does not work in fbcode yet") @setBlasBackendsToDefaultFinally def test_preferred_blas_library_settings(self): def _check_default(): default = torch.backends.cuda.preferred_blas_library() if torch.version.cuda: - self.assertTrue(default == torch._C._BlasBackend.Cublaslt) + # CUDA logic is easy, it's always cublas + self.assertTrue(default == torch._C._BlasBackend.Cublas) else: # ROCm logic is less so, it's cublaslt for some Instinct, cublas for all else gcn_arch = str( @@ -715,6 +743,9 @@ def _check_default(): torch.backends.cuda.preferred_blas_library("default") _check_default() + @unittest.skipIf( + IS_WINDOWS and SM89OrLater, "preferred_blas_library not supported on Windows" + ) @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") @serialTest() @blas_library_context("cublas") @@ -745,19 +776,22 @@ def check_workspace_size(inp): return finish - start # check default - os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" - self.assertLess(check_workspace_size(a) - default_workspace_size, 524288) self.assertLess(abs(check_workspace_size(a) - default_workspace_size), 524288) - # check default with bad user config - os.environ["CUBLAS_WORKSPACE_CONFIG"] = "-1" - self.assertLess(check_workspace_size(a) - default_workspace_size, 524288) - self.assertLess(abs(check_workspace_size(a) - default_workspace_size), 524288) + # check explicit size via API + explicit_size = 3072 * 1024 + torch.backends.cuda.cublas_workspace_size(explicit_size) + self.assertLess(abs(check_workspace_size(a) - explicit_size), 524288) - # check valid config - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":128:8:64:16:32:32" - self.assertLess(check_workspace_size(a) - (3072 * 1024), 524288) - self.assertLess(abs(check_workspace_size(a) - (3072 * 1024)), 524288) + # check invalid size rejected + with self.assertRaisesRegex( + RuntimeError, "cublas workspace size must be non-negative" + ): + torch.backends.cuda.cublas_workspace_size(-1) + + # restore default + torch._C._cuda_resetCublasWorkspaceSize() + self.assertLess(abs(check_workspace_size(a) - default_workspace_size), 524288) torch._C._cuda_clearCublasWorkspaces() @@ -781,6 +815,118 @@ def test_cublas_unified_workspace(self): # switching to Lt, otherwise the temporary allocation would bump the peak self.assertEqual(warmed_alloc, lt_alloc) + @setBlasBackendsToDefaultFinally + def test_cublas_workspace_size_api(self): + # Test getter returns a positive default + original_size = torch.backends.cuda.cublas_workspace_size() + self.assertGreater(original_size, 0) + + original_lt_size = torch.backends.cuda.cublaslt_workspace_size() + self.assertGreater(original_lt_size, 0) + + # Test setter changes the value and getter reflects it + new_size = 64 * 1024 * 1024 # 64 MiB + result = torch.backends.cuda.cublas_workspace_size(new_size) + self.assertEqual(result, new_size) + self.assertEqual(torch.backends.cuda.cublas_workspace_size(), new_size) + + new_lt_size = 2 * 1024 * 1024 # 2 MiB + result_lt = torch.backends.cuda.cublaslt_workspace_size(new_lt_size) + self.assertEqual(result_lt, new_lt_size) + self.assertEqual(torch.backends.cuda.cublaslt_workspace_size(), new_lt_size) + + # Test validation rejects negative values + with self.assertRaisesRegex( + RuntimeError, "cublas workspace size must be non-negative" + ): + torch.backends.cuda.cublas_workspace_size(-1) + with self.assertRaisesRegex( + RuntimeError, "cublaslt workspace size must be non-negative" + ): + torch.backends.cuda.cublaslt_workspace_size(-1) + + @setBlasBackendsToDefaultFinally + def test_blas_workspace_size_api(self): + # Dispatches to the correct backend based on preferred_blas_library() + pref = torch.backends.cuda.preferred_blas_library() + if pref == torch._C._BlasBackend.Cublaslt: + expected = torch.backends.cuda.cublaslt_workspace_size() + else: + # Default and Cublas both map to cuBLAS + expected = torch.backends.cuda.cublas_workspace_size() + self.assertEqual(torch.backends.cuda.blas_workspace_size(), expected) + + # Explicit backend= parameter + self.assertEqual( + torch.backends.cuda.blas_workspace_size(backend="cublas"), + torch.backends.cuda.cublas_workspace_size(), + ) + self.assertEqual( + torch.backends.cuda.blas_workspace_size(backend="cublaslt"), + torch.backends.cuda.cublaslt_workspace_size(), + ) + self.assertEqual( + torch.backends.cuda.blas_workspace_size( + backend=torch._C._BlasBackend.Cublas + ), + torch.backends.cuda.cublas_workspace_size(), + ) + + # Setting via the dispatcher updates the underlying function + new_size = 16 * 1024 * 1024 + torch.backends.cuda.blas_workspace_size(new_size, backend="cublas") + self.assertEqual(torch.backends.cuda.cublas_workspace_size(), new_size) + + torch.backends.cuda.blas_workspace_size(new_size, backend="cublaslt") + self.assertEqual(torch.backends.cuda.cublaslt_workspace_size(), new_size) + + # CK backend has no workspace + with self.assertRaisesRegex( + RuntimeError, "CK backend does not use a workspace." + ): + torch.backends.cuda.blas_workspace_size(backend="ck") + + # Invalid string + with self.assertRaisesRegex( + RuntimeError, + "Unknown backend string. Choose from: default, cublas, hipblas, cublaslt, hipblaslt, ck.", + ): + torch.backends.cuda.blas_workspace_size(backend="invalid") + + # Invalid type + with self.assertRaisesRegex(RuntimeError, "Unknown backend type."): + torch.backends.cuda.blas_workspace_size(backend=42) + + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "temporarily disabled for async") + @setBlasBackendsToDefaultFinally + def test_cublas_workspace_lazy_reallocation(self): + torch.backends.cuda.preferred_blas_library("cublas") + + original_size = torch.backends.cuda.cublas_workspace_size() + torch._C._cuda_clearCublasWorkspaces() + + # Trigger initial allocation with matmul + a = torch.randn(7, 7, device="cuda", requires_grad=False) + with torch.no_grad(): + torch.matmul(a, a) + + mem_after_first = torch.cuda.memory_stats()["active_bytes.all.allocated"] + + # Increase workspace size + bigger_size = original_size + 32 * 1024 * 1024 # +32 MiB + torch.backends.cuda.cublas_workspace_size(bigger_size) + + # No immediate memory change (lazy reallocation) + mem_after_set = torch.cuda.memory_stats()["active_bytes.all.allocated"] + self.assertEqual(mem_after_first, mem_after_set) + + # Next matmul triggers reallocation + with torch.no_grad(): + torch.matmul(a, a) + + mem_after_realloc = torch.cuda.memory_stats()["active_bytes.all.allocated"] + self.assertGreater(mem_after_realloc, mem_after_first) + def test_cublas_allow_tf32_get_set(self): skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int( os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"] @@ -898,6 +1044,35 @@ def test_cudnn_allow_tf32_get_set(self): ): self.assertTrue(torch.backends.cudnn.allow_tf32) + def test_cudnn_depthwise_kernel_get_set(self): + self.assertEqual(torch.backends.cudnn.depthwise_kernel, "auto") + + # Test all valid values via the flags() context manager + for mode in ("auto", "cudnn", "native"): + with torch.backends.cudnn.flags( + enabled=None, + benchmark=None, + deterministic=None, + allow_tf32=None, + depthwise_kernel=mode, + ): + self.assertEqual(torch.backends.cudnn.depthwise_kernel, mode) + + # Invalid value should raise + with self.assertRaises(RuntimeError): + torch._C._set_cudnn_depthwise_kernel("invalid") + + # Verify the flags() context manager restores the previous value + with torch.backends.cudnn.flags( + enabled=None, + benchmark=None, + deterministic=None, + allow_tf32=None, + depthwise_kernel="native", + ): + self.assertEqual(torch.backends.cudnn.depthwise_kernel, "native") + self.assertEqual(torch.backends.cudnn.depthwise_kernel, "auto") + @recover_orig_fp32_precision def test_fp32_precision_with_tf32(self): with torch.backends.cudnn.flags( @@ -1176,6 +1351,7 @@ def test_generic_stream_event(self): self.assertTrue(issubclass(type(cuda_event), torch.Event)) self.assertTrue(torch.Event in type(cuda_event).mro()) + @unittest.skip("Fails with Triton 3.7") def test_stream_event_compatibility(self): s1 = torch.cuda.Stream() s2 = torch.cuda.Stream() @@ -2173,8 +2349,10 @@ def test_graph_is_current_stream_capturing(self): with torch.cuda.stream(s): g = torch.cuda.CUDAGraph() self.assertFalse(torch.cuda.is_current_stream_capturing()) + self.assertFalse(s.is_capturing()) g.capture_begin() self.assertTrue(torch.cuda.is_current_stream_capturing()) + self.assertTrue(s.is_capturing()) g.capture_end() @unittest.skipIf( @@ -2277,6 +2455,100 @@ def get_final_offsets_of_states(generator_state): # Compare the states generated outside and inside the graph self.assertEqual(random_values, graphed_random_values) + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_rng_after_failed_capture(self): + """Test that a stream can be captured again for RNG after a failed capture.""" + if TEST_WITH_ROCM and self.expandable_segments: + self.skipTest( + "ROCm expandable segments has known issue with graph capture recovery - #179911" + ) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + x = torch.ones(1, device="cuda") + + with torch.cuda.stream(stream): + graph.capture_begin() + with self.assertRaises(RuntimeError): + (x + 1).item() + with self.assertRaises(RuntimeError): + graph.capture_end() + + torch.cuda.current_stream().wait_stream(stream) + + result = torch.randn(4, device="cuda") + self.assertEqual(result.shape, (4,)) + + new_graph = torch.cuda.CUDAGraph() + buf = torch.empty(4, device="cuda") + with torch.cuda.stream(stream): + new_graph.capture_begin() + buf.copy_(torch.randn_like(buf)) + new_graph.capture_end() + torch.cuda.current_stream().wait_stream(stream) + buf.zero_() + new_graph.replay() + torch.cuda.synchronize() + self.assertFalse(torch.allclose(buf, torch.zeros_like(buf))) + + @unittest.skipIf( + not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" + ) + def test_graph_rng_concurrent_replay_on_different_streams(self): + """Concurrent replay of two graphs sharing a generator on different streams. + + With per-generator state (old code), this would race on the shared + rng_state_offset_extragraph_ tensor. With per-(generator, capture_id) + state, each graph has its own tensors. + """ + seed = 1234 + shape = (64,) + + torch.manual_seed(seed) + ref0 = torch.randn(shape, device="cuda") + ref1 = torch.randn(shape, device="cuda") + + torch.manual_seed(seed) + g0 = torch.cuda.CUDAGraph() + g1 = torch.cuda.CUDAGraph() + s_cap = torch.cuda.Stream() + buf0 = torch.empty(shape, device="cuda") + buf1 = torch.empty(shape, device="cuda") + + with torch.cuda.stream(s_cap): + g0.capture_begin() + buf0.copy_(torch.randn_like(buf0)) + g0.capture_end() + + torch.cuda.current_stream().wait_stream(s_cap) + + with torch.cuda.stream(s_cap): + g1.capture_begin() + buf1.copy_(torch.randn_like(buf1)) + g1.capture_end() + + torch.cuda.current_stream().wait_stream(s_cap) + + s0 = torch.cuda.Stream() + s1 = torch.cuda.Stream() + + buf0.zero_() + buf1.zero_() + + s0.wait_stream(torch.cuda.current_stream()) + s1.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(s0): + g0.replay() + with torch.cuda.stream(s1): + g1.replay() + + torch.cuda.synchronize() + + self.assertEqual(buf0, ref0) + self.assertEqual(buf1, ref1) + @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" ) @@ -2287,14 +2559,16 @@ def clear_cuda_cache(): torch.cuda.empty_cache() # Executes a simple graph task which includes capturing and executing a random number generation within a CUDA graph. - def simple_graph_task(graph): + def simple_graph_task(graph, default_generator, generator_states): s = torch.cuda.Stream() with torch.cuda.stream(s): graph.capture_begin() - torch.rand(1, device="cuda") + for generator_state in generator_states: + default_generator.graphsafe_set_state(generator_state) + torch.rand(1, device="cuda") graph.capture_end() torch.cuda.current_stream().wait_stream(s) - graph.replay() # Replays the captured operations + graph.replay() def get_memory_stats(): stats = torch.cuda.memory_stats() @@ -2311,22 +2585,22 @@ def test(num_graphs, num_generators): # Allocate and manage generator states default_generator = torch.cuda.default_generators[0] - generators = [default_generator.graphsafe_get_state()] + generator_states = [default_generator.graphsafe_get_state()] - # Starts from 1 as one state is already added for _ in range(1, num_generators): - generators.append(default_generator.clone_state()) + generator_states.append(default_generator.clone_state()) for graph in graphs: - for generator_state in generators: + for generator_state in generator_states: graph.register_generator_state(generator_state) - simple_graph_task(graph) + simple_graph_task(graph, default_generator, generator_states) # Assert conditions after graph tasks num_blocks, total_size = get_memory_stats() - # The allocated blocks should only be proportional to the number of generators - expected_blocks_diff = 2 * num_generators - expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 + # Each (generator, graph) pair gets 2 tensors (seed + offset) + expected_captured_states = num_generators * num_graphs + expected_blocks_diff = 2 * expected_captured_states + expected_size_diff = 2 * 512 * expected_captured_states self.assertEqual( (num_blocks - baseline_num_blocks), @@ -2999,9 +3273,12 @@ def test_graph_memory_stats_and_use_result_after_destroy_graph(self): elem = 4 # this was annoying to write but stresses the expectations pretty rigorously + # For small_pool cases, delta_cudaMallocs and delta_cudaMalloc_bytes include + # an extra kSmallBuffer segment for the per-capture RNG state tensors, which + # are allocated on the default stream (separate from stream s). cases = ( - (512 // elem, 1, kSmallBuffer, kSmallBuffer, "small_pool"), - (kSmallSize // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), + (512 // elem, 2, 2 * kSmallBuffer, kSmallBuffer, "small_pool"), + (kSmallSize // elem, 3, 3 * kSmallBuffer, kSmallBuffer, "small_pool"), ((kSmallSize + 512) // elem, 1, kLargeBuffer, kLargeBuffer, "large_pool"), ( (kMinLargeAlloc - 512) // elem, @@ -3049,9 +3326,9 @@ def test_graph_memory_stats_and_use_result_after_destroy_graph(self): g = torch.cuda.CUDAGraph() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - # Allocation stat estimates assume input is created on the same stream as capture_begin() - # (in other words, the same stream silo as the rng offset holder, which is not allocated from the - # capture's private pool). + # Per-capture RNG state tensors are allocated on the default stream + # (not the capture stream), so they occupy a separate segment from + # user tensors created here on stream s. a = torch.ones((numel,), device="cuda") precapture_stats = torch.cuda.memory_stats() @@ -3270,7 +3547,7 @@ def test_graph_manual_seed_mismatch_raises(self): g = torch.cuda.CUDAGraph() with self.assertRaisesRegex( RuntimeError, - "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", # noqa: B950 + "CUDAGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", ): with torch.cuda.graph(g): torch.cuda.manual_seed(1) @@ -3710,14 +3987,14 @@ def raw_malloc(): try: with torch.cuda.stream(stream): mem = torch.cuda.caching_allocator_alloc(1024) - except BaseException: # noqa: B036 + except BaseException: if mem is None: return try: torch.cuda.caching_allocator_delete(mem) mem = None return None - except BaseException: # noqa: B036 + except BaseException: pass def throws_on_cuda_event(capture_error_mode): @@ -4135,6 +4412,9 @@ def test_gds_fails_in_ci(self): with self.assertRaisesRegex(RuntimeError, error_msg): torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR) + @unittest.skipIf( + IS_WINDOWS, "test relies on fork; Windows multiprocessing uses spawn" + ) def test_is_pinned_no_context(self): test_script = """\ import torch @@ -4185,7 +4465,7 @@ def worker(conn): @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest -class TestCudaMallocAsync(TestCase): +class TestCudaAllocator(TestCase): @unittest.skipIf( TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) @@ -4269,7 +4549,9 @@ def test_allocation_traceback_no_recording(self): x = torch.rand(64, device="cuda") self.assertIsNone(torch.cuda.memory._allocation_traceback(x.data_ptr())) - @unittest.skipUnless(IS_X86 and IS_LINUX, "x86 linux only cpp unwinding") + @unittest.skipUnless( + (IS_X86 or IS_ARM64) and IS_LINUX, "linux x86/aarch64 only cpp unwinding" + ) def test_direct_traceback(self): from torch._C._profiler import gather_traceback, symbolize_tracebacks # @manual @@ -4287,7 +4569,7 @@ def test_memory_snapshot_with_cpp(self): try: torch.cuda.memory.empty_cache() torch.cuda.memory._record_memory_history("state", stacks="all") - x = torch.rand(311, 411, device="cuda") # noqa: F841 + x = torch.rand(311, 411, device="cuda") ss = torch.cuda.memory._snapshot()["segments"] found_it = False @@ -4402,7 +4684,6 @@ def foo(p): TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) @requiresCppContext - @skipIfRocmArch(MI350_ARCH) def test_memory_plots(self): for context, stacks in ( ("all", "all" if IS_LINUX else "python"), @@ -4617,7 +4898,7 @@ def test_memory_snapshot_script(self): def foo(): return torch.rand(311, 411, device="cuda") - x = foo() # noqa: F841 + x = foo() ss = torch.cuda.memory._snapshot()["segments"] found_it = False @@ -4880,6 +5161,55 @@ def power2_div(size, div_factor): "pinned_num_register_threads:1024" ) + # Test throw_on_cudamalloc_oom config parsing - valid formats + torch.cuda.memory._set_allocator_settings("throw_on_cudamalloc_oom:True") + torch.cuda.memory._set_allocator_settings("throw_on_cudamalloc_oom:False") + + # Test throw_on_cudamalloc_oom config parsing - invalid formats + with self.assertRaises(ValueError): + torch._C._accelerator_setAllocatorSettings("throw_on_cudamalloc_oom:maybe") + + @unittest.skipIf(TEST_CUDAMALLOCASYNC, "throw_on_cudamalloc_oom not supported") + @serialTest() + def test_throw_on_cudamalloc_oom(self): + """Test that throw_on_cudamalloc_oom + per_process_memory_fraction works correctly.""" + torch.cuda.empty_cache() + device = torch._C._cuda_getDevice() + torch._C._cuda_resetAccumulatedMemoryStats(device) + + try: + # Test 1: With rejection disabled (default), allocations should succeed + torch._C._accelerator_setAllocatorSettings("throw_on_cudamalloc_oom:False") + x = torch.empty(10 * 1024 * 1024, dtype=torch.int8, device="cuda") + del x + torch.cuda.empty_cache() + + # Test 2: With throw_on_cudamalloc_oom enabled and a tight memory + # fraction, allocations that exceed the fraction limit should be + # preemptively rejected with OutOfMemoryError. + # Both settings must go through _accelerator_setAllocatorSettings so + # they are read from CUDAAllocatorConfig. + fraction = 0.005 + torch._C._accelerator_setAllocatorSettings( + f"throw_on_cudamalloc_oom:True,per_process_memory_fraction:{fraction}" + ) + + total_mem = torch.cuda.get_device_properties(0).total_memory + # Allocate the allowed threshold + 1 MiB to guarantee rejection + alloc_bytes = int(total_mem * fraction) + 1024 * 1024 + with self.assertRaises(torch.cuda.OutOfMemoryError): + torch.empty(alloc_bytes, dtype=torch.int8, device="cuda") + + # Check that rejection counter was incremented + stats = torch.cuda.memory_stats() + self.assertGreater(stats["num_oom_rejections"], 0) + + finally: + torch.cuda.empty_cache() + torch._C._accelerator_setAllocatorSettings( + "throw_on_cudamalloc_oom:False,per_process_memory_fraction:1.0" + ) + def test_allocator_backend(self): def check_output(script: str) -> str: return ( @@ -4997,7 +5327,7 @@ def test_raises_oom(self, max_split_size_mb_setting): torch.empty(1024 * 1024 * 1024 * 1024, device="cuda") @unittest.skipIf( - not (IS_LINUX and os.uname().machine == "x86_64"), "cpp traces only on linux" + not ((IS_X86 or IS_ARM64) and IS_LINUX), "cpp traces are linux x86/aarch64 only" ) @unittest.skipIf( TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" @@ -5123,7 +5453,11 @@ def test_temperature(self): @unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available") def test_device_memory_used(self): """ - Verify used device memory in bytes + Verify used device memory in bytes. + On Windows the NVML used value has been observed not to increase after + a CUDA allocation (delta 0); we only assert API sanity there (non-negative, + non-decreasing after alloc, <= total memory). Need to investigate expected behavior + with Windows WDDM """ torch.cuda.synchronize() gc.collect() @@ -5134,9 +5468,20 @@ def test_device_memory_used(self): torch.cuda.synchronize() torch.cuda.empty_cache() b = torch.cuda.device_memory_used() - mem_bytes = b - a - # test the order of magnitude - self.assertTrue(num_bytes // 32 <= mem_bytes <= num_bytes * 32) + if IS_WINDOWS: + # NVML used memory does not reflect CUDA allocations on WDDM; only check API sanity + self.assertGreaterEqual(a, 0, "device_memory_used should be non-negative") + self.assertGreaterEqual(b, 0, "device_memory_used should be non-negative") + self.assertGreaterEqual( + b, a, "used memory should not decrease after allocation" + ) + total = torch.cuda.get_device_properties(0).total_memory + self.assertLessEqual(a, total, "used should not exceed total device memory") + self.assertLessEqual(b, total, "used should not exceed total device memory") + else: + mem_bytes = b - a + # test the order of magnitude + self.assertTrue(num_bytes // 32 <= mem_bytes <= num_bytes * 32) @unittest.skipIf(not TEST_PYNVML, "pynvml/amdsmi is not available") def test_power_draw(self): @@ -5828,6 +6173,7 @@ def test_pin_memory_use(self, use_cuda_host_register): "use_memory, delete_memory", [(True, True), (True, False), (False, True), (False, False)], ) + @skip_background_threads_on_windows def test_two_graphs( self, use_background_threads, use_cuda_host_register, use_memory, delete_memory ): @@ -5931,6 +6277,43 @@ def test_mempool_id(self): # increments the id self.assertTrue(abs(pool2[1] - pool1[1]) > 0) + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + def test_pool_id_in_snapshot(self): + try: + torch.cuda.memory.empty_cache() + torch.cuda.memory._record_memory_history("all") + + pool = torch.cuda.MemPool() + with torch.cuda.use_mem_pool(pool): + x = torch.rand(64, device="cuda") + + ss = torch.cuda.memory._snapshot() + + # segment_pool_id should match the MemPool id + found_segment = False + for seg in ss["segments"]: + if seg["segment_pool_id"] == pool.id: + found_segment = True + break + self.assertTrue(found_segment) + + # trace entries for this allocation should carry pool_id + found_trace = False + for trace in ss["device_traces"]: + for te in trace: + if "pool_id" not in te: + continue + if te["pool_id"] == pool.id and te["action"] == "alloc": + found_trace = True + break + self.assertTrue(found_trace) + + del x + finally: + torch.cuda.memory._record_memory_history(None) + def get_dummy_allocator(self, check_vars): dummy_allocator_source_vars = """ #include @@ -6680,6 +7063,12 @@ def test_graph_capture_pre_capture_stream_use(self): "graph_capture_record_stream_reuse:False" ) + # expandable_segments not supported (PYTORCH_C10_DRIVER_API_SUPPORTED not defined for windows builds) + @unittest.skipIf( + IS_WINDOWS and SM89OrLater, + "expandable_segments not supported (PYTORCH_C10_DRIVER_API_SUPPORTED not defined for windows builds)", + ) + @skipIfRocm(msg="expandable_segments mode is not supported on ROCm") @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Load_inline doesn't work in fbcode") def test_mempool_expandable(self): torch.cuda.empty_cache() @@ -6880,6 +7269,203 @@ def get_total_allocated(segments): "Allocated memory should be same with or without traces", ) + @unittest.skipIf(TEST_WITH_ROCM, "not enabled by default on rocm") + @serialTest() + def test_multi_threads_alloc_in_same_order(self): + def alloc_tensors( + stream: torch.cuda.Stream, + sizes: [int], + device: str, + alloc_results: [], + lock, + thread_id=0, + ): + while len(sizes) > 0: + with lock: + if len(sizes) == 0: + return + sz = sizes.pop() + alloc_results.append( + ( + torch.empty(sz, dtype=torch.int8, device=device), + thread_id, + sz, + ) + ) + + def alloc_with_threads( + thread_count: int, stream: torch.cuda.Stream, mem_sizes: [int], device: str + ): + tensors_list = [] + threads = [] + mem_sizes = mem_sizes[:] + lock = threading.Lock() + for i in range(thread_count): + t = threading.Thread( + target=alloc_tensors, + args=(stream, mem_sizes, "cuda", tensors_list, lock, i), + ) + t.start() + threads.append(t) + for t in threads: + t.join() + tensors_ptrs = [ + (t.data_ptr(), thread_id, sz) for (t, thread_id, sz) in tensors_list + ] + return tensors_ptrs + + stream = torch.cuda.current_stream() + thread_count = 4 + mem_sizes = [ + s * 2 * 1024 * 1024 for s in range(1, 5) for _ in range(thread_count) + ] + tensor_ptrs_round_1 = alloc_with_threads( + thread_count, stream, mem_sizes, "cuda" + ) + tensor_ptrs_round_2 = alloc_with_threads( + thread_count, stream, mem_sizes, "cuda" + ) + ptrs_round_1 = torch.tensor([ptr for ptr, _, _ in tensor_ptrs_round_1]) + ptrs_round_2 = torch.tensor([ptr for ptr, _, _ in tensor_ptrs_round_2]) + self.assertTrue((ptrs_round_1 == ptrs_round_2).all().item()) + + @unittest.skipIf(TEST_WITH_ROCM, "not enabled by default on rocm") + @serialTest() + def test_nccl_mem_alloc_addresses_in_random_order(self): + """ + Test NCCL mem allocator with non-increasing allocation addresses. + + This test uses a custom allocator to simulate ncclMemAlloc scenario where + allocated memory addresses may not be monotonically increasing across + different ranks due to fragmentation, OS allocation policies, or prior + allocations. + + The registration counter-based comparator ensures consistent block ordering + regardless of actual memory addresses, which is critical for NCCL + symmetric memory alignment. + """ + + from cuda.bindings import runtime + + ALLOC_FN = ctypes.CFUNCTYPE( + ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p + ) + FREE_FN = ctypes.CFUNCTYPE( + None, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, ctypes.c_void_p + ) + + class AllocState: + def __init__(self, test_instance): + self.first_stream = None + self.second_stream = None + self.allocated_addrs = [] + self.buffer_size = 1 * 1024 * 1024 * 1024 # 1GB + self.base_ptr = -1 + err, base_ptr = runtime.cudaMalloc(self.buffer_size) + test_instance.assertEqual( + err, + runtime.cudaError_t.cudaSuccess, + "init allocation for test_nccl_mem_alloc_addresses_in_random_order should be successful", + ) + self.base_ptr = base_ptr + self.head = base_ptr + self.tail = base_ptr + self.buffer_size + + def __del__(self): + if self.base_ptr > 0: + runtime.cudaFree(self.base_ptr) + + state = AllocState(self) + + def my_alloc(size, device, stream, _runtime=runtime): + nonlocal state + if state.first_stream is None: + state.first_stream = stream + elif state.second_stream is None: + state.second_stream = stream + + if state.first_stream == stream: + ptr = state.head + state.head += size + else: + state.tail -= size + ptr = state.tail + state.allocated_addrs.append(ptr) + return ptr + + def my_free(ptr, size, device, stream, _runtime=runtime): + pass + + # Must keep these alive for the lifetime of the allocator + c_alloc = ALLOC_FN(my_alloc) + c_free = FREE_FN(my_free) + alloc_ptr = ctypes.cast(c_alloc, ctypes.c_void_p).value + free_ptr = ctypes.cast(c_free, ctypes.c_void_p).value + allocator = torch._C._cuda_customAllocator(alloc_ptr, free_ptr) + pool = torch.cuda.MemPool(allocator) + + first_stream = torch.cuda.Stream() + second_stream = torch.cuda.Stream() + + tensor_sizes = [24 * 1024 * 1024, 32 * 1024 * 1024] + alloc_cases = [] + case_no = 0 + for idx in range(3): + for stream_idx, stream in enumerate([first_stream, second_stream]): + # for second stream in reverse order + reverse_order = stream_idx % 2 == 1 + tensor_sizes.sort(reverse=reverse_order) + for size in tensor_sizes: + alloc_cases.append( + {"stream": stream, "size": size, "case_no": case_no} + ) + case_no += 1 + + # Test allocation and deallocation patterns + def alloc_tensors(): + all_tensors = [] + for case in alloc_cases: + with torch.cuda.stream(case["stream"]): + all_tensors.append( + torch.empty(case["size"], dtype=torch.uint8, device="cuda") + ) + return all_tensors + + with torch.cuda.use_mem_pool(pool): + first_round_tensors = alloc_tensors() + tensor_ptrs = [t.data_ptr() for t in first_round_tensors] + # for t in first_round_tensors: + # del t + del first_round_tensors + + with torch.cuda.use_mem_pool(pool): + second_round_tensors = alloc_tensors() + second_round_tensor_ptrs = [t.data_ptr() for t in second_round_tensors] + del second_round_tensors + + mem_snapshot = pool.snapshot() + self.assertEqual( + len(mem_snapshot), + len(tensor_ptrs), + f"expected to have {len(tensor_ptrs)} segments, but actually got {len(mem_snapshot)}", + ) + + for idx, first_addr in enumerate(tensor_ptrs): + second_addr = second_round_tensor_ptrs[idx] + self.assertEqual( + second_addr, + first_addr, + "mem addr allocated for second round should be same with first round", + ) + self.assertEqual( + second_addr, + state.allocated_addrs[idx], + f"{second_round_tensor_ptrs[idx]=} != {state.allocated_addrs[idx]=}", + ) + del pool + del state + torch.cuda.empty_cache() + @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest @@ -7244,25 +7830,6 @@ def test_graph_grad_scaling(self, device, dtype, optim_info, foreach, fused): self.assertEqual(scaler._growth_tracker, growth_tracker) -@unittest.skipIf( - not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green context not available, skipping tests" -) -class TestGreenContext(TestCase): - def test_greencontext_restores_stream(self): - # need to start on a side stream as we are comparing pointers and want to avoid - # two NULL streams... - s = torch.cuda.Stream() - with torch.cuda.stream(s): - start_stream = torch.cuda.current_stream() - ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) - ctx.set_context() - context_stream = torch.cuda.current_stream() - ctx.pop_context() - end_stream = torch.cuda.current_stream() - self.assertEqual(start_stream.cuda_stream, end_stream.cuda_stream) - self.assertNotEqual(start_stream.cuda_stream, context_stream.cuda_stream) - - @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") class TestGDS(TestCase): def _get_tmp_dir_fs_type(self): @@ -8396,7 +8963,7 @@ def test_graph_external_wait_and_record(self): # cudaEventQuery() will succeed before that happens. # See: - # "Before the first call to cudaEventRecord(), an event represents an empty set of work, so for example cudaEventQuery() would return cudaSuccess." # noqa: B950 + # "Before the first call to cudaEventRecord(), an event represents an empty set of work, so for example cudaEventQuery() would return cudaSuccess." # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html self.assertTrue(start_event.query(), "Start event's work should be empty") @@ -8436,7 +9003,7 @@ def test_graph_external_wait_and_record(self): # This writes allows wait_for_cpu to proceed # This is an atomic store at system scope according to this rule: - # "the scope is thread_scope_system and it is a load or store that affects a naturally-aligned object of sizes 1, 2, 4, 8, or 16 bytes on mapped memory" # noqa: B950 + # "the scope is thread_scope_system and it is a load or store that affects a naturally-aligned object of sizes 1, 2, 4, 8, or 16 bytes on mapped memory" # https://nvidia.github.io/cccl/libcudacxx/extended_api/memory_model.html#atomicity # Note that every CPU store is implicitly system scope, @@ -8581,12 +9148,163 @@ def test_fx_memory_profiler_augmentation(self): self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) +@unittest.skipIf( + not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported" +) +class TestCudaGreenContexts(TestCase): + def setUp(self): + super().setUp() + + def tearDown(self): + super().tearDown() + + def test_greencontext_restores_stream(self): + # need to start on a side stream as we are comparing pointers and want to avoid + # two NULL streams... + s = torch.cuda.Stream() + with torch.cuda.stream(s): + start_stream = torch.cuda.current_stream() + ctx = torch.cuda.green_contexts.GreenContext.create(num_sms=1) + ctx.set_context() + context_stream = torch.cuda.current_stream() + ctx.pop_context() + end_stream = torch.cuda.current_stream() + self.assertEqual(start_stream.cuda_stream, end_stream.cuda_stream) + self.assertNotEqual(start_stream.cuda_stream, context_stream.cuda_stream) + + @serialTest() + def test_greencontext_carveout(self): + # By default, everything is performed on the current device + a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(num_sms=1) + ctx.set_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t0 = time.perf_counter() + partial_res = torch.matmul(a, a) + torch.cuda.synchronize() + t1 = time.perf_counter() + ctx.pop_context() + torch.matmul(a, a) + torch.cuda.synchronize() + t2 = time.perf_counter() + full_res = torch.matmul(a, a) + torch.cuda.synchronize() + t3 = time.perf_counter() + self.assertEqual(partial_res, full_res) + self.assertGreater(t1 - t0, t3 - t2) + + @serialTest() + def test_greencontext_stream_carveout(self): + # By default, everything is performed on the current device + a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(num_sms=1) + ctx_stream = ctx.Stream() + with torch.cuda.stream(ctx_stream): + torch.matmul(a, a) + torch.cuda.synchronize() + t0 = time.perf_counter() + partial_res = torch.matmul(a, a) + torch.cuda.synchronize() + t1 = time.perf_counter() + torch.matmul(a, a) + torch.cuda.synchronize() + t2 = time.perf_counter() + full_res = torch.matmul(a, a) + torch.cuda.synchronize() + t3 = time.perf_counter() + self.assertEqual(partial_res, full_res) + self.assertGreater(t1 - t0, t3 - t2) + + @serialTest() + def test_greencontext_graphs(self): + # By default, everything is performed on the current device + a = torch.randn(4096, 4096, device="cuda", dtype=torch.bfloat16) + ctx = torch.cuda.green_contexts.GreenContext.create(num_sms=1) + ctx.set_context() + partial_res = torch.matmul(a, a) + ctx.pop_context() + full_res = torch.matmul(a, a) + full_res.zero_() + partial_res.zero_() + torch.cuda.synchronize() + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + ctx.set_context() + partial_res = torch.matmul(a, a) + ctx.pop_context() + full_res = torch.matmul(a, a) + g.replay() + self.assertEqual(partial_res, full_res) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_WORKQUEUE_CONFIG, "Workqueue config is not supported" + ) + @serialTest() + def test_greencontext_workqueue_concurrency_limit(self): + # By default, everything is performed on the current device + n_streams = 4 + GreenContext = torch.cuda.green_contexts.GreenContext + max_wq = GreenContext.max_workqueue_concurrency() + if max_wq < n_streams: + self.skipTest( + f"Device has {max_wq} workqueue(s), need >{n_streams} to test concurrency limiting" + ) + + wq_events = [torch.cuda.Event(enable_timing=True) for _ in range(n_streams * 2)] + + def run_multi_stream_sleep(streams): + for i, s in enumerate(streams): + with torch.cuda.stream(s): + # note we need to record timing events to ensure that the + # workqueue is actually used + wq_events[i * 2].record(s) + # 100M cycles is enough on all currently supported GPUs + # to ensure that the test runs significantly longer than + # host overhead to 1) activate the different streams, + # 2) record timing events, and 3) synchronize with the GPU. + torch.cuda._sleep(100_000_000) + wq_events[i * 2 + 1].record(s) + torch.cuda.synchronize() + + # Note: in case we have lazy module loading, ensure that we called the + # sleep kernel at least once s.t. it is loaded in memory. + torch.cuda._sleep(1_000) + + # Baseline: n_streams streams from default context, full workqueue concurrency + baseline_streams = [torch.cuda.Stream() for _ in range(n_streams)] + # note: torch.cuda.synchronize() will wait for all kernels on all streams + # so we can safely use CPU based timing here. + t0 = time.perf_counter() + run_multi_stream_sleep(baseline_streams) + t1 = time.perf_counter() + baseline_time = t1 - t0 + + # Green context with workqueue concurrency limited to 1 + ctx = GreenContext.create( + workqueue_scope="balanced", + workqueue_concurrency_limit=1, + ) + ctx.set_context() + green_streams = [ctx.Stream() for _ in range(n_streams)] + t2 = time.perf_counter() + run_multi_stream_sleep(green_streams) + t3 = time.perf_counter() + ctx.pop_context() + limited_time = t3 - t2 + + self.assertGreater(limited_time, baseline_time) + + instantiate_parametrized_tests(TestCuda) -instantiate_parametrized_tests(TestCudaMallocAsync) +instantiate_parametrized_tests(TestCudaAllocator) instantiate_parametrized_tests(TestCompileKernel) instantiate_parametrized_tests(TestCachingHostAllocatorCudaGraph) instantiate_device_type_tests(TestCudaOptims, globals()) instantiate_device_type_tests(TestCudaDeviceParametrized, globals()) +instantiate_device_type_tests(TestCudaGreenContexts, globals(), except_for="cpu") + if __name__ == "__main__": run_tests() diff --git a/test/test_cuda_expandable_segments.py b/test/test_cuda_expandable_segments.py index f22b50c64313e..25c2e9eaff2c5 100644 --- a/test/test_cuda_expandable_segments.py +++ b/test/test_cuda_expandable_segments.py @@ -7,7 +7,7 @@ from test_cuda import ( # noqa: F401 TestBlockStateAbsorption, TestCuda, - TestCudaMallocAsync, + TestCudaAllocator, ) import torch diff --git a/test/test_cuda_graph_annotations.py b/test/test_cuda_graph_annotations.py new file mode 100644 index 0000000000000..f29bf78a2ab34 --- /dev/null +++ b/test/test_cuda_graph_annotations.py @@ -0,0 +1,398 @@ +# Owner(s): ["module: cuda graphs"] + +"""Tests for CUDA graph kernel annotation via mark_kernels.""" + +import unittest + +import torch +from torch.cuda._graph_annotations import ( + _is_tools_id_unavailable, + clear_kernel_annotations, + enable_annotations, + get_kernel_annotations, + mark_kernels, + remap_to_exec_graph, + resolve_pending_annotations, +) +from torch.testing._internal.common_utils import run_tests, TestCase + + +TEST_CUDA = torch.cuda.is_available() + +try: + import cuda.bindings.runtime # noqa: F401 + + TEST_CUDA_BINDINGS = True +except ImportError: + TEST_CUDA_BINDINGS = False + + +@unittest.skipUnless(TEST_CUDA, "CUDA not available") +@unittest.skipUnless(TEST_CUDA_BINDINGS, "cuda.bindings not available") +@unittest.skipIf( + _is_tools_id_unavailable(), + "cudaGraphNodeGetToolsId not available (needs cuda-compat >= 13.1)", +) +class TestMarkKernels(TestCase): + def setUp(self): + enable_annotations() + clear_kernel_annotations() + + def tearDown(self): + clear_kernel_annotations() + + def test_noop_outside_capture(self): + x = torch.randn(8, device="cuda") + with mark_kernels("test"): + _ = x + 1 + self.assertEqual(len(get_kernel_annotations()), 0) + + def test_single_scope(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("phase_a"): + _ = x + 1 + resolve_pending_annotations() + + annotations = get_kernel_annotations() + self.assertGreater(len(annotations), 0) + for anns in annotations.values(): + for ann in anns: + self.assertEqual(ann, {"str": "phase_a"}) + + def test_multiple_scopes_no_overlap(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("scope_1"): + _ = x + 1 + with mark_kernels("scope_2"): + _ = x * 2 + resolve_pending_annotations() + + annotations = get_kernel_annotations() + scope_1_ids = set() + scope_2_ids = set() + for tid, anns in annotations.items(): + self.assertEqual(len(anns), 1) + if anns[0] == {"str": "scope_1"}: + scope_1_ids.add(tid) + elif anns[0] == {"str": "scope_2"}: + scope_2_ids.add(tid) + + self.assertGreater(len(scope_1_ids), 0) + self.assertGreater(len(scope_2_ids), 0) + self.assertEqual(len(scope_1_ids & scope_2_ids), 0) + + def test_dict_annotation(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + annotation = {"name": "all_gather", "Group size": 2, "dtype": "bfloat16"} + with torch.cuda.graph(graph): + with mark_kernels(annotation): + _ = x + 1 + resolve_pending_annotations() + + annotations = get_kernel_annotations() + self.assertGreater(len(annotations), 0) + for anns in annotations.values(): + self.assertEqual(anns[0]["name"], "all_gather") + self.assertEqual(anns[0]["Group size"], 2) + + def test_clear_resets_state(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("test"): + _ = x + 1 + resolve_pending_annotations() + + self.assertGreater(len(get_kernel_annotations()), 0) + clear_kernel_annotations() + self.assertEqual(len(get_kernel_annotations()), 0) + + def test_resolve_without_scopes_is_noop(self): + resolve_pending_annotations() + self.assertEqual(len(get_kernel_annotations()), 0) + + def test_scope_with_no_kernels(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + _ = x + 1 + with mark_kernels("empty"): + pass + _ = x * 2 + resolve_pending_annotations() + + for anns in get_kernel_annotations().values(): + for ann in anns: + self.assertNotEqual(ann, "empty") + + def test_only_annotates_scope_kernels(self): + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + _ = x + 1 + _ = x * 2 + with mark_kernels("tagged"): + _ = x + 3 + _ = x - 1 + resolve_pending_annotations() + + annotations = get_kernel_annotations() + total_annotated = sum(len(anns) for anns in annotations.values()) + self.assertGreater(total_annotated, 0) + for anns in annotations.values(): + for ann in anns: + self.assertEqual(ann, {"str": "tagged"}) + + def test_nested_scopes_innermost_wins(self): + """With nested string scopes, the innermost name wins.""" + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("outer"): + _ = x + 1 # outer only + with mark_kernels("inner"): + _ = x * 2 # nested: inner should win + _ = x - 1 # outer only + resolve_pending_annotations() + + annotations = get_kernel_annotations() + outer_ids = set() + inner_ids = set() + for tid, anns in annotations.items(): + self.assertEqual( + len(anns), 1, f"toolsId {hex(tid)} has {len(anns)} annotations" + ) + ann = anns[0] + self.assertIsInstance(ann, dict) + if ann["str"] == "outer": + outer_ids.add(tid) + elif ann["str"] == "inner": + inner_ids.add(tid) + + self.assertGreater(len(outer_ids), 0, "Should have outer-only kernels") + self.assertGreater(len(inner_ids), 0, "Should have inner kernels") + self.assertEqual(len(outer_ids & inner_ids), 0) + + def test_nested_dict_scopes_inner_wins_common_keys(self): + """With truly nested dict scopes, inner wins for common keys, + outer-only keys are preserved.""" + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + outer_ann = {"name": "ag_collective", "stream": 71} + inner_ann = { + "name": "all_gather", + "stream": 62, + "In msg nelems": 1024, + "dtype": "bfloat16", + } + + with torch.cuda.graph(graph): + with mark_kernels(outer_ann): + _ = x + 1 # outer only + with mark_kernels(inner_ann): + _ = x * 2 # nested + _ = x - 1 # outer only + resolve_pending_annotations() + + annotations = get_kernel_annotations() + outer_only_ids = set() + nested_ids = set() + for tid, anns in annotations.items(): + self.assertEqual(len(anns), 1) + ann = anns[0] + self.assertIsInstance(ann, dict) + if ann["name"] == "ag_collective": + outer_only_ids.add(tid) + elif ann["name"] == "all_gather": + nested_ids.add(tid) + # Inner wins for common keys + self.assertEqual(ann["stream"], 62) + # Inner-only keys preserved + self.assertEqual(ann["In msg nelems"], 1024) + self.assertEqual(ann["dtype"], "bfloat16") + + self.assertGreater(len(outer_only_ids), 0, "Should have outer-only kernels") + self.assertGreater(len(nested_ids), 0, "Should have nested kernels") + + def test_same_range_scopes_inner_wins_common_keys(self): + """With same-range scopes (inner ctx exits first), inner wins + for common keys, outer-only keys are preserved.""" + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + outer_ann = {"name": "ag_collective", "stream": 71} + inner_ann = { + "name": "all_gather", + "stream": 62, + "In msg nelems": 1024, + "dtype": "bfloat16", + } + + with torch.cuda.graph(graph): + # Both scopes wrap the same kernels; inner exits first. + with mark_kernels(outer_ann): + with mark_kernels(inner_ann): + _ = x + 1 + resolve_pending_annotations() + + annotations = get_kernel_annotations() + self.assertGreater(len(annotations), 0) + for anns in annotations.values(): + self.assertEqual(len(anns), 1) + ann = anns[0] + self.assertIsInstance(ann, dict) + # Inner wins for common keys + self.assertEqual(ann["name"], "all_gather", "Inner name should win") + self.assertEqual(ann["stream"], 62, "Inner stream should win") + # Inner-only keys preserved + self.assertEqual(ann["In msg nelems"], 1024) + self.assertEqual(ann["dtype"], "bfloat16") + + def test_remap_to_exec_graph(self): + from cuda.bindings import runtime as cuda_runtime + + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("test"): + _ = x + 1 + resolve_pending_annotations() + + annotations_before = dict(get_kernel_annotations()) + self.assertGreater(len(annotations_before), 0) + + exec_handle = cuda_runtime.cudaGraphExec_t( + init_value=graph.raw_cuda_graph_exec() + ) + _, exec_graph_id = cuda_runtime.cudaGraphExecGetId(exec_handle) + + remap_to_exec_graph(graph) + + annotations_after = get_kernel_annotations() + self.assertEqual(len(annotations_after), len(annotations_before)) + for tools_id in annotations_after: + self.assertEqual(tools_id >> 32, exec_graph_id) + + def test_disabled_is_noop(self): + from torch.cuda._graph_annotations import disable_annotations + + disable_annotations() + + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph): + with mark_kernels("should_not_appear"): + _ = x + 1 + resolve_pending_annotations() + + self.assertEqual(len(get_kernel_annotations()), 0) + + # Re-enable for other tests + enable_annotations() + + def test_enable_annotations_kwarg(self): + """enable_annotations on torch.cuda.graph auto-resolves annotations.""" + from torch.cuda._graph_annotations import disable_annotations + + # Start with annotations disabled to verify the kwarg enables them. + disable_annotations() + clear_kernel_annotations() + + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph, enable_annotations=True): + with mark_kernels("auto"): + _ = x + 1 + + annotations = get_kernel_annotations() + self.assertGreater(len(annotations), 0) + for anns in annotations.values(): + for ann in anns: + self.assertEqual(ann, {"str": "auto"}) + + def test_enable_annotations_does_not_clear(self): + """Annotations from a previous graph survive a second capture.""" + from torch.cuda._graph_annotations import disable_annotations + + disable_annotations() + clear_kernel_annotations() + + graph1 = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph1, enable_annotations=True): + with mark_kernels("first"): + _ = x + 1 + + first_count = len(get_kernel_annotations()) + self.assertGreater(first_count, 0) + + graph2 = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph2, enable_annotations=True): + with mark_kernels("second"): + _ = x * 2 + + # Both graphs' annotations should be present. + self.assertGreater(len(get_kernel_annotations()), first_count) + + def test_enable_annotations_remaps_to_exec_graph(self): + """enable_annotations=True must remap toolsIds to the exec graph ID.""" + from cuda.bindings import runtime as cuda_runtime + + clear_kernel_annotations() + + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + with torch.cuda.graph(graph, enable_annotations=True): + with mark_kernels("remap_test"): + _ = x + 1 + + exec_handle = cuda_runtime.cudaGraphExec_t( + init_value=graph.raw_cuda_graph_exec() + ) + _, exec_graph_id = cuda_runtime.cudaGraphExecGetId(exec_handle) + + annotations = get_kernel_annotations() + self.assertGreater(len(annotations), 0) + for tools_id in annotations: + graph_id = tools_id >> 32 + self.assertEqual( + graph_id, + exec_graph_id, + f"toolsId 0x{tools_id:016x} has graph_id {graph_id}, " + f"expected exec_graph_id {exec_graph_id}", + ) + + def test_enable_annotations_false_does_not_auto_resolve(self): + """Without enable_annotations, pending scopes are not resolved.""" + graph = torch.cuda.CUDAGraph() + x = torch.randn(8, device="cuda") + + # enable_annotations=False (default): no auto-resolve. + with torch.cuda.graph(graph): + with mark_kernels("unresolved"): + _ = x + 1 + + # Annotations should be empty because resolve was never called. + self.assertEqual(len(get_kernel_annotations()), 0) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_cuda_multigpu.py b/test/test_cuda_multigpu.py index 579ca1675f92b..199b403e00f92 100644 --- a/test/test_cuda_multigpu.py +++ b/test/test_cuda_multigpu.py @@ -42,7 +42,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest class TestCudaMultiGPU(TestCase): diff --git a/test/test_cuda_nvml_based_avail.py b/test/test_cuda_nvml_based_avail.py index eaf2365315d24..4f3be6e3c3d9b 100644 --- a/test/test_cuda_nvml_based_avail.py +++ b/test/test_cuda_nvml_based_avail.py @@ -32,7 +32,7 @@ TEST_CUDA = torch.cuda.is_available() if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # type: ignore[misc, assignment] # noqa: F811 + TestCase = NoTest # type: ignore[misc, assignment] @torch.testing._internal.common_utils.markDynamoStrictTest diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 60d4f36e0c16e..c8eccd478a1c8 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -12,7 +12,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest @torch.testing._internal.common_utils.markDynamoStrictTest diff --git a/test/test_cuda_sanitizer.py b/test/test_cuda_sanitizer.py index 35720176901d4..db77e1dcef0b8 100644 --- a/test/test_cuda_sanitizer.py +++ b/test/test_cuda_sanitizer.py @@ -13,7 +13,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest class TestArgumentHandler(TestCase): diff --git a/test/test_cuda_trace.py b/test/test_cuda_trace.py index 0794683f4ef26..6feb1dbf34885 100644 --- a/test/test_cuda_trace.py +++ b/test/test_cuda_trace.py @@ -14,7 +14,7 @@ if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest @torch.testing._internal.common_utils.markDynamoStrictTest diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index c33465466790b..23d5c9bce3468 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -53,8 +53,10 @@ run_tests, scoped_load_inline, skipIfTorchDynamo, + skipIfXpu, subtest, TemporaryFileName, + TEST_XPU, TestCase, ) from torch.testing._internal.custom_op_db import numpy_nonzero @@ -68,6 +70,12 @@ MyList = list MyTensor = torch.Tensor +device_type = ( + acc.type + if (acc := torch.accelerator.current_accelerator(check_available=True)) + else "cpu" +) + def requires_compile(fun): fun = unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")(fun) @@ -1148,7 +1156,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: with self.assertRaisesRegex(RuntimeError, "multiple times"): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") - def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError # Unless we delete the original op. @@ -1156,14 +1164,14 @@ def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 # Smoke test @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") - def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError custom_ops._destroy(f"{TestCustomOp.test_ns}::foo") def test_autograd_notimplemented(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") - def foo(x: torch.Tensor) -> torch.Tensor: # noqa: F811 + def foo(x: torch.Tensor) -> torch.Tensor: raise NotImplementedError x = torch.randn(3, requires_grad=True) @@ -1585,7 +1593,8 @@ def foo_backward(ctx, saved, grad0, grad1): with self.assertRaisesRegex(RuntimeError, "is not a Tensor"): op(x) - @unittest.skipIf(not TEST_CUDA, "requires CUDA") + @skipIfXpu(msg="Deprecated torch.custom_ops API") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires CUDA or XPU") def test_impl_separate(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: @@ -1595,7 +1604,7 @@ def foo(x: torch.Tensor) -> torch.Tensor: def foo_cpu(x): return x.sin() - @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types="cuda") + @custom_ops.impl(f"{TestCustomOp.test_ns}::foo", device_types=device_type) def foo_cuda(x): return x.cos() @@ -1604,12 +1613,13 @@ def foo_cuda(x): result = op(x) self.assertEqual(result, foo_cpu(x)) - x_cuda = x.cuda() + x_cuda = x.to(device_type) op = self.get_op(f"{self.test_ns}::foo") result = op(x_cuda) self.assertEqual(result, foo_cuda(x_cuda)) - @unittest.skipIf(not TEST_CUDA, "requires CUDA") + @skipIfXpu(msg="Deprecated torch.custom_ops API") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires CUDA or XPU") def test_impl_multiple(self): @custom_ops.custom_op(f"{TestCustomOp.test_ns}::foo") def foo(x: torch.Tensor) -> torch.Tensor: @@ -1624,7 +1634,7 @@ def foo_impl(x): result = op(x) self.assertEqual(result, foo_impl(x)) - x_cuda = x.cuda() + x_cuda = x.to(device_type) result = op(x_cuda) self.assertEqual(result, foo_impl(x_cuda)) @@ -1765,7 +1775,7 @@ def forward(self, x_1): sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1) sym_size_int_2 = torch.ops.aten.sym_size.int(x_1, 2) numpy_view_copy = torch.ops._torch_testing.numpy_view_copy.default(x_1, [sym_size_int, sym_size_int_1, sym_size_int_2]); x_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = None - return numpy_view_copy""", # noqa: B950 + return numpy_view_copy""", ) @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work on windows") @@ -2168,11 +2178,11 @@ def test_impl_device_cpu(self): self._test_impl_device("foo2", ["cpu"], "cpu") self._test_impl_device("foo3", ["cpu", "cuda"], "cpu") - @unittest.skipIf(not TEST_CUDA, "requires cuda") + @unittest.skipIf(not TEST_CUDA and not TEST_XPU, "requires cuda or xpu") def test_impl_device_cuda(self): - self._test_impl_device("foo4", "default", "cuda") - self._test_impl_device("foo5", ["cuda"], "cuda") - self._test_impl_device("foo6", ["cpu", "cuda"], "cuda") + self._test_impl_device("foo4", "default", device_type) + self._test_impl_device("foo5", [device_type], device_type) + self._test_impl_device("foo6", ["cpu", device_type], device_type) def test_impl_device_function(self): lib = self.lib() @@ -2371,6 +2381,56 @@ def test_load_library(self): ): torch.ops.load_library("libnoexist.so") + def test_list_scalar_type(self): + lib = self.lib() + lib.define("scalar_list(Tensor x, ScalarType[] dts) -> Tensor") + + received = None + + @torch.library.impl(lib, "scalar_list", "CPU") + def _(x, dts): + nonlocal received + received = dts + return x.clone() + + x = torch.randn(3) + torch.ops._test_custom_op.scalar_list(x, [torch.float32, torch.bfloat16]) + self.assertEqual(received, [torch.float32, torch.bfloat16]) + + def test_list_layout(self): + lib = self.lib() + lib.define("layout_list(Tensor x, Layout[] layouts) -> Tensor") + + received = None + + @torch.library.impl(lib, "layout_list", "CPU") + def _(x, layouts): + nonlocal received + received = layouts + return x.clone() + + x = torch.randn(3) + torch.ops._test_custom_op.layout_list(x, [torch.strided, torch.sparse_coo]) + self.assertEqual(received, [torch.strided, torch.sparse_coo]) + + def test_list_memory_format(self): + lib = self.lib() + lib.define("memfmt_list(Tensor x, MemoryFormat[] fmts) -> Tensor") + + received = None + + @torch.library.impl(lib, "memfmt_list", "CPU") + def _(x, fmts): + nonlocal received + received = fmts + return x.clone() + + x = torch.randn(3) + torch.ops._test_custom_op.memfmt_list( + x, [torch.contiguous_format, torch.channels_last] + ) + self.assertEqual(received, [torch.contiguous_format, torch.channels_last]) + def op_with_incorrect_schema(testcase, name): lib = testcase.lib() @@ -3001,6 +3061,27 @@ def f( continue self.assertGreater(after, prev) + @skipIfTorchDynamo("Expected to fail due to no FakeTensor support; not a bug") + def test_mutated_optional_arg_default_none(self): + @torch.library.custom_op( + "_torch_testing::copy_optional_out", mutates_args={"out"} + ) + def copy_optional_out(x: Tensor, out: Optional[Tensor] = None) -> Tensor: + if out is not None: + out.copy_(x) + return x.new_empty(0) + return x.clone() + + x = torch.randn(3) + self.assertEqual(copy_optional_out(x), x) + + out = torch.empty_like(x) + version = out._version + result = copy_optional_out(x, out=out) + self.assertEqual(result.numel(), 0) + self.assertEqual(out, x) + self.assertGreater(out._version, version) + def test_mutated_no_warning(self): # Run in subprocess since the warning is emitted only once script = """\ @@ -4500,6 +4581,32 @@ def test_library_get_kernel_invalid(self): torch.library.get_kernel("test_invalid_kernel::cpu_only_op", "CUDA") +class TestLibrarySourceLocation(TestCase): + def test_library_source_location(self): + # Library.__init__ uses sys._getframe(1) to capture the caller's + # filename and line number. Verify this works correctly by creating + # a Library and checking the source location in the error message + # that appears when a duplicate DEF library is created. + script = """\ +import torch +lib1 = torch.library.Library("_test_loc", "DEF") +lib1.define("foo(Tensor x) -> Tensor") +try: + lib2 = torch.library.Library("_test_loc", "DEF") +except RuntimeError as e: + print(str(e)) +""" + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + text=True, + ) + self.assertEqual(result.returncode, 0, result.stderr) + # The error message should reference :2, since + # lib1 = torch.library.Library(...) is on line 2 of the script. + self.assertIn(":2", result.stdout) + + class MiniOpTestOther(CustomOpTestCaseBase): test_ns = "mini_op_test" diff --git a/test/test_decomp.py b/test/test_decomp.py index a5fc1ee423591..281ff0bb5aa7a 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -611,6 +611,29 @@ def test_quick_core_backward(self, device, dtype, op): def test_comprehensive(self, device, dtype, op): self.do_cross_ref(device, dtype, op, run_all=True) + def test_hann_window_decomp(self, device): + # Verify the hann_window decomp matches the native kernel for all four + # overloads: .default, .periodic, .out, .periodic_out. + from torch._decomp.decompositions import hann_window, hann_window_periodic + + for n in (0, 1, 8, 9): + # .default (periodic=True) + ref = torch.hann_window(n, device=device) + res = hann_window(n, device=device) + self.assertEqual(ref, res) + + # .periodic overload, explicit periodic flag + for periodic in (True, False): + ref = torch.hann_window(n, periodic, device=device) + res = hann_window_periodic(n, periodic, device=device) + self.assertEqual(ref, res) + + # dtype forwarding + ref = torch.hann_window(8, dtype=torch.float64, device=device) + res = hann_window_periodic(8, dtype=torch.float64, device=device) + self.assertEqual(ref, res) + self.assertEqual(res.dtype, torch.float64) + def test_uniform(self, device): size = (2, 3, 4, 5) dtype = torch.float32 @@ -1063,6 +1086,24 @@ def run_without_python_dispatcher(mode): "only backwards is decomposed, but dtype doesn't support AD" ) + def test_binary_cross_entropy_with_logits_decomp(self, device): + op_config = { + "self": torch.randn([4, 5, 6], dtype=torch.bfloat16, device=device), + "target": torch.randn([4, 5, 6], dtype=torch.bfloat16, device=device), + "weight": torch.randn([6], dtype=torch.float32, device=device), + "reduction": 2, + } + + ref = torch.ops.aten.binary_cross_entropy_with_logits.default(**op_config) + + decomp_table = torch._inductor.decomposition.select_decomp_table() + bce_decomp = decomp_table[ + torch.ops.aten.binary_cross_entropy_with_logits.default + ] + res = bce_decomp(**op_config) + + torch.testing.assert_close(ref, res, check_dtype=True) + instantiate_device_type_tests(TestDecomp, globals()) @@ -1286,6 +1327,27 @@ def forward_pass_fn(): in generated_codes[1] ) + @onlyCUDA + @skipIfCrossRef + def test_addmm_out_dtype_decomp(self, device): + cases = [ + {"beta": 1, "alpha": 1}, + {"beta": 0, "alpha": 1}, + {"beta": 2, "alpha": 3}, + ] + for kwargs in cases: + a = torch.randn(4, 8, dtype=torch.bfloat16, device=device) + b = torch.randn(8, 4, dtype=torch.bfloat16, device=device) + c = torch.randn(4, 4, dtype=torch.float32, device=device) + + ref = torch.ops.aten.addmm.dtype(c, a, b, torch.float32, **kwargs) + res = torch._decomp.decompositions.addmm_dtype( + c, a, b, out_dtype=torch.float32, **kwargs + ) + + self.assertEqual(res.dtype, torch.float32) + self.assertEqual(res, ref) + instantiate_device_type_tests(DecompOneOffTests, globals()) diff --git a/test/test_dlpack.py b/test/test_dlpack.py index ad7889a360c78..e5cd153d5b270 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -10,8 +10,10 @@ onlyCPU, onlyCUDA, onlyNativeDeviceTypes, + onlyOn, skipCUDAIfNotRocm, skipMeta, + skipXPUIf, ) from torch.testing._internal.common_dtype import ( all_mps_types_and, @@ -101,26 +103,26 @@ def _dlpack_conversion_with_streams(self, stream, x): # DLPack protocol that establishes correct stream order # does not behave as expected on Jetson stream.synchronize() - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): + stream = torch.Stream() + with stream: z = from_dlpack(x) stream.synchronize() return z @skipMeta - @onlyCUDA + @onlyOn(["xpu", "cuda"]) @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) def test_dlpack_conversion_with_streams(self, device, dtype): # Create a stream where the tensor will reside - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): + stream = torch.Stream() + with stream: # Do an operation in the actual stream x = make_tensor((5,), dtype=dtype, device=device) + 1 z = self._dlpack_conversion_with_streams(stream, x) self.assertEqual(z, x) @skipMeta - @onlyCUDA + @onlyOn(["xpu", "cuda"]) @dtypes( torch.float8_e5m2, torch.float8_e5m2fnuz, @@ -130,8 +132,8 @@ def test_dlpack_conversion_with_streams(self, device, dtype): torch.float4_e2m1fn_x2, ) def test_dlpack_conversion_with_streams_narrow_precision(self, device, dtype): - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): + stream = torch.Stream() + with stream: x = make_tensor((5,), dtype=torch.uint8, device=device) + 1 x = x.view(dtype) z = self._dlpack_conversion_with_streams(stream, x) @@ -361,7 +363,7 @@ def test_dlpack_invalid_cpu_stream(self): x.__dlpack__(stream=0) @skipMeta - @onlyCUDA + @onlyOn(["xpu", "cuda"]) @deviceCountAtLeast(2) def test_dlpack_tensor_on_different_device(self, devices): dev0, dev1 = devices[:2] @@ -372,7 +374,7 @@ def test_dlpack_tensor_on_different_device(self, devices): with self.assertRaisesRegex( BufferError, r"Can't export tensors on a different CUDA device" ): - with torch.cuda.device(dev1): + with torch.device(dev1): x.__dlpack__() # TODO: add interchange tests once NumPy 1.22 (dlpack support) is required @@ -521,7 +523,8 @@ def test_copy(self, device): self._test_from_dlpack(device, out_device="cpu", copy=True) @skipMeta - @onlyCUDA + @onlyOn(["xpu", "cuda"]) + @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/3077") def test_no_copy(self, device): # No copy, since tensor lives in the same device. self._test_from_dlpack(device) @@ -576,6 +579,7 @@ def test_dlpack_unsupported_dtype_error(self, device): @skipMeta @onlyNativeDeviceTypes + @skipXPUIf(True, "https://github.com/intel/torch-xpu-ops/issues/3074") def test_dlpack_exchange_api(self, device): """Comprehensive test of all DLPack Exchange API functions using inline C++""" # Check that the C API capsule exists and get it @@ -802,10 +806,13 @@ def test_dlpack_exchange_api(self, device): functions=["test_dlpack_exchange_api"], verbose=False, with_cuda=device.startswith("cuda"), + with_sycl=device.startswith("xpu"), ) # Run the comprehensive C++ test - module.test_dlpack_exchange_api(tensor, api_capsule, device.startswith("cuda")) + module.test_dlpack_exchange_api( + tensor, api_capsule, device.startswith(("cuda", "xpu")) + ) @skipMeta @onlyCUDA @@ -858,7 +865,7 @@ def test_numpy_cross_device_transfer(self, device): self.assertEqual(np_array2[0], 999) @skipMeta - @onlyCUDA + @onlyOn(["xpu", "cuda"]) @deviceCountAtLeast(2) def test_numpy_cross_device_multi_gpu(self, devices): """Test cross-device transfer to specific CUDA devices (cuda:0, cuda:1, etc).""" @@ -883,7 +890,9 @@ def test_numpy_cross_device_multi_gpu(self, devices): self.assertNotEqual(t0.device, t1.device) -instantiate_device_type_tests(TestTorchDlPack, globals(), allow_mps=True) +instantiate_device_type_tests( + TestTorchDlPack, globals(), allow_mps=True, allow_xpu=True +) if __name__ == "__main__": run_tests() diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 98a97fdf4d181..0236b0da53289 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -23,20 +23,23 @@ from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, + _iterate_exprs, DimConstraints, DimDynamic, expect_true, + free_symbols, guard_bool, guard_float, guard_int, + guarding_hint_or_throw, GuardOnDataDependentSymNode, has_free_symbols, is_symbolic, ShapeEnv, - size_hint, StatelessSymbolicContext, statically_known_false, statically_known_true, + SYMPY_INTERP, ) from torch.testing._internal.common_dtype import all_types_and from torch.testing._internal.common_utils import ( @@ -918,6 +921,46 @@ def test_non_overlapping_and_dense_unbacked(self): ) ) + def test_sympy_interp_is_non_overlapping_and_dense_flat_args(self): + # SYMPY_INTERP is used as the eval() namespace for guard code strings. + # Guard code prints IsNonOverlappingAndDenseIndicator(s0, s1, ..., st0, st1, ...) + # with flat args, so the SYMPY_INTERP function must accept flat args. + interp_fn = SYMPY_INTERP["IsNonOverlappingAndDenseIndicator"] + + # 1D contiguous: sizes=(5,), strides=(1,) + self.assertEqual(interp_fn(5, 1), 1) + # 1D non-contiguous: sizes=(5,), strides=(2,) + self.assertEqual(interp_fn(5, 2), 0) + # 1D single element: sizes=(1,), strides=(42,) + self.assertEqual(interp_fn(1, 42), 1) + + # 2D contiguous: sizes=(3, 4), strides=(4, 1) + self.assertEqual(interp_fn(3, 4, 4, 1), 1) + # 2D non-contiguous: sizes=(3, 4), strides=(5, 1) -- gap in memory + self.assertEqual(interp_fn(3, 4, 5, 1), 0) + # 2D transposed but still dense: sizes=(4, 3), strides=(1, 4) + self.assertEqual(interp_fn(4, 3, 1, 4), 1) + + # 4D contiguous (the exact scenario from the MAST job failure): + # sizes=(2, 3, 4, 5), strides=(60, 20, 5, 1) + self.assertEqual(interp_fn(2, 3, 4, 5, 60, 20, 5, 1), 1) + # 4D non-contiguous: + self.assertEqual(interp_fn(2, 3, 4, 5, 100, 20, 5, 1), 0) + + def test_sympy_interp_guard_eval_simulation(self): + # Simulate the actual guard eval() path: guard code strings use + # IsNonOverlappingAndDenseIndicator as a function name, and SYMPY_INTERP + # provides the binding in the eval namespace. + guard_code = "IsNonOverlappingAndDenseIndicator(2, 3, 4, 5, 60, 20, 5, 1) == 1" + result = eval(guard_code, SYMPY_INTERP) # noqa: P204 + self.assertTrue(result) + + guard_code_false = ( + "IsNonOverlappingAndDenseIndicator(2, 3, 4, 5, 100, 20, 5, 1) == 1" + ) + result_false = eval(guard_code_false, SYMPY_INTERP) # noqa: P204 + self.assertFalse(result_false) + def test_prims_is_non_overlapping_and_dense_or_false(self): shape_env = ShapeEnv() cf = torch._prims_common.is_non_overlapping_and_dense_or_false @@ -1350,7 +1393,7 @@ def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"): native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0] getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None - return (getitem, getitem_1)""", # noqa: B950 + return (getitem, getitem_1)""", ) def test_statically_known_true(self): @@ -2005,6 +2048,13 @@ def check(l, r): with self.assertRaises(ZeroDivisionError): y % x + # test pow operations preserve DynamicInt type + check(w**2, 1) # DynamicInt ** int + check(2**z, 4) # int ** DynamicInt + check(y**x, 1) # DynamicInt ** DynamicInt + check(pow(z, 2), 4) # pow(DynamicInt, int) + self.assertTrue(isinstance(pow(y, 3, 5), DynamicInt)) # pow with modulo + # math, numpy self.assertEqual(math.cos(x), y) self.assertEqual(math.prod([z, z], start=z), 8) @@ -3161,16 +3211,24 @@ def test_guards_gt_lt(self): guards = shape_env.produce_guards_expression([s0]) - self.assertTrue(shape_env.evaluate_guards_expression(guards, [size_hint(s0)])) - self.assertFalse(shape_env.evaluate_guards_expression(guards, [size_hint(s1)])) - self.assertFalse(shape_env.evaluate_guards_expression(guards, [size_hint(s2)])) + self.assertTrue( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s0)]) + ) + self.assertFalse( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s1)]) + ) + self.assertFalse( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s2)]) + ) def test_guards_float_print(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 3) guard_bool(2 / s0 == 2 / 3) guards = shape_env.produce_guards_expression([s0]) - self.assertTrue(shape_env.evaluate_guards_expression(guards, [size_hint(s0)])) + self.assertTrue( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s0)]) + ) @skipIfTorchDynamo("Not a TorchDynamo suitable test") @torch._dynamo.config.patch("capture_scalar_outputs", True) @@ -3314,8 +3372,12 @@ def test_guards_float_div(self): self.assertIn("math.trunc(", guards) self.assertIn("float(", guards) - self.assertTrue(shape_env.evaluate_guards_expression(guards, [size_hint(s0)])) - self.assertFalse(shape_env.evaluate_guards_expression(guards, [size_hint(s1)])) + self.assertTrue( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s0)]) + ) + self.assertFalse( + shape_env.evaluate_guards_expression(guards, [guarding_hint_or_throw(s1)]) + ) @unittest.skipIf( TEST_XPU, "Skipped on XPU" @@ -3511,6 +3573,28 @@ def test_has_free_symbols(self): self.assertTrue(has_free_symbols(sympy.sympify("a*2"))) self.assertTrue(has_free_symbols(sympy.sympify("a+b"))) + def test_iterate_exprs_dict(self): + """Test that _iterate_exprs handles dict values (e.g. from triton_kernel_wrapper_functional).""" + a, b = sympy.symbols("a b") + + # dict with tensor values and string keys — should not crash + t = torch.randn(3, 4) + result = list(_iterate_exprs({"Out": t})) + # concrete tensor has no symbolic exprs + self.assertEqual(len(result), 0) + + # dict with sympy keys — should iterate over both keys and values + result = list(_iterate_exprs({a: 1, b: 2})) + self.assertEqual(len(result), 2) + self.assertIn(a, result) + self.assertIn(b, result) + + # free_symbols works on dicts with sympy keys + self.assertEqual(free_symbols({a + b: 1}), {a, b}) + + # empty dict + self.assertEqual(list(_iterate_exprs({})), []) + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) @@ -3713,6 +3797,39 @@ def test_meta_copy(self): meta_copy_(self_tensor, src_tensor) + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_unbacked_symbool_propagate_real_tensors(self): + """ + Test that propagate_real_tensors properly handles SymBool from boolean .item() calls. + + When tracing with propagate_real_tensors=True, if a boolean .item() returns False + during tracing, a runtime assertion should be generated that throws when the + boolean becomes True at runtime. + + This is a regression test for a bug where int(real_t) was used instead of real_t, + causing boolean values to be incorrectly converted to integers and losing the + proper runtime assertion. + """ + + @torch.compile(backend="eager", fullgraph=True) + def f(x): + if x.eq(0.1).any().item(): + return x + return x + 1 + + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + # First call with 0.2 - item() returns False, traces the x + 1 branch + result1 = f(torch.ones(2) * 0.2) + torch.testing.assert_close(result1, torch.ones(2) * 1.2) + + # Second call with 0.1 - item() returns True, should throw runtime assertion + # because the traced graph assumed False + with self.assertRaisesRegex( + RuntimeError, + r"Runtime assertion failed for expression Ne\(u0, 1\)", + ): + f(torch.ones(2) * 0.1) + class TestUbackedOps(TestCase): @fresh_cache() @@ -3773,7 +3890,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None - return (mul_11, mul_14, clone)""", # noqa: B950 + return (mul_11, mul_14, clone)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -3814,7 +3931,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None - return (mul_6, mul_9, clone)""", # noqa: B950 + return (mul_6, mul_9, clone)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -3884,7 +4001,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None - return (mul_21,)""", # noqa: B950 + return (mul_21,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -3915,7 +4032,7 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, - """""", # noqa: B950 + """""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4018,7 +4135,7 @@ def f3(x, xs): sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1) ge_2: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_2 = None - return (slice_1,)""", # noqa: B950 + return (slice_1,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4308,7 +4425,7 @@ def func(x): clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None - return (mul_6,)""", # noqa: B950 + return (mul_6,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4336,7 +4453,7 @@ def func(x): _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_1 = None add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None - return (mul_5,)""", # noqa: B950 + return (mul_5,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4374,7 +4491,7 @@ def func(x, y): select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None - return (select, select_1, select_2)""", # noqa: B950 + return (select, select_1, select_2)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4525,7 +4642,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None _reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None - return (_reshape_copy,)""", # noqa: B950 + return (_reshape_copy,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4561,7 +4678,7 @@ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1] eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None _reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None - return (_reshape_copy,)""", # noqa: B950 + return (_reshape_copy,)""", ignore_comments=True, ignore_empty_lines=True, ) @@ -4838,15 +4955,21 @@ def fn(x): def test_hint_override_consistent_stride1(self): @torch.compile(fullgraph=True, dynamic=True) def func(x): - a = torch.fx.experimental.symbolic_shapes.size_hint(x.size()[2]) - b = torch.fx.experimental.symbolic_shapes.size_hint(x.stride()[1]) + a = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + x.size()[2] + ) + b = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + x.stride()[1] + ) torch._check(a == b) torch._check(a == 6) - a = torch.fx.experimental.symbolic_shapes.size_hint( + a = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( x.size()[1] * x.size()[2] ) - b = torch.fx.experimental.symbolic_shapes.size_hint(x.stride()[0]) + b = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + x.stride()[0] + ) torch._check(a == b) torch._check(a == 120) @@ -4864,10 +4987,12 @@ def test_hint_override_consistent_stride2(self): @torch.compile(fullgraph=True, dynamic=True) def func(x): # only one of the sizes has hint overridden. - a = torch.fx.experimental.symbolic_shapes.size_hint( + a = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( x.size()[1] * x.size()[2] ) - b = torch.fx.experimental.symbolic_shapes.size_hint(x.stride()[0]) + b = torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + x.stride()[0] + ) torch._check(a == b) torch._check(a == 24) @@ -4880,20 +5005,20 @@ def func(x): torch._dynamo.mark_dynamic(x, 2, hint_override=6) func(x) - def test_size_hint(self): + def test_optimization_hint(self): @torch.compile(fullgraph=True) def func(x): u0 = x.item() a = torch.ones([u0]) torch._check( - torch.fx.experimental.symbolic_shapes.size_hint( + torch.fx.experimental.symbolic_shapes.optimization_hint( a.size()[0], fallback=300 ) == 300 ) b = torch.ones([x.item() * 2]) torch._check( - torch.fx.experimental.symbolic_shapes.size_hint( + torch.fx.experimental.symbolic_shapes.optimization_hint( b.size()[0], fallback=300 ) == 600 @@ -4903,13 +5028,16 @@ def func(x): func(torch.tensor([33])) - def test_size_hint_no_fallback(self): + def test_guarding_hint_or_throw(self): @torch.compile(fullgraph=True) def func(x): u0 = x.item() a = torch.ones([u0]) torch._check( - torch.fx.experimental.symbolic_shapes.size_hint(a.size()[0]) == 300 + torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + a.size()[0] + ) + == 300 ) return a * 10 diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 18a4ad502e4fb..a543ea56258e7 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -10,6 +10,8 @@ import io import itertools import pickle +import subprocess +import sys import unittest import weakref from unittest.mock import patch @@ -64,7 +66,9 @@ skipIfCrossRef, skipIfTorchDynamo, skipIfWindows, + skipIfXpu, TemporaryFileName, + TEST_ACCELERATOR, TEST_WITH_TORCHDYNAMO, TestCase, xfailIfTorchDynamo, @@ -82,6 +86,8 @@ torch._dynamo.config.fake_tensor_cache_enabled = True torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True +device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu" + def expectedFailurePropagateRealTensors(fn): fn._expected_failure_propagate_real_tensors = True @@ -350,6 +356,7 @@ def test_device_inplace_copy(self): if y.copy_(x).device.type != "cuda": raise AssertionError("expected cuda device") + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_fake_device(self): t = torch.ones(3) t = t.view(1, 3) @@ -600,6 +607,16 @@ def fn(x): out_eager = fn(torch.empty((0,))) self.checkMetaProps(out_fake, out_eager) + def test_as_strided_negative_stride_error(self): + error = ( + r"as_strided: Negative strides are not supported at the " + r"moment, got strides: \[-?[0-9]+(, -?[0-9]+)*\]" + ) + with FakeTensorMode(): + x = torch.empty(0) + with self.assertRaisesRegex(RuntimeError, error): + torch.as_strided(x, (17, 18), (-80, 1), 1) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): with FakeTensorMode(allow_fallback_kernels=False): @@ -1244,6 +1261,123 @@ def test_unbind_copy_out(self): self.assertEqual(out[1].dtype, eye.dtype) self.assertEqual(out[2].dtype, eye.dtype) + @unittest.skipIf(not torch.cuda._is_compiled(), "requires CUDA-compiled PyTorch") + def test_fake_device_guard_no_use_after_free(self): + # Regression test: when CUDA is compiled but no devices are visible, + # FakeTensorMode installs a FakeGuardImpl into the global + # device_guard_impl_registry. The impl must outlive all threads that + # use it; previously it was stored as a thread-local unique_ptr, so + # the pointer became dangling after the thread exited, causing a + # segfault when a later thread called deviceCount() on it. + # + # CUDA_VISIBLE_DEVICES must be set before torch is imported, so we + # run the repro in a subprocess. + + script = """\ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "" +import threading +import torch + +def f(x): + return x * x + +def run_compile(): + g = torch.compile(f, backend="eager") + for i in range(10): + g(torch.randn(i)) + +threads = [threading.Thread(target=run_compile) for _ in range(2)] +for t in threads: + t.start() +for t in threads: + t.join() + +torch._dynamo.reset() + +threads = [threading.Thread(target=run_compile) for _ in range(2)] +for t in threads: + t.start() +for t in threads: + t.join() +""" + result = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, + timeout=60, + ) + self.assertEqual( + result.returncode, + 0, + msg=f"subprocess failed:\n{result.stderr.decode()}", + ) + + @unittest.skipIf( + TEST_ACCELERATOR, "Only execute when an accelerator is not present" + ) + def test_avoid_device_init_without_backends(self): + fake_mode = FakeTensorMode() + + self.assertTrue( + fake_mode.avoid_device_init, + "Expected avoid_device_init to return True when no backends are registered", + ) + + @parametrize("is_available", [False, True]) + @skipIfTorchDynamo( + "TorchDynamo exposes https://github.com/pytorch/pytorch/issues/166696" + ) + @unittest.skipIf( + TEST_ACCELERATOR, "Only execute when an accelerator is not present" + ) + def test_avoid_device_init_with_privateuse1_backend(self, is_available): + class _DummyPrivateUse1Module: + @staticmethod + def is_available() -> bool: + return is_available + + backend_name = "privateuseone" + + try: + torch._register_device_module(backend_name, _DummyPrivateUse1Module) + + fake_mode = FakeTensorMode() + + self.assertEqual(fake_mode.avoid_device_init, not is_available) + finally: + delattr(torch, backend_name) + del sys.modules[f"torch.{backend_name}"] + + def test_unique_output_dtype(self): + shape_env = ShapeEnv() + for input_dtype in [torch.float32, torch.float16, torch.bfloat16]: + x_real = torch.randn(10, dtype=input_dtype) + real_unique, real_inverse, real_counts = torch.unique( + x_real, return_inverse=True, return_counts=True + ) + with FakeTensorMode(shape_env=shape_env): + x = torch.randn(10, dtype=input_dtype) + fake_unique, fake_inverse, fake_counts = torch.unique( + x, return_inverse=True, return_counts=True + ) + self.assertEqual(fake_unique.dtype, real_unique.dtype) + self.assertEqual(fake_inverse.dtype, real_inverse.dtype) + self.assertEqual(fake_counts.dtype, real_counts.dtype) + + # Also test with dim argument + x_real = torch.randn(3, 4, dtype=torch.float32) + real_unique, real_inverse, real_counts = torch.unique( + x_real, dim=0, return_inverse=True, return_counts=True + ) + with FakeTensorMode(shape_env=shape_env): + x = torch.randn(3, 4, dtype=torch.float32) + fake_unique, fake_inverse, fake_counts = torch.unique( + x, dim=0, return_inverse=True, return_counts=True + ) + self.assertEqual(fake_unique.dtype, real_unique.dtype) + self.assertEqual(fake_inverse.dtype, real_inverse.dtype) + self.assertEqual(fake_counts.dtype, real_counts.dtype) + instantiate_parametrized_tests(FakeTensorTest) @@ -1499,6 +1633,62 @@ def test_no_ref_cycle(self): if y_weak() is not None: raise AssertionError("expected y_weak() is None") + def test_grad_dtype_preserved(self): + t = torch.randn(4, dtype=torch.bfloat16, requires_grad=True) + t.grad_dtype = torch.float32 + + mode = FakeTensorMode() + fake_t = mode.from_tensor(t) + self.assertEqual(fake_t.grad_dtype, torch.float32) + + def test_grad_dtype_none_preserved(self): + t = torch.randn(4, dtype=torch.bfloat16, requires_grad=True) + t.grad_dtype = None + + mode = FakeTensorMode() + fake_t = mode.from_tensor(t) + self.assertIsNone(fake_t.grad_dtype) + + def test_grad_dtype_functional_tensor_no_crash(self): + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + FunctionalTensorMode, + ) + + t = torch.randn(4, dtype=torch.bfloat16, requires_grad=True) + t.grad_dtype = torch.float32 + + mode = FakeTensorMode() + fake_t = mode.from_tensor(t) + self.assertEqual(fake_t.grad_dtype, torch.float32) + + with FunctionalTensorMode(): + func_t = FunctionalTensor.to_functional(fake_t) + # Re-fakifying a FunctionalTensor should not crash even though + # the inner tensor has a custom grad_dtype. + re_faked = mode.from_tensor(func_t) + self.assertTrue(re_faked.requires_grad) + + @skipIfTorchDynamo("make_fx tracing is incompatible with dynamo") + def test_grad_dtype_make_fx(self): + def train_step(w): + y = (w.float() * 2).sum() + (g,) = torch.autograd.grad(y, w) + return g + + w = torch.randn(4, dtype=torch.bfloat16, requires_grad=True) + w.grad_dtype = torch.float32 + + g_eager = train_step(w) + + fake_mode = FakeTensorMode() + w_fake = fake_mode.from_tensor(w) + with fake_mode: + gm = make_fx(train_step)(w_fake) + g_traced = gm(w) + + self.assertEqual(g_eager.dtype, g_traced.dtype) + make_propagate_real_tensors_cls(FakeTensorConverterTest) @@ -1644,6 +1834,7 @@ def test_cross_entropy_loss(self): self.assertEqual(ref.size(), meta_out.size()) + @skipIfXpu(msg="MetadataMismatchError, torch-xpu-ops: 2802") @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware", @@ -1660,14 +1851,14 @@ def forward(self, arg1, arg2, arg3): args_new = [ [ - ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), - ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), - ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, "cuda"), + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), + ((1, 48, 64, 64), (0, 4096, 64, 1), torch.float16, device_type), ], [ - ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), - ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), - ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, "cuda"), + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), + ((4, 2, 16, 32), (1024, 512, 32, 1), torch.float16, device_type), ], ] for args_list in args_new: @@ -2861,6 +3052,23 @@ def test_fake_tensor_prefer_device_type_cpu_only(self): self.assertTrue(isinstance(result, FakeTensor)) +class FakeTensorMetaDevicePropagation(TestCase): + @parametrize("device", ["cpu", "cuda"]) + def test_inplace_add_with_meta_rhs_keeps_destination_device(self, device): + if device == "cuda" and not RUN_CUDA: + self.skipTest("requires cuda") + + with FakeTensorMode(): + log_det = torch.zeros(2, device=device) + log_det += torch.zeros(2, device="meta") + + self.assertEqual(log_det.device.type, device) + self.assertTrue(isinstance(log_det, FakeTensor)) + + +instantiate_parametrized_tests(FakeTensorMetaDevicePropagation) + + class FakeTensorViewCopy(TestCase): def test_expand_then_view_copy_matches_eager_mode(self): x = torch.arange(7) diff --git a/test/test_flop_counter.py b/test/test_flop_counter.py index d06d295790233..912883173b8fa 100644 --- a/test/test_flop_counter.py +++ b/test/test_flop_counter.py @@ -20,6 +20,7 @@ TestCase, ) from torch.testing._internal.triton_utils import requires_cuda_and_triton +from torch.utils.flop_counter import sdpa_flop_count try: @@ -51,8 +52,6 @@ def T(*shape, requires_grad=False): class TestFlopCounter(TestCase): def test_sdpa_flop_count_gqa(self): """sdpa_flop_count should handle GQA where KV heads < Q heads.""" - from torch.utils.flop_counter import sdpa_flop_count - # MHA: q_heads == kv_heads q_shape = (2, 32, 128, 64) k_shape = (2, 32, 128, 64) @@ -1175,6 +1174,94 @@ def test_scaled_mm(self): self.assertExpectedInline(get_total_flops(mode), """860160""") + @unittest.skipIf(not HAS_CUDA, "CUDA not available") + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Flash attention not supported (pre-SM80 hardware on CUDA)", + ) + def test_varlen_attn(self): + import torch.nn.attention.varlen + + n_heads = 8 + head_dim = 64 + dtype = torch.float16 + seq_lens = [128, 64] + total_tokens = sum(seq_lens) + cu_seqs = torch.tensor( + [0] + list(torch.tensor(seq_lens).cumsum(0).tolist()), + dtype=torch.int32, + device="cuda", + ) + max_s = max(seq_lens) + + query = torch.randn( + total_tokens, + n_heads, + head_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + key = torch.randn( + total_tokens, + n_heads, + head_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + value = torch.randn( + total_tokens, + n_heads, + head_dim, + device="cuda", + dtype=dtype, + requires_grad=True, + ) + + mode = FlopCounterMode() + with mode: + out, _, _ = torch.ops.torch_attn._varlen_attn( + query, + key, + value, + cu_seqs, + cu_seqs, + max_s, + max_s, + is_causal=True, + ) + fw_flops = int(get_total_flops(mode)) + expected_fw = sum( + sdpa_flop_count( + (1, n_heads, s, head_dim), + (1, n_heads, s, head_dim), + (1, n_heads, s, head_dim), + ) + for s in seq_lens + ) + self.assertEqual(fw_flops, expected_fw) + # 2 bmms per sequence, each 2*h*s*d*s; total = 2048*(128^2 + 64^2) = 41943040 + self.assertExpectedInline(str(fw_flops), """41943040""") + + mode_bw = FlopCounterMode() + with mode_bw: + out, _, _ = torch.ops.torch_attn._varlen_attn( + query, + key, + value, + cu_seqs, + cu_seqs, + max_s, + max_s, + is_causal=True, + ) + out.sum().backward() + fw_bw_flops = int(get_total_flops(mode_bw)) + # fw=2 bmms, bw=5 bmms (flash recomputes scores), fw+bw = fw * 7/2 + self.assertEqual(fw_bw_flops, fw_flops * 7 // 2) + self.assertExpectedInline(str(fw_bw_flops), """146800640""") + if __name__ == "__main__": run_tests() diff --git a/test/test_foreach.py b/test/test_foreach.py index 5b6a84932f4e4..a1511eee1368a 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -859,12 +859,12 @@ def test_binary_op_float_inf_nan(self, device, dtype, op): # note: Below three tests (postfixed with `_tensors_on_different_devices`) # checks whether foreach works with lists of tensors on different devices - # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu]. + # but tensors of the same index are on the same device, e.g., ['cuda', 'cpu']. @onlyCUDA @ops(foreach_unary_op_db) def test_unary_op_tensors_on_different_devices(self, device, dtype, op): method, ref, inplace_method, ref_inplace = self._get_funcs(op) - # tensors: ['cuda', 'cpu] + # tensors: ['cuda', 'cpu'] tensors = next( iter( op.sample_inputs( @@ -876,26 +876,28 @@ def test_unary_op_tensors_on_different_devices(self, device, dtype, op): ) ).input tensors[1] = tensors[1].to("cpu") - if not op.supports_out: - try: - actual = method((tensors,), False, False, zero_size=False) - except RuntimeError as e: - with self.assertRaisesRegex(type(e), str(e).splitlines()[0]): - ref((tensors,)) - else: - expected = ref((tensors,)) - self.assertEqual(expected, actual) try: - inplace_method((tensors,), False, False, zero_size=False) + actual = method((tensors,), False, False, zero_size=False) except RuntimeError as e: with self.assertRaisesRegex(type(e), str(e).splitlines()[0]): - ref_inplace((tensors,)) + ref((tensors,)) else: - if not op.supports_out: - self.assertEqual(expected, tensors) + expected = ref((tensors,)) + self.assertEqual(expected, actual) + + # Some foreach functions (e.g. _foreach_clone) don't have an inplace variant, so + # we explicitly test for that here. + if not inplace_method.is_inplace: + self.assertIsNone(ref_inplace.func) + else: + try: + inplace_method((tensors,), False, False, zero_size=False) + except RuntimeError as e: + with self.assertRaisesRegex(type(e), str(e).splitlines()[0]): + ref_inplace((tensors,)) else: - self.assertEqual([torch.zeros_like(t) for t in tensors], tensors) + self.assertEqual(expected, tensors) @onlyCUDA @ops(filter(lambda op: op.supports_out, foreach_binary_op_db)) @@ -1057,11 +1059,11 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty): # foreach_max cannot handle empty tensors as max requires an identity intersperse_empty_tensors = w_empty and op.name != "_foreach_max" - N = 600 + N = 4000 indices_with_empty_tensors = ( set() if not intersperse_empty_tensors - else {200, 300, 301, 400, 401, 402, 404, 598} + else {200, 1500, 1501, 2800, 2801, 2802, 3500, 3998} ) tensorlist = [ make_tensor((2, 3), dtype=dtype, device=device, noncontiguous=False) @@ -1074,7 +1076,7 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty): import math if op.name == "_foreach_norm": - ords = [1, 2] + ords = [0, 1, 2] if not intersperse_empty_tensors: # inf norm over an empty tensor is not defined by vector norm as it expects an identity ords.append(math.inf) @@ -1109,7 +1111,7 @@ def test_big_num_tensors(self, device, dtype, op, use_cuda_graph, w_empty): @ops(foreach_reduce_op_db) @parametrize("w_empty", (False, True)) def test_foreach_reduce_large_input(self, device, dtype, op, w_empty): - # test inputs larger than kChunkSize (65536) * max_num_blocks (320) + # test inputs larger than kChunkSize (65536) * max_num_blocks (2240 for the 32kb config) N = 65536 * 320 * 2 disable_fastpath = False kwargs = {} @@ -1168,6 +1170,25 @@ def test_foreach_norm_empty_tensor_inf_error(self, device, dtype, op): self.assertEqual(result_l1[0].item(), 0.0) self.assertEqual(result_l2[0].item(), 0.0) + @ops( + [o for o in foreach_reduce_op_db if o.name == "_foreach_max"], + ) + def test_foreach_max_empty_tensor_error(self, device, dtype, op): + """Test that _foreach_max errors on empty tensors""" + # Test with single empty tensor + empty_tensor = torch.empty(0, dtype=dtype, device=device) + err_re = ( + "_foreach_max cannot compute the maximum of an empty tensor; " + "max over zero elements is undefined\\." + ) + with self.assertRaisesRegex(RuntimeError, err_re): + torch._foreach_max([empty_tensor]) + + # Test with mixed empty and non-empty tensors + non_empty_tensor = make_tensor((4,), dtype=dtype, device=device) + with self.assertRaisesRegex(RuntimeError, err_re): + torch._foreach_max([empty_tensor, non_empty_tensor]) + @onlyCUDA @ops( foreach_unary_op_db diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 65e74297a531f..87034ca34338b 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -270,7 +270,7 @@ def forward(self, arg0_1): detach_copy_1 = torch.ops.aten.detach_copy.default(view_copy_11); view_copy_11 = None return detach_copy_1 """, - ) # noqa: B950 + ) def test_simple(self): def f(x): @@ -561,7 +561,7 @@ def forward(self, arg0_1): getitem_5 = _fused_moving_avg_obs_fq_helper_functional[5]; _fused_moving_avg_obs_fq_helper_functional = None copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_5); arg0_1 = getitem_5 = copy_ = None return (getitem, getitem_1) - """, # noqa: B950 + """, ) def test_as_strided(self): @@ -820,7 +820,7 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_copy_1 """, - ) # noqa: B950 + ) # NB: even with reapply_views=True, we expect to see scatter op reinplaced_logs = self.get_logs( @@ -852,7 +852,7 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_1 """, - ) # noqa: B950 + ) def test_split_with_sizes(self): def f(x): @@ -892,7 +892,7 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_copy_1 """, - ) # noqa: B950 + ) # NB: even with reapply_views=True, we expect to see scatter op reinplaced_logs = self.get_logs( @@ -924,7 +924,7 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, slice_scatter); arg0_1 = slice_scatter = copy_ = None return diagonal_1 """, - ) # noqa: B950 + ) def test_slice(self): def f(x): @@ -955,7 +955,7 @@ def forward(self, arg0_1): transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, - ) # noqa: B950 + ) # NB: even with reapply_views=True, we expect to see scatter op reinplaced_logs = self.get_logs( @@ -980,7 +980,7 @@ def forward(self, arg0_1): transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, - ) # noqa: B950 + ) def test_view_inplace(self): def f(x): @@ -1012,7 +1012,7 @@ def forward(self, arg0_1): transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, - ) # noqa: B950 + ) # NB: even with reapply_views=True, we expect to see scatter op reinplaced_logs = self.get_logs( @@ -1037,7 +1037,7 @@ def forward(self, arg0_1): transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, - ) # noqa: B950 + ) def test_unbind(self): def f(x): @@ -1073,7 +1073,7 @@ def forward(self, arg0_1): transpose_copy_4 = torch.ops.aten.transpose_copy.int(transpose_copy_2, 1, 0); transpose_copy_2 = None return transpose_copy_4 """, - ) # noqa: B950 + ) # NB: even with reapply_views=True, we expect to see scatter op reinplaced_logs = self.get_logs( @@ -1102,7 +1102,7 @@ def forward(self, arg0_1): transpose_4 = torch.ops.aten.transpose.int(transpose_2, 1, 0); transpose_2 = None return transpose_4 """, - ) # noqa: B950 + ) def test_optional_tensor_list(self): def f(x): @@ -1132,7 +1132,7 @@ def forward(self, arg0_1): copy_ = torch.ops.aten.copy_.default(arg0_1, view_copy_1); arg0_1 = view_copy_1 = copy_ = None return view_copy_2 """, - ) # noqa: B950 + ) def test_scalars(self): def f(x): @@ -1205,7 +1205,7 @@ def forward(self, arg0_1): _to_copy = torch.ops.aten._to_copy.default(ge, dtype = torch.float32, layout = torch.strided); ge = None return _to_copy """, - ) # noqa: B950 + ) @skipIfTorchDynamo("Test does not work with TorchDynamo") def test_metadata_change_out_op(self): @@ -1318,7 +1318,7 @@ def forward(self, arg0_1): add_2 = torch.ops.aten.add.Tensor(select_copy_1, view_copy_13); select_copy_1 = view_copy_13 = add_2 = None return getitem_2 """, - ) # noqa: B950 + ) reinplaced_logs = self.get_logs( f, torch.ones(4, 2), reapply_views=True, run_reinplace=True @@ -1535,7 +1535,7 @@ def forward(self, arg0_1): diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 """, - ) # noqa: B950 + ) reinplaced_logs = self.get_logs( f, torch.ones(2, dtype=torch.long), reapply_views=True, run_reinplace=True @@ -1555,7 +1555,7 @@ def forward(self, arg0_1): diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, - ) # noqa: B950 + ) # Test 4: copy_() with different dtype, different shape self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) @@ -1577,7 +1577,7 @@ def forward(self, arg0_1): diagonal_copy_2 = torch.ops.aten.diagonal_copy.default(diagonal_scatter_1); diagonal_scatter_1 = None return diagonal_copy_2 """, - ) # noqa: B950 + ) reinplaced_logs = self.get_logs( f, torch.ones(1, dtype=torch.long), reapply_views=True, run_reinplace=True @@ -1597,7 +1597,7 @@ def forward(self, arg0_1): diagonal_2 = torch.ops.aten.diagonal.default(zeros); zeros = None return diagonal_2 """, - ) # noqa: B950 + ) def test_expand_symint(self): # Once some existing SymInt bugs are ironed out, we should update @@ -1700,7 +1700,7 @@ def forward(self, arg0_1): as_strided_copy_3 = torch.ops.aten.as_strided_copy.default(view_copy_7, [3, 3], [3, 1]); view_copy_7 = None add_2 = torch.ops.aten.add.Tensor(as_strided_copy_3, 1); as_strided_copy_3 = None return add_2 - """, # noqa: B950 + """, ) reinplaced_logs = self.get_logs( @@ -1888,7 +1888,7 @@ def forward(self, arg0_1): select_copy_1 = torch.ops.aten.select_copy.int(select_scatter, 0, 5); select_copy_1 = None return select_scatter """, - ) # noqa: B950 + ) reinplaced_logs = self.get_logs( f, torch.ones(2), reapply_views=True, run_reinplace=True @@ -1970,7 +1970,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_copy_1); arg1_1 = alias_copy_1 = copy_ = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_copy_4); arg2_1 = alias_copy_4 = copy__1 = None return view_copy_5 - """, # noqa: B950 + """, ) reinplaced_logs = self.get_logs( @@ -2016,7 +2016,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, alias_1); arg1_1 = alias_1 = copy_ = None copy__1 = torch.ops.aten.copy_.default(arg2_1, alias_4); arg2_1 = alias_4 = copy__1 = None return view_5 - """, # noqa: B950 + """, ) def test_mutation_overlapping_mem(self): @@ -2064,7 +2064,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None return getitem - """, # noqa: B950 + """, ) reinplaced_logs = self.get_logs( @@ -2092,7 +2092,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): copy_ = torch.ops.aten.copy_.default(arg1_1, getitem_3); arg1_1 = getitem_3 = copy_ = None copy__1 = torch.ops.aten.copy_.default(arg2_1, getitem_4); arg2_1 = getitem_4 = copy__1 = None return getitem - """, # noqa: B950 + """, ) # This tests our python shims around C++ Functionalization: FunctionalTensor and FunctionalTensorMode diff --git a/test/test_fx.py b/test/test_fx.py index 7e0cf1138b1f5..9a7cea8cd269e 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -71,9 +71,12 @@ find_library_location, IS_FBCODE, IS_MACOS, + IS_ARM64, + IS_LINUX, IS_WINDOWS, run_tests, skipIfTorchDynamo, + xfailIf, ) from torch.testing._internal.jit_utils import JitTestCase @@ -371,7 +374,7 @@ def forward(self, *args, **kwargs): def test_args_kwargs_no_self(self): class T(torch.nn.Module): - def forward(*args, **kwargs): # noqa: B902 + def forward(*args, **kwargs): self = args[0] return torch.relu(args[1]) @@ -2054,6 +2057,7 @@ def is_leaf_module(self, m, module_qualified_name): f"got {tensor_meta[1].shape}" ) + @xfailIf(IS_ARM64 and IS_LINUX) # RuntimeError: label is too far def test_shape_prop_layout_3d(self): class ConvTest3d(torch.nn.Module): def __init__(self) -> None: @@ -4427,6 +4431,7 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + # This only fails on navi31 @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @torch.fx.experimental._config.patch("enrich_profiler_metadata", True) @blas_library_context("cublaslt") @@ -4478,8 +4483,29 @@ def forward(self, x): else: kernel_event = "cudaLaunchKernel" kernel_event_relu = "cudaLaunchKernel" - - expected = f"""\ + if IS_WINDOWS: + expected = f"""\ +event=aten::t node=t stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::transpose node=t stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::as_strided node=t stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::addmm node=addmm stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::expand node=addmm stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::as_strided node=addmm stack_trace=return F.linear(input, self.weight, self.bias) +event={kernel_event} node=addmm stack_trace=return F.linear(input, self.weight, self.bias) +event={kernel_event} node=addmm stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::relu node=relu stack_trace=return F.relu(input, inplace=self.inplace) +event=aten::clamp_min node=relu stack_trace=return F.relu(input, inplace=self.inplace) +event={kernel_event_relu} node=relu stack_trace=return F.relu(input, inplace=self.inplace) +event=aten::t node=t_1 stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::transpose node=t_1 stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::as_strided node=t_1 stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::addmm node=addmm_1 stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::expand node=addmm_1 stack_trace=return F.linear(input, self.weight, self.bias) +event=aten::as_strided node=addmm_1 stack_trace=return F.linear(input, self.weight, self.bias) +event={kernel_event} node=addmm_1 stack_trace=return F.linear(input, self.weight, self.bias) +event={kernel_event} node=addmm_1 stack_trace=return F.linear(input, self.weight, self.bias)""" + else: + expected = f"""\ event=aten::t node=t stack_trace=x = self.linear1(x) event=aten::transpose node=t stack_trace=x = self.linear1(x) event=aten::as_strided node=t stack_trace=x = self.linear1(x) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index d63a916971d55..0ff6bb62ab53d 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -45,6 +45,7 @@ from torch.fx.passes import graph_manipulation from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes from torch.fx.passes.shape_prop import ShapeProp +from torch.fx._lazy_graph_module import _use_lazy_graph_module from torch.fx.passes.split_module import split_module from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes from torch.testing._internal.common_device_type import ( @@ -54,7 +55,14 @@ ) from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_nn import module_tests, get_new_module_tests -from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase, TEST_WITH_CROSSREF +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + TEST_Z3, + run_tests, + TestCase, + TEST_WITH_CROSSREF, +) from torch.testing._internal.jit_utils import JitTestCase import torch.utils._pytree as pytree @@ -754,7 +762,8 @@ def forward(self, a, b): # Confirm that the output is correct self.assertEqual(traced(3, 3), m(3, 3)) - def test_subgraph_creation(self): + @parametrize("use_lazy", [True, False]) + def test_subgraph_creation(self, use_lazy): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -786,9 +795,10 @@ def mod_partition(node: Node): return partition # split module in module with submodules - module_with_submodules = split_module( - my_module_traced, my_module, mod_partition - ) + with _use_lazy_graph_module(use_lazy): + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) # Check that test_meta_info was still on all nodes. submodules = dict(module_with_submodules.named_modules()) @@ -897,6 +907,40 @@ def cb(_): else: raise RuntimeError("Expected the subgraph to have an output node.") + def test_split_module_tuple_return(self): + from torch._inductor.compile_fx import graph_returns_tuple + + class M(torch.nn.Module): + def forward(self, x, y): + a = x + y + return a * x + + gm = torch.fx.symbolic_trace(M()) + + # Assign ops to different partitions so a single-output submodule exists. + def partition_fn(node): + return 0 if node.target == operator.add else 1 + + # Without tuple_return: single-output submodules return a bare value. + sp = split_module(gm, None, partition_fn) + self.assertTrue( + any( + not graph_returns_tuple(submod) + for submod in sp.children() + ), + "expected at least one non-tuple-returning submodule", + ) + x, y = torch.randn(4), torch.randn(4) + self.assertEqual(sp(x, y), gm(x, y)) + + # With tuple_return: all submodules return a tuple. + sp_boxed = split_module(gm, None, partition_fn, tuple_return=True) + self.assertTrue( + all(graph_returns_tuple(submod) for submod in sp_boxed.children()), + "all submodules should return a tuple with tuple_return=True", + ) + self.assertEqual(sp_boxed(x, y), gm(x, y)) + def test_split_module_kwargs_expansion(self): class ModuleWithKwargsExpansion(torch.nn.Module): @@ -2152,6 +2196,7 @@ def test_z3str(self): self.assertEqual(z3str(expr), expected) +instantiate_parametrized_tests(TestFXExperimental) instantiate_device_type_tests(TestNormalizeOperators, globals()) if __name__ == "__main__": diff --git a/test/test_fx_reinplace_pass.py b/test/test_fx_reinplace_pass.py index 8837cea3535c4..7a45445c9d8f6 100644 --- a/test/test_fx_reinplace_pass.py +++ b/test/test_fx_reinplace_pass.py @@ -267,7 +267,7 @@ def forward(self, a__1): select_int = torch.ops.aten.select.int(as_strided, 0, 0) copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None return as_strided - """) # noqa: B950 + """) def test_reinplace_scatter_twice_with_different_view_op_invalid2(self): def f(a_): @@ -299,7 +299,7 @@ def forward(self, a__1): select_int = torch.ops.aten.select.int(as_strided, 0, 1) copy__default = torch.ops.aten.copy_.default(select_int, add); select_int = add = copy__default = None return as_strided - """) # noqa: B950 + """) def test_out_node_updated(self): diff --git a/test/test_indexing.py b/test/test_indexing.py index e7ce521a79639..87c3ddbc56e74 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -32,6 +32,7 @@ all_types_and, all_types_and_complex_and, all_types_complex_float8_and, + highest_precision_float, ) from torch.testing._internal.common_utils import ( DeterministicGuard, @@ -1217,7 +1218,7 @@ def func1(x, i, v): @onlyNativeDeviceTypes def test_index_put_accumulate_duplicate_indices(self, device): - dtype = torch.float if device.startswith("mps") else torch.double + dtype = highest_precision_float(device) for i in range(1, 512): # generate indices by random walk, this will create indices with # lots of duplicates interleaved with each other diff --git a/test/test_inspect_utils.py b/test/test_inspect_utils.py new file mode 100644 index 0000000000000..b9480b62507b8 --- /dev/null +++ b/test/test_inspect_utils.py @@ -0,0 +1,150 @@ +# Owner(s): ["module: fx"] + +import inspect + +from torch.testing._internal.common_utils import run_tests, TestCase +from torch.utils._inspect import _fast_bind + + +class TestFastBind(TestCase): + def _assert_fast_bind_matches_sig_bind(self, sig, args, kwargs): + try: + ref = sig.bind(*args, **kwargs) + except TypeError: + # _fast_bind should raise the same error + with self.assertRaises(TypeError): + _fast_bind(sig, *args, **kwargs) + return + # Success path: compare BoundArguments + got = _fast_bind(sig, *args, **kwargs) + self.assertEqual(ref.arguments, got.arguments) + self.assertEqual(ref.args, got.args) + self.assertEqual(ref.kwargs, got.kwargs) + + # Also validate default population matches + ref.apply_defaults() + got.apply_defaults() + self.assertEqual(ref.arguments, got.arguments) + self.assertEqual(ref.args, got.args) + self.assertEqual(ref.kwargs, got.kwargs) + + def test_positional_or_keyword_and_defaults(self): + def f(a, b=1, c=2): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (10,), {"c": 3}) + self._assert_fast_bind_matches_sig_bind(sig, (10, 20), {}) + self._assert_fast_bind_matches_sig_bind(sig, (), {"a": 1}) + + def test_missing_required(self): + def f(a, b): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (), {}) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {}) + + def test_too_many_positional(self): + def f(a, b=1): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (1, 2, 3), {}) + + def test_multiple_values_for_argument(self): + def f(a, b): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {"a": 2, "b": 3}) + + def test_unexpected_keyword(self): + def f(a, b=1): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {"c": 3}) + + def test_keyword_only(self): + def f(a, *, b, c=1): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {"b": 2}) + self._assert_fast_bind_matches_sig_bind(sig, (1, 2), {}) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {}) + + def test_positional_only(self): + # def f(x, /, y, *, z=1): ... + sig = inspect.Signature( + [ + inspect.Parameter("x", inspect.Parameter.POSITIONAL_ONLY), + inspect.Parameter("y", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("z", inspect.Parameter.KEYWORD_ONLY, default=1), + ] + ) + self._assert_fast_bind_matches_sig_bind(sig, (1, 2), {"z": 3}) + self._assert_fast_bind_matches_sig_bind(sig, (), {"x": 1, "y": 2}) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {"x": 2, "y": 3}) + + def test_from_keyword_positional_only_pattern(self): + # Mirrors a common TorchScript schema normalization pattern in operator_schemas.py + # where "from" is treated as positional-only. + sig = inspect.Signature( + [ + inspect.Parameter("input", inspect.Parameter.POSITIONAL_ONLY), + inspect.Parameter( + "from", inspect.Parameter.POSITIONAL_ONLY, default=0.0 + ), + inspect.Parameter( + "to", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=1.0 + ), + inspect.Parameter( + "generator", + inspect.Parameter.KEYWORD_ONLY, + default=None, + ), + ] + ) + self._assert_fast_bind_matches_sig_bind(sig, ("t",), {}) + self._assert_fast_bind_matches_sig_bind(sig, ("t", 0.25), {}) + self._assert_fast_bind_matches_sig_bind( + sig, ("t", 0.25, 0.75), {"generator": None} + ) + self._assert_fast_bind_matches_sig_bind(sig, (), {"input": "t", "from": 0.1}) + + def test_varargs_and_varkw_fallback(self): + def f(a, *args, b=0, **kwargs): + pass + + sig = inspect.signature(f) + # Extra positional should be captured by *args; extra keywords by **kwargs + self._assert_fast_bind_matches_sig_bind(sig, (1, 2, 3), {"b": 4, "x": 5}) + self._assert_fast_bind_matches_sig_bind(sig, (), {"a": 1, "x": 2}) + + def test_zero_params(self): + def f(): + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (), {}) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {}) # too many args + self._assert_fast_bind_matches_sig_bind(sig, (), {"x": 1}) # unexpected kwarg + self._assert_fast_bind_matches_sig_bind( + sig, (1,), {"b": 1} + ) # unexpected arg & kwarg + + def test_mutable_defaults(self): + # This test case reproduces an issue where unhashable default values (like list) + # caused _fast_bind to fail because of lru_cache on _signature_metadata + def f(a, b=[]): # noqa: B006 + pass + + sig = inspect.signature(f) + self._assert_fast_bind_matches_sig_bind(sig, (1,), {}) + self._assert_fast_bind_matches_sig_bind(sig, (1, [2]), {}) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_jit.py b/test/test_jit.py index 23ad6d1b61b0d..4fa07cb29cc5d 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -103,11 +103,11 @@ # Testing utils from torch.testing._internal import jit_utils from torch.testing._internal.common_jit import check_against_reference -from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, \ +from torch.testing._internal.common_utils import run_tests, IS_ARM64, IS_WINDOWS, \ GRAPH_EXECUTOR, suppress_warnings, IS_SANDCASTLE, ProfilingMode, \ TestCase, freeze_rng_state, slowTest, TemporaryFileName, \ enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ - skipIfCrossRef, skipIfTorchDynamo + skipIfCrossRef, skipIfTorchDynamo, xfailIf from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ _trace, do_input_map, get_execution_plan, make_global, \ execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ @@ -3135,7 +3135,7 @@ def forward(self, x): def test_oneline_func(self): - def fn(x): return x # noqa: E704 + def fn(x): return x self.checkScript(fn, (torch.ones(2, 2), )) @@ -3261,8 +3261,8 @@ def fct_loop(x): def test_no_self_arg_ignore_function(self): class MyModule(nn.Module): - @torch.jit.ignore # noqa: B902 - def call_np(): # noqa: B902 + @torch.jit.ignore + def call_np(): # type: () -> int return np.random.choice(2, p=[.95, .05]) @@ -3870,7 +3870,7 @@ def invalid4(a): def test_calls_in_type_annotations(self): with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): def spooky(a): - # type: print("Hello") -> Tensor # noqa: F723 + # type: print("Hello") -> Tensor return a + 2 print(torch.__file__) torch.jit.annotations.get_signature(spooky, None, 1, True) @@ -6122,7 +6122,7 @@ def test_not_cast(x): self.checkScript(test_not_cast, (torch.tensor(1),)) self.checkScript(test_not_cast, (torch.tensor(0),)) - with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605 + with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): @torch.jit.script def test_mult(x, y): return not (x, y) @@ -6147,7 +6147,7 @@ def test_cast_float(x): self.checkScript(test_cast_float, (0.,)) self.checkScript(test_cast_float, (-1.,)) - with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"): # noqa: W605 + with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"): @torch.jit.script def test_bad_conditional(x): @@ -11393,8 +11393,8 @@ def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int: def test_method_no_self(self): with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'): class MethodNoSelf(torch.jit.ScriptModule): - @torch.jit.script_method # noqa: B902 - def forward(): # noqa: B902 + @torch.jit.script_method + def forward(): return torch.zeros(3, 4) MethodNoSelf() @@ -11751,7 +11751,7 @@ def test_list_comprehension_variable_write(self): # i in comprehension doesn't write to function scope def foo(): i = 1 - x = [i if i != 5 else 3 for i in range(7)] # noqa: C416 + x = [i if i != 5 else 3 for i in range(7)] return i, x self.assertEqual(foo(), torch.jit.script(foo)()) @@ -13004,7 +13004,7 @@ def bad_type_line(a, # type: Tensor c # type: Tensor ): # type: (int, int, int) -> Tensor - # type: bad type line # noqa: F723 + # type: bad type line return a + b + c @@ -14362,7 +14362,7 @@ def test_non_primitive_types(x): self.assertEqual(out, torch.tensor(6.0)) def test_namedtuple_type_inference(self): - _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) # noqa: UP014 + _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) def test_check_named_tuple_value(): @@ -14510,12 +14510,12 @@ def test_function_overloads(self): # decorators. This is fixed on master but not on version 2.1.1. # Next version update remove noqa and add @typing.overload annotation - @torch.jit._overload # noqa: F811 - def test_simple(x1): # noqa: F811 + @torch.jit._overload + def test_simple(x1): # type: (int) -> int pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def test_simple(x1): # noqa: F811 # type: (float) -> float pass @@ -14537,8 +14537,8 @@ def invoke_function(): old_func = test_simple # testing that new functions added work with caching - @torch.jit._overload # noqa: F811 - def test_simple(x1): # noqa: F811 + @torch.jit._overload + def test_simple(x1): # type: (str) -> str pass @@ -14547,8 +14547,8 @@ def my_func(): return old_func("hi") # testing new function same qualified name - @torch.jit._overload # noqa: F811 - def test_simple(a, b): # noqa: F811 + @torch.jit._overload + def test_simple(a, b): # type: (int, int) -> int pass @@ -14564,12 +14564,12 @@ def fn(): # currently we take the default values have to be specified in the # overload as well - TODO take them from implementation and apply # where the type is valid. - @torch.jit._overload # noqa: F811 - def identity(x1): # noqa: F811 + @torch.jit._overload + def identity(x1): # type: (str) -> str pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def identity(x1): # noqa: F811 # type: (float) -> float pass @@ -14596,12 +14596,12 @@ def schema_match_failure(): with self.assertRaisesRegex(Exception, "cannot be directly compiled"): torch.jit.script(identity) - @torch.jit._overload # noqa: F811 - def impl_compile_failure(x, y): # noqa: F811 + @torch.jit._overload + def impl_compile_failure(x, y): # type: (str, str) -> (str) pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def impl_compile_failure(x, y): # noqa: F811 # type: (int, int) -> (int) pass @@ -14616,8 +14616,8 @@ def test(): with self.assertRaisesRegex(Exception, "Arguments for call are not valid"): torch.jit.script(test) - @torch.jit._overload # noqa: F811 - def good_overload(x=1): # noqa: F811 + @torch.jit._overload + def good_overload(x=1): # type: (int) -> (int) pass @@ -14632,8 +14632,8 @@ def foo(): with self.assertRaisesRegex(Exception, "must equal to the default parameter"): - @torch.jit._overload # noqa: F811 - def bad_default_on_overload(x, y=2): # noqa: F811 + @torch.jit._overload + def bad_default_on_overload(x, y=2): # type: (int, int) -> (int) pass @@ -14645,12 +14645,12 @@ def bad_default_on_overload(x, y=1): # noqa: F811 def test(): return bad_default_on_overload(1, 2) - @torch.jit._overload # noqa: F811 - def diff_default(x): # noqa: F811 + @torch.jit._overload + def diff_default(x): # type: (int) -> int pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def diff_default(x): # noqa: F811 # type: (str) -> str pass @@ -14663,12 +14663,12 @@ def test(): self.assertEqual(test(), torch.jit.script(test)()) - @torch.jit._overload # noqa: F811 - def diff_num_params(x): # noqa: F811 + @torch.jit._overload + def diff_num_params(x): # type: (float) -> float pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def diff_num_params(x, y): # noqa: F811 # type: (int, int) -> int pass @@ -14682,7 +14682,7 @@ def test(): self.assertEqual(test(), torch.jit.script(test)()) - @torch.jit._overload # noqa: F811 + @torch.jit._overload def diff_num_params_no_annot(): # type: () -> int pass @@ -14709,9 +14709,9 @@ def method(self): return 0 @torch.jit._overload - def null_overload(x: int) -> int: ... # noqa: E704 + def null_overload(x: int) -> int: ... - @torch.jit._overload # noqa: F811 + @torch.jit._overload def null_overload(x: str) -> str: # noqa: F811 pass @@ -14726,7 +14726,7 @@ class OverloadMisuse(torch.nn.Module): def forward(self, x: int): pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def forward(self, x: Tensor): # noqa: F811 pass @@ -14754,12 +14754,12 @@ def forward(self, x): self.assertEqual(out2, ref_out) def test_function_overloading_isinstance(self): - @torch.jit._overload # noqa: F811 - def my_conv(x, y): # noqa: F811 + @torch.jit._overload + def my_conv(x, y): # type: (float, str) -> (float) pass - @torch.jit._overload # noqa: F811 + @torch.jit._overload def my_conv(x, y): # noqa: F811 # type: (float, float) -> (float) pass @@ -14780,12 +14780,12 @@ def test_uses(): def test_method_overloading(self): class Over(torch.nn.Module): - @torch.jit._overload_method # noqa: F811 - def forward(self, x): # noqa: F811 + @torch.jit._overload_method + def forward(self, x): # type: (Tuple[Tensor, Tensor]) -> Tensor pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def forward(self, x): # noqa: F811 # type: (Tensor) -> Tensor pass @@ -14814,11 +14814,11 @@ def forward(self, x): self.assertEqual(over(x), x + 20) class Unannotated(torch.nn.Module): - @torch.jit._overload_method # noqa: F811 - def hello(self, x): # noqa: F811 + @torch.jit._overload_method + def hello(self, x): pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def hello(self, x): # noqa: F811 # type: (int) -> (int) pass @@ -14834,12 +14834,12 @@ def forward(self): torch.jit.script(w) class CompileOverloadError(torch.nn.Module): - @torch.jit._overload_method # noqa: F811 - def hello(self, x): # noqa: F811 + @torch.jit._overload_method + def hello(self, x): # type: (str) -> (int) pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def hello(self, x): # noqa: F811 # type: (int) -> (int) pass @@ -14858,12 +14858,12 @@ def forward(self): if sys.version_info < (3, 13): # test broken in 3.13 with self.assertRaisesRegex(Exception, "Overloads are not usable when a module"): class W3(torch.nn.Module): - @torch.jit._overload_method # noqa: F811 - def forward(self, x): # noqa: F811 + @torch.jit._overload_method + def forward(self, x): # type: (int) -> int pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def forward(self, x): # noqa: F811 # type: (Tensor) -> Tensor pass @@ -14875,7 +14875,7 @@ def forward(self, x): # noqa: F811 b = torch.jit.script(a) class W3(torch.nn.Module): - def forward(self, x): # noqa: F811 + def forward(self, x): return x + 5 + 10 a = W3() @@ -14893,11 +14893,11 @@ def forward(self, x): self.assertEqual(a(torch.tensor(1)), torch.tensor(2)) class W2(torch.nn.Module): - @torch.jit._overload_method # noqa: F811 - def hello(self, x): # noqa: F811 + @torch.jit._overload_method + def hello(self, x): pass - @torch.jit._overload_method # noqa: F811 + @torch.jit._overload_method def hello(self, x): # noqa: F811 # type: (int) -> (int) pass @@ -15634,7 +15634,7 @@ def foo(self): return 1 @torch.jit._overload_method - def hi(self, x: Tensor): ... # noqa: E704 + def hi(self, x: Tensor): ... def hi(self, x): # noqa: F811 return 2 @@ -15796,6 +15796,8 @@ def split_two(tensor): y = torch.randn(3, 6) self.checkScript(split_two, [(x + y)]) + @xfailIf(IS_ARM64) + # see https://github.com/pytorch/pytorch/issues/177255 def test_conv_error(self): @torch.jit.script def fn(x, y): diff --git a/test/test_jit_autocast.py b/test/test_jit_autocast.py index 93d03be9bae95..076d71a0f9901 100644 --- a/test/test_jit_autocast.py +++ b/test/test_jit_autocast.py @@ -6,7 +6,15 @@ import sys import unittest from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import parse_cmd_line_args, run_tests, skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_LINUX, + IS_CPU_EXT_SVE_SUPPORTED, + parse_cmd_line_args, + run_tests, + skipIfTorchDynamo, + xfailIf, +) from torch.testing import FileCheck from jit.test_models import MnistNet @@ -808,6 +816,8 @@ def test_generate_autocast_jit_trace_model(model, x): for i in range(self.models.__len__()): test_generate_autocast_jit_trace_model(self.models[i], self.inputs[i]) + @xfailIf(IS_ARM64 and IS_LINUX and not IS_CPU_EXT_SVE_SUPPORTED) + # see https://github.com/pytorch/pytorch/issues/177247 def test_nchw_autocast_jit_trace_model(self): def test_nchw_autocast_jit_trace_model(model, x): model.eval() @@ -822,6 +832,8 @@ def test_nchw_autocast_jit_trace_model(model, x): for i in range(self.models.__len__()): test_nchw_autocast_jit_trace_model(self.models[i], self.inputs[i]) + @xfailIf(IS_ARM64 and IS_LINUX and not IS_CPU_EXT_SVE_SUPPORTED) + # see https://github.com/pytorch/pytorch/issues/177247 def test_nhwc_autocast_jit_trace_model(self): def test_nhwc_autocast_jit_trace_model(model, x): model = model.to(memory_format=torch.channels_last) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c99c6ee2f4ce3..af6532d3ce1da 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -61,7 +61,6 @@ skipIfTorchDynamo, slowTest, TEST_WITH_ASAN, - TEST_WITH_ROCM, ) from torch.testing._internal.jit_metaprogramming_utils import create_traced_fn from torch.testing._internal.jit_utils import ( @@ -2406,32 +2405,34 @@ def eager(x): @skipIfTorchDynamo("too slow") @unittest.skipIf(TEST_WITH_ASAN, "takes 10+ minutes on asan") - @unittest.skipIf(TEST_WITH_ROCM, "Tensor-likes are not close for nans") def test_batch_norm(self): def test(fn, args): trace = torch.jit.trace(fn, args) self.assertAllFused(trace.graph_for(*args)) - # TODO: Are `NaN`'s actually ok here or did this pass silently before, because `equal_nan=True` was the - # default? torch.testing.assert_close(fn(*args), trace(*args), equal_nan=True) - def bn(i, x): - return torch.batch_norm(i, x, x, x, x, False, 0.1, 1e-4, False).relu() + def bn(i, x, rv): + return torch.batch_norm(i, x, x, x, rv, False, 0.1, 1e-4, False).relu() - def bn_no_weight(i, x): - return torch.batch_norm(i, None, x, x, x, False, 0.1, 1e-4, False).relu() + def bn_no_weight(i, x, rv): + return torch.batch_norm(i, None, x, x, rv, False, 0.1, 1e-4, False).relu() - def bn_no_bias(i, x): - return torch.batch_norm(i, x, None, x, x, False, 0.1, 1e-4, False).relu() + def bn_no_bias(i, x, rv): + return torch.batch_norm(i, x, None, x, rv, False, 0.1, 1e-4, False).relu() - def bn_neither(i, x): - return torch.batch_norm(i, None, None, x, x, False, 0.1, 1e-4, False).relu() + def bn_neither(i, x, rv): + return torch.batch_norm( + i, None, None, x, rv, False, 0.1, 1e-4, False + ).relu() for device in self.devices: i = torch.randn(4, 16, 32, 40, device=device) x = torch.randn(16, device=device) + rv = torch.randn( + 16, device=device + ).abs() # running_var must be non-negative for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]: - test(fn, (i, x)) + test(fn, (i, x, rv)) def test_profiler(self): @torch.jit.script diff --git a/test/test_jit_legacy.py b/test/test_jit_legacy.py index 90827ca8f465a..df871a1bbf61d 100644 --- a/test/test_jit_legacy.py +++ b/test/test_jit_legacy.py @@ -9,7 +9,7 @@ # before instantiating tests. parse_cmd_line_args() -from test_jit import * # noqa: F403, F401 +from test_jit import * # noqa: F403 if __name__ == '__main__': if sys.version_info < (3, 14): diff --git a/test/test_jiterator.py b/test/test_jiterator.py index 7adc8a1df0c87..d46db1257a3cb 100644 --- a/test/test_jiterator.py +++ b/test/test_jiterator.py @@ -12,7 +12,7 @@ if not TEST_CUDA: print('CUDA not available, skipping tests', file=sys.stderr) - TestCase = NoTest # noqa: F811 + TestCase = NoTest code_string = "template T my_fused_kernel(T x, T y, T alpha, T beta) { return alpha * x + beta * y; }" diff --git a/test/test_linalg.py b/test/test_linalg.py index a42300cf8c3b6..82c5ee64b51d8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -659,6 +659,22 @@ def run_test(shape, batch, contiguous): actual = torch.linalg.cholesky(A, upper=True) self.assertEqual(expected, actual) + @skipCUDAIfNoCusolver + @skipCPUIfNoLapack + @dtypes(*floating_and_complex_types()) + def test_cholesky_upper_reconstructs(self, device, dtype): + batch_dims = (1,) + matrix_size = 65 + A = torch.randn( + *(batch_dims + (matrix_size, matrix_size)), dtype=dtype, device=device + ) + pd_matrix = A @ A.mH + torch.eye(matrix_size, dtype=dtype, device=device) + pd_matrix = pd_matrix.squeeze(0) + U = torch.linalg.cholesky(pd_matrix, upper=True) + self.assertEqual(U, torch.triu(U)) + reconstructed = U.mH @ U + self.assertEqual(pd_matrix, reconstructed, atol=1e-4, rtol=1e-5) + @skipCUDAIfNoCusolver @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @@ -1050,6 +1066,16 @@ def test_det_backward(self, device, dtype): input = torch.tensor([[0.]], device=device, dtype=dtype, requires_grad=True) self.assertTrue(torch.autograd.gradcheck(torch.det, inputs=input)) + # When A has 0 elements (e.g. empty batch), backward should return a + # zeros tensor with the same shape as A, not an undefined tensor. + for shape in [(0, 3, 3), (2, 0, 0)]: + A = torch.randn(shape, device=device, dtype=dtype, requires_grad=True) + det = torch.linalg.det(A) + det.backward(torch.ones_like(det)) + self.assertIsNotNone(A.grad) + self.assertEqual(A.grad.shape, A.shape) + self.assertEqual(A.grad, torch.zeros_like(A)) + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(*floating_and_complex_types()) @@ -5151,9 +5177,6 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): # We set the TunableOp numerical check environment variable here because it is # possible to hit some invalid numerical solutions due to the small matrix sizes. - if torch.version.hip and isRocmArchAnyOf(MI350_ARCH) and dtype is torch.double: - self.skipTest("Currently hangs on rocm mi350") - with self._tunableop_ctx(): torch.cuda.tunable.set_rotating_buffer_size(0) # Numerical check adds significant overhead, unsure if this is needed @@ -6123,6 +6146,7 @@ def test_blaslog_tunableop(self, device, dtype): @onlyCUDA @skipCUDAIfNotRocm + # Fails with triton 3.7 @dtypes(torch.float) def test_mm_submatrix_offline_tunableop(self, device, dtype): import os @@ -6253,8 +6277,12 @@ def test_mm_submatrix_offline_tunableop(self, device, dtype): # This stores total number of cumulative results total_num_results = new_results - ref_results - # There must be a new tuning results - self.assertEqual(total_num_results, 10) + preferred_blas = str(torch.backends.cuda.preferred_blas_library()) + # With hipBLASLt preferred, the two linear/addmm+bias calls are + # tracked as GemmAndBias tunables (+2). With rocBLAS preferred, + # they fall back to regular GEMM signatures and don't add entries. + expected_num_results = 10 if preferred_blas == "_BlasBackend.Cublaslt" else 8 + self.assertEqual(total_num_results, expected_num_results) results_filename = torch.cuda.tunable.get_filename() self.assertTrue(os.path.exists(results_filename)) @@ -6465,9 +6493,9 @@ def test_call_count_tunableop(self, device, dtype): # launched per PyTorch API. The kernels have string # that always starts with `Cijk*` mm_key = 'Cijk' - events = prof.key_averages() + events = prof.events() for evt in events: - if mm_key in evt.key: + if mm_key in evt.name: self.assertEqual(evt.count, 1) kernel_count = kernel_count + 1 diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index a68a5310ab671..ca99d00e7068a 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -2,7 +2,6 @@ import contextlib import os -import time import unittest from itertools import product from functools import partial @@ -21,7 +20,6 @@ from torch.testing._internal.common_cuda import ( blas_library_context, PLATFORM_SUPPORTS_BF16, - PLATFORM_SUPPORTS_GREEN_CONTEXT, SM80OrLater, SM90OrLater, SM100OrLater, @@ -48,6 +46,7 @@ runOnRocmArch, serialTest, skipIfRocm, + skipIfRocmArch, TEST_CUDA, TEST_WITH_ROCM, TestCase, @@ -242,6 +241,8 @@ def test_cublas_addmm_bias_shapes(self, size: int, dtype: torch.dtype, backend): @onlyCUDA + # Fails with triton 3.7 + @skipIfRocmArch(NAVI_ARCH) @dtypes(torch.float16) # m == 4 chooses OUTPUT_TYPE reduction on H200 # m == 8 chooses OUTPUT_TYPE reduction on A100 @@ -711,43 +712,56 @@ def test_grouped_gemm_compiled(self, op, a_row_major, b_row_major, max_autotune) self.assertEqual(C, C_ref) @skipCUDAIfNotRocm + # Fails with triton 3.7 def test_grouped_gemm_rocm_ck_flag(self): - CK_HINT = "kernel_grouped_gemm_xdl_splitk" + CK_EQUAL_K_HINT = "kernel_grouped_gemm_xdl_splitk" + CK_UNEQUAL_K_HINT = "kernel_grouped_gemm_xdl_splitk" HIPBLASLT_HINT = "Cijk_Alik_Bljk_BBS_BH_Bias_HA_S_SAV_UserArgs" - def uses_ck(kernels: set[str]) -> bool: - return any(CK_HINT in k for k in kernels) + def has_ck_kernel(kernels: set[str], hint: str) -> bool: + return any(hint in k for k in kernels) def uses_hipblaslt(kernels: set[str]) -> bool: return any(HIPBLASLT_HINT in k for k in kernels) - def run_grouped_mm(): + def run_grouped_mm(equal_k: bool): device = "cuda" dtype = torch.bfloat16 - # row-major 3d-3d - G, M, N, K = 4, 16, 32, 64 - a = torch.randn(G, M, K, device=device, dtype=dtype) - b = torch.randn(G, N, K, device=device, dtype=dtype) - # 3d-3d grouped GEMM: [G, M, K] @ [G, K, N] - out = F.grouped_mm(a, b.transpose(-2, -1), out_dtype=dtype) - return out - - def collect_kernel_names(): + if equal_k: + # 3d-3d grouped GEMM with identical K for all groups + G, M, N, K = 4, 16, 32, 64 + a = torch.randn(G, M, K, device=device, dtype=dtype) + b = torch.randn(G, N, K, device=device, dtype=dtype) + return F.grouped_mm(a, b.transpose(-2, -1), out_dtype=dtype) + + # 2d-2d grouped GEMM with non-uniform offs, i.e. per-group K is not equal + M, N = 16, 32 + offs = torch.tensor([64, 136], device=device, dtype=torch.int32) + K_total = offs[-1].item() + a = torch.randn(M, K_total, device=device, dtype=dtype) + b = torch.randn(N, K_total, device=device, dtype=dtype) + return F.grouped_mm(a, b.transpose(-2, -1), offs=offs, out_dtype=dtype) + + def collect_kernel_names(equal_k: bool): kernels = set() with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False, with_stack=False, ) as prof: - run_grouped_mm() + run_grouped_mm(equal_k=equal_k) for evt in prof.key_averages(group_by_input_shape=False): kernels.add(evt.key) return kernels with rocm_group_gemm_ck_env(None): - self.assertTrue(uses_hipblaslt(collect_kernel_names())) + self.assertTrue(uses_hipblaslt(collect_kernel_names(equal_k=True))) with rocm_group_gemm_ck_env("1"): - self.assertTrue(uses_ck(collect_kernel_names())) + ck_equal_kernels = collect_kernel_names(equal_k=True) + self.assertTrue(has_ck_kernel(ck_equal_kernels, CK_EQUAL_K_HINT)) + + ck_unequal_kernels = collect_kernel_names(equal_k=False) + self.assertTrue(has_ck_kernel(ck_unequal_kernels, CK_UNEQUAL_K_HINT)) @onlyCUDA @parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) @@ -993,73 +1007,6 @@ def is_batched(): op(a, mismatch_batch_dim_b, out_dtype=torch.float32) - @unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported") - @serialTest() - def test_greencontext_carveout(self): - a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) - ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) - ctx.set_context() - torch.matmul(a, a) - torch.cuda.synchronize() - t0 = time.perf_counter() - partial_res = torch.matmul(a, a) - torch.cuda.synchronize() - t1 = time.perf_counter() - ctx.pop_context() - torch.matmul(a, a) - torch.cuda.synchronize() - t2 = time.perf_counter() - full_res = torch.matmul(a, a) - torch.cuda.synchronize() - t3 = time.perf_counter() - self.assertEqual(partial_res, full_res) - self.assertGreater(t1 - t0, t3 - t2) - - @unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported") - @serialTest() - def test_greencontext_stream_carveout(self): - a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) - ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) - ctx_stream = ctx.Stream() - with torch.cuda.stream(ctx_stream): - torch.matmul(a, a) - torch.cuda.synchronize() - t0 = time.perf_counter() - partial_res = torch.matmul(a, a) - torch.cuda.synchronize() - t1 = time.perf_counter() - torch.matmul(a, a) - torch.cuda.synchronize() - t2 = time.perf_counter() - full_res = torch.matmul(a, a) - torch.cuda.synchronize() - t3 = time.perf_counter() - self.assertEqual(partial_res, full_res) - self.assertGreater(t1 - t0, t3 - t2) - - @unittest.skipIf(not PLATFORM_SUPPORTS_GREEN_CONTEXT, "Green contexts are not supported") - @serialTest() - def test_greencontext_graphs(self): - a = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16) - ctx = torch.cuda.green_contexts.GreenContext.create(1, 0) - ctx.set_context() - partial_res = torch.matmul(a, a) - ctx.pop_context() - full_res = torch.matmul(a, a) - full_res.zero_() - partial_res.zero_() - torch.cuda.synchronize() - - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - ctx.set_context() - partial_res = torch.matmul(a, a) - ctx.pop_context() - full_res = torch.matmul(a, a) - g.replay() - self.assertEqual(partial_res, full_res) - - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") diff --git a/test/test_meta.py b/test/test_meta.py index 6ff1639ed1b77..ca5869722c98f 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import ( TestCase, skipIfCrossRef, + skipIfTorchDynamo, suppress_warnings, TEST_WITH_TORCHDYNAMO, run_tests, @@ -1698,6 +1699,32 @@ def test_embedding_bag_dense_backward_per_sample_weights(self): ) self.assertEqual(grad_weight.to('meta'), meta_grad_weight) + def _assert_fft_meta_stride_matches_eager(self, op, *args): + to_meta = MetaConverter() + meta_args = tree_map_only(torch.Tensor, to_meta, args) + ref_out = op(*args) + meta_out = op(*meta_args) + self.assertEqual(ref_out.size(), meta_out.size()) + self.assertEqual(ref_out.stride(), meta_out.stride()) + + @onlyCUDA + @unittest.skipIf(torch.version.hip, "cuFFT-specific stride behavior") + def test_fft_multi_dim_cufft_stride_matches_meta(self, device): + self._assert_fft_meta_stride_matches_eager( + aten._fft_c2c.default, + torch.randn((5, 5, 5, 5, 5), device=device, dtype=torch.complex64), + [1, 2, 3, 4], + 0, + True, + ) + self._assert_fft_meta_stride_matches_eager( + aten._fft_c2r.default, + torch.randn((5, 5, 5, 5, 3), device=device, dtype=torch.complex64), + [0, 1, 2, 3, 4], + 0, + 5, + ) + # opinfo test is using aten.fill_, it's not testing aten.fill @onlyCUDA def test_fill_stride(self): @@ -1869,8 +1896,87 @@ def fn(input, weight, bias, need_grad_input): else: self.assertEqual(out_dtype, [in_dtype,]) +class TestMetaKernelConv(TestCase): + @skipIfTorchDynamo("tests raw meta kernel, not dynamo") + def test_convolution_backward_meta_kernel_channels_last(self): + """Test the meta kernel directly (device='meta', no FakeTensorMode). + This exercises the @register_meta path used by torch.export, which + does NOT go through the FakeTensor intercept in fake_impls.py. + """ + # channels_last grad_output + contiguous input/weight -> contiguous + grad_out = torch.empty(2, 3, 4, 4, device="meta").to( + memory_format=torch.channels_last + ) + inp = torch.empty(2, 3, 4, 4, device="meta") + w = torch.empty(3, 3, 3, 3, device="meta") + gi, gw, _ = torch.ops.aten.convolution_backward( + grad_out, + inp, + w, + [3], + [1, 1], + [1, 1], + [1, 1], + False, + [0, 0], + 1, + [True, True, True], + ) + self.assertTrue(gi.is_contiguous()) + self.assertTrue(gw.is_contiguous()) + + # contiguous grad_output + channels_last input -> channels_last + grad_out2 = torch.empty(2, 3, 4, 4, device="meta") + inp2 = torch.empty(2, 3, 4, 4, device="meta").to( + memory_format=torch.channels_last + ) + gi2, gw2, _ = torch.ops.aten.convolution_backward( + grad_out2, + inp2, + w, + [3], + [1, 1], + [1, 1], + [1, 1], + False, + [0, 0], + 1, + [True, True, True], + ) + self.assertTrue(gi2.is_contiguous(memory_format=torch.channels_last)) + self.assertTrue(gw2.is_contiguous(memory_format=torch.channels_last)) + + + + +class TestMetaKernelRegistrations(TestCase): + @skipIfTorchDynamo("tests raw meta kernel, not dynamo") + def test_make_dep_token(self): + cpu_result = torch.ops.aten._make_dep_token(device=torch.device("cpu")) + meta_result = torch.ops.aten._make_dep_token(device=torch.device("meta")) + self.assertEqual(cpu_result.shape, meta_result.shape) + self.assertEqual(cpu_result.dtype, meta_result.dtype) + + @skipIfTorchDynamo("tests raw meta kernel, not dynamo") + def test_rrelu_backward_small_range(self): + from torch._decomp.decompositions import rrelu_with_noise_backward + + x = torch.randn(5, requires_grad=True) + lower, upper = 0.125, 0.125 + torch.finfo(torch.float32).eps + noise = torch.rand(5) + grad = torch.ones(5) + cpp_result = torch.ops.aten.rrelu_with_noise_backward( + grad, x, noise, lower, upper, True, False + ) + decomp_result = rrelu_with_noise_backward( + grad, x, noise, lower, upper, True, False + ) + self.assertEqual(cpp_result, decomp_result) + + instantiate_device_type_tests(TestMeta, globals()) + def print_op_str_if_not_supported(op_str): op = OperatorName.parse(op_str) packet = getattr(torch.ops.aten, str(op.name)) diff --git a/test/test_mps.py b/test/test_mps.py index 66ba23cf46731..0fba38f57f31d 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -21,6 +21,7 @@ from collections import defaultdict from torch import inf from torch.nn import Buffer, Parameter +from torch.export import Dim, export from torch.testing._internal import opinfo from torch.testing._internal.common_utils import \ (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, MACOS_VERSION, IS_CI, @@ -76,8 +77,8 @@ # Same logic as test_cuda.py if not torch.backends.mps.is_available(): print('MPS not available, skipping tests', file=sys.stderr) - TestCase = NoTest # noqa: F811 - NNTestCase = NoTest # noqa: F811 + TestCase = NoTest + NNTestCase = NoTest total_memory = int(subprocess.check_output(["sysctl", "-n", "hw.memsize"])) @@ -973,6 +974,17 @@ def helper(val, shape, dtype): helper(0, [1024], torch.float32) helper(0.2, [2, 3], torch.float32) helper(0.2 + 0.5j, [2, 3], torch.complex64) + helper(2**63 + 100, [3], torch.uint64) + + def test_fill_strided(self): + # Regression test: strided fill_ must convert byte strides to element strides + for dtype in [torch.cfloat, torch.float32, torch.int32, torch.uint16, torch.uint8]: + x_mps = torch.zeros(4, 4, device='mps', dtype=dtype) + x_cpu = x_mps.cpu() + # Every other row — non-contiguous view + for t in (x_mps, x_cpu): + t[::2].fill_(1) + self.assertEqual(x_mps, x_cpu) def test_fill_storage_offset(self): shape = [2, 10] @@ -1222,6 +1234,27 @@ def test_bmm(self): self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + def test_bmm_conj(self): + # bmm must respect the conjugate bit on input tensors. + # See https://github.com/pytorch/pytorch/issues/177474 + a = torch.randn(4, 3, 5, dtype=torch.complex64, device="mps") + b = torch.randn(4, 5, 2, dtype=torch.complex64, device="mps") + result_mps = torch.bmm(a, b.conj()) + result_cpu = torch.bmm(a.cpu(), b.cpu().conj()) + self.assertEqual(result_cpu, result_mps) + result_mps = torch.bmm(a.conj(), b) + result_cpu = torch.bmm(a.cpu().conj(), b.cpu()) + self.assertEqual(result_cpu, result_mps) + + def test_addmm_conj(self): + # Regression test: addmm must respect the conjugate bit on the bias tensor. + bias = torch.randn(3, 2, dtype=torch.complex64, device="mps") + a = torch.randn(3, 5, dtype=torch.complex64, device="mps") + b = torch.randn(5, 2, dtype=torch.complex64, device="mps") + result_mps = torch.addmm(bias.conj(), a, b) + result_cpu = torch.addmm(bias.cpu().conj(), a.cpu(), b.cpu()) + self.assertEqual(result_cpu, result_mps) + @xfailIf(MACOS_VERSION < 15.0) @parametrize("dtype", [torch.float16, torch.bfloat16]) def test_large_bmm(self, dtype): @@ -1373,6 +1406,19 @@ def test_linear_non_contiguous(self): result_contig = torch.nn.functional.linear(input_s, weight_contiguous_equiv) self.assertEqual(result_contig, result_sliced) + def test_linear_backward_channels_last_grad(self): + # Regression test for https://github.com/pytorch/pytorch/issues/178222 + # Linear backward crashed when grad_output had channels_last strides, + # because suggest_memory_format() returned ChannelsLast which was then + # applied to 2D weight grad and 1D bias grad tensors (requires rank 4). + x = torch.randn(2, 8, 3, 4, device='mps', requires_grad=True) + proj = torch.nn.Linear(4, 4, device='mps') + y = proj(x) + z = y.permute(0, 2, 3, 1).contiguous() + target = torch.randn_like(z) + loss = (z - target).pow(2).sum() + loss.backward() + def _linear_helper(self, in_features, out_features, shape, bias=True, backward_pass=False): cpu_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="cpu", bias=bias) mps_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, device="mps", bias=bias) @@ -1981,6 +2027,19 @@ def test_batch_norm_backward_weight_bias_gradients(self): self.assertEqual(bn_cpu.weight.grad, bn_mps.weight.grad, atol=1e-5, rtol=1e-5) self.assertEqual(bn_cpu.bias.grad, bn_mps.bias.grad, atol=1e-5, rtol=1e-5) + def test_batch_norm_mixed_dtype(self): + # Regression test for https://github.com/pytorch/pytorch/issues/178770 + # BatchNorm with float32 weights and float16 input should work + x = torch.rand(2, 32, 15, 15, device="mps", dtype=torch.float16, requires_grad=True) + model = nn.BatchNorm2d(32, device="mps", dtype=torch.float32).train() + y_mps = model(x) + # Compare against CPU using a copy of the model + import copy + y_cpu = copy.deepcopy(model).cpu()(x.detach().cpu()) + self.assertEqual(y_cpu, y_mps.cpu(), atol=1e-3, rtol=1e-3) + y_mps.sum().backward() + self.assertIsNotNone(x.grad) + def test_layer_norm_backward(self): inputs = torch.rand(4, 4, device="mps", requires_grad=True) x = torch.nn.LayerNorm(4).to("mps") @@ -4930,6 +4989,45 @@ def test_reduction_sum_max_long_val(self): res_cpu = torch.sum(x_cpu) self.assertEqual(res_mps, res_cpu) + # TODO: fold into OpInfo-based consistency tests once reduction kernels expose + # explicit precision expectations per dtype. + def test_mean_low_prec_order(self): + """Regression test for fp16/bf16/half2 mean: the division must happen + in opmath_t (fp32 for half/bfloat, float2 for half2) BEFORE the cast + to the output dtype. Otherwise the fp32 accumulation done inside the + sum kernel is immediately discarded by a half/bfloat divide. Seen + in inductor/test_adaptive_avg_pool2d_low_prec_mps before the fix.""" + torch.manual_seed(0) + x_fp32 = torch.randn(4, 3, 7, 7) + for dtype in (torch.half, torch.bfloat16, torch.complex32): + if dtype == torch.complex32: + src = torch.complex(x_fp32, x_fp32 + 1).to(torch.complex64) + ref = src.mean(dim=(-2, -1), keepdim=True).to(dtype) + mps = src.to("mps").to(dtype).mean(dim=(-2, -1), keepdim=True) + else: + src = x_fp32 + ref = src.to(dtype).to(torch.float).mean(dim=(-2, -1), keepdim=True).to(dtype) + mps = src.to("mps").to(dtype).mean(dim=(-2, -1), keepdim=True) + self.assertEqual(mps.cpu(), ref, msg=f"mean({dtype}) diverges from fp32-intermediate reference") + + # TODO: fold into OpInfo-based consistency tests once there's a hook to + # exercise the two-pass reduction path for scalar outputs. + def test_sum_full_reduction_repeated(self): + """Regression test for the two-pass full reduction: when + reduction_size doesn't divide evenly into num_groups, the last TG + must not read past the input. Before the fix, the MPS caching + allocator sometimes placed the `partials` buffer in the OOB region, + causing repeated sum() calls to accumulate (1x, 2x, 3x, ...).""" + # 102400 is the exact size that triggered the bug in + # inductor/test_dense_mask_index_mps: num_groups=13 initially, but + # 102400 / 13 has a remainder, so the pre-fix kernel read 14 + # elements past the logical end. + for N in (102400, 99991, 1_000_003): + x = torch.randn(N, device="mps") + ref = x.sum().item() + for _ in range(4): + self.assertEqual(x.sum().item(), ref, msg=f"unstable sum for N={N}") + # Test forward max # Note - don't test grad now def test_max_el(self): @@ -6060,7 +6158,7 @@ def run_cholesky_test(size, *batch_dims, upper=False, check_errors=False): input_mps = input_cpu.to('mps') output_cpu = torch.linalg.cholesky_ex(input_cpu, upper=upper) output_mps = torch.linalg.cholesky_ex(input_mps, upper=upper) - self.assertEqual(output_cpu, output_mps, atol=2e-5, rtol=1e-6) + self.assertEqual(output_cpu, output_mps, atol=3e-5, rtol=1e-6) # test with different even/odd matrix sizes matrix_sizes = [1, 2, 3, 4, 8, 17, 64, 128, 154] @@ -8183,6 +8281,18 @@ def helper(shape, alpha, op_name, inplace): # Regression test for https://github.com/pytorch/pytorch/issues/160208 self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2)) + def test_add_sub_alpha_cast(self): + # In-place with alpha when self is promoted (e.g. float16 + float32) + for op in [torch.Tensor.add_, torch.Tensor.sub_]: + for dtype_a, dtype_b in [(torch.float16, torch.float32), (torch.bfloat16, torch.float32)]: + a_mps = torch.arange(16, dtype=dtype_a, device='mps') + b_mps = torch.arange(16, dtype=dtype_b, device='mps') + a_cpu, b_cpu = a_mps.cpu(), b_mps.cpu() + alpha = torch.tensor(0.33, dtype=dtype_b) + op(a_mps, b_mps, alpha=alpha) + op(a_cpu, b_cpu, alpha=alpha) + self.assertEqual(a_mps, a_cpu) + # Test add def test_add_scalars(self): def helper(alpha): @@ -8364,6 +8474,19 @@ def test_non_contiguous_sampling_variation(self): # indicating the sampling is working properly on non-contiguous tensors self.assertNotEqual(len(samples), 1) + def test_multinomial_large_input_no_segfault(self): + # Regression test for https://github.com/pytorch/pytorch/issues/178579 + # The previous MPSGraph kernel materialized an N x N ones matrix and + # segfaulted for N >= 2**17 + # this test just checks that it doesn't segfault so we don't regress + n = 2**17 + probs = torch.rand(n, device="mps") + out = torch.multinomial(probs, 100, replacement=True) + torch.mps.synchronize() + self.assertEqual(out.shape, torch.Size([100])) + self.assertEqual(out.dtype, torch.int64) + self.assertEqual(out.device.type, "mps") + def test_cumsum_dim_check(self): x = torch.rand((3, 3), device="mps") self.assertEqual(x.cumsum(1), x.cumsum(-1)) @@ -9231,6 +9354,25 @@ def test_conv3d_backward_collision(self): # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion y2.sum().backward() + def test_channels_last_channel_slice(self): + # Regression test for https://github.com/pytorch/pytorch/issues/180984 + # A channel-slice view of a channels_last tensor has channels-last-like + # strides but is not packed NHWC in memory. MPSGraph ops that support NHWC + # are only work with packed NHWC buffer, giving wrong results. + shared = torch.randn(2, 4, 8, 8, device="mps").contiguous(memory_format=torch.channels_last) + mps_slice = shared[:, :2] + + weight = torch.randn(3, 2, 3, 3, device="mps") + self.assertEqual(F.conv2d(mps_slice.cpu(), weight.cpu()), F.conv2d(mps_slice, weight).cpu()) + + self.assertEqual(F.avg_pool2d(mps_slice.cpu(), 2), F.avg_pool2d(mps_slice, 2).cpu()) + self.assertEqual(F.adaptive_avg_pool2d(mps_slice.cpu(), 2), F.adaptive_avg_pool2d(mps_slice, 2).cpu()) + + bn = nn.BatchNorm2d(2).eval() + bn_mps = nn.BatchNorm2d(2).to("mps").eval() + bn_mps.load_state_dict(bn.state_dict()) + self.assertEqual(bn(mps_slice.cpu()), bn_mps(mps_slice).cpu()) + # Regression test for https://github.com/pytorch/pytorch/issues/141471 def test_conv3d_channels_last_3d(self): m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu") @@ -9695,6 +9837,33 @@ def weight_int8pack_mm(a, b_int8pack, b_scales): mean_err = ((res - ref).abs() / ref).mean() self.assertLess(mean_err, 0.05) + def test_loradown_regression_original_case(self): + a = torch.rand(2, 1025, device='mps', dtype=torch.half) + b = torch.rand(2, 1041, device='mps', dtype=torch.half)[:, :1025].t() + result = a @ b + self.assertEqual(result.shape, (2, 2)) + + self.assertFalse(torch.isnan(result).any()) + self.assertFalse(torch.isinf(result).any()) + + @parametrize("padding", [0, 3, 4, 7, 8, 15, 16]) + @parametrize("vector_dim", [2, 15, 16, 24]) + def test_loradown_correctness_vs_cpu(self, padding, vector_dim): + torch.manual_seed(13) + + base_size = 64 + physical_size = base_size + padding + + a_mps = torch.rand(vector_dim, base_size, device='mps', dtype=torch.half) + b_mps = torch.rand(vector_dim, physical_size, device='mps', dtype=torch.half)[:, :base_size].t() + + a_cpu = a_mps.cpu() + b_cpu = b_mps.cpu() + + result_cpu = (a_cpu @ b_cpu) + result_mps = (a_mps @ b_mps).cpu() + + torch.testing.assert_close(result_mps, result_cpu, rtol=1e-3, atol=1e-3) class TestSDPA(TestCaseMPS): def _compare_tensors(self, y, ref, tol=0.01): @@ -9745,6 +9914,38 @@ def test_sdpa_no_mask_causal_fp32(self): def test_sdpa_no_mask_causal_fp16(self): self._test_sdpa_no_mask(True, torch.float16) + def test_sdpa_export_dynamic_seq_len(self): + # Regression test for https://github.com/pytorch/pytorch/issues/177603 + class M(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("k", torch.zeros(1, 4, 64, 128)) + self.register_buffer("v", torch.zeros(1, 4, 64, 128)) + + def forward(self, q, mask): + out, _ = torch.ops.aten._scaled_dot_product_attention_math_for_mps( + q, self.k, self.v, mask, 0.0, False, None + ) + return out + + model = M().to(device="mps", dtype=torch.float32).eval() + seq = Dim("seq", min=1, max=64) + + q = torch.randn(1, 4, 4, 128, device="mps", dtype=torch.float32) + mask = torch.zeros(4, 64, device="mps", dtype=torch.bool) + + ep = export( + model, + (q, mask), + dynamic_shapes={"q": {2: seq}, "mask": {0: seq}}, + strict=True, + ) + + q2 = torch.randn(1, 4, 7, 128, device="mps", dtype=torch.float32) + mask2 = torch.zeros(7, 64, device="mps", dtype=torch.bool) + out = ep.module()(q2, mask2) + self.assertEqual(out.shape, (1, 4, 7, 128)) + def test_sdpa_no_mask_causal_fp16_L7(self): self._test_sdpa_no_mask(True, torch.float16, 7) @@ -10098,8 +10299,10 @@ def format_res(res): for t in pytree.tree_flatten(res)[0] ] - # Format the output so that we only look at the tensor metadata - self.test.assertEqual(format_res(res), format_res(meta_res)) + # Format the output so that we only look at the tensor metadata. + # Only compare the first returned value for this op; the second output + # is not consumed and inconsistent across paths. + self.test.assertEqual(format_res(res)[0], format_res(meta_res)[0]) return res @@ -11611,6 +11814,18 @@ def test_nonzero_discontiguous(self): self.assertEqual(dst1, dst4, atol=0, rtol=0) self.assertEqual(strides, dst4.stride()) + def test_nonzero_large(self): + # Regression test: with 2M elements and threadgroup size 1024, the + # prefix_sum_blocks kernel must handle ~1954 blocks via its + # multi-element-per-thread loop (each thread processes + # ceil(num_blocks/tg_size) blocks). This verifies correctness when + # the block count exceeds a single threadgroup. + x = torch.rand(2_000_000, device="mps") + x[x > 0.5] = 0 + result = torch.nonzero(x) + expected = torch.nonzero(x.cpu()) + self.assertEqual(result.cpu(), expected) + def test_nonzero_non_diff(self): device = "mps" x = torch.randn(10, requires_grad=True, device=device) @@ -12472,6 +12687,27 @@ def test_lstm_backward(self, device="mps", dtype=torch.float32): for test_options in self.LSTM_TEST_CASES: self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options) + def test_lstm_eval_after_train_same_shape(self): + # Regression test for https://github.com/pytorch/pytorch/issues/180744 + # The MPS LSTM graph cache key did not include the `train` flag, so a + # graph built with dropout during train() was reused in eval() at the + # same input shape, silently applying dropout during inference + torch.manual_seed(0) + lstm = nn.LSTM( + input_size=2, hidden_size=4, num_layers=2, + dropout=0.1, batch_first=True, + ).to("mps") + opt = torch.optim.SGD(lstm.parameters(), lr=1e-2) + lstm(torch.randn(3, 5, 2, device="mps"))[0].mean().backward() + opt.step() + + lstm.eval() + probe = torch.randn(10, 5, 2, device="mps") + with torch.no_grad(): + full = lstm(probe)[0] + part = lstm(probe[:3])[0] + self.assertEqual(full[:3], part) + def test_RNN_cell_no_broadcasting(self): def test(cell_module, input, hx, input_size, hidden_size): cell = cell_module(input_size, hidden_size, device='mps') @@ -12767,6 +13003,10 @@ class TestConsistency(TestCaseMPS): 'matmul', '__rmatmul__', 'linalg.multi_dot', 'addbmm', + # Accumulates sigmoid + log + weighted sum rounding; CPU and MPS + # end up within ~3e-5 of fp64 but differ from each other by more + # than the default fp32 tolerance. + 'nn.functional.binary_cross_entropy_with_logits', } def _compute_tolerances(self, op, dtype): @@ -12782,8 +13022,10 @@ def _compute_tolerances(self, op, dtype): if op.name in ['nn.functional.conv_transpose1d', 'nn.functional.conv_transpose2d', 'nn.functional.conv_transpose3d', - '__rmatmul__', 'addbmm', 'addmv', - 'baddbmm', 'cov', 'matmul', 'mv'] and dtype in [torch.float16, torch.bfloat16]: + '__rmatmul__', 'addbmm', 'addmm', 'addmv', + 'baddbmm', 'corrcoef', 'cov', 'linalg.multi_dot', + 'matmul', 'mv', + 'nn.functional.linear'] and dtype in [torch.float16, torch.bfloat16]: return (5e-2, 5e-2) if dtype == torch.float16 else (5e-2, 1e-1) if op.name == "masked.mean": return (7e-4, 2e-3) @@ -12805,6 +13047,12 @@ def _compute_tolerances(self, op, dtype): NEW_ALLOW_LIST_GRAD = defaultdict(list) def _run_op(self, op, mps_sample, dtype=None): + # MPS uses float32 intermediates for these ops, so the CPU reference + # must also run in float32 to avoid comparing against less-precise + # native half-precision CPU results. + if op.name in ["grid_sampler_2d", "grid_sampler_3d"] and dtype is None and mps_sample.input.dtype in [torch.float16, torch.bfloat16]: + dtype = torch.float32 + cpu_sample = transform_opinfo_sample_to_cpu(mps_sample, dtype) with warnings.catch_warnings(): @@ -12845,9 +13093,10 @@ def test_output_match(self, device, dtype, op): set_seed=True): opt_dtype = None - # CPU implementation is less precise than MPS one so compare MPS to full fp32 - if dtype in [torch.float16, torch.bfloat16] and op.name in ["grid_sampler_2d", "grid_sampler_3d"]: - opt_dtype = torch.float32 + + if op.name == "histc" and not dtype.is_floating_point and not dtype.is_complex: + opt_dtype = dtype + mps_out, cpu_out, cpu_sample = self._run_op(op, mps_sample, opt_dtype) atol, rtol = self._compute_tolerances(op, dtype) @@ -12963,7 +13212,7 @@ def req_grad(t): # The CPU impl of grid_sampler_3d gives a large amount of error for half # precision types. So instead of testing MPS-vs-CPU outputs, test - # full-vs-half precision dtypes for MPS. + # full-vs-half precision dtypes for MPS (both forward and backward). @dtypes(torch.float16, torch.bfloat16) def test_grid_sampler_3d_half_precision(self, device, dtype): op = next((op for op in test_consistency_op_db if op.name == "grid_sampler_3d"), None) @@ -12982,18 +13231,32 @@ def get_samples(): half_input = half_sample.input half_grid, mode, padding_mode, align_corners = half_sample.args - full_input = half_input.to(torch.float).detach() - full_grid = half_grid.to(torch.float).detach() + full_input = half_input.to(torch.float).detach().requires_grad_(True) + full_grid = half_grid.to(torch.float).detach().requires_grad_(True) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) half_out = op(half_input, half_grid, mode, padding_mode, align_corners) full_out = op(full_input, full_grid, mode, padding_mode, align_corners) - atol, rtol = 1e-4, 1e-4 - + atol, rtol = (1e-4, 1e-4) if dtype == torch.float16 else (5e-3, 5e-3) self.assertEqual(half_out, full_out.to(dtype), atol=atol, rtol=rtol) + grad_output = torch.rand_like(half_out) + full_grad_output = grad_output.to(torch.float) + + half_grad_input, half_grad_grid = torch.autograd.grad( + half_out, [half_input, half_grid], grad_outputs=grad_output) + full_grad_input, full_grad_grid = torch.autograd.grad( + full_out, [full_input, full_grid], grad_outputs=full_grad_output) + + # grad_grid uses direct stores — same tolerance as forward + self.assertEqual(half_grad_grid, full_grad_grid.to(dtype), atol=atol, rtol=rtol) + # grad_input uses atomic half-precision accumulation, which is + # non-deterministic and introduces more rounding than forward + bwd_atol, bwd_rtol = (5e-3, 2e-2) if dtype == torch.float16 else (1e-1, 5e-1) + self.assertEqual(half_grad_input, full_grad_input.to(dtype), atol=bwd_atol, rtol=bwd_rtol) + def test_grid_sampler_3d_nan(self, device): input = torch.ones(1, 1, 3, 3, 3) grid_nan = torch.tensor([[[[[torch.nan, 1., 1.], [1., 1., 1.]]]]]) @@ -13042,6 +13305,16 @@ def test_fmax_mixed_dtypes(self, device): # Broadcast self.assertEqual(op(x, y[0]), op(x.to("mps"), y.to("mps")[0]).cpu()) + def test_mm_stride_zero(self): + # Regression test for https://github.com/pytorch/pytorch/issues/180201 + # MPSGraph matrixMultiplication produces incorrect results with stride-0 + # inputs on macOS < 26.4 (only every 16th row is correct). + expanded = torch.ones(1, 1, device="mps").expand(64, 1) + w = torch.randn(1, 16, device="mps") + result = expanded.mm(w) + expected = torch.ones(64, 1, device="mps").mm(w) + self.assertEqual(result, expected) + class TestErrorInputs(TestCase): _ignore_not_implemented_error = True @@ -13242,6 +13515,21 @@ def test_metal_error_checking(self): self.assertRaises(ValueError, lambda: lib.full(mps_tensor, threads=(3, max_thread_group_size), group_size=(3, max_thread_group_size))) + def test_metal_randn(self): + lib = torch.mps.compile_shader(""" + #include + kernel void randn(device float* out, constant long2& seed_offset, + uint idx [[thread_position_in_grid]]) { + out[idx] = c10::metal::randn(seed_offset.x, seed_offset.y + idx); + } + """) + N = 100_000 + out = torch.empty(N, device="mps") + seed_offset = torch.tensor([42, 0], dtype=torch.long, device="mps") + lib.randn(out, seed_offset) + self.assertEqual(out.mean(), torch.tensor(0.0, device="mps"), atol=0.01, rtol=0) + self.assertEqual(out.std(), torch.tensor(1.0, device="mps"), atol=0.02, rtol=0) + def test_metal_include(self): # Checks that includes embedding works lib = torch.mps.compile_shader("#include ") @@ -13293,6 +13581,45 @@ def test_reduction_utils(self, dtype): self.assertTrue(z0.isnan().all().item(), f"results are {z0}, but all elements should have been nan") self.assertTrue((z1 == idx).all().item(), f"results are {z1}, but all elements should have been {idx}") + def test_reduction_utils_complex(self): + """Test simd_sum and simd_prod for float2 (complex64).""" + lib = torch.mps.compile_shader(""" + #include + kernel void do_sum(device float2* out, + constant float2* inp, + uint idx [[thread_position_in_grid]]) { + out[idx] = c10::metal::simd_sum(inp[idx]); + } + + kernel void do_prod(device float2* out, + constant float2* inp, + uint idx [[thread_position_in_grid]]) { + out[idx] = c10::metal::simd_prod(inp[idx]); + } + """) + + # Test simd_sum: all 32 lanes get the same total + x = torch.randn(28, device="mps", dtype=torch.complex64) + y = torch.empty_like(x) + lib.do_sum(y, x) + x_sum = x.sum() + max_err = (y - x_sum).abs().max().item() + self.assertLess(max_err, 1e-4, f"simd_sum error {max_err}, expected {x_sum}") + + # Test simd_prod: product of a few small complex numbers + # Use only 4 non-unit values to keep the product numerically stable + x_prod = torch.ones(32, device="mps", dtype=torch.complex64) + x_prod[0] = 1 + 2j + x_prod[1] = 3 - 1j + x_prod[2] = -1 + 1j + x_prod[3] = 2 + 0j + y_prod = torch.empty_like(x_prod) + lib.do_prod(y_prod, x_prod, threads=(32,), group_size=(32,)) + expected_prod = x_prod.prod() + # Only lane 0 has the final result for shuffle-down reduction + max_err = (y_prod[0] - expected_prod).abs().item() + self.assertLess(max_err, 1e-4, f"simd_prod error {max_err}, expected {expected_prod}") + @parametrize("dtype", [torch.float32, torch.float16, torch.int32, torch.bfloat16]) def test_atomic_add(self, dtype): from torch._inductor.codegen.mps import DTYPE_TO_METAL @@ -13421,6 +13748,40 @@ def test_metal_compiler_bug_workaround(self): for i in [0, 5, 6, 7, 63, 64]: self.assertEqual(out[i], 0) + def test_load_precompiled_metallib(self): + # Load a checked-in precompiled metallib containing square and inc_inplace kernels + metallib_path = os.path.join(os.path.dirname(__file__), "metal", "test_kernels.metallib") + with open(metallib_path, "rb") as f: + lib = torch.mps.load_metallib(f.read()) + + # Verify kernel discovery + kernel_names = set(dir(lib)) + self.assertIn("square", kernel_names) + self.assertIn("inc_inplace", kernel_names) + + # Test square kernel: [1, 2, 3, 4] -> [1, 4, 9, 16] + x = torch.tensor([1.0, 2.0, 3.0, 4.0], device="mps") + lib.square(x) + self.assertEqual(x, torch.tensor([1.0, 4.0, 9.0, 16.0], device="mps")) + + # Test inc_inplace kernel: [1, 4, 9, 16] -> [2, 5, 10, 17] + lib.inc_inplace(x) + self.assertEqual(x, torch.tensor([2.0, 5.0, 10.0, 17.0], device="mps")) + + def test_load_precompiled_metallib_from_path(self): + # Load metallib directly from file path (uses newLibraryWithURL:) + metallib_path = os.path.join(os.path.dirname(__file__), "metal", "test_kernels.metallib") + lib = torch.mps.load_metallib(metallib_path) + + # Verify kernel discovery + kernel_names = set(dir(lib)) + self.assertIn("square", kernel_names) + self.assertIn("inc_inplace", kernel_names) + + # Test square kernel + x = torch.tensor([1.0, 2.0, 3.0, 4.0], device="mps") + lib.square(x) + self.assertEqual(x, torch.tensor([1.0, 4.0, 9.0, 16.0], device="mps")) # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing. diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py index 45a09a9312ced..6a8d828acf03f 100644 --- a/test/test_multiprocessing.py +++ b/test/test_multiprocessing.py @@ -1,5 +1,4 @@ # Owner(s): ["module: multiprocessing"] -# ruff: noqa: F841 import contextlib import copy import gc diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 7a7cdd0ee91d1..e96b351eac8ae 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8769,6 +8769,11 @@ def f(values, offsets): ), name="unimplemented_masked_fill", ), + XFailRule( + sample_match_fn=lambda device, sample: "(T, NT)" in sample.name, + op_match_fn=lambda device, op: op.full_name == "nextafter", + name="nextafter_backward_not_implemented", + ), ] COMPILE_FORWARD_SKIPS_AND_XFAILS = [ diff --git a/test/test_nn.py b/test/test_nn.py index 8a168ef2a6796..3a5258d8ff9f6 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2,6 +2,8 @@ # ruff: noqa: F841 import contextlib +import ctypes +import ctypes.util import math import random import unittest @@ -35,7 +37,7 @@ skipIfNoLapack, skipIfRocm, MI300_ARCH, skipIfRocmArch, \ TEST_NUMPY, TEST_SCIPY, TEST_WITH_CROSSREF, TEST_WITH_ROCM, \ download_file, get_function_arglist, load_tests, skipIfMPS, \ - IS_PPC, \ + IS_PPC, IS_ARM64, IS_MACOS, IS_WINDOWS, IS_CPU_CAPABILITY_SVE256, IS_CPU_EXT_SVE_SUPPORTED, xfailIf, \ parametrize as parametrize_test, subtest, instantiate_parametrized_tests, \ skipIfTorchDynamo, gcIfJetson, set_default_dtype from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ @@ -2223,6 +2225,7 @@ def test_threshold_bfloat16_half(self): res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float() self.assertEqual(res_bf16, expected) + @xfailIf(IS_ARM64 and not IS_CPU_EXT_SVE_SUPPORTED) # SIGILL on AArch64 without SVE @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs' ' with instruction set support avx2 or newer.') @@ -4011,22 +4014,26 @@ def test_rnn_check_device(self): hidden = torch.randn(correct_hidden_shape) # input and weights are not at the same device - with self.assertRaisesRegex(RuntimeError, - "Input and parameter tensors are not at the same device"): + rnn_param_device_msg = ( + r"(?:Input and parameter tensors are not at the same device|" + r"Expected all tensors to be on the same device)" + ) + with self.assertRaisesRegex(RuntimeError, rnn_param_device_msg): model(input.to('cuda:0')) - with self.assertRaisesRegex(RuntimeError, - "Input and parameter tensors are not at the same device"): + with self.assertRaisesRegex(RuntimeError, rnn_param_device_msg): model_cuda(input) # input and hiddens are not at the same device - with self.assertRaisesRegex(RuntimeError, - r"Input and hidden tensors are not at the same device"): + rnn_hidden_device_msg = ( + r"(?:Input and hidden tensors are not at the same device|" + r"Expected all tensors to be on the same device)" + ) + with self.assertRaisesRegex(RuntimeError, rnn_hidden_device_msg): if mode == 'LSTM': model(input, (hidden.to('cuda:0'), hidden.to('cuda:0'))) else: model(input, (hidden.to('cuda:0'))) - with self.assertRaisesRegex(RuntimeError, - r"Input and hidden tensors are not at the same device"): + with self.assertRaisesRegex(RuntimeError, rnn_hidden_device_msg): if mode == 'LSTM': model_cuda(input.to('cuda:0'), (hidden, hidden)) else: @@ -4034,8 +4041,7 @@ def test_rnn_check_device(self): # hidden tensors are not at the same CUDA device if mode == 'LSTM': - with self.assertRaisesRegex(RuntimeError, - "Input and hidden tensors are not at the same device"): + with self.assertRaisesRegex(RuntimeError, rnn_hidden_device_msg): model(input.to('cuda:0'), (hidden.to('cuda:0'), hidden.to('cuda:1'))) @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") @@ -6844,6 +6850,8 @@ def test_upsampling_small_scale(self): expected_out_t = torch.tensor([[[[2.5]]]]) self.assertEqual(expected_out_t, out_t) + @xfailIf(IS_ARM64 and IS_CPU_EXT_SVE_SUPPORTED and not IS_CPU_CAPABILITY_SVE256) + # see https://github.com/pytorch/pytorch/issues/177250 def test_upsampling_bfloat16(self, dtype=torch.bfloat16): def helper(size, scale_factor, mode, device, memory_format=torch.contiguous_format): input = torch.randn(size, device=device, dtype=dtype).to(memory_format=memory_format).detach().requires_grad_(True) @@ -6977,6 +6985,77 @@ def helper(size, dtype, mode, device, is_channels_last): helper(size, dtype, mode, device, is_channels_last) + @unittest.skipIf(IS_WINDOWS, "requires mmap/mprotect") + @unittest.skipUnless(TEST_NUMPY, "requires numpy") + def test_interpolate_uint8_overread(self): + # Regression test for vectorized resize overreads (NEON vld3_u8 and + # similar AVX2 paths). The vectorized block-of-4/8 loops may load + # more bytes than needed; on the last row this can read past the + # buffer. We detect this by placing tensor data right before an + # unmapped guard page so any overread triggers SIGBUS/SIGSEGV. + + page_size = os.sysconf("SC_PAGE_SIZE") + libc = ctypes.CDLL(ctypes.util.find_library("c")) + libc.mmap.restype = ctypes.c_void_p + libc.mmap.argtypes = [ + ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_longlong, + ] + libc.mprotect.restype = ctypes.c_int + libc.mprotect.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] + libc.munmap.restype = ctypes.c_int + libc.munmap.argtypes = [ctypes.c_void_p, ctypes.c_size_t] + + MAP_PRIVATE = 0x02 + MAP_ANON = 0x1000 if IS_MACOS else 0x20 + PROT_RW = 0x01 | 0x02 + PROT_NONE = 0x00 + + for num_channels in (3, 4): # 3ch → NEON on aarch64, 4ch → AVX2 on x86 + for mode in ("bilinear", "bicubic"): + for w_in in range(4, 60): + for w_out in [1, 2, 3]: + h_in = 1 + tensor_bytes = h_in * w_in * num_channels + num_data_pages = (tensor_bytes + page_size - 1) // page_size + total_pages = num_data_pages + 1 + total_size = total_pages * page_size + + addr = libc.mmap(0, total_size, PROT_RW, + MAP_PRIVATE | MAP_ANON, -1, 0) + self.assertTrue( + addr not in (0, 2**64 - 1, -1), "mmap failed" + ) + + guard_start = addr + num_data_pages * page_size + libc.mprotect(guard_start, page_size, PROT_NONE) + + data_start = guard_start - tensor_bytes + ArrayType = ctypes.c_uint8 * tensor_bytes + c_arr = ArrayType.from_address(data_start) + np_arr = np.frombuffer(c_arr, dtype=np.uint8) + np_arr[:] = 128 + + flat = torch.from_numpy(np_arr) + # channels-last strides: (H*W*C, 1, W*C, C) + t = flat.as_strided( + size=[1, num_channels, h_in, w_in], + stride=[ + h_in * w_in * num_channels, + 1, + w_in * num_channels, + num_channels, + ], + ) + + # Overread past the guard page will SIGBUS/SIGSEGV + F.interpolate( + t, size=(h_in, w_out), mode=mode, + antialias=True, align_corners=False, + ) + + libc.munmap(addr, total_size) + @set_default_dtype(torch.double) def test_interpolate(self): def _test_interpolate_non_integer_size_warning(in_t, out_size, dim, **kwargs): @@ -9415,6 +9494,49 @@ def group_norm_ref(X, gamma, beta, groups, channels, eps): Y_cpu = group_norm(X.cpu()) self.assertEqual(Y_cpu, Y, rtol=0, atol=1e-5) + @onlyCUDA + @dtypes(torch.float32, torch.bfloat16) + def test_GroupNorm_backward_large_batch(self, device, dtype): + # Test GroupNorm backward with N > 128, which triggers + # GammaBetaBackwardCUDAKernel2. This kernel uses a (32, 16) block + # with WarpReduceSum after a shared memory transpose. On AMD + # wavefront-64, WarpReduceSum incorrectly summed across two tile + # columns, producing wrong dgamma/dbeta. + # bfloat16 tolerances are loose because the forward pass computes + # intermediate values (mean, var, x_norm) in bfloat16 vs float32 on CPU, + # causing accumulated differences in dgamma/dbeta. The AMD wavefront-64 + # bug this test targets produces ~100% error on all elements, so atol=1.0 + # still catches it with wide margin. + rtol = 1.0 if dtype == torch.bfloat16 else 1e-3 + atol = 1.0 if dtype == torch.bfloat16 else 1e-3 + for N in [129, 256, 512]: + for C, G in [(32, 4), (64, 8), (128, 32)]: + x = torch.randn(N, C, 16, device=device, dtype=dtype, requires_grad=True) + gamma = torch.randn(C, device=device, dtype=dtype, requires_grad=True) + beta = torch.randn(C, device=device, dtype=dtype, requires_grad=True) + + y = F.group_norm(x, G, gamma, beta) + grad = torch.randn_like(y) + y.backward(grad) + dgamma_gpu = gamma.grad.clone() + dbeta_gpu = beta.grad.clone() + + # CPU reference in float32 + x_cpu = x.detach().float().cpu().requires_grad_(True) + gamma_cpu = gamma.detach().float().cpu().requires_grad_(True) + beta_cpu = beta.detach().float().cpu().requires_grad_(True) + y_cpu = F.group_norm(x_cpu, G, gamma_cpu, beta_cpu) + y_cpu.backward(grad.float().cpu()) + + self.assertEqual( + dgamma_gpu.float().cpu(), gamma_cpu.grad, atol=atol, rtol=rtol, + msg=f"dgamma mismatch: N={N} C={C} G={G} dtype={dtype}", + ) + self.assertEqual( + dbeta_gpu.float().cpu(), beta_cpu.grad, atol=atol, rtol=rtol, + msg=f"dbeta mismatch: N={N} C={C} G={G} dtype={dtype}", + ) + @expectedFailureMPS # Double is not supported on MPS @onlyNativeDeviceTypes @dtypes(torch.float64, torch.complex128) @@ -10565,7 +10687,7 @@ def test_upsamplingNearestExact3d_correctness(self, device, memory_format, isize @parametrize_test("antialias", [True, False]) @parametrize_test("align_corners", [True, False]) - @parametrize_test("mode", ["bilinear", "bicubic"]) + @parametrize_test("mode", ["bilinear", "bicubic", "lanczos"]) @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) @expectedFailureMPS # double device type @onlyNativeDeviceTypes @@ -10573,6 +10695,14 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory # Forward AD does not support XLA because XLA tensors don't have storage check_forward_ad = torch.device(device).type != 'xla' + if mode == "lanczos": + if torch.device(device).type != "cpu": + raise SkipTest("Lanczos mode is only supported on CPU") + if not antialias: + raise SkipTest("Lanczos mode requires antialias=True") + if align_corners: + raise SkipTest("Lanczos mode does not support align_corners=True") + kwargs = dict(mode=mode, align_corners=align_corners, antialias=antialias) # test float scale factor up & downsampling for scale_factor in [0.5, 1.5, 2]: @@ -10635,7 +10765,7 @@ def test_upsamplingBiMode2d(self, device, antialias, align_corners, mode, memory @parametrize_test("antialias", [True, False]) @parametrize_test("num_channels", [3, 5]) - @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic"]) + @parametrize_test("mode", ["nearest", "nearest-exact", "bilinear", "bicubic", "lanczos"]) @parametrize_test("dtype", integral_types() + floating_types()) @skipIfMPS # Error message is wrong for some dtypes @onlyNativeDeviceTypes @@ -10650,6 +10780,14 @@ def test_upsamplingBiMode2d_nonsupported_dtypes(self, device, antialias, num_cha if dtype in (torch.uint8, ) + floating_types(): should_raise_runtime_error = False + elif mode == "lanczos": + if torch.device(device).type != "cpu": + raise SkipTest("Lanczos mode is only supported on CPU") + if not antialias: + raise SkipTest("Lanczos mode requires antialias=True") + if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8): + should_raise_runtime_error = False + elif mode in ("bilinear", "bicubic"): if dtype in floating_types() or (device == "cpu" and dtype == torch.uint8): should_raise_runtime_error = False @@ -10683,7 +10821,7 @@ def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format): # Partially passes. NotImplementedError: aten::upsample_bicubic2d.out https://github.com/pytorch/pytorch/issues/77764 @skipIfMPS @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) - @parametrize_test("mode", ["bilinear", "bicubic"]) + @parametrize_test("mode", ["bilinear", "bicubic", "lanczos"]) @parametrize_test("antialias", [True, False]) @parametrize_test("align_corners", [True, False]) @parametrize_test("num_channels", [3, 5]) @@ -10708,13 +10846,19 @@ def test_upsamplingBiMode2d_consistency( if torch.device(device).type == "cuda": raise SkipTest("CUDA implementation is not yet supporting uint8") + if mode == "lanczos": + if not antialias: + raise SkipTest("Lanczos mode requires antialias=True") + if align_corners: + raise SkipTest("Lanczos mode does not support align_corners=True") + torch.manual_seed(0) - # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create - # [intermediate] values outside of the [0, 255] range, which need - # to be clipped in uint8 path, but not in float path. This isn't - # an issue with bilinear kernel. - input_range = (30, 220) if mode == "bicubic" else (0, 256) + # - input range is set to [30, 220] for bicubic and lanczos modes, + # because these kernels may create [intermediate] values outside of + # the [0, 255] range, which need to be clipped in uint8 path, but + # not in float path. This isn't an issue with bilinear kernel. + input_range = (30, 220) if mode in ("bicubic", "lanczos") else (0, 256) input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device) input_ui8 = input_ui8.contiguous(memory_format=memory_format) @@ -10752,6 +10896,7 @@ def test_upsamplingBiMode2d_consistency( if mode == "bilinear": torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1) else: + # bicubic and lanczos diff = (output_f32 - output_ui8.float()).abs() self.assertLess(diff.max(), 15) @@ -10820,6 +10965,50 @@ def test_upsamplingBicubic2d_aa_correctness(self, device, memory_format): t_out = F.interpolate(t_in, size=(2, 2), mode="bicubic", align_corners=False, antialias=True) self.assertEqual(expected_out, t_out) + @onlyCPU + @parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) + def test_upsamplingLanczos2d_aa_correctness(self, device, memory_format): + t_in = torch.arange(3 * 8 * 8, dtype=torch.float, device=device).reshape(1, 3, 8, 8) + t_in = t_in.contiguous(memory_format=memory_format) + # This expected result is obtained using PIL.Image.resize + # for c in range(3): + # a_in = t_in.numpy()[0, c, ...] + # pil_in = Image.fromarray(a_in) + # pil_out = pil_in.resize((2, 2), resample=Image.LANCZOS) + expected_out = torch.tensor([ + 14.267621, 18.097038, 44.902962, 48.732376, 78.267616, 82.097038, + 108.902962, 112.732384, 142.267624, 146.097031, 172.902969, 176.732376 + ], device=device, dtype=t_in.dtype).reshape(1, 3, 2, 2) + t_out = F.interpolate(t_in, size=(2, 2), mode="lanczos", align_corners=False, antialias=True) + self.assertEqual(expected_out, t_out) + + @onlyCPU + def test_upsamplingLanczos2d_errors(self, device): + # 3D input (1D spatial) not supported + x_3d = torch.randn(1, 3, 8, device=device) + with self.assertRaisesRegex(ValueError, "4-D tensor"): + F.interpolate(x_3d, size=(4,), mode="lanczos", antialias=True) + + # 5D input (3D spatial) not supported + x_5d = torch.randn(1, 3, 8, 8, 8, device=device) + with self.assertRaisesRegex(ValueError, "4-D tensor"): + F.interpolate(x_5d, size=(4, 4, 4), mode="lanczos", antialias=True) + + # antialias=False not supported + x_4d = torch.randn(1, 3, 8, 8, device=device) + with self.assertRaisesRegex(ValueError, "antialias=True"): + F.interpolate(x_4d, size=(4, 4), mode="lanczos", antialias=False) + + # align_corners=True not supported + with self.assertRaisesRegex(ValueError, "align_corners=True"): + F.interpolate(x_4d, size=(4, 4), mode="lanczos", align_corners=True, antialias=True) + + @onlyCPU + def test_upsamplingLanczos2d_identity(self, device): + x = torch.randn(1, 3, 8, 8, device=device) + out = F.interpolate(x, size=(8, 8), mode="lanczos", align_corners=False, antialias=True) + self.assertEqual(x, out) + @onlyCUDA def test_upsamplingBicubic2d_many_channels(self, device): # Exercises the parallelized batch/channel kernel for small spatial diff --git a/test/test_opaque_obj_v2.py b/test/test_opaque_obj_v2.py index bfe2cc185a61c..ef9eba4c674e1 100644 --- a/test/test_opaque_obj_v2.py +++ b/test/test_opaque_obj_v2.py @@ -1,8 +1,10 @@ # Owner(s): ["module: custom-operators"] import contextlib +import enum import gc import random +import re import unittest from contextlib import ExitStack from dataclasses import dataclass @@ -11,7 +13,6 @@ import torch.distributed as dist import torch.utils._pytree as pytree from torch._dynamo.functional_export import _dynamo_graph_capture_for_export -from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import ( AotEagerAndRecordGraphs, CompileCounter, @@ -28,8 +29,9 @@ aot_export_module, ) from torch._inductor import config as inductor_config -from torch._inductor.compile_fx import compile_fx -from torch._inductor.utils import fresh_inductor_cache +from torch._inductor.compile_fx import compile_fx, compile_fx_inner +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import fresh_inductor_cache, run_and_get_code from torch._library.effects import EffectType from torch._library.fake_class_registry import FakeScriptObject, maybe_to_fake_obj from torch._library.opaque_object import ( @@ -50,6 +52,7 @@ instantiate_parametrized_tests, parametrize, ) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU class Color(OpaqueBase): @@ -343,11 +346,7 @@ def _(x, s): ) register_opaque_type(AddModule, typ="reference") register_opaque_type(ValueConfig, typ="value") -register_opaque_type( - SizeStore, - typ="value", - members={"size": MemberType.USE_REAL, "increment_size": MemberType.USE_REAL}, -) +register_opaque_type(SizeStore, typ="value") register_opaque_type(NestedValueSize, typ="value") register_opaque_type(OpaqueMultiplier, typ="reference") register_opaque_type(Color, typ="reference") @@ -440,6 +439,11 @@ def get_counter(self): class TestOpaqueObject(TestCase): def setUp(self): + # Must run first: super().setUp() can raise SkipTest (e.g. under + # PYTORCH_TEST_SKIP_FAST), and unittest skips tearDown when setUp + # raises. Any registrations before this would leak into the next test. + super().setUp() + self.lib = torch.library.Library("_TestOpaqueObject", "FRAGMENT") # noqa: TOR901 self._opaque_types_before_test = set(_OPAQUE_TYPES_BY_NAME.keys()) @@ -811,6 +815,27 @@ def backward(ctx, grad_output: torch.Tensor): lib=self.lib, ) + torch.library.define( + "_TestOpaqueObject::create_multiplier", + f"(Tensor scale) -> {opaque_multiplier_type}", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::create_multiplier", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def create_multiplier_impl(scale: torch.Tensor) -> OpaqueMultiplier: + return OpaqueMultiplier(scale.item()) + + @torch.library.register_fake( + "_TestOpaqueObject::create_multiplier", lib=self.lib + ) + def create_multiplier_fake(scale: torch.Tensor) -> OpaqueMultiplier: + return OpaqueMultiplier(0.0) + counter_type = get_opaque_type_name(Counter) torch.library.define( "_TestOpaqueObject::create_counter", @@ -862,8 +887,6 @@ def counter_start_impl(a: Counter) -> torch.Tensor: def counter_start_fake(a: Counter) -> torch.Tensor: return torch.scalar_tensor(0, dtype=torch.int64) - super().setUp() - def tearDown(self): self.lib._destroy() @@ -951,6 +974,31 @@ def test_fake_script_object_isinstance_per_type(self): self.assertIsInstance(fake_queue, FakeScriptObject) self.assertIsInstance(fake_rng, FakeScriptObject) + def test_isinstance_opaque_base_covers_all_opaque_types(self): + # isinstance(x, OpaqueBase) should match all registered opaque types, + # not just classes that directly subclass OpaqueBase. + + # Value-type opaque (Enum) — registered but doesn't subclass OpaqueBase + class MyEnum(enum.Enum): + A = 1 + + self.assertIsInstance(MyEnum.A, OpaqueBase) + + # Reference-type opaque (subclasses OpaqueBase) — sanity check + queue = OpaqueQueue([], torch.zeros(3)) + self.assertIsInstance(queue, OpaqueBase) + + # FakeScriptObject wrapping a reference-type opaque + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + with fake_mode: + fake_queue = maybe_to_fake_obj(fake_mode, queue) + self.assertIsInstance(fake_queue, FakeScriptObject) + self.assertIsInstance(fake_queue, OpaqueBase) + + # Non-opaque value should not match + self.assertNotIsInstance(42, OpaqueBase) + self.assertNotIsInstance("hello", OpaqueBase) + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) def test_make_fx(self, make_fx_tracing_mode): class M(torch.nn.Module): @@ -997,6 +1045,132 @@ def forward(self, arg0_1, arg1_1): """, ) + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) + def test_make_fx_value_type(self, make_fx_tracing_mode): + def f(x, cfg): + return torch.ops._TestOpaqueObject.process_with_config(x, cfg) + + x = torch.randn(3, 3) + cfg = ValueConfig("square") + gm = make_fx(f, tracing_mode=make_fx_tracing_mode)(x, cfg) + self.assertEqual(gm(x, cfg), f(x, cfg)) + + self.assertExpectedInline( + gm.code.strip("\n"), + """\ +def forward(self, x_1, cfg_1): + process_with_config = torch.ops._TestOpaqueObject.process_with_config.default(x_1, ValueConfig(mode='square')); x_1 = None + return process_with_config + """, + ) + + def test_subclass_opaque_output_reuses_input_proxy(self): + # Regression test: when a tensor subclass's __torch_dispatch__ wraps + # the output with the real OpaqueBase (not the FakeScriptObject proxy), + # the AOTAutograd forward graph should still reference the opaque via + # its input placeholder — not create a duplicate get_attr constant. + # + # This mirrors DTensor where C++ dispatch creates new DTensors storing + # the real DeviceMesh, so output flattening calls maybe_to_fake_obj + # and mints a fresh FakeScriptObject for the same underlying object. + + class TensorWithRealCounter(torch.Tensor): + @staticmethod + def __new__(cls, data, counter): + return torch.Tensor._make_wrapper_subclass( + cls, + data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + device=data.device, + dtype=data.dtype, + layout=data.layout, + requires_grad=data.requires_grad, + ) + + def __init__(self, data, counter): + self._data = data + self._counter = counter + + def __repr__(self): + return "TensorWithRealCounter(...)" + + def __tensor_flatten__(self): + return ["_data", "_counter"], () + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + return TensorWithRealCounter( + inner_tensors["_data"], inner_tensors["_counter"] + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if kwargs is None: + kwargs = {} + + counter = None + for arg in torch.utils._pytree.tree_leaves(args): + if isinstance(arg, TensorWithRealCounter): + counter = arg._counter + break + + def unwrap(x): + return x._data if isinstance(x, TensorWithRealCounter) else x + + out = func( + *torch.utils._pytree.tree_map(unwrap, args), + **torch.utils._pytree.tree_map(unwrap, kwargs), + ) + + # Unwrap FakeScriptObject to real OpaqueBase, simulating what + # happens in DTensor's C++ dispatch path. + real_counter = counter + if isinstance(counter, FakeScriptObject): + real_counter = counter.real_obj + + return torch.utils._pytree.tree_map( + lambda x: TensorWithRealCounter(x, real_counter) + if isinstance(x, torch.Tensor) + else x, + out, + ) + + counter = Counter(start=0, end=10) + x = TensorWithRealCounter(torch.randn(4), counter) + + backend = AotEagerAndRecordGraphs() + torch.compile(lambda x: x + 1, fullgraph=True, backend=backend)(x) + + fw = backend.fw_graphs[0] + get_attr_nodes = [n for n in fw.graph.nodes if n.op == "get_attr"] + self.assertEqual( + get_attr_nodes, + [], + "Opaque output should reuse the input placeholder, not create a get_attr constant", + ) + + def test_guard_pickle_subclass_with_opaque_inner_attr(self): + # Regression test: the guard state pickler serializes tensor subclasses + # by iterating over __tensor_flatten__ inner attrs. Opaque inner attrs + # (e.g. Counter) must be handled correctly — they are pickled by the + # normal pickle machinery alongside tensor inner attrs. + from torch._dynamo.guards import GuardsStatePickler + + a = torch.randn(4) + b = torch.randn(4) + counter = Counter(start=0, end=10) + size = SizeStore(4) + x = TensorWithCounter(a, b, counter, size) + + import io + + buf = io.BytesIO() + pickler = GuardsStatePickler({id(x): x}, {}, {}, buf) + func, args = pickler.reducer_override(x) + obj = func(*args) + self.assertIsInstance(obj, torch.Tensor) + @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) def test_bad_fake(self, make_fx_tracing_mode): torch.library.define( @@ -1071,7 +1245,7 @@ def forward(self, arg0_1, arg1_1): mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None - return (add,)""", # noqa: B950 + return (add,)""", ) torch.library._register_effectful_op( @@ -1093,7 +1267,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): getitem_2 = with_effects_1[0] getitem_3 = with_effects_1[1]; with_effects_1 = None add = torch.ops.aten.add.Tensor(getitem_3, getitem_3); getitem_3 = None - return (getitem_2, add)""", # noqa: B950 + return (getitem_2, add)""", ) finally: torch.library._register_effectful_op( @@ -1116,7 +1290,7 @@ def forward(self, x): l_flat_args_0_ = arg_0 l__self____export_root___closure___0_cell_contents = self.L__self____export_root___closure___0_cell_contents res = torch.ops._TestOpaqueObject.noisy_inject(l_flat_args_0_, l__self____export_root___closure___0_cell_contents); l_flat_args_0_ = l__self____export_root___closure___0_cell_contents = None - return pytree.tree_unflatten((res,), self._out_spec)""", # noqa: B950 + return pytree.tree_unflatten((res,), self._out_spec)""", ) def test_compile1(self): @@ -1151,7 +1325,7 @@ def forward(self, L_rng_state_ : {fx_class}, L_x_ : torch.Tensor): x_1 = x * x; x = None x_2 = torch.ops._TestOpaqueObject.noisy_inject(x_1, l_rng_state_); x_1 = l_rng_state_ = None x_3 = x_2 + x_2; x_2 = None - return (x_3,)""", # noqa: B950 + return (x_3,)""", ) self.assertExpectedInline( backend.fw_graphs[0].code.strip(), @@ -1161,7 +1335,7 @@ def forward(self, arg0_1, arg1_1): mul = torch.ops.aten.mul.Tensor(noisy_inject, noisy_inject); noisy_inject = None noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None - return (add,)""", # noqa: B950 + return (add,)""", ) def test_compile_inline_methods(self): @@ -1188,7 +1362,7 @@ def forward(self, arg0_1, arg1_1): mul = torch.ops.aten.mul.Tensor(noisy_inject, 1); noisy_inject = None noisy_inject_1 = torch.ops._TestOpaqueObject.noisy_inject.default(mul, arg0_1); mul = arg0_1 = None add = torch.ops.aten.add.Tensor(noisy_inject_1, noisy_inject_1); noisy_inject_1 = None - return (add,)""", # noqa: B950 + return (add,)""", ) res = torch.compile(foo, fullgraph=True, backend="inductor")(rng, x) @@ -1253,15 +1427,15 @@ def foo(nested_counter, x): self.assertExpectedInline( backend.graphs[0].code.strip(), f"""\ -def forward(self, L_x_ : torch.Tensor, object_getattribute_L_nested_counter_c_0_ : {fx_class}, object_getattribute_L_nested_counter_c_1_ : {fx_class}): +def forward(self, L_x_ : torch.Tensor, L_nested_counter_c_0_ : {fx_class}, L_nested_counter_c_1_ : {fx_class}): l_x_ = L_x_ - object_getattribute_l_nested_counter_c_0_ = object_getattribute_L_nested_counter_c_0_ - object_getattribute_l_nested_counter_c_1_ = object_getattribute_L_nested_counter_c_1_ - x = torch.ops._TestOpaqueObject.increment_counter(object_getattribute_l_nested_counter_c_0_, l_x_); object_getattribute_l_nested_counter_c_0_ = l_x_ = None - x_1 = torch.ops._TestOpaqueObject.increment_counter(object_getattribute_l_nested_counter_c_1_, x); object_getattribute_l_nested_counter_c_1_ = x = None + l_nested_counter_c_0_ = L_nested_counter_c_0_ + l_nested_counter_c_1_ = L_nested_counter_c_1_ + x = torch.ops._TestOpaqueObject.increment_counter(l_nested_counter_c_0_, l_x_); l_nested_counter_c_0_ = l_x_ = None + x_1 = torch.ops._TestOpaqueObject.increment_counter(l_nested_counter_c_1_, x); l_nested_counter_c_1_ = x = None x_2 = x_1 + 1; x_1 = None x_3 = x_2 + 2; x_2 = None - return (x_3,)""", # noqa: B950 + return (x_3,)""", ) def test_nested_reference_trace(self): @@ -1286,25 +1460,25 @@ def foo(nested_queue, x): self.assertExpectedInline( backend.graphs[0].code.strip(), f"""\ -def forward(self, L_x_ : torch.Tensor, object_getattribute_L_nested_queue_q_ : {fx_class}): +def forward(self, L_x_ : torch.Tensor, L_nested_queue_q : {fx_class}): l_x_ = L_x_ - object_getattribute_l_nested_queue_q_ = object_getattribute_L_nested_queue_q_ + l_nested_queue_q = L_nested_queue_q tan = l_x_.tan() - queue_push = torch.ops._TestOpaqueObject.queue_push(object_getattribute_l_nested_queue_q_, tan); tan = queue_push = None + queue_push = torch.ops._TestOpaqueObject.queue_push(l_nested_queue_q, tan); tan = queue_push = None cos = l_x_.cos(); l_x_ = None - queue_push_1 = torch.ops._TestOpaqueObject.queue_push(object_getattribute_l_nested_queue_q_, cos); cos = queue_push_1 = None - pop1 = torch.ops._TestOpaqueObject.queue_pop(object_getattribute_l_nested_queue_q_) + queue_push_1 = torch.ops._TestOpaqueObject.queue_push(l_nested_queue_q, cos); cos = queue_push_1 = None + pop1 = torch.ops._TestOpaqueObject.queue_pop(l_nested_queue_q) sym_size_int = torch.ops.aten.sym_size.int(pop1, 0) ge = sym_size_int >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None - pop2 = torch.ops._TestOpaqueObject.queue_pop(object_getattribute_l_nested_queue_q_); object_getattribute_l_nested_queue_q_ = None + pop2 = torch.ops._TestOpaqueObject.queue_pop(l_nested_queue_q); l_nested_queue_q = None sym_size_int_1 = torch.ops.aten.sym_size.int(pop2, 0) ge_1 = sym_size_int_1 >= 0 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default_1 = None eq = sym_size_int == sym_size_int_1; sym_size_int = sym_size_int_1 = None _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u0, u1) on node 'eq'"); eq = _assert_scalar_default_2 = None add = pop1 + pop2; pop1 = pop2 = None - return (add,)""", # noqa: B950 + return (add,)""", ) # inputs: (token, nested_queue.q, x) @@ -1333,7 +1507,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): eq_2 = sym_size_int == sym_size_int_1; sym_size_int = sym_size_int_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq_2, "Runtime assertion failed for expression Eq(u0, u1) on node 'eq'"); eq_2 = _assert_scalar_2 = None add_4 = torch.ops.aten.add.Tensor(getitem_5, getitem_7); getitem_5 = getitem_7 = None - return (getitem_6, add_4)""", # noqa: B950 + return (getitem_6, add_4)""", ) def test_compile_global(self): @@ -1374,7 +1548,7 @@ def forward(self, arg0_1, arg1_1, arg2_1): getitem_3 = auto_functionalized_v2_1[1]; auto_functionalized_v2_1 = None add = torch.ops.aten.add.Tensor(mul, getitem_2); mul = getitem_2 = None copy_ = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = getitem_3 = copy_ = None - return (add,)""", # noqa: B950 + return (add,)""", ) def test_compile_create_intermediate(self): @@ -1428,23 +1602,107 @@ def foo(counter, x): NestedCounters(Counter(1, 5)), torch.ones(2, 3) ) - config = ValueConfig("double") + def test_compile_fixed_stride_order(self): + hs_name = get_opaque_type_name(HoistedString) - def foo(mode, x): - return config.mode + torch.library.define( + "_TestOpaqueObject::stride_op", + f"(Tensor x, {hs_name} s) -> Tensor", + tags=(torch.Tag.needs_fixed_stride_order,), + lib=self.lib, + ) - with self.assertRaisesRegex( - RuntimeError, "Attempted to access unregistered member on an OpaqueObject" - ): - torch.compile(foo, backend="eager")(config, torch.ones(2, 3)) + @torch.library.impl( + "_TestOpaqueObject::stride_op", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def stride_op_impl(x: torch.Tensor, s: HoistedString) -> torch.Tensor: + return x * 2.0 - def bar(mode, x): - config.print_mode() + @torch.library.register_fake("_TestOpaqueObject::stride_op", lib=self.lib) + def stride_op_fake(x, s): + return torch.empty_like(x) - with self.assertRaisesRegex( - RuntimeError, "Attempted to access unregistered member on an OpaqueObject" - ): - torch.compile(bar, backend="eager")(config, torch.ones(2, 3)) + def fn(x, s): + return torch.ops._TestOpaqueObject.stride_op(x, s) + + s = HoistedString("double") + x = torch.randn(4, 4) + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(x, s) + + expected = x * 2.0 + self.assertEqual(result, expected) + + def test_compile_exact_strides(self): + hs_name = get_opaque_type_name(HoistedString) + + torch.library.define( + "_TestOpaqueObject::exact_op", + f"(Tensor x, {hs_name} s) -> Tensor", + tags=(torch.Tag.needs_exact_strides,), + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::exact_op", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def exact_op_impl(x: torch.Tensor, s: HoistedString) -> torch.Tensor: + return x * 3.0 + + @torch.library.register_fake("_TestOpaqueObject::exact_op", lib=self.lib) + def exact_op_fake(x, s): + return torch.empty_like(x) + + def fn(x, s): + return torch.ops._TestOpaqueObject.exact_op(x, s) + + s = HoistedString("double") + x = torch.randn(4, 4) + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(x, s) + + expected = x * 3.0 + self.assertEqual(result, expected) + + def test_compile_contiguous_strides(self): + hs_name = get_opaque_type_name(HoistedString) + + torch.library.define( + "_TestOpaqueObject::contig_op", + f"(Tensor x, {hs_name} s) -> Tensor", + tags=(torch.Tag.needs_contiguous_strides,), + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::contig_op", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def contig_op_impl(x: torch.Tensor, s: HoistedString) -> torch.Tensor: + return x * 4.0 + + @torch.library.register_fake("_TestOpaqueObject::contig_op", lib=self.lib) + def contig_op_fake(x, s): + return torch.empty_like(x) + + def fn(x, s): + return torch.ops._TestOpaqueObject.contig_op(x, s) + + s = HoistedString("double") + x = torch.randn(4, 4) + + compiled_fn = torch.compile(fn, backend="inductor") + result = compiled_fn(x, s) + + expected = x * 4.0 + self.assertEqual(result, expected) def test_export_joint(self): torch.library.define( @@ -1502,7 +1760,7 @@ def forward(self, primals, tangents): _opaque_obj0 = self._opaque_obj0 module_mul = torch.ops._TestOpaqueObject.module_mul.default(_opaque_obj0, primals_1, _local_scalar_dense); _opaque_obj0 = primals_1 = None mul_1 = torch.ops.aten.mul.Tensor(tangents_1, _local_scalar_dense); tangents_1 = _local_scalar_dense = None - return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", # noqa: B950 + return pytree.tree_unflatten([module_mul, mul_1, None], self._out_spec)""", ) compiled_fn = aot_compile_joint_with_descriptors(joint) @@ -1676,6 +1934,20 @@ def foo(x, cfg): self.assertEqual(res, x + x) self.assertEqual(cnt.frame_count, 2) + @parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_value_type_graph_output(self, backend): + def foo(x): + return x * x, ValueConfig("square") + + x = torch.randn(3, 3) + opt_f = torch.compile(foo, fullgraph=True, backend=backend) + res = opt_f(x) + self.assertEqual(res[1], ValueConfig("square")) + + gm = _dynamo_graph_capture_for_export(foo)(x) + res = gm(x) + self.assertEqual(res[1], ValueConfig("square")) + def test_value_type_graph_input(self): # Even though cfg is an input, it should not be an input to the dynamo # graph. Instead it should directly put in the graph argument as a @@ -1695,14 +1967,14 @@ def foo(x, cfg): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ process_with_config = torch.ops._TestOpaqueObject.process_with_config(l_x_, ValueConfig(mode='square')); l_x_ = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) self.assertExpectedInline( backend.fw_graphs[0].code.strip(), """\ def forward(self, arg0_1): process_with_config = torch.ops._TestOpaqueObject.process_with_config.default(arg0_1, ValueConfig(mode='square')); arg0_1 = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) opt_f(x, ValueConfig("double")) @@ -1712,7 +1984,7 @@ def forward(self, arg0_1): """\ def forward(self, arg0_1): process_with_config = torch.ops._TestOpaqueObject.process_with_config.default(arg0_1, ValueConfig(mode='double')); arg0_1 = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) def test_value_type_graph_intermediate(self): @@ -1731,14 +2003,14 @@ def foo(x, config): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ process_with_config = torch.ops._TestOpaqueObject.process_with_config(l_x_, ValueConfig(mode='square')); l_x_ = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) self.assertExpectedInline( backend.fw_graphs[0].code.strip(), """\ def forward(self, arg0_1): process_with_config = torch.ops._TestOpaqueObject.process_with_config.default(arg0_1, ValueConfig(mode='square')); arg0_1 = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) opt_f(x, "double") @@ -1747,7 +2019,7 @@ def forward(self, arg0_1): """\ def forward(self, arg0_1): process_with_config = torch.ops._TestOpaqueObject.process_with_config.default(arg0_1, ValueConfig(mode='double')); arg0_1 = None - return (process_with_config,)""", # noqa: B950 + return (process_with_config,)""", ) opt_f = torch.compile(foo, fullgraph=True, backend="inductor") @@ -1776,7 +2048,27 @@ def forward(self, arg0_1): ones = torch.ops.aten.ones.default([3], device = device(type='cpu'), pin_memory = False) cat = torch.ops.aten.cat.default([arg0_1, ones]); arg0_1 = ones = None add = torch.ops.aten.add.Tensor(cat, 3); cat = None - return (add,)""", # noqa: B950 + return (add,)""", + ) + + def test_value_type_unregistered_method(self): + # Unregistered methods on value types should inline (no error) + def foo(x): + cfg = ValueConfig("square") + return x + len(cfg.mode) + + x = torch.randn(3) + backend = AotEagerAndRecordGraphs() + opt_f = torch.compile(foo, fullgraph=True, backend=backend) + res = opt_f(x) + self.assertEqual(res, foo(x)) + + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ +def forward(self, arg0_1): + add = torch.ops.aten.add.Tensor(arg0_1, 6); arg0_1 = None + return (add,)""", ) def test_weakref_cleanup(self): @@ -1828,7 +2120,7 @@ def foo(x, config): """\ def forward(self, arg0_1): process_nested_config = torch.ops._TestOpaqueObject.process_nested_config.default(arg0_1, NestedValueSize(size=SizeStore(size=3), config=ValueConfig(mode='square'))); arg0_1 = None - return (process_nested_config,)""", # noqa: B950 + return (process_nested_config,)""", ) opt_f = torch.compile(foo, fullgraph=True, backend="inductor") @@ -1853,7 +2145,7 @@ def foo(x): """\ def forward(self, arg0_1): process_multiple_sizes = torch.ops._TestOpaqueObject.process_multiple_sizes.default(arg0_1, [SizeStore(size=3), SizeStore(size=3)]); arg0_1 = None - return (process_multiple_sizes,)""", # noqa: B950 + return (process_multiple_sizes,)""", ) opt_f = torch.compile(foo, fullgraph=True, backend="inductor") @@ -1898,7 +2190,7 @@ def forward(self, L_x_: "TensorWithCounter(i64[])"): get_start_tensor: "i64[]" = getitem.get_start_tensor(); getitem = None mul_1: "TensorWithCounter(i64[])" = y * get_start_tensor; y = get_start_tensor = None return (mul_1,) -""", # noqa: B950 +""", ) def test_tensor_subclass_with_opaque_attr(self): @@ -2020,6 +2312,71 @@ def fn(x): self.assertIs(x2.grad._counter, counter2) self.assertEqual(x2.grad._size_store, size2) + def test_tangent_primal_proxy_collision_for_opaque_inner_attr(self): + """Regression test for tangent/primal proxy collision. + + When a tensor subclass has an opaque inner attr, joint graph tracing + creates separate FakeScriptObject wrappers for the primal and tangent + that share the same underlying real object. set_proxy_slot must map + the tangent wrapper to the *primal* proxy so that forward outputs + don't spuriously depend on tangent placeholders (which would crash the + partitioner with 'Node tangents_N was invalid, but is output'). + """ + from torch._library.fake_class_registry import FakeScriptObject + from torch.fx.experimental.proxy_tensor import ( + _GraphAppendingTracerEx, + set_proxy_slot, + ) + + counter = Counter(start=3, end=10) + fso_primal = FakeScriptObject(counter, "Counter", counter) + fso_tangent = FakeScriptObject(counter, "Counter", counter) + # Sanity: different wrappers, same real_obj + self.assertIsNot(fso_primal, fso_tangent) + self.assertIs( + object.__getattribute__(fso_primal, "real_obj"), + object.__getattribute__(fso_tangent, "real_obj"), + ) + + graph = torch.fx.Graph() + tracer = _GraphAppendingTracerEx(graph) + + primal_node = graph.placeholder("primals_1") + tangent_node = graph.placeholder("tangents_1") + primal_proxy = torch.fx.Proxy(primal_node, tracer) + tangent_proxy = torch.fx.Proxy(tangent_node, tracer) + + # Register primal first, then tangent (mirrors joint graph tracing) + set_proxy_slot(fso_primal, tracer, primal_proxy) + set_proxy_slot(fso_tangent, tracer, tangent_proxy) + + # Both wrappers should resolve to the primal proxy + self.assertIs(tracer.opaque_tracker[fso_primal].node, primal_node) + self.assertIs(tracer.opaque_tracker[fso_tangent].node, primal_node) + + def test_opaque_produced_by_call_function_saved_for_backward(self): + """Test that an opaque object produced by a call_function node + (not a placeholder) is correctly saved for backward. + + Without is_opaque_node() in the min_cut partitioner, this crashes + with 'AOT Autograd failed to partition' because the opaque node + is a non-tensor call_function that can't be saved or recomputed.""" + + def fn(x): + scale = torch.tensor(2.5) + multiplier = torch.ops._TestOpaqueObject.create_multiplier(scale) + return torch.ops._TestOpaqueObject.mul_with_scale(multiplier, x) + + x = torch.randn(3, 3, requires_grad=True) + opt_fn = torch.compile(fn, backend="aot_eager", fullgraph=True) + out = opt_fn(x) + self.assertTrue(torch.allclose(out, x * 2.5)) + + out.backward(torch.ones_like(out)) + self.assertIsNotNone(x.grad) + expected_grad = torch.ones_like(x) * 2.5 + self.assertTrue(torch.allclose(x.grad, expected_grad)) + def test_tensor_subclass_opaque_backward_compiled_autograd(self): """Test opaque objects work with compiled autograd backward.""" import torch._dynamo.compiled_autograd @@ -2584,7 +2941,7 @@ def forward(self, L_x_ : torch.Tensor, L_scale_obj_ : {_illegal_char_regex.sub(" l_scale_obj_ = L_scale_obj_ result = torch.ops._TestOpaqueObject.mul_with_scale(l_scale_obj_, l_x_); l_scale_obj_ = l_x_ = None result_1 = result * 2; result = None - return (result_1,)""", # noqa: B950 + return (result_1,)""", ) backend = AotEagerAndRecordGraphs() @@ -2685,7 +3042,7 @@ def forward(self, L_x_ : torch.Tensor, G_Color_GREEN : {_illegal_char_regex.sub( l_x_ = L_x_ g_color_green = G_Color_GREEN apply_color_scale = torch.ops._TestOpaqueObject.apply_color_scale(g_color_green, l_x_); g_color_green = l_x_ = None - return (apply_color_scale,)""", # noqa: B950 + return (apply_color_scale,)""", ) def test_hoist_basic(self): @@ -2749,6 +3106,23 @@ def g(x): self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1) self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1) + def test_hoisted_value_type_make_fx(self): + def foo(x, hoisted_str): + return op_with_string(x, hoisted_str) + + x = torch.randn(3, 3) + gm = make_fx(foo)(x, HoistedString("double")) + + self.assertExpectedInline( + gm.code.strip(), + """\ +def forward(self, x_1, hoisted_str_1): + op_with_string = torch.ops.mylib.op_with_string.default(x_1, hoisted_str_1); x_1 = hoisted_str_1 = None + return op_with_string""", + ) + self.assertEqual(gm(x, HoistedString("double")), x * 2) + self.assertEqual(gm(x, HoistedString("square")), x * x) + def test_opaque_class_literal_attribute_inlined(self): """Test that literal attributes on opaque classes are inlined without source tracking. @@ -2918,28 +3292,31 @@ def gn(scale_obj, x): res.sum().backward() actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) + # Normalize module-qualified opaque type names since they differ + # depending on how the test is invoked (__main__ vs test_opaque_obj_v2). + actual = re.sub(r"\w+_OpaqueMultiplier", "OpaqueMultiplier", actual) self.assertExpectedInline( actual, """\ class GraphModule(torch.nn.Module): - def forward(self, L_x_: "f32[2, 2]", L_scale_obj_ : __main___OpaqueMultiplier): + def forward(self, L_x_: "f32[2, 2]", L_scale_obj_ : OpaqueMultiplier): l_x_ = L_x_ l_scale_obj_ = L_scale_obj_ subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_scale_obj_, l_x_); subgraph_0 = l_scale_obj_ = l_x_ = None - getitem_2: "f32[2, 2]" = invoke_subgraph[0]; invoke_subgraph = None + getitem: "f32[2, 2]" = invoke_subgraph[0]; invoke_subgraph = None - add: "f32[2, 2]" = getitem_2 + getitem_2; getitem_2 = None + add: "f32[2, 2]" = getitem + getitem; getitem = None return (add,) class subgraph_0(torch.nn.Module): - def forward(self, l_scale_obj_ : __main___OpaqueMultiplier, l_x_: "f32[2, 2]"): + def forward(self, l_scale_obj_ : OpaqueMultiplier, l_x_: "f32[2, 2]"): result: "f32[2, 2]" = torch.ops._TestOpaqueObject.mul_with_scale(l_scale_obj_, l_x_); l_scale_obj_ = l_x_ = None result_1: "f32[2, 2]" = result * 2; result = None return (result_1,) -""", # noqa: B950 +""", ) self.assertEqual(ref, res) @@ -3032,7 +3409,7 @@ def forward(self, x): def forward(self, p_linear_weight, p_linear_bias, obj_lifted_custom_0, x): noisy_inject = torch.ops._TestOpaqueObject.noisy_inject.default(x, obj_lifted_custom_0); obj_lifted_custom_0 = noisy_inject = None linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None - return (linear,)""", # noqa: B950 + return (linear,)""", ) def test_hoist_no_recompile_on_different_string(self): @@ -3052,6 +3429,152 @@ def f(x, label): self.assertEqual(res2, f(x, "square")) self.assertEqual(cnt.frame_count, 1) + def test_opaque_multi_output_not_tensor_irnode(self): + """OpaqueMultiOutput must not be classified as a tensor IR node. + + _is_tensor_irnode is used to select nodes for stride constraint + functions (require_stride1, require_contiguous, etc). If + OpaqueMultiOutput passes the check, those functions would call + get_stride/get_dtype/make_loader on it and crash.""" + from unittest.mock import MagicMock + + from torch._inductor import ir + from torch._inductor.lowering import _is_tensor_irnode + + opaque_multi = ir.OpaqueMultiOutput.__new__(ir.OpaqueMultiOutput) + self.assertIsInstance(opaque_multi, ir.IRNode) + self.assertFalse(_is_tensor_irnode(opaque_multi)) + + non_tensor = MagicMock(spec=ir.NonTensorObj) + non_tensor.__class__ = ir.NonTensorObj + self.assertFalse(_is_tensor_irnode(non_tensor)) + + self.assertFalse(_is_tensor_irnode(42)) + + @unittest.skipIf(not dist.is_available(), "requires distributed") + def test_fake_script_object_process_group_pybind(self): + """FakeScriptObject wrapping ProcessGroup must be unwrapped in the + pybind toIValue before casting to c10::intrusive_ptr. + + During Dynamo tracing with CooR, mesh_get_process_group returns a + FakeScriptObject-wrapped ProcessGroup. When that flows into a C++ + custom op, toIValue needs to unwrap real_obj before the pybind + cast. This test exercises that path by calling + mesh_get_process_group (which returns a wrapped PG during tracing) + and passing it to another op that consumes the ProcessGroup.""" + from torch.distributed.device_mesh import ( + _register_distributed_opaque_types, + DeviceMesh, + ) + from torch.testing._internal.distributed.fake_pg import FakeStore + + already_initialized = dist.is_initialized() + if already_initialized: + dist.destroy_process_group() + + dist.init_process_group("fake", store=FakeStore(), rank=0, world_size=2) + try: + _register_distributed_opaque_types() + mesh = DeviceMesh("cpu", torch.arange(2)) + + def f(mesh, x): + pg = torch.ops._dtensor.mesh_get_process_group(mesh, 0) + return x + pg.size() + + x = torch.randn(4) + compiled_f = torch.compile(f, backend="aot_eager", fullgraph=True) + result = compiled_f(mesh, x) + expected = f(mesh, x) + self.assertEqual(result, expected) + finally: + dist.destroy_process_group() + if already_initialized: + dist.init_process_group("fake", store=FakeStore(), rank=0, world_size=2) + + def test_enum_export(self): + class Direction(enum.Enum): + UP = 0 + DOWN = 1 + + class Mod(torch.nn.Module): + def forward(self, x, d): + return x + d.value + + ep = torch.export.export(Mod(), (torch.randn(4, 4), Direction.UP), strict=False) + self.assertEqual( + ep.module()(torch.ones(4, 4), Direction.UP), + torch.ones(4, 4) + Direction.UP.value, + ) + self.assertExpectedInline( + normalize_gm(ep.graph_module.print_readable(False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, x: "f32[4, 4]", d): + add: "f32[4, 4]" = torch.ops.aten.add.Tensor(x, 0); x = None + return (add,) +""", + ) + + backend = EagerAndRecordGraphs() + opt_fn = torch.compile(Mod(), backend=backend) + x = torch.randn(4, 4) + res = opt_fn(x, Direction.UP) + self.assertEqual( + res, + x + Direction.UP.value, + ) + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + add: "f32[4, 4]" = l_x_ + 0; l_x_ = None + return (add,) +""", + ) + + def test_enum_custom_op(self): + def get_color(): + class Color(enum.Enum): + RED = 0 + GREEN = 1 + BLUE = 2 + + return Color + + Color = get_color() + + @torch.library.custom_op("test_enum::add_color", mutates_args=()) + def add_color(x: torch.Tensor, c: Color) -> torch.Tensor: + return x + c.value + + @add_color.register_fake + def _(x, c): + return torch.empty_like(x) + + def fn(x, c): + return add_color(x, c) + + x = torch.randn(4, 4) + ref = fn(x, Color.GREEN) + backend = EagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + res = opt_fn(x, Color.GREEN) + self.assertEqual(ref, res) + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[4, 4]"): + l_x_ = L_x_ + + add_color_default: "f32[4, 4]" = torch.ops.test_enum.add_color.default(l_x_, Color.GREEN); l_x_ = None + return (add_color_default,) +""", + ) + def test_subclass_parametrization_with_opaque_attrs(self): """unwrap_tensor_subclass_parameters should handle non-tensor attrs.""" from torch._functorch._aot_autograd.subclass_parametrization import ( @@ -3154,9 +3677,389 @@ def fn2(x): self.assertEqual(out2._size_store, size2) self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) + def test_op_passthrough_counter_in_tuple(self): + # When a fake kernel returns its Counter input directly, the getitem + # proxy's example_value is already a FakeScriptObject. + counter_type = get_opaque_type_name(Counter) + torch.library.define( + "_TestOpaqueObject::passthrough_counter", + f"({counter_type} c, Tensor x) -> ({counter_type}, Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=self.lib, + ) + + @torch.library.impl( + "_TestOpaqueObject::passthrough_counter", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def passthrough_impl(c: Counter, x: torch.Tensor): + return c, x * c.start + + @torch.library.register_fake( + "_TestOpaqueObject::passthrough_counter", lib=self.lib + ) + def passthrough_fake(c: Counter, x: torch.Tensor): + return c, torch.empty_like(x) + + def fn(c, x): + out_c, out_x = torch.ops._TestOpaqueObject.passthrough_counter(c, x) + return torch.ops._TestOpaqueObject.counter_start(out_c) + out_x + + c = Counter(3, 10) + x = torch.randn(4) + ref = fn(c, x) + opt_fn = torch.compile(fn, fullgraph=True, backend="eager") + res = opt_fn(c, x) + self.assertEqual(ref, res) + + @torch._dynamo.config.patch(skip_fwd_side_effects_in_bwd_under_checkpoint=True) + def test_script_object_intermediate_exposed_from_checkpoint(self): + # A TorchScriptObjectVariable created inside an AC region and accessed + # outside via a list side effect must be exposed as a subgraph output. + import torch.utils.checkpoint + + def gn(x, results): + counter = torch.ops._TestOpaqueObject.create_counter(x.shape[0], x.shape[0]) + results.append(counter) + return x * 2 + + def fn(x): + results = [] + out = torch.utils.checkpoint.checkpoint(gn, x, results, use_reentrant=False) + return torch.ops._TestOpaqueObject.counter_start(results[0]) + out + + x = torch.randn(3, 4) + ref = fn(x) + opt_fn = torch.compile(fn, fullgraph=True, backend="aot_eager_decomp_partition") + res = opt_fn(x) + self.assertEqual(ref, res) + + def test_reference_type_opaque_object_state(self): + """When compile_fx_inner receives an FX graph whose placeholder meta['val'] + is a raw opaque reference type (not wrapped in FakeScriptObject), inductor + handles it via OpaqueObjectState. This codepath is used by compile-on-one-rank + (COOR) where the FX graph is constructed with real objects.""" + m = OpaqueMultiplier(2.0) + x = torch.ones(3) + + graph = torch.fx.Graph() + m_node = graph.placeholder("m") + m_node.meta["val"] = m + fake_mode = FakeTensorMode() + x_node = graph.placeholder("x") + x_node.meta["val"] = fake_mode.from_tensor(x) + out = graph.call_function( + torch.ops._TestOpaqueObject.mul_with_scale.default, (m_node, x_node) + ) + out.meta["val"] = fake_mode.from_tensor(x) + graph.output((out,)) + + gm = torch.fx.GraphModule({}, graph) + compiled = compile_fx_inner(gm, [m, x]) + result = compiled([m, x]) + self.assertEqual(result, (x * 2,)) + + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_benchmark_harness_no_pickle_for_opaque_inputs(self): + """Opaque graph inputs must not be pickled in the benchmark harness.""" + a = torch.randn(4, 4, device=GPU_TYPE) + b = torch.randn(4, 4, device=GPU_TYPE) + twc = TensorWithCounter(a, b, Counter(0, 10), SizeStore(4)) + + def fn(x): + return x + 1 + + compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True) + _, codes = run_and_get_code(compiled_fn, twc) + self.assertGreater(len(codes), 0) + for code in codes: + self.assertNotIn("pickle", code) + + def test_opaque_object_state_in_graph_output(self): + """When compile_fx_inner receives a graph where a raw opaque reference + type appears in the outputs (not just inputs), inductor must handle it. + This happens in CooR (compile-on-one-rank) precompile: aot_autograd + partitions the joint graph so the forward graph saves opaque objects + (e.g. DeviceMesh) for the backward graph by including them in forward + outputs. Inductor must: + 1. Accept OpaqueObjectState in the output type assertion + (GraphLowering.output). + 2. Handle NonTensorObj graph inputs in the memory planner + (get_dep_size_hint) without calling get_numel/get_dtype on them.""" + m = OpaqueMultiplier(2.0) + x = torch.ones(3) + + graph = torch.fx.Graph() + m_node = graph.placeholder("m") + m_node.meta["val"] = m + fake_mode = FakeTensorMode() + x_node = graph.placeholder("x") + x_node.meta["val"] = fake_mode.from_tensor(x) + out = graph.call_function( + torch.ops._TestOpaqueObject.mul_with_scale.default, (m_node, x_node) + ) + out.meta["val"] = fake_mode.from_tensor(x) + # Include the opaque object in the output tuple, simulating how + # aot_autograd's forward graph passes saved-for-backward objects + # through as outputs. + graph.output((out, m_node)) + + gm = torch.fx.GraphModule({}, graph) + compiled = compile_fx_inner(gm, [m, x]) + result = compiled([m, x]) + self.assertEqual(result[0], x * 2) + self.assertIs(result[1], m) + + def test_reconstruct_fn_sets_meta_val(self): + """Opaque nodes created via reconstruct_fn have meta['val'] set. + + When _try_reconstruct_opaque dispatches through a custom op via + Proxy.__torch_function__, the resulting node does not get + meta['val'] set automatically (unlike the __torch_dispatch__ + path which calls track_tensor_tree). The fix in + _try_reconstruct_opaque ensures set_meta is called so that + downstream consumers like the min-cut partitioner can classify + the node correctly via is_opaque_node(). + """ + from torch._functorch._aot_autograd.graph_compile import is_opaque_node + + # Use the already-registered OpaqueMultiplier type. + # Register a reconstruct_fn that derives one multiplier from + # another via a custom op — mirrors how DeviceMesh submeshes + # are derived from a parent mesh via _get_submesh. + multiplier_type = get_opaque_type_name(OpaqueMultiplier) + + self.lib.define( + f"derive_multiplier({multiplier_type} parent, float scale)" + f" -> {multiplier_type}", + ) + + @torch.library.impl( + "_TestOpaqueObject::derive_multiplier", + "CompositeExplicitAutograd", + lib=self.lib, + ) + def derive_impl(parent, scale): + return OpaqueMultiplier(parent.multiplier * scale) + + @torch.library.register_fake( + "_TestOpaqueObject::derive_multiplier", lib=self.lib + ) + def derive_fake(parent, scale): + return OpaqueMultiplier(0.0) + + parent = OpaqueMultiplier(3.0) + # The "child" is derived at eager time before tracing. + # It's NOT passed as an input — it's captured in a closure, + # simulating how DeviceMesh submeshes are captured in DTensor + # backward closures. This forces _try_reconstruct_opaque to + # be called when make_fx encounters the child. + child = OpaqueMultiplier(parent.multiplier * 0.5) + child._parent = parent + child._scale = 0.5 + + from torch._library.opaque_object import _OPAQUE_TYPES + + original_reconstruct_fn = _OPAQUE_TYPES[OpaqueMultiplier].reconstruct_fn + + def multiplier_reconstruct_fn(obj, get_tracked_proxy, tracer): + if not hasattr(obj, "_parent"): + return None + parent_proxy = get_tracked_proxy(obj._parent) + if parent_proxy is None: + return None + return torch.ops._TestOpaqueObject.derive_multiplier( + parent_proxy, obj._scale + ) + + _OPAQUE_TYPES[OpaqueMultiplier].reconstruct_fn = multiplier_reconstruct_fn + try: + # child is captured from the closure, NOT an input + def fn(parent, x): + return torch.ops._TestOpaqueObject.mul_with_scale(child, x) + + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + with fake_mode: + fake_x = torch.randn(4) + + gm = make_fx(fn, tracing_mode="fake")(parent, fake_x) + + # Find the derive_multiplier node (created by reconstruct_fn) + derive_nodes = [ + n + for n in gm.graph.nodes + if n.op == "call_function" and "derive_multiplier" in str(n.target) + ] + self.assertGreater(len(derive_nodes), 0, "No derive_multiplier node found") + + for node in derive_nodes: + self.assertIn( + "val", + node.meta, + f"Node {node.name} created via reconstruct_fn is missing " + f"meta['val']. This would cause the partitioner to fail " + f"with 'Expected {node.name} to be a tensor'.", + ) + self.assertTrue( + is_opaque_node(node), + f"Node {node.name} should be classified as opaque", + ) + finally: + _OPAQUE_TYPES[OpaqueMultiplier].reconstruct_fn = original_reconstruct_fn + + def test_partitioner_must_save_opaque_node(self): + """Opaque nodes tagged MUST_SAVE go to saved_opaque_nodes. + + When activation checkpointing tags an opaque node with + CheckpointPolicy.MUST_SAVE, the default_partition function + must route it to saved_opaque_nodes (not saved_values). + Otherwise the runtime assertion in save_from_forward will + fail because it expects all saved_values to be Tensors. + """ + from functorch.compile import default_partition + from torch._functorch._aot_autograd.graph_compile import is_opaque_node + from torch.utils.checkpoint import CheckpointPolicy + + m = OpaqueMultiplier(2.0) + + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + + # Build a joint fwd+bwd graph where the opaque node is tagged + # MUST_SAVE — simulates activation checkpointing. + graph = torch.fx.Graph() + m_node = graph.placeholder("m") + m_node.meta["val"] = m + x_node = graph.placeholder("x") + with fake_mode: + x_node.meta["val"] = torch.randn(3) + + mul_node = graph.call_function( + torch.ops._TestOpaqueObject.mul_with_scale.default, + (m_node, x_node), + ) + with fake_mode: + mul_node.meta["val"] = torch.randn(3) + + # Tag the opaque node as MUST_SAVE + m_node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + + # Create backward portion: identity backward + bw_node = graph.call_function(torch.ops.aten.mul.Scalar, (mul_node, 1.0)) + with fake_mode: + bw_node.meta["val"] = torch.randn(3) + + graph.output((mul_node, bw_node)) + joint_gm = torch.fx.GraphModule({}, graph) + + # Run default_partition — this should NOT raise even though + # the opaque node is tagged MUST_SAVE. + num_fwd_outputs = 1 + fw_module, bw_module = default_partition( + joint_gm, [], num_fwd_outputs=num_fwd_outputs + ) + + # Verify the opaque node ended up in the forward graph as + # a pass-through (not as a saved tensor that would fail the + # save_from_forward assertion). + fw_opaque_nodes = [ + n + for n in fw_module.graph.nodes + if n.op == "placeholder" and is_opaque_node(n) + ] + self.assertGreater( + len(fw_opaque_nodes), + 0, + "Forward graph should have opaque placeholder nodes", + ) + instantiate_parametrized_tests(TestOpaqueObject) +@unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") +class TestOpaqueGenerator(TestCase): + def test_make_fx_with_generator(self): + """make_fx should trace through Generator inputs as opaque values.""" + from torch._prims.rng_prims import graphsafe_run_with_rng_state + + class M(torch.nn.Module): + def forward(self, q, k, v, rng_state): + out = graphsafe_run_with_rng_state( + torch.ops.aten._scaled_dot_product_efficient_attention.default, + q, + k, + v, + None, + True, + 0.1, + True, + rng_state=rng_state, + ) + return out[0] + + q = torch.randn(2, 8, 64, 32, device="cuda", dtype=torch.float16) + k = torch.randn(2, 8, 64, 32, device="cuda", dtype=torch.float16) + v = torch.randn(2, 8, 64, 32, device="cuda", dtype=torch.float16) + gen = torch.cuda.default_generators[0].clone_state() + + gm = make_fx(M(), tracing_mode="real")(q, k, v, gen) + + # The last placeholder (generator) should be used by graphsafe_run_with_rng_state + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + gen_placeholder = placeholders[-1] + self.assertEqual(len(gen_placeholder.users), 1) + user = next(iter(gen_placeholder.users)) + self.assertIs(user.target, graphsafe_run_with_rng_state) + + # Verify the traced graph produces the same result as eager. + # Use dropout_p=0.0 so the result is deterministic. + class M0(torch.nn.Module): + def forward(self, q, k, v, rng_state): + out = graphsafe_run_with_rng_state( + torch.ops.aten._scaled_dot_product_efficient_attention.default, + q, + k, + v, + None, + False, + 0.0, + True, + rng_state=rng_state, + ) + return out[0] + + gen1 = torch.cuda.default_generators[0].clone_state() + gen2 = torch.cuda.default_generators[0].clone_state() + gm0 = make_fx(M0(), tracing_mode="real")(q, k, v, gen1) + expected = M0()(q, k, v, gen2) + gen3 = torch.cuda.default_generators[0].clone_state() + actual = gm0(q, k, v, gen3) + self.assertEqual(actual, expected) + + def test_make_fx_randn_with_generator(self): + """make_fx should trace torch.randn with a Generator input.""" + + def fn(a, generator): + return torch.randn([20, 20], generator=generator, device=a.device) + + gen = torch.Generator("cuda") + gm = make_fx(fn, tracing_mode="real")(torch.randn(4, device="cuda"), gen) + + # Generator is baked in as a get_attr constant (not a placeholder input) + # because torch.randn passes it directly to C++ without going through + # proxy dispatch. The generator placeholder has 0 users. + self.assertExpectedInline( + normalize_gm(gm.print_readable(False)), + """\ +class fn(torch.nn.Module): + def forward(self, a_1: "f32[4]", generator_1): + _opaque_obj0 = self._opaque_obj0 + randn: "f32[20, 20]" = torch.ops.aten.randn.generator([20, 20], generator = _opaque_obj0, device = device(type='cuda', index=0), pin_memory = False); _opaque_obj0 = None + return randn +""", + ) + + if __name__ == "__main__": run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 83e7d0d311afc..e89df86607076 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -39,6 +39,7 @@ from torch.testing._internal.common_dtype import ( all_types_and_complex_and, floating_and_complex_types_and, + highest_precision_float, integral_types_and, ) from torch.testing._internal.common_methods_invocations import ( @@ -467,6 +468,7 @@ def test_reduction_ops_reduce(self, device, op): # resulting in possible equality check failures. # skip windows case on CPU due to https://github.com/pytorch/pytorch/issues/129947 # XPU test will be enabled step by step. Skip the tests temporarily. + # MPS does not support double precision, so single precision has to be used instead. @skipXPU @onlyNativeDeviceTypesAnd(["hpu"]) @suppress_warnings @@ -479,11 +481,11 @@ def test_numpy_ref(self, device, dtype, op): and dtype == torch.float64 and ("cuda" in device or "xpu" in device) or "cpu" in device - ): # noqa: E121 + ): raise unittest.SkipTest("XXX: raises tensor-likes are not close.") # Sets the default dtype to NumPy's default dtype of double - with set_default_dtype(torch.double): + with set_default_dtype(highest_precision_float(device)): for sample_input in op.reference_inputs(device, dtype): self.compare_with_reference( op, op.ref, sample_input, exact_dtype=(dtype is not torch.long) @@ -635,9 +637,13 @@ def _ref_test_helper( # precise dtypes -- they simply must be close precise_dtype = dtype if prims.utils.is_float_dtype(dtype): - precise_dtype = torch.double + precise_dtype = highest_precision_float(device) if prims.utils.is_complex_dtype(dtype): - precise_dtype = torch.cdouble + precise_dtype = ( + torch.complex32 + if torch.device(device).type == "mps" + else torch.cdouble + ) # Checks if the results are close try: diff --git a/test/test_optim.py b/test/test_optim.py index 6de3cabefc417..23094907f944b 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -43,6 +43,7 @@ markDynamoStrictTest, parametrize, run_tests, + serialTest, TEST_WITH_TORCHDYNAMO, TestCase, ) @@ -65,6 +66,38 @@ def drosenbrock(tensor): return torch.stack((-400 * x * (y - x**2) - 2 * (1 - x), 200 * (y - x**2))) +def _bf16_state_init_hook(optimizer, args, kwargs): + """Step pre-hook that initializes Adam/AdamW states in bfloat16. + + Pre-populates optimizer state before Adam's lazy initialization so that + ``_init_group`` finds non-empty state and skips its own fp32 allocation. + The fused CUDA kernel then dispatches to its mixed-precision path. + """ + for group in optimizer.param_groups: + for p in group["params"]: + if p.grad is None: + continue + state = optimizer.state[p] + if len(state) == 0: + state["step"] = ( + torch.zeros((), dtype=torch.float32, device=p.device) + if group.get("capturable") or group.get("fused") + else torch.tensor(0.0, dtype=torch.float32) + ) + state["exp_avg"] = torch.zeros_like( + p, dtype=torch.bfloat16, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=torch.bfloat16, memory_format=torch.preserve_format + ) + if group.get("amsgrad"): + state["max_exp_avg_sq"] = torch.zeros_like( + p, + dtype=torch.bfloat16, + memory_format=torch.preserve_format, + ) + + @markDynamoStrictTest class TestOptimRenewed(TestCase): """ @@ -964,6 +997,7 @@ def test_set_default_dtype_works_with_foreach(self, device, dtype, optim_info): @onlyCUDA @largeTensorTest("72GB", "cuda") + @serialTest() @optims( [optim for optim in optim_db if "foreach" in optim.supported_impls], dtypes=[torch.float16], @@ -1129,6 +1163,7 @@ def test_fused_error_on_params_on_meta(self, device, dtype, optim_info): @onlyNativeDeviceTypes @largeTensorTest("64GB") + @serialTest() @optims( [optim for optim in optim_db if "fused" in optim.supported_impls], dtypes=[torch.float16], @@ -2313,6 +2348,160 @@ def test_non_empty_state(self, device, dtype, optim_info): for state in optim.state.values(): self.assertGreater(len(state), 0) + @onlyCUDA + @parametrize("amsgrad", [False, True]) + @optims( + [o for o in optim_db if o.optim_cls.__name__ in ["Adam", "AdamW"]], + dtypes=[torch.float32], + ) + def test_fused_mixed_precision_state_init(self, device, dtype, optim_info, amsgrad): + optim_cls = optim_info.optim_cls + params = [torch.rand(20, 7, device=device, dtype=dtype) for _ in range(5)] + for p in params: + p.grad = torch.rand_like(p) + + optim = optim_cls(params, lr=1e-3, fused=True, amsgrad=amsgrad) + optim.register_step_pre_hook(_bf16_state_init_hook) + + optim.step() + + for p in params: + self.assertEqual(p.dtype, torch.float32) + state = optim.state[p] + self.assertEqual(state["step"].dtype, torch.float32) + self.assertEqual(state["exp_avg"].dtype, torch.bfloat16) + self.assertEqual(state["exp_avg_sq"].dtype, torch.bfloat16) + if amsgrad: + self.assertEqual(state["max_exp_avg_sq"].dtype, torch.bfloat16) + + # Second step: hook should be idempotent (skips already-populated state) + for p in params: + p.grad = torch.rand_like(p) + optim.step() + + for p in params: + state = optim.state[p] + self.assertEqual(state["step"].dtype, torch.float32) + self.assertEqual(state["exp_avg"].dtype, torch.bfloat16) + self.assertEqual(state["exp_avg_sq"].dtype, torch.bfloat16) + if amsgrad: + self.assertEqual(state["max_exp_avg_sq"].dtype, torch.bfloat16) + + @onlyCUDA + @parametrize("amsgrad", [False, True]) + @optims( + [o for o in optim_db if o.optim_cls.__name__ in ["Adam", "AdamW"]], + dtypes=[torch.float32], + ) + def test_fused_mixed_precision_hook_skips_existing_state( + self, device, dtype, optim_info, amsgrad + ): + optim_cls = optim_info.optim_cls + + # Two param groups: group 1 gets f32 state pre-populated (hook should + # skip it), group 2 has no state (hook should initialize it in bf16). + # This exercises the fused kernel handling two groups whose states have + # different dtypes within the same optimizer.step() call. + g1_params = [torch.rand(10, 5, device=device, dtype=dtype) for _ in range(2)] + g2_params = [torch.rand(10, 5, device=device, dtype=dtype) for _ in range(2)] + for p in g1_params + g2_params: + p.grad = torch.rand_like(p) + + optim = optim_cls( + [{"params": g1_params}, {"params": g2_params}], + lr=1e-3, + fused=True, + amsgrad=amsgrad, + ) + + for p in g1_params: + optim.state[p]["step"] = torch.zeros( + (), dtype=torch.float32, device=p.device + ) + optim.state[p]["exp_avg"] = torch.zeros_like(p) + optim.state[p]["exp_avg_sq"] = torch.zeros_like(p) + if amsgrad: + optim.state[p]["max_exp_avg_sq"] = torch.zeros_like(p) + + optim.register_step_pre_hook(_bf16_state_init_hook) + optim.step() + + # Group 1: hook skipped (state was non-empty), dtypes stay f32. + for p in g1_params: + state = optim.state[p] + self.assertEqual(state["step"].dtype, torch.float32) + self.assertEqual(state["exp_avg"].dtype, torch.float32) + self.assertEqual(state["exp_avg_sq"].dtype, torch.float32) + if amsgrad: + self.assertEqual(state["max_exp_avg_sq"].dtype, torch.float32) + + # Group 2: hook initialized state in bf16. + for p in g2_params: + state = optim.state[p] + self.assertEqual(state["step"].dtype, torch.float32) + self.assertEqual(state["exp_avg"].dtype, torch.bfloat16) + self.assertEqual(state["exp_avg_sq"].dtype, torch.bfloat16) + if amsgrad: + self.assertEqual(state["max_exp_avg_sq"].dtype, torch.bfloat16) + + @onlyCUDA + @optims( + [o for o in optim_db if o.optim_cls.__name__ in ["Adam", "AdamW"]], + dtypes=[torch.float32], + ) + def test_fused_mixed_precision_numerics(self, device, dtype, optim_info): + optim_inputs = optim_info.optim_inputs_func(device=device, dtype=dtype) + optim_cls = optim_info.optim_cls + for optim_input in optim_inputs: + kwargs = {**optim_input.kwargs, "fused": True} + + params = [torch.rand(20, 7, device=device, dtype=dtype) for _ in range(10)] + for p in params: + p.grad = torch.rand_like(p) + + params_c = [p.clone() for p in params] + for p, pc in zip(params, params_c): + pc.grad = p.grad.clone() + + ref_optim = optim_cls(params, **kwargs) + bf16_optim = optim_cls(params_c, **kwargs) + bf16_optim.register_step_pre_hook(_bf16_state_init_hook) + + # Simulate bf16 storage: after each ref step, quantize states to + # bf16 and back so the reference matches the mixed-precision kernel. + tracker = TensorTracker() + for i in range(7): + ref_optim.step() + bf16_optim.step() + for p in params: + tracker.add(p) + tracker.add(p.grad) + for d in ref_optim.state.values(): + exp_avg_bf16 = d["exp_avg"].to(torch.bfloat16) + tracker.add(exp_avg_bf16) + d["exp_avg"] = exp_avg_bf16.to(torch.float32) + exp_avg_sq_bf16 = d["exp_avg_sq"].to(torch.bfloat16) + tracker.add(exp_avg_sq_bf16) + d["exp_avg_sq"] = exp_avg_sq_bf16.to(torch.float32) + if "max_exp_avg_sq" in d: + max_exp_avg_sq_bf16 = d["max_exp_avg_sq"].to(torch.bfloat16) + tracker.add(max_exp_avg_sq_bf16) + d["max_exp_avg_sq"] = max_exp_avg_sq_bf16.to(torch.float32) + + for e, pc in enumerate(params_c): + tracker.pop_check_set(pc, self) + tracker.pop_check_set(pc.grad, self) + + for p, pc in zip(params, params_c): + self.assertEqual(p, pc) + + for dc in bf16_optim.state.values(): + tracker.pop_check_set(dc["exp_avg"], self) + tracker.pop_check_set(dc["exp_avg_sq"], self) + if "max_exp_avg_sq" in dc: + tracker.pop_check_set(dc["max_exp_avg_sq"], self) + self.assertTrue(tracker.all_popped()) + @parametrize("dtype", [torch.float32]) def test_step_iteration(self, device, dtype): def _get_model_and_input_tensor(device, dtype): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8296f386f0977..fe3c5347dadaf 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -656,7 +656,7 @@ def __init__( self.layer_norm = torch.nn.LayerNorm(input_dim) - def forward(mod_self, x): # noqa: B902 + def forward(mod_self, x): self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) y = mod_self.layer_norm(x) self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) @@ -1212,7 +1212,7 @@ def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None randn = torch.ops.aten.randn.default([3, _local_scalar_dense, 3], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None cumsum = torch.ops.aten.cumsum.default(randn, 0); randn = None - return cumsum""" # noqa: B950 + return cumsum""" ) @@ -1229,7 +1229,7 @@ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(sum_1); sum_1 = None repeat_interleave = torch.ops.aten.repeat_interleave.Tensor(x_1, output_size = _local_scalar_dense); x_1 = _local_scalar_dense = None index_select = torch.ops.aten.index_select.default(y_1, 0, repeat_interleave); y_1 = repeat_interleave = None - return index_select""" # noqa: B950 + return index_select""" ) def test_arange_unbacked_output_size(self): @@ -1242,7 +1242,7 @@ def f(x): def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None arange = torch.ops.aten.arange.start(0, _local_scalar_dense, device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None - return arange""" # noqa: B950 + return arange""" ) def test_adv_index_batch(self): @@ -1347,7 +1347,7 @@ def f(a): def forward(self, a_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None - return empty""" # noqa: B950 + return empty""" ) @@ -1366,7 +1366,7 @@ def forward(self, x_1): scalar_tensor = torch.ops.aten.scalar_tensor.default(sym_size_int, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')); sym_size_int = None select = torch.ops.aten.select.int(x_1, 0, 0) copy_ = torch.ops.aten.copy_.default(select, scalar_tensor); select = scalar_tensor = copy_ = None - return x_1""" # noqa: B950 + return x_1""" ) def test_dynamic_pointwise_scalar(self): @@ -1422,7 +1422,7 @@ def forward(self, crop_camera_1, mask_1): mm = torch.ops.aten.mm.default(view_3, eye); view_3 = eye = None _unsafe_view = torch.ops.aten._unsafe_view.default(mm, [sym_size_int, 3, 3]); mm = sym_size_int = None index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], _unsafe_view); crop_camera_1 = mask_1 = _unsafe_view = index_put_ = None - return None""") # noqa: B950 + return None""") def test_unbacked_slice(self): def f(x, m): @@ -1501,7 +1501,7 @@ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None add = torch.ops.aten.add.Tensor(zeros, y_1); zeros = y_1 = None - return add""") # noqa: B950 + return add""") def test_reshape_divisibility_unbacked(self): def f(x): @@ -1545,7 +1545,7 @@ def forward(self, x_1, y_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = zeros = None add = torch.ops.aten.add.Tensor(y_1, 2); y_1 = None - return add""") # noqa: B950 + return add""") @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') @unittest.expectedFailure @@ -1569,7 +1569,7 @@ def f(x1, x2, y): gm.recompile() r = str(gm.code).strip() # self.assertExpectedInline( - # r, """""" # noqa: B950 + # r, """""" # ) @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') @@ -1624,7 +1624,7 @@ def forward(self, lengths_1, values_1): getitem = split_with_sizes[0] getitem_1 = split_with_sizes[1] getitem_2 = split_with_sizes[2]; split_with_sizes = None - return (getitem, getitem_1, getitem_2)""") # noqa: B950 + return (getitem, getitem_1, getitem_2)""") def test_invalidate_nonzero(self): ok = False @@ -1792,7 +1792,7 @@ def f(x): def forward(self, x_1): _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x_1); x_1 = None zeros = torch.ops.aten.zeros.default([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None - return zeros""") # noqa: B950 + return zeros""") def test_expand(self): def f(a): @@ -1870,12 +1870,12 @@ def f(a, b): fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) from torch._dynamo.source import LocalSource self.assertExpectedInline( - str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950 - """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950 + str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), + """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" ) self.assertExpectedInline( - str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950 - """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950 + str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), + """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" ) def test_guard_upperbound_range_refinement(self): @@ -2029,8 +2029,6 @@ def f(t): xfail('cov'), xfail('nn.functional.gaussian_nll_loss'), xfail('corrcoef'), - xfail('quantile'), - xfail('nanquantile'), # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse xfail('sparse.sampled_addmm'), @@ -2061,11 +2059,8 @@ def f(t): xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition - xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. - xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition - xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... } diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py index 2eac18eead5ed..f47915361415d 100644 --- a/test/test_public_bindings.py +++ b/test/test_public_bindings.py @@ -184,7 +184,6 @@ def test_no_new_bindings(self): "PyTorchFileReader", "PyTorchFileWriter", "qscheme", - "read_vitals", "RRefType", "ScriptClass", "ScriptClassFunction", @@ -211,7 +210,6 @@ def test_no_new_bindings(self): "set_flush_denormal", "set_num_interop_threads", "set_num_threads", - "set_vital", "Size", "StaticModule", "Stream", @@ -232,7 +230,6 @@ def test_no_new_bindings(self): "Value", "set_autocast_gpu_dtype", "get_autocast_gpu_dtype", - "vitals_enabled", "wait", "Tag", "set_autocast_xla_enabled", @@ -415,9 +412,10 @@ def onerror(modname): errors = [] for mod, exc in failures: - if mod in private_allowlist: - # make sure mod is actually private - if not any(t.startswith("_") for t in mod.split(".")): + if mod in private_allowlist or ( + mod.startswith("torch._native.ops.") and "triton" in mod + ): + if self._is_mod_public(mod): raise AssertionError( f"Expected private module name to include '_' segments: {mod}" ) diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py index 0c69b1c61003e..71afd0c09851e 100644 --- a/test/test_python_dispatch.py +++ b/test/test_python_dispatch.py @@ -26,7 +26,9 @@ from torch.testing._internal.common_methods_invocations import op_db from torch.testing._internal.common_utils import ( first_sample, + instantiate_parametrized_tests, IS_WINDOWS, + parametrize, run_tests, TEST_WITH_ROCM, TestCase, @@ -47,6 +49,8 @@ _get_current_dispatch_mode, _get_current_dispatch_mode_stack, is_in_torch_dispatch_mode, + is_traceable_wrapper_subclass, + is_traceable_wrapper_subclass_type, TorchDispatchMode, ) from torch.utils._pytree import tree_map, tree_map_only @@ -57,6 +61,49 @@ def _identity(x): return x +class _TraceableWrapperSubclassTestBase(torch.Tensor): + elem: torch.Tensor + + __slots__ = ["elem"] + + @staticmethod + def __new__(cls, elem): + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, + elem.size(), + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=elem.requires_grad, + strides=elem.stride(), + storage_offset=elem.storage_offset(), + ) + r.elem = elem + return r + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + raise RuntimeError("NYI") + + +class _StaticUnflattenWrapper(_TraceableWrapperSubclassTestBase): + def __tensor_flatten__(self): + return ["elem"], {"kind": "static"} + + @staticmethod + def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): + return _StaticUnflattenWrapper(inner_tensors["elem"]) + + +class _ClassmethodUnflattenWrapper(_TraceableWrapperSubclassTestBase): + def __tensor_flatten__(self): + return ["elem"], {"kind": "classmethod"} + + @classmethod + def __tensor_unflatten__(cls, inner_tensors, metadata, outer_size, outer_stride): + return cls(inner_tensors["elem"]) + + class TestDispatcherPythonBindings(TestCase): def test_call_boxed(self) -> None: sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "") @@ -219,9 +266,6 @@ def my_fallback(op, *args, **kwargs): self.assertEqual(c, a + b) self.assertTrue(is_called) - @unittest.skip( - "Causing flakiness, see https://github.com/pytorch/pytorch/issues/145108" - ) def test_fallthrough_for_dense_key_with_meta_in_tls(self) -> None: # This tests that if meta is included in TlS dispatch key set, # then a meta kernel should be called regardless if a dense @@ -244,6 +288,29 @@ def sum_meta(*args, **kwargs): torch.ops.custom.sum.default(a) self.assertTrue(meta_is_called) + def test_include_dispatch_key_guard_restores_tls_exactly(self) -> None: + before = torch._C._dispatch_tls_local_include_set().raw_repr() + with torch._C._IncludeDispatchKeyGuard(torch.DispatchKey.Meta): + pass + after = torch._C._dispatch_tls_local_include_set().raw_repr() + self.assertEqual(before, after) + + @parametrize( + "key", + [ + torch.DispatchKey.Meta, + torch.DispatchKey.CUDA, + torch.DispatchKey.CPU, + ], + ) + def test_exclude_dispatch_key_guard_restores_tls_exactly(self, key) -> None: + keyset = torch._C.DispatchKeySet(key) + before = torch._C._dispatch_tls_local_exclude_set().raw_repr() + with torch._C._ExcludeDispatchKeyGuard(keyset): + pass + after = torch._C._dispatch_tls_local_exclude_set().raw_repr() + self.assertEqual(before, after) + def test_dispatchkeyset_pickle(self) -> None: keyset = torch._C.DispatchKeySet(torch._C.DispatchKey.AutogradCPU) serialized = pickle.dumps(keyset) @@ -689,6 +756,9 @@ def test_register_fallthrough(self): self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16) +instantiate_parametrized_tests(TestPythonRegistration) + + class TestPythonDispatch(TestCase): def test_basic(self) -> None: with capture_logs() as logs: @@ -1055,6 +1125,49 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor) self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) + def test_traceable_wrapper_subclass_protocol_runtime_check(self) -> None: + @torch._dynamo.disable + def run_checks() -> None: + base = torch.randn(2, 2) + static = _StaticUnflattenWrapper(base) + classmethod_wrapper = _ClassmethodUnflattenWrapper(base) + + self.assertFalse(is_traceable_wrapper_subclass(base)) + self.assertTrue(is_traceable_wrapper_subclass(static)) + self.assertTrue(is_traceable_wrapper_subclass(classmethod_wrapper)) + self.assertTrue(is_traceable_wrapper_subclass_type(type(static))) + self.assertTrue( + is_traceable_wrapper_subclass_type(type(classmethod_wrapper)) + ) + self.assertFalse(is_traceable_wrapper_subclass_type(torch.Tensor)) + + static_attrs, static_meta = static.__tensor_flatten__() + static_rebuilt = type(static).__tensor_unflatten__( + {attr: getattr(static, attr) for attr in static_attrs}, + static_meta, + static.size(), + static.stride(), + ) + self.assertIs(type(static_rebuilt), _StaticUnflattenWrapper) + self.assertEqual(static_rebuilt.elem, static.elem) + + classmethod_attrs, classmethod_meta = ( + classmethod_wrapper.__tensor_flatten__() + ) + classmethod_rebuilt = type(classmethod_wrapper).__tensor_unflatten__( + { + attr: getattr(classmethod_wrapper, attr) + for attr in classmethod_attrs + }, + classmethod_meta, + classmethod_wrapper.size(), + classmethod_wrapper.stride(), + ) + self.assertIs(type(classmethod_rebuilt), _ClassmethodUnflattenWrapper) + self.assertEqual(classmethod_rebuilt.elem, classmethod_wrapper.elem) + + run_checks() + def test_make_fx_with_subclass(self) -> None: def f(x, y): # Returns (TwoTensor, Tensor) diff --git a/test/test_quantization.py b/test/test_quantization.py index 42e145edbab3f..0355246d32e59 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -72,7 +72,7 @@ except ImportError as e: # In FBCode we separate FX out into a separate target for the sake of dev # velocity. These are covered by a separate test target `quantization_fx` - log.warning(e) # noqa:G200 + log.warning(e) try: from quantization.fx.test_numeric_suite_fx import TestFXGraphMatcher # noqa: F401 @@ -81,7 +81,7 @@ from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteNShadows # noqa: F401 from quantization.fx.test_numeric_suite_fx import TestFXNumericSuiteCoreAPIsModels # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) # Test the model report module try: @@ -93,19 +93,19 @@ from quantization.fx.test_model_report_fx import TestFxDetectOutliers # noqa: F401 from quantization.fx.test_model_report_fx import TestFxModelReportVisualizer # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) # Equalization for FX mode try: from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) # Backward Compatibility. Tests serialization and BC for quantized modules. try: from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) # JIT Graph Mode Quantization from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401 @@ -124,29 +124,29 @@ try: from quantization.ao_migration.test_quantization_fx import TestAOMigrationQuantizationFx # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) # Experimental functionality try: from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) try: from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) try: from quantization.core.experimental.test_floatx import TestFloat8DtypeCPU # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) try: from quantization.core.experimental.test_floatx import TestFloat8DtypeCUDA # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) try: from quantization.core.experimental.test_floatx import TestFloat8DtypeCPUOnlyCPU # noqa: F401 except ImportError as e: - log.warning(e) # noqa:G200 + log.warning(e) if __name__ == '__main__': run_tests() diff --git a/test/test_reductions.py b/test/test_reductions.py index a6afa48308557..81dcee3d8bfcc 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -14,17 +14,19 @@ from torch import inf, nan from torch.testing import make_tensor from torch.testing._internal.common_dtype import ( - all_types_and_complex_and, get_all_math_dtypes, integral_types, complex_types, floating_types_and, + all_types_and_complex_and, get_all_math_dtypes, highest_precision_float, + integral_types, complex_types, floating_types_and, integral_types_and, floating_and_complex_types_and, all_types_and, all_types, ) from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, parametrize, + skipIfMPS, skipIfTorchDynamo, IS_WINDOWS) from torch.testing._internal.common_device_type import ( - OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, - dtypesIfXPU, onlyNativeDeviceTypes, onlyCUDA, onlyOn, largeTensorTest, ops, precisionOverride) + OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, + dtypesIfCPU, dtypesIfXPU, onlyNativeDeviceTypes, onlyCUDA, onlyOn, largeTensorTest, ops, precisionOverride) from torch.testing._internal.common_methods_invocations import ( ReductionOpInfo, ReductionPythonRefInfo, reduction_ops, reference_masked_ops) @@ -126,24 +128,28 @@ def test_dim_default(self, device, op: ReductionOpInfo): for ndim in range(3): self._test_dim_keepdim(op, device, ndim=ndim) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_default_keepdim(self, device, op: ReductionOpInfo): """Tests that the default dim, when keepdim=True, reduces all dimensions to size 1.""" for ndim in range(3): self._test_dim_keepdim(op, device, ndim=ndim, keepdim=True) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_none(self, device, op: ReductionOpInfo): """Tests that dim=None reduces all dimensions.""" for ndim in range(3): self._test_dim_keepdim(op, device, ndim=ndim, dim=None) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_none_keepdim(self, device, op: ReductionOpInfo): """Tests that dim=None, when keepdim=True, reduces all dimensions to size 1.""" for ndim in range(3): self._test_dim_keepdim(op, device, ndim=ndim, dim=None, keepdim=True) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_single(self, device, op: ReductionOpInfo): """Tests that dim=i reduces dimension i.""" @@ -152,6 +158,7 @@ def test_dim_single(self, device, op: ReductionOpInfo): self._test_dim_keepdim(op, device, ndim=2, dim=-1) self._test_dim_keepdim(op, device, ndim=3, dim=1) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_single_keepdim(self, device, op: ReductionOpInfo): """Tests that dim=i, when keepdim=True, reduces dimension i to size 1.""" @@ -160,58 +167,68 @@ def test_dim_single_keepdim(self, device, op: ReductionOpInfo): self._test_dim_keepdim(op, device, ndim=2, dim=-1, keepdim=True) self._test_dim_keepdim(op, device, ndim=3, dim=1, keepdim=True) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_empty(self, device, op: ReductionOpInfo): """Tests that dim=[] is a no-op""" self._test_dim_keepdim(op, device, ndim=0, dim=[]) self._test_dim_keepdim(op, device, ndim=2, dim=[]) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_empty_keepdim(self, device, op: ReductionOpInfo): """Tests that dim=[], when keepdim=True, is a no-op""" self._test_dim_keepdim(op, device, ndim=0, dim=[], keepdim=True) self._test_dim_keepdim(op, device, ndim=2, dim=[], keepdim=True) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi(self, device, op: ReductionOpInfo): """Tests that dim=[i, j, ...] reduces dimensions i, j, ....""" self._test_dim_keepdim(op, device, ndim=1, dim=[0]) self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi_keepdim(self, device, op: ReductionOpInfo): """Tests that dim=[i, j, ...], when keepdim=True, reduces dimensions i, j, .... to size 1.""" self._test_dim_keepdim(op, device, ndim=1, dim=[0], keepdim=True) self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2], keepdim=True) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi_unsorted(self, device, op: ReductionOpInfo): """Tests that operator correctly handles unsorted dim list.""" self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2]) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi_unsorted_keepdim(self, device, op: ReductionOpInfo): """Tests that operator correctly handles unsorted dim list when keepdim=True.""" self._test_dim_keepdim(op, device, ndim=4, dim=[3, 0, 2], keepdim=True) + @skipIfMPS @ops(filter(lambda op: op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi_duplicate(self, device, op: ReductionOpInfo): """Tests that an error is raised if dim has duplicate entries.""" with self.assertRaises(RuntimeError): self._test_dim_keepdim(op, device, ndim=3, dim=[0, 1, 1, 2]) + @skipIfMPS @ops(filter(lambda op: not op.supports_multiple_dims, reduction_ops), dtypes=OpDTypes.none) def test_dim_multi_unsupported(self, device, op: ReductionOpInfo): """Tests that ops claiming to not support multi dim actually don't.""" with self.assertRaises(TypeError): self._test_dim_keepdim(op, device, ndim=3, dim=[0, 2]) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_offbounds(self, device, op: ReductionOpInfo): """Tests that passing an off-bounds dim throws""" with self.assertRaises(IndexError): self._test_dim_keepdim(op, device, ndim=2, dim=2) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_dim_ndim_limit(self, device, op: ReductionOpInfo): """Tests that an exception is raised when reducing a tensor with more @@ -220,6 +237,7 @@ def test_dim_ndim_limit(self, device, op: ReductionOpInfo): with self.assertRaisesRegex(RuntimeError, "only tensors with up to 64 dims are supported"): op(t, dim=0) + @skipIfMPS @ops(filter(lambda op: op.identity is not None, reduction_ops), dtypes=OpDTypes.supported) def test_identity(self, device, dtype, op: ReductionOpInfo): """Tests that the identity value is an identity for the operator""" @@ -256,6 +274,7 @@ def test_nan_policy_omit(self, device, dtype, op: ReductionOpInfo): result_with_nan = op(t, *args, **kwargs) self.assertEqual(result, result_with_nan) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.supported) def test_result_dtype(self, device, dtype, op: ReductionOpInfo): """Tests that the result has the correct dtype""" @@ -279,6 +298,7 @@ def test_result_dtype(self, device, dtype, op: ReductionOpInfo): else: self.assertEqual(result.dtype, dtype) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo): """Tests for consistent behavior when reducing over an empty slice. @@ -312,6 +332,7 @@ def test_empty_tensor_empty_slice(self, device, op: ReductionOpInfo): with self.assertRaises(IndexError): op(t, *args, dim=dim, **kwargs) + @skipIfMPS @ops(reduction_ops, dtypes=OpDTypes.none) def test_empty_tensor_nonempty_slice(self, device, op: ReductionOpInfo): """Tests that reducing a nonempty slice of an empty tensor returns an @@ -334,30 +355,35 @@ def _test_noncontiguous(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_ expected = op(t_contig, *args, **kwargs) self.assertEqual(result, expected) + @skipIfMPS @ops(reduction_ops) def test_noncontiguous_innermost(self, device, dtype, op: ReductionOpInfo): """Tests reducing along noncontiguous innermost dimension.""" t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1) self._test_noncontiguous(op, t[:, ::2], dim=1) + @skipIfMPS @ops(reduction_ops) def test_noncontiguous_outermost(self, device, dtype, op: ReductionOpInfo): """Tests reducing along noncontiguous outermost dimension.""" t = make_tensor((10, 10), dtype=dtype, device=device, low=-1, high=1) self._test_noncontiguous(op, t[::2, :], dim=0) + @skipIfMPS @ops(reduction_ops) def test_noncontiguous_all(self, device, dtype, op: ReductionOpInfo): """Tests reducing all dimensions of a noncontiguous tensor.""" t = make_tensor((5, 5, 5), dtype=dtype, device=device, low=-1, high=1) self._test_noncontiguous(op, t[::2, ::3, 1:-1:2]) + @skipIfMPS @ops(reduction_ops) def test_noncontiguous_transposed(self, device, dtype, op: ReductionOpInfo): """Tests reducing a transposed tensor.""" t = make_tensor((5, 5), dtype=dtype, device=device, low=-1, high=1) self._test_noncontiguous(op, t.T) + @skipIfMPS @ops(reduction_ops) def test_noncontiguous_expanded(self, device, dtype, op: ReductionOpInfo): """Tests reducing a tensor with expanded singleton dimensions.""" @@ -378,12 +404,14 @@ def _test_ref(self, op: ReductionOpInfo, t: torch.Tensor, **reduction_kwargs): expected = op.ref(t.detach().cpu().numpy(), *args, **kwargs) self.assertEqual(result, expected, exact_dtype=False) + @skipIfMPS @ops(filter(lambda op: op.ref is not None, reduction_ops), allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) def test_ref_scalar_input(self, device, dtype, op: ReductionOpInfo): """Compares op against reference for scalar input tensors""" self._test_ref(op, make_tensor([], dtype=dtype, device=device)) + @skipIfMPS @ops(filter(lambda op: op.ref is not None, reduction_ops), allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) def test_ref_small_input(self, device, dtype, op: ReductionOpInfo): @@ -413,6 +441,7 @@ def test_ref_large_input_64bit_indexing(self, device, dtype, op: ReductionOpInfo """Compares op against reference for a very large input tensor that requires 64 bit indexing""" self._test_ref(op, make_tensor((275000000,), dtype=dtype, device=device, low=-1, high=1, exclude_zero=True)) + @skipIfMPS @ops(filter(lambda op: op.ref is not None, reduction_ops), allowed_dtypes=all_types_and_complex_and(torch.half, torch.bool)) def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo): @@ -423,6 +452,7 @@ def test_ref_duplicate_values(self, device, dtype, op: ReductionOpInfo): self._test_ref(op, t, dim=0) self._test_ref(op, t, dim=1) + @skipIfMPS @ops(filter(lambda op: op.ref is not None, reduction_ops), allowed_dtypes=[torch.float32, torch.complex64]) def test_ref_extremal_values(self, device, dtype, op: ReductionOpInfo): @@ -471,6 +501,7 @@ def test_sum_dim_reduction_uint8_overflow(self, device): torch.sum(x, 0, out=y) self.assertEqual(x.sum(0, dtype=torch.uint8), y) + @skipIfMPS def test_dim_reduction_less_than_64(self, device): sizes = [1] * 65 x = torch.randn(sizes, device=device) @@ -496,6 +527,7 @@ def test_dim_reduction_lastdim(self, device, dtype): @skipIfNoSciPy @dtypes(torch.float32, torch.double, torch.complex64, torch.complex128) + @skipIfMPS def test_logsumexp(self, device, dtype): from scipy.special import logsumexp a = torch.randn(5, 4, device=device, dtype=dtype) @@ -516,6 +548,7 @@ def test_logsumexp(self, device, dtype): self.assertEqual(expected, b[:, 0]) @skipIfNoSciPy + @skipIfMPS def test_logsumexp_integral_promotion(self, device): from scipy.special import logsumexp # check integral inputs is promoted to floating point @@ -529,6 +562,7 @@ def test_logsumexp_integral_promotion(self, device): @dtypes(torch.complex64, torch.complex128) @dtypesIfXPU(torch.complex128) # Skip the torch.complex64 for XPU, see https://github.com/intel/torch-xpu-ops/issues/2279 for details + @skipIfMPS def test_logcumsumexp_complex(self, device, dtype): # logcumsumexp is a more precise way to compute than ``log(cumsum(exp(a)))`` # and faster than ``[log(sum(exp(a[:i]))) for i in range(a.shape[0])]`` @@ -972,6 +1006,7 @@ def test_cumprod_integer_upcast(self, device): self._test_reduce_integer_upcast(lambda x, **kwargs: torch.cumprod(x, 0, **kwargs)) @dtypes(*all_types()) + @skipIfMPS def test_mode(self, device, dtype): SIZE = 10 x = torch.arange(1., SIZE * SIZE + 1, device=device, dtype=dtype).clone().resize_(SIZE, SIZE) @@ -1041,6 +1076,7 @@ def testset_for_shape(shape, i): # Naive kernel for big slice sizes (> 2048) testset_for_shape((10, 4096), 10) + @skipIfMPS def test_mode_boolean(self, device): shapes = [ (10, 10), @@ -1068,6 +1104,7 @@ def test_mode_boolean(self, device): @expectedFailureMeta # mode only supports CPU and CUDA device type @onlyNativeDeviceTypes + @skipIfMPS def test_mode_wrong_dtype(self, device): def test_for_dtypes(x_ty, v_ty, i_ty, message): x = torch.ones(10, device=device, dtype=x_ty) @@ -1154,6 +1191,7 @@ def test_all_issue117215(self, device): @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfXPU(torch.half, torch.bfloat16, torch.float, torch.double) @dtypes(torch.half, torch.bfloat16, torch.float, torch.double) + @skipIfMPS def test_max_with_inf(self, device, dtype): a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) self.assertTrue(torch.all(torch.max(a, dim=1).values == inf).item()) @@ -1164,6 +1202,7 @@ def test_max_with_inf(self, device, dtype): @dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double) @dtypesIfXPU(torch.half, torch.bfloat16, torch.float, torch.double) @dtypes(torch.half, torch.float, torch.bfloat16, torch.double) + @skipIfMPS def test_min_with_inf(self, device, dtype): a = torch.tensor([[-inf, -inf, inf, 3], [inf, inf, -inf, -1]], dtype=dtype, device=device) self.assertTrue(torch.all(torch.min(a, dim=1).values == (-inf)).item()) @@ -1224,6 +1263,7 @@ def get_values(x): @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) @dtypesIfXPU(torch.half, torch.float, torch.long, torch.bool) @dtypes(torch.half, torch.float, torch.double) + @skipIfMPS def test_max(self, device, dtype): self._test_minmax_helper(torch.max, np.amax, device, dtype) @@ -1231,6 +1271,7 @@ def test_max(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.long, torch.bool) @dtypesIfXPU(torch.half, torch.float, torch.long, torch.bool) @dtypes(torch.half, torch.float, torch.double) + @skipIfMPS def test_min(self, device, dtype): self._test_minmax_helper(torch.min, np.amin, device, dtype) @@ -1238,6 +1279,7 @@ def test_min(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) @dtypesIfXPU(torch.half, torch.float, torch.int, torch.long, torch.bool) @dtypes(torch.half, torch.float, torch.double) + @skipIfMPS def test_amin(self, device, dtype): self._test_minmax_helper(torch.amin, np.amin, device, dtype) @@ -1245,6 +1287,7 @@ def test_amin(self, device, dtype): @dtypesIfCUDA(torch.half, torch.float, torch.int, torch.long, torch.bool) @dtypesIfXPU(torch.half, torch.float, torch.int, torch.long, torch.bool) @dtypes(torch.float, torch.double) + @skipIfMPS def test_amax(self, device, dtype): self._test_minmax_helper(torch.amax, np.amax, device, dtype) @@ -1252,6 +1295,7 @@ def test_amax(self, device, dtype): @dtypes(torch.float, torch.double, torch.bfloat16, torch.half) @dtypesIfCUDA(torch.half, torch.float, torch.bfloat16) @dtypesIfXPU(torch.half, torch.float, torch.bfloat16) + @skipIfMPS def test_aminmax(self, device, dtype): def _amin_wrapper(x, dim=None, keepdims=False): @@ -1265,12 +1309,14 @@ def _amax_wrapper(x, dim=None, keepdims=False): @onlyNativeDeviceTypes @dtypes(*complex_types()) + @skipIfMPS def test_invalid_0dim_aminmax(self, device, dtype): with self.assertRaisesRegex(RuntimeError, 'not implemented'): torch.aminmax(torch.tensor(1., dtype=dtype, device=device), dim=0) # TODO: bincount isn't a classic reduction -- maybe this test suite is # reductions and summary ops? + @skipIfMPS def test_bincount(self, device): # negative input throws with self.assertRaisesRegex(RuntimeError, '1-d non-negative integral'): @@ -1575,6 +1621,7 @@ def test_min_mixed_devices(self, device): lambda: torch.amin(a, 0, out=values)) # TODO: consider refactoring with bincount test + @skipIfMPS # AssertionError: UserWarning not triggered by def test_bucketization(self, device): values_1d = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9], device=device) values_3d = torch.tensor([[[1, 3, 5], [2, 4, 6]], [[1, 2, 3], [4, 5, 6]]], device=device) @@ -1605,8 +1652,8 @@ def test_bucketization(self, device): self.assertEqual(torch.bucketize(values_0_el, boundaries), expected_result) # nan input - values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=torch.float64) - boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=torch.float64) + values_nan = torch.tensor([1.0, float('nan'), 2.0, float('nan')], device=device, dtype=highest_precision_float(device)) + boundaries = torch.tensor([0.0, 1.0, 2.0, 3.0], device=device, dtype=highest_precision_float(device)) expected_result = torch.tensor([1, 4, 2, 4], device=device) self.assertEqual(torch.searchsorted(boundaries, values_nan), expected_result) expected_result = torch.tensor([2, 4, 3, 4], device=device) @@ -1615,7 +1662,7 @@ def test_bucketization(self, device): # type promotion and non contiguous tensors values_3d_permute = values_3d.permute(2, 1, 0).to(torch.int32) - boundaries_permute = values_3d.permute(2, 1, 0).to(torch.float64) + boundaries_permute = values_3d.permute(2, 1, 0).to(highest_precision_float(device)) expected_result = torch.tensor([[[0, 0], [0, 1]], [[2, 0], [0, 1]], [[2, 0], [0, 0]]], device=device) if self.device_type != 'xla': self.assertWarnsRegex( @@ -1699,6 +1746,7 @@ def test_dtype_bfloat16(values_bf16=False, boundaries_bf16=False): test_dtype_bfloat16(True, True) @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @skipIfMPS def test_nansum(self, device, dtype): args = product( (True, False), # noncontiguous @@ -1752,6 +1800,7 @@ def _test_reduction_function_with_numpy(self, torch_func, np_func, device, dtype atol=atol, rtol=rtol, exact_dtype=exact_dtype) @dtypes(*all_types_and_complex_and(torch.half)) + @skipIfMPS def test_count_nonzero(self, device, dtype): self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype) self._test_reduction_function_with_numpy(torch.count_nonzero, np.count_nonzero, device, dtype, True) @@ -1794,6 +1843,7 @@ def is_integral(dtype): @onlyNativeDeviceTypes @dtypes(*set(all_types_and(torch.half)) - {torch.uint8}) + @skipIfMPS def test_sum_vs_numpy(self, device, dtype): self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype) self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True) @@ -1801,6 +1851,7 @@ def test_sum_vs_numpy(self, device, dtype): @onlyNativeDeviceTypes @dtypes(*set(all_types_and(torch.half)) - {torch.uint8}) + @skipIfMPS def test_nansum_vs_numpy(self, device, dtype): self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype) self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True) @@ -1814,6 +1865,7 @@ def test_nansum_complex(self, device, dtype): torch.nansum(x) @dtypes(*all_types_and(torch.half)) + @skipIfMPS def test_nansum_out_dtype(self, device, dtype): out_dtype = dtype inp_dtypes = all_types_and(torch.half) if out_dtype.is_floating_point else integral_types() @@ -1830,6 +1882,7 @@ def test_nansum_out_dtype(self, device, dtype): @dtypes(*all_types_and(torch.half)) @dtypesIfXPU(torch.half, torch.int8, torch.uint8, torch.float32) # Acc issue for other types on xpu, see https://github.com/intel/torch-xpu-ops/issues/2295 + @skipIfMPS def test_argminmax_multiple(self, device, dtype): # Case: All Ones t = torch.ones(3, 3, device=device, dtype=dtype) @@ -1915,6 +1968,7 @@ def verify_against_numpy(t): verify_against_numpy(t) @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) + @skipIfMPS def test_all_any_vs_numpy(self, device, dtype): # Note [all, any uint8 compatibility]: However for compatibility reason, # for `uint8`, they return Tensor of same dtype `uint8`. @@ -2027,6 +2081,7 @@ def _test_output_dtype(x): # TODO: part of this test covers torch.norm, with should be covered by test_linalg @onlyNativeDeviceTypes + @skipIfMPS def test_repeated_dim(self, device): ops = [torch.mean, torch.sum, torch.nansum, torch.std, torch.logsumexp, torch.std, torch.var, torch.norm] @@ -2112,6 +2167,7 @@ def test_sum_cpu_device_mismatch(self, device): # Assert for illegal dtype would not be raised on XLA @onlyNativeDeviceTypes + @skipIfMPS def test_minmax_illegal_dtype(self, device): x = torch.randn(5, 5, dtype=torch.float32, device=device) valid_values = torch.empty(5, dtype=torch.float32, device=device) @@ -2137,6 +2193,7 @@ def test_minmax_illegal_dtype(self, device): torch.min(x, dim=0, out=(illegal_values, illegal_indices)) @dtypes(*all_types_and(torch.half, torch.bfloat16)) + @skipIfMPS def test_dim_arg_reduction_scalar(self, device, dtype): example = 4.0 @@ -2155,6 +2212,7 @@ def test_dim_arg_reduction_scalar(self, device, dtype): @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}) @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8}) + @skipIfMPS def test_dim_reduction(self, device, dtype): example = [[-1, 2, 1], [5, 3, 6]] @@ -2299,6 +2357,7 @@ def test_nanmean_integral_types(self, device, dtype): ): torch.nanmean(t) + @skipIfMPS @precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2}) @dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8}) @parametrize("fn_name", [ @@ -2523,6 +2582,7 @@ def test_argminmax_axis_with_dim_one(self, device): @dtypes(torch.int, torch.long, torch.float, torch.double) @dtypesIfCUDA(torch.int, torch.long, torch.half, torch.float, torch.double) @dtypesIfXPU(torch.int, torch.long, torch.half, torch.float, torch.double) + @skipIfMPS def test_median_real_values(self, device, dtype): # Generate random 0-3D sizes sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] @@ -2553,6 +2613,7 @@ def test_median_real_values(self, device, dtype): @dtypes(torch.float, torch.double) @dtypesIfCUDA(torch.half, torch.float, torch.double) @dtypesIfXPU(torch.half, torch.float, torch.double) + @skipIfMPS def test_median_nan_values(self, device, dtype): # Generate random 0-3D sizes sizes = [random.sample(range(1, 32), i) for i in range(4) for _ in range(2)] @@ -2591,6 +2652,7 @@ def test_median_nan_values(self, device, dtype): ref = numpy_op(t_numpy, dim, keepdims=True)[mask.cpu().numpy()] self.assertEqual(res, torch.from_numpy(ref)) + @skipIfMPS def test_median_corner_cases(self, device): def check(op, a, args, key): t = torch.tensor(a, device=device) @@ -2642,6 +2704,7 @@ def check(op, a, args, key): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/pull/138657 discovers a latent bug") @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) + @skipIfMPS def test_quantile(self, device, dtype): # Generate some random test cases ops = ['quantile', 'nanquantile'] @@ -2683,6 +2746,7 @@ def test_quantile(self, device, dtype): torch_op(a, q, dim=dim, keepdim=keepdim, interpolation=interpolation, out=out) self.assertEqual(out.cpu(), result.cpu()) + @skipIfMPS # Fails only on macos-m1-stable def test_quantile_backward(self, device): def check(a, q, dim, expected_grad, ops=(torch.quantile, torch.nanquantile)): for op in ops: @@ -2823,6 +2887,7 @@ def _compare_std_var_with_numpy(self, op, device, dtype, input, dim, @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue for float64 on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat, torch.cdouble) + @skipIfMPS def test_var_vs_numpy(self, device, dtype): _size = (20, 20) @@ -2836,6 +2901,7 @@ def test_var_vs_numpy(self, device, dtype): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue for float64 on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat, torch.cdouble) + @skipIfMPS def test_std_vs_numpy(self, device, dtype): _size = (20, 20) @@ -2849,6 +2915,7 @@ def test_std_vs_numpy(self, device, dtype): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue for float64 on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat, torch.cdouble) + @skipIfMPS def test_var_correction_vs_numpy(self, device, dtype): _size = (20, 20) test_args = [ @@ -2885,6 +2952,7 @@ def test_var_correction_vs_numpy(self, device, dtype): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue for float64 on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat, torch.cdouble) + @skipIfMPS def test_std_correction_vs_numpy(self, device, dtype): _size = (20, 20) test_args = [ @@ -2921,6 +2989,7 @@ def test_std_correction_vs_numpy(self, device, dtype): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat) + @skipIfMPS def test_std_mean_correction(self, device, dtype): _size = (20, 20) test_args = [ @@ -2954,6 +3023,7 @@ def test_std_mean_correction(self, device, dtype): @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) # Driver issue on XPU, see https://github.com/intel/torch-xpu-ops/issues/2295 @dtypesIfXPU(torch.float, torch.cfloat) + @skipIfMPS def test_var_mean_correction(self, device, dtype): _size = (20, 20) test_args = [ @@ -2985,6 +3055,7 @@ def test_var_mean_correction(self, device, dtype): self.assertEqual(mean1, mean2) @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) + @skipIfMPS def test_warn_invalid_degrees_of_freedom(self, device, dtype): def _assert_warning(_func, _tensor, _correction): with warnings.catch_warnings(record=True) as w: @@ -3017,6 +3088,7 @@ def test_amin_amax_some_dims(self, device): self.assertEqual(amin1, amin2) self.assertEqual(amax1, amax2) + @skipIfMPS def test_histc(self, device): # negative nbins throws with self.assertRaisesRegex(RuntimeError, 'bins must be > 0'): @@ -3155,11 +3227,13 @@ def test_histc_lowp(self, device, dtype): self.assertEqual(actual.dtype, dtype) @dtypes(torch.uint8, torch.int8, torch.int, torch.long, torch.float, torch.double) + @skipIfMPS def test_histc_min_max_errors(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "max must be larger than min"): torch.histc(torch.tensor([1., 2., 3.], dtype=dtype, device=device), bins=4, min=5, max=1) @dtypes(torch.float, torch.double) + @skipIfMPS def test_histc_min_max_corner_cases(self, device, dtype): actual = torch.histc( torch.tensor([1., 2, 1], dtype=dtype, device=device), @@ -3579,6 +3653,7 @@ def test_tensor_compare_ops_empty(self, device): # test_tensot_compare_ops_empty because not specifying a `dim` parameter in the former tests does # not throw errors. Also, checking the return type of argmax requires supplying a different dtype # argument than that for the input tensor. There is also variation in numpy testing. + @skipIfMPS def test_tensor_compare_ops_argmax_argmix_kthvalue_dim_empty(self, device): shape = (2, 0, 4) master_input = torch.randn(shape, device=device) @@ -3683,6 +3758,7 @@ def test_tensor_reduce_ops_empty(self, device): # Tests to ensure that any() and all() functions work with zero-dim tensors. Kept separate from # other tests for checking reduction with zero-dim tensors because these tests have significantly # different testing behaviour than that used for the former tests. + @skipIfMPS def test_reduction_empty_any_all(self, device): shape = (2, 0, 4) x = torch.randn(shape, device=device) @@ -3711,6 +3787,7 @@ def test_reduction_empty_any_all(self, device): self.assertEqual(torch.ones((), device=device, dtype=out_dtype), xb.all()) # TODO: can these be merged with their respective OpInfos? + @skipIfMPS def test_reduce_dtype(self, device): def test_reduction(op, has_no_dim, takes_dtype=True): x = torch.randn(3, 3, dtype=torch.float, requires_grad=True, device=device) @@ -3736,6 +3813,7 @@ def test_reduction(op, has_no_dim, takes_dtype=True): test_reduction(torch.cumprod, False) test_reduction(torch.logcumsumexp, False, takes_dtype=False) + @skipIfMPS @ops(reference_masked_ops) def test_reference_masked(self, device, dtype, op): """Test masked reduction operations on strided-only tensors using @@ -3829,7 +3907,7 @@ def foo_compile(): self.assertEqual(result_eager.shape, result_compiled.shape) self.assertEqual(result_eager.shape, torch.Size([2, 2])) -instantiate_device_type_tests(TestReductions, globals(), allow_xpu=True) +instantiate_device_type_tests(TestReductions, globals(), allow_xpu=True, allow_mps=True) if __name__ == '__main__': run_tests() diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 47048f8a4091a..5aafa1399ff12 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -46,11 +46,14 @@ from torch.testing._internal.common_utils import ( IS_WINDOWS, + MI350_ARCH, parametrize, run_tests, + runOnRocmArch, skipIfRocm, TEST_CUDA, TestCase, + skipIfXpu, ) from torch.testing._internal.common_quantized import ( _bfloat16_to_float4_e2m1fn_x2, @@ -69,7 +72,7 @@ if TEST_CUDA: _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 -f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+ and XPU devices" +f8_msg = "FP8 is only supported on H100+, SM 8.9 and MI300+, XPU and CPU devices" f8_grouped_msg = "FP8 grouped is only supported on SM90 and MI300/MI350 devices" mx_skip_msg = "MX gemm is only supported on CUDA capability 10.0+" mxfp8_grouped_mm_skip_msg = "MXFP8 grouped GEMM is only supported when PyTorch is built with USE_MSLK=1 on SM100+" @@ -653,23 +656,49 @@ def _test_tautological_mm(self, device: str = "cuda", # Skip on XPU due to known oneDNN accuracy issue (#169772) @skipXPU - def test_float8_basics(self, device) -> None: + @parametrize( + "test_case", + [ + # test_case tuple schema: + # (case_name, x_dtype, y_dtype, out_dtype, size) + # "default" in case_name means out_dtype=None (backend default output dtype path). + ("e4m3_e4m3_default", e4m3_type, e4m3_type, None, 16), + ("e5m2_e5m2_default", e5m2_type, e5m2_type, None, 16), + ("e5m2_e5m2_f32", e5m2_type, e5m2_type, torch.float32, 16), + ("e4m3_e5m2_default", e4m3_type, e5m2_type, None, 32), + ("e5m2_e4m3_default", e5m2_type, e4m3_type, None, 48), + ("e4m3_e4m3_f16", e4m3_type, e4m3_type, torch.float16, 64), + ("e4m3_e4m3_f32", e4m3_type, e4m3_type, torch.float32, 96), + ("e4m3_e4m3_bf16", e4m3_type, e4m3_type, torch.bfloat16, 80), + ], + name_fn=lambda test_case: f"cuda_{test_case[0]}", + ) + def test_float8_basics(self, device, test_case) -> None: if not _device_supports_scaled_mm_fp8(device): raise unittest.SkipTest(f8_msg) - self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) + _, x_dtype, y_dtype, out_dtype, size = test_case + expect_e5m2_cuda_error = x_dtype == e5m2_type and y_dtype == e5m2_type # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported # supported on ROCm but fails on CUDA - ctx = self.assertRaises(ValueError) if torch.version.hip is None and "cuda" in device else contextlib.nullcontext() + ctx = ( + self.assertRaises(ValueError) + if expect_e5m2_cuda_error and torch.version.hip is None and "cuda" in device + else contextlib.nullcontext() + ) with ctx: - self._test_tautological_mm(device, e5m2_type, e5m2_type) - - self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) - self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) - - self._test_tautological_mm(device, size=64, out_dtype=torch.float16) - self._test_tautological_mm(device, size=96, out_dtype=torch.float32) - self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) + self._test_tautological_mm( + device, + x_dtype=x_dtype, + y_dtype=y_dtype, + out_dtype=out_dtype, + size=size, + ) + # Skip on XPU due to known oneDNN accuracy issue (#169772) + @skipXPU + def test_float8_basics_layout_permutations(self, device) -> None: + if not _device_supports_scaled_mm_fp8(device): + raise unittest.SkipTest(f8_msg) if torch.cuda.is_available(): for (x_cm, y_cm) in itertools.product([True, False], repeat=2): # SM 10 and 11 support all permutations, SM 12 TT and TN, SM 9 only TN @@ -683,6 +712,11 @@ def test_float8_basics(self, device) -> None: with contextlib.nullcontext() if layouts_supported else self.assertRaises(RuntimeError): self._test_tautological_mm(device, size=64, out_dtype=torch.bfloat16, x_cm=x_cm, y_cm=y_cm) + # Skip on XPU due to known oneDNN accuracy issue (#169772) + @skipXPU + def test_float8_basics_invalid_out_dtype(self, device) -> None: + if not _device_supports_scaled_mm_fp8(device): + raise unittest.SkipTest(f8_msg) with self.assertRaises( AssertionError if (torch.version.hip or "xpu" in device or "cpu" in device) else RuntimeError @@ -922,15 +956,15 @@ def _2d_to_blocked_scaled(X, K, G, offs, format): # Assert outputs are close. torch.testing.assert_close(y_lp, y_bf16, atol=8.0e-2, rtol=8.0e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) @parametrize("x_cm", [True, False]) @parametrize("y_cm", [True, False]) - def test_scaled_mm_vs_emulated(self, base_dtype, x_cm, y_cm, device="cuda"): + def test_scaled_mm_vs_emulated(self, base_dtype, x_cm, y_cm, device): # Blackwell (SM_10) supports all possible layout permutations, while Hopper only TN - if (x_cm, y_cm) != (True, False) and torch.cuda.get_device_properties(0).major != 10: - raise unittest.SkipTest("Unsupported layout on the architecture") + if torch.cuda.is_available(): + if (x_cm, y_cm) != (True, False) and torch.cuda.get_device_properties(0).major != 10: + raise unittest.SkipTest("Unsupported layout on the architecture") torch.manual_seed(42) input_dtype = e4m3_type output_dtype = base_dtype @@ -1031,7 +1065,7 @@ def test_scaled_mm_change_stride(self, base_dtype, device="cuda"): torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @skipCUDAIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias(self, device) -> None: (k, l, m) = (16, 48, 32) @@ -1048,7 +1082,7 @@ def test_float8_bias(self, device) -> None: difference = torch.abs(out_fp32 - outb_fp32) self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("bias", [True, False]) def test_non_divisible_leading_dim(self, device, bias: bool) -> None: @@ -1061,7 +1095,7 @@ def test_non_divisible_leading_dim(self, device, bias: bool) -> None: input_bias = torch.rand((16,), device=device).to(torch.bfloat16) _ = scaled_mm_wrap(x, y, scale_a, scale_b, bias=input_bias) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_bias_relu_edgecase(self, device) -> None: (k, l, m) = (16, 48, 32) @@ -1074,7 +1108,7 @@ def test_float8_bias_relu_edgecase(self, device) -> None: outb_fp32 = outb_fp8.to(torch.float32) self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float32_output_errors_with_bias(self, device) -> None: (k, l, m) = (16, 48, 32) @@ -1083,8 +1117,8 @@ def test_float32_output_errors_with_bias(self, device) -> None: scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) - # XPU supports the case when out_dtype is fp32 + bias. So we just test it with normal run. - if "xpu" not in device: + # XPU and CPU supports the case when out_dtype is fp32 + bias. So we just test it with normal run. + if "xpu" not in device and "cpu" not in device: self.assertRaisesRegex( ValueError if torch.cuda.is_available() else RuntimeError, "Bias is not supported when out_dtype is set to Float32", @@ -1105,6 +1139,7 @@ def test_error_message_fp8_pre_sm89(self, device) -> None: lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32), ) + @skipIfXpu(msg="AssertionError, torch-xpu-ops: 2862") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(SM100OrLater, "fast_accum is SM90-only") def test_float8_scale_fast_accum(self, device) -> None: @@ -1120,7 +1155,7 @@ def test_float8_scale_fast_accum(self, device) -> None: out_fp8_s = scaled_mm_wrap(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=e4m3_type, use_fast_accum=True) self.assertEqual(out_fp8, out_fp8_s) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @skipCUDAIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific") @parametrize("use_fast_accum", [True, False]) @@ -1151,7 +1186,7 @@ def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> No out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) ) - @onlyOn(["cuda", "xpu"]) + @onlyOn(["cuda", "xpu", "cpu"]) @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) def test_float8_error_messages(self, device) -> None: M, K, N = (1024, 512, 2048) @@ -1227,7 +1262,7 @@ def e5m2(): (torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) and torch.version.cuda and - torch.version.cuda >= "12.9")): + torch.version.cuda >= "12.9") or (not torch.cuda.is_available() and torch.cpu.is_available())): out = e5m2() self.assertEqual(out, torch.ones_like(out) * 128.) else: @@ -1626,20 +1661,21 @@ def test_scaled_mm_vs_emulated_block_wise_verify_small_shapes( output_dtype ) - @skipIfRocm @onlyCUDA @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) @unittest.skipIf(IS_SM90, "DeepSeek style (1x128, 128x128) blockwise scaling works on SM90 (Hopper)") @unittest.skipIf( - _get_torch_cuda_version() < (12, 9), + not torch.version.hip and _get_torch_cuda_version() < (12, 9), "cuBLAS blockwise scaling added in CUDA 12.9", ) + @runOnRocmArch(MI350_ARCH) @parametrize("output_dtype", [torch.bfloat16, ]) @parametrize("lhs_block,rhs_block", [(1, 1), (128, 1), (1, 128)]) @parametrize("M,N,K", [(256, 256, 256), (256, 256, 512)]) def test_scaled_mm_deepseek_error_messages( self, output_dtype, lhs_block, rhs_block, M, N, K ): + torch.manual_seed(42) x = torch.randn(M, K, device="cuda", dtype=output_dtype).pow(3) @@ -1661,10 +1697,18 @@ def test_scaled_mm_deepseek_error_messages( else: rhs_recipe = ScalingType.BlockWise128x128 - # Verify that actual F8 mm raises expected error on non-SM90 + # Verify that actual F8 mm raises expected error + if torch.version.hip: + # ROCm does not yet support DeepSeek-style blockwise scaling + expected_error = NotImplementedError + expected_pattern = "1x128 and 128x128 scaling not available with ROCm" + else: + # CUDA non-SM90 should raise NotImplementedError + expected_error = NotImplementedError + expected_pattern = ".*DeepSeek.*scaling.*only supported in CUDA for SM90.*" with self.assertRaisesRegex( - NotImplementedError, - ".*DeepSeek.*scaling.*only supported in CUDA for SM90.*" + expected_error, + expected_pattern ): scaled_mm_wrap( x_fp8, @@ -2128,7 +2172,9 @@ def test_passed_swizzle_arrays(self, device) -> None: # No swizzle passed - must fail on swizzle_a with self.assertRaisesRegex( ValueError, - "swizzle_a must have 1 values, got 0", + "swizzle_a and swizzle_b must each have 1 value" + if torch.version.hip + else "swizzle_a must have 1 value, got 0", ): _ = torch.nn.functional.scaled_mm( x, @@ -2142,7 +2188,9 @@ def test_passed_swizzle_arrays(self, device) -> None: # swizzle_a passed, not b, must fail on swizzle_b with self.assertRaisesRegex( ValueError, - "swizzle_b must have 1 values, got 0", + "swizzle_a and swizzle_b must each have 1 value" + if torch.version.hip + else "swizzle_b must have 1 value, got 0", ): _ = torch.nn.functional.scaled_mm( x, @@ -2153,6 +2201,21 @@ def test_passed_swizzle_arrays(self, device) -> None: ScalingType.BlockWise1x32, swizzle_a=SwizzleType.SWIZZLE_32_4_4, ) + if torch.version.hip: + with self.assertRaisesRegex( + ValueError, + "swizzle_a and swizzle_b must both be NO_SWIZZLE", + ): + _ = torch.nn.functional.scaled_mm( + x, + w.t(), + x_scale, + ScalingType.BlockWise1x32, + w_scale, + ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + ) # NVFP4 two-level: swizzle=[SWIZZLE_32_4_4, NO_SWIZZLE] x = _bfloat16_to_float4_e2m1fn_x2(x.to(torch.bfloat16)) @@ -2165,8 +2228,10 @@ def test_passed_swizzle_arrays(self, device) -> None: # No swizzles passed - must fail on swizzle_a with self.assertRaisesRegex( - ValueError, - "swizzle_a must have 2 values, got 0", + NotImplementedError if torch.version.hip else ValueError, + "NVFP4 scaling not supported on ROCM" + if torch.version.hip + else "swizzle_a must have 2 values, got 0", ): _ = torch.nn.functional.scaled_mm( x, @@ -2179,8 +2244,10 @@ def test_passed_swizzle_arrays(self, device) -> None: # Not enough swizzles passed - must fail on swizzle_a with self.assertRaisesRegex( - ValueError, - "swizzle_a must have 2 values, got 1", + NotImplementedError if torch.version.hip else ValueError, + "NVFP4 scaling not supported on ROCM" + if torch.version.hip + else "swizzle_a must have 2 values, got 1", ): _ = torch.nn.functional.scaled_mm( x, @@ -2194,8 +2261,10 @@ def test_passed_swizzle_arrays(self, device) -> None: # Not enough swizzles passed to b - must fail on swizzle_b with self.assertRaisesRegex( - ValueError, - "swizzle_b must have 2 values, got 1", + NotImplementedError if torch.version.hip else ValueError, + "NVFP4 scaling not supported on ROCM" + if torch.version.hip + else "swizzle_b must have 2 values, got 1", ): _ = torch.nn.functional.scaled_mm( x, @@ -2516,7 +2585,7 @@ def test_blockwise_nvfp4_compile(self) -> None: torch.testing.assert_close(C, C_ref, atol=0, rtol=0) -instantiate_device_type_tests(TestFP8Matmul, globals(), except_for="cpu", allow_xpu=True) +instantiate_device_type_tests(TestFP8Matmul, globals(), allow_xpu=True) if __name__ == '__main__': TestCase._default_dtype_check_enabled = True diff --git a/test/test_serialization.py b/test/test_serialization.py index 51ff9182fa9af..7a9bcb63f297a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -60,7 +60,7 @@ TEST_WITH_MTIA, TestCase, ) -from torch.testing._internal.two_tensor import TwoTensor # noqa: F401 +from torch.testing._internal.two_tensor import TwoTensor from torch.utils._import_utils import import_dill from pickle import UnpicklingError @@ -4673,7 +4673,7 @@ def test_get_unsafe_globals_in_checkpoint(self): f.seek(0) try: old_get_allowed_globals = torch._weights_only_unpickler._get_allowed_globals - torch._weights_only_unpickler._get_allowed_globals = lambda: dict() # noqa: PIE807 + torch._weights_only_unpickler._get_allowed_globals = lambda: dict() unsafe_all_globals = torch.serialization.get_unsafe_globals_in_checkpoint(f) self.assertEqual(set(unsafe_all_globals), expected_all_global_strs) finally: diff --git a/test/test_sparse.py b/test/test_sparse.py index 5e21192e19147..aa444db285778 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -23,12 +23,13 @@ from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, ops, dtypes, dtypesIfCUDA, dtypesIfMPS, onlyCPU, onlyCUDA, precisionOverride, deviceCountAtLeast, OpDTypes, onlyNativeDeviceTypes, skipCUDAIf, expectedFailureMPS, - expectedFailureMPSComplex, largeTensorTest) + largeTensorTest) from torch.testing._internal.common_methods_invocations import \ (op_db, reduction_ops, sparse_unary_ufuncs, sparse_masked_reduction_ops, binary_ufuncs) from torch.testing._internal.common_dtype import ( all_types, all_types_and_complex, all_mps_types, all_types_and_complex_and, floating_and_complex_types, - floating_and_complex_types_and, integral_types, floating_types_and, + floating_and_complex_types_and, highest_precision_complex, highest_precision_float, + integral_types, floating_types_and, ) from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse from torch.testing._internal.opinfo.refs import ( @@ -339,7 +340,7 @@ def _test_print(self, device, dtype, coalesced): if values.dtype == torch.double: dtypes.append(torch.float) else: - dtypes.append(torch.double if values.device != torch.device("mps:0") else torch.float32) + dtypes.append(highest_precision_float(values.device)) for dtype in dtypes: printed.append(f"########## {dtype} ##########") x = sp_tensor.detach().to(dtype) @@ -589,7 +590,7 @@ def fn(x): x.requires_grad_(True) gradcheck(fn, (x,)) - values_types = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64] + values_types = [highest_precision_float(device), highest_precision_complex(device)] for value_type in values_types: i = self.index_tensor([ [0, 1, 2, 2], @@ -634,7 +635,7 @@ def fn(x): def test_to_sparse(self, device, dtype, coalesced): shape = [5, 2, 10, 4] max_nnz = 1 - dtypes = [torch.double, torch.cdouble] if device != "mps:0" else [torch.float32, torch.complex64] + dtypes = [highest_precision_float(device), highest_precision_complex(device)] for value_type in dtypes: for dim, dim_sz in enumerate(shape, 1): max_nnz *= dim_sz @@ -969,7 +970,7 @@ def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced): x1.copy_(x2) self.assertEqual(x1_dtype, x1.dtype) - x2 = x2.to(torch.float64) if device != "mps:0" else x2.to(torch.float32) + x2 = x2.to(highest_precision_float(device)) x1_dtype = x1.dtype x1.copy_(x2) self.assertEqual(x1_dtype, x1.dtype) @@ -1819,19 +1820,25 @@ def fn(S, D): @unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupported triggers assertion error") @gradcheck_semantics() def test_sparse_mul(self, device, dtype, coalesced, gradcheck): + # check_batched_grad=False: slow gradcheck's batched/vmap Jacobian + # path calls aten::view which is unsupported for sparse tensors. # https://github.com/pytorch/pytorch/issues/79914 a = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) b = torch.tensor([[0., 1]], dtype=dtype, device=device).to_sparse().requires_grad_(True) - gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), [a, b]) + gradcheck( + lambda x, y: torch.sparse.sum(x * y).to_dense(masked_grad=gradcheck.masked), + [a, b], + check_batched_grad=False, + ) def test_shape(sparse_dims, nnz, with_shape): a = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) b = self._gen_sparse(sparse_dims, nnz, with_shape, dtype, device, coalesced)[0].requires_grad_(True) self.assertEqual((a * b).to_dense(), a.to_dense() * b.to_dense()) - gradcheck(lambda x, y: (x * y).to_dense(), [a, b]) + gradcheck(lambda x, y: (x * y).to_dense(), [a, b], check_batched_grad=False) # Issues with 0-dim indices/values - gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True) + gradcheck(lambda x, y: torch.sparse.sum(x * y).to_dense(), [a, b], masked=True, check_batched_grad=False) test_shape(2, 3, [2, 3, 4, 5]) test_shape(2, 3, [2, 2, 0]) @@ -1967,7 +1974,6 @@ def test_sparse_add_out_bfloat16(self, device, dtype, coalesced): self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0) @coalescedonoff - @expectedFailureMPSComplex @dtypes(torch.double, torch.cdouble) @dtypesIfMPS(torch.float32, torch.complex64) def test_norm(self, device, dtype, coalesced): diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index 7fbc4d9f451b8..42d41a39ad4ea 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -1031,6 +1031,16 @@ def _npref_block_addmm_addmv(c, a, b, alpha, beta): class TestSparseCSR(TestCase): + @onlyCPU + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_empty_plain_indices_with_stride_zero(self, device, dtype): + # Test that empty plain_indices with stride 0 works. + crow_indices = torch.tensor([0, 0], dtype=torch.int32, device=device) + col_indices = torch.as_strided(torch.empty((0,), device=device, dtype=torch.int32), (0,), (0,)) + values = torch.empty(0, dtype=dtype, device=device) + t = torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 100), dtype=dtype, device=device) + self.assertEqual(t._nnz(), 0) + def test_csr_stride(self): a = self.genSparseCSRTensor((3, 3), 3, dtype=torch.float, device=self.device_type, index_dtype=torch.int64) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 22255120670be..881470944aee3 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -4,45 +4,43 @@ import random import unittest +import pytest import torch -from torch import nn +import torch._dynamo.test_case import torch.nn.functional as F - +from torch import nn from torch.sparse import ( SparseSemiStructuredTensor, SparseSemiStructuredTensorCUSPARSELT, SparseSemiStructuredTensorCUTLASS, to_sparse_semi_structured, ) - from torch.sparse._semi_structured_conversions import ( - sparse_semi_structured_from_dense_cutlass, - _sparse_semi_structured_tile, _compute_compressed_swizzled_bitmask, + _sparse_semi_structured_tile, + sparse_semi_structured_from_dense_cutlass, ) - from torch.testing import make_tensor -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, xfailIfSM89PreCUDA13 +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FP8, + PLATFORM_SUPPORTS_FP8_SPARSE, + xfailIfSM89PreCUDA13, +) from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, ) - from torch.testing._internal.common_dtype import all_types_and_complex -import torch._dynamo.test_case from torch.testing._internal.common_utils import ( + IS_WINDOWS, parametrize, run_tests, subtest, - TestCase, TEST_WITH_ROCM, - IS_WINDOWS, + TestCase, ) - from torch.testing._internal.inductor_utils import HAS_GPU -import pytest - SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() _IS_SM8X = False @@ -50,16 +48,28 @@ _IS_HIPSPARSELT_AVAILABLE = False if torch.cuda.is_available(): - _IS_SM8X = torch.version.cuda is not None and (torch.cuda.get_device_capability(0)[0] == 8) - _IS_SM9X = torch.version.cuda is not None and (torch.cuda.get_device_capability(0)[0] == 9) - _IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4) + _IS_SM8X = torch.version.cuda is not None and ( + torch.cuda.get_device_capability(0)[0] == 8 + ) + _IS_SM9X = torch.version.cuda is not None and ( + torch.cuda.get_device_capability(0)[0] == 9 + ) + _IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple( + int(v) for v in torch.version.hip.split(".")[:2] + ) >= (7, 12) # CUTLASS kernels only work for Ampere if _IS_SM8X: - SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS + SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = ( + SparseSemiStructuredTensorCUTLASS + ) # add cuSPASRELt tests if available - if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE): - SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT + if torch.backends.cusparselt.is_available() and ( + _IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE + ): + SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = ( + SparseSemiStructuredTensorCUSPARSELT + ) inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) training_dtypes = dtypes(torch.float16, torch.bfloat16) @@ -76,13 +86,16 @@ }, } + def sparse24_largest_mask_2d(original): sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original) return sparse.to_dense().bool() + def sparsify24_dense(original): return sparse24_largest_mask_2d(original) * original + def rand_sparse_semi_structured_mask( r, c, dtype=torch.float16, device="cuda", choice=None ): @@ -100,15 +113,13 @@ def rand_sparse_semi_structured_mask( .contiguous() ) + def rand_sparse_semi_structured(r, c, dtype, device, choice=None): - pattern = '2by4' if dtype != torch.float32 else '1by2' - if pattern == '1by2': + pattern = "2by4" if dtype != torch.float32 else "1by2" + if pattern == "1by2": ksparse = 2 - choices = [ - [0, 1], - [1, 0] - ] - elif pattern == '2by4': + choices = [[0, 1], [1, 0]] + elif pattern == "2by4": ksparse = 4 choices = [ [1, 1, 0, 0], @@ -116,7 +127,7 @@ def rand_sparse_semi_structured(r, c, dtype, device, choice=None): [1, 0, 0, 1], [0, 1, 1, 0], [0, 1, 0, 1], - [0, 0, 1, 1] + [0, 0, 1, 1], ] mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)] mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device) @@ -127,16 +138,16 @@ def rand_sparse_semi_structured(r, c, dtype, device, choice=None): def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): - pattern = '2by4' if dtype != torch.float32 else '1by2' - if pattern == '1by2': + pattern = "2by4" if dtype != torch.float32 else "1by2" + if pattern == "1by2": ksparse = 2 choices = [ [[0, 0], [0, 1]], [[0, 1], [0, 1]], [[1, 0], [1, 0]], - [[1, 1], [1, 0]] + [[1, 1], [1, 0]], ] - elif pattern == '2by4': + elif pattern == "2by4": ksparse = 4 choices = [ [[0, 0, 0, 0], [0, 0, 1, 1]], @@ -164,7 +175,7 @@ def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device) mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device) dense = make_tensor(r, c, dtype=dtype, device=device) - dense[dense == 0] = 1 # To prevent zeros except where mask below applied. + dense[dense == 0] = 1 # To prevent zeros except where mask below applied. dense_inv = dense.masked_fill(~mask_inv, 0) dense_val = dense_inv.masked_fill(~mask_val, 0) @@ -172,10 +183,9 @@ def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): - def setUp(self): if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: - self.skipTest('semi-structured sparsity has no available backend!') + self.skipTest("semi-structured sparsity has no available backend!") super().setUp() def tearDown(self): @@ -209,32 +219,45 @@ def forward(self, x): mod_linear.weight = nn.Parameter(mod_linear.weight * mask) dense_result = model(input) - mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight)) + mod_linear.weight = nn.Parameter( + SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight) + ) sparse_result = model(input) model = torch.compile(model, backend="inductor", fullgraph=True) sparse_compile_result = model(input) # test that sparse_compile_result and dense_result are numerically close - torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3 + ) # assert sparse and sparse_compile have the same strides, # as meta registrations may return contiguous tensors when the output is transposed # https://github.com/pytorch/pytorch/pull/114477 if sparse_result.stride() != sparse_compile_result.stride(): - raise AssertionError(f"stride mismatch: {sparse_result.stride()} != {sparse_compile_result.stride()}") + raise AssertionError( + f"stride mismatch: {sparse_result.stride()} != {sparse_compile_result.stride()}" + ) @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") - @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf( + "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, + "cusparselt not supported on this machine", + ) @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cusparselt(self): """ test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile """ for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: - SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape) - + SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile( + "cusparselt", dense_input_shape + ) - @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") + @unittest.skipIf( + "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, + "cutlass not supported on this machine", + ) @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cutlass(self): @@ -242,18 +265,24 @@ def test_mlp_contiguous_relu_compile_cutlass(self): test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile """ for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: - SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape) - + SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile( + "cutlass", dense_input_shape + ) @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") - @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf( + "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, + "cusparselt not supported on this machine", + ) @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @unittest.skipIf( "RelWithAssert" in torch.__config__.show(), "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context", ) def test_sp24_compile(self) -> None: - x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) + x = torch.randn( + [1024, 512], device="cuda", dtype=torch.float16, requires_grad=True + ) def fn(x): y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x) @@ -267,18 +296,18 @@ def fn(x): output = torch.compile(fn)(x) output.backward(output) -class TestSparseSemiStructured(TestCase): +class TestSparseSemiStructured(TestCase): def setUp(self): if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0: - self.skipTest('semi-structured sparsity has no available backend!') + self.skipTest("semi-structured sparsity has no available backend!") if IS_WINDOWS: self.skipTest("torch.compile not supported on windows") @inference_dtypes @parametrize_backends def test_to_sparse_semi_structured(self, dtype, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) A_sparse = to_sparse_semi_structured(A) @@ -292,7 +321,9 @@ def test_to_sparse_semi_structured(self, dtype, backend): if not isinstance(A, torch.Tensor): raise AssertionError(f"A should be torch.Tensor, got {type(A)}") if not isinstance(A_sparse, SparseSemiStructuredTensor): - raise AssertionError(f"A_sparse should be SparseSemiStructuredTensor, got {type(A_sparse)}") + raise AssertionError( + f"A_sparse should be SparseSemiStructuredTensor, got {type(A_sparse)}" + ) @inference_dtypes @parametrize_backends @@ -301,7 +332,7 @@ def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8 """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) A_sparse = to_sparse_semi_structured(A) @@ -310,16 +341,26 @@ def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend): # Currently we don't support int matmul on GPU, so evaluate on CPU and copy over if dtype is torch.int8: if backend == "cutlass": - with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): + with self.assertRaisesRegex( + RuntimeError, "spgemm_cutlass_dispatch_layouts" + ): sparse_result = torch.mm(A_sparse, B) else: - with self.assertRaisesRegex(RuntimeError, - "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): + if torch.version.hip: + self.skipTest( + "Skipping int8 sparse mm (NN, cuSPARSELt) test on ROCm" + ) + with self.assertRaisesRegex( + RuntimeError, + "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit", + ): sparse_result = torch.mm(A_sparse, B) else: dense_result = torch.mm(A, B) sparse_result = torch.mm(A_sparse, B) - torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + dense_result, sparse_result, rtol=1e-3, atol=1e-3 + ) @inference_dtypes @parametrize_backends @@ -329,7 +370,7 @@ def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16 and will throw an error for int8 + padding """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) A_sparse = to_sparse_semi_structured(A) @@ -340,22 +381,34 @@ def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend): # padding with int8 throws an error because transposing B yields a contiguous output # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. if backend == "cutlass": - with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): + with self.assertRaisesRegex( + RuntimeError, "spgemm_cutlass_dispatch_layouts" + ): sparse_result = torch.mm(A_sparse, B.t()) else: - with self.assertRaisesRegex(RuntimeError, - "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"): + if torch.version.hip: + self.skipTest( + "Skipping int8 sparse mm (NT, cusparselt, shape=(1,128)) test on ROCm" + ) + with self.assertRaisesRegex( + RuntimeError, + "CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit", + ): sparse_result = torch.mm(A_sparse, B.t()) elif dtype is torch.int8: # test transpose dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8) sparse_result = torch.mm(A_sparse, B.t()) - torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + dense_result, sparse_result, rtol=1e-3, atol=1e-3 + ) else: # test transpose dense_result = torch.mm(A, B.t()) sparse_result = torch.mm(A_sparse, B.t()) - torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + dense_result, sparse_result, rtol=1e-3, atol=1e-3 + ) @inference_dtypes @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)]) @@ -364,7 +417,7 @@ def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend): """ Ensure torch.mm(A_sparse.t(), B) throws error """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype) @@ -385,7 +438,7 @@ def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A, B_sparse.t()) is correct """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) @@ -410,7 +463,7 @@ def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend): """ Ensure torch.mm(A, B_sparse) throws error """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) @@ -431,7 +484,7 @@ def test_linear(self, dense_input_shape, inference_mode, device, backend): """ Test nn.Linear has the same numerics """ - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") input = torch.rand((dense_input_shape), device=device).half() @@ -456,7 +509,7 @@ def test_linear(self, dense_input_shape, inference_mode, device, backend): @parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)]) @parametrize_backends def test_mlp(self, device, dense_input_shape, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" input = torch.rand(dense_input_shape, device=device).half() model = ( nn.Sequential( @@ -486,34 +539,44 @@ def test_mlp(self, device, dense_input_shape, backend): @parametrize_backends def test_values(self, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = rand_sparse_semi_structured_mask(128, 128) A_sparse = to_sparse_semi_structured(A) if A_sparse.values().shape != (128, 64): - raise AssertionError(f"values shape should be (128, 64), got {A_sparse.values().shape}") + raise AssertionError( + f"values shape should be (128, 64), got {A_sparse.values().shape}" + ) if not (A_sparse.values() == 1).all(): raise AssertionError("values should all be 1") @parametrize_backends def test_indices(self, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = rand_sparse_semi_structured_mask(128, 128) A_sparse = to_sparse_semi_structured(A) if A_sparse.indices().shape != (128, 8): - raise AssertionError(f"indices shape should be (128, 8), got {A_sparse.indices().shape}") + raise AssertionError( + f"indices shape should be (128, 8), got {A_sparse.indices().shape}" + ) @inference_dtypes @parametrize_backends def test_min_sparse_shape(self, dtype, device, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") - config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype] - A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device) + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" + config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[ + dtype + ] + A = rand_sparse_semi_structured_mask( + config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device + ) A_sparse = to_sparse_semi_structured(A) - B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype) + B = torch.rand( + (config.sparse_min_cols, config.dense_min_cols), device=device + ).to(dtype) if dtype == torch.int8: dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8) # int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R @@ -527,7 +590,7 @@ def test_min_sparse_shape(self, dtype, device, backend): @inference_dtypes @parametrize_backends def test_unsupported_shape(self, dtype, device, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device) @@ -537,12 +600,15 @@ def test_unsupported_shape(self, dtype, device, backend): @dtypes(*all_types_and_complex()) @parametrize_backends def test_unsupported_dtype(self, dtype, device, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device) - if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS: + if ( + dtype + not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS + ): with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"): A_sparse = to_sparse_semi_structured(A) else: @@ -550,7 +616,7 @@ def test_unsupported_dtype(self, dtype, device, backend): @parametrize_backends def test_unsupported_dim(self, device, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") A = torch.rand(128, 128, 128, device=device, dtype=torch.float16) @@ -577,18 +643,23 @@ def create_random_mask(shape) -> torch.Tensor: mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool) return mask -class TestSparseSemiStructuredTraining(TestCase): +class TestSparseSemiStructuredTraining(TestCase): def setUp(self): if not _IS_SM8X: - self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)") + self.skipTest( + "SparseSemiStructuredTensor training only supported on SM8x (Ampere)" + ) if IS_WINDOWS: - self.skipTest('CUTLASS not supported on windows') - + self.skipTest("CUTLASS not supported on windows") @training_dtypes @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") + @unittest.skipIf( + not torch.backends.cusparselt.is_available(), + "cuSPARSELt not available", + ) @unittest.skipIf( "RelWithAssert" in torch.__config__.show(), "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context", @@ -600,40 +671,57 @@ def test_prune_dense_static_sort(self, dtype) -> None: pruned = _sparse_semi_structured_tile(dense) # CUTLASS - reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy") + reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort( + pruned, algorithm="largest_abs_values_greedy" + ) torch.testing.assert_close(pruned, reference_cutlass.to_dense()) packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned) - packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous()) - meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride()) - meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride()) + packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass( + pruned.t().contiguous() + ) + meta_cutlass = meta_cutlass.as_strided( + reference_cutlass.meta.shape, reference_cutlass.meta.stride() + ) + meta_t_cutlass = meta_t_cutlass.as_strided( + reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride() + ) compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned) - compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape, - reference_cutlass.compressed_swizzled_bitmask.stride()) - cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape, - packed_cutlass, - meta_cutlass, - packed_t_cutlass, - meta_t_cutlass, - compressed_swizzled_bitmask) + compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided( + reference_cutlass.compressed_swizzled_bitmask.shape, + reference_cutlass.compressed_swizzled_bitmask.stride(), + ) + cutlass = SparseSemiStructuredTensorCUTLASS( + dense.shape, + packed_cutlass, + meta_cutlass, + packed_t_cutlass, + meta_t_cutlass, + compressed_swizzled_bitmask, + ) torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense()) # CUSPARSELT - reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned, - algorithm="largest_abs_values_greedy") + reference_cusparselt = ( + SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort( + pruned, algorithm="largest_abs_values_greedy" + ) + ) torch.testing.assert_close(pruned, reference_cusparselt.to_dense()) packed_cusparselt = torch._cslt_compress(pruned) packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous()) - cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape, - packed_cusparselt, - None, - packed_t_cusparselt, - None, - compressed_swizzled_bitmask) - torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense()) - - + cusparselt = SparseSemiStructuredTensorCUSPARSELT( + dense.shape, + packed_cusparselt, + None, + packed_t_cusparselt, + None, + compressed_swizzled_bitmask, + ) + torch.testing.assert_close( + reference_cusparselt.to_dense(), cusparselt.to_dense() + ) @training_dtypes @parametrize_backends @@ -649,7 +737,9 @@ def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: dtype=dtype, ) inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1) - sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy") + sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort( + inp, algorithm="largest_abs_values_greedy" + ) mask = sInp.to_dense() / inp expected = [ @@ -659,9 +749,15 @@ def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: [1, 0, 0, 1], ] if mask[:4, :4].int().tolist() != expected: - raise AssertionError(f"mask[:4, :4] mismatch: {mask[:4, :4].int().tolist()} != {expected}") + raise AssertionError( + f"mask[:4, :4] mismatch: {mask[:4, :4].int().tolist()} != {expected}" + ) @training_dtypes + @unittest.skipIf( + not torch.backends.cusparselt.is_available(), + "cuSPARSELt not available", + ) def test_gemm(self, dtype) -> None: M, N, K = 32, 32, 64 a = torch.randn([M, K], device="cuda", dtype=dtype) @@ -677,7 +773,6 @@ def test_gemm(self, dtype) -> None: sp24_out = a_sparse @ b torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype]) - @training_dtypes @parametrize_backends @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @@ -695,15 +790,20 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: a = a.cuda().to(dtype) b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype) - a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a) + a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort( + a + ) mask_dense = sparse24_largest_mask_2d(a).to(dtype) if backend == "cutlass": if not isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS): - raise AssertionError(f"a_sparse should be SparseSemiStructuredTensorCUTLASS, got {type(a_sparse)}") - (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( - mask_dense, use_cutlass=True) + raise AssertionError( + f"a_sparse should be SparseSemiStructuredTensorCUTLASS, got {type(a_sparse)}" + ) + (packed, meta, packed_t, meta_t, bitmask) = ( + torch._sparse_semi_structured_tile(mask_dense, use_cutlass=True) + ) sparse_mask = SparseSemiStructuredTensorCUTLASS( mask_dense.shape, @@ -713,7 +813,9 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: meta_t=meta_t, compressed_swizzled_bitmask=bitmask, ) - torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta) + torch.testing.assert_close( + a_sparse.meta.view(torch.short), sparse_mask.meta + ) ref_gemm = (mask_dense * a) @ b pack_gemm = a_sparse @ b @@ -731,9 +833,7 @@ def test_pack_both_ways_id(self, dtype) -> None: a = torch.randn([N, N], dtype=dtype, device="cuda") b = torch.eye(N, dtype=dtype, device="cuda") - packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[ - :4 - ] + packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4] # Heuristic to ensure we pack the same values torch.testing.assert_close( packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum() @@ -746,7 +846,8 @@ def test_pack_both_ways_id(self, dtype) -> None: pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t() max_diff = (ref_gemm - pack_gemm).abs().argmax() torch.testing.assert_close( - ref_gemm, pack_gemm, + ref_gemm, + pack_gemm, **atol_rtol_kw[dtype], msg=f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})", ) @@ -755,7 +856,8 @@ def test_pack_both_ways_id(self, dtype) -> None: max_diff = (ref_gemm - pack_gemm).abs().argmax() torch.testing.assert_close( - ref_gemm, pack_gemm, + ref_gemm, + pack_gemm, **atol_rtol_kw[dtype], msg=f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})", ) @@ -789,9 +891,13 @@ def test_pack_both_ways_edge_case1(self, dtype) -> None: raise AssertionError(f"packed[0, 1] should be 0, got {packed[0, 1].item()}") # And first column in A.t if packed_t[0, 0].item() != 2: - raise AssertionError(f"packed_t[0, 0] should be 2, got {packed_t[0, 0].item()}") + raise AssertionError( + f"packed_t[0, 0] should be 2, got {packed_t[0, 0].item()}" + ) if packed_t[0, 1].item() != 0: - raise AssertionError(f"packed_t[0, 1] should be 0, got {packed_t[0, 1].item()}") + raise AssertionError( + f"packed_t[0, 1] should be 0, got {packed_t[0, 1].item()}" + ) @training_dtypes @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @@ -854,7 +960,6 @@ def test_sp24_apply_dense(self, dtype) -> None: torch.testing.assert_close(dense, expected) torch.testing.assert_close(sparse.to_dense(), expected) - @training_dtypes @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @unittest.skipIf( @@ -867,7 +972,9 @@ def test_sp24_matmuls(self, dtype) -> None: b = torch.randn([K, N], device="cuda", dtype=dtype) a_m = sparse24_largest_mask_2d(a) b_m = sparse24_largest_mask_2d(b) - (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a) + (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( + a + ) a_s = SparseSemiStructuredTensorCUTLASS( a.shape, packed=packed, @@ -876,7 +983,9 @@ def test_sp24_matmuls(self, dtype) -> None: meta_t=meta_t, compressed_swizzled_bitmask=bitmask, ) - (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b) + (packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile( + b + ) b_s = SparseSemiStructuredTensorCUTLASS( b.shape, packed=packed, @@ -891,11 +1000,13 @@ def test_sp24_matmuls(self, dtype) -> None: torch.testing.assert_close( a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1 ) - torch.testing.assert_close( - a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 - ) + torch.testing.assert_close(a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1) @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") + @unittest.skipIf( + not torch.backends.cusparselt.is_available(), + "cuSPARSELt not available", + ) @unittest.skipIf( "RelWithAssert" in torch.__config__.show(), "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context", @@ -910,6 +1021,10 @@ def test_sp24_matmuls_mat_vec(self) -> None: torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") + @unittest.skipIf( + not torch.backends.cusparselt.is_available(), + "cuSPARSELt not available", + ) @unittest.skipIf( "RelWithAssert" in torch.__config__.show(), "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context", @@ -923,34 +1038,53 @@ def test_sp24_matmuls_bmm(self) -> None: with pytest.raises(NotImplementedError): torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) + class TestSparseSemiStructuredCUTLASS(TestCase): """ This contains CUTLASS specific tests for - torch._sparse_semi_structured_linear """ + def setUp(self): SparseSemiStructuredTensor._FORCE_CUTLASS = True if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - self.skipTest('CUTLASS not enabled') + self.skipTest("CUTLASS not enabled") def tearDown(self): SparseSemiStructuredTensor._FORCE_CUTLASS = False super().tearDown() - @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") + @unittest.skipIf( + TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS" + ) @inference_dtypes def test_linear_cutlass(self, device, dtype): - - def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol): + def run_test( + batch_shape, + m, + n, + k, + device, + dtype, + dtype_out, + add_bias, + activation, + rtol, + atol, + ): weight = rand_sparse_semi_structured(m, k, dtype, device) input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device) - bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None + bias = ( + make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None + ) dtype_dense = torch.float32 input_dense = input.to(dtype_dense) weight_dense = weight.to(dtype_dense) bias_dense = bias.to(dtype_dense) if add_bias else None - output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense) + output0 = torch.nn.functional.linear( + input_dense, weight_dense, bias=bias_dense + ) if activation == "relu": relu = torch.nn.ReLU() output0 = relu(output0) @@ -963,9 +1097,17 @@ def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activatio weight_sparse = compressed.values() meta = compressed.indices() - output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation, - out_dtype=dtype_out if dtype == torch.int8 else None) - torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) + output1 = torch._sparse_semi_structured_linear( + input, + weight_sparse, + meta, + bias=bias, + activation=activation, + out_dtype=dtype_out if dtype == torch.int8 else None, + ) + torch.testing.assert_close( + output1.to(dtype_dense), output0, rtol=rtol, atol=atol + ) if dtype == torch.float32: # Inputs are converted to TF32 internally for sparse GEMM, @@ -974,32 +1116,51 @@ def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activatio torch.backends.cuda.matmul.allow_tf32 = True batch_shapes = [[], [3], [3, 1]] - dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} + dtype_out = { + torch.int8: torch.int32, + torch.half: torch.half, + torch.bfloat16: torch.bfloat16, + torch.float32: torch.float32, + } activations = [None, "relu", "silu"] rtol, atol = 1e-3, 1e-3 if dtype == torch.bfloat16: rtol, atol = 5e-3, 5e-3 elif dtype == torch.float32: rtol, atol = 1e-3, 75e-2 - for batch_shape, m, n, k, add_bias, activation in \ - itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations): + for batch_shape, m, n, k, add_bias, activation in itertools.product( + batch_shapes, range(3), range(3), range(3), (False, True), activations + ): if activation == "silu" and dtype == torch.int8: continue # SiLU not supported for integer inputs - m = 2 ** m * 32 - n = 2 ** n * 32 - k = 2 ** k * 128 - run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol) + m = 2**m * 32 + n = 2**n * 32 + k = 2**k * 128 + run_test( + batch_shape, + m, + n, + k, + device, + dtype, + dtype_out[dtype], + add_bias, + activation, + rtol, + atol, + ) if dtype == torch.float32: torch.backends.cuda.matmul.allow_tf32 = orig - - @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") + @unittest.skipIf( + TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS" + ) @parametrize("backend", ["cutlass"]) @inference_dtypes def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend): - SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + SparseSemiStructuredTensor._FORCE_CUTLASS = backend == "cutlass" if backend == "cutlass" and IS_WINDOWS: self.skipTest("CUTLASS not supported on Windows") @@ -1007,7 +1168,9 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): mat1 = rand_sparse_semi_structured(m, k, dtype, device) # mat2 transposed as int8 case supports only row-major/column-major combination mat2 = make_tensor((n, k), dtype=dtype, device=device).t() - input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None + input = ( + make_tensor((m,), dtype=dtype_out, device=device) if use_input else None + ) if use_input: if dtype.is_floating_point: @@ -1024,7 +1187,9 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): output0 = torch.mm(mat1_dense, mat2_dense) else: input_dense = input.to(dtype_dense)[:, None] - output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta) + output0 = torch.addmm( + input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta + ) compressed = to_sparse_semi_structured(mat1) @@ -1032,12 +1197,22 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): mat1_meta = compressed.indices() if not use_input: - output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out) + output1 = torch._sparse_semi_structured_mm( + mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out + ) else: output1 = torch._sparse_semi_structured_addmm( - input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out + input, + mat1_sparse, + mat1_meta, + mat2, + alpha=alpha, + beta=beta, + out_dtype=dtype_out, ) - torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) + torch.testing.assert_close( + output1.to(dtype_dense), output0, rtol=rtol, atol=atol + ) if dtype == torch.float32: # Inputs are converted to TF32 internally for sparse GEMM, @@ -1045,28 +1220,32 @@ def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): orig = torch.backends.cuda.matmul.allow_tf32 torch.backends.cuda.matmul.allow_tf32 = True - dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} + dtype_out = { + torch.int8: torch.int32, + torch.half: torch.half, + torch.bfloat16: torch.bfloat16, + torch.float32: torch.float32, + } rtol, atol = 1e-3, 1e-3 if dtype == torch.bfloat16: rtol, atol = 5e-3, 5e-3 elif dtype == torch.float32: rtol, atol = 1e-3, 75e-2 - for m, n, k, use_input in \ - itertools.product(range(3), range(3), range(3), (False, True)): - m = 2 ** m * 32 - n = 2 ** n * 32 - k = 2 ** k * 128 + for m, n, k, use_input in itertools.product( + range(3), range(3), range(3), (False, True) + ): + m = 2**m * 32 + n = 2**n * 32 + k = 2**k * 128 run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol) if dtype == torch.float32: torch.backends.cuda.matmul.allow_tf32 = orig - @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_conversions(self, device, dtype): - def run_test(r, c, device, dtype): dense_ref = rand_sparse_semi_structured(r, c, dtype, device) @@ -1097,7 +1276,9 @@ def run_test(r, c, device, dtype): def test_conversions_all_patterns(self, device, dtype): r, c = 32, 128 - dense_inv, dense_val = rand_sparse_semi_structured_all_patterns(r, c, dtype, device) + dense_inv, dense_val = rand_sparse_semi_structured_all_patterns( + r, c, dtype, device + ) compressed = to_sparse_semi_structured(dense_inv) dense = compressed.to_dense() @@ -1107,6 +1288,7 @@ def test_conversions_all_patterns(self, device, dtype): CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32] + def to_float8(x, dtype=torch.float8_e4m3fn): finfo = torch.finfo(dtype) # Calculate the scale as dtype max divided by absmax @@ -1119,17 +1301,20 @@ def to_float8(x, dtype=torch.float8_e4m3fn): # as both required as inputs to torch._scaled_mm return x_scl_sat.to(dtype), scale.float().reciprocal() + class TestSparseSemiStructuredCUSPARSELT(TestCase): """ This contains cuSPARSELt specific tests for torch._cslt_compress torch._cslt_sparse_mm """ + def setUp(self): SparseSemiStructuredTensor._FORCE_CUTLASS = False if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - self.skipTest('cuSPARSELt not enabled') + self.skipTest("cuSPARSELt not enabled") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", @@ -1152,6 +1337,7 @@ def test_sparse_fp8fp8_mm(self, dense_input_shape, device): ): dense_result = torch.mm(A_fp8_sparse, B_fp8) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", @@ -1159,19 +1345,26 @@ def test_sparse_fp8fp8_mm(self, dense_input_shape, device): @xfailIfSM89PreCUDA13 def test_sparse_semi_structured_scaled_mm_fp8(self, device) -> None: (k, l, m) = (32, 64, 32) - x = rand_sparse_semi_structured_mask(k, l, dtype=torch.float8_e4m3fn, device=device) - y = torch.full((m, l), .25, device=device, dtype=torch.float8_e4m3fn).t() + x = rand_sparse_semi_structured_mask( + k, l, dtype=torch.float8_e4m3fn, device=device + ) + y = torch.full((m, l), 0.25, device=device, dtype=torch.float8_e4m3fn).t() scale_a = torch.tensor(1.0, device=device) scale_b = torch.tensor(1.0, device=device) - out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn) + out_fp8 = torch._scaled_mm( + x, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn + ) x_sparse = to_sparse_semi_structured(x) - out_fp8_sparse = torch._scaled_mm(x_sparse, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn) + out_fp8_sparse = torch._scaled_mm( + x_sparse, y, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.float8_e4m3fn + ) # this fails on ROCm currently because hipblaslt doesn't have amax op out_fp32 = out_fp8.to(torch.float32) out_fp32_sparse = out_fp8_sparse.to(torch.float32) torch.testing.assert_close(out_fp32, out_fp32_sparse, rtol=1e-1, atol=1e-1) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", @@ -1206,7 +1399,9 @@ def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): B = torch.rand(dense_input_shape, device=device).to(torch.int8) - dense_result = torch.mm(A.cpu().to(torch.int64), B.t().cpu().to(torch.int64)).to(device, dtype=out_dtype) + dense_result = torch.mm( + A.cpu().to(torch.int64), B.t().cpu().to(torch.int64) + ).to(device, dtype=out_dtype) sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), out_dtype=out_dtype) torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) @@ -1216,7 +1411,7 @@ def test_cslt_sparse_mm_mixed_dtype(self, dense_input_shape, out_dtype, device): def test_cslt_sparse_mm_alpha(self, dtype, device): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() B = torch.ones((256, 128), device=device).to(dtype) - alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda() + alpha = torch.Tensor([2 ** (-i) for i in range(128)]).cuda() bias = torch.ones(128, device=device).to(dtype) A_compressed = torch._cslt_compress(A) @@ -1233,40 +1428,52 @@ def test_cslt_sparse_mm_alpha(self, dtype, device): def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device) B = torch.ones((128, 256), device=device, dtype=torch.int8).t() - alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda() + alpha = torch.Tensor([2 ** (-i) for i in range(128)]).cuda() A_compressed = torch._cslt_compress(A) cslt_sparse_mm_c = torch.compile(torch._cslt_sparse_mm, mode="max-autotune") - sparse_result = cslt_sparse_mm_c(A_compressed, B, alpha=alpha, out_dtype=out_dtype) + sparse_result = cslt_sparse_mm_c( + A_compressed, B, alpha=alpha, out_dtype=out_dtype + ) # disable this otherwise inductor will attempt to reorder strides and pass a contiguous B @torch.compiler.disable def get_dense_result(): alpha_scaled = torch.stack([alpha] * 128).t().cpu().float() - dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) + dense_result = alpha_scaled * torch.mm( + A.to(torch.int64).cpu(), B.to(torch.int64).cpu() + ) dense_result = dense_result.to(out_dtype) return dense_result - torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3 + ) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() B = torch.ones((128, 256), device=device).to(torch.int8).t() - alpha = torch.Tensor([2**(-i) if out_dtype is not torch.int32 else 1 - for i in range(128)]).cuda() + alpha = torch.Tensor( + [2 ** (-i) if out_dtype is not torch.int32 else 1 for i in range(128)] + ).cuda() A_compressed = torch._cslt_compress(A) - sparse_result = torch._cslt_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=out_dtype).cpu() + sparse_result = torch._cslt_sparse_mm( + A_compressed, B, alpha=alpha, out_dtype=out_dtype + ).cpu() alpha_scaled = torch.stack([alpha] * 128).t() - dense_result = alpha_scaled.cpu() * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu()) + dense_result = alpha_scaled.cpu() * torch.mm( + A.to(torch.int64).cpu(), B.to(torch.int64).cpu() + ) dense_result = dense_result.to(out_dtype) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @inference_dtypes def test_cslt_sparse_mm_search(self, device, dtype): A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) @@ -1280,6 +1487,7 @@ def test_cslt_sparse_mm_search(self, device, dtype): dense_result = dense_result.to(dtype) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") @inference_dtypes def test_csrc_cslt_sparse_mm_search(self, device, dtype): A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype) @@ -1287,22 +1495,30 @@ def test_csrc_cslt_sparse_mm_search(self, device, dtype): B = torch.ones((128, 128), device=device).to(dtype) A_compressed = torch._cslt_compress(A) - alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False) - sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), - alg_id=alg_id, - split_k=split_k, - split_k_mode=split_k_mode) + alg_id, split_k, split_k_mode, _ = torch._C._cusparselt.mm_search( + A_compressed, B.t(), None, None, None, False + ) + sparse_result = torch._cslt_sparse_mm( + A_compressed, + B.t(), + alg_id=alg_id, + split_k=split_k, + split_k_mode=split_k_mode, + ) dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32)) dense_result = dense_result.to(dtype) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cusparselt_backend(self): if not torch.backends.cusparselt.is_available(): raise AssertionError("cusparselt backend should be available") # PyTorch CUDA 12.4+ using cuSPARSELt v0.6.2+ if torch.backends.cusparselt.version() < 602: - raise AssertionError(f"cusparselt version should be >= 602, got {torch.backends.cusparselt.version()}") + raise AssertionError( + f"cusparselt version should be >= 602, got {torch.backends.cusparselt.version()}" + ) @inference_dtypes def test_cslt_sparse_tensor_clone(self, dtype): @@ -1312,13 +1528,151 @@ def test_cslt_sparse_tensor_clone(self, dtype): self.assertNotEqual(A_sparse.packed.data_ptr, A_clone.packed.data_ptr) self.assertEqual(A_sparse.packed, A_clone.packed) + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_device_cpu(self, dtype): + """Test .to('cpu') converts sparse semi-structured tensor to dense on CPU.""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + A_dense = A_sparse.to_dense() + + result = A_sparse.to("cpu") + self.assertEqual(result.device.type, "cpu") + self.assertFalse(isinstance(result, SparseSemiStructuredTensor)) + torch.testing.assert_close(result, A_dense.cpu(), rtol=1e-3, atol=1e-3) + + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_device_cuda(self, dtype): + """Test .to('cuda') / .to(device) raises NotImplementedError.""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + + with self.assertRaises(NotImplementedError): + A_sparse.to(A_sparse.device, copy=True) + + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_dtype(self, dtype): + """Test .to(dtype) raises NotImplementedError (only to('cpu') is supported).""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + + with self.assertRaises(NotImplementedError): + A_sparse.to(torch.float32) + + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_device_and_dtype(self, dtype): + """Test .to(device, dtype) converts both device and dtype.""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + A_dense = A_sparse.to_dense() + + result = A_sparse.to("cpu", torch.float32, copy=True) + self.assertEqual(result.device.type, "cpu") + self.assertEqual(result.dtype, torch.float32) + self.assertFalse(isinstance(result, SparseSemiStructuredTensor)) + torch.testing.assert_close( + result, A_dense.cpu().to(torch.float32), rtol=1e-3, atol=1e-3 + ) + + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_same_dtype_noop(self, dtype): + """Test .to(same_dtype) raises NotImplementedError (only to('cpu') is supported).""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + + with self.assertRaises(NotImplementedError): + A_sparse.to(dtype, copy=True) + + @dtypes(torch.float16, torch.bfloat16) + def test_semi_sparse_to_kwargs(self, dtype): + """Test .to() with keyword arguments (device=, dtype=).""" + A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype).cuda() + A_sparse = to_sparse_semi_structured(A) + A_dense = A_sparse.to_dense() + + result = A_sparse.to(device="cpu", dtype=torch.float32) + self.assertEqual(result.device.type, "cpu") + self.assertEqual(result.dtype, torch.float32) + self.assertFalse(isinstance(result, SparseSemiStructuredTensor)) + torch.testing.assert_close( + result, A_dense.cpu().to(torch.float32), rtol=1e-3, atol=1e-3 + ) + + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") + def test_cslt_sparse_tensor_with_alg_id(self, device): + A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16).cuda() + A_compressed = torch._cslt_compress(A) + B = torch.ones((128, 128), device=device).to(torch.float16) + alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t()) + A_sparse = to_sparse_semi_structured(A, alg_id=alg_id) + self.assertEqual(A_sparse.alg_id_cusparselt, alg_id) + dense_result = torch.mm(A, B).to(torch.float16) + sparse_result = torch.mm(A_sparse, B).to(torch.float16) + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) + + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8_SPARSE, + "FP8 sparse requires cuSPARSELt v0.6.2+ on SM 8.9+ or MI350+ (gfx950) on ROCm", + ) + def test_cslt_compress_fp8(self, device): + A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) + A_fp8, _ = to_float8(A) + compressed = torch._cslt_compress(A_fp8) + self.assertEqual(compressed.dtype, torch.float8_e4m3fn) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8_SPARSE, + "FP8 sparse requires cuSPARSELt v0.6.2+ on SM 8.9+ or MI350+ (gfx950) on ROCm", + ) + @parametrize("dense_input_shape", [(256, 128)]) + def test_cslt_sparse_mm_fp8_to_fp32(self, dense_input_shape, device): + A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) + B = torch.rand(dense_input_shape, device=device).to(torch.float16).t() + + A_fp8, _ = to_float8(A) + B_fp8, _ = to_float8(B) + + compressed = torch._cslt_compress(A_fp8) + sparse_result = torch._cslt_sparse_mm(compressed, B_fp8, out_dtype=torch.float32) + + self.assertEqual(sparse_result.dtype, torch.float32) + + dense_result = torch.mm(A_fp8.to(torch.float32), B_fp8.to(torch.float32)) + torch.testing.assert_close(sparse_result, dense_result, rtol=1e-1, atol=1e-1) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FP8_SPARSE, + "FP8 sparse requires cuSPARSELt v0.6.2+ on SM 8.9+ or MI350+ (gfx950) on ROCm", + ) + @unittest.skipIf(not TEST_WITH_ROCM, "ROCm-specific out_dtype restriction") + @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) + def test_cslt_sparse_mm_fp8_unsupported_out_dtype_rocm(self, out_dtype, device): + A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) + B = torch.rand(128, 128, device=device).to(torch.float16).t() + + A_fp8, _ = to_float8(A) + B_fp8, _ = to_float8(B) + + compressed = torch._cslt_compress(A_fp8) + with self.assertRaisesRegex( + RuntimeError, + "Unsupported out_dtype passed, must be float32 for fp8 inputs on ROCm", + ): + torch._cslt_sparse_mm(compressed, B_fp8, out_dtype=out_dtype) + if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) > 0: instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") if "cutlass" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda") - instantiate_device_type_tests(TestSparseSemiStructuredTraining, globals(), only_for="cuda") + instantiate_device_type_tests( + TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda" + ) + instantiate_device_type_tests( + TestSparseSemiStructuredTraining, globals(), only_for="cuda" + ) if "cusparselt" in SEMI_STRUCTURED_SUPPORTED_BACKENDS: - instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda") + instantiate_device_type_tests( + TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda" + ) if __name__ == "__main__": run_tests() diff --git a/test/test_stateless_rng.py b/test/test_stateless_rng.py new file mode 100644 index 0000000000000..08631a9eaa974 --- /dev/null +++ b/test/test_stateless_rng.py @@ -0,0 +1,567 @@ +# Owner(s): ["module: random"] + +import torch +import torch.func._random as random +from torch.testing._internal.common_device_type import ( + dtypes, + instantiate_device_type_tests, +) +from torch.testing._internal.common_dtype import floating_types_and +from torch.testing._internal.common_utils import parametrize, run_tests, TestCase + + +all_floating_dtypes = floating_types_and(torch.half, torch.bfloat16) + + +class TestStatelessRNGKey(TestCase): + def test_basic_shape_and_dtype(self, device): + key = random.key(42, device=device) + self.assertEqual(key.shape, (2,)) + self.assertEqual(key.dtype, torch.uint64) + self.assertEqual(key.device, torch.device(device)) + + def test_different_seeds(self, device): + key1 = random.key(42, device=device) + key2 = random.key(43, device=device) + self.assertNotEqual(key1, key2) + + def test_determinism(self, device): + key1 = random.key(42, device=device) + key2 = random.key(42, device=device) + self.assertEqual(key1, key2) + + def test_error_unsupported_impl(self, device): + with self.assertRaisesRegex( + NotImplementedError, "does not support PRNG impl 'unsupported'" + ): + random.key(42, impl="unsupported", device=device) + + +class TestStatelessRNGKeySplit(TestCase): + def test_basic_shape_and_dtype(self, device): + key = random.key(42, device=device) + splits = random.split(key, 4) + self.assertEqual(splits.shape, (4, 2)) + self.assertEqual(splits.dtype, torch.uint64) + self.assertEqual(splits.device, key.device) + + def test_single_split(self, device): + key = random.key(42, device=device) + splits = random.split(key, 1) + self.assertEqual(splits.shape, (1, 2)) + + def test_large_num_splits(self, device): + key = random.key(42, device=device) + splits = random.split(key, 10000) + self.assertEqual(splits.shape, (10000, 2)) + + def test_determinism(self, device): + key = random.key(42, device=device) + splits1 = random.split(key, 8) + splits2 = random.split(key, 8) + self.assertEqual(splits1, splits2) + + def test_all_keys_unique(self, device): + key = random.key(42, device=device) + splits = random.split(key, 100) + unique_keys = torch.unique(splits, dim=0) + self.assertEqual(unique_keys.shape[0], 100) + + def test_different_seeds_produce_different_outputs(self, device): + key1 = random.key(42, device=device) + key2 = random.key(43, device=device) + splits1 = random.split(key1, 4) + splits2 = random.split(key2, 4) + self.assertNotEqual(splits1, splits2) + + def test_different_offsets_produce_different_outputs(self, device): + key1 = random.key(42, device=device) + key2 = random.fold_in(key1, 1) + splits1 = random.split(key1, 4) + splits2 = random.split(key2, 4) + self.assertNotEqual(splits1, splits2) + + def test_offset_zero_vs_one_produce_different_splits(self, device): + key1 = random.key(42, device=device) + key2 = torch.tensor([42, 1], dtype=torch.uint64, device=device) + splits1 = random.split(key1, 4) + splits2 = random.split(key2, 4) + self.assertNotEqual(splits1, splits2) + + def test_batched(self, device): + key = random.key(42, device=device) + keys = random.split(key, 4) # (4, 2) + num_splits = 3 + batched = random.split(keys, num_splits) # (3, 4, 2) + self.assertEqual(batched.shape, (num_splits, 4, 2)) + for k in range(4): + individual = random.split(keys[k], num_splits) + for s in range(num_splits): + self.assertEqual(batched[s][k], individual[s]) + + def test_multi_batch(self, device): + key = random.key(42, device=device) + keys = random.split(key, 12).reshape(3, 4, 2) + num_splits = 5 + batched = random.split(keys, num_splits) # (5, 3, 4, 2) + self.assertEqual(batched.shape, (num_splits, 3, 4, 2)) + for i in range(3): + for j in range(4): + individual = random.split(keys[i][j], num_splits) + for s in range(num_splits): + self.assertEqual(batched[s][i][j], individual[s]) + + def test_error_wrong_shape(self, device): + key = torch.tensor([42, 0, 1], dtype=torch.uint64, device=device) + with self.assertRaisesRegex( + RuntimeError, r"key must have shape \(\*batch, 2\)" + ): + random.split(key, 4) + + def test_error_wrong_dtype(self, device): + key = torch.tensor([42, 0], dtype=torch.float32, device=device) + with self.assertRaisesRegex(RuntimeError, "key must have dtype uint64"): + random.split(key, 4) + + def test_error_wrong_device(self, device): + key = random.key(42) # CPU key + with self.assertRaisesRegex( + NotImplementedError, + "Could not run .* with arguments from the 'CPU' backend", + ): + random.split(key, 4) + + def test_error_invalid_num_splits(self, device): + key = random.key(42, device=device) + with self.assertRaisesRegex(RuntimeError, "num_splits must be positive"): + random.split(key, 0) + with self.assertRaisesRegex(RuntimeError, "num_splits must be positive"): + random.split(key, -1) + + def test_error_batched_last_dim_not_2(self, device): + key = torch.tensor([[42, 0, 1], [43, 0, 1]], dtype=torch.uint64, device=device) + with self.assertRaisesRegex( + RuntimeError, r"key must have shape \(\*batch, 2\)" + ): + random.split(key, 4) + + def test_offset_overflow(self, device): + near_max = (1 << 64) - 1 + key = torch.tensor([42, near_max], dtype=torch.uint64, device=device) + splits = random.split(key, 3) + # split_idx=1 wraps offset to 0, split_idx=2 wraps to 1 + key0 = torch.tensor([42, 0], dtype=torch.uint64, device=device) + self.assertEqual(splits[1], random.fold_in(key0, 0)) + self.assertEqual(splits[2], random.fold_in(key0, 1)) + + +class TestStatelessRNGKeyFoldIn(TestCase): + def test_basic_shape_and_dtype(self, device): + key = random.key(42, device=device) + result = random.fold_in(key, 7) + self.assertEqual(result.shape, (2,)) + self.assertEqual(result.dtype, torch.uint64) + self.assertEqual(result.device, key.device) + + def test_determinism(self, device): + key = random.key(42, device=device) + result1 = random.fold_in(key, 7) + result2 = random.fold_in(key, 7) + self.assertEqual(result1, result2) + + def test_fold_in_produces_new_key_for_zero_data(self, device): + key = random.key(42, device=device) + folded = random.fold_in(key, 0) + self.assertNotEqual(folded, key) + + def test_different_data_produces_different_outputs(self, device): + key = random.key(42, device=device) + result1 = random.fold_in(key, 0) + result2 = random.fold_in(key, 1) + self.assertNotEqual(result1, result2) + + def test_consistency_with_split(self, device): + key = random.key(42, device=device) + splits = random.split(key, 10) + for i in range(10): + folded = random.fold_in(key, i) + self.assertEqual(folded, splits[i]) + + def test_batched(self, device): + key = random.key(42, device=device) + keys = random.split(key, 4) # (4, 2) + data = 7 + batched = random.fold_in(keys, data) # (4, 2) + self.assertEqual(batched.shape, (4, 2)) + for k in range(4): + individual = random.fold_in(keys[k], data) + self.assertEqual(batched[k], individual) + + def test_multi_batch(self, device): + key = random.key(42, device=device) + keys = random.split(key, 12).reshape(3, 4, 2) + data = 7 + batched = random.fold_in(keys, data) # (3, 4, 2) + self.assertEqual(batched.shape, (3, 4, 2)) + for i in range(3): + for j in range(4): + individual = random.fold_in(keys[i][j], data) + self.assertEqual(batched[i][j], individual) + + def test_error_wrong_shape(self, device): + key = torch.tensor([42, 0, 1], dtype=torch.uint64, device=device) + with self.assertRaisesRegex( + RuntimeError, r"key must have shape \(\*batch, 2\)" + ): + random.fold_in(key, 0) + + def test_error_wrong_dtype(self, device): + key = torch.tensor([42, 0], dtype=torch.float32, device=device) + with self.assertRaisesRegex(RuntimeError, "key must have dtype uint64"): + random.fold_in(key, 0) + + def test_error_wrong_device(self, device): + key = random.key(42) # CPU key + with self.assertRaisesRegex( + NotImplementedError, + "Could not run .* with arguments from the 'CPU' backend", + ): + random.fold_in(key, 0) + + def test_error_batched_last_dim_not_2(self, device): + key = torch.tensor([[42, 0, 1], [43, 0, 1]], dtype=torch.uint64, device=device) + with self.assertRaisesRegex( + RuntimeError, r"key must have shape \(\*batch, 2\)" + ): + random.fold_in(key, 0) + + def test_offset_overflow(self, device): + near_max = (1 << 64) - 1 + key = torch.tensor([42, near_max], dtype=torch.uint64, device=device) + # fold_in(data=1) wraps offset to 0, so it should match fold_in on + # a key with offset=0 and data=0. + result = random.fold_in(key, 1) + key0 = torch.tensor([42, 0], dtype=torch.uint64, device=device) + self.assertEqual(result, random.fold_in(key0, 0)) + + +class TestStatelessRNGDistribution(TestCase): + def _gen(self, gen_fn_name, *args, **kwargs): + return getattr(random, gen_fn_name)(*args, **kwargs) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_basic_shape(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + result = self._gen(gen_fn_name, key, (100,), dtype=dtype) + self.assertEqual(result.shape, (100,)) + self.assertEqual(result.dtype, dtype) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_determinism(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + a = self._gen(gen_fn_name, key, (1000,), dtype=dtype) + b = self._gen(gen_fn_name, key, (1000,), dtype=dtype) + self.assertEqual(a, b) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_different_keys(self, device, dtype, gen_fn_name): + key1 = random.key(42, device=device) + key2 = random.key(43, device=device) + a = self._gen(gen_fn_name, key1, (1000,), dtype=dtype) + b = self._gen(gen_fn_name, key2, (1000,), dtype=dtype) + self.assertNotEqual(a, b) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_batched_keys(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + keys = random.split(key, 4).unsqueeze(-2) # (4, 1, 2) + result = self._gen(gen_fn_name, keys, (4, 100), dtype=dtype) + for i in range(4): + individual = self._gen(gen_fn_name, keys[i], (100,), dtype=dtype) + self.assertEqual(result[i], individual) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_batched_keys_large(self, device, dtype, gen_fn_name): + # Large event_numel to exercise the multi-key tiled kernel path. + key = random.key(42, device=device) + keys = random.split(key, 4).unsqueeze(-2) # (4, 1, 2) + result = self._gen(gen_fn_name, keys, (4, 10000), dtype=dtype) + for i in range(4): + individual = self._gen(gen_fn_name, keys[i], (10000,), dtype=dtype) + self.assertEqual(result[i], individual) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_multi_batch(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + keys = random.split(key, 6).view(2, 3, 1, 2) + result = self._gen(gen_fn_name, keys, (2, 3, 50), dtype=dtype) + for i in range(2): + for j in range(3): + individual = self._gen(gen_fn_name, keys[i][j], (50,), dtype=dtype) + self.assertEqual(result[i][j], individual) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_key_broadcasting_semantics(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + + # Broadcast key dim: size-1 dims replicate, other dims index keys. + keys = random.split(key, 3).unsqueeze(0).unsqueeze(-2) # (1, 3, 1, 2) + result = self._gen(gen_fn_name, keys, (4, 3, 100), dtype=dtype) + for i in range(1, 4): + self.assertEqual(result[0], result[i]) + for j in range(1, 3): + self.assertNotEqual(result[0][0], result[0][j]) + + # All-broadcast key matches unbatched. + batched = self._gen(gen_fn_name, key.view(1, 1, 2), (4, 100), dtype=dtype) + unbatched = self._gen(gen_fn_name, key, (400,), dtype=dtype) + self.assertEqual(batched.flatten(), unbatched) + + # Multiple trailing size-1 dims to broadcast over. + keys = random.split(key, 4).view(4, 1, 1, 2) + result = self._gen(gen_fn_name, keys, (4, 10, 100), dtype=dtype) + for i in range(4): + individual = self._gen(gen_fn_name, keys[i], (10, 100), dtype=dtype) + self.assertEqual(result[i], individual) + keys_flat = random.split(key, 4).unsqueeze(-2) # (4, 1, 2) + flat = self._gen(gen_fn_name, keys_flat, (4, 1000), dtype=dtype) + self.assertEqual(result.view(4, 1000), flat) + + # No generation dims: every element gets its own key. + keys = random.split(key, 12).view(4, 3, 2) + result = self._gen(gen_fn_name, keys, (4, 3), dtype=dtype) + for i in range(4): + for j in range(3): + individual = self._gen(gen_fn_name, keys[i][j], (1,), dtype=dtype) + self.assertEqual(result[i][j], individual.squeeze()) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + def test_error_wrong_key_dtype(self, device, gen_fn_name): + key = torch.tensor([42, 0], dtype=torch.float32, device=device) + with self.assertRaisesRegex(RuntimeError, "key must have dtype uint64"): + self._gen(gen_fn_name, key, (100,)) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + def test_error_key_shape(self, device, gen_fn_name): + key = random.key(42, device=device) + # Last dim must be 2. + bad_key = torch.tensor([42, 0, 1], dtype=torch.uint64, device=device) + with self.assertRaisesRegex( + RuntimeError, r"key must have shape \(2,\) or \(\*batch, 2\)" + ): + self._gen(gen_fn_name, bad_key, (100,)) + # Key batch ndim must equal output ndim (too few). + with self.assertRaisesRegex( + RuntimeError, "batched key must have ndim == output ndim \\+ 1" + ): + self._gen(gen_fn_name, random.split(key, 3), (3, 4, 100)) + # Key batch ndim must equal output ndim (too many). + with self.assertRaisesRegex( + RuntimeError, "batched key must have ndim == output ndim \\+ 1" + ): + self._gen(gen_fn_name, random.split(key, 3).view(3, 1, 1, 2), (3, 100)) + # Key batch dims must be broadcastable with output. + with self.assertRaisesRegex( + RuntimeError, "is not broadcastable with output shape" + ): + self._gen(gen_fn_name, random.split(key, 5).unsqueeze(-2), (3, 100)) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_offset_shift_consistency(self, device, dtype, gen_fn_name): + seed = 42 + n = 100 + key0 = random.key(seed, device=device) + ref = self._gen(gen_fn_name, key0, (n,), dtype=dtype) + + # as a key's offset shifts, we expect the stream to shift by + # the number of elements per philox call (2 for double; 4 otherwise) + for offset in range(1, 4): + key = torch.tensor([seed, offset], dtype=torch.uint64, device=device) + elems_per_call = 2 if dtype == torch.float64 else 4 + expected_shift = offset * elems_per_call + result = self._gen(gen_fn_name, key, (n,), dtype=dtype) + self.assertEqual(ref[expected_shift:], result[:-expected_shift]) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_offset_overflow(self, device, dtype, gen_fn_name): + seed = 42 + n = 100 + last_offset_before_wrap = (1 << 64) - 1 + key = torch.tensor( + [seed, last_offset_before_wrap], dtype=torch.uint64, device=device + ) + result = self._gen(gen_fn_name, key, (n,), dtype=dtype) + + # ensure offset wraps around to 0 by comparing with 0-offset key results + key0 = random.key(seed, device=device) + result0 = self._gen(gen_fn_name, key0, (n,), dtype=dtype) + elems_per_call = 2 if dtype == torch.float64 else 4 + self.assertEqual(result[elems_per_call:], result0[:-elems_per_call]) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_small_output_sizes(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + large = self._gen(gen_fn_name, key, (100,), dtype=dtype) + for n in [0, 1, 2, 3, 4, 5, 7]: + result = self._gen(gen_fn_name, key, (n,), dtype=dtype) + self.assertEqual(result.shape, (n,)) + # Determinism. + result2 = self._gen(gen_fn_name, key, (n,), dtype=dtype) + self.assertEqual(result, result2) + # Prefix consistency: first n elements of a larger output. + if n > 0: + self.assertEqual(result, large[:n]) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @parametrize("layout", ["contiguous", "noncontiguous", "unaligned"]) + @dtypes(*all_floating_dtypes) + def test_inplace(self, device, dtype, gen_fn_name, layout): + key = random.key(42, device=device) + if layout == "contiguous": + result = torch.empty(1000, dtype=dtype, device=device) + elif layout == "noncontiguous": + result = torch.empty(2000, dtype=dtype, device=device)[::2] + else: + # Contiguous but data pointer is not aligned to vectorized write width. + result = torch.empty(1001, dtype=dtype, device=device)[1:] + inplace_fn = getattr(random, gen_fn_name + "_") + out = inplace_fn(key, result) + self.assertIs(out, result) + functional = self._gen(gen_fn_name, key, (1000,), dtype=dtype) + self.assertEqual(result, functional) + + @parametrize("gen_fn_name", ["normal", "uniform"]) + @dtypes(*all_floating_dtypes) + def test_empty_output(self, device, dtype, gen_fn_name): + key = random.key(42, device=device) + result = self._gen(gen_fn_name, key, (0,), dtype=dtype) + self.assertEqual(result.shape, (0,)) + self.assertEqual(result.dtype, dtype) + result = self._gen(gen_fn_name, key, (3, 0), dtype=dtype) + self.assertEqual(result.shape, (3, 0)) + self.assertEqual(result.dtype, dtype) + + # Distribution-specific tests + + @dtypes(*all_floating_dtypes) + def test_standard_normal_statistics(self, device, dtype): + key = random.key(42, device=device) + result = random.normal(key, (100000,), dtype=dtype) + self.assertTrue(abs(result.mean().item()) < 0.05) + self.assertTrue(abs(result.std().item() - 1.0) < 0.05) + + @dtypes(*all_floating_dtypes) + def test_custom_mean_std(self, device, dtype): + key = random.key(42, device=device) + result = random.normal(key, (100000,), mean=5.0, std=2.0, dtype=dtype) + self.assertTrue(abs(result.mean().item() - 5.0) < 0.1) + self.assertTrue(abs(result.std().item() - 2.0) < 0.1) + + @dtypes(*all_floating_dtypes) + def test_standard_uniform_statistics(self, device, dtype): + key = random.key(42, device=device) + result = random.uniform(key, (100000,), dtype=dtype) + self.assertTrue(abs(result.mean().item() - 0.5) < 0.05) + self.assertTrue(result.min().item() >= 0.0) + self.assertTrue(result.max().item() < 1.0) + + @dtypes(*all_floating_dtypes) + def test_custom_low_high(self, device, dtype): + key = random.key(42, device=device) + result = random.uniform(key, (100000,), low=2.0, high=5.0, dtype=dtype) + self.assertTrue(abs(result.mean().item() - 3.5) < 0.1) + self.assertTrue(result.min().item() >= 2.0) + self.assertTrue(result.max().item() <= 5.0) + + +class TestStatelessRNGCompile(TestCase): + def test_split_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + return random.split(key, 4) + + self.assertEqual(f(key), random.split(key, 4)) + + def test_fold_in_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + return random.fold_in(key, 7) + + self.assertEqual(f(key), random.fold_in(key, 7)) + + def test_uniform_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + return random.uniform(key, (100,)) + + self.assertEqual(f(key), random.uniform(key, (100,))) + + def test_normal_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + return random.normal(key, (100,)) + + self.assertEqual(f(key), random.normal(key, (100,))) + + def test_batched_normal_fullgraph(self, device): + key = random.key(42, device=device) + keys = random.split(key, 4).unsqueeze(-2) # (4, 1, 2) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(keys): + return random.normal(keys, (4, 50)) + + self.assertEqual(f(keys), random.normal(keys, (4, 50))) + + def test_split_then_normal_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + keys = random.split(key, 4).unsqueeze(-2) + return random.normal(keys, (4, 100)) + + self.assertEqual( + f(key), random.normal(random.split(key, 4).unsqueeze(-2), (4, 100)) + ) + + def test_fold_in_then_uniform_fullgraph(self, device): + key = random.key(42, device=device) + + @torch.compile(backend="aot_eager", fullgraph=True) + def f(key): + k = random.fold_in(key, 3) + return random.uniform(k, (100,)) + + self.assertEqual(f(key), random.uniform(random.fold_in(key, 3), (100,))) + + +instantiate_device_type_tests(TestStatelessRNGKey, globals(), only_for=("cuda",)) +instantiate_device_type_tests(TestStatelessRNGKeySplit, globals(), only_for=("cuda",)) +instantiate_device_type_tests(TestStatelessRNGKeyFoldIn, globals(), only_for=("cuda",)) +instantiate_device_type_tests( + TestStatelessRNGDistribution, globals(), only_for=("cuda",) +) +instantiate_device_type_tests(TestStatelessRNGCompile, globals(), only_for=("cuda",)) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index 002157b4a7d40..e91aab10ac323 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -563,7 +563,7 @@ def test_fusion_outputs(self): torch.testing.assert_close(o_ref[i], o_test[i]) def test_create_object(self): - class Foo: # noqa: B903 + class Foo: def __init__(self, x: torch.Tensor) -> None: self.x = x diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 287ee3cb7e421..87e7db57318b6 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -4138,7 +4138,8 @@ def _check(self, original, cvt=lambda t: t, is_alias=True, same_dtype=True, same 3. Whether the result lives in the expected device 4. Whether the result has its 'requires_grad' set or not """ - result = torch.asarray(cvt(original), **kwargs) + converted_original = cvt(original) + result = torch.asarray(converted_original, **kwargs) self.assertTrue(isinstance(result, torch.Tensor)) # 1. The storage pointers should be equal only if 'is_alias' is set @@ -4169,8 +4170,10 @@ def _check(self, original, cvt=lambda t: t, is_alias=True, same_dtype=True, same if device.index is not None: self.assertEqual(device.index, result.device.index) - # 4. By default, 'requires_grad' is unset - self.assertEqual(result.requires_grad, kwargs.get("requires_grad", False)) + # 4. By default, 'requires_grad' mirrors the original tensor's requires_grad, if + # present. + original_requires_grad = converted_original.requires_grad if isinstance(converted_original, torch.Tensor) else False + self.assertEqual(result.requires_grad, kwargs.get("requires_grad", original_requires_grad)) def _test_alias_with_cvt(self, cvt, device, dtype, shape=(5, 5), only_with_dtype=False): original = make_tensor(shape, dtype=dtype, device=device) @@ -4335,7 +4338,7 @@ def test_retain_autograd_history(self, device, dtype): def check(**kwargs): a = torch.asarray(cloned, **kwargs) - requires_grad = kwargs.get("requires_grad", False) + requires_grad = kwargs.get("requires_grad", cloned.requires_grad) self.assertEqual(a.requires_grad, requires_grad) # Autograd history shouldn't be retained when requires_grad is False self.assertEqual(a.grad_fn is None, not requires_grad) diff --git a/test/test_tensorboard.py b/test/test_tensorboard.py index 8b63014f46749..2805f47520520 100644 --- a/test/test_tensorboard.py +++ b/test/test_tensorboard.py @@ -740,6 +740,7 @@ def test_torchvision_smoke(self): class TestTensorBoardFigure(BaseTestCase): @skipIfNoMatplotlib + @skipIfTorchDynamo("dynamo fails to trace matplotlib WRITEABLE flag and slice.indices") def test_figure(self): writer = self.createSummaryWriter() @@ -765,6 +766,7 @@ def test_figure(self): writer.close() @skipIfNoMatplotlib + @skipIfTorchDynamo("dynamo fails to trace matplotlib WRITEABLE flag and slice.indices") def test_figure_list(self): writer = self.createSummaryWriter() @@ -781,13 +783,13 @@ def test_figure_list(self): writer.add_figure("add_figure/figure_list", figures, 0, close=False) self.assertTrue( all(plt.fignum_exists(figure.number) is True for figure in figures) - ) # noqa: F812 + ) writer.add_figure("add_figure/figure_list", figures, 1) if matplotlib.__version__ != "3.3.0": self.assertTrue( all(plt.fignum_exists(figure.number) is False for figure in figures) - ) # noqa: F812 + ) else: print( "Skipping fignum_exists, see https://github.com/matplotlib/matplotlib/issues/18163" diff --git a/test/test_testing.py b/test/test_testing.py index 21d7c0017b664..12b55c30b77e8 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -2139,7 +2139,7 @@ def test_op_parametrized(self, device, dtype, op, flag): for op in op_db: for dtype in op.supported_dtypes(torch.device(device).type): for flag_part in ('flag_disabled', 'flag_enabled'): - expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}' # noqa: B950 + expected_name = f'{device_cls.__name__}.test_op_parametrized_{op.formatted_name}_{flag_part}_{device}_{dtype_name(dtype)}' expected_test_names.append(expected_name) test_names = _get_test_names_for_test_class(device_cls) @@ -2397,6 +2397,8 @@ def _check_python_output(cls, program) -> str: cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") @skipIfXpu(msg="The test is flaky on XPU, see https://github.com/pytorch/pytorch/issues/110040") + # The test is flaky on ROCm/XPU and has been open and close multiple times + # https://github.com/pytorch/pytorch/issues/110040 def test_circular_dependencies(self) -> None: """ Checks that all modules inside torch can be imported Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """ @@ -2408,6 +2410,7 @@ def test_circular_dependencies(self) -> None: "torch.ao.pruning._experimental.", # depends on pytorch_lightning, not user-facing "torch.onnx._internal", # depends on onnx-script "torch._inductor.runtime.triton_helpers", # depends on triton + "torch._native.ops.bmm_outer_product.triton_kernels", # depends on triton "torch._inductor.codegen.cuda", # depends on cutlass "torch._inductor.codegen.cutedsl", # depends on cutlass "torch.distributed.benchmarks", # depends on RPC and DDP Optim diff --git a/test/test_torch.py b/test/test_torch.py index 8485055ebb5ac..325ce2094198a 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -15,7 +15,6 @@ import random import re import copy -import os import tempfile import unittest import warnings @@ -84,52 +83,6 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or torch.cuda.is_tf32_supported() -@contextlib.contextmanager -def torch_vital_set(value): - stash = None - if 'TORCH_VITAL' in os.environ: - stash = os.environ['TORCH_VITAL'] - os.environ['TORCH_VITAL'] = value - try: - yield - finally: - if stash: - os.environ['TORCH_VITAL'] = stash - else: - del os.environ['TORCH_VITAL'] - -# Tests Vital Signs for Torch -# FIXME: document or deprecate whatever this is -class TestBasicVitalSigns(TestCase): - def test_basic_vitals(self): - with torch_vital_set(''): - self.assertFalse(torch.vitals_enabled()) - with torch_vital_set('ON'): - self.assertTrue(torch.vitals_enabled()) - - def test_basic_vitals_read_write(self): - with torch_vital_set('ON'): - self.assertTrue(torch.vitals_enabled()) - # This tests the code path of setting a vital - self.assertTrue(torch.set_vital('Dataloader', 'basic_unit_test', 'TEST_VALUE_STRING')) - self.assertIn('TEST_VALUE_STRING', torch.read_vitals()) - self.assertIn('CUDA.used', torch.read_vitals()) - - def test_dataloader_vitals(self): - with torch_vital_set('ON'): - inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) - tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) - dataset = torch.utils.data.TensorDataset(inps, tgts) - torch.utils.data.DataLoader(dataset, batch_size=2) - self.assertIn('Dataloader.enabled\t\t True', torch.read_vitals()) - -# FIXME: document or deprecate whatever this is -class TestVitalSignsCuda(TestCase): - @onlyCUDA - def test_cuda_vitals_gpu_only(self, device): - with torch_vital_set('ON'): - self.assertIn('CUDA.used\t\t true', torch.read_vitals()) - is_cuda_sm86 = torch.cuda.is_available() and torch.cuda.get_device_capability(0) == (8, 6) @@ -1861,6 +1814,22 @@ def test_nondeterministic_alert_grid_sample_2d(self, device): 'grid_sampler_2d_backward_cuda', torch.device(device).type == 'cuda') + @unittest.skipIf(not TEST_CUDNN, "CUDNN not available") + @skipIfRocm + @onlyCUDA + @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") + def test_nondeterministic_alert_grid_sample_2d_cudnn(self, device): + def fn(): + input = torch.empty(1, 1, 2, 2, device=device, requires_grad=True) + grid = torch.empty(1, 1, 1, 2, device=device) + with torch.backends.cudnn.flags(enabled=True): + res = torch.nn.functional.grid_sample(input, grid, align_corners=True) + res.backward(torch.ones_like(res)) + + self.check_nondeterministic_alert( + fn, + 'cudnn_grid_sampler_backward') + @skipIfMPS @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") def test_nondeterministic_alert_grid_sample_3d(self, device): @@ -5004,6 +4973,27 @@ def test_storage_all_devices(self, devices, non_blocking): # that they have to materialize in the expected order. @skipXLA @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) + def test_const_data_ptr(self, device, dtype): + t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) + + # For a regular tensor, const_data_ptr and data_ptr return the same address + self.assertEqual(t.const_data_ptr(), t.data_ptr()) + + clone = t._lazy_clone() + + self.assertTrue(torch._C._is_cow_tensor(t)) + self.assertTrue(torch._C._is_cow_tensor(clone)) + + # const_data_ptr should not trigger COW materialization + addr = clone.const_data_ptr() + self.assertEqual(addr, t.const_data_ptr()) + + self.assertTrue(torch._C._is_cow_tensor(t)) + self.assertTrue(torch._C._is_cow_tensor(clone)) + + # See Note [lazy_clone_ tests with inductor enabled] + @skipXLA + @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_lazy_clone(self, device, dtype): t = torch.tensor([[0, 1], [2, 3]], device=device, dtype=dtype) t_orig_storage_addr = torch._C._storage_address(t) @@ -10593,31 +10583,6 @@ def callback(w): self.assertTrue(called) - def test_storage_thread_safety(self): - import threading - from concurrent.futures import ThreadPoolExecutor - - NUM_ITERS = 10 - NUM_THREADS = 4 - - # Concurrent calls to tensor.untyped_storage() - def access_untyped_storage(tensor, barrier): - barrier.wait() - return weakref.ref(tensor.untyped_storage()) - - for i in range(NUM_ITERS): - tensor = torch.tensor([1.0, 2.0, 3.0]) - barrier = threading.Barrier(NUM_THREADS) - with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: - futures = [ - executor.submit(access_untyped_storage, tensor, barrier) - for _ in range(NUM_THREADS) - ] - - # Check that all the storages returned were the same - for future in futures: - self.assertEqual(future.result()(), tensor.untyped_storage()) - # FIXME: move to test_linalg @torch.inference_mode() def test_bmm_multithreaded(self): @@ -10671,6 +10636,23 @@ def generate_inputs(num_batches): finally: torch.set_num_threads(num_threads) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_bmm_matmul_mixed_dtype_error(self): + a = torch.randn(2, 8, 8, device="cuda", dtype=torch.float16) + b = torch.randn(2, 8, 64, device="cuda", dtype=torch.float32) + + with self.assertRaisesRegex(RuntimeError, "expected scalar type .* but found"): + torch.bmm(a, b) + + with self.assertRaisesRegex(RuntimeError, "expected scalar type .* but found"): + torch.compile(lambda x, y: torch.bmm(x, y), fullgraph=True)(a, b) + + with self.assertRaisesRegex(RuntimeError, "expected scalar type .* but found"): + torch.matmul(a, b) + + with self.assertRaisesRegex(RuntimeError, "expected scalar type .* but found"): + torch.compile(lambda x, y: torch.matmul(x, y), fullgraph=True)(a, b) + def test_conj_neg_tolist(self): x = torch.randn(2, dtype=torch.cfloat) y1 = x.conj() @@ -10993,7 +10975,6 @@ class TestTensorDeviceOps(TestCase): # pytest will fail. add_neg_dim_tests() instantiate_device_type_tests(TestViewOps, globals()) -instantiate_device_type_tests(TestVitalSignsCuda, globals()) instantiate_device_type_tests(TestTensorDeviceOps, globals()) instantiate_device_type_tests(TestTorchDeviceType, globals()) instantiate_device_type_tests(TestDevicePrecision, globals(), except_for='cpu') diff --git a/test/test_torch_config_hash_determinism.py b/test/test_torch_config_hash_determinism.py index 74865ebc89226..0f81ea386b5a6 100644 --- a/test/test_torch_config_hash_determinism.py +++ b/test/test_torch_config_hash_determinism.py @@ -19,6 +19,10 @@ class TestConfigModule(TestCase): + # Config keys that legitimately contain absolute paths, for example, + # /opt/clang-15/bin/clang + KNOWN_PATH_CONFIGS = {"cpp.cxx"} + def check_deterministic(self, key: str, value: object): if isinstance(value, (int, float, bool)) or value is None: return @@ -79,6 +83,8 @@ def test_inductor_config_hash_portable_deterministic(self): torch_config = inductor_config.save_config_portable() for key, value in torch_config.items(): + if key in self.KNOWN_PATH_CONFIGS: + continue self.check_deterministic(key, value) def test_inductor_config_hash_portable_without_ignore(self): @@ -93,6 +99,8 @@ def test_inductor_config_hash_portable_without_ignore(self): f"Detected path in config value '.*', key='{cutlass_dir_key}'", ): for key, value in changed_torch_config.items(): + if key in self.KNOWN_PATH_CONFIGS: + continue self.check_deterministic(key, value) finally: inductor_config._cache_config_ignore_prefix.insert(idx, cutlass_dir_key) diff --git a/test/test_transformers.py b/test/test_transformers.py index da070bb2cb5d4..9d54dfceb3d2d 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -52,9 +52,12 @@ PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, PLATFORM_SUPPORTS_FUSED_ATTENTION, PLATFORM_SUPPORTS_CUDNN_ATTENTION, + PLATFORM_SUPPORTS_CK_SDPA, tf32_on_and_off, tf32_enabled, ) +from torch.testing._internal.common_device_type import skipXPUIf +from torch.testing._internal.common_xpu import PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU if TEST_FAIRSEQ: import fairseq.models.transformer as fairseq_transformer @@ -88,7 +91,6 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool): isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5 isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8 -TEST_WITH_CK = TEST_WITH_ROCM and torch.backends.cuda.preferred_rocm_fa_library() == torch.backends.cuda._ROCmFABackends['ck'] def _check_equal( golden: torch.Tensor, @@ -457,7 +459,7 @@ def hook(module, inputs, output): handle.remove() @skipIfRocmArch(MI300_ARCH) - @tf32_on_and_off(0.001) + @tf32_on_and_off(0.002) @parametrize("use_torchscript", [False]) @parametrize("enable_nested_tensor", [True, False]) @parametrize("use_autocast", [True, False]) @@ -2926,6 +2928,16 @@ def test_cudnn_attention_fail_d128(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel."): torch.nn.functional.scaled_dot_product_attention(q, k, v) + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system") + @parametrize("shape", [(65536, 1, 4, 8), (1, 65536, 4, 8)]) + def test_cudnn_attention_fail_large_batch_or_num_heads(self, device, shape): + q = torch.randn(shape, device=device, dtype=torch.float16) + k = torch.randn(shape, device=device, dtype=torch.float16) + v = torch.randn(shape, device=device, dtype=torch.float16) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + with self.assertRaisesRegex(RuntimeError, "No available kernel."): + torch.nn.functional.scaled_dot_product_attention(q, k, v) + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_trivial_output_transpose(self, device): # see also: https://github.com/pytorch/pytorch/issues/134001 @@ -3143,6 +3155,83 @@ def test_cudnn_attention_broken_166211(self): self.assertFalse(dk.isnan().any()) self.assertFalse(dv.isnan().any()) + @skipIfRocm + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_mask_broken_177842(self): + # https://github.com/pytorch/pytorch/issues/177842 + q = torch.randn(1, 10, 8, 8, dtype=torch.bfloat16, device='cuda') + k = torch.randn(1, 10, 1, 8, dtype=torch.bfloat16, device='cuda') + v = torch.randn(1, 10, 1, 8, dtype=torch.bfloat16, device='cuda') + + attention_mask_custom = torch.zeros(10, 10, dtype=torch.bool).to("cuda") + attention_mask_custom[:7, :7] = torch.tril(torch.ones(7, 7, dtype=torch.bool), diagonal=0) + + with sdpa_kernel(SDPBackend.MATH): + attn_output_math = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + attn_output_cudnn = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + self.assertEqual(attn_output_math, attn_output_cudnn, atol=5e-3, rtol=3e-3) + + attention_mask_custom = torch.zeros(10, 10, dtype=torch.bool).to("cuda") + attention_mask_custom[:7, :7] = torch.triu(torch.ones(7, 7, dtype=torch.bool), diagonal=0) + + with sdpa_kernel(SDPBackend.MATH): + attn_output_math = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + attn_output_cudnn = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + self.assertEqual(attn_output_math, attn_output_cudnn, atol=5e-3, rtol=3e-3) + + attention_mask_custom = torch.zeros(10, 10, dtype=torch.bool).to("cuda") + attention_mask_custom[:7, :10] = torch.ones(7, 10, dtype=torch.bool) + + with sdpa_kernel(SDPBackend.MATH): + attn_output_math = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + attn_output_cudnn = torch.nn.functional.scaled_dot_product_attention( + query=q.transpose(1, 2), + key=k.transpose(1, 2), + value=v.transpose(1, 2), + attn_mask=attention_mask_custom, + is_causal=False, + enable_gqa=True, + ) + self.assertEqual(attn_output_math, attn_output_cudnn, atol=5e-3, rtol=3e-3) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): @@ -3912,10 +4001,12 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, @parametrize("scale", [None, "l1"]) @parametrize("enable_gqa", [True, False]) @parametrize("n_heads", [[16, 8], [10, 2]]) + @parametrize("sdpa_backend", ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"]) @tf32_enabled() def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, - head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, - scale: str, enable_gqa: bool, n_heads: list[int]): + head_dim: int, is_causal: bool, dropout_p: float, + dtype: torch.dtype, scale: str, enable_gqa: bool, + n_heads: list[int], sdpa_backend: str): if isSM8XDevice or isSM120Device and head_dim in range(193, 256 + 1): self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: @@ -3925,8 +4016,14 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le if max(seq_len_q, seq_len_k) >= 2048 and torch.cuda.get_device_properties('cuda').total_memory < 40 * 2**30: unittest.skip("Reference implementation OOM") return - if TEST_WITH_CK and dropout_p != 0: - self.skipTest("CK does not support tensor format dropout masks") + + # ROCm now supports 2 different backends for SDPA that require different set up. + TEST_WITH_CK = False + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) + # When no args are given to preferred_rocm_fa_library, it acts as a getter + TEST_WITH_CK = (torch.backends.cuda.preferred_rocm_fa_library() == torch._C._ROCmFABackend.Ck) + if TEST_WITH_CK and head_dim > 128: self.skipTest("CK does not support head dims over 128") @@ -4000,15 +4097,24 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le softmax_mask = self.convert_flash_attn_S_to_softmax( dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask, causal=is_causal)[:, :, :seq_len_q, :seq_len_k] + + # This is the default implementation for the mask but we need to match CK if we are using it dropout_mask = softmax_mask >= 0 + + # This logic matches how CK calculates the dropout mask. + # This is necessary because CK doesn't support passing in custom dropout masks + # So we use this logic to ensure we are comparing apples to apples. + if TEST_WITH_CK: + dropout_mask = (softmax_mask <= int((1.0 - dropout_p) * 255.0)).to(torch.float32) + # High Precision Math Reference out_ref = torch.ops.aten._scaled_dot_product_attention_math( query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0] # Low Precision Math Reference out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math( - query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, - dropout_mask=dropout_mask, enable_gqa=enable_gqa)[0] + query, key, value, dropout_mask=dropout_mask, dropout_p=dropout_p, + is_causal=is_causal, scale=scale, enable_gqa=enable_gqa)[0] upstream_grad = torch.rand_like(out, requires_grad=False) @@ -4028,6 +4134,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le 'grad_value': 4, } if TEST_WITH_ROCM: + fudge_factors['grad_value'] = 6.0 if TEST_WITH_CK: fudge_factors['out'] = 5.0 @@ -4499,7 +4606,6 @@ class TestSDPAXpuOnly(NNTestCase): Mostly migrate from TestSDPACudaOnly in test/test_transformers.py """ - PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION = torch.xpu.is_available() and torch._C._is_flash_attention_available() @parametrize("type", ["dense"]) @parametrize("dropout", [0.0, 0.7]) @@ -4806,7 +4912,7 @@ def test_onednn_attention_mask_vs_math( self.assertEqual(actual.float(), math_ref, atol=tol.atol, rtol=tol.rtol) - @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") @parametrize("dtype", [torch.float32, torch.float64]) def test_flash_attention_unsupport_dtypes(self, device, dtype): make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) @@ -4824,7 +4930,7 @@ def test_flash_attention_unsupport_dtypes(self, device, dtype): with self.assertRaisesRegex(RuntimeError, "No available kernel"): F.scaled_dot_product_attention(q, k, v) - @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") def test_flash_attention_unsupport_dropout(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) @@ -4842,7 +4948,7 @@ def test_flash_attention_unsupport_dropout(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel"): F.scaled_dot_product_attention(q, k, v, dropout_p=0.1) - @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") def test_flash_attention_headdim_size(self, device): dtype = torch.bfloat16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=False) @@ -4866,7 +4972,7 @@ def test_flash_attention_headdim_size(self, device): with self.assertRaisesRegex(RuntimeError, "No available kernel"): F.scaled_dot_product_attention(q, k, v) - @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") def test_flash_attention_fail_with_non_square_causal_attention(self, device): dtype = torch.bfloat16 q_shape = SdpaShape(1, 1, 8, 16) @@ -4880,7 +4986,7 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, is_causal=True)) - @unittest.skipIf(not PLATFORM_SUPPORTS_XPU_FLASH_ATTENTION, "XPU Flash Attention is not supported") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") @parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION]) @parametrize("dtype", [torch.half, torch.bfloat16]) @parametrize("batch_size", [1, 2, 4]) @@ -5046,6 +5152,7 @@ def run_test( "shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], ) + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True @@ -5078,6 +5185,7 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: lis [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], ) @skipIfTorchDynamo("This function already calls torch.compile.") + @skipXPUIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU, "XPU Flash Attention is not supported") def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): cnts = CompileCounterWithBackend("aot_eager") make_tensor = partial( diff --git a/test/test_tsan.py b/test/test_tsan.py new file mode 100644 index 0000000000000..9cb4c51db8a86 --- /dev/null +++ b/test/test_tsan.py @@ -0,0 +1,48 @@ +# Owner(s): ["module: multithreading"] + +import weakref + +import torch +from torch.testing._internal.common_utils import run_concurrently, run_tests, TestCase + + +class TestTSan(TestCase): + def test_storage_thread_safety(self): + # Concurrent calls to tensor.untyped_storage() + def access_untyped_storage(tensor): + return weakref.ref(tensor.untyped_storage()) + + for _ in range(10): + tensor = torch.tensor([1.0, 2.0, 3.0]) + weakrefs = run_concurrently( + access_untyped_storage, args=(tensor,), num_threads=4 + ) + for wr in weakrefs: + self.assertEqual(wr(), tensor.untyped_storage()) + + def test_concurrent_profiling(self): + """Repeatedly start/stop profiling while background threads are active. + + On free-threaded Python (3.14t+), this exercises concurrent access to + the profiler's per-thread state without GIL protection. Without the + thread-safety fixes (setprofileAllThreads, per-thread ValueCache, + StopTheWorldGuard), this crashes from heap corruption due to data + races on the shared hash maps. + """ + + def work(): + for _ in range(100): + d = {str(i): list(range(i % 10)) for i in range(20)} + _ = sorted(d.items(), key=lambda x: len(x[1])) + torch.ones(10) + torch.zeros(10) + + def profile_work(): + for _ in range(30): + with torch.profiler.profile(with_stack=True, with_modules=True): + torch.ones(10) + + run_concurrently([profile_work] + [work] * 8) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_typing.py b/test/test_typing.py index b2129b2f7867e..077fbb076f4c8 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -52,7 +52,7 @@ def _strip_filename(msg: str) -> str: def _run_mypy() -> dict[str, list[str]]: """Clears the cache and run mypy before running any of the typing tests.""" if os.path.isdir(CACHE_DIR): - shutil.rmtree(CACHE_DIR) + shutil.rmtree(CACHE_DIR, ignore_errors=True) rc: dict[str, list[str]] = {} for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR): @@ -186,7 +186,7 @@ def test_success(self, path) -> None: name_fn=lambda b: os.path.relpath(b, start=FAIL_DIR), ) def test_fail(self, path): - __tracebackhide__ = True # noqa: F841 + __tracebackhide__ = True with open(path) as fin: lines = fin.readlines() @@ -225,7 +225,7 @@ def test_fail(self, path): name_fn=lambda b: os.path.relpath(b, start=REVEAL_DIR), ) def test_reveal(self, path): - __tracebackhide__ = True # noqa: F841 + __tracebackhide__ = True with open(path) as fin: lines = _parse_reveals(fin) diff --git a/test/test_unary_ufuncs.py b/test/test_unary_ufuncs.py index 4d440ad51bcc3..6fa4bd4705701 100644 --- a/test/test_unary_ufuncs.py +++ b/test/test_unary_ufuncs.py @@ -1911,6 +1911,33 @@ def test_mvlgamma_integer_promotion(self, device, dtype): self.assertTrue(result.dtype.is_floating_point) self.assertTrue(torch.all(torch.isfinite(result))) + @onlyCUDA + @dtypes(torch.float32, torch.float16, torch.bfloat16) + def test_fp8_e4m3fn_conversion_subnormals(self, device, dtype): + # Regression test for ptxas codegen bug on sm_100 where FADD in the + # subnormal conversion path gets wrong source register for odd elements + # in the 8-wide unrolled vectorized_elementwise_kernel. + # e4m3fn subnormals: |x| < 2^-6 + torch.manual_seed(0) + N = 2**20 + x = (torch.randn(N, dtype=dtype, device=device) * 1e-3).clamp(-448, 448) + y = x.to(torch.float8_e4m3fn) + ref = x.cpu().float().to(torch.float8_e4m3fn) + self.assertEqual(y.cpu().view(torch.uint8), ref.view(torch.uint8)) + + @onlyCUDA + @dtypes(torch.float32, torch.float16, torch.bfloat16) + def test_fp8_e5m2_conversion_subnormals(self, device, dtype): + # Same regression test for e5m2. + # e5m2 subnormals: |x| < 2^-14 + torch.manual_seed(0) + N = 2**20 + x = (torch.randn(N, dtype=dtype, device=device) * 1e-4).clamp(-57344, 57344) + y = x.to(torch.float8_e5m2) + ref = x.cpu().float().to(torch.float8_e5m2) + self.assertEqual(y.cpu().view(torch.uint8), ref.view(torch.uint8)) + + instantiate_device_type_tests(TestUnaryUfuncs, globals()) if __name__ == "__main__": diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py index 8c2f1b0eace7b..97ff0effa450b 100644 --- a/test/test_utils_config_module.py +++ b/test/test_utils_config_module.py @@ -1,6 +1,8 @@ # Owner(s): ["module: unknown"] import os import pickle +import queue +import threading import warnings from unittest.mock import patch @@ -23,7 +25,7 @@ class TestConfigModule(TestCase): def tearDown(self): # Config changes get persisted between test cases for k in config._config: - config._config[k].user_override = _UNSET_SENTINEL + config._config[k].user_override.set(_UNSET_SENTINEL) config._hash_digest = None # Reset deprecation warning flags for k in config._config: @@ -86,7 +88,7 @@ def test_none_override_semantics(self): config.e_bool = None self.assertIsNone(config.e_bool) for k in config._config: - config._config[k].user_override = _UNSET_SENTINEL + config._config[k].user_override.set(_UNSET_SENTINEL) def test_reference_semantics(self): config.e_list.append(2) @@ -157,6 +159,42 @@ def test_save_config(self): self.assertTrue(config.e_bool) self.assertFalse(config.e_ignored) + def test_save_config_with_patch(self): + self.assertTrue(config.e_bool) + with config.patch(e_bool=False): + p = config.save_config() + self.assertDictEqual( + pickle.loads(p), + { + "_cache_config_ignore_prefix": ["magic_cache_config"], + "e_bool": False, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "_e_ignored": True, + "e_compile_ignored": True, + "magic_cache_config_ignored": True, + "_save_config_ignore": ["e_ignored"], + "e_config": True, + "e_jk": True, + "e_jk_false": False, + "e_env_default": True, + "e_env_default_FALSE": False, + "e_env_default_str": "1234", + "e_env_default_str_empty": "", + "e_env_force": True, + "e_optional": True, + "e_deprecated": True, + "e_not_deprecated": False, + }, + ) + def test_save_config_portable(self): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=FutureWarning) @@ -389,6 +427,65 @@ def test_patch(self): with config.patch("does_not_exist"): pass + def _assert_patch_reentrant_across_threads(self, fn): + barrier = threading.Barrier(2, timeout=5) + errors: queue.SimpleQueue[str] = queue.SimpleQueue() + + def worker(): + try: + fn(barrier) + except Exception as e: + errors.put(repr(e)) + + threads = [threading.Thread(target=worker) for _ in range(2)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + error_messages = [] + while True: + try: + error_messages.append(errors.get_nowait()) + except queue.Empty: + break + + self.assertFalse( + error_messages, f"concurrent patch usage failed: {error_messages}" + ) + self.assertTrue(config.e_bool) + + def test_patch_context_manager_is_reentrant_across_threads(self): + patcher = config.patch(e_bool=False) + + def fn(barrier): + with patcher: + self.assertFalse(config.e_bool) + barrier.wait() + self.assertFalse(config.e_bool) + + self._assert_patch_reentrant_across_threads(fn) + + def test_patch_context_manager_is_reentrant_when_nested(self): + patcher = config.patch(e_bool=False) + + with patcher: + self.assertFalse(config.e_bool) + with patcher: + self.assertFalse(config.e_bool) + self.assertFalse(config.e_bool) + + self.assertTrue(config.e_bool) + + def test_patch_decorator_is_reentrant_across_threads(self): + @config.patch(e_bool=False) + def fn(barrier): + self.assertFalse(config.e_bool) + barrier.wait() + self.assertFalse(config.e_bool) + + self._assert_patch_reentrant_across_threads(fn) + def test_make_closur_patcher(self): revert = config._make_closure_patcher(e_bool=False)() self.assertFalse(config.e_bool) @@ -408,7 +505,7 @@ def test_bad_jk_type(self): AssertionError, msg="AssertionError: justknobs only support booleans, thisisnotvalid is not a boolean", ): - _ConfigEntry(Config(default="bad", justknob="fake_knob")) + _ConfigEntry(Config(default="bad", justknob="fake_knob"), "test") def test_alias(self): self.assertFalse(config2.e_aliasing_bool) @@ -428,13 +525,15 @@ def test_reference_is_default(self): def test_invalid_config_int(self): with self.assertRaises(AssertionError): _ConfigEntry( - Config(default=2, env_name_default="FAKE_DISABLE", value_type=int) + Config(default=2, env_name_default="FAKE_DISABLE", value_type=int), + "test", ) def test_invalid_config_float(self): with self.assertRaises(AssertionError): _ConfigEntry( - Config(default=2, env_name_force="FAKE_DISABLE", value_type=float) + Config(default=2, env_name_force="FAKE_DISABLE", value_type=float), + "test", ) def test_deprecated_config(self): @@ -522,6 +621,49 @@ def fn(x): f"Unexpected config deprecation warnings: {[str(x.message) for x in deprecation_warnings]}", ) + def test_patch_then_global(self): + self.assertTrue(config.e_bool) + with config.patch(e_bool=False): + self.assertFalse(config.e_bool) + + config.e_bool = False + self.assertFalse(config.e_bool) + + def test_is_default_patch(self): + self.assertTrue(config.e_bool) + with config.patch(e_bool=False): + self.assertFalse(config._is_default("e_bool")) + + def test_dict_patch(self): + self.assertTrue(config.e_bool) + with config.patch(e_bool=False): + d = config._get_dict() + self.assertFalse(d["e_bool"]) + + def test_set_in_patch(self): + self.assertEqual(config.e_int, 1) + with config.patch(e_int=2): + self.assertEqual(config.e_int, 2) + config.e_int = 3 + self.assertEqual(config.e_int, 3) + # Exiting the patch resets e_int to 1, losing the fact that we explicitly set it to 4 - see ConfigModule._do_setattr + self.assertEqual(config.e_int, 1) + + def test_nested_patch(self): + self.assertEqual(config.e_int, 1) + self.assertEqual(config.e_string, "string") + with config.patch(e_int=2, e_string="inner"): + self.assertEqual(config.e_int, 2) + self.assertEqual(config.e_string, "inner") + with config.patch(e_int=3): + self.assertEqual(config.e_int, 3) + self.assertEqual(config.e_string, "inner") + config.e_int = 4 + self.assertEqual(config.e_int, 4) + self.assertEqual(config.e_int, 2) + self.assertEqual(config.e_int, 1) + self.assertEqual(config.e_string, "string") + if __name__ == "__main__": run_tests() diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py index 97ea94ba26a34..325001c1b8823 100644 --- a/test/test_varlen_attention.py +++ b/test/test_varlen_attention.py @@ -1,7 +1,7 @@ # Owner(s): ["module: sdpa"] import unittest from collections import namedtuple -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext import torch import torch.nn as nn @@ -11,10 +11,22 @@ restore_flash_attention_impl, ) from torch.nn.attention.varlen import varlen_attn, varlen_attn_out -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_cuda import ( + IS_SM90, + PLATFORM_SUPPORTS_CK_SDPA, + PLATFORM_SUPPORTS_FLASH_ATTENTION, + SM100OrLater, + SM120OrLater, + SM90OrLater, +) from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_nn import NNTestCase -from torch.testing._internal.common_utils import parametrize, run_tests, TEST_WITH_ROCM +from torch.testing._internal.common_utils import ( + decorateIf, + parametrize, + run_tests, + TEST_WITH_ROCM, +) from torch.utils._python_dispatch import TorchDispatchMode @@ -32,8 +44,33 @@ def use_fa3(): restore_flash_attention_impl() +@contextmanager +def use_fa4(): + try: + activate_flash_attention_impl("FA4") + except (ModuleNotFoundError, RuntimeError) as err: + raise unittest.SkipTest("FA4 backend not available") from err + try: + yield + finally: + restore_flash_attention_impl() + + +def _use_backend(backend): + return {"fa2": nullcontext, "fa3": use_fa3, "fa4": use_fa4}[backend]() + + +def _varlen_backends(*, include_fa4_paged_kv: bool) -> list[str]: + fa4_supported = ( + SM100OrLater if include_fa4_paged_kv else SM90OrLater + ) and not SM120OrLater + return ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if fa4_supported else []) + + VarlenShape = namedtuple( - "VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"] + "VarlenShape", + ["batch_size", "max_seq_len", "embed_dim", "num_heads", "num_kv_heads"], + defaults=[None], ) @@ -51,31 +88,48 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): class AttentionBlock(nn.Module): def __init__( - self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype + self, + embed_dim: int, + num_heads: int, + device: torch.device, + dtype: torch.dtype, + num_kv_heads: int | None = None, ): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads self.head_dim = embed_dim // num_heads - self.qkv_proj = nn.Linear( - embed_dim, 3 * embed_dim, bias=False, device=device, dtype=dtype + self.q_proj = nn.Linear( + embed_dim, + num_heads * self.head_dim, + bias=False, + device=device, + dtype=dtype, + ) + self.kv_proj = nn.Linear( + embed_dim, + 2 * self.num_kv_heads * self.head_dim, + bias=False, + device=device, + dtype=dtype, ) self.out_proj = nn.Linear( embed_dim, embed_dim, bias=False, device=device, dtype=dtype ) + @property + def enable_gqa(self): + return self.num_kv_heads != self.num_heads + def get_varlen_qkv( self, x_packed: torch.Tensor, ): - qkv = self.qkv_proj(x_packed) - q, k, v = qkv.chunk(3, dim=-1) - - q = q.view(-1, self.num_heads, self.head_dim) - k = k.view(-1, self.num_heads, self.head_dim) - v = v.view(-1, self.num_heads, self.head_dim) - + q = self.q_proj(x_packed).view(-1, self.num_heads, self.head_dim) + kv = self.kv_proj(x_packed).view(-1, 2, self.num_kv_heads, self.head_dim) + k, v = kv[:, 0], kv[:, 1] return q, k, v def forward_varlen( @@ -98,6 +152,7 @@ def forward_varlen( max_len, scale=scale, window_size=window_size, + enable_gqa=self.enable_gqa, ) attn_out = attn_out.view(-1, self.embed_dim) @@ -112,8 +167,11 @@ def forward_sdpa( ): batch_size, seq_len, _ = x_padded.shape - qkv = self.qkv_proj(x_padded) - q, k, v = qkv.chunk(3, dim=-1) + q = self.q_proj(x_padded) + kv = self.kv_proj(x_padded) + k, v = kv.view(batch_size, seq_len, 2, self.num_kv_heads, self.head_dim).unbind( + dim=2 + ) padding_mask = ( torch.arange(seq_len, device=x_padded.device)[None, :] @@ -144,14 +202,34 @@ def forward_sdpa( attn_mask = attn_mask & window_mask[None, None, :, :] q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) - - # Don't pass is_causal since we already incorporated it into attn_mask - attn_out = F.scaled_dot_product_attention( - q, k, v, attn_mask=attn_mask, scale=scale + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose( + 1, 2 + ) + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose( + 1, 2 ) + # Don't pass is_causal since we already incorporated it into attn_mask. + if self.enable_gqa: + # Force math backend for GQA + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + scale=scale, + enable_gqa=True, + ) + else: + attn_out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + scale=scale, + ) + attn_out = ( attn_out.transpose(1, 2) .contiguous() @@ -215,7 +293,18 @@ class TestVarlenAttention(NNTestCase): not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_basic_functionality(self, device, dtype): + @parametrize( + "sdpa_backend", + ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"], + ) + @parametrize( + "backend", + _varlen_backends(include_fa4_paged_kv=False), + ) + def test_basic_functionality(self, device, dtype, backend, sdpa_backend=None): + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) + torch.manual_seed(42) shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) @@ -236,45 +325,59 @@ def test_basic_functionality(self, device, dtype): [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 ) - output = attention_block.forward_varlen(x_packed, cu_seq, shape.max_seq_len) + with _use_backend(backend): + output = attention_block.forward_varlen(x_packed, cu_seq, shape.max_seq_len) - self.assertEqual(output.shape, (total_tokens, shape.embed_dim)) - self.assertEqual(output.device, torch.device(device)) - self.assertEqual(output.dtype, dtype) + self.assertEqual(output.shape, (total_tokens, shape.embed_dim)) + self.assertEqual(output.device, torch.device(device)) + self.assertEqual(output.dtype, dtype) - # varlen_attn_out should produce the same result and write into the buffer - with torch.no_grad(): - q, k, v = attention_block.get_varlen_qkv(x_packed) - expected = varlen_attn( - q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len - ) - out_buf = torch.empty_like(expected) - actual = varlen_attn_out( - out_buf, q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len - ) - self.assertEqual(actual.data_ptr(), out_buf.data_ptr()) - self.assertEqual(out_buf, expected) - - varlen_grad_out = torch.ones_like(output) - - varlen_grad = torch.autograd.grad( - outputs=output, - inputs=x_packed, - grad_outputs=varlen_grad_out, - retain_graph=True, - create_graph=False, - allow_unused=False, - )[0] + # varlen_attn_out should produce the same result and write into the buffer + with torch.no_grad(): + q, k, v = attention_block.get_varlen_qkv(x_packed) + expected = varlen_attn( + q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len + ) + out_buf = torch.empty_like(expected) + actual = varlen_attn_out( + out_buf, + q, + k, + v, + cu_seq, + cu_seq, + shape.max_seq_len, + shape.max_seq_len, + ) + self.assertEqual(actual.data_ptr(), out_buf.data_ptr()) + self.assertEqual(out_buf, expected) - self.assertIsNotNone(varlen_grad) - self.assertEqual(varlen_grad.shape, x_packed.shape) - self.assertEqual(varlen_grad.dtype, x_packed.dtype) + varlen_grad_out = torch.ones_like(output) + + varlen_grad = torch.autograd.grad( + outputs=output, + inputs=x_packed, + grad_outputs=varlen_grad_out, + retain_graph=True, + create_graph=False, + allow_unused=False, + )[0] + + self.assertIsNotNone(varlen_grad) + self.assertEqual(varlen_grad.shape, x_packed.shape) + self.assertEqual(varlen_grad.dtype, x_packed.dtype) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) + @parametrize( + "sdpa_backend", + ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"], + ) @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_custom_op_compliance(self, device, dtype): + def test_custom_op_compliance(self, device, dtype, sdpa_backend=None): + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) torch.manual_seed(42) shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) @@ -348,8 +451,14 @@ def test_custom_op_compliance(self, device, dtype): @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) + @parametrize( + "sdpa_backend", + ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"], + ) @parametrize("dtype", [torch.bfloat16, torch.float16]) - def test_custom_op_registration(self, device, dtype): + def test_custom_op_registration(self, device, dtype, sdpa_backend=None): + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) torch.manual_seed(42) shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) @@ -371,7 +480,7 @@ def test_custom_op_registration(self, device, dtype): ) compiled_forward = torch.compile( - attention_block.forward_varlen, backend="eager", fullgraph=True + attention_block.forward_varlen, backend="eager" ) with OpLoggingMode() as mode: output = compiled_forward(x_packed, cu_seq, shape.max_seq_len) @@ -402,7 +511,7 @@ def run_varlen_out(q, k, v, cu_seq, max_len): varlen_attn_out(out_buf, q, k, v, cu_seq, cu_seq, max_len, max_len) return out_buf - compiled_out = torch.compile(run_varlen_out, backend="eager", fullgraph=True) + compiled_out = torch.compile(run_varlen_out, backend="eager") with OpLoggingMode() as out_mode: compiled_out(q, k, v, cu_seq, shape.max_seq_len) @@ -412,6 +521,10 @@ def run_varlen_out(q, k, v, cu_seq, max_len): @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) + @parametrize( + "sdpa_backend", + ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"], + ) @parametrize("dtype", [torch.bfloat16, torch.float16]) @parametrize("scale", [None, 0.1]) @parametrize( @@ -430,11 +543,27 @@ def run_varlen_out(q, k, v, cu_seq, max_len): (1025, 1025), ], ) - def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): + @parametrize( + "backend", + _varlen_backends(include_fa4_paged_kv=False), + ) + @parametrize("enable_gqa", [False, True]) + def test_varlen_vs_sdpa( + self, device, dtype, scale, window_size, backend, enable_gqa, sdpa_backend=None + ): + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) + torch.manual_seed(42) + num_heads = 16 + num_kv_heads = 4 if enable_gqa else num_heads shape = VarlenShape( - batch_size=4, max_seq_len=1024, embed_dim=1024, num_heads=16 + batch_size=4, + max_seq_len=1024, + embed_dim=1024, + num_heads=num_heads, + num_kv_heads=num_kv_heads, ) batch_data = create_variable_length_batch(shape, device, dtype) @@ -446,26 +575,38 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): x_padded_ref = batch_data["x_padded_ref"] golden_attention_block = AttentionBlock( - shape.embed_dim, shape.num_heads, device, torch.float32 + shape.embed_dim, + shape.num_heads, + device, + torch.float32, + num_kv_heads=num_kv_heads, ) attention_block = AttentionBlock( - shape.embed_dim, shape.num_heads, device, dtype + shape.embed_dim, + shape.num_heads, + device, + dtype, + num_kv_heads=num_kv_heads, ) with torch.no_grad(): - attention_block.qkv_proj.weight.copy_( - golden_attention_block.qkv_proj.weight.to(dtype) + attention_block.q_proj.weight.copy_( + golden_attention_block.q_proj.weight.to(dtype) + ) + attention_block.kv_proj.weight.copy_( + golden_attention_block.kv_proj.weight.to(dtype) ) attention_block.out_proj.weight.copy_( golden_attention_block.out_proj.weight.to(dtype) ) - varlen_output = attention_block.forward_varlen( - x_packed, - cu_seq, - max_len, - scale=scale, - window_size=window_size, - ) + with _use_backend(backend): + varlen_output = attention_block.forward_varlen( + x_packed, + cu_seq, + max_len, + scale=scale, + window_size=window_size, + ) sdpa_output = attention_block.forward_sdpa( x_padded, seq_lengths, @@ -499,29 +640,30 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): start_idx = end_idx - grad_out = torch.randn_like(varlen_output) - sdpa_grad_out = torch.zeros_like(sdpa_output) - golden_sdpa_grad_out = torch.zeros( - shape.batch_size, - max_len, - shape.embed_dim, - device=device, - dtype=torch.float32, - ) - start_idx = 0 - for i, seq_len in enumerate(seq_lengths): - end_idx = start_idx + seq_len - sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx] - golden_sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx].to( - torch.float32 + with _use_backend(backend): + grad_out = torch.randn_like(varlen_output) + sdpa_grad_out = torch.zeros_like(sdpa_output) + golden_sdpa_grad_out = torch.zeros( + shape.batch_size, + max_len, + shape.embed_dim, + device=device, + dtype=torch.float32, ) - start_idx = end_idx + start_idx = 0 + for i, seq_len in enumerate(seq_lengths): + end_idx = start_idx + seq_len + sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx] + golden_sdpa_grad_out[i, :seq_len] = grad_out[start_idx:end_idx].to( + torch.float32 + ) + start_idx = end_idx - varlen_grad = torch.autograd.grad( - outputs=varlen_output, - inputs=x_packed, - grad_outputs=grad_out, - )[0] + varlen_grad = torch.autograd.grad( + outputs=varlen_output, + inputs=x_packed, + grad_outputs=grad_out, + )[0] sdpa_grad = torch.autograd.grad( outputs=sdpa_output, @@ -573,12 +715,17 @@ def test_varlen_vs_sdpa(self, device, dtype, scale, window_size): (384, 0), # edge case ], ) + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) def test_batch_invariance( - self, device, dtype, num_splits, window_size, sdpa_backend=None + self, device, dtype, num_splits, window_size, backend, sdpa_backend=None ): if TEST_WITH_ROCM: + if num_splits is not None: + self.skipTest("num_splits is not supported on ROCm") torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) - torch.manual_seed(42) num_heads, head_dim = 2, 128 @@ -618,7 +765,8 @@ def test_batch_invariance( all_k = torch.cat([target_k, extra_k], dim=0) all_v = torch.cat([target_v, extra_v], dim=0) - with use_fa3(), torch.no_grad(): + # fa4 is batch invariant (num_splits=1) by default + with _use_backend(backend), torch.no_grad(): solo_output = varlen_attn( target_q, target_k, @@ -670,19 +818,28 @@ def test_batch_invariance( window_size=window_size, num_splits=num_splits, ) - if num_splits == 1: self.assertEqual(solo_output, batched_output[:target_seq_len]) self.assertEqual(solo_out_buf, batched_out_buf[:target_seq_len]) self.assertEqual(solo_output, solo_out_buf) else: - self.assertNotEqual(solo_output, batched_output[:target_seq_len]) - self.assertNotEqual(solo_out_buf, batched_out_buf[:target_seq_len]) + if backend == "fa3": + self.assertNotEqual(solo_output, batched_output[:target_seq_len]) + self.assertNotEqual(solo_out_buf, batched_out_buf[:target_seq_len]) @unittest.skipIf( not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" ) @unittest.skipIf(TEST_WITH_ROCM, "ROCm does not support seqused_k") + @decorateIf( + unittest.expectedFailure, + lambda params: params["backend"] != "fa2" + and any(kv_len < 128 for kv_len in params["actual_kv_lens"]), + ) + @parametrize( + "sdpa_backend", + ["aotriton", "ck"] if PLATFORM_SUPPORTS_CK_SDPA else ["aotriton"], + ) @parametrize("dtype", [torch.bfloat16, torch.float16]) @parametrize( "actual_kv_lens", @@ -694,7 +851,13 @@ def test_batch_invariance( [127, 63, 33, 17], ], ) - def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): + @parametrize("backend", _varlen_backends(include_fa4_paged_kv=False)) + def test_seqused_k_kv_cache( + self, device, dtype, actual_kv_lens, backend, sdpa_backend=None + ): + if TEST_WITH_ROCM: + torch.backends.cuda.preferred_rocm_fa_library(sdpa_backend) + torch.manual_seed(42) batch_size = 4 @@ -748,7 +911,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): ) seqused_k = torch.tensor(actual_kv_lens, device=device, dtype=torch.int32) - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): output_cached = varlen_attn( q_packed, k_cache_packed, @@ -763,7 +926,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): k_real_packed, cu_seq_k_real, max_k_real = pack_sequences(k_seqs, device) v_real_packed = torch.cat(v_seqs, dim=0) - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): output_reference = varlen_attn( q_packed, k_real_packed, @@ -778,7 +941,7 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): self.assertEqual(output_cached, output_reference) # varlen_attn_out with seqused_k should match - with torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) output_out = varlen_attn_out( out_buf, @@ -811,9 +974,16 @@ def test_seqused_k_kv_cache(self, device, dtype, actual_kv_lens): [127, 63, 33, 17], ], ) + @parametrize( + "backend", + _varlen_backends(include_fa4_paged_kv=True), + ) def test_block_table_kv_cache( - self, device, dtype, page_size, compile, actual_kv_lens + self, device, dtype, page_size, compile, actual_kv_lens, backend ): + if backend == "fa2" and page_size % 256 != 0: + self.skipTest("FA2 paged KV requires page_size divisible by 256") + torch.manual_seed(42) batch_size = 4 @@ -860,7 +1030,8 @@ def test_block_table_kv_cache( attn_fn = torch.compile(varlen_attn, fullgraph=True) if compile else varlen_attn - with torch.no_grad(): + # Reference: no block_table + with _use_backend(backend), torch.no_grad(): output_reference = varlen_attn( q_packed, k_real_packed, @@ -879,41 +1050,26 @@ def test_block_table_kv_cache( dtype=torch.int32, ) - # FA2 path: paged KV with block_table (page_size % 256 == 0) - if page_size % 256 == 0: - with torch.no_grad(): - output_fa2 = attn_fn( - q_packed, - k_pages, - v_pages, - cu_seq_q, - cu_seq_k, - max_q, - cache_size, - seqused_k=seqused_k, - block_table=block_table, - ) - - self.assertEqual(output_fa2, output_reference) + # FA2 requires cu_seq_k for paged KV; FA3/FA4 pass None + cu_seq_k_paged = cu_seq_k if backend == "fa2" else None - # FA3 path: paged KV with block_table - with use_fa3(), torch.no_grad(): - output_fa3 = attn_fn( + with _use_backend(backend), torch.no_grad(): + output_paged = attn_fn( q_packed, k_pages, v_pages, cu_seq_q, - None, + cu_seq_k_paged, max_q, cache_size, seqused_k=seqused_k, block_table=block_table, ) - self.assertEqual(output_fa3, output_reference) + self.assertEqual(output_paged, output_reference) # varlen_attn_out with paged KV cache should match - with use_fa3(), torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) output_out = varlen_attn_out( out_buf, @@ -921,21 +1077,21 @@ def test_block_table_kv_cache( k_pages, v_pages, cu_seq_q, - None, + cu_seq_k_paged, max_q, cache_size, seqused_k=seqused_k, block_table=block_table, ) self.assertEqual(output_out.data_ptr(), out_buf.data_ptr()) - self.assertEqual(out_buf, output_fa3) + self.assertEqual(out_buf, output_paged) - # compile the lower level aten op, will cause graph break - if compile: + # compile the lower level aten op (FA3 only, will cause graph break) + if compile and backend != "fa2": compiled_aten_op = torch.compile( torch.ops.aten._flash_attention_forward_no_dropout_inplace ) - with use_fa3(), torch.no_grad(): + with _use_backend(backend), torch.no_grad(): out_buf = torch.empty_like(q_packed) compiled_aten_op( out_buf, @@ -954,6 +1110,81 @@ def test_block_table_kv_cache( ) self.assertEqual(out_buf, output_reference) + # With num_splits=1, paged and contiguous must be bit-identical + if backend == "fa2": + with _use_backend(backend), torch.no_grad(): + ref_num_splits = varlen_attn( + q_packed, + k_real_packed, + v_real_packed, + cu_seq_q, + cu_seq_k_real, + max_q, + max_k_real, + num_splits=1, + ) + paged_num_splits = varlen_attn( + q_packed, + k_pages, + v_pages, + cu_seq_q, + cu_seq_k_paged, + max_q, + cache_size, + seqused_k=seqused_k, + block_table=block_table, + num_splits=1, + ) + self.assertTrue(torch.equal(paged_num_splits, ref_num_splits)) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + @parametrize( + "backend", + ["fa2"] + (["fa3"] if IS_SM90 else []) + (["fa4"] if SM100OrLater else []), + ) + def test_enable_gqa(self, device, dtype, backend): + torch.manual_seed(42) + + head_dim = 64 + seq_len = 512 + num_heads_q, num_heads_k = 16, 4 + total_tokens = 2 * seq_len + + q = torch.randn(total_tokens, num_heads_q, head_dim, device=device, dtype=dtype) + k = torch.randn(total_tokens, num_heads_k, head_dim, device=device, dtype=dtype) + v = torch.randn(total_tokens, num_heads_k, head_dim, device=device, dtype=dtype) + cu_seq = torch.tensor( + [0, seq_len, total_tokens], device=device, dtype=torch.int32 + ) + + with self.assertRaisesRegex(ValueError, "enable_gqa=True"): + varlen_attn(q, k, v, cu_seq, cu_seq, seq_len, seq_len) + + with self.assertRaisesRegex(ValueError, "enable_gqa=True"): + varlen_attn_out( + torch.empty_like(q), q, k, v, cu_seq, cu_seq, seq_len, seq_len + ) + + k_bad = torch.randn(total_tokens, 3, head_dim, device=device, dtype=dtype) + v_bad = torch.randn(total_tokens, 3, head_dim, device=device, dtype=dtype) + with self.assertRaisesRegex(ValueError, "multiple of kv heads"): + varlen_attn( + q, k_bad, v_bad, cu_seq, cu_seq, seq_len, seq_len, enable_gqa=True + ) + + with _use_backend(backend), torch.no_grad(): + out = varlen_attn( + q, k, v, cu_seq, cu_seq, seq_len, seq_len, enable_gqa=True + ) + out_buf = torch.empty_like(q) + varlen_attn_out( + out_buf, q, k, v, cu_seq, cu_seq, seq_len, seq_len, enable_gqa=True + ) + self.assertEqual(out_buf, out) + device_types = ("cuda",) diff --git a/test/test_weak.py b/test/test_weak.py index 28fa1436b5c23..061b85a28f94b 100644 --- a/test/test_weak.py +++ b/test/test_weak.py @@ -175,11 +175,11 @@ def check_threaded_weak_dict_copy(self, type_, deepcopy): # Cannot give these slots as weakrefs weren't supported # on these objects until later versions of Python - class DummyKey: # noqa: B903 + class DummyKey: def __init__(self, ctr): self.ctr = ctr - class DummyValue: # noqa: B903 + class DummyValue: def __init__(self, ctr): self.ctr = ctr diff --git a/test/test_xpu.py b/test/test_xpu.py index b63b5defca026..f47c50b4f9be2 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -6,6 +6,7 @@ import ctypes import gc import json +import os import random import re import subprocess @@ -17,12 +18,14 @@ import warnings from copy import deepcopy from itertools import product +from unittest.mock import patch import torch import torch.xpu._gpu_trace as gpu_trace from torch.testing import make_tensor from torch.testing._internal.autocast_test_lists import AutocastTestLists, TestAutocast from torch.testing._internal.common_device_type import ( + dtypes, instantiate_device_type_tests, OpDTypes, ops, @@ -126,6 +129,8 @@ def test_get_device_properties(self): self.assertTrue(device_capability["max_work_group_size"] > 0) self.assertTrue(device_capability["max_num_sub_groups"] > 0) self.assertTrue(device_capability["local_mem_size"] > 0) + self.assertTrue(device_capability["memory_clock_rate"] > 0) + self.assertTrue(device_capability["memory_bus_width"] > 0) self.assertEqual( device_properties.driver_version, device_capability["driver_version"] ) @@ -244,6 +249,81 @@ def test_multi_process(model, input): rc = check_output(test_script).splitlines()[-1] self.assertEqual(rc, str(torch.xpu.device_count())) + def test_parse_visible_devices(self): + def _parse_visible_devices(val): + with patch.dict(os.environ, {"ZE_AFFINITY_MASK": val}, clear=True): + return torch.xpu._parse_visible_devices() + + # Tokens with trailing non-numeric characters are invalid; entire list is rejected + self.assertEqual(_parse_visible_devices("1a, 2b"), []) + # Negative indices are silently skipped; valid indices before and after are kept + self.assertEqual(_parse_visible_devices("0, 1, 2, -1, 3"), [0, 1, 2, 3]) + # Duplicate indices are silently ignored; each ordinal appears at most once + self.assertEqual(_parse_visible_devices("0, 1, 2, 1"), [0, 1, 2]) + # Leading '+'/'-' on an integer are accepted; '-0' is treated as 0 + self.assertEqual(_parse_visible_devices("2, +3, -0, 5"), [2, 3, 0, 5]) + # Purely alphabetic tokens make the entire list invalid + self.assertEqual(_parse_visible_devices("one, two, 3, 4"), []) + + def test_device_count_respects_affinity_mask(self): + try: + import pyzes # noqa: F401 + except ImportError: + self.skipTest("pyzes is required for this test") + + def _run(mask: str) -> str: + script = f"""\ +import torch +import os +os.environ['ZE_AFFINITY_MASK'] = {mask!r} +r1 = torch.xpu._device_count_zes() +r2 = torch._C._xpu_getDeviceCount() +print(f"{{r1}}, {{r2}}") +""" + return ( + subprocess.check_output([sys.executable, "-c", script]) + .decode("ascii") + .strip() + .splitlines()[-1] + ) + + # Index 128 is out of range → both return 0 + self.assertEqual(_run("128"), "0, 0") + # COMPOSITE-style mask → _device_count_zes returns -1 + self.assertEqual(_run("0.0").split(",")[0].strip(), "-1") + # Valid mask selecting device 0 on a single-GPU system → both return 1 + self.assertEqual(_run("0"), "1, 1") + if TEST_MULTIXPU: + # Valid mask selecting device 1 on a multi-GPU system → both return 1 + self.assertEqual(_run("1"), "1, 1") + + @unittest.skipIf(not TEST_MULTIXPU, "requires multiple devices") + def test_device_count_not_cached_pre_init(self): + try: + import pyzes # noqa: F401 + except ImportError: + self.skipTest("pyzes is required for this test") + + test_script = """\ +import torch +import os +r1 = torch.xpu.device_count() +os.environ['ZE_AFFINITY_MASK'] = '0' +r2 = torch.xpu.device_count() +torch.empty(10, device='xpu') +print(f"{r1}, {r2}") +""" + + r = ( + subprocess.check_output([sys.executable, "-c", test_script]) + .decode("ascii") + .strip() + .splitlines()[-1] + ) + + x = torch.xpu.device_count() + self.assertEqual(f"{x}, 1", r) + @unittest.skipIf( IS_WINDOWS, "Only for lazy initialization on Linux, not applicable on Windows." ) @@ -277,12 +357,12 @@ def counting_getDeviceCount(): """ rc = check_output(test_script).splitlines() self.assertEqual( - rc[0], + rc[-2], "0", "Importing torch._inductor.lowering should not query XPU device count", ) self.assertEqual( - rc[1], + rc[-1], "False", "Importing torch._inductor.lowering should not initialize XPU", ) @@ -548,7 +628,7 @@ def test_serialization_array_with_empty(self): def test_out_of_memory(self): if self.expandable_segments: self.skipTest("Skipping OOM test for expandable segments allocator.") - tensor = torch.zeros(1024, device="xpu") # noqa: F841 + tensor = torch.zeros(1024, device="xpu") with self.assertRaisesRegex(RuntimeError, "Tried to allocate 800000000.00 GiB"): torch.empty(1024 * 1024 * 1024 * 800000000, dtype=torch.int8, device="xpu") @@ -1155,6 +1235,8 @@ def get_dummy_allocator(self, check_vars): return allocator, dummy_allocator def test_xpu_pluggable_allocator(self): + from torch.utils.cpp_extension import load_inline + torch.xpu.init() allocator, dummy_allocator = self.get_dummy_allocator(True) alloc_lib = ctypes.CDLL(dummy_allocator) @@ -1229,6 +1311,40 @@ def check_output(script: str) -> str: self.assertEqual(called_dummy_alloc_value, "123") self.assertEqual(called_dummy_free_value, "321") + cpp_source = r""" + #include + #include + // Mimics what torchcomms' get_mem_allocator("xccl") does: + // creates an XPUPluggableAllocator and returns it as c10::Allocator*. + std::shared_ptr get_xpu_allocator() { + auto allocator = + torch::xpu::XPUPluggableAllocator::createCustomAllocator( + // alloc_fn + [](size_t size, int device, sycl::queue* queue) -> void* { + void* ptr = sycl::malloc_device(size, *queue); + return ptr; + }, + // free_fn + [](void* ptr, size_t size, int device, sycl::queue* queue) { + sycl::free(ptr, *queue); + }); + return allocator; + } + """ + ext = load_inline( + name="repro_xpu_alloc", + cpp_sources=[cpp_source], + functions=["get_xpu_allocator"], + verbose=True, + is_python_module=True, + with_sycl=True, + ) + # Verify that the XPUPluggableAllocator returned as + # std::shared_ptr is correctly recognized as Python type. + # A TypeError here would mean the custom allocator lacks proper + # c10::Allocator pybind11 bindings. + allocator = ext.get_xpu_allocator() + def test_torch_version_xpu(self): self.assertEqual(len(torch.version.xpu), 8) compiler_version = int(torch.version.xpu) @@ -1263,8 +1379,10 @@ def test_graph_is_current_stream_capturing(self): with torch.xpu.stream(s): g = torch.xpu.XPUGraph() self.assertFalse(torch.xpu.is_current_stream_capturing()) + self.assertFalse(s.is_capturing()) g.capture_begin() self.assertTrue(torch.xpu.is_current_stream_capturing()) + self.assertTrue(s.is_capturing()) g.capture_end() def test_graph_capture_simple(self): @@ -1285,6 +1403,21 @@ def test_graph_capture_simple(self): self.assertEqual(b.sum().item(), 11000.0) + def test_accelerator_graph_simple(self): + s = torch.Stream() + g = torch.accelerator.Graph() + + with s, g: + a = torch.full((1000,), 1, device="xpu") + b = a + for _ in range(10): + b = b + 1 + torch.accelerator.current_stream().wait_stream(s) + + g.replay() + + self.assertEqual(b.sum().item(), 11000.0) + def test_graphsafe_set_get_rng_state(self): # Define a function to create generator states, with optional graph registration def create_states(generator): @@ -1951,7 +2084,7 @@ def test_graph_manual_seed_mismatch_raises(self): g = torch.xpu.XPUGraph() with self.assertRaisesRegex( RuntimeError, - "XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", # noqa: B950 + "XPUGeneratorImpl::set_current_seed can be called during stream capture only if new seed is the same as the original seed.", ): with torch.xpu.graph(g): torch.xpu.manual_seed(1) @@ -2307,14 +2440,14 @@ def raw_malloc(): try: with torch.xpu.stream(stream): mem = torch.xpu.caching_allocator_alloc(1024) - except BaseException: # noqa: B036 + except BaseException: if mem is None: return try: torch.xpu.caching_allocator_delete(mem) mem = None return None - except BaseException: # noqa: B036 + except BaseException: pass def throws_on_xpu_event(): @@ -2770,6 +2903,78 @@ def convert_boolean_tensors(x): self.assertEqual(expect, actual) + @dtypes(torch.float16, torch.bfloat16, torch.float32) + def test_fused_rms_norm(self, device, dtype): + # Verify _fused_rms_norm is dispatched to XPU kernel (not fallback) + has_xpu_kernel = torch._C._dispatch_has_kernel_for_dispatch_key( + "aten::_fused_rms_norm", + torch._C._dispatch_key_name(torch._C.DispatchKey.XPU), + ) + self.assertTrue(has_xpu_kernel, "_fused_rms_norm XPU kernel is not registered") + has_xpu_kernel = torch._C._dispatch_has_kernel_for_dispatch_key( + "aten::_fused_rms_norm_backward", + torch._C._dispatch_key_name(torch._C.DispatchKey.XPU), + ) + self.assertTrue( + has_xpu_kernel, "_fused_rms_norm_backward XPU kernel is not registered" + ) + + shapes = [ + (2, 16), # small 2D + (4, 8, 32), # 3D + (1, 1, 64), # degenerate batch + (8, 128), # typical sequence hidden + (2, 16, 512), # typical LLM hidden dim + (4, 32, 1024), # larger hidden dim + (1, 1, 4096), # LLM-scale hidden + (3, 7, 17), # non-power-of-2 + ] + eps = 1e-5 + atol_fwd = 1e-1 if dtype in [torch.float16, torch.bfloat16] else 1e-5 + atol_bwd = 1e-1 if dtype in [torch.float16, torch.bfloat16] else 1e-5 + + for shape in shapes: + normalized_shape = list(shape[-1:]) + x = torch.randn(*shape, dtype=dtype, device=device, requires_grad=True) + w = torch.randn( + *normalized_shape, dtype=dtype, device=device, requires_grad=True + ) + grad_out = torch.randn(*shape, dtype=dtype, device=device) + x_cpu = x.detach().cpu().requires_grad_(True) + w_cpu = w.detach().cpu().requires_grad_(True) + grad_out_cpu = grad_out.detach().cpu() + + # Forward + y, _ = torch.ops.aten._fused_rms_norm(x, normalized_shape, w, eps) + y_cpu, _ = torch.ops.aten._fused_rms_norm( + x_cpu, normalized_shape, w_cpu, eps + ) + self.assertEqual( + y, + y_cpu, + atol=atol_fwd, + rtol=0, + msg=f"forward shape={shape}, dtype={dtype}", + ) + + # Backward + y.backward(grad_out) + y_cpu.backward(grad_out_cpu) + self.assertEqual( + x.grad.cpu(), + x_cpu.grad, + atol=atol_bwd, + rtol=0, + msg=f"x_grad shape={shape}, dtype={dtype}", + ) + self.assertEqual( + w.grad.cpu(), + w_cpu.grad, + atol=atol_bwd, + rtol=0, + msg=f"w_grad shape={shape}, dtype={dtype}", + ) + instantiate_device_type_tests(TestXpuOps, globals(), only_for="xpu", allow_xpu=True) @@ -2944,6 +3149,8 @@ def test_torch_config_for_xpu(self): self.assertTrue(value.group(1) in ["ON", "1"]) else: self.assertTrue(value.group(1) in ["OFF", "0"]) + value = re.search(r"SYCL_COMPILER_VERSION=([^,]+)", config) + self.assertEqual(value.group(1), torch.version.xpu) else: self.assertTrue(value.group(1) in ["OFF", "0"]) self.assertFalse(torch.distributed.is_xccl_available()) diff --git a/test/torch_np/numpy_tests/core/test_multiarray.py b/test/torch_np/numpy_tests/core/test_multiarray.py index 2f1fe86c2905d..c05e60ecf15d7 100644 --- a/test/torch_np/numpy_tests/core/test_multiarray.py +++ b/test/torch_np/numpy_tests/core/test_multiarray.py @@ -4443,6 +4443,10 @@ def test_index_getset(self): if it.index != it.base.size: raise AssertionError(f"index mismatch: {it.index} != {it.base.size}") + def test_flat_cumsum(self): + x = np.array([[1.0, 2.0], [3.0, 4.0]]) + assert_array_equal(np.cumsum(x.flat), np.array([1.0, 3.0, 6.0, 10.0])) + class TestResize(TestCase): @_no_tracing @@ -6385,7 +6389,7 @@ def test_scalar_interface(self, val, iface, expected): class TestDelMisc(TestCase): - @xpassIfTorchDynamo_np # (reason="TODO") + @xfail # torch._numpy .flat returns ravel() instead of flatiter, so del is not supported def test_flat_element_deletion(self): it = np.ones(3).flat try: @@ -6834,7 +6838,7 @@ def test_choose_mod_raise(self): np.choose(a, choices, out=out, mode="raise") assert_equal(out, np.array([[10, -10, 10], [-10, 10, -10], [10, -10, 10]])) - @xpassIfTorchDynamo_np # (reason="XXX: ndarray.flat") + @xfail # torch._numpy ndarray doesn't implement __array__ def test_flatiter__array__(self): a = np.arange(9).reshape(3, 3) b = a.T.flat diff --git a/test/torch_np/test_nep50_examples.py b/test/torch_np/test_nep50_examples.py index a3ad346bf9f1c..964683bd74c81 100644 --- a/test/torch_np/test_nep50_examples.py +++ b/test/torch_np/test_nep50_examples.py @@ -31,6 +31,7 @@ from torch._numpy.testing import assert_allclose from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, + IS_WINDOWS, parametrize, run_tests, TestCase, @@ -216,11 +217,36 @@ def test_compare_ufuncs(self, name, scalar, array): # TypeError: ufunc 'hypot' not supported for the input types result_numpy = None + type_mismatch = False + expected_numpy_dtype = None + expected_torch_dtype = None + if result is not None and result_numpy is not None: - if result.tensor.numpy().dtype != result_numpy.dtype: - raise AssertionError( - f"Expected result dtype == {result_numpy.dtype}, got {result.tensor.numpy().dtype}" - ) + expected_numpy_dtype = result_numpy.dtype + expected_torch_dtype = result.tensor.numpy().dtype + if IS_WINDOWS: + if ( + array.tensor.numpy().dtype != _np.bool_ + and result.tensor.numpy().dtype != result_numpy.dtype + ): + type_mismatch = True + + if ( + array.tensor.numpy().dtype == _np.bool_ + and result_numpy.dtype == _np.int32 + and result.tensor.numpy().dtype != _np.int64 + ): + expected_numpy_dtype = _np.int32 + expected_torch_dtype = tnp.int64 + type_mismatch = True + else: + if result.tensor.numpy().dtype != result_numpy.dtype: + type_mismatch = True + + if type_mismatch: + raise AssertionError( + f"Expected result numpy dtype == {expected_numpy_dtype}, torch dtype == {expected_torch_dtype}" + ) finally: _np._set_promotion_state(state) diff --git a/test/xpu/test_conv.py b/test/xpu/test_conv.py index d25905f5804d1..3e069f2d92637 100644 --- a/test/xpu/test_conv.py +++ b/test/xpu/test_conv.py @@ -1357,6 +1357,39 @@ def test_channels_last_ouput_stride(self, device, dtype): # input NHWC, output NHWC assert_size_stride(out, (2, 512, 7, 7), (25088, 1, 3584, 512)) + @dtypes(torch.float16, torch.bfloat16, torch.float32, torch.float64) + def test_conv_misaligned_input(self, device, dtype): + N, C, OUT_CHANNELS = 2, 3, 4 + + def make_misaligned_tensor(tensor, shape): + numel = math.prod(shape) + base = torch.empty(numel + 1, device=device, dtype=dtype) + tensor_device = base[1:].reshape(shape) + tensor_device.copy_(tensor.to(device)) + self.assertTrue(tensor_device.data_ptr() % 64 != 0) + return tensor_device + + for conv_fn, spatial in [ + (F.conv1d, (16,)), + (F.conv2d, (8, 8)), + (F.conv3d, (4, 4, 4)), + ]: + kernel = (3,) * len(spatial) + input_shape = (N, C) + spatial + weight_shape = (OUT_CHANNELS, C) + kernel + + input_cpu = torch.randn(input_shape, dtype=dtype) + weight_cpu = torch.randn(weight_shape, dtype=dtype) + bias_cpu = torch.randn(OUT_CHANNELS, dtype=dtype) + + input_device = make_misaligned_tensor(input_cpu, input_shape) + weight_device = make_misaligned_tensor(weight_cpu, weight_shape) + bias_device = make_misaligned_tensor(bias_cpu, (OUT_CHANNELS,)) + + output_cpu = conv_fn(input_cpu, weight_cpu, bias_cpu) + output_device = conv_fn(input_device, weight_device, bias_device) + self.assertEqual(output_device.cpu(), output_cpu) + @onlyXPU def test_onednn_allow_tf32_get_set(self): with torch.backends.mkldnn.flags( diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index 117d0133fc6df..4ceb858a4712c 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -26,6 +26,7 @@ _dynamically_quantize_per_channel, ) from torch.testing._internal.common_utils import ( + DeterministicGuard, iter_indices, parametrize, run_tests, @@ -175,6 +176,12 @@ def _test_addmm_addmv( self.assertEqual(res1, res2) self.assertEqual(res1, res3) + # Test inplace versions if they exist. + if hasattr(t, f.__name__ + "_"): + out_tensor = torch.broadcast_to(t, res1.shape).clone() + getattr(out_tensor, f.__name__ + "_")(m, v, alpha=alpha, beta=beta) + self.assertEqual(res1, out_tensor) + def _test_addmm_impl(self, func, activation, device, dtype): M = torch.randn(10, 25, device="cpu", dtype=torch.float32).to(dtype).to(device) m1 = torch.randn(10, 50, device="cpu", dtype=torch.float32).to(dtype).to(device) @@ -238,13 +245,13 @@ def maybe_transpose(cond, m): activation=activation, ) - @precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1}) + @precisionOverride({torch.float: 1e-4, torch.half: 1e-1}) @dtypes(torch.float32, torch.half, torch.double, torch.complex64) @tf32_on_and_off(0.05) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) - @precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1}) + @precisionOverride({torch.float: 1e-4, torch.half: 1e-1}) @dtypes(torch.float, torch.half, torch.double) def test_addmm_badmm_scalar_tnesor_input(self, device, dtype): input = torch.tensor(1).to(device=device, dtype=dtype) @@ -638,7 +645,7 @@ def generate_tensor(): for b1, b2, ref, out_tensor in generate_tensor(): self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) - @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6}) + @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5}) @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64) @tf32_on_and_off(0.01) def test_baddbmm(self, device, dtype): @@ -864,6 +871,15 @@ def test_matmul_45724(self, device): torch.matmul(a, b, out=c) self.assertEqual(c, cpu_result) + @parametrize("shape", [513, 767]) + @dtypes(torch.bfloat16, torch.half, torch.float, torch.double) + def test_matmul_deterministic_mode(self, device, shape, dtype): + with DeterministicGuard(True): + inp = torch.randn(shape, shape, device=device, dtype=dtype) + first = torch.matmul(inp, inp) + for _ in range(10): + self.assertEqual(first, torch.matmul(inp, inp), atol=0.0, rtol=0.0) + @dtypes( torch.int16, torch.int32, @@ -908,7 +924,6 @@ def test_baddbmm_nan_input_with_zero_beta(self, device, dtype): y = torch.baddbmm(input, mat1, mat2, beta=0.0, out=out) self.assertEqual(y_ref, y) - @precisionOverride({torch.double: 1e-6}) @dtypes(torch.float, torch.double) @tf32_on_and_off(0.005) def test_addmm_sizes(self, device, dtype): @@ -933,7 +948,6 @@ def test_addmm_sizes(self, device, dtype): @precisionOverride( { - torch.double: 1e-6, torch.float: 1e-4, torch.bfloat16: 5e-2, torch.half: 5e-2, @@ -948,7 +962,6 @@ def test_addmm_gelu(self, device, dtype): @precisionOverride( { - torch.double: 1e-6, torch.float: 1e-4, torch.bfloat16: 5e-2, torch.half: 5e-2, @@ -997,7 +1010,6 @@ def _test(row_major, incx, incy, lda_tail): @precisionOverride( { - torch.double: 1e-8, torch.float: 1e-4, torch.bfloat16: 0.6, torch.half: 1e-1, diff --git a/third_party/BUILD b/third_party/BUILD deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/third_party/cpp-httplib b/third_party/cpp-httplib index bd95e67c23493..4d7c9a788de13 160000 --- a/third_party/cpp-httplib +++ b/third_party/cpp-httplib @@ -1 +1 @@ -Subproject commit bd95e67c234930cd6d6bb11309588c5462c63cec +Subproject commit 4d7c9a788de136071ccf0dd4e96239151e2adadb diff --git a/third_party/cpp-httplib.BUILD b/third_party/cpp-httplib.BUILD deleted file mode 100644 index 3cd0c3dbe94ba..0000000000000 --- a/third_party/cpp-httplib.BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "cpp-httplib", - hdrs = ["httplib.h"], - includes = [ - "/", - ], - visibility = ["//visibility:public"], -) diff --git a/third_party/cpuinfo b/third_party/cpuinfo index f858c30bcb16f..bc3c01e230c69 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit f858c30bcb16f8effd5ff46996f0514539e17abc +Subproject commit bc3c01e230c6974283e4b89421cfb0e232435589 diff --git a/third_party/cuda.BUILD b/third_party/cuda.BUILD deleted file mode 100644 index 4767231b55830..0000000000000 --- a/third_party/cuda.BUILD +++ /dev/null @@ -1,82 +0,0 @@ -# Adopted from: https://github.com/tensorflow/runtime/blob/master/third_party/rules_cuda/private/BUILD.local_cuda -# Library targets are created corresponding to BUILD.bazel's needs. - -cc_library( - name = "cuda_headers", - hdrs = glob([ - "include/**", - "targets/x86_64-linux/include/**", - ]), - includes = [ - "include", - "targets/x86_64-linux/include", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cuda_driver", - srcs = ["lib64/stubs/libcuda.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cuda", - srcs = ["targets/x86_64-linux/lib/libcudart.so"], - visibility = ["//visibility:public"], - deps = [":cuda_headers"], -) - -cc_library( - name = "cufft", - srcs = ["targets/x86_64-linux/lib/libcufft.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cublas", - srcs = [ - "targets/x86_64-linux/lib/libcublasLt.so", - "targets/x86_64-linux/lib/libcublas.so", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "curand", - srcs = ["targets/x86_64-linux/lib/libcurand.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cusolver", - srcs = ["targets/x86_64-linux/lib/libcusolver.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cusparse", - srcs = ["targets/x86_64-linux/lib/libcusparse.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "cufile", - srcs = ["targets/x86_64-linux/lib/libcufile.so"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "nvrtc", - srcs = [ - "targets/x86_64-linux/lib/libnvrtc.so", - "targets/x86_64-linux/lib/libnvrtc-builtins.so", - ], - visibility = ["//visibility:public"], -) - -cc_library( - name = "nvToolsExt", - srcs = [ "lib64/libnvToolsExt.so"], - visibility = ["//visibility:public"], -) diff --git a/third_party/cudnn.BUILD b/third_party/cudnn.BUILD deleted file mode 100644 index 82b4526c21e6c..0000000000000 --- a/third_party/cudnn.BUILD +++ /dev/null @@ -1,26 +0,0 @@ -# Adopted from: https://github.com/NVIDIA/TRTorch/blob/master/third_party/cudnn/local/BUILD - -cc_library( - name = "cudnn_headers", - hdrs = ["include/cudnn.h"] + glob([ - "include/cudnn+.h", - "include/cudnn_*.h", - ]), - includes = ["include/"], - visibility = ["//visibility:private"], -) - -cc_import( - name = "cudnn_lib", - shared_library = "lib64/libcudnn.so", - visibility = ["//visibility:private"], -) - -cc_library( - name = "cudnn", - visibility = ["//visibility:public"], - deps = [ - "cudnn_headers", - "cudnn_lib", - ], -) diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index b8c0656e6f6c8..a91f0e04dcea1 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit b8c0656e6f6c84fc194f4d57329b55d609eff596 +Subproject commit a91f0e04dcea10515f0f776fc5a89535e316a9c8 diff --git a/third_party/cudnn_frontend.BUILD b/third_party/cudnn_frontend.BUILD deleted file mode 100644 index 0af69797a3e2d..0000000000000 --- a/third_party/cudnn_frontend.BUILD +++ /dev/null @@ -1,22 +0,0 @@ -# Adopted from: https://github.com/tensorflow/tensorflow/blob/master/third_party/cudnn_frontend.BUILD - -# Description: -# The cuDNN Frontend API is a C++ header-only library that demonstrates how -# to use the cuDNN C backend API. - -load("@rules_cc//cc:defs.bzl", "cc_library") - -package( - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) # MIT - -exports_files(["LICENSE.txt"]) - -cc_library( - name = "cudnn_frontend", - hdrs = glob(["include/**"]), - includes = ["include/"], - include_prefix = "third_party/cudnn_frontend", -) diff --git a/third_party/cutlass b/third_party/cutlass index 0d2b201e8c1c4..da5e086dab31d 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 0d2b201e8c1c4a03efa6e9c468161916e2334725 +Subproject commit da5e086dab31d63815acafdac9a9c5893b1c69e2 diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD deleted file mode 100644 index 10100531d9be6..0000000000000 --- a/third_party/cutlass.BUILD +++ /dev/null @@ -1,26 +0,0 @@ -# Description: -# CUDA Templates for Linear Algebra Subroutines - -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "cutlass", - hdrs = glob([ - "include/**/*.h", - "include/**/*.hpp", - "include/**/*.inl", - "tools/util/include/**/*.h", - "tools/util/include/**/*.hpp", - "tools/util/include/**/*.inl", - ]), - defines = [ - "CUTLASS_ENABLE_TENSOR_CORE_MMA=1", - "CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", - "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", - ], - includes = [ - "include/", - "tools/util/include/", - ], - visibility = ["//visibility:public"], -) diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD deleted file mode 100644 index a6a735360633b..0000000000000 --- a/third_party/eigen.BUILD +++ /dev/null @@ -1,91 +0,0 @@ -# This is BUILD file is derived from https://github.com/tensorflow/tensorflow/blob/master/third_party/eigen.BUILD - -# Description: -# Eigen is a C++ template library for linear algebra: vectors, -# matrices, and related algorithms. - -load("@rules_cc//cc:defs.bzl", "cc_library") - -licenses([ - # Note: Eigen is an MPL2 library that includes GPL v3 and LGPL v2.1+ code. - # We've taken special care to not reference any restricted code. - "reciprocal", # MPL2 - "notice", # Portions BSD -]) - -exports_files(["COPYING.MPL2"]) - -# License-restricted (i.e. not reciprocal or notice) files inside Eigen/... -EIGEN_RESTRICTED_FILES = [ - "Eigen/src/OrderingMethods/Amd.h", - "Eigen/src/SparseCholesky/**", -] - -# Notable transitive dependencies of restricted files inside Eigen/... -EIGEN_RESTRICTED_DEPS = [ - "Eigen/Eigen", - "Eigen/IterativeLinearSolvers", - "Eigen/MetisSupport", - "Eigen/Sparse", - "Eigen/SparseCholesky", - "Eigen/SparseLU", -] - -EIGEN_FILES = [ - "Eigen/**", - "unsupported/Eigen/CXX11/**", - "unsupported/Eigen/FFT", - "unsupported/Eigen/KroneckerProduct", - "unsupported/Eigen/src/FFT/**", - "unsupported/Eigen/src/KroneckerProduct/**", - "unsupported/Eigen/MatrixFunctions", - "unsupported/Eigen/SpecialFunctions", - "unsupported/Eigen/Splines", - "unsupported/Eigen/src/MatrixFunctions/**", - "unsupported/Eigen/src/SpecialFunctions/**", - "unsupported/Eigen/src/Splines/**", - "unsupported/Eigen/NonLinearOptimization", - "unsupported/Eigen/NumericalDiff", - "unsupported/Eigen/src/**", - "unsupported/Eigen/Polynomials", -] - -# List of files picked up by glob but actually part of another target. -EIGEN_EXCLUDE_FILES = ["Eigen/src/Core/arch/AVX/PacketMathGoogleTest.cc"] - -# Disallowed eigen modules/files in rNA: -# * Using the custom STL and memory support, it is not needed and should -# not be used with c++17. -# * We will only support the EulerAnglesZYX provided by //atg/geometry so -# just don't allow people to access the unsupported eigen module. -EIGEN_DISALLOW_FILES = [ - "Eigen/StlSupport/*.h", - "unsupported/Eigen/EulerAngles", - "unsupported/Eigen/src/EulerAngles/**", -] - -# Files known to be under MPL2 license. -EIGEN_MPL2_HEADER_FILES = glob( - EIGEN_FILES, - exclude = EIGEN_EXCLUDE_FILES + - EIGEN_RESTRICTED_FILES + - EIGEN_DISALLOW_FILES + - EIGEN_RESTRICTED_DEPS + [ - # Guarantees any file missed by excludes above will not compile. - "Eigen/src/Core/util/NonMPL2.h", - "Eigen/**/CMakeLists.txt", - ], -) - -cc_library( - name = "eigen", - hdrs = EIGEN_MPL2_HEADER_FILES, - defines = [ - # This define (mostly) guarantees we don't link any problematic - # code. We use it, but we do not rely on it, as evidenced above. - "EIGEN_MPL2_ONLY", - "EIGEN_MAX_ALIGN_BYTES=64", - ], - includes = ["."], - visibility = ["//visibility:public"], -) diff --git a/third_party/fbgemm b/third_party/fbgemm index c246916f9e380..d08742c6602ef 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit c246916f9e3804eacc3c95058e51cce02ae00fff +Subproject commit d08742c6602efedc6a3c9fca124b96ad555316e4 diff --git a/third_party/flash-attention b/third_party/flash-attention index fec3a6a18460c..b322ae2675065 160000 --- a/third_party/flash-attention +++ b/third_party/flash-attention @@ -1 +1 @@ -Subproject commit fec3a6a18460c1b40f097208d4c16fe8964a679d +Subproject commit b322ae2675065ad96ad3c248fd6ef0f32252808f diff --git a/third_party/fmt.BUILD b/third_party/fmt.BUILD deleted file mode 100644 index ea8c566b98a53..0000000000000 --- a/third_party/fmt.BUILD +++ /dev/null @@ -1,9 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "fmt", - hdrs = glob(["include/fmt/*.h",]), - defines = ["FMT_HEADER_ONLY=1"], - includes = ["include"], - visibility = ["//visibility:public"], -) diff --git a/third_party/gloo.BUILD b/third_party/gloo.BUILD deleted file mode 100644 index ff7eda654b3ad..0000000000000 --- a/third_party/gloo.BUILD +++ /dev/null @@ -1,88 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") -load("@pytorch//tools/rules:cu.bzl", "cu_library") -load("@pytorch//third_party:substitution.bzl", "template_rule") -load("@pytorch//tools/config:defs.bzl", "if_cuda") - -template_rule( - name = "gloo_config_cmake_macros", - src = "gloo/config.h.in", - out = "gloo/config.h", - substitutions = { - "@GLOO_VERSION_MAJOR@": "0", - "@GLOO_VERSION_MINOR@": "5", - "@GLOO_VERSION_PATCH@": "0", - "cmakedefine01 GLOO_USE_CUDA": "define GLOO_USE_CUDA 1", - "cmakedefine01 GLOO_USE_NCCL": "define GLOO_USE_NCCL 0", - "cmakedefine01 GLOO_USE_ROCM": "define GLOO_USE_ROCM 0", - "cmakedefine01 GLOO_USE_RCCL": "define GLOO_USE_RCCL 0", - "cmakedefine01 GLOO_USE_REDIS": "define GLOO_USE_REDIS 0", - "cmakedefine01 GLOO_USE_IBVERBS": "define GLOO_USE_IBVERBS 0", - "cmakedefine01 GLOO_USE_MPI": "define GLOO_USE_MPI 0", - "cmakedefine01 GLOO_USE_AVX": "define GLOO_USE_AVX 0", - "cmakedefine01 GLOO_USE_LIBUV": "define GLOO_USE_LIBUV 0", - # The `GLOO_HAVE_TRANSPORT_TCP_TLS` line should go above the `GLOO_HAVE_TRANSPORT_TCP` in order to properly substitute the template. - "cmakedefine01 GLOO_HAVE_TRANSPORT_TCP_TLS": "define GLOO_HAVE_TRANSPORT_TCP_TLS 1", - "cmakedefine01 GLOO_HAVE_TRANSPORT_TCP": "define GLOO_HAVE_TRANSPORT_TCP 1", - "cmakedefine01 GLOO_HAVE_TRANSPORT_IBVERBS": "define GLOO_HAVE_TRANSPORT_IBVERBS 0", - "cmakedefine01 GLOO_HAVE_TRANSPORT_UV": "define GLOO_HAVE_TRANSPORT_UV 0", - }, -) - -cc_library( - name = "gloo_headers", - hdrs = glob( - [ - "gloo/*.h", - "gloo/common/*.h", - "gloo/rendezvous/*.h", - "gloo/transport/*.h", - "gloo/transport/tcp/*.h", - "gloo/transport/tcp/tls/*.h", - ], - exclude = [ - "gloo/rendezvous/redis_store.h", - ], - ) + ["gloo/config.h"], - includes = [ - ".", - ], -) - -cu_library( - name = "gloo_cuda", - srcs = [ - "gloo/cuda.cu", - "gloo/cuda_private.cu", - ], - visibility = ["//visibility:public"], - deps = [ - ":gloo_headers", - ], - alwayslink = True, -) - -cc_library( - name = "gloo", - srcs = glob( - [ - "gloo/*.cc", - "gloo/common/*.cc", - "gloo/rendezvous/*.cc", - "gloo/transport/*.cc", - "gloo/transport/tcp/*.cc", - ], - exclude = [ - "gloo/cuda*.cc", - "gloo/common/win.cc", - "gloo/rendezvous/redis_store.cc", - ] - ) + if_cuda(glob(["gloo/cuda*.cc"])), - copts = [ - "-std=c++20", - ], - visibility = ["//visibility:public"], - deps = [":gloo_headers"] + if_cuda( - [":gloo_cuda"], - [], - ), -) diff --git a/third_party/ideep b/third_party/ideep index 8e7ddd65df95f..e539e0f9774e2 160000 --- a/third_party/ideep +++ b/third_party/ideep @@ -1 +1 @@ -Subproject commit 8e7ddd65df95f13e41f0a40c820c5f35ae4a0ea3 +Subproject commit e539e0f9774e2018f0d56fe865da66581f692e3d diff --git a/third_party/ideep.BUILD b/third_party/ideep.BUILD deleted file mode 100644 index 882d5cb342a41..0000000000000 --- a/third_party/ideep.BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "ideep", - hdrs = glob([ - "include/**/*.hpp", - "include/**/*.h", - ]), - defines = [ - "IDEEP_USE_MKL", - ], - includes = [ - "include/", - ], - visibility = ["//visibility:public"], - deps = ["@mkl_dnn//:mkl-dnn"], -) diff --git a/third_party/kineto b/third_party/kineto index 03ab8cb08c1ba..23b5bb5764b3d 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit 03ab8cb08c1ba1c57c917c1c2fa39ab24b67bcda +Subproject commit 23b5bb5764b3dec988e25c52098407e508d84bb4 diff --git a/third_party/kineto.BUILD b/third_party/kineto.BUILD deleted file mode 100644 index d8e484ae80b6b..0000000000000 --- a/third_party/kineto.BUILD +++ /dev/null @@ -1,10 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "kineto", - hdrs = glob(["libkineto/include/*.h",]), - includes = [ - "libkineto/include/", - ], - visibility = ["//visibility:public"], -) diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD deleted file mode 100644 index f62612ba18b86..0000000000000 --- a/third_party/mkl-dnn.BUILD +++ /dev/null @@ -1,164 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") -load("@pytorch//third_party:substitution.bzl", "template_rule") - -_DNNL_RUNTIME_OMP = { - "#cmakedefine DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_${DNNL_CPU_THREADING_RUNTIME}": "#define DNNL_CPU_THREADING_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_CPU_RUNTIME DNNL_RUNTIME_${DNNL_CPU_RUNTIME}": "#define DNNL_CPU_RUNTIME DNNL_RUNTIME_OMP", - "#cmakedefine DNNL_GPU_RUNTIME DNNL_RUNTIME_${DNNL_GPU_RUNTIME}": "#define DNNL_GPU_RUNTIME DNNL_RUNTIME_NONE", - "#cmakedefine DNNL_GPU_VENDOR DNNL_VENDOR_${DNNL_GPU_VENDOR}": "/* undef DNNL_GPU_VENDOR */", - "#cmakedefine DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE": "/* undef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE */", - "#cmakedefine DNNL_WITH_SYCL": "/* #undef DNNL_WITH_SYCL */", - "#cmakedefine DNNL_WITH_LEVEL_ZERO": "/* #undef DNNL_WITH_LEVEL_ZERO */", - "#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */", - "#cmakedefine DNNL_SYCL_HIP": "/* #undef DNNL_SYCL_HIP */", - "#cmakedefine DNNL_SYCL_GENERIC": "/* #undef DNNL_SYCL_GENERIC */", - "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", - "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */", - "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", - "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", - "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", - "#cmakedefine DNNL_EXPERIMENTAL_LOGGING": "#undef DNNL_EXPERIMENTAL_LOGGING", - "#cmakedefine DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER": "#undef DNNL_EXPERIMENTAL_SYCL_KERNEL_COMPILER", - "#cmakedefine DNNL_DISABLE_GPU_REF_KERNELS": "#undef DNNL_DISABLE_GPU_REF_KERNELS", - "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", - "#cmakedefine01 BUILD_INFERENCE": "#define BUILD_INFERENCE 0", - "#cmakedefine01 BUILD_PRIMITIVE_ALL": "#define BUILD_PRIMITIVE_ALL 1", - "#cmakedefine01 BUILD_BATCH_NORMALIZATION": "#define BUILD_BATCH_NORMALIZATION 0", - "#cmakedefine01 BUILD_BINARY": "#define BUILD_BINARY 0", - "#cmakedefine01 BUILD_CONCAT": "#define BUILD_CONCAT 0", - "#cmakedefine01 BUILD_CONVOLUTION": "#define BUILD_CONVOLUTION 0", - "#cmakedefine01 BUILD_DECONVOLUTION": "#define BUILD_DECONVOLUTION 0", - "#cmakedefine01 BUILD_ELTWISE": "#define BUILD_ELTWISE 0", - "#cmakedefine01 BUILD_GROUP_NORMALIZATION": "#define BUILD_GROUP_NORMALIZATION 0", - "#cmakedefine01 BUILD_INNER_PRODUCT": "#define BUILD_INNER_PRODUCT 0", - "#cmakedefine01 BUILD_LAYER_NORMALIZATION": "#define BUILD_LAYER_NORMALIZATION 0", - "#cmakedefine01 BUILD_LRN": "#define BUILD_LRN 0", - "#cmakedefine01 BUILD_MATMUL": "#define BUILD_MATMUL 0", - "#cmakedefine01 BUILD_POOLING": "#define BUILD_POOLING 0", - "#cmakedefine01 BUILD_PRELU": "#define BUILD_PRELU 0", - "#cmakedefine01 BUILD_REDUCTION": "#define BUILD_REDUCTION 0", - "#cmakedefine01 BUILD_REORDER": "#define BUILD_REORDER 0", - "#cmakedefine01 BUILD_RESAMPLING": "#define BUILD_RESAMPLING 0", - "#cmakedefine01 BUILD_RNN": "#define BUILD_RNN 0", - "#cmakedefine01 BUILD_SDPA": "#define BUILD_SDPA 0", - "#cmakedefine01 BUILD_SHUFFLE": "#define BUILD_SHUFFLE 0", - "#cmakedefine01 BUILD_SOFTMAX": "#define BUILD_SOFTMAX 0", - "#cmakedefine01 BUILD_SUM": "#define BUILD_SUM 0", - "#cmakedefine01 BUILD_PRIMITIVE_CPU_ISA_ALL": "#define BUILD_PRIMITIVE_CPU_ISA_ALL 1", - "#cmakedefine01 BUILD_SSE41": "#define BUILD_SSE41 0", - "#cmakedefine01 BUILD_AVX2": "#define BUILD_AVX2 0", - "#cmakedefine01 BUILD_AVX512": "#define BUILD_AVX512 0", - "#cmakedefine01 BUILD_AMX": "#define BUILD_AMX 0", - "#cmakedefine01 BUILD_PRIMITIVE_GPU_ISA_ALL": "#define BUILD_PRIMITIVE_GPU_ISA_ALL 1", - "#cmakedefine01 BUILD_XELP": "#define BUILD_XELP 0", - "#cmakedefine01 BUILD_XEHPG": "#define BUILD_XEHPG 0", - "#cmakedefine01 BUILD_XEHPC": "#define BUILD_XEHPC 0", - "#cmakedefine01 BUILD_XEHP": "#define BUILD_XEHP 0", - "#cmakedefine01 BUILD_XE2": "#define BUILD_XE2 0", - "#cmakedefine01 BUILD_XE3": "#define BUILD_XE3 0", - "#cmakedefine01 BUILD_GEMM_KERNELS_ALL": "#define BUILD_GEMM_KERNELS_ALL 0", - "#cmakedefine01 BUILD_GEMM_KERNELS_NONE": "#define BUILD_GEMM_KERNELS_NONE 0", - "#cmakedefine01 BUILD_GEMM_SSE41": "#define BUILD_GEMM_SSE41 0", - "#cmakedefine01 BUILD_GEMM_AVX2": "#define BUILD_GEMM_AVX2 0", - "#cmakedefine01 BUILD_GEMM_AVX512": "#define BUILD_GEMM_AVX512 0", -} - -template_rule( - name = "include_dnnl_version", - src = "include/oneapi/dnnl/dnnl_version.h.in", - out = "include/oneapi/dnnl/dnnl_version.h", - substitutions = { - "@DNNL_VERSION_MAJOR@": "3", - "@DNNL_VERSION_MINOR@": "10", - "@DNNL_VERSION_PATCH@": "2", - }, -) - -template_rule( - name = "include_dnnl_config", - src = "include/oneapi/dnnl/dnnl_config.h.in", - out = "include/oneapi/dnnl/dnnl_config.h", - substitutions = _DNNL_RUNTIME_OMP, -) - -template_rule( - name = "include_dnnl_version_hash", - src = "include/oneapi/dnnl/dnnl_version_hash.h.in", - out = "include/oneapi/dnnl/dnnl_version_hash.h", - substitutions = {"@DNNL_VERSION_HASH@": "f1d471933dc852f956fd05389f9313c7148783d5",} -) - -cc_library( - name = "mkl-dnn", - srcs = glob([ - "src/common/*.cpp", - "src/cpu/**/*.cpp", - "src/cpu/**/**/*.cpp", - ], exclude=[ - "src/cpu/aarch64/**/*.cpp", - "src/cpu/rv64/**/*.cpp", - "src/cpu/sycl/**/*.cpp", - "src/cpu/ppc64/**/*.cpp", - ]), - hdrs = glob([ - "include/oneapi/dnnl/*.h", - "include/oneapi/dnnl/*.hpp", - "include/*.h", - "include/*.hpp", - "src/cpu/**/*.hpp", - "src/cpu/**/*.h", - "src/cpu/**/**/*.h", - "src/common/*.hpp", - "src/common/**/**/*.h", - "third_party/xbyak/*.h", - "third_party/ittnotify/jitprofiling.h", - "third_party/spdlog/**/*.h", - ], exclude=[ - "src/cpu/aarch64/**/*.hpp", - "src/cpu/aarch64/**/*.h", - "src/cpu/rv64/**/*.hpp", - "src/cpu/rv64/**/*.h", - "src/cpu/sycl/**/*.hpp", - "src/cpu/ppc64/**/*.hpp", - ]) + [ - "include/oneapi/dnnl/dnnl_config.h", - "include/oneapi/dnnl/dnnl_version.h", - "include/oneapi/dnnl/dnnl_version_hash.h", - ], - copts = [ - "-DDNNL_DLL", - "-DDNNL_DLL_EXPORTS", - "-DDNNL_ENABLE_CONCURRENT_EXEC", - "-D__STDC_CONSTANT_MACROS", - "-D__STDC_LIMIT_MACROS", - "-fno-strict-overflow", - "-fopenmp", - ] + select({ - "@pytorch//tools/config:thread_sanitizer": ["-DDNNL_CPU_RUNTIME=0"], - "//conditions:default": ["-DDNNL_CPU_RUNTIME=2"], - }), - includes = [ - "include/", - "include/oneapi/", - "include/oneapi/dnnl/", - "src/", - "src/common/", - "src/cpu/", - "third_party/", - ], - visibility = ["//visibility:public"], - linkopts = [ - "-lgomp", - ], - deps = [ - "@mkl", - ], - defines = [ - "DNNL_ENABLE_MAX_CPU_ISA", - "DNNL_ENABLE_CONCURRENT_EXEC", - "DNNL_ENABLE_PRIMITIVE_CACHE", - "DNNL_ENABLE_CPU_ISA_HINTS", - "DNNL_EXPERIMENTAL_UKERNEL", - "ONEDNN_BUILD_GRAPH", - ], -) diff --git a/third_party/mkl.BUILD b/third_party/mkl.BUILD deleted file mode 100644 index b7abb0e035ad9..0000000000000 --- a/third_party/mkl.BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "mkl", - srcs = [ - "libmkl_avx2.so", - "libmkl_core.so", - "libmkl_def.so", - "libmkl_intel_lp64.so", - "libmkl_rt.so", - "libmkl_sequential.so", - "libmkl_vml_avx2.so", - "libmkl_vml_avx512.so", - "libmkl_vml_def.so", - ], - visibility = ["//visibility:public"], - deps = ["@mkl_headers"], -) diff --git a/third_party/mkl_headers.BUILD b/third_party/mkl_headers.BUILD deleted file mode 100644 index 965801c91aa97..0000000000000 --- a/third_party/mkl_headers.BUILD +++ /dev/null @@ -1,8 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "mkl_headers", - hdrs = glob(["include/*.h"]), - includes = ["include/"], - visibility = ["//visibility:public"], -) diff --git a/third_party/moodycamel.BUILD b/third_party/moodycamel.BUILD deleted file mode 100644 index d3028205016fb..0000000000000 --- a/third_party/moodycamel.BUILD +++ /dev/null @@ -1,7 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library( - name = "moodycamel", - hdrs = glob(["**/*.h"]), - visibility = ["//visibility:public"], -) diff --git a/third_party/nlohmann.BUILD b/third_party/nlohmann.BUILD deleted file mode 100644 index 64dfbbab2b6e9..0000000000000 --- a/third_party/nlohmann.BUILD +++ /dev/null @@ -1,18 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") - -cc_library(name = "nlohmann", - includes = ["include"], - deps = ["nlohmann-internal"], - visibility = ["//visibility:public"], -) - -cc_import(name = "nlohmann-internal", - hdrs = glob(["include/**/*.hpp"]), - visibility = ["//visibility:private"], -) - -cc_library( - name = "nlohmann_single_include", - hdrs = glob(["single_include/nlohmann/*.hpp"]), - visibility = ["//visibility:public"], -) diff --git a/third_party/onnx.BUILD b/third_party/onnx.BUILD deleted file mode 100644 index c4134e98d0fe2..0000000000000 --- a/third_party/onnx.BUILD +++ /dev/null @@ -1,108 +0,0 @@ -load("@rules_proto//proto:defs.bzl", "proto_library") -load("@rules_cc//cc:defs.bzl", "cc_library", "cc_proto_library") -load("@rules_python//python:defs.bzl", "py_binary") - -py_binary( - name = "gen_proto", - srcs = ["onnx/gen_proto.py"], - data = [ - "onnx/onnx.in.proto", - "onnx/onnx-operators.in.proto", - "onnx/onnx-data.in.proto", - ], -) - -genrule( - name = "generate_onnx_proto", - outs = [ - "onnx/onnx_onnx_torch-ml.proto", - "onnx/onnx-ml.pb.h", - ], - cmd = "$(location :gen_proto) -p onnx_torch -o $(@D)/onnx onnx -m >/dev/null && sed -i 's/onnx_onnx_torch-ml.pb.h/onnx\\/onnx_onnx_torch-ml.pb.h/g' $(@D)/onnx/onnx-ml.pb.h", - tools = [":gen_proto"], -) - -genrule( - name = "generate_onnx_operators_proto", - outs = [ - "onnx/onnx-operators_onnx_torch-ml.proto", - "onnx/onnx-operators-ml.pb.h", - ], - cmd = "$(location :gen_proto) -p onnx_torch -o $(@D)/onnx onnx-operators -m >/dev/null && sed -i 's/onnx-operators_onnx_torch-ml.pb.h/onnx\\/onnx-operators_onnx_torch-ml.pb.h/g' $(@D)/onnx/onnx-operators-ml.pb.h", - tools = [":gen_proto"], -) - -genrule( - name = "generate_onnx_data_proto", - outs = [ - "onnx/onnx-data_onnx_torch.proto", - "onnx/onnx-data.pb.h", - ], - cmd = "$(location :gen_proto) -p onnx_torch -o $(@D)/onnx onnx-data -m >/dev/null && sed -i 's/onnx-data_onnx_torch.pb.h/onnx\\/onnx-data_onnx_torch.pb.h/g' $(@D)/onnx/onnx-data.pb.h", - tools = [":gen_proto"], -) - -cc_library( - name = "onnx", - srcs = glob( - [ - "onnx/*.cc", - "onnx/common/*.cc", - "onnx/defs/**/*.cc", - "onnx/shape_inference/*.cc", - "onnx/version_converter/*.cc", - ], - exclude = [ - "onnx/cpp2py_export.cc", - ], - ), - hdrs = glob([ - "onnx/*.h", - "onnx/version_converter/*.h", - "onnx/common/*.h", - "onnx/defs/**/*.h", - "onnx/shape_inference/*.h", - "onnx/version_converter/adapters/*.h", - ]) + [ - "onnx/onnx-ml.pb.h", - "onnx/onnx-operators-ml.pb.h", - "onnx/onnx-data.pb.h", - ], - defines = [ - "ONNX_ML=1", - "ONNX_NAMESPACE=onnx_torch", - ], - includes = [ - ".", - "onnx/", - ], - visibility = ["//visibility:public"], - deps = [ - ":onnx_proto_lib", - ], -) - -cc_library( - name = "onnx_proto_headers", - hdrs = glob([ - "onnx/*_pb.h", - ]), - visibility = ["//visibility:public"], - deps = [ - ":onnx_proto_lib", - ], -) - -proto_library( - name = "onnx_proto", - srcs = [ - "onnx/onnx-operators_onnx_torch-ml.proto", - "onnx/onnx_onnx_torch-ml.proto", - "onnx/onnx-data_onnx_torch.proto", - ], -) - -cc_proto_library( - name = "onnx_proto_lib", - deps = [":onnx_proto"], -) diff --git a/third_party/pybind11 b/third_party/pybind11 index f5fbe867d2d26..d03662f0984f6 160000 --- a/third_party/pybind11 +++ b/third_party/pybind11 @@ -1 +1 @@ -Subproject commit f5fbe867d2d26e4a0a9177a51f6e568868ad3dc8 +Subproject commit d03662f0984f652b60e7ddce53d3868002275197 diff --git a/third_party/sleef.BUILD b/third_party/sleef.BUILD deleted file mode 100644 index f22a6e905e2be..0000000000000 --- a/third_party/sleef.BUILD +++ /dev/null @@ -1,487 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library") -load("@pytorch//third_party:sleef.bzl", "sleef_cc_library") - -SLEEF_COPTS = [ - "-DHAVE_MALLOC_USABLE_SIZE=1", - "-DHAVE_MMAP=1", - "-DHAVE_SHM_OPEN=1", - "-DHAVE_SHM_UNLINK=1", - "-DIDEEP_USE_MKL", - "-DDNNL_CPU_RUNTIME=TBB", - "-DONNX_ML=1", - "-DONNX_NAMESPACE=onnx", - "-D_FILE_OFFSET_BITS=64", - "-ffp-contract=off", - "-fno-math-errno", - "-fno-trapping-math", - "-DCAFFE2_USE_GLOO", - "-std=gnu99", -] - -SLEEF_COMMON_TARGET_COPTS = [ - "-DSLEEF_STATIC_LIBS=1", - "-DENABLE_ALIAS=1", -] - -SLEEF_PRIVATE_HEADERS = glob([ - "build/include/*.h", - "src/arch/*.h", - "src/common/*.h", - "src/libm/*.h", - "src/libm/include/*.h", -]) - -SLEEF_PUBLIC_HEADERS = [ - ":sleef_h", -] - -SLEEF_PRIVATE_INCLUDES = [ - "-Iexternal/sleef/src/arch", - "-Iexternal/sleef/src/common", - "-Iexternal/sleef/src/libm", -] - -SLEEF_PUBLIC_INCLUDES = [ - "build/include", -] - -SLEEF_VISIBILITY = [ - "//visibility:public", -] - -cc_binary( - name = "mkalias", - srcs = [ - "src/libm/funcproto.h", - "src/libm/mkalias.c", - ], -) - -genrule( - name = "alias_avx512f_h", - outs = ["alias_avx512f.h"], - cmd = "{ " + "; ".join([ - "$(location :mkalias) -16 __m512 __m512i e avx512f", - "$(location :mkalias) 8 __m512d __m256i e avx512f", - ]) + "; } > $@", - tools = [":mkalias"], -) - -cc_binary( - name = "mkdisp", - srcs = [ - "src/libm/funcproto.h", - "src/libm/mkdisp.c", - ], - copts = SLEEF_COPTS, -) - -genrule( - name = "dispavx_c", - srcs = ["src/libm/dispavx.c.org"], - outs = ["dispavx.c"], - cmd = "{ cat $(location src/libm/dispavx.c.org); $(location :mkdisp) 4 8 __m256d __m256 __m128i avx fma4 avx2; } > $@", - tools = [":mkdisp"], -) - -genrule( - name = "dispsse_c", - srcs = ["src/libm/dispsse.c.org"], - outs = ["dispsse.c"], - cmd = "{ cat $(location src/libm/dispsse.c.org); $(location :mkdisp) 2 4 __m128d __m128 __m128i sse2 sse4 avx2128; } > $@", - tools = [":mkdisp"], -) - -cc_binary( - name = "mkrename", - srcs = [ - "src/libm/funcproto.h", - "src/libm/mkrename.c", - ], -) - -genrule( - name = "renameavx_h", - outs = ["renameavx.h"], - cmd = "$(location :mkrename) cinz_ 4 8 avx > $@", - tools = [":mkrename"], -) - -genrule( - name = "renameavx2_h", - outs = ["renameavx2.h"], - cmd = "$(location :mkrename) finz_ 4 8 avx2 > $@", - tools = [":mkrename"], -) - -genrule( - name = "renameavx2128_h", - outs = ["renameavx2128.h"], - cmd = "$(location :mkrename) finz_ 2 4 avx2128 > $@", - tools = [":mkrename"], -) - -genrule( - name = "renameavx512f_h", - outs = ["renameavx512f.h"], - cmd = "$(location :mkrename) finz_ 8 16 avx512f > $@", - tools = [":mkrename"], -) - -genrule( - name = "renameavx512fnofma_h", - outs = ["renameavx512fnofma.h"], - cmd = "$(location :mkrename) cinz_ 8 16 avx512fnofma > $@", - tools = [":mkrename"], -) - -genrule( - name = "renamefma4_h", - outs = ["renamefma4.h"], - cmd = "$(location :mkrename) finz_ 4 8 fma4 > $@", - tools = [":mkrename"], -) - -genrule( - name = "renamepurec_scalar_h", - outs = ["renamepurec_scalar.h"], - cmd = "$(location :mkrename) cinz_ 1 1 purec > $@", - tools = [":mkrename"], -) - -genrule( - name = "renamepurecfma_scalar_h", - outs = ["renamepurecfma_scalar.h"], - cmd = "$(location :mkrename) finz_ 1 1 purecfma > $@", - tools = [":mkrename"], -) - -genrule( - name = "renamesse2_h", - outs = ["renamesse2.h"], - cmd = "$(location :mkrename) cinz_ 2 4 sse2 > $@", - tools = [":mkrename"], -) - -genrule( - name = "renamesse4_h", - outs = ["renamesse4.h"], - cmd = "$(location :mkrename) cinz_ 2 4 sse4 > $@", - tools = [":mkrename"], -) - -genrule( - name = "sleef_h", - srcs = [ - "src/libm/sleeflibm_header.h.org.in", - "src/libm/sleeflibm_footer.h.org", - ], - outs = ["build/include/sleef.h"], - cmd = "{ " + "; ".join([ - "cat $(location src/libm/sleeflibm_header.h.org.in)", - "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__", - "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__ sse2", - "$(location :mkrename) cinz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__ sse4", - "$(location :mkrename) cinz_ 4 8 __m256d __m256 __m128i \"struct { __m128i x, y; }\" __AVX__", - "$(location :mkrename) cinz_ 4 8 __m256d __m256 __m128i \"struct { __m128i x, y; }\" __AVX__ avx", - "$(location :mkrename) finz_ 4 8 __m256d __m256 __m128i \"struct { __m128i x, y; }\" __AVX__ fma4", - "$(location :mkrename) finz_ 4 8 __m256d __m256 __m128i __m256i __AVX__ avx2", - "$(location :mkrename) finz_ 2 4 __m128d __m128 __m128i __m128i __SSE2__ avx2128", - "$(location :mkrename) finz_ 8 16 __m512d __m512 __m256i __m512i __AVX512F__", - "$(location :mkrename) finz_ 8 16 __m512d __m512 __m256i __m512i __AVX512F__ avx512f", - "$(location :mkrename) cinz_ 8 16 __m512d __m512 __m256i __m512i __AVX512F__ avx512fnofma", - "$(location :mkrename) cinz_ 1 1 double float int32_t int32_t __STDC__ purec", - "$(location :mkrename) finz_ 1 1 double float int32_t int32_t FP_FAST_FMA purecfma", - "cat $(location src/libm/sleeflibm_footer.h.org)", - ]) + "; } > $@", - tools = [":mkrename"], -) - -cc_library( - name = "sleef", - srcs = [ - "src/libm/rempitab.c", - "src/libm/sleefdp.c", - "src/libm/sleefsp.c", - ], - hdrs = SLEEF_PUBLIC_HEADERS, - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLEFLOAT128=1", - "-Wno-unused-result", - ], - includes = SLEEF_PUBLIC_INCLUDES, - # -lgcc resolves - # U __addtf3 - # U __eqtf2 - # U __fixtfdi - # U __floatditf - # U __gttf2 - # U __lttf2 - # U __multf3 - # U __subtf3 - # in bazel-bin/external/sleef/_objs/sleef/sleefqp.pic.o - linkopts = [ - "-lgcc", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - # The purpose of the lists in deps is to keep related pairs of - # libraries together. In particular, each pair that contains a *det* - # library originates with a sleef_cc_library(). - deps = [ - ":common", - ":dispavx", - ":dispsse", - ] + [ - ":sleefavx", - ":sleefdetavx", - ] + [ - ":sleefavx2", - ":sleefdetavx2", - ] + [ - ":sleefavx2128", - ":sleefdetavx2128", - ] + [ - ":sleefavx512f", - ":sleefdetavx512f", - ] + [ - ":sleefavx512fnofma", - ":sleefdetavx512fnofma", - ] + [ - ":sleeffma4", - ":sleefdetfma4", - ] + [ - ":sleefsse2", - ":sleefdetsse2", - ] + [ - ":sleefsse4", - ":sleefdetsse4", - ] + [ - ":sleefpurec_scalar", - ":sleefdetpurec_scalar", - ] + [ - ":sleefpurecfma_scalar", - ":sleefdetpurecfma_scalar", - ], - alwayslink = True, -) - -cc_library( - name = "common", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/common/common.c", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + [ - "-Wno-unused-result", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -cc_library( - name = "dispavx", - srcs = SLEEF_PRIVATE_HEADERS + SLEEF_PUBLIC_HEADERS + [ - ":dispavx_c", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DENABLE_AVX2=1", - "-DENABLE_FMA4=1", - "-mavx", - ], - includes = SLEEF_PUBLIC_INCLUDES, - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -cc_library( - name = "dispsse", - srcs = SLEEF_PRIVATE_HEADERS + SLEEF_PUBLIC_HEADERS + [ - ":dispsse_c", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DENABLE_AVX2=1", - "-DENABLE_FMA4=1", - "-msse2", - ], - includes = SLEEF_PUBLIC_INCLUDES, - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefavx512f", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":alias_avx512f_h", - ":renameavx512f_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DALIAS_NO_EXT_SUFFIX=\\\"alias_avx512f.h\\\"", - "-DENABLE_AVX512F=1", - "-mavx512f", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefavx512fnofma", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renameavx512fnofma_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_AVX512FNOFMA=1", - "-mavx512f", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefavx", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renameavx_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_AVX=1", - "-mavx", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefavx2", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renameavx2_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_AVX2=1", - "-mavx2", - "-mfma", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefavx2128", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renameavx2128_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_AVX2128=1", - "-mavx2", - "-mfma", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleeffma4", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renamefma4_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_FMA4=1", - "-mfma4", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefsse2", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renamesse2_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_SSE2=1", - "-msse2", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefsse4", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renamesse4_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_SSE4=1", - "-msse4.1", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefpurec_scalar", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renamepurec_scalar_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_PUREC_SCALAR=1", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) - -sleef_cc_library( - name = "sleefpurecfma_scalar", - srcs = SLEEF_PRIVATE_HEADERS + [ - "src/libm/sleefsimddp.c", - "src/libm/sleefsimdsp.c", - ":renamepurecfma_scalar_h", - ], - copts = SLEEF_PRIVATE_INCLUDES + SLEEF_COPTS + SLEEF_COMMON_TARGET_COPTS + [ - "-DDORENAME=1", - "-DENABLE_PURECFMA_SCALAR=1", - "-mavx2", - "-mfma", - ], - linkstatic = True, - visibility = SLEEF_VISIBILITY, - alwayslink = True, -) diff --git a/third_party/substitution.bzl b/third_party/substitution.bzl deleted file mode 100644 index 9310985e50ea5..0000000000000 --- a/third_party/substitution.bzl +++ /dev/null @@ -1,81 +0,0 @@ -# This Bazel rules file is derived from https://github.com/tensorflow/tensorflow/blob/master/third_party/common.bzl - -# Rule for simple expansion of template files. This performs a simple -# search over the template file for the keys in substitutions, -# and replaces them with the corresponding values. -# -# Typical usage: -# load("/tools/build_rules/template_rule", "template_rule") -# template_rule( -# name = "ExpandMyTemplate", -# src = "my.template", -# out = "my.txt", -# substitutions = { -# "$VAR1": "foo", -# "$VAR2": "bar", -# } -# ) -# -# Args: -# name: The name of the rule. -# template: The template file to expand -# out: The destination of the expanded file -# substitutions: A dictionary mapping strings to their substitutions - -def template_rule_impl(ctx): - ctx.actions.expand_template( - template = ctx.file.src, - output = ctx.outputs.out, - substitutions = ctx.attr.substitutions, - ) - -template_rule = rule( - attrs = { - "out": attr.output(mandatory = True), - "src": attr.label( - mandatory = True, - allow_single_file = True, - ), - "substitutions": attr.string_dict(mandatory = True), - }, - # output_to_genfiles is required for header files. - output_to_genfiles = True, - implementation = template_rule_impl, -) - -# Header template rule is an extension of template substitution rule -# That also makes this header a valid dependency for cc_library -# From https://stackoverflow.com/a/55407399 -def header_template_rule_impl(ctx): - ctx.actions.expand_template( - template = ctx.file.src, - output = ctx.outputs.out, - substitutions = ctx.attr.substitutions, - ) - return [ - # create a provider which says that this - # out file should be made available as a header - CcInfo(compilation_context = cc_common.create_compilation_context( - - # pass out the include path for finding this header - system_includes = depset([ctx.attr.include, ctx.outputs.out.dirname, ctx.bin_dir.path]), - - # and the actual header here. - headers = depset([ctx.outputs.out]), - )), - ] - -header_template_rule = rule( - attrs = { - "include": attr.string(), - "out": attr.output(mandatory = True), - "src": attr.label( - mandatory = True, - allow_single_file = True, - ), - "substitutions": attr.string_dict(mandatory = True), - }, - # output_to_genfiles is required for header files. - output_to_genfiles = True, - implementation = header_template_rule_impl, -) diff --git a/third_party/tensorflow_cuda_bazel_build/cuda/build_defs.bzl b/third_party/tensorflow_cuda_bazel_build/cuda/build_defs.bzl deleted file mode 100755 index a394b6ce92044..0000000000000 --- a/third_party/tensorflow_cuda_bazel_build/cuda/build_defs.bzl +++ /dev/null @@ -1,31 +0,0 @@ -# Macros for building CUDA code. -def if_cuda(if_true, if_false = []): - """Shorthand for select()'ing on whether we're building with CUDA. - - Returns a select statement which evaluates to if_true if we're building - with CUDA enabled. Otherwise, the select statement evaluates to if_false. - - """ - return select({ - "@local_config_cuda//cuda:using_clang": if_true, - "@local_config_cuda//cuda:using_nvcc": if_true, - "//conditions:default": if_false, - }) - -def cuda_default_copts(): - """Default options for all CUDA compilations.""" - return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + []) - -def cuda_is_configured(): - """Returns true if CUDA was enabled during the configure process.""" - return True - -def if_cuda_is_configured(x): - """Tests if the CUDA was enabled during the configure process. - - Unlike if_cuda(), this does not require that we are building with - --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries. - """ - if cuda_is_configured(): - return x - return [] diff --git a/third_party/tensorpipe.BUILD b/third_party/tensorpipe.BUILD deleted file mode 100644 index 9c074fcd54ac8..0000000000000 --- a/third_party/tensorpipe.BUILD +++ /dev/null @@ -1,178 +0,0 @@ -load("@rules_cc//cc:defs.bzl", "cc_library") -load("@pytorch//third_party:substitution.bzl", "header_template_rule") - -LIBUV_COMMON_SRCS = [ - "third_party/libuv/src/fs-poll.c", - "third_party/libuv/src/idna.c", - "third_party/libuv/src/inet.c", - "third_party/libuv/src/random.c", - "third_party/libuv/src/strscpy.c", - "third_party/libuv/src/strtok.c", - "third_party/libuv/src/threadpool.c", - "third_party/libuv/src/timer.c", - "third_party/libuv/src/uv-common.c", - "third_party/libuv/src/uv-data-getter-setters.c", - "third_party/libuv/src/version.c", -] - -LIBUV_POSIX_SRCS = [ - "third_party/libuv/src/unix/async.c", - "third_party/libuv/src/unix/core.c", - "third_party/libuv/src/unix/dl.c", - "third_party/libuv/src/unix/fs.c", - "third_party/libuv/src/unix/getaddrinfo.c", - "third_party/libuv/src/unix/getnameinfo.c", - "third_party/libuv/src/unix/loop.c", - "third_party/libuv/src/unix/loop-watcher.c", - "third_party/libuv/src/unix/pipe.c", - "third_party/libuv/src/unix/poll.c", - "third_party/libuv/src/unix/process.c", - "third_party/libuv/src/unix/random-devurandom.c", - "third_party/libuv/src/unix/signal.c", - "third_party/libuv/src/unix/stream.c", - "third_party/libuv/src/unix/tcp.c", - "third_party/libuv/src/unix/thread.c", - "third_party/libuv/src/unix/tty.c", - "third_party/libuv/src/unix/udp.c", -] - -LIBUV_LINUX_SRCS = LIBUV_POSIX_SRCS + [ - "third_party/libuv/src/unix/proctitle.c", - "third_party/libuv/src/unix/linux.c", - "third_party/libuv/src/unix/procfs-exepath.c", - "third_party/libuv/src/unix/random-getrandom.c", - "third_party/libuv/src/unix/random-sysctl-linux.c", -] - -cc_library( - name = "libuv", - srcs = LIBUV_COMMON_SRCS + LIBUV_LINUX_SRCS, - includes = [ - "third_party/libuv/include", - "third_party/libuv/src", - ], - hdrs = glob( - [ - "third_party/libuv/include/*.h", - "third_party/libuv/include/uv/*.h", - "third_party/libuv/src/*.h", - "third_party/libuv/src/unix/*.h", - ], - ), - copts = ["-D_GNU_SOURCE"], - visibility = ["//visibility:public"], -) - -cc_library( - name = "libnop", - srcs = [], - includes = ["third_party/libnop/include"], - hdrs = glob(["third_party/libnop/include/**/*.h"]), -) - -header_template_rule( - name = "tensorpipe_cpu_config_header", - src = "tensorpipe/config.h.in", - out = "tensorpipe/config.h", - substitutions = { - "#cmakedefine01 TENSORPIPE_HAS_SHM_TRANSPORT": "#define TENSORPIPE_HAS_SHM_TRANSPORT 1", - "#cmakedefine01 TENSORPIPE_HAS_IBV_TRANSPORT": "#define TENSORPIPE_HAS_IBV_TRANSPORT 1", - "#cmakedefine01 TENSORPIPE_HAS_CMA_CHANNEL": "#define TENSORPIPE_HAS_CMA_CHANNEL 1", - }, -) - -header_template_rule( - name = "tensorpipe_cuda_config_header", - src = "tensorpipe/config_cuda.h.in", - out = "tensorpipe/config_cuda.h", - substitutions = { - "#cmakedefine01 TENSORPIPE_HAS_CUDA_IPC_CHANNEL": "#define TENSORPIPE_HAS_CUDA_IPC_CHANNEL 1", - "#cmakedefine01 TENSORPIPE_HAS_CUDA_GDR_CHANNEL": "#define TENSORPIPE_HAS_CUDA_GDR_CHANNEL 1", - }, -) - -# We explicitly list the CUDA headers & sources, and we consider everything else -# as CPU (using a catch-all glob). This is both because there's fewer CUDA files -# (thus making it easier to list them exhaustively) and because it will make it -# more likely to catch a misclassified file: if we forget to mark a file as CUDA -# we'll try to build it on CPU and that's likely to fail. - -TENSORPIPE_CUDA_HEADERS = [ - "tensorpipe/tensorpipe_cuda.h", - "tensorpipe/channel/cuda_basic/*.h", - "tensorpipe/channel/cuda_gdr/*.h", - "tensorpipe/channel/cuda_ipc/*.h", - "tensorpipe/channel/cuda_xth/*.h", - "tensorpipe/common/cuda.h", - "tensorpipe/common/cuda_buffer.h", - "tensorpipe/common/cuda_lib.h", - "tensorpipe/common/cuda_loop.h", - "tensorpipe/common/nvml_lib.h", -] - -TENSORPIPE_CUDA_SOURCES = [ - "tensorpipe/channel/cuda_basic/*.cc", - "tensorpipe/channel/cuda_gdr/*.cc", - "tensorpipe/channel/cuda_ipc/*.cc", - "tensorpipe/channel/cuda_xth/*.cc", - "tensorpipe/common/cuda_buffer.cc", - "tensorpipe/common/cuda_loop.cc", -] - -TENSORPIPE_CPU_HEADERS = glob( - [ - "tensorpipe/*.h", - "tensorpipe/channel/*.h", - "tensorpipe/channel/*/*.h", - "tensorpipe/common/*.h", - "tensorpipe/core/*.h", - "tensorpipe/transport/*.h", - "tensorpipe/transport/*/*.h", - ], - exclude=TENSORPIPE_CUDA_HEADERS) - -TENSORPIPE_CPU_SOURCES = glob( - [ - "tensorpipe/*.cc", - "tensorpipe/channel/*.cc", - "tensorpipe/channel/*/*.cc", - "tensorpipe/common/*.cc", - "tensorpipe/core/*.cc", - "tensorpipe/transport/*.cc", - "tensorpipe/transport/*/*.cc", - ], - exclude=TENSORPIPE_CUDA_SOURCES) - -cc_library( - name = "tensorpipe_cpu", - srcs = TENSORPIPE_CPU_SOURCES, - hdrs = TENSORPIPE_CPU_HEADERS + [":tensorpipe_cpu_config_header"], - includes = [ - ".", - ], - copts = [ - "-std=c++20", - ], - visibility = ["//visibility:public"], - deps = [ - ":libnop", - ":libuv", - ], -) - -cc_library( - name = "tensorpipe_cuda", - srcs = glob(TENSORPIPE_CUDA_SOURCES), - hdrs = glob(TENSORPIPE_CUDA_HEADERS) + [":tensorpipe_cuda_config_header"], - includes = [ - ".", - ], - copts = [ - "-std=c++20", - ], - visibility = ["//visibility:public"], - deps = [ - ":tensorpipe_cpu", - "@cuda", - ], -) diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index aea099c6b53d2..4ea930d029ddf 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -147,6 +147,29 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F deps = XNN_COMMON_MICROKERNEL_EXPORTED_DEPS, ) + fb_xplat_cxx_library( + name = "ukernels_wasm", + srcs = select({ + "DEFAULT": [], + "ovr_config//runtime:wasm-emscripten": prod_srcs_for_arch_wrapper("wasm"), + }), + headers = get_xnnpack_headers(), + header_namespace = "", + apple_sdks = (IOS, MACOSX), + compiler_flags = [ + "-O2", + "-fno-fast-math", + "-fno-math-errno", + "-ffp-contract=off", + ] + WASM_EMSCRIPTEN_COMPILER_FLAGS, + labels = labels, + fbandroid_link_whole = True, + preferred_linkage = "static", + preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, + visibility = ["PUBLIC"], + deps = XNN_COMMON_MICROKERNEL_EXPORTED_DEPS, + ) + fb_xplat_cxx_library( name = "ukernels_sse", srcs = select({ @@ -1828,6 +1851,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F "ovr_config//runtime:arm64-linux-ubuntu-neon": [":arm64_lib"], "ovr_config//runtime:fbcode-arm64": [":arm64_lib"], "ovr_config//runtime:platform010": [":x86_and_x86_64_lib"], + "ovr_config//runtime:wasm-emscripten": [":ukernels_wasm"], }), ) diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 03c3d150c8f7c..a475727240ef6 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -c6351981d75efe2cd213e97e0b720f818da84e5e +5aa1a01d53822166ff64db98d566bb871e419cc6 diff --git a/tools/amd_build/build_amd.py b/tools/amd_build/build_amd.py index adceecd1c1c8c..7f9289cae498d 100755 --- a/tools/amd_build/build_amd.py +++ b/tools/amd_build/build_amd.py @@ -109,6 +109,7 @@ "torch/*", "tools/autograd/templates/python_variable_methods.cpp", "torch/csrc/stable/*", + "test/cpp/c10d/*", ] includes = [os.path.join(proj_dir, include) for include in includes] diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2894cff012dec..7a6df216205da 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -769,13 +769,22 @@ - name: grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor input, grid: "grad.defined() ? grid_sampler_2d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" +- name: grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + grad_output, input, grid: grid_sampler_2d_double_backward(grads[0], grads[1], grad_output, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) + - name: grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor input, grid: "grad.defined() ? grid_sampler_3d_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) : std::tuple()" +- name: grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + grad_output, input, grid: grid_sampler_3d_double_backward(grads[0], grads[1], grad_output, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) + # See NOTE [ grid_sample CPU fallback ] - name: _grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor input, grid: "grad.defined() ? _grid_sampler_2d_cpu_fallback_backward(grad, input, grid, interpolation_mode, padding_mode, align_corners) : std::tuple()" +- name: _grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + grad_output, input, grid: grid_sampler_2d_double_backward(grads[0], grads[1], grad_output, input, grid, interpolation_mode, padding_mode, align_corners, grad_input_mask) + - name: gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) self: zeros_like(self) result: self_t.zero_() @@ -1201,6 +1210,11 @@ self: scale_grad_by_count(restore_reduced_dims(grad, dim, keepdim), restore_reduced_dims(result, dim, keepdim) == self, dim) result: amaxamin_jvp(self_p, self_t, result, dim, keepdim) +- name: aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) + self: aminmax_backward(self, dim, keepdim, grad_min, grad_max, min, max) + min: aminmax_jvp(self_p, self_t, min, dim, keepdim) + max: aminmax_jvp(self_p, self_t, max, dim, keepdim) + - name: mm(Tensor self, Tensor mat2) -> Tensor self: mm_mat1_backward(grad, mat2, self.sym_sizes(), self.sym_strides(), self.layout(), 1) mat2: mm_mat2_backward(grad, self, mat2.sym_sizes(), mat2.sym_strides(), mat2.layout(), 1) @@ -1306,8 +1320,9 @@ reserve: not_implemented("batch_norm_backward reserve") - name: nextafter(Tensor self, Tensor other) -> Tensor - self: not_implemented("nextafter") - other: not_implemented("nextafter") + self: at::where(self != other, grad, 0) + other: zeros_like(other) + result: at::where(self_p != other_p, self_t, 0) - name: norm.Scalar(Tensor self, Scalar p=2) -> Tensor self: norm_backward(grad, self, p, result) @@ -2277,6 +2292,10 @@ self: _upsample_bicubic2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) result: auto_linear +- name: _upsample_lanczos2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + self: _upsample_lanczos2d_aa_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_h, scales_w) + result: auto_linear + - name: upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor self: upsample_trilinear3d_backward_symint(grad, output_size, self.sym_sizes(), align_corners, scales_d, scales_h, scales_w) result: auto_linear @@ -2676,6 +2695,10 @@ grad_output: _upsample_bicubic2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) result: auto_linear +- name: _upsample_lanczos2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + grad_output: _upsample_lanczos2d_aa_symint(grad, output_size, align_corners, scales_h, scales_w) + result: auto_linear + - name: upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor grad_output: upsample_trilinear3d_symint(grad, output_size, align_corners, scales_d, scales_h, scales_w) result: auto_linear @@ -2733,6 +2756,9 @@ - name: cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output self, grid: "grad.defined() ? cudnn_grid_sampler_backward(self, grid, grad) : std::tuple()" +- name: cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) + grad_output, self, grid: grid_sampler_2d_double_backward(grads[0], grads[1], grad_output, self, grid, 0, 0, true, grad_input_mask) + - name: cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid theta: cudnn_affine_grid_generator_backward(grad, N, C, H, W) diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 56e622d38d65d..7513e26c817e7 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -869,7 +869,6 @@ def save_var(var: SavedAttribute, is_output: bool) -> None: saved_variables.append(f"{type.cpp_type()} {name};") if type in MISC_GETTER_DEFS: - # pyrefly: ignore [bad-index, index-error] getter_def, body = MISC_GETTER_DEFS[type] getter_definitions.append( getter_def.substitute(op=info.op, name=name, body=body) @@ -1040,7 +1039,6 @@ def emit_derivative( unpack_ivalues = [] for typ, name in zip(apply_functional_args_ref_types, apply_functional_args): typ = typ.removesuffix("&") - # pyrefly: ignore [bad-argument-type] unpack_ivalues.append(f"auto {name} = packed_args.unpack<{typ}>();") schema_args = [f"std::array"] diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index a18485a902d35..ad3238feb2d42 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -92,10 +92,15 @@ _SKIP_PYTHON_BINDINGS = [ "alias", "contiguous", + "dim", + "get_device", + "is_contiguous", "is_cuda", "is_sparse", "is_sparse_csr", + "numel", "size", + "storage_offset", "stride", "sym_is_contiguous", "sym_size", @@ -972,7 +977,7 @@ def gen_has_torch_function_check( if noarg: if method: return f"""\ -if(check_has_torch_function(self_)) {{ +if (has_torch_function(self_)) {{ return handle_torch_function(self_, "{name}"); }} """ diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 0f41f72bc9106..7a7b75b87c346 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -262,6 +262,7 @@ "alias", "atan", "ldexp", + "linear", "log", "log10", "log1p", @@ -1902,14 +1903,12 @@ def emit_any_has_forward_grad() -> list[str]: ) ) cur_derivative_conditions.append( - # pyrefly: ignore [bad-argument-type] FW_DERIVATIVE_CHECK_TEMPLATE.substitute( req_inp=inp_name + "[i]" ) ) else: cur_derivative_conditions.append( - # pyrefly: ignore [bad-argument-type] FW_DERIVATIVE_CHECK_TEMPLATE.substitute(req_inp=inp_name) ) diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index e35f66dbe173a..9aa0093d87f11 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -97,10 +97,8 @@ def add_view_copy_derivatives( # prefer manually-defined derivatives if any # pyrefly: ignore [unbound-name] if len(view_copy_differentiability_infos) > 0 and fn_schema not in infos: - # pyrefly: ignore [unbound-name] if fn_schema is None: raise AssertionError("Expected fn_schema to be non-None") - # pyrefly: ignore [unbound-name] view_infos[fn_schema] = view_copy_differentiability_infos infos.update(view_infos) @@ -176,7 +174,7 @@ def load_derivatives( add_view_copy_derivatives(infos, view_groups) - # cache both loaded infos as well a a set of all the dispatch_keys/aliases + # cache both loaded infos as well as a set of all the dispatch_keys/aliases # that appear in derivatives.yaml. used_dispatch_keys is useful for generating # VariableType.cpp where we need a TORCH_LIBRARY_IMPL for every autograd dispatch key used _GLOBAL_LOAD_DERIVATIVE_CACHE[key] = infos, used_dispatch_keys @@ -409,7 +407,6 @@ def repl(m: Any) -> str: for arg_name in all_arg_names: if arg_name in diff_arg_names: arg_name = arg_name + "_t" - # pyrefly: ignore [bad-argument-type] new_args.append(arg_name) # TODO we are trolling @@ -961,7 +958,6 @@ def stride_expr(name: str) -> str: + f".sym_strides(), which returned a c10::SymIntArrayRef. formula={formula}" ) for nctype in nctypes: - # pyrefly: ignore [bad-assignment] name = ( nctype.name.name if isinstance(nctype.name, SpecialArgName) else nctype.name ) diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index bfc5b80835c4b..a5e2061c14a73 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -63,7 +63,7 @@ namespace torch::autograd { static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "_is_view", args); } auto& self_ = THPVariable_Unpack(self); @@ -80,7 +80,7 @@ static PyObject * THPVariable__is_view(PyObject *self, PyObject* args) static PyObject * THPVariable_apply_(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { auto args = py::make_tuple(py::handle(arg)); return handle_torch_function(self, "apply_", args.ptr()); } @@ -171,7 +171,7 @@ static PyObject * THPVariable_stride(PyObject* self, PyObject* args, PyObject* k static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self_)) { + if (has_torch_function(self_)) { return handle_torch_function(self_, "get_device", args, nullptr); } auto& self = THPVariable_Unpack(self_); @@ -182,7 +182,7 @@ static PyObject * THPVariable_get_device(PyObject* self_, PyObject* args) static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self_)) { + if (has_torch_function(self_)) { return handle_torch_function(self_, "has_names", args); } auto& self = THPVariable_Unpack(self_); @@ -194,7 +194,7 @@ static PyObject * THPVariable_has_names(PyObject* self_, PyObject* args) static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self_)) { + if (has_torch_function(self_)) { return handle_torch_function(self_, "data_ptr", args); } auto& self = THPVariable_Unpack(self_); @@ -203,10 +203,24 @@ static PyObject * THPVariable_data_ptr(PyObject* self_, PyObject* args) } // implemented on the python object to avoid dispatch overhead -static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) +// Unlike data_ptr(), this is a read-only access that does not trigger +// copy-on-write materialization. +static PyObject * THPVariable_const_data_ptr(PyObject* self_, PyObject* args) { HANDLE_TH_ERRORS if (check_has_torch_function(self_)) { + return handle_torch_function(self_, "const_data_ptr", args); + } + auto& self = THPVariable_Unpack(self_); + return wrap(const_cast(self.const_data_ptr())); + END_HANDLE_TH_ERRORS +} + +// implemented on the python object to avoid dispatch overhead +static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) +{ + HANDLE_TH_ERRORS + if (has_torch_function(self_)) { return handle_torch_function(self_, "storage_offset"); } auto& self = THPVariable_Unpack(self_); @@ -218,7 +232,7 @@ static PyObject * THPVariable_storage_offset(PyObject* self_, PyObject* args) static PyObject * THPVariable_dim(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "dim", args); } auto& self_ = THPVariable_Unpack(self); @@ -230,7 +244,7 @@ static PyObject * THPVariable_dim(PyObject* self, PyObject* args) static PyObject * THPVariable_numel(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "numel", args); } auto& self_ = THPVariable_Unpack(self); @@ -328,7 +342,7 @@ static T dispatch_to(const Tensor & self) { static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "__float__", args); } jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -340,7 +354,7 @@ static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "__complex__", args); } jit::tracer::warn("Converting a tensor to a Python complex", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -352,7 +366,7 @@ static PyObject * THPVariable_complex_scalar(PyObject* self, PyObject* args) { static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "__int__", args); } jit::tracer::warn("Converting a tensor to a Python integer", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -371,7 +385,7 @@ static PyObject * THPVariable_integral_scalar(PyObject* self, PyObject* args) { // called when used as a slice. static PyObject * THPVariable_index_scalar(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "__index__", args); } auto& self_ = THPVariable_Unpack(self); @@ -392,7 +406,7 @@ static Tensor dispatch_invert(const Tensor & self) { static PyObject * THPVariable_invert(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "__invert__", args); } auto& self_ = THPVariable_Unpack(self); @@ -789,7 +803,7 @@ static PyObject * THPVariable_bfloat16(PyObject* self, PyObject* args, PyObject* static PyObject * THPVariable_element_size(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "element_size", args); } auto& self_ = THPVariable_Unpack(self); @@ -867,7 +881,7 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO auto r = parser.parse(self_, args, kwargs, parsed_args); if(r.has_torch_function()){ - return handle_torch_function(r, self_, args, kwargs, PyObject_Type(self_), "torch.Tensor"); + return handle_torch_function(r, self_, args, kwargs, reinterpret_cast(Py_TYPE(self_)), "torch.Tensor"); } auto memory_format = r.memoryformat(0); @@ -880,7 +894,7 @@ static PyObject * THPVariable_is_contiguous(PyObject* self_, PyObject* args, PyO static PyObject * THPVariable_item(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "item", args); } jit::tracer::warn("Converting a tensor to a Python number", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -952,7 +966,7 @@ static PyObject * THPVariable_map2_(PyObject* self, PyObject* args, PyObject* kw static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "new", args, kwargs); } auto& self_ = THPVariable_Unpack(self); @@ -964,7 +978,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "new_tensor", args, kwargs); } auto& self_ = THPVariable_Unpack(self); @@ -976,7 +990,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec static PyObject * THPVariable_storage(PyObject* self, PyObject* arg) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "untyped_storage"); } auto& self_ = THPVariable_Unpack(self); @@ -1027,7 +1041,7 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg static PyObject * THPVariable_tolist(PyObject* self, PyObject* args) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function(self, "tolist", args); } jit::tracer::warn("Converting a tensor to a Python list", jit::tracer::WARN_PYTHON_DATAFLOW); @@ -1092,7 +1106,7 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa ${py_methods} static PyObject * THPVariable_bool_scalar(PyObject* self, PyObject* args) { - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { HANDLE_TH_ERRORS return handle_torch_function(self, "__bool__", args); END_HANDLE_TH_ERRORS @@ -1297,6 +1311,7 @@ PyMethodDef variable_methods[] = { {"mtia", castPyCFunctionWithKeywords(THPVariable_mtia), METH_VARARGS | METH_KEYWORDS, nullptr}, {"xpu", castPyCFunctionWithKeywords(THPVariable_xpu), METH_VARARGS | METH_KEYWORDS, nullptr}, {"ipu", castPyCFunctionWithKeywords(THPVariable_ipu), METH_VARARGS | METH_KEYWORDS, nullptr}, + {"const_data_ptr", THPVariable_const_data_ptr, METH_NOARGS, nullptr}, {"data_ptr", THPVariable_data_ptr, METH_NOARGS, nullptr}, {"dim", THPVariable_dim, METH_NOARGS, nullptr}, {"has_names", THPVariable_has_names, METH_NOARGS, nullptr}, diff --git a/tools/bazel_tools/BUILD.bazel b/tools/bazel_tools/BUILD.bazel deleted file mode 100644 index f4c37fb7389bd..0000000000000 --- a/tools/bazel_tools/BUILD.bazel +++ /dev/null @@ -1,5 +0,0 @@ -sh_binary( - name = "shellwrap", - srcs = ["shellwrap.sh"], - visibility = ["//visibility:public"], -) diff --git a/tools/bazel_tools/shellwrap.sh b/tools/bazel_tools/shellwrap.sh deleted file mode 100755 index 712788ae09e06..0000000000000 --- a/tools/bazel_tools/shellwrap.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -# This script is helpful in entering an interactive shell from a bazel build -# before running a given bazel executable. -# This can provide a quick way to explore the sandbox directory and filesystem. -# Typical use is with -# -# bazel run --run_under=//tools/bazel:shell_wrapper //:target -# OR -# bazel run --config=shell //:target - -shell='/bin/bash' -rcfile='/tmp/pytorch_bazel_tools_shellwrap' -while [[ $# -gt 0 ]] ; do - case "$1" in - --shell_bin_path) - # path for the shell executable - shell="$2" - shift 2 - ;; - --rcfile) - # path for the file used to write the environment - rcfile="$2" - shift 2 - ;; - *) - # remaining arguments are part of the command for execution - break - ;; - esac -done - -if ! tty -s; then - echo 'A tty is not available.' - echo "Use \`bazel run\`, not \`bazel test\`." - exit 1 -fi - -NOCOLOR='\033[0m' -YELLOW='\033[1;33m' - -# store the environment in a file -export PYTORCH_SHELL_COMMAND=$* -echo "alias run=\"$*\"" > "$rcfile" -echo "PS1='\s-\v\$ '" >> "$rcfile" - -echo ===== -# print the execution command (command is yellow) -echo -e "alias run=${YELLOW}$PYTORCH_SHELL_COMMAND${NOCOLOR}" -echo ===== - -echo "Entering interactive shell at the execution root:" - -# quote escape all the arguments to use as a single input string -cmd="'$shell' --noprofile --rcfile '$rcfile'" - -# run the command in a script pseudo terminal and dump to null -/usr/bin/script -c "$cmd" -q /dev/null diff --git a/tools/build_pytorch_libs.py b/tools/build_pytorch_libs.py index 9d43de80f1298..2e30eb7e1e598 100644 --- a/tools/build_pytorch_libs.py +++ b/tools/build_pytorch_libs.py @@ -4,14 +4,8 @@ import platform import subprocess -from .optional_submodules import checkout_nccl from .setup_helpers.cmake import CMake, USE_NINJA -from .setup_helpers.env import ( - check_env_flag, - check_negative_env_flag, - IS_64BIT, - IS_WINDOWS, -) +from .setup_helpers.env import check_negative_env_flag, IS_64BIT, IS_WINDOWS def _get_vc_env(vc_arch: str) -> dict[str, str]: @@ -87,13 +81,6 @@ def build_pytorch( cmake: CMake, ) -> None: my_env = _create_build_env() - if ( - not check_negative_env_flag("USE_DISTRIBUTED") - and not check_negative_env_flag("USE_CUDA") - and not check_negative_env_flag("USE_NCCL") - and not check_env_flag("USE_SYSTEM_NCCL") - ): - checkout_nccl() build_test = not check_negative_env_flag("BUILD_TEST") cmake.generate( version, cmake_python_library, build_python, build_test, my_env, rerun_cmake diff --git a/tools/clean.py b/tools/clean.py new file mode 100644 index 0000000000000..e4f2b31adb8eb --- /dev/null +++ b/tools/clean.py @@ -0,0 +1,23 @@ +import glob +import os +import shutil +from pathlib import Path + + +def clean(): + """Clean, that is remove all files in .gitignore except in the NOT-CLEAN-FILES section.""" + ignores = Path(".gitignore").read_text(encoding="utf-8") + for wildcard in filter(None, ignores.splitlines()): + if wildcard.strip().startswith("#"): + if "BEGIN NOT-CLEAN-FILES" in wildcard: + # Marker is found and stop reading .gitignore. + break + # Ignore lines which begin with '#'. + else: + # Don't remove absolute paths from the system + wildcard = wildcard.lstrip("./") + for filename in glob.iglob(wildcard): + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) diff --git a/tools/code_analyzer/gen_operators_yaml.py b/tools/code_analyzer/gen_operators_yaml.py index ff2284e29d462..078d56f1e6071 100644 --- a/tools/code_analyzer/gen_operators_yaml.py +++ b/tools/code_analyzer/gen_operators_yaml.py @@ -200,7 +200,6 @@ def create_debug_info_from_selected_models( asset = model_info["asset"] hash = model_info["md5_hash"] - # pyrefly: ignore [missing-attribute] asset_info = model_dict["asset_info"].setdefault(asset, {}) # pyrefly: ignore [missing-attribute] diff --git a/tools/code_coverage/package/tool/clang_coverage.py b/tools/code_coverage/package/tool/clang_coverage.py index 36c353558927e..53cc43887fd40 100644 --- a/tools/code_coverage/package/tool/clang_coverage.py +++ b/tools/code_coverage/package/tool/clang_coverage.py @@ -77,7 +77,7 @@ def export_target( if binary_file is None: raise Exception( # noqa: TRY002 f"{merged_file} doesn't have corresponding binary!" - ) # noqa: TRY002 + ) print_log("start to export: ", merged_file) # run export cmd_shared_library = ( diff --git a/tools/code_coverage/package/tool/summarize_jsons.py b/tools/code_coverage/package/tool/summarize_jsons.py index aebff7a922da8..331b59d8d6fc9 100644 --- a/tools/code_coverage/package/tool/summarize_jsons.py +++ b/tools/code_coverage/package/tool/summarize_jsons.py @@ -68,7 +68,7 @@ def is_intrested_file( # ignore files that are not belong to pytorch if platform == TestPlatform.OSS: - # pyrefly: ignore [import-error, missing-import] + # pyrefly: ignore [missing-import] from package.oss.utils import get_pytorch_folder if not file_path.startswith(get_pytorch_folder()): diff --git a/tools/config/BUILD b/tools/config/BUILD deleted file mode 100644 index ba13eda2bba7b..0000000000000 --- a/tools/config/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -load("@bazel_skylib//lib:selects.bzl", "selects") - -config_setting( - name = "cuda", - define_values = { - "cuda": "true", - }, -) - -# Even when building with --config=cuda, host targets should be built with cuda disabled -# as these targets will run on CI machines that have no GPUs. -selects.config_setting_group( - name = "cuda_enabled_and_capable", - match_all = [ - ":cuda", - ], -) - -# Configures the system to build with cuda using clang. -config_setting( - name = "cuda_clang", - define_values = { - "cuda_clang": "true", - }, -) - -# Indicates that cuda code should be compiled with nvcc -# Mostly exists to support _analysis_ of tensorflow; more work is needed to actually make this -# setting work. -config_setting( - name = "cuda_nvcc", - define_values = { - "cuda_nvcc": "true", - }, -) - -config_setting( - name = "thread_sanitizer", - define_values = {"thread_sanitizer": "1"}, - visibility = ["//visibility:public"], -) diff --git a/tools/config/defs.bzl b/tools/config/defs.bzl deleted file mode 100644 index f8a1e9dc16f26..0000000000000 --- a/tools/config/defs.bzl +++ /dev/null @@ -1,65 +0,0 @@ -""" - Macros for selecting with / without various GPU libraries. Most of these are meant to be used - directly by tensorflow in place of their build's own configure.py + bazel-gen system. -""" - -load("@bazel_skylib//lib:selects.bzl", "selects") - -def if_cuda(if_true, if_false = []): - """Helper for selecting based on the whether CUDA is configured. """ - return selects.with_or({ - "@//tools/config:cuda_enabled_and_capable": if_true, - "//conditions:default": if_false, - }) - -def if_tensorrt(if_true, if_false = []): - """Helper for selecting based on the whether TensorRT is configured. """ - return select({ - "//conditions:default": if_false, - }) - -def if_rocm(if_true, if_false = []): - """Helper for selecting based on the whether ROCM is configured. """ - return select({ - "//conditions:default": if_false, - }) - -def if_sycl(if_true, if_false = []): - """Helper for selecting based on the whether SYCL/ComputeCPP is configured.""" - - # NOTE: Tensorflow expects some strange behavior (see their if_sycl) if we - # actually plan on supporting this at some point. - return select({ - "//conditions:default": if_false, - }) - -def if_ccpp(if_true, if_false = []): - """Helper for selecting based on the whether ComputeCPP is configured. """ - return select({ - "//conditions:default": if_false, - }) - -def cuda_default_copts(): - return if_cuda(["-DGOOGLE_CUDA=1"]) - -def cuda_default_features(): - return if_cuda(["-per_object_debug_info", "-use_header_modules", "cuda_clang"]) - -def rocm_default_copts(): - return if_rocm(["-x", "rocm"]) - -def rocm_copts(opts = []): - return rocm_default_copts() + if_rocm(opts) - -def cuda_is_configured(): - # FIXME(dcollins): currently only used by tensorflow's xla stuff, which we aren't building. However bazel - # query hits it so this needs to be defined. Because bazel doesn't actually resolve config at macro expansion - # time, `select` can't be used here (since xla expects lists of strings and not lists of select objects). - # Instead, the xla build rules must be rewritten to use `if_cuda_is_configured` - return False - -def if_cuda_is_configured(x): - return if_cuda(x, []) - -def if_rocm_is_configured(x): - return if_rocm(x, []) diff --git a/tools/download_mnist.py b/tools/download_mnist.py index 7c9432eddee9f..d9bbe1f413f2a 100644 --- a/tools/download_mnist.py +++ b/tools/download_mnist.py @@ -24,7 +24,6 @@ def report_download_progress( file_size: int, ) -> None: if file_size != -1: - # pyrefly: ignore [no-matching-overload] percent = min(1, (chunk_number * chunk_size) / file_size) bar = "#" * int(64 * percent) sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") diff --git a/tools/dynamo/gb_id_mapping.py b/tools/dynamo/gb_id_mapping.py index 35ec5fe6a0517..d6327222f897b 100644 --- a/tools/dynamo/gb_id_mapping.py +++ b/tools/dynamo/gb_id_mapping.py @@ -120,7 +120,6 @@ def extract_info_from_keyword(source: str, kw: ast.keyword) -> Any: evaluated_context = [] for value in kw.value.values: if isinstance(value, ast.FormattedValue): - # pyrefly: ignore [bad-argument-type] evaluated_context.append(f"{{{ast.unparse(value.value)}}}") elif isinstance(value, ast.Constant): # pyrefly: ignore [bad-argument-type] @@ -172,7 +171,6 @@ def find_unimplemented_calls( for kw in node.keywords: if kw.arg in info: - # pyrefly: ignore [unsupported-operation] info[kw.arg] = extract_info_from_keyword(source, kw) if info["gb_type"] is None: diff --git a/tools/dynamo/verify_dynamo.py b/tools/dynamo/verify_dynamo.py index 5591baa5552ec..b6a71d90b33df 100644 --- a/tools/dynamo/verify_dynamo.py +++ b/tools/dynamo/verify_dynamo.py @@ -6,7 +6,7 @@ import warnings -MIN_CUDA_VERSION = "11.6" +MIN_CUDA_VERSION = "12.1" MIN_ROCM_VERSION = "5.4" MIN_PYTHON_VERSION = (3, 10) @@ -174,7 +174,6 @@ def forward(self, x): return x + x mod = Module() - # pyrefly: ignore [bad-argument-type] opt_mod = dynamo.optimize(backend, nopython=True)(mod) for f in (fn, opt_mod): diff --git a/tools/experimental/torchfuzz/checks.py b/tools/experimental/torchfuzz/checks.py index 5b7b2e9da0e99..e1fc7452c1c63 100644 --- a/tools/experimental/torchfuzz/checks.py +++ b/tools/experimental/torchfuzz/checks.py @@ -25,6 +25,22 @@ def codegen(self, args_tuple: str) -> list[str]: ] +class EagerVsFullGraphDynamicCompileWithBackwardCheck(Check): + """Check that runs eager then fullgraph+dynamic compilation with backward pass.""" + + def codegen(self, args_tuple: str) -> list[str]: + return [ + f"args = {args_tuple}", + "result_original = fuzzed_program(*args)", + "result_original.sum().backward()", + "print('✅ eager + backward success')", + "compiled_program = torch.compile(fuzzed_program, fullgraph=True, dynamic=True)", + "result_compiled = compiled_program(*args)", + "result_compiled.sum().backward()", + "print('✅ compile + backward success')", + ] + + class EagerVsFullGraphDynamicCompileWithNumericsCheck(Check): """Check that runs eager and compiled, compares forward numerics.""" diff --git a/tools/experimental/torchfuzz/codegen.py b/tools/experimental/torchfuzz/codegen.py index da41392623bd9..b8a9c313ff1e8 100644 --- a/tools/experimental/torchfuzz/codegen.py +++ b/tools/experimental/torchfuzz/codegen.py @@ -1,5 +1,6 @@ # mypy: ignore-errors import os +import random import torch @@ -501,6 +502,176 @@ def epilogue_codegen(self): return [] +class StreamFuzzTemplate(DefaultFuzzTemplate): + """Template that wraps operations in random CUDA stream contexts. + + Reuses the same operator set as DefaultFuzzTemplate but partitions non-leaf + operations across 2-3 CUDA streams, inserting proper wait_stream + synchronization between dependent operations on different streams. + """ + + def __init__(self): + super().__init__() + from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithBackwardCheck + + self.check = EagerVsFullGraphDynamicCompileWithBackwardCheck() + + def imports_codegen(self): + return [ + "import torch", + ] + + def flags_codegen(self): + return [ + "torch.set_default_device('cuda')", + "torch._dynamo.config.capture_scalar_outputs = True", + ] + + def args_codegen(self, arg_operations): + """Generate args with requires_grad=True on float tensors. + + This ensures the backward pass traces through stream-wrapped operations, + exercising Inductor's stream handling in the backward graph. + """ + code_lines = super().args_codegen(arg_operations) + if arg_operations: + for i, (node_id, spec) in enumerate(arg_operations): + if isinstance(spec, TensorSpec) and spec.dtype in [ + torch.float32, + torch.float64, + torch.float16, + torch.bfloat16, + ]: + code_lines.append(f"arg_{i} = arg_{i}.requires_grad_(True)") + return code_lines + + @staticmethod + def wrap_body_with_streams( + generated_code_lines: list[str], + graph: OperationGraph, + ) -> list[str]: + """Wrap generated function body lines with CUDA stream contexts. + + Assigns each non-leaf operation to one of 2-3 random streams, wraps each + in ``with torch.cuda.stream(sN):``, and inserts ``wait_stream`` calls + between dependent operations on different streams. + """ + topo_order = graph.get_topological_order() + + # Identify leaf vs non-leaf node ids + leaf_ids = set() + non_leaf_ids = [] + for nid in topo_order: + node = graph.nodes[nid] + if ( + node.op_name == "arg" + or node.op_name.startswith("arg_") + or node.op_name == "constant" + ): + leaf_ids.add(nid) + else: + non_leaf_ids.append(nid) + + if not non_leaf_ids: + return generated_code_lines + + num_streams = random.randint(2, 3) + stream_names = [f"s{i + 1}" for i in range(num_streams)] + + # Decide sync strategy: wait_stream or event-based (record + wait_event) + use_events = random.choice([True, False]) + event_counter = 0 + + # Assign each non-leaf node to a random stream + node_stream: dict[str, str] = {} + for nid in non_leaf_ids: + node_stream[nid] = random.choice(stream_names) + + # Build a mapping from node_id -> the original code lines for that node. + # Each node produces lines prefixed with " " (4-space indent for the + # function body). We identify nodes by their ``var_{node_id} =`` pattern. + node_lines: dict[str, list[str]] = {} + current_node: str | None = None + current_buf: list[str] = [] + + for line in generated_code_lines: + stripped = line.strip() + # Detect lines like "var_node_3 = ..." or "var_node_3, _ = ..." + matched_node = None + for nid in topo_order: + if stripped.startswith((f"var_{nid} =", f"var_{nid},")): + matched_node = nid + break + + if matched_node is not None: + # Flush previous node buffer + if current_node is not None: + node_lines[current_node] = current_buf + current_node = matched_node + current_buf = [line] + else: + current_buf.append(line) + + # Flush last node + if current_node is not None: + node_lines[current_node] = current_buf + + # Rebuild the body with stream contexts and synchronization + new_lines: list[str] = [] + + # Stream variable declarations at the top of the function body + for sname in stream_names: + new_lines.append(f" {sname} = torch.cuda.Stream()") + + for nid in topo_order: + lines_for_node = node_lines.get(nid, []) + if nid in leaf_ids: + # Leaf nodes (args) stay on the default stream + new_lines.extend(lines_for_node) + continue + + stream = node_stream[nid] + node = graph.nodes[nid] + + # Insert synchronization for cross-stream dependencies + waited: set[str] = set() + for dep_id in node.input_nodes: + if dep_id in node_stream and node_stream[dep_id] != stream: + dep_stream = node_stream[dep_id] + if dep_stream not in waited: + if use_events: + ename = f"e{event_counter}" + event_counter += 1 + new_lines.append(f" {ename} = torch.cuda.Event()") + new_lines.append(f" {ename}.record({dep_stream})") + new_lines.append(f" {stream}.wait_event({ename})") + else: + new_lines.append(f" {stream}.wait_stream({dep_stream})") + waited.add(dep_stream) + + # Wrap the operation in a stream context + new_lines.append(f" with torch.cuda.stream({stream}):") + for code_line in lines_for_node: + # Each line already has 4-space indent; add 4 more for the with block + new_lines.append(" " + code_line) + + # Synchronize all streams before the return statement + if use_events: + for sname in stream_names: + ename = f"e{event_counter}" + event_counter += 1 + new_lines.append(f" {ename} = torch.cuda.Event()") + new_lines.append(f" {ename}.record({sname})") + new_lines.append(f" torch.cuda.current_stream().wait_event({ename})") + else: + for sname in stream_names: + new_lines.append( + f" torch.cuda.current_stream().wait_stream({sname})" + ) + + return new_lines + + class DTensorFuzzPlacementsTemplate(DTensorFuzzTemplate): """DTensor template with randomized placements (Replicate, Shard, Partial). @@ -656,6 +827,8 @@ def convert_graph_to_python_code( fuzz_template = DTensorFuzzPlacementsTemplate() elif template == "unbacked": fuzz_template = UnbackedFuzzTemplate() + elif template == "streams": + fuzz_template = StreamFuzzTemplate() else: fuzz_template = DefaultFuzzTemplate() @@ -730,6 +903,12 @@ def convert_graph_to_python_code( # Track this node's variable node_variables[node_id] = (output_var_name, output_spec) + # Wrap body with stream contexts if using the streams template + if template == "streams": + generated_code_lines = StreamFuzzTemplate.wrap_body_with_streams( + generated_code_lines, operation_graph + ) + # The final result comes from the root node root_node_id = operation_graph.root_node_id if root_node_id not in node_variables: diff --git a/tools/experimental/torchfuzz/fuzzer.py b/tools/experimental/torchfuzz/fuzzer.py index e1e1be477ea0e..d824fe0c6d0bb 100644 --- a/tools/experimental/torchfuzz/fuzzer.py +++ b/tools/experimental/torchfuzz/fuzzer.py @@ -262,7 +262,7 @@ def log(success: bool) -> None: ) parser.add_argument( "--template", - choices=["default", "dtensor", "dtensor_placements", "unbacked"], + choices=["default", "dtensor", "dtensor_placements", "unbacked", "streams"], default="default", help="Template to use for code generation (default: default)", ) diff --git a/tools/experimental/torchfuzz/operators/constant.py b/tools/experimental/torchfuzz/operators/constant.py index 18a3591146305..adc299982e58a 100644 --- a/tools/experimental/torchfuzz/operators/constant.py +++ b/tools/experimental/torchfuzz/operators/constant.py @@ -109,7 +109,6 @@ def codegen( ]: # Clamp integer values to [0, 3] to avoid index overflow in multiplication # Even with multiplication, indices should stay in reasonable range - # pyrefly: ignore [bad-argument-type] fill_value = max(0, min(3, abs(fill_value))) tensor_creation = ( diff --git a/tools/experimental/torchfuzz/ops_fuzzer.py b/tools/experimental/torchfuzz/ops_fuzzer.py index 9bea882bc5101..795c4fff60694 100644 --- a/tools/experimental/torchfuzz/ops_fuzzer.py +++ b/tools/experimental/torchfuzz/ops_fuzzer.py @@ -51,6 +51,10 @@ def _get_template_filtered_operators( from torchfuzz.codegen import UnbackedFuzzTemplate fuzz_template = UnbackedFuzzTemplate() + elif template == "streams": + from torchfuzz.codegen import StreamFuzzTemplate + + fuzz_template = StreamFuzzTemplate() else: from torchfuzz.codegen import DefaultFuzzTemplate @@ -252,6 +256,10 @@ def fuzz_spec(template: str = "default") -> Spec: from torchfuzz.codegen import UnbackedFuzzTemplate fuzz_template = UnbackedFuzzTemplate() + elif template == "streams": + from torchfuzz.codegen import StreamFuzzTemplate + + fuzz_template = StreamFuzzTemplate() else: from torchfuzz.codegen import DefaultFuzzTemplate diff --git a/tools/experimental/torchfuzz/test_determinism.py b/tools/experimental/torchfuzz/test_determinism.py index b260e561d0979..7c621d2e0cf21 100644 --- a/tools/experimental/torchfuzz/test_determinism.py +++ b/tools/experimental/torchfuzz/test_determinism.py @@ -36,7 +36,6 @@ def run_fuzzer_with_seed(seed): j = i + 1 block_lines = [] while j < len(lines) and not lines[j].startswith("==="): - # pyrefly: ignore [bad-argument-type] block_lines.append(lines[j]) j += 1 src_block = "\n".join(block_lines) diff --git a/tools/experimental/torchfuzz/test_streams_template.py b/tools/experimental/torchfuzz/test_streams_template.py new file mode 100644 index 0000000000000..b8fbf7e3fcb58 --- /dev/null +++ b/tools/experimental/torchfuzz/test_streams_template.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +"""Tests for the streams fuzzing template codegen.""" + +import os +import random +import sys +import unittest + + +# Add parent directory to path so we can import torchfuzz as a module +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + +import torch +from torchfuzz.codegen import convert_graph_to_python_code, StreamFuzzTemplate +from torchfuzz.ops_fuzzer import fuzz_operation_graph, fuzz_spec + + +class TestStreamsFuzzTemplate(unittest.TestCase): + def _generate_code(self, seed): + random.seed(seed) + torch.manual_seed(seed) + target_spec = fuzz_spec("streams") + graph = fuzz_operation_graph( + target_spec, max_depth=3, seed=seed, template="streams" + ) + return convert_graph_to_python_code(graph, seed=seed, template="streams") + + def test_template_inherits_default_ops(self): + template = StreamFuzzTemplate() + self.assertGreater(len(template.supported_ops), 0) + self.assertIn("torch.add", template.supported_ops) + self.assertIn("torch.matmul", template.supported_ops) + + def test_template_uses_backward_check(self): + from torchfuzz.checks import EagerVsFullGraphDynamicCompileWithBackwardCheck + + template = StreamFuzzTemplate() + self.assertIsInstance( + template.check, EagerVsFullGraphDynamicCompileWithBackwardCheck + ) + + def test_codegen_creates_streams(self): + code = self._generate_code(seed=999) + self.assertIn("torch.cuda.Stream()", code) + + def test_codegen_has_stream_context(self): + code = self._generate_code(seed=999) + self.assertIn("with torch.cuda.stream(", code) + + def test_codegen_has_final_sync(self): + code = self._generate_code(seed=999) + # Should sync all streams before return (wait_stream or wait_event) + self.assertTrue( + "torch.cuda.current_stream().wait_stream(" in code + or "torch.cuda.current_stream().wait_event(" in code + ) + + def test_codegen_has_backward(self): + code = self._generate_code(seed=999) + self.assertIn(".sum().backward()", code) + + def test_codegen_has_requires_grad(self): + """Float tensor args should have requires_grad for backward testing.""" + code = self._generate_code(seed=999) + self.assertIn("requires_grad_(True)", code) + + def test_codegen_cross_stream_sync(self): + """Seeds with cross-stream deps should have inter-stream sync.""" + # seed 999 produces a graph with ops on different streams + code = self._generate_code(seed=999) + # Should have either wait_stream or event-based sync between streams + has_wait_stream = ".wait_stream(s" in code + has_wait_event = ".wait_event(" in code + self.assertTrue( + has_wait_stream or has_wait_event, + "Expected cross-stream synchronization in generated code", + ) + + def test_codegen_event_based_sync(self): + """Some seeds should use event-based synchronization.""" + found_events = False + for seed in range(50, 70): + code = self._generate_code(seed=seed) + if "torch.cuda.Event()" in code and ".wait_event(" in code: + found_events = True + # When events are used, should have record + wait pattern + self.assertIn(".record(", code) + break + self.assertTrue(found_events, "No seed in range produced event-based sync") + + def test_codegen_deterministic(self): + """Same seed should produce identical code.""" + code1 = self._generate_code(seed=42) + code2 = self._generate_code(seed=42) + self.assertEqual(code1, code2) + + def test_codegen_is_valid_python(self): + """Generated code should be syntactically valid Python.""" + code = self._generate_code(seed=999) + compile(code, "", "exec") + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/gen_vulkan_spv.py b/tools/gen_vulkan_spv.py index 2cf8f7f799ac8..3a39f320f5c04 100644 --- a/tools/gen_vulkan_spv.py +++ b/tools/gen_vulkan_spv.py @@ -688,18 +688,15 @@ def genCppFiles( name = getName(spvPath).replace("_spv", "") sizeBytes, spv_bin_str = generateSpvBinStr(spvPath, name) - # pyrefly: ignore [bad-argument-type] spv_bin_strs.append(spv_bin_str) shader_info = getShaderInfo(srcPath) register_shader_info_strs.append( - # pyrefly: ignore [bad-argument-type] generateShaderInfoStr(shader_info, name, sizeBytes) ) if shader_info.register_for is not None: - # pyrefly: ignore [bad-argument-type] shader_registry_strs.append(generateShaderDispatchStr(shader_info, name)) spv_bin_arrays = "\n".join(spv_bin_strs) diff --git a/tools/jit/gen_unboxing.py b/tools/jit/gen_unboxing.py index 805bfa979aa41..2199f0d073e72 100644 --- a/tools/jit/gen_unboxing.py +++ b/tools/jit/gen_unboxing.py @@ -132,14 +132,12 @@ def __call__(self, f: NativeFunction) -> str: else: arg_cpp = f"c10::IValue({arg_default})" args_code.append( - # pyrefly: ignore [bad-argument-type] f"""c10::Argument("{arg.name}", nullptr, ::std::nullopt, {arg_cpp})""" ) returns = f.func.returns returns_code = [] for ret in returns: - # pyrefly: ignore [bad-argument-type] returns_code.append(f"""c10::Argument("{ret.name if ret.name else ""}")""") return f""" // aten::{schema} diff --git a/tools/linter/adapters/actionlint_linter.py b/tools/linter/adapters/actionlint_linter.py index 019f0fe896bcd..2884fe732625b 100644 --- a/tools/linter/adapters/actionlint_linter.py +++ b/tools/linter/adapters/actionlint_linter.py @@ -75,6 +75,10 @@ def check_file( '"runs-on" section must be sequence node but got mapping node with "!!map" tag', "-ignore", 'input "freethreaded" is not defined in action "actions/setup-python@v', + # GitHub increased workflow_dispatch limit to 25 inputs (Dec 2025). + # actionlint fixed this in v1.7.10; remove after upgrading from v1.7.7. + "-ignore", + 'maximum number of inputs for "workflow_dispatch" event is 10 but', file, ] ) diff --git a/tools/linter/adapters/bazel_linter.py b/tools/linter/adapters/bazel_linter.py deleted file mode 100644 index 926628d3d76a9..0000000000000 --- a/tools/linter/adapters/bazel_linter.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -This linter ensures that users don't set a SHA hash checksum in Bazel for the http_archive. -Although the security practice of setting the checksum is good, it doesn't work when the -archive is downloaded from some sites like GitHub because it can change. Specifically, -GitHub gives no guarantee to keep the same value forever. Check for more details at -https://github.com/community/community/discussions/46034. -""" - -from __future__ import annotations - -import argparse -import json -import re -import shlex -import subprocess -import sys -import xml.etree.ElementTree as ET -from enum import Enum -from typing import NamedTuple -from urllib.parse import urlparse - - -LINTER_CODE = "BAZEL_LINTER" -SHA256_REGEX = re.compile(r"\s*sha256\s*=\s*['\"](?P[a-zA-Z0-9]{64})['\"]\s*,") -DOMAINS_WITH_UNSTABLE_CHECKSUM = {"github.com"} - - -class LintSeverity(str, Enum): - ERROR = "error" - WARNING = "warning" - ADVICE = "advice" - DISABLED = "disabled" - - -class LintMessage(NamedTuple): - path: str | None - line: int | None - char: int | None - code: str - severity: LintSeverity - name: str - original: str | None - replacement: str | None - description: str | None - - -def is_required_checksum(urls: list[str | None]) -> bool: - if not urls: - return False - - for url in urls: - if not url: - continue - - parsed_url = urlparse(url) - if parsed_url.hostname in DOMAINS_WITH_UNSTABLE_CHECKSUM: - return False - - return True - - -def get_disallowed_checksums( - binary: str, -) -> set[str]: - """ - Return the set of disallowed checksums from all http_archive rules - """ - # Use bazel to get the list of external dependencies in XML format - proc = subprocess.run( - [binary, "query", "kind(http_archive, //external:*)", "--output=xml"], - capture_output=True, - check=True, - text=True, - ) - - root = ET.fromstring(proc.stdout) - - disallowed_checksums = set() - # Parse all the http_archive rules in the XML output - for rule in root.findall('.//rule[@class="http_archive"]'): - urls_node = rule.find('.//list[@name="urls"]') - if urls_node is None: - continue - urls = [n.get("value") for n in urls_node.findall(".//string")] - - checksum_node = rule.find('.//string[@name="sha256"]') - if checksum_node is None: - continue - checksum = checksum_node.get("value") - - if not checksum: - continue - - if not is_required_checksum(urls): - disallowed_checksums.add(checksum) - - return disallowed_checksums - - -def check_bazel( - filename: str, - disallowed_checksums: set[str], -) -> list[LintMessage]: - original = "" - replacement = "" - - with open(filename) as f: - for line in f: - original += f"{line}" - - m = SHA256_REGEX.match(line) - if m: - sha256 = m.group("sha256") - - if sha256 in disallowed_checksums: - continue - - replacement += f"{line}" - - if original == replacement: - return [] - - return [ - LintMessage( - path=filename, - line=None, - char=None, - code=LINTER_CODE, - severity=LintSeverity.ADVICE, - name="format", - original=original, - replacement=replacement, - description="Found redundant SHA checksums. Run `lintrunner -a` to apply this patch.", - ) - ] - - -def main() -> None: - parser = argparse.ArgumentParser( - description="A custom linter to detect redundant SHA checksums in Bazel", - fromfile_prefix_chars="@", - ) - parser.add_argument( - "--binary", - required=True, - help="bazel binary path", - ) - parser.add_argument( - "filenames", - nargs="+", - help="paths to lint", - ) - args = parser.parse_args() - - try: - disallowed_checksums = get_disallowed_checksums(args.binary) - except subprocess.CalledProcessError as err: - err_msg = LintMessage( - path=None, - line=None, - char=None, - code=__file__, - severity=LintSeverity.ADVICE, - name="command-failed", - original=None, - replacement=None, - description=( - f"COMMAND (exit code {err.returncode})\n" - f"{shlex.join(err.cmd)}\n\n" - f"STDERR\n{err.stderr or '(empty)'}\n\n" - f"STDOUT\n{err.stdout or '(empty)'}" - ), - ) - print(json.dumps(err_msg._asdict())) - return - except Exception as e: - err_msg = LintMessage( - path=None, - line=None, - char=None, - code=LINTER_CODE, - severity=LintSeverity.ERROR, - name="command-failed", - original=None, - replacement=None, - description=(f"Failed due to {e.__class__.__name__}:\n{e}"), - ) - print(json.dumps(err_msg._asdict()), flush=True) - sys.exit(0) - - for filename in args.filenames: - for lint_message in check_bazel(filename, disallowed_checksums): - print(json.dumps(lint_message._asdict()), flush=True) - - -if __name__ == "__main__": - main() diff --git a/tools/linter/adapters/clangformat_linter.py b/tools/linter/adapters/clangformat_linter.py index 0d82ddd939b15..9289dcd6375f3 100644 --- a/tools/linter/adapters/clangformat_linter.py +++ b/tools/linter/adapters/clangformat_linter.py @@ -73,7 +73,7 @@ def run_command( if remaining_retries == 0: raise err remaining_retries -= 1 - logging.warning( # noqa: G200 + logging.warning( "(%s/%s) Retrying because command failed with: %r", retries - remaining_retries, retries, diff --git a/tools/linter/adapters/clangtidy_linter.py b/tools/linter/adapters/clangtidy_linter.py index 62db5bb06fd8c..00c7eec5bd9a4 100644 --- a/tools/linter/adapters/clangtidy_linter.py +++ b/tools/linter/adapters/clangtidy_linter.py @@ -19,13 +19,11 @@ # PyTorch directory root def scm_root() -> str: path = os.path.abspath(os.getcwd()) - # pyrefly: ignore [bad-assignment] while True: if os.path.exists(os.path.join(path, ".git")): return path if os.path.isdir(os.path.join(path, ".hg")): return path - # pyrefly: ignore [bad-argument-type] n = len(path) path = os.path.dirname(path) if len(path) == n: diff --git a/tools/linter/adapters/codespell_linter.py b/tools/linter/adapters/codespell_linter.py index 527e3d4fbfb61..3c080b2359f5d 100644 --- a/tools/linter/adapters/codespell_linter.py +++ b/tools/linter/adapters/codespell_linter.py @@ -142,7 +142,7 @@ def check_dictionary(filename: str) -> list[LintMessage]: words_set = set(words) if len(words) != len(words_set): raise ValueError("The dictionary file contains duplicate entries.") - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] uncased_words = list(map(str.lower, words)) if uncased_words != sorted(uncased_words): raise ValueError( diff --git a/tools/linter/adapters/docstring_linter-grandfather.json b/tools/linter/adapters/docstring_linter-grandfather.json index f691f77008102..92fb43f07ede2 100644 --- a/tools/linter/adapters/docstring_linter-grandfather.json +++ b/tools/linter/adapters/docstring_linter-grandfather.json @@ -15,6 +15,14 @@ "class MixedMMH100": 133, "def MixedMMH100.get_best_choices()": 86 }, + "torch/_inductor/autoheuristic/artifacts/_PadMMA100.py": { + "class PadMMA100": 249, + "def PadMMA100.get_best_choices()": 221 + }, + "torch/_inductor/autoheuristic/artifacts/_PadMMH200.py": { + "class PadMMH200": 159, + "def PadMMH200.get_best_choices()": 131 + }, "torch/_inductor/bounds.py": { "class ValueRangeAnalysis": 108 }, @@ -31,7 +39,7 @@ }, "torch/_inductor/codegen/cpp.py": { "class CppKernelProxy": 616, - "class CppOverrides": 437, + "class CppOverrides": 496, "class CppScheduling": 819, "class CppVecKernel": 867, "class OuterLoopFusedSchedulerNode": 159, @@ -79,7 +87,7 @@ }, "torch/_inductor/codegen/halide.py": { "class HalideKernel": 999, - "class HalideOverrides": 339, + "class HalideOverrides": 384, "class HalidePrinter": 128, "def HalideKernel.halide_kernel_meta()": 82 }, diff --git a/tools/linter/adapters/flake8_linter.py b/tools/linter/adapters/flake8_linter.py index 621e060e2a284..1a869dd9df78d 100644 --- a/tools/linter/adapters/flake8_linter.py +++ b/tools/linter/adapters/flake8_linter.py @@ -188,7 +188,7 @@ def run_command( ): raise err remaining_retries -= 1 - logging.warning( # noqa: G200 + logging.warning( "(%s/%s) Retrying because command failed with: %r", retries - remaining_retries, retries, diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 82bd495df952a..50de963461f39 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -20,7 +20,6 @@ from pathlib import Path from typing import NamedTuple -# pyrefly: ignore [import-error] import isort import usort diff --git a/tools/linter/adapters/pyrefly_linter.py b/tools/linter/adapters/pyrefly_linter.py index ad9442c92bf05..27ae0cc2f4c3b 100644 --- a/tools/linter/adapters/pyrefly_linter.py +++ b/tools/linter/adapters/pyrefly_linter.py @@ -5,7 +5,7 @@ # "numpy==2.1.0 ; python_version >= '3.12' and python_version <= '3.13'", # "numpy==2.3.4 ; python_version >= '3.14'", # "expecttest==0.3.0", -# "pyrefly==0.52.0", +# "pyrefly==0.58.0", # "sympy==1.13.3", # "types-requests==2.27.25", # "types-pyyaml==6.0.2", diff --git a/tools/linter/adapters/ruff_linter.py b/tools/linter/adapters/ruff_linter.py index bad80bef925a2..77e1aa717855d 100644 --- a/tools/linter/adapters/ruff_linter.py +++ b/tools/linter/adapters/ruff_linter.py @@ -118,7 +118,7 @@ def run_command( if remaining_retries == 0: raise err remaining_retries -= 1 - logging.warning( # noqa: G200 + logging.warning( "(%s/%s) Retrying because command failed with: %r", retries - remaining_retries, retries, diff --git a/tools/linter/adapters/s3_init.py b/tools/linter/adapters/s3_init.py index 2d84cf719f310..e60546a6342c8 100644 --- a/tools/linter/adapters/s3_init.py +++ b/tools/linter/adapters/s3_init.py @@ -55,7 +55,6 @@ def report_download_progress( Pretty printer for file download progress. """ if file_size != -1: - # pyrefly: ignore [no-matching-overload] percent = min(1, (chunk_number * chunk_size) / file_size) bar = "#" * int(64 * percent) sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%") diff --git a/tools/linter/adapters/s3_init_config.json b/tools/linter/adapters/s3_init_config.json index 28843aa49fd4e..67e50fd2c347d 100644 --- a/tools/linter/adapters/s3_init_config.json +++ b/tools/linter/adapters/s3_init_config.json @@ -59,15 +59,5 @@ "download_url": "https://oss-clang-format.s3.us-east-2.amazonaws.com/actionlint/1.7.7/Linux_arm64/actionlint", "hash": "446687e63fac45472b0a66bae28975c28678af062670af119c11a7087baf35cc" } - }, - "bazel": { - "Darwin": { - "download_url": "https://raw.githubusercontent.com/bazelbuild/bazelisk/v1.16.0/bazelisk.py", - "hash": "1f6d76d023ddd5f1625f34d934418e7334a267318d084f31be09df8a8835ed16" - }, - "Linux": { - "download_url": "https://raw.githubusercontent.com/bazelbuild/bazelisk/v1.16.0/bazelisk.py", - "hash": "1f6d76d023ddd5f1625f34d934418e7334a267318d084f31be09df8a8835ed16" - } } } diff --git a/tools/linter/adapters/stable_shim_usage_linter.py b/tools/linter/adapters/stable_shim_usage_linter.py index 9b130ce5755b1..c2cb92ee419f7 100644 --- a/tools/linter/adapters/stable_shim_usage_linter.py +++ b/tools/linter/adapters/stable_shim_usage_linter.py @@ -82,7 +82,7 @@ def get_shim_functions( functions: dict[str, tuple[int, int]] = {} # Match function declarations like: AOTI_TORCH_EXPORT ... function_name( - function_pattern = re.compile(r"AOTI_TORCH_EXPORT\s+\w+\s+(\w+)\s*\(") + function_pattern = re.compile(r"AOTI_TORCH_EXPORT.+?(\w+)\s*\(") # Also match typedef function pointers typedef_pattern = re.compile(r"typedef\s+.*\(\*(\w+)\)") # Match using declarations like: using TypeName = ... diff --git a/tools/linter/adapters/test_has_main_linter.py b/tools/linter/adapters/test_has_main_linter.py index 7f7da52ba5fb6..460cabb45d991 100644 --- a/tools/linter/adapters/test_has_main_linter.py +++ b/tools/linter/adapters/test_has_main_linter.py @@ -20,10 +20,7 @@ from enum import Enum from typing import NamedTuple -# pyrefly: ignore [import-error] import libcst as cst - -# pyrefly: ignore [import-error] import libcst.matchers as m diff --git a/tools/linter/adapters/workflow_consistency_linter.py b/tools/linter/adapters/workflow_consistency_linter.py index 54687a1ff4214..f931a117b6978 100644 --- a/tools/linter/adapters/workflow_consistency_linter.py +++ b/tools/linter/adapters/workflow_consistency_linter.py @@ -116,6 +116,15 @@ def get_jobs_with_sync_tag( # and ['with']['tests-to-include'], since dispatch filters differ if "tests-to-include" in job.get("with", {}): del job["with"]["tests-to-include"] + # and ['with']['build-environment'], since GPU-specific suffixes differ for ROCm + if ( + "build-environment" in job.get("with", {}) + and "rocm" in job["with"]["build-environment"] + ): + del job["with"]["build-environment"] + # and ['name'], since ROCm jobs append a GPU-specific suffix to the job name + if "name" in job and "rocm" in job.get("name", ""): + del job["name"] # normalize needs: remove helper job-filter so comparisons ignore it needs = job.get("needs") diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index 5d6889d275551..f90d33c5ba452 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -73,7 +73,6 @@ def get_selected_kernel_dtypes_code( for kernel_tag, dtypes in selective_builder.kernel_metadata.items(): conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes] body_parts.append( - # pyrefly: ignore [bad-argument-type] if_condition_template.substitute( kernel_tag_name=kernel_tag, dtype_checks=" || ".join(conditions), diff --git a/tools/lldb/pytorch_lldb.py b/tools/lldb/pytorch_lldb.py index 57fa1d7e61b92..4907bc1e6854a 100644 --- a/tools/lldb/pytorch_lldb.py +++ b/tools/lldb/pytorch_lldb.py @@ -4,7 +4,6 @@ def get_target() -> Any: - # pyrefly: ignore [missing-attribute] target = lldb.debugger.GetSelectedTarget() if not target: print("[-] error: no target available. please add a target to lldb.") diff --git a/tools/nightly.py b/tools/nightly.py index 88ba7ccb2ba81..0789241e0ca30 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -288,7 +288,7 @@ def site_packages(self, python: Path | str | None = None) -> Path: python=python, capture_output=True, ).stdout - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] candidates = list(map(Path, filter(None, map(str.strip, output.splitlines())))) candidates = [p for p in candidates if p.is_dir() and p.name == "site-packages"] if not candidates: @@ -723,7 +723,7 @@ def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, N logging_record_exception(e) print(f"log file: {log_file}") sys.exit(1) - except BaseException as e: # noqa: B036 + except BaseException as e: # You could logging.debug here to suppress the backtrace # entirely, but there is no reason to hide it from technically # savvy users. diff --git a/tools/nightly_hotpatch.py b/tools/nightly_hotpatch.py index f4d3ab4e95fe9..97c9dac1625b1 100644 --- a/tools/nightly_hotpatch.py +++ b/tools/nightly_hotpatch.py @@ -118,7 +118,6 @@ def download_patch(pr_number: int, repo_url: str, download_dir: str) -> str: urllib.request.urlopen(patch_url) as response, open(patch_file, "wb") as out_file, ): - # pyrefly: ignore [bad-specialization] shutil.copyfileobj(response, out_file) if not os.path.isfile(patch_file): print(f"Failed to download patch for PR #{pr_number}") diff --git a/tools/optional_submodules.py b/tools/optional_submodules.py index c3c4dacaa4032..5164acf69058b 100644 --- a/tools/optional_submodules.py +++ b/tools/optional_submodules.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from subprocess import check_call @@ -26,16 +27,18 @@ def _checkout_by_tag(repo: str, tag: str) -> None: ) -def read_nccl_pin() -> str: +def read_nccl_pin(cuda_version: str = "") -> str: # Default NCCL version nccl_file = "nccl.txt" # If NCCL version diverges for different CUDA versions, uncomment the # following block and add the appropriate file (using CUDA 11 as an example) - # cuda_version = os.getenv("DESIRED_CUDA", os.getenv("CUDA_VERSION", "")) - # if cuda_version.startswith("11"): - # nccl_file = "nccl-cu11.txt" + # 12.6 builds for sm50, needs a lower version + if not cuda_version: + cuda_version = os.getenv("DESIRED_CUDA", os.getenv("CUDA_VERSION", "")) + if cuda_version.startswith("12.6") or cuda_version == "cu126": + nccl_file = "nccl-cu126.txt" nccl_pin_path = repo_root / ".ci" / "docker" / "ci_commit_pins" / nccl_file return _read_file(nccl_pin_path) diff --git a/tools/packaging/build_wheel.py b/tools/packaging/build_wheel.py index 5f6f262ab8204..dad2d80849673 100644 --- a/tools/packaging/build_wheel.py +++ b/tools/packaging/build_wheel.py @@ -114,7 +114,7 @@ def _find_manylinux_interpreters() -> list[str]: ) except subprocess.CalledProcessError as e: - logger.debug("Failed to get version for %s: %s", python_path, e) # noqa:G200 + logger.debug("Failed to get version for %s: %s", python_path, e) continue return interpreters diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py index 44f0360c1fb33..816480d22ca75 100644 --- a/tools/pyi/gen_pyi.py +++ b/tools/pyi/gen_pyi.py @@ -256,9 +256,12 @@ def sig_for_ops(opname: str) -> list[str]: ] elif name in arithmetic_ops: if name.startswith("i"): - # In-place binary-operation dunder methods, like `__iadd__`, should return `Self` + # In-place binary-operation dunder methods, like `__iadd__`, should return `Self`. + # `__idiv__` is not a real Python 3 in-place dunder (Python 3 uses `__itruediv__` / + # `__ifloordiv__`), so ruff's PYI034 doesn't fire on it and the noqa would be unused. + suffix = "" if name == "idiv" else " # noqa: PYI034" return [ - f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ... # noqa: PYI034" + f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ...{suffix}" ] return [f"def {opname}(self, other: Tensor | Number | _complex) -> Tensor: ..."] elif name in logic_ops: @@ -996,7 +999,7 @@ def add_docstr_to_hint(docstr: str, hint: str) -> str: hint = hint.removesuffix("...").rstrip() # remove "..." content = hint + "\n" + textwrap.indent(f'r"""\n{docstr}\n"""', prefix=" ") # Remove trailing whitespace on each line - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] return "\n".join(map(str.rstrip, content.splitlines())).rstrip() # attribute or property @@ -1072,7 +1075,7 @@ def gen_pyi( "dtype: _dtype | None = None", "device: DeviceLikeType | None = None", "copy: _bool | None = None", - "requires_grad: _bool = False", + "requires_grad: _bool | None = None", ], "Tensor", ) diff --git a/tools/rules/BUILD b/tools/rules/BUILD deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tools/rules/cu.bzl b/tools/rules/cu.bzl deleted file mode 100644 index 8e137f63304cd..0000000000000 --- a/tools/rules/cu.bzl +++ /dev/null @@ -1,42 +0,0 @@ -load("@rules_cuda//cuda:defs.bzl", "cuda_library") - -NVCC_COPTS = [ - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--compiler-options=-Werror=all", - # The following warnings come from -Wall. We downgrade them from - # error to warnings here. - # - # sign-compare has a tremendous amount of violations in the - # codebase. It will be a lot of work to fix them, just disable it - # for now. - "--compiler-options=-Wno-sign-compare", - # We intentionally use #pragma unroll, which is compiler specific. - "--compiler-options=-Wno-error=unknown-pragmas", - "--compiler-options=-Werror=extra", - # The following warnings come from -Wextra. We downgrade them from - # error to warnings here. - # - # unused-parameter-compare has a tremendous amount of violations - # in the codebase. It will be a lot of work to fix them, just - # disable it for now. - "--compiler-options=-Wno-unused-parameter", - # missing-field-parameters has both a large number of violations - # in the codebase, but it also is used pervasively in the Python C - # API. There are a couple of catches though: - # * we use multiple versions of the Python API and hence have - # potentially multiple different versions of each relevant - # struct. They may have different numbers of fields. It will be - # unwieldy to support multiple versions in the same source file. - # * Python itself for many of these structs recommends only - # initializing a subset of the fields. We should respect the API - # usage conventions of our dependencies. - # - # Hence, we just disable this warning altogether. We may want to - # clean up some of the clear-cut cases that could be risky, but we - # still likely want to have this disabled for the most part. - "-Wno-missing-field-initializers", -] - -def cu_library(name, srcs, copts = [], **kwargs): - cuda_library(name, srcs = srcs, copts = NVCC_COPTS + copts, **kwargs) diff --git a/tools/rules/workspace.bzl b/tools/rules/workspace.bzl deleted file mode 100644 index 178cc256b1557..0000000000000 --- a/tools/rules/workspace.bzl +++ /dev/null @@ -1,54 +0,0 @@ -def _impl(repository_ctx): - archive = repository_ctx.attr.name + ".tar" - reference = Label("@%s_unpatched//:README" % repository_ctx.attr.name) - dirname = repository_ctx.path(reference).dirname - repository_ctx.execute(["tar", "hcf", archive, "-C", dirname, "."]) - repository_ctx.extract(archive) - for patch in repository_ctx.attr.patches: - repository_ctx.patch(repository_ctx.path(patch), repository_ctx.attr.patch_strip) - build_file = repository_ctx.path(repository_ctx.attr.build_file) - repository_ctx.execute(["cp", build_file, "BUILD.bazel"]) - -_patched_rule = repository_rule( - implementation = _impl, - attrs = { - "build_file": attr.label(), - "patch_strip": attr.int(), - "patches": attr.label_list(), - }, -) - -def new_patched_local_repository(name, path, **kwargs): - native.new_local_repository( - name = name + "_unpatched", - build_file_content = """ -pkg_tar(name = "content", srcs = glob(["**"])) -""", - path = path, - ) - _patched_rule(name = name, **kwargs) - -def _new_empty_repository_impl(repo_ctx): - build_file = repo_ctx.attr.build_file - build_file_content = repo_ctx.attr.build_file_content - if not (bool(build_file) != bool(build_file_content)): - fail("Exactly one of 'build_file' or 'build_file_content' is required") - - if build_file_content: - repo_ctx.file("BUILD", build_file_content) - elif build_file: - repo_ctx.template("BUILD", repo_ctx.attr.build_file, {}) - -new_empty_repository = repository_rule( - attrs = { - "build_file": attr.label(allow_files = True), - "build_file_content": attr.string(), - }, - implementation = _new_empty_repository_impl, -) - -"""Create an empty repository with the supplied BUILD file. - -This is mostly useful to create wrappers for specific target that we want -to be used with the '@' syntax. -""" diff --git a/tools/rules_cc/cc_library_shim.patch b/tools/rules_cc/cc_library_shim.patch new file mode 100644 index 0000000000000..5ade9601cc165 --- /dev/null +++ b/tools/rules_cc/cc_library_shim.patch @@ -0,0 +1,9 @@ +diff --git cc/cc_library.bzl cc/cc_library.bzl +new file mode 100644 +--- /dev/null ++++ cc/cc_library.bzl +@@ -0,0 +1,4 @@ ++"""Shim: re-export native cc_library for rules_cc 0.2.x compat.""" ++ ++def cc_library(**kwargs): ++ native.cc_library(**kwargs) diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 88f0fe5d3094a..8ae8378b6aaff 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -12,7 +12,6 @@ import sysconfig from pathlib import Path from subprocess import CalledProcessError, check_call, check_output, DEVNULL -from typing import cast from .cmake_utils import CMakeValue, get_cmake_cache_variables_from_file from .env import ( @@ -238,157 +237,48 @@ def generate( toolset_expr = ",".join([f"{k}={v}" for k, v in toolset_dict.items()]) args.append("-T" + toolset_expr) + # base_dir is used as cmake's source-dir arg and install prefix; + # make it relative to build_dir so these are worktree-independent + # (ccache/re-cc friendly). cmake runs with cwd=build_dir so the + # relative path resolves correctly. base_dir = str(Path(__file__).absolute().parents[2]) + if os.environ.get("USE_RELATIVE_PATHS"): + base_dir = os.path.relpath( + str(Path(__file__).resolve().parents[2]), self.build_dir + ) install_dir = os.path.join(base_dir, "torch") _mkdir_p(install_dir) _mkdir_p(self.build_dir) - # Store build options that are directly stored in environment variables - build_options: dict[str, CMakeValue] = {} - - # Build options that do not start with "BUILD_", "USE_", or "CMAKE_" and are directly controlled by env vars. - # This is a dict that maps environment variables to the corresponding variable name in CMake. - additional_options = { - # Key: environment variable name. Value: Corresponding variable name to be passed to CMake. If you are - # adding a new build option to this block: Consider making these two names identical and adding this option - # in the block below. - "CUDNN_LIB_DIR": "CUDNN_LIBRARY", - "USE_CUDA_STATIC_LINK": "CAFFE2_STATIC_LINK_CUDA", - } - additional_options.update( - { - # Build options that have the same environment variable name and CMake variable name and that do not start - # with "BUILD_", "USE_", or "CMAKE_". If you are adding a new build option, also make sure you add it to - # CMakeLists.txt. - var: var - for var in ( - "UBSAN_FLAGS", - "BLAS", - "WITH_BLAS", - "CUDA_HOST_COMPILER", - "CUDA_NVCC_EXECUTABLE", - "CUDA_SEPARABLE_COMPILATION", - "CUDNN_LIBRARY", - "CUDNN_INCLUDE_DIR", - "CUDNN_ROOT", - "EXPERIMENTAL_SINGLE_THREAD_POOL", - "INSTALL_TEST", - "JAVA_HOME", - "INTEL_MKL_DIR", - "INTEL_OMP_DIR", - "MKL_THREADING", - "MKLDNN_CPU_RUNTIME", - "MSVC_Z7_OVERRIDE", - "CAFFE2_USE_MSVC_STATIC_RUNTIME", - "Numa_INCLUDE_DIR", - "Numa_LIBRARIES", - "ONNX_ML", - "ONNX_NAMESPACE", - "ATEN_THREADING", - "WERROR", - "OPENSSL_ROOT_DIR", - "STATIC_DISPATCH_BACKEND", - "SELECTED_OP_LIST", - "TORCH_CUDA_ARCH_LIST", - "TORCH_XPU_ARCH_LIST", - "TRACING_BASED", - "PYTHON_LIB_REL_PATH", - ) - } - ) + # Environment variable forwarding (BUILD_*, USE_*, CMAKE_*, aliases, + # passthrough vars, CMAKE_PREFIX_PATH, low-priority aliases) is now + # handled by cmake/EnvVarForwarding.cmake, which is included early in + # the top-level CMakeLists.txt. Only options that require Python-side + # detection are passed here. - # Aliases which are lower priority than their canonical option - low_priority_aliases = { - "CUDA_HOST_COMPILER": "CMAKE_CUDA_HOST_COMPILER", - "CUDAHOSTCXX": "CUDA_HOST_COMPILER", - "CMAKE_CUDA_HOST_COMPILER": "CUDA_HOST_COMPILER", - "CMAKE_CUDA_COMPILER": "CUDA_NVCC_EXECUTABLE", - "CUDACXX": "CUDA_NVCC_EXECUTABLE", - } - for var, val in my_env.items(): - # We currently pass over all environment variables that start with "BUILD_", "USE_", and "CMAKE_". This is - # because we currently have no reliable way to get the list of all build options we have specified in - # CMakeLists.txt. (`cmake -L` won't print dependent options when the dependency condition is not met.) We - # will possibly change this in the future by parsing CMakeLists.txt ourselves (then additional_options would - # also not be needed to be specified here). - true_var = additional_options.get(var) - if true_var is not None: - build_options[true_var] = val - elif var.startswith(("BUILD_", "USE_", "CMAKE_")) or var.endswith( - ("EXITCODE", "EXITCODE__TRYRUN_OUTPUT") - ): - build_options[var] = val - - if var in low_priority_aliases: - key = low_priority_aliases[var] - if key not in build_options: - build_options[key] = val - - # The default value cannot be easily obtained in CMakeLists.txt. We set it here. - py_lib_path = sysconfig.get_path("purelib") - cmake_prefix_path = build_options.get("CMAKE_PREFIX_PATH") - if cmake_prefix_path: - build_options["CMAKE_PREFIX_PATH"] = ( - py_lib_path + ";" + cast(str, cmake_prefix_path) - ) - else: - build_options["CMAKE_PREFIX_PATH"] = py_lib_path - - # Some options must be post-processed. Ideally, this list will be shrunk to only one or two options in the - # future, as CMake can detect many of these libraries pretty comfortably. We have them here for now before CMake - # integration is completed. They appear here not in the CMake.defines call below because they start with either - # "BUILD_" or "USE_" and must be overwritten here. - use_numpy = not check_negative_env_flag("USE_NUMPY") - build_options.update( - { - # Note: Do not add new build options to this dict if it is directly read from environment variable -- you - # only need to add one in `CMakeLists.txt`. All build options that start with "BUILD_", "USE_", or "CMAKE_" - # are automatically passed to CMake; For other options you can add to additional_options above. - "BUILD_PYTHON": build_python, - "BUILD_TEST": build_test, - # Most library detection should go to CMake script, except this one, which Python can do a much better job - # due to NumPy's inherent Pythonic nature. - "USE_NUMPY": use_numpy, - } - ) - - # Detect build dependencies from python lib path (in order to set *_HOME variables) - # NVSHMEM - nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem" - if os.path.exists(nvshmem_py_dir): - build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir - - # Options starting with CMAKE_ - cmake__options = { + build_options: dict[str, CMakeValue] = { "CMAKE_INSTALL_PREFIX": install_dir, + "BUILD_PYTHON": build_python, + "BUILD_TEST": build_test, } - # We set some CMAKE_* options in our Python build code instead of relying on the user's direct settings. Emit an - # error if the user also attempts to set these CMAKE options directly. - specified_cmake__options = set(build_options).intersection(cmake__options) - if len(specified_cmake__options) > 0: - eprint( - ", ".join(specified_cmake__options) - + " should not be specified in the environment variable. They are directly set by PyTorch build script." - ) - sys.exit(1) - build_options.update(cmake__options) - + use_numpy = not check_negative_env_flag("USE_NUMPY") + build_options["USE_NUMPY"] = use_numpy if use_numpy: try: - # This helps CMake find the correct include directory for NumPy - # This is especially useful in cross compiled environments import numpy - Python_NumPy_INCLUDE_DIR = numpy.get_include() - build_options.update( - dict(Python_NumPy_INCLUDE_DIR=Python_NumPy_INCLUDE_DIR) - ) + build_options["Python_NumPy_INCLUDE_DIR"] = numpy.get_include() except ImportError: - # use_numpy is just a hint.... so we can fail silently here pass + # NVSHMEM detection from Python lib path + py_lib_path = sysconfig.get_path("purelib") + nvshmem_py_dir = py_lib_path + "/nvidia/nvshmem" + if os.path.exists(nvshmem_py_dir): + build_options["NVSHMEM_PY_DIR"] = nvshmem_py_dir + CMake.defines( args, Python_EXECUTABLE=sys.executable, diff --git a/tools/stats/monitor.py b/tools/stats/monitor.py index f2764aeb91134..059854bd462b6 100644 --- a/tools/stats/monitor.py +++ b/tools/stats/monitor.py @@ -459,7 +459,6 @@ def _initial_gpu_handler(self) -> None: self._gpu_lib_detected = "amdsmi" self._gpu_handles = amdsmi.amdsmi_get_processor_handles() - # pyrefly: ignore[bad-assignment] self._num_of_cpus = psutil.cpu_count(logical=True) # update summary info self._metadata.gpu_count = len(self._gpu_handles) diff --git a/tools/stats/upload_utilization_stats/upload_utilization_stats.py b/tools/stats/upload_utilization_stats/upload_utilization_stats.py index 66348e42a08a0..c2b16e9f5358f 100644 --- a/tools/stats/upload_utilization_stats/upload_utilization_stats.py +++ b/tools/stats/upload_utilization_stats/upload_utilization_stats.py @@ -59,7 +59,6 @@ def generate( df[time_col_name] = pd.to_datetime(df[time_col_name], unit="s", utc=True) # get unique cmd names - # pyrefly: ignore [bad-argument-type] unique_cmds_df = pd.DataFrame(df[cmd_col_name].unique(), columns=[cmd_col_name]) # get all detected python cmds diff --git a/tools/test/gen_operators_yaml_test.py b/tools/test/gen_operators_yaml_test.py index eaaebfbc06273..c801cd8012aef 100644 --- a/tools/test/gen_operators_yaml_test.py +++ b/tools/test/gen_operators_yaml_test.py @@ -7,7 +7,7 @@ from collections import defaultdict from unittest.mock import Mock, patch -# pyrefly: ignore [import-error, missing-import] +# pyrefly: ignore [missing-import] from gen_operators_yaml import ( fill_output, get_parser_options, diff --git a/tools/test/stable_shim_usage_linter_data/sample_shim.h b/tools/test/stable_shim_usage_linter_data/sample_shim.h index 1d9e0e913f31d..1eac43427e97e 100644 --- a/tools/test/stable_shim_usage_linter_data/sample_shim.h +++ b/tools/test/stable_shim_usage_linter_data/sample_shim.h @@ -98,6 +98,13 @@ AOTI_TORCH_EXPORT int primary_path(int arg); AOTI_TORCH_EXPORT int secondary_path(int arg); #endif + +// Function with a return type that consists of multiple words. +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_12_0 +AOTI_TORCH_EXPORT const char* function_that_returns_constchar(); +#endif // TORCH_FEATURE_VERSION >= TORCH_VERSION_2_12_0 + + #ifdef __cplusplus } // extern "C" #endif diff --git a/tools/test/test_cmake.py b/tools/test/test_cmake.py index 24a4eb56e8b7f..47cdfd52dbc0f 100644 --- a/tools/test/test_cmake.py +++ b/tools/test/test_cmake.py @@ -7,7 +7,7 @@ import unittest.mock import tools.setup_helpers.cmake -import tools.setup_helpers.env # noqa: F401 unused but resolves circular import +import tools.setup_helpers.env if typing.TYPE_CHECKING: diff --git a/tools/test/test_gen_backend_stubs.py b/tools/test/test_gen_backend_stubs.py index 5c7b5e4cfc83c..f69a33b215019 100644 --- a/tools/test/test_gen_backend_stubs.py +++ b/tools/test/test_gen_backend_stubs.py @@ -7,7 +7,7 @@ import expecttest -from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE # noqa: F401 +from torchgen.gen import _GLOBAL_PARSE_NATIVE_YAML_CACHE from torchgen.gen_backend_stubs import run @@ -123,7 +123,7 @@ def test_backend_invalid_dispatch_key(self) -> None: """\ unknown dispatch key NOT_XLA The provided value for "backend" must be a valid DispatchKey, but got NOT_XLA.""", - ) # noqa: B950 + ) def test_missing_cpp_namespace(self) -> None: yaml_str = """\ @@ -183,7 +183,7 @@ def test_backend_has_no_autograd_key_but_provides_entries(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, """Found an invalid operator name: add""" - ) # noqa: B950 + ) # in an operator group, currently all operators must either be registered to the backend or autograd kernel. # Here, functional and out mismatch @@ -198,7 +198,7 @@ def test_backend_autograd_kernel_mismatch_out_functional(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".""", # noqa: B950 + """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_out is listed under "autograd".""", ) # in an operator group, currently all operators must either be registered to the backend or autograd kernel. @@ -214,7 +214,7 @@ def test_backend_autograd_kernel_mismatch_functional_inplace(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".""", # noqa: B950 + """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add_ is listed under "autograd".""", ) # Currently, the same operator can't be listed under both 'supported' and 'autograd', which would @@ -231,7 +231,7 @@ def test_op_appears_in_supported_and_autograd_lists(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".""", # noqa: B950 + """Currently, all variants of an op must either be registered to a backend key, or to a backend's autograd key. They cannot be mix and matched. If this is something you need, feel free to create an issue! add is listed under "supported", but add is listed under "autograd".""", ) # unrecognized extra yaml key @@ -245,7 +245,7 @@ def test_unrecognized_key(self) -> None: output_error = self.get_errors_from_gen_backend_stubs(yaml_str) self.assertExpectedInline( output_error, - """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", # noqa: B950 + """ contains unexpected keys: invalid_key. Only the following keys are supported: backend, class_name, cpp_namespace, extra_headers, supported, autograd, full_codegen, non_native, ir_gen, symint""", ) # if use_out_as_primary is provided, it must be a bool @@ -260,7 +260,7 @@ def test_use_out_as_primary_non_bool(self) -> None: self.assertExpectedInline( output_error, """You must provide either True or False for use_out_as_primary. Provided: frue""", - ) # noqa: B950 + ) # if device_guard is provided, it must be a bool def test_device_guard_non_bool(self) -> None: @@ -274,7 +274,7 @@ def test_device_guard_non_bool(self) -> None: self.assertExpectedInline( output_error, """You must provide either True or False for device_guard. Provided: frue""", - ) # noqa: B950 + ) def test_incorrect_kernel_name(self) -> None: yaml_str = """\ diff --git a/tools/test/test_stable_shim_usage_linter.py b/tools/test/test_stable_shim_usage_linter.py index 6655dbf305fcc..22c515c3ec7b6 100644 --- a/tools/test/test_stable_shim_usage_linter.py +++ b/tools/test/test_stable_shim_usage_linter.py @@ -71,6 +71,8 @@ def test_get_shim_functions(self): # Primary path (2.10) and secondary path (2.9) from #if/#elif "primary_path": (2, 10), "secondary_path": (2, 9), + # Function with a return type made up of two words. + "function_that_returns_constchar": (2, 12), } self.assertEqual(result, expected) diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 70dc67888af30..ea8d3e208db54 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -473,7 +473,6 @@ def test_split_shards_random(self) -> None: self.assertTrue(sharded_tests[0].time is None) else: # x.time is not None because of the above check - # pyrefly: ignore [no-matching-overload] self.assertAlmostEqual( random_times[test], sum(x.time for x in sharded_tests), # type: ignore[misc] diff --git a/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py b/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py index 6cb3d52b531ff..1fe93a3ef601e 100644 --- a/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py +++ b/tools/testing/target_determination/heuristics/correlated_with_historical_failures.py @@ -19,7 +19,6 @@ class CorrelatedWithHistoricalFailures(HeuristicInterface): def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/edited_by_pr.py b/tools/testing/target_determination/heuristics/edited_by_pr.py index ed612273e9856..cc4641d85ecc0 100644 --- a/tools/testing/target_determination/heuristics/edited_by_pr.py +++ b/tools/testing/target_determination/heuristics/edited_by_pr.py @@ -29,7 +29,6 @@ class EditedByPR(HeuristicInterface): def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/filepath.py b/tools/testing/target_determination/heuristics/filepath.py index a775d2e786e8c..0bdc7c226d913 100644 --- a/tools/testing/target_determination/heuristics/filepath.py +++ b/tools/testing/target_determination/heuristics/filepath.py @@ -112,7 +112,6 @@ class Filepath(HeuristicInterface): # Heuristic based on folders in the file path. Takes each folder of each # changed file and attempts to find matches based on those folders def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py b/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py index 6a21e791cf3b9..326072dc1f47e 100644 --- a/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py +++ b/tools/testing/target_determination/heuristics/historical_class_failure_correlation.py @@ -29,7 +29,6 @@ class HistoricalClassFailurCorrelation(HeuristicInterface): """ def __init__(self, **kwargs: Any) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/historical_edited_files.py b/tools/testing/target_determination/heuristics/historical_edited_files.py index 28ec7769353fe..be855538ca254 100644 --- a/tools/testing/target_determination/heuristics/historical_edited_files.py +++ b/tools/testing/target_determination/heuristics/historical_edited_files.py @@ -24,7 +24,6 @@ # a correlation dict is built based on what files were edited in commits on main. class HistorialEditedFiles(HeuristicInterface): def __init__(self, **kwargs: Any) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/llm.py b/tools/testing/target_determination/heuristics/llm.py index 376909c686d49..6c6a4b1be21eb 100644 --- a/tools/testing/target_determination/heuristics/llm.py +++ b/tools/testing/target_determination/heuristics/llm.py @@ -21,7 +21,6 @@ class LLM(HeuristicInterface): def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/mentioned_in_pr.py b/tools/testing/target_determination/heuristics/mentioned_in_pr.py index c0066f9138ea7..38a2ae6d5474b 100644 --- a/tools/testing/target_determination/heuristics/mentioned_in_pr.py +++ b/tools/testing/target_determination/heuristics/mentioned_in_pr.py @@ -22,7 +22,6 @@ # mentions "test_foo", test_foo will be rated 1. class MentionedInPR(HeuristicInterface): def __init__(self, **kwargs: Any) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def _search_for_linked_issues(self, s: str) -> list[str]: diff --git a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py index f407353205350..bf0a9549cc9fd 100644 --- a/tools/testing/target_determination/heuristics/previously_failed_in_pr.py +++ b/tools/testing/target_determination/heuristics/previously_failed_in_pr.py @@ -25,7 +25,6 @@ class PreviouslyFailedInPR(HeuristicInterface): def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/profiling.py b/tools/testing/target_determination/heuristics/profiling.py index aad00efe99245..8f17c51ca11ee 100644 --- a/tools/testing/target_determination/heuristics/profiling.py +++ b/tools/testing/target_determination/heuristics/profiling.py @@ -22,7 +22,6 @@ # dict (where all ratings are 1). class Profiling(HeuristicInterface): def __init__(self, **kwargs: Any) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/target_determination/heuristics/public_bindings.py b/tools/testing/target_determination/heuristics/public_bindings.py index c1ae5397ccd94..efae0370956ea 100644 --- a/tools/testing/target_determination/heuristics/public_bindings.py +++ b/tools/testing/target_determination/heuristics/public_bindings.py @@ -18,7 +18,6 @@ class PublicBindings(HeuristicInterface): additional_files = ["test/allowlist_for_publicAPI.json"] def __init__(self, **kwargs: dict[str, Any]) -> None: - # pyrefly: ignore [missing-attribute] super().__init__(**kwargs) def get_prediction_confidence(self, tests: list[str]) -> TestPrioritizations: diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index 21a67f0786e2c..633c3f885efd7 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -41,13 +41,11 @@ def concated_logs() -> str: for log_file in glob.glob( f"{REPO_ROOT}/test/test-reports/**/*.log", recursive=True ): - # pyrefly: ignore [bad-argument-type] logs.append(f"=== {log_file} ===") with open(log_file) as f: # For every line, prefix with fake timestamp for log classifier for line in f: line = line.rstrip("\n") # Remove any trailing newline - # pyrefly: ignore [bad-argument-type] logs.append(f"2020-01-01T00:00:00.0000000Z {line}") return "\n".join(logs) @@ -131,7 +129,7 @@ def trigger_upload_test_stats_intermediate_workflow() -> None: # The GITHUB_TOKEN cannot trigger workflow so this isn't used for now print("Triggering upload_test_stats_intermediate workflow") x = requests.post( - "https://api.github.com/repos/pytorch/pytorch/actions/workflows/upload_test_stats_intermediate.yml/dispatches", # noqa: B950 @lint-ignore + "https://api.github.com/repos/pytorch/pytorch/actions/workflows/upload_test_stats_intermediate.yml/dispatches", headers={ "Accept": "application/vnd.github.v3+json", "Authorization": f"Bearer {os.environ.get('GITHUB_TOKEN')}", diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 6898a5a5462bf..050ec815f194d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -147,6 +147,7 @@ class Stream: ) -> Self: ... def query(self) -> _bool: ... def synchronize(self) -> None: ... + def is_capturing(self) -> _bool: ... def wait_event(self, event: Event) -> None: ... def wait_stream(self, other: Stream) -> None: ... def record_event(self, event: Event | None = None) -> Event: ... @@ -1234,6 +1235,10 @@ def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn def _set_mkldnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledMkldnn def _get_cudnn_benchmark() -> _bool: ... # THPModule_benchmarkCuDNN def _set_cudnn_benchmark(arg: _bool) -> None: ... # THPModule_setBenchmarkCuDNN +def _get_cudnn_depthwise_kernel() -> str: ... # THPModule_getCuDNNDepthwiseKernel +def _set_cudnn_depthwise_kernel( + arg: str, +) -> None: ... # THPModule_setCuDNNDepthwiseKernel def _get_miopen_immediate() -> _bool: ... # THPModule_userImmediateMiopen def _set_miopen_immediate(arg: _bool) -> None: ... # THPModule_setUserImmediateMiopen def _get_cudnn_deterministic() -> _bool: ... # THPModule_deterministicCuDNN @@ -1419,6 +1424,7 @@ class BatchNormBackend(Enum): ... # type: ignore[misc] def _get_blas_preferred_backend() -> _BlasBackend: ... def _set_blas_preferred_backend(arg: _BlasBackend): ... +def _get_blas_default_backend() -> _BlasBackend: ... class _BlasBackend: Default: _BlasBackend @@ -2087,6 +2093,12 @@ def _cuda_getDefaultStream(device: _int) -> tuple: ... def _cuda_getStreamFromExternal(data_ptr: _int, device_index: _int) -> tuple: ... def _cuda_getCurrentBlasHandle() -> _int: ... def _cuda_clearCublasWorkspaces() -> None: ... +def _cuda_getCublasWorkspaceSize() -> _int: ... +def _cuda_setCublasWorkspaceSize(size: _int) -> None: ... +def _cuda_getCublasLtWorkspaceSize() -> _int: ... +def _cuda_setCublasLtWorkspaceSize(size: _int) -> None: ... +def _cuda_resetCublasWorkspaceSize() -> None: ... +def _cuda_resetCublasLtWorkspaceSize() -> None: ... def _cuda_setDevice(device: _int) -> None: ... def _cuda_exchangeDevice(device: _int) -> _int: ... def _cuda_maybeExchangeDevice(device: _int) -> _int: ... @@ -2106,6 +2118,7 @@ def _host_emptyCache() -> None: ... def _cuda_cudaCachingAllocator_raw_alloc(size: _int, cuda_stream: _int) -> _int: ... def _cuda_cudaCachingAllocator_raw_delete(ptr: _int) -> None: ... def _cuda_cudaCachingAllocator_enable(val: _bool) -> None: ... +def _cuda_cudaCachingAllocator_is_enabled() -> _bool: ... def _cuda_beginAllocateToPool(device: _int, mempool_id: tuple[_int, _int]) -> None: ... def _cuda_beginAllocateCurrentThreadToPool( device: _int, @@ -2187,6 +2200,7 @@ def _construct_CUDA_Tensor_From_Storage_And_Metadata( ) -> Tensor: ... def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ... +def _clear_storage_data_ptr_access_error_msg(storage_ptr: _int) -> None: ... def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... @@ -2195,8 +2209,6 @@ class _cuda_CUDAAllocator: ... def _cuda_customAllocator(alloc_fn: _int, free_fn: _int) -> _cuda_CUDAAllocator: ... def _cuda_changeCurrentAllocator(allocator: _cuda_CUDAAllocator) -> None: ... def _cuda_getAllocator() -> _cuda_CUDAAllocator: ... -def _cuda_lock_mutex() -> None: ... -def _cuda_unlock_mutex() -> None: ... def _cuda_canDeviceAccessPeer(device: _int, peer_device: _int) -> _bool: ... def _cuda_jiterator_compile_and_launch_kernel( code_string: str, @@ -2322,6 +2334,7 @@ def _is_flash_attention_available() -> _bool: ... def _can_use_cudnn_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_flash_attention(params: _SDPAParams, debug: _bool) -> _bool: ... def _can_use_mem_efficient_attention(params: _SDPAParams, debug: _bool) -> _bool: ... +def _is_ck_sdpa_available() -> _bool: ... # Defined in torch/csrc/cuda/GdsFile.cpp def _gds_register_buffer(t: Storage) -> None: ... @@ -2491,6 +2504,8 @@ class _XpuDeviceProperties: gpu_eu_count: _int max_work_group_size: _int max_num_sub_groups: _int + memory_clock_rate: _int + memory_bus_width: _int sub_group_sizes: list[_int] local_mem_size: _int has_fp16: _bool @@ -2604,6 +2619,7 @@ def _accelerator_exchangeDevice(device_index: _int) -> _int: ... def _accelerator_maybeExchangeDevice(device_index: _int) -> _int: ... def _accelerator_isAllocatorInitialized() -> _bool: ... def _accelerator_emptyCache() -> None: ... +def _accelerator_emptyHostCache() -> None: ... def _accelerator_getDeviceStats(device_index: _int) -> dict[str, Any]: ... def _accelerator_resetAccumulatedStats(device_index: _int) -> None: ... def _accelerator_resetPeakStats(device_index: _int) -> None: ... @@ -2611,6 +2627,24 @@ def _accelerator_getMemoryInfo(device_index: _int) -> tuple[_int, _int]: ... def _accelerator_getAllocatorSettings() -> str: ... def _accelerator_setAllocatorSettings(env: str) -> None: ... +class _acceleratorGraph: + def __new__(cls, _bool) -> Self: ... + def __init__(self, _bool) -> None: ... + def capture_begin( + self, + pool: tuple[_int, _int] | None = None, + capture_error_mode: Literal[ + "default", "global", "thread_local", "relaxed" + ] = "default", + ) -> None: ... + def capture_end(self) -> None: ... + def instantiate(self) -> None: ... + def replay(self) -> None: ... + def reset(self) -> None: ... + def pool(self) -> tuple[_int, _int]: ... + def enable_debug_mode(self) -> None: ... + def debug_dump(self, path: str) -> None: ... + # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: def push_scope(self, scope_name: str) -> None: ... diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 814ab243bde2a..733fb3410db9b 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -62,6 +62,7 @@ class _KinetoEvent: def duration_ns(self) -> int: ... def is_async(self) -> bool: ... def linked_correlation_id(self) -> int: ... + def external_id(self) -> int: ... def shapes(self) -> list[list[int]]: ... def dtypes(self) -> list[str]: ... def concrete_inputs(self) -> list[Any]: ... @@ -78,8 +79,19 @@ class _KinetoEvent: def cuda_elapsed_us(self) -> int: ... def privateuse1_elapsed_us(self) -> int: ... def is_user_annotation(self) -> bool: ... + def is_python_function(self) -> bool: ... def is_hidden_event(self) -> bool: ... def metadata_json(self) -> str: ... + def activity_type(self) -> str: ... + def extra_meta(self) -> dict[str, str]: ... + def flow_id(self) -> int: ... + def flow_type(self) -> int: ... + def flow_start(self) -> bool: ... + def structured_input_shapes(self) -> list[list[int] | list[list[int]]]: ... + def structured_input_strides(self) -> list[list[int] | list[list[int]]]: ... + def python_id(self) -> int: ... + def python_parent_id(self) -> int: ... + def python_module_id(self) -> int: ... class _ProfilerResult: def events(self) -> list[_KinetoEvent]: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 613ec83afa722..0c4aefc3f265d 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -55,6 +55,7 @@ class Reducer: skip_all_reduce_unused_params: bool = ..., use_python_reducer: bool = ..., bucket_bytes_cap_list: list[int] = ..., + batched_grad_copy: bool = ..., ) -> None: ... def prepare_for_forward(self) -> None: ... def prepare_for_backward(self, output: list[Tensor]) -> None: ... @@ -146,7 +147,9 @@ class ReduceOp: # stub with zero members. There is a chance this is due to a recent change # in the semantics of enum membership. If so, use `member = value` to mark # an enum member, instead of `member: type` - class RedOpType(Enum): ... # type: ignore[misc] + class RedOpType(Enum): + def __call__(self, factor: float | int | Tensor) -> ReduceOp: + """Create a PREMUL_SUM ReduceOp with the given factor. Only PREMUL_SUM supports this.""" class BroadcastOptions: rootRank: int @@ -797,6 +800,8 @@ class _SymmetricMemory: @staticmethod def get_backend(device: torch.device) -> str | None: ... @staticmethod + def is_symm_mem_tensor(tensor: torch.Tensor) -> bool: ... + @staticmethod def get_mempool_allocator(device: torch.device) -> Any: ... signal_pad_size: int @property @@ -893,3 +898,5 @@ class _Response: def _register_handler( name: str, handler: Callable[[_Request, _Response], None] ) -> None: ... +def _set_comm_profiling_name(name: str) -> None: ... +def _get_comm_profiling_name() -> str: ... diff --git a/torch/_C/_distributed_rpc.pyi b/torch/_C/_distributed_rpc.pyi index 48f636d852463..ff0e5b8379a20 100644 --- a/torch/_C/_distributed_rpc.pyi +++ b/torch/_C/_distributed_rpc.pyi @@ -78,8 +78,8 @@ class _TensorPipeRpcBackendOptionsBase(RpcBackendOptions): _channels: list | None, rpc_timeout: float = ..., init_method: str = ..., - device_maps: dict[str, dict[torch.device, torch.device]] = {}, # noqa: B006 - devices: list[torch.device] = [], # noqa: B006 + device_maps: dict[str, dict[torch.device, torch.device]] = {}, + devices: list[torch.device] = [], ) -> None: ... def _set_device_map( self, diff --git a/torch/_C/_dynamo/__init__.pyi b/torch/_C/_dynamo/__init__.pyi index 67d515697cbe4..ab9470c2afac9 100644 --- a/torch/_C/_dynamo/__init__.pyi +++ b/torch/_C/_dynamo/__init__.pyi @@ -2,3 +2,68 @@ from . import compiled_autograd, eval_frame, guards # noqa: F401 def strip_function_call(name: str) -> str: ... def is_valid_var_name(name: str) -> bool | int: ... +def get_type_slots(obj: type | object) -> tuple[int, int, int, int]: ... +def has_slot(slots: int, slot_bit: int) -> bool: ... + +class PySequenceSlots: + SQ_LENGTH: int + SQ_CONCAT: int + SQ_REPEAT: int + SQ_ITEM: int + SQ_CONTAINS: int + SQ_ASS_ITEM: int + SQ_INPLACE_CONCAT: int + SQ_INPLACE_REPEAT: int + +class PyMappingSlots: + MP_LENGTH: int + MP_SUBSCRIPT: int + MP_ASS_SUBSCRIPT: int + +class PyNumberSlots: + NB_ADD: int + NB_SUBTRACT: int + NB_MULTIPLY: int + NB_REMAINDER: int + NB_POWER: int + NB_NEGATIVE: int + NB_POSITIVE: int + NB_ABSOLUTE: int + NB_BOOL: int + NB_INVERT: int + NB_LSHIFT: int + NB_RSHIFT: int + NB_AND: int + NB_XOR: int + NB_OR: int + NB_INT: int + NB_FLOAT: int + NB_INPLACE_ADD: int + NB_INPLACE_SUBTRACT: int + NB_INPLACE_MULTIPLY: int + NB_INPLACE_REMAINDER: int + NB_INPLACE_POWER: int + NB_INPLACE_LSHIFT: int + NB_INPLACE_RSHIFT: int + NB_INPLACE_AND: int + NB_INPLACE_XOR: int + NB_INPLACE_OR: int + NB_FLOOR_DIVIDE: int + NB_TRUE_DIVIDE: int + NB_INPLACE_FLOOR_DIVIDE: int + NB_INPLACE_TRUE_DIVIDE: int + NB_INDEX: int + NB_MATRIX_MULTIPLY: int + NB_INPLACE_MATRIX_MULTIPLY: int + +class PyTypeSlots: + TP_HASH: int + TP_ITER: int + TP_ITERNEXT: int + TP_CALL: int + TP_REPR: int + TP_RICHCOMPARE: int + TP_GETATTRO: int + TP_SETATTRO: int + TP_DESCR_GET: int + TP_DESCR_SET: int diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index ea51e8cc8fe9f..a57eae9fd8779 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -95,10 +95,5 @@ def _load_precompile_entry( ) -> None: ... def _reset_precompile_entries(code: types.CodeType) -> None: ... def _debug_get_precompile_entries(code: types.CodeType) -> list[_PrecompileEntry]: ... - -class _EvalFrameOverride(enum.IntEnum): - NONE = 0 - SKIP = 1 - ERROR = 2 - -def set_eval_frame_override(override: _EvalFrameOverride) -> _EvalFrameOverride: ... +def set_fullgraph_compiled_frame_count(value: int) -> int: ... +def set_fullgraph_error_on_nested_compile(value: bool) -> bool: ... diff --git a/torch/_C/_dynamo/guards.pyi b/torch/_C/_dynamo/guards.pyi index 4ef645ffc3e61..ad443ae262573 100644 --- a/torch/_C/_dynamo/guards.pyi +++ b/torch/_C/_dynamo/guards.pyi @@ -490,6 +490,7 @@ def assert_alignment( alignment: int, op_name: str | None = None, ) -> None: ... +def copy_if_misaligned(item: torch.Tensor) -> torch.Tensor: ... def check_obj_id(obj: object, expected: int) -> bool: ... def check_type_id(obj: object, expected: int) -> bool: ... def dict_version(d: dict[Any, Any]) -> int: ... diff --git a/torch/_C/_nn.pyi.in b/torch/_C/_nn.pyi.in index 7be3dcff4da67..cf6e98fe42e30 100644 --- a/torch/_C/_nn.pyi.in +++ b/torch/_C/_nn.pyi.in @@ -129,6 +129,12 @@ def _upsample_bicubic2d_aa( align_corners: bool, scale_factors: Sequence[float] | None, ) -> Tensor: ... +def _upsample_lanczos2d_aa( + input: Tensor, + output_size: Sequence[int] | None, + align_corners: bool, + scale_factors: Sequence[float] | None, +) -> Tensor: ... def upsample_bicubic2d( input: Tensor, output_size: Sequence[int] | None, diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index 04c3f4a460086..f66764a449616 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -62,7 +62,13 @@ class _ExperimentalConfig: verbose: bool = ..., performance_events: list[str] = ..., enable_cuda_sync_events: bool = ..., + adjust_profiler_step: bool = ..., + disable_external_correlation: bool = ..., profile_all_threads: bool = ..., + capture_overload_names: bool = ..., + record_python_gc_info: bool = ..., + expose_kineto_event_metadata: bool = ..., + custom_profiler_config: str = ..., ) -> None: ... class ProfilerConfig: @@ -109,6 +115,8 @@ class _ProfilerEvent: @property def name(self) -> str: ... @property + def overload_name(self) -> str: ... + @property def tag(self) -> _EventType: ... @property def id(self) -> int: ... diff --git a/torch/__init__.py b/torch/__init__.py index 1205819aa2594..5cbb4d949cdcf 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -29,6 +29,7 @@ get_origin as _get_origin, overload as _overload, TYPE_CHECKING, + TypeGuard as _TypeGuard, TypeVar as _TypeVar, ) from typing_extensions import ( @@ -1171,7 +1172,7 @@ def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: return isinstance(obj, torch.Tensor) -def is_storage(obj: _Any, /) -> builtins.bool: +def is_storage(obj: _Any, /) -> _TypeGuard["TypedStorage | UntypedStorage"]: r"""Returns True if `obj` is a PyTorch storage object. Args: @@ -1711,7 +1712,7 @@ def _check_with( error_type, cond: builtins.bool | SymBool, message: _Callable[[], str], -): # noqa: F811 +): if not isinstance(cond, (builtins.bool, SymBool)): raise TypeError(f"cond must be a bool, but got {type(cond)}") @@ -1742,7 +1743,7 @@ def _check_with( raise error_type(message_evaluated) -def _check(cond, message=None): # noqa: F811 +def _check(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -1796,7 +1797,7 @@ def _check_is_size(i, message=None, *, max=None): _advise_is_bounded(i, max) -def _check_index(cond, message=None): # noqa: F811 +def _check_index(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -1814,7 +1815,7 @@ def _check_index(cond, message=None): # noqa: F811 _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type] -def _check_value(cond, message=None): # noqa: F811 +def _check_value(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -1832,7 +1833,7 @@ def _check_value(cond, message=None): # noqa: F811 _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type] -def _check_type(cond, message=None): # noqa: F811 +def _check_type(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -1850,7 +1851,7 @@ def _check_type(cond, message=None): # noqa: F811 _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type] -def _check_not_implemented(cond, message=None): # noqa: F811 +def _check_not_implemented(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -1873,7 +1874,7 @@ def _check_not_implemented(cond, message=None): # noqa: F811 ) -def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 +def _check_tensor_all_with(error_type, cond, message=None): if not is_tensor(cond): raise TypeError(f"cond must be a tensor, but got {type(cond)}") @@ -1884,7 +1885,7 @@ def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` -def _check_tensor_all(cond, message=None): # noqa: F811 +def _check_tensor_all(cond, message=None): r"""Throws error containing an optional message if the specified condition is False. @@ -2225,7 +2226,10 @@ def _manager_path(): __all__.extend( - name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) + # pyrefly: ignore [unresolvable-dunder-all] + name + for name in dir(torch) + if isinstance(getattr(torch, name), torch.dtype) ) ################################################################################ @@ -2403,11 +2407,12 @@ def compiled_with_cxx11_abi() -> builtins.bool: class _TorchCompileInductorWrapper: compiler_name = "inductor" - def __init__(self, mode, options, dynamic): + def __init__(self, mode, options, dynamic, name=None): from torch._inductor.compiler_bisector import CompilerBisector self.config: dict[str, _Any] = {} self.dynamic = dynamic + self.name = name self.apply_mode(mode) self.apply_options(options) self.apply_options(CompilerBisector.get_config_change("inductor")) @@ -2434,6 +2439,7 @@ def __eq__(self, other): isinstance(other, _TorchCompileInductorWrapper) and self.config == other.config and self.dynamic == other.dynamic + and self.name == other.name ) def apply_mode(self, mode: str | None): @@ -2473,7 +2479,12 @@ def __call__(self, model_, inputs_, *, config_patches=None): from torch._inductor.compile_fx import compile_fx all_patches = {**self.config, **(config_patches or {})} - return compile_fx(model_, inputs_, config_patches=all_patches) + return compile_fx( + model_, + inputs_, + config_patches=all_patches, + compile_region_name=self.name, + ) def get_compiler_config(self): from torch._inductor.compile_fx import get_patched_config_dict @@ -2493,8 +2504,8 @@ def reset(self): class _TorchCompileAOTInductorWrapper(_TorchCompileInductorWrapper): compiler_name = "aotinductor" - def __init__(self, mode, options, dynamic): - super().__init__(mode, options, dynamic) + def __init__(self, mode, options, dynamic, name=None): + super().__init__(mode, options, dynamic, name) self.apply_options({"cpp_wrapper": True}) self.apply_options({"aot_inductor.package": True}) @@ -2567,6 +2578,7 @@ def compile( backend: str | _Callable = "inductor", mode: str | None = None, options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + name: str | None = None, disable: builtins.bool = False, ) -> _Callable[_InputT, _RetT]: ... @@ -2580,6 +2592,7 @@ def compile( backend: str | _Callable = "inductor", mode: str | None = None, options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + name: str | None = None, disable: builtins.bool = False, ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... @@ -2589,10 +2602,12 @@ def compile( *, fullgraph: builtins.bool = False, dynamic: builtins.bool | None = None, - backend: str | _Callable = "inductor", + backend: str | _Callable | None = None, mode: str | None = None, options: dict[str, str | builtins.int | builtins.bool | _Callable] | None = None, + name: str | None = None, disable: builtins.bool = False, + recompile_limit: builtins.int | None = None, ) -> ( _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]] | _Callable[_InputT, _RetT] @@ -2682,6 +2697,8 @@ def compile( - `torch.compiler.keep_tensor_guards_unsafe` - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` + name (str or None): Optional identifier for the compiled region. When supported by downstream + tooling, this is surfaced on wrapped compiled-region higher-order operators and other debug metadata. disable (bool): Turn torch.compile() into a no-op for testing Example:: @@ -2706,6 +2723,11 @@ def foo(x): "Please use Python 3.13.3+." ) + if backend is None: + from torch._dynamo.backends.registry import get_default_backend + + backend = get_default_backend() + # Decorator mode if model is None: @@ -2719,6 +2741,7 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: backend=backend, mode=mode, options=options, + name=name, disable=disable, ) @@ -2764,9 +2787,9 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: if backend == "inductor": if use_aoti: - backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic) + backend = _TorchCompileAOTInductorWrapper(mode, options, dynamic, name) else: - backend = _TorchCompileInductorWrapper(mode, options, dynamic) + backend = _TorchCompileInductorWrapper(mode, options, dynamic, name) else: backend = _TorchCompileWrapper(backend, mode, options, dynamic) @@ -2776,6 +2799,7 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: dynamic=dynamic, disable=disable, guard_filter_fn=guard_filter_fn, + recompile_limit=recompile_limit, )(model) # type: ignore[return-value] @@ -2893,6 +2917,13 @@ def __getattr__(name): if name in _lazy_modules: return importlib.import_module(f".{name}", __name__) + # set_vital + if name == "set_vital": + import warnings + + warnings.warn(f"'{name}' is deprecated, please do not call", stacklevel=2) + return lambda *args: None + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 9cbbb34e07dee..2553f9e60de96 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -371,6 +371,7 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.hardswish_backward, aten.hardtanh_, aten.hardtanh_backward, + aten.hann_window, aten.heaviside, aten.heaviside_, aten.huber_loss, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 6c0bfda496cd7..75087863b52f9 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -330,7 +330,7 @@ def rrelu_with_noise_backward( training: bool, self_is_result: bool, ) -> Tensor: - if training and upper - lower > 1e-6: + if training: return grad_output.mul(noise) else: negative_slope = (lower + upper) / 2 @@ -1517,7 +1517,7 @@ def tensor_split_tensor_indices_or_sections_py_impl( # TODO: this doesn't appear to have enough precision in bfloat16 -@register_decomposition(aten.addmm) +@register_decomposition([aten.addmm.default, aten.addmm.out]) @out_wrapper(exact_dtype=True) @pw_cast_for_opmath def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1): @@ -1537,6 +1537,22 @@ def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = return out + beta * self +@register_decomposition([aten.addmm.dtype, aten.addmm.dtype_out]) +@out_wrapper(exact_dtype=True) +def addmm_dtype( + self: Tensor, + mat1: Tensor, + mat2: Tensor, + out_dtype: torch.dtype, + beta: int = 1, + alpha: int = 1, +): + out = alpha * torch.mm(mat1, mat2, out_dtype=out_dtype) + if beta == 0: + return out + return out + beta * self.to(out_dtype) + + @register_decomposition(aten._addmm_activation) @out_wrapper() @pw_cast_for_opmath @@ -3953,6 +3969,18 @@ def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors): ) +@register_decomposition(aten._upsample_lanczos2d_aa.vec) +@aten._upsample_lanczos2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd) +@aten._upsample_lanczos2d_aa.vec.py_impl(DispatchKey.Autograd) +def upsample_lanczos2d_aa_vec(input, output_size, align_corners, scale_factors): + osize = upsample_compute_output_size(input.size(), output_size, scale_factors) + scale_h = get_scale_value(scale_factors, 0) + scale_w = get_scale_value(scale_factors, 1) + return torch.ops.aten._upsample_lanczos2d_aa( + input, osize, align_corners, scale_h, scale_w + ) + + @register_decomposition(aten.upsample_bilinear2d.vec) @register_decomposition(aten.upsample_trilinear3d.vec) @aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd) @@ -4631,6 +4659,9 @@ def binary_cross_entropy_with_logits( if weight is not None: loss = loss * weight + # this is to align the resulted data type with the in-place + # operation in binary_cross_entropy_with_logits of aten/src/ATen/native/Loss.cpp + loss = loss.to(target.dtype) return apply_loss_reduction(loss, reduction) @@ -5384,7 +5415,9 @@ def isin(elements, test_elements, *, assume_unique=False, invert=False): else: return torch.eq(elements, test_elements) - if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + if guard_or_false(test_elements.numel() < 10.0 * pow(elements.numel(), 0.145)): return isin_default(elements, test_elements, invert=invert) else: return isin_sorting( @@ -5469,6 +5502,219 @@ def resize_as(self, other, memory_format=None): return aten.resize(self, other.shape, memory_format=memory_format) +@register_decomposition(aten.max_pool2d_with_indices_backward) +def max_pool2d_with_indices_backward( + grad_output: Tensor, + self: Tensor, + kernel_size, + stride, + padding, + dilation, + ceil_mode: bool, + indices: Tensor, +): + """ + Decomposition of max_pool2d_with_indices_backward using scatter_add. + + This replaces the native implementation with a high-level decomposition + that uses scatter_add for gradient accumulation. The scatter-based approach + provides automatic optimization opportunities for Inductor and handles all + pooling configurations without requiring specialized fallback paths. + + Algorithm: + For each output gradient position, use the corresponding index from the + forward pass to scatter the gradient to the input position. When multiple + output positions select the same input position as max, scatter_add + automatically accumulates their gradients. + + Complexity: O(B * C * H_out * W_out) + Independent of kernel size, unlike traditional O(B * C * H_in * W_in * K²) + approaches that iterate over input positions and kernel windows. + + Known Limitations: + - FP16/BF16: Uses FP32 accumulation internally to preserve precision when + many gradients accumulate to the same position (overlapping pooling windows). + This adds slight overhead but ensures numerical stability. + - Deterministic mode: Falls back to native implementation to ensure + consistent results across runs + + Args: + grad_output: Gradient w.r.t. pooling output [B, C, H_out, W_out] + self: Original input tensor (for shape) [B, C, H_in, W_in] + kernel_size: Pooling kernel size + stride: Pooling stride + padding: Pooling padding + dilation: Pooling dilation + ceil_mode: Whether to use ceil for output size calculation + indices: Indices from forward pass (per-channel linear positions) + + Returns: + Gradient w.r.t. input [B, C, H_in, W_in] + """ + # Use native kernel in deterministic mode + if torch.are_deterministic_algorithms_enabled(): + return NotImplemented + + # MPS: Use native kernel. scatter_add has correctness issues on macOS 14 + # (#163327) and numerical differences on macOS 15+. + if grad_output.device.type == "mps": + return NotImplemented + + # Get spatial dimensions + in_height = self.size(-2) + in_width = self.size(-1) + out_height = grad_output.size(-2) + out_width = grad_output.size(-1) + + # Handle both 3D (C, H, W) and 4D (B, C, H, W) cases by treating 3D as 4D + is_batched = self.dim() == 4 + if not is_batched: + self = self.unsqueeze(0) + grad_output = grad_output.unsqueeze(0) + indices = indices.unsqueeze(0) + + batch_size = self.size(0) + channels = self.size(1) + + # For FP16/BF16, use FP32 accumulation to avoid precision loss + # This is critical when many gradients accumulate to the same position + # (overlapping pooling windows with large kernels or stride < kernel_size) + use_fp32_accum = grad_output.dtype in (torch.float16, torch.bfloat16) + accum_dtype = torch.float32 if use_fp32_accum else grad_output.dtype + + # Create grad_input with correct accumulation dtype from the start + grad_input_flat = torch.zeros( + batch_size * channels, + in_height * in_width, + dtype=accum_dtype, + device=grad_output.device, + ) + + # Reshape grad_output and indices to (B*C, H_out*W_out) + grad_output_flat = grad_output.reshape( + batch_size * channels, out_height * out_width + ) + indices_flat = indices.reshape(batch_size * channels, out_height * out_width) + + # Convert grad_output to accumulation dtype if needed + if use_fp32_accum: + grad_output_flat = grad_output_flat.to(torch.float32) + + # Scatter gradients to input positions + grad_input_flat = grad_input_flat.scatter_add(1, indices_flat, grad_output_flat) + + # Reshape back to original input shape + grad_input = grad_input_flat.reshape(batch_size, channels, in_height, in_width) + + # Convert back to original dtype if we used FP32 accumulation + if use_fp32_accum: + grad_input = grad_input.to(grad_output.dtype) + + # Preserve memory format from input (channels_last vs channels_first) + memory_format = utils.suggest_memory_format(self) + grad_input = grad_input.contiguous(memory_format=memory_format) + + # Remove batch dimension for 3D case + if not is_batched: + grad_input = grad_input.squeeze(0) + + return grad_input + + +@register_decomposition([aten.hann_window.default, aten.hann_window.out]) +@out_wrapper() +def hann_window( + window_length: int, + *, + dtype: torch.dtype | None = None, + layout: torch.layout | None = None, + device: torch.device | None = None, + pin_memory: bool | None = None, +) -> Tensor: + """hann_window(window_length, *, dtype=None, layout=None, device=None, pin_memory=False) -> Tensor + + Returns a Hann window of size :attr:`window_length` with ``periodic=True``. + + Equivalent to :func:`torch.hann_window` with ``periodic=True``. + + Args: + window_length (int): the size of returned window. + + Keyword args: + dtype (:class:`torch.dtype`, optional): desired dtype. Default: global default. + layout (:class:`torch.layout`, optional): desired layout. Default: ``torch.strided``. + device (:class:`torch.device`, optional): desired device. Default: current device. + pin_memory (bool, optional): if ``True``, pins the returned tensor. Default: ``False``. + """ + return aten.hann_window.periodic( + window_length, + True, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + + +@register_decomposition([aten.hann_window.periodic, aten.hann_window.periodic_out]) +@out_wrapper() +def hann_window_periodic( + window_length: int, + periodic: bool = True, + *, + dtype: torch.dtype | None = None, + layout: torch.layout | None = None, + device: torch.device | None = None, + pin_memory: bool | None = None, +) -> Tensor: + r"""hann_window(window_length, periodic=True, *, dtype=None, layout=None, device=None, pin_memory=False) -> Tensor + + Returns a Hann window of size :attr:`window_length`. + + .. math:: + w[n] = 0.5 - 0.5 \cos\!\left(\frac{2\pi n}{N-1}\right) + + where :math:`N` is ``window_length + 1`` when ``periodic=True`` (for spectral analysis), + or ``window_length`` when ``periodic=False`` (symmetric window). + + Low-precision dtypes (``bfloat16``, ``float16``) are computed in ``float32`` then cast. + + Args: + window_length (int): the size of returned window. + periodic (bool, optional): if ``True``, returns a periodic window for use with STFT. + Default: ``True``. + + Keyword args: + dtype (:class:`torch.dtype`, optional): desired dtype. Default: global default. + layout (:class:`torch.layout`, optional): desired layout. Default: ``torch.strided``. + device (:class:`torch.device`, optional): desired device. Default: current device. + pin_memory (bool, optional): if ``True``, pins the returned tensor. Default: ``False``. + """ + dtype = dtype if dtype is not None else torch.get_default_dtype() + if window_length == 0: + return torch.empty( + (0,), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + if window_length == 1: + return torch.ones( + (1,), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory + ) + compute_dtype = utils.get_computation_dtype(dtype) + n = window_length + 1 if periodic else window_length + t = torch.arange( + n, + dtype=compute_dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + ) + t = t * (2.0 * torch.pi / (n - 1)) + t = torch.cos(t) + t = t * -0.5 + 0.5 + window = t.narrow(0, 0, window_length) if periodic else t + return window.to(dtype) + + register_inplace(aten.addbmm_, aten.addbmm) register_inplace(aten.addmm_, aten.addmm) register_inplace(aten.addmv_, aten.addmv) @@ -5495,3 +5741,22 @@ def resize_as(self, other, memory_format=None): register_inplace(aten.scatter_add_, aten.scatter_add) register_inplace(aten.scatter_reduce_, aten.scatter_reduce) register_inplace(aten.silu_, aten.silu) + + +@aten.one_hot.default.py_impl(DispatchKey.CompositeImplicitAutograd) +def one_hot(self: Tensor, num_classes: int = -1) -> Tensor: + if num_classes == -1: + num_classes = int(self.max().item()) + 1 + # _assert_async is side-effectful and won't be DCE'd + aten._assert_async.msg( + torch.all(self >= 0), + "one_hot: Class values must be non-negative.", + ) + aten._assert_async.msg( + torch.all(self < num_classes), + "one_hot: Class values must be smaller than num_classes.", + ) + return ( + self.unsqueeze(-1) + == torch.arange(num_classes, dtype=self.dtype, device=self.device) + ).to(torch.int64) diff --git a/torch/_decomp/decompositions_for_rng.py b/torch/_decomp/decompositions_for_rng.py index 27b29ca5f85aa..d1753b1baf5a8 100644 --- a/torch/_decomp/decompositions_for_rng.py +++ b/torch/_decomp/decompositions_for_rng.py @@ -78,7 +78,7 @@ def reset(self): self.seed = torch.tensor(()) self.base_offset = torch.tensor(()) self.relative_offset = 0 - self.offset_advanced_alteast_once = False + self.offset_advanced_at_least_once = False def validate_state(self): if self.seed.numel() == 0 or self.base_offset.numel() == 0: @@ -88,7 +88,7 @@ def validate_state(self): ) def advance_offset(self, consumed_offset): - self.offset_advanced_alteast_once = True + self.offset_advanced_at_least_once = True self.relative_offset = self.relative_offset + consumed_offset def set_state(self, seed, base_offset, relative_offset=0): @@ -205,7 +205,7 @@ def multiple_of_4(offset): @classmethod def get_updated_fwd_offset(cls): # Short circuit if no rand ops were observed - if not cls.fwd_state.offset_advanced_alteast_once: + if not cls.fwd_state.offset_advanced_at_least_once: return cls.fwd_state.base_offset return cls.multiple_of_4( cls.fwd_state.base_offset + cls.fwd_state.relative_offset @@ -214,7 +214,7 @@ def get_updated_fwd_offset(cls): @classmethod def get_updated_bwd_offset(cls): # Short circuit if no rand ops were observed - if not cls.bwd_state.offset_advanced_alteast_once: + if not cls.bwd_state.offset_advanced_at_least_once: return cls.bwd_state.base_offset return cls.multiple_of_4( cls.bwd_state.base_offset + cls.bwd_state.relative_offset diff --git a/torch/_dynamo/CLAUDE.md b/torch/_dynamo/CLAUDE.md index 73aecf050fe1f..e72321277fdb9 100644 --- a/torch/_dynamo/CLAUDE.md +++ b/torch/_dynamo/CLAUDE.md @@ -47,7 +47,7 @@ Key fields: `source` (where the value came from, for guards) and Key subclass families in `variables/`: `TensorVariable` / `SymNodeVariable` (tensor.py), `ConstantVariable` (constant.py), `ListVariable` / -`TupleVariable` (lists.py), `ConstDictVariable` / `SetVariable` (dicts.py), +`TupleVariable` (lists.py), `ConstDictVariable` (dicts.py), `SetVariable` (sets.py), `UserFunctionVariable` (functions.py), `BuiltinVariable` (builtin.py), `NNModuleVariable` (nn_module.py), `UserDefinedObjectVariable` (user_defined.py), `TorchHigherOrderOperatorVariable` (higher_order_ops.py), diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index f258338c3baf9..692151a382d75 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -40,6 +40,7 @@ maybe_mark_dynamic, nonstrict_trace, override_cudagraphs, + override_optimization_hint, patch_dynamo_config, run, set_stance, @@ -75,7 +76,7 @@ # Register polyfill functions -from .polyfills import loader as _ # usort: skip # noqa: F401 +from .polyfills import loader as _ # usort: skip __all__ = [ @@ -100,6 +101,7 @@ "mark_static", "mark_static_address", "nonstrict_trace", + "override_optimization_hint", "optimize", "optimize_assert", "OptimizedModule", diff --git a/torch/_dynamo/aot_compile.py b/torch/_dynamo/aot_compile.py index f1d3b30c5c4e5..7ac01d6cb5488 100644 --- a/torch/_dynamo/aot_compile.py +++ b/torch/_dynamo/aot_compile.py @@ -387,12 +387,20 @@ def new_guard_filter_fn( compiled_fn = backend( backend_input.graph_module, backend_input.example_inputs ) - # If Inductor backend is used, grab the compiled_fn from PrecompileContext + # If Inductor backend or AOTAutograd-based backend is used, + # wrap the compiled_fn for serialization. # TODO: this should be replaced once we make the backend return the SerializableCallable directly. - if isinstance(backend, torch._TorchCompileInductorWrapper) or ( - hasattr(backend, "compiler_fn") - and isinstance( - backend.compiler_fn, torch._dynamo.backends.common.AotAutograd + if ( + isinstance(backend, torch._TorchCompileInductorWrapper) + or ( + hasattr(backend, "compiler_fn") + and isinstance( + backend.compiler_fn, torch._dynamo.backends.common.AotAutograd + ) + ) + or ( + hasattr(compiled_fn, "serialize") + and compiled_fn.serialize is not None ) ): compiled_fn = BundledAOTAutogradSerializableCallable(compiled_fn) diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 0ba0588007ef3..0b60e0d74a623 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -19,7 +19,7 @@ import contextlib import functools import logging -from collections.abc import Callable, Iterable +from collections.abc import Callable, Iterable, Sequence from typing import Any from typing_extensions import ParamSpec, TypeVar from unittest.mock import patch @@ -47,7 +47,7 @@ def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs def __call__( - self, gm: torch.fx.GraphModule, example_inputs: Iterable[Any], **kwargs: Any + self, gm: torch.fx.GraphModule, example_inputs: Sequence[Any], **kwargs: Any ) -> Callable[..., Any]: if kwargs: log.warning("aot_autograd-based backend ignoring extra kwargs %s", kwargs) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index e6642584d7ccd..eacb304abf7b1 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -36,6 +36,7 @@ from torch._dynamo.output_graph import GraphCompileReason from torch._functorch import config as functorch_config from torch._functorch.compilers import ts_compile +from torch._inductor.output_code import OutputCode from .common import aot_autograd from .registry import CompiledFn, CompilerFn, register_debug_backend as register_backend @@ -166,6 +167,7 @@ def invoke_subgraph_inner_compiler( from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_infer @disable + # pyrefly: ignore [deprecated] @torch._dynamo.allow_in_graph def invoke_subgraph_wrapper_unboxed(*operands: Any) -> Any: return invoke_subgraph_infer(subgraph, *operands) @@ -260,6 +262,54 @@ def invoke_subgraph( )(gm, fake_tensor_inputs) +@dataclasses.dataclass +class AOTEagerOutputCode(OutputCode): + """ + An OutputCode that wraps a GraphModule for eager-mode execution. + + This allows non-inductor backends (like aot_eager) to participate in + the bundled autograd cache and aot_compile serialization flow. + """ + + gm: torch.fx.GraphModule | None = None + _serialized_gm: bytes | None = dataclasses.field(default=None, init=False) + + def __call__(self, inputs: Any) -> Any: + assert self.gm is not None + return self.gm.forward(inputs) + + def prepare_for_serialization(self) -> None: + from torch.fx._graph_pickler import GraphPickler, Options + + assert self.gm is not None + for node in self.gm.graph.nodes: + node.meta.pop("nn_module_stack", None) + node.meta.pop("source_fn_stack", None) + node.meta.pop("example_value", None) + + self._serialized_gm = GraphPickler.dumps(self.gm, Options(ops_filter=None)) + self.gm = None + + def post_compile(self, *args: Any, **kwargs: Any) -> None: + if self.gm is None and self._serialized_gm is not None: + from torch._subclasses import FakeTensorMode + from torch.fx._graph_pickler import GraphPickler + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from torch.fx.graph import _BoxedCodeGen + + fake_mode = FakeTensorMode(shape_env=ShapeEnv()) + gm = GraphPickler.loads(self._serialized_gm, fake_mode) + assert isinstance(gm, torch.fx.GraphModule) + self.gm = gm + assert isinstance(self.gm, torch.fx.GraphModule) + self.gm.graph.set_codegen(_BoxedCodeGen()) + self.gm.recompile() + self._serialized_gm = None + + def set_triton_bundle(self, triton_bundle: Any) -> None: + pass + + # used boxed call to discard inputs when they are no longer needed def boxed_nop( fx_g: torch.fx.GraphModule, example_inputs: list[torch.Tensor] @@ -270,6 +320,11 @@ def boxed_nop( fx_g.graph.set_codegen(_BoxedCodeGen()) fx_g.recompile() + if functorch_config.force_autograd_cache or functorch_config.bundled_autograd_cache: + result = AOTEagerOutputCode(gm=fx_g) + result._boxed_call = True # type: ignore[attr-defined] + return result + # Wrap the forward method in a function so we can set _boxed_call attribute forward_fn = fx_g.forward diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index bbd8dc53eaa70..95db743ff28a7 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -184,9 +184,11 @@ def __init__( module: fx.GraphModule, compiler: CompilerFn, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, + **compiler_configs: Any, ) -> None: super().__init__(module) self.compiler = compiler + self.compiler_configs = compiler_configs self.fake_mode = fake_mode # See Note [DDPOptimizer and fw_metadata] ctx = torch._guards.TracingContext.try_get() @@ -239,7 +241,7 @@ def forward(self, *args: Any) -> Any: ) wrapper = WrapperModule( - self.compiler(input_mod, args), + self.compiler(input_mod, args, **self.compiler_configs), unwrap_singleton_tuple, ) return wrapper @@ -485,10 +487,13 @@ def add_param_args(self, bucket: Bucket, node: fx.Node) -> None: self.add_param(bucket, param, str(arg.target)) def compile_fn( - self, gm: fx.GraphModule, example_inputs: list[torch.Tensor] + self, + gm: fx.GraphModule, + example_inputs: list[torch.Tensor], + **compiler_configs: Any, ) -> CompiledFn: """ - Implements graph splitting, first determining a set of of buckets by counting + Implements graph splitting, first determining a set of buckets by counting parameter sizes in reverse graph order, then invoking the user/backend compiler to compile each subgraph. Finally, stitches compiled graphs into one graphmodule and returns its callable. @@ -566,7 +571,7 @@ def compile_fn( if len(buckets) == 1: # bypass split/fuse logic if there is only one bucket - return self.backend_compile_fn(gm, example_inputs) + return self.backend_compile_fn(gm, example_inputs, **compiler_configs) # 2: partition the graphmodule according to bucket capacity partition_map = {} @@ -611,7 +616,9 @@ def compile_fn( if fake_mode is None: fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() - submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode) + submod_compiler = SubmodCompiler( + split_gm, self.backend_compile_fn, fake_mode, **compiler_configs + ) with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(): submod_compiler.run(*example_inputs) split_gm.recompile() diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index f8b311ceb52f8..0a3ee776e96d8 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -79,6 +79,7 @@ def __call__(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: ... _BACKENDS: dict[str, EntryPoint | None] = {} _COMPILER_FNS: dict[str, CompilerFn] = {} +_default_backend: str | CompilerFn = "inductor" def register_backend( @@ -204,3 +205,22 @@ def _is_registered_backend(compiler_fn: CompilerFn) -> bool: return compiler_fn.compiler_fn in _COMPILER_FNS.values() return False + + +def set_default_backend(backend: str | CompilerFn | None) -> None: + """Set the default backend used by torch.compile when no backend is explicitly specified. + + Pass None to reset to the default ("inductor"). + """ + global _default_backend + if backend is None: + _default_backend = "inductor" + return + if not isinstance(backend, str) and not callable(backend): + raise TypeError(f"backend must be a string or callable, got {type(backend)}") + _default_backend = backend + + +def get_default_backend() -> str | CompilerFn: + """Return the current default backend for torch.compile.""" + return _default_backend diff --git a/torch/_dynamo/bytecode_debugger.py b/torch/_dynamo/bytecode_debugger.py index 42cf59b58de38..30b6e5afda895 100644 --- a/torch/_dynamo/bytecode_debugger.py +++ b/torch/_dynamo/bytecode_debugger.py @@ -71,7 +71,7 @@ def __repr__(self) -> str: # Import NULL_STACK_VALUE sentinel from C++ module # This is returned by _get_frame_value_stack_with_depth for NULL stack slots -from torch._C._dynamo.eval_frame import NULL_STACK_VALUE # noqa: F401 +from torch._C._dynamo.eval_frame import NULL_STACK_VALUE @dataclass diff --git a/torch/_dynamo/callback.py b/torch/_dynamo/callback.py index 25e9f260e34b3..c9466f17edf98 100644 --- a/torch/_dynamo/callback.py +++ b/torch/_dynamo/callback.py @@ -29,7 +29,7 @@ def my_end_callback(): import threading from collections.abc import Callable, Generator from contextlib import contextmanager -from dataclasses import dataclass, field # noqa: F811 +from dataclasses import dataclass, field from typing import Any diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index ce8eac2b2a527..3190eade91bff 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -223,6 +223,7 @@ def add( assert not hasattr(self, name) result = Op(name, fn, is_custom_function) if is_traceable: + # pyrefly: ignore [deprecated] setattr(self, name, torch._dynamo.allow_in_graph(result)) else: # C++ autograd function was not marked as traceable @@ -309,7 +310,7 @@ def begin_capture( self, inputs: list[torch.Tensor], sizes: list[int], - scalars: list[int | float], + scalars: list[IntLikeType | FloatLikeType], origins: list[list[tuple[int, str]]], accumulate_grad: bool, check_nans: bool, @@ -381,7 +382,8 @@ def begin_capture( (proxies[i],), {}, ) - self.symnode_proxy_lookup[symint.node] = proxies[i] + if not isinstance(symint, int): + self.symnode_proxy_lookup[symint.node] = proxies[i] proxies = self.bind_objects_to_proxies(sym_sizes, proxies, sizes_origins) for idx, val in enumerate(scalars): @@ -489,7 +491,7 @@ def call_aot_bwd_prologue( ctx_saved_tensors: Sequence[torch.Tensor], ctx_symints: Sequence[IntLikeType], ctx_opaque_objs: Sequence[Any], - *flat_args: Sequence[Any], + flat_args: Sequence[Any], ) -> Any: out = torch._functorch._aot_autograd.runtime_wrappers._backward_prologue_functional( ctx_saved_tensors, @@ -497,18 +499,19 @@ def call_aot_bwd_prologue( ctx_opaque_objs, metadata, maybe_subclass_metadata, - *flat_args, + flat_args, ) return out pgrads = self.fx_tracer.create_proxy( kind="call_function", + # pyrefly: ignore [bad-argument-type] target=call_aot_bwd_prologue, args=( psaved_tensors, psymints, popaque_objects, - *pinputs, + pinputs, ), kwargs={}, ) @@ -633,6 +636,7 @@ def make_subclass(*unwrapped_args: Any) -> Any: poutput = self.fx_tracer.create_proxy( kind="call_function", + # pyrefly: ignore [bad-argument-type] target=make_subclass, args=tuple(punwrapped_args), kwargs={}, @@ -677,6 +681,11 @@ def proxy_call_backward( opaque_object_indices, ) else: + if getattr(ctx._forward_cls, "boxed_grads_call", False): # type: ignore[attr-defined] + raise RuntimeError( + f"boxed_grads_call=True on {ctx._forward_cls.__name__} " # type: ignore[attr-defined] + "is not supported with compiled autograd. " + ) proxies = self.fx_tracer.create_proxy( kind="call_function", target=call_backward, @@ -982,7 +991,7 @@ def dce(self) -> None: # Dynamo guards will error instead of creating aliasing guards unless we unpack them in the graph unpack_nodes: OrderedSet[torch.fx.Node] = OrderedSet() i: int | None = None - for i, node in enumerate(self.fx_tracer.graph.find_nodes(op="placeholder")): # noqa: B007 + for i, node in enumerate(self.fx_tracer.graph.find_nodes(op="placeholder")): unpack_nodes.update(node.users.keys()) assert i == len(_graph_placeholders) - 1 diff --git a/torch/_dynamo/synthetic_function_graph_break.py b/torch/_dynamo/comprehension_graph_break.py similarity index 99% rename from torch/_dynamo/synthetic_function_graph_break.py rename to torch/_dynamo/comprehension_graph_break.py index beb3a7a43cebe..d0e15cf5fa134 100644 --- a/torch/_dynamo/synthetic_function_graph_break.py +++ b/torch/_dynamo/comprehension_graph_break.py @@ -6,11 +6,12 @@ import functools import logging import sys -import types from typing import Any, TYPE_CHECKING if TYPE_CHECKING: + import types + from collections.abc import Callable from .symbolic_convert import InstructionTranslatorBase @@ -629,13 +630,7 @@ def update(instructions: list[Instruction], code_options: dict[str, Any]) -> Non skip_code(new_code) # Install as global - if new_code.co_freevars: - tx.output.install_global_unsafe(fn_name, new_code) - else: - tx.output.install_global_unsafe( - fn_name, - types.FunctionType(new_code, tx.f_globals, fn_name), - ) + tx.output.install_resume_function_global(fn_name, new_code, tx.f_globals) return new_code, fn_name diff --git a/torch/_dynamo/comptime.py b/torch/_dynamo/comptime.py index 34ac2e287f305..811ed4a46e482 100644 --- a/torch/_dynamo/comptime.py +++ b/torch/_dynamo/comptime.py @@ -228,7 +228,7 @@ def parent(self) -> "ComptimeContext": def __get_tx(self, stacklevel: int) -> Any: tx = self.__tx - # pyrefly: ignore [bad-assignment] + # pyrefly: ignore [bad-assignment, non-convergent-recursion] for _ in range(stacklevel): tx = tx.parent # type: ignore[assignment] return tx diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index d38105046f50a..6ff444419f617 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -180,6 +180,26 @@ # Valid options: "dynamic", "unbacked" automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic" +# When True, adds exclusion guards for tensor dims and scalars that transition +# from static to dynamic via automatic_dynamic_shapes. +# +# Invariant: when enabled, automatic_dynamic recompilation preserves graph +# selection — inputs that matched a previous static cache entry will continue +# to use that entry, not be intercepted by a newer dynamic entry. This holds +# as long as recompilations are caused solely by the same variable being +# observed with different static values (progressive dynamism). A recompilation +# triggered by a different reason (e.g., a guard failure unrelated to shape +# transitions) will clear the exclusion state for that entry. +# +# Mechanism: the exclusion guard rejects inputs matching the prior static +# graph's sizes, so those inputs fall through to the more specialized static +# graph instead of being captured by the newer dynamic graph. +# +# Scope: applies only to graph-input-level dimension and scalar transitions. +# Does NOT handle data-dependent branching (if x.size(0) > k), graph breaks, +# or other recompilation triggers where no dimension actually transitions. +automatic_dynamic_exclusion_guard = False + # log graph in/out metadata # This is only turned on for export today since we # know we are tracing a flat callable. later, this @@ -442,7 +462,7 @@ skip_tensor_guards_with_matching_dict_tags = True # Skips guards on func.__defaults__ if the element to be guarded is a constant -skip_guards_on_constant_func_defaults = True +skip_guards_on_constant_func_defaults = False # The recursive-dict-tag guard relies on the class/function identity staying @@ -544,11 +564,13 @@ inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated] default=True, justknob="pytorch/compiler:inline_inbuilt_nn_modules", + deprecated=True, + deprecation_message="does not do anything, inline_inbuilt_nn_modules is always True", ) # Resume tracing in nested frames if a nested graph break occurs # Old behavior is to bubble up the graph break to the top level frame. -nested_graph_breaks = False +nested_graph_breaks: bool = False # If True, error if Dynamo attempts to trace more code while running compiled code in fullgraph=True. # If Dynamo determines that it should skip tracing the code (either at the C/C++ or Python level), @@ -862,7 +884,8 @@ def default_debug_dir_root() -> str: # and AOTAutograd runtime wrapper. record_runtime_overhead = True -enable_aot_compile = False +# Flag to enable the use of torch.compile().aot_compile() API. Should be always True. +enable_aot_compile = True # HACK: this is for testing custom ops profiling only _custom_ops_profile: Any | None = None @@ -887,7 +910,7 @@ def default_debug_dir_root() -> str: invalidate_compile_context_weakrefs: bool | None = None if TYPE_CHECKING: - from torch.utils._config_typing import * # noqa: F401, F403 + from torch.utils._config_typing import * # noqa: F403 def _make_closure_patcher(**changes: Any) -> Any: ... diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 637857979fd14..b379d968038d4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -99,6 +99,7 @@ exceeds_recompile_limit, is_recompilation, ) +from .code_context import code_context from .eval_frame import ( always_optimize_code_objects, Constraint, @@ -219,6 +220,28 @@ def _clear_fake_mode_weakrefs( describer.lookup_storage.clear() +def clear_compile_context_weakrefs( + tracer_output: DynamoTracerOutput | None, + compiler_fn: CompilerFn, +) -> None: + """Clear WeakIdRef entries that can block swap_tensors after compile.""" + should_clear = config.invalidate_compile_context_weakrefs + if should_clear is None: + should_clear = _is_registered_backend(innermost_backend(compiler_fn)) + if not should_clear or not tracer_output: + return + # Use output_graph_for_cleanup which is set even on error paths + # (output_graph is None when the compilation errored). + output_graph = tracer_output.output_graph_for_cleanup + if output_graph is None: + return + tc = output_graph.tracing_context + tc.tensor_to_context.clear() + _clear_fake_mode_weakrefs(tc.fake_mode) + if hasattr(output_graph, "_old_fake_mode"): + _clear_fake_mode_weakrefs(output_graph._old_fake_mode) + + class Tracker: def __init__(self) -> None: self.seen: list[ReferenceType[CodeType]] = [] @@ -585,6 +608,7 @@ def __init__( export: bool = False, export_constraints: Any | None = None, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> None: # assert export_constraints is None reset_graph_break_dup_checker() @@ -593,6 +617,7 @@ def __init__( self._export = export self._export_constraints = export_constraints self._package = package + self._recompile_limit = recompile_limit self._box = ConvertFrameBox() @property @@ -602,6 +627,7 @@ def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: self._one_graph, self._export, self._export_constraints, + recompile_limit=self._recompile_limit, ) def __call__( @@ -720,7 +746,15 @@ def __call__( dynamo_tls.traced_frame_infos.append(info) try: - with compile_context(CompileContext(compile_id)): + compile_ctx = compile_context(CompileContext(compile_id)) + # When recompile_limit is set, temporarily override the global + # config so the existing exceeds_recompile_limit check uses it. + recompile_ctx = ( + config.patch(recompile_limit=self._recompile_limit) + if self._recompile_limit is not None + else contextlib.nullcontext() + ) + with compile_ctx, recompile_ctx: result = _compile( frame.f_code, frame.f_globals, @@ -759,10 +793,16 @@ def convert_frame_assert( export: bool = False, export_constraints: Any | None = None, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> ConvertFrameAssert: """Fully convert a frame into an FX graph, raising an exception if we fail.""" return ConvertFrameAssert( - compiler_fn, one_graph, export, export_constraints, package + compiler_fn, + one_graph, + export, + export_constraints, + package, + recompile_limit, ) @@ -773,12 +813,47 @@ def convert_frame_assert( # we have to use `OrderedDict` to make `RemovableHandle` work. _bytecode_hooks: dict[int, BytecodeHook] = OrderedDict() +_BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY = "bytecode_hook_side_effects" + + +def get_compiled_code_side_effects( + code: types.CodeType, +) -> tuple[str, ...] | None: + """Return Dynamo's replayed Python side-effect sources for compiled code. + + Returns ``None`` when ``code`` was not produced by Dynamo or no metadata was + attached to it. + """ + if not code_context.has_context(code): + return None + side_effects = code_context.get_context(code).get( + _BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY + ) + if side_effects is None: + return None + return side_effects + + +def compiled_code_has_side_effects(code: types.CodeType) -> bool: + """Return whether Dynamo recorded replayed Python side effects for compiled code.""" + return bool(get_compiled_code_side_effects(code)) + + +def _copy_code_context(src_code: types.CodeType, dst_code: types.CodeType) -> None: + if not code_context.has_context(src_code): + return + src_context = code_context.get_context(src_code) + if _BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY in src_context: + code_context.get_context(dst_code)[_BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY] = ( + src_context[_BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY] + ) def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: """Register hooks for bytecode generated by Dynamo. The hook can do some - logging, as well as return a new code object to be used. Please refer - to `BytecodeHook` for the hook signature. + logging, as well as return a new code object to be used. Hooks can query + `get_compiled_code_side_effects(new_code)` for replayed Python side effect + sources. Please refer to `BytecodeHook` for the hook signature. """ handle = RemovableHandle(_bytecode_hooks) _bytecode_hooks[handle.id] = hook @@ -1004,6 +1079,20 @@ def _check_external_refs(self, f_globals: dict[str, Any]) -> None: ) +def _safe_builtins_dict(builtins_dict: dict[str, Any]) -> dict[str, Any]: + """Filter a builtins dict to only picklable entries for serialization.""" + import pickle + + result = {} + for k, v in builtins_dict.items(): + try: + pickle.dumps(v) + result[k] = v + except Exception: + pass + return result + + @dataclass class GraphCaptureOutput: """ @@ -1053,6 +1142,17 @@ def get_runtime_env(self) -> GraphRuntimeEnv: # Scan bytecode for all external references external_refs = self._get_external_refs(self.bytecode) + # Best-effort serialization of builtins referenced by the bytecode. + # Similar to how guards prune __builtins_dict__ to only used entries. + import builtins as _builtins + + for ref in external_refs: + if ref not in used_globals: + if ref.startswith("__builtins_dict__") and ref in self.f_globals: + used_globals[ref] = _safe_builtins_dict(self.f_globals[ref]) + elif hasattr(_builtins, ref): + used_globals[ref] = getattr(_builtins, ref) + return GraphRuntimeEnv( bytecode=self.bytecode, import_sources=self.import_sources, @@ -1435,7 +1535,7 @@ def transform( failed_tracer_output = getattr(e, "_torch_dynamo_tracer_output", None) if failed_tracer_output: failed_tracer_output._cleanup_output_graph() - log.debug( # noqa: G200 + log.debug( "Received signal to skip frame (without graph break): %s %s \ %s %s", e, @@ -1598,17 +1698,23 @@ def log_bytecode( payload_fn=lambda: dis.Bytecode(out_code).dis(), ) + assert tracer_output.output_graph is not None + output = tracer_output.output_graph + code_context.get_context(out_code)[_BYTECODE_HOOK_SIDE_EFFECTS_CONTEXT_KEY] = ( + tuple(output.get_replayed_side_effect_source_refs()) + ) + for idx, hook in enumerate(_bytecode_hooks.values()): with dynamo_timed(f"bytecode_hooks_{idx}", log_pt2_compile_event=True): hook_output = hook(code, out_code) if hook_output is not None: + if hook_output is not out_code: + _copy_code_context(out_code, hook_output) out_code = hook_output orig_code_map[out_code] = code output_codes.add(out_code) dynamo_time_before_restart = last_attempt_start_time - start_time - assert tracer_output.output_graph is not None - output = tracer_output.output_graph from .bytecode_debugger import BREAKPOINT_MARKER @@ -1709,7 +1815,7 @@ def count_args(code: CodeType) -> int: return wrap_guarded_code(guarded_code), tracer_output metrics_context = get_metrics_context() - code_context = ( + package_code_context = ( package.code_context(code) if package is not None else contextlib.nullcontext() ) with ( @@ -1725,7 +1831,7 @@ def count_args(code: CodeType) -> int: phase_name="entire_frame_compile", dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ), - code_context, + package_code_context, ): restart_reasons: set[str] = set() if compile_pg := get_compile_pg(): @@ -1742,29 +1848,10 @@ def count_args(code: CodeType) -> int: recompile_reason = ( "Unable to find recompilation reasons" if not reasons else reasons[0] ) - # Recheck for recompilation, for when inline_inbuilt_nn_modules is set to False - inline_inbuilt_nn_modules_candidate = False - if not config.inline_inbuilt_nn_modules and frame: - inbuilt_nn_reasons = get_and_maybe_log_recompilation_reasons( - cache_entry, frame, innermost_fn(compiler_fn), skip_logging=True - ) - inbuilt_nn_recompile_reason = ( - None if not inbuilt_nn_reasons else inbuilt_nn_reasons[0] - ) - - if ( - inbuilt_nn_recompile_reason is not None - and "[inline-inbuilt-nn-modules-candidate]" - in inbuilt_nn_recompile_reason - ): - inline_inbuilt_nn_modules_candidate = True - - # Set if the recompile is a candidate for inline_inbuilt_nn_modules - # regardless of whether inline_inbuilt_nn_modules is set or not metrics_context.update_outer( { "recompile_reason": recompile_reason, - "inline_inbuilt_nn_modules_candidate": inline_inbuilt_nn_modules_candidate, + "inline_inbuilt_nn_modules_candidate": False, } ) @@ -1914,6 +2001,7 @@ def raise_unimplemented_cache_limit_exceeded() -> NoReturn: if recompile_reason and "size mismatch at index" in recompile_reason: _log_size_mismatch_recompile() + clear_compile_context_weakrefs(tracer_output, compiler_fn) return guarded_code except Exception as e: # NB: e's msg is mutated here to add user stack, but we DON'T want @@ -2059,7 +2147,7 @@ def raise_unimplemented_cache_limit_exceeded() -> NoReturn: ) # Cleanup guards unless if in export, which will return guards - # Make sure to to do this after collecting metrics + # Make sure to do this after collecting metrics if ( tracer_output is not None and tracer_output.output_graph is not None @@ -2067,24 +2155,7 @@ def raise_unimplemented_cache_limit_exceeded() -> NoReturn: ): tracer_output.output_graph.tracing_context.guards_context.dynamo_guards.clear() - # Clear WeakIdRef entries that can block swap_tensors after compile. - # Determine whether to clear based on config and backend type. - should_clear = config.invalidate_compile_context_weakrefs - if should_clear is None: - # Default: clear for registered backends, don't clear for custom - # Unwrap the compiler_fn to get the actual backend function - should_clear = _is_registered_backend(innermost_backend(compiler_fn)) - if should_clear: - if tracer_output and tracer_output.output_graph: - tc = tracer_output.output_graph.tracing_context - tc.tensor_to_context.clear() - # Clear both the current fake_mode and the old_fake_mode - # (the original is stored before backend_fake_mode replaces it) - _clear_fake_mode_weakrefs(tc.fake_mode) - if hasattr(tracer_output.output_graph, "_old_fake_mode"): - _clear_fake_mode_weakrefs( - tracer_output.output_graph._old_fake_mode - ) + clear_compile_context_weakrefs(tracer_output, compiler_fn) class ConvertFrame: @@ -2093,18 +2164,24 @@ def __init__( compiler_fn: CompilerFn, hooks: Hooks, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> None: self._torchdynamo_orig_backend = compiler_fn self._inner_convert = convert_frame_assert( - compiler_fn, one_graph=False, package=package + compiler_fn, + one_graph=False, + package=package, + recompile_limit=recompile_limit, ) self._hooks = hooks + self._recompile_limit = recompile_limit @property def _clone_with_backend(self) -> Callable[[WrapBackendDebug], ConvertFrame]: return lambda backend: convert_frame( backend, self._hooks, + recompile_limit=self._recompile_limit, ) def __call__( @@ -2229,9 +2306,12 @@ def convert_frame( compiler_fn: CompilerFn, hooks: Hooks, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> ConvertFrame: """Try to convert a frame into an FX graph, if error leave frame unmodified""" - return ConvertFrame(compiler_fn, hooks, package=package) + return ConvertFrame( + compiler_fn, hooks, package=package, recompile_limit=recompile_limit + ) # TODO mlazos: add support for same args, or record them diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py index 34cdc7eb2c5fb..7b8496a97ba53 100644 --- a/torch/_dynamo/debug_utils.py +++ b/torch/_dynamo/debug_utils.py @@ -588,6 +588,9 @@ def const(self, name: str) -> None: def unsupported(self, name: str) -> None: pass + def generator(self, device_type: str, device_index: int) -> None: + pass + def opaque(self, script_class_name: str) -> None: self.total += 1 @@ -680,6 +683,12 @@ def const(self, name: str) -> None: def unsupported(self, name: str) -> None: self.args.append(None) + def generator(self, device_type: str, device_index: int) -> Any: + device = torch.device(device_type, device_index) + gen = torch.Generator(device=device) + self.args.append(gen) + return gen + def opaque(self, script_class_name: str) -> None: self.args.append(None) @@ -821,6 +830,12 @@ def symint(self, name: str, val: Any) -> None: val = val.node.hint self._lines.append(f"reader.symint({val!r}) # {name}") + def generator(self, name: str, arg: torch._C.Generator) -> None: + device = arg.device + self._lines.append( + f"reader.generator({device.type!r}, {device.index!r}) # {name}" + ) + def opaque(self, name: str, script_class_name: str) -> None: self._lines.append(f"reader.opaque({script_class_name!r}) # {name}") diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index e5a8740ef7882..46cc9584d6666 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from types import TracebackType from typing import Any, overload, TYPE_CHECKING, TypeVar -from typing_extensions import deprecated, ParamSpec +from typing_extensions import ParamSpec import torch import torch.utils._pytree as pytree @@ -178,11 +178,6 @@ def assume_constant_result(fn): # type: ignore[no-untyped-def] return fn -@deprecated( - "torch._dynamo.allow_in_graph is deprecated and will be removed in a future version. " - "Use torch._dynamo.nonstrict_trace instead.", - category=FutureWarning, -) def allow_in_graph(fn): # type: ignore[no-untyped-def] """ Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function @@ -320,6 +315,8 @@ def _invoke_leaf_function_python( args: tuple[Any, ...], kwargs: dict[str, Any], mutates_args: frozenset[str] | None = None, + hook_fn: Callable[..., Any] | None = None, + hook_fake_fn: Callable[..., Any] | None = None, ) -> Any: """Call invoke_leaf_function HOP directly from Python. @@ -361,6 +358,10 @@ def _invoke_leaf_function_python( real_fn_callable = _LeafCallable(wrapped_real) fake_fn_callable = _LeafCallable(wrapped_fake) + if hook_fn is not None: + real_fn_callable._leaf_hook_real_fn = hook_fn # type: ignore[attr-defined] + real_fn_callable._leaf_hook_fake_fn = hook_fake_fn # type: ignore[attr-defined] + mutated_flat_indices = "" if mutates_args: from torch._higher_order_ops.invoke_leaf_function import ( @@ -513,6 +514,37 @@ def count_calls_fake(x): To validate that your fake implementation matches the real function's outputs, set ``torch._dynamo.config.leaf_function_validate_outputs = True``. + **register_multi_grad_hook (optional)**: + You can register a backward hook via ``@fn.register_multi_grad_hook`` + to run code when gradients have been computed + for all requires_grad tensor inputs during backward. The hook fires exactly once + per backward pass. The hook function has the same signature as the leaf function; + each requires_grad tensor argument receives the corresponding gradient instead + of the original tensor. Non-tensor arguments and tensors without requires_grad + are passed through unchanged. The hook must return ``None``. The hook is called + as a leaf function itself, so it is also opaque to the compiler. + + Example:: + + >>> @leaf_function + ... def debug_log(t, tag): + ... print(f"[{tag}][fwd] norm={t.norm().item()}") + ... return None + ... + >>> @debug_log.register_fake + ... def debug_log_fake(t, tag): + ... return None + ... + >>> @debug_log.register_multi_grad_hook + ... def debug_log_hook(t_grad, tag): + ... print(f"[{tag}][bwd] norm={t_grad.norm().item()}") + ... + >>> x = torch.randn(4, requires_grad=True) + >>> debug_log(x, "intermediate") # no assignment needed + [intermediate][fwd] norm=... + >>> (x * 2).sum().backward() + [intermediate][bwd] norm=... + Limitations: Currently, inductor backend and :func:`torch.export.export` are not yet supported. @@ -714,6 +746,8 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: args, kwargs, mutates_args=inner._torchdynamo_leaf_mutates_args, # pyrefly: ignore [missing-attribute] + hook_fn=inner._torchdynamo_leaf_hook_fn, # type: ignore[attr-defined] + hook_fake_fn=inner._torchdynamo_leaf_hook_fake_fn, # type: ignore[attr-defined] ) # type: ignore[attr-defined] inner._torchdynamo_leaf_real_fn = fn # type: ignore[attr-defined] @@ -721,6 +755,8 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: inner._torchdynamo_leaf_mutates_args = ( # pyrefly: ignore [missing-attribute] frozenset(mutates_args) if mutates_args else frozenset() ) # type: ignore[attr-defined] + inner._torchdynamo_leaf_hook_fn = None # type: ignore[attr-defined] + inner._torchdynamo_leaf_hook_fake_fn = None # type: ignore[attr-defined] # Follow nonstrict_trace implementation wrapped_id = id(inner) @@ -739,6 +775,13 @@ def register_fake_setter(fake_fn: Callable[..., Any]) -> Callable[..., Any]: inner.register_fake = register_fake_setter # type: ignore[attr-defined] + def register_hook_setter(hook_fn: Callable[..., Any]) -> Callable[..., Any]: + inner._torchdynamo_leaf_hook_fn = hook_fn # type: ignore[attr-defined] + inner._torchdynamo_leaf_hook_fake_fn = lambda *args, **kwargs: None # type: ignore[attr-defined] + return inner + + inner.register_multi_grad_hook = register_hook_setter # type: ignore[attr-defined] + return inner @@ -1644,6 +1687,80 @@ def override_cudagraphs( return CudagraphOverrideContextManager(fwd=fwd, bwd=bwd) +def override_optimization_hint(x: Any, val: int) -> None: + """Override the optimization hint for a scalar unbacked symbol. + + When the compiler or runtime needs a non-guarding integer hint for an + unbacked ``SymInt`` — for example during FX passes, graph partitioning, + or inductor autotuning — it calls + ``_optimization_hint_base`` + (see ``torch/fx/experimental/_size_hinting.py``). By default + that function uses internal heuristics to choose a hint and a global fixed + fallback; + this function lets user code override that choice. This is similar to + the ``hint_override`` parameter in ``mark_unbacked``, but applies to + symbols that already exist (e.g. from ``.item()`` calls). + + Typical usage:: + + u = x.item() # unbacked SymInt + torch._dynamo.override_optimization_hint(u, 42) + # From now on, any call to shape_env.optimization_hint(u, ...) + # returns 42 instead of the default heuristic value. + + This updates ``shape_env.var_to_hint_override`` so that any consumer + of ``_optimization_hint_base`` sees *val* as the hint for the + unbacked symbol behind *x*. + + Works both eagerly (during FX passes or outside dynamo) and inside + ``torch.compile`` regions. + + Behavior during compilation: + The dynamo handler applies the hint as a **side effect** on + ``shape_env.var_to_hint_override`` during tracing. No FX graph + node is emitted — the call is fully consumed at trace time and + does not appear in the pre-grad, joint, or post-autograd graphs. + ``FXGraphCache`` includes ``var_to_hint_override`` in its cache + key, so cache hits/misses correctly reflect hint changes. + + .. note:: + + To maximize performance, it is recommended to pass hints for + **all** unbacked symbols in the program to guide optimizations. + + Args: + x: A ``torch.SymInt`` wrapping an unbacked symbol (e.g. from + ``.item()``), or a plain ``int``. If *x* is a plain ``int`` + the call is a no-op. + val: The integer hint value to record. + """ + if not isinstance(val, int): + raise TypeError( + f"override_optimization_hint expects val to be an int, got {type(val)}" + ) + if isinstance(x, int): + return + if not isinstance(x, torch.SymInt): + raise TypeError( + f"override_optimization_hint expects a torch.SymInt or int, got {type(x)}" + ) + shape_env = x.node.shape_env + expr = x.node.expr + import sympy + + if not isinstance(expr, sympy.Symbol): + raise ValueError( + f"override_optimization_hint expects a single unbacked symbol, " + f"got derived expression: {expr}" + ) + if not shape_env.is_unbacked_symint(expr): + raise ValueError( + f"override_optimization_hint expects an unbacked symbol, " + f"but {expr} is backed" + ) + shape_env.var_to_hint_override[expr] = val + + def is_dynamo_disable_recursive(method: Callable[[Any], Any]) -> bool | None: """ Check if a method is marked as `dynamo_disable` recursively. It returns: diff --git a/torch/_dynamo/dynamo_profiler.py b/torch/_dynamo/dynamo_profiler.py index aa587c3f6cb9d..7c645b7f17071 100644 --- a/torch/_dynamo/dynamo_profiler.py +++ b/torch/_dynamo/dynamo_profiler.py @@ -380,9 +380,7 @@ def generate_svg( _, gprof2dot_err = gprof2dot.communicate() if gprof2dot.returncode != 0: - print( - f"gprof2dot failed: {gprof2dot_err.decode()}" # noqa: B950 - ) + print(f"gprof2dot failed: {gprof2dot_err.decode()}") return None if dot.returncode != 0: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 1f9aacacd4afc..7e5ac5477e90a 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -56,11 +56,11 @@ # see discussion at https://github.com/pytorch/pytorch/issues/120699 from torch._C._dynamo.eval_frame import ( # noqa: F401 - _EvalFrameOverride, reset_code, set_code_exec_strategy, set_eval_frame, - set_eval_frame_override, + set_fullgraph_compiled_frame_count, + set_fullgraph_error_on_nested_compile, set_guard_complete_hook, set_guard_error_hook, set_skip_guard_eval_unsafe, @@ -278,7 +278,12 @@ def fail_callback( cache_entries = _debug_get_cache_entry_list(frame.f_code) if cache_entries: reasons = get_and_maybe_log_recompilation_reasons( - cache_entries[0], frame, innermost_fn(callback), skip_logging=True + # pyrefly: ignore [bad-argument-type] + cache_entries[0], + frame, + # pyrefly: ignore [bad-argument-type] + innermost_fn(callback), + skip_logging=True, ) if reasons: failures = textwrap.indent("\n".join(reasons), "- ") @@ -721,12 +726,6 @@ def guard_collectives_hook(guard_eval_result: bool) -> bool: _not_set = object() -def _get_eval_frame_override() -> _EvalFrameOverride: - if torch._dynamo.config.error_on_dynamo_callback_in_fullgraph_compiled_code: - return _EvalFrameOverride.ERROR - return _EvalFrameOverride.SKIP - - class _TorchDynamoContext: def __init__( self, @@ -921,16 +920,19 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: f"A callable function is expected, but {type(fn)} is provided." ) - # NOTE [Top-level TorchInGraph functions] + # NOTE [Top-level TorchInGraph and polyfilled functions] # Some callables (e.g. torch.exp) are represented as TorchInGraphFunctionVariable # when traced inside a frame. When such a function is passed directly to # torch.compile, we detect it here so we can force it through wrap_inline. + # Similarly, functions registered via substitute_in_graph have a polyfill + # that Dynamo can trace, so they also need wrap_inline. from .variables import TorchInGraphFunctionVariable rule = trace_rules.lookup(fn) top_level_in_graph = isinstance(rule, type) and issubclass( rule, TorchInGraphFunctionVariable ) + has_polyfill = trace_rules.is_polyfilled_callable(fn) try: filename = inspect.getsourcefile(fn) @@ -939,7 +941,12 @@ def aot_compile(example_inputs: tuple[tuple[Any, ...], dict[str, Any]]) -> Any: if config.debug_force_nested_calls: fn = external_utils.wrap_inline(fn) elif config.wrap_top_frame or ( - (filename is None or trace_rules.check(fn) or top_level_in_graph) + ( + filename is None + or trace_rules.check(fn) + or top_level_in_graph + or has_polyfill + ) and ( getattr(fn, "__name__", "") not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"] @@ -968,11 +975,14 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: # Unlike in eval_frame_cpp.cpp/convert_frame.py, we don't attempt to restore global state # due to additional overhead costs. prior = set_eval_frame(None) - prior_eval_frame_override: _EvalFrameOverride | None = None + prior_error_on_nested_compile: bool | None = None + fullgraph_count_enabled = False if self.fullgraph: - prior_eval_frame_override = set_eval_frame_override( - _get_eval_frame_override() + prior_error_on_nested_compile = set_fullgraph_error_on_nested_compile( + torch._dynamo.config.error_on_dynamo_callback_in_fullgraph_compiled_code ) + if not self.export: + fullgraph_count_enabled = set_fullgraph_compiled_frame_count(0) < 0 try: # We shouldn't compile inside kernel invocation. if tracing_context := torch._guards.TracingContext.try_get(): @@ -1032,8 +1042,10 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: _maybe_set_eval_frame(_callback_from_stance(callback)) + call_succeeded = False try: - return fn(*args, **kwargs) + result = fn(*args, **kwargs) + call_succeeded = True except (Unsupported, UncapturedHigherOrderOpError, UserError) as e: if config.verbose: raise @@ -1051,10 +1063,20 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: finally: # Restore the dynamic layer stack depth if necessary. set_eval_frame(None) + if fullgraph_count_enabled and call_succeeded: + count = set_fullgraph_compiled_frame_count(-1) + if count == 0: + raise RuntimeError( + "torch.compile with fullgraph=True found no compiled frames. " + "The frame was likely skipped (e.g., a non-infra torch dispatch " + "mode was active, dynamo was disabled, or the frame was skipped." + ) if prior_error_on_graph_break is not None: _set_error_on_graph_break(prior_error_on_graph_break) - if prior_eval_frame_override is not None: - set_eval_frame_override(prior_eval_frame_override) + if prior_error_on_nested_compile is not None: + set_fullgraph_error_on_nested_compile( + prior_error_on_nested_compile + ) torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( saved_dynamic_layer_stack_depth ) @@ -1062,7 +1084,10 @@ def compile_wrapper(*args: Any, **kwargs: Any) -> Any: set_skip_guard_eval_unsafe(prior_skip_guard_eval_unsafe) for cleanup in cleanups: cleanup() + return result finally: + if fullgraph_count_enabled: + set_fullgraph_compiled_frame_count(-1) _maybe_set_eval_frame(prior) # hooks to properly handle inlining @@ -1510,6 +1535,7 @@ def _optimize( disable: bool = False, dynamic: bool | None = None, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> OptimizeContext | _NullDecorator: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1568,6 +1594,7 @@ def toy_example(a, b): ... hooks=hooks, rebuild_ctx=rebuild_ctx, package=package, + recompile_limit=recompile_limit, ) backend = get_compiler_fn(backend) @@ -1591,6 +1618,7 @@ def toy_example(a, b): ... backend, hooks, package=package, + recompile_limit=recompile_limit, ), hooks, backend_ctx_ctor, @@ -2443,6 +2471,7 @@ def _optimize_assert( export_constraints: Any | None = None, dynamic: bool | None = None, package: CompilePackage | None = None, + recompile_limit: int | None = None, ) -> OptimizeContext: """ Guarantees single-graph capture. @@ -2473,6 +2502,7 @@ def _optimize_assert( export=export, export_constraints=export_constraints, package=package, + recompile_limit=recompile_limit, ), hooks, backend_ctx_ctor, diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index b3bdf014581ce..e2da1f37cb1d4 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -112,6 +112,14 @@ class AutogradGradRestartAnalysis(RestartAnalysis): """ +class RequiresGradRestartAnalysis(RestartAnalysis): + """Raised when a source-less requires_grad_() intermediate leaks as output. + + On restart, requires_grad_() will graph break instead of being traced, + preserving partial acceleration for code before the call. + """ + + class UnspecializeRestartAnalysis(RestartAnalysis): pass @@ -419,16 +427,24 @@ def raise_observed_exception( exc_type: type[Exception], tx: InstructionTranslatorBase, *, - args: list[VariableTracker] | None = None, + args: list[VariableTracker] | list[str] | None = None, kwargs: dict[str, VariableTracker] | None = None, ) -> NoReturn: from .symbolic_convert import ExceptionVals from .variables.builder import SourcelessBuilder + if args: + args_ = [ + SourcelessBuilder.create(tx, arg) if isinstance(arg, str) else arg + for arg in args + ] + else: + args_: list[VariableTracker] = [] + # CPython here raises an exception. Since there is no python code, we have to manually setup the exception # stack and raise the exception. exception_vt = SourcelessBuilder.create(tx, exc_type).call_function( - tx, args or [], kwargs or {} + tx, args_, kwargs or {} ) assert isinstance(exception_vt, ExceptionVals) tx._attach_traceback_to_exception(exception_vt) @@ -436,10 +452,15 @@ def raise_observed_exception( raised_exc = get_dynamo_observed_exception(exc_type) # Store the original exception arguments for better error messages if args: - raise raised_exc(*args) + raise raised_exc(*args_) raise raised_exc +def raise_type_error(tx: InstructionTranslatorBase, msg: str) -> NoReturn: + """Raise a TypeError as an observed exception during tracing.""" + raise_observed_exception(TypeError, tx, args=[msg]) + + def handle_observed_exception(tx: Any) -> None: # This is essentially exception handling code, equivalent of this pseudo code # @@ -775,13 +796,13 @@ def filter_stack(stack: StackSummary) -> StackSummary: return user_stack -def remove_resume_prefix(name: str) -> str | None: +def remove_resume_prefix(name: str) -> str: from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX match = re.match(f"{TORCH_DYNAMO_RESUME_IN_PREFIX}_(\\w+)_at_\\d+", name) if match: return match.group(1) - return None + return name def collapse_resume_frames(stack: StackSummary | list[FrameSummary]) -> StackSummary: @@ -813,6 +834,7 @@ def collapse_resume_frames(stack: StackSummary | list[FrameSummary]) -> StackSum new_stack[-1] = frame frame.name = name else: + frame.name = name new_stack.append(frame) return new_stack diff --git a/torch/_dynamo/functional_export.py b/torch/_dynamo/functional_export.py index 31f06d09dff01..41de630917a5f 100644 --- a/torch/_dynamo/functional_export.py +++ b/torch/_dynamo/functional_export.py @@ -605,13 +605,11 @@ def create_fx_graph_from_captured_output( graph_module = backend_input.graph_module if isinstance(root, torch.nn.Module): - graph_module._parameters = root._parameters.copy() - graph_module._buffers = root._buffers.copy() + graph_module._parameters = root._parameters + graph_module._buffers = root._buffers assert all(not hasattr(graph_module, m) for m in root._modules) graph_module._modules.update(root._modules) - graph_module._non_persistent_buffers_set = ( - root._non_persistent_buffers_set.copy() - ) + graph_module._non_persistent_buffers_set = root._non_persistent_buffers_set if sys.version_info >= (3, 14): import annotationlib # added in 3.14 @@ -757,10 +755,16 @@ def gen_var_bindings( {", ".join(without_annotation)}, = self._dynamo_bytecode_flatten(*_fn_args)""" def generate_output( - self, output_args: torch.fx.node.Argument, *, descs: object | None = None + self, + output_args: torch.fx.node.Argument, + *, + descs: object | None = None, + repr_fn: Any | None = None, ) -> str: + if repr_fn is None: + repr_fn = repr # pyrefly: ignore [not-iterable] - returned = f"self._dynamo_bytecode_unflatten(({', '.join([str(a) for a in output_args])},), _fn_args)" + returned = f"self._dynamo_bytecode_unflatten(({', '.join([repr_fn(a) for a in output_args])},), _fn_args)" if self.wrap_tuple: returned = f"({returned},)" return f"return {returned}" @@ -968,9 +972,7 @@ def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule: ) transformed_graph.recompile() - clean_nn_module_stack_and_source_fn( - transformed_graph, torch._dynamo.config.inline_inbuilt_nn_modules - ) + clean_nn_module_stack_and_source_fn(transformed_graph, True) clean_export_root(transformed_graph) transformed_graph.meta["module_call_specs"] = module_call_spec diff --git a/torch/_dynamo/graph_break_registry.json b/torch/_dynamo/graph_break_registry.json index 59e4b032a0eeb..a1535309bdca1 100644 --- a/torch/_dynamo/graph_break_registry.json +++ b/torch/_dynamo/graph_break_registry.json @@ -3,7 +3,7 @@ { "Gb_type": "All __torch_function__ overrides returned NotImplemented due to TypeError from user code", "Context": "fn={fn}, args={args}, kwargs={kwargs}", - "Explanation": "All __torch_function__ overrides for for function {fn} returned NotImplemented", + "Explanation": "All __torch_function__ overrides for function {fn} returned NotImplemented", "Hints": [ "Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance(\"force_eager\")`. " ] @@ -396,6 +396,16 @@ ] } ], + "GB6999": [ + { + "Gb_type": "unsupported variable type for __dict__ access", + "Context": "VariableTracker type: {type(vt)}", + "Explanation": "Dynamo does not know how to get __dict__ from {type(vt)}", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], "GB0037": [ { "Gb_type": "Dynamic shape operator (no meta kernel)", @@ -567,6 +577,16 @@ ] } ], + "GB4880": [ + { + "Gb_type": "missing_mp_subscript", + "Context": "mp_subscript_impl not defined for {type(self).__name__}", + "Explanation": "Dynamo does not yet support subscripting '{self.python_type_name()}'.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB9741": [ { "Gb_type": "skip frame due to being in functorh mode", @@ -640,6 +660,16 @@ "Hints": [] } ], + "GB5195": [ + { + "Gb_type": "elementwise_dtypes unsupported arg type", + "Context": "str(arg)", + "Explanation": "elementwise_dtypes requires tensor or constant arguments, got {type(arg).__name__}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0054": [ { "Gb_type": "Failed to construct Enum variable", @@ -679,6 +709,18 @@ ] } ], + "GB7299": [ + { + "Gb_type": "_autograd_grad with lost grad_fn linkage", + "Context": "outputs lost autograd linkage during tracing", + "Explanation": "_autograd_grad() received tensors whose grad_fn was lost during tracing - this silently produces zero gradients.", + "Hints": [ + "Compile the full transform instead of the returned ", + "closure: torch.compile(lambda x: torch.func.vjp(f, x))", + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0058": [ { "Gb_type": "Failed to set tensor attribute", @@ -704,6 +746,16 @@ ] } ], + "GB6270": [ + { + "Gb_type": "Unhandled tensor method", + "Context": "call_method {self} {name} {args} {kwargs}", + "Explanation": "Tensor method `{name}` is not defined on {check_type.__name__} and does not have an explicit handler in TensorVariable.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0060": [ { "Gb_type": "Failed to trace unittest method", @@ -718,7 +770,7 @@ "GB8656": [ { "Gb_type": "Opaque object with custom __getattr__ not supported", - "Context": "{value_type.__name__} with custom __getattr__", + "Context": "{real_obj_type.__name__} with custom __getattr__", "Explanation": "Dynamo does not support opaque objects types with custom __getattr__ methods", "Hints": [] } @@ -865,17 +917,6 @@ "Hints": [] } ], - "GB0719": [ - { - "Gb_type": "Reconstruction of FakeIdVariable", - "Context": "str(self.value)", - "Explanation": "A fake id produced by id() on a compile-time container cannot be reconstructed across a graph break.", - "Hints": [ - "Avoid using id() on containers in code that may graph-break.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], "GB0072": [ { "Gb_type": "Inplace op on input tensor", @@ -887,6 +928,22 @@ ] } ], + "GB7230": [ + { + "Gb_type": "iter() with no arguments", + "Context": "iter()", + "Explanation": "iter() requires at least one argument", + "Hints": [ + "Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance(\"force_eager\")`. " + ] + }, + { + "Gb_type": "iter() with no arguments", + "Context": "iter()", + "Explanation": "iter() requires at least one argument", + "Hints": [] + } + ], "GB0073": [ { "Gb_type": "Invoking an nn.Module inside a HigherOrderOperator", @@ -903,6 +960,16 @@ "Hints": [] } ], + "GB6012": [ + { + "Gb_type": "Cannot trace user-defined __len__", + "Context": "{self.python_type_name()}.__len__()", + "Explanation": "Dynamo cannot trace len() on {self.python_type_name()} because the __len__ method is either not traceable (e.g., defined in C or built-in) or returns a non-constant value.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0075": [ { "Gb_type": "LOAD_BUILD_CLASS bytecode not supported", @@ -945,6 +1012,16 @@ ] } ], + "GB3775": [ + { + "Gb_type": "nb_int_impl not implemented", + "Context": "{type(self).__name__} has nb_int slot but no nb_int_impl override", + "Explanation": "The type {self.python_type_name()} has an nb_int C slot but the corresponding VariableTracker doesn't implement nb_int_impl.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0079": [ { "Gb_type": "Missing CALL_INTRINSIC_1 handler", @@ -1023,17 +1100,6 @@ ] } ], - "GB7581": [ - { - "Gb_type": "copy.deepcopy(tensor)", - "Context": "copy.deepcopy({self})", - "Explanation": "Dynamo does not support copy.deepcopy() on tensors.", - "Hints": [ - "Avoid calling copy.deepcopy() on tensors inside compiled regions.", - "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." - ] - } - ], "GB0085": [ { "Gb_type": "Non-function or method in subclass of torch.autograd.Function", @@ -1062,6 +1128,14 @@ } ], "GB0088": [ + { + "Gb_type": "Observed exception", + "Context": "raised exception {curr_exc.debug_repr()}", + "Explanation": "observed_exn_gb_explanation", + "Hints": [ + "Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance(\"force_eager\")`. " + ] + }, { "Gb_type": "Observed exception", "Context": "raised exception {curr_exc.python_type_name()}({curr_exc.args})", @@ -1409,6 +1483,16 @@ "Hints": [] } ], + "GB7650": [ + { + "Gb_type": "torch.Generator method", + "Context": "torch.Generator.{name}", + "Explanation": "torch.Generator.{name}() is a stateful RNG operation that cannot be soundly traced in the FX graph.", + "Hints": [ + "This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround." + ] + } + ], "GB0118": [ { "Gb_type": "Unimplemented next() call", @@ -1684,6 +1768,22 @@ ] } ], + "GB1380": [ + { + "Gb_type": "Illegal __getitem__ invocation in strict mode", + "Context": "mp_subscript_impl {self} {key}", + "Explanation": "Dynamo currently does not support __getitem__ invocation in strict mode.", + "Hints": [] + } + ], + "GB4240": [ + { + "Gb_type": "unsupported __getitem__", + "Context": "mp_subscript_impl {self} {key}", + "Explanation": "Dynamo does not know how to handle __getitem__ on {self}", + "Hints": [] + } + ], "GB0141": [ { "Gb_type": "Unsupported call_id() without source", @@ -1842,6 +1942,12 @@ } ], "GB0153": [ + { + "Gb_type": "Unsupported key type for nn.Module.__getitem__", + "Context": "mp_subscript_impl: {self} {key}", + "Explanation": "Dynamo does not support getitem on `nn.Module` with non-constant key.", + "Hints": [] + }, { "Gb_type": "Unsupported key type for nn.Module.__getitem__", "Context": "call_method: {self} {name} {args} {kwargs}", @@ -1954,6 +2060,16 @@ ] } ], + "GB4841": [ + { + "Gb_type": "Unsupported python_type() call", + "Context": "{obj} does not implement python_type()", + "Explanation": "This VariableTracker does not implement python_type(), which is required for object protocol operations.", + "Hints": [ + "This is likely to be a Dynamo bug. Please report an issue to PyTorch." + ] + } + ], "GB8335": [ { "Gb_type": "Generator reconstruction with mutations", @@ -2085,6 +2201,12 @@ } ], "GB0170": [ + { + "Gb_type": "Data-dependent branching", + "Context": "attempted to jump with {value}", + "Explanation": "\"Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). \" \"Dynamo does not support tracing dynamic control flow.\" + trace_info", + "Hints": [] + }, { "Gb_type": "Data-dependent branching", "Context": "attempted to jump with {value}", @@ -2456,6 +2578,15 @@ } ], "GB8843": [ + { + "Gb_type": "Opaque object member with method-type USE_REAL returned a reference-type opaque object.", + "Context": "Opaque object type: {real_obj_type}. Method name: '{name}'", + "Explanation": "To properly guard reference-type opaque objects, we must lift them as inputs to the graph. In order to do this, they must all have a source, meaning they come from a global value or are an attribute of an input.", + "Hints": [ + "Register member '{name}' with MemberType.INLINED in ", + "register_opaque_type({real_obj_type}, members=...)." + ] + }, { "Gb_type": "Opaque object member with method-type USE_REAL returned a reference-type opaque object.", "Context": "Opaque object type: {value_type}. Method name: '{name}'", @@ -2876,6 +3007,16 @@ ] } ], + "GB6025": [ + { + "Gb_type": "Failed to trace list()", + "Context": "list({arg_types})", + "Explanation": "Dynamo does not know how to construct a list from argument types {arg_types}", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0236": [ { "Gb_type": "Invalid input type for nonstrict_trace-ed function", @@ -3192,6 +3333,16 @@ } ], "GB4927": [ + { + "Gb_type": "autograd.grad consumed returned tensor's grad_fn", + "Context": "Leaked output tensors: {leaked_str}", + "Explanation": "torch.autograd.grad() consumes grad_fns that are needed by tensors returned from this compiled function. This would cause 'backward through graph a second time' errors.\n The following returned tensors have consumed grad_fns: {leaked_str}", + "Hints": [ + "Detach the problematic tensor(s) before returning: e.g. `{leaked[0]}.detach()`", + "Call .detach() on the tensor before returning.", + "If you need to backward through the returned tensor, use retain_graph=True in autograd.grad()." + ] + }, { "Gb_type": "autograd.grad consumed returned tensor's grad_fn", "Context": "", @@ -3448,6 +3599,20 @@ ] } ], + "GB7324": [ + { + "Gb_type": "returning intermediate with requires_grad_()", + "Context": "graph output depends on source-less requires_grad_()", + "Explanation": "msg", + "Hints": [ + "If you only need the tensor values without gradients, ", + "call .detach() before returning.", + "Consume the gradient inside the compiled function ", + "(call backward() and use .grad), ", + "or move requires_grad_() outside torch.compile." + ] + } + ], "GB0280": [ { "Gb_type": "1-arg super not implemented", @@ -3458,7 +3623,23 @@ ] } ], + "GB3383": [ + { + "Gb_type": "Pydantic dataclass constructor", + "Context": "{self.value}", + "Explanation": "Dynamo graph breaks on pydantic dataclass constructors because validation mutates the instance outside traced bytecode.", + "Hints": [] + } + ], "GB0281": [ + { + "Gb_type": "Invalid or non-const argument in nn.Module __getitem__", + "Context": "mp_subscript_impl: {self} {key}", + "Explanation": "Dynamo does not support calling method `__getitem__` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.", + "Hints": [ + "Use constant arguments of type str or int for __getitem__" + ] + }, { "Gb_type": "Invalid or non-const argument in nn.Module __getitem__", "Context": "call_method: {self} {name} {args} {kwargs}", @@ -3588,6 +3769,16 @@ ] } ], + "GB7906": [ + { + "Gb_type": "Missing len_impl", + "Context": "len({type(self).__name__})", + "Explanation": "Dynamo does not support len() on {type(self).__name__}. Add len_impl to this VariableTracker subclass.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0291": [ { "Gb_type": "logging.Logger method not supported for non-export cases", @@ -3850,6 +4041,14 @@ ] } ], + "GB5458": [ + { + "Gb_type": "Custom metaclass with __getattribute__", + "Context": "type({self.value}) = {metacls}", + "Explanation": "Dynamo does not trace attribute access on classes whose metaclass overrides __getattribute__", + "Hints": [] + } + ], "GB0310": [ { "Gb_type": "torch.cond: improper predicate", @@ -4063,6 +4262,16 @@ ] } ], + "GB1140": [ + { + "Gb_type": "_autograd_grad with unsupported argument type", + "Context": "got {type(var).__name__}", + "Explanation": "_autograd_grad() received an argument of type {type(var).__name__} which is not supported. Expected tensor or sequence of tensors.", + "Hints": [ + "Ensure outputs and inputs arguments are tensors or sequences of tensors." + ] + } + ], "GB0325": [ { "Gb_type": "torch.map: kwargs not supported", @@ -4083,6 +4292,16 @@ ] } ], + "GB8435": [ + { + "Gb_type": "missing_mp_subscript_impl", + "Context": "mp_subscript_impl not defined for {type(self).__name__}", + "Explanation": "'{self.python_type_name()}' subscript is not yet supported by Dynamo.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0327": [ { "Gb_type": "executorch_call_delegate: kwargs not supported", @@ -4247,6 +4466,26 @@ ] } ], + "GB4501": [ + { + "Gb_type": "Reconstruction of FakeIdVariable", + "Context": "str(self.value)", + "Explanation": "A fake id produced by id() on a compile-time container cannot be reconstructed across a graph break.", + "Hints": [ + "Avoid using id() on containers in code that may graph-break." + ] + } + ], + "GB4198": [ + { + "Gb_type": "Attempted to copy.deepcopy a tensor", + "Context": "copy.deepcopy({self})", + "Explanation": "Dynamo does not support copy.deepcopy() on tensors.", + "Hints": [ + "Avoid calling copy.deepcopy() on tensors inside compiled regions." + ] + } + ], "GB0344": [ { "Gb_type": "wrap_with_autocast: expected constant arg", @@ -4258,6 +4497,14 @@ } ], "GB0345": [ + { + "Gb_type": "strict_mode: improper args", + "Context": "args: {args}, kwargs: {kwargs}", + "Explanation": "strict_mode higher order op expects flat inputs (list/tuple/dict/set)", + "Hints": [ + "Your code may result in an error when running in eager. Please double check that your code doesn't contain a similar error when actually running eager/uncompiled. You can do this by removing the `torch.compile` call, or by using `torch.compiler.set_stance(\"force_eager\")`. " + ] + }, { "Gb_type": "strict_mode: improper args", "Context": "args: {args}, kwargs: {kwargs}", @@ -4299,6 +4546,14 @@ "Hints": [] } ], + "GB8887": [ + { + "Gb_type": "NamedTupleVariable.__setattr__ bad args", + "Context": "{len(args)} args and {len(kwargs)} kwargs", + "Explanation": "Expected exactly 2 positional args for __setattr__.", + "Hints": [] + } + ], "GB0349": [ { "Gb_type": "cannot unwrap variable for check_meta_consistency", @@ -4397,6 +4652,19 @@ ] } ], + "GB8985": [ + { + "Gb_type": "requires_grad_() intermediate leaked as output", + "Context": "call_method {self} requires_grad_", + "Explanation": "An intermediate tensor with requires_grad_() called on it (or a tensor derived from it) is returned from the compiled region. Graph breaking here to preserve partial acceleration.", + "Hints": [ + "Call .detach() before returning if you only need values.", + "Consume the gradient inside the compiled function ", + "(call backward() and use .grad), ", + "or move requires_grad_() outside torch.compile." + ] + } + ], "GB0358": [ { "Gb_type": "optimizer: pending mutation on parameter", @@ -4423,6 +4691,16 @@ "Hints": [] } ], + "GB7286": [ + { + "Gb_type": "nb_float_impl not implemented", + "Context": "{type(self).__name__} has nb_float slot but no nb_float_impl override", + "Explanation": "The type {self.python_type_name()} has an nb_float C slot but the corresponding VariableTracker doesn't implement nb_float_impl.", + "Hints": [ + "It may be possible to write Dynamo tracing rules for this code. Please report an issue to PyTorch if you encounter this graph break often and it is causing performance issues." + ] + } + ], "GB0361": [ { "Gb_type": "triton kernel unsupported feature", diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index b70ab1002e051..921757b0645a1 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -30,6 +30,15 @@ def stash_graph_created_object(obj: Any) -> Any: return obj +CURRENT_STREAM_INDEX = 0 + + +def set_external_object_by_index(index: int, value: Any) -> None: + """Update an entry in the external object registry at runtime.""" + keep_alive.append(value) + index_to_external_object_weakref[index] = weakref.ref(value) + + def get_external_object_by_index(index: int) -> Any: assert index in index_to_external_object_weakref, ( "Index not registered in index_to_user_object_weakref" diff --git a/torch/_dynamo/graph_id_filter.py b/torch/_dynamo/graph_id_filter.py index d13cf21619663..173005605b2db 100644 --- a/torch/_dynamo/graph_id_filter.py +++ b/torch/_dynamo/graph_id_filter.py @@ -185,7 +185,8 @@ class GraphBackendRouter(_GraphRouterBase[Any]): The router parses a configuration string with rules in the format: "filter1:backend1;filter2:backend2;..." - Rules are evaluated in order, and the first matching rule wins. + If a graph ID matches multiple rules with different backends, a ValueError + is raised. Examples: "0-5:eager;>5:inductor" - IDs 0-5 use eager, rest use inductor @@ -197,6 +198,7 @@ class GraphBackendRouter(_GraphRouterBase[Any]): """ def __init__(self, config_str: str) -> None: + self._backend_names: dict[int, str] = {} super().__init__(config_str, "backend") def _parse_value_str(self, value_str: str) -> Any | None: @@ -209,8 +211,21 @@ def _parse_value_str(self, value_str: str) -> Any | None: # Register the backend so its reset() is called during torch._dynamo.reset() assert backend is not None, "Invalid override backend: " + value_str cached_backends.setdefault(id(backend), backend) + self._backend_names[id(backend)] = value_str return backend + def _match_rules(self, graph_id: int) -> Any | None: + """Match rules with conflict detection for overlapping filters.""" + matches = {id(backend): backend for f, backend in self._rules if graph_id in f} + if len(matches) > 1: + names = [self._backend_names[bid] for bid in matches] + raise ValueError( + f"Conflicting backend override for graph {graph_id}: matched {names}" + ) + if matches: + return next(iter(matches.values())) + return None + def __repr__(self) -> str: if not self._rules: return "GraphBackendRouter(empty)" diff --git a/torch/_dynamo/graph_region_tracker.py b/torch/_dynamo/graph_region_tracker.py index 6df9feb6bad38..05567a84b8ea6 100644 --- a/torch/_dynamo/graph_region_tracker.py +++ b/torch/_dynamo/graph_region_tracker.py @@ -271,7 +271,7 @@ def track_node(self, tx: InstructionTranslatorBase, node: Node) -> None: duplicates.append(node) self.node_to_duplicates[node] = duplicates except NodeHashException as e: - log.debug("Unable to hash node %s with exception %s", node, e) # noqa: G200 + log.debug("Unable to hash node %s with exception %s", node, e) def track_node_mutations( self, diff --git a/torch/_dynamo/graph_utils.py b/torch/_dynamo/graph_utils.py index 99345af67b388..68aea6f96e102 100644 --- a/torch/_dynamo/graph_utils.py +++ b/torch/_dynamo/graph_utils.py @@ -1,4 +1,3 @@ -from collections import deque from typing import Any import torch @@ -35,46 +34,48 @@ def _get_flat_args_unique( def _detect_cycles( graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]] ) -> str: - current_path: deque[Node] = deque() - current_path_set: set[Node] = set() - pending: deque[tuple[Node, Node]] = deque() - - def add_to_current_path(node: Node) -> None: - current_path.append(node) - current_path_set.add(node) - - def pop_current_path() -> None: - node = current_path.pop() - current_path_set.remove(node) - - def current_path_head() -> Node: - return current_path[-1] - - for origin in graph.find_nodes(op="output"): - current_path.clear() - current_path_set.clear() - add_to_current_path(origin) - for child in _get_flat_args_unique(origin, node_to_additional_deps): - pending.append((child, origin)) - - while pending: - cur_node, parent = pending.pop() - - # handle backtracking - while current_path and current_path_head() != parent: - pop_current_path() - - if not isinstance(cur_node, Node): - continue - - if cur_node in current_path_set: - current_path.append(cur_node) - return f"cycle detected in path: {current_path}" - - add_to_current_path(cur_node) - - for child in _get_flat_args_unique(cur_node, node_to_additional_deps): - pending.append((child, cur_node)) + # States: 0=Unvisited, 1=Visiting, 2=Visited(Safe) + state: dict[Node, int] = {} + + for root in reversed(graph.nodes): + if root in state: + continue + + # Stack holds (current_node, children_iterator). + # Using an iterator allows us to pause and resume processing a node's children. + stack = [(root, iter(_get_flat_args_unique(root, node_to_additional_deps)))] + state[root] = 1 # Visiting + + while stack: + parent, children = stack[-1] + + try: + child = next(children) + + if not isinstance(child, Node): + continue + + child_state = state.get(child, 0) + + if child_state == 1: + # Back-edge: child is on the current DFS path -> cycle + cycle_path = [node for node, _ in stack] + [child] + return f"cycle detected in path: {cycle_path}" + + if child_state == 0: + state[child] = 1 + stack.append( + ( + child, + iter(_get_flat_args_unique(child, node_to_additional_deps)), + ) + ) + # child_state == 2 means already verified safe; skip. + + except StopIteration: + # All children processed — mark safe and pop. + stack.pop() + state[parent] = 2 return "no cycle detected" diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e5fd1f4e3f8e6..bf3b47f890b1b 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -26,6 +26,7 @@ import importlib import inspect import io +import itertools import logging import math import pickle @@ -178,19 +179,18 @@ dataclass_fields, dict_keys, get_current_stream, - get_custom_getattr, get_torch_function_mode_stack, get_torch_function_mode_stack_at, guard_failures, istype, key_is_id, key_to_id, + normalize_count_iter, normalize_range_iter, orig_code_map, tensor_always_has_static_shape, tuple_iterator_getitem, tuple_iterator_len, - unpatched_nn_module_getattr, verify_guard_fn_signature, ) @@ -200,6 +200,7 @@ guard_manager_testing_hook_fn: Callable[[Any, Any, Any], Any] | None = None +_COUNT_ITERATOR_TYPE = type(itertools.count()) try: import numpy as np @@ -582,7 +583,7 @@ def visit(node: GuardManager) -> list[GuardManager]: def populate_diff_guard_manager(self) -> None: self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources) - # Ensure that that C++ side points to the updated diff guard manager. + # Ensure that C++ side points to the updated diff guard manager. # When a new GuardManagerWrapper is created, it does not have a # cache_entry attribute, so it relies on the CacheEntry constructor to # set the diff_guard_root in C++. But once it is saved in the Dynamo @@ -751,6 +752,7 @@ def _get_closure_vars() -> dict[str, object]: "___dict_version": dict_version, "___dict_contains": lambda a, b: dict.__contains__(b, a), "___tuple_iterator_len": tuple_iterator_len, + "___normalize_count_iter": normalize_count_iter, "___normalize_range_iter": normalize_range_iter, "___tuple_iterator_getitem": tuple_iterator_getitem, "___dataclass_fields": dataclass_fields, @@ -910,13 +912,7 @@ def raise_local_type_error(obj: Any) -> NoReturn: def should_optimize_getattr_on_nn_module(value: Any) -> bool: - # If inline_inbuilt_nn_modules flag is True, Dynamo has already traced - # through the __getattr__, and therefore it is always safe to optimize - # getattr on nn modules. - return isinstance(value, torch.nn.Module) and ( - config.inline_inbuilt_nn_modules - or get_custom_getattr(value) is unpatched_nn_module_getattr - ) + return isinstance(value, torch.nn.Module) @dataclasses.dataclass(frozen=True) @@ -1103,6 +1099,16 @@ def check_closure(value: Any, metadata: Any) -> bool: return id(value) == metadata +def _constant_subclass_base_value(value: Any) -> Any: + """Extract the base constant value from a constant subclass instance.""" + from .variables.user_defined import _CONSTANT_BASE_TYPES + + for t in _CONSTANT_BASE_TYPES: + if isinstance(value, t): + return t(value) # pyrefly: ignore[bad-argument-type] + raise TypeError(f"Not a constant subclass: {type(value)}") + + def register_guard_check_spec( get_metadata_fn, eval_fn, @@ -1168,9 +1174,7 @@ def __init__( self.source_ref = source_ref self.lookup_weakrefs = lookup_weakrefs self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope} - self.src_get_value_cache: weakref.WeakKeyDictionary[Source, object] = ( - weakref.WeakKeyDictionary() - ) + self.src_get_value_cache: dict[Source, object] = {} self.runtime_global_scope = runtime_global_scope or global_scope self.scope["__builtins__"] = builtins.__dict__.copy() for ( @@ -2644,6 +2648,42 @@ def CONSTANT_MATCH(self, guard: Guard) -> None: else: self.EQUALS_MATCH(guard) + @register_guard_check_spec( + get_metadata_fn=lambda guard, value: _constant_subclass_base_value(value), + eval_fn=lambda value, metadata: _constant_subclass_base_value(value) + == metadata, + ) + def CONSTANT_SUBCLASS_MATCH(self, guard: Guard) -> None: + """Guard for subclasses of constant types (int, float, str, etc.). + + Extracts the base value using the base type's converter (e.g., + int.__int__) to avoid calling user-overridden __eq__. + """ + from .variables.user_defined import _CONSTANT_BASE_TYPES + + val = self.get(guard) + ref = self.arg_ref(guard) + + # Find the constant base type + base_type = None + for t in _CONSTANT_BASE_TYPES: + if isinstance(val, t): + base_type = t + break + assert base_type is not None + + base_value = base_type(val) + code = [f"{base_type.__name__}({ref}) == {base_value!r}"] + + def check_fn(x: Any) -> bool: + return base_type(x) == base_value + + self.get_guard_manager(guard).add_lambda_guard( + check_fn, + get_verbose_code_parts(code, guard), + guard.user_stack, + ) + @register_guard_check_spec( get_metadata_fn=lambda guard, value: value, eval_fn=lambda value, metadata: value is metadata, @@ -2816,6 +2856,30 @@ def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None: guard.user_stack, ) + @register_guard_check_spec( + get_metadata_fn=lambda guard, value: (type(value), normalize_count_iter(value)), + eval_fn=lambda value, metadata: ( + type(value) is metadata[0] and normalize_count_iter(value) == metadata[1] + ), + ) + def COUNT_ITERATOR_MATCH(self, guard: Guard) -> None: + ref = self.arg_ref(guard) + value = self.get(guard) + count_type = type(value) + normalized_count_iter = normalize_count_iter(value) + + def guard_fn(x: Any) -> bool: + return ( + type(x) is count_type + and normalize_count_iter(x) == normalized_count_iter + ) + + code = [f"___normalize_count_iter({ref}) == {normalized_count_iter}"] + self._set_guard_export_info(guard, code) + self.get_guard_manager(guard).add_lambda_guard( + guard_fn, get_verbose_code_parts(code, guard), guard.user_stack + ) + # Multi-source guard (two inputs aliasing) — not expressible as a # single source → value check. # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards @@ -3257,8 +3321,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Any | None = None) -> None: if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module(): return # For tensors that are part of the Dynamo extracted Fx graph module, an - # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these - # will be lifted as inputs and have a TENSOR_MATCH guard. + # ID_MATCH suffices. if match_on_id_for_tensor(guard): self.ID_MATCH(guard) else: @@ -3425,7 +3488,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Any | None = None) -> None: if not static: if hasattr(value, "_dynamo_dynamic_indices"): dynamic_indices = value._dynamo_dynamic_indices - code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 + code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" code.append(code_part) self.get_guard_manager(guard).add_dynamic_indices_guard( dynamic_indices, @@ -3452,7 +3515,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Any | None = None) -> None: # tensors that have the attribute when compile-time didn't. if hasattr(value, "_dynamo_unbacked_indices"): shape_ids = getattr(value, "_dynamo_shape_ids", None) - code_part = f"((getattr({tensor_name}, '_dynamo_shape_ids', None) == {shape_ids!r}) if hasattr({tensor_name}, '_dynamo_unbacked_indices') else True)" # noqa: B950 + code_part = f"((getattr({tensor_name}, '_dynamo_shape_ids', None) == {shape_ids!r}) if hasattr({tensor_name}, '_dynamo_unbacked_indices') else True)" code.append(code_part) self.get_guard_manager(guard).add_lambda_guard( lambda x, expected=shape_ids: ( @@ -3471,7 +3534,7 @@ def TENSOR_MATCH(self, guard: Guard, value: Any | None = None) -> None: # tensors that have the attribute when compile-time didn't. if hasattr(value, "_dynamo_unbacked_indices"): unbacked_bounds = getattr(value, "_dynamo_unbacked_bounds", None) - code_part = f"((getattr({tensor_name}, '_dynamo_unbacked_bounds', None) == {unbacked_bounds!r}) if hasattr({tensor_name}, '_dynamo_unbacked_indices') else True)" # noqa: B950 + code_part = f"((getattr({tensor_name}, '_dynamo_unbacked_bounds', None) == {unbacked_bounds!r}) if hasattr({tensor_name}, '_dynamo_unbacked_indices') else True)" code.append(code_part) self.get_guard_manager(guard).add_lambda_guard( lambda x, expected=unbacked_bounds: ( @@ -3762,12 +3825,9 @@ def _unpickle_traceable_wrapper_subclass( pytype: type, dispatch_keys_raw: int, ctx: Any, - inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]], + inner_data: list[tuple[str, Any]], ) -> torch.Tensor: - # Unpickle the inner tensor components. These could also be subclass instances. - inner_tensors = {} - for attr, unpickle_func, unpickle_func_args in inner_data: - inner_tensors[attr] = unpickle_func(*unpickle_func_args) + inner_tensors = dict(inner_data) outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride() out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined] @@ -3801,6 +3861,10 @@ def _unpickle_mapping_proxy( def _unpickle_dict_keys(cls, elems: list[Any]) -> Any: return dict.fromkeys(elems).keys() + @classmethod + def _unpickle_count_iter(cls, item: int, step: int) -> itertools.count[int]: + return itertools.count(item, step) + @classmethod def _unpickle_fsdp_module_type( cls, original_type: type[torch.nn.Module] @@ -3902,8 +3966,7 @@ def reducer_override( inner = getattr(obj, attr) if isinstance(inner, torch.Tensor): self.guard_tree_values[id(inner)] = inner - func, args_tuple = self.reducer_override(inner) - inner_data.append((attr, func, args_tuple)) + inner_data.append((attr, inner)) return type(self)._unpickle_traceable_wrapper_subclass, ( torch.empty_like(obj, device="meta"), @@ -3984,6 +4047,11 @@ def reducer_override( elif isinstance(obj, types.MappingProxyType): return type(self)._unpickle_mapping_proxy, (obj.copy(),) + elif type(obj) is _COUNT_ITERATOR_TYPE: + item, step = normalize_count_iter(obj) + if item is not NotImplemented and step is not NotImplemented: + return type(self)._unpickle_count_iter, (item, step) + elif isinstance(obj, torch._dynamo.utils.dict_keys): return type(self)._unpickle_dict_keys, (list(obj),) @@ -4083,7 +4151,7 @@ def make_guard_filter_entry(guard: Guard, builder: GuardBuilder) -> GuardFilterE # doesn't exist. value = builder.get(guard) has_value = True - except: # noqa: B001,E722 + except: # noqa: E722 value = MISSING has_value = False is_global = get_global_source_name(guard.originating_source) is not None @@ -4115,7 +4183,7 @@ def pickle_guards_state( try: type(base).__new__(type(base)) empty_values[id(base)] = base - except: # noqa: E722, B001 + except: # noqa: E722 pass elif id(leaf) not in guard_tree_values: # TODO See if we have lift this branch as the first one. @@ -5179,7 +5247,7 @@ def guard_error_hook( for guard in guard_manager.code_parts: try: eval(guard, guard_manager.global_scope, local_scope) - except: # noqa: B001,E722 + except: # noqa: E722 print(f"Malformed guard:\n{guard}") diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index 0467ea1ba1164..70f2d7dfc56fa 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -117,10 +117,8 @@ def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool: return True if hasattr(obj, "torchdynamo_force_dynamic"): return obj.torchdynamo_force_dynamic - if ( - isinstance(obj, torch.nn.Module) - and config.inline_inbuilt_nn_modules - and (not is_export or config.install_free_tensors) + if isinstance(obj, torch.nn.Module) and ( + not is_export or config.install_free_tensors ): return True diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index ac25a1f977c33..67fc3b00b4f06 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -32,6 +32,7 @@ import sys import time import traceback +import types import warnings import weakref from collections.abc import Callable, Generator, Sequence @@ -741,6 +742,14 @@ def __init__( # and restore_graphstate self.timestamp = 0 + # Maps stream id (id(stream_value)) → user stack trace for input + # mutations on that stream. Used to error when an event records on a + # stream that already has an input mutation (the epilogue copy_() + # wouldn't be captured). We key by id() of the underlying + # torch.Stream so we can peek lazy variables without realizing them. + self._input_mutation_streams: dict[int, traceback.StackSummary] = {} + self._last_checked_input_versions: dict[int, int] | None = None + # A list of register_finalizer_fns to apply to the output graph module self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = [] @@ -864,6 +873,62 @@ def __init__( self.used_inlined_inbuilt_modules_names: OrderedSet[str] = OrderedSet() self.attr_source_cache: dict[tuple[Source, str], AttrSource] = {} + self._cached_replayed_side_effect_source_refs: tuple[str, ...] | None = None + + def get_replayed_side_effect_source_refs( + self, *, populate_export_metadata: bool = False + ) -> list[str]: + """Return Python-side effect sources that Dynamo replays outside the FX graph.""" + if ( + not populate_export_metadata + and not self.side_effects.id_to_variable + and self._cached_replayed_side_effect_source_refs is not None + ): + return list(self._cached_replayed_side_effect_source_refs) + + from torch.export._trace import _ExportModuleSpecTrackerDict + + potential_side_effects = [] + for var in self.side_effects._get_modified_vars(): + if hasattr(var, "mutation_type"): + mut_type = var.mutation_type + # Skip codegen-specific mutations that never materialize as + # externally visible Python side effects. + if isinstance( + mut_type, (AttributeMutationExisting, ValueMutationExisting) + ): + if isinstance(var, UserDefinedDictVariable) and isinstance( + var.value, _ExportModuleSpecTrackerDict + ): + if populate_export_metadata: + assert var._base_vt is not None + for ( + k, + v, + ) in ( + var._base_vt.items.items() # pyrefly: ignore[missing-attribute] + ): + # pyrefly: ignore [implicit-any] + specs = {} + # pyrefly: ignore[missing-attribute] + for k_spec, val in v.items.items(): + specs[k_spec.vt.as_python_constant()] = ( + val.as_python_constant() + ) + assert ["in_spec", "out_spec"] == list(specs.keys()) + self.export_metadata.module_call_spec[ + # pyrefly: ignore[missing-attribute] + k.vt.as_python_constant() + ] = specs + # export uses tracepoint pass to dump submodule inp/out spec + # into global state, so we filter it here + if not ( + isinstance(var, UserDefinedDictVariable) + and isinstance(var.value, _ExportModuleSpecTrackerDict) + ): + potential_side_effects.append(var) + + return [_get_source_debug_name(var.source) for var in potential_side_effects] def get_chained_attr_source(self, base: Source, path: str) -> AttrSource: parts = path.split(".") @@ -1092,6 +1157,75 @@ def is_root_tracer(self) -> bool: # Helper to tell if we are inside the higher order operator tracing. return len(self.tracers) == 1 + def check_input_mutation_on_current_stream( + self, tx: "InstructionTranslatorBase" + ) -> None: + """Record which stream index has input mutations by comparing current + tensor versions against the versions captured at graph input creation.""" + if not hasattr(tx, "symbolic_stream_state"): + return + if not tx.symbolic_stream_state.in_stream_context(): + return + + tracer = self.root_tracer + if self._last_checked_input_versions is None: + self._last_checked_input_versions = dict( + enumerate(tracer._input_versions_at_beginning) + ) + + cur_stream_index = tx.symbolic_stream_state.cur_stream_id() + input_idx = 0 + for node in tracer.graph.nodes: + if node.op != "placeholder": + break + example_value = node.meta.get("example_value") + if not isinstance(example_value, torch.Tensor): + continue + prev_version = self._last_checked_input_versions.get(input_idx) + cur_version = example_value._version + if prev_version is not None and cur_version > prev_version: + if cur_stream_index not in self._input_mutation_streams: + self._input_mutation_streams[cur_stream_index] = ( + TracingContext.extract_stack() + ) + self._last_checked_input_versions[input_idx] = cur_version + input_idx += 1 + + _EVENT_INPUT_MUTATION_FIX = ( + "To fix this, either:\n" + " 1. Move the input mutation after the event.record() call.\n" + " 2. Record the event outside the compiled region:\n" + " compiled_fn(x)\n" + " event.record(stream) # after torch.compile returns\n" + " 3. Insert a graph break before recording:\n" + " torch._dynamo.graph_break()\n" + " event.record(stream)\n" + " 4. Record the event on a stream that has no input mutations." + ) + + def check_event_record_after_input_mutation(self, stream_index: int) -> None: + """Error if an event is being recorded on a stream that already has + an input mutation. Called at record time so ordering is naturally + respected — records before mutations won't trigger this.""" + if stream_index not in self._input_mutation_streams: + return + + mutation_stack = self._input_mutation_streams[stream_index] + record_stack = TracingContext.extract_stack() + + msg = ( + "An event was recorded on a stream where a graph input was " + "previously mutated. The input mutation is applied via copy_() " + "in the runtime epilogue after the graph executes, so the event " + "would not capture the mutation, leading to incorrect " + "synchronization.\n\n" + "Input mutation occurred here:\n" + f"{''.join(mutation_stack.format())}\n" + "Event record occurred here:\n" + f"{''.join(record_stack.format())}\n" + self._EVENT_INPUT_MUTATION_FIX + ) + raise RuntimeError(msg) + @property def graph(self) -> torch.fx.Graph: return self.current_tracer.graph @@ -1737,6 +1871,7 @@ def compile_subgraph( # i.e. last element corresponds to root frame (1), # first element corresponds to current frame (N) all_stack_values = [] + # pyrefly: ignore [implicit-any] all_stack_locals_metas = [] cur_tx: InstructionTranslatorBase | None = tx while cur_tx is not None: @@ -1937,7 +2072,7 @@ def compile_subgraph( elif ( vt.source is not None and (source := getattr(vt.source, "base", None)) # type: ignore[assignment] - and source.is_input + and getattr(source, "is_input", False) ): self.export_metadata.output_return_type[idx] = ( "input", @@ -2096,43 +2231,9 @@ def compile_subgraph( ) if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]: - from torch.export._trace import _ExportModuleSpecTrackerDict - - potential_side_effects = [] - for var in self.side_effects._get_modified_vars(): - if hasattr(var, "mutation_type"): - mut_type = var.mutation_type - # Make sure to skip codegen specific mutations - if isinstance( - mut_type, (AttributeMutationExisting, ValueMutationExisting) - ): - if isinstance(var, UserDefinedDictVariable) and isinstance( - var.value, _ExportModuleSpecTrackerDict - ): - for k, v in var.items.items(): - # pyrefly: ignore [implicit-any] - specs = {} - # pyrefly: ignore[missing-attribute] - for k_spec, val in v.items.items(): - specs[k_spec.vt.as_python_constant()] = ( - val.as_python_constant() - ) - assert ["in_spec", "out_spec"] == list(specs.keys()) - self.export_metadata.module_call_spec[ - # pyrefly: ignore[missing-attribute] - k.vt.as_python_constant() - ] = specs - # export uses tracepoint pass to dump submodule inp/out spec - # into global state, so we filter it here - if not ( - isinstance(var, UserDefinedDictVariable) - and isinstance(var.value, _ExportModuleSpecTrackerDict) - ): - potential_side_effects.append(var) - side_effect_refs = [ - _get_source_debug_name(var.source) for var in potential_side_effects - ] - + side_effect_refs = self.get_replayed_side_effect_source_refs( + populate_export_metadata=True + ) if side_effect_refs: if torch._dynamo.config.side_effect_replay_policy == "warn": warnings.warn( @@ -2381,11 +2482,80 @@ def _validate_outputs_safe_for_autograd_nodes( # if any node was consumed by autograd.grad reachable_grad_fns = collect_reachable_grad_fns([(fake_tensor, None)]) if reachable_grad_fns & self.autograd_grad_consumed_grad_fns: - # Set the flag to graph break at autograd.grad on retry - tx.speculation_log.graph_break_on_autograd_grad = True - raise exc.AutogradGradRestartAnalysis( - restart_reason="autograd.grad consumed grad_fns of returned tensors" + # Record info about the leaked tensor for the error message + tensor_name = str(var.source) if var.source else var.proxy.node.name + tx.speculation_log.autograd_grad_leaked_tensors.append(tensor_name) + + if tx.speculation_log.autograd_grad_leaked_tensors: + # Set the flag to graph break at autograd.grad on retry + tx.speculation_log.graph_break_on_autograd_grad = True + raise exc.AutogradGradRestartAnalysis( + restart_reason="autograd.grad consumed grad_fns of returned tensors" + ) + + def _check_requires_grad_intermediate_outputs( + self, rv: list["VariableTracker"], tx: "InstructionTranslatorBase" + ) -> None: + """Skip frame if a source-less requires_grad_() intermediate leaks as output. + + AOTAutograd's functionalization drops requires_grad_() on intermediates, + so returning them (or tensors derived from them) produces wrong results. + We detect this via FX graph reachability: find the requires_grad_() nodes + for source-less intermediates, then check if any output is downstream. + """ + from .variables.tensor import TensorVariable + + # Collect FX nodes for source-less requires_grad_() intermediates + tainted_nodes: set[torch.fx.Node] = set() + for v in self.leaf_var_creation_order: + if isinstance(v, TensorVariable) and not v.source: + tainted_nodes.add(v.as_proxy().node) + + if not tainted_nodes: + return + + # Propagate taint forward through the FX graph + for node in self.graph.nodes: + if node in tainted_nodes: + continue + if any(inp in tainted_nodes for inp in node.all_input_nodes): + tainted_nodes.add(node) + + # Check leaked outputs: tainted + requires_grad means the output + # carries autograd state that AOTAutograd would silently drop. + # Detached outputs (requires_grad=False) are fine — no autograd to lose. + for var in rv: + if ( + isinstance(var, TensorVariable) + and var.requires_grad + and var.as_proxy().node in tainted_nodes + ): + msg = ( + "An intermediate tensor that had requires_grad_() called " + "on it (or a tensor derived from it) is being returned " + "from the compiled region. AOTAutograd's functionalization " + "drops the requires_grad_() effect on graph outputs, " + "producing wrong results. If you only need the tensor " + "values without gradients, call .detach() before returning." ) + if tx.one_graph: + unimplemented( + gb_type="returning intermediate with requires_grad_()", + context="graph output depends on source-less requires_grad_()", + explanation=msg, + hints=[ + "If you only need the tensor values without gradients, " + "call .detach() before returning.", + "Consume the gradient inside the compiled function " + "(call backward() and use .grad), " + "or move requires_grad_() outside torch.compile.", + ], + ) + else: + tx.speculation_log.graph_break_on_requires_grad_ = True + raise exc.RequiresGradRestartAnalysis( + restart_reason="source-less requires_grad_() intermediate leaked as output" + ) def compile_and_call_fx_graph( self, @@ -2414,6 +2584,11 @@ def compile_and_call_fx_graph( assert isinstance(rv, list) assert isinstance(root, FakeRootModule) + # Error on source-less requires_grad_() outputs. + # Must run before autograd validation since detaching resolves the + # "consumed grad_fn" conflict for backward-consumed intermediates. + self._check_requires_grad_intermediate_outputs(rv, tx) + # Check if autograd.grad is used with outputs that require grad # This would cause double backward issues in aot_autograd self._validate_outputs_safe_for_autograd_nodes(rv, tx) @@ -2483,17 +2658,18 @@ def compile_and_call_fx_graph( # If dynamo produces a graph with parameters, skip package stuff # Bypass output graph self.bypass_package( - "Graph contains named parameters: either inline_inbuilt_nn_modules=False or there are static addresses.", - inline_builtin_nn_modules=torch._dynamo.config.inline_inbuilt_nn_modules, + "Graph contains named parameters due to static addresses.", gm=gm.print_readable( print_output=False, include_stride=True, include_device=True ), ) if self.package is not None: - gm._backend_id = name + gm._backend_id = name # pyrefly: ignore[bad-argument-type] + # pyrefly: ignore[bad-argument-type] gm.compile_subgraph_reason = self.compile_subgraph_reason + # pyrefly: ignore[bad-argument-type] gm.meta["dynamo_flat_name_to_original_fqn"] = ( self.dynamo_flat_name_to_original_fqn.copy() ) @@ -2545,6 +2721,7 @@ def compile_and_call_fx_graph( # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode self.tracing_context.fake_mode = backend_fake_mode + gm.graph.lint() with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm, self.example_inputs()) @@ -3035,6 +3212,34 @@ def add_output_instructions(self, prefix: list[Instruction]) -> None: self.output_instructions.extend(prefix) self.should_exit = True + def install_resume_function_global( + self, + name: str, + code: types.CodeType, + f_globals: dict[str, Any], + ) -> None: + """Install a resume function as a global. + + When the code has freevars, installs a factory that creates the + function with correct globals and closure (since MAKE_FUNCTION + inherits the current frame's globals, which is wrong for resume + functions from inlined frames). Otherwise installs the function + directly. + """ + if code.co_freevars: + + def _make_fn( + closure: tuple[types.CellType, ...], + ) -> types.FunctionType: + return types.FunctionType(code, f_globals, name, None, closure) + + self.install_global_unsafe(name, _make_fn) + else: + self.install_global_unsafe( + name, + types.FunctionType(code, f_globals, name), + ) + def install_global_unsafe(self, name: str, value: Any) -> None: """ WARNING: prefer the safer `install_global_by_id/install_global`. @@ -3086,6 +3291,9 @@ def cleanup(self) -> None: del node.meta["grapharg"] self.real_value_cache.clear() self.input_name_to_proxy.clear() + self._cached_replayed_side_effect_source_refs = tuple( + self.get_replayed_side_effect_source_refs() + ) self.side_effects.clear() self.variable_tracker_cache.clear() self.mro_source_cache.clear() @@ -3403,10 +3611,10 @@ def __init__( "Inference mode is supposed to be disabled during compilation. Please open an issue." ) - self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet() + self.tracked_proxyable_vt: OrderedSet[VariableTracker] = OrderedSet() - def record_tensor_or_symint_vt(self, vt: VariableTracker) -> None: - self.tracked_tensor_or_symint_vt.add(vt) + def record_proxyable_vt(self, vt: VariableTracker) -> None: + self.tracked_proxyable_vt.add(vt) # preserve original meta if it is available def _maybe_preserve_original_meta( @@ -3579,7 +3787,6 @@ def get_trace_call_log_str() -> str: rv.node.meta["source_fn_stack"] = self.source_fn_stack + [stack] elif kind == "call_module": if self.parent is not None: - # TODO can remove once inline_inbuilt_nn_modules is always True unimplemented( gb_type="Invoking an nn.Module inside a higher order operator", context=f"Higher order op name: {self.source_target}", @@ -3613,7 +3820,6 @@ def get_trace_call_log_str() -> str: ] elif kind == "call_module": if self.parent is not None: - # TODO can remove once inline_inbuilt_nn_modules is always True unimplemented( gb_type="Invoking an nn.Module inside a HigherOrderOperator", context="", @@ -3783,7 +3989,7 @@ def create_graph_input( # # 1. When create_graph_input for a tensor that has symbolic shapes, # we look for basic symbols in its size and stride, we check if the symbol is bound - # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder + # in current graph (i.e. bound_symbols), if it's not bound, we'll create a placeholder # for it then recursively check its parent, creates ph if not bound at parent until. # reachting the top-level, where we require a source is attached to the proxy. # diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index 39aaf7b0bb1e2..33ce8f55b922a 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -555,6 +555,7 @@ def _compile_frame_context( # under the same cache entry, so we don't have recompile ids # i.e. If cold start had 0/0, 0/1, 1/0, 1/1, these would be # collapsed into 0/0, 1/0 on warm. + # pyrefly: ignore [deprecated] @contextlib.contextmanager def _ctx() -> Iterator[None]: increment_frame() @@ -757,12 +758,12 @@ def add_resume_function( self, python_code: types.CodeType, python_module: str, - function_name: str | None, + function_name: str, ) -> None: self._add_function( python_code, python_module, - function_name=_FunctionId(function_name) if function_name else None, + function_name=_FunctionId(function_name), install_to_global=True, ) self._resume_codes.add(python_code) @@ -828,8 +829,29 @@ def install(self, backends: dict[_BackendId, Any]) -> None: target_code = code if entry.install_to_global: for function_name in entry.function_names: - fn = types.FunctionType(code, module.__dict__, function_name) - self._install_global(module, function_name, fn) + if code.co_freevars: + # Resume functions with freevars need a factory + # that takes a closure tuple, matching + # install_resume_function_global in output_graph.py. + f_globals = module.__dict__ + fn_name = function_name + + def _make_fn( + closure: tuple[types.CellType, ...], + _code: types.CodeType = code, + _globals: dict[str, Any] = f_globals, + _name: str = fn_name, + ) -> types.FunctionType: + return types.FunctionType( + _code, _globals, _name, None, closure + ) + + self._install_global(module, function_name, _make_fn) + else: + fn = types.FunctionType( + code, module.__dict__, function_name + ) + self._install_global(module, function_name, fn) if entry.code_source: target_code = _lookup_code(entry) diff --git a/torch/_dynamo/pgo.py b/torch/_dynamo/pgo.py index ee27fbd91947c..80bb6ddfda1cc 100644 --- a/torch/_dynamo/pgo.py +++ b/torch/_dynamo/pgo.py @@ -259,6 +259,10 @@ class FrameStateSizeEntry: stride: AutoDynamic | AutoUnset | tuple[int | AutoDynamic | InferStride, ...] = ( dataclasses.field(default=auto_unset) ) + excluded_sizes: tuple[int | None, ...] | None = dataclasses.field( + default=None, compare=False + ) + excluded_scalar: int | None = dataclasses.field(default=None, compare=False) def render(self) -> str: # Special cases @@ -384,8 +388,33 @@ def _merge_atom_tup( return tuple(cls._merge_atom(x, y) for x, y in zip(xs, ys)) def __ior__(self, other: Self) -> Self: + # Record current static sizes before merge. For dims that become + # dynamic, the exclusion guard will reject these values so inputs + # fall through to the earlier, more specialized cache entry. + # Already-dynamic dims become None and are ignored by the guard. + # When no dim transitions, clear stale excluded_sizes so later + # compilations don't inherit exclusions from earlier transitions. + new_size = self._merge_atom_tup(self.size, other.size) + if isinstance(self.size, tuple): + if new_size != self.size: + self.excluded_sizes = tuple( + s if type(s) is int else None for s in self.size + ) + elif self.excluded_sizes is not None: + self.excluded_sizes = None + self.size = new_size + # Same idea for scalars: record the static value about to become dynamic. + # Re-derive like excluded_sizes: only set when transitioning from a + # concrete int, clear when already dynamic. + if ( + type(self.scalar) is int + and type(other.scalar) is int + and self.scalar != other.scalar + ): + self.excluded_scalar = self.scalar + elif self.scalar is auto_dynamic and self.excluded_scalar is not None: + self.excluded_scalar = None self.scalar = self._merge_atom(self.scalar, other.scalar) - self.size = self._merge_atom_tup(self.size, other.size) self.stride = self._merge_atom_tup(self.stride, other.stride) return self diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index e856c0d0c3369..c206572aee856 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -11,7 +11,8 @@ from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, Sequence from itertools import repeat as _repeat from operator import eq, ne -from typing import Any, TYPE_CHECKING, TypeVar +from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar +from typing_extensions import TypeIs import torch @@ -62,6 +63,16 @@ def __enter__(self) -> None: pass +# Used by WrappedUserFunctionVariable and similar to inline decorated function +# calls with bytecode backing. Without this, the context enter/exit happens in +# Python-level VT code, so a nested graph break inside `fn` would skip applying +# the context in the compiled fn/resume. By inlining through this polyfill, the +# `with` statement has real bytecode that the resume function can continue from. +def _fn_with_ctx(ctx: Any, fn: Callable[..., T], *args: Any, **kwargs: Any) -> T: + with ctx: + return fn(*args, **kwargs) + + def index( iterator: Iterator[T], item: T, start: int = 0, end: int | None = None ) -> int: @@ -85,11 +96,11 @@ def radians(x: float) -> float: return math.pi / 180.0 * x -def impl_IS_MAPPING(a: object) -> bool: +def impl_IS_MAPPING(a: object) -> TypeIs[Mapping[Any, Any]]: return isinstance(a, Mapping) -def impl_MATCH_SEQUENCE(a: object) -> bool: +def impl_MATCH_SEQUENCE(a: object) -> TypeGuard[Sequence[Any]]: return isinstance(a, Sequence) and not isinstance(a, (str, bytes, bytearray)) @@ -316,7 +327,7 @@ def set_union( set_update(union_set, set2) # frozenset also uses this function - # pyrefly: ignore[not-callable] + # pyrefly: ignore [bad-argument-count, not-callable] return cls(union_set) @@ -404,9 +415,12 @@ def instantiate_user_defined_class_object( ) -> T: obj = cls.__new__(cls, *args, **kwargs) - # Only call __init__ if the object is an instance of the class + # Only call __init__ if the object's type is a subclass of cls. + # CPython uses PyType_IsSubtype(Py_TYPE(obj), type) at the C level, which does NOT + # go through metaclass __instancecheck__. Using isinstance() here would be wrong + # for classes with custom __instancecheck__ (e.g. torch.ByteStorage). # Reference: https://github.com/python/cpython/blob/3.12/Objects/typeobject.c#L1670-L1673 - if isinstance(obj, cls): + if issubclass(type(obj), cls): obj.__init__(*args, **kwargs) return obj @@ -550,20 +564,47 @@ def cmp_eq(a: object, b: object) -> bool: # slow in some corner cases. # if a is b: # return True - result = a.__eq__(b) + if isinstance(a, type): + # Default metaclass equality is identity-based. Preserve the reflected + # operand fallback without tracing through type.__eq__. + if type(a).__eq__ is type.__eq__: + result = True if a is b else NotImplemented + else: + result = type(a).__eq__(a, b) + else: + result = a.__eq__(b) if result is NotImplemented: - result = b.__eq__(a) + if isinstance(b, type): + if type(b).__eq__ is type.__eq__: + result = True if a is b else NotImplemented + else: + result = type(b).__eq__(b, a) + else: + result = b.__eq__(a) return result is not NotImplemented and result def cmp_ne(a: object, b: object) -> bool: - # Check if __ne__ is overridden - if isinstance(type(a).__ne__, types.FunctionType): + if isinstance(a, type): + if type(a).__ne__ is type.__ne__: + result = False if a is b else NotImplemented + else: + result = type(a).__ne__(a, b) + if result is not NotImplemented: + return result + elif isinstance(type(a).__ne__, types.FunctionType): result = a.__ne__(b) if result is not NotImplemented: return result # Fall through to try b.__ne__(a) or cmp_eq - if isinstance(type(b).__ne__, types.FunctionType): + if isinstance(b, type): + if type(b).__ne__ is type.__ne__: + result = False if a is b else NotImplemented + else: + result = type(b).__ne__(b, a) + if result is not NotImplemented: + return result + elif isinstance(type(b).__ne__, types.FunctionType): result = b.__ne__(a) if result is not NotImplemented: return result diff --git a/torch/_dynamo/polyfills/builtins.py b/torch/_dynamo/polyfills/builtins.py index 0c32e8ba5d524..386a96f6db4ef 100644 --- a/torch/_dynamo/polyfills/builtins.py +++ b/torch/_dynamo/polyfills/builtins.py @@ -7,6 +7,7 @@ import builtins import functools import operator +import typing from collections.abc import Callable from typing import TYPE_CHECKING, TypeVar @@ -20,6 +21,7 @@ __all__ = [ "all", "any", + "cast", "enumerate", "sum", ] @@ -120,3 +122,8 @@ def sequence_protocol(iterable): # type: ignore[no-untyped-def] raise TypeError("iter(v, w): v must be a callable") return _CallableIterator(fn, sentinel) + + +@substitute_in_graph(typing.cast, can_constant_fold_through=True) +def cast(typ: type, val: _T) -> _T: # type: ignore[type-var] + return val diff --git a/torch/_dynamo/polyfills/copy.py b/torch/_dynamo/polyfills/copy.py new file mode 100644 index 0000000000000..0188945f038cb --- /dev/null +++ b/torch/_dynamo/polyfills/copy.py @@ -0,0 +1,36 @@ +""" +Python polyfills for copy +""" + +from __future__ import annotations + +from typing import TypeVar + +from ..decorators import substitute_in_graph + + +__all__ = [ + "reduce_ex_user_defined_object", +] + +T = TypeVar("T") + + +@substitute_in_graph(object.__reduce_ex__, skip_signature_check=True) # type: ignore[arg-type] +def reduce_ex_user_defined_object(obj: T, protocol: int, /) -> tuple: # type: ignore[type-arg] + """Traceable polyfill for object.__reduce_ex__ on user-defined objects. + + Returns the same tuple that CPython's _common_reduce produces: + (copyreg.__newobj__, (cls,), obj.__dict__, None, None). + copy._reconstruct then calls cls.__new__(cls) and updates __dict__. + """ + import copyreg + + cls = type(obj) + return ( + copyreg.__newobj__, # pyrefly: ignore[missing-attribute] + (cls,), + obj.__dict__, + None, + None, + ) diff --git a/torch/_dynamo/polyfills/loader.py b/torch/_dynamo/polyfills/loader.py index 9b24295ada3c9..cfce3edfb0ab0 100644 --- a/torch/_dynamo/polyfills/loader.py +++ b/torch/_dynamo/polyfills/loader.py @@ -17,6 +17,7 @@ POLYFILLED_MODULE_NAMES: tuple[str, ...] = ( "_collections", "builtins", + "copy", "functools", "itertools", "operator", diff --git a/torch/_dynamo/polyfills/operator.py b/torch/_dynamo/polyfills/operator.py index cae61df2c0430..9b0eb3e64ae73 100644 --- a/torch/_dynamo/polyfills/operator.py +++ b/torch/_dynamo/polyfills/operator.py @@ -12,11 +12,11 @@ if TYPE_CHECKING: - from collections.abc import Callable, Iterable + from collections.abc import Callable, Iterable, Sequence # Most unary and binary operators are handled by BuiltinVariable (e.g., `pos`, `add`) -__all__ = ["attrgetter", "itemgetter", "methodcaller", "countOf"] +__all__ = ["attrgetter", "concat", "countOf", "iconcat", "itemgetter", "methodcaller"] _T = TypeVar("_T") @@ -69,6 +69,25 @@ def getter(obj: Any) -> tuple[Any, ...]: # type: ignore[misc] return getter +# Reference: https://docs.python.org/3/library/operator.html#operator.concat +@substitute_in_graph(operator.concat, can_constant_fold_through=True) # type: ignore[arg-type] +def concat(a: Sequence[_T], b: Sequence[_T2], /) -> Sequence[_T | _T2]: + return a + b # type: ignore[operator] + + +# Reference: https://docs.python.org/3/library/operator.html#operator.countOf +@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] +def countOf(a: Iterable[_T], b: _T, /) -> int: + return sum(it is b or it == b for it in a) + + +# Reference: https://docs.python.org/3/library/operator.html#operator.iconcat +@substitute_in_graph(operator.iconcat) # type: ignore[arg-type] +def iconcat(a: Sequence[_T], b: Sequence[_T2], /) -> Sequence[_T | _T2]: + a += b # type: ignore[operator] + return a # type: ignore[return-value] + + @overload # pyrefly: ignore [inconsistent-overload] def itemgetter(item: _T, /) -> Callable[[Any], _U]: ... @@ -111,9 +130,3 @@ def caller(obj: Any) -> Any: return getattr(obj, name)(*args, **kwargs) return caller - - -# Reference: https://docs.python.org/3/library/operator.html#operator.countOf -@substitute_in_graph(operator.countOf, can_constant_fold_through=True) # type: ignore[arg-type,misc] -def countOf(a: Iterable[_T], b: _T, /) -> int: - return sum(it is b or it == b for it in a) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index 65c1ff3a2cd84..d834ce1770667 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -1,3 +1,5 @@ +# Owner(s): ["module: pytree"] + """ Python polyfills for torch.utils.pytree """ diff --git a/torch/_dynamo/polyfills/sys.py b/torch/_dynamo/polyfills/sys.py index 4475b7b41ef62..d04a8b2febc64 100644 --- a/torch/_dynamo/polyfills/sys.py +++ b/torch/_dynamo/polyfills/sys.py @@ -26,10 +26,14 @@ def getrecursionlimit() -> int: return sys.getrecursionlimit() -if hasattr(sys, "get_int_max_str_digits"): +if sys.version_info >= (3, 11): @substitute_in_graph(sys.get_int_max_str_digits, can_constant_fold_through=True) def get_int_max_str_digits() -> int: return sys.get_int_max_str_digits() - __all__ += ["get_int_max_str_digits"] + @substitute_in_graph(sys.set_int_max_str_digits, can_constant_fold_through=True) + def set_int_max_str_digits(maxdigits: int) -> None: + sys.set_int_max_str_digits(maxdigits) + + __all__ += ["get_int_max_str_digits", "set_int_max_str_digits"] diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 37b2344d959b0..6b8f3499184ca 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -30,6 +30,7 @@ import subprocess import sys import textwrap +import typing import uuid from importlib import import_module from tempfile import TemporaryFile @@ -98,6 +99,39 @@ class TritonConstexpr: # type: ignore[no-redef] from .. import config +def _find_repeat_interleave_constraints( + gm: torch.fx.GraphModule, +) -> list[tuple[str, str]]: + """ + Find repeat_interleave operations with output_size constraints. + + Returns list of (repeats_placeholder_name, output_size_placeholder_name) pairs. + These represent constraints where sum(repeats) must equal output_size. + """ + constraints = [] + for node in gm.graph.nodes: + if ( + node.op != "call_function" + or "repeat_interleave" not in str(node.target) + or not node.args + ): + continue + + output_size_node = node.kwargs.get("output_size") + repeats_node = node.args[0] + + # Both must be FX nodes (not constants) and direct placeholders + if ( + isinstance(repeats_node, torch.fx.Node) + and isinstance(output_size_node, torch.fx.Node) + and repeats_node.op == "placeholder" + and output_size_node.op == "placeholder" + ): + constraints.append((str(repeats_node.target), str(output_size_node.target))) + + return constraints + + if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -121,6 +155,7 @@ def _extract_distributed_info( Returns a dict mapping group names to dicts with 'size' and 'rank' keys. Example: {'tp': {'size': 4, 'rank': 0}, 'dp': {'size': 2, 'rank': 0}} """ + from torch.distributed import GroupName from torch.fx.operator_schemas import normalize_function group_info: dict[str, dict[str, int]] = {} @@ -143,9 +178,10 @@ def _extract_distributed_info( continue _, kwargs = opt_args_kwargs - group_name = kwargs.get("group_name") - if group_name is None: + group_name_ = kwargs.get("group_name") + if not isinstance(group_name_, str): continue + group_name = typing.cast(GroupName, group_name_) if group_name in group_info: continue @@ -253,11 +289,16 @@ def wrap_compiler_debug( def debug_wrapper( gm: torch.fx.GraphModule, example_inputs: Sequence[InputType], + compile_region_name: str | None = None, **kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: from torch._subclasses import FakeTensorMode - compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) + compiler_fn = functools.partial( + unconfigured_compiler_fn, + compile_region_name=compile_region_name, + **kwargs, + ) from torch._functorch.aot_autograd import get_aot_graph_name @@ -546,11 +587,8 @@ def generate_compiler_repro_string( ) def get_fn_name(kernel: Any) -> str: - fn_name = ( - # pyrefly: ignore [missing-attribute] - kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name - ) - return fn_name.split(".")[-1] + fn: Any = kernel if isinstance(kernel, JITFunction) else kernel.fn + return fn.__name__.split(".")[-1] def write_kernel_dependencies( kernel: Any, @@ -653,6 +691,8 @@ def write_kernel_dependencies( writer.const(placeholder) elif isinstance(arg, FakeScriptObject): writer.opaque(placeholder, arg.script_class_name) + elif isinstance(arg, torch._C.Generator): + writer.generator(placeholder, arg) else: writer.unsupported(placeholder, arg) @@ -701,6 +741,32 @@ def write_kernel_dependencies( ) model_str = f"{hint_lines}\n\n{model_str}" + # Add fixup code for repeat_interleave constraints + # When inputs are regenerated randomly, sum(repeats) != output_size + # This fixup adjusts the repeats tensor to satisfy the constraint + constraints = _find_repeat_interleave_constraints(gm) + if constraints: + placeholder_to_idx = {name: idx for idx, name in enumerate(placeholder_targets)} + for repeats_name, output_size_name in constraints: + repeats_idx = placeholder_to_idx.get(repeats_name) + output_size_idx = placeholder_to_idx.get(output_size_name) + if repeats_idx is not None and output_size_idx is not None: + # Guard with hasattr since NopInputReader doesn't have args + writer._lines.append( + "# Fixup: ensure sum(repeats) == output_size for repeat_interleave" + ) + writer._lines.append("if hasattr(reader, 'args'):") + writer._lines.append(f" _repeats = reader.args[{repeats_idx}]") + writer._lines.append( + f" _output_size = reader.args[{output_size_idx}]" + ) + writer._lines.append( + " if isinstance(_repeats, torch.Tensor) and _repeats.dtype == torch.int64:" + ) + writer._lines.append(" _n = _repeats.numel()") + writer._lines.append(" _repeats.fill_(_output_size // _n)") + writer._lines.append(" _repeats[:_output_size % _n] += 1") + load_args_lines = writer.lines() load_args_code = "\n".join(load_args_lines) model_str += load_args_code + "\n" @@ -922,7 +988,8 @@ def sync() -> None: sync() try: - compile_mod = compile_fx_inner(fx_g, args) + compile_args = _get_compile_args(fx_g, args) + compile_mod = compile_fx_inner(fx_g, compile_args) assert not isinstance(compile_mod, str) compile_mod(args) sync() @@ -944,10 +1011,15 @@ def inductor_accuracy_fails( ) -> bool: from torch._inductor.compile_fx import compile_fx_inner + def _compile_with_symbolic_args( + gm: torch.fx.GraphModule, inputs: list[Any] + ) -> torch.fx.GraphModule: + return compile_fx_inner(gm, _get_compile_args(gm, inputs)) # type: ignore[return-value] + return backend_aot_accuracy_fails( fx_g, args, # type: ignore[arg-type] - compile_fx_inner, # type: ignore[arg-type] + _compile_with_symbolic_args, # type: ignore[arg-type] require_fp64=require_fp64, ignore_non_fp=ignore_non_fp, ) @@ -963,7 +1035,7 @@ def inductor_accuracy_fails( def repro_common( options: Any, mod: nn.Module, load_args: Any -) -> tuple[torch.fx.GraphModule, Sequence[Any]]: +) -> tuple[torch.fx.GraphModule, list[Any]]: # Invariant for graphs we generate with the repro script assert not any(mod.named_parameters()) for n, b in mod.named_buffers(): @@ -1007,6 +1079,33 @@ def repro_common( return mod, args +def _get_compile_args(mod: torch.fx.GraphModule, args: Sequence[Any]) -> Sequence[Any]: + """Extract FakeTensor/SymInt args from the traced graph for compilation. + + When repro_common traces with tracing_mode='symbolic', the resulting + GraphModule's placeholder nodes carry FakeTensor/SymInt metadata. + compile_fx_inner needs these (not the concrete args) so that Inductor + generates proper symbolic-size bindings in the output code. + + For tracing_mode='real', concrete args are fine — we must NOT extract + FakeTensor metadata because different nodes may have FakeTensors from + different FakeTensorModes, causing a FakeTensorMode mismatch assertion + in Inductor. We detect symbolic tracing by checking for SymInt values, + which only exist when tracing_mode='symbolic'. + """ + placeholders = [n for n in mod.graph.nodes if n.op == "placeholder"] + if not placeholders: + return args + # Only extract metadata if the graph was traced with symbolic mode. + # SymInt values in placeholder metadata are the reliable indicator — + # FakeTensors appear in both real and symbolic modes, but only symbolic + # tracing creates SymInts for integer inputs. + has_symint = any(isinstance(n.meta.get("val"), torch.SymInt) for n in placeholders) + if not has_symint: + return args + return [n.meta.get("val", a) for n, a in zip(placeholders, args)] + + ACCURACY_FAILS: dict[str, Callable[[torch.fx.GraphModule, Any], bool]] = { "": inductor_fails, # This might look inverted but it's not. strict_accuracy means "we will @@ -1081,8 +1180,9 @@ def repro_analyze(options: Any, mod: nn.Module, load_args: Any) -> None: # It is certainly faster though! It probably makes sense to let the # user specify the offload strategy. + compile_args = _get_compile_args(mod, args) with tqdm(desc="Compiling"): - compiled = compile_fx_inner(mod, args) + compiled = compile_fx_inner(mod, compile_args) total = counters["inductor"]["intermediate_hooks"] known_names = set() @@ -1212,7 +1312,7 @@ def repro_get_args( options: Any, mod: nn.Module, load_args: Any ) -> tuple[torch.fx.GraphModule, list[Any]]: mod, args = repro_common(options, mod, load_args) - return mod, args # type: ignore[return-value] + return mod, args def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: @@ -1222,7 +1322,8 @@ def repro_run(options: Any, mod: nn.Module, load_args: Any) -> None: from torch.cuda import synchronize - compiled = compile_fx_inner(mod, args) + compile_args = _get_compile_args(mod, args) + compiled = compile_fx_inner(mod, compile_args) assert not isinstance(compiled, str) if options.accuracy != "": diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index cc3d0068e3152..fe540d8d1a5db 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -106,7 +106,7 @@ def _try_except_tf_mode_template(dummy: Any, stack_var_name: Any) -> None: global __import_torch_dot__dynamo_dot_utils try: dummy - except: # noqa: E722, B001 + except: __import_torch_dot__dynamo_dot_utils.set_torch_function_mode_stack( # type: ignore[name-defined] stack_var_name ) diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 17d9deb916208..efc9e12dde14f 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -24,6 +24,7 @@ import collections import contextlib import inspect +import logging import textwrap import traceback import warnings @@ -36,6 +37,7 @@ import torch.nn from torch._dynamo.variables.misc import AutogradFunctionContextVariable from torch.utils._ordered_set import OrderedSet +from torch.utils._pytree import is_structseq_class from . import config, graph_break_hints, utils, variables from .bytecode_transformation import ( @@ -47,7 +49,7 @@ from .codegen import PyCodegen from .exc import collapse_resume_frames, get_stack_above_dynamo, unimplemented from .source import AttrSource, GlobalSource, LocalCellSource, Source, TempLocalSource -from .utils import is_frozen_dataclass, nn_module_new, object_new +from .utils import is_frozen_dataclass, is_namedtuple_cls, nn_module_new, object_new from .variables.base import ( AttributeMutation, AttributeMutationExisting, @@ -429,9 +431,7 @@ def is_modified(self, item: VariableTracker) -> bool: if isinstance(item, variables.UserDefinedObjectVariable): # Checks if the underlying dict or tuple vt has been modified - return item in self.store_attr_mutations or item.is_underlying_vt_modified( - self - ) + return item in self.store_attr_mutations or item.is_base_vt_modified(self) if self.is_attribute_mutation(item): return item in self.store_attr_mutations @@ -517,12 +517,19 @@ def get_variable_cls(self, user_cls: type) -> type: variable_cls = GenericContextWrappingVariable elif issubclass(user_cls, torch.nn.Module): variable_cls = variables.UnspecializedNNModuleVariable - elif issubclass(user_cls, (dict, collections.OrderedDict)): + elif issubclass(user_cls, collections.defaultdict): + variable_cls = variables.DefaultDictVariable + elif issubclass(user_cls, collections.OrderedDict): + variable_cls = variables.OrderedDictVariable + elif issubclass(user_cls, dict): variable_cls = variables.UserDefinedDictVariable elif issubclass(user_cls, (set, frozenset)): variable_cls = variables.UserDefinedSetVariable elif issubclass(user_cls, tuple): - variable_cls = variables.UserDefinedTupleVariable + if is_namedtuple_cls(user_cls): + variable_cls = variables.UserDefinedTupleVariable.get_vt_cls(user_cls) + else: + variable_cls = variables.UserDefinedTupleVariable elif issubclass(user_cls, list): variable_cls = variables.UserDefinedListVariable elif issubclass(user_cls, MutableMapping): @@ -531,6 +538,11 @@ def get_variable_cls(self, user_cls: type) -> type: variable_cls = FrozenDataClassVariable elif issubclass(user_cls, BaseException): variable_cls = variables.UserDefinedExceptionObjectVariable + elif issubclass( + user_cls, + variables.user_defined._CONSTANT_BASE_TYPES, + ): + variable_cls = variables.UserDefinedConstantVariable elif variables.InspectVariable.is_matching_class(user_cls): variable_cls = variables.InspectVariable assert issubclass(variable_cls, variables.UserDefinedObjectVariable) @@ -549,6 +561,10 @@ def get_example_value( else: if isinstance(base_cls_vt, variables.BuiltinVariable): base_cls = base_cls_vt.fn + elif isinstance(base_cls_vt, variables.DictBuiltinVariable): + base_cls = dict + elif isinstance(base_cls_vt, variables.ListBuiltinVariable): + base_cls = list elif isinstance(base_cls_vt, variables.UserDefinedClassVariable): base_cls = base_cls_vt.value else: @@ -557,14 +573,28 @@ def get_example_value( assert variables.UserDefinedClassVariable.is_supported_new_method( base_cls.__new__ ) - # TODO(anijain2305) - Consider adding get_example_value method to - # each VT to get an example value for all args. As we expand the - # scope to other __new__ methods, we might need to call __new__ with - # init_args (like functools.partial) - # init_args = [arg.get_example_value() for arg in init_args] - # obj = base_cls.__new__(user_cls, *init_args) - - obj = base_cls.__new__(user_cls) + if is_structseq_class(user_cls): + # Structseq tp_new requires a sequence argument and rejects + # tuple.__new__, so create a dummy with None placeholders. + obj = user_cls([None] * user_cls.n_fields) + elif init_args and issubclass( + user_cls, + variables.user_defined._CONSTANT_BASE_TYPES, + ): + example_args = [arg.as_python_constant() for arg in init_args] + try: + obj = base_cls.__new__( # pyrefly: ignore[bad-specialization] + user_cls, *example_args + ) + except Exception: + # __new__ can raise (e.g., exceeding int str digit limits). + # Fall back to creating without args — the example value is + # only used for tracing, not for correctness. + obj = base_cls.__new__( # pyrefly: ignore[bad-specialization] + user_cls + ) + else: + obj = base_cls.__new__(user_cls) return obj def track_new_user_defined_object( @@ -738,11 +768,7 @@ def mutation(self, var: VariableTracker) -> None: var.mutation_type.is_modified = True if var.source is not None: self.mutated_sources.add(var.source) - if ( - var.source - and isinstance(var, variables.ConstDictVariable) - and not isinstance(var, variables.SetVariable) - ): + if var.source and isinstance(var, variables.ConstDictVariable): self._has_existing_dict_mutation = True def has_existing_dict_mutation(self) -> bool: @@ -763,6 +789,19 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: assert var.source is not None continue + # Namedtuples/structseqs with no pending mutations should skip + # codegen_save_tempvars so that restore_stack handles them. In + # export, restore_stack uses value_from_source=False which makes + # child tensors become graph outputs. If we processed them here, + # add_cache would assign a TempLocalSource and restore_stack would + # load from cache with value_from_source=True, hiding the tensors + # from export. + if isinstance( + var, + (variables.NamedTupleVariable, variables.StructSequenceVariable), + ) and not self.has_pending_mutation(var): + continue + if isinstance(var, variables.CellVariable): # Cells created in the root frame are created either by # `MAKE_CELL` or by them being in `co_cellvars`, so we only emit @@ -835,6 +874,26 @@ def load_new_method() -> None: cg.add_cache(var) var.source = TempLocalSource(cg.tempvars[var]) + # For frozen dataclasses, we must emit object.__setattr__ + # immediately after __new__ — before any other code can + # access the object. The suffix-based codegen in + # codegen_update_mutated runs too late: if intervening code + # calls __repr__ (e.g. f-strings), the attributes won't be + # set yet. + if ( + isinstance(var, variables.FrozenDataClassVariable) + and var in self.store_attr_mutations + ): + for name, value in self.store_attr_mutations[var].items(): + cg.load_import_from("builtins", "object") + cg.load_method("__setattr__") + cg(var.source) + cg(variables.ConstantVariable(name)) + cg(value) + cg.extend_output( + [*create_call_method(3), create_instruction("POP_TOP")] + ) + for ctx, args in self.save_for_backward: cg(ctx.source) cg.load_method("save_for_backward") @@ -991,6 +1050,24 @@ def _format_side_effect_message(self, var: VariableTracker) -> str: return log_str + def _emit_side_effect_messages(self, side_effect_messages: list[str]) -> None: + if not side_effect_messages: + return + + for msg in side_effect_messages: + side_effects_log.debug(msg) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_side_effects", + "encoding": "string", + }, + payload_fn=lambda: "\n\n========================================\n\n".join( + side_effect_messages + ), + ) + def codegen_update_mutated( self, cg: PyCodegen, log_side_effects: bool = False ) -> None: @@ -1001,8 +1078,6 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: if config.side_effect_replay_policy != "silent" and log_side_effects: msg = self._format_side_effect_message(var) side_effect_messages.append(msg) - # Log individual side effects for granular debugging - side_effects_log.debug(msg) suffixes = [] for var in self._get_modified_vars(): @@ -1059,7 +1134,7 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: ) _maybe_log_side_effect(var) - elif isinstance(var, variables.ConstDictVariable): + elif isinstance(var, (variables.ConstDictVariable, variables.SetVariable)): # Reconstruct works as follow: # (1) Skip codegen if there are no new items # (2) codegen(...) each pair of key/value @@ -1132,10 +1207,25 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: _maybe_log_side_effect(var) elif self.is_attribute_mutation(var): - if isinstance( - var, - variables.UserDefinedDictVariable, - ) and self.is_modified(var._dict_vt): + # FrozenDataClassVariable attributes were emitted in + # codegen_save_tempvars right after __new__. Skip here to + # avoid double-emitting. + if isinstance(var.mutation_type, AttributeMutationNew) and isinstance( + var, variables.FrozenDataClassVariable + ): + continue + + if ( + isinstance( + var, + variables.UserDefinedDictVariable, + ) + and self.is_modified( + var._base_vt # pyrefly: ignore[bad-argument-type] + ) + and var._base_vt.has_new_items( # pyrefly: ignore[union-attr,missing-attribute] + ) + ): # Do dict related update manually here. The store_attr # mutations will be applied later. varname_map = {} @@ -1167,7 +1257,11 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: ] ) - cg(var._dict_vt, allow_cache=False) # Don't codegen via source + # Reconstruct all items — _manual_dict_setitem clears + # dict_to first, so we need every key/value, not just + # the ones that differ from original_items. + var._base_vt.should_reconstruct_all = True # type: ignore[union-attr] + cg(var._base_vt, allow_cache=False) # Don't codegen via source cg.extend_output( [ create_instruction( @@ -1186,11 +1280,15 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: create_instruction("POP_TOP"), ] ) - _maybe_log_side_effect(var._dict_vt) + _maybe_log_side_effect( + var._base_vt # pyrefly: ignore[bad-argument-type] + ) elif isinstance( var, variables.UserDefinedListVariable, - ) and self.is_modified(var._list_vt): + ) and self.is_modified( + var._base_vt # pyrefly: ignore[bad-argument-type] + ): # Update the list to the updated items. Be careful in # calling the list methods and not the overridden methods. varname_map = {} @@ -1206,7 +1304,7 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: ] ) - cg(var._list_vt, allow_cache=False) # Don't codegen via source + cg(var._base_vt, allow_cache=False) # Don't codegen via source cg.extend_output( [ create_instruction( @@ -1225,7 +1323,9 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: create_instruction("POP_TOP"), ] ) - _maybe_log_side_effect(var._list_vt) + _maybe_log_side_effect( + var._base_vt # pyrefly: ignore[bad-argument-type] + ) # Applying mutations involves two steps: 1) Push all # reconstructed objects onto the stack. 2) Call STORE_ATTR to @@ -1315,6 +1415,15 @@ def _maybe_log_side_effect(var: VariableTracker) -> None: cg.call_function(1, False) cg.pop_top() _maybe_log_side_effect(var) + elif isinstance(var, variables.CountIteratorVariable): + for _ in range(var.advance_count): + cg.add_push_null( + lambda: cg.load_import_from(utils.__name__, "iter_next") + ) + cg(var.source) # type: ignore[attr-defined] + cg.call_function(1, False) + cg.pop_top() + _maybe_log_side_effect(var) elif isinstance(var, variables.RandomVariable): # set correct random seed state def gen_fn() -> None: @@ -1340,17 +1449,16 @@ def gen_fn() -> None: # Send batched structured trace for all side effects in this compilation if log_side_effects and side_effect_messages: - combined_msg = "\n\n========================================\n\n".join( - side_effect_messages - ) - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "dynamo_side_effects", - "encoding": "string", - }, - payload_fn=lambda: combined_msg, - ) + self._emit_side_effect_messages(side_effect_messages) + + def log_side_effects_summary(self) -> None: + if config.side_effect_replay_policy == "silent": + return + if not side_effects_log.isEnabledFor(logging.DEBUG): + return + for var in self._get_modified_vars(): + msg = self._format_side_effect_message(var) + side_effects_log.debug(msg) def is_empty(self) -> bool: return not ( diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 78492139bbe49..813f5c2466485 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -90,6 +90,7 @@ ) from .code_context import code_context from .codegen import PyCodegen +from .comprehension_graph_break import maybe_setup_comprehension_speculation from .exc import ( augment_exc_message_with_hop_name, BackendCompilerFailed, @@ -128,7 +129,6 @@ SkipGuardSource, Source, ) -from .synthetic_function_graph_break import maybe_setup_comprehension_speculation from .trace_rules import is_builtin_constant, is_forbidden from .utils import ( _get_error_on_graph_break, @@ -141,17 +141,17 @@ LazyString, proxy_args_kwargs, ) -from .variables.base import typestr, ValueMutationNew, VariableTracker +from .variables.base import SourceLocation, typestr, ValueMutationNew, VariableTracker from .variables.builder import FrameStateSizeEntry, VariableBuilder, wrap_fx_proxy -from .variables.builtin import BuiltinVariable -from .variables.constant import CONSTANT_VARIABLE_NONE, ConstantVariable +from .variables.builtin import BuiltinVariable, DictBuiltinVariable +from .variables.constant import ConstantVariable from .variables.ctx_manager import ( ContextWrappingVariable, GenericContextWrappingVariable, WithEnterFunctionVariable, WithExitFunctionVariable, ) -from .variables.dicts import ConstDictVariable, SetVariable +from .variables.dicts import ConstDictVariable from .variables.functions import ( BaseUserFunctionVariable, LocalGeneratorFunctionVariable, @@ -181,8 +181,9 @@ UnknownVariable, ) from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable +from .variables.sets import SetVariable from .variables.streams import SymbolicStreamState -from .variables.tensor import supported_comparison_ops, SymNodeVariable +from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable from .variables.torch_function import ( SymbolicTorchFunctionState, TorchFunctionModeVariable, @@ -284,6 +285,12 @@ class SpeculationLog: # If True, graph break at autograd.grad instead of tracing it. # Set when we detect that autograd.grad consumed grad_fns that are returned. graph_break_on_autograd_grad: bool = False + # Names of output tensors whose grad_fns were consumed by autograd.grad. + autograd_grad_leaked_tensors: list[str] = dataclasses.field(default_factory=list) + # If True, graph break at requires_grad_() on source-less intermediates + # instead of tracing it. Set when we detect that such an intermediate + # leaks as a graph output with requires_grad=True. + graph_break_on_requires_grad_: bool = False def restart(self) -> None: self.index = 0 @@ -582,19 +589,154 @@ def _detect_and_normalize_assert_statement( # 3. Any compile_subgraph call should be preceded immediately by a log in the form of "... triggered compile". +def get_node_source_info(n: torch.fx.Node) -> str: + """Extract the innermost user source location from an FX node's stack trace.""" + st = n.meta.get("stack_trace", "") or getattr(n, "stack_trace", "") + if not st: + return "" + trace_lines = st.strip().split("\n") + last_file_idx = -1 + for i, line in enumerate(trace_lines): + if line.strip().startswith("File "): + last_file_idx = i + if last_file_idx < 0: + return "" + file_line = trace_lines[last_file_idx].strip() + code = "" + if last_file_idx + 1 < len(trace_lines): + code = trace_lines[last_file_idx + 1].strip() + return f"{file_line}" + (f", code: {code}" if code else "") + + +def format_tensor_computation_trace(value: VariableTracker, max_lines: int = 20) -> str: + """Walk the FX graph backwards from a TensorVariable to show how it was + computed, with user source locations. Graph inputs (placeholders) are always + shown first, followed by operations in dataflow order ending at the branch + condition.""" + if not isinstance(value, TensorVariable): + return "" + try: + node = value.proxy.node + + # Collect operations and placeholders separately so placeholders (root + # cause) always appear first regardless of traversal order. + op_blocks: list[list[str]] = [] + placeholder_blocks: list[list[str]] = [] + visited: set[str] = set() + + def fmt_arg(a: object) -> str: + return a.name if isinstance(a, torch.fx.Node) else repr(a) + + def walk(n: torch.fx.Node) -> None: + if n.name in visited: + return + visited.add(n.name) + source = get_node_source_info(n) + + if n.op == "placeholder": + block = [] + if source: + block.append(f"# {source}") + block.append(f"{n.name}: graph input ({n.target})") + placeholder_blocks.append(block) + return + + if n.op == "call_function" and len(op_blocks) < max_lines: + target_name = getattr(n.target, "__name__", str(n.target)) + args_str = ", ".join(fmt_arg(a) for a in n.args) + block = [] + if source: + block.append(f"# {source}") + block.append(f"{n.name} = {target_name}({args_str})") + op_blocks.append(block) + + for a in n.args: + if isinstance(a, torch.fx.Node): + walk(a) + + walk(node) + + if not op_blocks and not placeholder_blocks: + return "" + + # Placeholders first (root cause), then ops in dataflow order (reversed + # from the DFS collection order) ending at the branch condition. + all_lines: list[str] = [] + for block in placeholder_blocks + list(reversed(op_blocks)): + all_lines.extend(block) + all_lines.append("") + + return ( + "\n\n The branch condition involves a tensor computed as follows:\n" + + "\n".join(f" {line}" for line in all_lines) + ) + except Exception: + log.debug("format_tensor_computation_trace failed", exc_info=True) + return "" + + def generic_jump( truth_fn: Callable[[object], bool], push: bool ) -> Callable[[InstructionTranslatorBase, Instruction], None]: def raise_jump_graph_break(value: VariableTracker) -> NoReturn: + trace_info = format_tensor_computation_trace(value) + hints: list[str] = [] + if isinstance(value, TensorVariable): + try: + node = value.proxy.node + example = node.meta.get("example_value") + if ( + example is not None + and example.dim() == 0 + and example.dtype + in ( + torch.int32, + torch.int64, + ) + ): + hints.append( + "The branch condition uses a scalar integer tensor. " + "Consider rewriting the computation to use plain Python " + "ints (e.g. use int attributes instead of tensor buffers) " + "so the condition becomes a shape guard instead of " + "data-dependent branching." + ) + if ( + example is not None + and example.dim() == 0 + and example.dtype == torch.bool + ): + hints.append( + "For the common pattern `if tensor_cond: x = transform(x)` " + "(e.g. clamping inf/nan values), consider making the code " + "branchless by always applying the transform. Operations like " + "torch.clamp, torch.nan_to_num, and torch.where are typically " + "no-ops on well-behaved inputs and compile without graph breaks." + ) + # Detect boolean reductions (any/all) which are a telltale sign + # of `tensor.any() or other_tensor.any()` patterns. + # node.target is a str for call_method nodes (e.g. tensor.any()) + # and a callable for call_function nodes (e.g. torch.any()). + target_name = getattr(node.target, "__name__", None) or ( + node.target if isinstance(node.target, str) else None + ) + if target_name in ("any", "all", "bitwise_and", "bitwise_or"): + hints.append( + "Note: Python `or`/`and` between tensor expressions (e.g. " + "`tensor.any() or other_tensor.any()`) triggers implicit bool " + "conversion. Use `torch.logical_or`/`torch.logical_and` or the " + "`|`/`&` operators instead." + ) + except Exception: + pass + hints.extend(graph_break_hints.FUNDAMENTAL) + hints.append("Use `torch.cond` to express dynamic control flow.") unimplemented( gb_type="Data-dependent branching", context=f"attempted to jump with {value}", explanation="Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). " - "Dynamo does not support tracing dynamic control flow.", - hints=[ - *graph_break_hints.FUNDAMENTAL, - "Use `torch.cond` to express dynamic control flow.", - ], + "Dynamo does not support tracing dynamic control flow." + trace_info, + hints=hints, ) def jump_graph_break( @@ -668,6 +810,7 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: return self.jump(inst) elif self.should_compile_partial_graph(): jump_graph_break(self, inst, value) + return else: unimplemented( gb_type="Data-dependent assertion failed (cannot compile partial graph)", @@ -767,17 +910,19 @@ def inner(self: InstructionTranslatorBase, inst: Instruction) -> None: x = None # __bool__ or __len__ is function - if isinstance(x, UserMethodVariable): + if isinstance(x, (GetAttrVariable, UserMethodVariable)): result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment] method_name = getattr(getattr(x, "fn", None), "__name__", None) if result.is_python_constant(): result_value = result.as_python_constant() if method_name == "__bool__" and not isinstance(result_value, bool): - msg = VariableTracker.build( + exc.raise_observed_exception( + TypeError, self, - f"__bool__ should return bool, returned {type(result_value).__name__}", + args=[ + f"__bool__ should return bool, returned {type(result_value).__name__}" + ], ) - exc.raise_observed_exception(TypeError, self, args=[msg]) if isinstance(result_value, (bool, int)) and truth_fn(result_value): if push: self.push(value) @@ -1079,7 +1224,8 @@ def _set_context_recursive( if len(self._exc_stack) + prev_idx > 0: prev = self._exc_stack[prev_idx] self._set_context_recursive(prev, prev_idx - 1) - val.set_context(prev) # type: ignore[union-attr, arg-type] + if prev is not val: + val.set_context(prev) # type: ignore[union-attr, arg-type] return val def _break_context_reference_cycle(self, val: ExceptionVals) -> None: @@ -1095,7 +1241,7 @@ def _break_context_reference_cycle(self, val: ExceptionVals) -> None: break if context is val: - o.set_context(CONSTANT_VARIABLE_NONE) # type: ignore[union-attr, arg-type] + o.set_context(ConstantVariable.create(None)) # type: ignore[union-attr, arg-type] break o = context # type: ignore[assignment] @@ -1604,9 +1750,8 @@ def step_graph_break(self, continue_inst: Instruction) -> None: *create_copy(2), cg.create_load_const(0), cg.create_binary_subscr(), - create_dup_top(), *create_binary_slice(num_stack, None), - *create_swap(2), + *create_copy(3), cg.create_load_const(0), create_instruction("STORE_SUBSCR"), ] @@ -1728,6 +1873,9 @@ def run(self) -> None: exc=e, ) + if not isinstance(e, exc.RestartAnalysis): + self.output.side_effects.log_side_effects_summary() + if hasattr(e, "msg") and "Data-dependent" in e.msg: readable_graph = torch.fx.GraphModule( self.output.nn_modules, self.output.graph @@ -1741,6 +1889,8 @@ def run(self) -> None: except Exception as e: if self.exec_recorder: e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] + if not isinstance(e, exc.RestartAnalysis): + self.output.side_effects.log_side_effects_summary() raise finally: @@ -1762,6 +1912,25 @@ def push(self, val: VariableTracker) -> None: assert isinstance(val, VariableTracker), ( f"push expects VariableTracker, got {typestr(val)}" ) + if val.source_location is None: + inst = self.current_instruction + if inst.positions is not None and inst.positions.lineno is not None: + val.set_source_location( + SourceLocation( + filename=self.f_code.co_filename, + lineno=inst.positions.lineno, + end_lineno=inst.positions.end_lineno, + col_offset=inst.positions.col_offset, + end_col_offset=inst.positions.end_col_offset, + ) + ) + elif inst.starts_line is not None: + val.set_source_location( + SourceLocation( + filename=self.f_code.co_filename, + lineno=inst.starts_line, + ) + ) self.stack.append(val) def push_many(self, vals: list[VariableTracker]) -> None: @@ -1829,8 +1998,54 @@ def STORE_FAST(self, inst: Instruction) -> None: self.is_tracing_resume_prologue = val def DELETE_FAST(self, inst: Instruction) -> None: + var = self.symbolic_locals.get(inst.argval) + if isinstance(var, TensorVariable): + self._maybe_emit_sync_dealloc(var) del self.symbolic_locals[inst.argval] + def _maybe_emit_sync_dealloc(self, var: TensorVariable) -> None: + from .variables.streams import get_current_stream, new_event + + device = var.device + if device is None or device.type not in ("cuda", "xpu"): + return + + node = var.proxy.node + alloc_stream = node.meta.get("custom", {}).get("stream", None) + + users = set(node.users.keys()) + if not users: + return + last_user = None + for n in self.output.graph.nodes: + if n in users: + last_user = n + if last_user is None or last_user.op == "output": + return + + last_use_stream = last_user.meta.get("custom", {}).get("stream", None) + if alloc_stream == last_use_stream: + return + + if alloc_stream is None: + alloc_stream = get_current_stream(device) + if last_use_stream is None: + last_use_stream = get_current_stream(device) + + event_idx = new_event() + self.output.create_proxy( + "call_function", + torch.ops.streams.record_event, + (event_idx, last_use_stream), + {}, + ) + self.output.create_proxy( + "call_function", + torch.ops.streams.sync_dealloc, + (event_idx, alloc_stream, var.as_proxy()), + {}, + ) + def STORE_DEREF(self, inst: Instruction) -> None: # type: ignore[override] assert inst.argval in self.cell_and_freevars() cell = self.symbolic_locals[inst.argval] @@ -2026,7 +2241,7 @@ def IMPORT_NAME(self, inst: Instruction) -> None: self.exec_recorder.add_local_mod(recorded_name, value) # pyrefly: ignore [unbound-name] - if istype(value, (types.ModuleType, DummyModule)): + if isinstance(value, (types.ModuleType, DummyModule)): # pyrefly: ignore [unbound-name, bad-argument-type] self.push(PythonModuleVariable(value, source=source)) else: @@ -2126,7 +2341,7 @@ def FOR_ITER(self, inst: Instruction) -> None: # and performs the action of END_FOR as part of FOR_ITER. We jump # to the END_FOR and run it, so we need to make sure 2 values are # on the stack for it to pop. - self.push(CONSTANT_VARIABLE_NONE) + self.push(ConstantVariable.create(None)) else: # pop the iterator in Python < 3.12 self.pop() @@ -2198,25 +2413,23 @@ def _raise_exception_variable(self, val: VariableTracker) -> NoReturn: # Pass the stored python_stack to preserve the original exception location python_stack = getattr(val, "python_stack", None) raise observed_exception_type( - f"raised exception {val}", real_stack=python_stack + f"raised exception {val.debug_repr()}", real_stack=python_stack ) exc.raise_observed_exception( TypeError, self, args=[ - VariableTracker.build( - self, - f"exceptions must derive from BaseException, not {val.python_type_name()}", - ) + f"exceptions must derive from BaseException, not {val.python_type_name()}", ], ) def RAISE_VARARGS(self, inst: Instruction) -> None: if inst.arg == 0: if not len(self.exn_vt_stack): - msg = VariableTracker.build(self, "No active exception to reraise") - exc.raise_observed_exception(RuntimeError, self, args=[msg]) + exc.raise_observed_exception( + RuntimeError, self, args=["No active exception to reraise"] + ) # re-raise the previous exception. Here CPython refers to the exception # on top of the exception stack @@ -2342,7 +2555,7 @@ def bubble_exception_to_interpreter() -> None: assert isinstance(raised_exception, dynamo_exc) # sanity check unimplemented( gb_type="Observed exception", - context=f"raised exception {curr_exc.python_type_name()}({curr_exc.args})", # type: ignore[union-attr] + context=f"raised exception {curr_exc.debug_repr()}", explanation=observed_exn_gb_explanation, hints=[ *graph_break_hints.USER_ERROR, @@ -2439,9 +2652,9 @@ def bubble_exception_to_interpreter() -> None: self.push(variables.BuiltinVariable(old_exception.exc_type)) else: # Push empty exception tb, value, type - self.push(variables.CONSTANT_VARIABLE_NONE) - self.push(variables.CONSTANT_VARIABLE_NONE) - self.push(variables.CONSTANT_VARIABLE_NONE) + self.push(ConstantVariable.create(None)) + self.push(ConstantVariable.create(None)) + self.push(ConstantVariable.create(None)) # Push new exception - tb, val, type # Traceback is currently mapped to UnknownVariable @@ -2482,7 +2695,7 @@ def PUSH_EXC_INFO(self, inst: Instruction) -> None: val = self.pop() if len(self.exn_vt_stack) == 0: - prev_exc: VariableTracker = CONSTANT_VARIABLE_NONE + prev_exc: VariableTracker = ConstantVariable.create(None) else: prev_exc = self.exn_vt_stack[-1] self.push(prev_exc) @@ -2660,7 +2873,7 @@ def CALL_FUNCTION_EX(self, inst: Instruction) -> None: # Unpack for cases like fn(**obj) where obj is a map if isinstance(kwargsvars, UserDefinedObjectVariable): - kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] + kwargsvars = DictBuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type] # pyrefly: ignore [unbound-name] if not isinstance(argsvars, BaseListVariable) or not isinstance( @@ -2766,12 +2979,22 @@ def STORE_ATTR(self, inst: Instruction) -> None: def DELETE_ATTR(self, inst: Instruction) -> None: obj = self.pop() + self._maybe_sync_dealloc_attr(obj, inst.argval) VariableTracker.build(self, delattr).call_function( self, # type: ignore[arg-type] [obj, VariableTracker.build(self, inst.argval)], {}, ) + def _maybe_sync_dealloc_attr(self, obj: VariableTracker, name: str) -> None: + # Only check side_effects — a pure dict lookup with no observable + # side effects. We intentionally avoid var_getattr here because it + # can trigger __getattr__, add graph nodes, or cause graph breaks. + if self.output.side_effects.has_pending_mutation_of_attr(obj, name): + attr_var = self.output.side_effects.load_attr(obj, name) + if isinstance(attr_var, TensorVariable): + self._maybe_emit_sync_dealloc(attr_var) + @staticmethod def codegen_return_with_pops( inst: Instruction, num_stack: int @@ -2978,21 +3201,12 @@ def create_resume( ) # add resume function to the global scope - if new_code.co_freevars: - # expose code object for debugging purposes - self.output.install_global_unsafe(resume_name, new_code) - package_name = None - else: - # This is safe: we pre-generate a unique name - self.output.install_global_unsafe( - resume_name, - types.FunctionType(new_code, self.f_globals, resume_name), - ) - package_name = resume_name - + self.output.install_resume_function_global( + resume_name, new_code, self.f_globals + ) if self.package is not None: self.package.add_resume_function( - new_code, self.f_globals["__name__"], package_name + new_code, self.f_globals["__name__"], resume_name ) counters["resumes"][new_code.co_name] += 1 @@ -3207,7 +3421,15 @@ def codegen_call_resume( cg.create_binary_subscr(), ] ) - cg.make_function_with_closure(name, code) + # Call the factory function (stored under resume_name) to + # create the resume function with correct globals and closure. + cg.extend_output( + [ + cg.create_load_global(name, add=True), + *create_swap(2), + *create_call_function(1, push_null=True), + ] + ) else: cg.extend_output(cg.load_function_name(name, False, 0)) cg.extend_output(create_swap(2)) @@ -3228,7 +3450,15 @@ def codegen_call_resume( cg.create_binary_subscr(), ] ) - cg.make_function_with_closure(resume_names[-1], resume_codes[-1]) + # Call the factory function to create the resume function with + # correct globals and closure. + cg.extend_output( + [ + cg.create_load_global(resume_names[-1], add=True), + *create_swap(2), + *create_call_function(1, push_null=True), + ] + ) cg.extend_output( [ *create_rot_n(3), @@ -3313,8 +3543,32 @@ def STORE_SUBSCR(self, inst: Instruction) -> None: def DELETE_SUBSCR(self, inst: Instruction) -> None: obj, key = self.popn(2) + # Check for tensor items using side-effect-free internal lookups + # only. We avoid call_method("__getitem__") because it can execute + # user code and add unwanted graph nodes. + self._maybe_sync_dealloc_subscr(obj, key) obj.call_method(self, "__delitem__", [key], {}) + def _maybe_sync_dealloc_subscr( + self, obj: VariableTracker, key: VariableTracker + ) -> None: + from .variables.dicts import ConstDictVariable + from .variables.lists import BaseListVariable + + item_var = None + try: + if isinstance(obj, BaseListVariable): + item_var = obj.getitem_const( + self, # pyrefly: ignore [bad-argument-type] + key, + ) + elif isinstance(obj, ConstDictVariable): + item_var = obj.maybe_getitem_const(key) + except Exception: + pass + if isinstance(item_var, TensorVariable): + self._maybe_emit_sync_dealloc(item_var) + def BUILD_TUPLE(self, inst: Instruction) -> None: items = self.popn(inst.argval) self.push(TupleVariable(items)) @@ -3853,6 +4107,7 @@ def LOAD_BUILD_CLASS(self, inst: Instruction) -> None: BINARY_REMAINDER = stack_op(operator.mod) BINARY_ADD = stack_op(operator.add) BINARY_SUBTRACT = stack_op(operator.sub) + BINARY_SUBSCR = break_graph_if_unsupported( push=True, msg_prefix="Encountered graph break when attempting to trace BINARY_SUBSCR: a binary subscript, e.g. x[attr]", @@ -4226,7 +4481,8 @@ def SET_FUNCTION_ATTRIBUTE(self, inst: Instruction) -> None: self.push(fn) def CONVERT_VALUE(self, inst: Instruction) -> None: - self.push(self._convert_value(self.pop(), inst.argval)) + assert inst.arg is not None + self.push(self._convert_value(self.pop(), inst.arg)) def FORMAT_SIMPLE(self, inst: Instruction) -> None: self._format_value(VariableTracker.build(self, ""), 0) @@ -4391,6 +4647,25 @@ def _get_frame_loc_chain( frame_loc_chain_list.append(frame_loc) return tuple(frame_loc_chain_list) + def _format_stack_source_attribution(self) -> str: + """Format bytecode source locations for stack values involved in a graph break.""" + seen: set[SourceLocation] = set() + parts: list[str] = [] + for vt in self.stack: + source_location = vt.source_location + if source_location is None: + continue + if source_location in seen: + continue + seen.add(source_location) + parts.append( + f" {vt!r} originated from:\n{source_location.format().rstrip()}" + ) + + if not parts: + return "" + return "Stack variable source attribution:\n" + "\n".join(parts) + def log_graph_break( self, code_options: dict[str, Any], @@ -4440,11 +4715,15 @@ def log_graph_break( if exc is not None: reason = augment_exc_message_with_hop_name(exc, reason) - user_stack_trace = ( - f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n" - f"Graph Break Reason: {reason}\n" - "\nUser code traceback:\n" - ) + stack_source_attribution = self._format_stack_source_attribution() + user_stack_trace_parts = [ + f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}", + f"Graph Break Reason: {reason}", + ] + if stack_source_attribution: + user_stack_trace_parts.extend(["", stack_source_attribution]) + user_stack_trace_parts.extend(["", "User code traceback:"]) + user_stack_trace = "\n".join(user_stack_trace_parts) + "\n" if config.verbose: user_stack_trace += ( @@ -4900,7 +5179,6 @@ def _return(self, inst: Instruction) -> None: and not self.error_on_graph_break and not self.is_tracing_resume_prologue ): - # TODO graph break if one_graph is set - this might break things raise exc.SkipFrame( "No ops traced for the FX graph. `torch.compile` will skip the frame and fall back to eager.\n" f"Frame info: {format_frame_info(self.f_code)}" @@ -5272,7 +5550,7 @@ def inline_call_(self) -> VariableTracker: if self.output.should_exit: # graph break - return CONSTANT_VARIABLE_NONE # return dummy variable + return ConstantVariable.create(None) # return dummy variable assert self.symbolic_result is not None @@ -5531,7 +5809,7 @@ def YIELD_VALUE(self, inst: Instruction) -> None: self.generated_items.append(top) if len(self.generated_items) > MAX_ITERATOR_LIMIT: raise exc.InfiniteGeneratorError - self.push(CONSTANT_VARIABLE_NONE) + self.push(ConstantVariable.create(None)) if ( config.enable_faithful_generator_behavior or self.is_generator_from_ctx_manager diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py index c5a9ae15e8e42..f8047cc1b043e 100644 --- a/torch/_dynamo/test_case.py +++ b/torch/_dynamo/test_case.py @@ -60,6 +60,7 @@ def run_tests(needs: str | tuple[str, ...] = ()) -> None: importlib.import_module(need) except ImportError: return + run_tests() @@ -85,6 +86,8 @@ def setUpClass(cls) -> None: def setUp(self) -> None: self._prior_is_grad_enabled = torch.is_grad_enabled() + self._prior_nested_graph_breaks = config.nested_graph_breaks + config.nested_graph_breaks = True super().setUp() reset() utils.counters.clear() @@ -102,6 +105,7 @@ def tearDown(self) -> None: if self._prior_is_grad_enabled is not torch.is_grad_enabled(): log.warning("Running test changed grad mode") torch.set_grad_enabled(self._prior_is_grad_enabled) + config.nested_graph_breaks = self._prior_nested_graph_breaks def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override] if ( @@ -116,21 +120,6 @@ def assertEqual(self, x: Any, y: Any, *args: Any, **kwargs: Any) -> None: # typ # graph break tests -# NB: multiple inheritance with LoggingTestCase is possible - this should be fine -# since there is no overlap in overridden methods. -class TestCaseWithNestedGraphBreaks(TestCase): - def setUp(self) -> None: - super().setUp() - self.prev_nested_graph_breaks = torch._dynamo.config.nested_graph_breaks - # pyrefly: ignore [bad-assignment] - torch._dynamo.config.nested_graph_breaks = True - - def tearDown(self) -> None: - super().tearDown() - # pyrefly: ignore [bad-assignment] - torch._dynamo.config.nested_graph_breaks = self.prev_nested_graph_breaks - - class CPythonTestCase(TestCase): """ Test class for CPython tests located in "test/dynamo/CPython/Py_version/*". diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 06294862762d5..22badc5814b6e 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -604,6 +604,7 @@ def _testing_capture_invoke_subgraph_inductor_compile_gms() -> Generator[ # captured_gms will contain the list of captured graph modules """ global _testing_invoke_subgraph_inductor_compile_captured_gms + # pyrefly: ignore [implicit-any] _testing_invoke_subgraph_inductor_compile_captured_gms = [] try: yield _testing_invoke_subgraph_inductor_compile_captured_gms diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 54e499e45cddd..0633f70c874f9 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -34,11 +34,12 @@ import re import sys import types +import typing import unittest from collections import defaultdict from collections.abc import Callable, Iterator from pathlib import Path -from typing import Any, cast +from typing import Any import torch import torch._inductor.test_operators @@ -58,9 +59,13 @@ ) from .variables import ( BuiltinVariable, + DictBuiltinVariable, FunctionalCallVariable, FunctorchHigherOrderVariable, + GetAttrBuiltinVariable, InspectSignatureVariable, + IterBuiltinVariable, + ListBuiltinVariable, LocalGeneratorFunctionVariable, LocalGeneratorObjectVariable, NestedUserFunctionVariable, @@ -188,18 +193,16 @@ "torch._C._group_tensors_by_device_and_dtype": TorchInGraphFunctionVariable, "torch.to_dlpack": SkipFunctionVariable, "torch._check": TorchInGraphFunctionVariable, + "torch._dynamo.decorators.override_optimization_hint": TorchInGraphFunctionVariable, # We graph break on RNG state setters or getters like - # `torch.get_rng_state` or `torch.set_rng_state`. These functions - # are not aten operations and therefore they are completely ignored - # by the AOT dispatcher. As a result, the AOT graph does not have - # these setter or getter functions, producing an incorrect graph - # when it comes to rng states. - "torch.default_generator#get_state": SkipFunctionVariable, - "torch._C.Generator#get_state": SkipFunctionVariable, + # `torch.get_rng_state`, `torch.set_rng_state`, and + # `torch.Generator.manual_seed`. These functions are not aten + # operations and therefore they are completely ignored by the AOT + # dispatcher. As a result, the AOT graph does not have these setter + # or getter functions, producing an incorrect graph when it comes + # to rng states. "torch.get_rng_state": SkipFunctionVariable, "torch.cuda.get_rng_state": SkipFunctionVariable, - "torch.default_generator#set_state": SkipFunctionVariable, - "torch._C.Generator#set_state": SkipFunctionVariable, "torch.set_rng_state": SkipFunctionVariable, "torch.cuda.set_rng_state": SkipFunctionVariable, # https://github.com/pytorch/pytorch/issues/107187 @@ -235,18 +238,13 @@ "torch.Tensor#split": TorchInGraphFunctionVariable, "torch.cuda.set_device": SkipFunctionVariable, "torch.cuda.current_device": TorchInGraphFunctionVariable, - "torch._C.autocast_decrement_nesting": SkipFunctionVariable, - "torch._C.autocast_increment_nesting": SkipFunctionVariable, "torch.autograd.grad": TorchInGraphFunctionVariable, "torch.autograd.backward": SkipFunctionVariable, - "torch._C.clear_autocast_cache": SkipFunctionVariable, "torch.distributions.constraints.is_dependent": SkipFunctionVariable, "torch.jit.isinstance": SkipFunctionVariable, "torch._C.set_anomaly_enabled": SkipFunctionVariable, - "torch._C.set_autocast_cache_enabled": SkipFunctionVariable, "torch._C.set_autocast_cpu_dtype": SkipFunctionVariable, "torch._C.set_autocast_cpu_enabled": SkipFunctionVariable, - "torch._C.set_autocast_enabled": SkipFunctionVariable, "torch._C.set_autocast_gpu_dtype": SkipFunctionVariable, "torch._C.set_autocast_ipu_dtype": SkipFunctionVariable, "torch._C.set_autocast_ipu_enabled": SkipFunctionVariable, @@ -345,6 +343,13 @@ "torch._C._functorch.unwrap_if_dead": TorchInGraphFunctionVariable, "torch._functorch.predispatch._vmap_increment_nesting": TorchInGraphFunctionVariable, "torch._functorch.predispatch._vmap_decrement_nesting": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._jvp_increment_nesting": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._jvp_decrement_nesting": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._unwrap_for_grad": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._make_dual": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._unpack_dual": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._enter_dual_level": TorchInGraphFunctionVariable, + "torch._functorch.predispatch._exit_dual_level": TorchInGraphFunctionVariable, # everything else "torch._functorch.pyfunctorch.coerce_cinterpreter": TorchInGraphFunctionVariable, "torch._higher_order_ops.triton_kernel_wrap.do_prune_configs": UserFunctionVariable, @@ -362,7 +367,8 @@ "torch._dynamo.override_cudagraphs": UserFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_size_oblivious": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.is_nested_int": UserFunctionVariable, - "torch.fx.experimental.symbolic_shapes.size_hint": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw": TorchInGraphFunctionVariable, + "torch.fx.experimental.symbolic_shapes.optimization_hint": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_true": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.guard_or_false": TorchInGraphFunctionVariable, "torch.fx.experimental.symbolic_shapes.statically_known_true": TorchInGraphFunctionVariable, @@ -397,6 +403,27 @@ "inspect.signature": InspectSignatureVariable, } +# Keep this in sync with the stateful generator methods exposed from +# torch/csrc/Generator.cpp. +_GENERATOR_METHODS_THAT_GRAPH_BREAK = ( + "clone_state", + "get_offset", + "get_state", + "graphsafe_get_state", + "graphsafe_set_state", + "initial_seed", + "manual_seed", + "seed", + "set_offset", + "set_state", +) + +for generator_prefix in ("torch.default_generator", "torch._C.Generator"): + for method_name in _GENERATOR_METHODS_THAT_GRAPH_BREAK: + manual_torch_name_rule_map[f"{generator_prefix}#{method_name}"] = ( + SkipFunctionVariable + ) + # In graph functions (including constant folding) that are C bindings torch_c_binding_in_graph_functions = dict.fromkeys( @@ -477,6 +504,8 @@ "torch._C._add_docstr", "torch._C._are_functorch_transforms_active", "torch._C._autograd_init", + "torch._C._autograd._saved_tensors_hooks_disable", + "torch._C._autograd._saved_tensors_hooks_enable", "torch._C._awaitable_nowait", "torch._C._awaitable_wait", "torch._C._awaitable", @@ -548,7 +577,6 @@ "torch._C._cuda_isHistoryEnabled", "torch._C._cuda_isInBadFork", "torch._C._cuda_jiterator_compile_and_launch_kernel", - "torch._C._cuda_lock_mutex", "torch._C._cuda_maybeExchangeDevice", "torch._C._cuda_memorySnapshot", "torch._C._cuda_memoryStats", @@ -567,7 +595,6 @@ "torch._C._cuda_setStream", "torch._C._cuda_sleep", "torch._C._cuda_synchronize", - "torch._C._cuda_unlock_mutex", "torch._C._cudnn_set_conv_benchmark_empty_cache", "torch._C._cudnn.getCompileVersion", "torch._C._cudnn.getRuntimeVersion", @@ -657,6 +684,9 @@ "torch._C._from_dlpack", "torch._C._functionality_to_backend_keys", "torch._C._functionalization_reapply_views_tls", + "torch._C._functorch._grad_decrement_nesting", + "torch._C._functorch._grad_increment_nesting", + "torch._C._functorch.set_inplace_requires_grad_allowed", "torch._C._fuse_to_static_module", "torch._C._gather_out", "torch._C._gather", @@ -743,6 +773,7 @@ "torch._C._initExtension", "torch._C._is_alias_of", "torch._C._is_any_autocast_enabled", + "torch._C._is_autocast_available", "torch._C._is_cached_tensor", "torch._C._is_flash_attention_available", "torch._C._is_fwd_grad_enabled", @@ -1080,6 +1111,7 @@ "torch._C._nn._test_warn_in_autograd", "torch._C._nn._upsample_bicubic2d_aa", "torch._C._nn._upsample_bilinear2d_aa", + "torch._C._nn._upsample_lanczos2d_aa", "torch._C._nn._upsample_nearest_exact1d", "torch._C._nn._upsample_nearest_exact2d", "torch._C._nn._upsample_nearest_exact3d", @@ -1397,6 +1429,9 @@ "torch._C._xpu_resetPeakMemoryStats", "torch._C._xpu_setStream", "torch._C._xpu_synchronize", + "torch._C.autocast_decrement_nesting", + "torch._C.autocast_increment_nesting", + "torch._C.clear_autocast_cache", "torch._C.fork", "torch._C.get_autocast_cpu_dtype", "torch._C.get_autocast_dtype", @@ -1422,10 +1457,9 @@ "torch._C.parse_ir", "torch._C.parse_schema", "torch._C.parse_type_comment", - "torch._C.read_vitals", - "torch._C.set_vital", + "torch._C.set_autocast_cache_enabled", + "torch._C.set_autocast_enabled", "torch._C.unify_type_list", - "torch._C.vitals_enabled", "torch._C.wait", "torch._cast_Byte", "torch._cast_Char", @@ -1504,6 +1538,7 @@ "torch._foreach_clamp_max", "torch._foreach_clamp_min_", "torch._foreach_clamp_min", + "torch._foreach_clone", "torch._foreach_copy_", "torch._foreach_cos_", "torch._foreach_cos", @@ -1586,6 +1621,7 @@ "torch._functionalize_replace", "torch._functionalize_sync", "torch._functionalize_was_storage_changed", + "torch._fused_adagrad_", "torch._fused_adam_", "torch._fused_adamw_", "torch._fused_dropout", @@ -2221,6 +2257,7 @@ "torch.select", "torch.selu_", "torch.selu", + "torch.set_autocast_dtype", "torch.sgn", "torch.sigmoid_", "torch.sigmoid", @@ -2387,6 +2424,13 @@ "torch._functorch.predispatch._vmap_decrement_nesting", "torch._functorch.predispatch._add_batch_dim", "torch._functorch.predispatch._remove_batch_dim", + "torch._functorch.predispatch._jvp_increment_nesting", + "torch._functorch.predispatch._jvp_decrement_nesting", + "torch._functorch.predispatch._unwrap_for_grad", + "torch._functorch.predispatch._make_dual", + "torch._functorch.predispatch._unpack_dual", + "torch._functorch.predispatch._enter_dual_level", + "torch._functorch.predispatch._exit_dual_level", "torch._guards.compile_context", "torch._guards.detect_fake_mode", "torch._guards.tracing", @@ -2444,8 +2488,6 @@ "torch.accelerator.set_stream", "torch.accelerator.synchronize", "torch.align_tensors", - "torch.amp.autocast_mode._enter_autocast", - "torch.amp.autocast_mode._exit_autocast", "torch.amp.autocast_mode.autocast_decorator", "torch.amp.autocast_mode.custom_bwd", "torch.amp.autocast_mode.custom_fwd", @@ -2619,6 +2661,7 @@ "torch.cuda.get_device_properties", "torch.cuda.get_gencode_flags", "torch.cuda.get_sync_debug_mode", + "torch.cuda.synchronize", "torch.cuda.graphs.graph_pool_handle", "torch.cuda.graphs.is_current_stream_capturing", "torch.cuda.graphs.make_graphed_callables", @@ -2631,7 +2674,6 @@ "torch.cuda.jiterator._create_multi_output_jit_fn", "torch.cuda.memory_usage", "torch.cuda.memory._dump_snapshot", - "torch.cuda.memory._free_mutex", "torch.cuda.memory._get_current_allocator", "torch.cuda.memory._host_allocator", "torch.cuda.memory._record_memory_history_impl", @@ -3174,11 +3216,6 @@ def _builtin_function_ids() -> dict[int, str]: if not k.startswith("_") and callable(v) } ) - rv.update( - { - id(cast): "typing.cast", - } - ) return rv @@ -3813,6 +3850,7 @@ def f3(x, y): _force_inline_flag = False +# pyrefly: ignore [deprecated] @contextlib.contextmanager def _force_inline() -> Iterator[None]: """ @@ -3899,6 +3937,12 @@ def check_verbose( assert filename is not None fi = FunctionInfo(obj, None, filename, None) + # typing.cast is a polyfilled no-op, but unlike C builtins it has a code + # object that PEP 523 can intercept as a standalone frame after a graph + # break. Skip it at the top level to avoid installing unnecessary guards. + if fi.code is not None and fi.code is typing.cast.__code__: + return SkipResult(True, "typing.cast is a no-op, skip at top level") + # Consulte the central trace rules defined in torch._dynamo.trace_rules. reasons: set[str] = set() rule = lookup_inner(fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons) @@ -3970,6 +4014,13 @@ def is_torch(filename: str) -> bool: Main entry point for looking up the trace rule (the Dynamo variable) for a given callable object. """ +BUILTIN_CALLABLES = { + dict: DictBuiltinVariable, + getattr: GetAttrBuiltinVariable, + iter: IterBuiltinVariable, + list: ListBuiltinVariable, +} + def lookup_callable(obj: Callable[..., Any]) -> type[VariableTracker] | None: if not hashable(obj): @@ -3981,6 +4032,8 @@ def lookup_callable(obj: Callable[..., Any]) -> type[VariableTracker] | None: return TorchInGraphFunctionVariable if is_polyfilled_callable(obj): return PolyfilledFunctionVariable + if obj in BUILTIN_CALLABLES: + return BUILTIN_CALLABLES[obj] if is_builtin_callable(obj): return BuiltinVariable return None diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index f78de310497fa..67cc71ba7952c 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1740,6 +1740,7 @@ def clean_for_json(d: dict[str, Any]) -> dict[str, Any]: "_autograd_backward_strict_mode_banned_ops", "reorderable_logging_functions", "ignore_logger_methods", + "ignore_logging_functions", "traceable_tensor_subclasses", "nontraceable_tensor_subclasses", "_custom_ops_profile", @@ -1864,7 +1865,7 @@ def record_compilation_metrics( ), "dynamo_config": _get_dynamo_config_for_logging(), "config_suppress_errors": config.suppress_errors, - "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules, + "config_inline_inbuilt_nn_modules": True, "inductor_config": _scrubbed_inductor_config_for_logging(), "compiler_config": _compiler_config_for_logging(), "cuda_version": torch.version.cuda, @@ -2919,6 +2920,24 @@ def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, V | Any]]: return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)] +def enumerate_items_with_dict_position( + obj: dict[K, V], +) -> Iterable[tuple[int, K, V | Any]]: + """Enumerate dict items yielding (dict_keys_position, key, value). + + For OrderedDicts where move_to_end/prepend has been used, the OrderedDict + iteration order can differ from dict.keys() order. We iterate in + OrderedDict order (correct execution semantics) but return each key's + dict.keys() position so that ConstDictKeySource indices stay consistent + with PyDict_Next / C++ DictGuardManager. + """ + items = get_items_from_dict(obj) + if isinstance(obj, OrderedDict): + key_to_pos = {k: i for i, k in enumerate(dict.keys(obj))} + return ((key_to_pos[k], k, v) for k, v in items) + return ((i, k, v) for i, (k, v) in enumerate(items)) + + def nn_module_new(cls: Any) -> Any: obj = object_new(cls) torch.nn.Module.__init__(obj) @@ -2941,6 +2960,29 @@ def dataclass_fields(cls: Any) -> Any: iter_next = next +def normalize_count_iter(count_iter: Iterator[Any]) -> tuple[Any, Any]: + try: + _, args = count_iter.__reduce__() + except TypeError: + # Python 3.14 no longer pickles itertools.count, so fall back to the + # repr and only recover literal arguments. Non-literal arguments still + # fall back to user-defined handling via the NotImplemented sentinel. + import ast + + count_repr = repr(count_iter) + if not count_repr.startswith("count(") or not count_repr.endswith(")"): + return (NotImplemented, NotImplemented) + try: + args = ast.literal_eval(f"({count_repr[6:-1]},)") + except (SyntaxError, ValueError): + return (NotImplemented, NotImplemented) + if not isinstance(args, tuple) or not 1 <= len(args) <= 2: + return (NotImplemented, NotImplemented) + if len(args) == 1: + return (args[0], 1) + return (args[0], args[1]) + + def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]: _, (range_obj,), maybe_idx = range_iter.__reduce__() # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been @@ -2964,12 +3006,10 @@ def to_subclass(t: Any, cls: type) -> Any: @torch.fx.wrap def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any: - # Call dict(d) to prevent calling overridden __iter__/keys - dict_class = dict - if isinstance(d, OrderedDict): - dict_class = OrderedDict + # Use dict.keys() to match the iteration order of PyDict_Next used by + # the C++ DictGuardManager and by ConstDictKeySource._name_template. # pyrefly: ignore [bad-argument-type] - return next(itertools.islice(dict_class.keys(d), n, n + 1)) + return next(itertools.islice(dict.keys(d), n, n + 1)) def set_getitem(s: set[T], n: int) -> T: @@ -3034,7 +3074,6 @@ def raise_args_mismatch( actual: str = "", ) -> None: from torch._dynamo.exc import raise_observed_exception - from torch._dynamo.variables import ConstantVariable msg_str = ( f"wrong number of arguments or keyword arguments for {name}() call.\n" @@ -3045,7 +3084,7 @@ def raise_args_mismatch( raise_observed_exception( TypeError, tx, - args=[ConstantVariable(msg_str)], + args=[msg_str], ) @@ -3056,7 +3095,6 @@ def iter_contains( check_tensor_identity: bool = False, ) -> Any: from .variables import ConstantVariable - from .variables.constant import CONSTANT_VARIABLE_FALSE, CONSTANT_VARIABLE_TRUE if search.is_python_constant(): found_const = any( @@ -3077,7 +3115,7 @@ def iter_contains( if must_check_tensor_id: if x.is_tensor(): if search is _get_fake_tensor(x): # Object equivalence - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) else: from torch._dynamo.variables.builder import SourcelessBuilder @@ -3091,7 +3129,7 @@ def iter_contains( tx, [check, found], {} ) if found is None: - found = CONSTANT_VARIABLE_FALSE + found = ConstantVariable.create(False) return found @@ -3147,7 +3185,7 @@ def dict_keys_repr(const_keys: Any, *, local: Any) -> str: GLOBAL_KEY_PREFIX = "__dict_key" -from torch._subclasses import UnsupportedFakeTensorException # noqa: F401 +from torch._subclasses import UnsupportedFakeTensorException def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> str: @@ -4407,18 +4445,16 @@ def defake(x: Any) -> Any: size: torch._prims_common.ShapeType stride: torch._prims_common.StrideType if x._has_symbolic_sizes_strides: - size = [] - for s in x.size(): - if isinstance(s, torch.SymInt): - size.append(s.node.shape_env.size_hint(s.node.expr)) - else: - size.append(s) - stride = [] - for s in x.stride(): - if isinstance(s, torch.SymInt): - stride.append(s.node.shape_env.size_hint(s.node.expr)) - else: - stride.append(s) + # optimization_hint is appropriate here because defake only needs a + # plausible concrete shape to allocate a real tensor; it does not need + # to install guards. For unbacked symbols the heuristic fallback is fine. + size = [ + torch.fx.experimental.symbolic_shapes.optimization_hint(s) for s in x.size() + ] + stride = [ + torch.fx.experimental.symbolic_shapes.optimization_hint(s) + for s in x.stride() + ] else: size = x.size() stride = x.stride() diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index a56d8151c91ca..a5454e5e3974c 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -17,14 +17,15 @@ """ from .base import VariableTracker -from .builtin import BuiltinVariable -from .constant import ( - CONSTANT_VARIABLE_FALSE, - CONSTANT_VARIABLE_NONE, - CONSTANT_VARIABLE_TRUE, - ConstantVariable, - EnumVariable, +from .builtin import ( + BaseBuiltinVariable, + BuiltinVariable, + DictBuiltinVariable, + GetAttrBuiltinVariable, + IterBuiltinVariable, + ListBuiltinVariable, ) +from .constant import ConstantVariable from .ctx_manager import ( CatchWarningsCtxManagerVariable, ContextWrappingVariable, @@ -51,19 +52,14 @@ ) from .dicts import ( ConstDictVariable, - DefaultDictVariable, DictItemsVariable, - DictKeySetVariable, DunderDictVariable, - FrozensetVariable, MappingProxyVariable, NNModuleHooksDictVariable, - OrderedSetClassVariable, - OrderedSetVariable, - SetVariable, ) from .distributed import BackwardHookVariable, DistributedVariable from .functions import ( + BaseUserFunctionVariable, BuiltinMethodVariable, CollectionsNamedTupleFunction, CreateTMADescriptorExperimentalVariable, @@ -81,7 +77,7 @@ SparseTensorCreationSkipVariable, TMADescriptorExperimentalVariable, TMADescriptorStableVariable, - TritonSetAllocatorSkipVariable, + TritonSetAllocatorVariable, UserFunctionVariable, UserMethodVariable, WrapperUserFunctionVariable, @@ -108,7 +104,6 @@ BaseListVariable, ListIteratorVariable, ListVariable, - NamedTupleVariable, RangeVariable, SliceVariable, TupleIteratorVariable, @@ -145,7 +140,19 @@ ) from .optimizer import OptimizerVariable from .sdpa import SDPAParamsVariable -from .streams import EventVariable, StreamContextVariable, StreamVariable +from .sets import ( + DictKeySetVariable, + FrozensetVariable, + OrderedSetClassVariable, + OrderedSetVariable, + SetVariable, +) +from .streams import ( + CudaStreamVariable, + EventVariable, + StreamContextVariable, + StreamVariable, +) from .tensor import ( DataPtrVariable, FakeItemVariable, @@ -157,11 +164,16 @@ ) from .torch import TorchCtxManagerClassVariable, TorchInGraphFunctionVariable from .user_defined import ( + DefaultDictVariable, FrozenDataClassVariable, InspectVariable, MutableMappingVariable, + NamedTupleVariable, + OrderedDictVariable, RemovableHandleVariable, + StructSequenceVariable, UserDefinedClassVariable, + UserDefinedConstantVariable, UserDefinedDictVariable, UserDefinedExceptionClassVariable, UserDefinedExceptionObjectVariable, @@ -169,6 +181,7 @@ UserDefinedObjectVariable, UserDefinedSetVariable, UserDefinedTupleVariable, + UserDefinedVariable, ) @@ -176,14 +189,13 @@ "AutogradFunctionContextVariable", "AutogradFunctionVariable", "BackwardHookVariable", + "BaseBuiltinVariable", "BaseListVariable", "BuiltinVariable", "CatchWarningsCtxManagerVariable", - "CONSTANT_VARIABLE_FALSE", - "CONSTANT_VARIABLE_NONE", - "CONSTANT_VARIABLE_TRUE", "ConstantVariable", "ConstDictVariable", + "DictBuiltinVariable", "ContextWrappingVariable", "CountIteratorVariable", "CreateTMADescriptorExperimentalVariable", @@ -195,25 +207,27 @@ "DeletedVariable", "DictKeySetVariable", "DynamoConfigPatchVariable", - "EnumVariable", "FakeItemVariable", + "GetAttrBuiltinVariable", "GetAttrVariable", "GradModeVariable", "InspectSignatureVariable", "InspectVariable", + "IterBuiltinVariable", "IteratorVariable", "ItertoolsVariable", "LambdaVariable", "LazyConstantVariable", "LazyVariableTracker", + "ListBuiltinVariable", "ListIteratorVariable", "ListVariable", - "NamedTupleVariable", "NestedUserFunctionVariable", "CellVariable", "NewGlobalVariable", "NNModuleVariable", "NumpyNdarrayVariable", + "OrderedDictVariable", "NumpyVariable", "OptimizerVariable", "PolyfilledFunctionVariable", @@ -241,6 +255,8 @@ "UntypedStorageVariable", "UserDefinedClassVariable", "UserDefinedTupleVariable", + "NamedTupleVariable", + "StructSequenceVariable", "UserDefinedObjectVariable", "UserFunctionVariable", "UserMethodVariable", diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 0a6303fe8088f..9d08ce96e25ae 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -13,17 +13,18 @@ computations. """ +from __future__ import annotations + import collections +import dataclasses import functools +import linecache import logging from collections.abc import Callable, ItemsView, KeysView, Sequence, ValuesView from contextvars import ContextVar from enum import Enum from typing import Any, NoReturn, TYPE_CHECKING -from torch._guards import Guard -from torch.fx.proxy import Node - from .. import graph_break_hints, variables from ..current_scope_id import current_scope_id from ..exc import raise_observed_exception, unimplemented @@ -33,7 +34,11 @@ if TYPE_CHECKING: + from torch._guards import Guard + from torch.fx.proxy import Node + from ..codegen import PyCodegen + from ..side_effects import SideEffects from ..symbolic_convert import InstructionTranslator from .constant import ConstantVariable from .functions import UserFunctionVariable @@ -41,6 +46,28 @@ log = logging.getLogger(__name__) + +@dataclasses.dataclass(frozen=True) +class SourceLocation: + """Source position of the bytecode instruction that generated a VariableTracker.""" + + filename: str + lineno: int + end_lineno: int | None = None + col_offset: int | None = None + end_col_offset: int | None = None + + def format(self) -> str: + line = linecache.getline(self.filename, self.lineno).rstrip() + result = f' File "{self.filename}", line {self.lineno}\n' + if line: + result += f" {line}\n" + if line and self.col_offset is not None and self.end_col_offset is not None: + num_carets = max(1, self.end_col_offset - self.col_offset) + result += " " + " " * self.col_offset + "^" * num_carets + "\n" + return result + + # Tracks active method calls on VariableTracker instances to detect self-referential # calls (e.g., as_python_constant on a list that contains itself). Maps # (id(instance), id(original_method)) tuples to track which calls are in progress. @@ -219,13 +246,17 @@ def is_side_effect_safe(m: MutationType) -> bool: return m.scope == scope_id +class NO_SUCH_SUBOBJ: + """Sentinel indicating no concrete Python object is available.""" + + # This helps users of `as_python_constant` to catch unimplemented error with # more information; it inherits `NotImplementedError` for backward # compatibility reasons. class AsPythonConstantNotImplementedError(NotImplementedError): - vt: "VariableTracker" + vt: VariableTracker - def __init__(self, vt: "VariableTracker", msg: str | None = None) -> None: + def __init__(self, vt: VariableTracker, msg: str | None = None) -> None: msg = f"{vt} is not a constant" if msg is None else msg super().__init__(msg) self.vt = vt @@ -278,17 +309,23 @@ class VariableTracker(metaclass=VariableTrackerMeta): Prefer the factory function VariableTracker.build() over VariableTracker.__init__(). """ + # The CPython type(s) this VT is designed to represent. + # Single type or tuple of types. None means no static CPython type mapping + # (e.g., dynamic types like UserDefinedObjectVariable, or Dynamo-internal VTs). + _cpython_type: type | tuple[type, ...] | None = None + # fields to leave unmodified in apply() _nonvar_fields = { "value", "guards", "source", + "source_location", "mutation_type", "parents_tracker", "user_code_variable_name", } - def clone(self, **kwargs: Any) -> "VariableTracker": + def clone(self, **kwargs: Any) -> VariableTracker: """Shallow copy with some (optional) changes""" args = dict(self.__dict__) args.update(kwargs) @@ -297,12 +334,17 @@ def clone(self, **kwargs: Any) -> "VariableTracker": @classmethod def visit( cls, - fn: Callable[["VariableTracker"], None], + fn: Callable[[VariableTracker], None], value: Any, cache: dict[int, Any] | None = None, + side_effects: SideEffects | None = None, ) -> None: """ - Walk value and call fn on all the VariableTracker instances + Walk value and call fn on all the VariableTracker instances. + + When side_effects is provided, also walks attributes stored in + store_attr_mutations (e.g. dataclass fields set during tracing + that aren't in the VT's __dict__). """ if cache is None: cache = {} @@ -320,13 +362,16 @@ def visit( nonvars = value._nonvar_fields for key, subvalue in value.__dict__.items(): if key not in nonvars: - cls.visit(fn, subvalue, cache) + cls.visit(fn, subvalue, cache, side_effects) + if side_effects is not None and value in side_effects.store_attr_mutations: + for attr_vt in side_effects.store_attr_mutations[value].values(): + cls.visit(fn, attr_vt, cache, side_effects) elif istype(value, (list, tuple)): for subvalue in value: - cls.visit(fn, subvalue, cache) + cls.visit(fn, subvalue, cache, side_effects) elif istype(value, (dict, collections.OrderedDict)): for subvalue in value.values(): - cls.visit(fn, subvalue, cache) + cls.visit(fn, subvalue, cache, side_effects) def __repr__(self) -> str: return f"{self.__class__.__name__}()" @@ -396,6 +441,14 @@ def is_python_constant(self) -> bool: except NotImplementedError: return False + def bool_impl(self, tx: InstructionTranslator) -> VariableTracker | None: + # Mirrors CPython's tp_as_number->nb_bool slot. + # https://github.com/python/cpython/blob/c09ccd9c429/Objects/object.c#L2135-L2158 + # + # Returns None when the type has no nb_bool, causing generic_bool to + # fall through to length check, then truthy default. + return None + def is_constant_match(self, *values: Any) -> bool: """ Check if this variable is a python constant matching one of the given values. @@ -418,7 +471,7 @@ def make_guard(self, fn: Callable[..., Any]) -> Guard: # TODO[@lucaskabela] - change this type to `InstructionTranslatorBase` # and cascade that (large blast radius) - def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: + def const_getattr(self, tx: InstructionTranslator, name: str) -> Any: """getattr(self, name) returning a python constant""" raise NotImplementedError @@ -430,7 +483,7 @@ def is_tensor(self) -> bool: """Return True for TensorVariable instances""" return False - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + def var_getattr(self, tx: InstructionTranslator, name: str) -> VariableTracker: """getattr(self, name) returning a new variable""" value = self.const_getattr(tx, name) if not variables.ConstantVariable.is_literal(value): @@ -467,7 +520,7 @@ def _contains_self_reference(self) -> bool: """Check if this variable references itself (directly or indirectly).""" found_self = False - def check(vt: "VariableTracker") -> None: + def check(vt: VariableTracker) -> None: nonlocal found_self if vt is self: found_self = True @@ -479,13 +532,13 @@ def check(vt: "VariableTracker") -> None: return found_self - def reconstruct(self, codegen: "PyCodegen") -> None: + def reconstruct(self, codegen: PyCodegen) -> None: raise NotImplementedError - def unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + def unpack_var_sequence(self, tx: Any) -> list[VariableTracker]: raise NotImplementedError - def force_unpack_var_sequence(self, tx: Any) -> list["VariableTracker"]: + def force_unpack_var_sequence(self, tx: Any) -> list[VariableTracker]: # like unpack_var_sequence, but should only be used when it is # safe to eagerly (vs. lazily) unpack this variable. # e.g. map(f, x) is normally evaluated lazily but sometimes @@ -509,15 +562,15 @@ def has_force_unpack_var_sequence(self, tx: Any) -> bool: # Only use when it is safe to eagerly unpack this variable (like force_unpack_var_sequence). # INVARIANT: variable must satisfy has_force_unpack_var_sequence() == True! def force_apply_to_var_sequence( - self, tx: Any, fn: Callable[["VariableTracker"], Any] + self, tx: Any, fn: Callable[[VariableTracker], Any] ) -> None: assert self.has_force_unpack_var_sequence(tx) for v in self.unpack_var_sequence(tx): fn(v) def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> "ConstantVariable": + self, tx: InstructionTranslator, name: str + ) -> ConstantVariable: unimplemented( gb_type="Unsupported hasattr call", context=f"call_obj_hasattr {self} {name}", @@ -531,9 +584,9 @@ def call_obj_hasattr( def call_function( self, tx: Any, - args: Sequence["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: unimplemented( gb_type="Unsupported function call", context=f"call_function {self} {args} {kwargs}", @@ -544,16 +597,59 @@ def call_function( ], ) + def sq_length(self, tx: Any) -> VariableTracker: + """Called when sq_length is not implemented.""" + raise_observed_exception( + TypeError, + tx, + args=[f"object of type '{self.python_type_name()}' has no len()"], + ) + + def mp_length(self, tx: Any) -> VariableTracker: + """Called when mp_length is not implemented.""" + raise_observed_exception( + TypeError, + tx, + args=[f"object of type '{self.python_type_name()}' has no len()"], + ) + + def mp_subscript_impl( + self, + tx: InstructionTranslator, + key: VariableTracker, + ) -> VariableTracker: + # PyObject_GetItem: https://github.com/python/cpython/blob/62a6e898e01/Objects/abstract.c#L155-L206 + # TODO: raise TypeError for non-subscriptable objects (blocked on + # branch 3 __class_getitem__ support for type objects). + unimplemented( + gb_type="missing_mp_subscript", + context=f"mp_subscript_impl not defined for {type(self).__name__}", + explanation=f"Dynamo does not yet support subscripting '{self.python_type_name()}'.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + def call_method( self, tx: Any, name: str, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": - if name == "__len__" and self.has_unpack_var_sequence(tx): - assert not (args or kwargs) - return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__getitem__": + if len(args) == 1 and not kwargs: + return self.mp_subscript_impl(tx, args[0]) + from ..utils import raise_args_mismatch + + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + elif name == "__len__" and not (args or kwargs): + from .object_protocol import generic_len + + return generic_len(tx, self) elif ( name == "__getattr__" and len(args) == 1 @@ -561,6 +657,12 @@ def call_method( and not kwargs ): return self.var_getattr(tx, args[0].as_python_constant()) + elif name == "__index__" and not args and not kwargs: + return self.nb_index_impl(tx) + elif name == "__int__" and not args and not kwargs: + return self.nb_int_impl(tx) + elif name == "__float__" and not args and not kwargs: + return self.nb_float_impl(tx) elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: other = args[0] if not isinstance(self, type(other)) and not ( @@ -598,8 +700,22 @@ def call_method( raise_observed_exception( type(e), tx, - args=list(map(variables.ConstantVariable.create, e.args)), + args=list(e.args), ) + # __reduce_ex__ is a C builtin (object.__reduce_ex__) that Dynamo + # cannot trace into. Constant-fold it for VTs backed by a real + # Python object so that copy.deepcopy can trace through. + if ( + name == "__reduce_ex__" + and len(args) == 1 + and not kwargs + and self.is_python_constant() + ): + protocol = args[0].as_python_constant() + return VariableTracker.build( + tx, self.as_python_constant().__reduce_ex__(protocol) + ) + hints = [ f"Avoid calling `{self.python_type_name()}.{name}` in your code.", "Please report an issue to PyTorch.", @@ -636,11 +752,11 @@ def call_method( def call_tree_map( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: """Performance optimization to implement optree.tree_map faster than tracing it""" is_leaf_var = tree_map_kwargs.get("is_leaf") if is_leaf_var is not None and not is_leaf_var.is_constant_none(): @@ -669,11 +785,11 @@ def call_tree_map( def call_tree_map_branch( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: """Emulate optree.tree_map without is_leaf/none_is_leaf checks (handled above)""" return self._tree_map_fallback( tx, @@ -686,11 +802,11 @@ def call_tree_map_branch( def _tree_map_fallback( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], + ) -> VariableTracker: tree_map_fn_copy = tree_map_fn.clone() tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] log.debug( @@ -708,12 +824,12 @@ def _tree_map_fallback( def call_tree_map_with_path( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], keypath: tuple[Any, ...], - ) -> "VariableTracker": + ) -> VariableTracker: """Performance optimization to implement tree_map_with_path faster than tracing it""" is_leaf_var = tree_map_kwargs.get("is_leaf") if is_leaf_var is not None and not is_leaf_var.is_constant_none(): @@ -747,12 +863,12 @@ def call_tree_map_with_path( def call_tree_map_with_path_branch( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], keypath: tuple[Any, ...], - ) -> "VariableTracker": + ) -> VariableTracker: """Handle tree_map_with_path for leaf nodes (default behavior)""" keypath_var = variables.TupleVariable( [VariableTracker.build(tx, k) for k in keypath] @@ -762,12 +878,12 @@ def call_tree_map_with_path_branch( def _tree_map_with_path_fallback( self, tx: Any, - tree_map_fn: "UserFunctionVariable", - map_fn: "VariableTracker", - rest: Sequence["VariableTracker"], - tree_map_kwargs: dict[str, "VariableTracker"], + tree_map_fn: UserFunctionVariable, + map_fn: VariableTracker, + rest: Sequence[VariableTracker], + tree_map_kwargs: dict[str, VariableTracker], keypath: tuple[Any, ...], - ) -> "VariableTracker": + ) -> VariableTracker: tree_map_fn_copy = tree_map_fn.clone() tree_map_fn_copy._maybe_call_tree_map_fastpath = lambda *args, **kwargs: None # type: ignore[missing-attribute] log.debug( @@ -777,9 +893,6 @@ def _tree_map_with_path_fallback( tree_map_kwargs, keypath, ) - # For fallback, we need to reconstruct the subtree rooted at this node - # and call tree_map_with_path on it. Since we're in the middle of the tree, - # we fall back to tracing the tree_map_with_path function. return tree_map_fn_copy.call_function( tx, [map_fn, self, *rest], @@ -789,11 +902,14 @@ def _tree_map_with_path_fallback( def set_name_hint(self, name: str) -> None: pass - def realize(self) -> "VariableTracker": + def set_source_location(self, source_location: SourceLocation) -> None: + self.source_location = source_location + + def realize(self) -> VariableTracker: """Used by LazyVariableTracker to build the real VariableTracker""" return self - def unwrap(self) -> "VariableTracker": + def unwrap(self) -> VariableTracker: """Used by LazyVariableTracker to return the real VariableTracker if it already exists""" return self @@ -801,7 +917,7 @@ def is_realized(self) -> bool: """Used by LazyVariableTracker to indicate an unrealized node""" return True - def next_variable(self, tx: Any) -> "VariableTracker": + def next_variable(self, tx: Any) -> VariableTracker: unimplemented( gb_type="Unsupported next() call", context=f"next({self})", @@ -879,6 +995,13 @@ def get_python_hash(self) -> int: ], ) + def get_real_python_backed_value(self) -> object: + """Return the Python object this VT wraps, for `is` comparison. + + Returns NO_SUCH_SUBOBJ if no concrete Python object is available. + """ + return NO_SUCH_SUBOBJ + def is_python_equal(self, other: object) -> bool: """ NB - Deliberately not overriding the __eq__ method because that can @@ -896,14 +1019,69 @@ def is_python_equal(self, other: object) -> bool: ], ) + def nb_index_impl( + self, + tx: Any, + ) -> VariableTracker: + """Mirrors CPython's PyNumber_Index / nb_index slot. + + https://github.com/python/cpython/blob/c09ccd9c429/Objects/abstract.c#L1411-L1450 + + The base implementation raises TypeError, matching CPython's behavior + when tp_as_number->nb_index is NULL (_PyIndex_Check fails). + """ + raise_observed_exception( + TypeError, + tx, + args=[ + f"'{self.python_type_name()}' object cannot be interpreted as an integer" + ], + ) + + def nb_int_impl( + self, + tx: Any, + ) -> VariableTracker: + """Mirrors CPython's tp_as_number->nb_int slot. + + Called when type_implements_nb_int returns True for this type. + Subclasses override to provide the actual conversion. + """ + unimplemented( + gb_type="nb_int_impl not implemented", + context=f"{type(self).__name__} has nb_int slot but no nb_int_impl override", + explanation=f"The type {self.python_type_name()} has an nb_int C slot but " + "the corresponding VariableTracker doesn't implement nb_int_impl.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def nb_float_impl( + self, + tx: Any, + ) -> VariableTracker: + """Mirrors CPython's tp_as_number->nb_float slot. + + Called when type_implements_nb_float returns True for this type. + Subclasses override to provide the actual conversion. + """ + unimplemented( + gb_type="nb_float_impl not implemented", + context=f"{type(self).__name__} has nb_float slot but no nb_float_impl override", + explanation=f"The type {self.python_type_name()} has an nb_float C slot but " + "the corresponding VariableTracker doesn't implement nb_float_impl.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + def __init__( self, *, source: Source | None = None, mutation_type: MutationType | None = None, + source_location: SourceLocation | None = None, ) -> None: super().__init__() self.source = source + self.source_location = source_location self.mutation_type = mutation_type # NOTE sometimes mutation_type is set afterwards for implementation @@ -961,9 +1139,9 @@ def reconstruct_failure(self) -> NoReturn: @staticmethod def _add_call_once_guard( - cls: type["VariableTracker"], + cls: type[VariableTracker], method: str, - callback: Callable[["VariableTracker"], Any], + callback: Callable[[VariableTracker], Any], ) -> None: original_method = getattr(cls, method) @@ -993,11 +1171,6 @@ def guarded_method(self, *args: Any, **kwargs: Any) -> VariableTracker: setattr(cls, method, guarded_method) -def raise_type_error_exc(tx: Any, msg_str: str) -> NoReturn: - msg = variables.ConstantVariable.create(msg_str) - raise_observed_exception(TypeError, tx, args=[msg]) - - def typestr(*objs: object) -> str: if len(objs) == 1: (obj,) = objs @@ -1010,5 +1183,52 @@ def typestr(*objs: object) -> str: instancecheck = type.__instancecheck__ + + +_CPYTHON_BASE_URL = "https://github.com/python/cpython/blob/v3.13.0/" + + +def print_cpython_to_vt_mapping() -> None: + """Print the mapping from CPython types to Dynamo VariableTracker classes. + + Reads _cpython_type from each VT subclass (own attribute only, not inherited). + """ + # Ensure all VT modules are imported so all_subclasses is complete + from . import ( # noqa: F401 + builtin, + constant, + ctx_manager, + dicts, + distributed, + functions, + higher_order_ops, + iter, + lazy, + lists, + misc, + nn_module, + streams, + tensor, + torch, + torch_function, + user_defined, + ) + + mapping: dict[type, list[type]] = {} + for vt_cls in VariableTrackerMeta.all_subclasses: + cpython_type = vt_cls.__dict__.get("_cpython_type") + if cpython_type is None: + continue + types = cpython_type if isinstance(cpython_type, tuple) else (cpython_type,) + for t in types: + mapping.setdefault(t, []).append(vt_cls) + + for py_type, vt_classes in sorted(mapping.items(), key=lambda x: x[0].__qualname__): + for vt_cls in vt_classes: + print( + f"{py_type.__module__}.{py_type.__qualname__:30s} -> {vt_cls.__name__}" + ) + + from . import builder from .lazy import LazyVariableTracker diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 987c75e804354..f516cdc38b731 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -45,6 +45,7 @@ from torch import SymInt from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.graph_bytecode_inputs import ( + CURRENT_STREAM_INDEX, get_external_object_by_index, register_user_object, ) @@ -52,6 +53,7 @@ get_metrics_context, is_int_specialization_case, is_torch_sym, + normalize_count_iter, set_feature_use, ) from torch._guards import TracingContext @@ -116,6 +118,7 @@ ChainedSource, ConstDictKeySource, ConvertIntSource, + CurrentStreamSource, DictGetItemSource, DictSubclassGetItemSource, DynamicScalarSource, @@ -148,8 +151,8 @@ clone_input, common_constant_types, dict_keys, + enumerate_items_with_dict_position, get_fake_value, - get_items_from_dict, get_locals_to_steal, get_static_address_type, is_frozen_dataclass, @@ -185,7 +188,7 @@ VariableTrackerMeta, ) from .builtin import BuiltinVariable -from .constant import ConstantVariable, EnumVariable +from .constant import ConstantVariable from .ctx_manager import ( AutocastModeVariable, CudagraphOverrideVariable, @@ -194,16 +197,7 @@ NullContextVariable, PreserveVersionContextVariable, ) -from .dicts import ( - ConstDictVariable, - DefaultDictVariable, - DictKeySetVariable, - FrozensetVariable, - MappingProxyVariable, - OrderedSetClassVariable, - OrderedSetVariable, - SetVariable, -) +from .dicts import ConstDictVariable, MappingProxyVariable, SetVariable from .distributed import WorldMetaClassVariable from .functions import ( BuiltinMethodVariable, @@ -214,7 +208,7 @@ FunctoolsPartialVariable, SysFunctionVariable, TritonKernelVariable, - TritonSetAllocatorSkipVariable, + TritonSetAllocatorVariable, UserFunctionVariable, WrapperUserFunctionVariable, ) @@ -222,13 +216,12 @@ LocalMapWrappedHigherOrderVariable, TorchHigherOrderOperatorVariable, ) -from .iter import ItertoolsVariable +from .iter import CountIteratorVariable, ItertoolsVariable from .lazy import LazyConstantVariable, LazyVariableTracker from .lists import ( BaseListVariable, ListIteratorVariable, ListVariable, - NamedTupleVariable, RangeVariable, SizeVariable, SliceVariable, @@ -256,6 +249,7 @@ RandomClassVariable, RandomVariable, SavedTensorBox, + StringFormatVariable, TorchVersionVariable, TypingVariable, WeakRefVariable, @@ -268,6 +262,12 @@ from .optimizer import OptimizerVariable from .script_object import OpaqueObjectClassVariable, TorchScriptObjectVariable from .sdpa import SDPAParamsVariable +from .sets import ( + DictKeySetVariable, + FrozensetVariable, + OrderedSetClassVariable, + OrderedSetVariable, +) from .streams import EventVariable, StreamContextVariable, StreamVariable from .tensor import ( NumpyNdarrayVariable, @@ -289,15 +289,17 @@ TorchFunctionModeVariable, ) from .user_defined import ( + DefaultDictVariable, FrozenDataClassVariable, InspectVariable, IntWrapperVariable, KeyedJaggedTensorVariable, MutableMappingVariable, + OrderedDictVariable, SourcelessGraphModuleVariable, UserDefinedClassVariable, + UserDefinedConstantVariable, UserDefinedDictVariable, - UserDefinedEnumClassVariable, UserDefinedExceptionClassVariable, UserDefinedListVariable, UserDefinedObjectVariable, @@ -594,6 +596,7 @@ def _type_dispatch_impl(cls, trace_numpy: bool) -> dict[object, Callable[..., An (tuple, list, odict_values, collections.deque, torch.Size), cls.wrap_listlike, ), + (itertools.count, cls.wrap_itertools_count), (tuple_iterator, cls.wrap_tuple_iterator), (range_iterator, cls.wrap_range_iterator), ((slice, range), cls.wrap_slice_range), @@ -762,17 +765,17 @@ def set_allocator() -> None: pass if has_triton_experimental_host_tma(): - from triton.tools.experimental_descriptor import ( # noqa: F811 + from triton.tools.experimental_descriptor import ( create_1d_tma_descriptor, create_2d_tma_descriptor, ) if has_triton_tensor_descriptor_host_tma(): - from triton.tools.tensor_descriptor import TensorDescriptor # noqa: F811 + from triton.tools.tensor_descriptor import TensorDescriptor if has_triton(): import triton as triton_mod if hasattr(triton_mod, "set_allocator"): - set_allocator = triton_mod.set_allocator # noqa: F811 + set_allocator = triton_mod.set_allocator # Handle exact type() match type_dispatch = self._type_dispatch().get(type(value)) @@ -820,9 +823,8 @@ def set_allocator() -> None: source=self.source, mutation_type=ValueMutationExisting(), ) - result = NamedTupleVariable( - output, - tuple_cls=type(value), + result = UserDefinedTupleVariable.get_vt_cls(type(value))( + value, source=self.source, tuple_vt=tuple_vt, ) @@ -855,7 +857,7 @@ def set_allocator() -> None: self.tx.output.guard_on_key_order.add(self.source) # We need all the keys to be hashable. We do this within the - # _HashableTracker class in dicts.py + # HashableTracker class in hashable.py def build_key_value( i: Any, k: Any, v: Any ) -> tuple[VariableTracker, VariableTracker]: @@ -871,29 +873,39 @@ def build_key_value( return key, res_value - # Ensure that we call dict.keys and not value.keys (which can call - # overridden keys method). In the C++ guards, we relied on - # PyDict_Next to traverse the dictionary, which uses the internal - # data structure and does not call the overridden keys method. result = dict( build_key_value(i, k, v) - for i, (k, v) in enumerate(get_items_from_dict(value)) + for i, k, v in enumerate_items_with_dict_position(value) ) if istype(value, collections.defaultdict): factory_source = AttrSource(self.source, "default_factory") - result = DefaultDictVariable( + dict_vt = ConstDictVariable( result, # type: ignore[arg-type] - type(value), + mutation_type=ValueMutationExisting(), + source=self.source, + ) + result = DefaultDictVariable( + value, default_factory=VariableBuilder(self.tx, factory_source)( value.default_factory ), + dict_vt=dict_vt, source=self.source, ) + return self.tx.output.side_effects.track_object_existing(value, result) + elif istype(value, collections.OrderedDict): + dict_vt = ConstDictVariable( + result, # type: ignore[arg-type] + user_cls=collections.OrderedDict, + mutation_type=ValueMutationExisting(), + source=self.source, + ) + result = OrderedDictVariable(value, dict_vt=dict_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) else: result = ConstDictVariable( result, # type: ignore[arg-type] - user_cls=type(value), source=self.source, ) @@ -969,10 +981,11 @@ def build_key_value( self.install_guards(GuardBuilder.EQUALS_MATCH) return FrozensetVariable(items, source=self.source) elif isinstance( - value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + value, + (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType), ): self.install_guards(GuardBuilder.ID_MATCH) - return EnumVariable(value=value, source=self.source) + return UserDefinedObjectVariable(value, source=self.source) elif DebuggingVariable.is_reorderable_logging_function(value): # Put this above builtin_callable so that print() can be handled # along with other builtin debugging functions @@ -1133,6 +1146,7 @@ def build_key_value( source=AttrSource(self.source, member="__self__"), ), "apply", + py_type=type(value), ) elif isinstance(value, torch._C._ImperativeEngine): self.install_guards(GuardBuilder.ID_MATCH) @@ -1183,9 +1197,12 @@ def build_key_value( # type: ignore[arg-type] return StreamContextVariable.create(self.tx, stream_var) elif isinstance(value, torch.Stream): - # This refers to the device-agnostic torch.Stream self.install_guards(GuardBuilder.TYPE_MATCH) - index = register_user_object(value, self.source) + if isinstance(self.source, CurrentStreamSource): + # Reuse the index pre-allocated in SymbolicStreamState.__init__ + index = CURRENT_STREAM_INDEX + else: + index = register_user_object(value, self.source) stream_proxy = self.tx.output.create_proxy( "call_function", get_external_object_by_index, (index,), {} ) @@ -1267,6 +1284,8 @@ def build_key_value( hint=value.real, # type: ignore[attr-defined] source=source, ) + if not isinstance(node, SymInt): + raise AssertionError(f"Expected SymInt, got {type(node)}") # Bind to graph input sym_node_proxy = self.tx.output.root_tracer.create_graph_input( @@ -1329,6 +1348,8 @@ def build_key_value( hints=[*graph_break_hints.SUPPORTABLE], ) assert new_symint is not None + if not isinstance(new_symint, SymInt): + raise AssertionError(f"Expected SymInt, got {type(new_symint)}") sym_node_proxy = self.tx.output.root_tracer.create_graph_input( re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(new_symint), @@ -1371,18 +1392,21 @@ def build_key_value( elif value is TensorDescriptor.from_tensor: return CreateTMADescriptorStableVariable() elif value is set_allocator: - return TritonSetAllocatorSkipVariable(value) + return TritonSetAllocatorVariable(value) elif isinstance(value, torch.amp.autocast_mode.autocast): - self.install_guards(GuardBuilder.ID_MATCH) - return AutocastModeVariable( - target_values=[ - value.device, - value.fast_dtype, - value._enabled, - value._cache_enabled, - ], - source=self.source, - ) + if isinstance(value, torch.amp.autocast_mode._UnmanagedAutocast): + return self.wrap_user_defined(value) + else: + self.install_guards(GuardBuilder.ID_MATCH) + return AutocastModeVariable( + target_values=[ + value.device, + value.fast_dtype, + value._enabled, + value._cache_enabled, + ], + source=self.source, + ) elif TorchCtxManagerClassVariable.is_matching_cls(value): if inspect.isclass(value): self.install_guards(GuardBuilder.CLASS_MATCH) @@ -1432,6 +1456,7 @@ def build_key_value( return GetAttrVariable( BuiltinVariable(float, source=self.source), value.__name__, + py_type=type(value), ) elif is_function_or_wrapper(value): value, attr_name = unwrap_with_attr_name_if_wrapper(value) @@ -1521,14 +1546,14 @@ def build_key_value( # ID_MATCH even if its a global variable. self.install_guards(GuardBuilder.CLASS_MATCH) - if is_opaque_type(value): - return OpaqueObjectClassVariable( + if isinstance(value, type) and issubclass(value, enum.Enum): + return UserDefinedClassVariable( value, source=self.source, ) - if isinstance(value, type) and issubclass(value, enum.Enum): - return UserDefinedEnumClassVariable( + if is_opaque_type(value): + return OpaqueObjectClassVariable( value, source=self.source, ) @@ -1537,6 +1562,18 @@ def build_key_value( value, source=self.source, ) + elif type(value) is torch._C.Generator: + # Generator is registered as an opaque reference type for make_fx + # tracing, but in dynamo we handle it as a regular object so that + # trace_rules-based graph breaks (e.g. initial_seed, manual_seed) + # work gracefully — allowing dynamo to compile code before and + # after the generator call. TorchScriptObjectVariable's var_getattr + # and call_method are decorated with @_raise_hard_error_if_graph_break, + # which turns any graph break into a hard error that falls back to + # eager for the entire function. Generator methods intentionally + # graph-break (they mutate/read RNG state), so they need the + # UserDefinedObjectVariable path which supports graceful graph breaks. + return self.wrap_user_defined(value) elif TorchScriptObjectVariable.is_matching_cls(type(value)): from ..source import ( FlattenScriptObjectSource, @@ -1610,6 +1647,7 @@ def build_key_value( self.tx.output.fake_mode, value ) if is_opaque_value_type(type(value)) and not should_hoist(type(value)): + fake_script_obj = value proxy = value elif config.install_free_tensors and ( @@ -1657,7 +1695,7 @@ def build_key_value( self.tx.output.guard_on_key_order.add(self.source) # We need all the keys to be hashable. We do this within the - # _HashableTracker class in dicts.py + # HashableTracker class in hashable.py def build_key_value( i: Any, k: Any, v: Any ) -> tuple[VariableTracker, VariableTracker]: @@ -1670,22 +1708,15 @@ def build_key_value( return key, res_value - # Ensure that we call dict.keys and not value.keys (which can call - # overridden keys method). In the C++ guards, we relied on - # PyDict_Next to traverse the dictionary, which uses the internal - # data structure and does not call the overridden keys method. result = dict( build_key_value(i, k, v) - for i, (k, v) in enumerate(get_items_from_dict(value)) + for i, k, v in enumerate_items_with_dict_position(value) ) + is_ordered_dict = isinstance(value, collections.OrderedDict) dict_vt = ConstDictVariable( result, - user_cls=( - collections.OrderedDict - if isinstance(value, collections.OrderedDict) - else dict - ), + user_cls=(collections.OrderedDict if is_ordered_dict else dict), mutation_type=ValueMutationExisting(), source=self.source, ) @@ -1693,7 +1724,12 @@ def build_key_value( # bytecode simple dict_vt.should_reconstruct_all = True - result = UserDefinedDictVariable(value, dict_vt=dict_vt, source=self.source) + if is_ordered_dict: + result = OrderedDictVariable(value, dict_vt=dict_vt, source=self.source) + else: + result = UserDefinedDictVariable( + value, dict_vt=dict_vt, source=self.source + ) return self.tx.output.side_effects.track_object_existing(value, result) elif isinstance(value, tuple): self.install_guards(GuardBuilder.TYPE_MATCH) @@ -1767,7 +1803,7 @@ def build_key_value( return self.tx.output.side_effects.track_object_existing(value, result) elif is_frozen_dataclass(value): self.install_guards(GuardBuilder.TYPE_MATCH) - result = FrozenDataClassVariable.create(self.tx, value, source=self.source) + result = FrozenDataClassVariable(value, source=self.source) return self.tx.output.side_effects.track_object_existing(value, result) elif isinstance(value, dict_keys): if all(ConstantVariable.is_literal(k) for k in value): @@ -1825,6 +1861,8 @@ def build_key_value( return self.wrap_user_defined(value) def wrap_user_defined(self, value: Any) -> VariableTracker: + from .user_defined import _CONSTANT_BASE_TYPES + self.install_guards(GuardBuilder.TYPE_MATCH) if InspectVariable.is_matching_object(value): # Skip guards on inspect related variable trackers because they are @@ -1832,6 +1870,12 @@ def wrap_user_defined(self, value: Any) -> VariableTracker: # cause recompiles) and can cause a large number of OBJECT_ALIASING # guards. result = InspectVariable(value, source=SkipGuardSource(self.source)) + elif ( + isinstance(value, _CONSTANT_BASE_TYPES) + and type(value) not in common_constant_types + ): + self.install_guards(GuardBuilder.CONSTANT_SUBCLASS_MATCH) + result = UserDefinedConstantVariable(value, source=self.source) else: result = UserDefinedObjectVariable(value, source=self.source) if not SideEffects.cls_supports_mutation_side_effects(type(value)): @@ -1968,10 +2012,28 @@ def wrap_range_iterator(self, value: range_iterator) -> VariableTracker: self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH) # Get all the values from the range iterator; no need to install guards # on items since `RANGE_ITERATOR_MATCH` guarantees the same items. - items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] + items: list[VariableTracker] = [ + ConstantVariable.create(v) for v in copy.deepcopy(value) + ] result = ListIteratorVariable(items, source=self.source) return self.tx.output.side_effects.track_mutable(value, result) + def wrap_itertools_count(self, value: Any) -> VariableTracker: + current_item, step = normalize_count_iter(value) + if not ( + ConstantVariable.is_literal(current_item) + and ConstantVariable.is_literal(step) + ): + return self.wrap_user_defined(value) + + self.install_guards(GuardBuilder.COUNT_ITERATOR_MATCH) + result = CountIteratorVariable( + ConstantVariable.create(current_item), + ConstantVariable.create(step), + source=self.source, + ) + return self.tx.output.side_effects.track_mutable(value, result) + def wrap_slice_range(self, value: slice | range) -> SliceVariable | RangeVariable: items = [ VariableBuilder(self.tx, AttrSource(self.get_source(), k))( @@ -2102,62 +2164,54 @@ def wrap_module(self, value: torch.nn.Module) -> VariableTracker: value = value.get_base() self.source = AttrProxySource(self.source) - if torch._dynamo.config.inline_inbuilt_nn_modules: - freezing = is_parameter_freezing() + freezing = is_parameter_freezing() - # Guard against the case where user may overwrite named parameters - # / named buffers - # NOTE: This is not likely to happen but worth guarding to avoid - # exception - if ( - callable(value.named_parameters) + # Guard against the case where user may overwrite named parameters + # / named buffers + # NOTE: This is not likely to happen but worth guarding to avoid + # exception + if ( + callable(value.named_parameters) + # type: ignore[attr-defined] + and value.named_parameters.__func__ is og_module_named_parameters_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules # type: ignore[attr-defined] - and value.named_parameters.__func__ - is og_module_named_parameters_fn_ptr - ): - try: # catch TypeErrors in named_parameters() from unserializable nn modules - # type: ignore[attr-defined] - for _, p in value.named_parameters(): - self.mark_static_input(p, guard=freezing) - except TypeError as e: - raise_observed_exception(type(e), self.tx, args=list(e.args)) - - if ( - callable(value.named_buffers) + for _, p in value.named_parameters(): + self.mark_static_input(p, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if ( + callable(value.named_buffers) + # type: ignore[attr-defined] + and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr + ): + try: # catch TypeErrors in named_parameters() from unserializable nn modules # type: ignore[attr-defined] - and value.named_buffers.__func__ is og_module_named_buffers_fn_ptr - ): - try: # catch TypeErrors in named_parameters() from unserializable nn modules - # type: ignore[attr-defined] - for _, b in value.named_buffers(): - self.mark_static_input(b, guard=freezing) - except TypeError as e: - raise_observed_exception(type(e), self.tx, args=list(e.args)) - - if freezing: - # we need to add the module to tracing context - # in order to allow its params to get invalidated - # this will get cleaned up once compile ends - self.tx.output.nn_modules[self.name] = value + for _, b in value.named_buffers(): + self.mark_static_input(b, guard=freezing) + except TypeError as e: + raise_observed_exception(type(e), self.tx, args=list(e.args)) + + if freezing: + # we need to add the module to tracing context + # in order to allow its params to get invalidated + # this will get cleaned up once compile ends + self.tx.output.nn_modules[self.name] = value if ( value.__module__.startswith(("torch.nn.modules", "torch.ao.")) and not value.__module__.startswith("torch.nn.modules.container") ) or getattr(value.__class__, "_dynamo_marked_static", False): new_source = self.source - if config.inline_inbuilt_nn_modules and ( - not self.tx.output.export or config.install_free_tensors - ): - # Export corner case - look at test_repros.py test_inlining_cornercase + if not self.tx.output.export or config.install_free_tensors: new_source = UnspecializedBuiltinNNModuleSource(self.source) result = UnspecializedBuiltinNNModuleVariable(value, source=new_source) install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) else: new_source = self.source - if config.inline_inbuilt_nn_modules and ( - not self.tx.output.export or config.install_free_tensors - ): - # Export corner case - look at test_repros.py test_inlining_cornercase + if not self.tx.output.export or config.install_free_tensors: new_source = UnspecializedNNModuleSource(self.source) result = UnspecializedNNModuleVariable(value, source=new_source) install_guard(new_source.make_guard(GuardBuilder.TYPE_MATCH)) @@ -2302,15 +2356,10 @@ def wrap_tensor(self, value: torch.Tensor) -> VariableTracker: is_static_input = get_static_address_type(value) is not None - if ( - config.inline_inbuilt_nn_modules - and not is_static_input - and ( - isinstance(value, torch.nn.Parameter) - # mark tensor attributes of nn modules static. This is done to keep inline_inbuilt_nn_modules behavior - # compatible with previous behavior. - or (source and source.guard_source.is_unspecialized_nn_module()) - ) + if not is_static_input and ( + isinstance(value, torch.nn.Parameter) + # mark tensor attributes of nn modules static + or (source and source.guard_source.is_unspecialized_nn_module()) ): self.mark_static_input(value, guard=is_parameter_freezing()) is_static_input = True @@ -2326,9 +2375,7 @@ def wrap_tensor(self, value: torch.Tensor) -> VariableTracker: ) make_graph_attribute = is_static_input and ( - not config.inline_inbuilt_nn_modules - or is_parameter_freezing() - or torch._dynamo.config.prepare_freezing + is_parameter_freezing() or torch._dynamo.config.prepare_freezing ) if should_install_free_tensor or ( @@ -2506,25 +2553,22 @@ def wrap_tensor(self, value: torch.Tensor) -> VariableTracker: if is_dtensor: self.install_guards(GuardBuilder.TYPE_MATCH) - # The inner tensor name is always _local_tensor. If its not, we - # raise assertion to update the check accordingly. - inner_tensor_name = value.__tensor_flatten__()[0][0] - if inner_tensor_name != "_local_tensor": + inner_attrs = value.__tensor_flatten__()[0] + if inner_attrs != ["_local_tensor", "device_mesh"]: raise RuntimeError( - "Expecting Dtensor inner tensor name to be _local_tensor" + "Expecting DTensor inner attrs to be ['_local_tensor', 'device_mesh']" ) - # Now selectively guard on the flattening context flattening_ctx = value.__tensor_flatten__()[1] - # This is supposed to be (self._spec, self.requires_grad) if not ( - len(flattening_ctx) == 2 - and flattening_ctx[0] == value._spec - and flattening_ctx[1] == value.requires_grad + len(flattening_ctx) == 4 + and flattening_ctx[0] == value._spec.placements + and flattening_ctx[1] == value._spec.tensor_meta + and flattening_ctx[2] == value._spec.shard_order + and flattening_ctx[3] == value.requires_grad ): - # If not, raise an assertion to update to the new guards raise RuntimeError( - "Expecting Dtensor flattening ctx to be _spec, requires_grad" + "Expecting DTensor flattening ctx to be (placements, tensor_meta, shard_order, requires_grad)" ) # Guard on the dtensor spec install_guard( @@ -2548,9 +2592,17 @@ def wrap_tensor(self, value: torch.Tensor) -> VariableTracker: attrs, _ = value.__tensor_flatten__() for attr in attrs: inner_value = getattr(value, attr) + # FakeScriptObject wraps the real opaque object during + # fake-mode tracing; unwrap before the type check. + inner_type = type(inner_value) + if isinstance( + inner_value, + torch._library.fake_class_registry.FakeScriptObject, + ): + inner_type = type(inner_value.real_obj) if not isinstance( inner_value, torch.Tensor - ) and not is_opaque_reference_type(type(inner_value)): + ) and not is_opaque_reference_type(inner_type): raise RuntimeError( f"{type(inner_value).__name__!r} found in tensor attrs of " f"{type(value).__name__}.__tensor_flatten__(). " @@ -2673,6 +2725,7 @@ def wrap_symint( return self.tx.output.unspec_variable_map[self.name] shape_env = self.tx.output.shape_env + frame_state_entry: FrameStateSizeEntry | None = None if TracingContext.get().force_unspec_int_unbacked_size_like: wrapped_value = shape_env.create_unbacked_symint() _constrain_range_for_size(wrapped_value) @@ -2736,11 +2789,20 @@ def wrap_symint( self.install_guards(GuardBuilder.CONSTANT_MATCH) return ConstantVariable.create(value=value) + excluded_scalar = ( + frame_state_entry.excluded_scalar + if config.automatic_dynamic_exclusion_guard + and frame_state_entry is not None + else None + ) wrapped_value = shape_env.create_unspecified_symint_and_symbol( value, source=self.source, dynamic_dim=dynamic_dim, + excluded_value=excluded_scalar, ) + if not isinstance(wrapped_value, SymInt): + raise AssertionError(f"Expected SymInt, got {type(wrapped_value)}") self.tx.output.tracked_fakes.append( TrackedFake(wrapped_value, self.source, context) @@ -2811,7 +2873,7 @@ def wrap_symfloat(self, value: float) -> VariableTracker: # break because they expect all cuda inputs but our tensorified # float will be a f64[] cpu tensor. Fixes the following test # when specialize_float=False - # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda # noqa: B950 + # python test/inductor/test_compiled_optimizers.py CompiledOptimizerTests.test_rmsprop_weight_decay_maximize_capturable_cuda or torch._inductor.config.triton.cudagraphs or justknobs_check("pytorch/compiler:unspecialize_float_killswitch", False) or ( @@ -3151,7 +3213,7 @@ def wrap_fx_proxy_cls( ) and proxy.node.op != "placeholder" ): - tx.output.current_tracer.record_tensor_or_symint_vt(out) + tx.output.current_tracer.record_proxyable_vt(out) return out @@ -3360,6 +3422,7 @@ def handle_traced_output( # WARNING: this assumes the same target_cls as this tuple/list call unpacked.append( + # pyrefly: ignore [bad-argument-type] wrap_fx_proxy_cls( # pyrefly: ignore[bad-argument-type] target_cls=target_cls, @@ -3378,13 +3441,18 @@ def handle_traced_output( elif istype(example_value, (list, immutable_list)): return ListVariable(unpacked, **options) else: - assert ( - example_value.__class__.__module__ == "torch.return_types" - or hasattr(example_value, "_fields") - ), ( - f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}" + assert is_namedtuple(example_value), ( + f"expected namedtuple or structseq but got {type(example_value)}" + ) + tuple_vt = TupleVariable( + unpacked, + mutation_type=options.get("mutation_type", ValueMutationNew()), + ) + return UserDefinedTupleVariable.get_vt_cls(type(example_value))( + example_value, + tuple_vt=tuple_vt, + **options, # type: ignore[arg-type] ) - return NamedTupleVariable(unpacked, example_value.__class__, **options) # type: ignore[arg-type] elif example_value is None or proxy.node.target is torch.manual_seed: return ConstantVariable.create(None, **options) elif isinstance(example_value, (torch.SymInt, torch.SymFloat, torch.SymBool)): @@ -3483,6 +3551,7 @@ def handle_traced_output( torch._C._get_mem_efficient_sdp_enabled, torch._C._get_math_sdp_enabled, torch._C._get_overrideable_sdp_enabled, + torch._C._is_autocast_available, "is_integer", ] + list(supported_const_comparison_op_values.keys()) @@ -3499,10 +3568,21 @@ def handle_traced_output( elif isinstance(example_value, float) or proxy.node.target in ["hex", "__round__"]: set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) + elif isinstance(example_value, torch._library.fake_class_registry.FakeScriptObject): + # example_value is already a FakeScriptObject (e.g. returned by getitem + # on a container whose fake kernel returns a FakeScriptObject). No need + # to convert it — just wrap the proxy directly. + return TorchScriptObjectVariable.create( + proxy, + example_value, + ) elif is_opaque_type(type(example_value)): # This is for handling opaque objects in custom ops if is_opaque_value_type(type(example_value)): - proxy = example_value # pyrefly: ignore[bad-assignment] + return TorchScriptObjectVariable.create( + example_value, # pyrefly: ignore[bad-argument-type] + example_value, + ) fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( tx.output.fake_mode, example_value ) @@ -3597,6 +3677,7 @@ def construct_tensor_variable( if proxy.node.op != "placeholder": tx.output.current_tracer.track_produced_symints(example_value, proxy) options.update(get_specialized_props(target_cls, tx, example_value, subclass_type)) + # pyrefly: ignore [bad-argument-count] return target_cls(proxy, **options) @@ -4004,6 +4085,7 @@ def update_dim2constraint( shape_env_to_source_to_symbol_cache=shape_env_to_source_to_symbol_cache, shape_ids=getattr(e, "_dynamo_shape_ids", None), unbacked_bounds=getattr(e, "_dynamo_unbacked_bounds", None), + excluded_sizes=frame_state_entry.excluded_sizes, ) @@ -4205,13 +4287,15 @@ def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker: if isinstance(value, VariableTracker): # This is always valid to call, and useful for recursive calls. return value - elif is_opaque_type(type(value)): + elif is_opaque_value_type(type(value)) and not isinstance(value, enum.Enum): + return TorchScriptObjectVariable.create(value, value) + elif is_opaque_reference_type(type(value)): # This is for handling opaque objects in custom ops fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( tx.output.fake_mode, value ) return TorchScriptObjectVariable.create( - value, + value, # pyrefly: ignore[bad-argument-type] fake_script_obj, ) # type: ignore[attr-defined] @@ -4230,34 +4314,38 @@ def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker: # NamedTuple._make uses an alias of tuple.__new__ # pyrefly: ignore[not-callable, bad-argument-count, missing-attribute] obj = trace_rules.lookup_callable(value.__self__)(value.__self__) - return GetAttrVariable(obj, "__new__") + return GetAttrVariable(obj, "__new__", py_type=type(value)) elif is_function_or_wrapper(value): # pyrefly: ignore[not-callable, bad-argument-count] return trace_rules.lookup(value)(value) elif isinstance( - value, (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType) + value, + (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType), ): - return EnumVariable(value) + return UserDefinedObjectVariable(value) elif isinstance(value, (type, abc.ABCMeta)): - if isinstance(value, type) and issubclass(value, enum.Enum): - return UserDefinedEnumClassVariable(value) + if issubclass(type(value), type) and issubclass(value, BaseException): + return UserDefinedExceptionClassVariable(value) return UserDefinedClassVariable(value) elif isinstance(value, types.MethodWrapperType): return MethodWrapperVariable(value) - elif ( - isinstance(value, types.MethodType) - # We only want to support sourceless class objects here - # An instance variable is not allowed and it should have source - and isinstance(value.__self__, (type, abc.ABCMeta)) - ): - # value is a classmethod - assert getattr(value.__self__, value.__func__.__name__) == value - cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) - try: - # pyrefly: ignore[bad-argument-type] - return cls_obj_vt.var_getattr(tx, value.__func__.__name__) - except NotImplementedError: - pass # failthrough to unimplemented branch + elif isinstance(value, types.MethodType): + if isinstance(value.__self__, (type, abc.ABCMeta)): + # value is a classmethod + assert getattr(value.__self__, value.__func__.__name__) == value + cls_obj_vt = SourcelessBuilder.create(tx, value.__self__) + try: + # pyrefly: ignore[bad-argument-type] + return cls_obj_vt.var_getattr(tx, value.__func__.__name__) + except NotImplementedError: + pass # failthrough to unimplemented branch + else: + # Instance method — look up the VT for __self__ via side effects + obj_vt = tx.output.side_effects.id_to_variable.get(id(value.__self__)) + if obj_vt is not None: + return torch._dynamo.variables.UserMethodVariable( + value.__func__, obj_vt + ) elif isinstance(value, torch.fx.graph_module.GraphModule): return SourcelessGraphModuleVariable(value) elif isinstance(value, torch.utils._pytree.TreeSpec): @@ -4265,7 +4353,19 @@ def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker: elif isinstance(value, re.Pattern): return ConstantLikeVariable(value) elif isinstance(value, torch._dynamo.variables.lazy.LazySymNodeFormatString): - return ConstantVariable.create(str(value)) + try: + return ConstantVariable.create(str(value)) + # If we cannot create due to error in str() call, we should + # try explicitly for string format variable + except ( + torch._dynamo.exc.UserError, + torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, + ): + return StringFormatVariable.create( + value.fmt_var.as_python_constant(), + [value.sym_node_var], + {}, + ) elif isinstance(value, type(torch._higher_order_ops.flex_attention_backward)): return torch._dynamo.variables.higher_order_ops.FlexAttentionBackwardHighOrderVariable( value @@ -4277,13 +4377,22 @@ def create(tx: "InstructionTranslatorBase", value: Any) -> VariableTracker: SourcelessBuilder.create(tx, getattr(value, name)) for name in namedtuple_fields(type(value)) ] - return NamedTupleVariable(output, tuple_cls=type(value)) + tuple_vt = TupleVariable(output, mutation_type=ValueMutationNew()) + return UserDefinedTupleVariable.get_vt_cls(type(value))( + value, tuple_vt=tuple_vt + ) elif ( isinstance(value, torch.SymInt) and value.node.expr in tx.output.bound_symbols ): proxy = tx.output.bound_symbols[value.node.expr] return SymNodeVariable.create(tx, proxy) + elif isinstance(value, slice): + items = [ + SourcelessBuilder.create(tx, getattr(value, k)) + for k in ("start", "stop", "step") + ] + return SliceVariable(items, tx) # pyrefly: ignore[bad-argument-type] elif istype(value, object): return ObjectVariable(value) unimplemented( diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 3bd26c992568c..1bc1d17d13761 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -32,10 +32,9 @@ import unittest from collections import defaultdict, OrderedDict from collections.abc import Callable, Iterable, KeysView, Sequence -from typing import Any, cast, Literal, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING import torch -from torch import sym_float, sym_int from torch._subclasses.meta_utils import is_sparse_any from torch.overrides import BaseTorchFunctionMode from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -65,7 +64,6 @@ check_numpy_ndarray_args, check_unspec_or_constant_args, check_unspec_python_args, - cmp_name_to_op_mapping, dict_methods, extract_fake_example_value, frozenset_methods, @@ -78,26 +76,24 @@ proxy_args_kwargs, raise_args_mismatch, set_methods, + specialize_symnode, str_methods, tensortype_to_dtype, ) -from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker -from .constant import ( - CONSTANT_VARIABLE_FALSE, - CONSTANT_VARIABLE_NONE, - ConstantVariable, - EnumVariable, +from .base import ( + AsPythonConstantNotImplementedError, + NO_SUCH_SUBOBJ, + ValueMutationNew, + VariableTracker, ) +from .constant import ConstantVariable, FakeIdVariable from .dicts import ( ConstDictVariable, - DefaultDictVariable, + DictItemsVariable, DictKeysVariable, DictViewVariable, - FrozensetVariable, - is_hashable, - OrderedSetClassVariable, - SetVariable, ) +from .hashable import is_hashable from .lists import ( BaseListVariable, ListIteratorVariable, @@ -107,7 +103,8 @@ TupleIteratorVariable, TupleVariable, ) -from .streams import EventVariable, StreamVariable +from .misc import NullVariable +from .sets import FrozensetVariable, OrderedSetClassVariable, SetVariable from .tensor import ( FakeItemVariable, supported_comparison_ops, @@ -147,6 +144,14 @@ operator.ixor: operator.xor, } +_BUILTIN_CONSTANT_FOLDABLE_METHODS: dict[type, frozenset[str]] = { + int: frozenset({"__new__", "from_bytes"}), + bool: frozenset({"__new__", "from_bytes"}), + float: frozenset({"fromhex", "hex"}), +} +if sys.version_info >= (3, 14): + _BUILTIN_CONSTANT_FOLDABLE_METHODS[complex] = frozenset({"from_number"}) + _HandlerCallback = Callable[ ["InstructionTranslator", typing.Any, typing.Any], VariableTracker | None @@ -213,6 +218,13 @@ operator.length_hint, ) +_SET_LIKE_OP_SUPPORT: tuple[type[VariableTracker], ...] = ( + DictItemsVariable, + DictKeysVariable, + SetVariable, + UserDefinedObjectVariable, +) + BUILTIN_TO_TENSOR_FN_MAP: dict[Callable[..., Any], Callable[..., Any]] = {} # These functions represent the r* versions of the above ops @@ -288,7 +300,59 @@ def __torch_function__( BUILTIN_TO_TENSOR_RFN_MAP[op] = most_recent_func -class BuiltinVariable(VariableTracker): +class BaseBuiltinVariable(VariableTracker): + """ + Common base class for all builtin variable trackers (BuiltinVariable, + DictBuiltinVariable, IterBuiltinVariable, and future specialized builtins). + + Provides shared implementations for guard installation, hasattr tracing, + and Python-level hashability/equality. + + Specialized subclasses (e.g. DictBuiltinVariable) set `_fn` as a class + attribute. BuiltinVariable stores the callable on the instance as `self.fn` + and overrides as_python_constant / reconstruct / var_getattr accordingly. + """ + + _fn: Any = None + + @classmethod + def create_with_source(cls, value: Any, source: Source) -> "BaseBuiltinVariable": + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + return cls(source=source) + + def as_python_constant(self) -> Any: + return self._fn + + def reconstruct(self, codegen: "PyCodegen") -> None: + name = self.as_python_constant().__name__ + assert name not in codegen.tx.f_globals, "shadowed global" + codegen.append_output(codegen.create_load_global(name, add=True)) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + source = self.source and AttrSource(self.source, name) + attr = getattr(self._fn, name, None) + return variables.GetAttrVariable( + self, name, py_type=type(attr) if attr is not None else None, source=source + ) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return VariableTracker.build(tx, hasattr(self.as_python_constant(), name)) # type: ignore[return-value] + + def is_python_hashable(self) -> bool: + return True + + def get_python_hash(self) -> int: + return hash(self.as_python_constant()) + + def is_python_equal(self, other: object) -> bool: + return isinstance(other, BaseBuiltinVariable) and ( + self.as_python_constant() is other.as_python_constant() # type: ignore[union-attr] + ) + + +class BuiltinVariable(BaseBuiltinVariable): """ A VariableTracker that represents a built-in value (functions and operators). A lot of the code here assumes it will be a function object. @@ -317,17 +381,22 @@ def _constant_fold_functions() -> set[Callable[..., Any]]: abs, all, any, + ascii, + bin, bool, callable, chr, complex, divmod, float, + format, getattr, + hex, int, len, max, min, + oct, ord, pow, repr, @@ -484,15 +553,11 @@ def _binop_handlers() -> dict[ # combinations. Handlers are attempted in order, and will be used if the type checks # match. They are expected to have the signature: # fn(tx, arg0: VariableTracker, arg1: VariableTracker) -> VariableTracker - from .functions import BaseUserFunctionVariable, UserFunctionVariable + from .functions import BaseUserFunctionVariable from .nn_module import NNModuleVariable from .tensor import supported_const_comparison_ops from .torch import BaseTorchVariable - from .user_defined import ( - UserDefinedClassVariable, - UserDefinedObjectVariable, - UserDefinedVariable, - ) + from .user_defined import UserDefinedVariable # Override table contains: op_fn -> [list of handlers] op_handlers: dict[Any, list[Any]] = {} @@ -708,7 +773,7 @@ def expand_list_like( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) list_like_expansion_handlers: list[ @@ -736,7 +801,7 @@ def compare_by_value( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) result: list[ @@ -755,7 +820,6 @@ def compare_by_value( # time benchmark - add_loop_eager. result = [ ((ConstantVariable, ConstantVariable), compare_by_value), - ((EnumVariable, EnumVariable), compare_by_value), ] op_var = BuiltinVariable(op) @@ -825,41 +889,6 @@ def never( op_var = BuiltinVariable(op) result.extend( [ - ( - ( - (UserFunctionVariable, BuiltinVariable), - (UserFunctionVariable, BuiltinVariable), - ), - lambda tx, a, b: VariableTracker.build(tx, op(a.fn, b.fn)), - ), - ( - ( - NNModuleVariable, - NNModuleVariable, - ), - lambda tx, a, b: VariableTracker.build( - tx, - op( - tx.output.get_submodule(a.module_key), - tx.output.get_submodule(b.module_key), - ), - ), - ), - ( - (UserDefinedObjectVariable, UserDefinedObjectVariable), - compare_by_value, - ), - ( - (UserDefinedClassVariable, UserDefinedClassVariable), - compare_by_value, - ), - ( - ( - (StreamVariable, EventVariable, ConstantVariable), - (StreamVariable, EventVariable, ConstantVariable), - ), - compare_by_value, - ), ( (TensorVariable, VariableTracker), op_var._comparison_with_tensor, @@ -884,28 +913,15 @@ def handle_is( left: VariableTracker, right: VariableTracker, ) -> VariableTracker | None: - # If the two objects are of different type, we can safely return False - # and True for `is` and `is not`, respectively - if type(left) is not type(right): - return VariableTracker.build(tx, op.__name__ != "is_") - if left is right: - return VariableTracker.build(tx, op(left, right)) - # VT identity is a reliable proxy for Python identity for - # mutable containers created during tracing. For types - # like EnumVariable two distinct VTs can wrap the same - # singleton, so we must not claim "is False" there. - if isinstance(left, (ConstDictVariable, ListVariable)): - return VariableTracker.build(tx, op(left, right)) - if istype(left, variables.ObjectVariable) and istype( - right, variables.ObjectVariable - ): - return VariableTracker.build(tx, op(left.value, right.value)) - if ( - istype(left, variables.ExceptionVariable) - and istype(right, variables.ExceptionVariable) - and left.exc_type is not right.exc_type - ): - return VariableTracker.build(tx, op(left, right)) + from .object_protocol import vt_identity_compare + + result = vt_identity_compare(left, right) + if result is None: + return None + is_same = result.as_python_constant() + return VariableTracker.build( + tx, is_same if op.__name__ == "is_" else not is_same + ) result.append(((VariableTracker, VariableTracker), handle_is)) # type: ignore[arg-type] @@ -935,7 +951,15 @@ def _find_binop_handler( def can_insert_in_graph(self) -> bool: return self.fn in self._fx_graph_functions() + # Builtins that have been promoted to their own VT classes. Creating a + # BuiltinVariable for these is a bug; use the specialized class instead. + MUST_USE_SPECIALIZED: frozenset[Any] = frozenset({dict, getattr, iter, list}) + def __init__(self, fn: Any, **kwargs: Any) -> None: + assert fn not in self.MUST_USE_SPECIALIZED, ( + f"Use the specialized VT class for {fn!r}, not BuiltinVariable. " + f"E.g. DictBuiltinVariable for dict." + ) super().__init__(**kwargs) self.fn = fn @@ -950,6 +974,9 @@ def __repr__(self) -> str: def as_python_constant(self) -> Any: return self.fn + def get_real_python_backed_value(self) -> Any: + return self.fn + def as_proxy(self) -> Any: DTYPE = { bool: torch.bool, @@ -1021,7 +1048,7 @@ def _make_handler( ) -> Callable[ [ "InstructionTranslator", - tuple[VariableTracker, ...], + list[VariableTracker], dict[str, VariableTracker], ], VariableTracker | None, @@ -1046,7 +1073,7 @@ def _make_handler( def create_exception_class_object( tx: "InstructionTranslator", - args: tuple[VariableTracker, ...], + args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: if fn is AssertionError and not all( @@ -1118,7 +1145,7 @@ def call_self_handler( except TypeError as e: has_constant_handler = obj.has_constant_handler(args, kwargs) if not has_constant_handler: - log.warning( # noqa: G200 + log.warning( "incorrect arg count %s %s and no constant handler", self_handler, e, @@ -1161,7 +1188,7 @@ def constant_fold_handler( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) except AsPythonConstantNotImplementedError as exc: unimplemented( @@ -1201,7 +1228,7 @@ def constant_fold_handler( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) return VariableTracker.build(tx, res) return None @@ -1259,12 +1286,7 @@ def builtin_dispatch( def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: if len(args) == 0: - unimplemented( - gb_type="unimplemented builtin op vars() with no arguments", - context=f"vars: {self} {args}", - explanation=f"Dynamo does not know how to trace builtin operator {self.fn} with no arguments", - hints=[*graph_break_hints.SUPPORTABLE], - ) + return self._call_frame_locals_snapshot(tx) assert len(args) == 1 # vars(obj) is obj.__dict__ if __dict__ is present else TypeError try: @@ -1272,6 +1294,34 @@ def call_vars(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: except ObservedAttributeError: raise_observed_exception(TypeError, tx) + def call_locals( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + if len(args) != 0: + raise_observed_exception(TypeError, tx) + return self._call_frame_locals_snapshot(tx) + + @staticmethod + def _call_frame_locals_snapshot(tx: "InstructionTranslator") -> VariableTracker: + frame_local_names = set(tx.f_code.co_varnames) | set(tx.cell_and_freevars()) + cell_and_freevars = set(tx.cell_and_freevars()) + frame_locals = {} + for name, value in tx.symbolic_locals.items(): + if name not in frame_local_names: + continue + if name in cell_and_freevars: + value = tx.output.side_effects.load_cell(value) + if type.__instancecheck__(NullVariable, value) or isinstance( + value, variables.DeletedVariable + ): + continue + frame_locals[ConstantVariable.create(name)] = value + return ConstDictVariable( + frame_locals, + dict, + mutation_type=ValueMutationNew(), + ) + def _handle_insert_op_in_graph( self, tx: "InstructionTranslator", @@ -1476,19 +1526,6 @@ def call_method( self, args[0], args[1:] ) - if self.fn is dict and len(args) == 1 and not kwargs: - dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) - if isinstance(args[0], BuiltinVariable) and args[0].fn is dict: - return dict_vt - # We don't have to set the underlying dict_vt in - # UserDefinedDictVariable because it will be set to empty - # ConstDictVariableTracker in the constructor. - return tx.output.side_effects.track_new_user_defined_object( - self, - args[0], - args[1:], - ) - if ( self.fn is tuple and len(args) == 2 @@ -1507,56 +1544,34 @@ def call_method( args[1:], ) - if self.fn is list: - list_vt = ListVariable([], mutation_type=ValueMutationNew()) - if isinstance(args[0], BuiltinVariable) and args[0].fn is list: - return list_vt - return tx.output.side_effects.track_new_user_defined_object( - self, - args[0], - args[1:], - ) - - if ( - self.fn in (float, complex) - and len(args) == 1 - and ( - (self.fn is float and name in ("fromhex", "hex")) - or (name == "from_number" and sys.version_info >= (3, 14)) - ) - ): - if args[0].is_python_constant(): + if name in _BUILTIN_CONSTANT_FOLDABLE_METHODS.get(self.fn, ()): + if all(a.is_python_constant() for a in args) and all( + v.is_python_constant() for v in kwargs.values() + ): try: fn = getattr(self.fn, name) - res = fn(args[0].as_python_constant()) + res = fn( + *(a.as_python_constant() for a in args), + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) return VariableTracker.build(tx, res) - except (OverflowError, ValueError) as e: + except Exception as e: raise_observed_exception( type(e), tx, - args=[VariableTracker.build(tx, a) for a in e.args], + args=list(e.args), ) if self.fn is object and name == "__init__": # object.__init__ is a no-op - return variables.CONSTANT_VARIABLE_NONE - - if self.fn is dict and name == "fromkeys": - return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) - - if self.fn is dict: - resolved_fn = getattr(self.fn, name) - if resolved_fn in dict_methods: - if isinstance(args[0], variables.UserDefinedDictVariable): - return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) - elif isinstance(args[0], variables.ConstDictVariable): - return args[0].call_method(tx, name, args[1:], kwargs) + return variables.ConstantVariable.create(None) if self.fn is set: resolved_fn = getattr(self.fn, name) if resolved_fn in set_methods: if isinstance(args[0], variables.UserDefinedSetVariable): - return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) + assert args[0]._base_vt is not None + return args[0]._base_vt.call_method(tx, name, args[1:], kwargs) elif isinstance(args[0], variables.SetVariable): return args[0].call_method(tx, name, args[1:], kwargs) @@ -1580,69 +1595,64 @@ def call_method( tx, getattr(float, name)(args[0].as_python_constant()) ) + if name == "__len__" and len(args) == 1 and not kwargs: + # type.__len__(instance) → len(instance) + # e.g. list.__len__(my_list) → len(my_list) + from .object_protocol import generic_len + + return generic_len(tx, args[0]) + return super().call_method(tx, name, args, kwargs) - def _call_int_float( + def call_int( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker | None: - # Handle cases like int(torch.seed()) - # Also handle sym_float to sym_int cases - if arg.is_tensor() or isinstance(arg, SymNodeVariable): - if arg.is_tensor(): - item = arg.call_method(tx, "item", [], {}) - else: - item = arg - fn_ = sym_int if self.fn is int else sym_float - from torch._dynamo.variables.builder import wrap_fx_proxy + from .object_protocol import generic_int - return wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( - "call_function", - fn_, - (item.as_proxy(),), - {}, - ), - ) - return None + return generic_int(tx, arg) - call_int = _call_int_float - call_float = _call_int_float + def call_float( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker | None: + from .object_protocol import generic_float + + return generic_float(tx, arg) def call_bool( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker | None: - if arg.is_tensor(): - item = arg.call_method(tx, "item", [], {}) - if isinstance(item, SymNodeVariable) and isinstance( - item.sym_num, torch.SymBool - ): - return item - if isinstance(item, variables.ConstantVariable): - return VariableTracker.build(tx, bool(item.value)) - return SymNodeVariable.create(tx, item.as_proxy() != 0) - # Emulate `PyBool_Type.tp_vectorcall` which boils down to `PyObject_IsTrue`. - # https://github.com/python/cpython/blob/3.12/Objects/object.c#L1674-L1697 - if isinstance(arg, SymNodeVariable): - # Note that we delay specializing on symbolic values to avoid - # unnecessary guards. Specialization will happen later if, e.g., the - # resulting boolean is used for branching. - if isinstance(arg.sym_num, torch.SymBool): - return arg - - # Emulate `nb_bool` of int/float objects - # - https://github.com/python/cpython/blob/3.12/Objects/longobject.c#L4940-L4944 - # - https://github.com/python/cpython/blob/3.12/Objects/floatobject.c#L878-L882 - assert istype(arg.sym_num, (torch.SymInt, torch.SymFloat)) - return SymNodeVariable.create(tx, arg.as_proxy() != 0) - - # TODO handle more cases and merge this with this with `generic_jump`. - return None + # Emulate PyBool_Type.tp_vectorcall which boils down to PyObject_IsTrue. + from .object_protocol import generic_bool + + return generic_bool(tx, arg) def call_repr( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker | None: """Handle repr() on user defined objects.""" + if isinstance( + arg, + (variables.ExceptionVariable, variables.UserDefinedExceptionObjectVariable), + ): + try: + const_args = tuple(a.as_python_constant() for a in arg.args) + except NotImplementedError: + return None + if len(const_args) == 0: + value = f"{arg.exc_type.__name__}()" + elif len(const_args) == 1: + value = f"{arg.exc_type.__name__}({const_args[0]!r})" + else: + value = f"{arg.exc_type.__name__}{const_args!r}" + return VariableTracker.build(tx, value) + if isinstance(arg, variables.UserDefinedDictVariable): + assert arg._base_vt is not None + try: + return VariableTracker.build( + tx, repr(arg._base_vt.as_python_constant()) + ) + except Exception: + pass if isinstance(arg, variables.UserDefinedObjectVariable): repr_method = arg.value.__repr__ @@ -1669,7 +1679,7 @@ def call_repr( ( RangeVariable, ConstDictVariable, - DefaultDictVariable, + variables.DefaultDictVariable, OrderedSetClassVariable, DictViewVariable, ), @@ -1680,6 +1690,18 @@ def call_repr( def call_str( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker | None: + if isinstance( + arg, + (variables.ExceptionVariable, variables.UserDefinedExceptionObjectVariable), + ): + if len(arg.args) == 0: + return VariableTracker.build(tx, "") + elif len(arg.args) == 1: + return BuiltinVariable(str).call_function(tx, [arg.args[0]], {}) + else: + tuple_var = variables.TupleVariable(list(arg.args)) + return BuiltinVariable(str).call_function(tx, [tuple_var], {}) + # Handle `str` on a user defined function or object if isinstance(arg, (variables.UserFunctionVariable)): return VariableTracker.build(tx, str(arg.fn)) @@ -1728,12 +1750,6 @@ def call_str( # Inline the user function return user_func_variable.call_function(tx, [arg], {}) - elif isinstance(arg, (variables.ExceptionVariable,)): - if len(arg.args) == 0: - value = f"{arg.exc_type}" - else: - value = ", ".join(a.as_python_constant() for a in arg.args) - return VariableTracker.build(tx, value) return None def _call_min_max( @@ -1893,17 +1909,10 @@ def call_pos( def call_index( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: - if arg.is_tensor(): - unimplemented( - gb_type="unsupported index(Tensor)", - context="", - explanation="Dynamo does not support tracing builtin index() on a Tensor", - hints=[], - ) - - arg = guard_if_dyn(arg) - constant_value = operator.index(arg) - return VariableTracker.build(tx, constant_value) + # Specialize SymNodeVariable to a constant first, matching CPython's + # PyNumber_Index which forces a concrete int. + arg = specialize_symnode(arg) + return arg.nb_index_impl(tx) def call_round( self, @@ -1979,10 +1988,13 @@ def _call_iter_tuple_list( obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) ) else: - if ( - getattr(obj, "source", False) - and isinstance(obj, ConstDictVariable) - and not istype(obj, (SetVariable, FrozensetVariable)) + if getattr(obj, "source", False) and isinstance( + obj, + ( + ConstDictVariable, + variables.OrderedSetVariable, + variables.DictKeySetVariable, + ), ): tx.output.guard_on_key_order.add(obj.source) @@ -2039,48 +2051,7 @@ def _call_tuple_list( else: return self._call_iter_tuple_list(tx, obj, *args, **kwargs) - def call_iter( - self, - tx: "InstructionTranslator", - obj: VariableTracker, - *args: VariableTracker, - **kwargs: VariableTracker, - ) -> VariableTracker: - # avoid the overhead of tracing the polyfill if we already know the class implemented __iter__ - if isinstance( - obj, - ( - variables.ListVariable, - variables.RangeVariable, - variables.IteratorVariable, - variables.ConstDictVariable, - variables.NNModuleVariable, - variables.TensorVariable, - variables.TupleVariable, - DictViewVariable, - ), - ): - return obj.call_method(tx, "__iter__", [], {}) - else: - # If the object doesn't implement a __iter__ method, it will be an error in eager mode when calling iter on it anyway. - # If the object implements a __iter__ method, inlining effectively forwards the call to another iter call - # (e.g. when __iter__ just returns iter(self.list)) or return a user-defined iterator. - # If the object implements a __getitem__ method, iter(...) will call obj.__getitem__() - # with an integer argument starting at 0, until __getitem__ raises IndexError - ret = variables.UserFunctionVariable( - polyfills.builtins.iter_ # type: ignore[arg-type] - ).call_function(tx, [obj, *args], {}) - - if args: - # iter(obj, sentinel) returns an object that implements - # __iter__ and __next__ methods (UserDefinedObjectVariable) - # Wrap the return value in a IteratorVariable subclass (LazyObjectIteratorVariable) - # that forwards the next_variable call to the object. - ret = variables.ObjectIteratorVariable(ret) - return ret - call_tuple = _call_tuple_list - call_list = _call_tuple_list def call_callable( self, tx: "InstructionTranslator", arg: VariableTracker @@ -2097,7 +2068,7 @@ def call_callable( NNModuleVariable, ), ): - return variables.CONSTANT_VARIABLE_TRUE + return variables.ConstantVariable.create(True) elif isinstance(arg, UserDefinedVariable): return VariableTracker.build(tx, callable(arg.value)) elif isinstance( @@ -2111,7 +2082,7 @@ def call_callable( ListIteratorVariable, ), ): - return variables.CONSTANT_VARIABLE_FALSE + return variables.ConstantVariable.create(False) else: return None @@ -2135,218 +2106,95 @@ def call_dir( return VariableTracker.build(tx, dir(arg.value)) if isinstance(arg, BuiltinVariable): return VariableTracker.build(tx, dir(arg.fn)) + # Enable specialized VTs for constants to work with dir() + if arg.is_python_constant(): + return VariableTracker.build(tx, dir(arg.as_python_constant())) return None - def call_dict( + def call_set( self, tx: "InstructionTranslator", - /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + from .builder import SourcelessBuilder - @staticmethod - def call_custom_dict( + assert not kwargs + if not args: + return SetVariable([], mutation_type=ValueMutationNew()) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[f"set() takes 1 positional argument but {len(args)} were given"], + ) + arg = args[0] + if istype(arg, variables.SetVariable): + return arg.clone(mutation_type=ValueMutationNew()) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return SetVariable(items, mutation_type=ValueMutationNew()) + elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( + arg.value, KeysView + ): + iter_fn = arg.var_getattr(tx, "__iter__") + if isinstance(iter_fn, variables.UserMethodVariable): + out = tx.inline_user_function_return(iter_fn, args, kwargs) + if isinstance(out, SetVariable): + return out + return SourcelessBuilder.create(tx, set).call_set(tx, out) + raise_observed_exception( + TypeError, + tx, + args=["failed to construct builtin set()"], + ) + + def call_frozenset( + self, tx: "InstructionTranslator", - user_cls: type, - /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - args_list = list(args) - if ( - len(args_list) == 1 - and isinstance(args_list[0], variables.GetAttrVariable) - and isinstance(args_list[0].obj, variables.UserDefinedClassVariable) - and not tx.output.side_effects.has_pending_mutation(args_list[0].obj) - ): - # Forward the GetAttrVariable(foo, "__dict__") to a realized vt of - # VT(foo.__dict__). This simplifies the construction of the new - # dict. - args_list[0] = args_list[0].get_forwarded_dict(tx) - return tx.inline_user_function_return( - VariableTracker.build(tx, polyfills.construct_dict), - [VariableTracker.build(tx, user_cls), *args_list], - kwargs, + assert not kwargs + if not args: + return FrozensetVariable([]) + if len(args) != 1: + raise_observed_exception( + TypeError, + tx, + args=[ + f"frozenset() takes 1 positional argument but {len(args)} were given" + ], + ) + arg = args[0] + if istype(arg, variables.FrozensetVariable): + return FrozensetVariable([x.vt for x in arg.set_items]) + elif arg.has_force_unpack_var_sequence(tx): + items = arg.force_unpack_var_sequence(tx) + return FrozensetVariable(items) + raise_observed_exception( + TypeError, + tx, + args=["failed to construct builtin frozenset()"], ) - @staticmethod - def call_custom_dict_fromkeys( + def call_zip( + self, tx: "InstructionTranslator", - user_cls: type, - /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - if user_cls not in {dict, OrderedDict, defaultdict}: - unimplemented( - gb_type="Unsupported dict type for fromkeys()", - context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", - explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " - f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict", - hints=[ - f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.", - ], - ) + from .builder import SourcelessBuilder + if kwargs: - # Only `OrderedDict.fromkeys` accepts `value` passed by keyword - if ( - user_cls is not OrderedDict - or len(args) != 1 - or len(kwargs) != 1 - or "value" not in kwargs - ): + if not (len(kwargs) == 1 and "strict" in kwargs): raise_args_mismatch( tx, - f"{user_cls.__name__}.fromkeys", - "1 args and 1 kwargs (`value`)", - f"{len(args)} args and {len(kwargs)} kwargs", + "zip", + "1 kwargs (`strict`)", + f"{len(kwargs)} kwargs", ) - args = (*args, kwargs.pop("value")) - if len(args) == 0: - raise_args_mismatch( - tx, - f"{user_cls.__name__}.fromkeys", - "at least 1 args", - f"{len(args)} args", - ) - if len(args) == 1: - args = (*args, CONSTANT_VARIABLE_NONE) - if len(args) != 2: - raise_args_mismatch( - tx, - f"{user_cls.__name__}.fromkeys", - "2 args", - f"{len(args)} args", - ) - - arg, value = args - DictVariableType = ( - ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable - ) - - if isinstance(arg, dict): - arg_list = [VariableTracker.build(tx, k) for k in arg] - return DictVariableType( - dict.fromkeys(arg_list, value), - user_cls, - mutation_type=ValueMutationNew(), - ) - elif arg.has_force_unpack_var_sequence(tx): - keys = arg.force_unpack_var_sequence(tx) - if all(is_hashable(v) for v in keys): - return DictVariableType( - dict.fromkeys(keys, value), - user_cls, - mutation_type=ValueMutationNew(), - ) - - unimplemented( - gb_type="failed to call dict.fromkeys()", - context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", - explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " - "arguments could not be automatically converted to a list, " - "or some dict key is not hashable.", - hints=[ - "Manually convert the argument to a list.", - "Ensure all keys are hashable.", - ], - ) - - def call_set( - self, - tx: "InstructionTranslator", - *args: VariableTracker, - **kwargs: VariableTracker, - ) -> VariableTracker: - from .builder import SourcelessBuilder - - # Can we merge this implementation and call_dict's one? - assert not kwargs - if not args: - return SetVariable([], mutation_type=ValueMutationNew()) - if len(args) != 1: - raise_observed_exception( - TypeError, - tx, - args=[ - VariableTracker.build( - tx, - f"set() takes 1 positional argument but {len(args)} were given", - ) - ], - ) - arg = args[0] - if istype(arg, variables.SetVariable): - return arg.clone(mutation_type=ValueMutationNew()) - elif arg.has_force_unpack_var_sequence(tx): - items = arg.force_unpack_var_sequence(tx) - return SetVariable(items, mutation_type=ValueMutationNew()) - elif isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( - arg.value, KeysView - ): - iter_fn = arg.var_getattr(tx, "__iter__") - if isinstance(iter_fn, variables.UserMethodVariable): - out = tx.inline_user_function_return(iter_fn, args, kwargs) - if isinstance(out, SetVariable): - return out - return SourcelessBuilder.create(tx, set).call_set(tx, out) - raise_observed_exception( - TypeError, - tx, - args=[VariableTracker.build(tx, "failed to construct builtin set()")], - ) - - def call_frozenset( - self, - tx: "InstructionTranslator", - *args: VariableTracker, - **kwargs: VariableTracker, - ) -> VariableTracker: - assert not kwargs - if not args: - return FrozensetVariable([]) - if len(args) != 1: - raise_observed_exception( - TypeError, - tx, - args=[ - VariableTracker.build( - tx, - f"frozenset() takes 1 positional argument but {len(args)} were given", - ) - ], - ) - arg = args[0] - if istype(arg, variables.FrozensetVariable): - return FrozensetVariable([x.vt for x in arg.set_items]) - elif arg.has_force_unpack_var_sequence(tx): - items = arg.force_unpack_var_sequence(tx) - return FrozensetVariable(items) - raise_observed_exception( - TypeError, - tx, - args=[VariableTracker.build(tx, "failed to construct builtin frozenset()")], - ) - - def call_zip( - self, - tx: "InstructionTranslator", - *args: VariableTracker, - **kwargs: VariableTracker, - ) -> VariableTracker: - from .builder import SourcelessBuilder - - if kwargs: - if not (len(kwargs) == 1 and "strict" in kwargs): - raise_args_mismatch( - tx, - "zip", - "1 kwargs (`strict`)", - f"{len(kwargs)} kwargs", - ) - strict = kwargs.pop("strict", CONSTANT_VARIABLE_FALSE) + strict = kwargs.pop("strict", ConstantVariable.create(False)) iter_args = [ SourcelessBuilder.create(tx, iter).call_function(tx, [arg], {}) for arg in args @@ -2363,10 +2211,9 @@ def call_len( *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - try: - return args[0].call_method(tx, "__len__", list(args[1:]), kwargs) - except AttributeError as e: - raise_observed_exception(type(e), tx, args=list(e.args)) + from .object_protocol import generic_len + + return generic_len(tx, args[0]) def call_getitem( self, @@ -2374,7 +2221,9 @@ def call_getitem( *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - return args[0].call_method(tx, "__getitem__", list(args[1:]), kwargs) + from .object_protocol import vt_getitem + + return vt_getitem(tx, args[0], args[1]) def call_isinstance( self, @@ -2461,13 +2310,12 @@ def check_type(ty: Any) -> bool: ): isinstance_type_tuple = isinstance_type else: - msg = VariableTracker.build( - tx, "isinstance() arg 2 must be a type, a tuple of types, or a union" - ) raise_observed_exception( TypeError, tx, - args=[msg], + args=[ + "isinstance() arg 2 must be a type, a tuple of types, or a union" + ], ) try: @@ -2535,8 +2383,6 @@ def call_hasattr( ) -> VariableTracker | None: if attr.is_python_constant(): name = attr.as_python_constant() - if isinstance(obj, variables.BuiltinVariable): - return VariableTracker.build(tx, hasattr(obj.fn, name)) return obj.call_obj_hasattr(tx, name) return None @@ -2547,7 +2393,7 @@ def call_map( *seqs: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: - strict = CONSTANT_VARIABLE_FALSE + strict = ConstantVariable.create(False) if kwargs: if sys.version_info >= (3, 14): if not (len(kwargs) == 1 and "strict" in kwargs): @@ -2557,7 +2403,7 @@ def call_map( "1 kwargs (`strict`)", f"{len(kwargs)} kwargs", ) - strict = kwargs.pop("strict", CONSTANT_VARIABLE_FALSE) + strict = kwargs.pop("strict", ConstantVariable.create(False)) else: raise_args_mismatch( tx, @@ -2591,6 +2437,8 @@ def call_filter( def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: source = self.source and AttrSource(self.source, name) + if name == "__name__": + return VariableTracker.build(tx, self.fn.__name__, source) if self.fn is object: # for object, we can just directly read the attribute try: @@ -2599,200 +2447,36 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker raise_observed_exception(AttributeError, tx) if not callable(value): return VariableTracker.build(tx, value, source) - return variables.GetAttrVariable(self, name, source=source) + attr = getattr(self.fn, name, None) + return variables.GetAttrVariable( + self, name, py_type=type(attr) if attr is not None else None, source=source + ) - def call_getattr( + def call_setattr( self, tx: "InstructionTranslator", obj: VariableTracker, name_var: VariableTracker, - default: VariableTracker | None = None, + val: VariableTracker, ) -> VariableTracker | None: - if not name_var.is_python_constant(): - unimplemented( - gb_type="getattr() with non-constant name argument", - context=f"getattr({obj}, {name_var}, {default})", - explanation="getattr() with non-constant name argument is not supported", - hints=["Ensure the name argument of getattr() is a string"], - ) - - name = name_var.as_python_constant() - - # See NOTE [Tensor "grad" and "_grad" attr] - if obj.is_tensor() and name == "_grad": - name = "grad" - - if tx.output.side_effects.is_attribute_mutation(obj): - if isinstance(obj, variables.UnspecializedNNModuleVariable): - if ( - name - in ( - "named_parameters", - "parameters", - "named_buffers", - "buffers", - "named_modules", - "modules", - ) - and obj.is_state_mutated - and tx.output.side_effects.has_pending_mutation(obj) - ): - unimplemented( - gb_type="getattr() on nn.Module with pending mutation", - context=f"getattr({obj}, {name}, {default})", - explanation="Intentionally graph breaking on getattr() on a nn.Module " - "with a pending mutation", - hints=[], - ) - - if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): - return tx.output.side_effects.load_attr(obj, name) - - if default is not None: - hasattr_var = self.call_hasattr(tx, obj, name_var) - if hasattr_var is not None: - assert hasattr_var.is_constant_match(True, False) - if not hasattr_var.as_python_constant(): - return default - else: - return default - - source = obj.source and AttrSource(obj.source, name) - if name in {"__bases__", "__base__", "__flags__"}: - try: - value = obj.as_python_constant() - if isinstance(value, type): - if name == "__bases__": - tuple_args = [ - VariableTracker.build( - tx, b, source and GetItemSource(source, i) - ) - for i, b in enumerate(value.__bases__) - ] - return variables.TupleVariable(tuple_args, source=source) - if name == "__base__": - return VariableTracker.build(tx, value.__base__, source) - if name == "__flags__": - return VariableTracker.build(tx, value.__flags__) - except NotImplementedError: - pass - - if isinstance(obj, variables.NNModuleVariable): - return obj.var_getattr(tx, name) - elif isinstance( + if isinstance( obj, ( - variables.TensorVariable, - variables.NamedTupleVariable, - variables.ConstantVariable, variables.DefaultDictVariable, - variables.DistributedVariable, - variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable, + variables.NestedUserFunctionVariable, + variables.ExceptionVariable, + variables.TracebackVariable, ), ): - if ( - isinstance(obj, variables.UserDefinedObjectVariable) - and issubclass(obj.value.__class__, unittest.TestCase) - and config.enable_trace_unittest - and name - in ( - "assertRaisesRegex", - "assertNotWarns", - "assertWarnsRegex", - "assertWarns", - ) - ): - unimplemented( - gb_type="Failed to trace unittest method", - context=f"function: unittest.TestCase.{name}", - explanation=f"Dynamo does not know how to trace unittest method `{name}` ", - hints=[ - f"Avoid calling `TestCase.{name}`. " - "Please report an issue to PyTorch.", - ], - ) + return obj.call_method(tx, "__setattr__", [name_var, val], {}) + elif ( + tx.output.side_effects.is_attribute_mutation(obj) + and name_var.is_python_constant() + ): + name = name_var.as_python_constant() if obj.is_tensor(): - # pyrefly: ignore[missing-attribute] - fake_val = obj.as_proxy().node.meta["example_value"] - if ( - isinstance(fake_val, torch.Tensor) - and is_sparse_any(fake_val) - and (not tx.export or not config.capture_sparse_compute) - ): - unimplemented( - gb_type="Attempted to wrap sparse Tensor", - context="", - explanation="torch.compile does not support sparse Tensors", - hints=[*graph_break_hints.SPARSE_TENSOR], - ) - - try: - return obj.var_getattr(tx, name) - except AsPythonConstantNotImplementedError: - # dont fallback on as_python_constant error because this leads - # to a failure later on, and leads to a wrong stacktrace - raise - except NotImplementedError: - return variables.GetAttrVariable(obj, name, source=source) - elif isinstance(obj, variables.TorchInGraphFunctionVariable): - # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. - try: - member = getattr(obj.value, name) - except AttributeError: - raise_observed_exception(AttributeError, tx) - raise - - if isinstance( - member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) - ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member): - return variables.TorchInGraphFunctionVariable(member, source=source) - elif name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(obj, name, source=source) - else: - return None - elif isinstance(obj, DummyModule): - # TODO(mlazos) - Do we need this? - if obj.is_torch or name not in obj.value.__dict__: - member = getattr(obj.value, name) - else: - member = obj.value.__dict__[name] - - if config.replay_record_enabled: - tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr] - return VariableTracker.build(tx, member, source) - else: - try: - return obj.var_getattr(tx, name) - except NotImplementedError: - return variables.GetAttrVariable(obj, name, source=source) - - def call_setattr( - self, - tx: "InstructionTranslator", - obj: VariableTracker, - name_var: VariableTracker, - val: VariableTracker, - ) -> VariableTracker | None: - if isinstance( - obj, - ( - variables.DefaultDictVariable, - variables.NamedTupleVariable, - variables.UserDefinedObjectVariable, - variables.NestedUserFunctionVariable, - variables.ExceptionVariable, - variables.TracebackVariable, - ), - ): - return obj.call_method(tx, "__setattr__", [name_var, val], {}) - elif ( - tx.output.side_effects.is_attribute_mutation(obj) - and name_var.is_python_constant() - ): - name = name_var.as_python_constant() - if obj.is_tensor(): - from .builder import wrap_fx_proxy + from .builder import wrap_fx_proxy # Some special handling for tensor attributes. if name == "requires_grad": @@ -3031,40 +2715,29 @@ def call_id( nn_mod_variable = args[0] mod = tx.output.get_submodule(nn_mod_variable.module_key) return VariableTracker.build(tx, id(mod)) - elif len(args) == 1 and isinstance( - args[0], - (variables.UserDefinedClassVariable, variables.UserDefinedObjectVariable), - ): - if args[0].source: - if isinstance(args[0], variables.UserDefinedClassVariable): - install_guard(args[0].source.make_guard(GuardBuilder.CLASS_MATCH)) - else: - install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) - constant_result = id(args[0].value) - return VariableTracker.build(tx, constant_result) elif len(args) == 1 and args[0].is_tensor(): tensor_variable = cast(TensorVariable, args[0]) return tensor_variable.call_id(tx) - elif istype(args[0], variables.UserFunctionVariable): - return VariableTracker.build(tx, id(args[0].fn)) - elif istype(args[0], variables.SkipFunctionVariable): - return VariableTracker.build(tx, id(args[0].value)) elif istype(args[0], variables.FunctoolsPartialVariable): return VariableTracker.build(tx, id(args[0].fake_value)) - elif isinstance( - args[0], - ( - ConstantVariable, - ConstDictVariable, - ListVariable, - TupleVariable, - SetVariable, - SymNodeVariable, - ), - ): - from .constant import FakeIdVariable - - return FakeIdVariable(id(args[0])) + elif len(args) == 1: + arg = args[0] + if isinstance( + arg, + ( + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if arg.source: + if isinstance(arg, variables.UserDefinedClassVariable): + install_guard(arg.source.make_guard(GuardBuilder.CLASS_MATCH)) + else: + install_guard(arg.source.make_guard(GuardBuilder.ID_MATCH)) + real_val = arg.get_real_python_backed_value() + if real_val is not NO_SUCH_SUBOBJ: + return VariableTracker.build(tx, id(real_val)) + return FakeIdVariable(id(arg)) else: unimplemented( gb_type="id() with unsupported args", @@ -3161,8 +2834,15 @@ def _comparison_with_symnode( hints=[*graph_break_hints.SUPPORTABLE], ) - # This is seen in inspect signature where we check if the value is a default value - if isinstance(right, variables.UserDefinedClassVariable): + # SymNodes are numeric (int/float/bool). The non-SymNode operand + # must be a type that can participate in a traced numeric comparison. + # Anything else (classes, DataPtrVariable, etc.) is a different type + # entirely — the comparison result is known at compile time. + non_symnode = right if isinstance(left, SymNodeVariable) else left + if not isinstance( + non_symnode, (SymNodeVariable, ConstantVariable, TensorVariable) + ): + # pyrefly: ignore [bad-argument-type] return VariableTracker.build(tx, op(object(), None)) proxy = tx.output.create_proxy( @@ -3189,31 +2869,28 @@ def call_xor( sym_num=None, ) - if isinstance( - a, - (DictKeysVariable, SetVariable, UserDefinedObjectVariable), - ): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__xor__", [b], {}) return None def call_ixor( self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker ) -> VariableTracker | None: - if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__ixor__", [b], {}) return None def call_sub( self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker ) -> VariableTracker | None: - if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__sub__", [b], {}) return None def call_isub( self, tx: "InstructionTranslator", a: VariableTracker, b: VariableTracker ) -> VariableTracker | None: - if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__isub__", [b], {}) return None @@ -3231,7 +2908,7 @@ def call_and_( ), sym_num=None, ) - if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__and__", [b], {}) # None no-ops this handler and lets the driving function proceed return None @@ -3250,7 +2927,7 @@ def call_iand( ), sym_num=None, ) - if isinstance(a, (DictKeysVariable, SetVariable, UserDefinedObjectVariable)): + if isinstance(a, _SET_LIKE_OP_SUPPORT): return a.call_method(tx, "__iand__", [b], {}) return None @@ -3286,12 +2963,10 @@ def call_or_( if isinstance( a, ( + *_SET_LIKE_OP_SUPPORT, ConstDictVariable, - DictKeysVariable, MutableMappingVariable, - SetVariable, UserDefinedDictVariable, - UserDefinedObjectVariable, ), ): # TODO(guilhermeleobas): forward the call to b.__ror__(a) if @@ -3320,11 +2995,9 @@ def call_ior( if isinstance( a, ( + *_SET_LIKE_OP_SUPPORT, ConstDictVariable, - DictKeysVariable, MutableMappingVariable, - SetVariable, - UserDefinedObjectVariable, ), ): return a.call_method(tx, "__ior__", [b], {}) @@ -3347,8 +3020,11 @@ def call_not_( # Unwrap the underlying ConstDictVariable if isinstance(a, DictViewVariable): a = a.dv_dict - if isinstance(a, (ListVariable, ConstDictVariable)): + if isinstance(a, (ListVariable, ConstDictVariable, SetVariable)): return VariableTracker.build(tx, len(a.items) == 0) + if isinstance(a, UserDefinedObjectVariable): + bool_result = self.call_bool(tx, a) + return VariableTracker.build(tx, not bool_result.value) # type: ignore[missing-attribute] return None @@ -3357,16 +3033,550 @@ def call_contains( ) -> VariableTracker: return a.call_method(tx, "__contains__", [b], {}) - def is_python_hashable(self) -> Literal[True]: - return True - - def get_python_hash(self) -> int: - return hash(self.fn) - def is_python_equal(self, other: object) -> bool: return isinstance(other, variables.BuiltinVariable) and self.fn is other.fn +class DictBuiltinVariable(BaseBuiltinVariable): + """Variable tracker for the `dict` builtin constructor.""" + + _fn = dict + + def __init__(self, value: type = dict, **kwargs: Any) -> None: + assert value is dict + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "DictBuiltinVariable()" + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + return DictBuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__new__": + if args: + # dict.__new__ (tp_new) ignores extra args — only the first + # arg (the type) matters. Pass init_args=[] so reconstruction + # emits base_cls.__new__(cls) without extras. + # https://github.com/python/cpython/blob/v3.13.0/Objects/dictobject.c#L4735-L4768 + dict_vt = ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) + if isinstance(args[0], DictBuiltinVariable): + return dict_vt + return tx.output.side_effects.track_new_user_defined_object( + self, args[0], [] + ) + + if name == "fromkeys": + return DictBuiltinVariable.call_custom_dict_fromkeys( + tx, dict, *args, **kwargs + ) + + resolved_fn = getattr(dict, name, None) + if resolved_fn is not None and resolved_fn in dict_methods: + if isinstance(args[0], variables.UserDefinedDictVariable): + assert args[0]._base_vt is not None + return args[0]._base_vt.call_method(tx, name, args[1:], kwargs) + elif isinstance(args[0], ConstDictVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + return super().call_method(tx, name, args, kwargs) + + @staticmethod + def call_custom_dict( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + args_list = list(args) + return tx.inline_user_function_return( + VariableTracker.build(tx, polyfills.construct_dict), + [VariableTracker.build(tx, user_cls), *args_list], + kwargs, + ) + + @staticmethod + def call_custom_dict_fromkeys( + tx: "InstructionTranslator", + user_cls: type, + /, + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + if user_cls not in {dict, OrderedDict, defaultdict}: + unimplemented( + gb_type="Unsupported dict type for fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + f"{user_cls.__name__} is not any type of dict, OrderedDict, or defaultdict", + hints=[ + f"Ensure {user_cls.__name__} is a type of dict, OrderedDict, or defaultdict.", + ], + ) + if kwargs: + # Only `OrderedDict.fromkeys` accepts `value` passed by keyword + if ( + user_cls is not OrderedDict + or len(args) != 1 + or len(kwargs) != 1 + or "value" not in kwargs + ): + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "1 args and 1 kwargs (`value`)", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + args = (*args, kwargs.pop("value")) + if len(args) == 0: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "at least 1 args", + f"{len(args)} args", + ) + if len(args) == 1: + args = (*args, ConstantVariable.create(None)) + if len(args) != 2: + raise_args_mismatch( + tx, + f"{user_cls.__name__}.fromkeys", + "2 args", + f"{len(args)} args", + ) + + arg, value = args + + def _make_result( + items: dict[VariableTracker, VariableTracker], + ) -> VariableTracker: + if user_cls is OrderedDict: + from .builder import SourcelessBuilder + from .user_defined import OrderedDictVariable + + result = tx.output.side_effects.track_new_user_defined_object( + SourcelessBuilder.create(tx, dict), + SourcelessBuilder.create(tx, OrderedDict), + [], + ) + assert isinstance(result, OrderedDictVariable) + result._base_vt = ConstDictVariable( + items, + user_cls=OrderedDict, + mutation_type=ValueMutationNew(), + ) + return result + elif user_cls is defaultdict: + from .builder import SourcelessBuilder + from .user_defined import DefaultDictVariable + + result = tx.output.side_effects.track_new_user_defined_object( + SourcelessBuilder.create(tx, dict), + SourcelessBuilder.create(tx, defaultdict), + [], + ) + assert isinstance(result, DefaultDictVariable) + result._base_vt = ConstDictVariable( + items, mutation_type=ValueMutationNew() + ) + return result + else: + return ConstDictVariable(items, mutation_type=ValueMutationNew()) + + if isinstance(arg, dict): + arg_list = [VariableTracker.build(tx, k) for k in arg] + return _make_result(dict.fromkeys(arg_list, value)) + elif arg.has_force_unpack_var_sequence(tx): + keys = arg.force_unpack_var_sequence(tx) + if all(is_hashable(v) for v in keys): + return _make_result(dict.fromkeys(keys, value)) + + unimplemented( + gb_type="failed to call dict.fromkeys()", + context=f"{user_cls.__name__}.fromkeys(): {args} {kwargs}", + explanation=f"Failed to call {user_cls.__name__}.fromkeys() because " + "arguments could not be automatically converted to a list, " + "or some dict key is not hashable.", + hints=[ + "Manually convert the argument to a list.", + "Ensure all keys are hashable.", + ], + ) + + +class IterBuiltinVariable(BaseBuiltinVariable): + """Variable tracker for the `iter` builtin.""" + + _fn = iter + + def __init__(self, value: Any = iter, **kwargs: Any) -> None: + assert value is iter + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "IterBuiltinVariable()" + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if not args: + unimplemented( + gb_type="iter() with no arguments", + context="iter()", + explanation="iter() requires at least one argument", + hints=[*graph_break_hints.USER_ERROR], + ) + obj, *rest = args + + # Fast path: for known iterable VT types, call __iter__ directly + # instead of going through the polyfill, saving tracing overhead. + if ( + not rest + and not kwargs + and isinstance( + obj, + ( + variables.ListVariable, + variables.RangeVariable, + variables.IteratorVariable, + variables.ConstDictVariable, + variables.NNModuleVariable, + variables.TensorVariable, + variables.TupleVariable, + variables.UserDefinedClassVariable, + DictViewVariable, + ), + ) + ): + return obj.call_method(tx, "__iter__", [], {}) + + # General case: inline the polyfill which handles __iter__ and __getitem__ + ret = variables.UserFunctionVariable( + polyfills.builtins.iter_ # type: ignore[arg-type] + ).call_function(tx, [obj, *rest], {}) + + if rest: + # iter(obj, sentinel) returns a callable iterator; wrap it so + # Dynamo knows to forward __next__ calls to the returned object. + ret = variables.ObjectIteratorVariable(ret) + return ret + + +class GetAttrBuiltinVariable(BaseBuiltinVariable): + """Variable tracker for the `getattr` builtin.""" + + _fn = getattr + + def __init__(self, value: Any = getattr, **kwargs: Any) -> None: + assert value is getattr + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "GetAttrBuiltinVariable()" + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .lazy import LazyVariableTracker + + if any(isinstance(a, LazyVariableTracker) for a in args): + args = [ + a.realize() if isinstance(a, LazyVariableTracker) else a for a in args + ] + try: + return self._call_getattr(tx, args, kwargs) + except Unsupported: + # Replicate the constant-fold fallback from BuiltinVariable._make_handler: + # if all args are python constants, evaluate getattr() directly rather + # than propagating a graph break from var_getattr. + if not check_unspec_or_constant_args(args, kwargs): + raise + try: + result = getattr(*[a.as_python_constant() for a in args]) + except AttributeError: + raise_observed_exception(AttributeError, tx) + raise + except AsPythonConstantNotImplementedError: + raise + except Exception as exc: + raise_observed_exception(type(exc), tx, args=list(exc.args)) + raise + return VariableTracker.build(tx, result) + + def _call_getattr( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + obj = args[0] + name_var = args[1] + default = args[2] if len(args) > 2 else None + + if not name_var.is_python_constant(): + unimplemented( + gb_type="getattr() with non-constant name argument", + context=f"getattr({obj}, {name_var}, {default})", + explanation="getattr() with non-constant name argument is not supported", + hints=["Ensure the name argument of getattr() is a string"], + ) + + name = name_var.as_python_constant() + + # See NOTE [Tensor "grad" and "_grad" attr] + if obj.is_tensor() and name == "_grad": + name = "grad" + + if tx.output.side_effects.is_attribute_mutation(obj): + if isinstance(obj, variables.UnspecializedNNModuleVariable): + if ( + name + in ( + "named_parameters", + "parameters", + "named_buffers", + "buffers", + "named_modules", + "modules", + ) + and obj.is_state_mutated + and tx.output.side_effects.has_pending_mutation(obj) + ): + unimplemented( + gb_type="getattr() on nn.Module with pending mutation", + context=f"getattr({obj}, {name}, {default})", + explanation="Intentionally graph breaking on getattr() on a nn.Module " + "with a pending mutation", + hints=[], + ) + + if tx.output.side_effects.has_pending_mutation_of_attr(obj, name): + return tx.output.side_effects.load_attr(obj, name) + + if default is not None: + hasattr_var = obj.call_obj_hasattr(tx, name) + if hasattr_var is not None: + assert hasattr_var.is_constant_match(True, False) + if not hasattr_var.as_python_constant(): + return default + else: + return default + + source = obj.source and AttrSource(obj.source, name) + if name in {"__bases__", "__base__", "__flags__"}: + try: + value = obj.as_python_constant() + if isinstance(value, type): + if name == "__bases__": + tuple_args = [ + VariableTracker.build( + tx, b, source and GetItemSource(source, i) + ) + for i, b in enumerate(value.__bases__) + ] + return variables.TupleVariable(tuple_args, source=source) + if name == "__base__": + return VariableTracker.build(tx, value.__base__, source) + if name == "__flags__": + return VariableTracker.build(tx, value.__flags__) + except NotImplementedError: + pass + + if isinstance(obj, variables.NNModuleVariable): + return obj.var_getattr(tx, name) + elif isinstance( + obj, + ( + variables.TensorVariable, + variables.NamedTupleVariable, + variables.ConstantVariable, + variables.DefaultDictVariable, + variables.DistributedVariable, + variables.UserDefinedClassVariable, + variables.UserDefinedObjectVariable, + ), + ): + if ( + isinstance(obj, variables.UserDefinedObjectVariable) + and issubclass(obj.value.__class__, unittest.TestCase) + and config.enable_trace_unittest + and name + in ( + "assertRaisesRegex", + "assertNotWarns", + "assertWarnsRegex", + "assertWarns", + ) + ): + unimplemented( + gb_type="Failed to trace unittest method", + context=f"function: unittest.TestCase.{name}", + explanation=f"Dynamo does not know how to trace unittest method `{name}` ", + hints=[ + f"Avoid calling `TestCase.{name}`. " + "Please report an issue to PyTorch.", + ], + ) + if obj.is_tensor(): + # pyrefly: ignore[missing-attribute] + fake_val = obj.as_proxy().node.meta["example_value"] + if ( + isinstance(fake_val, torch.Tensor) + and is_sparse_any(fake_val) + and (not tx.export or not config.capture_sparse_compute) + ): + unimplemented( + gb_type="Attempted to wrap sparse Tensor", + context="", + explanation="torch.compile does not support sparse Tensors", + hints=[*graph_break_hints.SPARSE_TENSOR], + ) + + try: + return obj.var_getattr(tx, name) + except AsPythonConstantNotImplementedError: + # dont fallback on as_python_constant error because this leads + # to a failure later on, and leads to a wrong stacktrace + raise + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, variables.TorchInGraphFunctionVariable): + # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. + try: + member = getattr(obj.value, name) + except AttributeError: + raise_observed_exception(AttributeError, tx) + raise + + if isinstance( + member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) + ) and torch._dynamo.trace_rules.is_aten_op_or_tensor_method(member): + return variables.TorchInGraphFunctionVariable(member, source=source) + else: + return variables.GetAttrVariable(obj, name, source=source) + elif isinstance(obj, DummyModule): + # TODO(mlazos) - Do we need this? + if obj.is_torch or name not in obj.value.__dict__: + member = getattr(obj.value, name) + else: + member = obj.value.__dict__[name] + + if config.replay_record_enabled: + tx.exec_recorder.record_module_access(obj.value, name, member) # type: ignore[arg-type, union-attr] + return VariableTracker.build(tx, member, source) + else: + try: + return obj.var_getattr(tx, name) + except NotImplementedError: + return variables.GetAttrVariable(obj, name, source=source) + + +class ListBuiltinVariable(BaseBuiltinVariable): + """Variable tracker for the `list` builtin constructor.""" + + _fn = list + + def __init__(self, value: type = list, **kwargs: Any) -> None: + assert value is list + super().__init__(**kwargs) + + def __repr__(self) -> str: + return "ListBuiltinVariable()" + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .user_defined import UserDefinedObjectVariable + + obj = args[0] if args else None + + if isinstance( + obj, (variables.IteratorVariable, variables.LocalGeneratorObjectVariable) + ) or ( + isinstance(obj, UserDefinedObjectVariable) + and obj.has_force_unpack_var_sequence(tx) + ): + return ListVariable( + list(obj.force_unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + + if obj is None: + return ListVariable([], mutation_type=ValueMutationNew()) + + if obj.has_unpack_var_sequence(tx): + if obj.source and not is_constant_source(obj.source): + if isinstance(obj, TupleIteratorVariable): + install_guard( + obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) + ) + else: + if isinstance(obj, ConstDictVariable) and not istype( + obj, (SetVariable, FrozensetVariable) + ): + tx.output.guard_on_key_order.add(obj.source) + if isinstance(obj, variables.MappingProxyVariable): + install_guard( + obj.source.make_guard(GuardBuilder.MAPPING_KEYS_CHECK) + ) + elif not isinstance(obj, variables.UnspecializedNNModuleVariable): + install_guard( + obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH) + ) + return ListVariable( + list(obj.unpack_var_sequence(tx)), + mutation_type=ValueMutationNew(), + ) + + arg_types = [type(a).__name__ for a in args] + unimplemented( + gb_type="Failed to trace list()", + context=f"list({arg_types})", + explanation=f"Dynamo does not know how to construct a list from argument types {arg_types}", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name == "__new__": + if len(args) == 1 and not kwargs: + list_vt = ListVariable([], mutation_type=ValueMutationNew()) + if isinstance(args[0], ListBuiltinVariable): + return list_vt + return tx.output.side_effects.track_new_user_defined_object( + self, args[0], args[1:] + ) + + return super().call_method(tx, name, args, kwargs) + + +# pyrefly: ignore [deprecated] @contextlib.contextmanager def dynamo_disable_grad(tx: "InstructionTranslator") -> typing.Iterator[None]: from . import GradModeVariable diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index 1e0cd1cba61bc..c8061f4e7336a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -1,34 +1,29 @@ """ -Constant and enum variable tracking in Dynamo. +Constant variable tracking in Dynamo. This module is fundamental to Dynamo's ability to track and propagate constant values during compilation, ensuring proper handling of Python literals and maintaining type safety through the compilation process. """ -import enum +from __future__ import annotations + import operator -from collections.abc import Sequence -from typing import Any, Literal, Optional, overload, TYPE_CHECKING -from typing_extensions import Never, override +from typing import Any, Literal, overload, TYPE_CHECKING +from typing_extensions import override import torch -from torch._dynamo.source import AttrSource, GetItemSource +from torch._dynamo.source import GetItemSource -from .. import graph_break_hints, variables +from .. import variables from ..exc import raise_observed_exception, unimplemented -from ..utils import ( - cmp_name_to_op_mapping, - common_constant_types, - istype, - np, - raise_args_mismatch, - raise_on_overridden_hash, -) +from ..utils import common_constant_types, istype, np, raise_args_mismatch from .base import ValueMutationNew, VariableTracker if TYPE_CHECKING: + from collections.abc import Sequence + from torch._dynamo.symbolic_convert import InstructionTranslator from .functions import UserFunctionVariable @@ -43,23 +38,23 @@ class ConstantVariable(VariableTracker): nested collections. """ - @overload - @staticmethod - def create(value: None) -> Never: ... - - @overload - @staticmethod - def create(value: Literal[True]) -> Never: ... + # PyLong_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/longobject.c#L6585 + # PyFloat_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/floatobject.c#L1880 + # PyBool_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/boolobject.c#L171 + # PyUnicode_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/unicodeobject.c#L14931 + # PyBytes_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/bytesobject.c#L3017 + # PyComplex_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/complexobject.c#L1099 + # _PyNone_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/object.c#L2022 + _cpython_type = (int, float, str, bytes, bool, type(None), complex, type(...)) @overload @staticmethod - def create(value: Literal[False]) -> Never: ... + def create(value: None) -> ConstantVariable: ... @overload @staticmethod - def create(value: bool) -> "ConstantVariable": ... + def create(value: bool) -> ConstantVariable: ... - # TODO: Refactor to make these return ConstantVariable @overload @staticmethod def create(value: Any, **kwargs: Any) -> VariableTracker: ... @@ -74,6 +69,17 @@ def create(value: Any, **kwargs: Any) -> VariableTracker: NOTE: the caller must install the proper guards if needed; most often the guard will be `CONSTANT_MATCH`. """ + # Return pre-allocated sentinels for None/True/False when there are + # no extra kwargs (source, etc.) that would differentiate the instance. + if not kwargs: + match value: + case None: + return CONSTANT_VARIABLE_NONE + case True: + return CONSTANT_VARIABLE_TRUE + case False: + return CONSTANT_VARIABLE_FALSE + source = kwargs.get("source") # Routing for supported collection literals. @@ -147,7 +153,7 @@ def items(self) -> list[VariableTracker]: return self.unpack_var_sequence(tx=None) def getitem_const( - self, tx: "InstructionTranslator", arg: VariableTracker + self, tx: InstructionTranslator, arg: VariableTracker ) -> VariableTracker: return ConstantVariable.create( self.value[arg.as_python_constant()], @@ -170,17 +176,31 @@ def is_literal(obj: object, cache: dict[int, object] | None = None) -> bool: return ConstantVariable.is_base_literal(obj) def unpack_var_sequence( - self, tx: Optional["InstructionTranslator"] + self, tx: InstructionTranslator | None ) -> list[VariableTracker]: try: return [ConstantVariable.create(x) for x in self.as_python_constant()] except TypeError as e: raise NotImplementedError from e - def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + def len_impl(self, tx: InstructionTranslator) -> VariableTracker: + """Generic len for any constant value (sequence or mapping).""" + try: + return ConstantVariable.create(len(self.value)) + except TypeError as e: + raise_observed_exception(type(e), tx, args=list(e.args)) + + def sq_length(self, tx: InstructionTranslator) -> VariableTracker: + """Sequence length - delegates to len_impl for constants.""" + return self.len_impl(tx) + + def mp_length(self, tx: InstructionTranslator) -> VariableTracker: + """Mapping length - delegates to len_impl for constants.""" + return self.len_impl(tx) + + def const_getattr(self, tx: InstructionTranslator, name: str) -> VariableTracker: if not hasattr(self.value, name): - name_variable = variables.ConstantVariable.create(name) - raise_observed_exception(AttributeError, tx, args=[name_variable]) + raise_observed_exception(AttributeError, tx, args=[name]) member = getattr(self.value, name) if callable(member): raise NotImplementedError @@ -188,7 +208,7 @@ def const_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTrack def call_method( self, - tx: "InstructionTranslator", + tx: InstructionTranslator, name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], @@ -248,7 +268,7 @@ def call_method( raise_observed_exception( type(exc), tx, - args=list(map(ConstantVariable.create, exc.args)), + args=list(exc.args), ) if ( hasattr(operator, name) @@ -269,9 +289,7 @@ def call_method( try: return ConstantVariable.create(op(self.value, add_target)) except Exception as e: - raise_observed_exception( - type(e), tx, args=list(map(ConstantVariable.create, e.args)) - ) + raise_observed_exception(type(e), tx, args=list(e.args)) elif isinstance(self.value, bytes) and name == "decode": method = getattr(self.value, name) return ConstantVariable.create(method(*const_args, **const_kwargs)) @@ -282,20 +300,13 @@ def call_method( except Exception as e: raise_observed_exception(type(e), tx) - if name == "__len__" and not (args or kwargs): - try: - return ConstantVariable.create(len(self.value)) - except TypeError as e: - raise_observed_exception(type(e), tx, args=list(e.args)) - elif name == "__round__" and len(args) == 1 and args[0].is_python_constant(): + if name == "__round__" and len(args) == 1 and args[0].is_python_constant(): try: return ConstantVariable.create( round(self.value, args[0].as_python_constant()) ) except Exception as e: - raise_observed_exception( - type(e), tx, args=list(map(ConstantVariable.create, e.args)) - ) + raise_observed_exception(type(e), tx, args=list(e.args)) elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant(): assert not kwargs search = args[0].as_python_constant() @@ -303,15 +314,13 @@ def call_method( result = search in self.value return ConstantVariable.create(result) except TypeError as e: - raise_observed_exception( - type(e), tx, args=list(map(ConstantVariable.create, e.args)) - ) + raise_observed_exception(type(e), tx, args=list(e.args)) return super().call_method(tx, name, args, kwargs) def call_tree_map( self, - tx: "InstructionTranslator", - tree_map_fn: "UserFunctionVariable", + tx: InstructionTranslator, + tree_map_fn: UserFunctionVariable, map_fn: VariableTracker, rest: Sequence[VariableTracker], tree_map_kwargs: dict[str, VariableTracker], @@ -362,8 +371,8 @@ def call_tree_map( @override def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> "ConstantVariable": + self, tx: InstructionTranslator, name: str + ) -> ConstantVariable: result = hasattr(self.value, name) return variables.ConstantVariable.create(result) @@ -374,7 +383,6 @@ def get_python_hash(self) -> int: return hash(self.value) def is_python_equal(self, other: object) -> bool: - # Could be an EnumVariable as well from .tensor import SymNodeVariable if isinstance(other, SymNodeVariable): @@ -384,6 +392,37 @@ def is_python_equal(self, other: object) -> bool: and self.as_python_constant() == other.as_python_constant() ) + def get_real_python_backed_value(self) -> object: + return self.value + + def nb_index_impl( + self, + tx: Any, + ) -> VariableTracker: + # CPython: int and bool define nb_index (returns self for int, + # int(self) for bool). All other constant types do not. + if isinstance(self.value, (int, bool)): + return ConstantVariable.create(operator.index(self.value)) + return super().nb_index_impl(tx) + + def nb_int_impl( + self, + tx: Any, + ) -> VariableTracker: + # CPython: int defines nb_int (long_long, returns copy). + # bool inherits nb_int from int via slot inheritance. + # float defines nb_int (truncates toward zero via PyLong_FromDouble). + return ConstantVariable.create(int(self.value)) + + def nb_float_impl( + self, + tx: Any, + ) -> VariableTracker: + # CPython: float defines nb_float (float_float, returns copy). + # int defines nb_float (long_float, converts to float). + # bool inherits nb_float from int via slot inheritance. + return ConstantVariable.create(float(self.value)) + CONSTANT_VARIABLE_NONE = ConstantVariable(None) CONSTANT_VARIABLE_TRUE = ConstantVariable(True) @@ -402,6 +441,9 @@ class FakeIdVariable(VariableTracker): graph break does not silently bake a stale id into the resumed bytecode. """ + # PyLong_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/longobject.c#L6585 + _cpython_type = int + def __init__(self, value: int, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value @@ -436,68 +478,5 @@ def reconstruct(self, codegen: Any) -> None: ), hints=[ "Avoid using id() on containers in code that may graph-break.", - *graph_break_hints.SUPPORTABLE, ], ) - - -class EnumVariable(VariableTracker): - """VariableTracker for enum.Enum and enum.IntEnum instances - - Provides specialized handling for Python enum types, supporting - both standard Enum and IntEnum with proper value tracking and comparison. - """ - - def __init__(self, value: enum.Enum | enum.IntEnum, **kwargs: Any) -> None: - super().__init__(**kwargs) - self.value = value - - @classmethod - def create( - cls, cls_type: Any, value_vt: VariableTracker, options: Any - ) -> "EnumVariable": - if value_vt.is_python_constant(): - for member in list(cls_type): - if member.value == value_vt.as_python_constant(): - return cls(member, **options) - unimplemented( - gb_type="Failed to construct Enum variable", - context=f"value: {value_vt}, allowed enum values: {list(cls_type)}", - explanation="Attempted to construct an Enum value that is non-constant (e.g. int, string) " - "or is not an acceptable value for the Enum. " - f"Acceptable values for Enum `{cls_type}`: {list(cls_type)}.", - hints=[*graph_break_hints.USER_ERROR, *graph_break_hints.SUPPORTABLE], - ) - - def as_proxy(self) -> enum.Enum | int: - if isinstance(self.value, int): - return int(self.value) # convert IntEnum to a normal int - return self.value - - def __repr__(self) -> str: - return f"EnumVariable({type(self.value)})" - - def as_python_constant(self) -> enum.Enum | enum.IntEnum: - return self.value - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - if not hasattr(self.value, name): - raise NotImplementedError - if name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) - member = getattr(self.value, name) - source = self.source and AttrSource(self.source, name) - return VariableTracker.build(tx, member, source=source) - - def is_python_hashable(self) -> Literal[True]: - raise_on_overridden_hash(self.value, self) - return True - - def get_python_hash(self) -> int: - return hash(self.as_python_constant()) - - def is_python_equal(self, other: object) -> bool: - return ( - isinstance(other, VariableTracker) - and self.as_python_constant() == other.as_python_constant() - ) diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 756a052d88a85..f1e9af5254466 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -18,8 +18,10 @@ restoring state changes. """ +import contextlib import inspect import logging +import types import warnings from collections.abc import Callable, Sequence, Sized from contextlib import AbstractContextManager, ExitStack @@ -77,7 +79,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: if hasattr(self, "_call_func"): self._call_func(tx, self.target_values) self.set_cleanup_hook(tx) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def set_cleanup_hook( self, tx: "InstructionTranslator", fn: Callable[..., Any] | None = None @@ -95,7 +97,7 @@ def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: self.cleanup_assert() - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def reconstruct_type(self, codegen: "PyCodegen") -> None: codegen( @@ -271,7 +273,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: (enabled,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker @@ -283,7 +285,7 @@ def exit( (self.prev_state,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): @@ -311,7 +313,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: (), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker @@ -323,7 +325,7 @@ def exit( (self.proxy,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): @@ -355,7 +357,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: ) self.proxy = tx.output.create_node( "call_function", - torch._C._functorch._jvp_increment_nesting, + torch._functorch.predispatch._jvp_increment_nesting, (), {}, ) @@ -366,9 +368,9 @@ def exit( ) -> VariableTracker: self.cleanup() tx.output.create_node( - "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} + "call_function", torch._functorch.predispatch._jvp_decrement_nesting, (), {} ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class SetFwdGradEnabledContextManager(ContextWrappingVariable): @@ -398,7 +400,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: (mode,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker @@ -410,7 +412,10 @@ def exit( (self.prev_state,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) + + def python_type(self) -> type: + return torch.autograd.forward_ad._set_fwd_grad_enabled class DualLevelContextManager(ContextWrappingVariable): @@ -434,7 +439,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: ) self.proxy = tx.output.create_node( "call_function", - torch._C._enter_dual_level, + torch._functorch.predispatch._enter_dual_level, (), {}, ) @@ -446,11 +451,14 @@ def exit( self.cleanup() tx.output.create_node( "call_function", - torch._C._exit_dual_level, - (self.new_level,), - {}, + torch._functorch.predispatch._exit_dual_level, + (), + {"level": self.new_level}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) + + def python_type(self) -> type: + return torch.autograd.forward_ad.dual_level class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): @@ -493,7 +501,7 @@ def exit( tx.output.create_node( "call_function", torch._C._functorch._grad_decrement_nesting, (), {} ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class CatchWarningsCtxManagerVariable(ContextWrappingVariable): @@ -538,6 +546,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: keys = tuple(self.catch_warnings_args.keys()) codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, False)) + def python_type(self) -> type: + return warnings.catch_warnings + class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): """represents torch VMap increment/decrement nesting""" @@ -592,7 +603,7 @@ def exit( (), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class GradModeVariable(ContextWrappingVariable): @@ -630,13 +641,13 @@ def __init__( def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: self._call_func(tx, self.initial_values) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def call_function( self, @@ -663,6 +674,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "set_grad_enabled" + def python_type(self) -> type: + return torch.set_grad_enabled + class InferenceModeVariable(ContextWrappingVariable): @staticmethod @@ -697,7 +711,7 @@ def exit( (self.proxy,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def enter(self, tx: "InstructionTranslator") -> VariableTracker: disabled_inference_mode_forcibly = False @@ -727,7 +741,7 @@ def cleanup_hook() -> None: (*self.target_values,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "torch" @@ -735,6 +749,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "inference_mode" + def python_type(self) -> type: + return torch.inference_mode + class CUDADeviceVariable(ContextWrappingVariable): """represents torch.cuda.device""" @@ -770,7 +787,7 @@ def exit( (self.proxy,), {}, ) - return variables.CONSTANT_VARIABLE_FALSE + return variables.ConstantVariable.create(False) def enter(self, tx: "InstructionTranslator") -> VariableTracker: prev_idx = torch.cuda._exchange_device(*self.target_values) @@ -781,7 +798,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: (*self.target_values,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "torch.cuda" @@ -789,6 +806,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "device" + def python_type(self) -> type: + return torch.cuda.device + class TorchFunctionDisableVariable(ContextWrappingVariable): """represents whether torch function overrides are enabled or not""" @@ -863,6 +883,11 @@ def fn_name(self) -> str: return "DisableTorchFunctionSubclass" return "DisableTorchFunction" + def python_type(self) -> type: + if self.only_subclass: + return torch._C.DisableTorchFunctionSubclass # pyrefly: ignore[bad-return] + return torch._C.DisableTorchFunction # pyrefly: ignore[bad-return] + class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): """represents torch.autograd.graph.disable_saved_tensors_hook.""" @@ -893,7 +918,7 @@ def __init__( ) def enter(self, tx: "InstructionTranslator") -> VariableTracker: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def _call_func( self, tx: "InstructionTranslator", values: Sequence[str | None] @@ -924,6 +949,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "disable_saved_tensors_hooks" + def python_type(self) -> type: + return contextlib._GeneratorContextManager + class AutocastModeVariable(ContextWrappingVariable): @staticmethod @@ -980,7 +1008,7 @@ def exit( tx.output.create_node( "call_function", torch.amp._exit_autocast, (self.proxy,), {} ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def enter(self, tx: "InstructionTranslator") -> VariableTracker: ctx = torch.amp._enter_autocast(*self.target_values) @@ -988,7 +1016,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.proxy = tx.output.create_node( "call_function", torch.amp._enter_autocast, (*self.target_values,), {} ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "torch.amp.autocast_mode" @@ -996,6 +1024,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "autocast" + def python_type(self) -> type: + return torch.amp.autocast_mode.autocast + class NullContextVariable(ContextWrappingVariable): """ @@ -1006,13 +1037,13 @@ def __init__(self, target_values: Any | None = None, **kwargs: Any) -> None: super().__init__(target_values=target_values, **kwargs) def enter(self, tx: "InstructionTranslator") -> VariableTracker: - none = variables.CONSTANT_VARIABLE_NONE + none = variables.ConstantVariable.create(None) return self.target_values if self.target_values else none def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "contextlib" @@ -1020,6 +1051,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "nullcontext" + def python_type(self) -> type: + return contextlib.nullcontext + class ProfilerContextVariable(ContextWrappingVariable): """ @@ -1033,13 +1067,16 @@ class ProfilerContextVariable(ContextWrappingVariable): def __init__(self, **kwargs: Any) -> None: super().__init__(target_values=None, **kwargs) + def python_type(self) -> type: + return torch.profiler.profile + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return self def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "contextlib" @@ -1116,6 +1153,9 @@ def __init__( target_values=target_values, initial_values=initial_values, **kwargs ) + def python_type(self) -> type: + return torch.autograd.profiler.record_function + def enter(self, tx: "InstructionTranslator") -> VariableTracker: if config.capture_profiler_record_function: name, args = self.target_values @@ -1139,7 +1179,7 @@ def exit( (self.proxy,), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return ( @@ -1220,8 +1260,11 @@ def __init__( if self.prev_versions.is_symnode_like(): self.prev_versions = variables.TupleVariable([self.prev_versions]) + def python_type(self) -> type: + return torch.autograd.grad_mode._unsafe_preserve_version_counter + def enter(self, tx: "InstructionTranslator") -> VariableTracker: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker @@ -1279,13 +1322,13 @@ def __init__( def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker ) -> VariableTracker: self._call_func(tx, self.initial_values) # type: ignore[arg-type] - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def call_function( self, @@ -1388,7 +1431,7 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: (arg, bool(self.set_priority)), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def exit( self, tx: "InstructionTranslator", *args: VariableTracker @@ -1401,7 +1444,7 @@ def exit( (arg, bool(self.set_priority)), {}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "torch.nn.attention" @@ -1411,6 +1454,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "_sdpa_kernel_variadic" + def python_type(self) -> type: + return contextlib._GeneratorContextManager + class FxTracebackAnnotateVariable(ContextWrappingVariable): """ @@ -1439,7 +1485,7 @@ def enter( stack.enter_context(torch.fx.traceback.annotate(self.target_values)) stack.enter_context(torch.fx.traceback.preserve_node_meta()) self.set_cleanup_hook(tx, lambda: stack.close()) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def module_name(self) -> str: return "torch.fx.traceback" @@ -1447,6 +1493,9 @@ def module_name(self) -> str: def fn_name(self) -> str: return "annotate" + def python_type(self) -> type: + return contextlib._GeneratorContextManager + def reconstruct_type(self, codegen: "PyCodegen") -> None: unimplemented( gb_type="torch.fx.traceback.annotate escaped from compiled region", @@ -1491,6 +1540,11 @@ def module_name(self) -> str: def fn_name(self) -> str: return "patch_dynamo_config" + def python_type(self) -> type: + from torch._dynamo.decorators import DynamoConfigPatchProxy + + return DynamoConfigPatchProxy + class ErrorOnGraphBreakVariable(ContextWrappingVariable): """represents torch._dynamo.error_on_graph_break""" @@ -1512,6 +1566,11 @@ def module_name(self) -> str: def fn_name(self) -> str: return "error_on_graph_break" + def python_type(self) -> type: + from torch._dynamo.decorators import ErrorOnGraphBreakDecoratorContextManager + + return ErrorOnGraphBreakDecoratorContextManager + class CudagraphOverrideVariable(ContextWrappingVariable): """represents torch._dynamo.override_cudagraphs""" @@ -1552,6 +1611,11 @@ def module_name(self) -> str: def fn_name(self) -> str: return "override_cudagraphs" + def python_type(self) -> type: + from torch._dynamo.decorators import CudagraphOverrideContextManager + + return CudagraphOverrideContextManager + def exit_on_graph_break(self) -> bool: # Annotation persists until graph is compiled; each resume function # will reconstruct the context manager and call enter() again @@ -1575,6 +1639,9 @@ def __init__( super().__init__(**kwargs) self.ctx = ctx + def python_type(self) -> type: + return types.MethodType + def call_function( self, tx: "InstructionTranslator", @@ -1626,6 +1693,9 @@ def __init__( self.ctx = ctx self.target = target + def python_type(self) -> type: + return types.MethodType + def call_function( self, tx: "InstructionTranslator", diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 005eaf0d8c0ae..4df5586ade681 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -6,26 +6,23 @@ - Ordered dictionaries (collections.OrderedDict) - Default dictionaries (collections.defaultdict) - Dictionary views (keys and values) -- Sets and frozensets (implemented internally using dictionaries) These classes are responsible for tracking dictionary operations during graph compilation, maintaining proper guards for dictionary mutations and key existence checks. They handle dictionary creation, modification, key/value access, and view operations while ensuring correct behavior in the compiled code through appropriate guard installation. -The implementation uses a special _HashableTracker wrapper to handle dictionary keys -while preserving proper aliasing semantics. Sets are implemented as dictionaries with -None values for efficiency and code reuse. +The implementation uses a special HashableTracker wrapper to handle +dictionary keys while preserving proper aliasing semantics. Set-related classes live +in sets.py. """ import collections import functools -import operator import types -from collections.abc import Callable, Iterable, Iterator, Sequence -from typing import Any, Literal, Optional, TYPE_CHECKING, Union +from collections.abc import Callable, Iterator, Sequence +from typing import Any, Literal, TYPE_CHECKING, Union -from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import MappingKey from .. import graph_break_hints, polyfills, variables @@ -37,7 +34,12 @@ ) from ..exc import raise_observed_exception, unimplemented from ..guards import GuardBuilder, install_guard -from ..source import AttrSource, is_constant_source, is_from_local_source +from ..source import ( + AttrSource, + DictGetItemSource, + is_constant_source, + is_from_local_source, +) from ..utils import ( cmp_name_to_op_mapping, dict_items, @@ -45,27 +47,22 @@ dict_values, istype, raise_args_mismatch, - specialize_symnode, ) from .base import ( AttributeMutationExisting, AttributeMutationNew, + NO_SUCH_SUBOBJ, ValueMutationNew, VariableTracker, ) -from .constant import ( - CONSTANT_VARIABLE_FALSE, - CONSTANT_VARIABLE_NONE, - CONSTANT_VARIABLE_TRUE, - ConstantVariable, -) +from .constant import ConstantVariable +from .hashable import HashableTracker, is_hashable, raise_unhashable +from .sets import SetVariable if TYPE_CHECKING: from torch._dynamo.codegen import PyCodegen - from torch._dynamo.side_effects import SideEffects from torch._dynamo.symbolic_convert import InstructionTranslator - from torch._dynamo.variables.builtin import BuiltinVariable from .functions import UserFunctionVariable @@ -75,51 +72,10 @@ # - Implement get_python_hash() and is_python_equal() methods for hashable types -def was_instancecheck_override(obj: Any) -> bool: - return type(obj).__dict__.get("__instancecheck__", False) - - -def raise_unhashable( - arg: VariableTracker, tx: Optional["InstructionTranslator"] = None -) -> None: - from .builder import SourcelessBuilder - - if tx is None: - from torch._dynamo.symbolic_convert import InstructionTranslator - - tx = InstructionTranslator.current_tx() - try: - arg_type = arg.python_type() - except Exception: - arg_type = type(arg) - - raise_observed_exception( - TypeError, - tx, - args=[ - SourcelessBuilder.create( - tx, - f"unhashable type: {arg_type!r} and variable tracker = {type(arg.realize())}", - ) - ], - ) - - -def is_hashable(x: VariableTracker) -> bool: - # NB - performing isinstance check on a LazVT realizes the VT, accidentally - # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at - # the underlying value without realizing the VT. Consider updating the - # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. - if ( - isinstance(x, variables.LazyVariableTracker) - and not x.is_realized() - and x.is_hashable() - ): - return True - return x.is_python_hashable() - - class ConstDictVariable(VariableTracker): + # PyDict_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/dictobject.c#L4825 + _cpython_type = dict + CONTAINS_GUARD = GuardBuilder.DICT_CONTAINS NOT_CONTAINS_GUARD = GuardBuilder.DICT_NOT_CONTAINS @@ -128,60 +84,6 @@ class ConstDictVariable(VariableTracker): *VariableTracker._nonvar_fields, } - class _HashableTracker: - """ - Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable - This should not be seen or touched by anything outside of ConstDictVariable and its children - Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing - """ - - def __init__(self, vt: VariableTracker) -> None: - # We specialize SymNodes - vt = specialize_symnode(vt) - - # If Dynamo does not know the hashability of the vt, it will raise unsupported here - if not is_hashable(vt): - raise_unhashable(vt) - self.vt = vt - - def __hash__(self) -> int: - """ - Computes the hash value for the wrapped VariableTracker. - - For unrealized LazyVariableTrackers, uses the hash of the original value - to avoid realizing the tracker and inserting unnecessary guards. - For all other cases, delegates to the VariableTracker's get_python_hash method. - - Returns: - The hash value of the underlying variable tracker - """ - if ( - isinstance(self.vt, variables.LazyVariableTracker) - and not self.vt.is_realized() - and self.vt.is_hashable() - ): - return hash(self.vt.original_value()) - return self.vt.get_python_hash() - - def __eq__(self, other: object) -> bool: - """ - Checks equality between two _HashableTracker instances. - - Delegates to the VariableTracker's is_python_equal method to compare - the underlying variable trackers for Python-level equality. - - Args: - other: Another _HashableTracker instance to compare with - - Returns: - True if the underlying variable trackers are Python-equal, False otherwise - """ - if not isinstance(other, ConstDictVariable._HashableTracker): - return False - if self.vt is other.vt: - return True - return self.vt.is_python_equal(other.vt) - def __init__( self, items: dict[VariableTracker, VariableTracker], @@ -197,7 +99,7 @@ def __init__( super().__init__(**kwargs) - Hashable = ConstDictVariable._HashableTracker + Hashable = HashableTracker # Keys will just be HashableTrackers when cloning, in any other case they'll be VariableTrackers assert all( @@ -207,8 +109,8 @@ def __init__( ) def make_hashable( - key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], - ) -> "ConstDictVariable._HashableTracker": + key: Union[VariableTracker, "HashableTracker"], + ) -> "HashableTracker": return key if isinstance(key, Hashable) else Hashable(key) dict_cls = self._get_dict_cls_from_user_cls(user_cls) @@ -266,11 +168,12 @@ def python_type(self) -> type: def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) - Hashable = ConstDictVariable._HashableTracker - return ( - vt.is_python_hashable() - and Hashable(vt) in self.items - and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable) + Hashable = HashableTracker + if not is_hashable(vt): + return False + key = Hashable(vt) + return key in self.items and not isinstance( + self.items[key], variables.DeletedVariable ) def call_tree_map_branch( @@ -458,23 +361,15 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: - key = ConstDictVariable._HashableTracker(arg) + key = HashableTracker(arg) if key not in self.items: - try: - error_message = ( - f"Dict key lookup failed for {str(arg)}. " - f"Debug representation of the key is {arg.debug_repr()!r}" - ) - except Exception: - error_message = f"Dict key lookup failed for {str(arg)}" - error_message = VariableTracker.build(tx, error_message) - raise_observed_exception(KeyError, tx, args=[error_message]) + raise_observed_exception(KeyError, tx, args=[arg]) return self.items[key] def getitem_const( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: - key = ConstDictVariable._HashableTracker(arg) + key = HashableTracker(arg) if key not in self.items: msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] unimplemented( @@ -489,7 +384,7 @@ def getitem_const( return self.items[key] def maybe_getitem_const(self, arg: VariableTracker) -> VariableTracker | None: - key = ConstDictVariable._HashableTracker(arg) + key = HashableTracker(arg) if key not in self.items: return None return self.items[key] @@ -497,7 +392,7 @@ def maybe_getitem_const(self, arg: VariableTracker) -> VariableTracker | None: def realize_key_vt(self, arg: VariableTracker) -> None: # Realize the LazyVT on a particular index assert arg in self - key = ConstDictVariable._HashableTracker(arg) + key = HashableTracker(arg) index = tuple(self.items.keys()).index(key) original_key_vt = tuple(self.original_items.keys())[index] if isinstance(original_key_vt, variables.LazyVariableTracker): @@ -549,6 +444,15 @@ def install_dict_contains_guard( else: self.install_dict_keys_match_guard() + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # dict_subscript: https://github.com/python/cpython/blob/62a6e898e01/Objects/dictobject.c#L3673-L3706 + # Unhashable key check happens inside _HashableTracker (raise_unhashable → TypeError). + return self.getitem_const_raise_exception_if_absent(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -563,23 +467,18 @@ def call_method( # corresponding value VT. For __contains__, we add a DICT_CONTAINS # guard. But for all the other methods, we insert the DICT_KEYS_MATCH # guard to be conservative. - from . import BuiltinVariable + from . import DictBuiltinVariable from .builder import SourcelessBuilder - Hashable = ConstDictVariable._HashableTracker + Hashable = HashableTracker if name == "__init__": - temp_dict_vt = SourcelessBuilder.create(tx, dict).call_dict( - tx, *args, **kwargs + temp_dict_vt = DictBuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs ) tx.output.side_effects.mutation(self) self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] - return CONSTANT_VARIABLE_NONE - elif name == "__getitem__": - # Key guarding - Nothing to do. LazyVT for value will take care. - if len(args) != 1: - raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") - return self.getitem_const_raise_exception_if_absent(tx, args[0]) + return ConstantVariable.create(None) elif name == "items": if args or kwargs: raise_args_mismatch( @@ -625,16 +524,6 @@ def call_method( return self.clone( items=self.items.copy(), mutation_type=ValueMutationNew(), source=None ) - elif name == "__len__": - if args or kwargs: - raise_args_mismatch( - tx, - name, - "0 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - self.install_dict_keys_match_guard() - return VariableTracker.build(tx, len(self.items)) elif name == "__setitem__" and self.is_mutable(): arg_hashable = args and is_hashable(args[0]) if not arg_hashable: @@ -650,7 +539,7 @@ def call_method( ) tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = args[1] - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "__delitem__" and self.is_mutable(): arg_hashable = args and is_hashable(args[0]) if arg_hashable: @@ -658,7 +547,7 @@ def call_method( self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.__delitem__(Hashable(args[0])) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) elif name == "get": @@ -673,7 +562,7 @@ def call_method( self.install_dict_contains_guard(tx, args) if len(args) == 1: # if default is not given, return None - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) return args[1] # Key guarding - Nothing to do. return self.getitem_const(tx, args[0]) @@ -697,30 +586,21 @@ def call_method( tx.output.side_effects.mutation(self) return self.items.pop(Hashable(args[0])) elif name == "popitem" and self.is_mutable(): - if ( - issubclass(self.user_cls, dict) - and not issubclass(self.user_cls, collections.OrderedDict) - and len(args) - ): + # dict.popitem() takes no args. OrderedDict.popitem(last=) is + # handled by OrderedDictVariable.call_method. + if len(args): raise_args_mismatch(tx, name) if not self.items: - msg = VariableTracker.build(tx, "popitem(): dictionary is empty") - raise_observed_exception(KeyError, tx, args=[msg]) - - if self.user_cls is collections.OrderedDict and ( - len(args) == 1 or "last" in kwargs - ): - if len(args) == 1 and args[0].is_python_constant(): - last = args[0].as_python_constant() - elif (v := kwargs.get("last")) and v.is_python_constant(): - last = v.as_python_constant() - else: - raise_args_mismatch(tx, name) - k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] - else: - k, v = self.items.popitem() + raise_observed_exception( + KeyError, + tx, + args=[ + "popitem(): dictionary is empty", + ], + ) + k, v = self.items.popitem() self.should_reconstruct_all = True tx.output.side_effects.mutation(self) @@ -736,7 +616,7 @@ def call_method( self.should_reconstruct_all = True tx.output.side_effects.mutation(self) self.items.clear() - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "update" and self.is_mutable(): # In general, this call looks like `a.update(b, x=1, y=2, ...)`. # Either `b` or the kwargs is omittable, but not both. @@ -746,13 +626,16 @@ def call_method( if has_arg or has_kwargs: tx.output.side_effects.mutation(self) if has_arg: + dict_vt: VariableTracker if isinstance(args[0], ConstDictVariable): # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() - dict_vt: ConstDictVariable = args[0] + dict_vt = args[0] else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + dict_vt = DictBuiltinVariable.call_custom_dict( + tx, dict, args[0] + ) self.items.update(dict_vt.items) # type: ignore[attr-defined] if has_kwargs: # Handle kwargs @@ -761,7 +644,7 @@ def call_method( for k, v in kwargs.items() } self.items.update(kwargs_hashable) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) elif name == "__contains__": @@ -806,28 +689,12 @@ def call_method( return value else: if len(args) == 1: - x = CONSTANT_VARIABLE_NONE + x = ConstantVariable.create(None) else: x = args[1] tx.output.side_effects.mutation(self) self.items[Hashable(args[0])] = x return x - elif name == "move_to_end": - self.install_dict_keys_match_guard() - tx.output.side_effects.mutation(self) - if args[0] not in self: - raise_observed_exception(KeyError, tx) - - last = True - if len(args) == 2 and args[1].is_python_constant(): - last = args[1].as_python_constant() - - if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): - last = kwargs.get("last").as_python_constant() # type: ignore[union-attr] - - key = Hashable(args[0]) - self.items.move_to_end(key, last=last) - return CONSTANT_VARIABLE_NONE elif name == "__eq__" and istype( self, ConstDictVariable ): # don't let Set use this function @@ -864,14 +731,19 @@ def call_method( # defaultdict. # TODO(guilhermeleobas): this check should be on builtin.py::call_or_ - if istype( + if isinstance( other, ( ConstDictVariable, variables.UserDefinedDictVariable, - variables.DefaultDictVariable, ), ): + # Unwrap UserDefinedDictVariable to its underlying ConstDictVariable + if isinstance(other, variables.UserDefinedDictVariable): + assert other._base_vt is not None + assert isinstance(other._base_vt, ConstDictVariable) + other = other._base_vt + # Always return the specialized dictionary, and in the case # both are specialized, take the first to be the type of the # new dictionary @@ -879,7 +751,6 @@ def call_method( user_cls = self.user_cls to_cpy = self else: - assert isinstance(other, ConstDictVariable) user_cls = other.user_cls to_cpy = other @@ -892,17 +763,20 @@ def call_method( ) # NB - Guard on all the keys of the other dict to ensure - # correctness. - args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] - new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] + # correctness. Use `other` (already unwrapped from + # UserDefinedDictVariable to ConstDictVariable above). + other.install_dict_keys_match_guard() # type: ignore[union-attr] + new_dict_vt.items.update(other.items) # type: ignore[union-attr] return new_dict_vt else: - err_msg = VariableTracker.build( + raise_observed_exception( + TypeError, tx, - f"unsupported operand type(s) for |: '{self.python_type().__name__}'" - f"and '{other.python_type().__name__}'", + args=[ + f"unsupported operand type(s) for |: '{self.python_type().__name__}'" + f"and '{other.python_type().__name__}'" + ], ) - raise_observed_exception(TypeError, tx, args=[err_msg]) elif name == "__ior__": self.call_method(tx, "update", args, kwargs) return self @@ -921,6 +795,11 @@ def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTrack self.install_dict_keys_match_guard() return [x.vt for x in self.items] + def mp_length(self, tx: "InstructionTranslator") -> VariableTracker: + """Mapping length for dict objects.""" + self.install_dict_keys_match_guard() + return VariableTracker.build(tx, len(self.items)) + def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: @@ -931,9 +810,9 @@ def call_obj_hasattr( for t in (dict, collections.OrderedDict, collections.defaultdict) ): if hasattr(self.user_cls, name): - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) if self.user_cls is dict: - return CONSTANT_VARIABLE_FALSE + return ConstantVariable.create(False) msg = f"hasattr on {self.user_cls} is not supported" unimplemented( @@ -956,8 +835,16 @@ def is_python_hashable(self) -> bool: """ return False + def var_getattr(self, tx: "InstructionTranslator", name: str): + if name == "__class__": + return VariableTracker.build(tx, self.python_type()) + return super().var_getattr(tx, name) + class MappingProxyVariable(VariableTracker): + # PyDictProxy_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/descrobject.c#L1995 + _cpython_type = types.MappingProxyType + # proxies to the original dict_vt def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -997,13 +884,7 @@ def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.dv_dict) codegen.extend_output(create_call_function(1, False)) - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: + def _check_mutation_guard(self, tx: "InstructionTranslator") -> None: if self.source and tx.output.side_effects.has_existing_dict_mutation(): msg = ( "A dict has been modified while we have an existing mappingproxy object. " @@ -1024,609 +905,16 @@ def call_method( "Or avoid using the mapping proxy objects after modifying its underlying dictionary", ], ) - return self.dv_dict.call_method(tx, name, args, kwargs) - - def call_obj_hasattr( - self, tx: "InstructionTranslator", name: str - ) -> ConstantVariable: - if self.python_type() is types.MappingProxyType: - return VariableTracker.build(tx, name in types.MappingProxyType.__dict__) - return super().call_obj_hasattr(tx, name) - - -class NNModuleHooksDictVariable(ConstDictVariable): - # Special class to avoid adding any guards on the nn module hook ids. - def install_dict_keys_match_guard(self) -> None: - pass - - def install_dict_contains_guard( - self, tx: "InstructionTranslator", args: list[VariableTracker] - ) -> None: - pass - - -class DefaultDictVariable(ConstDictVariable): - def __init__( - self, - items: dict[VariableTracker, VariableTracker], - user_cls: type, - default_factory: VariableTracker | None = None, - **kwargs: Any, - ) -> None: - super().__init__(items, user_cls, **kwargs) - assert user_cls is collections.defaultdict - if default_factory is None: - default_factory = CONSTANT_VARIABLE_NONE - self.default_factory = default_factory - - def is_python_constant(self) -> bool: - # Return false for unsupported defaults. This ensures that a bad handler - # path is not taken in BuiltinVariable for getitem. - if self.default_factory not in [list, tuple, dict] and not self.items: - return False - return super().is_python_constant() - - def debug_repr(self) -> str: - assert self.default_factory is not None - return ( - f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" - ) - - @staticmethod - def is_supported_arg(arg: VariableTracker) -> bool: - return isinstance( - arg, - ( - variables.BuiltinVariable, - variables.functions.BaseUserFunctionVariable, - variables.functions.PolyfilledFunctionVariable, - ), - ) or (isinstance(arg, variables.ConstantVariable) and arg.value is None) - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if name == "__getitem__": - if len(args) != 1: - raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") - - if args[0] in self: - return self.getitem_const(tx, args[0]) - else: - if ( - istype(self.default_factory, ConstantVariable) - and self.default_factory.value is None - ): - raise_observed_exception(KeyError, tx, args=[args[0]]) - else: - default_var = self.default_factory.call_function(tx, [], {}) - super().call_method( - tx, "__setitem__", [args[0], default_var], kwargs - ) - return default_var - elif name == "__setattr__" and self.is_mutable: - if len(args) != 2: - raise_args_mismatch(tx, name, "2 args", f"{len(args)} args") - # Setting a default factory must be a callable or None type - if ( - istype(args[0], ConstantVariable) and args[0].value == "default_factory" - ) and self.is_supported_arg(args[1]): - tx.output.side_effects.mutation(self) - self.default_factory = args[1] - return CONSTANT_VARIABLE_NONE - return super().call_method(tx, name, args, kwargs) - elif name == "__eq__": - if len(args) != 1: - raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") - - return VariableTracker.build(tx, polyfills.dict___eq__).call_function( - tx, [self, args[0]], {} - ) - else: - return super().call_method(tx, name, args, kwargs) - - def var_getattr( - self, - tx: "InstructionTranslator", - name: str, - ) -> VariableTracker: - if name == "default_factory": - return self.default_factory - return super().var_getattr(tx, name) - - def reconstruct(self, codegen: "PyCodegen") -> None: - # emit `defaultdict(default_factory, new_dict)` - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(collections), - codegen.create_load_attr("defaultdict"), - ] - ) - ) - codegen(self.default_factory) - codegen.extend_output( - [ - *create_call_function(1, False), - create_dup_top(), - ] - ) - codegen.add_cache(self) - - codegen.append_output(create_dup_top()) - codegen.load_method("update") - self.reconstruct_kvs_into_new_dict(codegen) - codegen.extend_output( - [ - *create_call_method(1), - create_instruction("POP_TOP"), - ] - ) - - -# TODO: Implementing this via inheritance rather than composition is a -# footgun, because self method calls in dict will route back to the set -# implementation, which is almost assuredly wrong -class SetVariable(ConstDictVariable): - """We model a sets as dictionary with None values""" - - CONTAINS_GUARD = GuardBuilder.SET_CONTAINS - NOT_CONTAINS_GUARD = GuardBuilder.SET_NOT_CONTAINS - - def __init__( - self, - items: Iterable[VariableTracker], - **kwargs: Any, - ) -> None: - # Items can be either VariableTrackers or _HashableTrackers (from set ops). - # For VariableTrackers, realize them to ensure aliasing guards are installed - # when the same object appears multiple times. - realized_items = [] - for item in items: - if isinstance(item, ConstDictVariable._HashableTracker): - # Already a _HashableTracker from a set operation - realized_items.append(item) - else: - # VariableTracker - realize to install guards - # pyrefly: ignore [bad-argument-type] - realized_items.append(item.realize()) - items = dict.fromkeys(realized_items, SetVariable._default_value()) - super().__init__(items, **kwargs) - - def debug_repr(self) -> str: - if not self.items: - return "set()" - else: - items: list[str] = [] - for v in self.items: - vt = v.vt if isinstance(v, ConstDictVariable._HashableTracker) else v - val_str = repr(vt.value) if hasattr(vt, "value") else vt.debug_repr() - items.append(val_str) - return "{" + ",".join(items) + "}" - - @property - def set_items(self) -> set["ConstDictVariable._HashableTracker"]: - return set(self.items.keys()) - - @staticmethod - def _default_value() -> VariableTracker: - # Variable to fill in he keys of the dictionary - return CONSTANT_VARIABLE_NONE - - def as_proxy(self) -> Any: - return {k.vt.as_proxy() for k in self.set_items} - - def python_type(self) -> type: - return set - - def as_python_constant(self) -> Any: - return {k.vt.as_python_constant() for k in self.set_items} - - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.foreach([x.vt for x in self.set_items]) - codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) - - def _fast_set_method( - self, - tx: "InstructionTranslator", - fn: Any, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - try: - res = fn( - *[x.as_python_constant() for x in [self, *args]], - **{k: v.as_python_constant() for k, v in kwargs.items()}, - ) - except Exception as exc: - raise_observed_exception( - type(exc), tx, args=[VariableTracker.build(tx, a) for a in exc.args] - ) - return VariableTracker.build(tx, res) - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - # We forward the calls to the dictionary model - from ..utils import check_constant_args - from .builder import SourcelessBuilder - - if ( - name - in ( - "isdisjoint", - "union", - "intersection", - "difference", - "symmetric_difference", - ) - and check_constant_args(args, kwargs) - and self.python_type() is set - ): - py_type = self.python_type() - return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) - - if name == "__init__": - temp_set_vt = SourcelessBuilder.create(tx, set).call_set( - tx, *args, **kwargs - ) - tx.output.side_effects.mutation(self) - self.items.clear() - self.items.update(temp_set_vt.items) # type: ignore[attr-defined] - return CONSTANT_VARIABLE_NONE - elif name == "add": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - name = "__setitem__" - args = [args[0], SetVariable._default_value()] - elif name == "pop": - if kwargs or args: - raise_args_mismatch( - tx, - name, - "0 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - # Choose an item at random and pop it via the Dict.pop method - try: - result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] - except KeyError as e: - raise_observed_exception( - KeyError, tx, args=[VariableTracker.build(tx, a) for a in e.args] - ) - super().call_method(tx, name, [result], kwargs) - return result - elif name == "isdisjoint": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - return SourcelessBuilder.create(tx, polyfills.set_isdisjoint).call_function( - tx, [self, args[0]], {} - ) - elif name == "intersection": - if kwargs: - raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - return SourcelessBuilder.create( - tx, polyfills.set_intersection - ).call_function( - tx, - [self, *args], - {"cls": self.python_type_var()}, - ) - elif name == "intersection_update": - if kwargs: - raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - return SourcelessBuilder.create( - tx, polyfills.set_intersection_update - ).call_function(tx, [self, *args], {}) - elif name == "union": - if kwargs: - raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - return SourcelessBuilder.create(tx, polyfills.set_union).call_function( - tx, - [self, *args], - {"cls": self.python_type_var()}, - ) - elif name == "difference": - if kwargs: - raise_args_mismatch( - tx, name, f"Expect: 0 kwargs, Actual: {len(kwargs)} kwargs" - ) - return SourcelessBuilder.create(tx, polyfills.set_difference).call_function( - tx, - [self, *args], - {"cls": self.python_type_var()}, - ) - elif name == "difference_update": - if kwargs: - raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - return SourcelessBuilder.create( - tx, polyfills.set_difference_update - ).call_function(tx, [self, *args], {}) - elif name == "symmetric_difference": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - return SourcelessBuilder.create( - tx, polyfills.set_symmetric_difference - ).call_function( - tx, - [self, *args], - {"cls": self.python_type_var()}, - ) - elif name == "symmetric_difference_update": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - return SourcelessBuilder.create( - tx, polyfills.set_symmetric_difference_update - ).call_function(tx, [self, *args], {}) - elif name == "update" and self.is_mutable(): - if kwargs: - raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") - return SourcelessBuilder.create(tx, polyfills.set_update).call_function( - tx, [self, *args], {} - ) - elif name == "remove": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - if args[0] not in self: - raise_observed_exception(KeyError, tx, args=args) - return super().call_method(tx, "pop", args, kwargs) - elif name == "discard": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - if args[0] in self: - return super().call_method(tx, "pop", args, kwargs) - else: - return CONSTANT_VARIABLE_NONE - elif name in ("issubset", "issuperset"): - if len(args) != 1: - raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") - - op = { - "issubset": operator.le, - "issuperset": operator.ge, - } - other = args[0].realize() - if not istype(other, SetVariable): - other = SourcelessBuilder.create(tx, set).call_function(tx, [other], {}) - return SourcelessBuilder.create(tx, op.get(name)).call_function( - tx, [self, other], {} - ) - elif name in ("__and__", "__or__", "__xor__", "__sub__"): - m = { - "__and__": "intersection", - "__or__": "union", - "__xor__": "symmetric_difference", - "__sub__": "difference", - }.get(name) - if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): - msg = VariableTracker.build( - tx, - f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'", - ) - raise_observed_exception(TypeError, tx, args=[msg]) - assert m is not None - return self.call_method(tx, m, args, kwargs) - elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): - if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): - msg = VariableTracker.build( - tx, - f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'", - ) - raise_observed_exception(TypeError, tx, args=[msg]) - m = { - "__iand__": "intersection_update", - "__ior__": "update", - "__ixor__": "symmetric_difference_update", - "__isub__": "difference_update", - }.get(name) - assert m is not None - self.call_method(tx, m, args, kwargs) - return self - elif name == "__eq__": - if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): - return CONSTANT_VARIABLE_FALSE - r = self.call_method(tx, "symmetric_difference", args, kwargs) - return VariableTracker.build(tx, len(r.set_items) == 0) # type: ignore[attr-defined] - elif name in cmp_name_to_op_mapping: - if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): - return VariableTracker.build(tx, NotImplemented) - return VariableTracker.build( - tx, - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items), # type: ignore[attr-defined] - ) - return super().call_method(tx, name, args, kwargs) - - def python_type_var(self) -> "BuiltinVariable": - return variables.BuiltinVariable(set) - - def getitem_const( - self, tx: "InstructionTranslator", arg: VariableTracker - ) -> VariableTracker: - raise RuntimeError("Illegal to getitem on a set") - def install_dict_keys_match_guard(self) -> None: - # Already EQUALS_MATCH guarded - pass - - -class OrderedSetClassVariable(VariableTracker): - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) - - def as_python_constant(self) -> type[OrderedSet[Any]]: - return OrderedSet - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - if name == "__new__": - from .misc import GetAttrVariable - - if self.source: - attr_source = AttrSource(self.source, name) - else: - attr_source = None - return GetAttrVariable(self, name, source=attr_source) - else: - return super().var_getattr(tx, name) - - def call_method( + def mp_subscript_impl( self, tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], + key: VariableTracker, ) -> VariableTracker: - from .builtin import set_methods - - if name == "__new__": - if len(args) != 2 or kwargs: - raise_args_mismatch( - tx, - name, - "OrderedSet.__new__ only accepts one arg" - f"{len(args)} args and {len(kwargs)} kwargs", - ) - - return variables.OrderedSetVariable([], mutation_type=ValueMutationNew()) - - resolved_fn = getattr(set, name) - if resolved_fn in set_methods and isinstance(args[0], variables.SetVariable): - return args[0].call_method(tx, name, args[1:], kwargs) - - return super().call_method(tx, name, args, kwargs) - - def call_function( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> "OrderedSetVariable": - if len(args) > 1 or kwargs: - raise_args_mismatch( - tx, - "OrderedSet", - "OrderedSet only accepts one arg" - f"{len(args)} args and {len(kwargs)} kwargs", - ) - - if len(args) == 0: - # pyrefly: ignore [implicit-any] - items = [] - else: - items = args[0].force_unpack_var_sequence(tx) - return variables.OrderedSetVariable(items, mutation_type=ValueMutationNew()) - - -class OrderedSetVariable(SetVariable): - def debug_repr(self) -> str: - if not self.items: - return "OrderedSet([])" - else: - items: list[str] = [] - for k, v in self.items: - key_str = ( - repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() - ) - items.append(key_str) - return "OrderedSet([" + ",".join(items) + "])" - - def as_python_constant(self) -> OrderedSet[Any]: - return OrderedSet([k.vt.as_python_constant() for k in self.set_items]) - - def python_type(self) -> type[OrderedSet[Any]]: - return OrderedSet - - # pyrefly: ignore[bad-override] - def python_type_var(self) -> OrderedSetClassVariable: - return OrderedSetClassVariable() - - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null( - lambda: codegen.load_import_from("torch.utils._ordered_set", "OrderedSet") - ) - codegen.foreach([x.vt for x in self.set_items]) - codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.set_items))) - codegen.extend_output(create_call_function(1, False)) - - -class FrozensetVariable(SetVariable): - def debug_repr(self) -> str: - if not self.items: - return "frozenset()" - else: - items: list[str] = [] - for k in self.items: - key_str = ( - repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() - ) - items.append(key_str) - return "{" + ",".join(items) + "}" - - @property - def set_items(self) -> set["ConstDictVariable._HashableTracker"]: - return self.items.keys() - - def python_type(self) -> type: - return frozenset - - def python_type_var(self) -> "BuiltinVariable": - return variables.BuiltinVariable(frozenset) - - def as_python_constant(self) -> Any: - return frozenset({k.vt.as_python_constant() for k in self.set_items}) - - def reconstruct(self, codegen: "PyCodegen") -> None: - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_global("frozenset"), - ] - ) - ) - codegen.foreach([x.vt for x in self.set_items]) - codegen.extend_output( - [ - create_instruction("BUILD_LIST", arg=len(self.set_items)), - *create_call_function(1, False), - ] - ) + # mappingproxy_getitem: https://github.com/python/cpython/blob/62a6e898e01/Objects/descrobject.c#L1052-L1056 + # TODO(follow-up): add tests for invalid key type, missing key + self._check_mutation_guard(tx) + return self.dv_dict.mp_subscript_impl(tx, key) def call_method( self, @@ -1635,89 +923,30 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if name in ["add", "pop", "update", "remove", "discard", "clear"]: - raise RuntimeError(f"Illegal call_method {name} on a frozenset") - elif name == "__init__": - # frozenset is immutable. Calling __init__ again shouldn't have any effect - # In[1]: s = frozenset([1, 2]) - # - # In[2]: s.__init__([3, 4]) - # - # In[3]: s - # frozenset({1, 2}) - return CONSTANT_VARIABLE_NONE - elif name in ( - "copy", - "difference", - "intersection", - "symmetric_difference", - ): - r = super().call_method(tx, name, args, kwargs) - return FrozensetVariable(r.items) # type: ignore[attr-defined] - return super().call_method(tx, name, args, kwargs) - - def is_python_hashable(self) -> Literal[True]: - """ - Frozensets are immutable and hashable in Python. - """ - return True - - def get_python_hash(self) -> int: - return hash(self.as_python_constant()) + self._check_mutation_guard(tx) + return self.dv_dict.call_method(tx, name, args, kwargs) - def is_python_equal(self, other: object) -> bool: - return ( - isinstance(other, VariableTracker) - and self.as_python_constant() == other.as_python_constant() - ) + def mp_length(self, tx: "InstructionTranslator") -> VariableTracker: + return self.dv_dict.mp_length(tx) + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + if self.python_type() is types.MappingProxyType: + return VariableTracker.build(tx, name in types.MappingProxyType.__dict__) + return super().call_obj_hasattr(tx, name) -class DictKeySetVariable(SetVariable): - def debug_repr(self) -> str: - if not self.items: - return "dict_keys([])" - else: - items: list[str] = [] - for k in self.items: - key_str = ( - repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() - ) - items.append(key_str) - return "dict_keys([" + ",".join(items) + "])" +class NNModuleHooksDictVariable(ConstDictVariable): + # Special class to avoid adding any guards on the nn module hook ids. def install_dict_keys_match_guard(self) -> None: - # Already EQUALS_MATCH guarded pass def install_dict_contains_guard( self, tx: "InstructionTranslator", args: list[VariableTracker] ) -> None: - # Already EQUALS_MATCH guarded pass - @property - def set_items(self) -> Any: - return self.items - - def python_type(self) -> type: - return dict_keys - - def as_python_constant(self) -> Any: - return dict.fromkeys( - {k.vt.as_python_constant() for k in self.set_items}, None - ).keys() - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if name in ["add", "pop", "update", "remove", "discard", "clear"]: - raise RuntimeError(f"Illegal call_method {name} on a dict_keys") - return super().call_method(tx, name, args, kwargs) - class DictViewVariable(VariableTracker): """ @@ -1759,8 +988,8 @@ def call_obj_hasattr( ) -> ConstantVariable: assert self.kv is not None if name in self.python_type().__dict__: - return CONSTANT_VARIABLE_TRUE - return CONSTANT_VARIABLE_FALSE + return ConstantVariable.create(True) + return ConstantVariable.create(False) def call_method( self, @@ -1769,9 +998,7 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if name == "__len__": - return self.dv_dict.call_method(tx, name, args, kwargs) - elif name == "__iter__": + if name == "__iter__": from .lists import ListIteratorVariable return ListIteratorVariable( @@ -1781,8 +1008,15 @@ def call_method( return VariableTracker.build(tx, self.debug_repr()) return super().call_method(tx, name, args, kwargs) + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + """Sequence length for dict view objects.""" + return VariableTracker.build(tx, len(self.view_items)) + class DictKeysVariable(DictViewVariable): + # PyDictKeys_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/dictobject.c#L6365 + _cpython_type = dict_keys + kv = "keys" @property @@ -1832,8 +1066,16 @@ def call_method( m = getattr(self.set_items, name) r = m(args[0].set_items) # type: ignore[attr-defined] return SetVariable(r) - if name in cmp_name_to_op_mapping: - if not isinstance(args[0], (SetVariable, DictKeysVariable)): + elif name in cmp_name_to_op_mapping: + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): return VariableTracker.build(tx, NotImplemented) return VariableTracker.build( tx, @@ -1843,6 +1085,9 @@ def call_method( class DictValuesVariable(DictViewVariable): + # PyDictValues_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/dictobject.c#L6567 + _cpython_type = dict_values + # DictValuesVariable is an iterable but cannot be compared. kv = "values" @@ -1865,8 +1110,18 @@ def debug_repr(self) -> str: class DictItemsVariable(DictViewVariable): + # PyDictItems_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/dictobject.c#L6477 + _cpython_type = dict_items + kv = "items" + @property + def set_items(self) -> set["HashableTracker"]: + return { + HashableTracker(variables.TupleVariable([k.vt, v])) + for k, v in self.view_items + } + @property def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items @@ -1902,13 +1157,55 @@ def call_method( raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") if isinstance(args[0], DictItemsVariable): return self.dv_dict.call_method(tx, "__eq__", [args[0].dv_dict], {}) - return CONSTANT_VARIABLE_FALSE + elif isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + return VariableTracker.build( + tx, + len(self.set_items ^ args[0].set_items) == 0, + ) + return ConstantVariable.create(False) elif name == "__iter__": from .lists import ListIteratorVariable return ListIteratorVariable( self.view_items_vt, mutation_type=ValueMutationNew() ) + elif name in ( + "__and__", + "__iand__", + "__or__", + "__ior__", + "__sub__", + "__isub__", + "__xor__", + "__ixor__", + ): + # These methods always returns a set + fn_hdl = getattr(self.set_items, name) + ret_val = fn_hdl(args[0].set_items) # type: ignore[attr-defined] + return SetVariable(ret_val) + elif name in cmp_name_to_op_mapping: + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + return VariableTracker.build(tx, NotImplemented) + return VariableTracker.build( + tx, + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items), # type: ignore[attr-defined] + ) return super().call_method(tx, name, args, kwargs) def is_python_hashable(self) -> Literal[False]: @@ -1918,7 +1215,7 @@ def is_python_hashable(self) -> Literal[False]: return False -kV = ConstDictVariable._HashableTracker | str +kV = HashableTracker | str class SideEffectsProxyDict(collections.abc.MutableMapping[kV, VariableTracker]): @@ -1927,20 +1224,68 @@ class SideEffectsProxyDict(collections.abc.MutableMapping[kV, VariableTracker]): effects table as storage. """ - def __init__(self, item: VariableTracker, side_effects: "SideEffects") -> None: + @staticmethod + def get_example_value_dict(vt: VariableTracker) -> dict[str, object]: + if istype(vt, variables.NestedUserFunctionVariable): + # NestedUserFunctionVariable is created with MAKE_FUNCTION and its + # __dict__ starts empty. Any mutation will actually be recorded in + # the side effects table. + return {} + elif isinstance(vt, variables.LocalGeneratorFunctionVariable): + return SideEffectsProxyDict.get_example_value_dict(vt.vt) + else: + value = vt.get_real_python_backed_value() + if value is not NO_SUCH_SUBOBJ: + if isinstance(vt, variables.UserDefinedObjectVariable): + return vt._getattr_static("__dict__") # type: ignore[bad-return] + else: + return object.__getattribute__(value, "__dict__") + else: + unimplemented( + gb_type="unsupported variable type for __dict__ access", + context=f"VariableTracker type: {type(vt)}", + explanation=f"Dynamo does not know how to get __dict__ from {type(vt)}", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + @staticmethod + def get_value___dict__( + tx: "InstructionTranslator", vt: VariableTracker + ) -> dict[str, VariableTracker]: + example_value_dict = SideEffectsProxyDict.get_example_value_dict(vt) + + return { + key: VariableTracker.build( + tx, + value, + source=vt.source + and DictGetItemSource(AttrSource(vt.source, "__dict__"), key), + ) + for key, value in example_value_dict.items() + } + + def __init__(self, item: VariableTracker, tx: "InstructionTranslator") -> None: self.item = item - self.side_effects = side_effects + self.side_effects = tx.output.side_effects + self.item_dict = self.get_value___dict__(tx, item) def _maybe_unwrap_key(self, key: kV) -> str: - Hasher = ConstDictVariable._HashableTracker + Hasher = HashableTracker return key.vt.as_python_constant() if istype(key, Hasher) else key + def side_effects_table(self) -> dict[str, VariableTracker]: + return self.side_effects.store_attr_mutations.get(self.item, {}) + def __getitem__(self, key: kV) -> VariableTracker: name = self._maybe_unwrap_key(key) - return self.side_effects.load_attr(self.item, name) + if self.side_effects.has_pending_mutation_of_attr(self.item, name): + return self.side_effects.load_attr(self.item, name, deleted_ok=True) + return self.item_dict[name] def __setitem__(self, key: kV, value: VariableTracker) -> None: - # Find a way to not hash the key using _HashableTracker + # Find a way to not hash the key using HashableTracker name = self._maybe_unwrap_key(key) assert istype(name, str) self.side_effects.store_attr(self.item, name, value) @@ -1951,19 +1296,29 @@ def __delitem__(self, key: kV) -> None: def __contains__(self, key: kV) -> bool: # type: ignore[bad-override] name = self._maybe_unwrap_key(key) - return name in self.side_effects.store_attr_mutations.get(self.item, {}) + table = self.side_effects_table() + # if name in side effects, then it is only contained if it's not a DeletedVariable + # even if the original dict contains it + if name in table: + return not isinstance(table[name], variables.DeletedVariable) + else: + return name in self.item_dict def __len__(self) -> int: - return len(self.side_effects.store_attr_mutations.get(self.item, {})) + return sum(1 for _ in self) - def __iter__(self) -> Iterator[ConstDictVariable._HashableTracker]: - Hasher = ConstDictVariable._HashableTracker - d = self.side_effects.store_attr_mutations.get(self.item, {}) + def __iter__(self) -> Iterator[HashableTracker]: + Hasher = HashableTracker + d = self.side_effects_table() for k, v in d.items(): if isinstance(v, variables.DeletedVariable): continue yield Hasher(ConstantVariable.create(k)) + for k, v in self.item_dict.items(): + if k not in d: + yield Hasher(ConstantVariable.create(k)) + class DunderDictVariable(ConstDictVariable): """represents object.__dict__""" @@ -1973,14 +1328,13 @@ def create( cls, tx: "InstructionTranslator", vt: VariableTracker, - dict_proxy: dict[str, VariableTracker], ) -> "DunderDictVariable": mutation = AttributeMutationExisting() if vt.source else AttributeMutationNew() source = vt.source and AttrSource(vt.source, "__dict__") + return cls( vt, - dict_proxy=types.MappingProxyType(dict_proxy), - side_effects=tx.output.side_effects, + tx=tx, mutation_type=mutation, source=source, ) @@ -1988,28 +1342,20 @@ def create( def __init__( self, vt: VariableTracker, - dict_proxy: types.MappingProxyType[str, VariableTracker], # object __dict__ - side_effects: "SideEffects", + tx: "InstructionTranslator", **kwargs: Any, ) -> None: super().__init__({}, **kwargs) - self.items = SideEffectsProxyDict(vt, side_effects) - # Saves a "proxy" dict to the original __dict__ of the object - # This allows track mutations on __dict__ (using side effects) without - # modifying the original __dict__ - self.dict_proxy = dict_proxy + self.items = SideEffectsProxyDict(vt, tx) def setitem(self, name: str, value: VariableTracker) -> None: self.items[name] = value def getitem(self, name: str) -> VariableTracker: - if name in self.items: - return self.items[name] - else: - return self.dict_proxy[name] + return self.items[name] def contains(self, name: str) -> bool: - return name in self.items or name in self.dict_proxy + return name in self.items def getitem_or_default( self, @@ -2023,96 +1369,6 @@ def getitem_or_default( self.items[name] = value return value - # We need to overload the three functions below: - # - __contains__ - # - getitem_const - # - maybe_getitem_const - # - getitem_const_raise_exception_if_absent - # because the default implementation in ConstDictVariable will directly look - # up the name in self.items, which might add undesired guards. - def __contains__(self, vt: VariableTracker) -> bool: - name = vt.as_python_constant() - return self.contains(name) - - def getitem_const( - self, tx: "InstructionTranslator", arg: VariableTracker - ) -> VariableTracker: - name = arg.as_python_constant() - if self.contains(name): - return self.getitem(name) - return super().getitem_const(tx, arg) - - def maybe_getitem_const(self, arg: VariableTracker) -> VariableTracker | None: - name = arg.as_python_constant() - if self.contains(name): - return self.getitem(name) - return None - - def getitem_const_raise_exception_if_absent(self, tx, arg): - name = arg.as_python_constant() - if self.contains(name): - return self.getitem(name) - return super().getitem_const_raise_exception_if_absent(tx, arg) - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if name in ("items", "keys", "values"): - if args or kwargs: - raise_args_mismatch( - tx, - name, - "0 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - if self.source: - tx.output.guard_on_key_order.add(self.source) - merged_items = self._get_merged_dict(tx) - merged_dict = ConstDictVariable(merged_items, user_cls=dict) - - if name == "items": - return DictItemsVariable(merged_dict) - elif name == "keys": - return DictKeysVariable(merged_dict) - elif name == "values": - return DictValuesVariable(merged_dict) - elif name == "get": - if len(args) not in (1, 2): - raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") - name = args[0].as_python_constant() - if self.contains(name): - return self.getitem(name) - return CONSTANT_VARIABLE_NONE if len(args) == 1 else args[1] - return super().call_method(tx, name, args, kwargs) - - def _get_merged_dict( - self, tx: "InstructionTranslator" - ) -> dict[VariableTracker, VariableTracker]: - """Get all items as a proper dict, merging dict_proxy and side effects.""" - Hasher = ConstDictVariable._HashableTracker - - def make_key(k): - return Hasher(VariableTracker.build(tx, k)) - - merged = {} - - for k, v in self.dict_proxy.items(): - merged[make_key(k)] = v - - d = self.items.side_effects.store_attr_mutations.get(self.items.item, {}) - for k, v in d.items(): - if isinstance(v, variables.DeletedVariable): - key_obj = make_key(k) - merged.pop(key_obj, None) - else: - merged[make_key(k)] = v - - return merged - # Mutations to __dict__ are tracked through side effects (SideEffectsProxyDict), # so we don't need to install guards. Guard installation is overridden to no-op. def install_dict_keys_match_guard(self) -> None: diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index d6bd2f49a8d58..9714a9960ed0d 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -129,6 +129,9 @@ def is_group_member_type(cls, value: object) -> bool: return type(value) is _WorldMeta + def python_type(self) -> type: + return type(self.value) + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "WORLD": from .builder import SourcelessBuilder diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index b2e12c7e6c763..05b5fcf51cb16 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -53,6 +53,7 @@ ObservedGeneratorExit, ObservedUserStopIteration, raise_observed_exception, + raise_type_error, StepUnsupported, unimplemented, Unsupported, @@ -82,16 +83,10 @@ from .base import ( AsPythonConstantNotImplementedError, AttributeMutationNew, - raise_type_error_exc, ValueMutationNew, VariableTracker, ) -from .constant import ( - CONSTANT_VARIABLE_FALSE, - CONSTANT_VARIABLE_NONE, - CONSTANT_VARIABLE_TRUE, - ConstantVariable, -) +from .constant import ConstantVariable from .user_defined import UserDefinedObjectVariable @@ -394,17 +389,7 @@ def get_source(self) -> Source | None: def get_dict_vt(self, tx: "InstructionTranslator") -> "DunderDictVariable": if self.dict_vt is None: - dict_proxy: dict[str, VariableTracker] = {} - - if not istype(self, NestedUserFunctionVariable): - fn = self.get_function() - dict_proxy = { - name: VariableTracker.build( - tx, value, source=self.source and AttrSource(self.source, name) - ) - for name, value in fn.__dict__.items() - } - self.dict_vt = variables.DunderDictVariable.create(tx, self, dict_proxy) + self.dict_vt = variables.DunderDictVariable.create(tx, self) return self.dict_vt def call_method( @@ -500,7 +485,7 @@ def call_function( self.get_name() == "patch_track_step_called" and self.get_filename().endswith("torch/optim/lr_scheduler.py") ): - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) # type: ignore[attr-defined] def call_obj_hasattr( @@ -530,6 +515,9 @@ def should_allow_nested_graph_breaks(self) -> bool: class UserFunctionVariable(BaseUserFunctionVariable): """Some unsupported user-defined global function""" + # PyFunction_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/funcobject.c#L1046 + _cpython_type = types.FunctionType + _nonvar_fields = { "fn", "is_constant", @@ -587,6 +575,11 @@ def as_python_constant(self) -> Any: # subclasses (such as methods) usually aren't a constant return super().as_python_constant() + def get_real_python_backed_value(self) -> Any: + if istype(self, UserFunctionVariable): + return self.fn + return super().get_real_python_backed_value() + def self_args(self) -> list[VariableTracker]: return [] @@ -678,7 +671,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker if name == "__dict__": return super().var_getattr(tx, name) elif name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(self.fn, name)) + ) source = self.get_source() return fn_var_getattr(tx, self.fn, source, name) @@ -742,7 +737,7 @@ def call_function( ) from e elif self.fn is torch._dynamo.bytecode_debugger.breakpoint: tx.output._emit_debugger_breakpoint = True - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) # Handle a `nonstrict_trace(fn)` call elif self.fn is torch._dynamo.nonstrict_trace: bound = inspect.signature(self.fn).bind(*args, **kwargs) @@ -761,7 +756,7 @@ def call_function( if not isinstance(fn_var, UserFunctionVariable): fn_name = fn_var.get_name() - msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." # noqa: B950 + msg = f"Applying `nonstrict_trace` to function <{fn_name}>; however, `nonstrict_trace` currently requires the function to be defined outside `torch.compile` region." unimplemented( gb_type="Limitation of `nonstrict_trace", context=f"{self}", @@ -1071,6 +1066,9 @@ def __init__( assert isinstance(fn, types.BuiltinMethodType) self.fn = fn + def python_type(self) -> type: + return types.BuiltinMethodType + @staticmethod def is_supported_builtin_method(obj: Any) -> bool: method_self = obj.__self__ @@ -1096,6 +1094,9 @@ def call_function( class LocalGeneratorObjectVariable(VariableTracker): + # PyGen_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/genobject.c#L814 + _cpython_type = types.GeneratorType + def __init__( self, code: types.CodeType, @@ -1192,8 +1193,8 @@ def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: if name in self.python_type().__dict__: - return CONSTANT_VARIABLE_TRUE - return CONSTANT_VARIABLE_FALSE + return ConstantVariable.create(True) + return ConstantVariable.create(False) def has_unpack_var_sequence(self, tx: "InstructionTranslator") -> bool: return False @@ -1280,13 +1281,13 @@ def call_method( tracer = self.inline_tracer if self._is_generator_just_started() or self._is_generator_exhausted(): tracer.generator_exhausted = True - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) # Raise GeneratorExit to see if user code catches it. Any other exception # is propagated to the parent frame. try: self._setup_exception( - tx, variables.ExceptionVariable(GeneratorExit, ()) + tx, variables.ExceptionVariable(GeneratorExit, []) ) # There's an extra block on Python 3.12+ to handle StopIteration # see: https://github.com/python/cpython/blob/8f93dd8a8f237b277abad20d566df90c5cbd7f1e/Objects/genobject.c#L394-L397 @@ -1309,11 +1310,11 @@ def call_method( and tracer.next_instruction.opname == "CALL_INTRINSIC_1" ): tracer.generator_exhausted = True - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) except ObservedGeneratorExit: # If it doesn't catch, we just return None, as per the text above tracer.generator_exhausted = True - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) try: # Raise RuntimeError if the generator yields any other value @@ -1321,7 +1322,7 @@ def call_method( raise_observed_exception(RuntimeError, tx) except ObservedGeneratorExit: tracer.generator_exhausted = True - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) except ObservedUserStopIteration: # In Python 3.13+, one can capture GeneratorExit and return a value # See test_generator.py::test_close_capture_GeneratorExit_return @@ -1413,7 +1414,7 @@ def call_method( exc_type = type("__InternalThrowException", (Exception,), {}) try: - self._setup_exception(tx, variables.ExceptionVariable(exc_type, ())) + self._setup_exception(tx, variables.ExceptionVariable(exc_type, [])) self.next_variable(tx) except get_dynamo_observed_exception(exc_type): # We should get back the exception raised before. @@ -1446,6 +1447,9 @@ class LocalGeneratorFunctionVariable(BaseUserFunctionVariable): This is a wrapper around (Nested)UserFunctionVariable """ + def python_type(self) -> type: + return types.FunctionType + def __init__( self, vt: BaseUserFunctionVariable, @@ -1517,6 +1521,9 @@ def call_function( source=self.source, ) + def get_real_python_backed_value(self) -> object: + return self.vt.get_real_python_backed_value() + class FunctionDecoratedByContextlibContextManagerVariable( LocalGeneratorFunctionVariable @@ -1556,6 +1563,9 @@ def _build_inline_tracer( class UserMethodVariable(UserFunctionVariable): """Some unsupported user-defined method""" + # PyMethod_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/classobject.c#L332 + _cpython_type = types.MethodType + def __init__( self, fn: Callable[..., Any], @@ -1668,6 +1678,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return VariableTracker.build(tx, self.fn, self.source_fn) # type: ignore[arg-type] return super().var_getattr(tx, name) + def get_real_python_backed_value(self) -> Any: + return self.fn + class WrappedUserMethodVariable(UserMethodVariable): def __init__( @@ -1688,6 +1701,11 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: + if config.nested_graph_breaks: + wrapper_fn = UserFunctionVariable(polyfills._fn_with_ctx) + return wrapper_fn.call_function( + tx, [self.context, self.wrapped] + list(args), kwargs + ) self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) @@ -1717,6 +1735,11 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: + if config.nested_graph_breaks: + wrapper_fn = UserFunctionVariable(polyfills._fn_with_ctx) + return wrapper_fn.call_function( + tx, [self.context, self.wrapped] + list(args), kwargs + ) self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) @@ -1913,7 +1936,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker d = getattr(self, "defaults", None) return d.as_python_constant() if d else ConstantVariable.create(None) elif name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(types.FunctionType, name)) + ) else: return super().var_getattr(tx, name) @@ -1939,7 +1964,7 @@ def call_obj_hasattr( return VariableTracker.build(tx, hasattr(self, "defaults")) vt = ConstantVariable.create(name) if vt in self.get_dict_vt(tx): - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) return super().call_obj_hasattr(tx, name) def has_self(self) -> bool: @@ -2068,6 +2093,11 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: + if config.nested_graph_breaks: + wrapper_fn = UserFunctionVariable(polyfills._fn_with_ctx) + return wrapper_fn.call_function( + tx, [self.context, self.wrapped] + list(args), kwargs + ) self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) @@ -2094,6 +2124,9 @@ def __init__(self, value: Any, reason: str | None = None, **kwargs: Any) -> None def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + @classmethod def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable": # Use closure match guard (i.e. guard on __code__ object instead of @@ -2113,7 +2146,14 @@ def create_with_source(cls, value: Any, source: Source) -> "SkipFunctionVariable guard_on_source.make_guard(GuardBuilder.CLOSURE_MATCH) elif inspect.isbuiltin(value): - install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) + # Bound builtin methods (e.g. obj.__reduce_ex__) are created fresh + # on every attribute access, so their id() is unstable. Skip the + # id-based BUILTIN_MATCH guard for them — the type guard on + # the owner object is sufficient. + if not hasattr(value, "__self__") or isinstance( + value.__self__, types.ModuleType + ): + install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) elif not is_wrapper_or_member_descriptor(value): # These descriptors are not guaranteed to return the same object on # attribute lookup. They are unlikely to be changed, so we can skip @@ -2271,12 +2311,14 @@ def call_function( torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints)) if qualname == "allow_in_graph": explanation = ( - "Found an allow_in_graph decorator to a function which " - "is created inside the parent function that is getting " - "compiled. This is not supported for now." + "torch.compiler.allow_in_graph (or torch._dynamo.allow_in_graph) " + "was called inside a compiled region. Dynamically annotating functions " + "inside a compiled region is not supported." ) - # pyrefly: ignore [implicit-any] - hints = [] + hints = [ + "Apply @torch.compiler.allow_in_graph as a decorator before compilation, " + "not inside the compiled function.", + ] if self.reason: reason = self.reason else: @@ -2297,7 +2339,9 @@ def call_obj_hasattr( def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(self.value, name)) + ) return fn_var_getattr(tx, self.value, self.source, name) @@ -2333,6 +2377,11 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: + if config.nested_graph_breaks: + wrapper_fn = UserFunctionVariable(polyfills._fn_with_ctx) + return wrapper_fn.call_function( + tx, [self.context, self.wrapped] + list(args), kwargs + ) self.context.enter(tx) result = super().call_function(tx, args, kwargs) self.context.exit(tx) @@ -2352,6 +2401,9 @@ class WrapperUserFunctionVariable(BaseUserFunctionVariable): __script_if_tracing_wrapper have the original attr at "__original_fn". """ + def python_type(self) -> type: + return types.FunctionType + def __init__(self, wrapper_obj: Any, attr_to_trace: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.wrapper_obj = wrapper_obj @@ -2426,6 +2478,9 @@ def call_function( kwargs, ) + def get_real_python_backed_value(self) -> object: + return getattr(self.wrapper_obj, self.attr_to_trace) + class WrapperUserMethodVariable(WrapperUserFunctionVariable): """ @@ -2434,6 +2489,9 @@ class WrapperUserMethodVariable(WrapperUserFunctionVariable): WrapperUserFunctionVariable in `call_function` method. """ + def python_type(self) -> type: + return types.MethodType + def __init__( self, wrapper_obj: Any, @@ -2565,7 +2623,7 @@ def call_function( "`P2POp` used incorrectly" ) - ops = list() + ops: list[VariableTracker] = list() peers = list() tags = list() tensors = list() @@ -2662,7 +2720,7 @@ def call_function( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) return variables.UserDefinedClassVariable( value, @@ -2679,6 +2737,9 @@ def call_function( class FunctoolsPartialVariable(VariableTracker): + # partial_type_spec: https://github.com/python/cpython/blob/v3.13.0/Modules/_functoolsmodule.c#L538 + _cpython_type = functools.partial + _nonvar_fields = { "original_cache_hash", *VariableTracker._nonvar_fields, @@ -2752,7 +2813,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker items = {VariableTracker.build(tx, k): v for k, v in self.keywords.items()} return variables.ConstDictVariable(items, source=source) if name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(functools.partial, name)) + ) raise_observed_exception(AttributeError, tx) def as_python_constant(self) -> Any: @@ -2927,7 +2990,7 @@ def call_method( method = getattr(self.fn, name, None) if not (method or is_function(method)): - raise_type_error_exc(tx, f"Cannot find callable {name} in {self.fn}") + raise_type_error(tx, f"Cannot find callable {name} in {self.fn}") options = {} if self.source: options["source"] = AttrSource(self.source, name) @@ -2943,6 +3006,9 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def python_type(self) -> type: + return types.BuiltinFunctionType + def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": if len(tx.exn_vt_stack): exn = tx.exn_vt_stack[-1] @@ -2951,9 +3017,9 @@ def exc_info(self, tx: "InstructionTranslator") -> "variables.TupleVariable": items = [VariableTracker.build(tx, typ), exn, tb] else: items = [ - variables.CONSTANT_VARIABLE_NONE, - variables.CONSTANT_VARIABLE_NONE, - variables.CONSTANT_VARIABLE_NONE, + ConstantVariable.create(None), + ConstantVariable.create(None), + ConstantVariable.create(None), ] return variables.TupleVariable(items) # type: ignore[arg-type] @@ -3045,9 +3111,12 @@ def wrap_user_defined_obj( from .builder import VariableBuilder assert tx is not None + # Route through VariableBuilder.__call__ so already-tracked mutable + # objects (for example autotuner config lists) are reused instead of + # being registered for mutation twice in the same trace. wrapped_user_obj = VariableBuilder( tx, AttrSource(variable.kernel_source, f"{name}") - )._wrap(user_obj) + )(user_obj) return wrapped_user_obj def maybe_unpack_configs( @@ -3086,7 +3155,7 @@ def call_getitem( # type: ignore[override] kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=args[0], - kernel_source=variable.source, + kernel_source=variable.kernel_source, ) def call_HOP( @@ -3171,15 +3240,18 @@ class TritonKernelVariable(VariableTracker): grid: "TritonGridType" kernel: "TritonKernelType" kernel_idx: int | None - kernel_source: "AttrSource" + kernel_source: Source | None def __init__( self, kernel: Any, kernel_idx: int | None, grid: Any, **kwargs: Any ) -> None: - self.kernel_source = kwargs.pop("kernel_source", None) + self.kernel_source = kwargs.pop("kernel_source", kwargs.get("source")) super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) + def python_type(self) -> type: + return type(self.kernel) + def call_function( self, tx: "InstructionTranslator", @@ -3190,6 +3262,15 @@ def call_function( self, args, kwargs, tx ) + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # Triton kernel[grid] — triton-specific, not a CPython slot. + # TODO(follow-up): add test for invalid key type + return dynamo_triton_hopifier_singleton.call_getitem(self, [key]) + def call_method( self, tx: "InstructionTranslator", @@ -3197,9 +3278,7 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if name == "__getitem__": - return dynamo_triton_hopifier_singleton.call_getitem(self, args) - elif name == "run": + if name == "run": return dynamo_triton_hopifier_singleton.call_run(self, args, kwargs, tx) # type: ignore[return-value] # Bail out to parent's implementation @@ -3297,6 +3376,9 @@ def __init__( super().__init__(**kwargs) self.rank = rank + def python_type(self) -> type: + return types.FunctionType + def call_function( self, tx: "InstructionTranslator", @@ -3321,7 +3403,7 @@ def call_function( if self.rank == 1: if len(args) + len(kwargs) != 4: - raise_type_error_exc( + raise_type_error( tx, f"TMA metadata rank=1 requires exactly 4 arguments, got {len(args) + len(kwargs)}", ) @@ -3333,7 +3415,7 @@ def call_function( ] else: if len(args) + len(kwargs) != 6: - raise_type_error_exc( + raise_type_error( tx, f"TMA metadata rank=2 requires exactly 6 arguments, got {len(args) + len(kwargs)}", ) @@ -3359,6 +3441,9 @@ def call_function( class CreateTMADescriptorStableVariable(VariableTracker): + def python_type(self) -> type: + return types.FunctionType + def call_function( self, tx: "InstructionTranslator", @@ -3396,9 +3481,8 @@ def call_function( kwargs: dict[str, VariableTracker], ) -> VariableTracker: if len(args) != 1: - raise_type_error_exc( - tx, - f"pytree_get_node_type requires exactly 1 argument, got {len(args)}", + raise_type_error( + tx, f"pytree_get_node_type requires exactly 1 argument, got {len(args)}" ) type_source = None if args[0].source: @@ -3435,13 +3519,12 @@ def call_function( ) -> VariableTracker: # tree_is_leaf(tree, is_leaf=None) if len(args) < 1 or len(args) > 2: - raise_type_error_exc( - tx, - f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}", + raise_type_error( + tx, f"tree_is_leaf requires 1 or 2 arguments, got {len(args)}" ) # Check if is_leaf parameter is provided - is_leaf = kwargs.get("is_leaf", CONSTANT_VARIABLE_NONE) + is_leaf = kwargs.get("is_leaf", ConstantVariable.create(None)) if len(args) == 2: is_leaf = args[1] @@ -3495,14 +3578,67 @@ def call_function( ) -class TritonSetAllocatorSkipVariable(SkipFunctionVariable): - """ - Skip variable for triton.set_allocator with a clear message to move it outside the compiled region. +def emit_noargs_leaf_function_to_graph( + tx: "InstructionTranslator", + real_impl: Callable[[], None], + name: str, +) -> None: + """Emit an invoke_leaf_function node for a side-effectful function with no + tensor inputs or outputs. + + The function is captured as a closure inside _LeafCallable objects and + registered as a static attribute on the graph module. Because + invoke_leaf_function is registered as EffectType.ORDERED, effect tokens + prevent DCE and maintain execution ordering relative to other ops. + + Use this when Dynamo needs to preserve a pure-side-effect call (like + setting global runtime state) in the compiled graph so that it replays + at the correct position at runtime. """ + import torch.utils._pytree as pytree + from torch._higher_order_ops.invoke_leaf_function import ( + _LeafCallable, + invoke_leaf_function, + make_leaf_function_wrappers, + ) + + def fake_impl(): + return None + + captured_out_spec: list[pytree.TreeSpec | None] = [None] + wrapped_real, wrapped_fake = make_leaf_function_wrappers( + real_impl, fake_impl, captured_out_spec + ) + + real_callable = _LeafCallable(wrapped_real) + fake_callable = _LeafCallable(wrapped_fake) + input_spec = pytree.tree_flatten(((), {}))[1] + + def make_proxy(attr_name: str, val: Any) -> Any: + proxy = tx.output.register_static_attr_and_return_proxy(attr_name, val) + proxy.node.type = type(val) + return proxy + + invoke_args = ( + make_proxy(f"{name}_real_fn", real_callable), + make_proxy(f"{name}_fake_fn", fake_callable), + make_proxy(f"{name}_input_spec", input_spec), + "", # mutated_flat_indices + ) + tx.output.create_proxy("call_function", invoke_leaf_function, invoke_args, {}) + + +class TritonSetAllocatorVariable(VariableTracker): + """Trace triton.set_allocator as an invoke_leaf_function node in the + graph so that it executes at the right point at runtime, ordered by + effect tokens.""" def __init__(self, value: Any, **kwargs: Any) -> None: - reason = "triton.set_allocator is not supported inside torch.compile" - super().__init__(value, reason=reason, **kwargs) + super().__init__(**kwargs) + self.value = value + + def python_type(self) -> type: + return type(self.value) def call_function( self, @@ -3510,15 +3646,16 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - unimplemented( - gb_type="triton.set_allocator not supported", - context="triton.set_allocator called inside compiled region", - explanation=( - "triton.set_allocator is not supported inside torch.compile. " - "It modifies global Triton allocator state and cannot be traced." - ), - hints=[ - "Move triton.set_allocator() outside of the torch.compile region " - "(call it before the compiled function)." - ], - ) + assert len(args) == 1 and not kwargs + alloc_fn = args[0].as_python_constant() + + # Emit an invoke_leaf_function node so it runs at runtime. + set_allocator = self.value + + def real_impl(): + set_allocator(alloc_fn) + return None + + emit_noargs_leaf_function_to_graph(tx, real_impl, "set_alloc") + + return ConstantVariable.create(None) diff --git a/torch/_dynamo/variables/hashable.py b/torch/_dynamo/variables/hashable.py new file mode 100644 index 0000000000000..5896479eaa536 --- /dev/null +++ b/torch/_dynamo/variables/hashable.py @@ -0,0 +1,163 @@ +""" +Hashability utilities for PyTorch Dynamo variable tracking. + +This module provides the HashableTracker wrapper class and associated utilities +for making VariableTracker instances usable as dictionary keys and set elements +during symbolic execution. Used by both ConstDictVariable and SetVariable. +""" + +from typing import TYPE_CHECKING + +import torch + +from .. import variables +from ..exc import raise_observed_exception +from ..utils import specialize_symnode +from .base import VariableTracker + + +if TYPE_CHECKING: + from torch._dynamo.symbolic_convert import InstructionTranslator + + +def raise_unhashable( + arg: VariableTracker, tx: "InstructionTranslator | None" = None +) -> None: + if tx is None: + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + try: + arg_type = arg.python_type() + except Exception: + arg_type = type(arg) + + raise_observed_exception( + TypeError, + tx, + args=[ + f"unhashable type: {arg_type!r} and variable tracker = {type(arg.realize())}", + ], + ) + + +def is_hashable(x: VariableTracker) -> bool: + # NB - performing isinstance check on a LazVT realizes the VT, accidentally + # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at + # the underlying value without realizing the VT. Consider updating the + # lazyVT `is_hashable` method if you see unnecessary guarding for a key VT. + if ( + isinstance(x, variables.LazyVariableTracker) + and not x.is_realized() + and x.is_hashable() + ): + return True + return x.is_python_hashable() + + +class HashableTracker: + """ + Class that wraps a VariableTracker and makes it hashable. + Note that it's fine to put VTs into dictionaries and sets, but doing so + does not take into account aliasing. + """ + + _MISSING = object() + + def __init__(self, vt: VariableTracker) -> None: + # We specialize SymNodes + vt = specialize_symnode(vt) + + # If Dynamo does not know the hashability of the vt, it will raise unsupported here + # TODO(follow-up): check tp_hash via C-level slot detection — unhashable keys + # (e.g. list) should raise TypeError, not graph break via is_python_hashable/unimplemented. + if not is_hashable(vt): + raise_unhashable(vt) + self.vt = vt + + @classmethod + def _maybe_constant_torch_size(cls, vt: VariableTracker) -> object: + from .lists import SizeVariable + from .tensor import TensorVariable + + if ( + isinstance(vt, variables.LazyVariableTracker) + and not vt.is_realized() + and isinstance(vt.original_value(), torch.Size) + ): + return vt.original_value() + + if not isinstance(vt, SizeVariable): + return cls._MISSING + + items = [] + for item in vt.items: + if item.is_python_constant(): + items.append(item.as_python_constant()) + continue + + if isinstance(item, TensorVariable): + proxy = getattr(item, "proxy", None) + node = getattr(proxy, "node", None) + meta = getattr(node, "meta", None) if node is not None else None + example_value = ( + meta.get("example_value") if isinstance(meta, dict) else None + ) + constant = getattr(example_value, "constant", None) + + if isinstance(constant, torch.Tensor) and constant.numel() == 1: + items.append(constant.item()) + continue + + return cls._MISSING + + return torch.Size(items) + + def __hash__(self) -> int: + """ + Computes the hash value for the wrapped VariableTracker. + + For unrealized LazyVariableTrackers, uses the hash of the original value + to avoid realizing the tracker and inserting unnecessary guards. + For all other cases, delegates to the VariableTracker's get_python_hash method. + + Returns: + The hash value of the underlying variable tracker + """ + if ( + isinstance(self.vt, variables.LazyVariableTracker) + and not self.vt.is_realized() + and self.vt.is_hashable() + ): + return hash(self.vt.original_value()) + + maybe_constant = self._maybe_constant_torch_size(self.vt) + if maybe_constant is not self._MISSING: + return hash(maybe_constant) + + return self.vt.get_python_hash() + + def __eq__(self, other: object) -> bool: + """ + Checks equality between two HashableTracker instances. + + Delegates to the VariableTracker's is_python_equal method to compare + the underlying variable trackers for Python-level equality. + + Args: + other: Another HashableTracker instance to compare with + + Returns: + True if the underlying variable trackers are Python-equal, False otherwise + """ + if not isinstance(other, HashableTracker): + return False + if self.vt is other.vt: + return True + + self_constant = self._maybe_constant_torch_size(self.vt) + other_constant = self._maybe_constant_torch_size(other.vt) + if self_constant is not self._MISSING and other_constant is not self._MISSING: + return self_constant == other_constant + + return self.vt.is_python_equal(other.vt) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index cef823174fd44..653272d586e8d 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -28,17 +28,18 @@ import warnings from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Any, Literal, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Literal, Optional, TYPE_CHECKING, Union import torch._C import torch.fx import torch.nn from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import get_fake_value -from torch._dynamo.variables.constant import CONSTANT_VARIABLE_NONE, ConstantVariable +from torch._dynamo.variables.constant import ConstantVariable from torch._dynamo.variables.ctx_manager import RepararametrizeModuleContextVariable from torch._dynamo.variables.functions import UserFunctionVariable from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable +from torch._dynamo.variables.script_object import TorchScriptObjectVariable from torch._dynamo.variables.tensor import SymNodeVariable, TensorVariable from torch._guards import Source from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions @@ -63,6 +64,7 @@ from .dicts import ConstDictVariable from .lazy import LazyVariableTracker from .lists import ListVariable, TupleVariable +from .sets import SetVariable if TYPE_CHECKING: @@ -232,6 +234,9 @@ def find_mismatched_vars( elif isinstance(var, ConstDictVariable): for value in var.items.values(): mismatched_vars.update(find_mismatched_vars(value, types, allow_none)) + elif isinstance(var, SetVariable): + for key in var.items: + mismatched_vars.update(find_mismatched_vars(key.vt, types, allow_none)) else: if not isinstance(var, types) and not (allow_none and var.is_constant_none()): mismatched_vars.add(var) @@ -322,8 +327,17 @@ def overwrite_tensor_vt_proxy( # while still allowing `body_r` to contain arbitrary Python objects. # pyrefly: ignore[missing-attribute] for orig_vt, subgraph_vt in zip(graph_output_vts, flat_variable.items): - if isinstance(orig_vt, (variables.SymNodeVariable, variables.TensorVariable)): - assert subgraph_vt.is_tensor() or isinstance(subgraph_vt, SymNodeVariable) + if isinstance( + orig_vt, + ( + variables.SymNodeVariable, + variables.TensorVariable, + TorchScriptObjectVariable, + ), + ): + assert subgraph_vt.is_tensor() or isinstance( + subgraph_vt, (SymNodeVariable, TorchScriptObjectVariable) + ) orig_vt.proxy = subgraph_vt.proxy @@ -558,6 +572,24 @@ def check_and_track(self, proxy_node: Proxy) -> bool: return True +def taint_filtered_vt(vt: VariableTracker) -> None: + """Mark a VT as filtered due to aliasing so it raises a clear error if used.""" + original_as_proxy = vt.as_proxy + + def tainted_as_proxy() -> Any: + proxy = original_as_proxy() + raise RuntimeError( + f"An intermediate tensor '{proxy.node.name}' created inside a " + f"higher-order op subgraph aliases an input or output, so it was " + f"not included in the subgraph outputs. However, it is being used " + f"in the outer graph (e.g., via a side effect like list.append). " + f"To fix this, clone the tensor before capturing it " + f"(e.g., use tensor.clone() instead of tensor)." + ) + + vt.as_proxy = tainted_as_proxy # type: ignore[method-assign] + + def collect_intermediate_outputs( tx: "InstructionTranslator", subtracer: "SubgraphTracer", @@ -574,7 +606,7 @@ def collect_intermediate_outputs( tracker.collect_from_inputs(tx) tracker.collect_from_outputs(graph_output_vts) - for out in subtracer.tracked_tensor_or_symint_vt: + for out in subtracer.tracked_proxyable_vt: proxy = out.as_proxy() # Skip if already in output @@ -590,12 +622,11 @@ def collect_intermediate_outputs( else: # Filter out intermediates that alias with inputs or outputs. # This is needed for HOPs like invoke_subgraph that don't support aliasing. - # TODO: If a filtered intermediate is captured by side effects (e.g., appended - # to a list), it will fail later with "does not belong to this Graph" error - # when the outer graph tries to use it. See test_side_effect_with_aliased_intermediate. assert tracker is not None if tracker.check_and_track(proxy.node): extra_outputs.append(out) + else: + taint_filtered_vt(out) return extra_outputs @@ -898,6 +929,7 @@ def are_same_graph_modules( from torch._subclasses.fake_tensor import extract_tensor_metadata # Maps the equivalent nodes from a to b + # pyrefly: ignore [implicit-any] node_map = {} def check_all_args(a_nodes: Iterable[Any], b_nodes: Iterable[Any]) -> bool: @@ -1696,10 +1728,12 @@ def gn(x): graph_output_vt_list = [] def visit(vt: VariableTracker) -> None: - if vt.is_tensor() or isinstance(vt, SymNodeVariable): + if vt.is_tensor() or isinstance( + vt, (SymNodeVariable, TorchScriptObjectVariable) + ): graph_output_vt_list.append(vt) - VariableTracker.visit(visit, output) + VariableTracker.visit(visit, output, side_effects=tx.output.side_effects) graph_output_vts = tuple(graph_output_vt_list) # NOTE - [Return subgraph intermediates as subgraph outputs] @@ -1740,11 +1774,13 @@ def visit(vt: VariableTracker) -> None: # nested_compile_region and autograd.Function. Today, its safe # because we error out on seeing a side-effect. - allow_side_effects = ( - allow_side_effects - or tx.output.current_tracer.traced_with_externally_visible_side_effects + traced_externally = ( + tx.output.current_tracer.traced_with_externally_visible_side_effects ) - if allow_side_effects: + has_side_effects = ( + subtracer.side_effect_stack is not None or traced_externally + ) + if (allow_side_effects or traced_externally) and has_side_effects: extra_outputs = collect_intermediate_outputs( tx, subtracer, graph_output_vts, filter_aliased_intermediates ) @@ -1823,7 +1859,7 @@ def visit(vt: VariableTracker) -> None: f"fall back to eager-mode PyTorch, which could lead to a slowdown." ) log.info(msg) - log.info(ex) # noqa: G200 + log.info(ex) raise ex @@ -2023,7 +2059,7 @@ def speculate_subgraph( f"fall back to eager-mode PyTorch, which could lead to a slowdown." ) log.info(msg) - log.info(ex) # noqa: G200 + log.info(ex) raise ex @@ -2476,7 +2512,7 @@ def _call_function( def validate_subgraph_output_types( output: VariableTracker | Sequence[VariableTracker], ) -> None: - """Verify that that the output of the subgraph is a tensor, + """Verify that the output of the subgraph is a tensor, int, bool, SymBool, or SymInt. """ from . import TensorVariable @@ -2486,10 +2522,12 @@ def validate_subgraph_output_types( ): for out in non_tensor_output: if ( - isinstance(out, SymNodeVariable) and out.python_type() in (int, bool) - ) or ( - out.is_python_constant() - and isinstance(out.as_python_constant(), (int, bool)) + (isinstance(out, SymNodeVariable) and out.python_type() in (int, bool)) + or ( + out.is_python_constant() + and isinstance(out.as_python_constant(), (int, bool)) + ) + or isinstance(out, TorchScriptObjectVariable) ): continue unimplemented( @@ -3254,15 +3292,6 @@ def call_function( args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if not torch._dynamo.config.inline_inbuilt_nn_modules: - unimplemented( - gb_type="torch.func.functional_call capture is disabled", - context="", - explanation="torch.func.functional_call capture is disabled", - hints=[ - "Set `torch._dynamo.config.inline_inbuilt_nn_modules=True` to enable.", - ], - ) return super().call_function(tx, args, kwargs) @@ -3749,11 +3778,13 @@ def _call_function( unpacked_sequence = args[1].unpack_var_sequence(tx) # TODO (tmanlaibaatar) support pytree here for arg in unpacked_sequence: - if isinstance(arg, (ListVariable, TupleVariable, ConstDictVariable)): + if isinstance( + arg, (ListVariable, TupleVariable, ConstDictVariable, SetVariable) + ): unimplemented( gb_type="strict_mode: improper args", context=f"args: {args}, kwargs: {kwargs}", - explanation="strict_mode higher order op expects flat inputs (list/tuple/dict)", + explanation="strict_mode higher order op expects flat inputs (list/tuple/dict/set)", hints=[ *graph_break_hints.USER_ERROR, ], @@ -3989,6 +4020,31 @@ def _call_function( ) +class InlineAsmElementwiseHigherOrderVariable(TorchHigherOrderOperatorVariable): + _HOP_NAME = "inline_asm_elementwise" + + def _call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + + p_args = tuple(arg.as_proxy() for arg in args) + p_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()} + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=p_args, + kwargs=p_kwargs, + ), + example_value=None, + ) + + class AutoFunctionalizeHigherOrderVariable(TorchHigherOrderOperatorVariable): _HOP_NAME = "torch.ops.higher_order.auto_functionalized" @@ -4017,6 +4073,14 @@ def _call_function( class FlexAttentionBackwardHighOrderVariable(TorchHigherOrderOperatorVariable): _HOP_NAME = "torch.ops.higher_order.flex_attention_backward" + @staticmethod + def _uses_pretraced_graphs( + fw_graph: VariableTracker, joint_graph: VariableTracker + ) -> bool: + return not joint_graph.is_constant_none() or isinstance( + fw_graph, UnspecializedNNModuleVariable + ) + def proxy_submod( self, tx: "InstructionTranslator", arg: UnspecializedNNModuleVariable ) -> Proxy: @@ -4036,6 +4100,95 @@ def to_proxy(self, tx: "InstructionTranslator", arg: VariableTracker) -> Any: else: return arg.as_proxy() + def create_wrapped_node( + self, + tx: "InstructionTranslator", + query: VariableTracker, + fn: VariableTracker, + fn_name: str, + other_buffers: Sequence[VariableTracker], + ) -> tuple[Proxy, tuple[Proxy, ...], torch.fx.GraphModule]: + from .._trace_wrapped_higher_order_op import TransformGetItemToIndex + + def create_scalar() -> VariableTracker: + return query.call_method( + tx, + "new_empty", + [VariableTracker.build(tx, [])], + {"dtype": VariableTracker.build(tx, torch.int32)}, + ) + + with discard_graph_changes(tx): + bhmn = [create_scalar() for _ in range(4)] + if fn_name == "score_mod": + scores_require_grad: bool = query.requires_grad # type: ignore[attr-defined] + score = query.call_method( + tx, + "new_empty", + [VariableTracker.build(tx, [])], + {"requires_grad": VariableTracker.build(tx, scores_require_grad)}, + ) + new_args = [score, *bhmn, *other_buffers] + else: + assert fn_name == "mask_fn", "Illegal function name: " + fn_name + new_args = [*bhmn, *other_buffers] + + with TransformGetItemToIndex(): + ( + (_body_output, _body_spec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn, + new_args, + {}, + description=f"{self._HOP_NAME}: {fn_name}", + source_target=self.value, + set_subgraph_inputs="flatten_manual", + ) + + gm = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = tx.output.install_subgraph(fn_name, gm) + return make_attr(tx, body_name), tuple(body_lifted_freevars), gm + + @staticmethod + def _buffer_example_value(buffer: VariableTracker) -> Any: + proxy = buffer.as_proxy() + if isinstance(proxy, Proxy): + return proxy.node.meta["example_value"] + return proxy + + def _derive_joint_graph( + self, + tx: "InstructionTranslator", + query: VariableTracker, + fw_graph_gm: torch.fx.GraphModule, + score_mod_other_buffers: TupleVariable, + fw_graph_lifted_args: tuple[Proxy, ...], + ) -> Proxy: + from torch._higher_order_ops.flex_attention import create_fw_bw_graph + + query_example = query.as_proxy().node.meta["example_value"] + example_vals = ( + query_example.new_zeros((), requires_grad=True), + query_example.new_zeros((), dtype=torch.int), + query_example.new_zeros((), dtype=torch.int), + query_example.new_zeros((), dtype=torch.int), + query_example.new_zeros((), dtype=torch.int), + ) + all_buffer_examples = tuple( + self._buffer_example_value(buf) for buf in score_mod_other_buffers.items + ) + tuple(buf.node.meta["example_value"] for buf in fw_graph_lifted_args) + + _, joint_gm = create_fw_bw_graph(fw_graph_gm, example_vals, all_buffer_examples) + joint_gm = cast(GraphModule, joint_gm) + + submod_name = tx.output.install_subgraph("joint_graph", joint_gm) + p_submod = make_attr(tx, submod_name) + set_example_value(p_submod.node, joint_gm) + return p_submod + def _call_function( self, tx: "InstructionTranslator", @@ -4044,6 +4197,130 @@ def _call_function( ) -> VariableTracker: from .builder import wrap_fx_proxy + if len(args) != 14: + return self._call_function_fallback(tx, args, kwargs) + + ( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) = args + + if ( + not isinstance(block_mask, TupleVariable) + or not isinstance(score_mod_other_buffers, TupleVariable) + or not isinstance(mask_mod_other_buffers, TupleVariable) + or len(block_mask.items) < 1 + ): + return self._call_function_fallback(tx, args, kwargs) + + if self._uses_pretraced_graphs(fw_graph, joint_graph): + return self._call_function_fallback(tx, args, kwargs) + + fw_graph_node, fw_graph_lifted_args, fw_graph_gm = self.create_wrapped_node( + tx, query, fw_graph, "score_mod", score_mod_other_buffers.items + ) + + joint_graph_node = self._derive_joint_graph( + tx, + query, + fw_graph_gm, + score_mod_other_buffers, + fw_graph_lifted_args, + ) + + mask_fn = block_mask.items[-1] + if mask_fn.is_python_constant() and mask_fn.as_python_constant() is None: + mask_fn = VariableTracker.build( + tx, + torch.nn.attention.flex_attention.noop_mask, + source=mask_fn.source, + ) + mask_fn_node, mask_fn_lifted_args, _ = self.create_wrapped_node( + tx, query, mask_fn, "mask_fn", mask_mod_other_buffers.items + ) + + proxied_args = [ + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + TupleVariable(block_mask.items[:-1], source=block_mask.source), + scale, + kernel_options, + ] + inp_args, _ = proxy_args_kwargs(proxied_args, {}) + proxied_score_mod_other_buffers = tuple( + self.to_proxy(tx, arg) for arg in score_mod_other_buffers.items + ) + proxied_mask_mod_other_buffers = tuple( + self.to_proxy(tx, arg) for arg in mask_mod_other_buffers.items + ) + + ( + inp_q, + inp_k, + inp_v, + inp_out, + inp_lse, + inp_grad_out, + inp_grad_lse, + inp_block_mask, + inp_scale, + inp_kernel_options, + ) = inp_args + + block_mask_proxy = tuple(inp_block_mask + (mask_fn_node,)) + + with torch.fx.experimental.proxy_tensor.set_original_aten_op(self.value): + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + self.value, + args=( + inp_q, + inp_k, + inp_v, + inp_out, + inp_lse, + inp_grad_out, + inp_grad_lse, + fw_graph_node, + joint_graph_node, + block_mask_proxy, + inp_scale, + inp_kernel_options, + proxied_score_mod_other_buffers + fw_graph_lifted_args, + proxied_mask_mod_other_buffers + mask_fn_lifted_args, + ), + kwargs={}, + ), + example_value=None, + ) + + def _call_function_fallback( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builder import wrap_fx_proxy + p_args, p_kwargs = None, None try: p_args = tuple(self.to_proxy(tx, arg) for arg in args) @@ -4264,6 +4541,9 @@ def __init__( self.bwd_fn = bwd_fn self.parent_source = parent_source + def python_type(self) -> type: + return types.BuiltinMethodType + def call_function( self, tx: "InstructionTranslator", @@ -4502,6 +4782,7 @@ def trace_forward_graph( enable_grad=None, set_subgraph_inputs="automatic", allow_side_effects=True, + filter_aliased_intermediates=True, tracer=fwd_tracer, ) ) @@ -4573,7 +4854,7 @@ def trace_backward_graph( if i.is_tensor(): bwd_args.append(i) else: - bwd_args.append(CONSTANT_VARIABLE_NONE) + bwd_args.append(ConstantVariable.create(None)) bwd_fn, bwd_args = self.prepare_fn_vt(tx, ctx, "backward", bwd_args) @@ -5135,150 +5416,6 @@ def _call_function( ) -class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): - _HOP_NAME = "torch.ops.higher_order.invoke_subgraph" - _ALLOW_FALLBACK_TO_EAGER = False - supports_input_mutation = True - supports_aliasing = False - allow_side_effects = True - # invoke_subgraph is NOT desugared in AOTAutograd, so the HOP input/output - # shouldn't alias. For checkpoint HOP, we inline it so we don't need - # alias analysis as functionalization would just work on the flat graph. - filter_aliased_intermediates = True - - # pyrefly: ignore[bad-override] - def install_subgraph_in_output_graph( - self, - tx: "InstructionTranslator", - fn_vt: VariableTracker, - fn_args_vt: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - body_gmod: GraphModule, - attr_name: str, - ) -> str: - # Check if the subgraph from speculate_subgraph (body_gmod) and the fake - # inputs have already been seen before. If yes, the subgraph is already - # installed in the output graph and we can just access the subgraph - # using the saved attr name. - - if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): - unimplemented( - gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", - context=str(fn_vt), - explanation="invoke_subgraph does not support non user function variable", - hints=[*graph_break_hints.SUPPORTABLE], - ) - - invoke_subgraph_cache = ( - tx.output.tracing_context.hop_dispatch_set_cache.get_cache( - torch._higher_order_ops.invoke_subgraph - ) - ) - - if isinstance(fn_vt, UserFunctionVariable): - fn_id = id(fn_vt.get_function()) - fn_name = fn_vt.get_function().__name__ - else: - assert isinstance(fn_vt, UnspecializedNNModuleVariable) - fn_id = id(fn_vt.value.forward.__func__) # type: ignore[attr-defined] - fn_name = fn_vt.value.forward.__name__ # type: ignore[attr-defined] - # pyrefly: ignore [implicit-any] - previously_installed_submodules = [] - if invoke_subgraph_cache: - previously_installed_submodules = ( - invoke_subgraph_cache.get_dynamo_installed_submodules(fn_id) - ) - current_mod = body_gmod - # NB - reverse is more likely to cause a hit sooner because first - # graph can have requires_grad=False for a few inputs - for submodule_name in reversed(previously_installed_submodules): - assert submodule_name in tx.output.nn_modules - previous_mod = tx.output.nn_modules[submodule_name] - assert tx.fake_mode - if are_same_graph_modules( - fn_name, previous_mod, current_mod, tx.fake_mode - ): - return submodule_name - - body_name = super().install_subgraph_in_output_graph( - tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" - ) - hc_log.debug( - "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", - fn_name, - body_name, - fn_name, - len(previously_installed_submodules) + 1, - ) - if invoke_subgraph_cache: - invoke_subgraph_cache.add_dynamo_installed_submodule(fn_id, body_name) - - return body_name - - def _call_function( - self, - tx: "InstructionTranslator", - args: Sequence[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - fn_var = args[0] - fn_args_vt = args[1:] - - config = None - if hasattr(fn_var, "get_function"): - try: - fn = fn_var.get_function() - config = getattr(fn, "__marked_compile_region_config__", None) - except Exception: - log.warning( - "Failed to extract nested_compile_region() config from InvokeSubgraphHigherOrderVariable. ", - exc_info=True, - ) - raise - - assert self._HOP_NAME is not None - ( - p_args, - p_kwargs, - example_value, - body_r, - body_gmod, - body_name, - body_graph_output_vts, - tracing_info, - ) = self.create_wrapped_node(tx, fn_var, fn_args_vt, kwargs, self._HOP_NAME) - - if len(p_kwargs) > 0: - unimplemented( - gb_type="invoke_subgraph: kwargs unexpected", - context=f"args: {args}, kwargs: {kwargs}", - explanation="kwargs should have been flattened into lifted args.", - hints=[ - *graph_break_hints.DYNAMO_BUG, - ], - ) - - if isinstance(config, NestedCompileRegionOptions): - body_gmod.meta["nested_region_config"] = config - - p_args = ( - p_args[0], - body_name, - *p_args[1:], - ) - - return _call_function_with_auto_output_flattening( # type: ignore[return-value] - tx, - torch._higher_order_ops.invoke_subgraph, - tuple(p_args), - p_kwargs, - example_value, - body_r, - body_graph_output_vts, - config=config, - ) - - class LocalMapWrappedHigherOrderVariable(WrapHigherOrderVariable): _HOP_NAME = "torch.ops.higher_order.local_map_hop" supports_input_mutation = False @@ -5424,6 +5561,7 @@ def check_none_last(placements: Sequence[Any | None]) -> int: priors[vt] = global_tensor vt.as_proxy().node.meta["example_value"] = local_tensor + # pyrefly: ignore [missing-attribute] vt.synchronize_attributes(tx) # Step 3: Trace local_map subgraph with local tensors @@ -5508,6 +5646,7 @@ def make_error_msg(*args: Any) -> str: # Step 6: Restore inputs and outputs to global shapes for vt, global_tensor in priors.items(): vt.as_proxy().node.meta["example_value"] = global_tensor + # pyrefly: ignore [missing-attribute] vt.synchronize_attributes(tx) outs = out.items if isinstance(out, TupleVariable) else [out] @@ -5547,6 +5686,9 @@ def make_error_msg(*args: Any) -> str: return out +from .invoke_subgraph import InvokeSubgraphHigherOrderVariable + + # Map operator names to their corresponding variable for fast TorchHigherOrderOperatorVariable.make() _hop_name_to_variable_class = { "cond": CondHigherOrderVariable, @@ -5574,6 +5716,7 @@ def make_error_msg(*args: Any) -> str: "dynamo_bypassing_wrapper": DynamoBypassingWrapperHigherOrderVariable, "auto_functionalized": AutoFunctionalizeHigherOrderVariable, "auto_functionalized_v2": AutoFunctionalizeHigherOrderVariable, + "inline_asm_elementwise": InlineAsmElementwiseHigherOrderVariable, "invoke_subgraph": InvokeSubgraphHigherOrderVariable, "custom_function_call": CustomFunctionHigherOrderOperatorVariable, "local_map_hop": LocalMapWrappedHigherOrderVariable, diff --git a/torch/_dynamo/variables/invoke_subgraph.py b/torch/_dynamo/variables/invoke_subgraph.py new file mode 100644 index 0000000000000..0c135b3060d29 --- /dev/null +++ b/torch/_dynamo/variables/invoke_subgraph.py @@ -0,0 +1,1344 @@ +""" +This module contains the InvokeSubgraphHigherOrderVariable class and its +supporting helpers for subgraph reuse (auto-cache) in Dynamo's invoke_subgraph +higher-order operator. +""" + +import enum +import logging +import traceback +import types +from dataclasses import dataclass +from typing import Any, cast, NamedTuple, TYPE_CHECKING + +import torch +import torch._higher_order_ops +from torch._dynamo import graph_break_hints +from torch._dynamo.exc import unimplemented +from torch._dynamo.guards import ( + extract_tensor_metadata, + GUARD_VALUE_DISPATCH, + GuardCheckSpec, + SKIP_GUARD, + UnsupportedGuardCheckSpec, +) +from torch._dynamo.source import SyntheticLocalSource +from torch._dynamo.variables.base import VariableTracker +from torch._dynamo.variables.constant import ConstantVariable +from torch._dynamo.variables.functions import UserFunctionVariable +from torch._dynamo.variables.higher_order_ops import WrapHigherOrderVariable +from torch._dynamo.variables.lists import ListVariable, TupleVariable +from torch._dynamo.variables.nn_module import UnspecializedNNModuleVariable +from torch._dynamo.variables.tensor import SymNodeVariable, TensorVariable +from torch._guards import ( + Guard, + InvokeSubgraphReuseCondition, + InvokeSubgraphReuseEntry, + Source, +) +from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions +from torch.fx.graph_module import GraphModule +from torch.fx.proxy import Proxy +from torch.utils import _pytree as pytree +from torch.utils._ordered_set import OrderedSet + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.variables.higher_order_ops import SubgraphTracingInfo + +log = logging.getLogger(__name__) +hc_log = torch._logging.getArtifactLogger(__name__, "hierarchical_compile") + +# Note: [invoke_subgraph subgraph reuse] +# +# When mark_compile_region wraps a function called N times (e.g. 80 identical +# transformer layers), Dynamo traces the subgraph once and stamps out cached +# copies for subsequent calls. It does safety checks to ensure that a subgraph +# is reusable, if not (e.g. side-effect), it will fallback to tracing the +# next invocation. +# +# HIGH-LEVEL FLOW +# =============== +# User code: model.layers[0](x), model.layers[1](x), ..., model.layers[79](x) +# | | | +# v v v +# +--------------+ +--------------+ +--------------+ +# | First Call | | Second Call | ... | 80th Call | +# +------+-------+ +------+-------+ +------+-------+ +# | | | +# v v v +# +--------------+ +--------------+ +--------------+ +# | Full subgraph| | Cache lookup | | Cache lookup | +# | trace | | (is_reusable) | | (is_reusable) | +# +------+-------+ +------+-------+ +------+-------+ +# | | | +# v v v +# +--------------+ +--------------+ +--------------+ +# | save_reuse_ | | stamp_out_ | | stamp_out_ | +# | entry() | | subgraph() | | subgraph() | +# +--------------+ +--------------+ +--------------+ +# +# WHAT GETS CACHED +# ================ +# After the first trace, save_reuse_entry stores an InvokeSubgraphReuseEntry +# (in _guards.py) containing: +# - body_name/body_gmod: the traced subgraph +# - arg_sources: sources of the original call's arguments +# - subgraph_input_mapping: how each lifted arg maps back to user inputs or captures +# - output_metadata: shape/stride/dtype/device of outputs +# +# Paired with an InvokeSubgraphReuseCondition containing: +# - input_checks: (tag, tensor_metadata) per input +# - guards: (source, handler, expected, guard) tuples +# - treespec: pytree structure of the args +# - traced_sources: sources accessed during the trace +# +# CACHE LOOKUP (is_reusable) +# ========================== +# On subsequent calls: +# 1. Input structure match -- same treespec, tags, tensor metadata. +# 2. Source replacement -- clone each guard's source with a replacement map +# (old: L['self'].layers[0].weight -> new: L['self'].layers[1].weight), +# then evaluate against the new source's runtime value. +# 3. Mutation check -- reject if the subgraph mutated any captured var. +# +# A shared resolve_cache memoizes intermediate source resolution (e.g. +# L['self'].layers is evaluated once and reused across all guards). +# +# STAMP OUT (stamp_out_subgraph) +# ============================== +# On cache hit, reconstruct the argument list using the freevar mapping +# (list[LiftedArgOrigin]): +# +# LiftedUserArg(index) +# User arg (activation / explicit input). +# Looked up from new call's flat proxies. +# +# LiftedCapturedSource(source) +# Sourceful captured var (weight, param, etc). +# Source is cloned with replacement map, resolved via +# VariableBuilder. Deduplicates via input_source_to_var. +# +# LiftedSyntheticObject(ctor_fn, ctor_args, ctor_arg_sources) +# Synthetic object (opaque type with SyntheticLocalSource). +# Reconstructed via synthetic_graph_input with cached constructor info. +# +# SAFETY +# ====== +# In normal Dynamo compilation, safety is enforced at runtime: guards are +# installed during tracing and re-evaluated on every subsequent call against +# real Python objects. Subgraph reuse operates differently — we are in the +# middle of tracing, there are no real Python objects, only VariableTrackers. +# We must answer: what could cause the second invocation of a nested compile +# region to produce a different trace than the first? +# +# VariableTrackers fall into two categories: +# +# 1. Intermediates — values produced during tracing with no originating source +# (e.g. the result of a prior FX op). These can reach a nested compile region +# only via (a) the region's explicit function arguments, or (b) closure +# capture. We do not support nested-function regions that close over tensors, +# so only (a) applies. For explicit arguments, the set of types that can +# appear is small and well-defined: TensorVariable, SymNodeVariable, +# ConstantVariable, and NNModuleVariable. Each has a cheap structural +# comparison (tensor metadata, symnode identity, constant value equality). +# We also snapshot the pytree treespec of the argument list and verify it +# matches on lookup, ensuring the flattened structure is identical. +# +# 2. Sourceful variables — values with a known originating source (e.g. a +# module attribute or a local variable visible in the outer frame). For these +# we collect the guard delta from the first trace, parameterize the guard +# sources by replacing the original arg sources with the new arg sources, and +# re-evaluate the guards by resolving each source against the live f_locals / +# f_globals. The one extra hazard here is mutation: if the outer trace +# mutates a sourceful object between the first and second invocations, the +# cached guards would evaluate against stale values. We therefore also check +# that none of the sources read by the cached subgraph have been mutated in +# the outer SideEffects tracker before accepting a reuse. +# +# - max_reuse_entries (default 8, configurable via nested_compile_region arg) +# caps cache entries per function. Exceeding it raises RuntimeError. +# - Guard failures logged with guard type + user stack trace. +# Enable: TORCH_LOGS='+hierarchical_compile' +# --------------------------------------------------------------------------- +# Auto-cache helpers for invoke_subgraph +# --------------------------------------------------------------------------- + + +class InputTag(enum.Enum): + TENSOR = "tensor" + SYMNODE = "symnode" + CONSTANT = "constant" + MODULE = "module" + + +class InputFingerprint(NamedTuple): + # (InputTag, VariableTracker) pairs for each leaf input. + flat_vts: list[tuple[InputTag, VariableTracker]] + # 1-1 mapping to flat_vts: source for each leaf, or None if the VT has no source. + arg_sources: list[Source | None] + # True if any leaf VT had an unsupported type for reuse. + has_unknown: bool = False + # TreeSpec from pytree.tree_flatten of the (args, kwargs) structure. + treespec: pytree.TreeSpec | None = None + + +def classify_vt(vt: Any) -> InputTag | None: + """Return the tag for a leaf VT, or None if unsupported.""" + if isinstance(vt, TensorVariable): + return InputTag.TENSOR + elif isinstance(vt, SymNodeVariable): + return InputTag.SYMNODE + elif isinstance(vt, ConstantVariable): + return InputTag.CONSTANT + elif isinstance(vt, UnspecializedNNModuleVariable): + return InputTag.MODULE + return None + + +def build_input_fingerprint( + tx: "InstructionTranslator", + fn_args_vt: Any, + kwargs: dict[str, Any], +) -> InputFingerprint: + """Build an InputFingerprint by flattening (args, kwargs) via pytree. + + Uses _make_inlined(tx, pytree.tree_flatten) to recursively flatten + the argument structure into leaf VTs, classifying each leaf as + tensor/symnode/constant/module. Also records the TreeSpec so that + cache lookups can verify structural equivalence. + + Fast path: when kwargs is empty and all args are already leaf VTs + (tensor/symnode/constant/module), skip the expensive pytree flatten. + """ + # Fast path: flat args, no kwargs — skip pytree machinery. + if not kwargs: + all_leaf = True + for vt in fn_args_vt: + if classify_vt(vt) is None: + all_leaf = False + break + if all_leaf: + return build_fingerprint_fast(fn_args_vt) + + return build_fingerprint_with_pytree(tx, fn_args_vt, kwargs) + + +def build_fingerprint_fast(fn_args_vt: Any) -> InputFingerprint: + """Build fingerprint for the common case of flat leaf args, no kwargs.""" + flat_vts: list[tuple[InputTag, VariableTracker]] = [] + arg_sources: list[Source | None] = [] + for vt in fn_args_vt: + tag = classify_vt(vt) + assert tag is not None + flat_vts.append((tag, vt)) + # Always append (even None) to keep positional alignment with flat_vts + # so that source_replacement zip pairing is correct across calls. + arg_sources.append(getattr(vt, "source", None)) + return InputFingerprint(flat_vts, arg_sources) + + +def build_fingerprint_with_pytree( + tx: "InstructionTranslator", + fn_args_vt: Any, + kwargs: dict[str, Any], +) -> InputFingerprint: + """Build fingerprint via pytree flatten for nested/kwargs cases.""" + from torch._dynamo.variables.builder import SourcelessBuilder + from torch._dynamo.variables.higher_order_ops import _make_inlined + + container_vt = SourcelessBuilder.create(tx, (list(fn_args_vt), kwargs)) + flat_list_vt, treespec_vt = _make_inlined(tx, pytree.tree_flatten)( + container_vt + ).unpack_var_sequence(tx) + treespec = treespec_vt.as_python_constant() + + flat_vts: list[tuple[InputTag, VariableTracker]] = [] + arg_sources: list[Source | None] = [] + has_unknown = False + + for vt in flat_list_vt.unpack_var_sequence(tx): + tag = classify_vt(vt) + if tag is not None: + flat_vts.append((tag, vt)) + else: + has_unknown = True + continue + + # Always append (even None) to keep positional alignment with flat_vts. + arg_sources.append(getattr(vt, "source", None)) + + return InputFingerprint(flat_vts, arg_sources, has_unknown, treespec) + + +def get_flat_proxies(fingerprint: InputFingerprint) -> list[Proxy]: + """Collect deduplicated proxies from tensor/symnode leaves.""" + seen: set[torch.fx.Node] = set() + flat_proxies: list[Proxy] = [] + for tag, vt in fingerprint.flat_vts: + if tag in (InputTag.TENSOR, InputTag.SYMNODE): + proxy = vt.as_proxy() + if proxy.node not in seen: + seen.add(proxy.node) + flat_proxies.append(proxy) + return flat_proxies + + +@dataclass +class LiftedUserArg: + """Lifted arg that came from a user argument (intermediate activation or explicit input).""" + + index: int + + +@dataclass +class LiftedCapturedSource: + """Lifted arg that is a captured variable (e.g. a weight or parameter) with a Source.""" + + source: Any # Source + + +@dataclass +class LiftedSyntheticObject: + """Lifted arg that is a TorchScriptObject with a SyntheticLocalSource.""" + + ctor_fn: Any # Callable + ctor_args: tuple[Any, ...] + ctor_arg_sources: tuple[Any, ...] | None + + +@dataclass +class LiftedBoundSymbol: + """Lifted arg that is a SymInt already bound as a graph input. + + SymInt graph inputs are created during tensor wrapping (not through + VariableBuilder.wrap_symint), so they aren't registered in + unspec_variable_map or variable_tracker_cache. Using LiftedCapturedSource + for these would resolve the source to a concrete Python int via + source.get_value() instead of reusing the existing symbolic proxy. + """ + + expr: Any # sympy.Expr + + +LiftedArgOrigin = ( + LiftedUserArg | LiftedCapturedSource | LiftedSyntheticObject | LiftedBoundSymbol +) + + +def get_fn_code(fn_var: Any) -> types.CodeType | None: + if isinstance(fn_var, UserFunctionVariable): + return fn_var.get_function().__code__ + elif isinstance(fn_var, UnspecializedNNModuleVariable): + return ( + fn_var.value.forward.__func__.__code__ # pyrefly: ignore[missing-attribute] + ) + return None + + +def has_mutated_vars( + tx: "InstructionTranslator", + traced_sources: OrderedSet[Source], +) -> bool: + """Check if any source accessed by the subgraph has been mutated. + + SideEffects.mutated_sources records the exact AttrSource for every + store_attr call. A simple set intersection with traced_sources tells + us whether any source the subgraph read was later written to. + """ + overlap = tx.output.side_effects.mutated_sources & traced_sources + if overlap: + hc_log.debug( + "subgraph_reuse: mutated sources detected -- %s", + overlap, + ) + return True + return False + + +def is_reuse_eligible( + tx: "InstructionTranslator", + body_r: Any, + fingerprint: InputFingerprint, + tracing_info: "SubgraphTracingInfo", + traced_sources: OrderedSet[Source] | None = None, + has_reuse_hash_fn: bool = False, +) -> bool: + """Best-effort check for whether a traced subgraph result can be reused. + + It is possible that a subgraph is morally reusable but does not fall + into the limited support that Dynamo has today. Current limitations: + - The subgraph must not have side effects. + - No sourceful variable accessed by the subgraph may have been + mutated, because guards are snapshotted on source values at trace + time — if the underlying object changed since then, the cached + guards would silently evaluate against stale values. + - Output must be a single tensor, or a tuple/list of plain tensors. + - All flattened inputs must be one of: tensor, symnode, constant, + unspecialized NN module — for sourceless or other input types we + rely on the treespec and tags for structural matching, so only + types with well-defined comparison semantics are supported. + + When ``has_reuse_hash_fn`` is True, side-effect and mutation checks are + skipped because the hash key replaces guards — there are no guards to + go stale from mutations. + """ + if not has_reuse_hash_fn: + if tracing_info.side_effect_stack is not None: + stack_msg = "\n" + "".join( + traceback.format_list(tracing_info.side_effect_stack) + ) + hc_log.debug( + "subgraph_reuse: not eligible -- subgraph has side effects%s", + stack_msg, + ) + return False + + if traced_sources and has_mutated_vars(tx, traced_sources): + return False + + if isinstance(body_r, TensorVariable): + pass + elif isinstance(body_r, (TupleVariable, ListVariable)): + non_tensor = [ + type(item).__name__ + for item in body_r.items + if not isinstance(item, TensorVariable) + ] + if non_tensor: + hc_log.debug( + "subgraph_reuse: not eligible -- output contains non-tensor types: %s", + non_tensor, + ) + return False + else: + hc_log.debug( + "subgraph_reuse: not eligible -- output type %s is not tensor or tuple/list", + type(body_r).__name__, + ) + return False + + if fingerprint.has_unknown: + hc_log.debug( + "subgraph_reuse: not eligible -- unsupported input VT types", + ) + return False + + return True + + +def build_reuse_condition( + tx: "InstructionTranslator", + fingerprint: InputFingerprint, + traced_sources: OrderedSet[Source], +) -> InvokeSubgraphReuseCondition | None: + """Build an InvokeSubgraphReuseCondition from a traced subgraph. + + A reuse condition is a mix of two kinds of checks: + + 1. **Input tag checks** (from flat_vts): For each flattened leaf VT, + we record its tag (_VtTag.TENSOR/SYMNODE/CONSTANT/MODULE) and + metadata (e.g. tensor shape/stride/dtype/device/requires_grad). + At lookup time, the treespec ensures structural equivalence, and + then we compare tags and metadata leaf-by-leaf. + + 2. **Guard checks** (from traced_sources): During the subgraph trace, + every source accessed via VariableBuilder is recorded. We look up + all guards installed on those sources (and on the arg_sources) to + build the set of guards that must be re-evaluated on cache hit. + This is more robust than guard diffing because it catches guards + that were already installed before the subgraph trace began. + + Raise if any guard type is unsupported, as a feedback for compiler + developers to support that guard type. + """ + from torch._guards import InvokeSubgraphReuseCondition + + input_checks: list[tuple[InputTag, object]] = [] + for tag, vt in fingerprint.flat_vts: + if tag == InputTag.TENSOR: + assert isinstance(vt, TensorVariable) + example = vt.proxy.node.meta.get("example_value", None) + if example is None: + hc_log.debug( + "subgraph_reuse: cannot build condition -- tensor input has no example_value" + ) + return None + input_checks.append((InputTag.TENSOR, extract_tensor_metadata(example))) + elif tag == InputTag.SYMNODE: + assert isinstance(vt, SymNodeVariable) + # Store the SymInt/SymFloat/SymBool object itself. Two accesses to + # the same symbolic dimension (e.g. x.shape[0] twice) produce the + # same Python object, so identity comparison in is_reusable is + # correct and avoids false matches between distinct symbols. + input_checks.append((InputTag.SYMNODE, vt.sym_num)) + elif tag == InputTag.CONSTANT: + assert isinstance(vt, ConstantVariable) + input_checks.append((InputTag.CONSTANT, vt.value)) + elif tag == InputTag.MODULE: + input_checks.append((InputTag.MODULE, None)) + else: + raise AssertionError( + f"Unexpected input tag '{tag}' for {type(vt).__name__} -- " + f"is_reuse_eligible should have rejected this" + ) + + # Collect all guards for sources accessed during the subgraph trace + # and for the flattened arg sources. + all_sources = set(traced_sources) + all_sources.update(s for s in fingerprint.arg_sources if s is not None) + all_relevant_guards: set[Guard] = set() + for source in all_sources: + all_relevant_guards.update(tx.output.guards.get_guards_for_source(source)) + + guard_tuples: list[tuple[Source, GuardCheckSpec, object, Guard]] = [] + for guard in all_relevant_guards: + source = guard.originating_source + type_str = guard.create_fn_name() + handler = GUARD_VALUE_DISPATCH.get(type_str) + + if handler is SKIP_GUARD: + continue + + if handler is None or isinstance(handler, UnsupportedGuardCheckSpec): + raise RuntimeError( + f"subgraph_reuse: unsupported guard type '{type_str}' on source '{source.name}'" + ) + + try: + value = tx.output.resolve_source_value(source) + except Exception: + raise RuntimeError( + f"subgraph_reuse: failed to resolve source '{source.name}' for {type_str} guard" + ) from None + + # TODO(anijain2305): vLLM workaround -- skip CONSTANT_MATCH on + # strings. Re-evaluate once vLLM migrates off this pattern. + # if type_str == "CONSTANT_MATCH" and isinstance(value, str): + # continue + + handler = cast(GuardCheckSpec, handler) + expected = handler.get_metadata_fn(guard, value) + guard_tuples.append((source, handler, expected, guard)) + + hc_log.debug("Number of guards %s", len(guard_tuples)) + + return InvokeSubgraphReuseCondition( + input_checks=input_checks, + guards=guard_tuples, + treespec=fingerprint.treespec, + traced_sources=traced_sources, + ) + + +def build_source_replacement( + old_arg_sources: list[Source | None], + new_arg_sources: list[Source | None], +) -> dict[Source, Source]: + """Map old arg sources to new arg sources for remapping captured variable sources.""" + return { + old: new + for old, new in zip(old_arg_sources, new_arg_sources) + if old is not None and new is not None and old != new + } + + +def is_reusable( + tx: "InstructionTranslator", + condition: "InvokeSubgraphReuseCondition", + fingerprint: InputFingerprint, + cached_entry: InvokeSubgraphReuseEntry, +) -> bool: + """Check if a cached subgraph can be reused for the current call. + + Three-phase check: + (1) Verify that intermediates (tensor metadata, symnode types, constant + values) match the cached input_checks — these are lightweight + structural comparisons that don't require source resolution. + (2) Check for mutations on the remapped traced_sources — if any source + the subgraph read has been mutated since the original trace, the + cached guards would evaluate against stale values. + (3) Build a source replacement mapping (old sources → new sources) and + re-evaluate the snapshotted guards under the new sources. + """ + # Structural check: treespec must match first. + if condition.treespec is not None and fingerprint.treespec != condition.treespec: + hc_log.debug( + "subgraph_reuse: reuse failed -- treespec mismatch", + ) + return False + + # Input count, tags, and metadata must match. + # Tensor metadata (shape, stride, dtype, device, requires_grad) is checked + # here because TENSOR_MATCH guards for subgraph inputs typically already + # exist in the outer graph before tracing and thus won't appear in the + # guard delta. + if len(condition.input_checks) != len(fingerprint.flat_vts): + hc_log.debug( + "subgraph_reuse: reuse failed -- input count mismatch: cached %d vs current %d", + len(condition.input_checks), + len(fingerprint.flat_vts), + ) + return False + + for i, ((cached_tag, cached_val), (cur_tag, cur_vt)) in enumerate( + zip(condition.input_checks, fingerprint.flat_vts) + ): + if cached_tag != cur_tag: + hc_log.debug( + "subgraph_reuse: reuse failed -- input %d tag mismatch: cached '%s' vs current '%s'", + i, + cached_tag, + cur_tag, + ) + return False + if cached_tag == InputTag.TENSOR: + assert isinstance(cur_vt, TensorVariable) + example = cur_vt.proxy.node.meta.get("example_value", None) + if example is None: + hc_log.debug( + "subgraph_reuse: reuse failed -- input %d tensor has no example_value", + i, + ) + return False + cur_meta = extract_tensor_metadata(example) + if cur_meta != cached_val: + hc_log.debug( + "subgraph_reuse: reuse failed -- input %d tensor metadata mismatch", + i, + ) + return False + elif cached_tag == InputTag.SYMNODE: + assert isinstance(cur_vt, SymNodeVariable) + if cur_vt.sym_num is not cached_val: + return False + elif cached_tag == InputTag.CONSTANT: + assert isinstance(cur_vt, ConstantVariable) + if cur_vt.value != cached_val: + # If both the cached and current arg have sources, source + # replacement in stamp_out will resolve the correct value. + cached_src = ( + cached_entry.arg_sources[i] + if i < len(cached_entry.arg_sources) + else None + ) + new_src = ( + fingerprint.arg_sources[i] + if i < len(fingerprint.arg_sources) + else None + ) + if cached_src is None or new_src is None: + return False + + source_replacement = build_source_replacement( + cached_entry.arg_sources, fingerprint.arg_sources + ) + + # Parameterized source - this function gives you new sources parameterized + # on the arg_sources. For example, if the input to the nested compile region + # is a nn Module layer with source `layers[0]`, then old source + # `layers[0].weight` gets remapped to `layers[1].weight`. This + # parameterization is central in getting the new sources and then running + # guards on them. + def replacement_fn(s: Source) -> Source: + return source_replacement.get(s, s) + + # Check for mutations on remapped traced_sources. + if source_replacement: + remapped = OrderedSet(s.clone(replacement_fn) for s in condition.traced_sources) + else: + remapped = condition.traced_sources + if has_mutated_vars(tx, remapped): + return False + + # If no sources changed, all guards were already checked during the + # original trace and will trivially pass again. + if not source_replacement: + return True + + # Shared resolution context so source.get_value memoizes intermediate + # results (e.g. common base sources) across all guards in this check. + resolve_globals: dict[str, Any] = { + "G": tx.output.root_tx.f_globals, + "L": tx.output.root_tx.f_locals, + } + resolve_locals: dict[str, Any] = {} + resolve_cache: dict[Source, Any] = {} + + for source, handler, expected, guard in condition.guards: + new_source = source.clone(replacement_fn) + # Source unchanged after replacement — guard already passed during + # the original trace, skip re-evaluation. + if new_source == source: + continue + + try: + value = new_source.get_value(resolve_globals, resolve_locals, resolve_cache) + except Exception: + hc_log.debug( + "subgraph_reuse: reuse failed -- cannot resolve source\n" + " guard type: %s\n" + " guard source: %s\n" + " guard source name: %s\n" + " user stack:\n%s", + guard.create_fn_name(), + new_source, + new_source.name, + "".join(guard.user_stack.format()) + if guard.user_stack + else "", + ) + return False + + if not handler.eval_fn(value, expected): + hc_log.debug( + "subgraph_reuse: reuse failed --\n" + " guard type: %s\n" + " guard source: %s\n" + " guard source name: %s\n" + " expected: %s\n" + " got: %s\n" + " user stack:\n%s", + guard.create_fn_name(), + new_source, + new_source.name, + expected, + value, + "".join(guard.user_stack.format()) + if guard.user_stack + else "", + ) + return False + + return True + + +def has_reuse_entries( + tx: "InstructionTranslator", + fn_var: Any, +) -> bool: + """Cheap check: does the cache have any entries for this function?""" + from torch._guards import InvokeSubgraphCache + + invoke_subgraph_cache = tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + if not isinstance(invoke_subgraph_cache, InvokeSubgraphCache): + return False + fn_code = get_fn_code(fn_var) + return fn_code is not None and fn_code in invoke_subgraph_cache.subgraph_reuse_cache + + +def find_reuse_match( + tx: "InstructionTranslator", + fn_var: Any, + fingerprint: InputFingerprint, +) -> InvokeSubgraphReuseEntry | None: + from torch._guards import InvokeSubgraphCache + + invoke_subgraph_cache = tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + if not isinstance(invoke_subgraph_cache, InvokeSubgraphCache): + return None + fn_code = get_fn_code(fn_var) + if fn_code is None: + return None + + # this evaluator function is called one by one for all the invoke subgraph + # reuse entries - the one that evaluates to True is stamped out in the + # graph. + def evaluator( + cond: "InvokeSubgraphReuseCondition", entry: InvokeSubgraphReuseEntry + ) -> bool: + return is_reusable(tx, cond, fingerprint, entry) + + return invoke_subgraph_cache.find_reuse_entry(fn_code, evaluator) + + +def save_reuse_entry( + tx: "InstructionTranslator", + fn_var: Any, + fingerprint: InputFingerprint, + body_name: str, + body_gmod: torch.fx.GraphModule, + config: NestedCompileRegionOptions | None, + p_args: tuple[Any, ...], + body_r: VariableTracker, + example_value: Any, + max_reuse_entries: int = 8, + condition: "InvokeSubgraphReuseCondition | None" = None, + hash_key: int | None = None, +) -> None: + """Save a traced subgraph into the reuse cache for future cache hits. + + Builds an InvokeSubgraphReuseEntry with the freevar mapping (how each + lifted arg maps back to user inputs or captured variables), output + metadata, and arg sources. On a future cache hit, stamp_out_subgraph + uses this entry to emit a new invoke_subgraph call without re-tracing. + + Exactly one of ``condition`` or ``hash_key`` must be provided. + ``condition`` stores the entry in the guard-based cache (linear scan); + ``hash_key`` stores it in the hash-key cache (O(1) lookup). + """ + from torch._guards import InvokeSubgraphCache + + assert (condition is None) != (hash_key is None), ( + "Exactly one of condition or hash_key must be provided" + ) + + invoke_subgraph_cache = tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + if not isinstance(invoke_subgraph_cache, InvokeSubgraphCache): + return + + fn_code = get_fn_code(fn_var) + if fn_code is None: + return + + subgraph_input_mapping = build_subgraph_input_mapping( + tx, p_args, fingerprint.flat_vts + ) + single_tensor_output = isinstance(body_r, TensorVariable) + + # Count user-visible outputs from body_r. The graph may have additional + # outputs from side-effect intermediates that stamp_out_subgraph must + # not include when reconstructing the user-visible return value. + user_output_vts: list[VariableTracker] = [] + VariableTracker.visit( + lambda vt: user_output_vts.append(vt) + if vt.is_tensor() or isinstance(vt, SymNodeVariable) + else None, + body_r, + ) + num_user_outputs = len(user_output_vts) + + # Cache output tensor metadata so we can construct fresh FakeTensors on + # cache hit without re-running the subgraph. This is safe because + # invoke_subgraph does not support aliasing between inputs and outputs + # (speculate_subgraph will fail if that happens). + # example_value may contain SymInts (e.g. shape values for backward); + # only record metadata for actual tensors. + output_metadata = [ + (t.shape, t.stride(), t.dtype, t.device, t.requires_grad) + for t in example_value + if isinstance(t, torch.Tensor) + ] + + entry = InvokeSubgraphReuseEntry( + body_name=body_name, + body_gmod=body_gmod, + config=config, + subgraph_input_mapping=subgraph_input_mapping, + single_tensor_output=single_tensor_output, + output_metadata=output_metadata, + # Record arg sources so that on cache hit we can build a + # source replacement mapping (old sources → new sources) to + # rewrite captured variable sources for the current invocation. + arg_sources=fingerprint.arg_sources, + num_user_outputs=num_user_outputs, + ) + if condition is not None: + invoke_subgraph_cache.add_reuse_entry( + fn_code, condition, entry, max_reuse_entries + ) + else: + assert hash_key is not None + invoke_subgraph_cache.add_reuse_entry_by_key( + fn_code, hash_key, entry, max_reuse_entries + ) + + +def trace_reuse_hash_fn( + tx: "InstructionTranslator", + reuse_hash_fn: Any, + fn_args_vt: "Sequence[VariableTracker]", + kwargs: dict[str, VariableTracker], +) -> int: + """Trace the user's reuse_hash_fn to get a constant integer hash key. + + Guards installed during the hash function tracing are skipped — the hash + key itself is the reuse condition, not the guards. + """ + from torch._dynamo.exc import Unsupported + from torch._dynamo.utils import _make_inlined + + with tx.output.tracing_context.guards_context.skip_guard_install(): + try: + result = _make_inlined(tx, reuse_hash_fn)(*fn_args_vt, **kwargs) + except Unsupported as e: + raise RuntimeError( + f"reuse_hash_fn must be fully traceable without graph breaks. Got: {e}" + ) from e + + if not isinstance(result, ConstantVariable) or not isinstance(result.value, int): + raise RuntimeError( + f"reuse_hash_fn must return a constant integer, got {result}" + ) + + return result.value + + +def find_reuse_entry_by_key( + tx: "InstructionTranslator", + fn_var: Any, + hash_key: int, +) -> InvokeSubgraphReuseEntry | None: + from torch._guards import InvokeSubgraphCache + + invoke_subgraph_cache = tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + if not isinstance(invoke_subgraph_cache, InvokeSubgraphCache): + return None + fn_code = get_fn_code(fn_var) + if fn_code is None: + return None + return invoke_subgraph_cache.find_reuse_entry_by_key(fn_code, hash_key) + + +def stamp_out_subgraph( + tx: "InstructionTranslator", + fingerprint: InputFingerprint, + cached: InvokeSubgraphReuseEntry, +) -> VariableTracker: + """Emit a new invoke_subgraph call by stamping out a cached subgraph. + + Sources in the cached entry are parameterized: they refer to the original + call's sources and must be rewritten to the current call's sources via + source replacement before we can look up or create the corresponding + graph placeholders. + """ + from torch._dynamo.variables.builder import VariableBuilder + from torch._dynamo.variables.higher_order_ops import add_call_function, make_attr + + flat_proxies = get_flat_proxies(fingerprint) + new_arg_sources = fingerprint.arg_sources + + source_replacement = build_source_replacement(cached.arg_sources, new_arg_sources) + + new_lifted_args = [] + # Shared resolution context so get_value memoizes intermediate results + # (e.g. L['self'].layers) across all freevars in this stamp-out. + resolve_globals: dict[str, Any] = { + "G": tx.output.root_tx.f_globals, + "L": tx.output.root_tx.f_locals, + } + resolve_locals: dict[str, Any] = {} + resolve_cache: dict[Source, Any] = {} + + # Find the args for the about-to-be-inserted invoke_subgraph call. + for subgraph_input in cached.subgraph_input_mapping: + if isinstance(subgraph_input, LiftedUserArg): + new_lifted_args.append(flat_proxies[subgraph_input.index]) + elif isinstance(subgraph_input, LiftedBoundSymbol): + from torch._dynamo.output_graph import LazyProxy + + proxy = tx.output.current_tracer.bound_symbols[subgraph_input.expr] + if isinstance(proxy, LazyProxy): + proxy = proxy() + tx.output.current_tracer.bound_symbols[subgraph_input.expr] = proxy + new_lifted_args.append(proxy) + elif isinstance(subgraph_input, LiftedSyntheticObject): + ctor_args = subgraph_input.ctor_args + ctor_arg_sources = subgraph_input.ctor_arg_sources + if ctor_arg_sources and source_replacement: + new_ctor_args = [] + new_ctor_arg_sources = [] + for val, arg_src in zip(ctor_args, ctor_arg_sources): + if arg_src is not None: + new_src = arg_src.clone(lambda s: source_replacement.get(s, s)) + val = new_src.get_value( + resolve_globals, resolve_locals, resolve_cache + ) + arg_src = new_src + new_ctor_args.append(val) + new_ctor_arg_sources.append(arg_src) + ctor_args = tuple(new_ctor_args) + ctor_arg_sources = tuple(new_ctor_arg_sources) + vt = tx.output.synthetic_graph_input( + subgraph_input.ctor_fn, ctor_args, ctor_arg_sources + ) + new_lifted_args.append(vt.as_proxy()) + elif isinstance(subgraph_input, LiftedCapturedSource): + new_source = subgraph_input.source + if source_replacement: + new_source = new_source.clone(lambda s: source_replacement.get(s, s)) + # VariableBuilder deduplicates via input_source_to_var, + # so this reuses existing graph placeholders automatically. + value = new_source.get_value(resolve_globals, resolve_locals, resolve_cache) + vt = VariableBuilder(tx, new_source)(value) + new_lifted_args.append(vt.as_proxy()) + + # Generate fake tensor outputs + assert tx.fake_mode is not None + with tx.fake_mode: + example_value = tuple( + torch.empty_strided( + shape, + stride, + dtype=dtype, + device=device, + requires_grad=req_grad, + ) + for shape, stride, dtype, device, req_grad in cached.output_metadata + ) + + # Install the invoke_subgraph call + body_node = make_attr(tx, cached.body_name) + p_args = (body_node, cached.body_name, *new_lifted_args) + flat_variable = add_call_function( + tx, + torch._higher_order_ops.invoke_subgraph, + tuple(p_args), + {}, + example_value, + cached.config, + ) + + # Return only the user-visible outputs. The graph may have extra + # intermediate outputs from side effects (allow_side_effects=True) + # that should not be part of the user-facing return value. + if cached.single_tensor_output: + items = flat_variable.items # pyrefly: ignore[missing-attribute] + assert isinstance(items[0], TensorVariable), ( + f"Expected tensor output but got {type(items[0]).__name__}" + ) + return items[0] + + items = flat_variable.items # pyrefly: ignore[missing-attribute] + n = cached.num_user_outputs + if n > 0 and n < len(items): + from .builder import SourcelessBuilder + + return SourcelessBuilder.create(tx, tuple(items[:n])) + return flat_variable + + +def build_subgraph_input_mapping( + tx: "InstructionTranslator", + p_args: tuple[Any, ...], + flat_vts: list[tuple[InputTag, VariableTracker]], +) -> list[LiftedArgOrigin]: + """Build a mapping that records the origin of each lifted arg for a subgraph. + + On a cache hit, we stamp out a new invoke_subgraph call and need to + reconstruct its argument list in the correct order. Each lifted arg + (p_args[2:], skipping body_node and body_name) comes from one of: + + - LiftedUserArg: a user argument (intermediate activation or explicit input) + - LiftedCapturedSource: a captured variable (e.g. a weight or parameter) + - LiftedSyntheticObject: a TorchScriptObject with a SyntheticLocalSource + - LiftedBoundSymbol: a SymInt already bound as a graph input + """ + proxy_node_to_idx: dict[torch.fx.Node, int] = {} + idx = 0 + for tag, vt in flat_vts: + if tag in (InputTag.TENSOR, InputTag.SYMNODE): + node = vt.as_proxy().node + if node not in proxy_node_to_idx: + proxy_node_to_idx[node] = idx + idx += 1 + + subgraph_input_mapping: list[LiftedArgOrigin] = [] + for outer_proxy in p_args[2:]: + matched_idx = proxy_node_to_idx.get(outer_proxy.node, -1) + if matched_idx >= 0: + subgraph_input_mapping.append(LiftedUserArg(matched_idx)) + else: + grapharg = outer_proxy.node.meta.get("grapharg", None) + source = grapharg.source if grapharg is not None else None + # SymInt freevars must reuse the existing symbolic proxy rather + # than resolving via source.get_value() (which returns the + # concrete int). They appear as either: + # - placeholder nodes with grapharg.example being a SymInt + # - call_function nodes (e.g. sym_size_int) with no grapharg + # In both cases, store the sympy expression and look it up in + # bound_symbols during stamp-out. + example = ( + grapharg.example + if grapharg is not None + else outer_proxy.node.meta.get("example_value", None) + ) + if isinstance(example, torch.SymInt): + subgraph_input_mapping.append(LiftedBoundSymbol(example.node.expr)) + continue + assert source is not None, ( + f"Freevar has no source: node.op={outer_proxy.node.op} " + f"node.name={outer_proxy.node.name} -- this likely means a " + f"function argument was not included in the proxy matching" + ) + if isinstance(source, SyntheticLocalSource): + ctor_info = tx.output.synthetic_source_ctor_info.get(source) + if ctor_info is not None: + ctor_fn, ctor_args, ctor_arg_sources = ctor_info + subgraph_input_mapping.append( + LiftedSyntheticObject(ctor_fn, ctor_args, ctor_arg_sources) + ) + continue + subgraph_input_mapping.append(LiftedCapturedSource(source)) + return subgraph_input_mapping + + +class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): + _HOP_NAME = "torch.ops.higher_order.invoke_subgraph" + _ALLOW_FALLBACK_TO_EAGER = False + supports_input_mutation = True + supports_aliasing = False + allow_side_effects = True + # invoke_subgraph is NOT desugared in AOTAutograd, so the HOP input/output + # shouldn't alias. For checkpoint HOP, we inline it so we don't need + # alias analysis as functionalization would just work on the flat graph. + filter_aliased_intermediates = True + + # pyrefly: ignore[bad-override] + def install_subgraph_in_output_graph( + self, + tx: "InstructionTranslator", + fn_vt: VariableTracker, + fn_args_vt: "Sequence[VariableTracker]", + kwargs: dict[str, VariableTracker], + body_gmod: GraphModule, + attr_name: str, + ) -> str: + # Check if the subgraph from speculate_subgraph (body_gmod) and the fake + # inputs have already been seen before. If yes, the subgraph is already + # installed in the output graph and we can just access the subgraph + # using the saved attr name. + + if not isinstance(fn_vt, (UnspecializedNNModuleVariable, UserFunctionVariable)): + unimplemented( + gb_type="Encountered non user function variable during invoke_subgraph HOP tracing", + context=str(fn_vt), + explanation="invoke_subgraph does not support non user function variable", + hints=[*graph_break_hints.SUPPORTABLE], + ) + + invoke_subgraph_cache = ( + tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + ) + + if isinstance(fn_vt, UserFunctionVariable): + fn_code = fn_vt.get_function().__code__ + fn_name = fn_vt.get_function().__name__ + else: + assert isinstance(fn_vt, UnspecializedNNModuleVariable) + fn_code = fn_vt.value.forward.__func__.__code__ # type: ignore[attr-defined] + fn_name = fn_vt.value.forward.__name__ # type: ignore[attr-defined] + # pyrefly: ignore [implicit-any] + previously_installed_submodules = [] + if invoke_subgraph_cache: + previously_installed_submodules = ( + invoke_subgraph_cache.get_dynamo_installed_submodules(fn_code) + ) + current_mod = body_gmod + # NB - reverse is more likely to cause a hit sooner because first + # graph can have requires_grad=False for a few inputs + for submodule_name in reversed(previously_installed_submodules): + assert submodule_name in tx.output.nn_modules + previous_mod = tx.output.nn_modules[submodule_name] + assert tx.fake_mode + from torch._dynamo.variables.higher_order_ops import ( + are_same_graph_modules, + ) + + if are_same_graph_modules( + fn_name, previous_mod, current_mod, tx.fake_mode + ): + return submodule_name + + body_name = super().install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod, "subgraph" + ) + hc_log.debug( + "%s: Installing subgraph with identifier '%s', bringing total count for '%s' function to %s", + fn_name, + body_name, + fn_name, + len(previously_installed_submodules) + 1, + ) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_dynamo_installed_submodule(fn_code, body_name) + + return body_name + + def _call_function( + self, + tx: "InstructionTranslator", + args: "Sequence[VariableTracker]", + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from torch._dynamo.utils import dynamo_timed + from torch._dynamo.variables.higher_order_ops import ( + _call_function_with_auto_output_flattening, + ) + + fn_var = args[0] + fn_args_vt = args[1:] + + config = None + max_reuse_entries = 8 + reuse_hash_fn = None + if hasattr(fn_var, "get_function"): + try: + fn = fn_var.get_function() + config = getattr(fn, "__marked_compile_region_config__", None) + max_reuse_entries = getattr( + fn, "__marked_compile_region_max_reuse_entries__", 8 + ) + reuse_hash_fn = getattr( + fn, "__marked_compile_region_reuse_hash_fn__", None + ) + except Exception: + log.warning( + "Failed to extract nested_compile_region() config from InvokeSubgraphHigherOrderVariable. ", + exc_info=True, + ) + raise + + # TODO (anijain2305) - Collect issues why this does not work for export, + # and enable if request arises. + reuse = not tx.output.export + + # User-provided reuse_hash_fn path: hash key determines cache lookup. + if reuse and reuse_hash_fn is not None: + with dynamo_timed("invoke_subgraph_reuse_hash_fn"): + hash_key = trace_reuse_hash_fn(tx, reuse_hash_fn, fn_args_vt, kwargs) + + cached = find_reuse_entry_by_key(tx, fn_var, hash_key) + if cached is not None: + hc_log.debug( + "subgraph_reuse: hash key %d hit for '%s', reusing subgraph '%s'", + hash_key, + fn_var, + cached.body_name, + ) + fingerprint = build_input_fingerprint(tx, fn_args_vt, kwargs) + with dynamo_timed("invoke_subgraph_reuse_stamp_out"): + return stamp_out_subgraph(tx, fingerprint, cached) + + # Automatic reuse lookup (guard-based): check fn_code first (cheap) to + # avoid the expensive pytree flatten in build_input_fingerprint on + # the first call when there's nothing in the cache yet. + elif reuse and has_reuse_entries(tx, fn_var): + with dynamo_timed("invoke_subgraph_reuse_lookup"): + fingerprint = build_input_fingerprint(tx, fn_args_vt, kwargs) + match = find_reuse_match( + tx, + fn_var, + fingerprint, + ) + if match is not None: + hc_log.debug( + "subgraph_reuse: cache hit for '%s', reusing subgraph '%s'", + fn_var, + match.body_name, + ) + with dynamo_timed("invoke_subgraph_reuse_stamp_out"): + return stamp_out_subgraph(tx, fingerprint, match) + + assert self._HOP_NAME is not None + with dynamo_timed("invoke_subgraph_trace"): + ( + p_args, + p_kwargs, + example_value, + body_r, + body_gmod, + body_name, + body_graph_output_vts, + tracing_info, + ) = self.create_wrapped_node(tx, fn_var, fn_args_vt, kwargs, self._HOP_NAME) + + if len(p_kwargs) > 0: + unimplemented( + gb_type="invoke_subgraph: kwargs unexpected", + context=f"args: {args}, kwargs: {kwargs}", + explanation="kwargs should have been flattened into lifted args.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + # Store config in the body graph module meta + if isinstance(config, NestedCompileRegionOptions): + body_gmod.meta["nested_region_config"] = config + + p_args = ( + p_args[0], + body_name, + *p_args[1:], + ) + + # Subgraph reuse: save entry for future cache hits + if reuse: + fingerprint = build_input_fingerprint(tx, fn_args_vt, kwargs) + if reuse_hash_fn is not None: + traced_sources = tracing_info.traced_sources + if not is_reuse_eligible( + tx, + body_r, + fingerprint, + tracing_info, + traced_sources, + has_reuse_hash_fn=True, + ): + raise RuntimeError( + "reuse_hash_fn was provided but the subgraph is not " + "eligible for reuse. Check the logs with " + "TORCH_LOGS='+hierarchical_compile' for details." + ) + save_reuse_entry( + tx, + fn_var, + fingerprint, + body_name, + body_gmod, + config, + p_args, + body_r, + example_value, + max_reuse_entries, + hash_key=hash_key, # type: ignore[possibly-undefined] + ) + else: + traced_sources = tracing_info.traced_sources + if is_reuse_eligible( + tx, body_r, fingerprint, tracing_info, traced_sources + ): + condition = build_reuse_condition( + tx, + fingerprint, + traced_sources, + ) + if condition is not None: + save_reuse_entry( + tx, + fn_var, + fingerprint, + body_name, + body_gmod, + config, + p_args, + body_r, + example_value, + max_reuse_entries, + condition=condition, + ) + + return _call_function_with_auto_output_flattening( # type: ignore[return-value] + tx, + torch._higher_order_ops.invoke_subgraph, + tuple(p_args), + p_kwargs, + example_value, + body_r, + body_graph_output_vts, + config=config, + ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 10f01ca87f5e7..e56f470685ae9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -55,6 +55,9 @@ def __repr__(self) -> str: def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def call_function( self, tx: "InstructionTranslator", @@ -195,10 +198,20 @@ def keyfunc(x: VariableTracker) -> Any: return tx.inline_user_function_return( VariableTracker.build(tx, polyfills.repeat), args, kwargs ) - elif self.value is itertools.count: - return variables.CountIteratorVariable( - *args, mutation_type=ValueMutationNew() - ) + elif self.value is itertools.count and not kwargs: + if len(args) == 0: + return variables.CountIteratorVariable(mutation_type=ValueMutationNew()) + if len(args) == 1: + return variables.CountIteratorVariable( + item=args[0], mutation_type=ValueMutationNew() + ) + if len(args) == 2: + return variables.CountIteratorVariable( + item=args[0], + step=args[1], + mutation_type=ValueMutationNew(), + ) + return super().call_function(tx, args, kwargs) elif ( self.value is itertools.permutations and (len(args) == 1 or (len(args) == 2 and args[1].is_python_constant())) @@ -264,7 +277,7 @@ def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> "ConstantVariable": if name == "__iter__" or name == "__next__": - return variables.CONSTANT_VARIABLE_TRUE + return variables.ConstantVariable.create(True) return super().call_obj_hasattr(tx, name) def call_method( @@ -318,6 +331,9 @@ def __init__(self, item: VariableTracker, **kwargs: Any) -> None: super().__init__(**kwargs) self.item = item + def python_type(self) -> type: + return itertools.repeat + # Repeat needs no mutation, clone self def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: return self.item @@ -336,10 +352,21 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class CountIteratorVariable(IteratorVariable): + # advance_count tracks how many next() calls were made during tracing, + # used by side_effects.py to replay them on the real iterator post-execution. + _nonvar_fields = { + "advance_count", + *IteratorVariable._nonvar_fields, + } + + def python_type(self) -> type: + return itertools.count + def __init__( self, item: int | VariableTracker = 0, step: int | VariableTracker = 1, + advance_count: int = 0, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -349,12 +376,14 @@ def __init__( step = ConstantVariable.create(step) self.item = item self.step = step + self.advance_count = advance_count def next_variable(self, tx: "InstructionTranslator") -> VariableTracker: assert self.is_mutable() old_item = self.item tx.output.side_effects.mutation(self) self.item = self.item.call_method(tx, "__add__", [self.step], {}) + self.advance_count += 1 return old_item def reconstruct(self, codegen: "PyCodegen") -> None: @@ -376,6 +405,9 @@ class ZipVariable(IteratorVariable): Represents zip(*iterables) """ + # PyZip_Type: https://github.com/python/cpython/blob/v3.13.0/Python/bltinmodule.c#L3011 + _cpython_type = zip + _nonvar_fields = { "index", "strict", @@ -439,7 +471,7 @@ def get_item( idx: int | None = None try: - for idx, it in enumerate(self.iterables): # noqa:B007 + for idx, it in enumerate(self.iterables): args.append(get_item(it)) except ObservedUserStopIteration: if self.strict: @@ -497,6 +529,9 @@ class MapVariable(ZipVariable): Represents map(fn, *iterables) """ + # PyMap_Type: https://github.com/python/cpython/blob/v3.13.0/Python/bltinmodule.c#L1484 + _cpython_type = map + def __init__( self, fn: VariableTracker, @@ -544,6 +579,9 @@ class FilterVariable(IteratorVariable): Represents filter(fn, iterable) """ + # PyFilter_Type: https://github.com/python/cpython/blob/v3.13.0/Python/bltinmodule.c#L630 + _cpython_type = filter + _nonvar_fields = { "index", *IteratorVariable._nonvar_fields, diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index def844adda73a..af023b42403c7 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -6,7 +6,7 @@ from typing import Any, TYPE_CHECKING from ..utils import is_function_or_wrapper -from .base import VariableTracker, VariableTrackerMeta +from .base import SourceLocation, VariableTracker, VariableTrackerMeta if TYPE_CHECKING: @@ -25,6 +25,7 @@ def __init__(self, value: Any, source: Any) -> None: self.value = value self.source = source self.name_hint: str | None = None + self.source_location: SourceLocation | None = None self.vt: VariableTracker | None = None def realize(self) -> None: @@ -47,9 +48,13 @@ def realize(self) -> None: if self.name_hint is not None: self.vt.set_name_hint(self.name_hint) + if self.source_location is not None and self.vt.source_location is None: + self.vt.set_source_location(self.source_location) + del self.value del self.source del self.name_hint + del self.source_location class LazyVariableTracker(VariableTracker, metaclass=VariableTrackerMeta): @@ -137,6 +142,13 @@ def set_name_hint(self, name: str) -> None: else: self._cache.name_hint = name + def set_source_location(self, source_location: SourceLocation) -> None: + self.source_location = source_location + if self.is_realized(): + self._cache.vt.set_source_location(source_location) # type: ignore[union-attr] + else: + self._cache.source_location = source_location + def __str__(self) -> str: variable_info = "LazyVariableTracker(" if self.is_realized(): diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index cda1dda994657..4149d56f45076 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -22,7 +22,7 @@ class that handles its unique behaviors while integrating with Dynamo's import torch import torch.fx -from torch.utils._pytree import GetAttrKey, SequenceKey +from torch.utils._pytree import SequenceKey from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( @@ -31,27 +31,24 @@ class that handles its unique behaviors while integrating with Dynamo's create_call_method, create_dup_top, create_instruction, - create_rot_n, ) -from ..exc import raise_observed_exception, unimplemented -from ..source import AttrSource, NamedTupleFieldsSource +from ..exc import raise_observed_exception, raise_type_error, unimplemented +from ..source import AttrSource from ..utils import ( cmp_name_to_op_mapping, cmp_name_to_op_str_mapping, get_fake_value, guard_if_dyn, iter_contains, - namedtuple_fields, odict_values, raise_args_mismatch, range_iterator, set_example_value, ) from .base import AsPythonConstantNotImplementedError, ValueMutationNew, VariableTracker -from .constant import CONSTANT_VARIABLE_FALSE, CONSTANT_VARIABLE_NONE, ConstantVariable +from .constant import ConstantVariable from .functions import UserFunctionVariable from .iter import IteratorVariable -from .user_defined import UserDefinedTupleVariable if TYPE_CHECKING: @@ -113,6 +110,9 @@ def as_proxy(self) -> Any: def getitem_const( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: + # TODO(follow-up): this assumes the caller (mp_subscript_impl) has already + # run _PyIndex_Check → nb_index_impl. Direct callers bypassing + # mp_subscript_impl will skip that validation. from .tensor import SymNodeVariable if isinstance(arg, SymNodeVariable): @@ -122,8 +122,9 @@ def getitem_const( if isinstance(index, slice): if index.step == 0: - msg = VariableTracker.build(tx, "slice step cannot be zero") - raise_observed_exception(ValueError, tx, args=[msg]) + raise_observed_exception( + ValueError, tx, args=["slice step cannot be zero"] + ) # Set source to None because slicing a list gives a new local return self.clone( items=self.items[index], @@ -135,12 +136,17 @@ def getitem_const( try: return self.items[index] except IndexError: - error_message = VariableTracker.build(tx, "list index out of range") - raise_observed_exception(IndexError, tx, args=[error_message]) + raise_observed_exception( + IndexError, tx, args=["list index out of range"] + ) def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return list(self.items) + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + """Sequence length for lists, tuples, and range objects.""" + return VariableTracker.build(tx, len(self.items)) + def call_tree_map_branch( self, tx: "InstructionTranslator", @@ -232,6 +238,33 @@ def call_tree_map_with_path_branch( mutation_type=ValueMutationNew(), ) + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # list_subscript: https://github.com/python/cpython/blob/62a6e898e01/Objects/listobject.c#L3689-L3710 + # _PyIndex_Check: https://github.com/python/cpython/blob/62a6e898e01/Include/internal/pycore_abstract.h#L13-L17 + # TODO(follow-up): replace hasattr(key_type, "__index__") with + # has_slot(num_slots, PyNumberSlots.NB_INDEX) for C extension types. + try: + key_type = key.python_type() + except NotImplementedError: + key_type = None + if key_type not in (int, bool, slice): + if key_type is not None and not hasattr(key_type, "__index__"): + container_name = self.python_type_name() + raise_observed_exception( + TypeError, + tx, + args=[ + f"{container_name} indices must be integers or slices, not {key.python_type_name()}" + ], + ) + key = key.nb_index_impl(tx) + + return self.getitem_const(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -241,39 +274,7 @@ def call_method( ) -> VariableTracker: from .builder import SourcelessBuilder - if name == "__getitem__": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - - if args[0].is_tensor(): - value = get_fake_value(args[0].as_proxy().node, tx) - if value.constant is not None and value.constant.numel() == 1: - value = VariableTracker.build(tx, value.constant.item()) - else: - unimplemented( - gb_type="Indexing list with non-scalar tensor", - context=f"call_method {self} {name} {args} {kwargs}", - explanation=( - "Attempted to index list-like object with tensor with > 1 element." - ), - hints=[*graph_break_hints.USER_ERROR], - ) - else: - value = args[0] - - if value.python_type() not in (int, slice): - msg = f"indices must be integers or slices, not {value.python_type()}" - raise_observed_exception( - TypeError, tx, args=[SourcelessBuilder.create(tx, msg)] - ) - - return self.getitem_const(tx, value) - elif name == "__contains__": + if name == "__contains__": if kwargs or len(args) != 1: raise_args_mismatch( tx, @@ -304,7 +305,7 @@ def call_method( raise_observed_exception( ValueError, tx, - args=[VariableTracker.build(tx, "tuple.index()")], + args=["tuple.index()"], ) except AsPythonConstantNotImplementedError: return tx.inline_user_function_return( @@ -337,10 +338,13 @@ def call_method( if type(self) is not type(args[0]): tp_name = self.python_type_name() other = args[0].python_type_name() - msg_vt = VariableTracker.build( - tx, f'can only concatenate {tp_name} (not "{other}") to {tp_name}' + raise_observed_exception( + TypeError, + tx, + args=[ + f'can only concatenate {tp_name} (not "{other}") to {tp_name}' + ], ) - raise_observed_exception(TypeError, tx, args=[msg_vt]) if name == "__add__": return type(self)(self.items + args[0].items, source=self.source) # type: ignore[attr-defined] @@ -357,11 +361,13 @@ def call_method( ) if not (args[0].is_python_constant() and args[0].python_type() is int): - msg_vt = VariableTracker.build( + raise_observed_exception( + TypeError, tx, - f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'", + args=[ + f"can't multiply sequence by non-int type of '{args[0].python_type_name()}'" + ], ) - raise_observed_exception(TypeError, tx, args=[msg_vt]) val = args[0].as_python_constant() @@ -400,11 +406,13 @@ def call_method( op_str = cmp_name_to_op_str_mapping[name] left_ty = left.python_type_name() right_ty = right.python_type_name() - msg = VariableTracker.build( + raise_observed_exception( + TypeError, tx, - f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'", + args=[ + f"{op_str} not supported between instances of '{left_ty}' and '{right_ty}'" + ], ) - raise_observed_exception(TypeError, tx, args=[msg]) return SourcelessBuilder.create(tx, polyfills.list_cmp).call_function( tx, @@ -422,6 +430,9 @@ def call_method( class RangeVariable(BaseListVariable): + # PyRange_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/rangeobject.c#L767 + _cpython_type = range + def __init__(self, items: Sequence[VariableTracker], **kwargs: Any) -> None: items_to_map = items start = variables.ConstantVariable.create(0) @@ -534,8 +545,6 @@ def _get_slice_indices(self, length: int, slice: slice) -> list[int]: return [start, stop, step] def apply_index(self, tx: "InstructionTranslator", index: int) -> VariableTracker: - from .builder import SourcelessBuilder - length = self.range_length() if index < 0: index = length + index @@ -544,7 +553,7 @@ def apply_index(self, tx: "InstructionTranslator", index: int) -> VariableTracke raise_observed_exception( IndexError, tx, - args=[SourcelessBuilder.create(tx, "range object index out of range")], + args=["range object index out of range"], ) return VariableTracker.build(tx, self.start() + (index * self.step())) @@ -576,9 +585,10 @@ def as_python_constant(self) -> range: def getitem_const( self, tx: "InstructionTranslator", arg: VariableTracker ) -> VariableTracker: - from .builder import SourcelessBuilder - - # implementations mimics https://github.com/python/cpython/blob/main/Objects/rangeobject.c + # range_subscript: https://github.com/python/cpython/blob/main/Objects/rangeobject.c + # TODO(follow-up): this assumes the caller (mp_subscript_impl) has already + # run _PyIndex_Check → nb_index_impl. Direct callers bypassing + # mp_subscript_impl will skip that validation. index = arg.as_python_constant() if isinstance(index, slice): @@ -586,10 +596,9 @@ def getitem_const( elif isinstance(index, int): return self.apply_index(tx, index) else: - msg = SourcelessBuilder.create( - tx, "range indices must be integers or slices" + raise_observed_exception( + TypeError, tx, args=["range indices must be integers or slices"] ) - raise_observed_exception(TypeError, tx, args=[msg]) def as_proxy(self) -> range: return self.python_type()(*self._as_proxy()) @@ -599,6 +608,13 @@ def unpack_var_sequence( ) -> list[VariableTracker]: return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + """Sequence length for range objects.""" + length = self.range_length() + if length > sys.maxsize: + raise_observed_exception(OverflowError, tx) + return VariableTracker.build(tx, length) + def reconstruct(self, codegen: "PyCodegen") -> None: assert "range" not in codegen.tx.f_globals codegen.add_push_null( @@ -647,6 +663,29 @@ def range_count(self, x: VariableTracker) -> int: return int(re) return 0 + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # range_subscript: https://github.com/python/cpython/blob/62a6e898e01/Objects/rangeobject.c#L729-L748 + # CPython: range_subscript checks _PyIndex_Check → PyNumber_Index for non-slice keys + try: + key_type = key.python_type() + except NotImplementedError: + key_type = None + if key_type not in (int, bool, slice): + if key_type is not None and not hasattr(key_type, "__index__"): + raise_observed_exception( + TypeError, + tx, + args=[ + f"range indices must be integers or slices, not {key.python_type_name()}" + ], + ) + key = key.nb_index_impl(tx) + return self.getitem_const(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -665,11 +704,6 @@ def call_method( return RangeIteratorVariable( self.start(), self.stop(), self.step(), self.range_length() ) - elif name == "__len__": - length = self.range_length() - if length > sys.maxsize: - raise_observed_exception(OverflowError, tx) - return VariableTracker.build(tx, self.range_length()) elif name in ("count", "__contains__"): return SourcelessBuilder.create(tx, self.range_count(*args)) elif name == "index": @@ -681,10 +715,8 @@ def call_method( raise_observed_exception( ValueError, tx, - args=[VariableTracker.build(tx, f"{x} is not in range")], + args=[f"{x} is not in range"], ) - elif name == "__getitem__": - return self.getitem_const(tx, *args) elif name in cmp_name_to_op_mapping: other = args[0] pt = other.python_type() @@ -693,7 +725,7 @@ def call_method( raise_observed_exception( TypeError, tx, - args=[VariableTracker.build(tx, msg)], + args=[msg], ) if pt is not range: @@ -762,7 +794,7 @@ def call_method( (arg,) = args tx.output.side_effects.mutation(self) self.items.append(arg) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "extend" and self.is_mutable(): if kwargs or len(args) != 1: raise_args_mismatch( @@ -773,16 +805,15 @@ def call_method( ) if not args[0].has_force_unpack_var_sequence(tx): - msg = VariableTracker.build( - tx, f"{type(args[0])} object is not iterable" + raise_observed_exception( + TypeError, tx, args=[f"{type(args[0])} object is not iterable"] ) - raise_observed_exception(TypeError, tx, args=[msg]) (arg,) = args arg.force_apply_to_var_sequence( tx, lambda item: self.call_method(tx, "append", [item], {}) ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "insert" and self.is_mutable(): if kwargs or len(args) != 2: raise_args_mismatch( @@ -799,7 +830,7 @@ def call_method( tx.output.side_effects.mutation(self) # type: ignore[arg-type] self.items.insert(const_idx, value) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "pop" and self.is_mutable(): if kwargs or len(args) > 1: raise_args_mismatch( @@ -810,14 +841,14 @@ def call_method( ) if len(self.items) == 0: - msg = VariableTracker.build(tx, "pop from empty list") - raise_observed_exception(IndexError, tx, args=[msg]) + raise_observed_exception(IndexError, tx, args=["pop from empty list"]) if len(args): idx = args[0].as_python_constant() - if idx > len(self.items): - msg = VariableTracker.build(tx, "pop index out of range") - raise_observed_exception(IndexError, tx, args=[msg]) + if idx >= len(self.items): + raise_observed_exception( + IndexError, tx, args=["pop index out of range"] + ) tx.output.side_effects.mutation(self) return self.items.pop(*[a.as_python_constant() for a in args]) elif name == "clear" and self.is_mutable(): @@ -830,7 +861,7 @@ def call_method( ) tx.output.side_effects.mutation(self) self.items.clear() - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "__setitem__" and self.is_mutable() and args: # Realize args[0] to get the concrete type for proper type checking key = args[0].realize() @@ -870,7 +901,7 @@ def call_method( self.items[items_slice] = list(value.items) # type: ignore[attr-defined] else: self.items[key.as_python_constant()] = value - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "__delitem__" and self.is_mutable(): if kwargs or len(args) != 1: raise_args_mismatch( @@ -896,15 +927,12 @@ def call_method( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) else: - msg = VariableTracker.build( - tx, - f"list indices must be integers or slices, not {args[0].python_type_name()}", - ) - raise_observed_exception(TypeError, tx, args=[msg]) - return CONSTANT_VARIABLE_NONE + msg = f"list indices must be integers or slices, not {args[0].python_type_name()}" + raise_type_error(tx, msg) + return ConstantVariable.create(None) elif name == "copy": # List copy() doesn't have args and kwargs if args or kwargs: @@ -926,7 +954,7 @@ def call_method( ) self.items.reverse() tx.output.side_effects.mutation(self) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "remove" and self.is_mutable(): if kwargs or len(args) != 1: raise_args_mismatch( @@ -938,12 +966,16 @@ def call_method( idx = self.call_method(tx, "index", args, kwargs) self.call_method(tx, "pop", [idx], {}) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) class ListVariable(CommonListMethodsVariable): + # PyList_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/listobject.c#L3776 + _cpython_type = list + _has_instance_dict = False + def python_type(self) -> type: return list @@ -1002,13 +1034,15 @@ def call_method( tx.output.side_effects.mutation(self) if isinstance(key, SliceVariable): if not value.has_force_unpack_var_sequence(tx): - msg = VariableTracker.build(tx, "can only assign an iterable") - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception( + TypeError, tx, args=["can only assign an iterable"] + ) key_as_const = key.as_python_constant() if key_as_const.step == 0: - msg = VariableTracker.build(tx, "slice step cannot be zero") - raise_observed_exception(ValueError, tx, args=[msg]) + raise_observed_exception( + ValueError, tx, args=["slice step cannot be zero"] + ) value_unpack = value.force_unpack_var_sequence(tx) try: @@ -1017,7 +1051,7 @@ def call_method( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) else: if isinstance(key, SymNodeVariable): @@ -1032,14 +1066,14 @@ def call_method( raise_observed_exception( type(e), tx, args=[VariableTracker.build(tx, a) for a in e.args] ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) if name == "sort" and self.is_mutable(): if len(args) != 0: raise_args_mismatch(tx, name, "0 args", f"{len(args)} args") - key_fn_var = kwargs.pop("key", CONSTANT_VARIABLE_NONE) + key_fn_var = kwargs.pop("key", ConstantVariable.create(None)) reverse = kwargs.pop( - "reverse", CONSTANT_VARIABLE_FALSE + "reverse", ConstantVariable.create(False) ).as_python_constant() if len(kwargs) != 0: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") @@ -1090,18 +1124,18 @@ def call_method( self.items[:] = [x for x, *_ in sorted_items_with_keys] except Exception as e: raise_observed_exception(type(e), tx, args=list(e.args)) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) if name == "__init__" and self.is_mutable(): if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") if len(args) == 0: - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): (arg,) = args tx.output.side_effects.mutation(self) self.items[:] = arg.force_unpack_var_sequence(tx) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) return super().call_method(tx, name, args, kwargs) @@ -1126,7 +1160,15 @@ def is_python_hashable(self) -> bool: return False +# TODO(follow-up): DequeVariable inherits BaseListVariable.mp_subscript_impl which +# accepts slices. CPython's deque only has sq_item (Modules/_collectionsmodule.c:1888), +# not mp_subscript — deque[slice] should raise TypeError. Override mp_subscript_impl +# to reject slices and only accept integer-like keys via _PyIndex_Check → nb_index_impl. +# Also add tests for: negative index, __index__ object key, invalid type key. class DequeVariable(CommonListMethodsVariable): + # deque_spec: https://github.com/python/cpython/blob/v3.13.0/Modules/_collectionsmodule.c#L1866 + _cpython_type = collections.deque + def __init__( self, items: list[VariableTracker], @@ -1134,7 +1176,7 @@ def __init__( **kwargs: Any, ) -> None: if maxlen is None: - maxlen = CONSTANT_VARIABLE_NONE + maxlen = ConstantVariable.create(None) assert maxlen.is_python_constant(), ( f"maxlen must be a constant, got: {maxlen.debug_repr()}" ) @@ -1219,7 +1261,7 @@ def call_method( assert isinstance(key.as_python_constant(), int) tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) maxlen = self.maxlen.as_python_constant() if maxlen is not None: @@ -1246,7 +1288,7 @@ def call_method( tx, lambda item: self.call_method(tx, "appendleft", [item], {}) ) slice_within_maxlen = slice(None, maxlen) - result = CONSTANT_VARIABLE_NONE + result = ConstantVariable.create(None) elif name == "popleft" and self.is_mutable(): if kwargs or len(args) > 0: raise_args_mismatch( @@ -1268,7 +1310,7 @@ def call_method( tx.output.side_effects.mutation(self) self.items[:] = [args[0], *self.items] slice_within_maxlen = slice(None, maxlen) - result = CONSTANT_VARIABLE_NONE + result = ConstantVariable.create(None) elif name == "insert" and len(args) > 0 and self.is_mutable(): if kwargs or len(args) != 2: raise_args_mismatch( @@ -1278,10 +1320,9 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) if maxlen is not None and len(self.items) == maxlen: - error_message = VariableTracker.build( - tx, "deque already at its maximum size" + raise_observed_exception( + IndexError, tx, args=["deque already at its maximum size"] ) - raise_observed_exception(IndexError, tx, args=[error_message]) result = super().call_method(tx, name, args, kwargs) else: result = super().call_method(tx, name, args, kwargs) @@ -1303,6 +1344,9 @@ def call_obj_hasattr( class TupleVariable(BaseListVariable): + # PyTuple_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/tupleobject.c#L846 + _cpython_type = tuple + def python_type(self) -> type[tuple]: # type: ignore[type-arg] return tuple @@ -1346,6 +1390,8 @@ def is_python_equal(self, other: object) -> bool: class SizeVariable(TupleVariable): """torch.Size(...)""" + _cpython_type = torch.Size + _nonvar_fields = { "proxy", *TupleVariable._nonvar_fields, @@ -1453,6 +1499,29 @@ def numel(self, tx: "InstructionTranslator") -> VariableTracker: result = mul.call_function(tx, [result, v], {}) return result + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # tuple_subscript: https://github.com/python/cpython/blob/62a6e898e01/Objects/tupleobject.c#L877-L930 + # CPython: tuplesubscript checks _PyIndex_Check → PyNumber_AsSsize_t for non-slice keys + try: + key_type = key.python_type() + except NotImplementedError: + key_type = None + if key_type not in (int, bool, slice): + if key_type is not None and not hasattr(key_type, "__index__"): + raise_observed_exception( + TypeError, + tx, + args=[ + f"tuple indices must be integers or slices, not {key.python_type_name()}" + ], + ) + key = key.nb_index_impl(tx) + return self.get_item_dyn(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -1460,17 +1529,7 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - if name == "__getitem__": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - out = self.get_item_dyn(tx, args[0]) - return out - elif name == "numel": + if name == "numel": if args or kwargs: raise_args_mismatch( tx, @@ -1516,324 +1575,10 @@ def call_obj_hasattr( return VariableTracker.build(tx, hasattr(torch.Size, name)) -class NamedTupleVariable(UserDefinedTupleVariable): - _nonvar_fields = { - "tuple_cls", - "dynamic_attributes", - *UserDefinedTupleVariable._nonvar_fields, - } - - def __init__( - self, - items: list[VariableTracker], - # pyrefly: ignore [implicit-any] - tuple_cls: type[tuple], - dynamic_attributes: dict[str, VariableTracker] | None = None, - tuple_vt: TupleVariable | None = None, - **kwargs: Any, - ) -> None: - if tuple_vt is None: - assert getattr(kwargs, "source", None) is None - tuple_vt = variables.TupleVariable( - items, mutation_type=kwargs.get("mutation_type", ValueMutationNew()) - ) - - if tuple_cls.__module__ == "torch.return_types": - # Structseq: single iterable argument - dummy_value = tuple_cls(items) - else: - # Namedtuple: positional arguments - dummy_value = tuple_cls(*items) # type: ignore[arg-type] - - super().__init__( - value=dummy_value, - tuple_vt=tuple_vt, - init_args=items, - **kwargs, - ) - self.tuple_cls = tuple_cls - if len(self.tuple_cls.__mro__) < 3: - raise ValueError("NamedTuple should inherit from Tuple and Object.") - self.dynamic_attributes = dynamic_attributes if dynamic_attributes else {} - - @property - def items(self) -> list[VariableTracker]: - return self._tuple_vt.items - - def is_namedtuple(self) -> bool: - return isinstance(getattr(self.tuple_cls, "_fields", None), tuple) and callable( - getattr(self.tuple_cls, "_make", None) - ) - - def is_structseq(self) -> bool: - return not self.is_namedtuple() - - def fields(self) -> tuple[str, ...]: - return namedtuple_fields(self.tuple_cls) - - def as_python_constant(self) -> Any: - if self.is_structseq(): - # StructSequenceType(iterable) - result = self.python_type()([x.as_python_constant() for x in self.items]) - else: - # NamedTupleType(*iterable) - result = self.python_type()(*[x.as_python_constant() for x in self.items]) - - # Apply dynamic attributes if any were set - if self.dynamic_attributes: - for attr_name, attr_value in self.dynamic_attributes.items(): - # Convert VariableTracker to Python constant if needed - if hasattr(attr_value, "as_python_constant"): - python_value = attr_value.as_python_constant() - else: - raise NotImplementedError( - "Can not convert dynamic attribute without python constant value to python constant." - ) - setattr(result, attr_name, python_value) - - return result - - def as_proxy(self) -> Any: - if self.is_structseq(): - return self.python_type()([x.as_proxy() for x in self._tuple_vt.items]) - return self.python_type()(*[x.as_proxy() for x in self._tuple_vt.items]) - - def call_tree_map( - self, - tx: "InstructionTranslator", - tree_map_fn: UserFunctionVariable, - map_fn: VariableTracker, - rest: Sequence[VariableTracker], - tree_map_kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - is_leaf_var = tree_map_kwargs.get("is_leaf") - if is_leaf_var is not None and not is_leaf_var.is_constant_none(): - pred_result = is_leaf_var.call_function(tx, [self], {}) - try: - leaf_decision = pred_result.as_python_constant() - except NotImplementedError: - return self._tree_map_fallback( - tx, tree_map_fn, map_fn, rest, tree_map_kwargs - ) - if leaf_decision: - return map_fn.call_function(tx, [self, *rest], {}) - - return self.call_tree_map_branch( - tx, - tree_map_fn, - map_fn, - rest, - tree_map_kwargs, - ) - - def call_tree_map_branch( - self, - tx: "InstructionTranslator", - tree_map_fn: UserFunctionVariable, - map_fn: VariableTracker, - rest: Sequence[VariableTracker], - tree_map_kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - other_tuples: list[NamedTupleVariable] = [] - for candidate in rest: - if ( - not isinstance(candidate, NamedTupleVariable) - or len(candidate.items) != len(self.items) - or candidate.tuple_cls is not self.tuple_cls - ): - return self._tree_map_fallback( - tx, tree_map_fn, map_fn, rest, tree_map_kwargs - ) - other_tuples.append(candidate) - - new_items: list[VariableTracker] = [] - for idx, item in enumerate(self.items): - sibling_leaves = [candidate.items[idx] for candidate in other_tuples] - new_items.append( - item.call_tree_map( - tx, - tree_map_fn, - map_fn, - sibling_leaves, - tree_map_kwargs, - ) - ) - - return NamedTupleVariable( - new_items, - self.tuple_cls, - mutation_type=ValueMutationNew(), - ) - - def call_tree_map_with_path( - self, - tx: "InstructionTranslator", - tree_map_fn: UserFunctionVariable, - map_fn: VariableTracker, - rest: Sequence[VariableTracker], - tree_map_kwargs: dict[str, VariableTracker], - keypath: tuple[Any, ...], - ) -> VariableTracker: - is_leaf_var = tree_map_kwargs.get("is_leaf") - if is_leaf_var is not None and not is_leaf_var.is_constant_none(): - pred_result = is_leaf_var.call_function(tx, [self], {}) - try: - leaf_decision = pred_result.as_python_constant() - except NotImplementedError: - # For namedtuples, they're always pytree containers, so is_leaf - # should return False. Assume False and proceed with fast path. - leaf_decision = False - if leaf_decision: - keypath_var = variables.TupleVariable( - [VariableTracker.build(tx, k) for k in keypath] - ) - return map_fn.call_function(tx, [keypath_var, self, *rest], {}) - - return self.call_tree_map_with_path_branch( - tx, - tree_map_fn, - map_fn, - rest, - tree_map_kwargs, - keypath, - ) - - def call_tree_map_with_path_branch( - self, - tx: "InstructionTranslator", - tree_map_fn: UserFunctionVariable, - map_fn: VariableTracker, - rest: Sequence[VariableTracker], - tree_map_kwargs: dict[str, VariableTracker], - keypath: tuple[Any, ...], - ) -> VariableTracker: - other_tuples: list[NamedTupleVariable] = [] - for candidate in rest: - if ( - not isinstance(candidate, NamedTupleVariable) - or len(candidate.items) != len(self.items) - or candidate.tuple_cls is not self.tuple_cls - ): - return self._tree_map_with_path_fallback( - tx, tree_map_fn, map_fn, rest, tree_map_kwargs, keypath - ) - other_tuples.append(candidate) - - fields = self.fields() - new_items: list[VariableTracker] = [] - for idx, item in enumerate(self.items): - sibling_leaves = [candidate.items[idx] for candidate in other_tuples] - child_keypath = keypath + (GetAttrKey(fields[idx]),) - new_items.append( - item.call_tree_map_with_path( - tx, - tree_map_fn, - map_fn, - sibling_leaves, - tree_map_kwargs, - child_keypath, - ) - ) - - return NamedTupleVariable( - new_items, - self.tuple_cls, - mutation_type=ValueMutationNew(), - ) - - def reconstruct(self, codegen: "PyCodegen") -> None: - if self.is_structseq(): - create_fn = self.tuple_cls - else: - create_fn = self.tuple_cls._make # type: ignore[attr-defined] - - codegen.add_push_null( - lambda: codegen.append_output( - codegen.create_load_const_unchecked(create_fn) - ) - ) - codegen.foreach(self._tuple_vt.items) - codegen.extend_output( - [ - create_build_tuple(len(self._tuple_vt.items)), - ] - + create_call_function(1, False) - ) - - # Apply initial dynamic attributes after construction (if any) - # Runtime dynamic attributes are tracked via side effects system - for name, value in self.dynamic_attributes.items(): - codegen.dup_top() - codegen(value) - codegen.extend_output(create_rot_n(2)) - codegen.store_attr(name) - - def _is_method_overridden(self, method_name: str) -> bool: - if len(self.tuple_cls.__mro__) < 3: - raise ValueError("NamedTuple should inherit from Tuple and Object.") - if getattr(self.tuple_cls, method_name, None) == getattr( - self.tuple_cls.__mro__[-3], method_name, None - ): - return False - return True - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if self._is_method_overridden(name): - # Fall back to UserDefinedTupleVariable - return super().call_method(tx, name, args, kwargs) - elif name == "__setattr__": - if kwargs or len(args) != 2: - raise_args_mismatch( - tx, - name, - "2 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - attr_var, value = args - attr = attr_var.as_python_constant() - - if ( - # structseq is immutable - self.is_structseq() - # namedtuple directly created by `collections.namedtuple` is immutable - or self.tuple_cls.__bases__ == (tuple,) - or attr in self.fields() - ): - raise_observed_exception(AttributeError, tx) - - result = self.method_setattr_standard(tx, attr_var, value) - # Also update self.dynamic_attributes - self.dynamic_attributes[attr] = value - return result - - return super().call_method(tx, name, args, kwargs) - - def python_type(self) -> type: - return self.tuple_cls - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - if name == "_fields": - source = NamedTupleFieldsSource(self.source) if self.source else None - return VariableTracker.build(tx, self.fields(), source=source) - - if name in self.dynamic_attributes: - return self.dynamic_attributes[name] - - fields = self.fields() - if name in fields: - field_index = fields.index(name) - return self._tuple_vt.items[field_index] - - return super().var_getattr(tx, name) - - class SliceVariable(VariableTracker): + # PySlice_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/sliceobject.c#L689 + _cpython_type = slice + def __init__( self, items: Sequence[VariableTracker], @@ -1841,7 +1586,7 @@ def __init__( **kwargs: Any, ) -> None: items_to_map = items - start, stop, step = [variables.CONSTANT_VARIABLE_NONE] * 3 + start, stop, step = [variables.ConstantVariable.create(None)] * 3 if len(items_to_map) == 1: (stop,) = items_to_map @@ -1892,7 +1637,9 @@ def reconstruct(self, codegen: "PyCodegen") -> None: def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(slice, name)) + ) fields = ["start", "stop", "step"] if name not in fields: unimplemented( @@ -1906,6 +1653,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker class ListIteratorVariable(IteratorVariable): + # PyListIter_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/listobject.c#L3842 + _cpython_type = type(iter([])) + _nonvar_fields = { "index", *IteratorVariable._nonvar_fields, @@ -1981,10 +1731,14 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class TupleIteratorVariable(ListIteratorVariable): - pass + # PyTupleIter_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/tupleobject.c#L1067 + _cpython_type = type(iter(())) class RangeIteratorVariable(IteratorVariable): + # PyRangeIter_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/rangeobject.c#L896 + _cpython_type = type(iter(range(0))) + # only needed for isinstance(..., range_iterator) to work _nonvar_fields = { "iter_obj", diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 03cbb774eca25..93f3601cc1002 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -15,6 +15,7 @@ - DebuggingVariable: Handles print and logging """ +import builtins import dataclasses import enum import functools @@ -30,7 +31,7 @@ from collections.abc import Callable, Sequence from random import Random from types import BuiltinFunctionType -from typing import Any, Literal, NoReturn, TYPE_CHECKING, TypeGuard, Union +from typing import Any, Literal, TYPE_CHECKING, TypeGuard, Union import torch._C import torch._numpy as tnp @@ -46,12 +47,11 @@ create_instruction, ) from ..create_parameter_op import do_not_convert_to_tracable_parameter -from ..exc import raise_observed_exception, unimplemented +from ..exc import raise_observed_exception, raise_type_error, unimplemented from ..guards import GuardBuilder, install_guard from ..mutation_guard import unpatched_nn_module_init from ..source import ( AttrSource, - DictGetItemSource, GenericAttrSource, GetItemSource, TypeMROSource, @@ -64,17 +64,11 @@ identity, is_tensor_base_attr_getter, istype, - list_methods, proxy_args_kwargs, raise_args_mismatch, - tuple_methods, ) -from .base import ( - AsPythonConstantNotImplementedError, - raise_type_error_exc, - VariableTracker, -) -from .constant import CONSTANT_VARIABLE_FALSE, CONSTANT_VARIABLE_NONE, ConstantVariable +from .base import AsPythonConstantNotImplementedError, NO_SUCH_SUBOBJ, VariableTracker +from .constant import ConstantVariable from .functions import NestedUserFunctionVariable, UserFunctionVariable from .user_defined import call_random_fn, is_standard_setattr, UserDefinedObjectVariable @@ -84,11 +78,10 @@ from torch._dynamo.symbolic_convert import InstructionTranslator -class NO_SUCH_SUBOBJ: - pass - - class SuperVariable(VariableTracker): + # PySuper_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/typeobject.c#L11511 + _cpython_type = super + _nonvar_fields = { *VariableTracker._nonvar_fields, } @@ -110,6 +103,9 @@ def __init__( # cls for a classmethod) self.objvar = objvar + def python_type(self) -> type: + return builtins.super + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(variables.BuiltinVariable(super))) codegen(self.typevar) @@ -198,7 +194,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker # not just AttrSource). value, source = self._resolved_getattr_and_source(tx, name) if not variables.ConstantVariable.is_literal(value): - return GetAttrVariable(self, name) + return GetAttrVariable(self, name, py_type=type(value)) if source: install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH)) return variables.ConstantVariable.create(value, source=source) @@ -343,52 +339,33 @@ def call_method( tx.output.side_effects.store_attr( self.objvar, attr, variables.DeletedVariable() ) - return variables.CONSTANT_VARIABLE_NONE - elif ( - isinstance(self.objvar, variables.UserDefinedDictVariable) - and inner_fn in self.objvar._dict_methods - ): - return self.objvar._dict_vt.call_method(tx, name, args, kwargs) + return variables.ConstantVariable.create(None) elif ( - isinstance(self.objvar, variables.UserDefinedSetVariable) - and inner_fn in self.objvar._set_methods + isinstance(self.objvar, variables.UserDefinedObjectVariable) + and self.objvar._base_vt is not None + and self.objvar._base_methods is not None + and inner_fn in self.objvar._base_methods ): - return self.objvar._set_vt.call_method(tx, name, args, kwargs) - elif ( - isinstance(self.objvar, variables.UserDefinedTupleVariable) - and inner_fn in tuple_methods - ): - return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) - elif ( - isinstance(self.objvar, variables.UserDefinedListVariable) - and inner_fn in list_methods - ): - return self.objvar._list_vt.call_method(tx, name, args, kwargs) + return self.objvar._base_vt.call_method(tx, name, args, kwargs) elif inner_fn is object.__getattribute__: - # object.__getattribute__ has no side-effects. We can directly call - # __getattribute__ to access the attribute. attr_name = args[0].value # type: ignore[attr-defined] - if tx.output.side_effects.has_pending_mutation_of_attr( - self.objvar, attr_name - ): - result = tx.output.side_effects.load_attr( - self.objvar, attr_name, deleted_ok=True - ) - if isinstance(result, variables.DeletedVariable): - raise_observed_exception(AttributeError, tx) - return result + # object.__getattribute__ IS PyObject_GenericGetAttr. Delegate + # to the shared implementation so that __dict__, __class__, + # polyfilled C descriptors, etc. are all handled consistently. + if isinstance(self.objvar, UserDefinedObjectVariable): + return self.objvar.generic_getattr(tx, attr_name) attr_value = None try: - # NB - use object.__getattribute__ to prevent running any user code - # type: ignore[attr-defined] - attr_value = object.__getattribute__(self.objvar.value, attr_name) + attr_value = object.__getattribute__( + self.objvar.value, # pyrefly: ignore[missing-attribute] + attr_name, + ) except AttributeError: raise_observed_exception(AttributeError, tx) attr_source = None if self.objvar.source is not None: - # setup a object.__getattribute__(self.objvar, name) source attr_source = GenericAttrSource(self.objvar.source, attr_name) return VariableTracker.build(tx, attr_value, attr_source) elif inner_fn is torch._C._disabled_torch_function_impl: @@ -484,9 +461,7 @@ def from_frame_summary( @staticmethod def is_valid_traceback(obj: VariableTracker) -> bool: - return istype(obj, TracebackVariable) or ( - istype(obj, ConstantVariable) and obj.is_constant_none() - ) + return istype(obj, TracebackVariable) or obj.is_constant_none() def extract_tb(self) -> list[traceback.FrameSummary | FrameSummaryVariable]: if istype(self.tb_next, ConstantVariable): @@ -521,7 +496,7 @@ def call_setattr( ): raise_observed_exception(ValueError, tx) self.tb_next = val - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "tb_next": @@ -555,11 +530,14 @@ def call_method( class ExceptionVariable(VariableTracker): + # _PyExc_BaseException: https://github.com/python/cpython/blob/v3.13.0/Objects/exceptions.c + _cpython_type = BaseException + # The ExceptionVariable corresponds to the BaseException class in Python def __init__( self, exc_type: Any, - args: tuple[VariableTracker, ...], + args: list[VariableTracker], init_kwargs: dict[str, VariableTracker] | None = None, source: Source | None = None, mutation_type: MutationType | None = None, @@ -577,14 +555,14 @@ def __init__( # When raising a new exception while another exception is already being # handled, the new exception's __context__ attribute is automatically # set to the handled exception. - self.__context__: VariableTracker = CONSTANT_VARIABLE_NONE + self.__context__: VariableTracker = ConstantVariable.create(None) # Set when user raised an exception from another: # raise ... from ... - self.__cause__: VariableTracker = CONSTANT_VARIABLE_NONE + self.__cause__: VariableTracker = ConstantVariable.create(None) # Boolean flag that controls whether the __context__ attribute is set - self.__suppress_context__: VariableTracker = CONSTANT_VARIABLE_FALSE + self.__suppress_context__: VariableTracker = ConstantVariable.create(False) # Contains the call stack where the exception was raised. - self.__traceback__: VariableTracker = CONSTANT_VARIABLE_NONE + self.__traceback__: VariableTracker = ConstantVariable.create(None) # The user stack at the time this exception was first raised. # Used to preserve the original exception location when re-raising. self.python_stack: traceback.StackSummary | None = None @@ -622,15 +600,17 @@ def call_setattr( name_var: VariableTracker, val: VariableTracker, ) -> VariableTracker: - def raise_error(msg: str) -> NoReturn: - raise_observed_exception( - TypeError, tx, args=[VariableTracker.build(tx, msg)] - ) - name = name_var.as_python_constant() if name == "__context__": # Constant can be either an Exceptior or None - assert isinstance(val, (ExceptionVariable, ConstantVariable)) + assert val.is_constant_none() or isinstance( + val, + ( + variables.ExceptionVariable, + variables.UserDefinedExceptionClassVariable, + variables.UserDefinedExceptionObjectVariable, + ), + ), f"{val} is not a valid exception context" self.set_context(val) elif name == "__cause__": if val.is_constant_none() or isinstance( @@ -643,25 +623,21 @@ def raise_error(msg: str) -> NoReturn: ), ): self.__cause__ = val - self.__suppress_context__ = variables.CONSTANT_VARIABLE_TRUE + self.__suppress_context__ = variables.ConstantVariable.create(True) else: - raise_error("exception cause must be None or derive from BaseException") + raise_type_error( + tx, "exception cause must be None or derive from BaseException" + ) elif name == "__suppress_context__": if val.is_constant_match(True, False): self.__suppress_context__ = val else: - raise_error("exception cause must be None or derive from BaseException") + raise_type_error( + tx, "exception cause must be None or derive from BaseException" + ) elif name == "__traceback__": if not TracebackVariable.is_valid_traceback(val): - raise_observed_exception( - TypeError, - tx, - args=[ - VariableTracker.build( - tx, "__traceback__ must be a traceback object or None" - ) - ], - ) + raise_type_error(tx, "__traceback__ must be a traceback object or None") self.__traceback__ = val else: unimplemented( @@ -672,7 +648,7 @@ def raise_error(msg: str) -> NoReturn: "`__cause__`, `__suppress_context__`, and `__traceback__` are supported.", hints=[*graph_break_hints.SUPPORTABLE], ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def call_method( self, @@ -691,7 +667,9 @@ def call_method( return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - if name == "__context__": + if name == "__class__": + return VariableTracker.build(tx, self.exc_type) + elif name == "__context__": return self.__context__ elif name == "__cause__": return self.__cause__ @@ -700,7 +678,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker elif name == "__traceback__": return self.__traceback__ elif name == "args": - return variables.ListVariable(list(self.args), source=self.source) + return VariableTracker.build( + tx, + tuple(self.args), + source=self.source and AttrSource(self.source, "args"), + ) return super().var_getattr(tx, name) def __str__(self) -> str: @@ -708,6 +690,17 @@ def __str__(self) -> str: __repr__ = __str__ + @staticmethod + def _debug_format_arg(arg: VariableTracker) -> str: + try: + return repr(arg.as_python_constant()) + except Exception: + return arg.debug_repr() + + def debug_repr(self) -> str: + args = ", ".join(self._debug_format_arg(arg) for arg in self.args) + return f"{self.python_type_name()}({args})" + class UnknownVariable(VariableTracker): """ @@ -782,7 +775,7 @@ def call_function( # We have to manually bind the freevars ourselves code = fn.get_code() if fn.closure: - raise_type_error_exc( + raise_type_error( tx, f"comptime function must not have free variables, but these variables were free: {code.co_freevars}", ) @@ -803,10 +796,13 @@ def call_function( else: raise RuntimeError(f"unsupported argument to comptime: {type(fn)}") - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class CellVariable(VariableTracker): + # PyCell_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/cellobject.c#L151 + _cpython_type = types.CellType + # If the cell existed before Dynamo tracing started, this will be the # VariableTracker that represents the cell content. # @@ -825,6 +821,9 @@ def __init__( super().__init__(**kwargs) self.pre_existing_contents = pre_existing_contents + def python_type(self) -> type: + return types.CellType + class NewGlobalVariable(VariableTracker): def __init__(self, **kwargs: Any) -> None: @@ -852,6 +851,9 @@ def __init__(self, fn_cls: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.fn_cls = fn_cls + def python_type(self) -> type: + return type + def call_apply( self, tx: "InstructionTranslator", @@ -1149,7 +1151,7 @@ def call_method( if kwargs: raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") self.non_differentiable = proxy_args_kwargs(args, {})[0] - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) if name != "save_for_backward": unimplemented( @@ -1176,7 +1178,7 @@ def call_method( assert self.saved_tensors is not None if not self.inference: if kwargs or not self.source: - raise_type_error_exc( + raise_type_error( tx, "save_for_backward() requires a source and no keyword arguments" ) tx.output.side_effects.track_save_for_backward(self, args) @@ -1187,7 +1189,7 @@ def call_method( self.saved_tensors.tensors = [] for arg in args: self.saved_tensors.tensors.append(arg) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name in ["save_for_backward", "mark_non_differentiable"]: @@ -1266,6 +1268,9 @@ def __init__(self, fn: Callable[..., VariableTracker], **kwargs: Any) -> None: super().__init__(**kwargs) self.fn = fn + def python_type(self) -> type: + return types.FunctionType + def call_function( self, tx: "InstructionTranslator", @@ -1342,94 +1347,15 @@ def call_function( ) -> VariableTracker: return self.obj.call_method(tx, self.name, list(args), kwargs) - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - if ( - name in ("__getitem__", "get") - and self.name == "__dict__" - and not kwargs - and args[0].is_python_constant() - and isinstance( - self.obj, - ( - variables.NNModuleVariable, - variables.UserDefinedClassVariable, - ), - ) - ): - obj = self.obj - key = args[0].as_python_constant() - if obj.has_key_in_generic_dict(tx, key): - if tx.output.side_effects.has_pending_mutation_of_attr(obj, key): - return tx.output.side_effects.load_attr(obj, key) - - # For instance dicts, read directly from __dict__ - if isinstance(obj.value.__dict__, dict): - raw_value = obj.value.__dict__[key] - raw_source = ( - DictGetItemSource(AttrSource(obj.source, "__dict__"), key) - if obj.source - else None - ) - return VariableTracker.build(tx, raw_value, raw_source) - - return obj.var_getattr(tx, key) - - # Return the default value for get - if name == "get": - if len(args) == 2: - return args[1] - else: - return variables.CONSTANT_VARIABLE_NONE - - elif ( - name == "__contains__" - and self.name == "__dict__" - and len(args) == 1 - and args[0].is_python_constant() - and not kwargs - and isinstance( - self.obj, - ( - variables.NNModuleVariable, - variables.UserDefinedClassVariable, - ), - ) - ): - obj = self.obj - key = args[0].as_python_constant() - if obj.has_key_in_generic_dict(tx, key): - return variables.CONSTANT_VARIABLE_TRUE - else: - return variables.CONSTANT_VARIABLE_FALSE - - elif name == "__setitem__" and self.name == "__dict__" and not kwargs: - if isinstance(self.obj, variables.NNModuleVariable): - # This matches how `setattr` is handled for NNModuleVariable - self.obj.convert_to_unspecialized(tx) - - return super().call_method(tx, name, args, kwargs) - - def get_forwarded_dict(self, tx: "InstructionTranslator") -> VariableTracker: - assert ( - self.name == "__dict__" - and isinstance(self.obj, variables.UserDefinedClassVariable) - and not tx.output.side_effects.has_pending_mutation(self.obj) - ) - self.obj.ban_mutation = True - return VariableTracker.build(tx, self.obj.value.__dict__, self.source) - class MethodWrapperVariable(VariableTracker): def __init__(self, method_wrapper: types.MethodWrapperType, **kwargs: Any) -> None: super().__init__(**kwargs) self.method_wrapper = method_wrapper + def get_real_python_backed_value(self) -> types.MethodWrapperType: + return self.method_wrapper + def call_function( self, tx: "InstructionTranslator", @@ -1440,7 +1366,7 @@ def call_function( args[0], variables.TensorVariable ): if not (len(args) == 1 and len(kwargs) == 0): - raise_type_error_exc( + raise_type_error( tx, "tensor attribute getter takes exactly one argument" ) # type: ignore[arg-type, attr-defined] @@ -1567,6 +1493,9 @@ def __init__(self, desc: types.GetSetDescriptorType, **kwargs: Any) -> None: super().__init__(**kwargs) self.desc = desc + def get_real_python_backed_value(self) -> types.GetSetDescriptorType: + return self.desc + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: if name == "__get__" and self.source: source = AttrSource(self.source, "__get__") @@ -1585,6 +1514,9 @@ def as_python_constant(self) -> types.GetSetDescriptorType: class PythonModuleVariable(VariableTracker): + # PyModule_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/moduleobject.c#L1203 + _cpython_type = types.ModuleType + _nonvar_fields = { "value", "is_torch", @@ -1602,6 +1534,9 @@ def python_type(self) -> type[types.ModuleType]: def as_python_constant(self) -> types.ModuleType: return self.value + def get_real_python_backed_value(self) -> types.ModuleType: + return self.value + def __repr__(self) -> str: return f"PythonModuleVariable({self.value})" @@ -1633,6 +1568,16 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # e.g., List[int] → typing.List[int] + # TODO(follow-up): add test for invalid subscript type + new_typing = self.value[key.as_python_constant()] + return TypingVariable(new_typing) + def call_method( self, tx: "InstructionTranslator", @@ -1640,11 +1585,7 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - # Create a new typing variable, e.g., `List[int]` - if name == "__getitem__" and len(args) == 1: - new_typing = self.value[args[0].as_python_constant()] - return TypingVariable(new_typing) - elif name == "__eq__": + if name == "__eq__": if len(args) == 1 and not kwargs: result = istype(args[0], TypingVariable) and self.value == args[0].value return variables.ConstantVariable.create(result) @@ -1662,7 +1603,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker from .builder import SourcelessBuilder, VariableBuilder if name in cmp_name_to_op_mapping: - return variables.GetAttrVariable(self, name) + return variables.GetAttrVariable( + self, name, py_type=type(getattr(self.value, name)) + ) if tx.output.side_effects.has_pending_mutation_of_attr(self, name): return tx.output.side_effects.load_attr(self, name) @@ -1677,6 +1620,9 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def reconstruct(self, codegen: "PyCodegen") -> None: if not isinstance(self.value, types.GenericAlias): return super().reconstruct(codegen) @@ -1760,6 +1706,9 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> Any: + return self.value + @classmethod def can_constant_fold_through(cls, fn: types.FunctionType) -> bool: mod = fn.__module__.split(".") @@ -1943,6 +1892,9 @@ class StringFormatVariable(VariableTracker): _nonvar_fields = {"format_string", *VariableTracker._nonvar_fields} + def python_type(self) -> type: + return str + @classmethod def create( cls, @@ -1978,6 +1930,26 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})" + @staticmethod + def _debug_format_arg(arg: VariableTracker) -> object: + try: + return arg.as_python_constant() + except Exception: + return arg.debug_repr() + + def debug_repr(self) -> str: + try: + rendered = self.format_string.format( + *[self._debug_format_arg(arg) for arg in self.sym_args], + **{ + key: self._debug_format_arg(value) + for key, value in self.sym_kwargs.items() + }, + ) + except Exception: + return repr(self) + return repr(rendered) + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null( lambda: codegen.extend_output( @@ -1997,11 +1969,17 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class ObjectVariable(VariableTracker): + # PyBaseObject_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/typeobject.c#L7243 + _cpython_type = object + # placeholder for unknown / opaque values def __init__(self, value: object, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def get_real_python_backed_value(self) -> object: + return self.value + def python_type(self) -> type[object]: return object @@ -2016,6 +1994,9 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def python_type(self) -> type: + return type(self.value) + @staticmethod def is_reorderable_logging_function( obj: Any, @@ -2034,7 +2015,7 @@ def call_function( ) -> VariableTracker: if tx.export: # For export cases, we can just make debugging functions no-ops - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) if not self.can_reorder_logs(self.value, args, kwargs): unimplemented( @@ -2048,7 +2029,7 @@ def call_function( ) tx.debug_locals.append((self, list(args))) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) def reconstruct(self, codegen: "PyCodegen") -> None: assert self.source is not None @@ -2084,13 +2065,19 @@ def __init__(self, value: Any, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def python_type(self) -> type: + return type(self.value) + + def get_real_python_backed_value(self) -> Any: + return self.value + def call_function( self, tx: "InstructionTranslator", args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) class LoggingLoggerVariable(VariableTracker): @@ -2102,6 +2089,12 @@ def __init__(self, value: logging.Logger, **kwargs: Any) -> None: super().__init__(**kwargs) self.value = value + def python_type(self) -> type: + return type(self.value) + + def get_real_python_backed_value(self) -> logging.Logger: + return self.value + def call_method( self, tx: "InstructionTranslator", @@ -2111,7 +2104,7 @@ def call_method( ) -> VariableTracker: if tx.export: # For export cases, we can just make logging functions no-ops. - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) method = getattr(self.value, name, None) function = getattr(method, "__func__", None) @@ -2120,7 +2113,7 @@ def call_method( ignore_set = torch._dynamo.config.ignore_logging_functions if method in ignore_set or function in ignore_set: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) unimplemented( gb_type="logging.Logger method not supported for non-export cases", @@ -2215,7 +2208,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return NumpyVariable(result) if variables.ConstantVariable.is_literal(result): return VariableTracker.build(tx, result) - return GetAttrVariable(self, name) + return GetAttrVariable(self, name, py_type=type(result)) class TorchVersionVariable(ConstantLikeVariable): @@ -2250,6 +2243,9 @@ class RandomClassVariable(VariableTracker): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + def python_type(self) -> type: + return type + def call_function( self, tx: "InstructionTranslator", @@ -2265,7 +2261,7 @@ def call_function( *graph_break_hints.USER_ERROR, ], ) - seed = variables.CONSTANT_VARIABLE_NONE if len(args) == 0 else args[0] + seed = variables.ConstantVariable.create(None) if len(args) == 0 else args[0] return RandomVariable( seed=seed, mutation_type=variables.base.ValueMutationNew() ) @@ -2279,6 +2275,8 @@ class RandomVariable(VariableTracker): Assumes that random objects behave the same given a set seed or state. """ + _cpython_type = random.Random + _nonvar_fields = { "random", *VariableTracker._nonvar_fields, @@ -2373,13 +2371,13 @@ def call_method( *[x.as_python_constant() for x in args], **{key: val.as_python_constant() for key, val in kwargs.items()}, ) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) elif name == "getstate": return self.wrap_state(self.random.getstate()) elif name == "setstate": tx.output.side_effects.mutation(self) self.random.setstate(self.unwrap_state(args[0])) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) elif name in self._supported_fn_names: tx.output.side_effects.mutation(self) state = self.random.getstate() @@ -2419,8 +2417,11 @@ def reconstruct(self, codegen: "PyCodegen") -> None: class WeakRefVariable(VariableTracker): + def python_type(self) -> type: + return weakref.ref + @staticmethod - # pyrefly: ignore[bad-param-name-override] + # pyrefly: ignore [bad-override, bad-param-name-override] def build( tx: "InstructionTranslator", weakref_value: weakref.ReferenceType[Any], diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index 2f0784ca2b9d1..dd47a266bd82c 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -38,6 +38,7 @@ handle_observed_exception, ObservedAttributeError, raise_observed_exception, + raise_type_error, unimplemented, UnspecializeRestartAnalysis, Unsupported, @@ -54,6 +55,7 @@ UnspecializedNNModuleSource, ) from ..utils import ( + enumerate_items_with_dict_position, get_custom_getattr, get_fake_value, is_lazy_module, @@ -69,10 +71,9 @@ unpatched_nn_module_call, unpatched_nn_module_call_impl, ) -from .base import raise_type_error_exc, typestr, ValueMutationNew, VariableTracker +from .base import typestr, ValueMutationNew, VariableTracker from .functions import invoke_and_store_as_constant from .lazy import LazyVariableTracker -from .lists import SliceVariable from .user_defined import UserDefinedObjectVariable @@ -80,6 +81,7 @@ from torch._dynamo.symbolic_convert import InstructionTranslator from .constant import ConstantVariable + from .dicts import DunderDictVariable def initialize_lazy_module( @@ -114,15 +116,12 @@ def convert_to_fake(x: Any) -> Any: fake_kwargs = {k: convert_to_fake(v) for k, v in proxy_kwargs.items()} try: mod._infer_parameters(mod, fake_args, fake_kwargs) # type: ignore[operator] - except AttributeError as e: + except AttributeError: # Re-raise with the original error message from the AttributeError - error_message = VariableTracker.build( - tx, str(e) or "AttributeError during lazy module initialization" - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=["AttributeError during lazy module initialization"], ) @@ -204,6 +203,11 @@ def __init__( self.source: Source = self.source self.nn_module_stack_source = self.source + def get_dict_vt(self, tx: "InstructionTranslator") -> "DunderDictVariable": + if not hasattr(self, "dict_vt"): + self.dict_vt = variables.DunderDictVariable.create(tx, self) + return self.dict_vt + def get_nn_module_stack_source(self) -> Source: res = self.nn_module_stack_source or self.source assert res @@ -215,6 +219,22 @@ def set_nn_module_stack_source(self, source: Source) -> None: def python_type(self) -> type: return self.module_type + def get_real_python_backed_value(self) -> object: + return self.value + + def bool_impl(self, tx: "InstructionTranslator") -> VariableTracker: + """nb_bool for nn.Module. + + nn.Module itself has no __bool__ or __len__, so bare modules are always + truthy. Subclasses like ModuleList/ModuleDict define __len__, so + bool(module) calls PyObject_IsTrue which falls through nb_bool (NULL) + to sq_length/mp_length. We evaluate on the real module to capture this. + """ + from .constant import ConstantVariable + + mod = tx.output.get_submodule(self.module_key) + return ConstantVariable.create(bool(mod)) + def _wrap_submodule( self, tx: "InstructionTranslator", @@ -279,18 +299,6 @@ def convert_to_unspecialized(self, tx: "InstructionTranslator") -> None: GenerationTracker.mark_class_dynamic(type(mod)) raise UnspecializeRestartAnalysis - def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool: - base = tx.output.get_submodule(self.module_key) - - if tx.output.side_effects.has_pending_mutation_of_attr(self, key): - mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) - return not isinstance(mutated_attr, variables.DeletedVariable) - - # Use object.__getattribute__ to access __dict__ directly, - # bypassing any custom __getattribute__ on the module. - base_dict = object.__getattribute__(base, "__dict__") - return key in base_dict - def _custom_getattr_fallback( self, base: torch.nn.Module, @@ -378,7 +386,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) if name == "__dict__": - return variables.GetAttrVariable(self, name, source=source) + return self.get_dict_vt(tx) subobj = None if name in base_dict: @@ -405,13 +413,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker if result is not None: return result # if we can't find a __getattr__, we can't parse this, raise attribute error - error_message = VariableTracker.build( - tx, f"'{type(base).__name__}' object has no attribute '{name}'" - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[f"'{type(base).__name__}' object has no attribute '{name}'"], ) if name == "forward": @@ -472,7 +477,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ], ) - return variables.GetAttrVariable(self, name, source=source) + return super().var_getattr(tx, name) def call_function( self, @@ -583,6 +588,104 @@ def call_function( kwargs, ) + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: "VariableTracker", + ) -> "VariableTracker": + # nn.Module containers (ModuleList/Dict/Sequential/ParameterDict/ParameterList) + # These are Python-level __getitem__, not CPython C slots. + # TODO(follow-up): add tests for ModuleList negative index, ModuleList/Sequential + # slice, ModuleDict missing key, invalid key types + from .lists import SliceVariable + from .tensor import SymNodeVariable + + module = tx.output.get_submodule(self.module_key) + + builtin_supported = ( + torch.nn.ModuleDict.__getitem__, + torch.nn.ModuleList.__getitem__, + torch.nn.ParameterDict.__getitem__, + torch.nn.ParameterList.__getitem__, + torch.nn.Sequential.__getitem__, + ) + # pyrefly: ignore[missing-attribute] + if type(module).__getitem__ not in builtin_supported: + if not ( + key.is_python_constant() + and isinstance(key.as_python_constant(), (str, int)) + ): + unimplemented( + gb_type="Invalid or non-const argument in nn.Module __getitem__", + context=f"mp_subscript_impl: {self} {key}", + explanation="Dynamo does not support calling " + f"method `__getitem__` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.", + hints=["Use constant arguments of type str or int for __getitem__"], + ) + fn = module.__getitem__.__func__ # pyrefly: ignore[missing-attribute] + + assert isinstance(fn, types.FunctionType) + + src = AttrSource(AttrSource(self.source, "__getitem__"), "__func__") # type: ignore[arg-type] + return tx.inline_user_function_return( + variables.UserFunctionVariable(fn, source=src), + [self, key], + {}, + ) + + if isinstance(key, SliceVariable): + # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is + # enabled for export. + if tx.output.export: + result = [] + keys = list(range(len(module)))[key.as_python_constant()] # type: ignore[arg-type] + for idx, submod in enumerate(module[key.as_python_constant()]): # type: ignore[arg-type] + k = keys[idx] + src = NNModuleSource(GetItemSource(self.source, k)) + result.append( + tx.output.register_attr_or_module( + submod, + k, + source=src, + ) + ) + + new_module = module[key.as_python_constant()] # type: ignore[index] + new_module_variable = tx.output.register_attr_or_module( + new_module, + f"{self}.__getitem__(slice)", + source=NNModuleSource( + GetItemSource(self.source, key.as_python_constant()) + ), + ) + return new_module_variable + else: + # slice on nn module results in a creation of new module instance, so we need to make it sourceless. + # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. + self.convert_to_unspecialized(tx) + + key_value = 0 + if isinstance(key, SymNodeVariable): + key_value = key.evaluate_expr(tx.output) + elif key.is_python_constant(): + key_value = key.as_python_constant() + else: + unimplemented( + gb_type="Unsupported key type for nn.Module.__getitem__", + context=f"mp_subscript_impl: {self} {key}", + explanation="Dynamo does not support getitem on " + "`nn.Module` with non-constant key.", + hints=[], + ) + + submod = module[key_value] # type: ignore[index] + return tx.output.register_attr_or_module( + submod, + self.module_key, + key_value, + source=NNModuleSource(GetItemSource(self.source, key_value)), + ) + def call_method( self, tx: "InstructionTranslator", @@ -592,7 +695,7 @@ def call_method( constant: bool = False, ) -> VariableTracker: from . import ListIteratorVariable, TupleVariable - from .constant import CONSTANT_VARIABLE_TRUE + from .constant import ConstantVariable key = self.module_key module = tx.output.get_submodule(key) @@ -639,16 +742,16 @@ def generic_call_method_helper(name: str) -> VariableTracker: if name == "_check_input_dim" and trace_rules.is_torch_inline_allowed( inspect.getfile(module.__class__._check_input_dim) # type: ignore[union-attr] ): - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) if name == "_get_item_by_idx": if not args[1].is_python_constant(): - raise_type_error_exc( + raise_type_error( tx, f"``nn.Module`` {module}'s call method {name} requires a constant index argument", ) if not isinstance(args[0], TupleVariable): - raise_type_error_exc( + raise_type_error( tx, f"``nn.Module`` {module}'s call method {name} requires a tuple as first argument", ) @@ -827,15 +930,6 @@ def gen_source(source: Source, name: str) -> Source: for name, submod in module.items(): # type: ignore[operator] items_result.append(named_embed(name, submod)) return ListIteratorVariable(items_result, mutation_type=ValueMutationNew()) - elif name == "__len__": - if args or kwargs: - raise_args_mismatch( - tx, - name, - "0 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - return VariableTracker.build(tx, len(module)) # type: ignore[arg-type] elif name == "__iter__": return ListIteratorVariable( self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() @@ -849,104 +943,6 @@ def gen_source(source: Source, name: str) -> Source: return VariableTracker.build( tx, args[0].as_python_constant() in module._modules ) - elif name == "__getitem__": - if kwargs or len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - builtin_supported = ( - torch.nn.ModuleDict.__getitem__, - torch.nn.ModuleList.__getitem__, - torch.nn.ParameterDict.__getitem__, - torch.nn.ParameterList.__getitem__, - torch.nn.Sequential.__getitem__, - ) - # pyrefly: ignore[missing-attribute] - if type(module).__getitem__ not in builtin_supported: - if not ( - args[0].is_python_constant() - and isinstance(args[0].as_python_constant(), (str, int)) - ): - unimplemented( - gb_type="Invalid or non-const argument in nn.Module __getitem__", - context=f"call_method: {self} {name} {args} {kwargs}", - explanation="Dynamo does not support calling " - f"method `{name}` of ``nn.Module`` {module} with a non-constant or non-(str, int) key.", - hints=[ - "Use constant arguments of type str or int for __getitem__" - ], - ) - fn = getattr(module, name).__func__ - - assert isinstance(fn, types.FunctionType) - - src = AttrSource(AttrSource(self.source, name), "__func__") # type: ignore[arg-type] - return tx.inline_user_function_return( - variables.UserFunctionVariable(fn, source=src), - [self] + list(args), - kwargs, - ) - - if isinstance(args[0], SliceVariable): - # TODO(anijain2305,export-team) - Remove this if condition when inlining of inbuilt nn modules is - # enabled for export. - if tx.output.export: - # Build a TupleVariable of NNModules - result = [] - - # Turn the slice into the list of integers - keys = list(range(len(module)))[args[0].as_python_constant()] # type: ignore[arg-type] - for idx, submod in enumerate(module[args[0].as_python_constant()]): # type: ignore[arg-type] - key = keys[idx] - src = NNModuleSource(GetItemSource(self.source, key)) - result.append( - tx.output.register_attr_or_module( - submod, - key, - source=src, - ) - ) - - new_module = module[args[0].as_python_constant()] # type: ignore[index] - new_module_variable = tx.output.register_attr_or_module( - new_module, - f"{self}.__getitem__(slice)", - source=NNModuleSource( - GetItemSource(self.source, args[0].as_python_constant()) - ), - ) - return new_module_variable - else: - # slice on nn module results in a creation of new module instance, so we need to make it sourceless. - # Convert to unspecialized so that UnspecializedNNModule variable can take care of it. - self.convert_to_unspecialized(tx) - - from .tensor import SymNodeVariable - - key_value = 0 - if isinstance(args[0], SymNodeVariable): - key_value = args[0].evaluate_expr(tx.output) - elif args[0].is_python_constant(): - key_value = args[0].as_python_constant() - else: - unimplemented( - gb_type="Unsupported key type for nn.Module.__getitem__", - context=f"call_method: {self} {name} {args} {kwargs}", - explanation="Dynamo does not support getitem on " - "`nn.Module` with non-constant key.", - hints=[], - ) - - submod = module[key_value] # type: ignore[index] - return tx.output.register_attr_or_module( - submod, - self.module_key, - key_value, - source=NNModuleSource(GetItemSource(self.source, key_value)), - ) elif ( name == "_get_abs_string_index" or ( @@ -977,6 +973,11 @@ def gen_source(source: Source, name: str) -> Source: else: return super().call_method(tx, name, list(args), kwargs) + def sq_length(self, tx: "InstructionTranslator") -> "VariableTracker": + """Sequence length for container modules (e.g., nn.Sequential).""" + module = tx.output.get_submodule(self.module_key) + return VariableTracker.build(tx, len(module)) # type: ignore[arg-type] + class UnspecializedNNModuleVariable(UserDefinedObjectVariable): _nonvar_fields = { @@ -1092,31 +1093,40 @@ def call_function( and istype(mod._call_impl, types.MethodType) # type: ignore[attr-defined] and mod.__call__.__func__ is unpatched_nn_module_call # type: ignore[operator] and mod._call_impl.__func__ is unpatched_nn_module_call_impl # type: ignore[attr-defined] - and "forward" not in mod.__dict__ + # Consult pending STORE_ATTR side effects too. During tracing the + # patched forward may not be visible in mod.__dict__ yet. + and not self.has_key_in_generic_dict(tx, "forward") ): forward_method = inspect.getattr_static(mod, "forward") if isinstance(forward_method, types.FunctionType): globals_vt = tx.nn_modules_globals_vt - if not ( - self.var_getattr(tx, "_backward_hooks").realize().len() # type: ignore[attr-defined] - or self.var_getattr(tx, "_backward_pre_hooks").realize().len() # type: ignore[attr-defined] - or self.var_getattr(tx, "_forward_hooks").realize().len() # type: ignore[attr-defined] - or self.var_getattr(tx, "_forward_hooks_with_kwargs") # type: ignore[attr-defined] - .realize() - .len() - or self.var_getattr(tx, "_forward_pre_hooks").realize().len() # type: ignore[attr-defined] - or self.var_getattr(tx, "_forward_pre_hooks_with_kwargs") # type: ignore[attr-defined] - .realize() - .len() - or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_backward_pre_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_backward_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_forward_hooks").len() # type: ignore[attr-defined] - or globals_vt.var_getattr(tx, "_global_forward_pre_hooks").len() # type: ignore[attr-defined] - ): + + def _hooks_dict_len(obj: VariableTracker, attr: str) -> int: + vt = obj.var_getattr(tx, attr) + vt = vt.realize() if hasattr(vt, "realize") else vt + return vt.len() # type: ignore[union-attr] + + has_hooks = any( + _hooks_dict_len(self, attr) + for attr in ( + "_backward_hooks", + "_backward_pre_hooks", + "_forward_hooks", + "_forward_hooks_with_kwargs", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + ) + ) or any( + _hooks_dict_len(globals_vt, attr) + for attr in ( + "_global_backward_pre_hooks", + "_global_backward_hooks", + "_global_forward_hooks", + "_global_forward_pre_hooks", + ) + ) + + if not has_hooks: name = "forward" fn = self.value_type.forward # type: ignore[attr-defined] @@ -1167,7 +1177,7 @@ def call_method( fn_vt = VariableTracker.build(tx, fn, source=source, realize=True) return fn_vt.call_function(tx, [self] + list(args), kwargs) - if name not in getattr(self.value, "__dict__", {}): + if not self.has_key_in_generic_dict(tx, name): try: method = inspect.getattr_static(type(self.value), name) except AttributeError: @@ -1248,6 +1258,8 @@ def getattr_helper( self, tx: "InstructionTranslator", field: str, name_vt: VariableTracker ) -> VariableTracker | None: dict_vt = self.var_getattr(tx, field) + if isinstance(dict_vt, variables.UserDefinedDictVariable): + dict_vt = dict_vt._base_vt if isinstance(dict_vt, variables.ConstDictVariable): return dict_vt.maybe_getitem_const(name_vt) return None @@ -1311,7 +1323,8 @@ def build_key_value( return key, value result = dict( - build_key_value(i, k, v) for i, (k, v) in enumerate(hooks_dict.items()) + build_key_value(i, k, v) + for i, k, v in enumerate_items_with_dict_position(hooks_dict) ) return variables.NNModuleHooksDictVariable( @@ -1337,14 +1350,12 @@ def manually_trace_nn_module_getattr( if out is None: out = self.getattr_helper(tx, "_buffers", name_vt) if out is None: - error_message = VariableTracker.build( - tx, - f"'{type(self.value).__name__}' object has no attribute '{name}'", - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], ) assert out is not None return out diff --git a/torch/_dynamo/variables/object_protocol.py b/torch/_dynamo/variables/object_protocol.py new file mode 100644 index 0000000000000..217a7a796f582 --- /dev/null +++ b/torch/_dynamo/variables/object_protocol.py @@ -0,0 +1,319 @@ +""" +Dynamo implementations of CPython's PyObject_* default slot algorithms. + +Analogous to CPython's Objects/object.c, this module holds the general +dispatch machinery that is independent of any specific type. +Per-type hook implementations (bool_impl, richcompare_impl, etc.) +live in their respective VT files. +""" + +from functools import lru_cache +from typing import TYPE_CHECKING + +from torch._C._dynamo import ( + get_type_slots, + has_slot, + PyMappingSlots, + PyNumberSlots, + PySequenceSlots, +) + +from .. import graph_break_hints +from ..exc import ( + handle_observed_exception, + ObservedTypeError, + raise_observed_exception, + raise_type_error, + unimplemented, +) +from ..utils import istype +from .base import NO_SUCH_SUBOBJ, VariableTracker +from .constant import ConstantVariable + + +if TYPE_CHECKING: + from ..symbolic_convert import InstructionTranslator + + +def vt_identity_compare( + left: VariableTracker, + right: VariableTracker, +) -> "VariableTracker | None": + """Try to determine Python identity (left is right) at trace time. + + Returns ConstantVariable(True/False) if determinable, else None. + Mirrors the logic in BuiltinVariable's handle_is handler. + """ + if left is right: + return ConstantVariable.create(True) + + left_val = left.get_real_python_backed_value() + right_val = right.get_real_python_backed_value() + left_known = left_val is not NO_SUCH_SUBOBJ + right_known = right_val is not NO_SUCH_SUBOBJ + + if left_known and right_known: + return ( + ConstantVariable.create(True) + if left_val is right_val + else ConstantVariable.create(False) + ) + + # One side has a concrete backing object, the other doesn't — they can't + # be the same object. + if left_known != right_known: + return ConstantVariable.create(False) + + # Mutable containers created during tracing: VT identity = Python identity. + from .dicts import ConstDictVariable + from .lists import ListVariable + from .sets import SetVariable + + if isinstance(left, (ConstDictVariable, ListVariable, SetVariable)): + return ConstantVariable.create(False) + + # Different Python types can never be the same object. + try: + if left.python_type() is not right.python_type(): + return ConstantVariable.create(False) + except NotImplementedError: + pass + + # Different exception types are never identical. + from .. import variables + + if ( + istype(left, variables.ExceptionVariable) + and istype(right, variables.ExceptionVariable) + and left.exc_type is not right.exc_type # type: ignore[attr-defined] + ): + return ConstantVariable.create(False) + + return None + + +@lru_cache(maxsize=256) +def _get_cached_slots(obj_type: type) -> tuple[int, int, int, int]: + """Get all type slots for a type (cached).""" + return get_type_slots(obj_type) + + +def type_implements_sq_length(obj_type: type) -> bool: + """Check whether obj_type implements __len__ as sequence protocol""" + seq_slots, _, _, _ = _get_cached_slots(obj_type) + return has_slot(seq_slots, PySequenceSlots.SQ_LENGTH) + + +def type_implements_mp_length(obj_type: type) -> bool: + """Check whether obj_type implements __len__ as mapping protocol""" + _, map_slots, _, _ = _get_cached_slots(obj_type) + return has_slot(map_slots, PyMappingSlots.MP_LENGTH) + + +def type_implements_nb_bool(obj_type: type) -> bool: + """Check whether obj_type implements the nb_bool slot (i.e. has __bool__ or __len__).""" + _, _, number_slots, _ = _get_cached_slots(obj_type) + return has_slot(number_slots, PyNumberSlots.NB_BOOL) + + +def type_implements_nb_int(obj_type: type) -> bool: + """Check whether obj_type implements the nb_int slot.""" + _, _, number_slots, _ = _get_cached_slots(obj_type) + return has_slot(number_slots, PyNumberSlots.NB_INT) + + +def type_implements_nb_index(obj_type: type) -> bool: + """Check whether obj_type implements the nb_index slot.""" + _, _, number_slots, _ = _get_cached_slots(obj_type) + return has_slot(number_slots, PyNumberSlots.NB_INDEX) + + +def type_implements_nb_float(obj_type: type) -> bool: + """Check whether obj_type implements the nb_float slot.""" + _, _, number_slots, _ = _get_cached_slots(obj_type) + return has_slot(number_slots, PyNumberSlots.NB_FLOAT) + + +def maybe_get_python_type(obj: VariableTracker) -> type: + try: + return obj.python_type() + except NotImplementedError: + unimplemented( + gb_type="Unsupported python_type() call", + context=f"{obj} does not implement python_type()", + explanation="This VariableTracker does not implement python_type(), " + "which is required for object protocol operations.", + hints=[ + *graph_break_hints.DYNAMO_BUG, + ], + ) + + +def vt_mapping_size( + tx: "InstructionTranslator", obj: "VariableTracker" +) -> "VariableTracker": + # ref: https://github.com/python/cpython/blob/v3.13.3/Objects/abstract.c#L2308-L2330 + T = maybe_get_python_type(obj) + if type_implements_mp_length(T): + return obj.mp_length(tx) + + if type_implements_sq_length(T): + raise_type_error(tx, f"{obj.python_type_name()} is not a mapping") + + raise_type_error(tx, f"object of type {obj.python_type_name()} has no len()") + + +def generic_len( + tx: "InstructionTranslator", obj: "VariableTracker" +) -> "VariableTracker": + # ref: https://github.com/python/cpython/blob/v3.13.3/Objects/abstract.c#L53-L69 + """ + Implements PyObject_Size/PyObject_Length semantics for VariableTracker objects. + Dispatches to sq_length (sequences) or mp_length (mappings) depending on the VT type. + """ + + T = maybe_get_python_type(obj) + if type_implements_sq_length(T): + return obj.sq_length(tx) + return vt_mapping_size(tx, obj) + + +def generic_bool(tx: "InstructionTranslator", obj: VariableTracker) -> VariableTracker: + """Mirrors PyObject_IsTrue. + + https://github.com/python/cpython/blob/c09ccd9c429/Objects/object.c#L2135-L2158 + + Resolution order: constants → nb_bool → mp_length/sq_length → truthy. + """ + from .constant import ConstantVariable + + if obj.is_python_constant(): + return ConstantVariable.create(bool(obj.as_python_constant())) + + obj_type = maybe_get_python_type(obj) + + if type_implements_nb_bool(obj_type): + result = obj.bool_impl(tx) + if result is not None: + return result + + try: + length = generic_len(tx, obj) + from .tensor import SymNodeVariable + + if isinstance(length, SymNodeVariable): + return SymNodeVariable.create(tx, length.as_proxy() > 0) + return ConstantVariable.create(length.as_python_constant() > 0) + except ObservedTypeError: + handle_observed_exception(tx) + + return ConstantVariable.create(True) + + +def vt_getitem( + tx: "InstructionTranslator", + obj: VariableTracker, + key: VariableTracker, +) -> VariableTracker: + """CPython's PyObject_GetItem — dispatch to the type's mp_subscript/sq_item. + + PyObject_GetItem: https://github.com/python/cpython/blob/62a6e898e01/Objects/abstract.c#L155-L206 + + CPython checks three branches in order: + 1. tp_as_mapping->mp_subscript (L161-166) + 2. tp_as_sequence->sq_item (L168-181) — only if key passes _PyIndex_Check + 3. PyType_Check(o) (L183-203) — type[int] → GenericAlias/__class_getitem__ + + Branch 1 is the common path (list, tuple, dict, range all have mp_subscript). + TODO(follow-up): use has_slot(map_slots, PyMappingSlots.MP_SUBSCRIPT) to gate + Branch 1 and has_slot(seq_slots, PySequenceSlots.SQ_ITEM) to gate Branch 2, + matching CPython's dispatch order. + TODO(follow-up): Branch 2 (sq_item) for C extension types that only have + tp_as_sequence (e.g. deque — Modules/_collectionsmodule.c:1888). + Branch 3 is handled by TypingVariable.mp_subscript_impl for typing module types + and by BuiltinVariable for builtin types like list[int]. + + Types that work via constant fold fallback (no dedicated mp_subscript_impl): + TODO(follow-up): str (unicode_subscript, Objects/unicodeobject.c:13809) + TODO(follow-up): bytes (bytes_subscript, Objects/bytesobject.c) + """ + return obj.mp_subscript_impl(tx, key) + + +def generic_int(tx: "InstructionTranslator", obj: VariableTracker) -> VariableTracker: + """Mirrors PyNumber_Long (int(x) dispatch). + + https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1520-L1632 + + Resolution: nb_int → nb_index → str/bytes/bytearray parsing → TypeError. + """ + from .constant import ConstantVariable + + # Fast path for int (sub)class instances — mirrors PyLong_Check at the + # top of PyNumber_Long (abstract.c:1531). Avoids infinite recursion for + # int subclasses like IntEnum whose __int__ calls int() again. + if obj.is_python_constant() and isinstance(obj.as_python_constant(), int): + return ConstantVariable.create(int(obj.as_python_constant())) + + obj_type = maybe_get_python_type(obj) + + if type_implements_nb_int(obj_type): + return obj.nb_int_impl(tx) + + if type_implements_nb_index(obj_type): + return obj.nb_index_impl(tx) + + # String/bytes/bytearray parsing fallback. + # https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1598-L1612 + if obj.is_python_constant() and isinstance( + obj.as_python_constant(), (str, bytes, bytearray) + ): + try: + return ConstantVariable.create(int(obj.as_python_constant())) + except ValueError as e: + raise_observed_exception(ValueError, tx, args=[str(e)]) + + raise_type_error( + tx, + f"int() argument must be a string, a bytes-like object " + f"or a real number, not '{obj.python_type_name()}'", + ) + + +def generic_float(tx: "InstructionTranslator", obj: VariableTracker) -> VariableTracker: + """Mirrors PyNumber_Float (float(x) dispatch). + + https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1635-L1692 + + Resolution: nb_float → nb_index → str parsing → TypeError. + """ + from .constant import ConstantVariable + + # Fast path: if the value is already a float constant, return it directly. + # Mirrors PyFloat_CheckExact fast path at the top of PyNumber_Float + # (abstract.c:1641-1643). + if obj.is_python_constant() and isinstance(obj.as_python_constant(), float): + return ConstantVariable.create(float(obj.as_python_constant())) + + obj_type = maybe_get_python_type(obj) + + if type_implements_nb_float(obj_type): + return obj.nb_float_impl(tx) + + # https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1674-L1685 + if type_implements_nb_index(obj_type): + return obj.nb_index_impl(tx) + + # PyFloat_FromString fallback — handles str and bytes. + # https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1691 + if obj.is_python_constant() and isinstance(obj.as_python_constant(), (str, bytes)): + try: + return ConstantVariable.create(float(obj.as_python_constant())) + except ValueError as e: + raise_observed_exception(ValueError, tx, args=[str(e)]) + + raise_type_error( + tx, + f"float() argument must be a string or a real number, " + f"not '{obj.python_type_name()}'", + ) diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index e376dcd53442c..137f6f158f580 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -42,8 +42,9 @@ ) from ..utils import GLOBAL_KEY_PREFIX from .base import VariableTracker -from .constant import CONSTANT_VARIABLE_TRUE, ConstantVariable +from .constant import ConstantVariable from .dicts import ConstDictVariable +from .hashable import HashableTracker from .lists import ListVariable from .misc import GetAttrVariable from .user_defined import UserDefinedObjectVariable @@ -148,7 +149,12 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker # which will directly inline if name in ("_init_group"): assert self.source - return GetAttrVariable(self, name, source=AttrSource(self.source, name)) + return GetAttrVariable( + self, + name, + py_type=type(getattr(self.value, name)), + source=AttrSource(self.source, name), + ) if name == "param_groups": from ..decorators import mark_static_address @@ -207,10 +213,8 @@ def safe_to_set_capturable(group: dict[str, Any]) -> bool: VariableTracker.build(tx, self.value.param_groups, source) ) for param_group_vt in param_groups_vt.items: - key = ConstDictVariable._HashableTracker( - ConstantVariable.create("capturable") - ) - param_group_vt.items[key] = CONSTANT_VARIABLE_TRUE + key = HashableTracker(ConstantVariable.create("capturable")) + param_group_vt.items[key] = ConstantVariable.create(True) def get_python_args( self, *args: Any, **kwargs: Any diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 41cd0715b9971..5cccd43ab0fbd 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -18,6 +18,7 @@ by limiting operations to known-safe patterns and failing fast for unsafe usage. """ +import enum import functools import inspect import types @@ -89,6 +90,10 @@ class OpaqueObjectClassVariable(UserDefinedVariable): """ def __init__(self, value: Any, **kwargs: Any) -> None: + assert not (isinstance(value, type) and issubclass(value, enum.Enum)), ( + f"Enum class {value} should use UserDefinedClassVariable, " + "not OpaqueObjectClassVariable" + ) super().__init__(**kwargs) self.value = value @@ -104,6 +109,9 @@ def is_python_constant(self) -> bool: def is_python_hashable(self) -> bool: return is_opaque_value_type(self.value) # pyrefly: ignore[bad-argument-type] + def get_python_hash(self) -> int: + return hash(self.value) + def as_proxy(self) -> Any: return self.value @@ -129,12 +137,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker elif isinstance(obj, property): obj = obj.__get__(None, self.value) # pyrefly: ignore[no-matching-overload] elif hasattr(obj, "__get__"): - # Check for pybind11 static properties (common in PyTorch C++ bindings) - # Reference: https://github.com/python/mypy/blob/131f9d92da58294bb2f273425e8778bd7d5b861f/mypy/stubgenc.py#L590 - type_name = type(obj).__name__ - if type_name == "pybind11_static_property": - obj = obj.__get__(None, self.value) + if not isinstance(type(obj).__dict__.get("__get__"), types.FunctionType): + # C-level descriptors are safe to resolve dynamically. + obj = getattr(self.value, name) else: + type_name = type(obj).__name__ unimplemented( gb_type="Unsupported descriptor on opaque class", context=f"class={self.value}, attr={name}, descriptor={type_name}", @@ -160,8 +167,11 @@ def call_function( # disallow creating reference-type opaque objects in the middle of the # program if is_opaque_reference_type(self.value): - # Skip __init__ to prevent dynamo from tracing it during resume - skip_code(self.value.__init__.__code__) + # Skip __init__ to prevent dynamo from tracing it during resume. + # C extension types (e.g. torch._C.Generator) have wrapper_descriptor + # __init__ without __code__, so guard the skip_code call. + if hasattr(self.value.__init__, "__code__"): + skip_code(self.value.__init__.__code__) unimplemented( gb_type="An opaque object was created in the middle of the program.", @@ -190,14 +200,18 @@ def call_function( opaque_obj = self.value( # pyrefly: ignore[not-callable] *constant_args, **constant_kwargs ) - fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( - tx.output.fake_mode, opaque_obj - ) # Capture sources from the VT args so subgraph reuse can apply # source replacement to resolve new ctor arg values on stamp-out. ctor_arg_sources = tuple(getattr(a, "source", None) for a in args) + if is_opaque_value_type(type(opaque_obj)): + fake_script_obj = opaque_obj + else: + fake_script_obj = torch._library.fake_class_registry.maybe_to_fake_obj( + tx.output.fake_mode, opaque_obj + ) + return TorchScriptObjectVariable.create( opaque_obj, fake_script_obj, @@ -225,9 +239,18 @@ def create( ctor_arg_sources: tuple[Source | None, ...] | None = None, **options: Any, ) -> "TorchScriptObjectVariable": - return TorchScriptObjectVariable( + assert not isinstance(value, enum.Enum), ( + f"Enum {type(value)} should use UserDefinedObjectVariable, not TorchScriptObjectVariable" + ) + out = TorchScriptObjectVariable( proxy, value, ctor_args_kwargs, ctor_arg_sources=ctor_arg_sources, **options ) + if isinstance(proxy, torch.fx.Proxy) and proxy.node.op != "placeholder": + from torch._dynamo.symbolic_convert import InstructionTranslator + + tx = InstructionTranslator.current_tx() + tx.output.current_tracer.record_proxyable_vt(out) + return out def __init__( self, @@ -293,31 +316,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker from .higher_order_ops import TorchHigherOrderOperatorVariable - if hasattr(self.value, "script_class_name") and is_opaque_type( - self.value.script_class_name - ): - real_obj = self.value.real_obj # pyrefly: ignore[missing-attribute] - - member_type = get_member_type( - type(real_obj), - name, - ) - if member_type is None: - # Special case: __bool__ and __len__ are used for truthiness checks. - # If they're not registered and the real object doesn't have them, - # raise ObservedAttributeError so the caller can fall back to - # treating the object as truthy (Python default behavior - if name in ("__bool__", "__len__") and not hasattr(real_obj, name): - raise_observed_exception(AttributeError, tx) - - unimplemented( - gb_type="Attempted to access unregistered member on an OpaqueObject", - context=f"value={real_obj}, attr={name}", - explanation=f"Member '{name}' is not registered for this opaque object type.", - hints=[ - f"Register '{name}' with a MemberType in register_opaque_type(members=...).", - ], - ) + real_obj = self.as_python_constant() + real_obj_type = type(real_obj) + if is_opaque_type(real_obj_type): + member_type = get_member_type(real_obj_type, name) if member_type == MemberType.USE_REAL: value = getattr(real_obj, name) @@ -343,6 +345,26 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) return super().var_getattr(tx, name) + elif is_opaque_value_type(real_obj_type): + return super().var_getattr(tx, name) + + elif name in ("__bool__", "__len__") and not hasattr(real_obj, name): + # Special case: __bool__ and __len__ are used for truthiness checks. + # If they're not registered and the real object doesn't have them, + # raise ObservedAttributeError so the caller can fall back to + # treating the object as truthy (Python default behavior + raise_observed_exception(AttributeError, tx) + + else: + unimplemented( + gb_type="Attempted to access unregistered member on an OpaqueObject", + context=f"value={real_obj}, attr={name}", + explanation=f"Member '{name}' is not registered for this opaque object type.", + hints=[ + f"Register '{name}' with a MemberType in register_opaque_type(members=...).", + ], + ) + method = getattr(self.value, name, None) if method is None: unimplemented( @@ -373,9 +395,18 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker method_name=name, ) + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: "VariableTracker", + ) -> "VariableTracker": + # Call call_method directly on this class to avoid the __getitem__ → + # mp_subscript_impl loop in VariableTracker.call_method. + return TorchScriptObjectVariable.call_method(self, tx, "__getitem__", [key], {}) + # We only support method calls on script objects. Interpreting the bytecodes # should go through var_getattr then call_function instead of call_method. - # + # However, it's possible for call_method to be used directly e.g. for __setattr__. @_raise_hard_error_if_graph_break( "Dynamo cannot safely trace script object due to graph break." @@ -389,61 +420,31 @@ def call_method( ) -> VariableTracker: from .builder import wrap_fx_proxy - if hasattr(self.value, "script_class_name") and is_opaque_type( - self.value.script_class_name - ): - real_obj = self.value.real_obj # pyrefly: ignore[missing-attribute] - value_type = type(real_obj) + real_obj = self.as_python_constant() + real_obj_type = type(real_obj) + if is_opaque_type(real_obj_type): + member_type = get_member_type(real_obj_type, name) - member_type = get_member_type( - value_type, - name, - ) - if member_type is None: - unimplemented( - gb_type="Attempted to access unregistered member on an OpaqueObject", - context=f"value={real_obj}, attr={name}", - explanation=f"Member '{name}' is not registered for this opaque object type.", - hints=[ - f"Register '{name}' with a MemberType in register_opaque_type(members=...).", - ], - ) - - if member_type == MemberType.INLINED: - proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) - - proxy = tx.output.create_proxy( - "call_method", - name, - args=(self.proxy, *proxy_args), - kwargs=proxy_kwargs, - ) - - return wrap_fx_proxy(tx=tx, proxy=proxy) - - elif member_type == MemberType.USE_REAL: - if inspect.getattr_static(value_type, "__getattr__", None) is not None: + if member_type == MemberType.USE_REAL: + if ( + inspect.getattr_static(real_obj_type, "__getattr__", None) + is not None + ): unimplemented( gb_type="Opaque object with custom __getattr__ not supported", - context=f"{value_type.__name__} with custom __getattr__", + context=f"{real_obj_type.__name__} with custom __getattr__", explanation="Dynamo does not support opaque objects types with custom __getattr__ methods", hints=[], ) - def get_real_value(x: VariableTracker) -> Any: - # For TorchScriptObjectVariable, get the real object directly - if isinstance(x, TorchScriptObjectVariable): - return x.get_real_value() - return x.as_python_constant() - - args_const = [get_real_value(x) for x in args] - kwargs_const = {k: get_real_value(v) for k, v in kwargs.items()} + args_const = [x.as_python_constant() for x in args] + kwargs_const = {k: v.as_python_constant() for k, v in kwargs.items()} method = getattr(real_obj, name) if name == "__setattr__": method(*args_const, **kwargs_const) - return real_obj + return real_obj # pyrefly: ignore[bad-return] constant_val = method(*args_const, **kwargs_const) @@ -453,7 +454,7 @@ def get_real_value(x: VariableTracker) -> Any: ): unimplemented( gb_type="Opaque object member with method-type USE_REAL returned a reference-type opaque object.", - context=f"Opaque object type: {value_type}. Method name: '{name}'", + context=f"Opaque object type: {real_obj_type}. Method name: '{name}'", explanation=( "To properly guard reference-type opaque objects, " "we must lift them as inputs to the graph. In order " @@ -462,18 +463,34 @@ def get_real_value(x: VariableTracker) -> Any: ), hints=[ f"Register member '{name}' with MemberType.INLINED in " - "register_opaque_type({value_type}, members=...).", + f"register_opaque_type({real_obj_type}, members=...).", ], ) return VariableTracker.build(tx, constant_val) + elif member_type == MemberType.INLINED or is_opaque_value_type( + real_obj_type + ): + proxy_args, proxy_kwargs = proxy_args_kwargs(args, kwargs) + + proxy = tx.output.create_proxy( + "call_method", + name, + args=(self.proxy, *proxy_args), + kwargs=proxy_kwargs, + ) + + return wrap_fx_proxy(tx=tx, proxy=proxy) + else: unimplemented( - gb_type="Unsupported member type on OpaqueObject", - context=f"value={real_obj}, attr={name}, member_type={member_type}", - explanation=f"Member type '{member_type}' is not supported for this operation.", - hints=[], + gb_type="Attempted to access unregistered member on an OpaqueObject", + context=f"value={real_obj}, attr={name}", + explanation=f"Member '{name}' is not registered for this opaque object type.", + hints=[ + f"Register '{name}' with a MemberType in register_opaque_type(members=...).", + ], ) unimplemented( @@ -489,43 +506,30 @@ def get_real_value(x: VariableTracker) -> Any: ) def as_python_constant(self) -> Any: - if is_opaque_value_type( - type(self.value.real_obj) # pyrefly: ignore[missing-attribute] - ): - return self.value.real_obj # pyrefly: ignore[missing-attribute] + if isinstance(self.value, FakeScriptObject): + return self.value.real_obj + elif is_opaque_value_type(type(self.value)): + return self.value + elif isinstance(self.value, torch.ScriptObject): + return self.value return super().as_python_constant() def is_python_hashable(self) -> bool: try: - hash(self.value.real_obj) # pyrefly: ignore[missing-attribute] + self.get_python_hash() return True except TypeError: return False def get_python_hash(self) -> int: - real_obj = ( - self.value.real_obj - if isinstance(self.value, FakeScriptObject) - else self.value - ) + real_obj = self.as_python_constant() return hash(real_obj) def is_python_equal(self, other: object) -> bool: - if not isinstance(other, TorchScriptObjectVariable): - return False - real_self = ( - self.value.real_obj - if isinstance(self.value, FakeScriptObject) - else self.value - ) - real_other = ( - other.value.real_obj - if isinstance(other.value, FakeScriptObject) - else other.value - ) + assert isinstance(other, VariableTracker) + real_self = self.as_python_constant() + real_other = other.as_python_constant() return real_self == real_other def get_real_value(self) -> Any: - if isinstance(self.value, FakeScriptObject): - return self.value.real_obj return self.as_python_constant() diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 1a7006f5d56ab..48de0f9ab4e3c 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -50,6 +50,9 @@ def __init__( self.param_vars = param_vars super().__init__(**kwargs) + def python_type(self) -> type: + return SDPAParams + def reconstruct(self, codegen: "PyCodegen") -> None: assert self.source is None assert self.param_vars is not None diff --git a/torch/_dynamo/variables/sets.py b/torch/_dynamo/variables/sets.py new file mode 100644 index 0000000000000..acbcd62333e32 --- /dev/null +++ b/torch/_dynamo/variables/sets.py @@ -0,0 +1,797 @@ +""" +Set-related variable tracking classes for PyTorch Dynamo. + +This module implements variable tracking for different types of set-like objects: +- Regular Python sets (set) +- Frozen sets (frozenset) +- Ordered sets (torch.utils._ordered_set.OrderedSet) +- Dictionary key sets (dict_keys views used as sets) + +These classes are responsible for tracking set operations during graph compilation, +maintaining proper guards for set mutations and element existence checks. + +The implementation uses a special HashableTracker wrapper to handle set elements +while preserving proper aliasing semantics. Sets are modeled internally as +dictionaries with None values. +""" + +import functools +import operator +from collections.abc import Iterable, Sequence +from typing import Any, Literal, TYPE_CHECKING + +from torch.utils._ordered_set import OrderedSet + +from .. import polyfills, variables +from ..bytecode_transformation import create_call_function, create_instruction +from ..exc import raise_observed_exception +from ..guards import GuardBuilder, install_guard +from ..source import AttrSource, is_constant_source, is_from_local_source +from ..utils import cmp_name_to_op_mapping, istype, raise_args_mismatch +from .base import ValueMutationNew, VariableTracker +from .constant import ConstantVariable +from .hashable import HashableTracker, is_hashable, raise_unhashable + + +if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.variables.builtin import BuiltinVariable + + +# [Adding a new supported class within the keys of SetVariable] +# see steps outlined for ConstDictVariable + + +class SetVariable(VariableTracker): + """Represents a Python set during symbolic execution.""" + + # PySet_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/setobject.c#L2436 + _cpython_type = set + + CONTAINS_GUARD = GuardBuilder.SET_CONTAINS + NOT_CONTAINS_GUARD = GuardBuilder.SET_NOT_CONTAINS + + def __init__( + self, + items: Iterable[VariableTracker | HashableTracker], + **kwargs: Any, + ) -> None: + # .clone() passes these arguments in kwargs but they're recreated below + if "original_items" in kwargs: + kwargs.pop("original_items") + if "should_reconstruct_all" in kwargs: + kwargs.pop("should_reconstruct_all") + + super().__init__(**kwargs) + + # Items can be either VariableTrackers or HashableTrackers (from set ops). + # For VariableTrackers, realize them to ensure aliasing guards are installed + # when the same object appears multiple times. + hashable_items = [] + for item in items: + if isinstance(item, HashableTracker): + # Already a HashableTracker from a set operation + hashable_items.append(item) + else: + # VariableTracker - realize to install guards, then wrap + # pyrefly: ignore [bad-argument-type] + hashable_items.append(HashableTracker(item.realize())) + self.items = dict.fromkeys(hashable_items, SetVariable._default_value()) + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) + self.original_items = dict.fromkeys( + hashable_items, SetVariable._default_value() + ) + + def debug_repr(self) -> str: + if not self.items: + return "set()" + else: + items: list[str] = [] + for v in self.items: + vt = v.vt if isinstance(v, HashableTracker) else v + val_str = repr(vt.value) if hasattr(vt, "value") else vt.debug_repr() + items.append(val_str) + return "{" + ",".join(items) + "}" + + @property + def set_items(self) -> set["HashableTracker"]: + return set(self.items.keys()) + + @staticmethod + def _default_value() -> VariableTracker: + # Variable to fill in the keys of the dictionary + return ConstantVariable.create(None) + + def as_proxy(self) -> Any: + return {k.vt.as_proxy() for k in self.set_items} + + def python_type(self) -> type: + return set + + def as_python_constant(self) -> Any: + return {k.vt.as_python_constant() for k in self.set_items} + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) + + def __contains__(self, vt: VariableTracker) -> bool: + assert isinstance(vt, VariableTracker) + if not is_hashable(vt): + return False + key = HashableTracker(vt) + return key in self.items and not isinstance( + self.items[key], variables.DeletedVariable + ) + + def len(self) -> int: + return sum( + not isinstance(x, variables.DeletedVariable) for x in self.items.values() + ) + + def has_new_items(self) -> bool: + return self.should_reconstruct_all or any( + self.is_new_item(self.original_items.get(key.vt), value) + for key, value in self.items.items() + ) + + def is_new_item( + self, value: VariableTracker | None, other: VariableTracker + ) -> bool: + if value and value.is_realized() and other.is_realized(): + return id(value.realize()) != id(other.realize()) + return id(value) != id(other) + + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + return [x.vt for x in self.items] + + def clone(self, **kwargs: Any) -> VariableTracker: + return super().clone(**kwargs) + + def is_python_hashable(self) -> bool: + return False + + def var_getattr(self, tx: "InstructionTranslator", name: str): + if name == "__class__": + return VariableTracker.build(tx, self.python_type()) + return super().var_getattr(tx, name) + + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> ConstantVariable: + return VariableTracker.build(tx, hasattr(set, name)) + + def install_set_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + if not self.source: + return + + if tx.output.side_effects.is_modified(self): + return + + contains = args[0] in self + if args[0].source is None and args[0].is_python_constant(): + guard_fn = ( + type(self).CONTAINS_GUARD if contains else type(self).NOT_CONTAINS_GUARD + ) + install_guard( + self.make_guard( + functools.partial( + guard_fn, + key=args[0].as_python_constant(), + ) + ) + ) + + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + try: + res = fn( + *[x.as_python_constant() for x in [self, *args]], + **{k: v.as_python_constant() for k, v in kwargs.items()}, + ) + except Exception as exc: + raise_observed_exception(type(exc), tx, args=list(exc.args)) + return VariableTracker.build(tx, res) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from ..utils import check_constant_args + from .builder import SourcelessBuilder + + if ( + name + in ( + "isdisjoint", + "union", + "intersection", + "difference", + "symmetric_difference", + ) + and check_constant_args(args, kwargs) + and self.python_type() is set + ): + py_type = self.python_type() + return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) + + # Lazy imports to avoid circular dependencies + from .dicts import DictItemsVariable, DictKeysVariable + + if name == "__init__": + temp_set_vt = SourcelessBuilder.create(tx, set).call_set( + tx, *args, **kwargs + ) + tx.output.side_effects.mutation(self) + self.items.clear() + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] + return ConstantVariable.create(None) + elif name == "add": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # Convert add to __setitem__ with None value + if not is_hashable(args[0]): + raise_unhashable(args[0], tx) + tx.output.side_effects.mutation(self) + self.items[HashableTracker(args[0])] = SetVariable._default_value() + return ConstantVariable.create(None) + elif name == "pop": + if kwargs or args: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + # Choose an item at random and pop it + try: + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] + except KeyError as e: + raise_observed_exception(KeyError, tx, args=list(e.args)) + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.pop(HashableTracker(result)) + return result + elif name == "isdisjoint": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return SourcelessBuilder.create(tx, polyfills.set_isdisjoint).call_function( + tx, [self, args[0]], {} + ) + elif name == "intersection": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return SourcelessBuilder.create( + tx, polyfills.set_intersection + ).call_function( + tx, + [self, *args], + {"cls": self.python_type_var()}, + ) + elif name == "intersection_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return SourcelessBuilder.create( + tx, polyfills.set_intersection_update + ).call_function(tx, [self, *args], {}) + elif name == "union": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return SourcelessBuilder.create(tx, polyfills.set_union).call_function( + tx, + [self, *args], + {"cls": self.python_type_var()}, + ) + elif name == "difference": + if kwargs: + raise_args_mismatch( + tx, name, f"Expect: 0 kwargs, Actual: {len(kwargs)} kwargs" + ) + return SourcelessBuilder.create(tx, polyfills.set_difference).call_function( + tx, + [self, *args], + {"cls": self.python_type_var()}, + ) + elif name == "difference_update": + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return SourcelessBuilder.create( + tx, polyfills.set_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "symmetric_difference": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return SourcelessBuilder.create( + tx, polyfills.set_symmetric_difference + ).call_function( + tx, + [self, *args], + {"cls": self.python_type_var()}, + ) + elif name == "symmetric_difference_update": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return SourcelessBuilder.create( + tx, polyfills.set_symmetric_difference_update + ).call_function(tx, [self, *args], {}) + elif name == "update" and self.is_mutable(): + if kwargs: + raise_args_mismatch(tx, name, "0 kwargs", f"{len(kwargs)} kwargs") + return SourcelessBuilder.create(tx, polyfills.set_update).call_function( + tx, [self, *args], {} + ) + elif name == "remove": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] not in self: + raise_observed_exception(KeyError, tx, args=args) + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.pop(HashableTracker(args[0])) + return ConstantVariable.create(None) + elif name == "discard": + if kwargs or len(args) != 1: + raise_args_mismatch( + tx, + name, + "1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if args[0] in self: + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.pop(HashableTracker(args[0])) + return ConstantVariable.create(None) + elif name in ("issubset", "issuperset"): + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + + op = { + "issubset": operator.le, + "issuperset": operator.ge, + } + other = args[0].realize() + if not istype(other, SetVariable): + other = SourcelessBuilder.create(tx, set).call_function(tx, [other], {}) + return SourcelessBuilder.create(tx, op.get(name)).call_function( + tx, [self, other], {} + ) + elif name in ("__and__", "__or__", "__xor__", "__sub__"): + m = { + "__and__": "intersection", + "__or__": "union", + "__xor__": "symmetric_difference", + "__sub__": "difference", + }.get(name) + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + raise_observed_exception( + TypeError, + tx, + args=[ + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ], + ) + assert m is not None + return self.call_method(tx, m, args, kwargs) + elif name in ("__rand__", "__ror__", "__rxor__", "__rsub__"): + m = { + "__rand__": "__and__", + "__ror__": "__or__", + "__rxor__": "__xor__", + "__rsub__": "__sub__", + }.get(name) + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + raise_observed_exception( + TypeError, + tx, + args=[ + f"unsupported operand type(s) for {name}: '{args[0].python_type_name()}' and '{self.python_type_name()}'" + ], + ) + assert m is not None + return args[0].call_method(tx, m, [self], kwargs) + elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + raise_observed_exception( + TypeError, + tx, + args=[ + f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" + ], + ) + m = { + "__iand__": "intersection_update", + "__ior__": "update", + "__ixor__": "symmetric_difference_update", + "__isub__": "difference_update", + }.get(name) + assert m is not None + self.call_method(tx, m, args, kwargs) + return self + elif name == "__eq__": + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + return ConstantVariable.create(False) + r = self.call_method(tx, "symmetric_difference", args, kwargs) + return VariableTracker.build(tx, len(r.set_items) == 0) # type: ignore[attr-defined] + elif name == "__ne__": + eq_result = self.call_method(tx, "__eq__", args, kwargs) + return VariableTracker.build(tx, not eq_result.value) # type: ignore[attr-defined] + elif name in cmp_name_to_op_mapping: + if not isinstance( + args[0], + ( + SetVariable, + variables.UserDefinedSetVariable, + DictItemsVariable, + DictKeysVariable, + ), + ): + return VariableTracker.build(tx, NotImplemented) + return VariableTracker.build( + tx, + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items), # type: ignore[attr-defined] + ) + elif name == "__contains__": + if not len(args): + raise_args_mismatch( + tx, + name, + "more than 1 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + if not (args and is_hashable(args[0])): + raise_unhashable(args[0], tx) + self.install_set_contains_guard(tx, args) + contains = args[0] in self + return VariableTracker.build(tx, contains) + elif name == "__len__": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return VariableTracker.build(tx, len(self.items)) + elif name == "copy": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + return self.clone( + items=self.items.copy(), mutation_type=ValueMutationNew(), source=None + ) + elif name == "clear": + if args or kwargs: + raise_args_mismatch( + tx, + name, + "0 args and 0 kwargs", + f"{len(args)} args and {len(kwargs)} kwargs", + ) + self.should_reconstruct_all = True + tx.output.side_effects.mutation(self) + self.items.clear() + return ConstantVariable.create(None) + elif name == "__iter__": + from .lists import ListIteratorVariable + + if self.source and not is_constant_source(self.source): + tx.output.guard_on_key_order.add(self.source) + return ListIteratorVariable( + self.unpack_var_sequence(tx), mutation_type=ValueMutationNew() + ) + return super().call_method(tx, name, args, kwargs) + + def python_type_var(self) -> "BuiltinVariable": + return variables.BuiltinVariable(set) + + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: + raise RuntimeError("Illegal to getitem on a set") + + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + return VariableTracker.build(tx, len(self.set_items)) + + +class OrderedSetClassVariable(VariableTracker): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + def as_python_constant(self) -> type[OrderedSet[Any]]: + return OrderedSet + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if name == "__new__": + from .misc import GetAttrVariable + + if self.source: + attr_source = AttrSource(self.source, name) + else: + attr_source = None + return GetAttrVariable( + self, name, py_type=type(getattr(OrderedSet, name)), source=attr_source + ) + else: + return super().var_getattr(tx, name) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .builtin import set_methods + + if name == "__new__": + if len(args) != 2 or kwargs: + raise_args_mismatch( + tx, + name, + "OrderedSet.__new__ only accepts one arg" + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + return variables.OrderedSetVariable([], mutation_type=ValueMutationNew()) + + resolved_fn = getattr(set, name) + if resolved_fn in set_methods and isinstance(args[0], variables.SetVariable): + return args[0].call_method(tx, name, args[1:], kwargs) + + return super().call_method(tx, name, args, kwargs) + + def call_function( + self, + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> "OrderedSetVariable": + if len(args) > 1 or kwargs: + raise_args_mismatch( + tx, + "OrderedSet", + "OrderedSet only accepts one arg" + f"{len(args)} args and {len(kwargs)} kwargs", + ) + + if len(args) == 0: + # pyrefly: ignore [implicit-any] + items = [] + else: + items = args[0].force_unpack_var_sequence(tx) + return variables.OrderedSetVariable(items, mutation_type=ValueMutationNew()) + + +class OrderedSetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "OrderedSet([])" + else: + items: list[str] = [] + for k in self.items: + key_str = ( + repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() + ) + items.append(key_str) + return "OrderedSet([" + ",".join(items) + "])" + + def as_python_constant(self) -> OrderedSet[Any]: + return OrderedSet([k.vt.as_python_constant() for k in self.set_items]) + + def python_type(self) -> type[OrderedSet[Any]]: + return OrderedSet + + # pyrefly: ignore[bad-override] + def python_type_var(self) -> OrderedSetClassVariable: + return OrderedSetClassVariable() + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.load_import_from("torch.utils._ordered_set", "OrderedSet") + ) + codegen.foreach([x.vt for x in self.set_items]) + codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.set_items))) + codegen.extend_output(create_call_function(1, False)) + + +class FrozensetVariable(SetVariable): + # PyFrozenSet_Type: https://github.com/python/cpython/blob/v3.13.0/Objects/setobject.c#L2526 + _cpython_type = frozenset + + def debug_repr(self) -> str: + if not self.items: + return "frozenset()" + else: + items: list[str] = [] + for k in self.items: + key_str = ( + repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() + ) + items.append(key_str) + return "{" + ",".join(items) + "}" + + @property + def set_items(self) -> set["HashableTracker"]: + return set(self.items.keys()) + + def python_type(self) -> type: + return frozenset + + def python_type_var(self) -> "BuiltinVariable": + return variables.BuiltinVariable(frozenset) + + def as_python_constant(self) -> Any: + return frozenset({k.vt.as_python_constant() for k in self.set_items}) + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.add_push_null( + lambda: codegen.extend_output( + [ + codegen.create_load_global("frozenset"), + ] + ) + ) + codegen.foreach([x.vt for x in self.set_items]) + codegen.extend_output( + [ + create_instruction("BUILD_LIST", arg=len(self.set_items)), + *create_call_function(1, False), + ] + ) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a frozenset") + elif name == "__init__": + # frozenset is immutable. Calling __init__ again shouldn't have any effect + return ConstantVariable.create(None) + elif name in ( + "copy", + "difference", + "intersection", + "symmetric_difference", + ): + r = super().call_method(tx, name, args, kwargs) + return FrozensetVariable(r.items) # type: ignore[attr-defined] + return super().call_method(tx, name, args, kwargs) + + def is_python_hashable(self) -> Literal[True]: + """ + Frozensets are immutable and hashable in Python. + """ + return True + + def get_python_hash(self) -> int: + return hash(self.as_python_constant()) + + def is_python_equal(self, other: object) -> bool: + return ( + isinstance(other, VariableTracker) + and self.as_python_constant() == other.as_python_constant() + ) + + +class DictKeySetVariable(SetVariable): + def debug_repr(self) -> str: + if not self.items: + return "dict_keys([])" + else: + items: list[str] = [] + for k in self.items: + key_str = ( + repr(k.vt.value) if hasattr(k.vt, "value") else k.vt.debug_repr() + ) + items.append(key_str) + return "dict_keys([" + ",".join(items) + "])" + + def install_set_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: + # Already EQUALS_MATCH guarded + pass + + @property + def set_items(self) -> Any: + return self.items + + def python_type(self) -> type: + from ..utils import dict_keys + + return dict_keys + + def as_python_constant(self) -> Any: + return dict.fromkeys( + {k.vt.as_python_constant() for k in self.set_items}, None + ).keys() + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + if name in ["add", "pop", "update", "remove", "discard", "clear"]: + raise RuntimeError(f"Illegal call_method {name} on a dict_keys") + return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index f8f5d97b3720d..c42e6a84d1aaa 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -11,12 +11,15 @@ from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented from ..graph_bytecode_inputs import ( + CURRENT_STREAM_INDEX, get_external_object_by_index, register_graph_created_object, + register_user_object, + reset_user_object_tracking, ) from ..source import CurrentStreamSource from .base import VariableTracker -from .constant import CONSTANT_VARIABLE_NONE, ConstantVariable +from .constant import ConstantVariable from .ctx_manager import FxTracebackAnnotateVariable from .lazy import LazyVariableTracker @@ -125,7 +128,7 @@ def _( def record_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) stream = _get_stream_by_index(stream_index) - stream.record_event(event) + event.record(stream) @record_event.register_fake @@ -143,7 +146,7 @@ def _( def wait_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) stream = _get_stream_by_index(stream_index) - stream.wait_event(event) + event.wait(stream) @wait_event.register_fake @@ -157,6 +160,47 @@ def _( has_side_effect(torch.ops.streams.wait_event.default) +@custom_op("streams::synchronize_event", mutates_args=()) +def synchronize_event(event_index: int) -> None: + event = _get_event_by_index(event_index) + event.synchronize() + + +@synchronize_event.register_fake +def _(event_index: int) -> None: + pass + + +has_side_effect(torch.ops.streams.synchronize_event.default) + + +@custom_op("streams::synchronize_device", mutates_args=()) +def synchronize_device(device_type: str, device_index: int) -> None: + torch.accelerator.synchronize(torch.device(device_type, device_index)) + + +@synchronize_device.register_fake +def _(device_type: str, device_index: int) -> None: + pass + + +has_side_effect(torch.ops.streams.synchronize_device.default) + + +@custom_op("streams::synchronize_stream", mutates_args=()) +def synchronize_stream(stream_index: int) -> None: + stream = _get_stream_by_index(stream_index) + stream.synchronize() + + +@synchronize_stream.register_fake +def _(stream_index: int) -> None: + pass + + +has_side_effect(torch.ops.streams.synchronize_stream.default) + + @custom_op("streams::wait_stream", mutates_args=()) def wait_stream(waiting_stream_index: int, waited_on_stream_index: int) -> None: waiting = _get_stream_by_index(waiting_stream_index) @@ -188,6 +232,15 @@ def sync_dealloc( torch.ops.streams.wait_event.default(wait_event_index, src_stream_index) +@sync_dealloc.register_fake +def _( + wait_event_index: int, + src_stream_index: int, + to_dealloc: torch.Tensor, +) -> None: + pass + + has_side_effect(torch.ops.streams.sync_dealloc.default) @@ -198,13 +251,15 @@ def record_stream(tensor: torch.Tensor, stream_index: int) -> None: @record_stream.register_fake def _( - src_stream_index: int, - wait_event_index: int, - to_dealloc: torch.Tensor, + tensor: torch.Tensor, + stream_index: int, ) -> None: pass +has_side_effect(torch.ops.streams.record_stream.default) + + class SymbolicStreamState: """Track the currently entered stream if any""" @@ -213,10 +268,23 @@ def __init__(self) -> None: cur_stack: list[StreamVariable] = [] if torch.accelerator.is_available(): - stream_var = LazyVariableTracker.create( - torch.accelerator.current_stream(), - source=CurrentStreamSource(torch.accelerator.current_stream().device), + # Reset the registry so the current stream is guaranteed index 0. + reset_user_object_tracking() + stream = torch.accelerator.current_stream() + source = CurrentStreamSource(stream.device) + # Register the current stream so it gets index 0 (registry is + # fresh at tracing start). The inductor wrapper updates this + # entry at runtime so cudagraph capture uses the capture stream + # instead of this stale trace-time stream. + index = register_user_object(stream, source) + assert index == CURRENT_STREAM_INDEX, ( + f"Current stream must be registered at index {CURRENT_STREAM_INDEX}, " + f"got {index}" ) + stream_var = LazyVariableTracker.create(stream, source=source) + # Set user_object_index as an instance attribute so accessing it + # does NOT trigger LazyVariableTracker realization. + stream_var.user_object_index = index # type: ignore[union-attr] cur_stack = [stream_var] # type: ignore[list-item] self.cur_stream_stack: collections.deque[StreamVariable] = collections.deque( @@ -240,6 +308,13 @@ def cur_stream(self, device: torch.device | None = None) -> "StreamVariable": def in_stream_context(self) -> bool: return len(self.cur_stream_stack) > 0 + def cur_stream_id(self) -> int: + """Get a Python object id for the current stream without realizing lazy variables.""" + stream = self.cur_stream_stack[-1] + if isinstance(stream, LazyVariableTracker) and not stream.is_realized(): + return id(stream.peek_value()) + return id(stream.value) + class StreamContextVariable(FxTracebackAnnotateVariable): """This represents torch.cuda.StreamContext""" @@ -279,6 +354,9 @@ def exit( tx.symbolic_stream_state.exit_stream() return super().exit(tx, *args) + def python_type(self) -> type: + return torch.cuda.StreamContext + def supports_graph_breaks(self) -> bool: return True @@ -290,6 +368,8 @@ def get_stream(self) -> "StreamVariable": class StreamVariable(StreamContextVariable): """Represents the device-agnostic torch.Stream class""" + _cpython_type = torch.Stream + def __init__( self, proxy: Proxy, @@ -312,6 +392,9 @@ def __init__( def python_type(self) -> type: return torch.Stream + def get_real_python_backed_value(self) -> object: + return self.value + def call_method( self, tx: "InstructionTranslator", @@ -324,11 +407,34 @@ def call_method( from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs from .builder import wrap_fx_proxy_cls - if name in ("wait_stream", "synchronize", "wait_event"): + if name == "wait_event": + event_arg = args[0] + assert isinstance(event_arg, EventVariable) + tx.output.create_proxy( + "call_function", + torch.ops.streams.wait_event, + (event_arg.user_object_index, self.user_object_index), + {}, + ) + return ConstantVariable.create(None) + elif name == "wait_stream": + other_stream = args[0] + assert isinstance(other_stream, StreamVariable) + tx.output.create_proxy( + "call_function", + torch.ops.streams.wait_stream, + (self.user_object_index, other_stream.user_object_index), + {}, + ) + return ConstantVariable.create(None) + elif name == "synchronize": tx.output.create_proxy( - "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + "call_function", + torch.ops.streams.synchronize_stream, + (self.user_object_index,), + {}, ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=ConstantVariable, @@ -338,11 +444,34 @@ def call_method( ), ) elif name == "record_event": - return wrap_fx_proxy_cls( - target_cls=EventVariable, + from .builder import wrap_fx_proxy + + tx.output.check_event_record_after_input_mutation(id(self.value)) + if args and isinstance(args[0], EventVariable): + event_var = args[0] + event = event_var.value + event_index = event_var.user_object_index + else: + event = self.value.record_event() + event_index = register_graph_created_object( + event, + EventVariable.make_construct_in_graph_event_fn( + TupleVariable([]), ConstDictVariable({}) + ), + ) + tx.output.create_proxy( + "call_function", + torch.ops.streams.record_event, + (event_index, self.user_object_index), + {}, + ) + return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( - "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + "call_function", + get_external_object_by_index, + (event_index,), + {}, ), ) elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs: @@ -422,6 +551,32 @@ def fn(index: int, codegen: "PyCodegen") -> None: return fn +class CudaStreamVariable(StreamVariable): + """Represents torch.cuda.Stream, preserving device-specific type and attributes.""" + + _cpython_type = torch.cuda.Stream + + def python_type(self) -> type: + return torch.cuda.Stream + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + from . import ConstantVariable + + if name == "cuda_stream": + from ..guards import GuardBuilder, install_guard + + if self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + + if hasattr(self.value, "cuda_stream"): + return ConstantVariable.create(self.value.cuda_stream) + + if hasattr(self.value, "native_handle"): + return ConstantVariable.create(self.value.native_handle) + + return super().var_getattr(tx, name) + + class EventVariable(VariableTracker): def __init__( self, @@ -437,6 +592,12 @@ def __init__( self.value = value self.user_object_index = user_object_index + def python_type(self) -> type: + return torch.Event + + def get_real_python_backed_value(self) -> object: + return self.value + def call_method( self, tx: "InstructionTranslator", @@ -448,32 +609,38 @@ def call_method( from .builder import wrap_fx_proxy_cls if name == "wait": + _, stream_index = EventVariable._get_stream_arg(tx, args, kwargs) tx.output.create_proxy( "call_function", torch.ops.streams.wait_event, ( self.user_object_index, - EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + stream_index, ), {}, ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "record": + stream_arg, stream_index = EventVariable._get_stream_arg(tx, args, kwargs) + tx.output.check_event_record_after_input_mutation(id(stream_arg.value)) tx.output.create_proxy( "call_function", torch.ops.streams.record_event, ( self.user_object_index, - EventVariable._get_stream_arg(tx, args, kwargs).user_object_index, + stream_index, ), {}, ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "synchronize": tx.output.create_proxy( - "call_method", name, *proxy_args_kwargs([self] + args, kwargs) + "call_function", + torch.ops.streams.synchronize_event, + (self.user_object_index,), + {}, ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) elif name == "query": return wrap_fx_proxy_cls( target_cls=ConstantVariable, @@ -504,7 +671,14 @@ def _get_stream_arg( tx: "InstructionTranslator", args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "StreamVariable": + ) -> tuple["StreamVariable", int]: + """Returns (stream_variable, stream_index_for_op). + + The ambient current stream is registered at index 0 in the external + object registry. The inductor wrapper updates index 0 at runtime so + that cudagraph capture sees the capture stream, not the stale + trace-time default stream. + """ stream_arg = None if args: stream_arg = args[0] @@ -512,9 +686,10 @@ def _get_stream_arg( stream_arg = kwargs.get("stream") if not stream_arg: - stream_arg = tx.symbolic_stream_state.cur_stream() + stream_var = tx.symbolic_stream_state.cur_stream() + return stream_var, stream_var.user_object_index # type: ignore[return-value] - return stream_arg # type: ignore[return-value] + return stream_arg, stream_arg.user_object_index # type: ignore[return-value] @staticmethod def make_construct_in_graph_event_fn( diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 5b58c9edbf7e4..66c6175895053 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -32,6 +32,7 @@ import torch._numpy as tnp import torch.fx import torch.random +from torch import sym_float, sym_int from torch._dynamo import compiled_autograd from torch._library.opaque_object import is_opaque_reference_type from torch._opaque_base import OpaqueBase @@ -73,7 +74,7 @@ tensortype_to_dtype, ) from .base import AttributeMutationNew, ValueMutationNew, VariableTracker -from .constant import CONSTANT_VARIABLE_NONE, CONSTANT_VARIABLE_TRUE, ConstantVariable +from .constant import ConstantVariable from .lists import ListIteratorVariable, SizeVariable from .script_object import TorchScriptObjectVariable from .user_defined import UserDefinedClassVariable @@ -253,9 +254,10 @@ def _sync_if_inplace_mutation( version_before is not None and version_after is not None and version_after > version_before - and has_tensor_arg ): - self.synchronize_attributes(tx) + if has_tensor_arg: + self.synchronize_attributes(tx) + tx.output.check_input_mutation_on_current_stream(tx) def debug_repr(self) -> str: # TODO: strip off fake tensor from repr here @@ -270,6 +272,19 @@ def python_type(self) -> type: def is_tensor(self) -> bool: return True + def bool_impl(self, tx: "InstructionTranslator") -> VariableTracker: + # THPVariable_bool calls at::Tensor::is_nonzero(), i.e. .item() != 0. + from .constant import ConstantVariable + + item = self.call_method(tx, "item", [], {}) + if isinstance(item, SymNodeVariable) and isinstance( + item.sym_num, torch.SymBool + ): + return item + if isinstance(item, ConstantVariable): + return VariableTracker.build(tx, bool(item.value)) + return SymNodeVariable.create(tx, item.as_proxy() != 0) + @staticmethod def specialize(value: torch.Tensor) -> dict[str, Any]: props: dict[str, Any] = { @@ -385,14 +400,12 @@ def dynamic_getattr( try: real_value = getattr(_input_associated_real_value, name) except AttributeError: - error_message = VariableTracker.build( - tx, - f"'{type(_input_associated_real_value).__name__}' object has no attribute '{name}'", - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[ + f"'{type(_input_associated_real_value).__name__}' object has no attribute '{name}'" + ], ) attr_source = AttrSource(self.source, name) @@ -485,6 +498,12 @@ def method_attr_retain_grad(self, tx: "InstructionTranslator") -> NoReturn: hints=[], ) + def method_attr_grad(self, tx: "InstructionTranslator") -> VariableTracker | None: + if tx.output.side_effects.has_pending_mutation_of_attr(self, "grad"): + return tx.output.side_effects.load_attr(self, "grad") + # None tells var_getattr to use default .grad handling + return None + def method_attr_data(self, tx: "InstructionTranslator") -> VariableTracker: return variables.TorchInGraphFunctionVariable( torch._C._autograd._get_data_attr # type: ignore[attr-defined] @@ -501,7 +520,7 @@ def method_attr_grad_fn( hints=[], ) else: - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def method_attr__version(self, tx: "InstructionTranslator") -> VariableTracker: from ..tensor_version_op import _tensor_version @@ -520,14 +539,14 @@ def call_obj_hasattr( # attributes and existing attributes. This is a bug and requires more # deep dive. if name in all_tensor_attrs: - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) try: var = VariableTracker.build(tx, getattr).call_function( tx, [self, VariableTracker.build(tx, name)], {} ) # in the event that TensorVariable returns NotImplemented - # BuiltinVariable.call_getattr returns GetAttrVariable + # GetAttrBuiltinVariable.call_function returns GetAttrVariable ret_val = not isinstance(var, GetAttrVariable) except (AttributeError, ObservedAttributeError): ret_val = False @@ -555,7 +574,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) elif name in self._strict_mode_conditional_banned_ops(): raise UnknownPropertiesDuringBackwardTrace( - f"Unknown property {name} during speculating backward, dynamo will insert contiguous call ahead and speculate it again" # noqa: B950 + f"Unknown property {name} during speculating backward, dynamo will insert contiguous call ahead and speculate it again" ) if name == "__class__": @@ -757,6 +776,41 @@ def _strict_mode_conditional_banned_ops(self) -> list[str]: torch._dynamo.config._autograd_backward_strict_mode_conditional_banned_ops ) + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # Tensor.__getitem__ is a custom C slot, not CPython's mp_subscript. + # TODO(follow-up): add tests for negative index, bool index, invalid key type + from .builder import SourcelessBuilder, VariableBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + + if self.is_strict_mode(tx) and "__getitem__" in self._strict_mode_banned_ops(): + unimplemented( + gb_type="Illegal __getitem__ invocation in strict mode", + context=f"mp_subscript_impl {self} {key}", + explanation="Dynamo currently does not support __getitem__ " + "invocation in strict mode.", + hints=[], + ) + + empty_kwargs: dict[str, VariableTracker] = {} + static_attr = all_tensor_attrs.get("__getitem__", None) + if static_attr is not None and can_dispatch_torch_function( + tx, (self, key), empty_kwargs + ): + if self.source: + func_var = VariableBuilder( + tx, + AttrSource(AttrSource(self.source, "__class__"), "__getitem__"), + )(static_attr) + else: + func_var = SourcelessBuilder.create(tx, torch.Tensor.__getitem__) + return dispatch_torch_function(tx, func_var, (self, key), empty_kwargs) + + return self.method___getitem__(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -778,12 +832,11 @@ def call_method( if name == "__deepcopy__": unimplemented( - gb_type="copy.deepcopy(tensor)", + gb_type="Attempted to copy.deepcopy a tensor", context=f"copy.deepcopy({self})", explanation="Dynamo does not support copy.deepcopy() on tensors.", hints=[ "Avoid calling copy.deepcopy() on tensors inside compiled regions.", - *graph_break_hints.SUPPORTABLE, ], ) @@ -816,7 +869,7 @@ def call_method( # This is seen in inspect signature where we check if the value is a default value if name == "__eq__" and isinstance(args[0], UserDefinedClassVariable): - return variables.CONSTANT_VARIABLE_FALSE + return variables.ConstantVariable.create(False) if name == "wait": if args or kwargs: @@ -880,6 +933,23 @@ def call_method( from_exc=e, ) + # Guard against unknown methods reaching the generic proxy path. + # For traceable wrapper subclasses (DTensor, NestedTensor), class_type + # is torch.Tensor, so check the example_value's actual type instead. + example_value = self.proxy.node.meta.get("example_value") + check_type = ( + type(example_value) if example_value is not None else self.class_type + ) + if not hasattr(check_type, name): + unimplemented( + gb_type="Unhandled tensor method", + context=f"call_method {self} {name} {args} {kwargs}", + explanation=f"Tensor method `{name}` is not defined on " + f"{check_type.__name__} and does not have an explicit " + "handler in TensorVariable.", + hints=[*graph_break_hints.SUPPORTABLE], + ) + from .builder import wrap_fx_proxy proxy = tx.output.create_proxy( @@ -1214,7 +1284,11 @@ def _collect_backward_inputs( # (Dynamo creates GetAttrVariable instead of TensorVariable) # # In-graph created tensors without proper source also can't be handled - # because subguards_allowed() returns False for SyntheticLocalSource. + # when user explicitly passes them as inputs, because + # subguards_allowed() returns False for SyntheticLocalSource. + # However, in auto-detect mode (error_on_non_leaf=False), source-less + # leaves are valid backward targets — they gained requires_grad via + # requires_grad_() and accumulate_grad_ writes to .grad directly. if var.has_grad_fn: if error_on_non_leaf: unimplemented( @@ -1225,17 +1299,18 @@ def _collect_backward_inputs( "Only pass leaf tensors (parameters, graph inputs) to backward(inputs=...)", ], ) - elif not var.source or isinstance(var.source, SyntheticLocalSource): - if error_on_non_leaf: - unimplemented( - gb_type="backward() with in-graph created tensor", - context=f"backward(inputs=[...]) with in-graph created tensor: {var}", - explanation="backward(inputs=[...]) with tensors created inside the " - "compiled function is not yet supported.", - hints=[ - "Only pass tensors that are inputs to the compiled function or captured from outside", - ], - ) + elif error_on_non_leaf and ( + not var.source or isinstance(var.source, SyntheticLocalSource) + ): + unimplemented( + gb_type="backward() with in-graph created tensor", + context=f"backward(inputs=[...]) with in-graph created tensor: {var}", + explanation="backward(inputs=[...]) with tensors created inside the " + "compiled function is not yet supported.", + hints=[ + "Only pass tensors that are inputs to the compiled function or captured from outside", + ], + ) else: node = var.proxy.node if node not in seen_nodes: @@ -1303,7 +1378,7 @@ def method_backward( # No leaf tensors found - nothing to accumulate gradients into. # This matches eager behavior where backward() is a no-op if there # are no leaves requiring grad. - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) else: provided_vars = ( inputs.items @@ -1359,7 +1434,7 @@ def method_backward( grad_mode_var.exit(tx) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) def method_data_ptr( self, @@ -1369,6 +1444,33 @@ def method_data_ptr( ) -> "DataPtrVariable": return DataPtrVariable(self) + def method_const_data_ptr( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> "DataPtrVariable": + return DataPtrVariable(self, method_name="const_data_ptr") + + def method_record_stream( + self, + tx: "InstructionTranslator", + stream: VariableTracker, + ) -> VariableTracker: + from .streams import StreamVariable + + if not isinstance(stream, StreamVariable): + raise RuntimeError( + f"record_stream() expects a Stream argument, got {stream.python_type().__name__}" + ) + tx.output.create_proxy( + "call_function", + torch.ops.streams.record_stream, + (self.as_proxy(), stream.user_object_index), + {}, + ) + return ConstantVariable.create(None) + def method_item( self, tx: "InstructionTranslator", @@ -1391,6 +1493,95 @@ def method_item( ) return None + def nb_index_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # CPython: only integer tensors of a single element can be converted + # to an index. Mirrors THPVariable_index_scalar in + # https://github.com/pytorch/pytorch/blob/7cfd054075b/tools/autograd/templates/python_variable_methods.cpp#L372-L385 + if self.dtype is not None and ( + not self.dtype.is_floating_point and not self.dtype.is_complex + ): + item = self.call_method(tx, "item", [], {}) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + sym_int, + (item.as_proxy(),), + {}, + ), + ) + raise_observed_exception( + TypeError, + tx, + args=[ + "only integer tensors of a single element can be converted to an index" + ], + ) + + def nb_int_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # CPython: THPVariable_integral_scalar handles both integer and float + # tensors (floats are truncated to int). Complex tensors raise + # RuntimeError at runtime. + if self.dtype is not None and self.dtype.is_complex: + raise_observed_exception( + RuntimeError, + tx, + args=["value cannot be converted to type int64_t without overflow"], + ) + # For known non-complex dtypes and unknown dtype (None), emit the + # proxy and let it fail at runtime if the dtype is unsupported. + item = self.call_method(tx, "item", [], {}) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + sym_int, + (item.as_proxy(),), + {}, + ), + ) + + def method___int__(self, tx: "InstructionTranslator") -> VariableTracker: + return self.nb_int_impl(tx) + + def nb_float_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # CPython: THPVariable_float_scalar dispatches to double. + # Complex tensors raise RuntimeError at runtime. + if self.dtype is not None and self.dtype.is_complex: + raise_observed_exception( + RuntimeError, + tx, + args=["value cannot be converted to type double without overflow"], + ) + item = self.call_method(tx, "item", [], {}) + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + sym_float, + (item.as_proxy(),), + {}, + ), + ) + + def method___float__(self, tx: "InstructionTranslator") -> VariableTracker: + return self.nb_float_impl(tx) + def method___getitem__( self, tx: "InstructionTranslator", @@ -1443,6 +1634,10 @@ def _warn_capture_scalar_outputs() -> None: ) def method___len__(self, tx: "InstructionTranslator") -> VariableTracker: + return self.sq_length(tx) + + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + """Sequence length for tensors (size along first dimension).""" return self.call_method(tx, "size", [VariableTracker.build(tx, 0)], {}) def method___iter__(self, tx: "InstructionTranslator") -> ListIteratorVariable: @@ -1499,7 +1694,7 @@ def method___setitem__( if config.use_graph_deduplication or config.track_nodes_for_deduplication: tx.output.region_tracker.add_node_mutation(proxy.node, 0) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) def method_resize_( self, @@ -1792,16 +1987,100 @@ def method_requires_grad_( if requires_grad is not True: requires_grad = requires_grad.as_python_constant() # type: ignore[attr-defined] - if self.as_proxy().node.meta["example_value"].requires_grad != requires_grad: - unimplemented( - gb_type="Unsupported Tensor.requires_grad_() call", - context=f"call_method {self} requires_grad_", - explanation="Dynamo does not support changes to a Tensor's " - "`requires_grad` through calling `requires_grad_()`.", - hints=[], - ) - else: - return self + node = self.as_proxy().node + example_value = node.meta["example_value"] + if example_value.requires_grad != requires_grad: + # For graph inputs (tensors with source), requires_grad_() is a + # metadata mutation that we can't trace — graph break as before. + if self.source: + unimplemented( + gb_type="Unsupported Tensor.requires_grad_() call", + context=f"call_method {self} requires_grad_", + explanation="Dynamo does not support changes to a Tensor's " + "`requires_grad` through calling `requires_grad_()`.", + hints=[], + ) + # On a previous attempt, we traced through requires_grad_() but + # discovered at compile time that the tainted intermediate leaked + # as a graph output. Graph break here to preserve partial + # acceleration for code before requires_grad_(). + if tx.speculation_log.graph_break_on_requires_grad_: + unimplemented( + gb_type="requires_grad_() intermediate leaked as output", + context=f"call_method {self} requires_grad_", + explanation="An intermediate tensor with requires_grad_() called " + "on it (or a tensor derived from it) is returned from the " + "compiled region. Graph breaking here to preserve partial " + "acceleration.", + hints=[ + "Call .detach() before returning if you only need values.", + "Consume the gradient inside the compiled function " + "(call backward() and use .grad), " + "or move requires_grad_() outside torch.compile.", + ], + ) + # AOTAutograd re-traces the FX graph under functorch transforms + # (functionalization). Functorch's checkSupportsInplaceRequiresGrad() + # rejects requires_grad_() when the dynamic layer stack is non-empty. + # We wrap the call with set_inplace_requires_grad_allowed(True) to + # bypass this check, matching GradInplaceRequiresGradCtxManagerVariable + # in ctx_manager.py (which handles the explicit context manager case). + # + # Lines below do two things in parallel: + # 1. Mutate trace-time state so example_value.requires_grad_() works + # 2. Emit FX nodes so the same toggle happens during AOTAutograd re-trace + prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed(True) + try: + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (True,), + {}, + ) + tx.output.create_proxy( + "call_method", + "requires_grad_", + (self.as_proxy(),), + {}, + ) + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (prev_state,), + {}, + ) + finally: + torch._C._functorch.set_inplace_requires_grad_allowed(prev_state) + example_value.requires_grad_(requires_grad) + self.requires_grad = requires_grad + if requires_grad: + tx.output.leaf_var_creation_order.append(self) + # For source-less intermediates, initialize .grad = None in + # side effects so the accumulate_grad polyfill can read/write + # .grad naturally. Graph inputs don't need this — they handle + # .grad through their source. + if not self.source and tx.output.side_effects.is_attribute_mutation( + self + ): + tx.output.side_effects.store_attr( + self, "grad", variables.ConstantVariable.create(None) + ) + return self + + def method_detach_(self, tx: "InstructionTranslator") -> "TensorVariable": + from .builder import wrap_fx_proxy + + proxy = tx.output.create_proxy( + "call_method", + "detach_", + (self.as_proxy(),), + {}, + ) + # Run the fake op so the proxy metadata reflects the detached tensor state. + wrap_fx_proxy(tx, proxy) + self.synchronize_attributes(tx) + return self def method_share_memory_(self) -> NoReturn: unimplemented( @@ -1923,7 +2202,7 @@ def create( out = SymNodeVariable(proxy, sym_num, **options) if proxy.node.op != "placeholder": - tx.output.current_tracer.record_tensor_or_symint_vt(out) + tx.output.current_tracer.record_proxyable_vt(out) return out def __init__(self, proxy: Any, sym_num: Any, **kwargs: Any) -> None: @@ -1945,6 +2224,18 @@ def is_symnode_like(self) -> bool: def as_proxy(self) -> Any: return self.proxy + def bool_impl( + self, + tx: "InstructionTranslatorBase", + ) -> VariableTracker: + # long_bool / float_bool: non-zero check. SymBool is already boolean. + # https://github.com/python/cpython/blob/c09ccd9c429/Objects/longobject.c#L5200 + # https://github.com/python/cpython/blob/c09ccd9c429/Objects/floatobject.c#L853 + if isinstance(self.sym_num, torch.SymBool): + return self + assert isinstance(self.sym_num, (torch.SymInt, torch.SymFloat)) + return SymNodeVariable.create(tx, self.as_proxy() != 0) + def as_tensor(self, tx: "InstructionTranslatorBase", dtype: Any) -> TensorVariable: if self._tensor_var is None: self._tensor_var = VariableTracker.build( @@ -1985,6 +2276,52 @@ def call_method( ), ) + def nb_int_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # SymInt.__int__: https://github.com/pytorch/pytorch/blob/ee336ca5440939b8ad65e916d47421f849e56178/torch/__init__.py#L462 + # SymFloat.__int__: https://github.com/pytorch/pytorch/blob/ee336ca5440939b8ad65e916d47421f849e56178/torch/__init__.py#L682 + # SymBool.__int__: https://github.com/pytorch/pytorch/blob/ee336ca5440939b8ad65e916d47421f849e56178/torch/__init__.py#L784 + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + sym_int, + (self.as_proxy(),), + {}, + ), + ) + + def method___int__( + self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + ) -> VariableTracker: + return self.nb_int_impl(tx) + + def nb_float_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # SymFloat.__float__: https://github.com/pytorch/pytorch/blob/ee336ca5440939b8ad65e916d47421f849e56178/torch/__init__.py#L679 + from .builder import wrap_fx_proxy + + return wrap_fx_proxy( + tx=tx, + proxy=tx.output.create_proxy( + "call_function", + sym_float, + (self.as_proxy(),), + {}, + ), + ) + + def method___float__( + self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + ) -> VariableTracker: + return self.nb_float_impl(tx) + def is_python_hashable(self) -> bool: return True @@ -2043,7 +2380,7 @@ def insert_into_graph() -> VariableTracker: ), ) - if name in ["T", "real", "imag"]: + if name in ["T", "real", "imag", "flat"]: proxy = tx.output.create_proxy( "call_function", numpy_attr_wrapper, @@ -2291,6 +2628,9 @@ def __init__( # Example_value will always have device="meta" self.example_value = example_value + def python_type(self) -> type: + return torch.UntypedStorage + def call_method( self, tx: "InstructionTranslator", @@ -2346,12 +2686,17 @@ class DataPtrVariable(VariableTracker): def __init__( self, from_tensor: TensorVariable, + method_name: str = "data_ptr", **kwargs: Any, ) -> None: super().__init__(**kwargs) self.from_tensor = from_tensor + self.method_name = method_name + + def python_type(self) -> type: + return int def reconstruct(self, codegen: "PyCodegen") -> None: codegen(self.from_tensor) - codegen.load_method("data_ptr") + codegen.load_method(self.method_name) codegen.call_method(0) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 1f93a2842c9d6..7e5920cff8a28 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -42,12 +42,7 @@ import torch.nn import torch.utils._pytree as _pytree from torch._C import DispatchKeySet -from torch._dynamo.variables.constant import ( - CONSTANT_VARIABLE_FALSE, - CONSTANT_VARIABLE_NONE, - CONSTANT_VARIABLE_TRUE, - ConstantVariable, -) +from torch._dynamo.variables.constant import ConstantVariable from torch._dynamo.variables.streams import StreamVariable from torch._dynamo.variables.torch_function import TorchFunctionModeVariable from torch._guards import Guard, Source, TracingContext @@ -63,7 +58,13 @@ tracable_create_parameter, ) from ..device_interface import get_registered_device_interfaces -from ..exc import raise_observed_exception, unimplemented, UserError, UserErrorType +from ..exc import ( + raise_observed_exception, + raise_type_error, + unimplemented, + UserError, + UserErrorType, +) from ..guards import GuardBuilder, install_guard from ..source import ( AttrSource, @@ -82,7 +83,7 @@ proxy_args_kwargs, unwrap_if_wrapper, ) -from .base import raise_type_error_exc, typestr, VariableTracker +from .base import typestr, VariableTracker from .ctx_manager import ( AutocastModeVariable, ProfilerContextVariable, @@ -91,7 +92,7 @@ ) from .distributed import DistributedVariable from .functions import bind_args_cached, NestedUserFunctionVariable -from .lists import ListVariable, NamedTupleVariable, TupleVariable +from .lists import ListVariable, TupleVariable from .script_object import TorchScriptObjectVariable from .torch_function import ( can_dispatch_torch_function, @@ -99,6 +100,7 @@ TensorWithTFOverrideVariable, TorchFunctionModeStackVariable, ) +from .user_defined import UserDefinedTupleVariable try: @@ -114,7 +116,7 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator - from torch._library.opaque_object import OpaqueType + from torch._opaque_base import OpaqueBase from torch.utils._pytree import TreeSpec @@ -171,6 +173,7 @@ torch.cuda.is_initialized, torch.xpu.current_device, torch.xpu.is_initialized, + torch.__future__.get_overwrite_module_params_on_conversion, ] constant_fold_functions = [ @@ -251,7 +254,7 @@ def _check_for_gradient_edge(var: VariableTracker, arg_name: str) -> None: """ from .lists import BaseListVariable - if isinstance(var, NamedTupleVariable) and var.tuple_cls is GradientEdge: + if isinstance(var, UserDefinedTupleVariable) and type(var.value) is GradientEdge: # Try to get source info for context source_info = var.source.name if var.source else None context = f"GradientEdge in {arg_name}" @@ -284,7 +287,7 @@ def _collect_all_grad_fns(tensor: torch.Tensor) -> set[torch.autograd.graph.Node grad_fns: set[torch.autograd.graph.Node] = set() - plain_tensors: list[torch.SymInt | torch.Tensor | int | OpaqueType] = [] + plain_tensors: list[torch.SymInt | torch.Tensor | int | OpaqueBase] = [] # Get all plain tensors (handles nested subclasses) if is_traceable_wrapper_subclass(tensor): get_plain_tensors(tensor, out=plain_tensors) @@ -313,6 +316,8 @@ def _collect_tensors_with_sources( Used by handle_autograd_grad to collect tensors from the outputs and inputs arguments for grad_fn reachability analysis. """ + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + from .lazy import LazyVariableTracker from .lists import BaseListVariable from .tensor import TensorVariable @@ -320,7 +325,29 @@ def _collect_tensors_with_sources( results: list[tuple[torch.Tensor, str | None]] = [] if isinstance(var, TensorVariable): fake_tensor = var.as_proxy().node.meta.get("example_value") - assert isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor) + assert isinstance(fake_tensor, torch.Tensor) + if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor): + pass + elif is_traceable_wrapper_subclass(fake_tensor): + # For tensor subclasses (e.g. DTensor), verify the inner tensors + # are FakeTensors but keep the original subclass for grad_fn + # reachability analysis. + plain: list[object] = [] + torch._subclasses.fake_tensor.get_plain_tensors( + fake_tensor, # pyrefly: ignore[bad-argument-type] + out=plain, # pyrefly: ignore[bad-argument-type] + ) + assert all( + isinstance(t, torch._subclasses.fake_tensor.FakeTensor) + for t in plain + if isinstance(t, torch.Tensor) + ), ( + f"Expected all plain tensors to be FakeTensors, got {[type(t) for t in plain]}" + ) + else: + raise AssertionError( + f"Expected FakeTensor or subclass, got {type(fake_tensor)}" + ) source_name = var.source.name if var.source else None results.append((fake_tensor, source_name)) elif isinstance(var, LazyVariableTracker): @@ -344,6 +371,43 @@ def _collect_tensors_with_sources( return results +def _collect_placeholder_nodes(var: "VariableTracker") -> list[torch.fx.Node]: + """Recursively collect FX placeholder nodes from a VariableTracker. + + The returned placeholder nodes carry grapharg.example (real tensor) and + example_value (FakeTensor) metadata — comparing these reveals lost + autograd linkage (e.g., grad_fn dropped during tracing). + See NOTE [Detecting lost autograd linkage in closure-captured tensors]. + """ + from .lazy import LazyVariableTracker + from .lists import BaseListVariable + from .tensor import TensorVariable + + result: list[torch.fx.Node] = [] + if isinstance(var, TensorVariable): + node = var.as_proxy().node + if node.op == "placeholder": + result.append(node) + elif isinstance(var, LazyVariableTracker): + result.extend(_collect_placeholder_nodes(var.realize())) + elif isinstance(var, BaseListVariable): + for item in var.items: + result.extend(_collect_placeholder_nodes(item)) + else: + unimplemented( + gb_type="_autograd_grad with unsupported argument type", + context=f"got {type(var).__name__}", + explanation=( + f"_autograd_grad() received an argument of type {type(var).__name__} " + "which is not supported. Expected tensor or sequence of tensors." + ), + hints=[ + "Ensure outputs and inputs arguments are tensors or sequences of tensors.", + ], + ) + return result + + @functools.cache def get_overridable_functions() -> set[Callable[..., Any]]: from itertools import chain @@ -407,6 +471,9 @@ def as_proxy(self) -> Any: def as_python_constant(self) -> Any: return self.value + def get_real_python_backed_value(self) -> Any: + return self.value + def call_obj_hasattr( self, tx: "InstructionTranslator", name: str ) -> ConstantVariable: @@ -651,6 +718,7 @@ def get_function(self) -> Callable[..., Any]: def _get_handlers() -> dict[Callable[..., Any], Callable[..., Any]]: """Build a dict from function -> method to handle it so that we are O(1) in terms of the number of function with special handling.""" + # pyrefly: ignore [implicit-any] handlers = {} def register( @@ -817,9 +885,9 @@ def handle_is_tensor( and isinstance(arg, UserDefinedObjectVariable) and hasattr(arg.value, "__torch_function__") ): - return CONSTANT_VARIABLE_TRUE + return ConstantVariable.create(True) else: - return CONSTANT_VARIABLE_FALSE + return ConstantVariable.create(False) @register( torch.is_floating_point, @@ -941,7 +1009,101 @@ def handle_use_deterministic_algorithms( "call_function", torch._C._set_deterministic_algorithms, (value,), {} ) torch._C._set_deterministic_algorithms(value) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) + + @register(torch.autocast_increment_nesting) + def handle_autocast_increment_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.autocast_increment_nesting, (), {} + ) + prev = torch.autocast_increment_nesting() + tx.output.add_cleanup_hook(lambda: torch.autocast_decrement_nesting()) + return VariableTracker.build(tx, prev) + + @register(torch.autocast_decrement_nesting) + def handle_autocast_decrement_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.autocast_decrement_nesting, (), {} + ) + prev = torch.autocast_decrement_nesting() + tx.output.add_cleanup_hook(lambda: torch.autocast_increment_nesting()) + return VariableTracker.build(tx, prev) + + @register(torch.set_autocast_enabled) + def handle_set_autocast_enabled( + self, + tx: "InstructionTranslator", + device_type: VariableTracker, + enabled: VariableTracker, + ) -> VariableTracker: + tx.output.create_node( + "call_function", + torch.set_autocast_enabled, + (device_type.as_proxy(), enabled.as_proxy()), + ) + dev_py_const = device_type.as_python_constant() + prev = torch.is_autocast_enabled(dev_py_const) + torch.set_autocast_enabled(dev_py_const, enabled.as_python_constant()) + tx.output.add_cleanup_hook( + lambda: torch.set_autocast_enabled(dev_py_const, prev) + ) + return ConstantVariable.create(None) + + @register(torch.set_autocast_cache_enabled) + def handle_set_autocast_cache_enabled( + self, tx: "InstructionTranslator", enabled: VariableTracker + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch.set_autocast_cache_enabled, (enabled.as_proxy(),) + ) + prev = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(enabled.as_python_constant()) + tx.output.add_cleanup_hook(lambda: torch.set_autocast_cache_enabled(prev)) + return ConstantVariable.create(None) + + @register(torch._C._functorch._grad_increment_nesting) + def handle_grad_increment_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch._C._functorch._grad_increment_nesting + ) + level = torch._C._functorch._grad_increment_nesting() + tx.output.add_cleanup_hook(torch._C._functorch._grad_decrement_nesting) + return VariableTracker.build(tx, level) + + @register(torch._C._functorch._grad_decrement_nesting) + def handle_grad_decrement_nesting( + self, tx: "InstructionTranslator" + ) -> VariableTracker: + tx.output.create_node( + "call_function", torch._C._functorch._grad_decrement_nesting + ) + level = torch._C._functorch._grad_decrement_nesting() + tx.output.add_cleanup_hook(torch._C._functorch._grad_increment_nesting) + return VariableTracker.build(tx, level) + + @register(torch._C._functorch.set_inplace_requires_grad_allowed) + def handle_set_inplace_requires_grad_allowed( + self, tx: "InstructionTranslator", allowed: VariableTracker + ) -> VariableTracker: + tx.output.create_node( + "call_function", + torch._C._functorch.set_inplace_requires_grad_allowed, + (allowed.as_proxy(),), + ) + prev = torch._C._functorch.get_inplace_requires_grad_allowed() + torch._C._functorch.set_inplace_requires_grad_allowed( + allowed.as_python_constant() + ) + tx.output.add_cleanup_hook( + lambda: torch._C._functorch.set_inplace_requires_grad_allowed(prev) + ) + return ConstantVariable.create(None) @register(torch.are_deterministic_algorithms_enabled) def handle_are_deterministic_algorithms_enabled( @@ -1213,7 +1375,7 @@ def handle_assert( isinstance(condition, variables.SymNodeVariable) and condition.evaluate_expr() ): - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) return None @register(SDPAParams) @@ -1358,8 +1520,26 @@ def handle_one_hot( ) return None - @register(torch.fx.experimental.symbolic_shapes.size_hint) - def handle_size_hint( + @register(torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw) + def handle_guarding_hint_or_throw( + self, + tx: "InstructionTranslator", + expr: VariableTracker, + ) -> VariableTracker | None: + if isinstance(expr, SymNodeVariable): + return VariableTracker.build( + tx, + torch.fx.experimental.symbolic_shapes.guarding_hint_or_throw( + expr.sym_num + ), + ) + elif expr.is_python_constant(): + return expr + else: + return None + + @register(torch.fx.experimental.symbolic_shapes.optimization_hint) + def handle_optimization_hint( self, tx: "InstructionTranslator", expr: VariableTracker, @@ -1369,7 +1549,7 @@ def handle_size_hint( if isinstance(expr, SymNodeVariable): return VariableTracker.build( tx, - torch.fx.experimental.symbolic_shapes.size_hint( + torch.fx.experimental.symbolic_shapes.optimization_hint( expr.sym_num, fallback_int ), ) @@ -1646,14 +1826,14 @@ def handle_push_torch_function( **kwargs: VariableTracker, ) -> VariableTracker: if len(args) != 1 or kwargs: - raise_type_error_exc( + raise_type_error( tx, f"push_torch_function takes exactly one argument ({len(args)} given)", ) TorchFunctionModeStackVariable.register_mutation(tx) # type: ignore[arg-type] tx.symbolic_torch_function_state.push_torch_function_mode(args[0]) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) @register(torch._C._len_torch_function_stack) def handle_len_torch_function( @@ -1663,7 +1843,7 @@ def handle_len_torch_function( **kwargs: VariableTracker, ) -> VariableTracker: if args or kwargs: - raise_type_error_exc(tx, "len_torch_function_stack takes no arguments") + raise_type_error(tx, "len_torch_function_stack takes no arguments") return VariableTracker.build( tx, len(tx.symbolic_torch_function_state.mode_stack) ) @@ -1676,7 +1856,7 @@ def handle_get_stack_at( **kwargs: VariableTracker, ) -> TorchFunctionModeVariable: if len(args) != 1 or kwargs: - raise_type_error_exc( + raise_type_error( tx, f"get_function_stack_at takes exactly one argument ({len(args)} given)", ) @@ -1735,6 +1915,8 @@ def handle_current_stream( *args: VariableTracker, **kwargs: VariableTracker, ) -> StreamVariable: + from .streams import CudaStreamVariable + if len(args) + len(kwargs) > 1 or (kwargs and "device" not in kwargs): unimplemented( gb_type="unsupported arguments to torch.accelerator.current_stream", @@ -1752,7 +1934,17 @@ def handle_current_stream( else: device = None - return tx.symbolic_stream_state.cur_stream(device) + stream_var = tx.symbolic_stream_state.cur_stream(device) + if self.value is torch.cuda.current_stream and not isinstance( + stream_var, CudaStreamVariable + ): + stream_var = CudaStreamVariable( + stream_var.proxy, + stream_var.value, + stream_var.user_object_index, + source=stream_var.source, + ) + return stream_var except Exception as e: unimplemented( gb_type="bad device argument to torch.accelerator.current_stream", @@ -1762,6 +1954,53 @@ def handle_current_stream( from_exc=e, ) + _synchronize_fn_to_device_type = { + torch.cuda.synchronize: "cuda", + torch.xpu.synchronize: "xpu", + torch.mps.synchronize: "mps", + torch.cpu.synchronize: "cpu", + } + + @register( + torch.accelerator.synchronize, + torch.cuda.synchronize, + torch.xpu.synchronize, + torch.mps.synchronize, + torch.cpu.synchronize, + ) + def handle_synchronize( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + device = None + if kwargs and "device" in kwargs: + device = torch.device(kwargs["device"].as_python_constant()) + elif args: + device = torch.device(args[0].as_python_constant()) + + if device is None: + device_type = _synchronize_fn_to_device_type.get(self.value) + if device_type is None: + # torch.accelerator.synchronize with no args + accelerator = torch.accelerator.current_accelerator() + assert accelerator is not None + device_type = accelerator.type + device = torch.device(device_type) + + # CPU synchronize is a no-op, skip emitting the op + if device.type == "cpu": + return ConstantVariable.create(None) + + tx.output.create_proxy( + "call_function", + torch.ops.streams.synchronize_device, + (device.type, device.index or 0), + {}, + ) + return ConstantVariable.create(None) + @register(torch.set_default_device) def handle_set_default_device( self, @@ -1782,7 +2021,40 @@ def handle_set_default_device( else: TorchFunctionModeStackVariable.register_device_context_insertion(tx) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) + + from torch._prims_common import elementwise_dtypes + + @register(elementwise_dtypes) + def handle_elementwise_dtypes( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + from .builder import SourcelessBuilder + + type_promotion_kind = kwargs["type_promotion_kind"].as_python_constant() + real_args = [] + for arg in args: + if isinstance(arg, TensorVariable): + real_args.append(arg.as_proxy().node.meta["example_value"]) + elif arg.is_python_constant(): + real_args.append(arg.as_python_constant()) + else: + unimplemented( + gb_type="elementwise_dtypes unsupported arg type", + context=str(arg), + explanation=( + "elementwise_dtypes requires tensor or constant arguments, " + f"got {type(arg).__name__}" + ), + hints=[*graph_break_hints.SUPPORTABLE], + ) + result = elementwise_dtypes( + *real_args, type_promotion_kind=type_promotion_kind + ) + return SourcelessBuilder.create(tx, result) @register(torch._check) def handle_check( @@ -1849,7 +2121,7 @@ def handle_check( if predicate_vt.is_python_constant(): self.value(predicate_vt.as_python_constant(), message_eager) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) predicate_proxy = predicate_vt.as_proxy() @@ -1869,6 +2141,82 @@ def handle_check( ), ) + def exchange_device_helper( + tx: "InstructionTranslator", + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + fn: Callable[[int], int | None], + ) -> VariableTracker: + if len(args) != 1 or kwargs: + raise_type_error( + tx, + f"{fn.__name__} takes exactly one argument ({len(args)} given)", + ) + current_device_source = CallFunctionNoArgsSource( + AttrSource(AttrSource(ImportSource("torch"), "cuda"), "current_device") + ) + install_guard(current_device_source.make_guard(GuardBuilder.EQUALS_MATCH)) + arg = args[0].as_python_constant() + prev = fn(arg) + tx.output.create_node( + "call_function", + fn, + (arg,), + {}, + ) + tx.output.add_cleanup_hook(lambda: torch.cuda.set_device(prev)) + return VariableTracker.build(tx, prev) + + @register(torch.cuda._exchange_device) + def handle_exchange_device( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return exchange_device_helper(tx, args, kwargs, torch.cuda._exchange_device) + + @register(torch.cuda._maybe_exchange_device) + def handle_maybe_exchange_device( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + return exchange_device_helper( + tx, args, kwargs, torch.cuda._maybe_exchange_device + ) + + @register(torch._dynamo.decorators.override_optimization_hint) + def handle_override_optimization_hint( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker: + from .constant import ConstantVariable + from .tensor import SymNodeVariable + + x_vt = args[0] + val_vt = args[1] + val = val_vt.as_python_constant() + + if x_vt.is_python_constant(): + torch._dynamo.decorators.override_optimization_hint( + x_vt.as_python_constant(), val + ) + return ConstantVariable.create(None) + + if isinstance(x_vt, SymNodeVariable): + torch._dynamo.decorators.override_optimization_hint(x_vt.sym_num, val) + return ConstantVariable.create(None) + + raise UserError( + UserErrorType.INVALID_INPUT, + "override_optimization_hint expects a SymInt or int argument, " + f"got {type(x_vt).__name__}", + ) + @register(torch.autograd.grad) def handle_autograd_grad(self, tx: "InstructionTranslator", *args, **kwargs): """ @@ -1947,17 +2295,21 @@ def fn(x): # consumed grad_fns of returned tensors. This gives better compile # coverage than failing the entire compile. if tx.speculation_log.graph_break_on_autograd_grad: + leaked = tx.speculation_log.autograd_grad_leaked_tensors + leaked_str = ", ".join(leaked) if leaked else "unknown" unimplemented( gb_type="autograd.grad consumed returned tensor's grad_fn", - context="", + context=f"Leaked output tensors: {leaked_str}", explanation=( "torch.autograd.grad() consumes grad_fns that are needed by tensors " "returned from this compiled function. This would cause 'backward " - "through graph a second time' errors." + "through graph a second time' errors.\n" + f" The following returned tensors have consumed grad_fns: {leaked_str}" ), hints=[ - "If you don't need to backward through the returned tensor, " - "call .detach() before returning: `return loss.detach()`", + f"Detach the problematic tensor(s) before returning: e.g. `{leaked[0]}.detach()`" + if leaked + else "Call .detach() on the tensor before returning.", "If you need to backward through the returned tensor, use retain_graph=True in autograd.grad().", ], ) @@ -2111,14 +2463,78 @@ def fn(x): ) tx.output.autograd_grad_consumed_grad_fns.update(non_leaf_consumed) - return wrap_fx_proxy( - tx=tx, - proxy=tx.output.create_proxy( + with ( + torch.fx.traceback.preserve_node_meta(), + torch.fx.traceback._set_autograd_backward(), + ): + proxy = tx.output.create_proxy( "call_function", torch.autograd.grad, *proxy_args_kwargs(args, kwargs), - ), - ) + ) + return wrap_fx_proxy(tx=tx, proxy=proxy) + + @register(torch._functorch.eager_transforms._autograd_grad) + def handle_functorch_autograd_grad( + self, + tx: "InstructionTranslator", + *args: VariableTracker, + **kwargs: VariableTracker, + ) -> VariableTracker | None: + """Graph-break when closure-captured tensors lose their grad_fn. + + NOTE [Detecting lost autograd linkage in closure-captured tensors] + + Functorch transforms (vjp, grad, jacrev) return closures that capture + tensors with grad_fn. When such a closure is compiled separately, those + tensors become graph placeholders whose FakeTensors lose grad_fn, + causing _autograd_grad to silently return zeros. + + _collect_placeholder_nodes gathers placeholder nodes from the + outputs/inputs args. For each, we compare grapharg.example (the real + tensor, retains grad_fn) against example_value (FakeTensor, grad_fn + is None). A mismatch means autograd linkage was lost, so we + graph-break. + + This is a pre-check only: kwargs (retain_graph, create_graph, + grad_outputs) don't affect linkage detection and are handled by + the default proxy path when this returns None. + """ + outputs_var = args[0] if len(args) >= 1 else None + inputs_var = args[1] if len(args) >= 2 else None + + if outputs_var is None or inputs_var is None: + return None + + output_placeholder_nodes = _collect_placeholder_nodes(outputs_var) + input_placeholder_nodes = _collect_placeholder_nodes(inputs_var) + + if output_placeholder_nodes and input_placeholder_nodes: + for node in output_placeholder_nodes: + fake = node.meta.get("example_value") + grapharg = node.meta.get("grapharg") + if ( + grapharg is not None + and isinstance(fake, torch.Tensor) + and fake.grad_fn is None + ): + real = grapharg.example + if isinstance(real, torch.Tensor) and real.grad_fn is not None: + unimplemented( + gb_type="_autograd_grad with lost grad_fn linkage", + context="outputs lost autograd linkage during tracing", + explanation=( + "_autograd_grad() received tensors whose grad_fn " + "was lost during tracing - this silently produces " + "zero gradients." + ), + hints=[ + "Compile the full transform instead of the returned " + "closure: torch.compile(lambda x: torch.func.vjp(f, x))", + *graph_break_hints.SUPPORTABLE, + ], + ) + return None return handlers @@ -2161,7 +2577,7 @@ def call_function( raise_observed_exception( type(exc), tx, - args=[VariableTracker.build(tx, a) for a in exc.args], + args=list(exc.args), ) if self.is_tensor_method(): @@ -2748,6 +3164,11 @@ def _call_leaf_function( real_impl_callable = _LeafCallable(wrapped_real_impl) fake_impl_callable = _LeafCallable(wrapped_fake_impl) + hook_fn = getattr(decorated_fn, "_torchdynamo_leaf_hook_fn", None) + if hook_fn is not None: + hook_fake_fn = getattr(decorated_fn, "_torchdynamo_leaf_hook_fake_fn", None) + real_impl_callable._leaf_hook_real_fn = hook_fn # type: ignore[attr-defined] + real_impl_callable._leaf_hook_fake_fn = hook_fake_fn # type: ignore[attr-defined] def make_callable_proxy(name: str, spec: Any) -> Any: proxy = tx.output.register_static_attr_and_return_proxy(name, spec) diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 99a20f19d0146..699371b029af4 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -57,7 +57,7 @@ set_torch_function_mode_stack, ) from .base import VariableTracker -from .constant import CONSTANT_VARIABLE_NONE +from .constant import ConstantVariable from .ctx_manager import GenericContextWrappingVariable from .functions import UserMethodVariable from .lazy import LazyVariableTracker @@ -203,12 +203,12 @@ def enter(self, tx: "InstructionTranslator") -> VariableTracker: from .torch import TorchInGraphFunctionVariable if isinstance(self.value, NoEnterTorchFunctionMode): - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) TorchInGraphFunctionVariable( torch._C._push_on_torch_function_stack ).call_function(tx, [self], {}) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: from .torch import TorchInGraphFunctionVariable @@ -216,7 +216,7 @@ def exit(self, tx: "InstructionTranslator", *args: Any) -> VariableTracker: TorchInGraphFunctionVariable(torch._C._pop_torch_function_stack).call_function( tx, [], {} ) - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) def reconstruct_type(self, codegen: "PyCodegen") -> None: ty = NoEnterTorchFunctionMode @@ -489,7 +489,7 @@ def call_torch_function( # # Also notice the `cls` is not explicitly passed in the reference # implementations: - # 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # noqa: B950 + # 1. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/python_arg_parser.cpp#L368-L374 # 2. https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/overrides.py#L1741-L1743 tf_args = [ fn, @@ -576,7 +576,7 @@ def dispatch_torch_function( unimplemented( gb_type="All __torch_function__ overrides returned NotImplemented due to TypeError from user code", context=f"{fn=}, {args=}, {kwargs=}", - explanation=f"All __torch_function__ overrides for for function {fn} returned NotImplemented", + explanation=f"All __torch_function__ overrides for function {fn} returned NotImplemented", hints=[ *graph_break_hints.USER_ERROR, ], diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 5cb2718f034ad..f6b5aecfa054a 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -29,7 +29,6 @@ import enum import functools import inspect -import itertools import random import sys import threading @@ -45,6 +44,7 @@ import torch.nn from torch._guards import Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type +from torch.utils._pytree import GetAttrKey, is_structseq_class from .. import config, graph_break_hints, polyfills, variables from ..bytecode_transformation import create_call_function @@ -56,6 +56,7 @@ ObservedTypeError, ObservedUserStopIteration, raise_observed_exception, + raise_type_error, unimplemented, ) from ..graph_bytecode_inputs import get_external_object_by_index @@ -63,7 +64,6 @@ from ..source import ( AttrSource, CallFunctionNoArgsSource, - DataclassFieldsSource, DictGetItemSource, GetItemSource, RandomValueSource, @@ -76,11 +76,9 @@ check_constant_args, cmp_name_to_op_mapping, dict_methods, - enum_type_methods, frozenset_methods, get_custom_getattr, has_torch_function, - is_frozen_dataclass, is_lru_cache_wrapped_function, is_namedtuple_cls, is_wrapper_or_member_descriptor, @@ -96,8 +94,10 @@ tuple_methods, unpatched_nn_module_getattr, ) -from .base import MutationType, raise_type_error_exc, ValueMutationNew, VariableTracker -from .dicts import ConstDictVariable, DefaultDictVariable, SetVariable +from .base import MutationType, NO_SUCH_SUBOBJ, ValueMutationNew, VariableTracker +from .dicts import ConstDictVariable +from .hashable import HashableTracker +from .sets import SetVariable try: @@ -155,6 +155,14 @@ def is_cython_function(obj: object) -> bool: ) +def is_pydantic_dataclass_cls(value: object) -> bool: + return ( + inspect.isclass(value) + and dataclasses.is_dataclass(value) + and "__is_pydantic_dataclass__" in getattr(value, "__dict__", {}) + ) + + # Types whose instances are data descriptors (have __get__ + (__set__ or __delete__)). # CPython invokes data descriptors found on the type MRO *before* checking # the instance __dict__. This set is used by is_data_descriptor for a fast @@ -270,109 +278,295 @@ def supported_c_new_functions() -> set[Any]: frozenset.__new__, tuple.__new__, list.__new__, + int.__new__, + float.__new__, + str.__new__, } return c_new_fns.union(exceptions) @staticmethod def is_supported_new_method(value: object) -> bool: - # TODO(anijain2305) - Extend this to support objects with default tp_new - # functions. - return value in UserDefinedClassVariable.supported_c_new_functions() + if value in UserDefinedClassVariable.supported_c_new_functions(): + return True + # Structseq types each define their own C tp_new. + owner = getattr(value, "__self__", None) + return isinstance(owner, type) and is_structseq_class(owner) def can_constant_fold_through(self) -> bool: - return self.value in self._constant_fold_classes() + if self.value in self._constant_fold_classes(): + return True + # Enum class calls (e.g., Color(1)) are value lookups that return + # existing singleton members, so they can always be constant-folded. + return isinstance(self.value, type) and issubclass(self.value, enum.Enum) - def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool: - if tx.output.side_effects.has_pending_mutation_of_attr(self, key): - mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) - return not isinstance(mutated_attr, variables.DeletedVariable) + def lookup_cls_mro_attr(self, name: str) -> object: + """Walk cls.__mro__ only (not the metaclass chain) to find *name*.""" + for base in self.value.__mro__: + if name in base.__dict__: + return base.__dict__[name] + return NO_SUCH_SUBOBJ - return key in self.value.__dict__ + def lookup_metaclass_attr(self, name: str) -> object: + """Walk type(cls).__mro__ (the metaclass chain) to find *name*.""" + for base in type(self.value).__mro__: + if name in base.__dict__: + return base.__dict__[name] + return NO_SUCH_SUBOBJ - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - from . import ConstantVariable + def bool_impl( + self, + tx: "InstructionTranslator", + ) -> "VariableTracker": + from .constant import ConstantVariable + # bool() on a class consults the metaclass __bool__. + # If the metaclass is the default `type`, all classes are truthy. + metaclass = type(self.value) + if hasattr(metaclass, "__bool__") and metaclass is not type: + return self.call_method(tx, "__bool__", [], {}) + return ConstantVariable.create(True) + + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: source = AttrSource(self.source, name) if self.source is not None else None - if name == "__name__": - return VariableTracker.build(tx, self.value.__name__) - elif name == "__qualname__": - return VariableTracker.build(tx, self.value.__qualname__) - elif name == "__dict__": - options = {"source": source} - return variables.GetAttrVariable(self, name, None, **options) - elif name == "__mro__": - attr_source = self.source and TypeMROSource(self.source) - return VariableTracker.build(tx, self.value.__mro__, attr_source) + # --- Dynamo-specific pre-checks --- - # Special handling of collections.OrderedDict.fromkeys() - # Wrap it as GetAttrVariable(collections.OrderedDict, "fromkeys") to make it consistent with - # collections.defaultdict, and both will be handled at UserDefinedClassVariable.call_method(). - # Otherwise, it would be wrapped as UserDefinedObjectVariable(collections.OrderedDict.fromkeys), - # and we need duplicate code to handle both cases. + # Wrap OrderedDict/defaultdict.fromkeys as GetAttrVariable so it's + # handled uniformly in call_method(). if ( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): return super().var_getattr(tx, name) - obj = None - try: - obj = inspect.getattr_static(self.value, name) - except AttributeError: - if type(self.value) is type: - error_message = VariableTracker.build( - tx, f"type object '{self.value.__name__}' has no attribute '{name}'" - ) - raise_observed_exception( - AttributeError, - tx, - args=[error_message], + # Custom metaclasses that override __getattribute__ replace the entire + # lookup algorithm; bail out for those. Standard metaclasses (ABCMeta, + # EnumType, etc.) that don't override __getattribute__ use + # type.__getattribute__ which is the algorithm we implement below. + metacls = type(self.value) + if metacls is not type and "__getattribute__" in metacls.__dict__: + unimplemented( + gb_type="Custom metaclass with __getattribute__", + context=f"type({self.value}) = {metacls}", + explanation="Dynamo does not trace attribute access on classes whose " + "metaclass overrides __getattribute__", + hints=graph_break_hints.SUPPORTABLE, + ) + + # ---- CPython type_getattro algorithm ---- + # https://github.com/python/cpython/blob/3.13/Objects/typeobject.c#L5417-L5505 + # 1. meta_attr = lookup name in type(cls).__mro__ (metaclass chain) + # 2. if meta_attr is a DATA descriptor → invoke + # 3. cls_attr = lookup name in cls.__mro__ (class chain) + # 4. if cls_attr has __get__ → invoke cls_attr.__get__(None, cls) + # 5. if cls_attr exists (plain) → return as-is + # 6. if meta_attr is a non-data descriptor or plain → return + # 7. raise AttributeError + + # Step 1-2: Metaclass data descriptors. + # For type(cls) is type, these are C-level getset/member descriptors + # for __dict__, __mro__, __name__, __qualname__, __doc__, etc. + meta_attr = self.lookup_metaclass_attr(name) + if meta_attr is not NO_SUCH_SUBOBJ and is_data_descriptor(meta_attr): + return self.resolve_meta_data_descriptor(tx, name, meta_attr, source) + + # Step 3-5: Class MRO lookup. + cls_attr = self.lookup_cls_mro_attr(name) + if cls_attr is not NO_SUCH_SUBOBJ: + if hasattr(type(cls_attr), "__get__"): + # Step 4: Descriptor — invoke __get__(None, cls). + return self.resolve_cls_descriptor(tx, name, cls_attr, source) + # Step 5: Plain attribute. + return self.resolve_cls_plain_attr(tx, name, cls_attr, source) + + # Step 6: Metaclass non-data descriptor or plain attr. + # These are non-data descriptors on the metaclass (e.g. type.__call__, + # type.__subclasses__, type.mro). We use GetAttrVariable to defer to + # runtime rather than VariableTracker.build, because build would create + # a variable for the raw C-level descriptor which then fails when + # called (e.g. type.__subclasses__ is a method_descriptor that dynamo + # can't trace). GetAttrVariable defers the access and lets + # call_method handle it. + if meta_attr is not NO_SUCH_SUBOBJ: + return variables.GetAttrVariable(self, name, type(meta_attr), source=source) + + # __getattr__ on metaclass (not part of type_getattro proper — + # CPython handles this via slot_tp_getattr_hook). + metacls = type(self.value) + if metacls is not type: + meta_getattr = self.lookup_metaclass_attr("__getattr__") + if meta_getattr is not NO_SUCH_SUBOBJ and isinstance( + meta_getattr, types.FunctionType + ): + return variables.UserMethodVariable(meta_getattr, self).call_function( + tx, [variables.ConstantVariable.create(name)], {} ) - if name == "__new__" and UserDefinedClassVariable.is_supported_new_method(obj): - return super().var_getattr(tx, name) + # Step 7: AttributeError. + raise_observed_exception( + AttributeError, + tx, + args=[f"type object '{self.value.__name__}' has no attribute '{name}'"], + ) + + def resolve_meta_data_descriptor( + self, + tx: "InstructionTranslator", + name: str, + meta_attr: object, + source: Source | None, + ) -> VariableTracker: + """Handle data descriptors from the metaclass MRO (type.__dict__ slots).""" + if name == "__dict__": + return VariableTracker.build( + tx, + self.value.__dict__, + source=self.source and AttrSource(self.source, "__dict__"), + ) + if name == "__mro__": + attr_source = self.source and TypeMROSource(self.source) + return VariableTracker.build(tx, self.value.__mro__, attr_source) + # __name__, __qualname__, __doc__, __module__, __bases__, + # __abstractmethods__, etc. — all C-level getset descriptors on type. + resolved = type.__getattribute__(self.value, name) + if source: + return VariableTracker.build(tx, resolved, source) + from . import ConstantVariable + + if ConstantVariable.is_literal(resolved): + return VariableTracker.build(tx, resolved) + return variables.GetAttrVariable(self, name, type(resolved), source=source) - if name in cmp_name_to_op_mapping and not isinstance(obj, types.FunctionType): - return variables.GetAttrVariable(self, name, None, source=source) + def resolve_cls_descriptor( + self, + tx: "InstructionTranslator", + name: str, + cls_attr: object, + source: Source | None, + ) -> VariableTracker: + """Handle descriptors found in cls.__mro__.""" + if isinstance(cls_attr, staticmethod): + return VariableTracker.build(tx, cls_attr.__get__(self.value), source) - if isinstance(obj, staticmethod): - return VariableTracker.build(tx, obj.__get__(self.value), source) - elif isinstance(obj, classmethod): - if isinstance(obj.__func__, property): - fget_vt = VariableTracker.build(tx, obj.__func__.fget) + if isinstance(cls_attr, classmethod): + if isinstance(cls_attr.__func__, property): + fget_vt = VariableTracker.build(tx, cls_attr.__func__.fget) return fget_vt.call_function(tx, [self], {}) - return variables.UserMethodVariable(obj.__func__, self, source=source) - elif isinstance(obj, types.ClassMethodDescriptorType): - # e.g.: inspect.getattr_static(dict, "fromkeys") - # inspect.getattr_static(itertools.chain, "from_iterable") - func = obj.__get__(None, self.value) + return variables.UserMethodVariable(cls_attr.__func__, self, source=source) + + if isinstance(cls_attr, types.ClassMethodDescriptorType): + func = cls_attr.__get__(None, self.value) return VariableTracker.build(tx, func, source) - elif source: - if inspect.ismemberdescriptor(obj): - return VariableTracker.build(tx, obj.__get__(self.value), source) - - if ConstantVariable.is_literal(obj): - return VariableTracker.build(tx, obj) - elif isinstance(obj, enum.Enum): - return VariableTracker.build(tx, obj, source) - elif self.value is collections.OrderedDict: - return variables.GetAttrVariable(self, name) - elif name in getattr(self.value, "__dict__", {}) or ( - self.value.__module__.startswith("torch.") - or self.value.__module__ == "torch" - ): + + # property and _tuplegetter accessed on the class return the + # descriptor itself (descriptor.__get__(None, cls) is descriptor). + # Build directly — no need to invoke __get__. + if isinstance(cls_attr, (property, _collections._tuplegetter)): if source: - return VariableTracker.build(tx, obj, source) + return VariableTracker.build(tx, cls_attr, source) + return UserDefinedObjectVariable(cls_attr) - if ( - source - and not inspect.ismethoddescriptor(obj) - and not is_wrapper_or_member_descriptor(obj) + # Comparison dunders inherited from object — defer to runtime. + if name in cmp_name_to_op_mapping and not isinstance( + cls_attr, types.FunctionType ): - return VariableTracker.build(tx, obj, source) + return variables.GetAttrVariable( + self, name, py_type=type(cls_attr), source=source + ) - return super().var_getattr(tx, name) + # User-defined descriptor with Python __get__. + # For torch-internal classes or attributes in the class's own __dict__, + # defer descriptor invocation to runtime via VariableTracker.build to + # avoid compile-time side effects (e.g. deprecation warnings from + # _ClassPropertyDescriptor on torch.FloatStorage.dtype). + get_fn = inspect.getattr_static(type(cls_attr), "__get__", None) + if isinstance(get_fn, types.FunctionType): + if source and ( + name in getattr(self.value, "__dict__", {}) + or self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ): + return VariableTracker.build(tx, cls_attr, source) + return self.invoke_cls_descriptor_get(tx, name, cls_attr, source) + + # C-level descriptors (WrapperDescriptor, MethodDescriptor, etc.) + # Build directly when the attribute lives in the class's own __dict__ + # or the class belongs to torch (needed for e.g. torch.Tensor.dim). + # OrderedDict's C-level methods are handled at runtime. + if inspect.ismethoddescriptor(cls_attr) or is_wrapper_or_member_descriptor( + cls_attr + ): + if ( + source + and self.value is not collections.OrderedDict + and ( + name in getattr(self.value, "__dict__", {}) + or self.value.__module__.startswith("torch.") + or self.value.__module__ == "torch" + ) + ): + return VariableTracker.build(tx, cls_attr, source) + return variables.GetAttrVariable(self, name, type(cls_attr), source=source) + + # Everything else: FunctionType, etc. + return VariableTracker.build(tx, cls_attr, source) + + def resolve_cls_plain_attr( + self, + tx: "InstructionTranslator", + name: str, + cls_attr: object, + source: Source | None, + ) -> VariableTracker: + """Handle non-descriptor attributes from cls.__mro__.""" + if name == "__new__" and UserDefinedClassVariable.is_supported_new_method( + cls_attr + ): + return super().var_getattr(tx, name) + if self.value is collections.OrderedDict: + return variables.GetAttrVariable(self, name, py_type=type(cls_attr)) + return VariableTracker.build(tx, cls_attr, source) + + def invoke_cls_descriptor_get( + self, + tx: "InstructionTranslator", + name: str, + descriptor: object, + source: Source | None, + ) -> VariableTracker: + """Trace a class-MRO descriptor's __get__(None, cls) call.""" + from .constant import ConstantVariable + + descriptor_source = None + descriptor_get_source = None + if self.source: + descriptor_source = AttrSource(self.source, name) + descriptor_get_source = AttrSource(TypeSource(descriptor_source), "__get__") + descriptor_var = VariableTracker.build(tx, descriptor, descriptor_source) + else: + descriptor_var = UserDefinedObjectVariable(descriptor) + + none_var = ConstantVariable.create(None) + return variables.UserMethodVariable( + descriptor.__get__.__func__, # type: ignore[union-attr] + descriptor_var, + source=descriptor_get_source, + ).call_function(tx, [none_var, self], {}) + + def len_impl(self, tx: "InstructionTranslator") -> VariableTracker: + m = self._maybe_get_baseclass_method("__len__") + if m: + source = self.source and AttrSource(self.source, "__len__") + return variables.UserMethodVariable( + m, self, source_fn=source + ).call_function(tx, [], {}) + raise_type_error(tx, f"object of type {self.python_type_name()} has no length") + + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + return self.len_impl(tx) + + def mp_length(self, tx: "InstructionTranslator") -> VariableTracker: + return self.len_impl(tx) def _call_cross_entropy_loss( self, @@ -389,13 +583,13 @@ def _call_cross_entropy_loss( non functional loss call: input, target, optional_output """ - from . import CONSTANT_VARIABLE_NONE, ConstantVariable + from . import ConstantVariable def normalize_args( - weight: VariableTracker = CONSTANT_VARIABLE_NONE, - size_average: VariableTracker = CONSTANT_VARIABLE_NONE, + weight: VariableTracker = ConstantVariable.create(None), + size_average: VariableTracker = ConstantVariable.create(None), ignore_index: VariableTracker = ConstantVariable.create(-100), - reduce: VariableTracker = CONSTANT_VARIABLE_NONE, + reduce: VariableTracker = ConstantVariable.create(None), reduction: VariableTracker = ConstantVariable.create("mean"), label_smoothing: VariableTracker = ConstantVariable.create(0.0), ) -> tuple[VariableTracker, ...]: @@ -469,11 +663,15 @@ def call_method( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): - return variables.BuiltinVariable.call_custom_dict_fromkeys( + return variables.DictBuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) elif self.value is collections.OrderedDict and name == "move_to_end": return args[0].call_method(tx, name, [*args[1:]], kwargs) + elif name == "__len__" and len(args) == 1 and not kwargs: + from .object_protocol import generic_len + + return generic_len(tx, args[0]) elif name == "__eq__" and len(args) == 1 and hasattr(args[0], "value"): return VariableTracker.build(tx, self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): @@ -486,22 +684,6 @@ def call_method( elif issubclass(self.value, (set, frozenset)) and name != "__new__": # __new__ is handled below return SourcelessBuilder.create(tx, set).call_method(tx, name, args, kwargs) - elif ( - name == "__new__" - and self.value is collections.OrderedDict - and isinstance(args[0], UserDefinedClassVariable) - and args[0].value is collections.OrderedDict - ): - if kwargs and len(args) != 1: - raise_args_mismatch( - tx, - name, - "1 args and 0 kwargs", - f"{len(args)} args and {len(kwargs)} kwargs", - ) - return ConstDictVariable( - {}, collections.OrderedDict, mutation_type=ValueMutationNew() - ) elif ( len(args) == 1 and isinstance(args[0], variables.GenericContextWrappingVariable) @@ -511,10 +693,20 @@ def call_method( elif name == "__new__" and UserDefinedClassVariable.is_supported_new_method( self.value.__new__ ): + # Some C-level tp_new functions (dict.__new__, set.__new__) ignore + # extra args — only the type arg matters. Pass init_args=[] for + # those so reconstruction emits base_cls.__new__(cls) without + # unreconstructable args (e.g. generators). Other tp_new functions + # (tuple.__new__, BaseException.__new__) use the extra args. + new_fn = self.value.__new__ + if new_fn in (dict.__new__, set.__new__): + init_args: list[VariableTracker] = [] + else: + init_args = list(args[1:]) return tx.output.side_effects.track_new_user_defined_object( self, args[0], - args[1:], + init_args, ) elif name == "__setattr__" and self.ban_mutation: unimplemented( @@ -523,8 +715,31 @@ def call_method( explanation="Dyanmo does not support tracing mutations on a class when its __dict__ is materialized", hints=graph_break_hints.SUPPORTABLE, ) + + # Dispatch dunder methods defined on the metaclass (e.g., EnumType.__contains__). + # In Python, `x in Color` calls `type(Color).__contains__(Color, x)`. + metaclass = type(self.value) + if metaclass is not type: + # Look up the method on the metaclass MRO, not the class MRO + for klass in metaclass.__mro__: + if name in klass.__dict__: + method = klass.__dict__[name] + if isinstance(method, types.FunctionType): + source = self.source and AttrSource(self.source, name) + return variables.UserMethodVariable( + method, self, source=source + ).call_function(tx, args, kwargs) + break + return super().call_method(tx, name, args, kwargs) + def unpack_var_sequence( + self, tx: "InstructionTranslator" + ) -> list["VariableTracker"]: + if isinstance(self.value, type) and issubclass(self.value, enum.Enum): + return [VariableTracker.build(tx, item) for item in self.value] + raise NotImplementedError + def call_function( self, tx: "InstructionTranslator", @@ -547,7 +762,6 @@ def call_function( "Set TORCHDYNAMO_ENABLE_P2P_COMPILATION=1 to enable.", ], ) - var = tx.output.side_effects.track_new_user_defined_object( SourcelessBuilder.create(tx, object), self, @@ -572,33 +786,19 @@ def call_function( from .ctx_manager import NullContextVariable return NullContextVariable(*args, **kwargs) - elif self.value is collections.OrderedDict: - return tx.inline_user_function_return( - VariableTracker.build(tx, polyfills.construct_dict), - [self, *args], - kwargs, - ) elif self.value is collections.defaultdict: - if len(args) == 0: - default_factory = variables.CONSTANT_VARIABLE_NONE - elif len(args) == 1: - # In the case the argument is a builtin, then we can take the callable as the factory method. - # Otherwise, it must be a ConstantVariable holding None. - if not DefaultDictVariable.is_supported_arg(args[0]): - raise_observed_exception(TypeError, tx, args=[args[0]]) - default_factory = args[0] - args = [] - else: - default_factory, *args = args - dict_vt = variables.BuiltinVariable.call_custom_dict( - tx, dict, *args, **kwargs - ) - return DefaultDictVariable( - dict_vt.items, # type: ignore[attr-defined] - collections.defaultdict, - default_factory, - mutation_type=ValueMutationNew(), + # defaultdict construction — use track_new_user_defined_object + # which creates DefaultDictVariable. __init__ handler extracts + # default_factory and populates items. + from .builder import SourcelessBuilder + + result = tx.output.side_effects.track_new_user_defined_object( + SourcelessBuilder.create(tx, dict), + self, + [], ) + result.call_method(tx, "__init__", list(args), kwargs) + return result elif is_typeddict(self.value): if self.value.__optional_keys__: # type: ignore[attr-defined] unimplemented( @@ -610,9 +810,11 @@ def call_function( *graph_break_hints.SUPPORTABLE, ], ) - return SourcelessBuilder.create(tx, dict).call_dict(tx, *args, **kwargs) + return variables.DictBuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) elif self.value is collections.deque: - maxlen = variables.CONSTANT_VARIABLE_NONE + maxlen = variables.ConstantVariable.create(None) def deque_signature( iterable: Iterable[Any] | None = None, maxlen: int | None = None @@ -675,7 +877,7 @@ def deque_signature( if len(args) > 1: callback = args[1] else: - callback = variables.CONSTANT_VARIABLE_NONE + callback = variables.ConstantVariable.create(None) return variables.WeakRefVariable(args[0], callback) elif self.value is functools.partial: if not args: @@ -700,9 +902,7 @@ def deque_signature( return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs) elif self.value is torch.cuda.device and not kwargs and len(args) == 1: if not args[0].is_python_constant(): - raise_type_error_exc( - tx, "torch.cuda.device() requires a constant argument" - ) + raise_type_error(tx, "torch.cuda.device() requires a constant argument") return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant()) elif ( issubclass(type(self.value), type) @@ -792,17 +992,15 @@ def deque_signature( ) ] + args[1:] - cm_obj = tx.output.side_effects.track_new_user_defined_object( - SourcelessBuilder.create(tx, object), - self, - arg_new, # type: ignore[arg-type] + return tx.inline_user_function_return( + VariableTracker.build( + tx, polyfills.instantiate_user_defined_class_object + ), + [self, *arg_new], + kwargs, ) - cm_obj.call_method(tx, "__init__", arg_new, kwargs) # type: ignore[arg-type] - return cm_obj elif is_namedtuple_cls(self.value): - fields = namedtuple_fields(self.value) # type: ignore[arg-type] - # check if this a quasi-namedtuple or a real one - if self.value.__module__ == "torch.return_types": + if is_structseq_class(self.value): if kwargs or len(args) != 1: raise_args_mismatch( tx, @@ -810,91 +1008,42 @@ def deque_signature( "1 args and 0 kwargs", f"{len(args)} args and {len(kwargs)} kwargs", ) - items = args[0].force_unpack_var_sequence(tx) + # Structseq tp_new is a C function, so we can't trace into + # it like namedtuples. Use track_new_user_defined_object + # directly with self as both base_cls_vt and cls_vt. + return tx.output.side_effects.track_new_user_defined_object( + self, + self, + list(args), + ) else: - field_defaults = self.value._field_defaults # type: ignore[attr-defined] - - items = list(args) - # pyrefly: ignore[bad-argument-type] - items.extend([None] * (len(fields) - len(items))) - - var_tracker_kwargs: dict[str, VariableTracker] = {} - for field_name, var_tracker in zip(fields, items): - if var_tracker is None: - if field_name in kwargs: - field_var = kwargs[field_name] - else: - assert field_name in field_defaults - field_var = VariableTracker.build( - tx, field_defaults[field_name] - ) - var_tracker_kwargs[field_name] = field_var - - for name, value in var_tracker_kwargs.items(): - assert name in fields - items[fields.index(name)] = value # type: ignore[call-overload] - - assert all(x is not None for x in items) - - # Modify mutability of namedtuple for sourcelesss instantiations. - from .base import AttributeMutationNew - from .lists import NamedTupleVariable - - return NamedTupleVariable( - items, - self.value, # type: ignore[arg-type] - mutation_type=AttributeMutationNew(), - ) + # Namedtuple __new__ is a Python function that calls + # tuple.__new__(cls, (field_values,)). Let Dynamo trace + # into it so default values and kwargs are handled by + # the generated __new__ itself. + return tx.inline_user_function_return( + VariableTracker.build( + tx, polyfills.instantiate_user_defined_class_object + ), + [self, *args], + kwargs, + ) elif self.value is torch.Size: # This simulates `THPSize_pynew`, the C impl for `Size.__new__`. from .lists import SizeVariable tup = SourcelessBuilder.create(tx, tuple).call_function(tx, args, kwargs) return SizeVariable(tup.items) # type: ignore[missing-attribute] - elif is_frozen_dataclass(self.value) and self.is_standard_new(): - fields = dataclasses.fields(self.value) # type: ignore[arg-type] - assert self.source is not None - fields_source = DataclassFieldsSource(self.source) - items = list(args) - items.extend([None] * (len(fields) - len(items))) # type: ignore[arg-type] - - default_kwargs = {} - for ind, field, var_tracker in zip(itertools.count(), fields, items): - if var_tracker is None: - if field.name in kwargs: - var_tracker = kwargs[field.name] - else: - if not field.init: - continue - - if field.default is not dataclasses.MISSING: - var_tracker = VariableTracker.build( - tx, - field.default, - source=AttrSource( - GetItemSource(fields_source, ind), "default" - ), - ) - elif field.default_factory is not dataclasses.MISSING: - factory_fn = VariableTracker.build( - tx, field.default_factory - ) - var_tracker = factory_fn.call_function(tx, [], {}) - else: - # if we are subclass, the constructor could possibly - # be missing args - continue - - default_kwargs[field.name] = var_tracker - kwargs.update(default_kwargs) - - var = tx.output.side_effects.track_new_user_defined_object( - SourcelessBuilder.create(tx, object), - self, - args, # type: ignore[arg-type] + elif is_pydantic_dataclass_cls(self.value): + # Pydantic populates dataclass fields through an external validator, + # so tracing through the constructor misses the instance mutations. + unimplemented( + gb_type="Pydantic dataclass constructor", + context=f"{self.value}", + explanation="Dynamo graph breaks on pydantic dataclass constructors " + "because validation mutates the instance outside traced bytecode.", + hints=graph_break_hints.SUPPORTABLE, ) - var.call_method(tx, "__init__", args, kwargs) # type: ignore[arg-type] - return var elif ( self.value in self._in_graph_classes() or is_traceable_wrapper_subclass_type(self.value) @@ -993,23 +1142,24 @@ def deque_signature( seed = None random_object = random.Random(seed) return RandomVariable(random_object) - elif ( - self.value is types.MappingProxyType - and len(args) == 1 - and isinstance(args[0], ConstDictVariable) - ): + elif self.value is types.MappingProxyType and len(args) == 1: # types.MappingProxyType is a read-only proxy of the dict. If the # original dict changes, the changes are reflected in proxy as well. - return variables.MappingProxyVariable(args[0]) + dict_arg = args[0] + if isinstance(dict_arg, variables.UserDefinedDictVariable): + dict_arg = dict_arg._base_vt + if isinstance(dict_arg, ConstDictVariable): + return variables.MappingProxyVariable(dict_arg) elif SideEffects.cls_supports_mutation_side_effects(self.value) and self.source: with do_not_convert_to_tracable_parameter(): - return tx.inline_user_function_return( + result = tx.inline_user_function_return( VariableTracker.build( tx, polyfills.instantiate_user_defined_class_object ), [self, *args], kwargs, ) + return result return super().call_function(tx, args, kwargs) @@ -1029,8 +1179,7 @@ def call_obj_hasattr( functools.partial(GuardBuilder.HASATTR, attr=name) ) ) - return VariableTracker.build(tx, hasattr(self.value, name)) - return super().call_obj_hasattr(tx, name) + return VariableTracker.build(tx, hasattr(self.value, name)) def const_getattr(self, tx: "InstructionTranslator", name: str) -> Any: if name == "__name__": @@ -1049,71 +1198,34 @@ def is_python_equal(self, other: object) -> bool: and self.value is other.value ) + def get_real_python_backed_value(self) -> object: + return self.value + class UserDefinedExceptionClassVariable(UserDefinedClassVariable): @property def fn(self) -> type[object]: return self.value - -class UserDefinedEnumClassVariable(UserDefinedClassVariable): - """ - Represents Enum class objects (the class itself, not instances). - - Handles Enum metaclass methods like __contains__ by checking if the method - is from the standard EnumType metaclass and executing it directly. - - Not yet supported: - - Flag enum membership checks (e.g., `Flag.A in combined_flags`) - """ - - # pyrefly: ignore[bad-override] - value: type[enum.Enum] - - def call_method( + def call_function( self, tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], + args: Sequence[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - method = self._maybe_get_baseclass_method(name) - if method in enum_type_methods: - if name == "__contains__" and len(args) == 1 and not kwargs: - arg = args[0] - if isinstance(arg, variables.EnumVariable): - # Check if the enum value is a member of this enum class - return VariableTracker.build(tx, arg.value in self.value) - elif arg.is_python_constant(): - return VariableTracker.build( - tx, arg.as_python_constant() in self.value - ) - elif isinstance(method, types.FunctionType): - if name == "__contains__" and len(args) == 1 and not kwargs: - source = self.source and AttrSource(self.source, name) - return variables.UserMethodVariable( - method, self, source=source - ).call_function(tx, args, kwargs) - - return super().call_method(tx, name, args, kwargs) - - def unpack_var_sequence( - self, tx: "InstructionTranslator" - ) -> list["VariableTracker"]: - return [VariableTracker.build(tx, item) for item in self.value] - - def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - method = self._maybe_get_baseclass_method(name) - if method in enum_type_methods: - # __iter__ is a bound method which is not correctly handled by the parent var_getattr, so need to handle it here - if name == "__iter__": - source = self.source and AttrSource(self.source, name) - return variables.UserMethodVariable(method, self, source=source) - return super().var_getattr(tx, name) - + from .builder import SourcelessBuilder -class NO_SUCH_SUBOBJ: - pass + if self.source is None: + # NB: If source is added via side effects, create the exception + # object through side_effects as well. See FrozenDataClass creation + var = tx.output.side_effects.track_new_user_defined_object( + SourcelessBuilder.create(tx, BaseException), + self, + list(args), + ) + var.call_method(tx, "__init__", list(args), dict(kwargs)) + return var + return super().call_function(tx, args, kwargs) class RemovableHandleClass: @@ -1150,6 +1262,15 @@ class UserDefinedObjectVariable(UserDefinedVariable): Mostly objects of defined type. Catch-all for something where we only know the type. """ + # VT representing the base built-in type's data for subclassed built-in types + # (e.g., ConstDictVariable for dict subclasses, ListVariable for list subclasses). + # None for plain user-defined objects that don't subclass a built-in container. + _base_vt: VariableTracker | None = None + + # Set of base class methods that can be delegated to _base_vt. + # Used to check whether a method is overridden before delegating. + _base_methods: set[Any] | None = None + _nonvar_fields = { "value", "value_type", @@ -1218,25 +1339,27 @@ def __repr__(self) -> str: def get_dict_vt(self, tx: "InstructionTranslator") -> "DunderDictVariable": if self.dict_vt is None: - dict_proxy = { - key: VariableTracker.build( - tx, - value, - source=self.source - and DictGetItemSource(AttrSource(self.source, "__dict__"), key), - ) - for key, value in self.value.__dict__.items() - } - self.dict_vt = variables.DunderDictVariable.create(tx, self, dict_proxy) + self.dict_vt = variables.DunderDictVariable.create(tx, self) return self.dict_vt - def is_underlying_vt_modified(self, side_effects: "SideEffects") -> bool: + def is_base_vt_modified(self, side_effects: "SideEffects") -> bool: + if self._base_vt is not None: + return side_effects.is_modified(self._base_vt) return False def python_type(self) -> type: return self.value_type # type: ignore[return-value] + def get_real_python_backed_value(self) -> object: + return self.value + def as_python_constant(self) -> object: + if isinstance( + self.value, + (enum.Enum, torch.DispatchKey, torch._C._functorch.TransformType), + ): + return self.value + if self.is_pytree_constant_class and self.source: # NOTE pytree constants created in the torch.compile region will # NOT be guarded (even though they have a source set) @@ -1263,39 +1386,141 @@ def as_python_constant(self) -> object: return super().as_python_constant() + def as_proxy(self) -> object: + if isinstance(self.value, enum.Enum): + if isinstance(self.value, int): + return int(self.value) + return self.value + return super().as_proxy() + def guard_as_python_constant(self) -> object: if self.source: install_guard(self.source.make_guard(GuardBuilder.ID_MATCH)) return self.value return super().guard_as_python_constant() - def torch_function_check(self) -> None: - assert has_torch_function(self), ( - f"calling torch function on object without __torch_function__ {self}" - ) - - def get_torch_fn(self, tx: "InstructionTranslator") -> VariableTracker: - self.torch_function_check() - from .torch_function import get_torch_function_fn - - return get_torch_function_fn(tx, self) + def bool_impl( + self, + tx: "InstructionTranslator", + ) -> "VariableTracker | None": + # Mirrors slot_nb_bool: + # https://github.com/python/cpython/blob/c09ccd9c429/Objects/typeobject.c#L9408-L9458 + if self._maybe_get_baseclass_method("__bool__"): + result = self.call_method(tx, "__bool__", [], {}) + if result.is_python_constant(): + result_value = result.as_python_constant() + if not isinstance(result_value, bool): + raise_observed_exception( + TypeError, + tx, + args=[ + f"__bool__ should return bool, returned {type(result_value).__name__}" + ], + ) + return result + return None - def call_torch_function( + def nb_index_impl( self, tx: "InstructionTranslator", - fn: VariableTracker, - types: "TupleVariable", - args: Sequence[Any], - kwargs: dict[str, Any], ) -> VariableTracker: - self.torch_function_check() - - from .torch_function import call_torch_function + # CPython: PyNumber_Index checks tp_as_number->nb_index. + # For user-defined types, __index__ in tp_dict means nb_index is set. + type_attr = inspect.getattr_static(type(self.value), "__index__", None) + if type_attr is None: + return super().nb_index_impl(tx) + source = self.source and self.get_source_by_walking_mro(tx, "__index__") + method_var = self.resolve_type_attr(tx, "__index__", type_attr, source) + result = method_var.call_function(tx, [], {}) + # CPython validates that __index__ returns an int. + # https://github.com/python/cpython/blob/c09ccd9c429/Objects/abstract.c#L1433-L1438 + if result.is_python_constant() and not isinstance( + result.as_python_constant(), int + ): + raise_observed_exception( + TypeError, + tx, + args=[ + f"__index__ returned non-int (type {type(result.as_python_constant()).__name__})" + ], + ) + return result - return call_torch_function( + def nb_int_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # CPython: slot_nb_int calls __int__(), PyNumber_Long validates the return type. + # https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1538-L1550 + source = self.source and self.get_source_by_walking_mro(tx, "__int__") + method_var = self.resolve_type_attr( tx, - self.get_torch_fn(tx), - fn, + "__int__", + inspect.getattr_static(type(self.value), "__int__"), + source, + ) + result = method_var.call_function(tx, [], {}) + if not issubclass(result.python_type(), int): + raise_observed_exception( + TypeError, + tx, + args=[ + f"__int__ returned non-int (type {result.python_type().__name__})" + ], + ) + return result + + def nb_float_impl( + self, + tx: "InstructionTranslator", + ) -> VariableTracker: + # CPython: slot_nb_float calls __float__(), PyNumber_Float validates the return type. + # https://github.com/python/cpython/blob/v3.13.0/Objects/abstract.c#L1647-L1658 + source = self.source and self.get_source_by_walking_mro(tx, "__float__") + method_var = self.resolve_type_attr( + tx, + "__float__", + inspect.getattr_static(type(self.value), "__float__"), + source, + ) + result = method_var.call_function(tx, [], {}) + if not issubclass(result.python_type(), float): + raise_observed_exception( + TypeError, + tx, + args=[ + f"__float__ returned non-float (type {result.python_type().__name__})" + ], + ) + return result + + def torch_function_check(self) -> None: + assert has_torch_function(self), ( + f"calling torch function on object without __torch_function__ {self}" + ) + + def get_torch_fn(self, tx: "InstructionTranslator") -> VariableTracker: + self.torch_function_check() + from .torch_function import get_torch_function_fn + + return get_torch_function_fn(tx, self) + + def call_torch_function( + self, + tx: "InstructionTranslator", + fn: VariableTracker, + types: "TupleVariable", + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> VariableTracker: + self.torch_function_check() + + from .torch_function import call_torch_function + + return call_torch_function( + tx, + self.get_torch_fn(tx), + fn, types, args, kwargs, @@ -1312,6 +1537,28 @@ def _supported_random_functions() -> set[Any]: } return fns + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # PyObject_GetItem: https://github.com/python/cpython/blob/62a6e898e01/Objects/abstract.c#L155-L206 + method = self._maybe_get_baseclass_method("__getitem__") + if ( + self._base_vt is not None + and self._base_methods is not None + and method in self._base_methods + ): + return self._base_vt.mp_subscript_impl(tx, key) + if isinstance(method, types.FunctionType): + source_fn = self.source and self.get_source_by_walking_mro( + tx, "__getitem__" + ) + return variables.UserMethodVariable( + method, self, source_fn=source_fn, source=self.source + ).call_function(tx, [key], {}) + return super().mp_subscript_impl(tx, key) + def call_method( self, tx: "InstructionTranslator", @@ -1319,12 +1566,14 @@ def call_method( args: list[Any], kwargs: dict[str, Any], ) -> VariableTracker: - from . import CONSTANT_VARIABLE_NONE, UserMethodVariable + from .. import trace_rules + from . import UserMethodVariable + from .constant import ConstantVariable method = self._maybe_get_baseclass_method(name) if method is not None: if method is object.__init__: - return CONSTANT_VARIABLE_NONE + return ConstantVariable.create(None) if is_standard_setattr(method) or isinstance(self.value, threading.local): return self.method_setattr_standard(tx, *args, **kwargs) @@ -1358,6 +1607,28 @@ def call_method( ], ) + # torch.Generator methods like manual_seed(), get_state(), etc. + # are stateful RNG operations that cannot be soundly traced. + if ( + isinstance(self.value, torch._C.Generator) + and name in trace_rules._GENERATOR_METHODS_THAT_GRAPH_BREAK + ): + unimplemented( + gb_type="torch.Generator method", + context=f"torch.Generator.{name}", + explanation=f"torch.Generator.{name}() is a stateful RNG " + "operation that cannot be soundly traced in the FX graph.", + hints=[*graph_break_hints.FUNDAMENTAL], + ) + + # Delegate to _base_vt for non-overridden base-class methods + if ( + self._base_vt is not None + and self._base_methods is not None + and method in self._base_methods + ): + return self._base_vt.call_method(tx, name, args, kwargs) + # check for methods implemented in C++ if isinstance(method, types.FunctionType): source = self.source @@ -1379,6 +1650,46 @@ def call_method( return super().call_method(tx, name, args, kwargs) + def len_impl(self, tx: "InstructionTranslator") -> VariableTracker: + method = self._maybe_get_baseclass_method("__len__") + if method is not None: + type_attr = self.lookup_class_mro_attr("__len__") + source = self.source and self.get_source_by_walking_mro(tx, "__len__") + method_var = self.resolve_type_attr(tx, "__len__", type_attr, source) + if not isinstance(method_var, variables.GetAttrVariable): + return method_var.call_function(tx, [], {}) + + unimplemented( + gb_type="Cannot trace user-defined __len__", + context=f"{self.python_type_name()}.__len__()", + explanation=( + f"Dynamo cannot trace len() on {self.python_type_name()} because the __len__ " + "method is either not traceable (e.g., defined in C or built-in) or returns a " + "non-constant value." + ), + hints=[ + *graph_break_hints.SUPPORTABLE, + ], + ) + + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + if ( + self._base_vt is not None + and self._base_methods is not None + and self._maybe_get_baseclass_method("__len__") in self._base_methods + ): + return self._base_vt.sq_length(tx) + return self.len_impl(tx) + + def mp_length(self, tx: "InstructionTranslator") -> VariableTracker: + if ( + self._base_vt is not None + and self._base_methods is not None + and self._maybe_get_baseclass_method("__len__") in self._base_methods + ): + return self._base_vt.mp_length(tx) + return self.len_impl(tx) + def method_setattr_standard( self, tx: "InstructionTranslator", @@ -1468,9 +1779,22 @@ def method_setattr_standard( # NOTE: else we assume the descriptor (if any) has a # side-effect-free `__set__` as far as Dynamo tracing is concerned. - # Emulate the standard setattr on instance dict. + # If the code reaches here, the attribute is either: + # 1) a slot descriptor + # 2) a plain attribute with no descriptor + # If the object has no __dict__, only slot descriptors (member_descriptor) + # allow mutation. Any other attribute assignment raises AttributeError. + if not hasattr(self.value, "__dict__"): + descriptor = self.lookup_class_mro_attr(name_str) + if not inspect.ismemberdescriptor(descriptor): + error_msg = VariableTracker.build( + tx, + f"'{type(self.value).__name__}' object has no attribute '{name_str}'", + ) + raise_observed_exception(AttributeError, tx, args=[error_msg]) + tx.output.side_effects.store_attr(self, name_str, value) - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) def needs_slow_setattr(self) -> bool: return not is_standard_setattr( @@ -1478,6 +1802,10 @@ def needs_slow_setattr(self) -> bool: ) and not isinstance(self.value, threading.local) def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: + if self._base_vt is not None and self._base_methods is not None: + iter_method = self._maybe_get_baseclass_method("__iter__") + if iter_method is not None and iter_method in self._base_methods: + return self._base_vt.unpack_var_sequence(tx) if ( self.source and self._maybe_get_baseclass_method("__iter__") is list.__iter__ @@ -1616,10 +1944,18 @@ def _getattr_static(self, name: str) -> object: # has side-effect free __getattribute__ and the attribute is not visible without a dynamic lookup. # NOTE we assume the following descriptors are side-effect-free as far # as Dynamo tracing is concerned. - if not self._object_has_getattribute and ( + # + # C-level descriptors (getset_descriptor for __dict__, member_descriptor + # for __slots__) are always safe to resolve — their __get__ is + # implemented in C and doesn't run user code, so __getattribute__ + # overrides are irrelevant. The NO_SUCH_SUBOBJ and + # _is_c_defined_property cases DO require the absence of a custom + # __getattribute__ because they fall back to + # type(self.value).__getattribute__ which could be user-overridden. + if inspect.ismemberdescriptor(subobj) or inspect.isgetsetdescriptor(subobj): + subobj = type(self.value).__getattribute__(self.value, name) + elif not self._object_has_getattribute and ( subobj is NO_SUCH_SUBOBJ # e.g., threading.local - or inspect.ismemberdescriptor(subobj) # e.g., __slots__ - or inspect.isgetsetdescriptor(subobj) # e.g., __dict__ or self._is_c_defined_property(subobj) ): # Call __getattribute__, we have already checked that this is not overridden and side-effect free. We don't @@ -1687,6 +2023,7 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key: str) -> bool mutated_attr = tx.output.side_effects.load_attr(self, key, deleted_ok=True) return not isinstance(mutated_attr, variables.DeletedVariable) + # TODO(guilhermeleobas): This can trigger a side effect return key in self.value.__dict__ def get_source_by_walking_mro( @@ -1769,42 +2106,37 @@ def get_source_by_walking_mro( ], ) - def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: - source: Source | None = AttrSource(self.source, name) if self.source else None + def generic_getattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + """Dynamo implementation of CPython's PyObject_GenericGetAttr. - if self._object_has_getattribute: - getattribute_fn = inspect.getattr_static( - type(self.value), "__getattribute__" - ) - new_source: AttrSource | None = ( - AttrSource(self.source, "__getattribute__") if self.source else None - ) + This mirrors object.__getattribute__ and is called from: + - var_getattr (for objects without a custom __getattribute__) + - SuperVariable.call_method (when super().__getattribute__() resolves + to object.__getattribute__) - try: - return variables.UserMethodVariable( - getattribute_fn, - self, - source=new_source, - ).call_function(tx, [VariableTracker.build(tx, name)], {}) - except ObservedAttributeError: - # Pass through to __getattr__ if __getattribute__ fails - handle_observed_exception(tx) + The algorithm: MRO walk → data descriptor → instance __dict__ → + non-data descriptor / plain class attr → dynamic fallback → + __getattr__ → AttributeError. + """ + source: Source | None = AttrSource(self.source, name) if self.source else None if tx.output.side_effects.has_pending_mutation_of_attr(self, name): result = tx.output.side_effects.load_attr(self, name, deleted_ok=True) if isinstance(result, variables.DeletedVariable): - error_message = VariableTracker.build( - tx, - f"'{type(self.value).__name__}' object has no attribute '{name}'", - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'", + ], ) return result if name == "__dict__": + if not hasattr(self.value, "__dict__"): + raise_observed_exception(AttributeError, tx) return self.get_dict_vt(tx) # TODO(anijain2305) - Investigate if we need specialization for more @@ -1846,6 +2178,8 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return self.resolve_data_descriptor(tx, name, type_attr, source) # Step 3: Instance __dict__ — return as-is, no descriptor invocation. + # TODO(guilhermeleobas): step 3 should look into dict_vt and not self.value.__dict__ + # as the object could have mutated an attribute via setattr if hasattr(self.value, "__dict__") and name in self.value.__dict__: subobj = self.value.__dict__[name] source = self.maybe_wrap_nn_module_source_for_instance(tx, name, source) @@ -1921,15 +2255,33 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker ) # Step 7: AttributeError. - error_message = VariableTracker.build( - tx, f"'{type(self.value).__name__}' object has no attribute '{name}'" - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[f"'{type(self.value).__name__}' object has no attribute '{name}'"], ) + def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: + if self._object_has_getattribute: + getattribute_fn = inspect.getattr_static( + type(self.value), "__getattribute__" + ) + new_source: AttrSource | None = ( + AttrSource(self.source, "__getattribute__") if self.source else None + ) + + try: + return variables.UserMethodVariable( + getattribute_fn, + self, + source=new_source, + ).call_function(tx, [VariableTracker.build(tx, name)], {}) + except ObservedAttributeError: + # Pass through to __getattr__ if __getattribute__ fails + handle_observed_exception(tx) + + return self.generic_getattr(tx, name) + def resolve_data_descriptor( self, tx: "InstructionTranslator", @@ -1962,13 +2314,12 @@ def resolve_data_descriptor( try: resolved = type(self.value).__getattribute__(self.value, name) except AttributeError: - error_message = VariableTracker.build( - tx, f"'{type(self.value).__name__}' object has no attribute '{name}'" - ) raise_observed_exception( AttributeError, tx, - args=[error_message], + args=[ + f"'{type(self.value).__name__}' object has no attribute '{name}'" + ], ) return VariableTracker.build(tx, resolved, source) @@ -2041,6 +2392,9 @@ def resolve_type_attr( # MethodDescriptorType: e.g. list.append (PyMethodDef) # WrapperDescriptorType: e.g. list.__add__ (slot wrappers) # MethodWrapperType: e.g. [].__add__ (bound slot wrappers) + # + # Exception: if the descriptor has a registered polyfill, return the + # polyfill as a bound method so Dynamo can trace through it. if ( isinstance( type_attr, @@ -2053,7 +2407,17 @@ def resolve_type_attr( or torch._C._dynamo.utils.is_instancemethod(type_attr) # type: ignore[attr-defined] or is_cython_function(type_attr) ): - return variables.GetAttrVariable(self, name, None, source=source) + from .. import trace_rules + + if trace_rules.is_polyfilled_callable(type_attr): # type: ignore[arg-type] + from .functions import PolyfilledFunctionVariable + + polyfill_handlers = PolyfilledFunctionVariable._get_polyfill_handlers() + wrapped: Any = polyfill_handlers.get(type_attr) # type: ignore[arg-type] + if wrapped is not None: + traceable_fn = wrapped.__torch_dynamo_polyfill__ + return variables.UserMethodVariable(traceable_fn, self) + return variables.GetAttrVariable(self, name, type(type_attr), source=source) # Plain class variable (or MethodType, C-level non-data descriptor # without __get__, etc.). @@ -2095,11 +2459,7 @@ def maybe_wrap_nn_module_source_for_instance( ) -> Source | None: """Wrap source for nn.Module instance dict attribute access if needed.""" if ( - ( - torch._dynamo.config.inline_inbuilt_nn_modules - or isinstance(self, variables.FSDPManagedNNModuleVariable) - ) - and source + source and isinstance(self, variables.UnspecializedNNModuleVariable) and (not tx.output.export or torch._dynamo.config.install_free_tensors) ): @@ -2126,17 +2486,26 @@ def call_obj_hasattr( ) except ObservedAttributeError: handle_observed_exception(tx) - return variables.CONSTANT_VARIABLE_FALSE + return variables.ConstantVariable.create(False) def is_python_hashable(self) -> bool: raise_on_overridden_hash(self.value, self) + if self._base_vt is not None: + return self._base_vt.is_python_hashable() return True def get_python_hash(self) -> int: - # default hash + if self._base_vt is not None: + return self._base_vt.get_python_hash() return hash(self.value) def is_python_equal(self, other: object) -> bool: + if ( + isinstance(other, VariableTracker) + and self.is_python_constant() + and other.is_python_constant() + ): + return self.as_python_constant() == other.as_python_constant() # id check if not isinstance(other, UserDefinedVariable): return False @@ -2319,36 +2688,20 @@ def call_tree_map_with_path_branch( class FrozenDataClassVariable(UserDefinedObjectVariable): - @staticmethod - def create( - tx: "InstructionTranslator", value: object, source: Source - ) -> "FrozenDataClassVariable": - from dataclasses import fields + """Frozen dataclass variable for as_proxy/as_python_constant/hashability. - assert is_frozen_dataclass(value) + Construction is handled by the generic polyfill path (tracing through + the auto-generated __init__). Field values are retrieved dynamically + via var_getattr using InstructionTranslator.current_tx(). + """ - field_map = {} - for field in fields(value): # type: ignore[arg-type] - if hasattr(value, field.name): - field_map[field.name] = VariableTracker.build( - tx, - getattr(value, field.name), - source and AttrSource(source, field.name), - ) + def _get_field_vt(self, field_name: str) -> VariableTracker: + from torch._dynamo.symbolic_convert import InstructionTranslator - return FrozenDataClassVariable(value, fields=field_map, source=source) - - def __init__( - self, value: object, fields: dict[str, Any] | None = None, **kwargs: Any - ) -> None: - super().__init__(value, **kwargs) - if fields is None: - fields = {} - self.fields = fields + tx = InstructionTranslator.current_tx() + return self.var_getattr(tx, field_name) def as_python_constant(self) -> object: - # NOTE: this is an intentionally limited version of - # `as_python_constant` for `nonstrict_trace` implementation. from dataclasses import fields import torch.utils._pytree as pytree @@ -2356,154 +2709,74 @@ def as_python_constant(self) -> object: if not istype( self.value, (pytree.TreeSpec, pytree.LeafSpec, pytree.ConstantNode) ): - # TODO loosen this restriction and fix `as_proxy`. raise NotImplementedError( "currently can't reconstruct arbitrary frozen dataclass instances" ) - # LeafSpec is deprecated, use treespec_leaf() instead if istype(self.value, pytree.LeafSpec): return pytree.treespec_leaf() - args = [] - kwargs = {} + args: list[object] = [] + kwargs: dict[str, object] = {} for field in fields(self.value): # type: ignore[arg-type] if field.init: - data = self.fields[field.name].as_python_constant() + data = self._get_field_vt(field.name).as_python_constant() if getattr(field, "kw_only", False): kwargs[field.name] = data else: args.append(data) - # This is safe because we know the TreeSpec classes constructors don't - # have external side effects. - ctor = self.python_type() - return ctor(*args, **kwargs) + return self.python_type()(*args, **kwargs) def as_proxy(self) -> object: from dataclasses import fields - args = [] - kwargs = {} + args: list[object] = [] + kwargs: dict[str, object] = {} for field in fields(self.value): # type: ignore[arg-type] - proxy = self.fields[field.name].as_proxy() + proxy = self._get_field_vt(field.name).as_proxy() if hasattr(field, "kw_only") and field.kw_only: kwargs[field.name] = proxy else: args.append(proxy) - # TODO this isn't really safe, because - # 1. it could invoke a user defined `__post_init__`. - # 2. it could invoke a user defined `__init__` if the class _subclasses_ - # a frozen dataclass. - # Either of the above could end up mutating external state. - ctor = self.python_type() - return ctor(*args, **kwargs) + return self.python_type()(*args, **kwargs) def reconstruct(self, codegen: "PyCodegen") -> None: - from dataclasses import fields - - # Handle specific pytree classes - import torch.utils._pytree as pytree - - if isinstance(self.value, pytree.TreeSpec) and self.value.is_leaf(): - # Create a new LeafSpec instance by calling the constructor - codegen.add_push_null( - lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec") - ) - codegen.extend_output(create_call_function(0, False)) + if self.source is not None: + codegen(self.source) return - - # For general frozen dataclasses, reconstruct by calling the constructor - # with the field values as arguments - dataclass_cls = self.python_type() - - if hasattr(dataclass_cls, "__post_init__"): - unimplemented( - gb_type="Frozen dataclass with __post_init__", - context=f"dataclass={dataclass_cls.__name__}", - explanation="Cannot reconstruct frozen dataclass with __post_init__ method, " - "as it may have side effects that would be incorrectly replayed.", - hints=[ - "Remove the __post_init__ method from the frozen dataclass.", - *graph_break_hints.SUPPORTABLE, - ], - ) - - # Collect positional and keyword-only arguments - pos_args = [] - # pyrefly: ignore [implicit-any] - kw_args = [] - for field in fields(dataclass_cls): - if not field.init: - continue - field_vt = self.fields.get(field.name) - if field_vt is None: - unimplemented( - gb_type="Frozen dataclass with missing field", - context=f"dataclass={dataclass_cls.__name__}, field={field.name}", - explanation=f"Cannot reconstruct frozen dataclass: field '{field.name}' " - "was not tracked during tracing.", - hints=[*graph_break_hints.SUPPORTABLE], - ) - if getattr(field, "kw_only", False): - kw_args.append((field.name, field_vt)) - else: - pos_args.append(field_vt) - - # Load the dataclass constructor - codegen.add_push_null( - lambda: codegen.append_output( - codegen.create_load_const_unchecked(dataclass_cls) - ) + codegen.append_output( + codegen.create_load_const_unchecked(self.as_python_constant()) ) - # Reconstruct all arguments - for arg_vt in pos_args: - codegen(arg_vt) - for _, arg_vt in kw_args: - codegen(arg_vt) - # Call the constructor - total_args = len(pos_args) + len(kw_args) - if kw_args: - kw_names = tuple(name for name, _ in kw_args) - codegen.extend_output( - codegen.create_call_function_kw(total_args, kw_names, push_null=False) - ) - else: - codegen.extend_output(create_call_function(total_args, False)) - - # NB: This is called during __init__ for a frozen dataclass - # use this to accumulate the most up-to-date field values - def method_setattr_standard( - self, - tx: "InstructionTranslator", - name: VariableTracker, - value: VariableTracker, - directly_update_dict: bool = False, - ) -> VariableTracker: - self.fields[name.as_python_constant()] = value - return super().method_setattr_standard(tx, name, value, directly_update_dict) def __repr__(self) -> str: return f"{self.__class__.__name__}({self.value_type.__name__})" def is_python_hashable(self) -> Literal[True]: - # TODO - Check corner cases like eq=False, hash=False etc return True def get_python_hash(self) -> int: - return hash(tuple(arg.get_python_hash() for arg in self.fields.values())) + from dataclasses import fields as dc_fields + + return hash( + tuple( + self._get_field_vt(f.name).get_python_hash() + for f in dc_fields(self.value) # type: ignore[arg-type] + ) + ) def is_python_equal(self, other: object) -> bool: if not isinstance(other, FrozenDataClassVariable): return False - is_class_same = self.python_type() is other.python_type() - is_field_name_same = self.fields.keys() == other.fields.keys() - is_field_value_same = all( - value_a.is_python_equal(value_b) - for value_a, value_b in zip(self.fields.values(), other.fields.values()) + if self.python_type() is not other.python_type(): + return False + from dataclasses import fields as dc_fields + + return all( + self._get_field_vt(f.name).is_python_equal(other._get_field_vt(f.name)) + for f in dc_fields(self.value) # type: ignore[arg-type] ) - return is_class_same and is_field_name_same and is_field_value_same class SourcelessGraphModuleVariable(UserDefinedObjectVariable): @@ -2533,7 +2806,8 @@ def call_method( class UserDefinedExceptionObjectVariable(UserDefinedObjectVariable): def __init__(self, value: object, **kwargs: Any) -> None: super().__init__(value, **kwargs) - self.exc_vt = variables.ExceptionVariable(self.value_type, ()) + init_args = kwargs.get("init_args", []) + self.exc_vt = variables.ExceptionVariable(self.value_type, init_args) @property def fn(self) -> Callable[..., object]: @@ -2552,10 +2826,7 @@ def call_method( and inspect.ismethoddescriptor(method) and len(kwargs) == 0 ): - self.exc_vt.args = tuple(args) - # pyrefly: ignore[missing-attribute] - self.value.args = args - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) elif ( name == "__setattr__" and len(args) == 2 @@ -2568,13 +2839,24 @@ def call_method( return self.exc_vt.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) + def var_getattr(self, tx: "InstructionTranslator", name: str): + if name in ( + "args", + "__cause__", + "__context__", + "__suppress_context__", + "__traceback__", + ): + return self.exc_vt.var_getattr(tx, name) + return super().var_getattr(tx, name) + @property def __context__(self) -> "ConstantVariable": # type: ignore[return-value] return self.exc_vt.__context__ @property - def args(self) -> tuple[VariableTracker, ...]: + def args(self) -> list[VariableTracker]: return self.exc_vt.args def set_context(self, context: "variables.ExceptionVariable") -> None: @@ -2588,6 +2870,9 @@ def exc_type(self) -> type[BaseException]: def python_stack(self) -> traceback.StackSummary | None: return self.exc_vt.python_stack + def debug_repr(self) -> str: + return self.exc_vt.debug_repr() + @python_stack.setter def python_stack(self, value: traceback.StackSummary) -> None: self.exc_vt.python_stack = value @@ -2645,6 +2930,40 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker return super().var_getattr(tx, name) +_CONSTANT_BASE_TYPES = (int, float, str) + +_constant_base_methods: dict[type, set[Any]] = { + t: {m for m in t.__dict__.values() if callable(m)} for t in _CONSTANT_BASE_TYPES +} + + +class UserDefinedConstantVariable(UserDefinedObjectVariable): + """ + Represents user-defined objects that subclass immutable constant types + (int, float, str). + + Uses a ConstantVariable as _base_vt for the underlying constant value. + """ + + def __init__(self, value: Any, **kwargs: Any) -> None: + from .constant import ConstantVariable + + super().__init__(value, **kwargs) + for base in type(value).__mro__: + if base in _CONSTANT_BASE_TYPES: + self._base_vt = ConstantVariable.create(base(value)) + self._base_methods = _constant_base_methods[base] + break + assert self._base_vt is not None + + def as_python_constant(self) -> Any: + return self.value + + def as_proxy(self) -> object: + assert self._base_vt is not None + return self._base_vt.as_proxy() + + class IntWrapperVariable(UserDefinedObjectVariable): # Dummy class to check if the object is an IntWrapper, and turn it into a # symint @@ -2680,7 +2999,7 @@ def call_method( assert self.idx is not None tx.output.side_effects.remove_hook(self.idx) self.idx = self.REMOVED - return variables.CONSTANT_VARIABLE_NONE + return variables.ConstantVariable.create(None) return super().call_method(tx, name, args, kwargs) def reconstruct(self, codegen: "PyCodegen") -> None: @@ -2720,12 +3039,47 @@ def __init__( assert self.source is None, ( "dict_vt must be constructed by builder.py when source is present" ) - self._dict_vt = ConstDictVariable( - {}, type(value), mutation_type=ValueMutationNew() + self._base_vt = ConstDictVariable( + {}, + mutation_type=ValueMutationNew(), ) else: - self._dict_vt = dict_vt - self._dict_methods = dict_methods + self._base_vt = dict_vt + self._base_methods = dict_methods + assert self._base_vt is not None + + def len(self) -> int: + # Used by nn_module.py to short-circuit the nn.Module forward method + # when no hooks are registered. Calling .len() directly avoids the + # overhead of full call_method("__len__") dispatch during tracing. + assert self._base_vt is not None + return self._base_vt.len() # type: ignore[union-attr] + + def sq_length(self, tx: "InstructionTranslator") -> VariableTracker: + # Dict implements __len__ via mp_length (mapping protocol), not + # sq_length (sequence protocol). Redirect so generic_len works. + return self.mp_length(tx) + + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: VariableTracker, + ) -> VariableTracker: + # dict_subscript: https://github.com/python/cpython/blob/62a6e898e01/Objects/dictobject.c#L3673-L3706 + # TODO(follow-up): add test for unhashable/invalid key type, Counter missing key + method = self._maybe_get_baseclass_method("__getitem__") + if method in self._base_methods: + assert self._base_vt is not None + try: + return self._base_vt.mp_subscript_impl(tx, key) + except ObservedKeyError: + if issubclass( + self.python_type(), dict + ) and self._maybe_get_baseclass_method("__missing__"): + return self.call_method(tx, "__missing__", [key], {}) + else: + raise + return super().mp_subscript_impl(tx, key) def call_method( self, @@ -2734,54 +3088,322 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - method = self._maybe_get_baseclass_method(name) - if method in self._dict_methods: - # Dict subclasses can override __missing__ to provide fallback - # behavior instead of raising a KeyError. This is used, for example, - # by collections.Counter. + # Dict subclasses can override __missing__ to provide fallback + # behavior instead of raising a KeyError. This is used, for example, + # by collections.Counter. + if ( + name == "__getitem__" + and self._maybe_get_baseclass_method("__getitem__") in self._base_methods + and self._maybe_get_baseclass_method("__missing__") + ): + assert self._base_vt is not None try: - return self._dict_vt.call_method(tx, name, args, kwargs) + return self._base_vt.call_method(tx, name, args, kwargs) except ObservedKeyError: - if ( - name == "__getitem__" - and issubclass(self.python_type(), dict) - and self._maybe_get_baseclass_method("__missing__") - ): - return self.call_method(tx, "__missing__", args, kwargs) - else: - raise + handle_observed_exception(tx) + return self.call_method(tx, "__missing__", args, kwargs) return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: - if type(self.value).__iter__ in ( # type: ignore[attr-defined] - dict.__iter__, - collections.OrderedDict.__iter__, - ): - return self._dict_vt.unpack_var_sequence(tx) - raise NotImplementedError - def is_underlying_vt_modified(self, side_effects: "SideEffects") -> bool: - return side_effects.is_modified(self._dict_vt) +# TODO: move to dicts.py alongside ConstDictVariable and DefaultDictVariable. +# Currently blocked by circular imports (dicts.py ↔ user_defined.py). +class OrderedDictVariable(UserDefinedDictVariable): + """ + Represents collections.OrderedDict instances. - @property - def user_cls(self) -> type[object]: - return self._dict_vt.user_cls + CPython has both a pure-Python implementation: + https://github.com/python/cpython/blob/v3.13.0/Lib/collections/__init__.py#L86-L339 + and a C accelerator that replaces it at runtime: + https://github.com/python/cpython/blob/v3.13.0/Objects/odictobject.c - @property - def items(self) -> dict[VariableTracker, VariableTracker]: - return self._dict_vt.items + The C accelerator is always active, so methods like move_to_end and + popitem are C-level method_descriptors, not Python functions. - def install_dict_keys_match_guard(self) -> None: - return self._dict_vt.install_dict_keys_match_guard() + Dict storage is delegated to _base_vt (a ConstDictVariable) via + UserDefinedDictVariable. + """ - def install_dict_contains_guard( - self, tx: "InstructionTranslator", args: list[VariableTracker] + def __init__( + self, + value: object, + dict_vt: "ConstDictVariable | None" = None, + **kwargs: Any, ) -> None: - return self._dict_vt.install_dict_contains_guard(tx, args) + if dict_vt is None: + from .dicts import ConstDictVariable - def is_python_hashable(self) -> Literal[False]: - raise_on_overridden_hash(self.value, self) - return False + dict_vt = ConstDictVariable( + {}, + user_cls=collections.OrderedDict, + mutation_type=ValueMutationNew(), + ) + super().__init__(value, dict_vt=dict_vt, **kwargs) + + def is_python_constant(self) -> bool: + assert self._base_vt is not None + return self._base_vt.is_python_constant() + + def as_python_constant(self) -> Any: + assert self._base_vt is not None + return collections.OrderedDict(self._base_vt.as_python_constant()) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .constant import ConstantVariable + from .dicts import HashableTracker + + # OrderedDict-exclusive C methods that ConstDictVariable doesn't handle. + # https://github.com/python/cpython/blob/v3.13.0/Objects/odictobject.c + if name == "move_to_end": + assert self._base_vt is not None + self._base_vt.install_dict_keys_match_guard() # type: ignore[union-attr] + tx.output.side_effects.mutation(self._base_vt) + if args[0] not in self._base_vt: # type: ignore[operator] + raise_observed_exception(KeyError, tx) + + last = True + if len(args) == 2 and args[0].is_python_constant(): + last = args[1].as_python_constant() + if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): + last = kwargs["last"].as_python_constant() + + key = HashableTracker(args[0]) + self._base_vt.items.move_to_end(key, last=last) # type: ignore[union-attr] + return ConstantVariable.create(None) + elif name == "popitem": + assert self._base_vt is not None + if not self._base_vt.items: # type: ignore[union-attr] + raise_observed_exception( + KeyError, tx, args=["popitem(): dictionary is empty"] + ) + + last = True + if len(args) == 1 and args[0].is_python_constant(): + last = args[0].as_python_constant() + if kwargs and "last" in kwargs and kwargs["last"].is_python_constant(): + last = kwargs["last"].as_python_constant() + + k, v = self._base_vt.items.popitem(last=last) # type: ignore[union-attr] + self._base_vt.should_reconstruct_all = True # type: ignore[union-attr] + tx.output.side_effects.mutation(self._base_vt) + return variables.TupleVariable([k.vt, v]) + return super().call_method(tx, name, args, kwargs) + + +# TODO: move to dicts.py alongside ConstDictVariable. +# Currently blocked by circular imports (dicts.py ↔ user_defined.py). +class DefaultDictVariable(UserDefinedDictVariable): + """ + Represents collections.defaultdict instances. + + CPython's defaultdict is implemented in C: + https://github.com/python/cpython/blob/v3.13.3/Modules/_collectionsmodule.c#L2177-L2180 + + default_factory is a field on the C struct (defdictobject.default_factory), + not a Python instance attribute, so we model it as a field on the VT. + + Dict storage is delegated to _base_vt (a ConstDictVariable) via + UserDefinedDictVariable. + """ + + _cpython_type = collections.defaultdict + + def __init__( + self, + value: object, + default_factory: VariableTracker | None = None, + dict_vt: ConstDictVariable | None = None, + **kwargs: Any, + ) -> None: + if dict_vt is None: + from .dicts import ConstDictVariable + + dict_vt = ConstDictVariable( + {}, + mutation_type=ValueMutationNew(), + ) + super().__init__(value, dict_vt=dict_vt, **kwargs) + if default_factory is None: + from .constant import ConstantVariable + + default_factory = ConstantVariable.create(None) + self.default_factory = default_factory + + @staticmethod + def is_supported_factory(arg: VariableTracker) -> bool: + """Check if arg is a valid default_factory (callable or None). + + CPython's defaultdict.__init__ checks ``callable(factory)`` and + raises TypeError if not. We mirror this by checking the + underlying Python value when possible. + """ + if isinstance(arg, variables.ConstantVariable): + return arg.value is None + # Check the real Python value for callable() + try: + val = arg.as_python_constant() + return val is None or callable(val) + except Exception: + pass + # Callables (functions, builtins, classes) are supported + return isinstance( + arg, + ( + variables.BaseBuiltinVariable, + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + variables.UserDefinedClassVariable, + ), + ) + + def is_python_constant(self) -> bool: + assert self._base_vt is not None + # An empty defaultdict with a non-constant factory can't be + # constant-folded (we can't serialize the factory). + if not self.default_factory.is_python_constant() and not self._base_vt.items: # type: ignore[union-attr] + return False + return self._base_vt.is_python_constant() + + def as_python_constant(self) -> Any: + assert self._base_vt is not None + factory = None + if self.default_factory.is_python_constant(): + factory = self.default_factory.as_python_constant() + return collections.defaultdict(factory, self._base_vt.as_python_constant()) + + def debug_repr(self) -> str: + assert self.default_factory is not None + assert self._base_vt is not None + return ( + f"defaultdict({self.default_factory.debug_repr()}, " + f"{self._base_vt.debug_repr()})" + ) + + def var_getattr( + self, + tx: "InstructionTranslator", + name: str, + ) -> VariableTracker: + if name == "default_factory": + return self.default_factory + return super().var_getattr(tx, name) + + def _missing_impl( + self, + tx: "InstructionTranslator", + key: "VariableTracker", + ) -> "VariableTracker": + """defaultdict.__missing__: auto-vivification via default_factory. + + https://github.com/python/cpython/blob/v3.13.3/Modules/_collectionsmodule.c#L2233-L2254 + """ + from .constant import ConstantVariable + + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[key]) + default_var = self.default_factory.call_function(tx, [], {}) + assert self._base_vt is not None + self._base_vt.call_method(tx, "__setitem__", [key, default_var], {}) + return default_var + + def mp_subscript_impl( + self, + tx: "InstructionTranslator", + key: "VariableTracker", + ) -> "VariableTracker": + """defaultdict.__getitem__: dict lookup with __missing__ fallback.""" + assert self._base_vt is not None + if key in self._base_vt: # type: ignore[operator] + return self._base_vt.getitem_const(tx, key) # type: ignore[union-attr] + return self._missing_impl(tx, key) + + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + from .constant import ConstantVariable + + if name == "__init__": + # defaultdict.__init__(self, default_factory=None, *args, **kwargs) + # https://github.com/python/cpython/blob/v3.13.3/Modules/_collectionsmodule.c#L2072 + # Extract default_factory, delegate rest to dict.__init__ + if len(args) >= 1: + if self.is_supported_factory(args[0]): + self.default_factory = args[0] + tx.output.side_effects.store_attr( + self, + "default_factory", + self.default_factory, + ) + args = list(args[1:]) + else: + # CPython raises TypeError for non-callable first arg + raise_observed_exception( + TypeError, + tx, + args=["first argument must be callable or None"], + ) + assert self._base_vt is not None + return self._base_vt.call_method(tx, "__init__", args, kwargs) + elif name == "__getitem__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + return self.mp_subscript_impl(tx, args[0]) + elif name == "__missing__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + return self._missing_impl(tx, args[0]) + elif name == "copy": + # defaultdict.copy() creates a new defaultdict with same factory + # https://github.com/python/cpython/blob/v3.13.3/Modules/_collectionsmodule.c#L2282 + from .builder import SourcelessBuilder + + assert self._base_vt is not None + new_dd = tx.output.side_effects.track_new_user_defined_object( + SourcelessBuilder.create(tx, dict), + SourcelessBuilder.create(tx, collections.defaultdict), + [], + ) + assert isinstance(new_dd, DefaultDictVariable) + new_dd.default_factory = self.default_factory + new_dd._base_vt = self._base_vt.clone( + mutation_type=ValueMutationNew(), + source=None, + ) + tx.output.side_effects.store_attr( + new_dd, "default_factory", new_dd.default_factory + ) + return new_dd + elif name == "__setattr__": + if len(args) != 2: + raise_args_mismatch(tx, name, "2 args", f"{len(args)} args") + if ( + istype(args[0], ConstantVariable) and args[0].value == "default_factory" + ) and self.is_supported_factory(args[1]): + self.default_factory = args[1] + tx.output.side_effects.store_attr( + self, "default_factory", self.default_factory + ) + return ConstantVariable.create(None) + return super().call_method(tx, name, args, kwargs) + elif name == "__eq__": + if len(args) != 1: + raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") + return VariableTracker.build(tx, polyfills.dict___eq__).call_function( + tx, [self, args[0]], {} + ) + return super().call_method(tx, name, args, kwargs) class UserDefinedSetVariable(UserDefinedObjectVariable): @@ -2802,7 +3424,7 @@ def __init__( super().__init__(value, **kwargs) python_type = set if isinstance(value, set) else frozenset - self._set_methods = set_methods if python_type is set else frozenset_methods + self._base_methods = set_methods if python_type is set else frozenset_methods if set_vt is None: assert self.source is None, ( @@ -2810,7 +3432,7 @@ def __init__( ) if python_type is set: # set is initialized later - self._set_vt = variables.SetVariable( + self._base_vt = variables.SetVariable( set(), mutation_type=ValueMutationNew(), ) @@ -2818,65 +3440,32 @@ def __init__( init_args = kwargs.get("init_args", {}) if tx is None: tx = torch._dynamo.symbolic_convert.InstructionTranslator.current_tx() - self._set_vt = SourcelessBuilder.create(tx, python_type).call_function( # type: ignore[assignment] + self._base_vt = SourcelessBuilder.create(tx, python_type).call_function( # type: ignore[assignment] tx, init_args, {} ) else: - self._set_vt = set_vt - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - method = self._maybe_get_baseclass_method(name) - if method in self._set_methods: - return self._set_vt.call_method(tx, name, args, kwargs) - return super().call_method(tx, name, args, kwargs) + self._base_vt = set_vt + assert self._base_vt is not None def as_python_constant(self) -> object: - return self._set_vt.as_python_constant() - - def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: - if inspect.getattr_static(self.value, "__iter__") in ( - set.__iter__, - frozenset.__iter__, - ): - return self._set_vt.unpack_var_sequence(tx) - raise NotImplementedError + assert self._base_vt is not None + return self._base_vt.as_python_constant() @property def set_items(self) -> set[Any]: - return self._set_vt.set_items + assert self._base_vt is not None + return self._base_vt.set_items # pyrefly: ignore[missing-attribute] @property - def items(self) -> list[VariableTracker]: - return self._set_vt.items - - def is_underlying_vt_modified(self, side_effects: "SideEffects") -> bool: - return side_effects.is_modified(self._set_vt) - - def install_dict_keys_match_guard(self) -> None: - return self._set_vt.install_dict_keys_match_guard() - - def install_dict_contains_guard( - self, tx: "InstructionTranslator", args: list[VariableTracker] - ) -> None: - return self._set_vt.install_dict_contains_guard(tx, args) - - def is_python_hashable(self) -> bool: - raise_on_overridden_hash(self.value, self) - return self._set_vt.is_python_hashable() - - def get_python_hash(self) -> int: - return self._set_vt.get_python_hash() + def items(self) -> dict[HashableTracker, VariableTracker]: + assert self._base_vt is not None + return self._base_vt.items # pyrefly: ignore[missing-attribute] def is_python_equal(self, other: object) -> bool: + assert self._base_vt is not None return isinstance( other, UserDefinedSetVariable - ) and self._set_vt.is_python_equal(other._set_vt) + ) and self._base_vt.is_python_equal(other._base_vt) class UserDefinedListVariable(UserDefinedObjectVariable): @@ -2898,35 +3487,11 @@ def __init__( assert self.source is None, ( "list_vt must be constructed by builder.py when source is present" ) - self._list_vt = ListVariable([], mutation_type=ValueMutationNew()) + self._base_vt = ListVariable([], mutation_type=ValueMutationNew()) else: - self._list_vt = list_vt - - def call_method( - self, - tx: "InstructionTranslator", - name: str, - args: list[VariableTracker], - kwargs: dict[str, VariableTracker], - ) -> VariableTracker: - assert self._list_vt is not None - method = self._maybe_get_baseclass_method(name) - if method in list_methods: - return self._list_vt.call_method(tx, name, args, kwargs) - return super().call_method(tx, name, args, kwargs) - - def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: - assert self._list_vt is not None - if type(self.value).__iter__ is list.__iter__: # type: ignore[attr-defined] - return self._list_vt.unpack_var_sequence(tx) - raise NotImplementedError - - def is_underlying_vt_modified(self, side_effects: "SideEffects") -> bool: - return side_effects.is_modified(self._list_vt) - - def is_python_hashable(self) -> Literal[False]: - raise_on_overridden_hash(self.value, self) - return False + self._base_vt = list_vt + self._base_methods = list_methods + assert self._base_vt is not None class UserDefinedTupleVariable(UserDefinedObjectVariable): @@ -2936,9 +3501,21 @@ class UserDefinedTupleVariable(UserDefinedObjectVariable): Internally, it uses a TupleVariable to represent the tuple part of the variable tracker. For everything else, it falls back to UserDefinedObjectVariable. + + NamedTupleVariable and StructSequenceVariable are subclasses that handle + namedtuples and structseqs (torch.return_types.*) respectively. """ - _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + _nonvar_fields = { + "tuple_cls", + *UserDefinedObjectVariable._nonvar_fields, + } + + @staticmethod + def get_vt_cls(cls: type) -> type["UserDefinedTupleVariable"]: + if is_structseq_class(cls): + return StructSequenceVariable + return NamedTupleVariable def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): # type: ignore[all] from .lists import TupleVariable @@ -2959,24 +3536,17 @@ def __init__(self, value, tuple_vt=None, init_args=None, **kwargs): # type: ign tx = InstructionTranslator.current_tx() elems = init_args[0].force_unpack_var_sequence(tx) - self._tuple_vt = TupleVariable(elems, mutation_type=ValueMutationNew()) + self._base_vt = TupleVariable(elems, mutation_type=ValueMutationNew()) else: - self._tuple_vt = tuple_vt + self._base_vt = tuple_vt + self.tuple_cls = type(value) + self._base_methods = tuple_methods + assert self._base_vt is not None - def resolve_data_descriptor( - self, - tx: "InstructionTranslator", - name: str, - type_attr: object, - source: Source | None, - ) -> VariableTracker: - if isinstance(type_attr, _collections._tuplegetter): - # namedtuple fields are _tuplegetter descriptors implemented in C. - # We emulate _tuplegetter.__get__ by indexing into the tracked - # tuple items, because self.value may not hold actual runtime values. - _, (idx, _) = type_attr.__reduce__() - return self._tuple_vt.items[idx] # type: ignore[union-attr] - return super().resolve_data_descriptor(tx, name, type_attr, source) + @property + def items(self) -> list[VariableTracker]: + assert self._base_vt is not None + return self._base_vt.items # type: ignore[return-value] def call_method( self, @@ -2985,7 +3555,6 @@ def call_method( args: list[VariableTracker], kwargs: dict[str, VariableTracker], ) -> VariableTracker: - assert self._tuple_vt is not None if name == "__eq__": if len(args) != 1 or kwargs: raise ValueError("Improper arguments for method.") @@ -2994,29 +3563,201 @@ def call_method( if len(args) != 1 or kwargs: raise ValueError("Improper arguments for method.") return VariableTracker.build(tx, not self.is_python_equal(args[0])) - method = self._maybe_get_baseclass_method(name) - if method in tuple_methods: - return self._tuple_vt.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: - assert self._tuple_vt is not None - if type(self.value).__iter__ is tuple.__iter__: # type: ignore[attr-defined] - return self._tuple_vt.unpack_var_sequence(tx) + def reconstruct(self, codegen: "PyCodegen") -> None: + # Sourceless namedtuples/structseqs (e.g. tensor subclass metadata from + # SourcelessBuilder) aren't in id_to_variable so codegen_save_tempvars + # never processes them. When they appear in return values, codegen falls + # through to call_reconstruct. This is the same pattern as other + # sourceless containers (ConstDictVariable, TupleVariable, etc.). + # UserDefinedDictVariable doesn't need this because it's never created + # sourceless — it only comes from VariableBuilder which always has a + # source. + assert self.source is None + create_fn = self.get_construct_fn() + codegen.add_push_null( + lambda: codegen.append_output( + codegen.create_load_const_unchecked(create_fn) + ) + ) + codegen(self._base_vt) + codegen.extend_output(create_call_function(1, False)) + + def get_construct_fn(self) -> Callable[..., Any]: raise NotImplementedError - def is_python_hashable(self) -> bool: - raise_on_overridden_hash(self.value, self) - return self._tuple_vt.is_python_hashable() + def _validate_rest_for_tree_map( + self, rest: "collections.abc.Sequence[VariableTracker]" + ) -> list["UserDefinedTupleVariable"] | None: + """Validate that rest args are compatible for tree_map fast-path.""" + others: list[UserDefinedTupleVariable] = [] + n = len(self.items) + for candidate in rest: + if ( + not isinstance(candidate, UserDefinedTupleVariable) + or len(candidate.items) != n + or candidate.tuple_cls is not self.tuple_cls + ): + return None + others.append(candidate) + return others - def get_python_hash(self) -> int: - return self._tuple_vt.get_python_hash() + def _make_tree_map_result( + self, new_items: list[VariableTracker] + ) -> "UserDefinedTupleVariable": + from .lists import TupleVariable - def is_python_equal(self, other: object) -> bool: - other = ( - other._tuple_vt if isinstance(other, UserDefinedTupleVariable) else other + tuple_vt = TupleVariable(new_items, mutation_type=ValueMutationNew()) + return type(self)( + self.value, + tuple_vt=tuple_vt, + mutation_type=ValueMutationNew(), ) - return self._tuple_vt.is_python_equal(other) + + def _is_pytree_node(self) -> bool: + from torch.utils._pytree import is_namedtuple_class + + return is_namedtuple_class(self.tuple_cls) or is_structseq_class(self.tuple_cls) + + def call_tree_map_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "variables.functions.UserFunctionVariable", + map_fn: "VariableTracker", + rest: "collections.abc.Sequence[VariableTracker]", + tree_map_kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + if not self._is_pytree_node(): + return super().call_tree_map_branch( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + others = self._validate_rest_for_tree_map(rest) + if others is None: + return self._tree_map_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs + ) + + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [o.items[idx] for o in others] + new_items.append( + item.call_tree_map( + tx, tree_map_fn, map_fn, sibling_leaves, tree_map_kwargs + ) + ) + + return self._make_tree_map_result(new_items) + + def call_tree_map_with_path_branch( + self, + tx: "InstructionTranslator", + tree_map_fn: "variables.functions.UserFunctionVariable", + map_fn: "VariableTracker", + rest: "collections.abc.Sequence[VariableTracker]", + tree_map_kwargs: "dict[str, VariableTracker]", + keypath: "tuple[Any, ...]", + ) -> "VariableTracker": + if not self._is_pytree_node(): + return super().call_tree_map_with_path_branch( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs, keypath + ) + others = self._validate_rest_for_tree_map(rest) + if others is None: + return self._tree_map_with_path_fallback( + tx, tree_map_fn, map_fn, rest, tree_map_kwargs, keypath + ) + + fields = namedtuple_fields(self.tuple_cls) + new_items: list[VariableTracker] = [] + for idx, item in enumerate(self.items): + sibling_leaves = [o.items[idx] for o in others] + child_keypath = keypath + (GetAttrKey(fields[idx]),) + new_items.append( + item.call_tree_map_with_path( + tx, + tree_map_fn, + map_fn, + sibling_leaves, + tree_map_kwargs, + child_keypath, + ) + ) + + return self._make_tree_map_result(new_items) + + def is_python_equal(self, other: object) -> bool: + assert self._base_vt is not None + other = other._base_vt if isinstance(other, UserDefinedTupleVariable) else other + return self._base_vt.is_python_equal(other) + + +class NamedTupleVariable(UserDefinedTupleVariable): + """Represents Python namedtuples (created via collections.namedtuple). + + Namedtuples use _tuplegetter descriptors for field access and + Type(*args) / Type._make(iterable) for construction. + """ + + def resolve_data_descriptor( + self, + tx: "InstructionTranslator", + name: str, + type_attr: object, + source: Source | None, + ) -> VariableTracker: + if isinstance(type_attr, _collections._tuplegetter): + # namedtuple fields are _tuplegetter descriptors implemented in C. + # We emulate _tuplegetter.__get__ by indexing into the tracked + # tuple items, because self.value may not hold actual runtime values. + _, (idx, _) = type_attr.__reduce__() + return self.items[idx] # type: ignore[union-attr] + return super().resolve_data_descriptor(tx, name, type_attr, source) + + def get_construct_fn(self) -> Callable[..., Any]: + return self.tuple_cls._make # type: ignore[attr-defined] + + def as_python_constant(self) -> Any: + items = [x.as_python_constant() for x in self.items] + return self.tuple_cls(*items) # type: ignore[arg-type] + + def as_proxy(self) -> Any: + items = [x.as_proxy() for x in self.items] + return self.tuple_cls(*items) # type: ignore[arg-type] + + +class StructSequenceVariable(UserDefinedTupleVariable): + """Represents C-implemented PyStructSequence types (torch.return_types.*). + + Structseqs use Type(iterable) calling convention and reject tuple.__new__. + """ + + def resolve_data_descriptor( + self, + tx: "InstructionTranslator", + name: str, + type_attr: object, + source: Source | None, + ) -> VariableTracker: + if isinstance(type_attr, types.MemberDescriptorType): + # Structseq fields are member_descriptor objects. We emulate + # field access by looking up the field name in _fields and + # indexing into the tracked tuple items. + fields = namedtuple_fields(self.tuple_cls) + if name in fields: + return self.items[fields.index(name)] + return super().resolve_data_descriptor(tx, name, type_attr, source) + + def get_construct_fn(self) -> Callable[..., Any]: + return self.tuple_cls + + def as_python_constant(self) -> Any: + items = [x.as_python_constant() for x in self.items] + return self.tuple_cls(items) + + def as_proxy(self) -> Any: + items = [x.as_proxy() for x in self.items] + return self.tuple_cls(items) class MutableMappingVariable(UserDefinedObjectVariable): @@ -3074,6 +3815,11 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker else: return super().var_getattr(tx, name) + def mp_length(self, tx: "InstructionTranslator") -> VariableTracker: + if self._maybe_get_baseclass_method("__len__") in dict_methods: + return VariableTracker.build(tx, len(self.value)) # type: ignore[bad-argument-type] + return super().mp_length(tx) + class RandomVariable(UserDefinedObjectVariable): pass diff --git a/torch/_export/config.py b/torch/_export/config.py index ec3963eaa34ac..4f800487b2ab6 100644 --- a/torch/_export/config.py +++ b/torch/_export/config.py @@ -37,7 +37,7 @@ if TYPE_CHECKING: - from torch.utils._config_typing import * # noqa: F401, F403 + from torch.utils._config_typing import * # noqa: F403 def _make_closure_patcher(**changes: Any) -> Any: ... diff --git a/torch/_export/converter.py b/torch/_export/converter.py index e69ca58d3ae7e..a212d8064d61a 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -663,7 +663,7 @@ def convert_aten_Float(self, node: torch._C.Node): def to_float_tensor(t): return t.to(dtype=torch.float).item() - inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 + inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] fx_node = self.fx_graph.call_function( to_float_tensor, tuple(inp_list), @@ -750,7 +750,7 @@ def convert_prim_Constant(self, node: torch._C.Node): self.name_to_constant[name] = value def convert_prim_CallMethod(self, node: torch._C.Node): - inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] # noqa: C416 + inp_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()] fx_node = self.fx_graph.call_method( node.s("name"), tuple(inp_list), diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 83d7450e0fb85..2f1aae39d0a1a 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeGuard import torch import torch.utils._pytree as pytree @@ -623,7 +623,7 @@ def produce_guards_and_solve_constraints( raise constraint_violation_error -def is_int(x: object) -> bool: +def is_int(x: object) -> TypeGuard[int | torch.SymInt]: return isinstance(x, int) or (isinstance(x, torch.SymInt) and x.node.expr.is_number) diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 7c5690d7c6b41..7b91eb28c7ff1 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -137,7 +137,7 @@ def _unused_constant(node: torch.fx.Node) -> list[torch.fx.Node] | None: This function returns None if this constant is being used, otherwise it returns the lift_fresh and detach node to be removed later. - """ # noqa: B950 + """ if len(node.users) > 1: return None diff --git a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py index 9be342f0d188d..ba36183b108d6 100644 --- a/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py +++ b/torch/_export/passes/replace_quantized_ops_with_standard_ops_pass.py @@ -285,7 +285,7 @@ def _conv1d_op_with_squeeze( ) -> torch.Tensor: # In quantized version, conv1d is emulated using conv2d with squeeze and unsqueeze # operations before and after the conv2d operation to match the dimension of weights. - # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 # noqa: B950 + # Reference: https://github.com/pytorch/pytorch/blob/eca0cb0fbe84bb0a34fa94afe261bceecd52c436/aten/src/ATen/native/quantized/cpu/qconv.cpp#L1827 s_inp = torch.ops.aten.unsqueeze(inp, 2) conv1d_res = torch.ops.aten.conv2d( s_inp, diff --git a/torch/_export/serde/export_schema.thrift b/torch/_export/serde/export_schema.thrift index f07ff26f8b351..5268efc883809 100644 --- a/torch/_export/serde/export_schema.thrift +++ b/torch/_export/serde/export_schema.thrift @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<<4810951b41da4955307f49a67e053a7371ef6c8d56f886080b3a6f18b2fd7ec4>> +// checksum<> namespace py3 torch._export namespace cpp2 torch._export.schema @@ -53,6 +53,8 @@ enum ScalarType { FLOAT8E4M3FNUZ = 31, FLOAT8E5M2FNUZ = 32, FLOAT8E8M0FNU = 33, + UINT32 = 34, + UINT64 = 35, } @@ -171,6 +173,7 @@ union Argument { 270: list> as_nested_tensors; 280: list> as_int_lists; 290: map as_string_to_argument; + 300: list> as_float_lists; } struct NamedArgument { diff --git a/torch/_export/serde/schema.py b/torch/_export/serde/schema.py index 92f5250501359..c7ee7b43495bd 100644 --- a/torch/_export/serde/schema.py +++ b/torch/_export/serde/schema.py @@ -9,7 +9,7 @@ # NOTE: Please update this value if any modifications are made to the schema -SCHEMA_VERSION = (8, 18) +SCHEMA_VERSION = (8, 20) TREESPEC_VERSION = 1 @@ -41,6 +41,8 @@ class ScalarType(IntEnum): FLOAT8E4M3FNUZ = 31 FLOAT8E5M2FNUZ = 32 FLOAT8E8M0FNU = 33 + UINT32 = 34 + UINT64 = 35 class Layout(IntEnum): @@ -82,6 +84,8 @@ class MemoryFormat(IntEnum): ScalarType.FLOAT8E4M3FNUZ: "Float8_e4m3fnuz", ScalarType.FLOAT8E5M2FNUZ: "Float8_e5m2fnuz", ScalarType.FLOAT8E8M0FNU: "Float8_e8m0fnu", + ScalarType.UINT32: "UInt32", + ScalarType.UINT64: "UInt64", } LAYOUT_TO_C10: dict[int, str] = { @@ -261,6 +265,7 @@ class Argument(_Union): as_nested_tensors: Annotated[list[list[TensorArgument]], 270] as_int_lists: Annotated[list[list[int]], 280] as_string_to_argument: Annotated[dict[str, "Argument"], 290] + as_float_lists: Annotated[list[list[float]], 300] class ArgumentKind(IntEnum): diff --git a/torch/_export/serde/schema.yaml b/torch/_export/serde/schema.yaml index 734fcafe309b3..3b9c4d3051a87 100644 --- a/torch/_export/serde/schema.yaml +++ b/torch/_export/serde/schema.yaml @@ -1,5 +1,5 @@ # @generated by update_schema.py -# checksum<> +# checksum<> AOTInductorModelPickleData: kind: struct fields: @@ -81,6 +81,8 @@ Argument: type: List[List[int]] as_string_to_argument: type: Dict[str, Argument] + as_float_lists: + type: List[List[float]] ArgumentKind: kind: enum fields: @@ -449,6 +451,8 @@ ScalarType: FLOAT8E4M3FNUZ: 31 FLOAT8E5M2FNUZ: 32 FLOAT8E8M0FNU: 33 + UINT32: 34 + UINT64: 35 SchemaVersion: kind: struct fields: @@ -561,5 +565,5 @@ UserOutputSpec: type: Argument SCHEMA_VERSION: - 8 -- 18 +- 20 TREESPEC_VERSION: 1 diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 402db208f30e3..12d373c11c7a7 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -153,6 +153,8 @@ def _reverse_map(d: dict[Any, Enum]): torch.float8_e4m3fnuz: ScalarType.FLOAT8E4M3FNUZ, torch.float8_e5m2fnuz: ScalarType.FLOAT8E5M2FNUZ, torch.float8_e8m0fnu: ScalarType.FLOAT8E8M0FNU, + torch.uint32: ScalarType.UINT32, + torch.uint64: ScalarType.UINT64, } @@ -1520,6 +1522,12 @@ def serialize_optional_tensor_args(a): ): # list of int tuples return Argument.create(as_int_lists=[list(t) for t in arg]) + elif all( + isinstance(a, (list, tuple)) and all(isinstance(x, float) for x in a) + for a in arg + ): + # list of float lists (List[List[float]]) + return Argument.create(as_float_lists=[list(t) for t in arg]) else: raise SerializeError( f"Unsupported list/tuple argument type: {[type(a) for a in arg]}" @@ -2175,7 +2183,14 @@ def __init__( self.pickle_protocol = pickle_protocol - def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: + def serialize( + self, + exported_program: ep.ExportedProgram, + *, + serialize_state_dict: bool = True, + serialize_constants: bool = True, + serialize_example_inputs: bool = True, + ) -> _SerializedProgram: """ Args: exported_program: Exported Program to serialize @@ -2219,13 +2234,29 @@ def serialize(self, exported_program: ep.ExportedProgram) -> _SerializedProgram: new_state_dict = remove_proxy_from_state_dict( exported_program.state_dict, in_place=False ) + serialized_state_dict = b"" + if serialize_state_dict: + serialized_state_dict = serialize_torch_artifact( + new_state_dict, self.pickle_protocol + ) + + serialized_constants = b"" + if serialize_constants: + serialized_constants = serialize_torch_artifact( + constants, self.pickle_protocol + ) + + serialized_example_inputs = b"" + if serialize_example_inputs: + serialized_example_inputs = serialize_torch_artifact( + exported_program.example_inputs, self.pickle_protocol + ) + return _SerializedProgram( serialized_ep, - serialize_torch_artifact(new_state_dict, self.pickle_protocol), - serialize_torch_artifact(constants, self.pickle_protocol), - serialize_torch_artifact( - exported_program.example_inputs, self.pickle_protocol - ), + serialized_state_dict, + serialized_constants, + serialized_example_inputs, ) @@ -3071,6 +3102,8 @@ def deserialize_input(self, inp: Argument) -> Any: elif typ_ == "as_int_lists": # Convert list of lists back to list of tuples for Triton grids return [tuple(dims) for dims in value] + elif typ_ == "as_float_lists": + return [list(floats) for floats in value] elif typ_ == "as_nested_tensors": # nested list of tensors (List[List[Tensor]]) return [ @@ -3590,11 +3623,20 @@ def serialize( exported_program: ep.ExportedProgram, opset_version: dict[str, int] | None = None, pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, + *, + serialize_state_dict: bool = True, + serialize_constants: bool = True, + serialize_example_inputs: bool = True, ) -> SerializedArtifact: with _enable_graph_inputs_of_type_nn_module(exported_program.example_inputs): serialized_program = ExportedProgramSerializer( opset_version, pickle_protocol - ).serialize(exported_program) + ).serialize( + exported_program, + serialize_state_dict=serialize_state_dict, + serialize_constants=serialize_constants, + serialize_example_inputs=serialize_example_inputs, + ) if not isinstance(serialized_program.exported_program, ExportedProgram): raise AssertionError( f"expected ExportedProgram, got {type(serialized_program.exported_program).__name__}" @@ -3760,6 +3802,8 @@ def _get_argument(a: Argument): return None elif a.type == "as_int_lists": return None + elif a.type == "as_float_lists": + return None elif a.type == "as_string_to_argument": return None elif a.type == "as_nested_tensors": diff --git a/torch/_export/utils.py b/torch/_export/utils.py index c0ebd5d5938fb..fab8d7f576585 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -216,6 +216,9 @@ def _getattr(model: torch.fx.GraphModule, attr_name: str): def _maybe_find_pre_dispatch_tf_mode_for_export(): + if not torch.compiler.is_exporting(): + return None + if not torch._C._is_torch_function_mode_enabled(): return None @@ -397,7 +400,7 @@ def _check_symint( path = get_keystr(keypath) if i is not None: path += f".shape[{i}]" - raise RuntimeError( # noqa: B904 + raise RuntimeError( f"Expected input {path} = {arg} to be " f"of the form {symint.node.expr}, where {symbol} is an integer" ) @@ -472,7 +475,17 @@ def _check_input_constraints_for_graph( ) elif isinstance(node_val, (int, float, str)): - if type(arg) is not type(node_val) or arg != node_val: + if type(arg) is not type(node_val): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", + ) + # NaN != NaN in Python, so use math.isnan for NaN-to-NaN comparison + if isinstance(node_val, float) and math.isnan(node_val): + if not isinstance(arg, float) or not math.isnan(arg): + raise RuntimeError( + f"Expected input at {get_keystr(key_path)} to be nan, but got {arg}", + ) + elif arg != node_val: raise RuntimeError( f"Expected input at {get_keystr(key_path)} to be equal to {node_val}, but got {arg}", ) diff --git a/torch/_export/verifier.py b/torch/_export/verifier.py index 9dbf6dc3a5c3c..7327e908ac753 100644 --- a/torch/_export/verifier.py +++ b/torch/_export/verifier.py @@ -246,6 +246,13 @@ def _allowed_op_types() -> tuple[type[Any], ...]: torch._functorch.predispatch._vmap_increment_nesting, torch._functorch.predispatch._vmap_decrement_nesting, torch._functorch.predispatch.lazy_load_decompositions, + torch._functorch.predispatch._make_dual, + torch._functorch.predispatch._unpack_dual, + torch._functorch.predispatch._jvp_increment_nesting, + torch._functorch.predispatch._jvp_decrement_nesting, + torch._functorch.predispatch._unwrap_for_grad, + torch._functorch.predispatch._enter_dual_level, + torch._functorch.predispatch._exit_dual_level, ) if not isinstance(op, _allowed_op_types()): diff --git a/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py b/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py index 0eff65d9bad41..a7e7bfbb2310c 100644 --- a/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py +++ b/torch/_functorch/_activation_checkpointing/remat_using_tags_for_fwd_loss_bwd_graph_pass.py @@ -1,5 +1,6 @@ """AC rematerialize pass: Duplicates recompute nodes for backward, then DCE removes unused forward versions.""" +import itertools import logging from typing import Any, overload @@ -17,6 +18,7 @@ log = logging.getLogger(__name__) +_EMPTY_CUSTOM_META: dict[str, object] = {} def is_impure_node_for_dce(node: fx.Node) -> bool: @@ -34,29 +36,73 @@ def is_impure_node_for_dce(node: fx.Node) -> bool: return node.is_impure(impure_random) -def _is_backward_node(node: fx.Node) -> bool: - """Check if node is in backward region via annotation""" - return node.meta.get("custom", {}).get("remat_pass_tag", None) == "is_backward" +def _is_backward_node(node: fx.Node, use_phase: bool = False) -> bool: + """Check if node is in backward region. - -def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModule: - """ - Duplicate recompute nodes for backward use. DCE removes unused forward versions. We assume that - you already annotated your backward region with fx.traceback.annotate({"remat_pass_tag": "is_backward"}) - which helps us identify the backward region. + If use_phase is True, only checks custom["phase"] == "backward" + (user annotation). Otherwise falls back to node.meta["autograd_backward"], + which Dynamo adds when tracing torch.autograd.grad. """ - if not has_recomputable_ops(gm): - return gm + custom = node.meta.get("custom", _EMPTY_CUSTOM_META) + if use_phase: + return custom.get("phase") == "backward" + return node.meta.get("autograd_backward", False) + + +def _has_user_phase_annotation(gm: fx.GraphModule) -> bool: + """Check if any node has the user-level phase: backward annotation.""" + return any( + node.meta.get("custom", _EMPTY_CUSTOM_META).get("phase") == "backward" + for node in gm.graph.nodes + ) + + +def _collect_backward_regions( + gm: fx.GraphModule, use_phase: bool +) -> list[tuple[int, int, bool]]: + """Returns (bwd_start, bwd_end, needs_remat) for each backward region. - # Find backward boundary and build ordering + Regions are maximal contiguous runs of backward nodes, as [start, end) + indices into the graph node list. + """ + regions: list[tuple[int, int, bool]] = [] bwd_start: int | None = None - order = {} + needs_remat = False + for idx, node in enumerate(gm.graph.nodes): - order[node] = idx - if _is_backward_node(node) and bwd_start is None: - bwd_start = idx + if _is_backward_node(node, use_phase=use_phase): + if bwd_start is None: + bwd_start = idx + needs_remat = False + if not needs_remat and any( + must_recompute(inp) for inp in node.all_input_nodes + ): + needs_remat = True + elif bwd_start is not None: + regions.append((bwd_start, idx, needs_remat)) + bwd_start = None + + if bwd_start is not None: + regions.append((bwd_start, idx + 1, needs_remat)) + + return regions + + +def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModule: + """ + Duplicate recompute nodes for backward use. DCE removes unused forward versions. - if bwd_start is None: + Backward regions are identified by custom["phase"] == "backward" (user + annotation) or node.meta["autograd_backward"] == True (set automatically when + Dynamo traces torch.autograd.grad). When the user provides phase + annotations, only those annotated regions are used. + + The graph may contain multiple disjoint backward regions (e.g. chunked + loss). Regions that do not depend on recomputable forward nodes are + skipped. Only one region may require remat; if multiple do, we error + and ask the user to annotate which region to rematerialize. + """ + if not has_recomputable_ops(gm): return gm if has_recomputable_rng_ops(gm): @@ -71,12 +117,41 @@ def remat_using_tags_for_fwd_loss_bwd_graph(gm: fx.GraphModule) -> fx.GraphModul force_save_bw_mutation_src(gm) + # must_recompute (used inside _collect_backward_regions) requires + # cleanup_recompute_tags to have run first. + use_phase = _has_user_phase_annotation(gm) + regions = _collect_backward_regions(gm, use_phase) + if not regions: + return gm + + # User-annotated phase regions: multiple annotations is always an error. + if use_phase and len(regions) > 1: + raise RuntimeError( + f"Detected {len(regions)} disjoint backward regions annotated with " + 'phase: "backward" but remat only supports a single backward region. ' + "Please ensure only one contiguous region is annotated." + ) + + remat_regions = [(s, e) for s, e, needs in regions if needs] + + if len(remat_regions) > 1: + raise RuntimeError( + f"Detected {len(remat_regions)} disjoint backward regions that require recomputation, " + "but remat only supports one such region in a forward-loss-backward graph." + ) + + if not remat_regions: + return gm + + bwd_start, bwd_end = remat_regions[0] + + order = {node: idx for idx, node in enumerate(gm.graph.nodes)} new_graph = fx.Graph() env: dict[fx.Node, fx.Node] = {} recomputed_nodes: dict[fx.Node, fx.Node] = {} # Insert forward nodes - for node in list(gm.graph.nodes)[:bwd_start]: + for node in itertools.islice(gm.graph.nodes, 0, bwd_start): env[node] = new_graph.node_copy(node, lambda x: env[x]) @overload @@ -107,7 +182,7 @@ def _gather(n: fx.Node) -> None: return deps # Insert backward nodes - for node in list(gm.graph.nodes)[bwd_start:]: + for node in itertools.islice(gm.graph.nodes, bwd_start, bwd_end): # Gather all deps that need to be recomputed for this node deps = gather_recompute_deps(node) @@ -128,6 +203,9 @@ def _gather(n: fx.Node) -> None: env[node] = new_graph.node_copy(node, remat_input) + for node in itertools.islice(gm.graph.nodes, bwd_end, None): + env[node] = new_graph.node_copy(node, lambda x: env[x]) + new_gm = torch.fx.GraphModule(gm, new_graph) # DCE with custom is_impure_node (like default_partition) diff --git a/torch/_functorch/_activation_offloading/activation_offloading.py b/torch/_functorch/_activation_offloading/activation_offloading.py index 6296c9f4525a3..8f9cd2203d12f 100644 --- a/torch/_functorch/_activation_offloading/activation_offloading.py +++ b/torch/_functorch/_activation_offloading/activation_offloading.py @@ -1,14 +1,8 @@ -""" -Activation offloading for memory optimization in (more like post) partitioners. - -This module provides functionality to offload activations to CPU during forward pass -and reload them during backward pass, reducing GPU memory usage. +"""Activation offloading for memory optimization during compilation. -Additional TODO: -* given the fact that PT2 stream support is in active development, testings should - be done once that is more finalized. A issue currently known is that with streams, - each iteration will have its own offload streams, but the streams should be shared - across the iterations. +This module provides functionality to offload activations to CPU during the forward +pass and reload them during the backward pass, reducing GPU memory usage. It can be +applied to graphs produced by both AOT Autograd partitioners and make_fx-based tracing. """ import logging @@ -17,8 +11,11 @@ import torch import torch.fx as fx -from torch._dynamo.variables.streams import get_current_stream, new_event, new_stream -from torch._inductor import config as inductor_config +from torch._functorch._activation_offloading.offload_ops import ( # noqa: F401 -- registers ao::offload, ao::reload, ao::wait_tensor ops + offload, + reload, + wait_tensor, +) from torch._inductor.fx_passes.overlap_scheduling import benchmark_node, is_compute_node from torch._subclasses.fake_tensor import extract_tensor_metadata from torch.utils._ordered_set import OrderedSet @@ -36,20 +33,31 @@ GPU_RELOAD_PREFIX = "gpu_reload_" +def _find_all_effective_users(node: fx.Node, op_types: OpTypes) -> OrderedSet[fx.Node]: + """Find all effective users of a node, where view ops extend the lifetime + of the original node. If a user is a view op, recursively find users of + the view.""" + effective_users: OrderedSet[fx.Node] = OrderedSet() + for user in node.users: + if user.op == "output": + continue + effective_users.add(user) + if op_types.is_view(user): + effective_users.update(_find_all_effective_users(user, op_types)) + return effective_users + + @dataclass class ReloadNodeInfo: """ Information about backward reload related nodes for each reload operation. - Pattern: fork → wait_stream → device_put → record_event → join → wait_event + Pattern: ao.reload → ao.wait_tensor - This pattern is divided into two logical groups for optimization purposes: - - Reload group (fork → wait_stream → device_put → record_event → join): - Performs the actual asynchronous data transfer on a separate stream. - These nodes can be moved earlier in the graph to overlap with computation. - - Wait group (wait_event): - Synchronization point that blocks until the data transfer completes. - This must remain at the point where the reloaded data is first needed. + - Reload group (ao.reload): Performs the actual asynchronous data transfer. + Can be moved earlier in the graph to overlap with computation. + - Wait node (ao.wait_tensor): Synchronization point that blocks until the data + transfer completes. Must remain at the point where the data is first needed. """ reload_group_nodes: list[fx.Node] @@ -92,21 +100,6 @@ def offload_activation_fw(graph: fx.Graph) -> None: op_types: OpTypes = get_default_op_list() - def find_all_effective_users(node: fx.Node) -> OrderedSet[fx.Node]: - """ - Find all effective users of a node, where view ops extend the lifetime of the - original node. If a user is a view op, recursively find users of the view. - """ - effective_users: OrderedSet[fx.Node] = OrderedSet() - for user in node.users: - if user.op == "output": - continue - effective_users.add(user) - if op_types.is_view(user): - effective_users.update(find_all_effective_users(user)) - - return effective_users - output_node: fx.Node = graph.find_nodes(op="output")[0] # pyrefly: ignore [bad-assignment] fwd_outputs: tuple[fx.Node, ...] = output_node.args[ @@ -122,8 +115,7 @@ def find_all_effective_users(node: fx.Node) -> OrderedSet[fx.Node]: continue # Find insertion point, which is the last use - all_effective_users: OrderedSet[fx.Node] = find_all_effective_users(node) - if all_effective_users := find_all_effective_users(node): + if all_effective_users := _find_all_effective_users(node, op_types): last_user = max(all_effective_users, key=lambda n: node_to_index[n]) else: last_user: fx.Node = node @@ -193,6 +185,116 @@ def reload_activation_bw(graph: fx.Graph) -> None: user.replace_input_with(node, gpu_node) +def offload_activation_fw_async(graph: fx.Graph) -> None: + """Insert async CPU offload operations in the forward pass graph. + + Uses ao.offload + ao.wait_tensor ops which encapsulate stream management + internally, producing a clean 2-node IR per offloaded tensor. + """ + + op_types: OpTypes = get_default_op_list() + + output_node: fx.Node = graph.find_nodes(op="output")[0] + # pyrefly: ignore [bad-assignment] + fwd_outputs: tuple[fx.Node, ...] = output_node.args[ + 0 + ] # pyrefly: ignore [bad-assignment] + node_to_offload: dict[fx.Node, fx.Node] = dict() + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + + if not any(n.meta.get("saved_for_offloading", False) for n in fwd_outputs): + return + + for node in fwd_outputs: + if node.meta.get("saved_for_offloading", False) is False: + continue + + if all_effective_users := _find_all_effective_users(node, op_types): + last_user = max(all_effective_users, key=lambda n: node_to_index[n]) + else: + last_user: fx.Node = node + + with graph.inserting_after(last_user): + offload_node: fx.Node = graph.call_function( + torch.ops.ao.offload.default, + args=(node,), + name=f"async_{CPU_OFFLOAD_PREFIX}{node.name}", + ) + offload_node.meta["val"] = node.meta["val"].to(torch.device("cpu")) + offload_node.meta["tensor_meta"] = extract_tensor_metadata( + offload_node.meta["val"] + ) + # The keepalive=node arg extends the GPU tensor's lifetime in the + # graph so the allocator doesn't reclaim it before the async D2H + # copy completes. + with graph.inserting_after(offload_node): + wait_node: fx.Node = graph.call_function( + torch.ops.ao.wait_tensor.default, + args=(offload_node, node), + name=CPU_OFFLOAD_PREFIX + str(node.name), + ) + wait_node.meta["val"] = offload_node.meta["val"] + wait_node.meta["tensor_meta"] = offload_node.meta["tensor_meta"] + + node_to_offload[node] = wait_node + + output_node.update_arg( + 0, tuple(node_to_offload.get(node, node) for node in fwd_outputs) + ) + + +def reload_activation_bw_async(graph: fx.Graph) -> None: + """Insert async GPU reload operations in the backward pass graph. + + Uses ao.reload + ao.wait_tensor ops which encapsulate stream management internally, + producing a clean 2-node IR per reloaded tensor. + """ + + node_to_index: dict[fx.Node, int] = { + node: idx for idx, node in enumerate(graph.nodes) + } + + nodes_to_reload = [ + n + for n in graph.find_nodes(op="placeholder") + if n.meta.get("saved_for_offloading", False) + ] + if not nodes_to_reload: + return + + for node in nodes_to_reload: + if not node.users: + raise RuntimeError( + f"Offloaded tensor {node.name} has no users in the backward graph" + ) + insert_point: fx.Node = min(node.users.keys(), key=lambda n: node_to_index[n]) + + original_device: torch.device = node.meta["original_device"] + with graph.inserting_before(insert_point): + reload_node: fx.Node = graph.call_function( + torch.ops.ao.reload.default, + args=(node, original_device), + name=f"async_{str(node.name).replace(CPU_OFFLOAD_PREFIX, GPU_RELOAD_PREFIX)}", + ) + reload_node.meta["val"] = node.meta["val"].to(original_device) + reload_node.meta["tensor_meta"] = extract_tensor_metadata( + reload_node.meta["val"] + ) + wait_node: fx.Node = graph.call_function( + torch.ops.ao.wait_tensor.default, + args=(reload_node,), + name=str(node.name).replace(CPU_OFFLOAD_PREFIX, GPU_RELOAD_PREFIX), + ) + wait_node.meta["val"] = reload_node.meta["val"] + wait_node.meta["tensor_meta"] = reload_node.meta["tensor_meta"] + + for user in list(node.users.keys()): + if user != reload_node: + user.replace_input_with(node, wait_node) + + def can_offload( node: fx.Node, fwd_outputs: OrderedSet[fx.Node], @@ -341,245 +443,129 @@ def offload_chosen_sets( reload_activation_bw(bwd_module.graph) -def add_forward_offload_stream_ops(graph: fx.Graph) -> None: +def offload_chosen_sets_async( + fwd_module: fx.GraphModule, + bwd_module: fx.GraphModule, +) -> None: """ - Add stream operations for forward pass CPU offloading. - - Pattern: record_event → fork → wait_event → record_stream → device_put → record_event_2 → join → wait_event_2 - - This ensures that: - 1. Offloading waits for the last use to complete (record_event on default stream) - 2. Offloading happens on a separate stream (fork → wait_event → device_put) - 3. The tensor is marked as used in the offload stream (record_stream) - 4. Execution returns to the default stream after offloading and - waits for offload to complete (record_event_2 → join → wait_event_2) + Add async offload and reload nodes using ao ops. - NOTE: For stream optimization and overlapping compute with communication, - the "wait_event_2" ops can be sinked to the end of the graph. - - Args: - graph: The forward graph to modify + Uses ao.offload/ao.reload + ao.wait_tensor which encapsulate stream management, + instead of device_put + explicit stream operations. Can be applied to + partitioned forward/backward graphs or to a joint graph produced by make_fx. """ - # Find all CPU offload nodes - offload_nodes: list[fx.Node] = [ - node - for node in graph.nodes - if CPU_OFFLOAD_PREFIX in node.name and node.op == "call_function" - ] - if not offload_nodes: - return - - # Get default stream id and offload stream id - current_stream_id: int = get_current_stream( - offload_nodes[0].args[0].meta["val"].device # type: ignore[assignment] - ) - offload_stream_id: int = new_stream() + offload_activation_fw_async(fwd_module.graph) - for offload_node in offload_nodes: - offload_ready_event_id: int = new_event() - offload_completion_event_id: int = new_event() - - # Get the tensor being offloaded - tensor_node: fx.Node = offload_node.args[0] # type: ignore[assignment] - - with graph.inserting_before(offload_node): - # Record event on default stream to ensure last use completes - graph.call_function( - torch.ops.streams.record_event.default, - args=(offload_ready_event_id, current_stream_id), - ) - # Fork to offload stream - graph.call_function( - torch.ops.streams.fork.default, - args=(current_stream_id, offload_stream_id), - name=f"stream_in_{offload_node.name}", - ) - # Wait for the event on offload stream - graph.call_function( - torch.ops.streams.wait_event.default, - args=(offload_ready_event_id, offload_stream_id), - ) - # Inform the CUDA Caching Allocator that this tensor will be accessed in the - # offload stream. Without this, the program may prematurely free its memory - # even though the async offload operation is still in progress, and this can - # lead to memory corruption, especially with reordering for compute and - # communication overlaps. - graph.call_function( - torch.ops.streams.record_stream.default, - args=(tensor_node, offload_stream_id), - name=f"record_stream_{tensor_node.name}", - ) - with graph.inserting_after(offload_node): - # Record event on offload stream after device_put completes - record_event_node = graph.call_function( - torch.ops.streams.record_event.default, - args=(offload_completion_event_id, offload_stream_id), - ) - with graph.inserting_after(record_event_node): - # Join back to default stream - join_node = graph.call_function( - torch.ops.streams.join.default, - args=(offload_stream_id, current_stream_id), - name=f"stream_out_{offload_node.name}", - ) - with graph.inserting_after(join_node): - # Wait for the offload to complete on default stream - graph.call_function( - torch.ops.streams.wait_event.default, - args=(offload_completion_event_id, current_stream_id), - ) + # Replace backward graph placeholders with their offloaded (CPU) counterparts. + # For each offloaded forward output, find the matching backward input and swap + # it with a new placeholder carrying the CPU tensor's metadata, then mark it + # for reloading. + bwd_inputs: dict[str, fx.Node] = { + node.name: node for node in bwd_module.graph.find_nodes(op="placeholder") + } + for fwd_node in fwd_module.graph.find_nodes(op="output")[0].args[0]: + if CPU_OFFLOAD_PREFIX not in fwd_node.name: + continue + bwd_node: fx.Node = bwd_inputs[fwd_node.name.replace(CPU_OFFLOAD_PREFIX, "")] + with bwd_module.graph.inserting_after(bwd_node): + bwd_offload_node: fx.Node = bwd_module.graph.placeholder(name=fwd_node.name) -def add_backward_reload_stream_ops(graph: fx.Graph) -> None: - """ - Add stream operations for backward pass GPU reloading. + bwd_offload_node.meta.update(fwd_node.meta) + bwd_offload_node.meta["saved_for_offloading"] = True + bwd_offload_node.meta["original_device"] = bwd_node.meta["val"].device + bwd_node.replace_all_uses_with(bwd_offload_node) + bwd_module.graph.erase_node(bwd_node) - Pattern: fork → wait_stream → device_put → record_event → join → wait_event + reload_activation_bw_async(bwd_module.graph) - This ensures that: - 1. Reloading doesn't start prematurely (fork → wait_stream) - 2. Reloading happens on a separate stream (device_put) - 3. First use waits for reload completion (record_event → join → wait_event) - NOTE: The pattern consists of two logical groups: - - First group (fork → wait_stream → device_put → record_event → join): - Performs asynchronous data transfer on a separate stream - - Second group (wait_event): - Data transfer completion check when the data is actually needed +def activation_offload_sink_wait_async(fwd_module: fx.GraphModule) -> None: + """Sink ao.wait_tensor operations for offload completion to the end of the graph. - For prefetch optimization, the first group can be moved earlier in the graph - to overlap computation with data transfer, while the wait_event must remain - at its current position to prevent blocking computation unnecessarily. + This allows computation to overlap with offload operations. - Args: - graph: The backward graph to modify + NOTE: Sinking waits to the end delays GPU memory release of the source + tensor (kept alive via the wait's keepalive arg) until the end of the + compiled graph. For per-layer compile this is fine (one layer's worth of + memory), but for full-model compile this means offloaded GPU tensors are + not freed until the entire forward pass completes. """ + graph: fx.Graph = fwd_module.graph + output_node: fx.Node = graph.find_nodes(op="output")[0] - # Find all GPU reload nodes - reload_nodes: list[fx.Node] = [ + wait_nodes_to_sink: list[fx.Node] = [ node for node in graph.nodes - if GPU_RELOAD_PREFIX in node.name and node.op == "call_function" + if node.op == "call_function" + and node.target == torch.ops.ao.wait_tensor.default + and isinstance(node.args[0], fx.Node) + and node.args[0].op == "call_function" + and node.args[0].target == torch.ops.ao.offload.default ] - if not reload_nodes: - return - - # Get default stream id and offload stream id - current_stream_id: int = get_current_stream( - reload_nodes[0].args[0].meta["original_device"] # type: ignore[assignment] - ) - reload_stream_id: int = new_stream() - - for reload_node in reload_nodes: - event_id: int = new_event() - - with graph.inserting_before(reload_node): - # Fork to reload stream - graph.call_function( - torch.ops.streams.fork.default, - args=(current_stream_id, reload_stream_id), - name=f"stream_in_{reload_node.name}", - ) - # Wait for default stream to prevent premature reloading - graph.call_function( - torch.ops.streams.wait_stream.default, - args=(reload_stream_id, current_stream_id), - ) - with graph.inserting_after(reload_node): - # Record event on reload stream after device_put - record_event_node = graph.call_function( - torch.ops.streams.record_event.default, - args=(event_id, reload_stream_id), - ) - with graph.inserting_after(record_event_node): - # Join back to default stream - join_node = graph.call_function( - torch.ops.streams.join.default, - args=(reload_stream_id, current_stream_id), - name=f"stream_out_{reload_node.name}", - ) - with graph.inserting_after(join_node): - # Wait for the event on default stream - graph.call_function( - torch.ops.streams.wait_event.default, - args=(event_id, current_stream_id), - ) - -def put_offload_nodes_on_separate_stream( - fwd_module: fx.GraphModule, - bwd_module: fx.GraphModule, -) -> None: - """ - Add stream and event related operations around offload nodes. - - Args: - fwd_module: Forward module graph - bwd_module: Backward module graph - """ - - add_forward_offload_stream_ops(fwd_module.graph) - add_backward_reload_stream_ops(bwd_module.graph) + # prepend moves the node from its current position (no manual removal needed) + for wait_node in wait_nodes_to_sink: + output_node.prepend(wait_node) -def _validate_pattern_nodes( - fork_node: fx.Node, - wait_stream_node: fx.Node, - record_event_node: fx.Node, - join_node: fx.Node, - wait_event_node: fx.Node, -) -> None: +def activation_reload_prefetch_async(bwd_module: fx.GraphModule) -> None: """ - Validate that the pattern nodes match the expected structure. - - Raises ValueError if any node doesn't match expectations. + Prefetch backward reload operations by moving ao.reload nodes earlier + in the graph to overlap data transfer with computation, while keeping + ao.wait_tensor at its original position. """ + graph: fx.Graph = bwd_module.graph + nodes_list: list[fx.Node] = list(graph.nodes) - if not ( - fork_node.op == "call_function" - and fork_node.target == torch.ops.streams.fork.default - ): - raise ValueError("Expected fork node two nodes before device_put node") - - if not ( - wait_stream_node.op == "call_function" - and wait_stream_node.target == torch.ops.streams.wait_stream.default - ): - raise ValueError("Expected wait_stream node one node before device_put node") - - if not ( - record_event_node.op == "call_function" - and record_event_node.target == torch.ops.streams.record_event.default - ): - raise ValueError("Expected record_event node one node after device_put node") - - if not ( - join_node.op == "call_function" - and join_node.target == torch.ops.streams.join.default - ): - raise ValueError("Expected join node two nodes after device_put node") + # Identify reload + wait pairs + reload_patterns: dict[fx.Node, ReloadNodeInfo] = {} + for node in graph.nodes: + if not ( + node.op == "call_function" and node.target == torch.ops.ao.reload.default + ): + continue + wait_node = next( + (u for u in node.users if u.target == torch.ops.ao.wait_tensor.default), + None, + ) + if wait_node is None: + continue + transfer_size_bytes: int = _calculate_transfer_size(node) + transfer_time_ms: float = _estimate_transfer_time_in_ms(transfer_size_bytes) + reload_patterns[node] = ReloadNodeInfo( + reload_group_nodes=[node], + wait_event_node=wait_node, + transfer_size_bytes=transfer_size_bytes, + transfer_time_ms=transfer_time_ms, + ) - if not ( - wait_event_node.op == "call_function" - and wait_event_node.target == torch.ops.streams.wait_event.default - ): - raise ValueError("Expected wait_event node three nodes after device_put node") + reorder_for_prefetch(nodes_list, reload_patterns) def _calculate_transfer_size(device_put_node: fx.Node) -> int: """Calculate the size in bytes of data being transferred.""" - return _size_of(device_put_node.args[0]) # pyrefly: ignore [bad-argument-type] + # ao.offload(tensor) -> tensor at args[0] + # ao.reload(tensor, device) -> tensor at args[0] + if device_put_node.target in ( + torch.ops.ao.offload.default, + torch.ops.ao.reload.default, + ): + return _size_of(device_put_node.args[0]) # pyrefly: ignore [bad-argument-type] + raise ValueError(f"Unexpected transfer op: {device_put_node.target}") def _estimate_transfer_time_in_ms(transfer_size_bytes: int) -> float: - """ - Estimate transfer time in milliseconds based on size and bandwidth. - NOTE: potentially could be standardized in node estimator class - """ + """Estimate transfer time in milliseconds based on size and bandwidth. - return transfer_size_bytes / (1024**3) * 1_000 / inductor_config.cpu_gpu_bw + Uses config.activation_offload_cpu_gpu_bw (GB/s) which should be set by + the user to match their hardware. + """ + return ( + transfer_size_bytes / (1024**3) * 1_000 / config.activation_offload_cpu_gpu_bw + ) def identify_reload_patterns( @@ -628,7 +614,8 @@ def identify_reload_patterns( wait_event_node: fx.Node = nodes_list[reload_node_idx + 3] # Validate the nodes are what we expect - _validate_pattern_nodes( + # Removed in follow-up commit + _validate_pattern_nodes( # noqa: F821 # pyrefly: ignore [unknown-name] fork_node, wait_stream_node, record_event_node, @@ -811,15 +798,16 @@ def enable_activation_offloading( return # Step 2: Add offload and reload nodes to the graphs - offload_chosen_sets(fwd_module, bwd_module) - - # Step 3: Put offload nodes on separate stream if configured if config.activation_offload_separate_stream: - put_offload_nodes_on_separate_stream(fwd_module, bwd_module) + # Use async ao ops (2 nodes each: offload/reload + wait_tensor) + offload_chosen_sets_async(fwd_module, bwd_module) if config.activation_offload_sink_wait: - activation_offload_sink_wait(fwd_module) + activation_offload_sink_wait_async(fwd_module) if config.activation_reload_prefetch: - activation_reload_prefetch(bwd_module) + activation_reload_prefetch_async(bwd_module) + else: + # Use synchronous device_put (1 node each) + offload_chosen_sets(fwd_module, bwd_module) fwd_module.graph.lint() bwd_module.graph.lint() diff --git a/torch/_functorch/_activation_offloading/offload_ops.py b/torch/_functorch/_activation_offloading/offload_ops.py new file mode 100644 index 0000000000000..5259b97adca3d --- /dev/null +++ b/torch/_functorch/_activation_offloading/offload_ops.py @@ -0,0 +1,180 @@ +"""Custom ops for async activation offloading between GPU and CPU. + +These ops encapsulate stream management internally, producing a clean 2-node +IR pattern (offload/reload + wait_tensor) similar to c10d functional collectives. + +A single dedicated transfer stream handles all D2H/H2D copies. +Completion events are keyed by output tensor data_ptr() and stored in a +module-level registry, so ``ao.wait_tensor`` takes only the tensor itself +(plus an optional keepalive). + +Offload pattern: + cpu_tensor = ao.offload(gpu_tensor) + cpu_tensor = ao.wait_tensor(cpu_tensor, gpu_tensor) + (keepalive arg extends gpu_tensor lifetime past the async D2H copy) + +Reload pattern: + gpu_tensor = ao.reload(cpu_tensor, device) + gpu_tensor = ao.wait_tensor(gpu_tensor) +""" + +import torch +from torch._library.custom_ops import custom_op +from torch.fx import has_side_effect + + +# --- Global transfer stream (one per device, lazily created) --- +_transfer_streams: dict[torch.device, torch.Stream] = {} + + +def _get_or_create_transfer_stream(device: torch.device) -> torch.Stream: + if device not in _transfer_streams: + _transfer_streams[device] = torch.Stream(device=device) + return _transfer_streams[device] + + +# --- Wait registry: maps data_ptr() -> (completion_event, device) --- +# Created by ao.offload / ao.reload, consumed (popped) by ao.wait_tensor. +# Not thread-safe — graph execution is single-threaded Python. +_wait_registry: dict[int, tuple[torch.Event, torch.device]] = {} + + +def _register_wait(tensor: torch.Tensor, device: torch.device) -> torch.Event: + """Create an event for an async transfer and register it for wait_tensor.""" + event = torch.Event() + _wait_registry[tensor.data_ptr()] = (event, device) + return event + + +def _pop_wait(tensor: torch.Tensor) -> tuple[torch.Event, torch.device]: + key = tensor.data_ptr() + try: + return _wait_registry.pop(key) + except KeyError: + raise RuntimeError( + f"ao.wait_tensor: no pending transfer for tensor with data_ptr={key}. " + "Every ao.wait_tensor must be paired with a preceding ao.offload or ao.reload." + ) from None + + +def _clear_wait_registry() -> None: + _wait_registry.clear() + + +@custom_op("ao::offload", mutates_args=()) +def offload(tensor: torch.Tensor) -> torch.Tensor: + """Async offload a GPU tensor to CPU on the dedicated transfer stream. + + Callers MUST pair this with an ``ao.wait_tensor`` that passes the source GPU + tensor as ``keepalive`` to extend its lifetime past the async D2H copy. + Do NOT use ``record_stream`` — it causes memory fragmentation and + unbounded memory growth. + + Uses pinned-memory allocation + copy_ so the transfer is compatible + with CUDA graph capture. + """ + device = tensor.device + transfer_stream = _get_or_create_transfer_stream(device) + current_stream = torch.accelerator.current_stream(device) + + transfer_stream.wait_stream(current_stream) + + torch.accelerator.set_stream(transfer_stream) + result = torch.empty_like(tensor, device="cpu", pin_memory=True) + completion_event = _register_wait(result, device) + result.copy_(tensor, non_blocking=True) + transfer_stream.record_event(completion_event) + torch.accelerator.set_stream(current_stream) + + return result + + +@offload.register_fake +def _(tensor: torch.Tensor) -> torch.Tensor: + return torch.empty_like(tensor, device="cpu") + + +@custom_op("ao::reload", mutates_args=()) +def reload( + tensor: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """Async reload a CPU tensor to GPU on the dedicated transfer stream. + + The GPU tensor is allocated on the compute stream to avoid cross-stream + allocator ownership issues. The H2D copy runs on the transfer stream. + The completion event is keyed by the output tensor's data_ptr. + """ + transfer_stream = _get_or_create_transfer_stream(device) + current_stream = torch.accelerator.current_stream(device) + + # Allocate on compute stream so the allocator tracks ownership correctly + result = torch.empty_like(tensor, device=device) + completion_event = _register_wait(result, device) + + transfer_stream.wait_stream(current_stream) + + torch.accelerator.set_stream(transfer_stream) + result.copy_(tensor, non_blocking=True) + transfer_stream.record_event(completion_event) + torch.accelerator.set_stream(current_stream) + + return result + + +@reload.register_fake +def _( + tensor: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + return torch.empty_like(tensor, device=device) + + +# ao::wait_tensor is defined via torch.library with an aliasing schema so the +# output can alias the input (custom_op forbids this). +# +# Uses CompositeExplicitAutograd (single impl for all devices) because the +# offload case has mixed-device args: ``tensor`` is CPU (the offload result) +# while ``keepalive`` is CUDA (the source GPU tensor). A single impl avoids +# relying on device-priority dispatch ordering. +# +# Synchronization details (completion event, device) are looked up from +# ``_wait_registry`` keyed on ``tensor.data_ptr()``. +# +# ``keepalive`` is not read by the op — its sole purpose is to create a graph +# dependency that extends the tensor's lifetime in the FX graph. For offload, +# this keeps the source GPU tensor alive until the compute stream has waited +# on the D2H completion event, preventing the allocator from reclaiming it +# while the async copy is still in flight. +_lib = torch.library.Library("ao", "DEF") +_lib.define("wait_tensor(Tensor(a) tensor, Tensor? keepalive=None) -> Tensor(a)") + + +@torch.library.impl("ao::wait_tensor", "CompositeExplicitAutograd") +def _ao_wait_tensor( + tensor: torch.Tensor, + keepalive: torch.Tensor | None = None, +) -> torch.Tensor: + completion_event, device = _pop_wait(tensor) + current_stream = torch.accelerator.current_stream(device) + current_stream.wait_event(completion_event) + return tensor + + +@torch.library.register_fake("ao::wait_tensor") +def _ao_wait_tensor_fake( + tensor: torch.Tensor, + keepalive: torch.Tensor | None = None, +) -> torch.Tensor: + return tensor + + +has_side_effect(torch.ops.ao.wait_tensor.default) + + +def wait_tensor( + tensor: torch.Tensor, + keepalive: torch.Tensor | None = None, +) -> torch.Tensor: + """Callable wrapper so ``wait_tensor`` can be imported by name for op registration.""" + return torch.ops.ao.wait_tensor.default(tensor, keepalive) diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ea56c74798016..83ab2240b01c9 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -39,6 +39,7 @@ from .runtime_wrappers import ( AOTDispatchAutograd, + AOTDispatchAutogradCompileSpec, AOTDispatchSubclassWrapper, CachedAutogradLazyBackwardCompileInfo, CompilerWrapper, @@ -124,9 +125,22 @@ def post_compile( payload_fn=lambda: json.dumps(cache_info), ) result = graph # type: ignore[assignment] + result.compile_region_name = ( # pyrefly: ignore[missing-attribute] + fx_config.get("compile_region_name") + ) # Run normal post compile result.post_compile(self.example_inputs, constants, fx_config) + + # Let the CUDAGraph policy do outer-level wrapping (e.g. wrapping + # an entire RegionalOutputCode as a single CUDA graph instead of + # per-inner-region). + import torch._inductor.config as _inductor_config + + policy = _inductor_config.cudagraph_policy + if policy is not None: + result = policy.wrap_output(result) + return result @@ -218,7 +232,17 @@ def post_compile( """ Called after FXGraphCacheLoadable.load, mutates fx_config """ + result.compile_region_name = fx_config.get( # pyrefly: ignore[bad-assignment] + "compile_region_name" + ) result.post_compile(self.example_inputs, self.constants, fx_config) + + import torch._inductor.config as _inductor_config + + policy = _inductor_config.cudagraph_policy + if policy is not None: + result = policy.wrap_output(result) + return result @@ -238,27 +262,28 @@ class GenericCompiledBackward(InductorOutput[TOut]): backward_state_indices: list[int] num_symints_saved_for_bw_: int + def post_compile(self, result: TOut, fx_config: _CompileFxKwargs) -> TOut: + # The concrete post_compile comes from the loadable mixin in each subclass MRO. + compiled_bw = super().post_compile( # pyrefly: ignore[missing-attribute] + result, fx_config + ) + # See note [Wrapping bw_compiler in disable] + # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py + # But since on cache hit we do not call the bw_compiler, we need to reapply the disable + return torch._dynamo.disable( # type: ignore[return-value] + compiled_bw, reason="do not trace generated backwards pass" + ) + @dataclass class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoadable): """ - Cacheable entry for a forward function + Cacheable entry for a backward function """ def _is_backward(self) -> bool: return True - def post_compile( - self, result: CompiledFxGraph, fx_config: _CompileFxKwargs - ) -> CompiledFxGraph: - compiled_bw = super().post_compile(result, fx_config) - # See note [Wrapping bw_compiler in disable] - # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py - # But since on cache hit we do not call the bw_compiler, we need to reapply the disable - return torch._dynamo.disable( # type: ignore[return-value] - compiled_bw, reason="do not trace generated backwards pass" - ) - # Generic bundled forward/backward classes that work with any OutputCode type @dataclass @@ -282,17 +307,6 @@ class BundledCompiledBackward( Works with any OutputCode type (CompiledFxGraph, RegionalOutputCode, etc.) """ - def post_compile( - self, result: TOutputCode, fx_config: _CompileFxKwargs - ) -> TOutputCode: - compiled_bw = super().post_compile(result, fx_config) - # See note [Wrapping bw_compiler in disable] - # This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py - # But since on cache hit we do not call the bw_compiler, we need to reapply the disable - return torch._dynamo.disable( # type: ignore[return-value] - compiled_bw, reason="do not trace generated backwards pass" - ) - @dataclass class SerializedGraphModule: @@ -382,123 +396,115 @@ def pre_save(self) -> None: if self.compiled_bw is not None: self.compiled_bw.pre_save() - # Turn result into the original callable - def wrap_post_compile( - self, - args: list[torch.Tensor], - aot_config: AOTConfig, - fx_config: _CompileFxKwargs, - # pyrefly: ignore [implicit-any] - ) -> Callable: - """ - This function takes a result and carefully reconstructs the original callable - that AOTAutograd returned the first time it was run. It does this by running the various - post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. - - In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. - In the autograd path, this consists of AOTAutogradDispatch.post_compile. - - The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + def _log_cached_graphs(self, aot_config: AOTConfig) -> None: + if not aot_config.enable_log: + return - Notably absent from the cached path are: - - DebugAssertWrapper - - FakifiedOutWrapper - - Which we'll handle separately later on, if necessary. - """ - from torch._dynamo.utils import CompileEventLogger, dynamo_timed - - # Log the output of AOTAutogradCache - if aot_config.enable_log: - if self.aot_joint_graph_str is not None: - torch._logging.trace_structured( - "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str - ) - aot_graphs_log.info( - "Joint graph (from cache)\n\n%s", self.aot_joint_graph_str - ) + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + aot_graphs_log.info( + "Joint graph (from cache)\n\n%s", self.aot_joint_graph_str + ) - if self.aot_forward_graph_str is not None: - from torchgen.utils import dataclass_repr + if self.aot_forward_graph_str is not None: + from torchgen.utils import dataclass_repr + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(self.runtime_metadata), + ) + if self.maybe_subclass_meta is not None: torch._logging.trace_structured( "artifact", metadata_fn=lambda: { - "name": "aot_forward_graph_fw_metadata", + "name": "aot_forward_graph_fw_subclass_metadata", "encoding": "string", }, - payload_fn=lambda: dataclass_repr(self.runtime_metadata), - ) - if self.maybe_subclass_meta is not None: - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "aot_forward_graph_fw_subclass_metadata", - "encoding": "string", - }, - payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), - ) - - # It's called an inference graph if not running with autograd - has_backward = self.aot_backward_graph_str is not None - torch._logging.trace_structured( - "aot_forward_graph" if has_backward else "aot_inference_graph", - payload_fn=lambda: self.aot_forward_graph_str, - ) - aot_graphs_log.info( - "Forward graph (from cache)\n\n%s", - self.aot_forward_graph_str, + payload_fn=lambda: dataclass_repr(self.maybe_subclass_meta), ) - if self.aot_backward_graph_str is not None: - torch._logging.trace_structured( - "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str - ) - aot_graphs_log.info( - "Backward graph (from cache)\n\n%s", - self.aot_backward_graph_str, - ) - with dynamo_timed("AOTAutogradCache.inductor_load"): - compiled_fw_func = self.compiled_fw.load(args) - compiled_bw_func = None - if self.compiled_bw is not None: - compiled_bw_func = self.compiled_bw.load(args) - needs_autograd = True - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="autograd" - ) - # Now that we've loaded forward and backward, call post compile on both - # This avoids setting things like BoxedBools in fx_config until - # after both forward and backward cache hit - fw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - bw_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": True, - } - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, fw_fx_config - ) - compiled_bw_func = self.compiled_bw.post_compile( - compiled_bw_func, bw_fx_config - ) - else: - inference_fx_config: _CompileFxKwargs = { - **fx_config, - "is_backward": False, - } - - needs_autograd = False - CompileEventLogger.try_add_pt2_compile( - "backend_compile", dispatch_mode="inference" - ) - compiled_fw_func = self.compiled_fw.post_compile( - compiled_fw_func, inference_fx_config - ) + # It's called an inference graph if not running with autograd + has_backward = self.aot_backward_graph_str is not None + torch._logging.trace_structured( + "aot_forward_graph" if has_backward else "aot_inference_graph", + payload_fn=lambda: self.aot_forward_graph_str, + ) + aot_graphs_log.info( + "Forward graph (from cache)\n\n%s", + self.aot_forward_graph_str, + ) + + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + aot_graphs_log.info( + "Backward graph (from cache)\n\n%s", + self.aot_backward_graph_str, + ) + + def _load_and_post_compile( + self, + args: list[torch.Tensor], + fx_config: _CompileFxKwargs, + ) -> tuple[Callable[..., Any], Callable[..., Any] | None, bool]: + from torch._dynamo.utils import CompileEventLogger + + compiled_fw_func = self.compiled_fw.load(args) + if self.compiled_bw is not None: + compiled_bw_func = self.compiled_bw.load(args) + needs_autograd = True + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="autograd" + ) + # Now that we've loaded forward and backward, call post compile on both + # This avoids setting things like BoxedBools in fx_config until + # after both forward and backward cache hit + fw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + bw_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": True, + } + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, fw_fx_config + ) + compiled_bw_func = self.compiled_bw.post_compile( + compiled_bw_func, bw_fx_config + ) + return compiled_fw_func, compiled_bw_func, needs_autograd + + inference_fx_config: _CompileFxKwargs = { + **fx_config, + "is_backward": False, + } + + needs_autograd = False + CompileEventLogger.try_add_pt2_compile( + "backend_compile", dispatch_mode="inference" + ) + compiled_fw_func = self.compiled_fw.post_compile( + compiled_fw_func, inference_fx_config + ) + return compiled_fw_func, None, needs_autograd + + def _apply_runtime_wrappers( + self, + compiled_fw_func: Callable[..., Any], + compiled_bw_func: Callable[..., Any] | None, + needs_autograd: bool, + aot_config: AOTConfig, + ) -> Callable[..., Any]: + from torch._dynamo.utils import CompileEventLogger - # Wrap the forward function in post compile wrappers compiled_fw_func = AOTDispatchSubclassWrapper( trace_joint=needs_autograd, fw_only=None, @@ -538,19 +544,20 @@ def wrap_post_compile( # 1. the bw is already compiled # 2. we don't need to save to the cache again # so those corresponding arguments are set to None. - compiled_function = AOTDispatchAutograd.post_compile( - compiled_fw_func, - compiled_bw_func, - self.maybe_subclass_meta, - self.compiled_bw.num_symints_saved_for_bw_, - self.compiled_bw.backward_state_indices, - disable_amp, - self.indices_of_inps_to_detach, - cached_lazy_backward, - aot_config, + compile_spec = AOTDispatchAutogradCompileSpec( + compiled_fw_func=compiled_fw_func, + compiled_bw_func=compiled_bw_func, + maybe_subclass_meta=self.maybe_subclass_meta, + num_symints_saved_for_bw=self.compiled_bw.num_symints_saved_for_bw_, + backward_state_indices=self.compiled_bw.backward_state_indices, + disable_amp=disable_amp, + indices_of_inps_to_detach=self.indices_of_inps_to_detach, + lazy_backward_info=cached_lazy_backward, + aot_config=aot_config, fw_metadata=self.runtime_metadata, try_save_cache_entry=None, ) + compiled_function = AOTDispatchAutograd.post_compile(compile_spec) else: compiled_function = RuntimeWrapper( @@ -568,9 +575,9 @@ def wrap_post_compile( aot_config, runtime_metadata=self.runtime_metadata, ) + return compiled_function - # Now that we're pretty sure it's a successful load, add guards - # to the existing shape environment from the cache + def _check_guards(self, args: list[torch.Tensor]) -> None: if self.guards_expr: from .autograd_cache import AOTAutogradCache @@ -579,6 +586,43 @@ def wrap_post_compile( if check is not True: raise AssertionError(f"guards check failed: {check}") + # Turn result into the original callable + def wrap_post_compile( + self, + args: list[torch.Tensor], + aot_config: AOTConfig, + fx_config: _CompileFxKwargs, + ) -> Callable[..., Any]: + """ + This function takes a result and carefully reconstructs the original callable + that AOTAutograd returned the first time it was run. It does this by running the various + post compile steps that AOTAutograd runs on its compiled artifact after running the fw/bw compilers. + + In the inference path, this consists of the Subclass, FunctionalzedRngRuntime, and RuntimeWrappers. + In the autograd path, this consists of AOTAutogradDispatch.post_compile. + + The steps here should match exactly the steps that are run in aot_dispatch_base and aot_dispatch_autograd. + + Notably absent from the cached path are: + - DebugAssertWrapper + - FakifiedOutWrapper + + Which we'll handle separately later on, if necessary. + """ + from torch._dynamo.utils import dynamo_timed + + self._log_cached_graphs(aot_config) + with dynamo_timed("AOTAutogradCache.inductor_load"): + compiled_fw_func, compiled_bw_func, needs_autograd = ( + self._load_and_post_compile(args, fx_config) + ) + + compiled_function = self._apply_runtime_wrappers( + compiled_fw_func, compiled_bw_func, needs_autograd, aot_config + ) + # Now that we're pretty sure it's a successful load, add guards + # to the existing shape environment from the cache. + self._check_guards(args) return compiled_function diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 85edc69358356..6af2fed98fa0b 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -7,6 +7,7 @@ import base64 import contextlib import functools +import hashlib import json import logging import os @@ -22,11 +23,17 @@ import torch from torch._dynamo.precompile_context import PrecompileContext from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions -from torch._dynamo.utils import chromium_event_log_active, CompileEventLogger, counters +from torch._dynamo.utils import ( + chromium_event_log_active, + CompileEventLogger, + counters, + warn_once, +) from torch._functorch import config from torch._inductor.codecache import ( _ident, add_ephemeral_timeout_increase_for_distributed, + AOTAUTOGRAD_CACHE_PREFIX, BypassFxGraphCache, create_cache, extract_tensor_metadata_for_cache_key, @@ -52,7 +59,7 @@ CacheArtifactFactory, CacheArtifactManager, ) -from torch.fx.experimental.symbolic_shapes import size_hint +from torch.fx.experimental.symbolic_shapes import guarding_hint_or_throw from torch.fx.node import Node from torch.utils._triton import has_triton_package @@ -72,14 +79,13 @@ SerializableCompiledFunction, SubclassMeta, ) -from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta # noqa: F401 +from .schemas import AOTAutogradCacheInfo, AOTConfig, ViewAndMutationMeta if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence - from torch._inductor.compile_fx import _CompileFxKwargs - from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.compile_fx import _CompileFxKwargs, CompilerConfigExtra from torch._inductor.remote_cache import JsonDataTy, RemoteCache @@ -315,6 +321,14 @@ def _get_context_fn_cache_hash(context_fn: Callable[..., Any]) -> str | None: return None +def _iter_graph_modules( + gm: torch.fx.GraphModule, +) -> Generator[torch.fx.GraphModule, None, None]: + for module in gm.modules(): + if isinstance(module, torch.fx.GraphModule): + yield module + + def _collect_context_fn_hashes(gm: torch.fx.GraphModule) -> list[str]: """ Collect cache hashes from all context_fn used in SAC HOPs within the graph module. @@ -323,9 +337,7 @@ def _collect_context_fn_hashes(gm: torch.fx.GraphModule) -> list[str]: lacks a cache_hash attribute. """ hashes = [] - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue + for module in _iter_graph_modules(gm): context_fn = module.meta.get("_checkpoint_context_fn") if context_fn is not None: cache_hash = _get_context_fn_cache_hash(context_fn) @@ -342,6 +354,32 @@ def _collect_context_fn_hashes(gm: torch.fx.GraphModule) -> list[str]: return hashes +def _collect_wrapped_user_cache_hashes(gm: torch.fx.GraphModule) -> list[str]: + wrapped_user_cache_hashes = [] + for node in gm.graph.nodes: + if node.meta and node.meta.get("is_wrapped", False): + wrapped_user_cache_hashes.append(node.meta["user_cache_hash"]) + return wrapped_user_cache_hashes + + +def _collect_saved_tensors_hooks_fx_wrap_cache_hashes( + gm: torch.fx.GraphModule, +) -> tuple[list[str], list[str]]: + if not hasattr(gm, "saved_tensors_hooks_pack_0"): + return ([], []) + + return ( + _collect_wrapped_user_cache_hashes( + # pyrefly: ignore[bad-argument-type] + gm.saved_tensors_hooks_pack_0 + ), + _collect_wrapped_user_cache_hashes( + # pyrefly: ignore[bad-argument-type] + gm.saved_tensors_hooks_unpack_0 + ), + ) + + def _get_custom_estimator_solver_uuids( autograd_config: Any, ) -> tuple[object | None, object | None]: @@ -401,6 +439,20 @@ class AOTAutogradCacheDetails(FxGraphHashDetails): a safe and stable cache key for AOTAutograd. """ + def _iter_triton_kernels_from_node(self, node: Node) -> Generator[Any, None, None]: + if isinstance(node.target, torch._ops.OpOverloadPacket): + for attr in node.target._dir: + custom_op = getattr(node.target, attr, None) + if custom_op is not None: + yield from torch._library.triton.get_triton_kernels_for_op( + custom_op._name + ) + return + if isinstance(node.target, torch._ops.OpOverload): + yield from torch._library.triton.get_triton_kernels_for_op( + node.target._name + ) + def get_triton_source_codes_from_gm( self, gm: torch.fx.GraphModule, @@ -409,23 +461,9 @@ def get_triton_source_codes_from_gm( raise AssertionError("Triton is not available") triton_kernels = [] - for module in gm.modules(): - if not isinstance(module, torch.fx.GraphModule): - continue + for module in _iter_graph_modules(gm): for node in module.graph.nodes: - if isinstance(node.target, torch._ops.OpOverloadPacket): - attrs = node.target._dir - for attr in attrs: - if custom_op := getattr(node.target, attr, None): - kernels = torch._library.triton.get_triton_kernels_for_op( - custom_op._name - ) - triton_kernels.extend(kernels) - elif isinstance(node.target, torch._ops.OpOverload): - kernels = torch._library.triton.get_triton_kernels_for_op( - node.target._name - ) - triton_kernels.extend(kernels) + triton_kernels.extend(self._iter_triton_kernels_from_node(node)) triton_kernel_source_codes = [] from torch._inductor.codegen.wrapper import ( @@ -452,47 +490,37 @@ def __init__( aot_config: AOTConfig, fx_config: _CompileFxKwargs, ) -> None: - # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info + # FxGraphHashDetails contains all the keys related to inductor. Also + # includes some system info. self.aot_config = aot_config + self._record_runtime_state(gm) + self.saved_tensors_hooks_fx_wrap_cache_hashes = ( + _collect_saved_tensors_hooks_fx_wrap_cache_hashes(gm) + ) + self.sac_context_fn_hashes = _collect_context_fn_hashes(gm) + + # Note: We use the live config module, not self.autograd_config (the + # saved config), because activation_memory_budget_runtime_estimator and + # activation_memory_budget_solver are excluded from save_config (in + # _save_config_ignore) since they're not serializable. We must access the + # config module directly to get the patched runtime values. + self.custom_estimator_solver_uuids = _get_custom_estimator_solver_uuids(config) + self._init_fx_graph_hash_details(gm, example_inputs, fx_config) + + def _record_runtime_state(self, gm: torch.fx.GraphModule) -> None: self.grad_enabled = torch.is_grad_enabled() self.disable_amp = torch._C._is_any_autocast_enabled() self.deterministic_algorithms = torch.are_deterministic_algorithms_enabled() self.autograd_config = config.save_config() - self.saved_tensors_hooks_fx_wrap_cache_hashes: tuple[list[str], list[str]] = ( - [], - [], - ) if has_triton_package(): self.triton_kernel_source_codes = self.get_triton_source_codes_from_gm(gm) - if hasattr(gm, "saved_tensors_hooks_pack_0"): - - def _add_wrapped_user_cache_hashes( - _gm: torch.fx.GraphModule, _l: list[str] - ) -> None: - for node in _gm.graph.nodes: - if node.meta and node.meta.get("is_wrapped", False): - _l.append(node.meta["user_cache_hash"]) - - _add_wrapped_user_cache_hashes( - # pyrefly: ignore[bad-argument-type] - gm.saved_tensors_hooks_pack_0, - self.saved_tensors_hooks_fx_wrap_cache_hashes[0], - ) - _add_wrapped_user_cache_hashes( - # pyrefly: ignore[bad-argument-type] - gm.saved_tensors_hooks_unpack_0, - self.saved_tensors_hooks_fx_wrap_cache_hashes[1], - ) - - self.sac_context_fn_hashes: list[str] = _collect_context_fn_hashes(gm) - - # Note: We use the live config module, not self.autograd_config (the saved config), - # because activation_memory_budget_runtime_estimator and activation_memory_budget_solver - # are excluded from save_config (in _save_config_ignore) since they're not serializable. - # We must access the config module directly to get the patched runtime values. - self.custom_estimator_solver_uuids = _get_custom_estimator_solver_uuids(config) - + def _init_fx_graph_hash_details( + self, + gm: torch.fx.GraphModule, + example_inputs: Sequence[Any], + fx_config: _CompileFxKwargs, + ) -> None: try: # FXGraphCache has constraints on what can be pickled in its inductor # config. Check that the gm is cacheable by inductor first, @@ -516,6 +544,116 @@ def __init__(self, gm: torch.fx.GraphModule) -> None: } ) + # pyrefly: ignore [bad-override] + def reducer_override(self, obj: Any) -> Any: + """ + Override to handle tensor subclasses (like DTensor) that aren't caught + by the dispatch_table's exact type matching. + + The dispatch_table only matches exact types, so subclasses like DTensor + fall through to the default __reduce_ex__ which includes non-deterministic + storage addresses. This method catches those cases using isinstance checks. + """ + # Handle tensor subclasses that aren't exactly torch.Tensor + # dispatch_table already handles torch.Tensor exactly + if isinstance(obj, torch.Tensor) and type(obj) is not torch.Tensor: + return self._reduce_tensor_subclass(obj) + # Return NotImplemented to fall back to default behavior + return NotImplemented + + # [NOTE] Tensor subclass stable hashing for AOT autograd cache + # Python's hash() varies with PYTHONHASHSEED, making cache keys unstable + # across processes. We use blake2b for cross-process determinism. + # + # EXTENSION POINT: Traceable wrapper subclasses can override cache key + # generation by implementing _stable_hash_for_caching(self) -> str. + # This method should return a deterministic string that uniquely identifies + # the tensor's metadata for caching purposes. See DTensor for an example. + # + # We can't define a default method on subclasses because there is no abstract + # base subclass, and we don't want to pollute torch.Tensor. Instead, we provide + # a default implementation here that uses __tensor_flatten__ to recursively + # hash inner tensors and metadata. + + def _hash_bytes_for_cache(self, data: bytes) -> str: + return hashlib.blake2b(data, digest_size=16).hexdigest() + + def _hash_pickled_value_for_cache(self, value: Any) -> str: + return self._hash_bytes_for_cache(pickle.dumps(value)) + + def _stable_hash_for_cache_value(self, obj: Any) -> str: + """Get a stable hash for an object used inside tensor subclass metadata.""" + from torch._opaque_base import OpaqueBase + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if hasattr(obj, "_stable_hash_for_caching"): + return obj._stable_hash_for_caching() + if isinstance(obj, torch.Tensor) and is_traceable_wrapper_subclass(obj): + return self._default_stable_hash_for_caching(obj) + if isinstance(obj, OpaqueBase): + # Opaque objects are runtime pass-throughs; only the type matters + # for cache key purposes, not the instance identity or value. + return self._hash_bytes_for_cache(type(obj).__qualname__.encode()) + if isinstance(obj, torch.Tensor): + return self._hash_pickled_value_for_cache( + extract_tensor_metadata_for_cache_key(obj) + ) + return self._hash_pickled_value_for_cache(obj) + + def _reduce_tensor_subclass( + self, tensor: torch.Tensor + ) -> tuple[Callable[..., Any], tuple[Any]]: + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + if hasattr(tensor, "_stable_hash_for_caching"): + return (_ident, (tensor._stable_hash_for_caching(),)) + if is_traceable_wrapper_subclass(tensor): + warn_once( + f"{type(tensor).__name__} does not implement _stable_hash_for_caching. " + "For PT2-compatible tensor subclasses, it is recommended to implement " + "_stable_hash_for_caching(self) -> str for stable AOT autograd caching." + ) + return (_ident, (self._default_stable_hash_for_caching(tensor),)) + return self._reduce_tensor(tensor) + + def _collect_inner_tensor_hashes( + self, tensor: torch.Tensor, inner_tensor_names: Sequence[str] + ) -> dict[str, str]: + inner_hashes = {} + for name in inner_tensor_names: + inner_hashes[name] = self._stable_hash_for_cache_value( + getattr(tensor, name) + ) + return inner_hashes + + def _stabilize_tensor_subclass_metadata(self, obj: Any) -> Any: + from torch._opaque_base import OpaqueBase + + if isinstance(obj, OpaqueBase): + return type(obj).__qualname__ + if isinstance(obj, tuple): + return tuple(self._stabilize_tensor_subclass_metadata(x) for x in obj) + if isinstance(obj, list): + return [self._stabilize_tensor_subclass_metadata(x) for x in obj] + if isinstance(obj, dict): + return { + k: self._stabilize_tensor_subclass_metadata(v) for k, v in obj.items() + } + return obj + + def _default_stable_hash_for_caching(self, tensor: torch.Tensor) -> str: + """ + Default stable hash implementation for traceable wrapper subclasses. + """ + inner_tensor_names, subclass_metadata = tensor.__tensor_flatten__() # type: ignore[attr-defined] + cache_payload = ( + tensor.shape, + tensor.requires_grad, + self._stabilize_tensor_subclass_metadata(subclass_metadata), + self._collect_inner_tensor_hashes(tensor, inner_tensor_names), + ) + return self._hash_pickled_value_for_cache(cache_payload) + def _reduce_aot_config( self, aot_config: AOTConfig ) -> tuple[Callable[..., Any], tuple[Any, ...]]: @@ -600,55 +738,95 @@ def normalize_placeholder_names( gm.recompile() +def create_fx_config( + compiler_config_extra: CompilerConfigExtra | None = None, + compile_region_name: str | None = None, +) -> _CompileFxKwargs: + if compiler_config_extra is None: + cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + boxed_forward_device_index = None + else: + cudagraphs = compiler_config_extra.cudagraphs + boxed_forward_device_index = compiler_config_extra.forward_device + return { + "cudagraphs": cudagraphs, + "boxed_forward_device_index": boxed_forward_device_index, + "compile_region_name": compile_region_name, # pyrefly: ignore[bad-typed-dict-key] + } + + +def _check_triton_cache_version() -> None: + if not has_triton_package(): + return + + # Due to https://github.com/triton-lang/triton/issues/3729, if triton is < + # 3.2.0, AOTAutogradCache may cause us to attempt to load a cache entry + # without initializing the CUDA context on the autograd thread. + # + # Without caching, we naturally do this initialization when tracing through + # the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + + +def _get_debug_lines_for_cache_key( + pickler: AOTAutogradCachePickler, + details: AOTAutogradCacheDetails, + key: str, +) -> list[str]: + # debug_lines re-hashes every attribute individually and is expensive. Only + # compute when debug logging is enabled. + if not log.isEnabledFor(logging.DEBUG): + return [] + + debug_lines = pickler.debug_lines(details) + log.debug( + "Autograd graph cache hash details for key %s:\n%s", + key, + LazyString(lambda: "\n".join(debug_lines)), + ) + return debug_lines + + def autograd_cache_key( - gm: torch.fx.GraphModule, + mod: torch.fx.GraphModule | torch._dynamo.utils.GmWrapper, example_inputs: Sequence[Any], config: AOTConfig, - fx_config: _CompileFxKwargs, + compiler_config_extra: CompilerConfigExtra | None = None, # TODO: add args and parameters ) -> tuple[str, list[str]]: """ Generate a unique hash of the FX graph for caching. """ - try: - check_cacheable(gm) - if has_triton_package(): - # Due to https://github.com/triton-lang/triton/issues/3729, - # if triton is < 3.2.0, AOTAutogradCache may cause us to - # attempt to load a cache entry without initializing - # the CUDA context on the autograd thread. - - # Without caching, we naturally do this initialization when - # tracing through the graph with the autograd engine. - import triton - - if triton.__version__ < "3.2.0": - raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") - details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) - pickler = AOTAutogradCachePickler(gm) - # The prefix distinguishes among the other kinds of objects we cache - key = "a" + pickler.get_hash(details) - debug_lines = pickler.debug_lines(details) - log.debug( - "Autograd graph cache hash details for key %s:\n%s", - key, - LazyString(lambda: "\n".join(debug_lines)), - ) - return key, debug_lines - except Exception: - # If enable_aot_compile is set, we're in AOT precompile mode where we always - # want to use fallback nonce keys. Unlike caching, it's fine if we can't generate - # a proper key because we are guaranteed in an AOT precompile world users are in - # complete control of distributing and loading artifacts. - if torch._functorch.config.bypass_autograd_cache_key: - log.info( - "Failed to generate AOTAutograd cache key; falling back to nonce due to enable_aot_compile", - exc_info=True, + gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod + with sanitize_gm_for_cache(gm): + try: + check_cacheable(gm) + _check_triton_cache_version() + details = AOTAutogradCacheDetails( + gm, example_inputs, config, create_fx_config(compiler_config_extra) ) - return str(random.random()), [] - else: - raise + pickler = AOTAutogradCachePickler(gm) + # The prefix distinguishes among the other kinds of objects we cache + key = AOTAUTOGRAD_CACHE_PREFIX + pickler.get_hash(details) + debug_lines = _get_debug_lines_for_cache_key(pickler, details, key) + return key, debug_lines + except Exception: + # If enable_aot_compile is set, we're in AOT precompile mode where we always + # want to use fallback nonce keys. Unlike caching, it's fine if we can't generate + # a proper key because we are guaranteed in an AOT precompile world users are in + # complete control of distributing and loading artifacts. + if torch._functorch.config.bypass_autograd_cache_key: + log.info( + "Failed to generate AOTAutograd cache key; falling back to nonce due to enable_aot_compile", + exc_info=True, + ) + return str(random.random()), [] + else: + raise @contextlib.contextmanager @@ -667,6 +845,7 @@ def sanitize_gm_for_cache( """ # Mapping from each field to a default value IGNORED_FIELDS: dict[str, Any] = { + # pyrefly: ignore [implicit-any] "meta": {}, # metadata used by export "compile_subgraph_reason": None, # Used by dynamo only for logging, no change in inductor/autograd behavior "_param_name_to_source": None, # Encapsulated by aot_config.aot_autograd_arg_pos_to_source @@ -736,170 +915,154 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradResult[Any, Any]]): @staticmethod def clear() -> None: """Clear the cache""" - try: - shutil.rmtree(AOTAutogradCache._get_tmp_dir()) - except FileNotFoundError: - pass + shutil.rmtree(AOTAutogradCache._get_tmp_dir(), ignore_errors=True) @staticmethod def try_load( mod: torch.fx.GraphModule | torch._dynamo.utils.GmWrapper, args: list[Any], aot_config: AOTConfig, - cudagraphs: BoxedBool, - boxed_forward_device_index: BoxedDeviceIndex | None, + compiler_config_extra: CompilerConfigExtra | None, local: bool, remote: bool, + compile_region_name: str | None = None, ) -> Callable[..., Any] | None: """ Load a result from the cache, and reconstruct a runtime wrapper around the object """ - gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod - with sanitize_gm_for_cache(gm): - compiled_fn = None - cache_info: dict[str, Any] = {} - cache_key = None - debug_lines: list[str] = [] - cache_event_time = time.time_ns() - cache_state = None - fx_config: _CompileFxKwargs = { - "cudagraphs": cudagraphs, - "boxed_forward_device_index": boxed_forward_device_index, - } - try: - cache_key, debug_lines = autograd_cache_key( - gm, args, aot_config, fx_config + compiled_fn = None + cache_info: dict[str, Any] = {} + cache_key = None + debug_lines: list[str] = [] + cache_event_time = time.time_ns() + cache_state = None + try: + cache_key, debug_lines = autograd_cache_key( + mod, args, aot_config, compiler_config_extra + ) + result: tuple[GenericAOTAutogradResult[Any, Any], bytes] | None = ( + AOTAutogradCache._lookup( + cache_key, local, remote, args, cache_info, aot_config ) - result: tuple[GenericAOTAutogradResult[Any, Any], bytes] | None = ( - AOTAutogradCache._lookup( - cache_key, local, remote, args, cache_info, aot_config - ) + ) + if result is not None: + (entry, pickled_content) = result + fx_config = create_fx_config(compiler_config_extra, compile_region_name) + compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) + # Make the compiled_fn serializable, where the serialize function just + # makes a copy of the original entry before post compile via the pickled content + compiled_fn = SerializableCompiledFunction( + compiled_fn, lambda: pickle.loads(pickled_content) ) - if result is not None: - (entry, pickled_content) = result - compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config) - # Make the compiled_fn serializable, where the serialize function just - # makes a copy of the original entry before post compile via the pickled content - compiled_fn = SerializableCompiledFunction( - compiled_fn, lambda: pickle.loads(pickled_content) - ) - log.info("AOTAutograd cache hit for key %s", cache_key) - - counters["aot_autograd"]["autograd_cache_hit"] += 1 - cache_state = "hit" - cache_event_time = time.time_ns() - forward_time_saved = entry.forward_time_taken_ns // 1e6 - backward_time_saved = entry.backward_time_taken_ns // 1e6 - cache_info.update( - { - "forward_time_saved_ms": forward_time_saved, - "backward_time_saved_ms": backward_time_saved, - "time_saved_ms": forward_time_saved + backward_time_saved, - } - ) - time_saved_ns = ( - entry.forward_time_taken_ns + entry.backward_time_taken_ns - ) - # TODO: should we use the same field for remote cache time saved for both - # FXGraphCache and AOTAutogradCache? - # get_metrics_context().increment(...) - if ( - ephemeral_increase - := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) - ) != 0: - cache_info["ephemeral_timeout_increase"] = ephemeral_increase - - if compiled_fn is None: - log.info("AOTAutograd cache miss for key %s", cache_key) - counters["aot_autograd"]["autograd_cache_miss"] += 1 - cache_state = "miss" - cache_event_time = time.time_ns() - # Count missing the FXGraphCache as a miss not a bypass - except FXGraphCacheMiss as e: - counters["aot_autograd"]["autograd_cache_miss"] += 1 - cache_state = "miss" - if ( - config.strict_autograd_cache - or torch._dynamo.config.strict_precompile - ): - raise e - # Most often this is BypassAOTAutogradCache, but - # if there's ever different reason we can't cache, - # we still never want to hard throw an exception, since - # we can always fallback to a cache bypass. - # As an example, if the user calls autograd via - # standalone inductor, we will sometimes get a GraphModule - # that doesn't actually have a `.graph` on it. Instead - # of checking every single case, we safely catch the exception - # in those cases. - except Exception as e: - cache_key = None - counters["aot_autograd"]["autograd_cache_bypass"] += 1 - log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 - cache_state = "bypass" + log.info("AOTAutograd cache hit for key %s", cache_key) + + counters["aot_autograd"]["autograd_cache_hit"] += 1 + cache_state = "hit" cache_event_time = time.time_ns() - cache_info["cache_bypass_reason"] = str(e) - cache_info["cache_bypass_exception_type"] = type(e).__name__ - cache_info["cache_bypass_traceback"] = traceback.format_exc().split( - "\n" + forward_time_saved = entry.forward_time_taken_ns // 1e6 + backward_time_saved = entry.backward_time_taken_ns // 1e6 + cache_info.update( + { + "forward_time_saved_ms": forward_time_saved, + "backward_time_saved_ms": backward_time_saved, + "time_saved_ms": forward_time_saved + backward_time_saved, + } ) - # TODO: this gets logged implicitly by cache_bypass_reason, - # and here we explicitly log it into tlparse. - # We may want to log this as an extra column in Scuba, though. - cache_info["cache_bypass_hard_exception"] = not isinstance( - e, BypassAOTAutogradCache + time_saved_ns = ( + entry.forward_time_taken_ns + entry.backward_time_taken_ns ) - if remote: - log_cache_bypass("bypass_aot_autograd", str(e)) + # TODO: should we use the same field for remote cache time saved for both + # FXGraphCache and AOTAutogradCache? + # get_metrics_context().increment(...) if ( - config.strict_autograd_cache - or torch._dynamo.config.strict_precompile - ): - raise e - if compiled_fn is None: - # Set the cache key so we can save a cache result later - symints = AOTAutogradCache._filter_backed_symints(args) - if cache_key is not None: - aot_config.cache_info = AOTAutogradCacheInfo( - cache_key, - time.time_ns(), - forward_symints=symints, - ) + ephemeral_increase + := add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + ) != 0: + cache_info["ephemeral_timeout_increase"] = ephemeral_increase - cache_info.update( - { - "key": cache_key, - "cache_state": cache_state, - "components": debug_lines, - } + if compiled_fn is None: + log.info("AOTAutograd cache miss for key %s", cache_key) + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + cache_event_time = time.time_ns() + # Count missing the FXGraphCache as a miss not a bypass + except FXGraphCacheMiss as e: + counters["aot_autograd"]["autograd_cache_miss"] += 1 + cache_state = "miss" + if config.strict_autograd_cache or torch._dynamo.config.strict_precompile: + raise e + # Most often this is BypassAOTAutogradCache, but + # if there's ever different reason we can't cache, + # we still never want to hard throw an exception, since + # we can always fallback to a cache bypass. + # As an example, if the user calls autograd via + # standalone inductor, we will sometimes get a GraphModule + # that doesn't actually have a `.graph` on it. Instead + # of checking every single case, we safely catch the exception + # in those cases. + except Exception as e: + cache_key = None + counters["aot_autograd"]["autograd_cache_bypass"] += 1 + log.info("Bypassing autograd cache due to: %s", e) + cache_state = "bypass" + cache_event_time = time.time_ns() + cache_info["cache_bypass_reason"] = str(e) + cache_info["cache_bypass_exception_type"] = type(e).__name__ + cache_info["cache_bypass_traceback"] = traceback.format_exc().split("\n") + # TODO: this gets logged implicitly by cache_bypass_reason, + # and here we explicitly log it into tlparse. + # We may want to log this as an extra column in Scuba, though. + cache_info["cache_bypass_hard_exception"] = not isinstance( + e, BypassAOTAutogradCache ) - if chromium_event_log_active(): - CompileEventLogger.instant( - f"autograd_cache_{cache_state}", - metadata=cache_info, - time_ns=cache_event_time, - ) - CompileEventLogger.try_add_pt2_compile( - "backend_compile", - cache_state=cache_state, - cache_event_time=cache_event_time, - key=cache_info.get("key"), - components=cache_info.get("components"), - cache_bypass_reason=cache_info.get("cache_bypass_reason"), - remote_cache_enabled=remote, - local_cache_enabled=local, + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) + if config.strict_autograd_cache or torch._dynamo.config.strict_precompile: + raise e + if compiled_fn is None: + # Set the cache key so we can save a cache result later + symints = AOTAutogradCache._filter_backed_symints(args) + if cache_key is not None: + aot_config.cache_info = AOTAutogradCacheInfo( + cache_key, + time.time_ns(), + forward_symints=symints, ) - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": f"aotautograd_cache_{cache_state}", - "encoding": "json", - }, - payload_fn=lambda: json.dumps(cache_info), + cache_info.update( + { + "key": cache_key, + "cache_state": cache_state, + "components": debug_lines, + } + ) + if chromium_event_log_active(): + CompileEventLogger.instant( + f"autograd_cache_{cache_state}", + metadata=cache_info, + time_ns=cache_event_time, ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": f"aotautograd_cache_{cache_state}", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(cache_info), + ) - return compiled_fn + return compiled_fn @classmethod def generate_guards_expression( @@ -986,7 +1149,7 @@ def _lookup( remote_cache = AOTAutogradCache.get_remote_cache() symints = AOTAutogradCache._filter_backed_symints(args) - hints = [size_hint(s) for s in symints] + hints = [guarding_hint_or_throw(s) for s in symints] entry = None pickled_content = None try: @@ -1019,7 +1182,7 @@ def _lookup( ), ) except Exception as e: - log.info("AOTAutograd cache unable to load compiled graph: %s", e) # noqa: G200 + log.info("AOTAutograd cache unable to load compiled graph: %s", e) if config.strict_autograd_cache: raise e if entry is not None: @@ -1074,7 +1237,7 @@ def _pickle_entry( except (pickle.PicklingError, TypeError, AttributeError) as e: bad_field = AOTAutogradCache._find_unpicklable_field(entry) error_str = str(e) - log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) # noqa: G200 + log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) torch._logging.trace_structured( "artifact", metadata_fn=lambda: { @@ -1096,10 +1259,10 @@ def _handle_save_error(e: Exception, remote: bool, is_bypass: bool) -> None: """Handle exceptions during save, re-raising if strict mode is enabled.""" if is_bypass: counters["aot_autograd"]["autograd_cache_bypass"] += 1 - log.info("Bypassing autograd cache due to: %s", e) # noqa: G200 + log.info("Bypassing autograd cache due to: %s", e) bypass_reason = str(e) else: - log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) # noqa: G200 + log.warning("AOTAutograd cache unable to serialize compiled graph: %s", e) bypass_reason = "Unable to serialize: " + str(e) if remote: log_cache_bypass("bypass_aot_autograd", bypass_reason) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index f6edf2f4036a4..d4d56919424eb 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -168,8 +168,6 @@ def run_functionalized_fw_and_collect_metadata( *, flat_args_descs: list[AOTInput], keep_input_mutations: bool, - # TODO: refactor to kill this flag - is_train: bool = False, # Note: this is guaranteed to be set when running under dynamo static_input_indices: list[int] | None = None, pre_dispatch: bool = False, @@ -835,25 +833,15 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp: object) -> bool: for inp, info in zip(flat_f_args, input_info) if info.mutation_type == MutationType.MUTATED_OUT_GRAPH ] - f_metadata_mutated_inputs = [ - inp for inp, info in zip(flat_f_args, input_info) if info.mutates_metadata - ] - # This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be. - # When handling subclasses, we need info about **all** outputs of compiled forward graph, - # so we know precisely which graph outputs to wrap back into tensor subclasses - # Ideally we would refactor this so not have an is_train flag, and have the separate - # inference and training paths decide which inputs/output to ask for subclass info on. - # However, we currently stash indexing information on each SubclassMeta about its order - # in the graph outputs list. - f_fw_graph_outs = list(flat_f_outs) - if is_train or not keep_input_mutations: - f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs - else: - # even when "keep_input_mutations" is True, - # we never keep metadata-only mutations in the fw graph - f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs - if is_train: - f_fw_graph_outs = f_fw_graph_outs + intermediate_bases + # Build the full list of forward graph outputs so the subclass wrapping + # code knows exactly which graph outputs to wrap back into subclasses. + # Including intermediate_bases unconditionally is safe: they are only + # populated when outputs require grad (line ~539), so they are naturally + # empty during pure inference. In the "downgrade from training to + # inference" path, num_intermediate_bases > 0 is already gated behind + # `assert not req_subclass_dispatch` (aot_autograd.py), so the subclass + # wrapping code that consumes subclass_fw_graph_out_meta never sees them. + f_fw_graph_outs = [*f_mutated_inputs, *flat_f_outs, *intermediate_bases] fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs) grad_enabled_mutation = None @@ -870,6 +858,12 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp: object) -> bool: grad_enabled_mutation, ) + subclass_inp_meta = create_subclass_meta(flat_args) + subclass_fw_graph_out_meta = create_subclass_meta(fw_graph_outs) + subclass_tangent_meta = create_subclass_meta( + traced_tangents, count_symints=False, with_memory_format=True + ) + metadata = ViewAndMutationMeta( input_info=input_info, output_info=output_info, @@ -877,14 +871,9 @@ def _is_subclass_mutated_input_tangent_always_subclass(inp: object) -> bool: keep_input_mutations=keep_input_mutations, traced_tangents=traced_tangents, traced_tangents_descs=traced_tangents_descs, - subclass_inp_meta=create_subclass_meta(flat_args), - subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs), - subclass_tangent_meta=create_subclass_meta( - traced_tangents, - count_symints=False, - with_memory_format=True, - ), - is_train=is_train, + subclass_inp_meta=subclass_inp_meta, + subclass_fw_graph_out_meta=subclass_fw_graph_out_meta, + subclass_tangent_meta=subclass_tangent_meta, grad_enabled_mutation=grad_enabled_mutation, static_input_indices=static_input_indices, tokens=mode._tokens, diff --git a/torch/_functorch/_aot_autograd/frontend_utils.py b/torch/_functorch/_aot_autograd/frontend_utils.py index ee7624f886c8a..5691a2641a301 100644 --- a/torch/_functorch/_aot_autograd/frontend_utils.py +++ b/torch/_functorch/_aot_autograd/frontend_utils.py @@ -34,7 +34,48 @@ def process_inputs( fake_mode: FakeTensorMode, shape_env: ShapeEnv | None, ignore_shape_env: bool = False, -) -> FakifiedFlatArgs: +) -> tuple[FakifiedFlatArgs, list[int]]: + """Convert real tensor inputs into fake tensors for AOT autograd tracing. + + Called at compile time (not runtime) to produce the fake inputs that AOT + autograd traces through. Each real tensor is converted to a FakeTensor + via ``fake_mode.from_tensor``, preserving shape, dtype, device, and + symbolic shape information from the ShapeEnv. Non-tensor inputs (ints, + SymInts, ScriptObjects) are converted or passed through as appropriate. + + Tensor subclass inputs (DTensor, etc.) are fakified recursively by + walking their ``__tensor_flatten__`` attrs. AsyncCollectiveTensors are + resolved via ``trigger_wait()`` before fakification so they don't appear + in the traced metadata (see below). + + Called from ``aot_function``, ``aot_module_simplified``, and + ``aot_export_module`` — anywhere AOT autograd needs fake inputs before + graph capture. + + Returns: + A tuple of (fakified_args, act_input_indices) where act_input_indices + records which positions held AsyncCollectiveTensors. These indices are + stored on ViewAndMutationMeta so that the runtime wrapper can emit + direct trigger_wait() calls on those positions. + """ + # Resolve AsyncCollectiveTensors before tracing. ACTs are transient + # eager-mode wrappers for async collective overlap; if they leak into the + # traced graph as input types, AOT autograd records them in + # SubclassCreationMeta for output tangent metadata. At runtime, autograd + # produces plain tensor tangents, causing a type mismatch. Unwrapping + # here prevents ACT from appearing in the traced metadata. + try: + from torch.distributed._functional_collectives import AsyncCollectiveTensor + except ImportError: + AsyncCollectiveTensor = None + + act_input_indices: list[int] = [] + if AsyncCollectiveTensor is not None: + for i, a in enumerate(flat_args): + if isinstance(a, AsyncCollectiveTensor): + act_input_indices.append(i) + flat_args[i] = a.trigger_wait() + with fake_mode: def convert(idx: int, x: Any) -> Any: @@ -120,7 +161,9 @@ def convert(idx: int, x: Any) -> Any: ) return result - return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)]) + return FakifiedFlatArgs( + [convert(idx, x) for idx, x in enumerate(flat_args)] + ), act_input_indices def construct_fake_mode( diff --git a/torch/_functorch/_aot_autograd/graph_capture.py b/torch/_functorch/_aot_autograd/graph_capture.py index 3100ba9bf91f4..f1752bd4d13e7 100644 --- a/torch/_functorch/_aot_autograd/graph_capture.py +++ b/torch/_functorch/_aot_autograd/graph_capture.py @@ -191,6 +191,96 @@ def _detach_and_copy_item_memo(t: torch.Tensor) -> torch.Tensor: return detached_t +@dataclasses.dataclass +class _GraphCaptureTracingResult: + fn_to_trace: Callable[..., Any] + flat_args: Any + flat_args_descs: Any + maybe_subclass_meta: SubclassMeta | None + + +def _detach_traced_inputs(flat_args: Any) -> Any: + if detect_fake_mode(): + detach_tensor = _detach_and_copy_item_memo + else: + + def detach_tensor(t: torch.Tensor) -> torch.Tensor: + return t.detach() + + return pytree.tree_map_only(torch.Tensor, detach_tensor, flat_args) + + +def _prepare_graph_capture_tracing( + fn_to_trace: Callable[..., Any], + flat_args: Any, + flat_args_descs: Any, + flat_fn: TraceFn, + *, + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + trace_joint: bool, + joint_fn_handle: Any | None = None, +) -> _GraphCaptureTracingResult: + if aot_config.disable_functionalization: + updated_flat_args, updated_flat_args_descs = flat_args, flat_args_descs + else: + fn_to_trace, updated_flat_args, updated_flat_args_descs = ( + create_functionalized_fn( + fn_to_trace, + flat_args, + flat_args_descs, + meta=fw_metadata, + aot_config=aot_config, + trace_joint=trace_joint, + joint_fn_handle=joint_fn_handle, + ) + ) + + subclass_tracing_info = aot_dispatch_subclass( + fn_to_trace, + updated_flat_args, + updated_flat_args_descs, + is_joint_structure=trace_joint, + meta=fw_metadata, + fw_only=flat_fn, + ) + fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn + updated_flat_args = subclass_tracing_info.plain_tensor_args + updated_flat_args_descs = subclass_tracing_info.plain_tensor_args_descs + + if not aot_config.disable_functionalization: + fn_to_trace, updated_flat_args, updated_flat_args_descs = ( + handle_effect_tokens_fn( + fn_to_trace, + updated_flat_args, + updated_flat_args_descs, + meta=fw_metadata, + trace_joint=trace_joint, + ) + ) + + return _GraphCaptureTracingResult( + fn_to_trace=fn_to_trace, + flat_args=updated_flat_args, + flat_args_descs=updated_flat_args_descs, + maybe_subclass_meta=subclass_tracing_info.maybe_subclass_meta, + ) + + +def _create_graph_and_save_traced_inputs( + fn_to_trace: Callable[..., Any], + flat_args: Any, + flat_args_descs: Any, + *, + aot_config: AOTConfig, +) -> tuple[torch.fx.GraphModule, Any]: + saved_flat_args = _detach_traced_inputs(flat_args) + return ( + _create_graph(fn_to_trace, flat_args, flat_args_descs, aot_config=aot_config), + saved_flat_args, + ) + + def aot_dispatch_base_graph( flat_fn: TraceFn, flat_args: list[FxValue], @@ -212,59 +302,28 @@ def aot_dispatch_base_graph( fw_metadata, keep_data_input_mutations=aot_config.keep_inference_input_mutations, ) - - if aot_config.disable_functionalization: - updated_flat_args, updated_flat_args_descs = ( - flat_args, - flat_args_descs, - ) - else: - fn_to_trace, updated_flat_args, updated_flat_args_descs = ( - create_functionalized_fn( - fn_to_trace, - flat_args, - flat_args_descs, - meta=fw_metadata, - aot_config=aot_config, - trace_joint=False, - ) - ) - # TODO: replace with AOTDispatchSubclassWrapper once we refactor # fn_input_mutations_to_outputs and create_functionalized_fn # into CompilerWrappers. - ( + tracing_state = _prepare_graph_capture_tracing( fn_to_trace, - updated_flat_args_subclasses_desugared, - updated_flat_args_subclasses_desugared_descs, - maybe_subclass_meta, - ) = aot_dispatch_subclass( - fn_to_trace, - updated_flat_args, - updated_flat_args_descs, - is_joint_structure=False, - meta=fw_metadata, - fw_only=flat_fn, + flat_args, + flat_args_descs, + flat_fn, + fw_metadata=fw_metadata, + aot_config=aot_config, + trace_joint=False, ) - - if not aot_config.disable_functionalization: - ( - fn_to_trace, - updated_flat_args_subclasses_desugared, - updated_flat_args_subclasses_desugared_descs, - ) = handle_effect_tokens_fn( - fn_to_trace, - updated_flat_args_subclasses_desugared, - updated_flat_args_subclasses_desugared_descs, - meta=fw_metadata, - trace_joint=False, - ) + fn_to_trace = tracing_state.fn_to_trace + updated_flat_args_subclasses_desugared = tracing_state.flat_args + updated_flat_args_subclasses_desugared_descs = tracing_state.flat_args_descs + maybe_subclass_meta = tracing_state.maybe_subclass_meta aot_graphs_log.debug( "aot_config id: %s, fw_metadata=%s,subclass_metadata=%s", - str(aot_config.aot_id), - str(fw_metadata), - str(maybe_subclass_meta), + aot_config.aot_id, + fw_metadata, + maybe_subclass_meta, ) # We track buffer assignments when exporting in non-strict mode. @@ -278,27 +337,18 @@ def aot_dispatch_base_graph( mod_when_exporting_non_strict, assigned_buffers ) - fake_mode = detect_fake_mode() - if fake_mode: - saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( - torch.Tensor, - _detach_and_copy_item_memo, - updated_flat_args_subclasses_desugared, - ) - else: - saved_updated_flat_args_subclasses_desugared = pytree.tree_map_only( - torch.Tensor, lambda t: t.detach(), updated_flat_args_subclasses_desugared - ) - saved_updated_flat_args_subclasses_desugared_descs = ( - updated_flat_args_subclasses_desugared_descs - ) - - fw_module = _create_graph( + ( + fw_module, + saved_updated_flat_args_subclasses_desugared, + ) = _create_graph_and_save_traced_inputs( fn_to_trace, updated_flat_args_subclasses_desugared, updated_flat_args_subclasses_desugared_descs, aot_config=aot_config, ) + saved_updated_flat_args_subclasses_desugared_descs = ( + updated_flat_args_subclasses_desugared_descs + ) if aot_config.is_export and mod_when_exporting_non_strict is not None: # We update metadata to consider any assigned buffers as buffer mutations. @@ -328,6 +378,8 @@ def aot_dispatch_base_graph( if not aot_config.disable_functionalization: copy_count = assert_functional_graph(fw_module.graph) assign_epilogue_copy_streams(fw_module) + # Wrap sync nodes with control_deps to prevent reordering + wrap_all_sync_nodes_with_control_deps(fw_module) # Populate fw_metadata with stream indices from the compiled graph populate_fw_metadata_with_stream_indices(fw_module, fw_metadata) fw_module.graph.eliminate_dead_code() @@ -397,11 +449,10 @@ def aot_dispatch_base_graph( ) # TODO: should factor this into a separate function for export that always only returns just the graph. - if aot_config.is_export: - if maybe_subclass_meta is not None: - raise AssertionError( - "aot_export_module does not support tensor subclass inputs for now." - ) + if aot_config.is_export and maybe_subclass_meta is not None: + raise AssertionError( + "aot_export_module does not support tensor subclass inputs for now." + ) return ( fw_module, saved_updated_flat_args_subclasses_desugared, @@ -447,51 +498,23 @@ def aot_dispatch_autograd_graph( ) # pyrefly: ignore[missing-attribute] joint_fn_handle = joint_fn_to_trace.handle - - if aot_config.disable_functionalization: - updated_joint_inputs, updated_joint_inputs_descs = ( - joint_inputs, - joint_inputs_descs, - ) - else: - joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs = ( - create_functionalized_fn( - joint_fn_to_trace, - joint_inputs, - joint_inputs_descs, - meta=fw_metadata, - aot_config=aot_config, - trace_joint=True, - joint_fn_handle=joint_fn_handle, - ) - ) - # TODO: replace with AOTDispatchSubclassWrapper once we refactor # fn_input_mutations_to_outputs and create_functionalized_fn # into CompilerWrappers. - subclass_tracing_info = aot_dispatch_subclass( + tracing_state = _prepare_graph_capture_tracing( joint_fn_to_trace, - updated_joint_inputs, - updated_joint_inputs_descs, - is_joint_structure=True, - meta=fw_metadata, - fw_only=flat_fn, + joint_inputs, + joint_inputs_descs, + flat_fn, + fw_metadata=fw_metadata, + aot_config=aot_config, + trace_joint=True, + joint_fn_handle=joint_fn_handle, ) - - joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn - updated_joint_inputs = subclass_tracing_info.plain_tensor_args - updated_joint_inputs_descs = subclass_tracing_info.plain_tensor_args_descs - - if not aot_config.disable_functionalization: - (joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs) = ( - handle_effect_tokens_fn( - joint_fn_to_trace, - updated_joint_inputs, - updated_joint_inputs_descs, - meta=fw_metadata, - trace_joint=True, - ) - ) + joint_fn_to_trace = tracing_state.fn_to_trace + updated_joint_inputs = tracing_state.flat_args + updated_joint_inputs_descs = tracing_state.flat_args_descs + maybe_subclass_meta = tracing_state.maybe_subclass_meta # When we call _create_graph, this may mutate the metadata of joint # inputs. But callers are expecting to get the original joint inputs. So @@ -501,19 +524,7 @@ def aot_dispatch_autograd_graph( # This destroys requires_grad/grad_fn information. However, backends # beneath AOTAutograd are indifferent to this information, so it doesn't # matter. - - fake_mode = detect_fake_mode() - if fake_mode: - saved_updated_joint_inputs = pytree.tree_map_only( - torch.Tensor, _detach_and_copy_item_memo, updated_joint_inputs - ) - else: - saved_updated_joint_inputs = pytree.tree_map_only( - torch.Tensor, lambda t: t.detach(), updated_joint_inputs - ) - maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta - - fx_g = _create_graph( + fx_g, saved_updated_joint_inputs = _create_graph_and_save_traced_inputs( joint_fn_to_trace, updated_joint_inputs, updated_joint_inputs_descs, @@ -562,11 +573,10 @@ def aot_dispatch_autograd_graph( # TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect # when we need to manually detach() some inputs in the forward. # Higher order ops might eventually need to do the same. - if aot_config.is_export: - if maybe_subclass_meta is not None: - raise AssertionError( - "aot_export_module does not support tensor subclass inputs for now." - ) + if aot_config.is_export and maybe_subclass_meta is not None: + raise AssertionError( + "aot_export_module does not support tensor subclass inputs for now." + ) return ( fx_g, saved_updated_joint_inputs, diff --git a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py index 71a92a7f0d647..bb26cca92171c 100644 --- a/torch/_functorch/_aot_autograd/graph_capture_wrappers.py +++ b/torch/_functorch/_aot_autograd/graph_capture_wrappers.py @@ -10,6 +10,7 @@ 4. dispatching subclasses """ +import typing import warnings from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext @@ -751,7 +752,7 @@ def apply_in_graph_mutations( if input_info.mutates_storage_metadata: if mcs is None or mcs.mc_storage > applied_mcs.mc_storage: # type: ignore[union-attr] with torch.no_grad(): - # pyrefly: ignore[no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] inpt_old.set_(inpt_new) # Note [Ordering of resize_() and set_()] @@ -1089,6 +1090,7 @@ def _post_forward(primals: Any) -> None: ): apply_in_graph_mutations( inpt_info, + # pyrefly: ignore [bad-argument-type] before, after, f_inpt, @@ -1356,11 +1358,15 @@ def inner_fn(fn: Callable[..., Any], args: Any, *, use_trace_joint: bool) -> Any # Add extra symints as outputs to the forward/backward graphs # ignore nested ints here forward_outs, forward_outs_descs = unwrap_tensor_subclasses( - wrapped_outs[0], wrapped_outs_descs[0], append_symints=True + wrapped_outs[0], + wrapped_outs_descs[0], + append_symints=True, ) # ignore nested ints here backward_outs, backward_outs_descs = unwrap_tensor_subclasses( - wrapped_outs[1], wrapped_outs_descs[1], append_symints=True + wrapped_outs[1], + wrapped_outs_descs[1], + append_symints=True, ) return ( (forward_outs, backward_outs), @@ -1394,42 +1400,53 @@ def inner_fw_only(*args: Any) -> Any: return inner_fn(inner_fw_only, primals, use_trace_joint=False) if is_joint_structure: + primals_wrapped: list[FxValue] = typing.cast(list[FxValue], args[0]) + primals_wrapped_descs: list[AOTInput] = typing.cast( + list[AOTInput], args_descs[0] + ) + tangents_wrapped: list[FxValue] = typing.cast(list[FxValue], args[1]) + tangents_wrapped_descs: list[AOTInput] = typing.cast( + list[AOTInput], args_descs[1] + ) + # Add extra symints (size/strides) as input to the forward graph primals_unwrapped_pair = unwrap_tensor_subclasses( - args[0], # type: ignore[arg-type] - args_descs[0], # type: ignore[arg-type] + primals_wrapped, + primals_wrapped_descs, append_symints=True, ) # We pass append_symints=False here because the partitioner will # capture and add any extra argument. tangents_unwrapped_pair = unwrap_tensor_subclasses( - args[1], # type: ignore[arg-type] - args_descs[1], # type: ignore[arg-type] + tangents_wrapped, + tangents_wrapped_descs, append_symints=False, ) args_unwrapped = (primals_unwrapped_pair[0], tangents_unwrapped_pair[0]) args_descs_unwrapped = (primals_unwrapped_pair[1], tangents_unwrapped_pair[1]) remapped_static_indices = remap_unwrapped_subclass_arg_indices( - args[0], # type: ignore[arg-type] + primals_wrapped, meta.static_input_indices, # type: ignore[arg-type] ) + + primals_unwrapped = args_unwrapped[0] # type: ignore[assignment] + primals_unwrapped_descs = args_descs_unwrapped[0] # type: ignore[assignment] + fn_to_trace = joint_fn # type: ignore[assignment] else: + primals_wrapped: list[FxValue] = typing.cast(list[FxValue], args) + primals_wrapped_descs: list[AOTInput] = typing.cast(list[AOTInput], args_descs) + args_unwrapped, args_descs_unwrapped = unwrap_tensor_subclasses( # type: ignore[assignment] - args, # type: ignore[arg-type] - args_descs, # type: ignore[arg-type] + primals_wrapped, + primals_wrapped_descs, append_symints=True, ) remapped_static_indices = remap_unwrapped_subclass_arg_indices( - args, # type: ignore[arg-type] + primals_wrapped, meta.static_input_indices, # type: ignore[arg-type] ) - if is_joint_structure: - primals_unwrapped = args_unwrapped[0] # type: ignore[assignment] - primals_unwrapped_descs = args_descs_unwrapped[0] # type: ignore[assignment] - fn_to_trace = joint_fn # type: ignore[assignment] - else: primals_unwrapped = args_unwrapped # type: ignore[assignment] primals_unwrapped_descs = args_descs_unwrapped # type: ignore[assignment] fn_to_trace = fw_fn # type: ignore[assignment] @@ -1457,7 +1474,6 @@ def inner_fw_only(*args: Any) -> Any: flat_args_descs=primals_unwrapped_descs, static_input_indices=remapped_static_indices, keep_input_mutations=meta.keep_input_mutations, - is_train=meta.is_train, # pyrefly: ignore [not-iterable] )(*primals_unwrapped) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 51d2600af6777..c1e57730ecb81 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -13,22 +13,13 @@ import itertools import logging import operator +import threading import time import traceback from collections import defaultdict from collections.abc import Callable, Generator -from contextlib import nullcontext -from typing import Any, TYPE_CHECKING - -from torch._library.fake_class_registry import FakeScriptObject -from torch._opaque_base import OpaqueBase - - -if TYPE_CHECKING: - from collections.abc import Sequence - -import threading -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +from typing import Any import torch import torch.utils._pytree as pytree @@ -41,7 +32,10 @@ lazy_format_graph_code, ) from torch._guards import CompileContext, TracingContext +from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_value from torch._logging import getArtifactLogger, trace_structured +from torch._opaque_base import OpaqueBase from torch._subclasses import FakeTensor from torch._subclasses.meta_utils import is_sparse_any from torch.fx.experimental._backward_state import BackwardState @@ -67,6 +61,7 @@ from .runtime_wrappers import ( AOTDedupeWrapper, AOTDispatchAutograd, + AOTDispatchAutogradCompileSpec, AOTDispatchSubclassWrapper, AOTSyntheticBaseWrapper, AutogradLazyBackwardCompileInfo, @@ -102,9 +97,31 @@ ) +def is_opaque_node(node: Any) -> bool: + """Check if a node contains an opaque or non-tensor value (e.g., ProcessGroup).""" + from torch._library.fake_class_registry import FakeScriptObject + + if not isinstance(node, torch.fx.Node): + return False + if "val" not in getattr(node, "meta", {}): + return False + val = node.meta["val"] + if is_opaque_value(val): + return True + if isinstance(val, (torch.ScriptObject, FakeScriptObject)): + return True + return False + + _thread_local = threading.local() +def _should_save_cache(*compiled_fns: Callable[..., Any]) -> bool: + if should_bundle_autograd_cache(): + return True + return all(hasattr(fn, "_fx_graph_cache_key") for fn in compiled_fns) + + @contextmanager def maybe_skip_decompose(aot_config: AOTConfig) -> Generator[None, None, None]: old_decomp = aot_config.decompositions @@ -482,14 +499,8 @@ def _cache_inference_info( cache_info = aot_config.cache_info - def should_save_cache() -> bool: - if should_bundle_autograd_cache(): - return True - else: - return hasattr(compiled_fw, "_fx_graph_cache_key") - entry: GenericAOTAutogradResult[Any, Any] | None = None - if cache_info is not None and should_save_cache(): + if cache_info is not None and _should_save_cache(compiled_fw): time_taken_ns = time.time_ns() - cache_info.start_time_ns guards_expr = AOTAutogradCache.generate_guards_expression(cache_info) entry = AOTAutogradCache.make_entry( @@ -571,7 +582,7 @@ def collect_fw_donated_buffer_idxs( t = saved_tensors[i] if ( t is not None - and isinstance(t, torch.Tensor) + and isinstance(t, FakeTensor) and not is_sparse_any(t) and StorageWeakRef(t.untyped_storage()) not in storage_refs ): @@ -1649,9 +1660,9 @@ def _log_fw_bw_graphs( ) aot_graphs_log.info( "aot_config id: %s, fw_metadata=%s, inner_meta=%s", - str(aot_config.aot_id), - str(fw_metadata), - str(_get_inner_meta(maybe_subclass_meta, fw_metadata)), + aot_config.aot_id, + fw_metadata, + _get_inner_meta(maybe_subclass_meta, fw_metadata), ) aot_graphs_log.info( @@ -1718,6 +1729,265 @@ def _log_fw_bw_graphs( return fw_module_str, bw_module_str +def _partition_joint_graph_into_fw_bw( + fx_g: torch.fx.GraphModule, + joint_inputs: list[Any] | tuple[list[Any], list[Any]], + inner_meta: ViewAndMutationMeta, + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, +) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule, int]: + # See Note: [Partitioner handling for Subclasses, Part 1] + # See Note: [Recomputing subclass mutation handling] + mutated_inp_runtime_indices = compute_inner_mutated_inp_indices_from_subclass_meta( + fw_metadata, inner_meta + ) + num_tokens = len(fw_metadata.tokens) + num_inner_fwd_outputs = ( + len(mutated_inp_runtime_indices) + + inner_meta.num_outputs + + inner_meta.num_intermediate_bases + + inner_meta.num_outputs_rng_offset + + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] + ) + + fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) + + # apply joint_gm callback here + if callable(torch._functorch.config.joint_custom_pass): + # pyrefly: ignore [bad-assignment] + fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) + + if aot_config.partition_fn is None: + raise AssertionError("aot_config.partition_fn must not be None") + fw_module, bw_module = aot_config.partition_fn( + fx_g, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=fw_metadata.static_input_indices, + ) + + rng_states = [ + n + for n in fw_module.graph.find_nodes(op="placeholder") + if "fwd_rng_state" in n.name + ] + fw_metadata.num_graphsafe_rng_states = len(rng_states) + if rng_states: + fw_metadata.graphsafe_rng_state_index = rng_states[0].meta["val"].device.index + + return fw_module, bw_module, num_inner_fwd_outputs + + +def _joint_inputs_for_forward( + joint_inputs: list[Any] | tuple[list[Any], list[Any]], +) -> list[Any]: + return joint_inputs[0] if isinstance(joint_inputs, tuple) else joint_inputs + + +def _maybe_unlift_partitioned_effect_tokens( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + joint_inputs: list[Any] | tuple[list[Any], list[Any]], + fw_metadata: ViewAndMutationMeta, + aot_config: AOTConfig, + num_inner_fwd_outputs: int, +) -> tuple[int, list[Any] | tuple[list[Any], list[Any]]]: + num_tokens = len(fw_metadata.tokens) + + # See Note [Side-Effectful Tokens in AOTAutograd] + if config.unlift_effect_tokens and ( + num_tokens > 0 or fw_metadata.num_backward_tokens > 0 + ): + unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) + num_inner_fwd_outputs -= num_tokens + if isinstance(joint_inputs, tuple): + joint_inputs = ( + _joint_inputs_for_forward(joint_inputs)[num_tokens:], + joint_inputs[1], + ) + else: + joint_inputs = joint_inputs[num_tokens:] + + return num_inner_fwd_outputs, joint_inputs + + +def _categorize_saved_tensors_for_backward( + fw_module: torch.fx.GraphModule, + bw_module: torch.fx.GraphModule, + inner_meta: ViewAndMutationMeta, + fw_metadata: ViewAndMutationMeta, + num_inner_fwd_outputs: int, +) -> tuple[int, int]: + fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] + # we only need to bookkeep the symints that are saved for bw, not any symints + # the user forward might have returned in its own output + fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] + num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) + + num_symints_saved_for_bw = 0 + num_opaque_objects_saved_for_bw = 0 + for idx, node in enumerate(fw_outs_saved_for_bw): + if is_sym_node(node): + num_symints_saved_for_bw += 1 + elif is_opaque_node(node): + num_opaque_objects_saved_for_bw += 1 + elif isinstance(node, torch.fx.Node) and "val" in getattr(node, "meta", {}): + if isinstance(node.meta["val"], FakeTensor): + # record dynamic tensor activations + dynamic_dims: set[int] = { + dim + for dim, size in enumerate(node.meta["val"].shape) + if not isinstance(size, int) + } + if dynamic_dims: + fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims + elif isinstance(node.meta["val"], (FakeScriptObject, OpaqueBase)): + num_opaque_objects_saved_for_bw += 1 + + fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw + fw_metadata.num_opaque_objects_saved_for_bw = num_opaque_objects_saved_for_bw + inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw + inner_meta.num_opaque_objects_saved_for_bw = num_opaque_objects_saved_for_bw + + # See Note [Activations with no version counter checks in eager] + # Count tensors saved with no version counter check. + # These are tensors that were stashed on ctx (e.g., ctx.x = x) rather than + # via save_for_backward in an autograd.Function. + # The partitioner sorts these to be at the end of saved_values. + num_tensors_saved_with_no_vc_check = sum( + 1 + for node in fw_outs_saved_for_bw + if isinstance(node, torch.fx.Node) + and node.meta.get("saved_tensor_with_no_vc_check", False) + ) + fw_metadata.num_tensors_saved_with_no_vc_check = num_tensors_saved_with_no_vc_check + inner_meta.num_tensors_saved_with_no_vc_check = num_tensors_saved_with_no_vc_check + + if torch._functorch.config.donated_buffer: + fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + fw_module, + bw_module, + inner_meta, + ) + inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs + + return num_fw_outs_saved_for_bw, num_symints_saved_for_bw + + +# Note [Detaching inputs that never need gradients] +# See https://github.com/pytorch/pytorch/issues/97745 +# Suppose we have a function like this that we want to compile: +# +# def f(x, y): +# return torch.mul(x, y.detach()) +# +# What gradients should we compute for x and y? +# By default, AOTAutograd will compute a gradient for **every** input that requires gradients, +# and so we'll compute: +# x_grad_input = y +# y_grad_input = None +# Does this preserve the semantics of eager mode? +# Unfortunately, no. +# Doing the above will cause autograd to **continue** to backprop the autograd tape +# that was generated from constructing y. +# +# This is **different** from what would have happened in eager mode. +# In eager mode, if we backprop through the output of this function, autograd will only traverse +# the bit of the autograd tape corresponding to "x". +# In particular, if a user had previously backpropped through y's autograd tape, +# And then they try to backprop through the output of the above function, +# then we'll hit the dreaded "Trying to backward through the graph a second time" error. +# +# You might think: If autograd sees that a gradient is None, shouldn't it stop early, +# instead of continuing the backprop through the ancestors of that node in the graph? +# +# Autograd has two passes: +# (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed +# (2) a second pass that actually goes ahead and executes each node when it becomes ready, +# propagating gradients +# By the time we're executing a node and we see that it produces a None, the set of nodes to execute +# is already locked-in. +# +# The fix: instead, we can recognize statically that the graph we're compiling will never contribute +# gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. +# We can do this by manually detach'ing y before sending it through the `CompiledFunction`. +# +# Note that this solution is not bulletproof. +# It's possible to construct a case where eager may or may not have have tried to autograd through y, +# depending on the actual grad_outputs that were passed in during the backward. +# There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, +# allowing autograd to reuse the graph. +# +# An example of this case is: +# def f(x): +# return x.detach() * 2, x * 3 +# If we were to only backprop through outs[0], in eager, we would stop +# If we backward only on the first output, we shouldn't send a grad through x. +# But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 +# and we will end up with a zero grad at x. +# If we later backprop through the second output, this will also require backprop'ing through x. +# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. +def _compute_indices_of_inps_to_detach( + bw_module: torch.fx.GraphModule, + maybe_subclass_meta: SubclassMeta | None, + inner_meta: ViewAndMutationMeta, + fw_metadata: ViewAndMutationMeta, +) -> list[int]: + # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" + # optimization even if we have subclass inputs/outputs (we do not handle this today). + # Computing which our our inputs get None gradients is a bit more complicated, + # if any of our inputs are subclasses. Why? + # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. + # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, + # so we need to figure out which subclass fw inputs they map to. + if maybe_subclass_meta is not None: + return [] + + indices_of_inps_to_detach: list[int] = [] + + # reversed() since we expect output at end of graph + bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) + bw_outs = bw_output.args[0] + + num_backward_tokens = inner_meta.num_backward_tokens + expected_bw_outs = ( + len(fw_metadata.input_info) + + inner_meta.num_outputs_rng_offset + + num_backward_tokens + ) + if len(bw_outs) != expected_bw_outs: + raise AssertionError( + f"expected len(bw_outs) == {expected_bw_outs}, got {len(bw_outs)}" + ) + + bw_outs_no_rng_no_tokens = bw_outs + if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: + bw_outs_no_rng_no_tokens = bw_outs[ + : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) + ] + if len(bw_outs_no_rng_no_tokens) != len(fw_metadata.input_info): + raise AssertionError( + f"expected len(bw_outs_no_rng_no_tokens) == {len(fw_metadata.input_info)}, " + f"got {len(bw_outs_no_rng_no_tokens)}" + ) + + for i, bw_out in enumerate(bw_outs_no_rng_no_tokens): + # If our input experiences a metadata mutation inside the graph (e.g. set_()), + # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation + metadata_mutation_in_graph = ( + fw_metadata.input_info[i].mutation_type == MutationType.MUTATED_IN_GRAPH + and fw_metadata.input_info[i].mutates_storage_metadata + ) + is_non_leaf = ( + fw_metadata.input_info[i].requires_grad + and not fw_metadata.input_info[i].is_leaf + ) + if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: + indices_of_inps_to_detach.append(i) + + return indices_of_inps_to_detach + + def _aot_stage2a_partition( fx_g: torch.fx.GraphModule, joint_inputs: list[Any] | tuple[list[Any], list[Any]], @@ -1738,60 +2008,25 @@ def _aot_stage2a_partition( with torch.no_grad(): context = torch._C._DisableAutocast if disable_amp else nullcontext with context(), track_graph_compiling(aot_config, "joint"): - # See Note: [Partitioner handling for Subclasses, Part 1] - # See Note: [Recomputing subclass mutation handling] - mutated_inp_runtime_indices = ( - compute_inner_mutated_inp_indices_from_subclass_meta( - fw_metadata, inner_meta + fw_module, bw_module, num_inner_fwd_outputs = ( + _partition_joint_graph_into_fw_bw( + fx_g, + joint_inputs, + inner_meta, + fw_metadata, + aot_config, ) ) - num_tokens = len(fw_metadata.tokens) - num_mutated_inp_runtime_indices = len(mutated_inp_runtime_indices) - num_inner_fwd_outputs = ( - num_mutated_inp_runtime_indices - + inner_meta.num_outputs - + inner_meta.num_intermediate_bases - + inner_meta.num_outputs_rng_offset - + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] - ) - fx_g = run_joint_graph_passes_on_hops(fx_g, joint_inputs, aot_config) - - # apply joint_gm callback here - if callable(torch._functorch.config.joint_custom_pass): - # pyrefly: ignore [bad-assignment] - fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs) - - static_lifetime_input_indices = fw_metadata.static_input_indices - if aot_config.partition_fn is None: - raise AssertionError("aot_config.partition_fn must not be None") - fw_module, bw_module = aot_config.partition_fn( - fx_g, - joint_inputs, - num_fwd_outputs=num_inner_fwd_outputs, - static_lifetime_input_indices=static_lifetime_input_indices, - ) - rng_states = [ - n - for n in fw_module.graph.find_nodes(op="placeholder") - if "fwd_rng_state" in n.name - ] - fw_metadata.num_graphsafe_rng_states = len(rng_states) - if rng_states: - fw_metadata.graphsafe_rng_state_index = ( - rng_states[0].meta["val"].device.index - ) - - # See Note [Side-Effectful Tokens in AOTAutograd] - if config.unlift_effect_tokens and ( - num_tokens > 0 or fw_metadata.num_backward_tokens > 0 - ): - unlift_tokens(fw_module, fw_metadata, aot_config, bw_module) - - num_inner_fwd_outputs -= num_tokens - joint_inputs = ( - joint_inputs[0][num_tokens:], - joint_inputs[1], + num_inner_fwd_outputs, joint_inputs = ( + _maybe_unlift_partitioned_effect_tokens( + fw_module, + bw_module, + joint_inputs, + fw_metadata, + aot_config, + num_inner_fwd_outputs, ) + ) maybe_inline_graph_saved_tensors_hooks( fw_module, @@ -1801,170 +2036,22 @@ def _aot_stage2a_partition( aot_config, fw_metadata.static_input_indices, ) - static_lifetime_input_indices = fw_metadata.static_input_indices - - fw_outs = next(iter(fw_module.graph.find_nodes(op="output"))).args[0] - # we only need to bookkeep the symints that are saved for bw, not any symints - # the user forward might have returned in its own output - fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:] - num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw) - symint_outs_saved_for_bw = [] - opaque_outs_saved_for_bw = [] - for idx, node in enumerate(fw_outs_saved_for_bw): - if is_sym_node(node): - symint_outs_saved_for_bw.append(node) - elif isinstance(node, torch.fx.Node) and "val" in getattr( - node, "meta", {} - ): - if isinstance(node.meta["val"], FakeTensor): - # record dynamic tensor activations - dynamic_dims: set[int] = { - dim - for dim, size in enumerate(node.meta["val"].shape) - if not isinstance(size, int) - } - if dynamic_dims: - fw_metadata.dynamic_saved_tensors_idxs[idx] = dynamic_dims - elif isinstance(node.meta["val"], (FakeScriptObject, OpaqueBase)): - opaque_outs_saved_for_bw.append(node) - - num_symints_saved_for_bw = len(symint_outs_saved_for_bw) - num_opaque_objects_saved_for_bw = len(opaque_outs_saved_for_bw) - fw_metadata.num_symints_saved_for_bw = num_symints_saved_for_bw - fw_metadata.num_opaque_objects_saved_for_bw = ( - num_opaque_objects_saved_for_bw - ) - inner_meta.num_symints_saved_for_bw = num_symints_saved_for_bw - inner_meta.num_opaque_objects_saved_for_bw = num_opaque_objects_saved_for_bw - - # See Note [Activations with no version counter checks in eager] - # Count tensors saved with no version counter check. - # These are tensors that were stashed on ctx (e.g., ctx.x = x) rather than - # via save_for_backward in an autograd.Function. - # The partitioner sorts these to be at the end of saved_values. - num_tensors_saved_with_no_vc_check = 0 - for node in fw_outs_saved_for_bw: - if isinstance(node, torch.fx.Node) and node.meta.get( - "saved_tensor_with_no_vc_check", False - ): - num_tensors_saved_with_no_vc_check += 1 - fw_metadata.num_tensors_saved_with_no_vc_check = ( - num_tensors_saved_with_no_vc_check - ) - inner_meta.num_tensors_saved_with_no_vc_check = ( - num_tensors_saved_with_no_vc_check - ) - - if torch._functorch.config.donated_buffer: - fw_metadata.bw_donated_idxs = collect_bw_donated_buffer_idxs( + num_fw_outs_saved_for_bw, num_symints_saved_for_bw = ( + _categorize_saved_tensors_for_backward( fw_module, bw_module, inner_meta, + fw_metadata, + num_inner_fwd_outputs, ) - inner_meta.bw_donated_idxs = fw_metadata.bw_donated_idxs - - # Note [Detaching inputs that never need gradients] - # See https://github.com/pytorch/pytorch/issues/97745 - # Suppose we have a function like this that we want to compile: - # - # def f(x, y): - # return torch.mul(x, y.detach()) - # - # What gradients should we compute for x and y? - # By default, AOTAutograd will compute a gradient for **every** input that requires gradients, - # and so we'll compute: - # x_grad_input = y - # y_grad_input = None - # Does this preserve the semantics of eager mode? - # Unfortunately, no. - # Doing the above will cause autograd to **continue** to backprop the autograd tape - # that was generated from constructing y. - # - # This is **different** from what would have happened in eager mode. - # In eager mode, if we backprop through the output of this function, autograd will only traverse - # the bit of the autograd tape corresponding to "x". - # In particular, if a user had previously backpropped through y's autograd tape, - # And then they try to backprop through the output of the above function, - # then we'll hit the dreaded "Trying to backward through the graph a second time" error. - # - # You might think: If autograd sees that a gradient is None, shouldn't it stop early, - # instead of continuing the backprop through the ancestors of that node in the graph? - # - # Autograd has two passes: - # (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed - # (2) a second pass that actually goes ahead and executes each node when it becomes ready, - # propagating gradients - # By the time we're executing a node and we see that it produces a None, the set of nodes to execute - # is already locked-in. - # - # The fix: instead, we can recognize statically that the graph we're compiling will never contribute - # gradients to y, and prevent autograd from trying to traverse y's autograd tape at all. - # We can do this by manually detach'ing y before sending it through the `CompiledFunction`. - # - # Note that this solution is not bulletproof. - # It's possible to construct a case where eager may or may not have have tried to autograd through y, - # depending on the actual grad_outputs that were passed in during the backward. - # There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`, - # allowing autograd to reuse the graph. - # - # An example of this case is: - # def f(x): - # return x.detach() * 2, x * 3 - # If we were to only backprop through outs[0], in eager, we would stop - # If we backward only on the first output, we shouldn't send a grad through x. - # But the custom autograd function doesn't know that: it will materialize zero grads for x * 3 - # and we will end up with a zero grad at x. - # If we later backprop through the second output, this will also require backprop'ing through x. - # Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time. - _indices_of_inps_to_detach: list[int] = [] - - # reversed() since we expect output at end of graph - bw_output = next(reversed(bw_module.graph.find_nodes(op="output"))) - bw_outs: Sequence[torch.fx.Node] = bw_output.args[0] # type: ignore[assignment] - - # TODO: we should apply the below "detach inputs if their gradients are statically known to be None" - # optimization even if we have subclass inputs/outputs (we do not handle this today). - # Computing which our our inputs get None gradients is a bit more complicated, - # if any of our inputs are subclasses. Why? - # (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses. - # (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors, - # so we need to figure out which subclass fw inputs they map to. - if maybe_subclass_meta is None: - num_backward_tokens: int = inner_meta.num_backward_tokens - expected_bw_outs = ( - len(fw_metadata.input_info) - + inner_meta.num_outputs_rng_offset - + num_backward_tokens ) - if len(bw_outs) != expected_bw_outs: - raise AssertionError( - f"expected len(bw_outs) == {expected_bw_outs}, got {len(bw_outs)}" - ) - bw_outs_no_rng_no_tokens = bw_outs - if (inner_meta.num_outputs_rng_offset + num_backward_tokens) > 0: - bw_outs_no_rng_no_tokens = bw_outs[ - : -(inner_meta.num_outputs_rng_offset + num_backward_tokens) - ] - if len(bw_outs_no_rng_no_tokens) != len(fw_metadata.input_info): - raise AssertionError( - f"expected len(bw_outs_no_rng_no_tokens) == {len(fw_metadata.input_info)}, " - f"got {len(bw_outs_no_rng_no_tokens)}" - ) - for i, (bw_out) in enumerate(bw_outs_no_rng_no_tokens): - # If our input experiences a metadata mutation inside the graph (e.g. set_()), - # we *must* not detach, otherwise it will be the detach'd input that gets the metadata mutation - metadata_mutation_in_graph = ( - fw_metadata.input_info[i].mutation_type - == MutationType.MUTATED_IN_GRAPH - and fw_metadata.input_info[i].mutates_storage_metadata - ) - is_non_leaf = ( - fw_metadata.input_info[i].requires_grad - and not fw_metadata.input_info[i].is_leaf - ) - if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: - _indices_of_inps_to_detach.append(i) + _indices_of_inps_to_detach = _compute_indices_of_inps_to_detach( + bw_module, + maybe_subclass_meta, + inner_meta, + fw_metadata, + ) return ( fw_module, @@ -1972,7 +2059,7 @@ def _aot_stage2a_partition( num_fw_outs_saved_for_bw, num_symints_saved_for_bw, _indices_of_inps_to_detach, - joint_inputs[0], + _joint_inputs_for_forward(joint_inputs), ) @@ -2270,19 +2357,20 @@ def _aot_stage2c_make_autograd_function( ) disable_amp = torch._C._is_any_autocast_enabled() - compiled_fn = AOTDispatchAutograd.post_compile( - compiled_fw_func, - compiled_bw_func, - maybe_subclass_meta, - num_symints_saved_for_bw, - backward_state_indices, - disable_amp, - _indices_of_inps_to_detach, - lazy_backward_info, - aot_config, + compile_spec = AOTDispatchAutogradCompileSpec( + compiled_fw_func=compiled_fw_func, + compiled_bw_func=compiled_bw_func, + maybe_subclass_meta=maybe_subclass_meta, + num_symints_saved_for_bw=num_symints_saved_for_bw, + backward_state_indices=backward_state_indices, + disable_amp=disable_amp, + indices_of_inps_to_detach=_indices_of_inps_to_detach, + lazy_backward_info=lazy_backward_info, + aot_config=aot_config, fw_metadata=fw_metadata, try_save_cache_entry=try_save_cache_entry, ) + compiled_fn = AOTDispatchAutograd.post_compile(compile_spec) if entry is not None: compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: entry) @@ -2342,7 +2430,7 @@ def _cache_autograd_info( # NB: aot_config here is technically not needed as an argument: we could just # close over aot_config.cache_info, since aot_config never changes. # But closing over random variables is confusing IMO, so I'm leaving it. - def try_save_cache_entry( # noqa: F811 + def try_save_cache_entry( compiled_bw_func: Callable[..., Any], bw_module: torch.fx.GraphModule, _fw_metadata: ViewAndMutationMeta, @@ -2350,15 +2438,9 @@ def try_save_cache_entry( # noqa: F811 ) -> GenericAOTAutogradResult[Any, Any] | None: cache_info = aot_config.cache_info - def should_save_cache() -> bool: - if should_bundle_autograd_cache(): - return True - else: - return hasattr(compiled_fw_func, "_fx_graph_cache_key") and hasattr( - compiled_bw_func, "_fx_graph_cache_key" - ) - - if cache_info is not None and should_save_cache(): + if cache_info is not None and _should_save_cache( + compiled_fw_func, compiled_bw_func + ): if forward_time_taken_ns is None: raise AssertionError("forward_time_taken_ns must not be None") # TODO: technically, AOTAutograd does a *little* bit of post processing work diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index b8f9ba46ffba4..0c0721fcb73f5 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -109,7 +109,6 @@ def remove_dupe_metadata( subclass_inp_meta=[], subclass_fw_graph_out_meta=[], subclass_tangent_meta=subclass_tangent_meta, - is_train=m.is_train, ) @@ -308,7 +307,6 @@ def create_synthetic_base_metadata( subclass_inp_meta=[], subclass_fw_graph_out_meta=[], subclass_tangent_meta=subclass_tangent_meta, - is_train=m.is_train, ), outer_aliased_arg_idx_with_metadata_mutations, ) diff --git a/torch/_functorch/_aot_autograd/logging_utils.py b/torch/_functorch/_aot_autograd/logging_utils.py index 550f99e6dd213..a5dda6c67dfa7 100644 --- a/torch/_functorch/_aot_autograd/logging_utils.py +++ b/torch/_functorch/_aot_autograd/logging_utils.py @@ -117,6 +117,7 @@ def prehook(grad_output: Any) -> None: fx_traceback.set_stack_trace(stack_) fx_traceback.set_grad_fn_seq_nr(seq_nr) + fx_traceback._mark_autograd_backward() return prehook @@ -126,6 +127,7 @@ def get_posthook( def posthook(grad_input: Any, grad_output: Any) -> None: fx_traceback.set_stack_trace(special_stack_) fx_traceback.reset_grad_fn_seq_nr() + fx_traceback._reset_autograd_backward() return posthook @@ -139,6 +141,16 @@ def posthook(grad_input: Any, grad_output: Any) -> None: node.register_hook(get_posthook(special_stack, node._sequence_nr())) +def setup_stacktrace_preservation_hooks_from_tensors(outputs: Any) -> None: + roots = [ + t.grad_fn + for t in (outputs if isinstance(outputs, (list, tuple)) else (outputs,)) + if isinstance(t, torch.Tensor) and t.grad_fn is not None + ] + if roots: + setup_stacktrace_preservation_hooks(roots) + + def describe_input(i: int, aot_config: AOTConfig) -> str: if i < aot_config.num_params_buffers: return f"parameter/buffer {i}" diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 41eb18c504b6a..ddacdaefcd030 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -6,7 +6,6 @@ 4. deduplicate inputs and consolidate views into their bases (see input_output_analysis) """ -import builtins import collections import contextlib import copy @@ -105,21 +104,31 @@ ) +def _unwrap_tensor_subclasses_no_symints( + args: list[Any], +) -> list[Any]: + return runtime_unwrap_tensor_subclasses(args, append_symints=False) # type: ignore[arg-type] + + zip = strict_zip aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") +def _unwrap_no_symints(args: list[Any]) -> list[Any]: + return runtime_unwrap_tensor_subclasses(args, append_symints=False) + + def _describe_arg_for_logging(arg: object) -> str: from torch._library import opaque_object try: - is_dtensor = isinstance(arg, torch.distributed._tensor.DTensor) + is_dtensor = isinstance(arg, torch.distributed.tensor.DTensor) except AttributeError: is_dtensor = False if is_dtensor: - arg = typing.cast(torch.distributed._tensor.DTensor, arg) + arg = typing.cast(torch.distributed.tensor.DTensor, arg) mesh = arg.device_mesh return ( f"DTensor(shape={arg.shape}, dtype={arg.dtype}, " @@ -462,6 +471,253 @@ def __call__(self) -> AbstractContextManager[Any]: return nullcontext() +@dataclass +class _RuntimeCompiledFnInvoker: + compiled_fn: Callable[..., Any] + indices_of_inps_to_detach: list[int] + trace_joint: bool + disable_amp: bool + first_invocation_ctx: _FirstInvocationContext = field( + default_factory=_FirstInvocationContext + ) + + def __post_init__(self) -> None: + if not getattr(self.compiled_fn, "_boxed_call", False): + self.compiled_fn = make_boxed_func(self.compiled_fn) + + def run(self, args: list[Any], *, on_before_call: Callable[[], None]) -> list[Any]: + with self.first_invocation_ctx(): + if self.trace_joint: + args_ = list(args) + # See Note [Detaching inputs that never need gradients] + for idx in self.indices_of_inps_to_detach: + if isinstance(args_[idx], torch.Tensor): + args_[idx] = args_[idx].detach() + + # It's possible to have trace_joint inside user specified with no_grad() region, + # if there is a nested with enable_grad(), that forces some outputs to require gradients. + # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. + with ( + torch.autograd._force_original_view_tracking(True), + torch.enable_grad(), + ): + on_before_call() + return call_func_at_runtime_with_args( + self.compiled_fn, + args_, + disable_amp=self.disable_amp, + steal_args=True, + ) + + # When we have an inference graph, we run with grad disabled. + # It's possible to get an inference graph with inputs that require grad, + # in which case we want to make sure autograd is disabled + # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) + # NOTE: We use _set_grad_enabled directly to reduce runtime overhead + grad_enabled = torch.is_grad_enabled() + try: + if grad_enabled: + torch._C._set_grad_enabled(False) + on_before_call() + return call_func_at_runtime_with_args( + self.compiled_fn, + args, + disable_amp=self.disable_amp, + steal_args=True, + ) + finally: + if grad_enabled: + torch._C._set_grad_enabled(True) + + +@dataclass +class _RuntimeForwardEpilogue: + runtime_metadata: ViewAndMutationMeta + trace_joint: bool + keep_input_mutations: bool + epilogue_args_idx: tuple[int, ...] = field(init=False) + output_handlers: tuple[Any, ...] = field(init=False) + + def __post_init__(self) -> None: + epilogue_args_idx = list(self.runtime_metadata.mutated_inp_runtime_indices) + for info in self.runtime_metadata.output_info: + if ( + info.output_type == OutputType.alias_of_input + or info.output_type == OutputType.is_input + ): + if not isinstance(info.base_idx, int): + raise AssertionError( + f"expected info.base_idx to be int, got {type(info.base_idx)}" + ) + epilogue_args_idx.append(info.base_idx) + self.epilogue_args_idx = tuple(epilogue_args_idx) + + if config.unlift_effect_tokens: + if len(self.runtime_metadata.tokens) != 0: + raise AssertionError( + "expected no tokens when unlift_effect_tokens is True, " + f"got {len(self.runtime_metadata.tokens)}" + ) + + if self.runtime_metadata.num_outputs_aliased > 0: + self.output_handlers = tuple( + make_output_handler(info, self.runtime_metadata, self.trace_joint) + for info in self.runtime_metadata.output_info + ) + else: + self.output_handlers = () + + def capture_orig_inputs(self, args: list[Any]) -> dict[int, Any]: + return {i: args[i] for i in self.epilogue_args_idx} + + def increment_mutation_versions(self, args: list[Any]) -> None: + if self.keep_input_mutations: + mutated_args = ( + args[i] + for i in self.runtime_metadata.mutated_graph_handled_indices_seen_by_autograd + ) + torch.autograd.graph.increment_version(mutated_args) + + def finalize(self, orig_inputs: dict[int, Any], all_outs: list[Any]) -> Any: + self._validate_compiled_output_arity(all_outs) + updated_inputs, fw_outs = self._split_mutated_inputs(all_outs) + if updated_inputs is not None: + self._apply_input_mutations(orig_inputs, updated_inputs) + + ret_outs = self._replay_output_aliases(orig_inputs, fw_outs) + if self.runtime_metadata.dynamic_outputs: + for t, o in zip(ret_outs, self.runtime_metadata.output_info): + if o.dynamic_dims is None: + continue + maybe_mark_dynamic_helper(t, o.dynamic_dims) + if self.runtime_metadata.grad_enabled_mutation is not None: + torch._C._set_grad_enabled(self.runtime_metadata.grad_enabled_mutation) + return ret_outs + + def _validate_compiled_output_arity(self, all_outs: list[Any]) -> None: + expected_outs = ( + self.runtime_metadata.num_mutated_inp_runtime_indices + + self.runtime_metadata.num_outputs + + self.runtime_metadata.num_intermediate_bases + ) + if len(all_outs) != expected_outs: + raise AssertionError( + f"expected {expected_outs} outputs, got {len(all_outs)}" + ) + + def _split_mutated_inputs( + self, all_outs: list[Any] + ) -> tuple[list[Any] | None, list[Any]]: + num_mutated_runtime_inps = self.runtime_metadata.num_mutated_inp_runtime_indices + if num_mutated_runtime_inps == 0: + return None, all_outs + return ( + all_outs[:num_mutated_runtime_inps], + all_outs[num_mutated_runtime_inps:], + ) + + def _apply_input_mutations( + self, orig_inputs: dict[int, Any], updated_inputs: list[Any] + ) -> None: + for i, inpt_idx in enumerate(self.runtime_metadata.mutated_inp_runtime_indices): + meta = self.runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + original_inpt = orig_inputs[inpt_idx] + updated_inpt = updated_inputs[i] + if meta.mutates_storage_metadata: + # See Note [set_() Input Mutations in AOTAutograd] + # mutates_storage_metadata means our input saw a x.set_(y) call. + # What if x **also** saw a data and/or a metadata mutation? + # (1) If the [meta]data mutation occurred after the set_(), + # then there is no need to copy_() the data. + # When we perform x.set_(x_updated), we are guaranteed that + # x_updated already has the final version of the data/metadata + # (2) If a data mutation occurred before the set_(). + # This case seems very difficult to support. + # TODO: discuss on the PR and decide if we want to tr to + # either support it, or detect and ban it. + if self.trace_joint: + if not isinstance(updated_inpt, TensorAlias): + raise AssertionError( + f"expected TensorAlias for updated_inpt, got {type(updated_inpt)}" + ) + updated_inpt = updated_inpt.alias + with torch.no_grad(): + original_inpt.set_(updated_inpt) + continue + if meta.mutates_metadata and not meta.mutates_data: + if self.trace_joint: + if not isinstance(updated_inpt, TensorAlias): + raise AssertionError( + f"expected TensorAlias for updated_inpt, got {type(updated_inpt)}" + ) + updated_inpt = updated_inpt.alias + # We need to grab the size/stride/storage_offset from the compiled forward, + # and use that to mutate the metadata of the input + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if meta.mutates_data and meta.mutates_metadata: + original_inpt.as_strided_( + updated_inpt.size(), + updated_inpt.stride(), + updated_inpt.storage_offset(), + ) + else: + if not meta.mutates_data: + raise AssertionError("expected meta.mutates_data to be True") + if meta.is_leaf and original_inpt.requires_grad: + # We can hit this situation in this case: + # def f(x): + # x.detach().mul_(2) + # return x + 1 + # AOTAutograd will see a mutation in the above case, and try to + # apply a copy_() here, in the epilogue. + # But if x required gradients, and is a leaf, then autograd + # will yell at us for trying to mutate it. + # However, it's only possible to end up in this scenario (like the above) + # if all of the mutations to the leaf input were non-autograd-tracking mutations + # (aka mutations under no_grad(), or on detached views). + # In that case, we fully want to hide the mutation from autograd, so detaching is ok. + original_inpt.detach().copy_(updated_inpt) + else: + # Check if we have stream index information for this mutated input + if ( + self.runtime_metadata.mutated_inp_stream_indices is not None + and i < len(self.runtime_metadata.mutated_inp_stream_indices) + and self.runtime_metadata.mutated_inp_stream_indices[i] + is not None + ): + raise RuntimeError( + "Mutations on inputs with user-specified streams are not yet supported. " + "See: https://github.com/pytorch/pytorch/issues/172522" + ) + original_inpt.copy_(updated_inpt) + + def _replay_output_aliases( + self, orig_inputs: dict[int, Any], fw_outs: list[Any] + ) -> Any: + if self.runtime_metadata.num_outputs_aliased == 0: + return fw_outs + + # The compiled forward also returned intermediate bases. We don't want to return them to the user. + expect_num_outputs = ( + len(self.output_handlers) + self.runtime_metadata.num_intermediate_bases + ) + if len(fw_outs) != expect_num_outputs: + raise AssertionError( + f"expected {expect_num_outputs} fw_outs, got {len(fw_outs)}" + ) + return [ + handler(orig_inputs, fw_outs, out) + for out, handler in zip(fw_outs, self.output_handlers) + ] + + def _create_runtime_wrapper( compiled_fn: Callable[..., Any], *, @@ -471,48 +727,89 @@ def _create_runtime_wrapper( keep_input_mutations: bool, disable_amp: bool, ) -> Callable[..., Any]: - if not getattr(compiled_fn, "_boxed_call", False): - compiled_fn = make_boxed_func(compiled_fn) - - # We only want to run debugmode on custom ops at the first invocation of - # runtime wrapper. For all subsequent uses, we should no-op for performance - # See: https://github.com/pytorch/pytorch/issues/165349 - first_invocation_ctx = _FirstInvocationContext() - - # Note [Inputs needed in runtime epilogue after list clearing] - # In Python functions, you can't free the input arguments of a function within the scope of that function. A workaround is to - # wrap the input arguments in a list, and clear the list from within the function. - # Here, this is implemented as `call_func_at_runtime_with_args(..., steal_args=True)`. - # - # This is needed for Compiled Autograd since some of the inputs (activations) should be freed early. - # However, we cannot blindly clear the entire list, because AOTAutograd may need access to some of the graph inputs - # **after** the compiled function has finished running. There are two main cases: - # (1) Input mutations: If there are an input mutations that we must run outside of the graph, we need access to the input. - # (2) Output aliasing: Outputs that aliases graph inputs generally must be regenerated outside of the `autograd.Function`, - # and doing so requires us accessing the corresponding input after the compiled artifact has run. - epilogue_args_idx = [] - epilogue_args_idx.extend(runtime_metadata.mutated_inp_runtime_indices) - for info in runtime_metadata.output_info: - if ( - info.output_type == OutputType.alias_of_input - or info.output_type == OutputType.is_input - ): - if not isinstance(info.base_idx, int): + compiled_invoker = _RuntimeCompiledFnInvoker( + compiled_fn=compiled_fn, + indices_of_inps_to_detach=indices_of_inps_to_detach, + trace_joint=trace_joint, + disable_amp=disable_amp, + ) + runtime_epilogue = _RuntimeForwardEpilogue( + runtime_metadata=runtime_metadata, + trace_joint=trace_joint, + keep_input_mutations=keep_input_mutations, + ) + + # Codegen output alias regeneration: emit straight-line code per output + # with all handler branches resolved at compile time. + if runtime_metadata.num_outputs_aliased > 0: + output_handlers = runtime_epilogue.output_handlers + alias_lines = ["def _alias_fn(orig_inputs, fw_outs):"] + alias_lines.append(" ret_outs = []") + alias_globals: dict[str, object] = { + "gen_alias_from_base": gen_alias_from_base, + "_unwrap_tensoralias": _unwrap_tensoralias, + } + for i, handler in enumerate(output_handlers): + if isinstance(handler, NoopAliasHandler): + alias_lines.append(f" ret_outs.append(fw_outs[{i}])") + elif isinstance(handler, IsInputHandler): + alias_lines.append( + f" ret_outs.append(orig_inputs[{handler.base_idx}])" + ) + elif isinstance(handler, AliasOfInputHandler): + vms_name = f"_vms_{i}" + alias_globals[vms_name] = handler.view_meta_sequence + out_expr = ( + f"_unwrap_tensoralias(fw_outs[{i}])" + if trace_joint + else f"fw_outs[{i}]" + ) + alias_lines.append( + f" ret_outs.append(gen_alias_from_base(" + f"orig_inputs[{handler.base_idx}], {out_expr}, " + f"{handler.requires_grad!r}, {vms_name}, " + f"replay_views={handler.replay_views!r}))" + ) + elif isinstance(handler, AliasOfIntermediateHandler): + vms_name = f"_vms_{i}" + alias_globals[vms_name] = handler.view_meta_sequence + out_expr = ( + f"_unwrap_tensoralias(fw_outs[{i}])" + if trace_joint + else f"fw_outs[{i}]" + ) + base_unwrap = handler._unwrap_aliased_base_tensor is _unwrap_tensoralias + base_expr = ( + f"_unwrap_tensoralias(fw_outs[{handler.base_idx}])" + if base_unwrap + else f"fw_outs[{handler.base_idx}]" + ) + alias_lines.append( + f" ret_outs.append(gen_alias_from_base(" + f"{base_expr}, {out_expr}, " + f"{handler.requires_grad!r}, {vms_name}, " + f"replay_views={handler.replay_views!r}))" + ) + else: raise AssertionError( - f"expected info.base_idx to be int, got {type(info.base_idx)}" + f"unhandled output handler type: {type(handler).__name__}" ) - epilogue_args_idx.append(info.base_idx) + alias_lines.append(" return ret_outs") + alias_source = "\n".join(alias_lines) - if config.unlift_effect_tokens: - if len(runtime_metadata.tokens) != 0: - raise AssertionError( - f"expected no tokens when unlift_effect_tokens is True, got {len(runtime_metadata.tokens)}" - ) + from .subclass_codegen import _compile_and_exec_source - if runtime_metadata.num_outputs_aliased > 0: - output_handlers = tuple( - make_output_handler(info, runtime_metadata, trace_joint) - for info in runtime_metadata.output_info + _codegen_alias_fn = _compile_and_exec_source( + alias_source, alias_globals, "_alias_fn", "output_alias_wrapper" + ) + import types + + def _replay_alias(self, orig_inputs, fw_outs): + return _codegen_alias_fn(orig_inputs, fw_outs) + + runtime_epilogue._replay_output_aliases = types.MethodType( # type: ignore[attr-defined] + _replay_alias, + runtime_epilogue, ) def record_runtime_wrapper_prologue_enter() -> AbstractContextManager[None] | None: @@ -533,188 +830,103 @@ def record_runtime_wrapper_prologue_exit( if cm is not None: cm.__exit__(None, None, None) - @simple_wraps(compiled_fn) - def runtime_wrapper(args: list[Any]) -> Any: - # Create context manager for profiler - cm = record_runtime_wrapper_prologue_enter() - - # stash a ref to each input tensor we plan to use after the compiled function - orig_inputs = {i: args[i] for i in epilogue_args_idx} - - if keep_input_mutations: - mutated_args = ( - args[i] - for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd - ) - torch.autograd.graph.increment_version(mutated_args) - - # Enable _AnalyzeCustomOpInputOutputMode on first invocation to check aliasing constraints for custom ops - with first_invocation_ctx(): - if trace_joint: - args_ = list(args) - # See Note [Detaching inputs that never need gradients] - for idx in indices_of_inps_to_detach: - if isinstance(args_[idx], torch.Tensor): - args_[idx] = args_[idx].detach() - - # It's possible to have trace_joint inside user specified with no_grad() region, - # if there is a nested with enable_grad(), that forces some outputs to require gradients. - # Therefore, we unconditionally turn on enable_grad() for compiled_fn execution. - with ( - torch.autograd._force_original_view_tracking(True), - torch.enable_grad(), - ): - record_runtime_wrapper_prologue_exit(cm) - all_outs = call_func_at_runtime_with_args( - compiled_fn, args_, disable_amp=disable_amp, steal_args=True - ) + # Codegen mutation epilogue: emit straight-line code per mutated input + # with all branches resolved at compile time. + if runtime_metadata.num_mutated_inp_runtime_indices > 0: + mut_lines = ["def _apply_mutations(orig_inputs, updated_inputs):"] + mut_globals: dict[str, object] = { + "torch": torch, + "_unwrap_tensoralias": _unwrap_tensoralias, + } + for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): + meta = runtime_metadata.input_info[inpt_idx] + if not meta.mutates_data and not meta.mutates_metadata: + continue + oi = f"orig_inputs[{inpt_idx}]" + ui = f"updated_inputs[{i}]" + if meta.mutates_storage_metadata: + if trace_joint: + mut_lines.append(f" _u{i} = _unwrap_tensoralias({ui})") + else: + mut_lines.append(f" _u{i} = {ui}") + mut_lines.append(f" with torch.no_grad(): {oi}.set_(_u{i})") + elif meta.mutates_metadata and not meta.mutates_data: + if trace_joint: + mut_lines.append(f" _u{i} = _unwrap_tensoralias({ui})") + else: + mut_lines.append(f" _u{i} = {ui}") + mut_lines.append( + f" {oi}.as_strided_(_u{i}.size(), _u{i}.stride(), _u{i}.storage_offset())" + ) else: - # When we have an inference graph, we run with grad disabled. - # It's possible to get an inference graph with inputs that require grad, - # in which case we want to make sure autograd is disabled - # (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on) - # NOTE: We use _set_grad_enabled directly to reduce runtime overhead - grad_enabled = torch.is_grad_enabled() - try: - if grad_enabled: - torch._C._set_grad_enabled(False) - record_runtime_wrapper_prologue_exit(cm) - all_outs = call_func_at_runtime_with_args( - compiled_fn, args, disable_amp=disable_amp, steal_args=True + if meta.mutates_data and meta.mutates_metadata: + mut_lines.append( + f" {oi}.as_strided_({ui}.size(), {ui}.stride(), {ui}.storage_offset())" ) - finally: - if grad_enabled: - torch._C._set_grad_enabled(True) - - del args + else: + assert meta.mutates_data, ( # noqa: S101 + f"expected mutates_data for input {inpt_idx}" + ) + if meta.is_leaf: + mut_lines.append( + f" if {oi}.requires_grad: {oi}.detach().copy_({ui})" + ) + mut_lines.append(f" else: {oi}.copy_({ui})") + else: + has_stream = ( + runtime_metadata.mutated_inp_stream_indices is not None + and i < len(runtime_metadata.mutated_inp_stream_indices) + and runtime_metadata.mutated_inp_stream_indices[i] is not None + ) + if has_stream: + msg_name = f"_stream_err_{i}" + mut_globals[msg_name] = ( + "Mutations on inputs with user-specified streams are not yet supported. " + "See: https://github.com/pytorch/pytorch/issues/172522" + ) + mut_lines.append(f" raise RuntimeError({msg_name})") + else: + mut_lines.append(f" {oi}.copy_({ui})") + if len(mut_lines) == 1: + mut_lines.append(" pass") + mut_source = "\n".join(mut_lines) - num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices - num_intermediate_bases = runtime_metadata.num_intermediate_bases + from .subclass_codegen import _compile_and_exec_source - expected_outs = ( - num_mutated_runtime_inps - + runtime_metadata.num_outputs - + num_intermediate_bases + codegen_apply_mutations = _compile_and_exec_source( + mut_source, mut_globals, "_apply_mutations", "mutation_epilogue" ) - if len(all_outs) != expected_outs: - raise AssertionError( - f"expected {expected_outs} outputs, got {len(all_outs)}" - ) + import types - # Step 3: After running the compiled fw, apply updates to mutated inputs - if num_mutated_runtime_inps > 0: - updated_inputs = all_outs[:num_mutated_runtime_inps] - fw_outs = all_outs[num_mutated_runtime_inps:] + runtime_epilogue._apply_input_mutations = types.MethodType( # type: ignore[attr-defined] + lambda self, orig_inputs, updated_inputs: codegen_apply_mutations( + orig_inputs, updated_inputs + ), + runtime_epilogue, + ) - for i, inpt_idx in enumerate(runtime_metadata.mutated_inp_runtime_indices): - meta = runtime_metadata.input_info[inpt_idx] - if not meta.mutates_data and not meta.mutates_metadata: - continue - original_inpt = orig_inputs[inpt_idx] - updated_inpt = updated_inputs[i] - if meta.mutates_storage_metadata: - # See Note [set_() Input Mutations in AOTAutograd] - # mutates_storage_metadata means our input saw a x.set_(y) call. - # What if x **also** saw a data and/or a metadata mutation? - # (1) If the [meta]data mutation occurred after the set_(), - # then there is no need to copy_() the data. - # When we perform x.set_(x_updated), we are guaranteed that - # x_updated already has the final version of the data/metadata - # (2) If a data mutation occurred before the set_(). - # This case seems very difficult to support. - # TODO: discuss on the PR and decide if we want to tr to - # either support it, or detect and ban it. - if trace_joint: - if not isinstance(updated_inpt, TensorAlias): - raise AssertionError( - f"expected TensorAlias for updated_inpt, got {type(updated_inpt)}" - ) - updated_inpt = updated_inpt.alias - with torch.no_grad(): - original_inpt.set_(updated_inpt) - continue - if meta.mutates_metadata and not meta.mutates_data: - if trace_joint: - if not isinstance(updated_inpt, TensorAlias): - raise AssertionError( - f"expected TensorAlias for updated_inpt, got {type(updated_inpt)}" - ) - updated_inpt = updated_inpt.alias - # We need to grab the size/stride/storage_offset from the compiled forward, - # and use that to mutate the metadata of the input - original_inpt.as_strided_( - updated_inpt.size(), - updated_inpt.stride(), - updated_inpt.storage_offset(), - ) - else: - if meta.mutates_data and meta.mutates_metadata: - original_inpt.as_strided_( - updated_inpt.size(), - updated_inpt.stride(), - updated_inpt.storage_offset(), - ) - else: - if not meta.mutates_data: - raise AssertionError( - "expected meta.mutates_data to be True" - ) - if meta.is_leaf and original_inpt.requires_grad: - # We can hit this situation in this case: - # def f(x): - # x.detach().mul_(2) - # return x + 1 - # AOTAutograd will see a mutation in the above case, and try to - # apply a copy_() here, in the epilogue. - # But if x required gradients, and is a leaf, then autograd - # will yell at us for trying to mutate it. - # However, it's only possible to end up in this scenario (like the above) - # if all of the mutations to the leaf input were non-autograd-tracking mutations - # (aka mutations under no_grad(), or on detached views). - # In that case, we fully want to hide the mutation from autograd, so detaching is ok. - original_inpt.detach().copy_(updated_inpt) - else: - # Check if we have stream index information for this mutated input - if ( - runtime_metadata.mutated_inp_stream_indices is not None - and i < len(runtime_metadata.mutated_inp_stream_indices) - and runtime_metadata.mutated_inp_stream_indices[i] - is not None - ): - raise RuntimeError( - "Mutations on inputs with user-specified streams are not yet supported. " - "See: https://github.com/pytorch/pytorch/issues/172522" - ) - original_inpt.copy_(updated_inpt) - else: - fw_outs = all_outs - - # Step 4: Manually regenerate any outputs that are aliased to inputs, instead of - # compiling them. - if runtime_metadata.num_outputs_aliased > 0: - # The compiled forward also returned intermediate bases. We don't want to return them to the user. - expect_num_outputs = ( - len(output_handlers) + runtime_metadata.num_intermediate_bases - ) - if len(fw_outs) != expect_num_outputs: - raise AssertionError( - f"expected {expect_num_outputs} fw_outs, got {len(fw_outs)}" - ) - ret_outs = [ - handler(orig_inputs, fw_outs, out) - for out, handler in builtins.zip(fw_outs, output_handlers) - ] - else: - ret_outs = fw_outs + @simple_wraps(compiled_invoker.compiled_fn) + def runtime_wrapper(args: list[Any]) -> Any: + # Create context manager for profiler + cm = record_runtime_wrapper_prologue_enter() + prologue_exited = False + + def exit_prologue() -> None: + nonlocal prologue_exited + if not prologue_exited: + record_runtime_wrapper_prologue_exit(cm) + prologue_exited = True + + try: + # stash a ref to each input tensor we plan to use after the compiled function + orig_inputs = runtime_epilogue.capture_orig_inputs(args) + runtime_epilogue.increment_mutation_versions(args) + all_outs = compiled_invoker.run(args, on_before_call=exit_prologue) + finally: + exit_prologue() - if runtime_metadata.dynamic_outputs: - for t, o in zip(ret_outs, runtime_metadata.output_info): - if o.dynamic_dims is None: - continue - maybe_mark_dynamic_helper(t, o.dynamic_dims) - if runtime_metadata.grad_enabled_mutation is not None: - torch._C._set_grad_enabled(runtime_metadata.grad_enabled_mutation) - return ret_outs + del args + return runtime_epilogue.finalize(orig_inputs, all_outs) if not (trace_joint and _should_disable_saved_tensors_hooks()): return runtime_wrapper @@ -951,7 +1163,7 @@ def post_compile( *, runtime_metadata: ViewAndMutationMeta, ) -> Callable[..., Any]: - if self.maybe_subclass_meta is None: + if self.maybe_subclass_meta is None and not runtime_metadata.act_input_indices: return compiled_fn from .subclass_codegen import codegen_subclass_wrapper @@ -962,6 +1174,7 @@ def post_compile( out_metas=runtime_metadata.subclass_fw_graph_out_meta, num_fw_outs_saved_for_bw=self.num_fw_outs_saved_for_bw, frozen_inp_indices=self._get_frozen_inp_indices(), + act_input_indices=runtime_metadata.act_input_indices, ) inner_fn._boxed_call = True # type: ignore[attr-defined] return inner_fn @@ -1250,7 +1463,6 @@ def wrapped_flat_fn( flat_args_descs=deduped_flat_args_descs, static_input_indices=aot_config.static_input_indices, keep_input_mutations=fw_metadata.keep_input_mutations, - is_train=fw_metadata.is_train, )(*deduped_flat_args) if ref_fw_metadata != updated_fw_metadata: raise AssertionError( @@ -1274,11 +1486,23 @@ def post_compile( if not self.needs_post_compile: return compiled_fn - @wraps(compiled_fn) - def wrapped_compiled_fn(args: list[Any]) -> Any: - deduped_args = self.remove_dupe_args(args) - args.clear() - return compiled_fn(deduped_args) + keep_indices = [i for i, keep in enumerate(self.keep_arg_mask) if keep] + idx_list = ", ".join(f"args[{i}]" for i in keep_indices) + source = ( + f"def inner_fn(args):\n" + f" deduped_args = [{idx_list}]\n" + f" args.clear()\n" + f" return compiled_fn(deduped_args)\n" + ) + from .subclass_codegen import _compile_and_exec_source + + wrapped_compiled_fn: Callable[..., Any] = _compile_and_exec_source( # type: ignore[assignment] + source, + {"compiled_fn": compiled_fn}, + "inner_fn", + "dedup_wrapper", + wrapped_fn=compiled_fn, + ) wrapped_compiled_fn._boxed_call = True # type: ignore[attr-defined] @@ -1473,7 +1697,6 @@ def wrapped_flat_fn(*args: Any) -> Any: flat_args_descs=flat_args_descs_with_synthetic_bases, static_input_indices=aot_config.static_input_indices, keep_input_mutations=fw_metadata.keep_input_mutations, - is_train=fw_metadata.is_train, )(*flat_args_with_synthetic_bases) if ref_fw_metadata != fw_metadata_updated: raise AssertionError( @@ -1844,6 +2067,7 @@ def make_hashable(arg: Any) -> Any: raise AssertionError( "every argument in the inner calling convention should be accounted for" ) + # pyrefly: ignore [bad-return] return ( args_to_functionalization, args_to_functionalization_descs, @@ -1894,7 +2118,8 @@ def _backward_prologue_functional( ctx_opaque_objects: Sequence[Any], metadata: ViewAndMutationMeta, maybe_subclass_metadata: SubclassMeta | None, - *flat_args: Any, + flat_args: Sequence[Any], + codegen_unwrap_fn: Callable[..., Any] | None = None, ) -> list[Any]: # Calling convention: we expect a grad_out passed to the backward: # - for every output of the fw that does *not* alias an input or graph intermediate @@ -1938,6 +2163,12 @@ def _backward_prologue_functional( ], flat_args[num_mutated_runtime_inps + metadata.num_outputs :], ) + # Release grad refs from the caller's list (boxed calling convention). + # Slicing already copied refs into sub-lists above, so clearing the + # original list only drops redundant refs. The isinstance guard skips + # this when flat_args is a tuple (non-boxed path from compiled_autograd). + if isinstance(flat_args, list): + flat_args.clear() # input_info contains info on *every* input, # But in the backward(), we are only given grad outputs for every mutated input # We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad @@ -1989,7 +2220,7 @@ def _backward_prologue_functional( bw_tokens = [None] * metadata.num_backward_tokens - # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors) showing up first + # - note: donated buffer logic requires (*ctx.symints, *ctx.saved_tensors, *ctx.opaques) showing up first # in the bw output order. # Every dereference of ctx.saved_tensors incurs saved_tensors_hooks calls @@ -2075,22 +2306,14 @@ def _backward_prologue_functional( ) ) + if codegen_unwrap_fn is not None: + unwrap = codegen_unwrap_fn + else: + unwrap = _unwrap_no_symints all_args = ( - runtime_unwrap_tensor_subclasses( - all_args[:tangents_start_idx], # type: ignore[arg-type] - # SymInts that are inputs to the backward graph are - # already included in the "all_args" list. - # Any symints coming from tensor subclasses should always - # come from primals, and so they will show up as extra - # arguments to the forward graph, and they will be saved - # as activation in the backward graph. - append_symints=False, - ) + unwrap(all_args[:tangents_start_idx]) + flat_processed_tangents - + runtime_unwrap_tensor_subclasses( - all_args[tangents_end_idx:], # type: ignore[arg-type] - append_symints=False, - ) + + unwrap(all_args[tangents_end_idx:]) ) else: stack_traces = metadata.tangent_source_stack_traces or () @@ -2172,6 +2395,7 @@ def _backward_epilogue_functional( *, ctx_opaque_objects: Sequence[Any] = (), make_subclass_override: Callable[..., Any] | None = None, + codegen_wrap_fn: Callable[..., Any] | None = None, ) -> tuple[Any, ...]: # Toss out the backward output tokens num_bw_tokens = metadata.num_backward_tokens @@ -2205,6 +2429,8 @@ def _backward_epilogue_functional( if maybe_subclass_metadata is not None: if maybe_subclass_metadata.grad_input_metas is None: raise AssertionError("grad_input_metas must not be None") + if codegen_wrap_fn is not None and make_subclass_override is None: + return codegen_wrap_fn(out) outs_wrapped = wrap_tensor_subclasses( out, subclass_metas=maybe_subclass_metadata.grad_input_metas, @@ -2311,240 +2537,528 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.compiled_fn(*args, **kwargs) -# This is wrapped in a class just for namespacing purposes -# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly -class AOTDispatchAutograd: - @staticmethod - def _raise_tangent_metadata_error( - expected_type: type | None, - expected_meta: Any, - runtime_type: type, - runtime_meta: Any, - orig_x: torch.Tensor, - tangent_idx: int | None, - tangent_desc: Any | None, - compile_id_str: str | None, - tangent_stack_trace: str | None, - ) -> RuntimeError: - expected_subclass_got_plain_tensor = ( - expected_type is not None - and expected_type is not torch.Tensor - and runtime_type is torch.Tensor - ) - if expected_subclass_got_plain_tensor: - tangent_msg = "" - if tangent_idx is not None: - tangent_msg = f" (tangent index: {tangent_idx})" +@dataclass +class AOTDispatchAutogradCompileSpec: + compiled_fw_func: Callable[..., Any] + compiled_bw_func: Callable[..., Any] | None + maybe_subclass_meta: SubclassMeta | None + num_symints_saved_for_bw: int + backward_state_indices: list[int] + disable_amp: bool + indices_of_inps_to_detach: list[int] + lazy_backward_info: ( + AutogradLazyBackwardCompileInfo | CachedAutogradLazyBackwardCompileInfo | None + ) + aot_config: AOTConfig + fw_metadata: ViewAndMutationMeta + try_save_cache_entry: Callable[..., Any] | None - output_hint = "" - if tangent_desc is not None: - from .descriptors import PlainAOTOutput, TangentAOTInput - if isinstance(tangent_desc, TangentAOTInput) and isinstance( - tangent_desc.output, PlainAOTOutput - ): - idx = tangent_desc.output.idx - output_hint = f"\n\nThe problematic output is: forward output at index {idx} (0-indexed)" - else: - output_hint = ( - f"\n\nThe problematic output is: {tangent_desc.expr()}" - ) +@dataclass +class _AutogradSavedState: + metadata: ViewAndMutationMeta - graph_hint = "" - if compile_id_str is not None: - graph_hint = ( - f"\n\nThis error occurred in compiled graph [{compile_id_str}]." - ) + def save_from_forward(self, ctx: Any, fw_outs: Sequence[Any]) -> None: + tensors_saved_with_vc_check = fw_outs[ + self.metadata.tensors_saved_for_backwards_with_vc_check_slice + ] + tensors_saved_no_vc_check = fw_outs[ + self.metadata.tensors_saved_for_backwards_no_vc_check_slice + ] + if not all(isinstance(x, torch.Tensor) for x in tensors_saved_with_vc_check): + raise AssertionError( + "expected all tensors_saved_with_vc_check to be Tensors, " + f"got types: {[type(x) for x in tensors_saved_with_vc_check]}" + ) + if not all(isinstance(x, torch.Tensor) for x in tensors_saved_no_vc_check): + raise AssertionError( + "expected all tensors_saved_no_vc_check to be Tensors, " + f"got types: {[type(x) for x in tensors_saved_no_vc_check]}" + ) - stack_trace_hint = "" - if tangent_stack_trace is not None: - stack_trace_hint = ( - f"\n\nThe forward output was created here:\n{tangent_stack_trace}" - ) + # See Note [Detaching saved tensors in AOTAutograd] + num_vc_check = len(tensors_saved_with_vc_check) + tensors_to_save = [ + x.detach() if x._is_view() else x for x in tensors_saved_with_vc_check + ] + tensors_no_vc_check = [ + x.detach() if x._is_view() else x for x in tensors_saved_no_vc_check + ] - return RuntimeError( - f""" -During the backward, we encountered a tensor subclass where we guessed its -metadata incorrectly. -Expected a {expected_type.__name__} tangent but got a plain Tensor{tangent_msg}. -This happens when a compiled function returns multiple outputs that -require gradients, but .backward() is only called on some of them. -To fix: call .detach() on forward outputs you don't need gradients for.{output_hint}{graph_hint}{stack_trace_hint} + # dynamic_saved_tensors_idxs has indices relative to all saved tensors + # (vc_check + no_vc_check combined). Mark dynamics on the detached tensors. + for idx, dims in self.metadata.dynamic_saved_tensors_idxs.items(): + if idx < num_vc_check: + maybe_mark_dynamic_helper(tensors_to_save[idx], dims) + else: + maybe_mark_dynamic_helper(tensors_no_vc_check[idx - num_vc_check], dims) -This error is also more likely to occur if your compiled model is suffering -from a large number of graph breaks. For more advice on finding and fixing -graph breaks, see: -https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/compile/programming_model.graph_breaks_index.html + ctx.save_for_backward(*tensors_to_save) + ctx._tensors_no_vc_check = tensors_no_vc_check -For more info about this error, see: -https://github.com/pytorch/pytorch/issues/172556""" + symint_outs = fw_outs[self.metadata.symints_saved_for_backwards_slice] + if not all( + isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) + for x in symint_outs + ): + raise AssertionError( + "expected all symint_outs to be int/float/SymInt/SymFloat, " + f"got types: {[type(x) for x in symint_outs]}" ) - else: - return RuntimeError( - f""" -During the backward, we encountered a tensor subclass where we guessed its -metadata incorrectly. -Expected: {expected_meta} (type {expected_type}), -got: {runtime_meta} (type {runtime_type}), shape: {orig_x.shape}. -Your tensor subclass must implement __coerce_same_metadata_as_tangent__.""" + ctx.symints = symint_outs + + opaque_object_outs = fw_outs[ + self.metadata.opaque_objects_saved_for_backwards_slice + ] + if not all( + is_opaque_type(type(obj)) or isinstance(obj, OpaqueBase) + for obj in opaque_object_outs + ): + raise AssertionError( + "expected all opaque_object_outs to be opaque types, " + f"got types: {[type(obj) for obj in opaque_object_outs]}" ) + ctx.opaque_objects = opaque_object_outs - @staticmethod - def process_runtime_tangent( - x: Any, - meta: PlainTensorMeta | SubclassCreationMeta, - tangent_idx: int | None = None, - tangent_desc: Any | None = None, - compile_id_str: str | None = None, - tangent_stack_trace: str | None = None, - ) -> tuple[Any, list[Any]]: - if not isinstance(x, torch.Tensor): - return x, [x] + def load_tensors(self, ctx: Any) -> Sequence[torch.Tensor]: + if len(ctx._tensors_no_vc_check) > 0: + return list(ctx.saved_tensors) + ctx._tensors_no_vc_check + return ctx.saved_tensors - if isinstance(x, FakeTensor): - if not meta.memory_format: + +@dataclass +class _AutogradForwardEpilogue: + metadata: ViewAndMutationMeta + + def finalize(self, ctx: Any, fw_outs: Sequence[Any]) -> tuple[Any, ...]: + num_outputs = self.metadata.num_outputs + num_outputs_aliased = self.metadata.num_outputs_aliased + num_mutated_runtime_inps = self.metadata.num_mutated_inp_runtime_indices + num_forward_returns = self.metadata.num_forward_returns + + raw_returns = list(fw_outs[:num_forward_returns]) + + # Wrap all autograd.Function.forward() outputs that are aliases + # so that autograd.Function doesn't treat them as tensors + if num_mutated_runtime_inps > 0: + for i, idx in enumerate(self.metadata.mutated_inp_runtime_indices): + # We could make this faster by only looping over inputs with metadata-only mutations + # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. + info = self.metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + raw_returns[i] = TensorAlias(raw_returns[i]) + + if config.debug_assert: + user_mutated_inputs_raw = raw_returns[0:num_mutated_runtime_inps] + mut_inp_infos = [ + x + for x in self.metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + if len(user_mutated_inputs_raw) != len(mut_inp_infos): + raise AssertionError( + "expected len(user_mutated_inputs_raw) == len(mut_inp_infos), " + f"got {len(user_mutated_inputs_raw)} != {len(mut_inp_infos)}" + ) + + if self.metadata.num_unsafe_view_outputs > 0: + for idx in self.metadata.unsafe_view_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + o = raw_returns[raw_return_idx] + raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(o, o.shape) + + if num_outputs_aliased > 0: + for idx in self.metadata.aliased_out_indices: + raw_return_idx = num_mutated_runtime_inps + idx + raw_returns[raw_return_idx] = TensorAlias(raw_returns[raw_return_idx]) + + if config.debug_assert: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + if any(isinstance(x, TensorAlias) for x in intermediates_raw): + raise AssertionError("expected no TensorAlias in intermediates_raw") + + # invariant: intermediate bases always require gradients, so we don't have to + # consider marking them as non-differentiable. + raw_returns_not_including_intermediate_bases = raw_returns[ + : num_mutated_runtime_inps + num_outputs + ] + raw_returns_meta = [ + x + for x in self.metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + self.metadata.output_info + + fw_outs_not_requiring_grad = [ + x + for (i, x) in enumerate(raw_returns_not_including_intermediate_bases) + if isinstance(x, torch.Tensor) and not raw_returns_meta[i].requires_grad + ] + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + return tuple(raw_returns) + + +@dataclass +class _AutogradRngStateTracker: + num_rng: int + graphsafe_idx: int | None + fwd_rng_states: list[torch.Generator] = field(default_factory=list) + bwd_rng_states: list[torch.Generator] = field(default_factory=list) + curr_fwd_iter: Any = field(default_factory=lambda: itertools.count(0)) + backward_state_position: int = 0 + pending_forwards: set[int] = field(default_factory=set) + saved_backward_tensor_states: dict[int, list[torch.Tensor]] = field( + default_factory=dict + ) + + def add_forward_args(self, ctx: Any, args: tuple[Any, ...]) -> tuple[Any, ...]: + if self.num_rng == 0: + return args + + if len(self.fwd_rng_states) == 0: + if self.graphsafe_idx is None: + raise AssertionError("graphsafe_idx must not be None when num_rng > 0") + initialize_rng_states( + self.num_rng, + self.graphsafe_idx, + self.fwd_rng_states, + self.bwd_rng_states, + ) + + curr_iter = next(self.curr_fwd_iter) + ctx._curr_iter = curr_iter + + # if this state is not contained in the backward, + # we need to save it for when its backward pass happens + if curr_iter != self.backward_state_position: + self.saved_backward_tensor_states[curr_iter] = [ + rng_state.get_state() for rng_state in self.fwd_rng_states + ] + + self.pending_forwards.add(curr_iter) + return (*args, *self.fwd_rng_states) + + def add_backward_args(self, ctx: Any, all_args: list[Any]) -> None: + if self.num_rng == 0: + return + + curr_backward_iter = ctx._curr_iter + retain_graph = torch._C._autograd._get_current_graph_task_keep_graph() + + # Save current state if we have a pending forward that needs this state + # or this state may be needed again because of retain graph + if ( + self.backward_state_position in self.pending_forwards + and self.backward_state_position not in self.saved_backward_tensor_states + and (self.backward_state_position != curr_backward_iter or retain_graph) + ): + self.saved_backward_tensor_states[self.backward_state_position] = [ + rng_state.get_state() for rng_state in self.bwd_rng_states + ] + + # Restore saved states if needed + if curr_backward_iter in self.saved_backward_tensor_states: + if self.backward_state_position != curr_backward_iter: + for bwd_state, saved_state in zip( + self.bwd_rng_states, + self.saved_backward_tensor_states[curr_backward_iter], + ): + bwd_state.set_state(saved_state) + if not retain_graph: + del self.saved_backward_tensor_states[curr_backward_iter] + else: + if self.backward_state_position != curr_backward_iter: raise AssertionError( - "meta.memory_format must not be None for FakeTensor" + "expected backward_state_position == curr_backward_iter, " + f"got {self.backward_state_position} != {curr_backward_iter}" ) - x = coerce_to_expected_memory_format(x, meta.memory_format) - return x, [x] - expected_type: type | None = torch.Tensor - expected_meta = None - if isinstance(meta, SubclassCreationMeta): - expected_type = meta.original_subclass_type - expected_meta = meta.meta + self.backward_state_position = curr_backward_iter + 1 + if not retain_graph: + self.pending_forwards.remove(curr_backward_iter) + all_args.extend(self.bwd_rng_states) - runtime_type = type(x) - # When we're inside compiled autograd's AOTDispatcher step, - # regular Tensors look like FunctionalTensors. - # Tensor subclasses still look like Tensor subclasses though. - if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): - runtime_type = torch.Tensor - runtime_meta = None - runtime_subclass_keys: Sequence[str] = [] +@dataclass +class _AutogradBackwardCompiler: + compiled_bw: Callable[..., Any] | None + lazy_backward_info: ( + AutogradLazyBackwardCompileInfo | CachedAutogradLazyBackwardCompileInfo | None + ) + disable_amp: bool + aot_config: AOTConfig + fw_metadata: ViewAndMutationMeta + try_save_cache_entry: Callable[..., Any] | None - if is_traceable_wrapper_subclass(x): - runtime_subclass_keys, runtime_meta = x.__tensor_flatten__() + def get_or_compile(self, *, saved_tensors_use_once: bool) -> Callable[..., Any]: + if self.compiled_bw is not None: + return self.compiled_bw - def maybe_coerce(x: torch.Tensor) -> torch.Tensor | None: - same_type: bool = expected_type == runtime_type - same_meta: bool = expected_meta == runtime_meta + if self.lazy_backward_info is None: + raise AssertionError("lazy_backward_info must not be None") + if not isinstance(self.lazy_backward_info, AutogradLazyBackwardCompileInfo): + raise AssertionError( + "expected AutogradLazyBackwardCompileInfo, " + f"got {type(self.lazy_backward_info)}" + ) - if same_type and same_meta: - return x + self._prepare_lazy_backward_context(saved_tensors_use_once) + + bw_module = self.lazy_backward_info.bw_module + placeholder_list = self.lazy_backward_info.placeholder_list + saved_context = self.lazy_backward_info.saved_context + saved_compile_context = self.lazy_backward_info.saved_compile_context + + context = torch._C._DisableAutocast if self.disable_amp else nullcontext + metrics_context = get_metrics_context() + with ( + tracing(saved_context), + compile_context(saved_compile_context), + context(), + track_graph_compiling(self.aot_config, "backward"), + metrics_context, + dynamo_timed( + "backward._backward_impl", + phase_name="entire_backward_compile", + log_pt2_compile_event=True, + dynamo_compile_column_us="backward_cumulative_compile_time_us", + log_waitcounter=True, + waitcounter_name_override="entire_backward_compile", + ), + callback_handler.install_callbacks( + CallbackTrigger.LAZY_BACKWARD, + str(CompileContext.current_compile_id()), + ), + ): + CompileEventLogger.compilation_metric(is_forward=False) + # See Note: [Backward graph lazy lowering] + if self.aot_config.bw_compiler is None: + raise AssertionError("aot_config.bw_compiler must not be None") + self.compiled_bw = self.aot_config.bw_compiler( + copy.deepcopy(bw_module), placeholder_list + ) + # Maybe save cache entry + if self.try_save_cache_entry is not None: + self.try_save_cache_entry( + self.compiled_bw, + bw_module, + self.fw_metadata, + self.aot_config, + ) - if not hasattr(x, "__coerce_same_metadata_as_tangent__"): - return None + return self.compiled_bw - if same_type: - # Backward Compatibility, as some Subclass impls can have original 1-arg function. - return x.__coerce_same_metadata_as_tangent__(expected_meta) + def _prepare_lazy_backward_context(self, saved_tensors_use_once: bool) -> None: + if self.lazy_backward_info is None: + raise AssertionError("lazy_backward_info must not be None") + if not isinstance(self.lazy_backward_info, AutogradLazyBackwardCompileInfo): + raise AssertionError( + "expected AutogradLazyBackwardCompileInfo, " + f"got {type(self.lazy_backward_info)}" + ) - return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type) + if ( + hasattr(self.lazy_backward_info, "saved_context") + and self.lazy_backward_info.saved_context is not None + ): + if not isinstance(self.lazy_backward_info.saved_context, TracingContext): + raise AssertionError( + f"expected TracingContext, got {type(self.lazy_backward_info.saved_context)}" + ) + ddp_ctx = self.lazy_backward_info.saved_context.ddp_optimizer_ctx + if ddp_ctx is not None: + if ddp_ctx.curr_bucket < 0: + raise AssertionError( + "expected same # of fw and bw compiles, " + f"but found bucket {ddp_ctx.curr_bucket}" + ) + curr_fw_meta = ddp_ctx.metadata_per_bucket[ddp_ctx.curr_bucket] + # Note [DDPOptimizer and fw_metadata] + # When using the DDPOptimizer, we have a single dynamo graph (and TracingContext), + # but multiple AOTDispatcher graph. + # + # One consequence is that there will be **multiple** fw_metadata objects, one per AOT graph, + # which we stash the fw_metadata on the TracingContext. + # + # Normally what happens is that as we compile AOT graphs 1...N, we clobber the fw_metadata + # for graph i-1 when we start running AOT for graph i. + # Ordinarily this is fine, because inductor no longer needs the metadata from graph i-1. + # + # However, this is a problem for lazy compilation of the backward. During backward compilation, + # we compile the backward lazily at backward runtime, meaning that we will first compile + # backward graph N, N-1, ..., 1. + # We need to ensure that at the time inductor compiles bw graph N-1, it can access + # the corresponding fw_metadta for graph N-1. + # + # We do this by stashing a DDPOptimizerContext, which tracks: + # - the metadata of all N graphs + # - the graph we are currently compiling in our DDPOptimizer region. + ddp_ctx.curr_bucket -= 1 + self.lazy_backward_info.saved_context.fw_metadata = curr_fw_meta + + if not saved_tensors_use_once: + self.fw_metadata.bw_donated_idxs = [] + # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` + if ( + hasattr(self.lazy_backward_info, "saved_context") + and hasattr(self.lazy_backward_info.saved_context, "fw_metadata") + and hasattr( + self.lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] + "bw_donated_idxs", + ) + ): + self.lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] + # pyrefly: ignore [implicit-any] + [] + ) - # Coerce to expected type and metadata - orig_x = x - x = maybe_coerce(x) - if x is None: - raise AOTDispatchAutograd._raise_tangent_metadata_error( - expected_type, - expected_meta, - runtime_type, - runtime_meta, - orig_x, - tangent_idx, - tangent_desc, - compile_id_str, - tangent_stack_trace, + +@dataclass +class _AOTDispatchAutogradFunctionFactory: + spec: AOTDispatchAutogradCompileSpec + + def build(self) -> type[torch.autograd.Function]: + compile_id = CompileContext.current_compile_id() + compile_id_str = str(compile_id) if compile_id is not None else None + self.spec.fw_metadata.compile_id_str = compile_id_str + + saved_state = _AutogradSavedState(self.spec.fw_metadata) + forward_epilogue = _AutogradForwardEpilogue(self.spec.fw_metadata) + rng_state = _AutogradRngStateTracker( + num_rng=self.spec.fw_metadata.num_graphsafe_rng_states, + graphsafe_idx=self.spec.fw_metadata.graphsafe_rng_state_index, + ) + backward_compiler = _AutogradBackwardCompiler( + compiled_bw=self.spec.compiled_bw_func, + lazy_backward_info=self.spec.lazy_backward_info, + disable_amp=self.spec.disable_amp, + aot_config=self.spec.aot_config, + fw_metadata=self.spec.fw_metadata, + try_save_cache_entry=self.spec.try_save_cache_entry, + ) + + compiled_fw_func = self.spec.compiled_fw_func + compiled_bw_func = self.spec.compiled_bw_func + maybe_subclass_meta = self.spec.maybe_subclass_meta + num_symints_saved_for_bw_ = self.spec.num_symints_saved_for_bw + backward_state_indices = self.spec.backward_state_indices + disable_amp = self.spec.disable_amp + lazy_backward_info = self.spec.lazy_backward_info + aot_config = self.spec.aot_config + fw_metadata = self.spec.fw_metadata + + _codegen_bw_unwrap_fn = None + _codegen_bw_wrap_fn = None + if maybe_subclass_meta is not None: + from .subclass_codegen import codegen_backward_subclass_fns + + _codegen_bw_unwrap_fn, _codegen_bw_wrap_fn = codegen_backward_subclass_fns( + grad_input_metas=maybe_subclass_meta.grad_input_metas, ) - # Coerce to expected memory format - if not meta.memory_format: - raise AssertionError("meta.memory_format must not be None") - x = coerce_to_expected_memory_format(x, meta.memory_format) + # Codegen for CompiledFunction.forward: emit straight-line TensorAlias + # wrapping, _unsafe_view, and non-differentiable output collection with + # all indices resolved at compile time. + num_mutated_runtime_inps = fw_metadata.num_mutated_inp_runtime_indices + num_outputs = fw_metadata.num_outputs + num_outputs_aliased = fw_metadata.num_outputs_aliased + + _xform_lines = ["def _transform_raw_returns(raw_returns):"] + _xform_globals: dict[str, object] = { + "TensorAlias": TensorAlias, + "torch": torch, + "Tensor": Tensor, + } + + for i, idx in enumerate(fw_metadata.mutated_inp_runtime_indices): + info = fw_metadata.input_info[idx] + if info.mutates_metadata and not info.mutates_data: + _xform_lines.append( + f" raw_returns[{i}] = TensorAlias(raw_returns[{i}])" + ) - if not is_traceable_wrapper_subclass(x): - return x, [x] + if fw_metadata.num_unsafe_view_outputs > 0: + for idx in fw_metadata.unsafe_view_out_indices: + ri = num_mutated_runtime_inps + idx + _xform_lines.append(f" _o = raw_returns[{ri}]") + _xform_lines.append( + f" raw_returns[{ri}] = torch.ops.aten._unsafe_view(_o, _o.shape)" + ) - if not isinstance(meta, SubclassCreationMeta): - raise AssertionError(f"expected SubclassCreationMeta, got {type(meta)}") - if orig_x is not x: - runtime_subclass_keys = x.__tensor_flatten__()[0] + if num_outputs_aliased > 0: + for idx in fw_metadata.aliased_out_indices: + ri = num_mutated_runtime_inps + idx + _xform_lines.append( + f" raw_returns[{ri}] = TensorAlias(raw_returns[{ri}])" + ) - if len(meta.attrs) != len(runtime_subclass_keys): - raise AssertionError( - f"expected len(meta.attrs) == len(runtime_subclass_keys), " - f"got {len(meta.attrs)} != {len(runtime_subclass_keys)}" + # Non-differentiable output collection: build a list of specific indices + # at compile time rather than iterating at runtime. + _non_diff_indices: list[int] = [] + _returns_meta = [ + x + for x in fw_metadata.input_info + if x.mutation_type == MutationType.MUTATED_OUT_GRAPH + ] + list(fw_metadata.output_info) + for i, meta in enumerate(_returns_meta): + if i < num_mutated_runtime_inps + num_outputs and not meta.requires_grad: + _non_diff_indices.append(i) + if _non_diff_indices: + checks = " + ".join( + f"([raw_returns[{i}]] if isinstance(raw_returns[{i}], Tensor) else [])" + for i in _non_diff_indices ) - leaves = [] - for attr, attr_meta in meta.attrs.items(): - if isinstance(attr_meta, OpaqueMeta): - # Opaques aren't differentiable but occupy a flat arg slot. - leaves.append(getattr(x, attr)) - continue - elem = getattr(x, attr) - new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( - elem, attr_meta + _xform_lines.append(f" non_diff = {checks}") + else: + _xform_lines.append(" non_diff = []") + _xform_lines.append(" return non_diff") + + _xform_source = "\n".join(_xform_lines) + + from .subclass_codegen import _compile_and_exec_source + + _codegen_transform_raw_returns: Callable[..., list[Any]] = ( + _compile_and_exec_source( # type: ignore[assignment] + _xform_source, + _xform_globals, + "_transform_raw_returns", + "compiled_fn_wrapper", ) - if new_elem is not elem: - setattr(x, attr, new_elem) - leaves.extend(elem_leaves) + ) - return x, leaves + # Monkey-patch forward_epilogue.finalize to use codegen'd transform + def _codegen_finalize(ctx: Any, fw_outs: Any) -> tuple[Any, ...]: + num_forward_returns = fw_metadata.num_forward_returns + raw_returns = list(fw_outs[:num_forward_returns]) + fw_outs_not_requiring_grad = _codegen_transform_raw_returns(raw_returns) + if config.debug_assert: + if num_mutated_runtime_inps > 0: + user_mutated_inputs_raw = raw_returns[0:num_mutated_runtime_inps] + mut_inp_infos = [ + x + for x in fw_metadata.input_info + if x.mutates_data or x.mutates_metadata + ] + if len(user_mutated_inputs_raw) != len(mut_inp_infos): + raise AssertionError( + f"expected len(user_mutated_inputs_raw) == len(mut_inp_infos), " + f"got {len(user_mutated_inputs_raw)} != {len(mut_inp_infos)}" + ) + if num_outputs_aliased > 0: + intermediates_raw = raw_returns[ + num_mutated_runtime_inps + num_outputs : + ] + if any(isinstance(x, TensorAlias) for x in intermediates_raw): + raise AssertionError( + "expected no TensorAlias in intermediates_raw" + ) + ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) + ctx._materialize_non_diff_grads = False + return tuple(raw_returns) - @staticmethod - def post_compile( - compiled_fw_func: Callable[..., Any], # fw_module after compilation + wrappers - compiled_bw_func: Callable[..., Any] - | None, # bw_module after compilation + wrappers - maybe_subclass_meta: SubclassMeta | None, - num_symints_saved_for_bw_: int, - backward_state_indices: list[int], - disable_amp: bool, - indices_of_inps_to_detach: list[int], - lazy_backward_info: AutogradLazyBackwardCompileInfo - | CachedAutogradLazyBackwardCompileInfo - | None, - aot_config: AOTConfig, - *, - fw_metadata: ViewAndMutationMeta, # runtime metadata - try_save_cache_entry: Callable[..., Any] | None, # Serialization function - ) -> Callable[..., Any]: - # For additional context see Note [CUDA Graph Safe RNG Functionalization] - # Each pair forward, backward rng states must be equal prior to its invocation on any - # iteration of forward, backward. Because they are initialized equal, and are computing the same rng op, - # running forward then backward advances them the same amount and keeps them equal. - # However, a user may invoke multiple forwards, then backwards, such that they are not in sync. - # Initially we have: - # fwd_state0 == bwd_state0. - # Lets say we run: - # fwd0: fwd_state0 -> fwd_state1 - # fwd1: fwd_state1 -> fwd_state2 - # fwd2: fwd_state2 -> fwd_state3 - # If we now invoke bwd2, - # we need to update bwd_state equal to the rng that was observed in fwd2. - # we save the rng_state fwd_state2 in forward because we detect that it is not the - # current backward state and therefore would not be accessible if we do not save it. - # Similarly, if we are going to update the backward state to a new value, and there is a pending - # forwards which needs its current state, we will save it. - # Within the autograd context, we keep track of the curr iteration so that on backward - # we know what the generator state must be before the backward is run. - num_rng = fw_metadata.num_graphsafe_rng_states - graphsafe_idx = fw_metadata.graphsafe_rng_state_index - fwd_rng_states: list[torch.Generator] = [] - bwd_rng_states: list[torch.Generator] = [] - curr_fwd_iter = itertools.count(0) - backward_state_position = 0 - pending_forwards: set[int] = set() - saved_backward_tensor_states: dict[int, list[torch.Tensor]] = {} - - # capture the compile_id at compile time for error messages - _compile_id = CompileContext.current_compile_id() - _compile_id_str = str(_compile_id) if _compile_id is not None else None - # store on metadata so it's accessible during backward error handling - fw_metadata.compile_id_str = _compile_id_str + forward_epilogue.finalize = _codegen_finalize # type: ignore[method-assign] class CompiledFunction(torch.autograd.Function): compiled_fw = compiled_fw_func @@ -2554,6 +3068,9 @@ class CompiledFunction(torch.autograd.Function): num_symints_saved_for_bw = num_symints_saved_for_bw_ _aot_id = aot_config.aot_id _lazy_backward_info = lazy_backward_info + _bw_epilogue_wrap_fn = _codegen_bw_wrap_fn + _bw_prologue_unwrap_fn = _codegen_bw_unwrap_fn + boxed_grads_call = True @staticmethod def _compiled_autograd_key(ctx: Any) -> tuple[Any, ...]: @@ -2571,28 +3088,7 @@ def forward(ctx: Any, *deduped_flat_tensor_args: Any) -> Any: ) ctx._compiled_autograd_backward_state = bw_state - if num_rng: - if len(fwd_rng_states) == 0: - if graphsafe_idx is None: - raise AssertionError( - "graphsafe_idx must not be None when num_rng > 0" - ) - initialize_rng_states( - num_rng, graphsafe_idx, fwd_rng_states, bwd_rng_states - ) - - _curr_iter = next(curr_fwd_iter) - ctx._curr_iter = _curr_iter - - # if this state is not contained in the backward, - # we need to save it for when its backward pass happens - if _curr_iter != backward_state_position: - saved_backward_tensor_states[_curr_iter] = [ - rng_state.get_state() for rng_state in fwd_rng_states - ] - - pending_forwards.add(_curr_iter) - args = (*args, *fwd_rng_states) + args = rng_state.add_forward_args(ctx, args) # There is a pretty complicated calling convention around what the compiled fw returns. # The full list of outputs and their relative order is: @@ -2608,231 +3104,39 @@ def forward(ctx: Any, *deduped_flat_tensor_args: Any) -> Any: disable_amp=disable_amp, ) - num_outputs = CompiledFunction.metadata.num_outputs - num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased - num_mutated_runtime_inps = ( - CompiledFunction.metadata.num_mutated_inp_runtime_indices - ) - num_forward_returns = CompiledFunction.metadata.num_forward_returns - - # See Note [Activations with no version counter checks in eager] - # Partitioners must put symint arguments at the end separate from tensor arguments - # Split tensors into those that need VC checks (via save_for_backward) - # and those that don't (stashed directly on ctx). - # The partitioner sorts tensors so that no-VC-check tensors are at the end. - tensors_saved_with_vc_check = fw_outs[ - CompiledFunction.metadata.tensors_saved_for_backwards_with_vc_check_slice - ] - tensors_saved_no_vc_check = fw_outs[ - CompiledFunction.metadata.tensors_saved_for_backwards_no_vc_check_slice - ] - if not all( - isinstance(x, torch.Tensor) for x in tensors_saved_with_vc_check - ): - raise AssertionError( - f"expected all tensors_saved_with_vc_check to be Tensors, " - f"got types: {[type(x) for x in tensors_saved_with_vc_check]}" - ) - if not all( - isinstance(x, torch.Tensor) for x in tensors_saved_no_vc_check - ): - raise AssertionError( - f"expected all tensors_saved_no_vc_check to be Tensors, " - f"got types: {[type(x) for x in tensors_saved_no_vc_check]}" - ) - - # See Note [Detaching saved tensors in AOTAutograd] - num_vc_check = len(tensors_saved_with_vc_check) - tensors_to_save = [ - x.detach() if x._is_view() else x - for x in tensors_saved_with_vc_check - ] - tensors_no_vc = [ - x.detach() if x._is_view() else x for x in tensors_saved_no_vc_check - ] - - # dynamic_saved_tensors_idxs has indices relative to all saved tensors - # (vc_check + no_vc_check combined). Mark dynamics on the detached tensors. - for ( - idx, - dims, - ) in CompiledFunction.metadata.dynamic_saved_tensors_idxs.items(): - if idx < num_vc_check: - maybe_mark_dynamic_helper(tensors_to_save[idx], dims) - else: - maybe_mark_dynamic_helper( - tensors_no_vc[idx - num_vc_check], dims - ) - - # Only save tensors that need VC checks via save_for_backward - ctx.save_for_backward(*tensors_to_save) - ctx._tensors_no_vc_check = tensors_no_vc - - symint_outs = fw_outs[ - CompiledFunction.metadata.symints_saved_for_backwards_slice - ] - if not all( - isinstance(x, (int, float, torch.SymInt, torch.SymFloat)) - for x in symint_outs - ): - raise AssertionError( - f"expected all symint_outs to be int/float/SymInt/SymFloat, " - f"got types: {[type(x) for x in symint_outs]}" - ) - ctx.symints = symint_outs - - opaque_object_outs = fw_outs[ - CompiledFunction.metadata.opaque_objects_saved_for_backwards_slice - ] - if not all( - is_opaque_type(type(obj)) or isinstance(obj, OpaqueBase) - for obj in opaque_object_outs - ): - raise AssertionError( - f"expected all opaque_object_outs to be opaque types, " - f"got types: {[type(obj) for obj in opaque_object_outs]}" - ) - ctx.opaque_objects = opaque_object_outs - - raw_returns = fw_outs[0:num_forward_returns] - - # Wrap all autograd.Function.forward() outputs that are aliases - # so that autograd.Function doesn't treat them as tensors - if num_mutated_runtime_inps > 0: - for i, idx in enumerate( - CompiledFunction.metadata.mutated_inp_runtime_indices - ): - # We could make this faster by only looping over inputs with metadata-only mutations - # (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many. - info = CompiledFunction.metadata.input_info[idx] - if info.mutates_metadata and not info.mutates_data: - raw_return_idx = i - raw_returns[raw_return_idx] = TensorAlias( - raw_returns[raw_return_idx] - ) - - if config.debug_assert: - user_mutated_inputs_raw = raw_returns[ - 0:num_mutated_runtime_inps - ] - mut_inp_infos = [ - x - for x in CompiledFunction.metadata.input_info - if x.mutates_data or x.mutates_metadata - ] - if len(user_mutated_inputs_raw) != len(mut_inp_infos): - raise AssertionError( - f"expected len(user_mutated_inputs_raw) == len(mut_inp_infos), " - f"got {len(user_mutated_inputs_raw)} != {len(mut_inp_infos)}" - ) - - if CompiledFunction.metadata.num_unsafe_view_outputs > 0: - for idx in CompiledFunction.metadata.unsafe_view_out_indices: - raw_return_idx = num_mutated_runtime_inps + idx - o = raw_returns[raw_return_idx] - raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view( - o, o.shape - ) - - if num_outputs_aliased > 0: - for idx in CompiledFunction.metadata.aliased_out_indices: - raw_return_idx = num_mutated_runtime_inps + idx - raw_returns[raw_return_idx] = TensorAlias( - raw_returns[raw_return_idx] - ) - - if config.debug_assert: - intermediates_raw = raw_returns[ - num_mutated_runtime_inps + num_outputs : - ] - if any(isinstance(x, TensorAlias) for x in intermediates_raw): - raise AssertionError( - "expected no TensorAlias in intermediates_raw" - ) - - # invariant: intermediate bases always require gradients, so we don't have to - # consider marking them as non-differentiable. - raw_returns_not_including_intermediate_bases = raw_returns[ - : num_mutated_runtime_inps + num_outputs - ] - raw_returns_meta = [ - x - for x in CompiledFunction.metadata.input_info - if x.mutation_type == MutationType.MUTATED_OUT_GRAPH - ] + CompiledFunction.metadata.output_info - - fw_outs_not_requiring_grad = [ - x - for (i, x) in enumerate( - raw_returns_not_including_intermediate_bases - ) - if isinstance(x, torch.Tensor) - and not raw_returns_meta[i].requires_grad - ] - ctx.mark_non_differentiable(*fw_outs_not_requiring_grad) - ctx._materialize_non_diff_grads = False - return tuple(raw_returns) + saved_state.save_from_forward(ctx, fw_outs) + return forward_epilogue.finalize(ctx, fw_outs) @staticmethod def backward(ctx: Any, *flat_args: Any) -> tuple[Any, ...]: - # Combine tensors from both sources: - # 1. ctx.saved_tensors - tensors that went through save_for_backward (with VC check) - # 2. ctx._tensors_no_vc_check - tensors stashed directly on ctx (no VC check) + # With boxed_grads_call, grads arrive as a single mutable + # list (not *args) so backward can free them individually + # to reduce peak memory. + if CompiledFunction.boxed_grads_call: + if len(flat_args) != 1 or not isinstance(flat_args[0], list): + raise AssertionError( + "boxed_grads_call is set but backward received " + f"{len(flat_args)} args instead of a single mutable " + "list. When boxed_grads_call=True, grads must be " + "passed as a single list argument [grad0, grad1, ...] " + "to allow freeing individual grads mid-backward." + ) + grad_args = flat_args[0] + else: + # Non-boxed path: used by subclasses of CompiledFunction + # that override boxed_grads_call to False. + grad_args = list(flat_args) + del flat_args all_args = _backward_prologue_functional( - ( - list(ctx.saved_tensors) + ctx._tensors_no_vc_check - if len(ctx._tensors_no_vc_check) > 0 - else ctx.saved_tensors - ), + saved_state.load_tensors(ctx), ctx.symints, ctx.opaque_objects, CompiledFunction.metadata, CompiledFunction.maybe_subclass_metadata, - *flat_args, + grad_args, + codegen_unwrap_fn=CompiledFunction._bw_prologue_unwrap_fn, ) - - if num_rng: - nonlocal backward_state_position, bwd_rng_states - curr_backward_iter = ctx._curr_iter - retain_graph = ( - torch._C._autograd._get_current_graph_task_keep_graph() - ) - - # Save current state if we have a pending forward that needs this state - # or this state may be needed again because of retain graph - if ( - backward_state_position in pending_forwards - and backward_state_position not in saved_backward_tensor_states - and ( - backward_state_position != curr_backward_iter - or retain_graph - ) - ): - saved_backward_tensor_states[backward_state_position] = [ - rng_state.get_state() for rng_state in bwd_rng_states - ] - - # Restore saved states if needed - if curr_backward_iter in saved_backward_tensor_states: - if backward_state_position != curr_backward_iter: - for bwd_state, saved_state in zip( - bwd_rng_states, - saved_backward_tensor_states[curr_backward_iter], - ): - bwd_state.set_state(saved_state) - if not retain_graph: - del saved_backward_tensor_states[curr_backward_iter] - else: - if backward_state_position != curr_backward_iter: - raise AssertionError( - f"expected backward_state_position == curr_backward_iter, " - f"got {backward_state_position} != {curr_backward_iter}" - ) - - backward_state_position = curr_backward_iter + 1 - if not retain_graph: - pending_forwards.remove(curr_backward_iter) - all_args.extend(bwd_rng_states) + rng_state.add_backward_args(ctx, all_args) def impl_fn(double_ctx: Any = None) -> Any: out = CompiledFunction._backward_impl(ctx, all_args) @@ -2840,16 +3144,22 @@ def impl_fn(double_ctx: Any = None) -> Any: CompiledFunction.metadata, CompiledFunction.maybe_subclass_metadata, out, + codegen_wrap_fn=CompiledFunction._bw_epilogue_wrap_fn, ) + if ( + torch._C._is_key_in_tls("context") + and (config_ctx := torch._C._get_obj_in_tls("context")) is not None + ): + impl_fn = functools.partial(config_ctx.run, impl_fn) + needs_grad = torch.is_grad_enabled() and any( t.requires_grad for t in all_args if isinstance(t, torch.Tensor) ) if needs_grad: # double backward return CompiledFunction._double_backward(ctx, impl_fn, all_args) - else: - return impl_fn() + return impl_fn() @staticmethod def _double_backward( @@ -2889,118 +3199,10 @@ def _backward_impl(ctx: Any, all_args: list[Any]) -> Any: saved_tensors_use_once = ( not torch._C._autograd._get_current_graph_task_keep_graph() ) - - if CompiledFunction.compiled_bw is None: - if lazy_backward_info is None: - raise AssertionError("lazy_backward_info must not be None") - if not isinstance( - lazy_backward_info, AutogradLazyBackwardCompileInfo - ): - raise AssertionError( - f"expected AutogradLazyBackwardCompileInfo, got {type(lazy_backward_info)}" - ) - - if ( - hasattr(lazy_backward_info, "saved_context") - and lazy_backward_info.saved_context is not None - ): - if not isinstance( - lazy_backward_info.saved_context, TracingContext - ): - raise AssertionError( - f"expected TracingContext, got {type(lazy_backward_info.saved_context)}" - ) - ddp_ctx = lazy_backward_info.saved_context.ddp_optimizer_ctx - if ddp_ctx is not None: - if ddp_ctx.curr_bucket < 0: - raise AssertionError( - f"expected same # of fw and bw compiles, but found bucket {ddp_ctx.curr_bucket}" - ) - curr_fw_meta = ddp_ctx.metadata_per_bucket[ - ddp_ctx.curr_bucket - ] - # Note [DDPOptimizer and fw_metadata] - # When using the DDPOptimizer, we have a single dynamo graph (and TracingContext), - # but multiple AOTDispatcher graph. - # - # One consequence is that there will be **multiple** fw_metadata objects, one per AOT graph, - # which we stash the fw_metadata on the TracingContext. - # - # Normally what happens is that as we compile AOT graphs 1...N, we clobber the fw_metadata - # for graph i-1 when we start running AOT for graph i. - # Ordinarily this is fine, because inductor no longer needs the metadata from graph i-1. - # - # However, this is a problem for lazy compilation of the backward. During backward compilation, - # we compile the backward lazily at backward runtime, meaning that we will first compile - # backward graph N, N-1, ..., 1. - # We need to ensure that at the time inductor compiles bw graph N-1, it can access - # the corresponding fw_metadta for graph N-1. - # - # We do this by stashing a DDPOptimizerContext, which tracks: - # - the metadata of all N graphs - # - the graph we are currently compiling in our DDPOptimizer region. - ddp_ctx.curr_bucket -= 1 - lazy_backward_info.saved_context.fw_metadata = curr_fw_meta - - if not saved_tensors_use_once: - fw_metadata.bw_donated_idxs = [] - # Update bw_donated_idxs if using lazy_backward_info from `aot_dispatch_autograd` - if ( - hasattr(lazy_backward_info, "saved_context") - and hasattr(lazy_backward_info.saved_context, "fw_metadata") - and hasattr( - lazy_backward_info.saved_context.fw_metadata, # type: ignore[union-attr] - "bw_donated_idxs", - ) - ): - lazy_backward_info.saved_context.fw_metadata.bw_donated_idxs = ( # type: ignore[union-attr] - # pyrefly: ignore [implicit-any] - [] - ) - - bw_module = lazy_backward_info.bw_module - placeholder_list = lazy_backward_info.placeholder_list - saved_context = lazy_backward_info.saved_context - saved_compile_context = lazy_backward_info.saved_compile_context - - context = torch._C._DisableAutocast if disable_amp else nullcontext - metrics_context = get_metrics_context() - with ( - tracing(saved_context), - compile_context(saved_compile_context), - context(), - track_graph_compiling(aot_config, "backward"), - metrics_context, - dynamo_timed( - "backward._backward_impl", - phase_name="entire_backward_compile", - log_pt2_compile_event=True, - dynamo_compile_column_us="backward_cumulative_compile_time_us", - log_waitcounter=True, - waitcounter_name_override="entire_backward_compile", - ), - callback_handler.install_callbacks( - CallbackTrigger.LAZY_BACKWARD, - str(CompileContext.current_compile_id()), - ), - ): - CompileEventLogger.compilation_metric(is_forward=False) - # See Note: [Backward graph lazy lowering] - if aot_config.bw_compiler is None: - raise AssertionError( - "aot_config.bw_compiler must not be None" - ) - CompiledFunction.compiled_bw = aot_config.bw_compiler( - copy.deepcopy(bw_module), placeholder_list - ) - # Maybe save cache entry - if try_save_cache_entry is not None: - try_save_cache_entry( - CompiledFunction.compiled_bw, - bw_module, - fw_metadata, - aot_config, - ) + compiled_bw = backward_compiler.get_or_compile( + saved_tensors_use_once=saved_tensors_use_once + ) + CompiledFunction.compiled_bw = compiled_bw if ( torch._functorch.config.donated_buffer @@ -3019,26 +3221,212 @@ def _backward_impl(ctx: Any, all_args: list[Any]) -> Any: ), ) - out = call_func_at_runtime_with_args( - CompiledFunction.compiled_bw, + return call_func_at_runtime_with_args( + compiled_bw, all_args, steal_args=True, disable_amp=disable_amp, ) - return out - compiled_function = RuntimeWrapper( - indices_of_inps_to_detach=indices_of_inps_to_detach, + return CompiledFunction + + +# This is wrapped in a class just for namespacing purposes +# No need to make it into an actual CompilerWrapper because it doesn't fit the abstract as cleanly +class AOTDispatchAutograd: + @staticmethod + def _raise_tangent_metadata_error( + expected_type: type | None, + expected_meta: Any, + runtime_type: type, + runtime_meta: Any, + orig_x: torch.Tensor, + tangent_idx: int | None, + tangent_desc: Any | None, + compile_id_str: str | None, + tangent_stack_trace: str | None, + ) -> RuntimeError: + expected_subclass_got_plain_tensor = ( + expected_type is not None + and expected_type is not torch.Tensor + and runtime_type is torch.Tensor + ) + if expected_subclass_got_plain_tensor: + tangent_msg = "" + if tangent_idx is not None: + tangent_msg = f" (tangent index: {tangent_idx})" + + output_hint = "" + if tangent_desc is not None: + from .descriptors import PlainAOTOutput, TangentAOTInput + + if isinstance(tangent_desc, TangentAOTInput) and isinstance( + tangent_desc.output, PlainAOTOutput + ): + idx = tangent_desc.output.idx + output_hint = f"\n\nThe problematic output is: forward output at index {idx} (0-indexed)" + else: + output_hint = ( + f"\n\nThe problematic output is: {tangent_desc.expr()}" + ) + + graph_hint = "" + if compile_id_str is not None: + graph_hint = ( + f"\n\nThis error occurred in compiled graph [{compile_id_str}]." + ) + + stack_trace_hint = "" + if tangent_stack_trace is not None: + stack_trace_hint = ( + f"\n\nThe forward output was created here:\n{tangent_stack_trace}" + ) + + return RuntimeError( + f""" +During the backward, we encountered a tensor subclass where we guessed its +metadata incorrectly. +Expected a {expected_type.__name__} tangent but got a plain Tensor{tangent_msg}. +This happens when a compiled function returns multiple outputs that +require gradients, but .backward() is only called on some of them. +To fix: call .detach() on forward outputs you don't need gradients for.{output_hint}{graph_hint}{stack_trace_hint} + +This error is also more likely to occur if your compiled model is suffering +from a large number of graph breaks. For more advice on finding and fixing +graph breaks, see: +https://docs.pytorch.org/docs/stable/user_guide/torch_compiler/compile/programming_model.graph_breaks_index.html + +For more info about this error, see: +https://github.com/pytorch/pytorch/issues/172556""" + ) + else: + return RuntimeError( + f""" +During the backward, we encountered a tensor subclass where we guessed its +metadata incorrectly. +Expected: {expected_meta} (type {expected_type}), +got: {runtime_meta} (type {runtime_type}), shape: {orig_x.shape}. +Your tensor subclass must implement __coerce_same_metadata_as_tangent__.""" + ) + + @staticmethod + def process_runtime_tangent( + x: Any, + meta: PlainTensorMeta | SubclassCreationMeta, + tangent_idx: int | None = None, + tangent_desc: Any | None = None, + compile_id_str: str | None = None, + tangent_stack_trace: str | None = None, + ) -> tuple[Any, list[Any]]: + if not isinstance(x, torch.Tensor): + return x, [x] + + if isinstance(x, FakeTensor): + if not meta.memory_format: + raise AssertionError( + "meta.memory_format must not be None for FakeTensor" + ) + x = coerce_to_expected_memory_format(x, meta.memory_format) + return x, [x] + + expected_type: type | None = torch.Tensor + expected_meta = None + if isinstance(meta, SubclassCreationMeta): + expected_type = meta.original_subclass_type + expected_meta = meta.meta + + runtime_type = type(x) + # When we're inside compiled autograd's AOTDispatcher step, + # regular Tensors look like FunctionalTensors. + # Tensor subclasses still look like Tensor subclasses though. + if isinstance(x, torch._subclasses.functional_tensor.FunctionalTensor): + runtime_type = torch.Tensor + + runtime_meta = None + runtime_subclass_keys: Sequence[str] = [] + + if is_traceable_wrapper_subclass(x): + runtime_subclass_keys, runtime_meta = x.__tensor_flatten__() + + def maybe_coerce(x: torch.Tensor) -> torch.Tensor | None: + same_type: bool = expected_type == runtime_type + same_meta: bool = expected_meta == runtime_meta + + if same_type and same_meta: + return x + + if not hasattr(x, "__coerce_same_metadata_as_tangent__"): + return None + + if same_type: + # Backward Compatibility, as some Subclass impls can have original 1-arg function. + return x.__coerce_same_metadata_as_tangent__(expected_meta) + + return x.__coerce_same_metadata_as_tangent__(expected_meta, expected_type) + + # Coerce to expected type and metadata + orig_x = x + x = maybe_coerce(x) + if x is None: + raise AOTDispatchAutograd._raise_tangent_metadata_error( + expected_type, + expected_meta, + runtime_type, + runtime_meta, + orig_x, + tangent_idx, + tangent_desc, + compile_id_str, + tangent_stack_trace, + ) + + # Coerce to expected memory format + if not meta.memory_format: + raise AssertionError("meta.memory_format must not be None") + x = coerce_to_expected_memory_format(x, meta.memory_format) + + if not is_traceable_wrapper_subclass(x): + return x, [x] + + if not isinstance(meta, SubclassCreationMeta): + raise AssertionError(f"expected SubclassCreationMeta, got {type(meta)}") + if orig_x is not x: + runtime_subclass_keys = x.__tensor_flatten__()[0] + + if len(meta.attrs) != len(runtime_subclass_keys): + raise AssertionError( + f"expected len(meta.attrs) == len(runtime_subclass_keys), " + f"got {len(meta.attrs)} != {len(runtime_subclass_keys)}" + ) + leaves = [] + for attr, attr_meta in meta.attrs.items(): + if isinstance(attr_meta, OpaqueMeta): + # Opaques aren't differentiable but occupy a flat arg slot. + leaves.append(getattr(x, attr)) + continue + elem = getattr(x, attr) + new_elem, elem_leaves = AOTDispatchAutograd.process_runtime_tangent( + elem, attr_meta + ) + if new_elem is not elem: + setattr(x, attr, new_elem) + leaves.extend(elem_leaves) + + return x, leaves + + @staticmethod + def post_compile(spec: AOTDispatchAutogradCompileSpec) -> Callable[..., Any]: + compiled_function_cls = _AOTDispatchAutogradFunctionFactory(spec).build() + return RuntimeWrapper( + indices_of_inps_to_detach=spec.indices_of_inps_to_detach, trace_joint=True, - disable_amp=disable_amp, + disable_amp=spec.disable_amp, ).post_compile( - CompiledFunction.apply, - aot_config, - runtime_metadata=fw_metadata, + compiled_function_cls.apply, + spec.aot_config, + runtime_metadata=spec.fw_metadata, ) - return compiled_function - @dataclass class DebugAssertWrapper(CompilerWrapper): @@ -3051,36 +3439,38 @@ def post_compile( *, runtime_metadata: ViewAndMutationMeta, ) -> Callable[..., Any]: - @wraps(compiled_fn) - def debug_compiled_function(args: list[Any]) -> Any: - # TODO: T253242027 Check aliasing relationships - # TODO: Check strides for metadata mutation - # (NB: ideally, this logic is factored out of this function and - # you move these debug checks there) - - # Check requires grad. Bad case is when we compiled with - # requires_grad = False, but input requires_grad = True - # (vice versa is OK; we compute a gradient and then throw - # it away when it hits the input.) - for i, a in enumerate(args): - can_require_grad = self.flat_requires_grad[i] - if can_require_grad is None: - if isinstance(a, Tensor): - raise AssertionError( - f"expected non-Tensor for arg {i}, got Tensor" - ) - elif not can_require_grad: - if a.requires_grad: - raise AssertionError( - format_guard_bug_msg( - aot_config, - f"{describe_input(i, aot_config)} would not require grad", - ) - ) + lines = ["def inner_fn(args):"] + globals_dict: dict[str, object] = {"compiled_fn": compiled_fn} + for i, can_require_grad in enumerate(self.flat_requires_grad): + if can_require_grad is None: + lines.append( + f" if isinstance(args[{i}], Tensor):" + f" raise AssertionError(" + f"'expected non-Tensor for arg {i}, got Tensor')" + ) + elif not can_require_grad: + msg_name = f"_msg_{i}" + globals_dict[msg_name] = format_guard_bug_msg( + aot_config, + f"{describe_input(i, aot_config)} would not require grad", + ) + lines.append( + f" if args[{i}].requires_grad: raise AssertionError({msg_name})" + ) + lines.append(" return compiled_fn(args)") + + source = "\n".join(lines) + globals_dict["Tensor"] = Tensor - return compiled_fn(args) + from .subclass_codegen import _compile_and_exec_source - return debug_compiled_function + return _compile_and_exec_source( + source, + globals_dict, + "inner_fn", + "debug_assert_wrapper", + wrapped_fn=compiled_fn, + ) def pre_compile( diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index ea69810b92b78..d1d1e76008d5f 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -186,14 +186,18 @@ class MemoryFormatMeta: memory_format: torch.memory_format | None = None @staticmethod - def from_tensor(t: torch.Tensor) -> MemoryFormatMeta | None: + def from_tensor( + t: torch.Tensor, force_use_memory_format: bool = False + ) -> MemoryFormatMeta | None: # We only memorize expected memory format for # 1. Traceable wrapper subclasses # We can not create restrided subclass tensor, as torch.empty_strided works only with dense tensors. # 2. Dynamic shape tensors # Support for symbolic shapes is not implemented yet. + # 3. force_use_memory_format=True (e.g., local_map where shapes change) use_memory_format: bool = ( - not torch._functorch.config.guess_tangent_strides_as_outputs + force_use_memory_format + or not torch._functorch.config.guess_tangent_strides_as_outputs or is_traceable_wrapper_subclass(t) ) if not use_memory_format: @@ -378,7 +382,7 @@ def _make_size_runtime_safe(x: None | int | torch.SymInt) -> int | None: # `_make_size_runtime_safe` replaces any nested int with a dummy value (-1) # to prevent serializing a SymInt at runtime. Internally, nested tensor __tensor_unflatten__ # is designed to safely ignore this dummy value. - # For more details, see: https://github.com/pytorch/pytorch/blob/5141ade8e30c64e873e14dcc8de233da45d15025/torch/nested/_internal/nested_tensor.py#L266-L299 # noqa: B950 + # For more details, see: https://github.com/pytorch/pytorch/blob/5141ade8e30c64e873e14dcc8de233da45d15025/torch/nested/_internal/nested_tensor.py#L266-L299 self.outer_size = tuple(map(_make_size_runtime_safe, self.outer_size)) self.outer_stride = tuple(map(_make_size_runtime_safe, self.outer_stride)) @@ -455,9 +459,6 @@ class ViewAndMutationMeta: subclass_fw_graph_out_meta: list[PlainTensorMeta | SubclassCreationMeta] # length = # backward graph inputs subclass_tangent_meta: list[PlainTensorMeta | SubclassCreationMeta] - # TODO: we should kill this - # (need to default it to not break internal) - is_train: bool = False # length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors) # + (# intermediate bases) @@ -493,6 +494,11 @@ class ViewAndMutationMeta: # Keeps track of which input indices store parameters (which we will treat as static) static_input_indices: list[int] = field(default_factory=list) + # Input indices that held AsyncCollectiveTensors at compile time. + # Used to emit direct trigger_wait() calls at runtime instead of + # scanning every arg on every graph invocation. + act_input_indices: list[int] = field(default_factory=list) + # Map of effect type (ex. _EffectType.ORDERED) to token. If there are # side-effectful operators, FunctionalTensorMode will populate this # dictionary telling us how many tokens we will need during tracing. diff --git a/torch/_functorch/_aot_autograd/streams.py b/torch/_functorch/_aot_autograd/streams.py index ed2b62d57b26d..2f8dd0e8b896e 100644 --- a/torch/_functorch/_aot_autograd/streams.py +++ b/torch/_functorch/_aot_autograd/streams.py @@ -6,6 +6,7 @@ import torch.utils._pytree as pytree from torch._dynamo.graph_utils import _get_flat_args from torch._dynamo.variables.streams import get_current_stream, new_event +from torch.fx.node import map_arg from torch.utils._runtime_estimation import ( _FLOAT_TYPES, _IGNORE_OPS, @@ -15,7 +16,7 @@ if TYPE_CHECKING: - from .schemas import ViewAndMutationMeta # noqa: TC004 + from .schemas import ViewAndMutationMeta from .indexed_dict import IndexedDict @@ -28,6 +29,9 @@ _SYNC_OPS = ( torch.ops.streams.record_event.default, torch.ops.streams.wait_event.default, + torch.ops.streams.synchronize_event.default, + torch.ops.streams.synchronize_device.default, + torch.ops.streams.synchronize_stream.default, ) @@ -363,11 +367,12 @@ def _wrap_sync_node( sync_node: Node, deps_before_sync: list[Node], visited: set[Node], -) -> Node: +) -> tuple[Node, list[Node]]: """ Core logic: wrap a single sync node in control_deps. - Returns the control_deps node that replaced the sync node. + Returns (control_deps_node, passthrough_getitems) where passthrough_getitems + are the getitem nodes that thread dependencies through the control_deps node. ``visited`` is the set of nodes at or before the sync node in graph order, used to distinguish pre-sync vs post-sync users. """ @@ -425,22 +430,29 @@ def _wrap_sync_node( replacements[dep] = getitem_node visited.add(getitem_node) - # Replace uses of dependencies that come after sync_node + # Replace uses of dependencies that come after sync_node. + # Use map_arg to handle nested structures (e.g. output node's list args). for dep, getitem_node in replacements.items(): for user in list(dep.users.keys()): if user is control_deps_node: continue if user in visited: continue - user.args = tuple(getitem_node if arg is dep else arg for arg in user.args) - user.kwargs = { - k: getitem_node if v is dep else v for k, v in user.kwargs.items() - } + # Don't replace forward outputs in the output node — they belong + # to the forward partition and must not reference backward nodes. + if user.op == "output" and not is_bwd_node(dep): + continue + + def _replace(n: Node) -> Node: + return getitem_node if n is dep else n + + user.args = map_arg(user.args, _replace) + user.kwargs = map_arg(user.kwargs, _replace) # Remove original sync node sync_node.replace_all_uses_with(control_deps_node) graph.erase_node(sync_node) - return control_deps_node + return control_deps_node, list(replacements.values()) def wrap_all_sync_nodes_with_control_deps(gm: torch.fx.GraphModule) -> None: @@ -457,8 +469,14 @@ def wrap_all_sync_nodes_with_control_deps(gm: torch.fx.GraphModule) -> None: raise RuntimeError("Expected a non-empty graph") stream_to_nodes: dict[int | None, list[Node]] = {} # Maps event_index -> control_deps node that wrapped its record_event, - # so the corresponding wait_event can depend on the record. + # so the corresponding wait_event/synchronize_event can depend on the record. event_to_ctrl: dict[int, Node] = {} + # Maps event_index -> getitem nodes threaded through record_event's control_deps, + # so synchronize_event can thread them through to subsequent ops. + event_to_passthrough: dict[int, list[Node]] = {} + # Maps event_index -> stream that the event was recorded on, + # so synchronize_event can infer its stream. + event_to_stream: dict[int, int | None] = {} visited: set[Node] = set() found_sync = False @@ -471,15 +489,59 @@ def wrap_all_sync_nodes_with_control_deps(gm: torch.fx.GraphModule) -> None: if node.op == "call_function": if node.target in _SYNC_OPS: + # synchronize_device and synchronize_stream block the CPU, + # so all subsequent kernel launches are host-ordered after + # them. Treat both as full barriers across all streams. + if node.target in ( + torch.ops.streams.synchronize_device.default, + torch.ops.streams.synchronize_stream.default, + ): + all_stream_deps: list[Node] = [ + n for nodes in stream_to_nodes.values() for n in nodes + ] + if all_stream_deps: + found_sync = True + _wrap_sync_node(gm, node, all_stream_deps, visited) + stream_to_nodes.clear() + node = next_node + continue + event_index: int = node.args[0] # type: ignore[assignment] - sync_stream: int | None = node.args[1] # type: ignore[assignment] - deps_before_sync = stream_to_nodes.get(sync_stream, []) - # For wait_events, add a cross-event dependency on the - # matching record_event's control_deps node so the wait - # cannot be reordered before the record. + # synchronize_event blocks the CPU thread, so it acts + # as a barrier across all streams. Collect deps from every + # stream and reset them all afterward. If the event was + # recorded externally, thread the graph inputs through so + # that any post-sync uses depend on the synchronize. + if node.target is torch.ops.streams.synchronize_event.default: + sync_stream: int | None = event_to_stream.get(event_index) + all_stream_deps: list[Node] = [ + n for nodes in stream_to_nodes.values() for n in nodes + ] + if event_index not in event_to_stream: + placeholders = [n for n in graph.nodes if n.op == "placeholder"] + deps_before_sync = [*placeholders, *all_stream_deps] + else: + deps_before_sync = all_stream_deps + else: + sync_stream = node.args[1] # type: ignore[assignment] + deps_before_sync = list(stream_to_nodes.get(sync_stream, ())) + # Nodes without explicit stream annotation (custom.stream=None) + # run on the current/default stream. Include them when the sync + # op references a stream, since the unannotated nodes are + # implicitly on that stream. + if None in stream_to_nodes and sync_stream is not None: + deps_before_sync.extend(stream_to_nodes[None]) + + # For wait_event and synchronize_event, add a cross-event + # dependency on the matching record_event's control_deps node + # so they cannot be reordered before the record. if ( - node.target is torch.ops.streams.wait_event.default + node.target + in ( + torch.ops.streams.wait_event.default, + torch.ops.streams.synchronize_event.default, + ) and event_index in event_to_ctrl ): deps_before_sync = [ @@ -487,22 +549,43 @@ def wrap_all_sync_nodes_with_control_deps(gm: torch.fx.GraphModule) -> None: *deps_before_sync, ] + # For synchronize_event, also include the getitem nodes + # threaded through record_event's control_deps. This ensures + # subsequent ops that depend on recorded values get rewired + # through synchronize_event. + if ( + node.target is torch.ops.streams.synchronize_event.default + and event_index in event_to_passthrough + ): + deps_before_sync = [ + *deps_before_sync, + *event_to_passthrough[event_index], + ] + if deps_before_sync: found_sync = True - ctrl_node = _wrap_sync_node(gm, node, deps_before_sync, visited) + ctrl_node, passthrough = _wrap_sync_node( + gm, node, deps_before_sync, visited + ) else: ctrl_node = None + passthrough: list[torch.fx.Node] = [] - if ( - node.target is torch.ops.streams.record_event.default - and ctrl_node is not None - ): - event_to_ctrl[event_index] = ctrl_node + if node.target is torch.ops.streams.record_event.default: + event_to_stream[event_index] = sync_stream + if ctrl_node is not None: + event_to_ctrl[event_index] = ctrl_node + event_to_passthrough[event_index] = passthrough # Reset: ops between this sync and the next will accumulate # fresh. Ordering with prior ops is already enforced because # their uses were rewired through getitems from control_deps. - stream_to_nodes[sync_stream] = [] + if node.target is torch.ops.streams.synchronize_event.default: + stream_to_nodes.clear() + else: + stream_to_nodes[sync_stream] = [] + if None in stream_to_nodes: + stream_to_nodes[None] = [] elif "val" in node.meta: stream = get_stream(node) stream_to_nodes.setdefault(stream, []).append(node) diff --git a/torch/_functorch/_aot_autograd/subclass_codegen.py b/torch/_functorch/_aot_autograd/subclass_codegen.py index c416e76171301..38b19e2691d72 100644 --- a/torch/_functorch/_aot_autograd/subclass_codegen.py +++ b/torch/_functorch/_aot_autograd/subclass_codegen.py @@ -70,6 +70,7 @@ def _codegen_unwrap_subclass( meta: SubclassCreationMeta, var: str, indent: int = 1, + include_symints: bool = True, ) -> None: """Emit code to recursively unwrap a single subclass input.""" for attr, attr_meta in meta.attrs.items(): @@ -84,26 +85,35 @@ def _codegen_unwrap_subclass( state.emit( f"{inner_var} = {_safe_attr_access(var, attr)}", indent=indent ) - _codegen_unwrap_subclass(state, attr_meta, inner_var, indent=indent) + _codegen_unwrap_subclass( + state, + attr_meta, + inner_var, + indent=indent, + include_symints=include_symints, + ) # Emit symint extraction - size_placeholders = _compute_placeholders(meta.outer_size) - stride_placeholders = _compute_placeholders(meta.outer_stride) - has_size_symints = any(size_placeholders) - has_stride_symints = any(stride_placeholders) - - if has_size_symints or has_stride_symints: - size_var = state.fresh_name("_size") - state.emit(f"{size_var} = {var}.size()", indent=indent) - for i, is_sym in enumerate(size_placeholders): - if is_sym: - state.emit(f"unwrapped_args.append({size_var}[{i}])", indent=indent) - - stride_var = state.fresh_name("_stride") - state.emit(f"{stride_var} = {var}.stride()", indent=indent) - for i, is_sym in enumerate(stride_placeholders): - if is_sym: - state.emit(f"unwrapped_args.append({stride_var}[{i}])", indent=indent) + if include_symints: + size_placeholders = _compute_placeholders(meta.outer_size) + stride_placeholders = _compute_placeholders(meta.outer_stride) + has_size_symints = any(size_placeholders) + has_stride_symints = any(stride_placeholders) + + if has_size_symints or has_stride_symints: + size_var = state.fresh_name("_size") + state.emit(f"{size_var} = {var}.size()", indent=indent) + for i, is_sym in enumerate(size_placeholders): + if is_sym: + state.emit(f"unwrapped_args.append({size_var}[{i}])", indent=indent) + + stride_var = state.fresh_name("_stride") + state.emit(f"{stride_var} = {var}.stride()", indent=indent) + for i, is_sym in enumerate(stride_placeholders): + if is_sym: + state.emit( + f"unwrapped_args.append({stride_var}[{i}])", indent=indent + ) def _concrete_value(val: None | int | SymInt) -> int: @@ -178,23 +188,42 @@ def _build_tuple( return result_var -def _codegen_subclass_wrapper_source( - inp_metas: list[PlainTensorMeta | SubclassCreationMeta], +def _emit_output_wrapping( + state: _CodegenState, out_metas: list[PlainTensorMeta | SubclassCreationMeta], - num_fw_outs_saved_for_bw: int | None, - frozen_inp_indices: frozenset[int] = frozenset(), -) -> tuple[str, dict[str, object]]: - """Generate source and globals for a subclass wrapper. +) -> tuple[list[str], int]: + """Emit wrapping code for output metas. - Returns (source, globals_dict). The globals_dict will NOT contain - ``compiled_fn`` — the caller is responsible for adding it before exec. + Returns (result_exprs, num_args_tallied) where result_exprs are Python + expression strings referencing each wrapped output. """ - state = _CodegenState() + out_idx_ref = [0] + result_exprs: list[str] = [] + num_args_tallied = 0 - state.emit("def inner_fn(args):", indent=0) + for meta in out_metas: + if isinstance(meta, PlainTensorMeta): + result_exprs.append(f"unwrapped_outs[{meta.unwrapped_idx}]") + num_args_tallied += 1 + out_idx_ref[0] = max(out_idx_ref[0], meta.unwrapped_idx + 1) + else: + result_var = _codegen_wrap_subclass(state, meta, out_idx_ref) + result_exprs.append(result_var) + num_args_tallied += meta.arg_count - # --- Input unwrapping --- - state.emit("unwrapped_args = []") + return result_exprs, num_args_tallied + + +def _emit_input_unwrapping( + state: _CodegenState, + inp_metas: list[PlainTensorMeta | SubclassCreationMeta], + frozen_inp_indices: frozenset[int] = frozenset(), + include_symints: bool = True, +) -> None: + """Emit unwrapping code for input metas into unwrapped_args. + + Caller must have already emitted ``unwrapped_args = []``. + """ for i, meta in enumerate(inp_metas): if isinstance(meta, PlainTensorMeta): state.emit(f"unwrapped_args.append(args[{i}])") @@ -212,7 +241,38 @@ def _codegen_subclass_wrapper_source( f"assert type({inp_var}) is {type_name}, " f"f'expected {{{type_name}}}, got {{type({inp_var})}}'", ) - _codegen_unwrap_subclass(state, meta, inp_var, indent=1) + _codegen_unwrap_subclass( + state, meta, inp_var, indent=1, include_symints=include_symints + ) + + +def _codegen_subclass_wrapper_source( + inp_metas: list[PlainTensorMeta | SubclassCreationMeta], + out_metas: list[PlainTensorMeta | SubclassCreationMeta], + num_fw_outs_saved_for_bw: int | None, + frozen_inp_indices: frozenset[int] = frozenset(), + act_input_indices: list[int] | None = None, +) -> tuple[str, dict[str, object]]: + """Generate source and globals for a subclass wrapper. + + Returns (source, globals_dict). The globals_dict will NOT contain + ``compiled_fn`` — the caller is responsible for adding it before exec. + """ + state = _CodegenState() + + state.emit("def inner_fn(args):", indent=0) + + # --- Resolve AsyncCollectiveTensors --- + # ACTs are transient eager-mode wrappers for async collective overlap. + # Inductor triton kernels bypass __torch_dispatch__, so we must call + # trigger_wait() before the compiled graph uses the data. + if act_input_indices: + for i in act_input_indices: + state.emit(f"args[{i}] = args[{i}].trigger_wait()") + + # --- Input unwrapping --- + state.emit("unwrapped_args = []") + _emit_input_unwrapping(state, inp_metas, frozen_inp_indices=frozen_inp_indices) # Pass through any trailing args not covered by inp_metas # (e.g. rng seed/offset added by FunctionalizedRngRuntimeWrapper). @@ -224,38 +284,104 @@ def _codegen_subclass_wrapper_source( state.emit("unwrapped_outs = compiled_fn(unwrapped_args)") # --- Output wrapping --- - state.emit("wrapped_outs = []") - out_idx_ref = [0] - num_args_tallied = 0 - - for meta in out_metas: - if isinstance(meta, PlainTensorMeta): - state.emit(f"wrapped_outs.append(unwrapped_outs[{meta.unwrapped_idx}])") - num_args_tallied += 1 - out_idx_ref[0] = max(out_idx_ref[0], meta.unwrapped_idx + 1) - else: - result_var = _codegen_wrap_subclass(state, meta, out_idx_ref) - state.emit(f"wrapped_outs.append({result_var})") - num_args_tallied += meta.arg_count - - # Append activations saved for backward + result_exprs, num_args_tallied = _emit_output_wrapping(state, out_metas) + result_tuple = f"({', '.join(result_exprs)},)" if result_exprs else "()" if num_fw_outs_saved_for_bw is not None: state.emit( - f"return tuple(wrapped_outs) + tuple(unwrapped_outs[{num_args_tallied}:])" + f"return {result_tuple} + tuple(unwrapped_outs[{num_args_tallied}:])" ) else: - state.emit("return tuple(wrapped_outs)") + state.emit(f"return {result_tuple}") source = "\n".join(state.lines) return source, state.globals +def _codegen_subclass_wrap_source( + out_metas: list[PlainTensorMeta | SubclassCreationMeta], +) -> tuple[str, dict[str, object]]: + """Generate source for wrapping flat outputs into subclasses. + + Used for the backward epilogue. Shares output-wrapping logic with + _codegen_subclass_wrapper_source via _emit_output_wrapping. + """ + state = _CodegenState() + state.emit("def wrap_fn(unwrapped_outs):", indent=0) + result_exprs, _ = _emit_output_wrapping(state, out_metas) + result_tuple = f"({', '.join(result_exprs)},)" if result_exprs else "()" + state.emit(f"return {result_tuple}") + source = "\n".join(state.lines) + return source, state.globals + + +def _compile_and_exec_source( + source: str, + globals_dict: dict[str, object], + fn_name: str, + artifact_name: str, + wrapped_fn: Callable[..., object] | None = None, +) -> Callable[..., object]: + """Compile generated source, exec it, and return the named function. + + If wrapped_fn is provided, applies functools.update_wrapper so that + __wrapped__ and __dict__ (e.g. _fx_graph_cache_key) propagate to the + generated function. + """ + if log.isEnabledFor(logging.DEBUG): + log.debug("Generated %s:\n%s", artifact_name, source) + + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": artifact_name, + "encoding": "string", + }, + payload_fn=lambda: source, + ) + + code = compile(source, f"<{artifact_name}>", "exec") + local_dict: dict[str, object] = {} + exec(code, globals_dict, local_dict) + fn = local_dict[fn_name] + if wrapped_fn is not None: + functools.update_wrapper(fn, wrapped_fn) # type: ignore[arg-type] + return fn # type: ignore[return-value] + + +def codegen_backward_subclass_fns( + grad_input_metas: list[PlainTensorMeta | SubclassCreationMeta] | None = None, +) -> tuple[Callable[..., object], Callable[..., object] | None]: + """Generate codegen'd unwrap and wrap functions for the backward pass. + + Returns (unwrap_fn, wrap_fn). unwrap_fn is used by the backward prologue + to unwrap non-tangent subclass inputs (always an identity in AOT dispatch + since the compiled forward operates on unwrapped inner tensors). wrap_fn + is used by the backward epilogue to wrap flat grad inputs back into + subclasses; it is None when grad_input_metas is None. + """ + source = "def unwrap_fn(args):\n return list(args)" + globals_dict: dict[str, object] = {} + unwrap_fn = _compile_and_exec_source( + source, globals_dict, "unwrap_fn", "backward_subclass_unwrap" + ) + + wrap_fn = None + if grad_input_metas is not None: + wrap_source, wrap_globals = _codegen_subclass_wrap_source(grad_input_metas) + wrap_fn = _compile_and_exec_source( + wrap_source, wrap_globals, "wrap_fn", "backward_subclass_wrapper" + ) + + return unwrap_fn, wrap_fn + + def codegen_subclass_wrapper( compiled_fn: Callable[..., object], inp_metas: list[PlainTensorMeta | SubclassCreationMeta], out_metas: list[PlainTensorMeta | SubclassCreationMeta], num_fw_outs_saved_for_bw: int | None, frozen_inp_indices: frozenset[int] = frozenset(), + act_input_indices: list[int] | None = None, ) -> Callable[..., object]: """Generate a specialized wrapper function for subclass unwrap/wrap.""" source, globals_dict = _codegen_subclass_wrapper_source( @@ -263,27 +389,9 @@ def codegen_subclass_wrapper( out_metas, num_fw_outs_saved_for_bw, frozen_inp_indices, + act_input_indices=act_input_indices, ) globals_dict["compiled_fn"] = compiled_fn - - if log.isEnabledFor(logging.DEBUG): - log.debug("Generated subclass wrapper:\n%s", source) - - torch._logging.trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "subclass_wrapper", - "encoding": "string", - }, - payload_fn=lambda: source, + return _compile_and_exec_source( + source, globals_dict, "inner_fn", "subclass_wrapper", wrapped_fn=compiled_fn ) - - code = compile(source, "", "exec") - local_dict: dict[str, object] = {} - exec(code, globals_dict, local_dict) # noqa: S102 - inner_fn = local_dict["inner_fn"] - # Replicate @wraps(compiled_fn): sets __wrapped__ (so autograd_cache can - # unwrap to the underlying OutputCode) and copies __dict__ (so attributes - # like _fx_graph_cache_key are visible on the wrapper). - functools.update_wrapper(inner_fn, compiled_fn) # type: ignore[arg-type] - return inner_fn # type: ignore[return-value] diff --git a/torch/_functorch/_aot_autograd/subclass_parametrization.py b/torch/_functorch/_aot_autograd/subclass_parametrization.py index a820adad1a65a..0911526b91242 100644 --- a/torch/_functorch/_aot_autograd/subclass_parametrization.py +++ b/torch/_functorch/_aot_autograd/subclass_parametrization.py @@ -124,7 +124,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul """ for name, tensor in itertools.chain( list(module.named_parameters(recurse=False)), - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] list(module.named_buffers(recurse=False)), ): if is_traceable_wrapper_subclass(tensor): diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index b9548107eb6fe..8e5b0da115f12 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -5,9 +5,8 @@ """ import collections -import typing from collections.abc import Callable, Iterable, Sequence -from typing import Any, TYPE_CHECKING, TypeGuard, TypeVar +from typing import Any, TypeGuard, TypeVar import torch import torch.utils._pytree as pytree @@ -17,7 +16,10 @@ from torch._opaque_base import OpaqueBase from torch._subclasses.fake_tensor import get_plain_tensors from torch.types import IntLikeType -from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TraceableWrapperSubclass, +) from .descriptors import ( AOTInput, @@ -42,10 +44,6 @@ from .utils import strict_zip -if TYPE_CHECKING: - from torch._library.opaque_object import OpaqueType - - zip = strict_zip T = TypeVar("T", bound=torch.Tensor) @@ -257,7 +255,7 @@ def compute_symint_placeholders(lst: Iterable[None | int | SymInt]) -> list[bool # this function below. def unwrap_tensor_subclasses( wrapped_args: list[FxValue], - wrapped_args_descs: list[AOTDescriptor], + wrapped_args_descs: Sequence[AOTDescriptor], *, append_symints: bool, ) -> tuple[list[FxValue], list[AOTDescriptor]]: @@ -289,14 +287,21 @@ def flatten_subclass( attrs, _ = t.__tensor_flatten__() + SubclassGetAttr: Callable[[AOTInput | AOTOutput, str], AOTDescriptor] + SubclassSize: Callable[[AOTInput | AOTOutput, int], AOTDescriptor] + SubclassStride: Callable[[AOTInput | AOTOutput, int], AOTDescriptor] + if isinstance(desc, AOTInput): + SubclassGetAttr = SubclassGetAttrAOTInput # type: ignore[bad-assignment] + SubclassSize = SubclassSizeAOTInput # type: ignore[bad-assignment] + SubclassStride = SubclassStrideAOTInput # type: ignore[bad-assignment] + else: + SubclassGetAttr = SubclassGetAttrAOTOutput # type: ignore[bad-assignment] + SubclassSize = SubclassSizeAOTOutput # type: ignore[bad-assignment] + SubclassStride = SubclassStrideAOTOutput # type: ignore[bad-assignment] + for attr in attrs: inner_value = getattr(t, attr) - n_desc: Any = ( - SubclassGetAttrAOTInput(desc, attr) - if isinstance(desc, AOTInput) - # pyrefly: ignore [bad-argument-type] - else SubclassGetAttrAOTOutput(desc, attr) - ) + n_desc: Any = SubclassGetAttr(desc, attr) flatten_subclass(inner_value, n_desc, out=out) if append_symints: @@ -304,19 +309,14 @@ def flatten_subclass( strides = enumerate_filter_symints(t.stride()) out[0].extend(s for _, s in sizes) out[0].extend(s for _, s in strides) - if isinstance(desc, AOTInput): - out[1].extend(SubclassSizeAOTInput(desc, i) for i, _ in sizes) # type: ignore[misc] - out[1].extend(SubclassStrideAOTInput(desc, i) for i, _ in strides) # type: ignore[misc] - else: - out[1].extend(SubclassSizeAOTOutput(desc, i) for i, _ in sizes) # type: ignore[misc] - out[1].extend(SubclassStrideAOTOutput(desc, i) for i, _ in strides) # type: ignore[misc] + out[1].extend(SubclassSize(desc, i) for i, _ in sizes) + out[1].extend(SubclassStride(desc, i) for i, _ in strides) xs_inner: list[FxValue] = [] descs_inner: list[AOTDescriptor] = [] for x, desc in zip(wrapped_args, wrapped_args_descs): - # pyrefly: ignore [bad-argument-type] - flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner)) + flatten_subclass(x, desc, out=(xs_inner, descs_inner)) return xs_inner, descs_inner @@ -328,16 +328,21 @@ def runtime_unwrap_tensor_subclasses( *, append_symints: bool, subclass_metas: list[PlainTensorMeta | SubclassCreationMeta] | None = None, -) -> list[Any]: +) -> list[int | Tensor | SymInt | OpaqueBase]: def flatten_subclass( - x: Tensor, meta: SubclassCreationMeta | None, *, out: list[Any] - ) -> list[Any]: + x: Tensor | TraceableWrapperSubclass, + subclass_meta: PlainTensorMeta | SubclassCreationMeta | OpaqueMeta | None, + *, + out: list[OpaqueBase | SymInt | Tensor | int], + ) -> list[OpaqueBase | SymInt | Tensor | int]: if not is_traceable_wrapper_subclass(x): out.append(x) return out if not isinstance(x, Tensor): raise AssertionError(f"expected Tensor, got {type(x)}") + if not isinstance(subclass_meta, SubclassCreationMeta): + raise AssertionError("subclass_meta should be a SubclassCreationMeta") attrs, _ = x.__tensor_flatten__() @@ -347,8 +352,7 @@ def flatten_subclass( case OpaqueBase(): out.append(inner_value) case Tensor(): - # pyrefly: ignore [missing-attribute] - inner_meta = meta.attrs.get(attr) + inner_meta = subclass_meta.attrs.get(attr) flatten_subclass(inner_value, inner_meta, out=out) case _: raise AssertionError( @@ -356,11 +360,9 @@ def flatten_subclass( ) if append_symints: - if not isinstance(meta, SubclassCreationMeta): - raise AssertionError(f"expected SubclassCreationMeta, got {type(meta)}") # outer_size size = x.size() - symint_placeholders = compute_symint_placeholders(meta.outer_size) + symint_placeholders = compute_symint_placeholders(subclass_meta.outer_size) if len(size) != len(symint_placeholders): raise AssertionError( f"size length mismatch: {len(size)} != {len(symint_placeholders)}" @@ -371,7 +373,9 @@ def flatten_subclass( # outer_stride stride = x.stride() - symint_placeholders = compute_symint_placeholders(meta.outer_stride) + symint_placeholders = compute_symint_placeholders( + subclass_meta.outer_stride + ) if len(stride) != len(symint_placeholders): raise AssertionError( f"stride length mismatch: {len(stride)} != {len(symint_placeholders)}" @@ -381,7 +385,7 @@ def flatten_subclass( ) return out - xs_inner: list[int | Tensor | SymInt | OpaqueType] = [] + xs_inner: list[int | Tensor | SymInt | OpaqueBase] = [] if append_symints: if subclass_metas is None: @@ -395,12 +399,14 @@ def flatten_subclass( continue if subclass_metas is None: - get_plain_tensors(typing.cast(Tensor, x), out=xs_inner) + get_plain_tensors(x, out=xs_inner) else: - meta = subclass_metas[idx] - if not isinstance(meta, SubclassCreationMeta): - raise AssertionError(f"expected SubclassCreationMeta, got {type(meta)}") - flatten_subclass(typing.cast(Tensor, x), meta, out=xs_inner) + subclass_meta = subclass_metas[idx] + if not isinstance(subclass_meta, SubclassCreationMeta): + raise AssertionError( + f"expected SubclassCreationMeta, got {type(subclass_meta)}" + ) + flatten_subclass(x, subclass_meta, out=xs_inner) return xs_inner @@ -431,7 +437,7 @@ def remap_unwrapped_subclass_arg_indices( num_indices = 1 if is_traceable_wrapper_subclass(arg): num_indices = ( - len(get_plain_tensors(typing.cast(Tensor, arg), out=[])) + len(get_plain_tensors(arg, out=[])) + len(enumerate_filter_symints(arg.size())) + len(enumerate_filter_symints(arg.stride())) ) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 2dbda56bfd295..a06c82307ffdf 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -48,18 +48,6 @@ strict_zip = partial(zip, strict=True) -def _get_symint_hints(exprs: Any) -> Any: - """ - Get the hints of a list/tuple of int/SymInt. - """ - if isinstance(exprs, (list, tuple)): - return type(exprs)(_get_symint_hints(e) for e in exprs) - elif isinstance(exprs, torch.SymInt): - return exprs.node.shape_env.size_hint(exprs.node.expr) - else: - return exprs - - def partial_flatten_asdict(obj: object) -> Any: if dataclasses.is_dataclass(obj): return { diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 0ee89cd521846..cd486a37d6744 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -5,7 +5,7 @@ import time from contextlib import nullcontext from functools import wraps -from typing import Any, TYPE_CHECKING +from typing import Any, Literal, TYPE_CHECKING from typing_extensions import ParamSpec, TypeVar from unittest.mock import patch @@ -25,23 +25,21 @@ set_feature_use, ) from torch._guards import detect_fake_mode -from torch._inductor.utils import BoxedBool +from torch._inductor.codecache import resolve_pre_grad_pass_timing + +# Runtime annotation consumers still resolve BoxedBool from module globals. from torch._subclasses import FakeTensor, FakeTensorMode from torch.export._tree_utils import reorder_kwargs from torch.fx.experimental.proxy_tensor import make_fx - -static_inputs_log = torch._logging.getArtifactLogger( - __name__, "cudagraph_static_inputs" -) from . import config -from ._aot_autograd.autograd_cache import ( # noqa: F401 +from ._aot_autograd import autograd_cache +from ._aot_autograd.autograd_cache import ( AOTAutogradCache, - autograd_cache_key, should_use_local_autograd_cache, should_use_remote_autograd_cache, ) -from ._aot_autograd.collect_metadata_analysis import ( # noqa: F401 +from ._aot_autograd.collect_metadata_analysis import ( run_functionalized_fw_and_collect_metadata, ) from ._aot_autograd.descriptors import ( @@ -79,7 +77,7 @@ fn_input_mutations_to_outputs, fn_prepped_for_autograd, ) -from ._aot_autograd.graph_compile import ( # noqa: F401 +from ._aot_autograd.graph_compile import ( aot_stage1_graph_capture, aot_stage2_compile, aot_stage2_export, @@ -141,7 +139,6 @@ ) from ._aot_autograd.utils import ( # noqa: F401 _get_autocast_states, - _get_symint_hints, call_func_at_runtime_with_args, create_tree_flattened_fn, KNOWN_TYPES, @@ -161,7 +158,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, Sequence - from torch._inductor.cudagraph_utils import BoxedDeviceIndex + from torch._inductor.compile_fx import CompilerConfigExtra from torch._inductor.output_code import OutputCode from torch._inductor.utils import InputType from torch._ops import OpOverload @@ -185,6 +182,7 @@ # may involve compiling multiple subgraphs; e.g., for forwards/backwards) AOT_COUNTER = itertools.count() + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # @@ -586,7 +584,6 @@ def _dup_fake_script_obj(fake_flat_args: FakifiedFlatArgs) -> list[Any]: flat_args_descs=flat_args_descs, static_input_indices=aot_config.static_input_indices, keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=needs_autograd, pre_dispatch=aot_config.pre_dispatch, )(*_dup_fake_script_obj(fake_flat_args)) @@ -622,35 +619,6 @@ def _dup_fake_script_obj(fake_flat_args: FakifiedFlatArgs) -> list[Any]: # and none of the inputs that require grad are mutated. # so we actually have an inference graph. needs_autograd = False - # A bit silly: right now in the subclass codepath, our ViewAndMutationMeta - # changes depending on whether we pass in is_train / keep_input_mutations, - # so we're forced to recompute the metadata. - # TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata - # so that this is unnecessary. - if req_subclass_dispatch: - fw_metadata = run_functionalized_fw_and_collect_metadata( - flat_fn, - flat_args_descs=flat_args_descs, - keep_input_mutations=aot_config.keep_inference_input_mutations, - is_train=False, - pre_dispatch=aot_config.pre_dispatch, - static_input_indices=aot_config.static_input_indices, - )(*fake_flat_args) - else: - fw_metadata = ViewAndMutationMeta( - input_info=fw_metadata.input_info, - output_info=fw_metadata.output_info, - num_intermediate_bases=fw_metadata.num_intermediate_bases, - keep_input_mutations=aot_config.keep_inference_input_mutations, - traced_tangents=fw_metadata.traced_tangents, - traced_tangents_descs=fw_metadata.traced_tangents_descs, - subclass_inp_meta=fw_metadata.subclass_inp_meta, - subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta, - subclass_tangent_meta=fw_metadata.subclass_tangent_meta, - is_train=False, - tokens=fw_metadata.tokens, - static_input_indices=fw_metadata.static_input_indices, - ) if fw_metadata.num_intermediate_bases > 0: if req_subclass_dispatch: @@ -824,7 +792,8 @@ def returned_function(*args: _P.args, **kwargs: _P.kwargs) -> Any: if cached_res is None: flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs) (fake_mode, shape_env) = construct_fake_mode(flat_args, aot_config) - fake_flat_args: FakifiedFlatArgs = process_inputs( + fake_flat_args: FakifiedFlatArgs + fake_flat_args, act_input_indices = process_inputs( flat_args, aot_config, fake_mode, shape_env ) # TODO: We actually could use the pytree path to make better descs. @@ -842,6 +811,7 @@ def returned_function(*args: _P.args, **kwargs: _P.kwargs) -> Any: fake_mode, shape_env, ) + aot_state.fw_metadata.act_input_indices = act_input_indices aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) compiled_fn, _ = aot_stage2_compile( aot_state, @@ -921,39 +891,60 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return AOTModule() -def prepare_aot_module_simplified( +def autograd_cache_key( + graph, + example_inputs, + ignore_shape_env: bool, + decompositions, + compiler_config_extra: CompilerConfigExtra, + keep_inference_input_mutations: bool = False, + disable_functionalization: bool = False, +): + ( + _params_buffers_flat, + _params_spec, + _buffers_spec, + full_args, + _full_args_descs, + aot_config, + ) = prepare_aot_config( + graph, + example_inputs, + decompositions, + keep_inference_input_mutations, + ignore_shape_env, + force_non_lazy_backward_lowering=config.force_non_lazy_backward_lowering, + disable_functionalization=disable_functionalization, + ) + + fake_mode, shape_env = construct_fake_mode(full_args, aot_config) + fake_flat_args, _act_input_indices = process_inputs( + full_args, aot_config, fake_mode, shape_env, ignore_shape_env + ) + + return autograd_cache.autograd_cache_key( + graph, fake_flat_args, aot_config, compiler_config_extra + ) + + +def prepare_aot_config( mod: nn.Module, args: Iterable[Any], - kwargs: dict[str, Any] | None, decompositions: dict[OpOverload, Callable[..., Any]] | None, keep_inference_input_mutations: bool, - boxed_forward_device_index: BoxedDeviceIndex | None, ignore_shape_env: bool, - flatten: bool, *, force_non_lazy_backward_lowering: bool = False, disable_functionalization: bool = False, - _record_nn_module_stack: bool = False, _disable_torch_fn_metadata_mode: bool = False, ) -> tuple[ - Any, list[torch.nn.Parameter | Tensor], list[str], list[str], - FakifiedFlatArgs, + list[Any], list[Any], AOTConfig, - FakeTensorMode, - ShapeEnv | None, - pytree.TreeSpec | None, - PytreeThunk | None, ]: - if not flatten: - if kwargs is not None: - raise AssertionError("kwargs must be None when flatten=False") - elif kwargs is None: - kwargs = {} - # TODO: There's something a bit suspicious here; typically simplified # module shouldn't actually have any parameters... params = dict(mod.named_parameters(remove_duplicate=False)) @@ -967,41 +958,14 @@ def prepare_aot_module_simplified( params_buffers = {**params, **buffers} params_buffers_flat = params_flat + buffers_flat - params_buffers_spec = params_spec + buffers_spec - - # Take a break to figure what we're doing with the module - - # NB: This doesn't change the in/out convention, except adding the - # parameters as explicit arguments - functional_call = create_functional_call( - mod, - params_buffers_spec, - params_len + buffers_len, - strict_out_tuple=not flatten, - # We need this for export to run ModuleStackTracer - # instead of PythonKeyTracer - store_orig_mod=_record_nn_module_stack, - ) full_args = [*params_flat, *buffers_flat, *args] - in_spec, out_spec = None, None - if flatten: - functional_call, out_spec = create_tree_flattened_fn( - functional_call, full_args, kwargs - ) - full_args, in_spec = pytree.tree_flatten((full_args, kwargs)) - - del kwargs # OK, set up the descs full_args_descs: list[DifferentiableAOTInput] = [] full_args_descs.extend(ParamAOTInput(fqn) for fqn in params_spec) full_args_descs.extend(BufferAOTInput(fqn) for fqn in buffers_spec) - # TODO: it would be better to put pytree information in here - full_args_descs.extend( - PlainAOTInput(i) for i in range(len(full_args) - len(full_args_descs)) - ) # TODO: These tracing_context fields should become unnecessary once we # always maintain sources on all arguments @@ -1050,9 +1014,97 @@ def prepare_aot_module_simplified( disable_functionalization=disable_functionalization, _disable_torch_fn_metadata_mode=_disable_torch_fn_metadata_mode, ) + + return ( + params_buffers_flat, + params_spec, + buffers_spec, + full_args, + full_args_descs, + aot_config, + ) + + +def prepare_aot_module_simplified( + mod: nn.Module, + args: Iterable[Any], + kwargs: dict[str, Any] | None, + decompositions: dict[OpOverload, Callable[..., Any]] | None, + keep_inference_input_mutations: bool, + ignore_shape_env: bool, + flatten: bool, + *, + force_non_lazy_backward_lowering: bool = False, + disable_functionalization: bool = False, + _record_nn_module_stack: bool = False, + _disable_torch_fn_metadata_mode: bool = False, +) -> tuple[ + Any, + list[torch.nn.Parameter | Tensor], + list[str], + list[str], + FakifiedFlatArgs, + list[Any], + AOTConfig, + FakeTensorMode, + ShapeEnv | None, + pytree.TreeSpec | None, + PytreeThunk | None, + list[int], +]: + if not flatten: + if kwargs is not None: + raise AssertionError("kwargs must be None when flatten=False") + elif kwargs is None: + kwargs = {} + + ( + params_buffers_flat, + params_spec, + buffers_spec, + full_args, + full_args_descs, + aot_config, + ) = prepare_aot_config( + mod, + args, + decompositions, + keep_inference_input_mutations, + ignore_shape_env, + force_non_lazy_backward_lowering=force_non_lazy_backward_lowering, + disable_functionalization=disable_functionalization, + _disable_torch_fn_metadata_mode=_disable_torch_fn_metadata_mode, + ) + + params_buffers_spec = params_spec + buffers_spec + + # NB: This doesn't change the in/out convention, except adding the + # parameters as explicit arguments + functional_call = create_functional_call( + mod, + params_buffers_spec, + aot_config.num_params_buffers, + strict_out_tuple=not flatten, + # We need this for export to run ModuleStackTracer + # instead of PythonKeyTracer + store_orig_mod=_record_nn_module_stack, + ) + + in_spec, out_spec = None, None + if flatten: + functional_call, out_spec = create_tree_flattened_fn( + functional_call, full_args, kwargs + ) + full_args, in_spec = pytree.tree_flatten((full_args, kwargs)) + + # TODO: it would be better to put pytree information in here + full_args_descs.extend( + PlainAOTInput(i) for i in range(len(full_args) - len(full_args_descs)) + ) + fake_mode, shape_env = construct_fake_mode(full_args, aot_config) # NB: full_args_descs not needed here, fake_flat_args is 1:1 with full_args - fake_flat_args = process_inputs( + fake_flat_args, act_input_indices = process_inputs( full_args, aot_config, fake_mode, shape_env, ignore_shape_env ) @@ -1068,12 +1120,13 @@ def prepare_aot_module_simplified( shape_env, in_spec, out_spec, + act_input_indices, ) def aot_module_simplified( mod: torch.fx.GraphModule | torch._dynamo.utils.GmWrapper, - args: Iterable[Any], + args: Sequence[Any], fw_compiler: AOTDispatchCompiler, bw_compiler: AOTDispatchCompiler | None = None, partition_fn: Callable[..., Any] = default_partition, @@ -1082,8 +1135,7 @@ def aot_module_simplified( inference_compiler: AOTDispatchCompiler | None = None, # TODO: This doesn't seem to be used in any nontrivial way, check if it's # actually needed - cudagraphs: BoxedBool | None = None, - boxed_forward_device_index: BoxedDeviceIndex | None = None, + compiler_config_extra: CompilerConfigExtra | None = None, ignore_shape_env: bool = False, disable_functionalization: bool = False, # Optional callback to run passes on the module at the start of AOT autograd. @@ -1091,6 +1143,7 @@ def aot_module_simplified( [torch.fx.GraphModule, Sequence[InputType]], torch.fx.GraphModule ] | None = None, + compile_region_name: str | None = None, ) -> Callable[..., Any]: """ This is the simplified or low overhead version of aot_module. For frontends @@ -1103,8 +1156,14 @@ def aot_module_simplified( :func:`aot_module_simplified` removes these overheads. """ - if cudagraphs is None: - cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs) + pre_grad_pass_timing: Literal["early", "late"] = resolve_pre_grad_pass_timing() + + if ( + pre_grad_pass_timing == "early" + and pre_grad_passes + and isinstance(mod, torch.fx.GraphModule) + ): + mod = pre_grad_passes(mod, args) with contextlib.ExitStack() as stack: ( @@ -1119,13 +1178,13 @@ def aot_module_simplified( shape_env, _in_spec, _out_spec, + act_input_indices, ) = prepare_aot_module_simplified( mod, args, None, decompositions, keep_inference_input_mutations, - boxed_forward_device_index, ignore_shape_env, flatten=False, force_non_lazy_backward_lowering=config.force_non_lazy_backward_lowering, @@ -1146,16 +1205,19 @@ def aot_module_simplified( mod, fake_flat_args, aot_config, - cudagraphs, - boxed_forward_device_index, + compiler_config_extra, local, remote, + compile_region_name=compile_region_name, ) if compiled_fn is None: - # Run pre-grad passes after cache lookup to cache pre-grad transforms. - if pre_grad_passes is not None and isinstance(mod, torch.fx.GraphModule): - mod = pre_grad_passes(mod, fake_flat_args) + if ( + pre_grad_pass_timing == "late" + and pre_grad_passes + and isinstance(mod, torch.fx.GraphModule) + ): + mod = pre_grad_passes(mod, args) stack.enter_context(compiled_autograd._disable()) aot_state = create_aot_state( @@ -1167,6 +1229,7 @@ def aot_module_simplified( fake_mode, shape_env, ) + aot_state.fw_metadata.act_input_indices = act_input_indices aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) compiled_fn, _ = aot_stage2_compile( aot_state, @@ -1230,7 +1293,7 @@ def grab_serialize_fn(fn: Any) -> Callable[..., Any] | None: def boxed_nop_preserve_node_meta( gm: torch.fx.GraphModule, example_inputs: Sequence[InputType] ) -> Any: - def run(args: Sequence[Any]) -> OutputCode: + def run(args: list[Any]) -> OutputCode: with torch.fx.traceback.preserve_node_meta(): return torch.fx.Interpreter(gm).boxed_run(args) @@ -1323,6 +1386,7 @@ def aot_export_joint_with_descriptors( shape_env, in_spec, out_spec, + act_input_indices, ) = prepare_aot_module_simplified( mod, args, @@ -1330,7 +1394,6 @@ def aot_export_joint_with_descriptors( # In contrast, decompositions are needed at this stage. decompositions, keep_inference_input_mutations, - None, ignore_shape_env, flatten=True, # Without this, we will attempt to "compile" the backward lazily @@ -1357,6 +1420,7 @@ def aot_export_joint_with_descriptors( fake_mode, shape_env, ) + aot_state.fw_metadata.act_input_indices = act_input_indices # NB: no cache lookup! aot_graph_capture = aot_stage1_graph_capture(aot_state, functional_call) @@ -1770,7 +1834,7 @@ def aot_export_joint_simple( if config.debug_assert: # Smoke test that after partitioning, we can run the forward without any calling convention changes. - fw_module, _bw_module = default_partition( # noqa: F821 + fw_module, _bw_module = default_partition( # type: ignore[bad-argument-type] fx_g, args, @@ -1847,9 +1911,6 @@ def _aot_export_function( decompositions=decompositions, num_params_buffers=num_params_buffers, aot_id=next(AOT_COUNTER), - # For now there's no use case involving keeping input mutations in the graph - # (which we can only do in the inference case anyway). - # We can add this later if we need to. keep_inference_input_mutations=keep_input_mutations, dynamic_shapes=dynamic_shapes, aot_autograd_arg_pos_to_source=None, @@ -1862,7 +1923,9 @@ def _aot_export_function( fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) else: shape_env = fake_mode.shape_env - fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) + fake_flat_args, act_input_indices = process_inputs( + flat_args, aot_config, fake_mode, shape_env + ) # TODO: Improve the descs here with pytree information fake_flat_args_descs: list[AOTInput] = [ PlainAOTInput(i) for i in range(len(fake_flat_args)) @@ -1878,6 +1941,7 @@ def _aot_export_function( fake_mode, shape_env, ) + aot_state.fw_metadata.act_input_indices = act_input_indices aot_graph_capture = aot_stage1_graph_capture(aot_state, flat_fn) fx_g, meta = aot_stage2_export(aot_state, aot_graph_capture) diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index f80d81b2f2958..cf85187c6b84e 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -116,7 +116,7 @@ def checkable_node(node: fx.Node) -> bool: # replace the meta of downstream users. eg. one bug we've seen is: # # _local_scalar_dense_11: "Sym(u14)" = torch.ops.aten._local_scalar_dense.default(select_10); - # sym_sum_2: "Sym(u19 + u20 + u21)" = torch.sym_sum((_local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13)) # noqa: B950 + # sym_sum_2: "Sym(u19 + u20 + u21)" = torch.sym_sum((_local_scalar_dense_11, _local_scalar_dense_12, _local_scalar_dense_13)) # # Notice how _local_scalar_dense_11 is u14 but sym_sum_2's meta is incorrectly the old # pre-cse value of u19. @@ -219,7 +219,7 @@ def strip_overloads(gm: fx.GraphModule) -> None: gm.recompile() -def get_placeholders(graph: fx.Graph) -> fx.graph._node_list: +def get_placeholders(graph: fx.Graph) -> list[Any]: return graph.find_nodes(op="placeholder") diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index 687aabbb9c9cf..63eb37585f035 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -447,8 +447,8 @@ def graph_saver_helper( if dump_example_input: torch.save( args, - f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 - ) # noqa: E501 + f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", + ) def graph_saver_forward( gm: fx.GraphModule, example_inputs: list[torch.Tensor] diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index a779f5105531d..398e188cd8111 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -230,6 +230,10 @@ def remote_autograd_cache_default() -> bool | None: # activation reloading with prefetching when using separate streams (bwd graph) activation_reload_prefetch = False +# CPU ↔ GPU bandwidth in GB/s, used to estimate transfer times for prefetch +# scheduling. This is hardware-specific and should be set by the user. +activation_offload_cpu_gpu_bw: float = 50.0 + # If FakeTensor.data_ptr() should error. # This option is independent of AOTAutograd and torch.compile, but our policy # is to turn it off during torch.compile. @@ -399,7 +403,8 @@ def remote_autograd_cache_default() -> bool | None: # At runtime non contiguous tangents will be coerced to be contiguous. # This config changes this guess for tangents strides to be the same as outputs. # TODO(ivankobzarev): Remove this config once extra memory usage is investigated. -guess_tangent_strides_as_outputs = False +guess_tangent_strides_as_outputs = not is_fbcode() + # This is a temporary config to ensure all ranks take the same decision in the partitioner # it will ultimately be removed once we share size_hints across ranks through compiler collectives @@ -432,7 +437,7 @@ def remote_autograd_cache_default() -> bool | None: if TYPE_CHECKING: - from torch.utils._config_typing import * # noqa: F401, F403 + from torch.utils._config_typing import * # noqa: F403 # adds patch, save_config, invalid config checks, etc diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index fd225747147b7..8c112adfdfc03 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -19,10 +19,7 @@ _func_increment_nesting, # type: ignore[attr-defined] _grad_decrement_nesting, _grad_increment_nesting, - _jvp_decrement_nesting, - _jvp_increment_nesting, _propagate_functional_input_mutation, # type: ignore[attr-defined] - _unwrap_for_grad, _unwrap_functional_tensor, _wrap_for_grad, _wrap_functional_tensor, @@ -31,6 +28,11 @@ is_functorch_wrapped_tensor, set_inplace_requires_grad_allowed, ) +from torch._functorch.predispatch import ( + _jvp_decrement_nesting, + _jvp_increment_nesting, + _unwrap_for_grad, +) from torch._functorch.utils import argnums_t, exposed_in from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental import const_fold @@ -335,6 +337,23 @@ def vjp( return _vjp_with_argnums(func, *primals, has_aux=has_aux) +@contextlib.contextmanager +def _disable_inference_mode() -> Generator[None, None, None]: + # Disable inference_mode without clobbering grad_mode / fw_grad_mode. + # torch.inference_mode(False) unconditionally sets grad_mode=True and + # fw_grad_mode=True; we save and restore those to avoid that. + # No-op when inference_mode is already off. + if not torch.is_inference_mode_enabled(): + yield + return + prev_grad = torch.is_grad_enabled() + prev_fw_grad = torch._C._is_fwd_grad_enabled() + with torch.inference_mode(False): + torch._C._set_grad_enabled(prev_grad) + torch._C._set_fwd_grad_enabled(prev_fw_grad) + yield + + @contextlib.contextmanager def grad_increment_nesting() -> Generator[int, None, None]: try: @@ -438,13 +457,23 @@ def wrapper( f"cotangents: {treespec_pprint(cotangents_spec)}, " f"primal output: {treespec_pprint(primals_out_spec)}" ) - result = _autograd_grad( - flat_primals_out, - flat_diff_primals, - flat_cotangents, - retain_graph=retain_graph, - create_graph=create_graph, + # This closure runs after grad_increment_nesting exits, so + # inference_mode may have been restored. Disable it for autograd. + # Skip under Dynamo — tracing through the generator CM emits + # spurious _enter_inference_mode nodes. + ctx = ( + contextlib.nullcontext() + if torch.compiler.is_compiling() + else _disable_inference_mode() ) + with ctx: + result = _autograd_grad( + flat_primals_out, + flat_diff_primals, + flat_cotangents, + retain_graph=retain_graph, + create_graph=create_graph, + ) return tree_unflatten(result, primals_spec) if has_aux: diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 9fc1d14321035..7786d7e5dcc06 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -44,7 +44,7 @@ find_symbol_binding_fx_nodes, free_symbols, is_symbol_binding_fx_node, - size_hint, + optimization_hint, statically_known_false, statically_known_true, ) @@ -67,6 +67,7 @@ SavedForBackwardsNoVcCheckAOTOutput, ) from ._aot_autograd.functional_utils import _is_functional_graph +from ._aot_autograd.graph_compile import is_opaque_node from ._aot_autograd.logging_utils import get_aot_graph_name from ._aot_autograd.utils import ( _is_bwd_seed_offset, @@ -382,6 +383,12 @@ def _extract_graph_with_inputs_outputs( output_values.append(x) out = new_graph.output(tuple(output_values)) out.meta["desc"] = outputs_descs + # Snapshot stack traces on the output node before passes run, + # as later passes may strip stack_trace from individual nodes. + out.meta["output_stack_traces"] = [ + v.meta.get("stack_trace") if isinstance(v, fx.Node) else None + for v in output_values + ] new_graph.eliminate_dead_code() new_graph.lint() @@ -697,7 +704,7 @@ def calculate_range(dtype: torch.dtype) -> tuple[float, float]: return info.min, info.max -def quantize_activation_fw(graph: torch.fx.Graph) -> None: +def quantize_activation_fw(graph: torch.fx.Graph, num_fwd_outputs: int = 0) -> None: output = graph.find_nodes(op="output")[0] fwd_outputs = output.args[0] quant_type = get_quant_type() @@ -706,6 +713,14 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None: tensor_scale_nodes: list[fx.Node] = [] sym_scale_nodes: list[fx.Node] = [] for position, node in enumerate(fwd_outputs): + # Don't quantize user-visible forward outputs. A tensor may appear as + # both a user output and a saved-for-backward activation (same FX node + # at two positions). Quantizing the user output position would: + # 1. Return fp8 to the user instead of the original precision + # 2. Create duplicate fp8_quant/fp8_scale backward placeholders that + # shift the stride mapping in _aot_stage2b_bw_compile (T264303372) + if position < num_fwd_outputs: + continue # check if the activation node is the node saved for quantization if node.meta.get("saved_for_quantization", False): # case: use scaling @@ -846,6 +861,7 @@ def perform_fp8_activation_quantization( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, bwd_module_inputs: dict[str, fx.Node], + num_fwd_outputs: int = 0, ) -> None: trace_structured( "artifact", @@ -858,7 +874,7 @@ def perform_fp8_activation_quantization( ), ) - quantize_activation_fw(fwd_module.graph) + quantize_activation_fw(fwd_module.graph, num_fwd_outputs) trace_structured( "artifact", @@ -937,6 +953,7 @@ def enable_activation_quantization( fwd_module: fx.GraphModule, bwd_module: fx.GraphModule, static_lifetime_input_nodes: OrderedSet[fx.Node] | None = None, + num_fwd_outputs: int = 0, ) -> None: static_input_names: list[str] = ( [node.name for node in static_lifetime_input_nodes] @@ -968,13 +985,16 @@ def enable_activation_quantization( should_perform_fp8_quant = True if should_perform_fp8_quant: - perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs) + perform_fp8_activation_quantization( + fwd_module, bwd_module, bwd_module_inputs, num_fwd_outputs + ) def _extract_fwd_bwd_modules( joint_module: fx.GraphModule, saved_values: list[fx.Node], saved_sym_nodes: list[fx.Node], + saved_opaque_nodes: list[fx.Node] | None = None, *, num_fwd_outputs: int, static_lifetime_input_nodes: OrderedSet[fx.Node] | None = None, @@ -1007,9 +1027,16 @@ def _extract_fwd_bwd_modules( bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] backward_state_inputs = [*filter(_is_backward_state, placeholders)] + if saved_opaque_nodes is None: + saved_opaque_nodes = [] + bwd_graph = _extract_graph_with_inputs_outputs( joint_module.graph, - saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs, + saved_sym_nodes + + saved_opaque_nodes + + saved_values + + tangent_inputs + + bwd_seed_offset_inputs, bwd_outputs, bwd_outputs_descs, "backward", @@ -1023,6 +1050,7 @@ def _extract_fwd_bwd_modules( if not node.users: _remove_by_name(saved_values, node.name) _remove_by_name(saved_sym_nodes, node.name) + _remove_by_name(saved_opaque_nodes, node.name) # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, # but this dead activation is actually a collective, # then the collective will generally by followed by a wait_tensor() call. @@ -1034,6 +1062,7 @@ def _extract_fwd_bwd_modules( ): _remove_by_name(saved_values, node.name) _remove_by_name(saved_sym_nodes, node.name) + _remove_by_name(saved_opaque_nodes, node.name) elif _is_backward_state(node): # BackwardState is saved directly _remove_by_name(saved_values, node.name) @@ -1117,18 +1146,25 @@ def _extract_fwd_bwd_modules( # Now, we re-generate the fwd/bwd graphs. # NB: This might increase compilation time, but I doubt it matters - # Convention for saved acts is (tensors_with_vc_check, tensors_no_vc_check, opaque_objects, symints) + # Convention for saved acts is (tensors_with_vc_check, tensors_no_vc_check, opaque_objects, symints, opaque_nodes) fwd_graph = _extract_graph_with_inputs_outputs( joint_module.graph, primal_inputs + fwd_seed_offset_inputs, - fwd_outputs + saved_values + saved_opaque_objects + saved_sym_nodes, + fwd_outputs + + saved_values + + saved_opaque_objects + + saved_opaque_nodes + + saved_sym_nodes, fwd_outputs_descs + [ SavedForBackwardsNoVcCheckAOTOutput(i) if i >= no_vc_check_start_idx and i < len(saved_values) else SavedForBackwardsAOTOutput(i) for i in range( - len(saved_values) + len(saved_opaque_objects) + len(saved_sym_nodes) + len(saved_values) + + len(saved_opaque_objects) + + len(saved_opaque_nodes) + + len(saved_sym_nodes) ) ], "forward", @@ -1139,6 +1175,7 @@ def _extract_fwd_bwd_modules( saved_sym_nodes + saved_values + saved_opaque_objects + + saved_opaque_nodes + tangent_inputs + bwd_seed_offset_inputs + backward_state_inputs, @@ -1182,7 +1219,11 @@ def _extract_fwd_bwd_modules( is not None ): enable_activation_quantization( - saved_values, fwd_module, bwd_module, static_lifetime_input_nodes + saved_values, + fwd_module, + bwd_module, + static_lifetime_input_nodes, + num_fwd_outputs, ) return fwd_module, bwd_module @@ -1267,6 +1308,7 @@ def default_partition( saved_values = [] saved_sym_nodes = [] + saved_opaque_nodes = [] distributed_enabled = torch.distributed.is_available() @@ -1326,14 +1368,23 @@ def is_impure(node: fx.Node) -> bool: # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE. continue if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE: - saved_values.append(node) + if is_opaque_node(node): + saved_opaque_nodes.append(node) + else: + saved_values.append(node) continue if is_impure(node): if graph_has_recomputable_ops: raise AssertionError( f"Trying to apply AC on a graph with impure op: {node}, {node.target}" ) - saved_values.append(node) + if is_opaque_node(node): + saved_opaque_nodes.append(node) + else: + saved_values.append(node) + continue + if is_opaque_node(node): + saved_opaque_nodes.append(node) continue if not is_tensor(node) and node.op == "call_function": raise AssertionError(f"Expected {node} to be a tensor") @@ -1354,6 +1405,7 @@ def is_impure(node: fx.Node) -> bool: saved_values = list(dict.fromkeys(saved_values).keys()) saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) + saved_opaque_nodes = list(dict.fromkeys(saved_opaque_nodes).keys()) if config._sync_decision_cross_ranks: saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values) @@ -1364,6 +1416,7 @@ def is_impure(node: fx.Node) -> bool: joint_module, saved_values, saved_sym_nodes=saved_sym_nodes, + saved_opaque_nodes=saved_opaque_nodes, num_fwd_outputs=num_fwd_outputs, static_lifetime_input_nodes=static_lifetime_input_nodes, ) @@ -1415,7 +1468,7 @@ def _size_of(node: fx.Node) -> int: def object_nbytes(x: object) -> int: if not isinstance(x, torch.Tensor): return 0 - return _tensor_nbytes(size_hint(x.numel(), fallback=4096), x.dtype) + return _tensor_nbytes(optimization_hint(x.numel(), fallback=4096), x.dtype) if "val" in node.meta: val = node.meta["val"] @@ -1706,7 +1759,7 @@ def get_device(node: fx.Node) -> torch.device | None: return torch.device("cpu") def get_sample_rng_state(device: torch.device | None) -> torch.Tensor: - from torch._guards import detect_fake_mode # noqa: F401 + from torch._guards import detect_fake_mode fake_mode = detect_fake_mode() if fake_mode is None: @@ -2300,9 +2353,12 @@ def ban_recomputation_if_allowed(node: fx.Node, reason: str = "") -> bool: weight = float(sym_node_size(node)) cannot_save_reason = None elif is_non_tensor_node: - # FakeScriptObjects (opaque objects) should have weight 0.0 so they can be - # properly partitioned between forward and backward, like BackwardState. - if isinstance(node.meta.get("val"), (BackwardState, FakeScriptObject)): + # FakeScriptObjects and opaque objects should have weight 0.0 + # so they can be properly partitioned between forward and + # backward, like BackwardState. + if isinstance( + node.meta.get("val"), (BackwardState, FakeScriptObject) + ) or is_opaque_node(node): weight = 0.0 cannot_save_reason = None else: @@ -2804,7 +2860,7 @@ def get_default_op_list() -> OpTypes: aten.unsqueeze, aten.rsub, aten._to_copy, - ] # noqa: E501,B950 + ] recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias] recomputable_view_ops += [ aten.view, @@ -2854,7 +2910,7 @@ def get_default_op_list() -> OpTypes: aten.maximum, prims.iota, prims._low_memory_max_pool_offsets_to_indices, - ] # noqa: E501,B950 + ] # Natalia said that we should allow recomputing indexing :) default_recomputable_ops += [aten.index, aten.gather] default_recomputable_ops += view_ops @@ -2883,7 +2939,7 @@ def get_default_op_list() -> OpTypes: aten._efficient_attention_forward, aten.upsample_bilinear2d, aten._scaled_mm, - ] # noqa: E501,B950 + ] fusible_ops = recomputable_ops | random_ops return OpTypes( @@ -2958,7 +3014,7 @@ def _remove_symbols_without_guarding(x: torch.Tensor, fallback: int) -> torch.Te shape = list(x.shape) def realize_symbol(d: torch.SymInt | int) -> int: - return size_hint(d, fallback=fallback) + return optimization_hint(d, fallback=fallback) shape = [realize_symbol(s) for s in shape] stride = [realize_symbol(s) for s in x.stride()] @@ -2972,7 +3028,7 @@ def materialize_arg(x: Any) -> Any: if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor): return _remove_symbols_without_guarding(x.meta["val"], fallback=4096) elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt): - return size_hint(x.meta["val"], fallback=4096) + return optimization_hint(x.meta["val"], fallback=4096) elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat): return 1.0 elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool): @@ -3155,7 +3211,7 @@ def get_saved_values_knapsack( # if idx in all_recomputable_banned_nodes: try: dont_ban.add(all_recomputable_banned_nodes[idx]) - except BaseException: # noqa: B036 + except BaseException: pass if not dont_ban.issubset(all_recomputable_banned_nodes): @@ -3587,9 +3643,27 @@ def min_cut_rematerialization_partition( # pyrefly: ignore [unbound-name] if config._sync_decision_cross_ranks: saved_values = _sync_decision_cross_ranks(joint_graph, saved_values) + # save_for_backward on tensors and stashes symints in autograd .ctx - saved_sym_nodes = list(filter(is_sym_node, saved_values)) - saved_values = list(filter(lambda n: not is_sym_node(n), saved_values)) + # Skip SymBool nodes whose only consumers are _assert_scalar calls. + # These are runtime assertion intermediates and are not needed in backward + # for any real computation. + def _is_assert_only_symbool(n: fx.Node) -> bool: + return ( + isinstance(n.meta.get("val"), torch.SymBool) + and len(n.users) > 0 + and all(u.target is torch.ops.aten._assert_scalar.default for u in n.users) + ) + + saved_sym_nodes = list( + filter( + lambda n: is_sym_node(n) and not _is_assert_only_symbool(n), saved_values + ) + ) + saved_opaque_nodes = list(filter(is_opaque_node, saved_values)) + saved_values = list( + filter(lambda n: not is_sym_node(n) and not is_opaque_node(n), saved_values) + ) # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols fw_module, bw_module = _extract_fwd_bwd_modules( @@ -3597,6 +3671,7 @@ def min_cut_rematerialization_partition( saved_values, # pyrefly: ignore [bad-argument-type] saved_sym_nodes=saved_sym_nodes, + saved_opaque_nodes=saved_opaque_nodes, num_fwd_outputs=num_fwd_outputs, static_lifetime_input_nodes=node_info.static_lifetime_input_nodes, ) diff --git a/torch/_functorch/predispatch.py b/torch/_functorch/predispatch.py index ffa66ad8c2b74..c482ff18cefa5 100644 --- a/torch/_functorch/predispatch.py +++ b/torch/_functorch/predispatch.py @@ -14,12 +14,23 @@ from typing import TYPE_CHECKING import torch +from torch._C import ( + _enter_dual_level as _enter_dual_level_impl, + _exit_dual_level as _exit_dual_level_impl, +) from torch._C._functorch import ( _add_batch_dim as _add_batch_dim_impl, + _jvp_decrement_nesting as _jvp_decrement_nesting_impl, + _jvp_increment_nesting as _jvp_increment_nesting_impl, _remove_batch_dim as _remove_batch_dim_impl, + _unwrap_for_grad as _unwrap_for_grad_impl, _vmap_decrement_nesting as _vmap_decrement_nesting_impl, _vmap_increment_nesting as _vmap_increment_nesting_impl, ) +from torch._VF import ( # type: ignore[attr-defined] + _make_dual as _make_dual_impl, + _unpack_dual as _unpack_dual_impl, +) if TYPE_CHECKING: @@ -167,3 +178,120 @@ def _register_python_decomposition_vmap(decomp: torch._ops.OpOverload) -> None: _register_python_decomposition_vmap(torch.ops.aten.addr.default) DECOMPOSITIONS_LOADED = True + + +def _make_dual( + tensor: torch.Tensor, tangent: torch.Tensor, *, level: int = 0 +) -> torch.Tensor: + """ + Thin wrapper around torch._VF._make_dual that is used to proxy in + PT2 export/compile fx graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function( + _make_dual, (tensor, tangent), tensor, tangent, level=level + ) + + return _make_dual_impl(tensor, tangent, level=level) + + +def _unpack_dual( + tensor: torch.Tensor, *, level: int = 0 +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Thin wrapper around torch._VF._unpack_dual that is used to proxy in + PT2 export/compile fx graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function( + _unpack_dual, (tensor,), tensor, level=level + ) + + return _unpack_dual_impl(tensor, level=level) + + +def _jvp_increment_nesting() -> int: + """ + Thin wrapper around torch._C._functorch._jvp_increment_nesting that is + used to proxy in export/compile graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function( + _jvp_increment_nesting, + (), + ) + return _jvp_increment_nesting_impl() + + +def _jvp_decrement_nesting() -> int: + """ + Thin wrapper around torch._C._functorch._jvp_decrement_nesting that is + used to proxy in export/compile graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function( + _jvp_decrement_nesting, + (), + ) + return _jvp_decrement_nesting_impl() + + +def _unwrap_for_grad(tensor: torch.Tensor, level: int) -> torch.Tensor: + """ + Thin wrapper around torch._C._functorch._unwrap_for_grad that is used + to proxy in PT2 export/compile fx graph for functorch transforms (grad, vjp, jvp). + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function( + _unwrap_for_grad, (tensor,), tensor, level + ) + + return _unwrap_for_grad_impl(tensor, level) + + +def _enter_dual_level() -> int: + """ + Thin wrapper around torch._C._enter_dual_level that is used to proxy in + PT2 export/compile fx graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function(_enter_dual_level, ()) + return _enter_dual_level_impl() + + +def _exit_dual_level(*, level: int) -> None: + """ + Thin wrapper around torch._C._exit_dual_level that is used to proxy in + PT2 export/compile fx graph for forward-mode AD. + """ + from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export + + mode = _maybe_find_pre_dispatch_tf_mode_for_export() + + if mode: + return torch.overrides.handle_torch_function(_exit_dual_level, (), level=level) + return _exit_dual_level_impl(level=level) diff --git a/torch/_guards.py b/torch/_guards.py index 955fe18f760b4..1ae422d999c6e 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -36,7 +36,9 @@ from torch._dynamo.backends.distributed import DDPOptimizerContext from torch._dynamo.codegen import PyCodegen + from torch._dynamo.guards import GuardCheckSpec from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta + from torch._higher_order_ops.invoke_subgraph import NestedCompileRegionOptions from torch._subclasses.fake_tensor import FakeTensorMode @@ -739,10 +741,12 @@ def restore_graphstate(self, state: GuardsCheckpointState) -> None: class HopSubgraphCache: @abstractmethod - def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ... + def add_dynamo_installed_submodule( + self, fn_code: CodeType, identifier: str + ) -> None: ... @abstractmethod - def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... + def get_dynamo_installed_submodules(self, fn_code: CodeType) -> list[str]: ... @abstractmethod def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ... @@ -770,23 +774,89 @@ def get_lazy_bwd_entry( ) -> tuple[torch.fx.GraphModule | None, int | None]: ... +@dataclass +class InvokeSubgraphReuseEntry: + body_name: str + body_gmod: torch.fx.GraphModule + config: NestedCompileRegionOptions | None + subgraph_input_mapping: list[ + Any + ] # list[LiftedArgOrigin] - defined in invoke_subgraph.py + single_tensor_output: bool + # Per-output tensor metadata (shape, stride, dtype, device, requires_grad) + # cached from the first trace so we can construct fresh FakeTensors on + # cache hit without re-running body_gmod. + output_metadata: list[ + tuple[torch.Size, tuple[int, ...], torch.dtype, torch.device, bool] + ] + # 1-1 mapping to flat_vts: source for each flattened arg/kwarg, or None if + # the VT has no source. On cache hit, we build a source replacement mapping + # (old arg sources → new arg sources) to rewrite captured variable sources + # for the current invocation. + arg_sources: list[Source | None] + # Number of user-visible outputs (from the function return value). + # The graph may have additional outputs from side-effect intermediates; + # stamp_out_subgraph uses this to return only the user-visible slice. + num_user_outputs: int = 0 + + +@dataclass +class InvokeSubgraphReuseCondition: + # Per flattened input VT: (InputTag, metadata). + # (InputTag.TENSOR, TensorMetadata) + # (InputTag.SYMNODE, sym_num — same object implies same symbol) + # (InputTag.CONSTANT, value) + # (InputTag.MODULE, None) + # Tensor metadata is checked here because TENSOR_MATCH guards for + # subgraph inputs may already exist before tracing and thus won't + # appear in the guard delta. + input_checks: list[tuple[Any, object]] # list[tuple[InputTag, object]] + + # Guards captured during the trace (delta + source-mapped). + # Each entry: (source, handler, expected_value, guard) + # handler is a pre-resolved GuardCheckSpec from GUARD_VALUE_DISPATCH. + guards: list[tuple[Source, GuardCheckSpec, object, Guard]] + + # TreeSpec from pytree.tree_flatten of the (args, kwargs) structure. + # On cache hit, we verify the new call has the same treespec. + treespec: pytree.TreeSpec | None = None + + # All sources accessed via VariableBuilder during the subgraph trace. + # On cache hit, we check if any modified VT's source is a base of one + # of these to detect mutations on captured variables. + traced_sources: OrderedSet[Source] = dataclasses.field(default_factory=OrderedSet) + + class InvokeSubgraphCache(HopSubgraphCache): def __init__(self) -> None: self.autograd_cache: dict[str, Callable] = {} self.proxy_dispatch_cache: dict[str, Callable] = {} - self.dynamo_installed_submodules: dict[int, list[str]] = defaultdict(list) + self.dynamo_installed_submodules: dict[CodeType, list[str]] = defaultdict(list) self.lazy_bwd_cache: dict[ str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] ] = defaultdict(dict) self.effects_cache: dict[ str, set ] = {} # Maps identifier -> set of effect types + # fn.__code__ → list of (condition, cache_entry) pairs. Walked linearly + # on lookup; first matching condition wins. + self.subgraph_reuse_cache: dict[ + CodeType, + list[tuple[InvokeSubgraphReuseCondition, InvokeSubgraphReuseEntry]], + ] = defaultdict(list) + # fn_code → {hash_key → cache_entry}. Used by user-provided + # reuse_hash_fn for O(1) subgraph reuse lookup. + self.subgraph_reuse_key_cache: dict[ + CodeType, dict[int, InvokeSubgraphReuseEntry] + ] = defaultdict(dict) - def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: - self.dynamo_installed_submodules[fn_id].append(identifier) + def add_dynamo_installed_submodule( + self, fn_code: CodeType, identifier: str + ) -> None: + self.dynamo_installed_submodules[fn_code].append(identifier) - def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: - return self.dynamo_installed_submodules.get(fn_id, []) + def get_dynamo_installed_submodules(self, fn_code: CodeType) -> list[str]: + return self.dynamo_installed_submodules.get(fn_code, []) def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: self.autograd_cache[identifier] = key @@ -835,6 +905,65 @@ def get_effects(self, identifier: str) -> set | None: """Retrieve the effect types for a given invoke_subgraph identifier.""" return self.effects_cache.get(identifier, None) + def add_reuse_entry( + self, + fn_code: CodeType, + condition: InvokeSubgraphReuseCondition, + entry: InvokeSubgraphReuseEntry, + max_reuse_entries: int = 8, + ) -> None: + entries = self.subgraph_reuse_cache[fn_code] + if len(entries) >= max_reuse_entries: + raise RuntimeError( + f"invoke_subgraph: exceeded maximum reuse entries " + f"({max_reuse_entries}) for function code {fn_code}. " + f"This most likely means a guard keeps failing on every " + f"invocation, preventing subgraph reuse. " + f"Set TORCH_LOGS='+hierarchical_compile' to identify which " + f"guard is failing. If reuse is genuinely not possible and " + f"you need more cache entries, increase the limit via the " + f"max_reuse_entries argument to nested_compile_region()." + ) + entries.append((condition, entry)) + + def find_reuse_entry( + self, + fn_code: CodeType, + evaluator: Callable[ + [InvokeSubgraphReuseCondition, InvokeSubgraphReuseEntry], bool + ], + ) -> InvokeSubgraphReuseEntry | None: + entries = self.subgraph_reuse_cache.get(fn_code, []) + for i, (condition, entry) in enumerate(entries): + if evaluator(condition, entry): + # MRU: move the hit entry to the front for faster future lookups + if i > 0: + entries.insert(0, entries.pop(i)) + return entry + return None + + def find_reuse_entry_by_key( + self, fn_code: CodeType, hash_key: int + ) -> InvokeSubgraphReuseEntry | None: + return self.subgraph_reuse_key_cache.get(fn_code, {}).get(hash_key) + + def add_reuse_entry_by_key( + self, + fn_code: CodeType, + hash_key: int, + entry: InvokeSubgraphReuseEntry, + max_reuse_entries: int = 8, + ) -> None: + key_cache = self.subgraph_reuse_key_cache[fn_code] + if len(key_cache) >= max_reuse_entries and hash_key not in key_cache: + raise RuntimeError( + f"invoke_subgraph: exceeded maximum reuse entries " + f"({max_reuse_entries}) for function code {fn_code} (hash-key path). " + f"Increase the limit via the max_reuse_entries argument to " + f"nested_compile_region()." + ) + key_cache[hash_key] = entry + class HopDispatchSetCache: def __init__(self) -> None: @@ -1188,7 +1317,7 @@ def __reduce__(self): # The _hash is a cached value that can be nondeterministically computed # (e.g., based on id() of objects), so it should not affect pickling. fields = dataclasses.fields(self) - field_values = tuple(getattr(self, f.name) for f in fields) + field_values = tuple(getattr(self, f.name) for f in fields if f.init) return (self.__class__, field_values) new_cls.__hash__ = __hash__ @@ -1240,7 +1369,7 @@ def get_value( self, globals: dict[str, Any], locals: dict[str, Any], - cache: weakref.WeakKeyDictionary[Source, Any], + cache: dict[Source, Any], ) -> Any: if self in cache: return cache[self] @@ -1298,7 +1427,7 @@ def get_value( self, globals: dict[str, Any], locals: dict[str, Any], - cache: weakref.WeakKeyDictionary[Source, Any], + cache: dict[Source, Any], ) -> Any: if self in cache: return cache[self] diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 29b2fc1ba4b33..09bd8d188112e 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -20,6 +20,7 @@ ) from torch._higher_order_ops.foreach_map import _foreach_map, foreach_map from torch._higher_order_ops.hints_wrap import hints_wrapper +from torch._higher_order_ops.inline_asm_elementwise import inline_asm_elementwise from torch._higher_order_ops.invoke_leaf_function import invoke_leaf_function from torch._higher_order_ops.invoke_subgraph import invoke_subgraph from torch._higher_order_ops.local_map import local_map_hop @@ -81,4 +82,5 @@ "local_map_hop", "print", "inductor_compiled_code", + "inline_asm_elementwise", ] diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index 7c2d8d9828f87..bf50d44a70e01 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -460,12 +460,17 @@ def can_auto_functionalize( # Skip schema returns -> None return True if isinstance(op, OpOverload): - # The returns of OpOverload must not alias anything - for ret in schema.returns: - if ret.alias_info is None and type(ret.type) is torch.TensorType: - continue - # Not yet supported: List[Tensor] return. - return False + if torch._library.utils.is_out(op): + # Out ops have aliased returns (returns alias the mutable args). + # This is fine because the mutable args are write-only output buffers. + pass + else: + # The returns of OpOverload must not alias anything + for ret in schema.returns: + if ret.alias_info is None and type(ret.type) is torch.TensorType: + continue + # Not yet supported: List[Tensor] return. + return False if torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), "Functionalize"): return False return True @@ -675,11 +680,74 @@ def _maybe_functionalize(name: str, arg: Any) -> Any: else: normalized_kwargs[arg.name] = arg.default_value - # List of the name of args that get mutated (according to the schema) + if isinstance(op, OpOverload) and torch._library.utils.is_out(op): + return _do_auto_functionalize_v2_for_out_operator( + ctx, op, schema, normalized_kwargs + ) + return _do_auto_functionalize_v2_for_generic_mutable_operator( + ctx, op, schema, normalized_kwargs + ) + + +def _do_auto_functionalize_v2_for_out_operator( + ctx: Any, + op: OpOverload, + schema: Any, + normalized_kwargs: dict[str, Any], +) -> Any: + """Handle functionalization for out= operators. + + Out= operators have write-only mutable args. These are not inputs to + auto_functionalized_v2; we encode their tensor properties so the dense + impl can create empty tensors. + """ + mutable_args_names, _ = get_mutable_args_from_schema(schema) + + # Save references to the original FunctionalTensors for post-call sync + out_arg_originals = [normalized_kwargs[name] for name in mutable_args_names] + + # Encode tensor properties and remove the actual out tensors from kwargs + for arg_name in mutable_args_names: + arg = normalized_kwargs[arg_name] + normalized_kwargs[f"_{arg_name}_size"] = arg.size() + normalized_kwargs[f"_{arg_name}_stride"] = arg.stride() + normalized_kwargs[f"_{arg_name}_dtype"] = arg.dtype + normalized_kwargs[f"_{arg_name}_device"] = arg.device + del normalized_kwargs[arg_name] + + unwrapped_kwargs = ctx.unwrap_tensors(normalized_kwargs) # type: ignore[arg-type] + auto_func_kwargs = dict(unwrapped_kwargs, _all_bases=[]) + + with ctx.redispatch_to_next(): + unwrapped_outs = auto_functionalized_v2( + op, + **auto_func_kwargs, # type: ignore[arg-type] + ) + + # For out= ops, the functional HOP returns exactly the out tensors. + # auto_functionalized_v2 returns a bare tensor for single return, + # tuple for multiple. + results = unwrapped_outs if isinstance(unwrapped_outs, tuple) else (unwrapped_outs,) + for orig_arg, result in zip(out_arg_originals, results, strict=True): + ctx.replace(orig_arg, result) + ctx.commit_update(orig_arg) + ctx.sync(orig_arg) + + return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type] + + +def _do_auto_functionalize_v2_for_generic_mutable_operator( + ctx: Any, + op: Any, + schema: Any, + normalized_kwargs: dict[str, Any], +) -> Any: + """Handle functionalization for generic mutable operators via the + auto_functionalized_v2 HOP.""" mutable_args_names, mutable_args_types = get_mutable_args_from_schema(schema) # A list of all bases of mutable args without duplication - all_bases = [] + all_bases: list[Any] = [] all_bases_addresses: list[int] = [] # Map arg_name to the index of its base in all_bases. @@ -922,11 +990,16 @@ def auto_functionalized_v2_dense( raise AssertionError(f"Expected HopSchema, got {type(schema)}") _callable_op = HopInstance(_mutable_op, schema) + _is_out = isinstance(_mutable_op, OpOverload) and torch._library.utils.is_out( + _mutable_op + ) + op_kwargs_new, all_bases_new = _generate_new_op_kwargs_from_bases( schema, kwargs, _all_bases, _only_clone_these_bases, + _is_out, ) out = call_op( @@ -935,6 +1008,9 @@ def auto_functionalized_v2_dense( op_kwargs_new, ) + if _is_out: + return out # type: ignore[return-value] + if isinstance(out, tuple): return (*out, *all_bases_new) # type: ignore[return-value] else: @@ -942,9 +1018,25 @@ def auto_functionalized_v2_dense( def _generate_new_op_kwargs_from_bases( - schema, kwargs, all_bases, _only_clone_these_bases + schema, kwargs, all_bases, _only_clone_these_bases, _is_out ): mutable_args_names, mutable_args_types = get_mutable_args_from_schema(schema) + + if _is_out: + # For out= ops, _all_bases is empty. Create empty tensors from the + # metadata that was encoded by _do_auto_functionalize_v2_for_out_operator. + new_kwargs = dict(**kwargs) + created_out_tensors = [] + for arg_name in mutable_args_names: + size = new_kwargs.pop(f"_{arg_name}_size") + stride = new_kwargs.pop(f"_{arg_name}_stride") + dtype = new_kwargs.pop(f"_{arg_name}_dtype") + device = new_kwargs.pop(f"_{arg_name}_device") + t = torch.empty_strided(size, stride, dtype=dtype, device=device) + new_kwargs[arg_name] = t + created_out_tensors.append(t) + return new_kwargs, created_out_tensors + args_view_info = read_view_information_from_args( mutable_args_names, mutable_args_types, kwargs, all_bases ) @@ -1031,6 +1123,7 @@ def auto_functionalized_v2_proxy( {k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")}, all_bases, _only_clone_these_bases, + _is_out=False, ) _, materialized_kwargs = materialize_callable_in_args( diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index 208abce90a0b0..4f05618d38806 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -241,12 +241,9 @@ def _validate_input(pred, true_fn, false_fn, operands): def _cond_op_wrapper(*args, **kwargs): return cond_op(*args, **kwargs) - from torch._higher_order_ops.utils import setup_compilation_env + from torch._higher_order_ops.utils import _hop_compile_and_call - with setup_compilation_env() as backend: - return torch.compile(_cond_op_wrapper, backend=backend, fullgraph=True)( - pred, true_fn, false_fn, operands - ) + return _hop_compile_and_call(_cond_op_wrapper, (pred, true_fn, false_fn, operands)) def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 7ecda5dcabd3b..9949aba8b69b2 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -86,7 +86,7 @@ def _permute_strides(out: torch.Tensor, query_strides: tuple[int, ...]) -> torch fill_order = [last_dim] + fill_order out_strides = _construct_strides(out.shape, fill_order) - new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) + new_out = out.new_empty_strided(out.shape, out_strides) new_out.copy_(out) return new_out @@ -966,6 +966,17 @@ def sdpa_dense_backward( f"got query.dtype={query.dtype}, key.dtype={key.dtype}, " f"and value.dtype={value.dtype}" ) + if joint_graph is None: + example_vals = ( + query.new_zeros((), requires_grad=True), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + query.new_zeros((), dtype=torch.int), + ) + _, joint_graph = create_fw_bw_graph( + fw_graph, example_vals, score_mod_other_buffers + ) from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex Bq, Hq, seq_len_q, qk_head_dim = query.shape @@ -1012,7 +1023,9 @@ def _maybe_new_buffer( if grad_logsumexp is None: grad_logsumexp = torch.zeros_like(logsumexp) - # We're undoing the log -> log2 change of base in the forwards + # logsumexp is expected in log2 scale (as returned by the forward HOP). + # The public flex_attention API converts lse to natural log before returning, + # so callers using the public API must not pass that value here directly. logsumexp = logsumexp * math.log(2) # The backwards formula for the log -> log2 change of base in the forwards grad_logsumexp = grad_logsumexp / math.log(2) diff --git a/torch/_higher_order_ops/inline_asm_elementwise.py b/torch/_higher_order_ops/inline_asm_elementwise.py new file mode 100644 index 0000000000000..f23f8dc626503 --- /dev/null +++ b/torch/_higher_order_ops/inline_asm_elementwise.py @@ -0,0 +1,299 @@ +# mypy: allow-untyped-defs +import functools +import re + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._ops import HigherOrderOperator +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree + + +__all__ = ["inline_asm_elementwise"] + + +class InlineAsmElementwiseOp(HigherOrderOperator): + """Execute inline PTX assembly elementwise over tensors. + + This is an elementwise map where the function body is inline assembly. + Input tensors are implicitly broadcast to the same shape. + + Each invocation of the inline asm processes ``pack`` elements at a time. + Exactly which set of inputs a given invocation receives is unspecified. + + Output strides follow PyTorch's standard pointwise striding propagation + rules. + + In eager mode, the assembly is executed via the CUDA Jiterator. Under + ``torch.compile`` the assembly is lowered to Triton's + ``tl.inline_asm_elementwise`` via Inductor, which allows fusion with + surrounding operators. + + Args: + *inputs: Input tensors whose values are passed to the asm block. + asm_str: PTX assembly string. Operands use ``$N`` syntax + (e.g. ``$0`` for the first output, ``$1`` for the first input). + constraints: Inline-asm constraints in LLVM format. Output constraints + are prefixed with ``=`` (e.g. ``"=f,f,f"`` for one float output + and two float inputs). + dtype: Element type of the returned tensor. + is_pure: Must be ``True``. If true, the compiler may assume the asm + block has no side-effects. + pack: Number of elements processed per asm invocation. When + ``pack > 1``, the constraint string must list ``pack`` outputs + and ``pack`` copies of each input. Requires ``torch.compile``. + + Returns: + A tensor with the broadcast shape of the inputs and the given dtype. + + Example:: + + >>> # xdoctest: +SKIP(requires CUDA) + >>> # Float32 fused multiply-add via PTX + >>> result = inline_asm_elementwise( + ... a, b, c, + ... asm_str="fma.rn.f32 $0, $1, $2, $3;", + ... constraints="=f,f,f,f", + ... dtype=torch.float32, + ... ) + + >>> # xdoctest: +SKIP(requires CUDA) + >>> # pack=2: each asm invocation processes two elements + >>> result = inline_asm_elementwise( + ... x, + ... asm_str="mov.b32 $0, $2; mov.b32 $1, $3;", + ... constraints="=r,=r,r,r", + ... dtype=torch.float32, + ... pack=2, + ... ) + """ + + def __init__(self): + super().__init__("inline_asm_elementwise") + + def __call__( + self, + *inputs: torch.Tensor, + asm_str: str, + constraints: str, + dtype: torch.dtype, + is_pure: bool = True, + pack: int = 1, + ) -> torch.Tensor: + if not is_pure: + raise ValueError("inline_asm_elementwise only supports is_pure=True") + # pyrefly: ignore [missing-attribute] + return super().__call__( + *inputs, + asm_str=asm_str, + constraints=constraints, + dtype=dtype, + is_pure=True, + pack=pack, + ) + + +inline_asm_elementwise = InlineAsmElementwiseOp() + + +def _parse_constraints(constraints: str) -> tuple[int, int]: + parts = [p.strip() for p in constraints.split(",")] + n_outputs = sum(1 for p in parts if p.startswith("=")) + n_inputs = len(parts) - n_outputs + return n_outputs, n_inputs + + +_DTYPE_TO_CUDA_TYPE = { + torch.float32: "float", + torch.float64: "double", + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.int32: "int", + torch.int64: "long long", + torch.int16: "short", + torch.int8: "signed char", + torch.uint8: "unsigned char", + torch.uint16: "unsigned short", + torch.uint32: "unsigned int", + torch.bool: "bool", +} + + +_TRITON_ARG_RE = re.compile(r"\$(\d+)") + + +def _triton_asm_to_cuda_asm(asm_str: str) -> str: + return _TRITON_ARG_RE.sub(r"%\1", asm_str) + + +@functools.lru_cache +def _get_jiterator_fn( + asm_str: str, + constraints: str, + n_inputs: int, + input_dtype: torch.dtype, + output_dtype: torch.dtype, +): + from torch.cuda.jiterator import _create_jit_fn + + cuda_asm = _triton_asm_to_cuda_asm(asm_str) + + constraint_parts = [p.strip() for p in constraints.split(",")] + output_constraints = [p.lstrip("=") for p in constraint_parts if p.startswith("=")] + input_constraints = [p for p in constraint_parts if not p.startswith("=")] + + if input_dtype not in _DTYPE_TO_CUDA_TYPE: + raise ValueError(f"Unsupported input dtype for inline asm: {input_dtype}") + if output_dtype not in _DTYPE_TO_CUDA_TYPE: + raise ValueError(f"Unsupported output dtype for inline asm: {output_dtype}") + + input_type = _DTYPE_TO_CUDA_TYPE[input_dtype] + output_type = _DTYPE_TO_CUDA_TYPE[output_dtype] + + input_params = ", ".join(f"{input_type} in{i}" for i in range(n_inputs)) + out_constraints_str = ", ".join(f'"={c}"(result)' for c in output_constraints) + in_constraints_str = ", ".join( + f'"{c}"(in{i})' for i, c in enumerate(input_constraints) + ) + escaped_asm = ( + cuda_asm.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n") + ) + + code = f""" +template +{output_type} inline_asm_kernel({input_params}) {{ + {output_type} result; + asm volatile ( + "{escaped_asm}" + : {out_constraints_str} + : {in_constraints_str} + ); + return result; +}} +""" + + return _create_jit_fn(code) + + +def _inline_asm_dense(*inputs, asm_str, constraints, dtype, is_pure, pack): + if not inputs: + raise ValueError("inline_asm_elementwise requires at least one input tensor") + + inputs = torch.broadcast_tensors(*inputs) + + if not inputs[0].is_cuda: + raise RuntimeError("inline_asm_elementwise only supports CUDA tensors") + + if pack > 1: + raise RuntimeError( + "inline_asm_elementwise with pack > 1 requires torch.compile" + ) + + n_outputs, n_inputs = _parse_constraints(constraints) + + if n_outputs != 1: + raise ValueError(f"Expected 1 output constraint, got {n_outputs}") + + if n_inputs != len(inputs): + raise ValueError( + f"Constraint string specifies {n_inputs} inputs but got " + f"{len(inputs)} tensor(s)" + ) + + # Jiterator generates a single input type for all inputs — mixed dtypes + # would produce incorrect CUDA code. + input_dtypes = {inp.dtype for inp in inputs} + if len(input_dtypes) > 1: + raise ValueError( + f"All inputs must have the same dtype for eager execution, " + f"got {sorted(str(d) for d in input_dtypes)}" + ) + + jit_fn = _get_jiterator_fn( + asm_str=asm_str, + constraints=constraints, + n_inputs=len(inputs), + input_dtype=inputs[0].dtype, + output_dtype=dtype, + ) + + return jit_fn(*inputs) + + +@inline_asm_elementwise.py_impl(DispatchKey.CompositeExplicitAutograd) +def _(*inputs, asm_str, constraints, dtype, is_pure=True, pack=1): + return _inline_asm_dense( + *inputs, + asm_str=asm_str, + constraints=constraints, + dtype=dtype, + is_pure=is_pure, + pack=pack, + ) + + +inline_asm_elementwise.py_autograd_impl( + autograd_not_implemented(inline_asm_elementwise, deferred_error=True) +) + + +def _elementwise_output_like(*inputs, dtype): + from torch._prims_common import compute_elementwise_output_logical_to_physical_perm + + broadcasted = torch.broadcast_tensors(*inputs) + l2p_perm, _ = compute_elementwise_output_logical_to_physical_perm(*broadcasted) + return torch.empty_permuted( + broadcasted[0].shape, l2p_perm, dtype=dtype, device=broadcasted[0].device + ) + + +@inline_asm_elementwise.py_impl(FakeTensorMode) +def _(mode, *inputs, asm_str, constraints, dtype, is_pure=True, pack=1): + with mode: + return _elementwise_output_like(*inputs, dtype=dtype) + + +@inline_asm_elementwise.py_impl(ProxyTorchDispatchMode) +def _(mode, *inputs, asm_str, constraints, dtype, is_pure=True, pack=1): + proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, inputs) + + out_proxy = mode.tracer.create_proxy( + "call_function", + inline_asm_elementwise, + proxy_args, + { + "asm_str": asm_str, + "constraints": constraints, + "dtype": dtype, + "is_pure": is_pure, + "pack": pack, + }, + name="inline_asm_elementwise", + ) + + out = inline_asm_elementwise( + *inputs, + asm_str=asm_str, + constraints=constraints, + dtype=dtype, + is_pure=is_pure, + pack=pack, + ) + return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) + + +@inline_asm_elementwise.py_functionalize_impl +def _(ctx, *inputs, asm_str, constraints, dtype, is_pure=True, pack=1): + unwrapped_inputs = ctx.unwrap_tensors(inputs) + + with ctx.redispatch_to_next(): + res = inline_asm_elementwise( + *unwrapped_inputs, + asm_str=asm_str, + constraints=constraints, + dtype=dtype, + pack=pack, + ) + return ctx.wrap_tensors(res) diff --git a/torch/_higher_order_ops/invoke_leaf_function.py b/torch/_higher_order_ops/invoke_leaf_function.py index bd8822809e161..a4f17b182d90f 100644 --- a/torch/_higher_order_ops/invoke_leaf_function.py +++ b/torch/_higher_order_ops/invoke_leaf_function.py @@ -150,7 +150,7 @@ def _lms_to_attr_dict(val: Any) -> Any: indices: list[int] = [] for expr in mutates_args: # Empty __builtins__ prevents access to builtins like __import__, open, exec. - result = eval(expr, {"__builtins__": {}}, namespace) # noqa: S307 + result = eval(expr, {"__builtins__": {}}, namespace) leaves = pytree.tree_leaves(result) for sentinel in leaves: if not isinstance(sentinel, int): @@ -520,7 +520,7 @@ def __call__( input_spec, mutated_arg_indices, *flat_args, - requires_grad_indices=(), + requires_grad_indices="", ): """ real_fn_callable: _LeafCallable wrapping the real function @@ -529,7 +529,8 @@ def __call__( mutated_arg_indices: comma-separated string of flat-arg indices that are declared as mutated (e.g. "1,2"), or "" for no mutations. Encoded as a string so it is a pytree leaf for the HOP schema infrastructure. - requires_grad_indices: tuple of indices for inputs that require grad + requires_grad_indices: comma-separated string of flat-arg indices that + require grad (e.g. "0,1"), or "" for none. """ return super().__call__( # type: ignore[attr-defined] real_fn_callable, @@ -548,7 +549,7 @@ def gen_schema( input_spec, mutated_arg_indices, *flat_args, - requires_grad_indices=(), + requires_grad_indices="", ): from torch._higher_order_ops.schema import HopSchemaGenerator from torch._higher_order_ops.utils import _maybe_fake_prop_ignore_unbacked @@ -583,6 +584,12 @@ def run_fake(*unfunc_flat_args): gen.add_arg("mutated_arg_indices", mutated_arg_indices) for i, arg in enumerate(flat_args): gen.add_arg(f"arg{i}", arg, is_mutated=i in mutated_set) + gen.add_arg( + "requires_grad_indices", + requires_grad_indices, + default_value="", + kw_only=True, + ) if isinstance(fake_outputs, tuple): for out in fake_outputs: @@ -641,8 +648,8 @@ def forward( include_keys = torch._C._dispatch_tls_local_include_set() exclude_keys = torch._C._dispatch_tls_local_exclude_set() - requires_grad_indices = tuple( - i + requires_grad_indices = ",".join( + str(i) for i, arg in enumerate(flat_args) if isinstance(arg, torch.Tensor) and arg.requires_grad ) @@ -698,6 +705,41 @@ def fake_backward(*grads): requires_grad_indices=requires_grad_indices, ) + hook_real = getattr(real_fn_callable, "_leaf_hook_real_fn", None) + hook_fake = getattr(real_fn_callable, "_leaf_hook_fake_fn", None) + if hook_real is not None: + assert hook_fake is not None # noqa: S101 + hook_captured_out_spec: list[pytree.TreeSpec | None] = [None] + wrapped_hook_real, wrapped_hook_fake = make_leaf_function_wrappers( + hook_real, hook_fake, hook_captured_out_spec + ) + hook_real_callable = _LeafCallable(wrapped_hook_real) + hook_fake_callable = _LeafCallable(wrapped_hook_fake) + + grad_tensors = [ + arg + for arg in flat_args + if isinstance(arg, torch.Tensor) and arg.requires_grad + ] + if grad_tensors: + + @torch._dynamo.disable + def _multi_grad_callback( + grads: Sequence[torch.Tensor], + ) -> None: + _, hook_spec = pytree.tree_flatten((tuple(grads), {})) + invoke_leaf_function( + hook_real_callable, + hook_fake_callable, + hook_spec, + "", + *grads, + ) + + torch.autograd.graph.register_multi_grad_hook( + grad_tensors, _multi_grad_callback + ) + ctx.real_backward = real_backward ctx.fake_backward = fake_backward @@ -722,7 +764,7 @@ def invoke_leaf_function_autograd( input_spec, mutated_arg_indices, *flat_args, - requires_grad_indices=(), + requires_grad_indices="", ): return InvokeLeafFunctionAutogradOp.apply( real_fn_callable, fake_fn_callable, input_spec, mutated_arg_indices, *flat_args @@ -848,7 +890,7 @@ def invoke_leaf_function_fake( input_spec, mutated_arg_indices, *flat_args, - requires_grad_indices=(), + requires_grad_indices="", ): with unflatten_args_with_modules(flat_args, input_spec) as (args, kwargs): return fake_fn_callable(*args, **kwargs) @@ -861,7 +903,7 @@ def invoke_leaf_function_dense( input_spec, mutated_arg_indices, *flat_args, - requires_grad_indices=(), + requires_grad_indices="", ): from torch._dynamo import config as dynamo_config @@ -872,7 +914,7 @@ def invoke_leaf_function_dense( flat_args = tuple( arg.detach() if isinstance(arg, torch.Tensor) else arg for arg in flat_args ) - requires_grad_indices_set = set(requires_grad_indices) + requires_grad_indices_set = _parse_mutated_arg_indices(requires_grad_indices) flat_args = tuple( arg.requires_grad_(True) if idx in requires_grad_indices_set else arg for idx, arg in enumerate(flat_args) diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py index 2c03ebc6f6f51..6a79caeb6315a 100644 --- a/torch/_higher_order_ops/invoke_subgraph.py +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -252,19 +252,19 @@ def invoke_subgraph_placeholder(func, *args, **kwargs): def _invoke_subgraph_placeholder_wrapper(func, args): return invoke_subgraph_placeholder(func, *args) - from torch._higher_order_ops.utils import setup_compilation_env + from torch._higher_order_ops.utils import _hop_compile_and_call - with setup_compilation_env() as backend: - return torch.compile( - _invoke_subgraph_placeholder_wrapper, - backend=backend, - fullgraph=True, - )(func, args) + return _hop_compile_and_call(_invoke_subgraph_placeholder_wrapper, (func, args)) return func(*args, **kwargs) -def mark_compile_region(fn=None, options: NestedCompileRegionOptions | None = None): +def mark_compile_region( + fn=None, + options: NestedCompileRegionOptions | None = None, + max_reuse_entries: int = 8, + reuse_hash_fn=None, +): """ This wrapper instructs torch.compile to compile the wrapped region once and reuse the compiled artifact, instead of the usual way of aggressively @@ -278,6 +278,9 @@ def mark_compile_region(fn=None, options: NestedCompileRegionOptions | None = No options: Optional config to use for compiling the subgraph. Warning: this is an experimental feature under development and not ready for use yet. + max_reuse_entries: Maximum number of reuse cache entries per function + before raising an error. If this limit is hit, guards keep failing + across invocations and hierarchical compilation is not effective. """ def wrap(func): @@ -290,6 +293,8 @@ def inner(*args, **kwargs): inner.__marked_compile_region_fn__ = func # type: ignore[attr-defined] func.__marked_compile_region_config__ = options # type: ignore[attr-defined] + func.__marked_compile_region_max_reuse_entries__ = max_reuse_entries # type: ignore[attr-defined] + func.__marked_compile_region_reuse_hash_fn__ = reuse_hash_fn # type: ignore[attr-defined] return inner diff --git a/torch/_higher_order_ops/local_map.py b/torch/_higher_order_ops/local_map.py index efc99a9755e57..e691eda36b057 100644 --- a/torch/_higher_order_ops/local_map.py +++ b/torch/_higher_order_ops/local_map.py @@ -477,8 +477,12 @@ def forward( saved_activations = fw_outs_with_saved_activations[num_fw_outs:] save_values_for_backward(ctx, saved_activations) + # Force memory_format path (not exact size/stride) because local_map forward + # operates on local shapes but backward receives global-shaped tangents. + # TODO(ivankobzarev): Support exact size/stride by converting between local/global shapes. ctx.expected_tangent_metadata = { - i: MemoryFormatMeta.from_tensor(fw_outs[i]) for i in filtered_grads_idx + i: MemoryFormatMeta.from_tensor(fw_outs[i], force_use_memory_format=True) + for i in filtered_grads_idx } return fw_outs diff --git a/torch/_higher_order_ops/partitioner.py b/torch/_higher_order_ops/partitioner.py index 3517ca2b35127..f580a464fdf14 100644 --- a/torch/_higher_order_ops/partitioner.py +++ b/torch/_higher_order_ops/partitioner.py @@ -65,7 +65,7 @@ def _check_partition_boundary(self) -> None: if len(fw_outputs) != self.n_fw_outputs + self.n_intermediates: invalid_reasons.append( - f"len(fw_outputs) ({len(fw_outputs)}) != n_fw_outputs ({self.n_fw_outputs}) + n_intermediates ({self.n_intermediates})" # noqa: B950 + f"len(fw_outputs) ({len(fw_outputs)}) != n_fw_outputs ({self.n_fw_outputs}) + n_intermediates ({self.n_intermediates})" ) bw_phs = list(self.bw_gm.graph.find_nodes(op="placeholder")) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index 40fbf14d64b33..0d7015f266316 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -298,22 +298,26 @@ def _scan(init, xs): return carry, [] num_elems = xs[0].shape[dim] - ind = 0 - - # Compute dummy shapes for the pre-allocation num_init_leaves = len(init) - dummy_carry, dummy_out = _extract_carry_and_out( + + # Process element 0 to infer output shapes for pre-allocation + # AND produce the first real result in a single call. The previous + # approach used first_slice_copy() for shape inference and then + # re-processed element 0 in the main loop, calling the operator + # num_elems+1 times. That extra invocation is incorrect for + # operators with side effects. + carry, out_0 = _extract_carry_and_out( call_operator( operator, *carry, - *[first_slice_copy(elem, dim) for elem in xs], + *[elem.select(dim, 0) for elem in xs], *additional_inputs, ), num_init_leaves, ) - out_tensor_mask = get_tensor_mask(dummy_out) - dummy_out_masked = mask_list(out_tensor_mask, dummy_out) + out_tensor_mask = get_tensor_mask(out_0) + out_0_masked = mask_list(out_tensor_mask, out_0) # Pre-allocate # outs -> Output matrix @@ -321,16 +325,15 @@ def _scan(init, xs): # out: (num_elems, M, N, ...) # idx: (1, M, N) outs = [ - torch.zeros( + torch.empty( [num_elems] + list(e.size()), dtype=e.dtype, device=e.device, ) - for i, e in enumerate(dummy_out_masked) + for e in out_0_masked ] idxs = [ - torch.ones_like(e, dtype=torch.int64).unsqueeze(0) - for i, e in enumerate(dummy_out_masked) + torch.ones_like(e, dtype=torch.int64).unsqueeze(0) for e in out_0_masked ] def store_out_in_outs(out, ind): @@ -342,20 +345,21 @@ def store_out_in_outs(out, ind): # essentially: o[ind][n][k] = x[0][n][k] o.scatter_(0, ind * idx, x.unsqueeze(0)) - for i in range(num_elems): - ind = i + # Store element 0's result, then continue from element 1. + store_out_in_outs(out_0_masked, 0) + + for i in range(1, num_elems): carry, out = _extract_carry_and_out( call_operator( operator, *carry, - *[elem.select(dim, ind) for elem in xs], + *[elem.select(dim, i) for elem in xs], *additional_inputs, ), num_init_leaves, ) - # Store the inits in the outs matrix. - store_out_in_outs(mask_list(out_tensor_mask, out), ind) + store_out_in_outs(mask_list(out_tensor_mask, out), i) # Expand outs with None depending on the tensor mask of the output outs_expanded = [outs.pop(0) if out_m else None for out_m in out_tensor_mask] diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py index f9ae59c37bc17..ac0f2952f2c23 100644 --- a/torch/_higher_order_ops/strict_mode.py +++ b/torch/_higher_order_ops/strict_mode.py @@ -22,12 +22,9 @@ def strict_mode(callable, operands): if torch.compiler.is_dynamo_compiling(): return strict_mode_op(callable, operands) - from torch._higher_order_ops.utils import setup_compilation_env + from torch._higher_order_ops.utils import _hop_compile_and_call - with setup_compilation_env() as backend: - return torch.compile(strict_mode_op, backend=backend, fullgraph=True)( - callable, operands - ) + return _hop_compile_and_call(strict_mode_op, (callable, operands)) class StrictMode(HigherOrderOperator): diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index bdbc435192405..ca4354ae41079 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -44,6 +44,7 @@ from torch._dynamo.symbolic_convert import InstructionTranslator from torch._dynamo.variables.constant import ConstantVariable from torch._dynamo.variables.functions import TritonKernelVariable + from torch._guards import Source from torch._inductor.dependencies import ReadWrites from torch._subclasses.functional_tensor import BaseFunctionalizeAPI from torch.fx.proxy import Proxy @@ -334,10 +335,9 @@ def is_stable_tensor_descriptor_arg(arg: Any) -> bool: return True return False - def is_tensor_like_arg(arg: Any) -> bool: - if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg): - return True - return False + def _is_constexpr_or_none(name: str, arg: Any) -> bool: + param_idx = kernel.arg_names.index(name) + return kernel.params[param_idx].is_constexpr or arg is None # Note: one would expect that each input to the triton kernel maps to # one input parameter in the TTIR. This is _not_ true for TMA descriptors: @@ -347,9 +347,15 @@ def is_tensor_like_arg(arg: Any) -> bool: # * N sizes, for a rank-N tensor # To account for this, we inject some fake arg names as placeholders for # the stride and size parameters. - def get_tensor_names(name: str, arg: Any) -> list[str]: - if isinstance(arg, Tensor): - return [name] + # + # Additionally, tensors and scalars are both included as TTIR parameters, + # whereas `constexpr` are inlined, and None are excluded. We both preserve + # scalars and tensors as this matters for "odd" ordering, + # eg. [tensor, scalar, tensor]. + def get_arg_names(name: str, arg: Any) -> list[str]: + if _is_constexpr_or_none(name, arg): + return [] + if is_stable_tensor_descriptor_arg(arg): stable_meta = maybe_unpack_tma_stable_metadata( tma_descriptor_metadata[name] @@ -362,11 +368,12 @@ def get_tensor_names(name: str, arg: Any) -> list[str]: names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank)) names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank)) return names - return [] - ordered_tensor_names = list( + return [name] + + ordered_arg_names = list( itertools.chain.from_iterable( - get_tensor_names(name, arg) for name, arg in ordered_args.items() + get_arg_names(name, arg) for name, arg in ordered_args.items() ) ) @@ -456,8 +463,12 @@ def _native_specialize_impl( return attrs specialization = _get_specialization(ordered_args.values()) + # Triton explicitly interprets ASTSource.constants entries as constexpr + # Thus, only None and arguments marked `is_constexpr` should be treated as such. constants = { - name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg) + name: arg + for name, arg in ordered_args.items() + if _is_constexpr_or_none(name, arg) } if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None: @@ -529,7 +540,7 @@ def get_signature_value(idx: int, arg: Any) -> str: if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") - return ttir_module, ordered_tensor_names + return ttir_module, ordered_arg_names def ttir_to_functions( @@ -904,6 +915,7 @@ def analyze_kernel_access( fn_name: str, num_args: int, tensor_names: tuple[str, ...], + tensor_arg_indices: frozenset[int] | None, ) -> TensorAccesses: """ Analyzes the graph to detect which arguments are written to and which are read. @@ -978,12 +990,16 @@ def analyze_kernel_access( ) # Create placeholder names for nested function arguments nested_names = tuple(f"_arg{i}" for i in range(len(op.args))) + + # Do not pass tensor_arg_indices, most outer call of + # analyze_kernel_access will filter Param nodes. accesses = analyze_kernel_access( functions, # pyrefly: ignore [bad-argument-type] op.fn_call_name, len(op.args), nested_names, + None, ) # Map back from StarDep names to args written_set = {dep.name for dep in accesses.read_writes.writes} @@ -997,6 +1013,15 @@ def analyze_kernel_access( write_stack.extend(op.args[idx] for idx in WRITE_OPS.get(op.name, [])) read_stack.extend(op.args[idx] for idx in READ_OPS.get(op.name, [])) + # For these ops, only the first argument (base pointer) refers to actual + # memory. The remaining arguments are shape/stride/offset metadata and + # should not be traced during mutation analysis. + POINTER_ONLY_OPS = { + "tt.make_tensor_ptr", + "tt.advance", + "tt.make_tensor_descriptor", + } + def _find_arg_access_count( initial_stack: list[Param | Intermediate], skip_loads: bool, @@ -1011,6 +1036,8 @@ def _find_arg_access_count( if isinstance(arg, Param): if arg.idx >= num_args: continue + if tensor_arg_indices is not None and arg.idx not in tensor_arg_indices: + continue if arg.idx not in access_count: access_count[arg.idx] = 1 else: @@ -1019,7 +1046,10 @@ def _find_arg_access_count( for op in ops[arg]: if skip_loads and op.name == "tt.load": continue - stack.extend(op.args) + if op.name in POINTER_ONLY_OPS: + stack.append(op.args[0]) + else: + stack.extend(op.args) return access_count @@ -1072,12 +1102,14 @@ def identify_accessed_tensors( 2) Parses the TTIR and creates a control flow graph 3) Analyzes the graph to detect which input tensors are read and/or written """ + from torch._inductor.dependencies import Dep, ReadWrites, StarDep + from torch._inductor.ir import TensorBox ttir_module = None functions = None try: - ttir_module, ordered_tensor_names = generate_ttir( + ttir_module, ordered_arg_names = generate_ttir( kernel, kwargs, tma_descriptor_metadata ) @@ -1099,21 +1131,30 @@ def identify_accessed_tensors( # detection, so each top level invocation needs a clean cache analyze_kernel_access.reset() get_tma_stores.reset() + + # Build frozenset of indices corresponding to tensor args only. + # Used to filter out scalars which are transitively captured as mutated + # during traversal. + tensor_arg_indices = frozenset( + i + for i, name in enumerate(ordered_arg_names) + if isinstance(kwargs.get(name), (Tensor, TensorBox)) + ) + return analyze_kernel_access( functions, kernel_name, - len(ordered_tensor_names), - tuple(ordered_tensor_names), + len(ordered_arg_names), + tuple(ordered_arg_names), + tensor_arg_indices, ) except Exception: - import torch._inductor.ir - log.warning( "Encountered an exception in identify_accessed_tensors, assuming every input is mutated", exc_info=True, ) if ttir_module is not None: - log.debug("TTIR:\n%s", str(ttir_module)) + log.debug("TTIR:\n%s", ttir_module) if functions is not None: log.debug("functions:") for name, fn in functions.items(): @@ -1124,7 +1165,7 @@ def identify_accessed_tensors( all_tensor_names = [ key for key, value in kwargs.items() - if isinstance(value, (Tensor, torch._inductor.ir.TensorBox)) + if isinstance(value, (Tensor, TensorBox)) ] all_deps = OrderedSet(StarDep(name) for name in all_tensor_names) all_deps = typing.cast(OrderedSet[Dep], all_deps) @@ -1158,8 +1199,10 @@ def identify_triton_stores(source_code: str) -> TritonStores: tl.store signature: store(pointer, value, mask=None, boundary_check=(), ...) """ + return identify_triton_stores_from_ast(ast.parse(source_code)) + - tree = ast.parse(source_code) +def identify_triton_stores_from_ast(tree: ast.Module) -> TritonStores: stores = [] def _extract_arg(node, arg_name, positional_index): @@ -1298,7 +1341,7 @@ def triton_kernel_wrapper_mutation_dense( for k, v in tma_descriptor_metadata.items(): tensor = kwargs[k] if (exp_meta := maybe_unpack_tma_experimental_metadata(v)) is not None: - from triton.tools.experimental_descriptor import ( # noqa: F401 + from triton.tools.experimental_descriptor import ( create_1d_tma_descriptor, create_2d_tma_descriptor, ) @@ -1435,7 +1478,13 @@ def get_mutated_tensors( tensor_accesses = identify_accessed_tensors( kernel, {**kwargs, **constant_args}, tma_descriptor_metadata ) - return [dep.name for dep in tensor_accesses.read_writes.writes] + # Filter to only tensor kwargs: with Triton 3.7+, ordered_arg_names + # includes scalars, so writes may reference non-tensor args like SymInts. + return [ + dep.name + for dep in tensor_accesses.read_writes.writes + if isinstance(kwargs.get(dep.name), Tensor) + ] @triton_kernel_wrapper_mutation.py_functionalize_impl @@ -1691,7 +1740,7 @@ def do_prune_configs( # type: ignore[no-untyped-def] kwargs: dict, ) -> list["TritonConfig"]: # Reimplement autotuner.prune_configs(...) here - # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950 + # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # We do this to avoid calling prune_configs, which in turn calls early_config_prune and perf_model # These are both user-defined functions which can contain side effects, so we want to sandbox them in Dynamo @@ -1818,6 +1867,30 @@ def defaults_ok( "pre_hook and post_hook are not supported in triton.Autotune or triton.Config" ) + @staticmethod + def get_kernel_source( + variable: "TritonKernelVariable | TraceableTritonKernelWrapper", + ) -> "Source | None": + kernel_source = getattr(variable, "kernel_source", None) + if kernel_source is None: + kernel_source = getattr(variable, "source", None) + return kernel_source + + def recreate_variable( + self, + variable: "TritonKernelVariable | TraceableTritonKernelWrapper", + *, + kernel: "TritonKernelType", + kernel_idx: int | None, + grid: "TritonGridType | None", + ) -> "TritonKernelVariable | TraceableTritonKernelWrapper": + return type(variable)( + kernel=kernel, + kernel_idx=kernel_idx, + grid=grid, + kernel_source=self.get_kernel_source(variable), + ) + def call_getitem( self, variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], @@ -1830,7 +1903,8 @@ def call_getitem( "Triton kernels should be called with only a single grid" ) - return type(variable)( + return self.recreate_variable( + variable, kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=args[0], @@ -1849,8 +1923,11 @@ def call_run( kwargs.pop("warmup", None) # rewrite kernel.run(*args, grid=grid) to kernel[grid](*args) return self.call_triton_kernel( - type(variable)( - kernel=variable.kernel, kernel_idx=variable.kernel_idx, grid=grid + self.recreate_variable( + variable, + kernel=variable.kernel, + kernel_idx=variable.kernel_idx, + grid=grid, ), args, kwargs, @@ -1955,7 +2032,12 @@ def call_triton_kernel( )(iter_kernel) # create a new variable to contain the new (wrapped) kernel; # skip kernel_idx to get a new record in the kernel side table - new_var = type(variable)(new_kernel, None, variable.grid) + new_var = self.recreate_variable( + variable, + kernel=new_kernel, + kernel_idx=None, + grid=variable.grid, + ) return self.call_triton_kernel(new_var, args, kwargs, tx) SPECIAL_CONFIG_NAMES = { @@ -2000,7 +2082,12 @@ def call_triton_kernel( # create a new variable to contain the new (wrapped) kernel; # skip kernel_idx to get a new record in the kernel side table - new_var = type(variable)(new_kernel, None, variable.grid) + new_var = self.recreate_variable( + variable, + kernel=new_kernel, + kernel_idx=None, + grid=variable.grid, + ) return self.call_triton_kernel(new_var, args, kwargs, tx) if isinstance(variable.kernel, Autotuner): @@ -2040,11 +2127,16 @@ def call_triton_kernel( new_kernel = autotune( configs=new_configs, prune_configs_by=prune_configs_by, key=[] )(variable.kernel.fn) - new_var = type(variable)(new_kernel, None, variable.grid) + new_var = self.recreate_variable( + variable, + kernel=new_kernel, + kernel_idx=None, + grid=variable.grid, + ) return self.call_triton_kernel(new_var, args, kwargs, tx) # These are the default values in upstream Triton - # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py # noqa: E501,B950 + # see: https://github.com/triton-lang/triton/blob/e57b46897191b3b3061c78d0d60e58e94be565b6/python/triton/runtime/autotuner.py default_perf_model = None default_early_config_prune = None @@ -2100,7 +2192,12 @@ def call_triton_kernel( new_kernel = autotune(configs=pruned_configs, key=[])(variable.kernel.fn) # create a new variable to contain the new (wrapped) kernel; # skip kernel_idx to get a new record in the kernel side table - new_var = type(variable)(new_kernel, None, variable.grid) + new_var = self.recreate_variable( + variable, + kernel=new_kernel, + kernel_idx=None, + grid=variable.grid, + ) return self.call_triton_kernel(new_var, args, kwargs, tx) # Both for grid's meta as well as for the kernel, we need combined @@ -2308,15 +2405,18 @@ class TraceableTritonKernelWrapper: kernel: "TritonKernelType" kernel_idx: int | None grid: Optional["TritonGridType"] + kernel_source: "Source | None" def __init__( self, kernel: "TritonKernelType", kernel_idx: int | None, grid: Optional["TritonGridType"], + kernel_source: "Source | None" = None, ) -> None: self.kernel = None self.grid = None + self.kernel_source = kernel_source tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) if self.kernel is None: raise AssertionError("kernel was not initialized properly") diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 7db6a25fdbc68..9c198fe7038c6 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -96,10 +96,29 @@ def graph_with_interpreter(*args): return maybe_interpreted_fn +def _hop_compile_and_call(fn, args, kwargs=None): + """Compile and call fn with fullgraph=True for HOP eager execution. + + Pre-activates the fullgraph counter so that compile_wrapper treats this as + a nested compile. This avoids erroring when a non-infra dispatch mode + causes the frame to be skipped — the function still executes eagerly within + compile_wrapper and returns normally. + """ + from torch._dynamo.eval_frame import set_fullgraph_compiled_frame_count + + with setup_compilation_env() as backend: + old_count = set_fullgraph_compiled_frame_count(0) + try: + return torch.compile(fn, backend=backend, fullgraph=True)( + *args, **(kwargs or {}) + ) + finally: + set_fullgraph_compiled_frame_count(old_count) + + def _maybe_compile_and_run_fn(fn, *args): if not torch.compiler.is_dynamo_compiling(): - with setup_compilation_env() as backend: # type: ignore[attr-defined] - return torch.compile(fn, backend=backend, fullgraph=True)(*args) + return _hop_compile_and_call(fn, args) else: return fn(*args) @@ -466,9 +485,9 @@ def _check_alias_and_mutation(graph_module, inputs_fake, name, pre_dispatch): graph_module, inputs_fake, pre_dispatch=pre_dispatch ) if aliases: - raise RuntimeError(f"{name} might be aliasing the input or the output!") # noqa: F541 + raise RuntimeError(f"{name} might be aliasing the input or the output!") if inp_mutation: - raise RuntimeError(f"{name} might be modifying the input!") # noqa: F541 + raise RuntimeError(f"{name} might be modifying the input!") def unique_graph_id(proxy_mode, prefix): @@ -1030,8 +1049,11 @@ def check_input_alias_and_mutation_return_outputs( def _get_example_value(n): if not isinstance(n, torch.fx.Node): return n - else: - return n.meta["val"] if "val" in n.meta else n.meta["example_value"] + if "val" in n.meta: + return n.meta["val"] + if "example_value" in n.meta: + return n.meta["example_value"] + return None fake_args = [ _get_example_value(n) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index 9cc617079c367..5484c7ab6e486 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -248,12 +248,12 @@ def _validate_input(cond_fn, body_fn, carried_inputs): def _while_loop_op_wrapper(*args, **kwargs): return while_loop_op(*args, **kwargs) - from torch._higher_order_ops.utils import setup_compilation_env + from torch._higher_order_ops.utils import _hop_compile_and_call - with setup_compilation_env() as backend: - return torch.compile(_while_loop_op_wrapper, backend=backend, fullgraph=True)( - flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple() - ) + return _hop_compile_and_call( + _while_loop_op_wrapper, + (flat_cond_fn, flat_body_fn, tuple(flat_inputs), tuple()), + ) @while_loop_op.py_impl(DispatchKey.CompositeExplicitAutograd) diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py index 756b2d7a3cd33..8268d724609dd 100644 --- a/torch/_higher_order_ops/wrap.py +++ b/torch/_higher_order_ops/wrap.py @@ -68,9 +68,9 @@ class InductorCompiledCode(HigherOrderOperator): def __init__(self) -> None: super().__init__("inductor_compiled_code") - def __call__(self, func, *args, **kwargs): + def __call__(self, func, inputs, *, name: str | None = None): # pyrefly: ignore [missing-attribute] - return super().__call__(func, *args, **kwargs) + return super().__call__(func, inputs, name=name) inductor_compiled_code = InductorCompiledCode() @@ -88,10 +88,16 @@ class InductorCompiledCallable: Each instance gets a globally unique idx at creation (via atomic itertools.count). """ - def __init__(self, compiled_callable, original_gm=None): + def __init__( + self, + compiled_callable, + original_gm=None, + compile_region_name: str | None = None, + ): self.idx = next(_inductor_compiled_callable_id) self.compiled_callable = compiled_callable self.original_gm = original_gm + self.compile_region_name = compile_region_name # AOT autograd needs this to know inputs are passed as a list self._boxed_call = True @@ -160,7 +166,7 @@ def _resolve_inductor_callable( @inductor_compiled_code.py_impl(DispatchKey.CompositeExplicitAutograd) -def inductor_compiled_code_impl(func, inputs): +def inductor_compiled_code_impl(func, inputs, *, name=None): resolved = _resolve_inductor_callable(func) return resolved.compiled_callable(inputs) @@ -171,7 +177,7 @@ def inductor_compiled_code_impl(func, inputs): @register_fake(inductor_compiled_code) -def inductor_compiled_code_fake(func, inputs): +def inductor_compiled_code_fake(func, inputs, *, name=None): resolved = _resolve_inductor_callable(func) if resolved.original_gm is None: raise RuntimeError( @@ -184,22 +190,24 @@ def inductor_compiled_code_fake(func, inputs): @inductor_compiled_code.py_functionalize_impl -def inductor_compiled_code_functionalize(ctx, func, inputs): +def inductor_compiled_code_functionalize(ctx, func, inputs, *, name=None): # Unwrap the functional tensors to get the underlying tensors unwrapped_inputs = ctx.unwrap_tensors(inputs) # Redispatch to the next handler in the dispatch chain with ctx.redispatch_to_next(): - result = inductor_compiled_code(func, unwrapped_inputs) + kwargs = {"name": name} if name is not None else {} + result = inductor_compiled_code(func, unwrapped_inputs, **kwargs) return ctx.wrap_tensors(result) @inductor_compiled_code.py_impl(ProxyTorchDispatchMode) -def inductor_compiled_code_proxy(mode, func, inputs): +def inductor_compiled_code_proxy(mode, func, inputs, *, name=None): resolved = _resolve_inductor_callable(func) # Run the fake impl to get example outputs for tracing - example_out = inductor_compiled_code(func, inputs) + kwargs = {"name": name} if name is not None else {} + example_out = inductor_compiled_code(func, inputs, **kwargs) # Register in side table so the FX node stores a serializable int callable_idx = inductor_code_side_table.add_callable(resolved) @@ -210,7 +218,7 @@ def inductor_compiled_code_proxy(mode, func, inputs): "call_function", inductor_compiled_code, (callable_idx, proxy_inputs), - {}, + kwargs, ) return track_tensor_tree(example_out, out_proxy, constant=None, tracer=mode.tracer) @@ -229,7 +237,7 @@ def __call__( ) -> _R: # Dynamo already traces the body of HigherOrderOp beforehand when it # so no need to trace into it. - import torch._dynamo # noqa: F401 + import torch._dynamo from torch._dynamo import disable @disable @@ -262,7 +270,7 @@ def __call__( ) -> _R: # Dynamo already traces the body of HigherOrderOp beforehand when it # so no need to trace into it. - import torch._dynamo # noqa: F401 + import torch._dynamo from torch._dynamo import disable @disable @@ -295,7 +303,7 @@ def __call__( ): # Dynamo already traces the body of HigherOrderOp beforehand when it # so no need to trace into it. - import torch._dynamo # noqa: F401 + import torch._dynamo from torch._dynamo import disable is_compiling = isinstance(wrapper_fn_or_key, str) @@ -419,22 +427,6 @@ def divide_kwargs(kwargs): } return checkpoint_kwargs, gmod_kwargs - @staticmethod - def tag_nodes(gmod, is_sac): - from torch.utils.checkpoint import CheckpointPolicy - - unique_graph_id = next(uid) - for node in gmod.graph.nodes: - if node.op in ("call_function", "call_method", "call_module"): - node.meta["ac_graph_id"] = unique_graph_id - if is_sac: - # For selective checkpointing, we will populate this tag later in _CachingTorchDispatchMode. - node.meta["recompute"] = None - else: - # Under vanilla activation checkpointing, all nodes should be recomputed. - node.meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE - return gmod - def __call__(self, gmod, *args, **kwargs): dispatch_key_set = torch._ops._compute_keyset( args, kwargs, self.non_fallthrough_keys @@ -450,11 +442,22 @@ def __call__(self, gmod, *args, **kwargs): tag_activation_checkpoint = TagActivationCheckpoint() +def _always_prefer_recompute(ctx, op, *args, **kwargs): + from torch.utils.checkpoint import CheckpointPolicy + + return CheckpointPolicy.PREFER_RECOMPUTE + + def tag_activation_checkpoint_impl(gmod, *args, **kwargs): + import functools + import torch.fx.traceback as fx_traceback from torch.fx import Interpreter + from torch.utils.checkpoint import create_selective_checkpoint_contexts + unique_graph_id = next(uid) if "_checkpoint_context_fn" in gmod.meta: + context_fn = gmod.meta["_checkpoint_context_fn"] warning_once( log, """ @@ -462,33 +465,47 @@ def tag_activation_checkpoint_impl(gmod, *args, **kwargs): Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_). """, ) - # use_reentrant is set to False because this op is going to be traced. - # And we ensure that AOT Autograd traces through the non reentrant - # version of checkpointing. - kwargs["use_reentrant"] = False - # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through - # `torch.random.fork_rng` op (which is not supported yet under CUDA). - # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state - # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor - # instead of in AOTAutograd). - kwargs["preserve_rng_state"] = False - kwargs["context_fn"] = gmod.meta["_checkpoint_context_fn"] - # We first tag all nodes as "recompute" in this graph, and then we undo the "recompute" tag - # for specific nodes in _CachingTorchDispatchMode in torch/utils/checkpoint.py. - gmod = TagActivationCheckpoint.tag_nodes(gmod, is_sac=True) - # Using interpreter allows preservation of metadata through torch.compile stack. - with fx_traceback.preserve_node_meta(): - from torch.utils.checkpoint import checkpoint - - return checkpoint(Interpreter(gmod).run, *args, **kwargs) else: - gmod = TagActivationCheckpoint.tag_nodes(gmod, is_sac=False) - # Using interpreter allows preservation of metadata through torch.compile stack. - # TODO: We want to use the same `checkpoint(Interpreter(gmod).run, *args, **kwargs)` here - # as the `context_fn != None` case, but that depends on in-place op support in TorchDispatchMode + torch.compile. - # (for details on in-place op issue, run `test_compile_selective_checkpoint_inplace_op` unit test) - with fx_traceback.preserve_node_meta(): - return Interpreter(gmod).run(*args) + # Vanilla AC: use the SAC path with a policy that always recomputes. + context_fn = functools.partial( + create_selective_checkpoint_contexts, _always_prefer_recompute + ) + + def context_fn_with_graph_id(): + fwd_ctx, recomp_ctx = context_fn() + # Plumb ac_graph_id so _CachingTorchDispatchMode tags all nodes + # (including ops from desugared HOPs like custom autograd.Function). + fwd_ctx.ac_graph_id = unique_graph_id + return fwd_ctx, recomp_ctx + + # use_reentrant is set to False because this op is going to be traced. + # And we ensure that AOT Autograd traces through the non reentrant + # version of checkpointing. + kwargs["use_reentrant"] = False + # preserve_rng_state is set to False because we want to prevent AOTAutograd from tracing through + # `torch.random.fork_rng` op (which is not supported yet under CUDA). + # This doesn't mean that we don't preserve RNG state. Instead, we will always preserve RNG state + # regardless of this flag (by doing RNG functionalization via `replace_random_passes` in Inductor + # instead of in AOTAutograd). + kwargs["preserve_rng_state"] = False + kwargs["context_fn"] = context_fn_with_graph_id + # Disable early stop to prevent _StopRecomputationError from interrupting + # recomputation between _vmap_increment_nesting and _vmap_decrement_nesting, + # which would leak a functorch dynamic layer. + kwargs["early_stop"] = False + # Using interpreter allows preservation of metadata through torch.compile stack. + # We use a wrapper instead of passing Interpreter(gmod).run directly because + # checkpoint's recompute_fn captures the function in a closure. A bound method + # reference would keep the Interpreter alive, whose env dict retains the output + # tensors and prevents the autograd graph from being freed. + + def run_with_interpreter(*args): + return Interpreter(gmod).run(*args) + + with fx_traceback.preserve_node_meta(): + from torch.utils.checkpoint import checkpoint + + return checkpoint(run_with_interpreter, *args, **kwargs) @tag_activation_checkpoint.py_impl(ProxyTorchDispatchMode) diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 4b8cce828e41f..11fb1a0c5d7f4 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -9,7 +9,7 @@ import torch.fx -from .standalone_compile import CompiledArtifact # noqa: TC001 +from .standalone_compile import CompiledArtifact, DynamicShapesType # noqa: TC001 if TYPE_CHECKING: @@ -407,11 +407,10 @@ def standalone_compile( gm: torch.fx.GraphModule, example_inputs: list[InputType], *, - dynamic_shapes: Literal[ - "from_example_inputs", "from_tracing_context", "from_graph" - ] = "from_graph", + dynamic_shapes: DynamicShapesType = "from_graph", options: dict[str, Any] | None = None, aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache + donate_graph_module: bool = False, ) -> CompiledArtifact: """ Precompilation API for inductor. @@ -435,6 +434,9 @@ def standalone_compile( If "from_example_inputs", we will specialize the graph on the example_inputs. options: Inductor compilation options + donate_graph_module: If True, standalone_compile takes ownership of + the graph module and may mutate it, avoiding an internal deepcopy. + Defaults to False for backwards compatibility. Returns: CompiledArtifact that can be saved to disk or invoked directly. @@ -443,7 +445,12 @@ def standalone_compile( options = options if options else {} return standalone_compile( - gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options, aot=aot + gm, + example_inputs, + dynamic_shapes=dynamic_shapes, + options=options, + aot=aot, + donate_graph_module=donate_graph_module, ) diff --git a/torch/_inductor/analysis/profile_analysis.py b/torch/_inductor/analysis/profile_analysis.py index 3dd20fba88a92..44d8220555b9f 100644 --- a/torch/_inductor/analysis/profile_analysis.py +++ b/torch/_inductor/analysis/profile_analysis.py @@ -354,18 +354,26 @@ def _augment_trace_helper(data: dict[str, Any]) -> dict[str, Any]: return data -_dtype_map = { +_dtype_map: dict[str, torch.dtype] = { "float": torch.float, "float32": torch.float, + "double": torch.double, + "float64": torch.double, "int": torch.int, "int8": torch.int8, "int16": torch.int16, "int32": torch.int, "long": torch.long, "long int": torch.long, + "signed char": torch.int8, + "unsigned char": torch.uint8, + "bool": torch.bool, "bfloat16": torch.bfloat16, "float16": torch.float16, - "float64": torch.double, + "c10::BFloat16": torch.bfloat16, + "c10::Half": torch.float16, + "c10::complex": torch.complex64, + "c10::complex": torch.complex128, } diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 9e3213a5e0159..1e6dcbece6b63 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -36,6 +36,7 @@ ROCmCodeCache, StaticAutotunerFuture, torch_key, + XPUCodeCache, ) from torch._inductor.compile_worker.subproc_pool import ( AnyPool, @@ -236,6 +237,7 @@ class AsyncCompile: """ _ready_future: Future[Any] | None = None + _metal_sources: list[tuple[str, str, list[str]]] | None = None def __init__(self) -> None: pass @@ -535,19 +537,25 @@ def cpp_pybinding(self, argtypes: list[str], source_code: str): ) return LambdaFuture(get_result) - def cuda(self, source_code, dst_file_ext, aot_compile=False): - kernel_code_log.info("CUDA Kernel:\n%s", source_code) - + def cutlass(self, cache_cls, source_code, dst_file_ext, aot_compile=False): def task(): if aot_compile: # We rely on JITInductor to compile the CUDA code, # so that we can load it into AOTInductor. - output_path, *_ = CUDACodeCache.compile(source_code, "o") - CUDACodeCache.aot_kernels_o.append(output_path) - return CUDACodeCache.load(source_code, dst_file_ext)[0] + output_path, *_ = cache_cls.compile(source_code, "o") + cache_cls.aot_kernels_o.append(output_path) + return cache_cls.load(source_code, dst_file_ext)[0] return self.submit(task) + def cuda(self, source_code, dst_file_ext, aot_compile=False): + kernel_code_log.info("CUDA Kernel:\n%s", source_code) + return self.cutlass(CUDACodeCache, source_code, dst_file_ext, aot_compile) + + def xpu(self, source_code, dst_file_ext, aot_compile=False): + kernel_code_log.info("XPU Kernel:\n%s", source_code) + return self.cutlass(XPUCodeCache, source_code, dst_file_ext, aot_compile) + def rocm( self, source_code, @@ -695,6 +703,12 @@ def task(): future = self.submit(task) return LambdaFuture(lambda: future.result()) + def metal(self, kernel_name: str, source: str, headers: list[str]) -> None: + """Register a Metal kernel body; wait() compiles all registered kernels into one library.""" + if self._metal_sources is None: + self._metal_sources = [] + self._metal_sources.append((kernel_name, source, headers)) + def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( @@ -706,6 +720,12 @@ def wait(self, scope: dict[str, Any]) -> None: ): self._wait_futures(scope) + if self._metal_sources: + from torch._inductor.runtime.runtime_utils import compile_mps_shaders + + scope.update(compile_mps_shaders(self._metal_sources)) + self._metal_sources.clear() + _compile_end() def _wait_futures(self, scope: dict[str, Any]) -> None: diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index b582ea868e15e..e6ff7cbe01e41 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -1,10 +1,15 @@ +import logging from collections import defaultdict import torch import torch.fx as fx +from torch._logging import trace_structured from torch.utils._ordered_set import OrderedSet +log = logging.getLogger(__name__) + + class AugmentedGraphHelper: """ Graph helper that augments the original graph with additional @@ -89,8 +94,22 @@ def get_merged_deps(self, node: fx.Node) -> OrderedSet[fx.Node]: return deps def has_cycle(self) -> bool: - merged_deps = {n: self.get_merged_deps(n) for n in self.graph.nodes} - return torch._dynamo.graph_deduplication._has_cycle(self.graph, merged_deps) + return torch._dynamo.graph_deduplication._has_cycle( + self.graph, self.get_all_extra_deps() + ) + + def _get_all_ancestors(self, node: fx.Node) -> OrderedSet[fx.Node]: + """Transitive ancestors through both data deps and extra deps.""" + ancestors: OrderedSet[fx.Node] = OrderedSet() + stack: list[fx.Node] = list(node.all_input_nodes) + stack.extend(self.extra_deps.get(node, ())) + while stack: + n = stack.pop() + if n not in ancestors: + ancestors.add(n) + stack.extend(n.all_input_nodes) + stack.extend(self.extra_deps.get(n, ())) + return ancestors def has_path(self, source: fx.Node, target: fx.Node) -> bool: """Check if there's a path from source to target.""" @@ -137,6 +156,10 @@ def transfer_erased_node_deps( """ Transfer all extra dependencies from erased nodes to their replacements, handling cross-dependencies between erased nodes correctly. + + Skips deps where both endpoints resolve to replacement nodes from the + same erasure batch — these are intra-bucket deps that would create + cycles (e.g. new_start <-> new_wait within the same bucket). """ erased_merge_sets: dict[fx.Node, fx.Node | None] = {} @@ -161,6 +184,11 @@ def transfer_erased_node_deps( for extra_dep in self.extra_deps[old_node]: updated_dep = erased_merge_sets.get(extra_dep, extra_dep) if updated_dep is not None and updated_dep != new_node: + # Skip if reverse dep already exists (extra or data) + if new_node in self.extra_deps.get( + updated_dep, () + ) or new_node in OrderedSet(updated_dep.all_input_nodes): + continue self.extra_deps[new_node].add(updated_dep) self.extra_uses[updated_dep].discard(old_node) self.extra_uses[updated_dep].add(new_node) @@ -169,6 +197,11 @@ def transfer_erased_node_deps( for extra_use in self.extra_uses[old_node]: updated_use = erased_merge_sets.get(extra_use, extra_use) if updated_use is not None and updated_use != new_node: + # Skip if reverse dep already exists (extra or data) + if updated_use in self.extra_deps.get( + new_node, () + ) or updated_use in OrderedSet(new_node.all_input_nodes): + continue self.extra_deps[updated_use].discard(old_node) self.extra_deps[updated_use].add(new_node) self.extra_uses[new_node].add(updated_use) @@ -179,6 +212,58 @@ def transfer_erased_node_deps( self.extra_uses[old_node].clear() del self.merge_sets[old_node] + def remove_erased_extra_deps(self) -> None: + """Remove extra deps referencing erased nodes.""" + for node in list(self.extra_deps): + if node._erased: + for dep in list(self.extra_deps[node]): + self.remove_extra_dep(n=node, dep=dep) + continue + for dep in list(self.extra_deps[node]): + if dep._erased: + self.remove_extra_dep(n=node, dep=dep) + + def check_and_maybe_autofix_cyclic_extra_deps( + self, *, autofix: bool = False + ) -> None: + """Check for and optionally remove extra deps that create cycles. + + Args: + autofix: If True, silently remove cyclic deps. If False (default), + raise an error so the root cause gets investigated. + """ + if not self.has_cycle(): + return + removed = [] + for node in list(self.extra_deps): + for dep in list(self.extra_deps[node]): + ancestors = self._get_all_ancestors(dep) + if node in ancestors: + removed.append((node.name, dep.name)) + self.remove_extra_dep(n=node, dep=dep) + if not removed: + return + msg = ( + f"Overlap scheduling: detected {len(removed)} cyclic extra " + f"dep(s): {removed}. Please report this to the overlap " + f"scheduling developers." + ) + log.warning(msg) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_overlap_cyclic_extra_deps", + "encoding": "string", + }, + payload_fn=lambda: msg, + ) + if not autofix: + raise RuntimeError( + f"{msg}\nTo unblock, set " + f"torch._inductor.config.aten_distributed_optimizations" + f".overlap_scheduling_autofix_cycles = True" + ) + def get_all_extra_deps(self) -> dict[fx.Node, OrderedSet[fx.Node]]: """ Get all extra dependencies in a format suitable for topological sort. diff --git a/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py b/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py index b61f8a9dd1e99..c19817f98484a 100644 --- a/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py +++ b/torch/_inductor/autoheuristic/artifacts/_PadMMA100.py @@ -2,16 +2,23 @@ # fmt: off # This file was generated by AutoHeuristic. Do not modify it manually! # To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ -from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata, Choice, CHOICE_COL +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) from torch._inductor.autoheuristic.learnedheuristic_interface import ( - LearnedHeuristicRegression, + LearnedHeuristicDecision, ) -class PadMMA100(LearnedHeuristicRegression): +class PadMMA100(LearnedHeuristicDecision): def __init__(self) -> None: - pass + self.choices: list[Choice] = [] + self.fill_choices() def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: return ( @@ -20,90 +27,239 @@ def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: and str(metadata.device_capa) == "(8, 0)" ) - def get_feedback(self, context: AHContext, choice: Choice) -> float: - context.context_dict[CHOICE_COL] = choice - return self.predict(context) - def get_confidence_threshold(self) -> float: - return 1.7025303314066 + return 0.9294871794871795 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('orig') + self.choices.append('pad') def get_name(self) -> str: return 'pad_mm' - def predict(self, context: AHContext) -> float: - if str(context.get_value('choice')) != 'pad': - if str(context.get_value('using_tf32')) != 'False': - if context.get_value('m*n') <= 4171264.0: - if context.get_value('m*k') <= 3999308.0: - return 1.8751469764071178 + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if str(context.get_value('mat1_innermost_needs_padding')) != 'False': + if context.get_value('arith_intensity') <= 880.0238037109375: + if str(context.get_value('m_multiple_2')) != 'True': + if context.get_value('n') <= 652.0: + if context.get_value('m') <= 2022.0: + if str(context.get_value('using_tf32')) != 'False': + return [(0.579, 1), (0.421, 0)] + else: + return [(1.000, 0)] + else: + return [(1.000, 0)] else: - if str(context.get_value('n_multiple_32')) != 'True': - return 0.9117231355626345 + if context.get_value('m*k') <= 107278336.0: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m') <= 23691.0: + return [(0.993, 1), (0.007, 0)] + else: + return [(0.840, 1), (0.160, 0)] + else: + if context.get_value('arith_intensity') <= 793.3185424804688: + return [(0.958, 0), (0.042, 1)] + else: + return [(0.792, 1), (0.208, 0)] else: - return 1.1607689608873861 + if context.get_value('arith_intensity') <= 795.6242370605469: + if context.get_value('mat2_stride_1') <= 2048.5: + return [(0.929, 0), (0.071, 1)] + else: + return [(1.000, 1)] + else: + if context.get_value('arith_intensity') <= 796.3460388183594: + return [(0.957, 1), (0.043, 0)] + else: + return [(0.778, 0), (0.222, 1)] else: - if str(context.get_value('n_multiple_2')) != 'True': - if str(context.get_value('using_tf32')) != 'True': - return 0.7430382200435992 + if context.get_value('mat2_stride_0') <= 2432.0: + if str(context.get_value('k_multiple_2')) != 'False': + if context.get_value('n') <= 1024.5: + if str(context.get_value('prepadded_mat1')) != 'False': + return [(0.580, 0), (0.420, 1)] + else: + return [(0.986, 0), (0.014, 1)] + else: + if context.get_value('mat1_stride_0') <= 5125.0: + return [(0.551, 0), (0.449, 1)] + else: + return [(0.916, 0), (0.084, 1)] else: - return 0.8531269794448678 + if context.get_value('mat2_align_size') <= 6.0: + if str(context.get_value('using_tf32')) != 'True': + return [(1.000, 0)] + else: + return [(0.800, 0), (0.200, 1)] + else: + if context.get_value('mat2_stride_1') <= 3820.0: + return [(0.986, 1), (0.014, 0)] + else: + return [(0.532, 0), (0.468, 1)] else: - if str(context.get_value('k_multiple_2')) != 'True': - return 0.7577181972719917 + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*n') <= 5244928.0: + if str(context.get_value('k_multiple_2')) != 'True': + return [(0.971, 1), (0.029, 0)] + else: + return [(0.646, 1), (0.354, 0)] + else: + if context.get_value('k/(m*n)') <= 9.468618827668251e-06: + return [(0.800, 1), (0.200, 0)] + else: + return [(1.000, 1)] else: - return 0.8977349440424219 + if context.get_value('mat1_stride_1') <= 1288.0: + if context.get_value('k') <= 5717.0: + return [(0.983, 0), (0.017, 1)] + else: + return [(0.800, 0), (0.200, 1)] + else: + return [(0.588, 0), (0.412, 1)] else: - if context.get_value('m*n') <= 1299712.0: - return 1.1669723418995592 - else: - if context.get_value('mat2_stride_1') <= 45217.5: - if context.get_value('m*n') <= 55884158.0: - return 1.0262769936909601 + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('n') <= 2640.0: + if context.get_value('m_padded_length') <= 1.5: + if context.get_value('mat2_stride_1') <= 6021.0: + return [(1.000, 1)] + else: + if context.get_value('mat2_stride_1') <= 6069.0: + return [(0.900, 1), (0.100, 0)] + else: + return [(1.000, 1)] else: - return 1.0022677428470845 + if str(context.get_value('m_multiple_2')) != 'False': + if context.get_value('m*k') <= 24444928.0: + return [(0.593, 0), (0.407, 1)] + else: + return [(0.923, 0), (0.077, 1)] + else: + return [(1.000, 1)] else: - if context.get_value('m') <= 18478.0: - return 1.1127066261894312 + if context.get_value('m*k') <= 404182016.0: + if str(context.get_value('mat2_innermost_needs_padding')) != 'False': + if context.get_value('m*k') <= 12328960.0: + return [(0.732, 1), (0.268, 0)] + else: + return [(0.989, 1), (0.011, 0)] + else: + if context.get_value('m*k') <= 389028864.0: + return [(0.998, 1), (0.002, 0)] + else: + return [(0.922, 1), (0.078, 0)] else: - return 1.0337740659894263 + if context.get_value('m*n') <= 137631744.0: + if context.get_value('m*k') <= 405715968.0: + return [(0.611, 1), (0.389, 0)] + else: + return [(0.946, 1), (0.054, 0)] + else: + return [(0.714, 0), (0.286, 1)] + else: + if context.get_value('mat2_stride_0') <= 3902.5: + return [(0.941, 0), (0.059, 1)] + else: + return [(0.583, 1), (0.417, 0)] else: - if str(context.get_value('mat1_dtype')) != 'torch.float32': - if str(context.get_value('n_multiple_2')) != 'False': + if context.get_value('n_padded_length') <= 0.5: + if str(context.get_value('mat2_innermost_needs_padding')) != 'False': if str(context.get_value('k_multiple_2')) != 'True': - if context.get_value('mat1_stride_0') <= 561.0: - return 1.2900382135142956 + if context.get_value('arith_intensity') <= 884.5185852050781: + if context.get_value('arith_intensity') <= 743.931884765625: + return [(1.000, 0)] + else: + return [(0.583, 0), (0.417, 1)] else: - return 1.5761737616057887 + return [(1.000, 1)] else: - if context.get_value('num_dims_needs_padding') <= 1.5: - return 1.0472263310239422 + if context.get_value('k/(m*n)') <= 0.00023481580865336582: + return [(0.900, 0), (0.100, 1)] else: - return 1.1727673465762514 + return [(1.000, 0)] else: - if context.get_value('k') <= 28238.5: - if context.get_value('k/(m*n)') <= 0.00026227018679492176: - return 1.6770542505397175 + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m*k') <= 93734912.0: + if context.get_value('mat1_stride_0') <= 1344.0: + if context.get_value('n') <= 7168.0: + return [(1.000, 0)] + else: + return [(0.970, 0), (0.030, 1)] + else: + if context.get_value('m') <= 22883.5: + return [(0.977, 0), (0.023, 1)] + else: + return [(0.800, 0), (0.200, 1)] else: - return 1.3974785435105923 + if context.get_value('arith_intensity') <= 1914.3681030273438: + if str(context.get_value('prepadded_mat1')) != 'False': + return [(0.981, 0), (0.019, 1)] + else: + return [(0.995, 0), (0.005, 1)] + else: + return [(1.000, 0)] else: - if str(context.get_value('mat1_dtype')) != 'torch.bfloat16': - return 1.3952699800111992 + if str(context.get_value('prepadded_mat1')) != 'False': + if context.get_value('mat2_stride_1') <= 256.5: + if context.get_value('m') <= 5880.5: + return [(1.000, 0)] + else: + return [(0.800, 0), (0.200, 1)] + else: + if context.get_value('m*k') <= 6318080.0: + return [(0.618, 1), (0.382, 0)] + else: + return [(0.880, 0), (0.120, 1)] else: - return 1.5759286511628336 + if context.get_value('k/(m*n)') <= 0.0009747986623551697: + if context.get_value('k/(m*n)') <= 0.0006397514371201396: + return [(0.951, 0), (0.049, 1)] + else: + return [(0.857, 0), (0.143, 1)] + else: + return [(1.000, 0)] else: if str(context.get_value('using_tf32')) != 'False': - if context.get_value('m*n') <= 14119424.0: - return 0.8875772670422478 + if str(context.get_value('n_multiple_2')) != 'False': + if context.get_value('m') <= 2024.0: + if context.get_value('mat2_stride_0') <= 1629.0: + if context.get_value('k*n') <= 1288704.0: + return [(0.600, 0), (0.400, 1)] + else: + return [(0.982, 0), (0.018, 1)] + else: + if context.get_value('mat1_stride_0') <= 768.0: + return [(0.619, 0), (0.381, 1)] + else: + return [(0.812, 1), (0.188, 0)] + else: + if context.get_value('m*n') <= 5803008.0: + return [(0.500, 0), (0.500, 1)] + else: + if context.get_value('mat2_stride_1') <= 896.0: + return [(1.000, 1)] + else: + return [(0.818, 1), (0.182, 0)] else: - if str(context.get_value('mat2_innermost_needs_padding')) != 'True': - return 1.1467728924377265 + if context.get_value('mat2_stride_1') <= 2560.0: + if context.get_value('num_dims_needs_padding') <= 1.5: + if context.get_value('mat2_stride_1') <= 896.0: + return [(1.000, 1)] + else: + return [(0.857, 1), (0.143, 0)] + else: + return [(0.727, 1), (0.273, 0)] else: - return 1.215842963532998 + return [(0.667, 1), (0.333, 0)] else: - if context.get_value('arith_intensity') <= 396.8774871826172: - return 0.89940161869551 + if context.get_value('k/(m*n)') <= 0.00015462777810171247: + return [(0.857, 0), (0.143, 1)] else: - if context.get_value('mat2_stride_1') <= 45217.5: - return 0.9964328169353532 + if context.get_value('k/(m*n)') <= 0.0019917909521609545: + return [(1.000, 0)] else: - return 0.9493479238294826 + return [(0.900, 0), (0.100, 1)] diff --git a/torch/_inductor/autoheuristic/artifacts/_PadMMH200.py b/torch/_inductor/autoheuristic/artifacts/_PadMMH200.py new file mode 100644 index 0000000000000..1527b02319c10 --- /dev/null +++ b/torch/_inductor/autoheuristic/artifacts/_PadMMH200.py @@ -0,0 +1,175 @@ +# flake8: noqa: B950 +# fmt: off +# This file was generated by AutoHeuristic. Do not modify it manually! +# To regenerate this file, take a look at the steps in the README.md file inside torchgen/_autoheuristic/pad_mm/ +from typing import Optional + +from torch._inductor.autoheuristic.autoheuristic_utils import ( + AHContext, + AHMetadata, + Choice, +) +from torch._inductor.autoheuristic.learnedheuristic_interface import ( + LearnedHeuristicDecision, +) + + +class PadMMH200(LearnedHeuristicDecision): + + def __init__(self) -> None: + self.choices: list[Choice] = [] + self.fill_choices() + + def check_precondition(self, metadata: AHMetadata, context: AHContext,) -> bool: + return ( + metadata.name == self.get_name() + and metadata.shared_memory == 232448 + and str(metadata.device_capa) == "(9, 0)" + ) + + def get_confidence_threshold(self) -> float: + return 0.7710651828298887 + + def get_choice(self, idx: int) -> Optional[str]: + if idx < len(self.choices): + return self.choices[idx] + return None + + def fill_choices(self) -> None: + self.choices.append('orig') + self.choices.append('pad') + + def get_name(self) -> str: + return 'pad_mm' + + def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]: + if str(context.get_value('mat1_innermost_needs_padding')) != 'True': + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + if context.get_value('n_padded_length') <= 0.5: + if str(context.get_value('prepadded_mat1')) != 'True': + if str(context.get_value('using_tf32')) != 'False': + return [(1.000, 0)] + else: + if context.get_value('mat1_stride_0') <= 3584.0: + return [(1.000, 0)] + else: + if context.get_value('mat2_stride_0') <= 3584.0: + return [(1.000, 0)] + else: + return [(0.528, 0), (0.472, 1)] + else: + if context.get_value('n') <= 2304.0: + if context.get_value('m*k') <= 25198592.0: + if context.get_value('arith_intensity') <= 1103.9319458007812: + return [(1.000, 0)] + else: + return [(0.885, 0), (0.115, 1)] + else: + if context.get_value('m*k') <= 25688064.0: + return [(1.000, 1)] + else: + return [(0.771, 0), (0.229, 1)] + else: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('m') <= 27825.0: + return [(0.948, 0), (0.052, 1)] + else: + return [(0.855, 0), (0.145, 1)] + else: + if context.get_value('mat2_stride_0') <= 3584.0: + return [(1.000, 0)] + else: + return [(0.917, 1), (0.083, 0)] + else: + if context.get_value('m') <= 1823.5: + if str(context.get_value('n_multiple_2')) != 'False': + if context.get_value('k*n') <= 7859200.0: + return [(0.600, 0), (0.400, 1)] + else: + return [(1.000, 0)] + else: + if context.get_value('k/(m*n)') <= 0.00040277576772496104: + return [(1.000, 1)] + else: + return [(0.800, 1), (0.200, 0)] + else: + if context.get_value('n') <= 3602.0: + return [(0.800, 1), (0.200, 0)] + else: + return [(1.000, 1)] + else: + if str(context.get_value('using_tf32')) != 'False': + if str(context.get_value('n_multiple_16')) != 'False': + if str(context.get_value('k_multiple_2')) != 'True': + if context.get_value('arith_intensity') <= 744.8332214355469: + return [(0.600, 0), (0.400, 1)] + else: + return [(1.000, 1)] + else: + if context.get_value('m*n') <= 8912896.0: + if context.get_value('m*k') <= 5934080.0: + return [(0.800, 0), (0.200, 1)] + else: + return [(1.000, 0)] + else: + return [(1.000, 1)] + else: + return [(1.000, 1)] + else: + return [(1.000, 0)] + else: + if context.get_value('arith_intensity') <= 895.8767395019531: + if str(context.get_value('m_multiple_2')) != 'False': + if context.get_value('mat1_stride_1') <= 3421.0: + if str(context.get_value('using_tf32')) != 'False': + if context.get_value('mat2_stride_1') <= 10706.5: + if context.get_value('mat2_stride_0') <= 1024.5: + return [(0.816, 1), (0.184, 0)] + else: + return [(1.000, 1)] + else: + if str(context.get_value('k_multiple_2')) != 'True': + return [(0.905, 1), (0.095, 0)] + else: + return [(1.000, 0)] + else: + if str(context.get_value('prepadded_mat2')) != 'True': + if str(context.get_value('mat2_innermost_needs_padding')) != 'False': + return [(1.000, 0)] + else: + return [(0.932, 0), (0.068, 1)] + else: + if context.get_value('arith_intensity') <= 742.1241760253906: + return [(0.889, 0), (0.111, 1)] + else: + return [(0.765, 1), (0.235, 0)] + else: + if context.get_value('n') <= 1216.0: + if str(context.get_value('using_tf32')) != 'True': + if context.get_value('mat1_stride_1') <= 5567.0: + return [(0.896, 0), (0.104, 1)] + else: + return [(0.999, 0), (0.001, 1)] + else: + return [(1.000, 1)] + else: + if str(context.get_value('using_tf32')) != 'False': + return [(1.000, 1)] + else: + return [(1.000, 0)] + else: + if str(context.get_value('using_tf32')) != 'False': + return [(1.000, 1)] + else: + if context.get_value('mat2_stride_1') <= 2688.0: + return [(1.000, 0)] + else: + return [(0.500, 0), (0.500, 1)] + else: + if str(context.get_value('using_tf32')) != 'False': + return [(1.000, 1)] + else: + if str(context.get_value('mat2_innermost_needs_padding')) != 'True': + return [(1.000, 0)] + else: + return [(0.800, 0), (0.200, 1)] diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index ab60a8b2b66f5..92cb149c1683d 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -2,6 +2,7 @@ from __future__ import annotations import atexit +import contextvars import ctypes import dataclasses import functools @@ -22,7 +23,7 @@ from typing import Any, IO, TYPE_CHECKING import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided from torch._inductor import ir @@ -32,6 +33,7 @@ DLLWrapper, get_hash, PyCodeCache, + XPUCodeCache, ) from torch._inductor.compile_worker.timer import Timer from torch._inductor.utils import ( @@ -522,7 +524,7 @@ def benchmark( bench_elapse = time.time() - start_ts # type: ignore[possibly-undefined] autotuning_log.debug( "InChildProcess %s: load %f, create tensor %f, bench %f", - str(self), + self, load_elapse, # type: ignore[possibly-undefined] create_tensor_elapse, # type: ignore[possibly-undefined] bench_elapse, @@ -828,6 +830,53 @@ class ExternKernelCPUBenchmarkRequest( pass +class SubgraphBenchmarkRequest(BenchmarkRequest): + """ + Benchmark request for subgraph choices. + + Pre-compiles the subgraph in the main process and stores + the module path/cache key for loading in subprocess. + """ + + def __init__( + self, + kernel_name: str, + input_tensor_meta: TensorMeta | list[TensorMeta], + output_tensor_meta: TensorMeta | list[TensorMeta], + extra_args: Iterable[Any], + module_path: str, + module_cache_key: str, + sym_input_values: list[int], + ) -> None: + super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args) + self.module_path = module_path + self.module_cache_key = module_cache_key + self.sym_input_values = sym_input_values + + def make_run_fn( + self, *input_tensors: torch.Tensor, out: torch.Tensor + ) -> Callable[[], None]: + mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path) + sym_input_values = self.sym_input_values + # Create a new list each call since mod.call does args.clear() + return lambda: mod.call([*sym_input_values, *input_tensors]) + + def precompile(self) -> None: + # Module is already compiled in main process, no precompilation needed + pass + + def __str__(self) -> str: + return f"SubgraphBenchmarkRequest({self.kernel_name}, {self.module_path})" + + +class SubgraphGPUBenchmarkRequest(GPUDeviceBenchmarkMixin, SubgraphBenchmarkRequest): + pass + + +class SubgraphCPUBenchmarkRequest(CPUDeviceBenchmarkMixin, SubgraphBenchmarkRequest): + pass + + class CUTLASSBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest): """ A class to handle CUDA (CUTLASS) benchmark requests. This class is for @@ -856,7 +905,7 @@ def __init__( self.hash_key: str = "" self.source_file: str = "" self.device_type = device_type - self.codecache_cls = CUDACodeCache + self.codecache_cls = XPUCodeCache if device_type == "xpu" else CUDACodeCache self.device_interface = get_interface_for_device(device_type) self.hash_key, self.source_file = self.codecache_cls.write( self.source_code, "so" @@ -868,7 +917,7 @@ def precompile(self): This may happen in a separate thread pool. """ autotuning_log.debug("Precompiling %s", self) - CUDACodeCache.compile(self.source_code, "so") + self.codecache_cls.compile(self.source_code, "so") autotuning_log.debug("Done precompiling %s", self) def make_run_fn( @@ -890,8 +939,9 @@ def make_run_fn( args, self.extra_args, ) - current_stream = self.device_interface.current_stream() - stream_ptr = c_void_p(current_stream.cuda_stream) # type: ignore[attr-defined] + stream_ptr = c_void_p( + self.device_interface.get_raw_stream(self.device_interface.current_device()) + ) run_method = getattr(self.DLL, self.kernel_name) workspace_ptr = c_void_p(0) if self.workspace_size > 0: @@ -934,8 +984,9 @@ def update_workspace_size(self) -> None: dict.fromkeys(meta.name for meta in self.input_tensor_meta) ) args = [c_void_p(None) for _ in range(unique_input_count + 1)] - current_stream = self.device_interface.current_stream() - stream_ptr = c_void_p(current_stream.cuda_stream) # type: ignore[attr-defined] + stream_ptr = c_void_p( + self.device_interface.get_raw_stream(self.device_interface.current_device()) + ) run_method = getattr(self.DLL, self.kernel_name) # Retrieve workspace_size and initialize workspace. @@ -952,7 +1003,7 @@ def update_workspace_size(self) -> None: self.device_interface.synchronize() # shake out any device errors self.workspace_size = c_workspace_size.value autotuning_log.debug( - "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", self.workspace_size, self.kernel_name, self.source_file, @@ -968,6 +1019,7 @@ def ensure_dll_loaded(self): self.DLL, self.hash_key, self.source_file = self.codecache_cls.load( self.source_code, "so" ) + self.DLL.open() def cleanup_run_fn(self) -> None: if self.DLL is not None: @@ -1301,9 +1353,10 @@ def run_autotune_in_subprocess( return timing except Exception: - autotuning_log.error( + autotuning_log.warning( "Failed to benchmark choice %s", benchmark_request, + exc_info=True, ) # Use infinity for failed benchmarks so they're not selected return float("inf") @@ -1334,6 +1387,9 @@ def get_instance(cls) -> PrecompileThreadPool: return cls._instance def submit(self, fn, *args, **kwargs): + ctx = contextvars.copy_context() + # Need to copy context so workers have access to the correct config settings + fn = functools.partial(ctx.run, fn) return self._executor.submit(fn, *args, **kwargs) def _shutdown(self, wait: bool = False): diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index d5b7e4eb2a76d..50b0fbaf96c12 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -195,6 +195,7 @@ def index_expr(cls, index: Any, dtype: torch.dtype) -> ValueRanges[Any]: return cls.to_dtype(index, dtype) @staticmethod + # pyrefly: ignore [bad-override] def to_dtype( x: Any, dtype: torch.dtype, @@ -237,10 +238,12 @@ def cast(x: Any, dtype: torch.dtype) -> sympy.Expr: return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype)) @staticmethod + # pyrefly: ignore [bad-override] def square(x: Any) -> ValueRanges[Any]: return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod + # pyrefly: ignore [bad-override] def neg(x: Any) -> ValueRanges[Any]: return ValueRanges.decreasing_map(x, operator.neg) diff --git a/torch/_inductor/cache_key.py b/torch/_inductor/cache_key.py new file mode 100644 index 0000000000000..cfef4170a15b4 --- /dev/null +++ b/torch/_inductor/cache_key.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import base64 +import dataclasses +import hashlib +import json +from typing import Any, Literal, Protocol +from typing_extensions import assert_never + + +CacheKeyComponent = str | bytes | bytearray | memoryview + + +class _HashLike(Protocol): + def update(self, data: bytes, /) -> None: ... + def digest(self) -> bytes: ... + def hexdigest(self) -> str: ... + + +@dataclasses.dataclass(frozen=True) +class CacheKeyStrategy: + """ + Describes how an Inductor cache turns stable components into a cache key. + + Different caches intentionally use different key formats for compatibility + with existing on-disk and remote cache layouts. Keeping those choices in + named strategies makes the composition explicit at the call site. + """ + + # Human-readable label for repr/debugging; it is not part of the cache key. + name: str + digest_format: Literal["base32", "hex"] + prefix: str = "" + separator: bytes | None = None + base32_length: int = 51 + + @staticmethod + def _to_bytes(component: CacheKeyComponent) -> bytes: + if isinstance(component, str): + return component.encode("utf-8") + if isinstance(component, bytes): + return component + if isinstance(component, bytearray): + return bytes(component) + if isinstance(component, memoryview): + return component.tobytes() + raise TypeError(f"Unsupported cache key component: {type(component)!r}") + + def _hasher(self, components: tuple[CacheKeyComponent, ...]) -> _HashLike: + hasher = hashlib.sha256() + for idx, component in enumerate(components): + if idx > 0 and self.separator is not None: + hasher.update(self.separator) + hasher.update(self._to_bytes(component)) + return hasher + + def digest(self, *components: CacheKeyComponent) -> str: + hasher = self._hasher(components) + if self.digest_format == "hex": + return hasher.hexdigest() + if self.digest_format == "base32": + # [:51] strips the "Q====" suffix common to every SHA256 base32 digest. + return ( + base64.b32encode(hasher.digest())[: self.base32_length] + .decode("utf-8") + .lower() + ) + assert_never(self.digest_format) + + def key(self, *components: CacheKeyComponent) -> str: + return f"{self.prefix}{self.digest(*components)}" + + def key_from_json(self, value: Any, *, sort_keys: bool = True) -> str: + return self.key(json.dumps(value, sort_keys=sort_keys)) + + +COMPACT_CACHE_KEY_STRATEGY = CacheKeyStrategy( + name="compact", + digest_format="base32", +) + +CODE_CACHE_KEY_STRATEGY = CacheKeyStrategy( + name="code", + digest_format="base32", + prefix="c", + separator=b"||", +) + +FX_GRAPH_CACHE_KEY_STRATEGY = CacheKeyStrategy( + name="fx_graph", + digest_format="base32", + prefix="f", +) + +SYSTEM_CACHE_KEY_STRATEGY = CacheKeyStrategy( + name="system", + digest_format="hex", +) + +AUTOTUNE_CACHE_KEY_STRATEGY = CacheKeyStrategy( + name="autotune", + digest_format="hex", + # Preserve the existing autotune cache format, which concatenates components. + separator=None, +) diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index b9dc99bc113f8..1d48aea09b32f 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -95,7 +95,7 @@ def __lt__(self, other): class InductorChoices: """ - This class contains a collection of default heuristics that effect performance of our generated + This class contains a collection of default heuristics that affect performance of our generated code. We try to not put correctness requirements in this file. You can override the choices made here by doing: @@ -104,6 +104,9 @@ class MyHeuristics(InductorChoices): ... torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) + + Subclasses used with inductor_choices_class must implement uuid() for + cache key computation. """ def get_config_heuristics( @@ -240,11 +243,6 @@ def _need_to_fix_layout( Returns: True if we need to fix the layout, False otherwise """ - # TLX force mode uses Triton templates which require fixed layouts - # This check is independent of max_autotune - if config.is_fbcode() and config.triton.tlx_mode == "force": - return True - # TODO: debug and fix # NOTE: on mps, we see issues with flexible layouts on baddmm. This check just makes sure # that for mps, everything stays as it was before this optimization @@ -352,18 +350,29 @@ def triton_kernel_kwargs( """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations""" return kernel_kwargs + def override_best_choice( + self, + best_choice: ChoiceCaller, + timings: dict[ChoiceCaller, float], + ) -> ChoiceCaller: + """Hook to override the autotuning best choice after benchmarking.""" + return best_choice + + def customize_fused_kernel_name(self, fused_name: str, src_code: str) -> str: + """Hook to transform fused kernel names during codegen""" + return fused_name + @staticmethod - def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool: + def should_use_cooperative_reduction( + device: torch.device, numel: sympy.Expr, reduction_numel: sympy.Expr + ) -> bool: """Heuristic to decide if a cooperative reduction should be used.""" if config.triton.force_cooperative_reductions: return True - if ( - not config.triton.cooperative_reductions - or V.graph.get_current_device_or_throw().type == "cpu" - ): + if not config.triton.cooperative_reductions or device.type == "cpu": return False - xhint = V.graph.sizevars.optimization_hint(features.numel, fallback=2) + xhint = V.graph.sizevars.optimization_hint(numel, fallback=2) if xhint <= 8: threshold = 32768 * xhint elif xhint <= 16: @@ -371,11 +380,9 @@ def should_use_cooperative_reduction(features: SIMDKernelFeatures) -> bool: else: return False # TODO(jansel): should this default on for dynamic shapes? - # TODO(laith) What if hint(features.reduction_numel) >= threshold ? + # TODO(laith) What if hint(reduction_numel) >= threshold ? # shall we compare hints instead - return V.graph.sizevars.statically_known_geq( - features.reduction_numel, threshold - ) + return V.graph.sizevars.statically_known_geq(reduction_numel, threshold) @staticmethod def should_use_persistent_reduction( @@ -410,7 +417,7 @@ def should_use_persistent_reduction( lower = next_power_of_2(int(lower)) upper = next_power_of_2(int(upper)) - # If we are are coalescing on xblock (not ReductionHint.INNER) and this is not a tiny kernel + # If we are coalescing on xblock (not ReductionHint.INNER) and this is not a tiny kernel # (not ReductionHint.OUTER_TINY), do not use persistent reduction if it induces tile # quantization. Persistent reduction forces rblock == rnumel, if the bounds between lower # and upper are large, for the lower values we will be masking off large % of read/writes, @@ -468,7 +475,12 @@ def reduction_split_factor( # we leak reduction autotune configs here, and will need to refactor to avoid this later if numel_hint >= 2 * num_sm: # don't split if there are enough outputs return 1 - if reduction_numel_hint <= 8192: + # based on sum(x[N]) on GB200, split reduction provides higher performance when N >= 1M + # TODO: test more hardwares + no_split_threshold = ( + 524288 if props.major is not None and props.major >= 10 else 8192 + ) + if reduction_numel_hint <= no_split_threshold: return 1 if reduction_numel_hint * numel_hint <= min_elements_per_device: split_size = min_elements_per_thread @@ -648,11 +660,10 @@ def score_fusion( - Fusions closer together in original graph order """ - memory_score, buffer_overlap_score, is_mix_order_reduction = typing.cast( - tuple[int, int, bool], + memory_score, buffer_overlap_score, is_mix_order_reduction = ( scheduler.score_fusion_memory( node1, node2, return_is_mix_order_reduction=True - ), + ) ) proximity_score = -max( abs(node1.min_order - node2.max_order), diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 3127057e0dae2..10aba27bba242 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -34,7 +34,7 @@ from tempfile import _TemporaryFileWrapper from time import time, time_ns from types import ModuleType -from typing import Any, cast, Generic, NoReturn, TYPE_CHECKING, TypeVar +from typing import Any, cast, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar from typing_extensions import override, Self import torch @@ -65,11 +65,13 @@ _LINKER_SCRIPT, _set_gpu_runtime_env, _TORCH_PATH, + batch_convert_cubins_to_obj, convert_cubin_to_obj, CppBuilder, CppOptions, CppTorchDeviceOptions, get_compiler_version_info, + get_cpp_compiler, get_ld_and_objcopy, get_name_and_dir_from_output_file_path, normalize_path_separator, @@ -95,6 +97,7 @@ XPU_KERNEL_FORMAT, ) from torch._library.fake_class_registry import FakeScriptObject +from torch._library.opaque_object import is_opaque_reference_type from torch._logging import trace_structured from torch._subclasses.fake_tensor import ( extract_tensor_metadata, @@ -110,9 +113,19 @@ ) from torch.export.pt2_archive._package_weights import TensorProperties, Weights from torch.export.pt2_archive.constants import CUSTOM_OBJ_FILENAME_PREFIX -from torch.fx.experimental.symbolic_shapes import has_hint, ShapeEnv, size_hint +from torch.fx.experimental.symbolic_shapes import ( + guarding_hint_or_throw, + has_guarding_hint, + ShapeEnv, +) from torch.utils._ordered_set import OrderedSet +from .cache_key import ( + CODE_CACHE_KEY_STRATEGY, + COMPACT_CACHE_KEY_STRATEGY, + FX_GRAPH_CACHE_KEY_STRATEGY, + SYSTEM_CACHE_KEY_STRATEGY, +) from .output_code import CompiledFxGraph from .remote_cache import create_cache from .runtime import autotune_cache @@ -145,6 +158,8 @@ autotuning_log = torch._logging.getArtifactLogger(__name__, "autotuning") log = logging.getLogger(__name__) +AOTAUTOGRAD_CACHE_PREFIX = "a" + def get_cpp_wrapper_cubin_path_name() -> str: return "cubin_path" if torch.version.hip is None else "hsaco_path" @@ -207,9 +222,7 @@ def get_system() -> dict[str, Any]: # If cuda is not installed, none of the above config is relevant. system = {} - system["hash"] = hashlib.sha256( - json.dumps(system, sort_keys=True).encode("utf-8") - ).hexdigest() + system["hash"] = SYSTEM_CACHE_KEY_STRATEGY.key_from_json(system) return system @@ -327,16 +340,13 @@ def get_lock_dir() -> str: def sha256_hash(data: bytes) -> str: - # [:51] to strip off the "Q====" suffix common to every hash value. - return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() + return COMPACT_CACHE_KEY_STRATEGY.key(data) def code_hash(code: str | bytes, extra: str | bytes = "") -> str: - hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") if extra: - extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") - hashing_str = hashing_str + b"||" + extra_b - return "c" + sha256_hash(hashing_str) + return CODE_CACHE_KEY_STRATEGY.key(code, extra) + return CODE_CACHE_KEY_STRATEGY.key(code) def get_path( @@ -494,14 +504,19 @@ def __init__( self, gm: torch.fx.GraphModule, has_user_defined_triton_kernels: bool = False, + device_id_agnostic: bool = False, ) -> None: """ Create an FX graph pickler. If include_non_inlined=True, then pickling will include the _values_ for all Tensors. (Note that any tensors are constants attached as attributes to the GraphModule). Otherwise, pickling will include only the metadata for these tensors. + + If device_id_agnostic=True, device indices in TensorMetadata are normalized + to 0, so that the same graph on different GPUs produces identical bytes. """ self._stream = io.BytesIO() + self._device_id_agnostic = device_id_agnostic super().__init__(self._stream) self.dispatch_table = copyreg.dispatch_table.copy() @@ -511,6 +526,7 @@ def __init__( torch.Tensor: functools.partial(self._reduce_tensor), torch.nn.parameter.Parameter: functools.partial(self._reduce_tensor), torch.SymInt: functools.partial(self._reduce_symint), + torch.SymBool: functools.partial(self._reduce_symbool), torch.fx.experimental._backward_state.BackwardState: functools.partial( self._reduce_unsupported ), @@ -534,6 +550,10 @@ def _reduce_fake_tensor( Custom reducer to pickle FakeTensors. """ metadata = extract_tensor_metadata_for_cache_key(t) + if self._device_id_agnostic: + metadata = dataclasses.replace( + metadata, device=torch.device(metadata.device.type, 0) + ) return (_ident, (metadata,)) def _reduce_tensor( @@ -552,6 +572,10 @@ def _reduce_tensor( raise BypassFxGraphCache("mkldnn tensors unpickleable") metadata = extract_tensor_metadata_for_cache_key(t) + if self._device_id_agnostic: + metadata = dataclasses.replace( + metadata, device=torch.device(metadata.device.type, 0) + ) # If this is a non-inlined frozen parameter, we consider the metadata only. if is_frozen_param(t) and not GraphLowering.can_inline_constant(t): @@ -579,6 +603,14 @@ def _reduce_symint(self, s: SymInt) -> tuple[Callable[[T], T], tuple[str]]: # entity with SymInt args is safe to reuse. return (_ident, (str(s),)) + def _reduce_symbool(self, s: torch.SymBool) -> tuple[Callable[[T], T], tuple[str]]: + """ + Custom reducer to pickle SymBools. + """ + # Same approach as _reduce_symint: use the string representation for + # hashing. Guards ensure correctness on cache reload. + return (_ident, (str(s),)) + def _reduce_unsupported(self, s: Any) -> NoReturn: """ Custom reducer to handle any objects that we don't support and therefore @@ -649,7 +681,14 @@ def get_hash(self, obj: Any) -> str: Serialize an object and return a hash of the bytes. """ serialized_data = self.dumps(obj) - return sha256_hash(serialized_data) + return COMPACT_CACHE_KEY_STRATEGY.key(serialized_data) + + def get_key(self, obj: Any) -> str: + """ + Serialize an object and return an FX graph cache key. + """ + serialized_data = self.dumps(obj) + return FX_GRAPH_CACHE_KEY_STRATEGY.key(serialized_data) def debug_lines(self, inp: FxGraphHashDetails) -> list[str]: """ @@ -781,14 +820,71 @@ class BypassFxGraphCache(Exception): """ +_warned_pre_grad_pass_missing_uuid: OrderedSet[str] = OrderedSet() + + +def resolve_pre_grad_pass_timing() -> Literal["early", "late"]: + """Resolve the effective pre-grad pass timing from the config. + + "default" is resolved based on whether the custom pass provides a UUID: + passes with a UUID (or no custom pass) run "late" (after cache lookup), + passes without a UUID run "early" (before cache lookup). + + Raises RuntimeError if a custom pass without a UUID is explicitly set to + run "late", since the cache key cannot account for it. + """ + timing: Literal["early", "late", "default"] = config.pre_grad_pass_timing + custom_pass = config.pre_grad_custom_pass + has_uuid = ( + custom_pass + and isinstance(custom_pass, CustomGraphPass) + and custom_pass.uuid() is not None + ) + + if timing == "default": + supports_late = custom_pass is None or has_uuid + timing = "late" if supports_late else "early" + if timing == "early" and custom_pass: + pass_name = type(custom_pass).__qualname__ + if pass_name not in _warned_pre_grad_pass_missing_uuid: + _warned_pre_grad_pass_missing_uuid.add(pass_name) + log.warning( + "pre_grad_custom_pass %s does not implement uuid(); " + "falling back to early timing (pre-grad pass cache will be bypassed). " + "Implement uuid() on your CustomGraphPass to enable caching.", + pass_name, + ) + CompileEventLogger.try_add_pt2_compile( + "backend_compile", + pre_grad_pass_missing_uuid=True, + pre_grad_pass_name=pass_name, + ) + + if timing == "late" and custom_pass and not has_uuid: + raise RuntimeError( + "pre_grad_custom_pass must implement uuid() to run late " + "(after cache lookup). Either implement uuid() or set " + "pre_grad_pass_timing to 'early'." + ) + + return timing + + +@dataclasses.dataclass +class HashableOpaqueValue: + ordinal: int + + class FxGraphHashDetails: """ Object to capture all the details for a compiled FX graph relevant to computing a safe and stable cache key. """ - # Excluded kwargs param that are not stable between runs - EXCLUDED_KWARGS = ["graph_id"] + # Excluded kwargs param that are not stable between runs or that + # don't affect compiled output (like compile_region_name which is + # just a debug label). + EXCLUDED_KWARGS = ["graph_id", "compile_region_name"] def __init__( self, @@ -798,7 +894,19 @@ def __init__( inputs_to_check: Sequence[int], ) -> None: self.gm = gm - self.example_inputs = example_inputs + # Replace opaque references with hashable ordinals. What's important + # is that if the same reference appears twice then it's the same hash + # value for each. + processed_inputs: list[InputType | HashableOpaqueValue] = [] + seen_opaques: dict[int, HashableOpaqueValue] = {} + for inp in example_inputs: + if is_opaque_reference_type(type(inp)): + if id(inp) not in seen_opaques: + seen_opaques[id(inp)] = HashableOpaqueValue(len(seen_opaques)) + processed_inputs.append(seen_opaques[id(inp)]) + else: + processed_inputs.append(inp) + self.example_inputs = processed_inputs self.cache_key_tag = cconfig.cache_key_tag # Order kwargs so hashing is stable to changes in kwarg order. Although @@ -881,6 +989,12 @@ def __init__( torch.utils.deterministic.fill_uninitialized_memory, # type: ignore[attr-defined] ) + # Provenance tracking level affects whether provenance data is stored + # in the CompiledFxGraph, so it must be part of the cache key. + # Note: the "trace" prefix is excluded from _cache_config_ignore_prefix, + # so we add this explicitly. + self.provenance_tracking_level = config.trace.provenance_tracking_level + # Global settings affecting matmul codegen. self.cuda_matmul_settings = ( torch.backends.cuda.matmul.fp32_precision, @@ -908,7 +1022,11 @@ def __init__( self.torch_version = torch_key() self.system_info = CacheBase.get_system() self.inductor_config = config.save_config_portable(ignore_private_configs=False) - # Custom post grad passes should provide an ID to hash. + # Custom passes should provide an ID to hash when they run late (after cache lookup). + if resolve_pre_grad_pass_timing() != "early": + self.pre_grad_custom_pass = self._get_custom_pass_detail( + config.pre_grad_custom_pass + ) self.post_grad_custom_pre_pass = self._get_custom_pass_detail( config.post_grad_custom_pre_pass ) @@ -1016,7 +1134,7 @@ def compiled_fx_graph_hash( # The prefix distinguishes among the other kinds of objects we # cache in this module. - key = "f" + pickler.get_hash(details) + key = pickler.get_key(details) debug_lines = pickler.debug_lines(details) debug_str = "\n".join(debug_lines) log.debug(f"FX graph cache hash details for key {key}:\n{debug_str}") # noqa: G004 @@ -1194,7 +1312,9 @@ def _filter_backed_symints( Get the backed SymInt objects from the input list. Note that we can never have guards that depend on unbacked symint. """ - return [s for s in inputs if isinstance(s, torch.SymInt) and has_hint(s)] + return [ + s for s in inputs if isinstance(s, torch.SymInt) and has_guarding_hint(s) + ] @classmethod def _get_shape_env(cls: type[GuardedCache[T]]) -> ShapeEnv | None: @@ -1431,7 +1551,7 @@ def _lookup_graph( assert shape_env is not None symints = FxGraphCache._filter_backed_symints(example_inputs) - hints = [size_hint(s) for s in symints] + hints = [guarding_hint_or_throw(s) for s in symints] # If this config is turned on, everything is a guard hit and we check nothing if config.unsafe_skip_cache_dynamic_shape_guards: @@ -1569,6 +1689,11 @@ def _check_for_hop(gm: torch.fx.GraphModule) -> None: raise BypassFxGraphCache( f"Can't cache HigherOrderOperator: {node.target.name()}" ) + # TODO: this check is broken in two ways: + # 1. FX uses "get_attr" (with underscore), not "getattr" + # 2. It only checks for ScriptObject, not FakeScriptObject + # Fixing it would also bypass AOTAutogradCache (which calls + # _check_can_cache), so we'd need to decouple the two first. if node.op == "getattr" and isinstance( getattr(gm, node.target), torch._C.ScriptObject ): @@ -1580,8 +1705,15 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: Check some conditions that would preclude caching and raise BypassFxGraphCache to bypass in case caching is not possible. """ - # Post grad custom passes must implement the CustomGraphPass or we don't + # Custom passes must implement the CustomGraphPass or we don't # know how to include them in the cache key calculation. + # When timing is EARLY, pre-grad passes already ran before the cache + # lookup so there's nothing to validate here. + if resolve_pre_grad_pass_timing() != "early": + assert not config.pre_grad_custom_pass or ( + isinstance(config.pre_grad_custom_pass, CustomGraphPass) + and config.pre_grad_custom_pass.uuid() + ), "Unsupported pre grad custom pass" for p in (config.post_grad_custom_pre_pass, config.post_grad_custom_post_pass): if p and (not isinstance(p, CustomGraphPass) or not p.uuid()): raise BypassFxGraphCache("Unsupported post grad custom pass") @@ -1650,7 +1782,7 @@ def prepare_key( ) except BypassFxGraphCache as e: counters["inductor"]["fxgraph_cache_bypass"] += 1 - log.info("Bypassing FX Graph Cache because '%s'", e) # noqa: G200 + log.info("Bypassing FX Graph Cache because '%s'", e) if remote: log_cache_bypass("bypass_fx_graph", str(e)) cache_info = { @@ -2477,9 +2609,13 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: f.write(json.dumps(qual_name_to_id)) generated_files.append(constants_config_json) - gpu_codecache: ROCmCodeCache | CUDACodeCache = ( - ROCmCodeCache() if torch.version.hip else CUDACodeCache() - ) + cache_cls = { + "rocm": ROCmCodeCache, + "cuda": CUDACodeCache, + "xpu": XPUCodeCache, + }.get("rocm" if torch.version.hip else device_type, CUDACodeCache) + + gpu_codecache = cache_cls() gpu_kernels_o = gpu_codecache.aot_kernels_o.copy() # clear the list of aot kernels after each linking gpu_codecache.aot_kernels_o.clear() @@ -2491,7 +2627,9 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: cubins_o = [] asm_files = [] + fatbin_cmds: list[tuple[str, str]] = [] if not _IS_WINDOWS: + cubins_to_embed: list[tuple[str, str]] = [] ld, objcopy = get_ld_and_objcopy(use_relative_path) kernels = getattr(V.graph.wrapper_code, "_kernel_name_to_body", {}) for kernel_name, value in CudaKernelParamCache.cache.items(): @@ -2509,30 +2647,7 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: and device_type == "cuda" ): if torch.version.hip is None: - current_arch = ( - cuda_compile_utils._nvcc_arch_as_compile_option() - ) - cmd = ( - # pyrefly: ignore [unbound-name] - f"{cuda_compile_utils._cuda_compiler()} -fatbin {asm_file} -o {cubin_file} " - # Triton only allows generating PTX version as same as the current arch - f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " - # Include SASS for the current specific arch - f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " - ) - try: - subprocess.run( - cmd.split(), - capture_output=True, - text=True, - check=True, - ) - except subprocess.CalledProcessError as e: - print( - f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", - file=sys.stderr, - ) - raise + fatbin_cmds.append((asm_file, cubin_file)) else: # ROCm multi-arch: compile LLVM IR to multi-arch bundle @@ -2566,10 +2681,59 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: log.info("Created multi-arch bundle: %s", cubin_file) if config.aot_inductor.embed_kernel_binary: - # Embed cubin files into model.so using objcopy - cubins_o.append( - convert_cubin_to_obj(cubin_file, kernel_name, ld, objcopy) + cubins_to_embed.append((cubin_file, kernel_name)) + + # Compile PTX → fatbin in parallel (each nvcc call is independent). + # Must complete before cubin embedding below. + if fatbin_cmds: + from concurrent.futures import ThreadPoolExecutor + + current_arch = cuda_compile_utils._nvcc_arch_as_compile_option() + nvcc = cuda_compile_utils._cuda_compiler() + + def _compile_fatbin(asm_and_cubin: tuple[str, str]) -> None: + asm_f, cubin_f = asm_and_cubin + cmd = ( + f"{nvcc} -fatbin {asm_f} -o {cubin_f} " + f"-gencode arch=compute_{current_arch},code=compute_{current_arch} " + f"-gencode arch=compute_{current_arch},code=sm_{current_arch} " + ) + try: + subprocess.run( + cmd.split(), capture_output=True, text=True, check=True + ) + except subprocess.CalledProcessError as e: + print( + f"{cmd} failed with:\nstdout:\n{e.stdout}\nstderr:\n{e.stderr}", + file=sys.stderr, + ) + raise + + with ThreadPoolExecutor() as pool: + list(pool.map(_compile_fatbin, fatbin_cmds)) + + if cubins_to_embed: + # Batch all cubins into a single .o using .incbin assembly. + # This replaces N * 3 subprocess calls (ld + 2x objcopy per + # cubin) with a single compiler invocation. + try: + combined_obj = batch_convert_cubins_to_obj( + cubins_to_embed, + os.path.dirname(output_so), + cpp_compiler=get_cpp_compiler(), + ) + cubins_o.append(combined_obj) + except subprocess.CalledProcessError: + log.warning( + "Batched cubin embedding failed, " + "falling back to per-cubin objcopy" ) + for cubin_file, kernel_name in cubins_to_embed: + cubins_o.append( + convert_cubin_to_obj( + cubin_file, kernel_name, ld, objcopy + ) + ) output_name, output_dir = get_name_and_dir_from_output_file_path(output_so) so_build_options = CppTorchDeviceOptions( @@ -2579,6 +2743,26 @@ def _pad_to_alignment(raw_bytes: bytes) -> bytes: use_relative_path=use_relative_path, ) + if gpu_kernels_o and device_type == "xpu": + so_build_options = CppTorchDeviceOptions( + compiler="icpx", + vec_isa=picked_vec_isa, + device_type=device_type, + aot_mode=graph.aot_mode, + use_relative_path=use_relative_path, + extra_flags=[ + "-fsycl", + "-fsycl-targets=intel_gpu_pvc", + "-Xspirv-translator", + ( + "-spirv-ext=" + "+SPV_INTEL_split_barrier," + "+SPV_INTEL_2d_block_io," + "+SPV_INTEL_subgroup_matrix_multiply_accumulate" + ), + ], + ) + obj_srcs = [wrapper_o, kernel_o, consts_o, *gpu_kernels_o, *cubins_o] so_builder = CppBuilder( name=output_name, @@ -2938,9 +3122,12 @@ def load_async( # the optimized_code argument is present at all, since that's how the user of # this function opts in, but we do compilation and linking in one step if the # optimized_code argument is empty (as a micro-optimization). + # On GPU the C++ wrapper is just glue — the real kernels are compiled + # separately by Triton/CUDA. Always use -O1 to cut compile time. + min_optimize = optimized_code is not None or device_type != "cpu" main_build_option = CppTorchDeviceOptions( compile_only=bool(optimized_code), - min_optimize=optimized_code is not None, + min_optimize=min_optimize, # pyrefly: ignore [bad-argument-type] **compile_command, ) @@ -2988,7 +3175,7 @@ def get_hashable_command_line(build_option: BuildOptionsBase) -> str: main_build_option.precompiled_header = _precompile_header( header, main_cmd_line, - min_optimize=optimized_code is not None, + min_optimize=min_optimize, **compile_command, ) @@ -3088,7 +3275,15 @@ def _worker_compile_cpp( @clear_on_fresh_cache class CppPythonBindingsCodeCache(CppCodeCache): cache: dict[str, Callable[[], CDLL | ModuleType]] = {} - cache_clear = staticmethod(cache.clear) + _loaded_module_names: OrderedSet[str] = OrderedSet() + + @staticmethod + def cache_clear() -> None: + CppPythonBindingsCodeCache.cache.clear() + for name in CppPythonBindingsCodeCache._loaded_module_names: + sys.modules.pop(name, None) + CppPythonBindingsCodeCache._loaded_module_names.clear() + cpp_compile_command_flags = { # kernels have no dependency on libtorch "include_pytorch": False, @@ -3201,6 +3396,7 @@ def _load_library_inner(cls, path: str, key: str) -> ModuleType: assert spec is not None module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module + CppPythonBindingsCodeCache._loaded_module_names.add(module_name) assert spec.loader is not None spec.loader.exec_module(module) return module @@ -3270,7 +3466,11 @@ def load_pybinding(cls, *args: Any, **kwargs: Any) -> Any: @clear_on_fresh_cache class CppWrapperCodeCache(CppPythonBindingsCodeCache): cache: dict[str, Callable[[], CDLL | ModuleType]] = {} - cache_clear = staticmethod(cache.clear) + + @staticmethod + def cache_clear() -> None: + CppWrapperCodeCache.cache.clear() + cpp_compile_command_flags = { "include_pytorch": True, "shared": True, @@ -3884,8 +4084,12 @@ def __init__( ) -> None: self.lib_path = lib_path self.is_open = False - self.DLL = cdll.LoadLibrary(lib_path) - self.is_open = True + self.open() + + def open(self) -> None: + if not self.is_open: + self.DLL = cdll.LoadLibrary(self.lib_path) + self.is_open = True def close(self) -> None: if self.is_open: @@ -3894,7 +4098,16 @@ def close(self) -> None: def _dlclose(self) -> None: f_dlclose = None + # During Python interpreter shutdown, importing modules or calling + # dlclose is unsafe. Silently skip cleanup in that case. + try: + import sys + if sys.is_finalizing(): + return + except Exception: + # import machinery may already be torn down + return if is_linux(): syms = CDLL(None) if not hasattr(syms, "dlclose"): @@ -4119,7 +4332,7 @@ def compile( f.write(f"// {cls._BACKEND} {operation_name} cmd\n// {cmd}\n") start_time = time() log.debug("%s %s: %s", cls._BACKEND, operation_name, cmd) - cmd_parts = cmd.split(" ") + cmd_parts = shlex.split(cmd) try: if cls._use_re_build(): from triton.fb.re_build_helper import run_build_command @@ -4268,6 +4481,63 @@ def _source_code_extra(cls) -> str: return extra +from torch._inductor.codegen.xpu import compile_utils as xpu_compile_utils + + +@clear_on_fresh_cache +class XPUCodeCache(CUTLASSCodeCache): + _SOURCE_CODE_SUFFIX = "cpp" + _BACKEND = "XPU" + dll_cache: dict[str, DLLWrapper] = {} + + @classmethod + def _use_re_build(cls) -> bool: + return False + + @classmethod + def _compile_command( + cls, + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: list[str] | None = None, + ) -> str: + return xpu_compile_utils.xpu_compile_command( + src_files, dst_file, dst_file_ext, extra_args=extra_args + ) + + @classmethod + def _source_code_extra(cls) -> str: + extra = repr( + [ + xpu_compile_utils._sycl_compiler(), + xpu_compile_utils._sycl_compiler_options(), + cutlass_key(), + ] + ) + return extra + + @classmethod + def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]: + """ + Compiles source code and loads the generated .so file. + Returns a tuple of DLLWrapper, hash_key, source_code_path + """ + + if dst_file_ext != "so": + raise RuntimeError( + f"Only support loading a .so file for now. " + f"Requested file extension: {dst_file_ext}. Source code: {source_code}" + ) + dst_file_path, hash_key, source_code_path = cls.compile( + source_code, dst_file_ext + ) + if dst_file_path not in cls.dll_cache: + cls.dll_cache[dst_file_path] = DLLWrapper(dst_file_path) + + return (cls.dll_cache[dst_file_path], hash_key, source_code_path) + + @clear_on_fresh_cache class ROCmCodeCache: @dataclasses.dataclass diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 8359c75f74f1f..78df950947969 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -371,6 +371,7 @@ def cpp_scratch( device_op_overrides_dict: dict[str, DeviceOpOverrides] = {} +_device_op_overrides_initialized = False custom_backend_passes: dict[str, CustomGraphModulePass | None] = {} custom_backend_codegen_configs: dict[str, ConfigModule | None] = {} @@ -506,6 +507,7 @@ def init_backend_registration() -> None: from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen from .wrapper_fxir import WrapperFxCodegen + from .xpu.xpu_combined_scheduling import XPUCombinedScheduling if get_scheduling_for_device("cpu") is None: cpu_backends = { @@ -551,7 +553,7 @@ def init_backend_registration() -> None: if get_scheduling_for_device("xpu") is None: register_backend_for_device( "xpu", - TritonScheduling, + XPUCombinedScheduling, PythonWrapperCodegen, CppWrapperGpu, WrapperFxCodegen, @@ -626,24 +628,28 @@ def register_device_op_overrides( device_op_overrides_dict[device] = device_op_overrides -def get_device_op_overrides(device: str) -> DeviceOpOverrides: - assert isinstance(device, str), type(device) +def _initialize_device_op_overrides(): + # Use a flag rather than checking device_op_overrides_dict, since external/test + # code may partially populate it before we are called. + global _device_op_overrides_initialized + if _device_op_overrides_initialized: + return - if not device_op_overrides_dict: - from . import ( # noqa: F401 # noqa: F401 - cpu_device_op_overrides, - mps_device_op_overrides, - ) - from .cuda import device_op_overrides # noqa: F401 - from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401 - from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 + from . import mps_device_op_overrides # noqa: F401 + from .cpu_device_op_overrides import CpuDeviceOpOverrides + from .cuda import device_op_overrides # noqa: F401 + from .mtia import device_op_overrides as mtia_op_overrides # noqa: F401 + from .xpu import device_op_overrides as xpu_op_overrides # noqa: F401 - if device not in device_op_overrides_dict: - # For backends like TPU that only need no-op overrides (Pallas handles codegen) - from .cpu_device_op_overrides import CpuDeviceOpOverrides + # TPU uses Pallas for codegen and only needs no-op overrides + register_device_op_overrides("tpu", CpuDeviceOpOverrides()) - register_device_op_overrides(device, CpuDeviceOpOverrides()) + _device_op_overrides_initialized = True + +def get_device_op_overrides(device: str) -> DeviceOpOverrides: + assert isinstance(device, str), type(device) + _initialize_device_op_overrides() return device_op_overrides_dict[device] @@ -1011,34 +1017,42 @@ def constant(value: bool | float | int, dtype: torch.dtype) -> OpVarT: return repr(value) @staticmethod + # pyrefly: ignore [bad-override] def bitwise_not(x: OpVarT) -> OpVarT: return f"~{OpOverrides.paren(x)}" @staticmethod + # pyrefly: ignore [bad-override] def logical_not(a: OpVarT) -> OpVarT: return f"{OpOverrides.paren(a)} == 0" @staticmethod + # pyrefly: ignore [bad-override] def bitwise_and(x: OpVarT, y: OpVarT) -> OpVarT: return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}" @staticmethod + # pyrefly: ignore [bad-override] def bitwise_or(x: OpVarT, y: OpVarT) -> OpVarT: return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}" @staticmethod + # pyrefly: ignore [bad-override] def bitwise_xor(x: OpVarT, y: OpVarT) -> OpVarT: return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}" @staticmethod + # pyrefly: ignore [bad-override] def bitwise_left_shift(x: OpVarT, y: OpVarT) -> OpVarT: return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}" @staticmethod + # pyrefly: ignore [bad-override] def bitwise_right_shift(x: OpVarT, y: OpVarT) -> OpVarT: return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" @staticmethod + # pyrefly: ignore [bad-override] def int_truediv(a: OpVarT, b: OpVarT) -> OpVarT: # TODO: this is wrong # TODO: an easy bandaid is to generate runtime asserts that it's @@ -1154,6 +1168,7 @@ def inline_asm_elementwise( dtype: torch.dtype = torch.float32, is_pure: bool = True, pack: int = 1, + input_dtypes: tuple[torch.dtype, ...] | None = None, ) -> OpVarT: raise NotImplementedError( f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" @@ -2598,7 +2613,7 @@ def maybe_append_choice( choices.append(self.generate(**kwargs)) return None except NotImplementedError as e: - log.info( # noqa: G200 + log.info( "Cannot Append Choice: %s. KernelTemplate type is %s", e, type(self), @@ -2850,7 +2865,8 @@ def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None ) -> None: self.kernel.store_buffer_names.add(name) - if mode is None: + # Update store cache when mode is None or "tma" + if mode != "atomic_add": self._update_store_cache(name, value) if name not in V.graph.removed_buffers: self.kernel.store(name, index, value, mode=mode) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 0481c5e662a9e..4125ce4bcb103 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -746,6 +746,7 @@ def to_dtype(x, dtype, src_dtype=None, use_compute_types=True): return csevar @staticmethod + # pyrefly: ignore [bad-override] def round_to_int(x, dtype, src_dtype=None, use_compute_types=True): assert isinstance(x, CppCSEVariable) if src_dtype is None: @@ -763,14 +764,17 @@ def to_dtype_bitcast(x, dtype, src_dtype): return f"c10::bit_cast<{DTYPE_TO_CPP[dtype]}>({x})" @staticmethod + # pyrefly: ignore [bad-override] def abs(x): return f"std::abs({x})" @staticmethod + # pyrefly: ignore [bad-override] def sin(x): return f"std::sin({x})" @staticmethod + # pyrefly: ignore [bad-override] def cos(x): return f"std::cos({x})" @@ -779,6 +783,7 @@ def neg(x): return f"decltype({x})(-{x})" @staticmethod + # pyrefly: ignore [bad-override] def exp(x): # return f"Sleef_expf_u10({x})" return f"std::exp({x})" @@ -792,6 +797,7 @@ def expm1(x): return f"std::expm1({x})" @staticmethod + # pyrefly: ignore [bad-override] def erf(x): return f"std::erf({x})" @@ -800,14 +806,17 @@ def erfc(x): return f"std::erfc({x})" @staticmethod + # pyrefly: ignore [bad-override] def erfinv(x): return f"calc_erfinv({x})" @staticmethod + # pyrefly: ignore [bad-override] def sqrt(x): return f"std::sqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def rsqrt(x): return f"1 / std::sqrt({x})" @@ -824,14 +833,17 @@ def log1p(x): ) @staticmethod + # pyrefly: ignore [bad-override] def tan(x): return f"std::tan({x})" @staticmethod + # pyrefly: ignore [bad-override] def tanh(x): return f"std::tanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def signbit(x): """ On windows std::signbit only support float type. @@ -848,14 +860,17 @@ def pow(a, b): return f"std::pow({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def log(x): return f"std::log({x})" @staticmethod + # pyrefly: ignore [bad-override] def round(x): return f"std::nearbyint({x})" @staticmethod + # pyrefly: ignore [bad-override] def floor(x): return f"std::floor({x})" @@ -867,75 +882,93 @@ def floordiv(a, b): return f"(({a} < 0) != ({b} < 0) ? ({rem} != 0 ? {quot} - 1 : {quot}) : {quot})" @staticmethod + # pyrefly: ignore [bad-override] def ceil(x): return f"std::ceil({x})" @staticmethod + # pyrefly: ignore [bad-override] def trunc(x): return f"std::trunc({x})" @staticmethod + # pyrefly: ignore [bad-override] def truncdiv(a, b): # a and b are integer type return f"{a} / {b}" @staticmethod + # pyrefly: ignore [bad-override] def fmod(a, b): return f"std::fmod({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def isinf(x): return f"std::isinf({x})" @staticmethod + # pyrefly: ignore [bad-override] def isnan(x): return f"std::isnan({x})" @staticmethod + # pyrefly: ignore [bad-override] def lgamma(x): return f"std::lgamma({x})" @staticmethod + # pyrefly: ignore [bad-override] def acos(x): return f"std::acos({x})" @staticmethod + # pyrefly: ignore [bad-override] def acosh(x): return f"std::acosh({x})" @staticmethod + # pyrefly: ignore [bad-override] def cosh(x): return f"std::cosh({x})" @staticmethod + # pyrefly: ignore [bad-override] def sinh(x): return f"std::sinh({x})" @staticmethod + # pyrefly: ignore [bad-override] def asin(x): return f"std::asin({x})" @staticmethod + # pyrefly: ignore [bad-override] def asinh(x): return f"std::asinh({x})" @staticmethod + # pyrefly: ignore [bad-override] def atan2(x, y): return f"std::atan2({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def atan(x): return f"std::atan({x})" @staticmethod + # pyrefly: ignore [bad-override] def atanh(x): return f"std::atanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def copysign(x, y): return f"std::copysign({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def frexp(x): cache_keys = f"frexp({x})[0]", f"frexp({x})[1]" if all(V.kernel.cse.try_get(cache_key) is not None for cache_key in cache_keys): @@ -953,6 +986,7 @@ def frexp(x): return mantissa, exponent @staticmethod + # pyrefly: ignore [bad-override] def hypot(x, y): return f"std::hypot({x}, {y})" @@ -965,10 +999,12 @@ def log2(x): return f"std::log2({x})" @staticmethod + # pyrefly: ignore [bad-override] def ldexp(x, n): return f"std::ldexp({x}, {n})" @staticmethod + # pyrefly: ignore [bad-override] def nextafter(x, y): return f"std::nextafter({x}, {y})" @@ -989,14 +1025,17 @@ def relu(x): ) @staticmethod + # pyrefly: ignore [bad-override] def minimum(a, b): return f"min_propagate_nan({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a, b): return f"max_propagate_nan({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def where(a, b, c): return f"{a} ? {b} : {c}" @@ -1034,6 +1073,7 @@ def masked(mask, body, other): return f"{mask} ? {body_var}() : {other_code}" @staticmethod + # pyrefly: ignore [bad-override] def logical_and(a, b): return f"{a} && {b}" @@ -1042,10 +1082,12 @@ def logical_not(a): return f"!{a}" @staticmethod + # pyrefly: ignore [bad-override] def logical_or(a, b): return f"{a} || {b}" @staticmethod + # pyrefly: ignore [bad-override] def logical_xor(a, b): return f"{a} != {b}" @@ -1107,6 +1149,19 @@ def bitwise_right_shift(a, b): def rand(seed: sympy.Expr, offset: sympy.Expr): return f"normalized_rand_cpu({seed}, {offset})" + @staticmethod + def rand_eager( + seed: sympy.Expr, + base_offset: sympy.Expr, + threads_per_round: sympy.Expr, + tid: sympy.Expr, + vec: sympy.Expr, + ): + # NOTE: This is a codegen fallback used by the C++ backend for eager random. + # It is not intended to provide CPU parity; parity target is CUDA eager vs compiled. + # Keep this hook to satisfy codegen paths and CI. + return f"normalized_rand_cpu({seed}, {base_offset})" + @staticmethod def randn(seed: sympy.Expr, offset: sympy.Expr): return f"randn_cpu({seed}, {offset})" @@ -1120,6 +1175,7 @@ def sigmoid(x): return f"decltype({x})(1) / (decltype({x})(1) + std::exp(-{x}))" @staticmethod + # pyrefly: ignore [bad-override] def sign(x): code = BracesBuffer() scalar_zero = f"decltype({x})(0)" @@ -1465,6 +1521,12 @@ def remainder(a, b): assert a.dtype == b.dtype, ( "remainder vec implementation expect the same inputs' dtype." ) + if is_integer_dtype(a.dtype): + # Doing blend to set the remaining bits of b to non-zero + _t = f"decltype({a})" + if V.kernel._get_raw_num_vectors(b.dtype) < 1: + b = f"{_t}::blend<{(1 << V.kernel.tiling_factor) - 1}>({_t}(1), {b})" + return f"remainder_integral({a}, {b})" return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" @staticmethod @@ -2131,14 +2193,24 @@ def check_bounds( csevar = ops.index_expr(expr, torch.int64).value buffer = V.kernel.compute else: - # indexing in loads - prior_compute = V.kernel.compute - try: - V.kernel.compute = self.loads + # Prefer to put the assert in loads so it runs before the actual + # memory access. However, if the index expression may have already + # been CSE'd into compute by a prior ops.index_expr call, placing a + # reference to it in loads would be a forward reference (loads are + # emitted before compute in the kernel body). In that case fall + # back to compute. + idx_str = cexpr(self.rename_indexing(expr)) + if self.cse.try_get(idx_str) is not None: csevar = ops.index_expr(expr, torch.int64).value - finally: - V.kernel.compute = prior_compute - buffer = self.loads + buffer = V.kernel.compute + else: + prior_compute = V.kernel.compute + try: + V.kernel.compute = self.loads + csevar = ops.index_expr(expr, torch.int64).value + finally: + V.kernel.compute = prior_compute + buffer = self.loads size_str = V.kernel.sexpr(self.rename_indexing(size)) if upper else None @@ -3010,6 +3082,9 @@ def store(self, name, index, value, mode=None): else: raise NotImplementedError(f"store mode={mode}") + def _adjust_argreduce_index(self, index: sympy.Expr) -> sympy.Expr: + return index + def reduction(self, dtype, src_dtype, reduction_type, value): """ Perform vectorized reduction operation. @@ -3331,18 +3406,20 @@ def reduction_init_vec(self, reduction_type, dtype): return f"Welford<{vec_type}>()" if reduction_type in ("argmin", "argmax"): - cdtype = DTYPE_TO_CPP[scalar_type] + # For bool argmin/argmax, we use float for computations + compute_dtype = torch.float if dtype == torch.bool else scalar_type + cdtype = DTYPE_TO_CPP[compute_dtype] acc_type = self.reduction_acc_type_vec(reduction_type, dtype) if reduction_type == "argmin": val = ( f"std::numeric_limits<{cdtype}>::infinity()" - if is_float_dtype(dtype) + if is_float_dtype(dtype) or dtype == torch.bool else f"std::numeric_limits<{cdtype}>::max()" ) else: val = ( f"-std::numeric_limits<{cdtype}>::infinity()" - if is_float_dtype(dtype) + if is_float_dtype(dtype) or dtype == torch.bool else f"std::numeric_limits<{cdtype}>::min()" ) return f"{acc_type}({val})" @@ -3363,10 +3440,13 @@ def reduction_acc_type_vec(self, reduction_type, dtype): if is_welford_reduction(reduction_type): return f"Welford<{vec_type}>" if reduction_type in ("argmin", "argmax"): - n_src = self._get_num_vectors(scalar_type) n_idx = self._get_num_vectors(torch.int64) if dtype == torch.bool: + # For bool argmin/argmax, we use float for computations + # so n_src must be computed from float, not bool + n_src = self._get_num_vectors(torch.float) return f"IndexValueVec<{DTYPE_TO_CPP[torch.float]}, {n_src}, {n_idx}>" + n_src = self._get_num_vectors(scalar_type) return f"IndexValueVec<{DTYPE_TO_CPP[scalar_type]}, {n_src}, {n_idx}>" if dtype == torch.bool: assert reduction_type in ("min", "max", "any", "sum") @@ -3449,16 +3529,22 @@ def reduction_combine_vec( elif reduction_type in ("argmin", "argmax"): assert src_dtype is not None cdtype = DTYPE_TO_CPP[src_dtype] + compute_dtype = src_dtype if src_dtype == torch.bool: + # For bool argmin/argmax, we use float for computations cdtype = DTYPE_TO_CPP[torch.float] - n_src = self._get_num_vectors(src_dtype) + compute_dtype = torch.float + # Convert bool VecMask to float vector for argmax_combine_vec + if isinstance(next_value, CppCSEVariable) and next_value.is_vec: + (next_value,) = unify_mask_base_type(self.compute, (next_value,)) + n_src = self._get_num_vectors(compute_dtype) n_idx = self._get_num_vectors(torch.int64) t_extra = "" arg_extra = "" if index is not None: assert horizontal_reduction is not None t_extra = f", {str(horizontal_reduction).lower()}" - arg_extra = f", {index}" + arg_extra = f", {self._adjust_argreduce_index(index)}" if self.tail_size: return ( f"{reduction_type}_combine_vec<{cdtype}, {n_src}, {n_idx}{t_extra}>" @@ -3503,7 +3589,7 @@ def indirect_assert(self, var, lower, upper, mask=None): cond = f"{self._get_mask_type(var.dtype)}({cond})" if mask: if not mask.is_vec: - mask = f"{self._get_mask_type(var.dtype)}({mask})" + mask = f"{self._get_mask_type(var.dtype)}::from({mask})" # We need not check when the mask is False cond = f"({cond}) | ~({mask})" if self.tail_size: @@ -3609,6 +3695,10 @@ def inner_itervar(self): def need_vec_transpose(self, index): outer_var = self.itervars[self.outer_idx] inner_var = self.itervars[self.tiling_idx] + # Indirect indexing (SymT.TMP) variables are declared inside the inner + # loop, but transpose_mxn is emitted into preloads (before the loop). + if free_symbol_is_type(index, SymT.TMP): + return False outer_stride = stride_at_vec_range(index, outer_var, self.tiling_factor) inner_stride = stride_at_vec_range(index, inner_var, self.tiling_factor) return ( @@ -3768,6 +3858,9 @@ def transform_indexing(self, index: sympy.Expr) -> sympy.Expr: offset=self.inner_itervar(), ) + def _adjust_argreduce_index(self, index: sympy.Expr) -> sympy.Expr: + return self.transform_indexing(index) + def get_loop_body_lowp_fp(_body: LoopBody) -> tuple[torch.dtype | None, bool]: """ @@ -4733,6 +4826,77 @@ def group_fn(self, sizes): def reset_kernel_group(self): self.kernel_group = KernelGroup() + def _get_indexing_ranges_exprs(self, node): + if isinstance(node, FusedSchedulerNode): + assert len(node.snodes) > 0, node.snodes + var_ranges = None + indexing_exprs = OrderedSet[Any]() + for snode in node.snodes: + v, exprs = self._get_indexing_ranges_exprs(snode) + if var_ranges is None: + var_ranges = v + assert var_ranges == v, (var_ranges, v, node.snodes) + indexing_exprs.update(exprs) + return var_ranges, list(indexing_exprs) + + assert isinstance(node, SchedulerNode) + comp_buffer = node.node + assert isinstance(comp_buffer, ir.ComputedBuffer) + _, body, _ = comp_buffer.get_default_sizes_body() + return body.var_ranges, list(body.indexing_exprs.values()) + + def _snapshot_node_loop_states(self, node): + if isinstance(node, SchedulerNode): + return [(node, node.snapshot_loop_state())] + + assert isinstance(node, FusedSchedulerNode) + snapshots = [] + for snode in node.snodes: + assert isinstance(snode, SchedulerNode) + snapshots.append((snode, snode.snapshot_loop_state())) + return snapshots + + def _align_compatible_range_nodes(self, node1, node2): + assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) + assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) + + _, (vars1, reduce1) = node1.group + _, (vars2, reduce2) = node2.group + assert reduce1 == () and reduce2 == (), (reduce1, reduce2) + + node_to_recomp = node1 if len(vars1) < len(vars2) else node2 + ref_node = node2 if len(vars1) < len(vars2) else node1 + assert isinstance(node_to_recomp, SchedulerNode) + + ref_indexing_constraints = self._get_indexing_ranges_exprs(ref_node) + node_to_recomp.recompute_size_and_body( + extra_indexing_constraints=ref_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + if vars1 == vars2: + return True + + node_to_recomp_indexing_constraints = self._get_indexing_ranges_exprs( + node_to_recomp + ) + if isinstance(ref_node, SchedulerNode): + ref_node.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + else: + assert isinstance(ref_node, FusedSchedulerNode) + for snode in ref_node.snodes: + assert isinstance(snode, SchedulerNode) + snode.recompute_size_and_body( + extra_indexing_constraints=node_to_recomp_indexing_constraints + ) + + _, (vars1, _) = node1.group + _, (vars2, _) = node2.group + return vars1 == vars2 + def fuse(self, node1, node2): if node1.is_foreach() or node2.is_foreach(): return ForeachKernelSchedulerNode.fuse(node1, node2) @@ -4744,69 +4908,10 @@ def fuse(self, node1, node2): self._why_fuse_nodes(node1, node2) == ReasonFusedNodes.COMPATIBLE_RANGES_NO_REDUCTION ): - assert isinstance(node1, (SchedulerNode, FusedSchedulerNode)) - assert isinstance(node2, (SchedulerNode, FusedSchedulerNode)) - - _, (vars1, reduce1) = node1.group - _, (vars2, reduce2) = node2.group - assert reduce1 == () and reduce2 == (), (reduce1, reduce2) - - def get_indexing_ranges_exprs(node): - if isinstance(node, FusedSchedulerNode): - assert len(node.snodes) > 0, node.snodes - var_ranges = None - indexing_exprs = OrderedSet[Any]() - for snode in node.snodes: - v, exprs = get_indexing_ranges_exprs(snode) - if var_ranges is None: - var_ranges = v - assert var_ranges == v, (var_ranges, v, node.snodes) - indexing_exprs.update(exprs) - return var_ranges, list(indexing_exprs) - else: - assert isinstance(node, SchedulerNode) - comp_buffer = node.node - assert isinstance(comp_buffer, ir.ComputedBuffer) - _, body, _ = comp_buffer.get_default_sizes_body() - return body.var_ranges, list(body.indexing_exprs.values()) - - node_to_recomp = node1 if len(vars1) < len(vars2) else node2 - assert isinstance(node_to_recomp, SchedulerNode) - - ref_node = node2 if len(vars1) < len(vars2) else node1 - - ref_indexing_constraints = get_indexing_ranges_exprs(ref_node) - - node_to_recomp.recompute_size_and_body( - extra_indexing_constraints=ref_indexing_constraints - ) - - _, (vars1, _) = node1.group - _, (vars2, _) = node2.group - - if vars1 == vars2: - return FusedSchedulerNode.fuse(node1, node2) - - # recompute ref_node if its ranges are also changed - node_to_recomp_indexing_constraints = get_indexing_ranges_exprs( - node_to_recomp + assert self._align_compatible_range_nodes(node1, node2), ( + node1.group, + node2.group, ) - if isinstance(ref_node, SchedulerNode): - ref_node.recompute_size_and_body( - extra_indexing_constraints=node_to_recomp_indexing_constraints - ) - else: - assert isinstance(ref_node, FusedSchedulerNode) - for snode in ref_node.snodes: - assert isinstance(snode, SchedulerNode) - snode.recompute_size_and_body( - extra_indexing_constraints=node_to_recomp_indexing_constraints - ) - ref_node = FusedSchedulerNode(ref_node.scheduler, ref_node.snodes) - - _, (vars1, _) = node1.group - _, (vars2, _) = node2.group - assert vars1 == vars2, (vars1, vars2) return FusedSchedulerNode.fuse(node1, node2) elif self.can_fuse_vertical_outer_loop(node1, node2): return OuterLoopFusedSchedulerNode.fuse( @@ -4865,6 +4970,7 @@ def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): if isinstance(ref_node, FusedSchedulerNode): ranges_set = OrderedSet[tuple[Any, ...]]() for snode in ref_node.snodes: + assert isinstance(snode, SchedulerNode) if isinstance(snode.node, ir.TemplateBuffer): break assert isinstance(snode.node, ir.ComputedBuffer) @@ -4882,7 +4988,13 @@ def _can_fuse_nodes_with_compatible_ranges(self, node1, node2): if ranges1 != ranges2: return False - return True + snapshots = self._snapshot_node_loop_states(node_to_recomp) + snapshots.extend(self._snapshot_node_loop_states(ref_node)) + try: + return self._align_compatible_range_nodes(node1, node2) + finally: + for node, state in reversed(snapshots): + node.restore_loop_state(state) def _can_fuse_horizontal_impl(self, node1, node2): assert isinstance( @@ -5179,7 +5291,7 @@ def get_call_ranges(node: BaseSchedulerNode): for _node in node.get_outer_nodes() ): # Ref to the typical case of local buffer in - # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950 + # https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # where the buffer is with size of last dim and contiguous. # Only support this typical case at first. visited_scheduler_nodes: OrderedSet[str] = OrderedSet() @@ -5594,6 +5706,10 @@ def codegen_group(self, name=None) -> str: # 3. Function body with code.indent(): + code.writeline("std::atomic inductor_cpu_integer_div_error{0};") + code.writeline( + "inductor_cpu_integer_div_error_flag = &inductor_cpu_integer_div_error;" + ) if enable_kernel_profile: graph_id = V.graph.graph_id prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else "" @@ -5608,6 +5724,10 @@ def codegen_group(self, name=None) -> str: for old, new in self.args.aliases(): code.writeline(f"auto {old} = {new};") code.splice(self.loops_code) + code.writeline("inductor_cpu_integer_div_error_flag = nullptr;") + code.writeline( + "inductor_cpu_throw_if_integer_div_error(inductor_cpu_integer_div_error);" + ) return code.getvalue() def call_kernel(self, wrapper, kernel_name): diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 1df9b20862e47..e5299897ead7e 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -104,6 +104,7 @@ def get_promote_dtype(args): # pyrefly: ignore [no-matching-overload] functools.reduce( torch.promote_types, # type: ignore[arg-type] + # pyrefly: ignore [bad-argument-type] [n.dtype for n in args if isinstance(n, CppCSEVariable)], ) if all(n.dtype is not None for n in args if isinstance(n, CppCSEVariable)) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 6cb3ff894fd28..93be898a8cf7f 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -14,7 +14,7 @@ import torch import torch._higher_order_ops.torchbind -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch._ops from torch._inductor.runtime.runtime_utils import dynamo_timed from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes @@ -123,15 +123,16 @@ def _generate_temporary_array_pointer( # e.g. const double** is possible, but not const double* const*. This means # that an array containing pointers must _already_ be properly const-qualified # by the c_type, and not add additional const-ness. - # MSVC does not support implicitly converting a const iterator to a const pointer. - ptr_call = ( - "data()" - if force_mutable or c_type.endswith("*") or cpp_builder.is_msvc_cl() - else "cbegin()" - ) - return ( - f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}.{ptr_call}" - ) + array_expr = f"std::array<{c_type}, {len(elements)}>{{{', '.join(elements)}}}" + if force_mutable or c_type.endswith("*"): + return f"{array_expr}.data()" + + # MSVC does not support implicitly converting a const iterator to a const + # pointer, so use data() and cast to keep const qualification. + if cpp_builder.is_msvc_cl(): + return f"static_cast({array_expr}.data())" + + return f"{array_expr}.cbegin()" def _generate_kernel_call_helper( self, @@ -147,6 +148,7 @@ def _generate_kernel_call_helper( inductor_meta=None, graph_name="", original_fxnode_name=None, + current_stream_idx=None, ): """ Generates kernel call code. @@ -2733,7 +2735,7 @@ def parse_arg(arg_type: torch.JitType, codegen_arg: str) -> str: dispatch_lines.writeline("AOTI_TORCH_ERROR_CODE_CHECK(") with dispatch_lines.indent(): dispatch_lines.writeline( - f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' # noqa: B950 + f'aoti_torch_call_dispatcher("{op_overload._schema.name}", "{op_overload._schema.overload_name}", dispatch_vars.data())' ) dispatch_lines.writeline(");") @@ -2806,7 +2808,7 @@ def generate_fallback_kernel_with_runtime_lookup_python( for idx, output_arg in enumerate(output_args): if output_arg is None: continue - lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" # noqa: B950 + lines += f"{output_arg} = reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));\n" if raw_outputs: declarations_before_scope = [ diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 4f1a0989c8097..9263b909086b4 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -5,11 +5,12 @@ import sympy import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch._ops +from torch.utils._ordered_set import OrderedSet from .. import config, ir -from ..utils import sympy_product +from ..utils import IndentedBuffer, sympy_product from ..virtualized import V from .cpp_utils import DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu @@ -46,6 +47,8 @@ def __init__(self): assert self.device == "cpu", "ArrayRefTensor only supported on CPU!" self.allow_stack_allocation = config.aot_inductor.allow_stack_allocation self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} + self.v2_raw_wrapper_body = IndentedBuffer() + self.v2_raw_output_refs: list[str] | None = None @staticmethod def create( @@ -70,6 +73,16 @@ def get_input_cpp_type(input): return DTYPE_TO_CPP[dtype] return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + @staticmethod + def get_input_element_cpp_type(input): + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return DTYPE_TO_CPP[input.get_dtype()] + @staticmethod def get_device_include_path(device: str) -> str: assert device == "cpu", "ArrayRef only supported on CPU!" @@ -77,7 +90,8 @@ def get_device_include_path(device: str) -> str: return "#include " return "#include " - def codegen_input_numel_asserts(self): + def codegen_input_numel_asserts(self, indented_buffer=None): + writer = indented_buffer or self.prefix for name, buf in V.graph.graph_inputs.items(): if isinstance(buf, sympy.Expr): continue @@ -86,7 +100,154 @@ def codegen_input_numel_asserts(self): if sympy_product(buf.get_size()) == 0: continue numel = buf.get_numel() - self.prefix.writeline(f"assert_numel({name}, {numel});") + writer.writeline(f"assert_numel({name}, {numel});") + + def _codegen_v2_raw_input_bindings(self, code: IndentedBuffer): + for idx, (input_key, input_value) in enumerate(V.graph.graph_inputs.items()): + input_cpp_type = CppWrapperCpuArrayRef.get_input_element_cpp_type( + input_value + ) + if isinstance(input_value, sympy.Expr): + # cond / symint wrappers can surface symbolic scalar inputs here. + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input_value) + assert dtype is not None, "Fails to get the dtype of the sympy.Expr" + input_tensor = f"{input_key}_arrayref_tensor" + code.writeline( + f"auto {input_tensor} = torch::aot_inductor::c_to_arrayref_tensor<{input_cpp_type}>(c_inputs[{idx}]);" + ) + self.codegen_tensor_item(dtype, input_tensor, input_key, code) + else: + code.writeline( + f"auto {input_key} = torch::aot_inductor::c_to_arrayref_tensor<{input_cpp_type}>(c_inputs[{idx}]);" + ) + + def _codegen_v2_raw_input_symbols(self, code: IndentedBuffer) -> None: + bound_vars = OrderedSet[sympy.Symbol]() + graph_inputs = self.get_graph_inputs() + inputs = [ + (k, v) for k, v in graph_inputs.items() if isinstance(v, sympy.Symbol) + ] + [(k, v) for k, v in graph_inputs.items() if not isinstance(v, sympy.Symbol)] + + # Temporarily redirect self.prefix so the base class + # codegen_input_symbol_assignment writes into our buffer. + orig_prefix = self.prefix + self.prefix = code + try: + for name, value in inputs: + self.codegen_input_symbol_assignment(name, value, bound_vars) + finally: + self.prefix = orig_prefix + + for _, value in inputs: + if not isinstance(value, ir.TensorBox): + continue + for expr in [*value.get_size(), *value.get_stride()]: + if not isinstance(expr, sympy.Expr) or isinstance(expr, sympy.Symbol): + continue + undefined_symbols = [ + sym for sym in expr.free_symbols if sym not in bound_vars + ] + if len(undefined_symbols) > 0: + raise AssertionError( + f"For {expr}, expected {undefined_symbols} to have been codegen-ed." + ) + + def _codegen_v2_raw_prelude(self, code: IndentedBuffer): + self._codegen_v2_raw_input_bindings(code) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + code.writeline(f"""auto {constants_key} = constants_->at({idx});""") + + self._codegen_v2_raw_input_symbols(code) + + self.codegen_input_numel_asserts(code) + code.writeline( + "[[maybe_unused]] auto& kernels = static_cast(*this->kernels_.get());" + ) + + def _codegen_v2_raw_outputs( + self, code: IndentedBuffer, output_refs: list[str] + ) -> None: + cst_names = V.graph.constants.keys() + + def write_output_to_c_array(idx: int, output: str) -> None: + output_arrayref_name = f"output_arrayref_{idx}" + code.splice( + f""" + std::tuple_element_t<{idx}, AOTInductorModelOutputs> {output_arrayref_name}; + convert_handle_to_arrayref_tensor({output}, {output_arrayref_name}); + torch::aot_inductor::arrayref_tensor_to_c({output_arrayref_name}, c_outputs[{idx}]); + """ + ) + + for idx, output in enumerate(output_refs): + if output == "nullptr": + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + assert isinstance(output_storage, (ir.BaseView, ir.MutableBox)) + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + output_tensor = f"scalar_to_tensor_{next(self.scalar_to_tensor_id)}" + code.writeline( + f"RAIIAtenTensorHandle {output_tensor} = scalar_to_tensor_handle({output});" + ) + write_output_to_c_array(idx, output_tensor) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + code.writeline(f"if constexpr ({output_is_tensor_handle_expr}) {{") + with code.indent(): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + code.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + code.splice( + f""" + AtenTensorHandle {cached_output_name}_tmp; + aoti_torch_clone({output}, &{cached_output_name}_tmp); + {cached_output_name} = {cached_output_name}_tmp; + """ + ) + else: + code.writeline(f"{cached_output_name} = {output}.release();") + write_output_to_c_array(idx, cached_output_name) + code.writeline("} else {") + with code.indent(): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + output_arrayref_type = f"output_arrayref_{idx}_type" + output_element_type = f"output_arrayref_{idx}_element_type" + output_arrayref_name = f"output_arrayref_{idx}" + code.splice( + f""" + thread_local ThreadLocalCachedOutputArray> + {cached_output_name}({output}); + {cached_output_name}.copy_data_from({output}); + using {output_arrayref_type} = std::tuple_element_t<{idx}, AOTInductorModelOutputs>; + using {output_element_type} = typename {output_arrayref_type}::value_type; + auto {output_arrayref_name} = {cached_output_name}.arrayref_tensor<{output_element_type}>(); + torch::aot_inductor::arrayref_tensor_to_c({output_arrayref_name}, c_outputs[{idx}]); + """ + ) + code.writeline("}") def generate_extern_kernel_alloc(self, *args, **kwargs): # Disable stack allocation for extern kernels. @@ -117,6 +278,7 @@ def _generate_kernel_call_helper( inductor_meta=None, graph_name="", original_fxnode_name=None, + current_stream_idx=None, ): """ Generates kernel call code. @@ -153,6 +315,7 @@ def _generate_kernel_call_helper( self.writeline(self.wrap_kernel_call(kernel_name, new_args)) def write_wrapper_decl(self): + """Declare the generated AOTI wrapper entry points.""" inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: if ( @@ -271,6 +434,54 @@ def write_wrapper_decl(self): } """ ) + + self.suffix.splice( + f""" + // C-ABI-safe variant: uses flat AOTInductorArrayRefTensor arrays + // instead of std::tuple across the DSO boundary, and + // runs directly on the descriptor arrays to avoid + // DSO-side tuple marshaling. + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterfaceV2( + AOTInductorModelHandle model_handle, + int32_t num_inputs, + const AOTInductorArrayRefTensor* c_inputs, + int32_t num_outputs, + AOTInductorArrayRefTensor* c_outputs) {{ + constexpr int32_t expected_num_inputs = {len(V.graph.graph_inputs)}; + constexpr int32_t expected_num_outputs = {len(V.graph.graph_outputs)}; + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({{ + if (num_inputs != expected_num_inputs) {{ + throw std::runtime_error( + std::string("AOTInductorModelRunMinimalArrayrefInterfaceV2 expected ") + + std::to_string(expected_num_inputs) + + " inputs but got " + + std::to_string(num_inputs)); + }} + if (num_outputs != expected_num_outputs) {{ + throw std::runtime_error( + std::string("AOTInductorModelRunMinimalArrayrefInterfaceV2 expected ") + + std::to_string(expected_num_outputs) + + " outputs but got " + + std::to_string(num_outputs)); + }} + if (num_inputs > 0 && c_inputs == nullptr) {{ + throw std::runtime_error( + "AOTInductorModelRunMinimalArrayrefInterfaceV2 received null input descriptors"); + }} + if (num_outputs > 0 && c_outputs == nullptr) {{ + throw std::runtime_error( + "AOTInductorModelRunMinimalArrayrefInterfaceV2 received null output descriptors"); + }} + model->run_impl_minimal_arrayref_interface_v2_raw( + c_inputs, + c_outputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }}) + }} + """ + ) else: self.prefix.splice(run_impl_proto) else: @@ -368,6 +579,11 @@ def generate_return(self, output_refs: list[str]): and config.aot_inductor.use_minimal_arrayref_interface ) # For brevity. + if arr_iface and V.graph.aot_mode: + self.v2_raw_wrapper_body.clear() + self.v2_raw_wrapper_body.splice(self.wrapper_call) + self.v2_raw_output_refs = list(output_refs) + def use_thread_local_cached_output_tensor(idx, output): cached_output_name = f"cached_output_{next(self.cached_output_id)}" cache_type = "Array" if arr_iface else "Tensor" @@ -490,6 +706,31 @@ def use_thread_local_cached_output_tensor(idx, output): if arr_iface: self.wrapper_call.writeline("return output_arrayref_tensors;") + def generate_before_suffix(self, result): + super().generate_before_suffix(result) + if self.v2_raw_output_refs is None: + return + + raw_impl = IndentedBuffer() + raw_impl.splice( + """ + void AOTInductorModel::run_impl_minimal_arrayref_interface_v2_raw( + const AOTInductorArrayRefTensor* c_inputs, + AOTInductorArrayRefTensor* c_outputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + with raw_impl.indent(): + self._codegen_v2_raw_prelude(raw_impl) + raw_impl.splice(self.v2_raw_wrapper_body) + self._codegen_v2_raw_outputs(raw_impl, self.v2_raw_output_refs) + raw_impl.writeline( + "} // AOTInductorModel::run_impl_minimal_arrayref_interface_v2_raw" + ) + result.splice(raw_impl) + def memory_plan(self): from .memory_planning import MemoryPlanner @@ -669,6 +910,90 @@ def _assert_safe_to_use_borrow_arrayref_tensor_as_tensor(self): def is_safe_to_use_borrow_arrayref_tensor_as_tensor(self): return not self.allow_stack_allocation and not self.stack_allocated_buffers + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + + for (inner_input, inner_input_val), outer_input in zip( + subgraph.graph.graph_inputs.items(), outer_inputs + ): + if not isinstance(inner_input_val, ir.TensorBox): + continue + + # Wrap with a generic lambda so if constexpr can discard the + # ill-formed branch (if constexpr only discards in dependent + # contexts, i.e. templates / generic lambdas). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"[&](auto&& _t) {{ " + f"if constexpr (::torch::aot_inductor::is_arrayref_tensor_type_v" + f">) {{ " + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out(" + f"borrow_arrayref_tensor_as_tensor(_t), &{inner_input}_handle)); " + f"}} else {{ " + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out(" + f"_t, &{inner_input}_handle)); " + f"}} }}({outer_input});" + ) + self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);") + + def codegen_while_loop(self, while_loop, stack_output=False): + if stack_output: + raise NotImplementedError("NYI cpp wrapper for while_loop_stack_output") + is_bool_pred = isinstance( + while_loop.cond_subgraph.graph.graph_outputs[0], ir.ShapeAsConstantBuffer + ) + name = while_loop.get_name() + outer_carried_inputs = [ + buf.codegen_reference() for buf in while_loop.carried_inputs + ] + outer_additional_inputs = [ + buf.codegen_reference() for buf in while_loop.additional_inputs + ] + cond_result_name = f"{name}_cond_result" + if is_bool_pred: + self.writeline(f"bool {cond_result_name};") + else: + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + "AOTI_TORCH_ERROR_CODE_CHECK(" + f"aoti_torch_assign_tensors_out(borrow_arrayref_tensor_as_tensor({inp}), " + f"&{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) + + cond_outer_inputs.extend(outer_additional_inputs) + + cond_outer_outputs = [cond_result_name] + body_outer_inputs = list(cond_outer_inputs) + body_outer_outputs = body_outer_inputs[: len(outer_carried_inputs)] + + self.writeline("while (1) {") + self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph)) + self.codegen_subgraph( + while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs + ) + + if is_bool_pred: + cond_result = f"{cond_result_name}" + else: + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) + self.writeline(f"if (!{cond_result}) break;") + + self.writeline(ExitSubgraphLine(self)) + self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph)) + self.codegen_subgraph( + while_loop.body_subgraph, body_outer_inputs, body_outer_outputs + ) + self.writeline(ExitSubgraphLine(self)) + self.writeline("}") + def generate_c_shim_extern_kernel_call( self, kernel: str, args: list[str], device: str, **_ ) -> None: diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 0c32caab0e451..2caa18b89f230 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -77,151 +77,6 @@ def signature_is_tma_desc(sig: str | None) -> bool: return False -# Lazy compile helper code - only included in JIT mode -LAZY_COMPILE_HELPER = """ -#include -#include -#include - -struct LazyKernelCompileResult { - std::string cubin_path; - std::string mangled_name; - int num_warps; - int shared_mem; - int xblock; - int yblock; - int zblock; - int r0block; - int rsplit; - int rsplit_size; - int config_index; - int global_scratch; - int profile_scratch; -}; - -// Cached module and function references -static PyObject* triton_lazy_compile_module = nullptr; -static PyObject* start_kernel_compile = nullptr; -static PyObject* run_triton_kernel_with_autotune = nullptr; - -static inline void loadLazyCompileFuncs() { - if (triton_lazy_compile_module == nullptr) { - triton_lazy_compile_module = PyImport_ImportModule("torch._inductor.runtime.triton_lazy_compile"); - AOTI_TORCH_CHECK(triton_lazy_compile_module, "Failed to import triton_lazy_compile"); - - start_kernel_compile = PyObject_GetAttrString(triton_lazy_compile_module, "start_kernel_compile"); - AOTI_TORCH_CHECK(start_kernel_compile, "Failed to get start_kernel_compile"); - - run_triton_kernel_with_autotune = PyObject_GetAttrString(triton_lazy_compile_module, "run_triton_kernel_with_autotune"); - AOTI_TORCH_CHECK(run_triton_kernel_with_autotune, "Failed to get run_triton_kernel_with_autotune"); - } -} - -static inline std::string getStringAttr(PyObject* obj, const char* attr) { - RAIIPyObject val = PyObject_GetAttrString(obj, attr); - AOTI_TORCH_CHECK(val, "Failed to get attribute"); - return PyUnicode_AsUTF8(val); -} - -static inline int getIntAttr(PyObject* obj, const char* attr) { - RAIIPyObject val = PyObject_GetAttrString(obj, attr); - AOTI_TORCH_CHECK(val, "Failed to get attribute"); - return THPUtils_unpackLong(val); -} - -static inline int getOptionalIntAttr(PyObject* obj, const char* attr, int sentinel = -1) { - RAIIPyObject val = PyObject_GetAttrString(obj, attr); - AOTI_TORCH_CHECK(val, "Failed to get attribute"); - return (val.get() != Py_None) ? THPUtils_unpackLong(val) : sentinel; -} - -static inline LazyKernelCompileResult extractCompileResult(PyObject* result) { - LazyKernelCompileResult compile_result; - compile_result.cubin_path = getStringAttr(result, "cubin_path"); - compile_result.mangled_name = getStringAttr(result, "mangled_name"); - compile_result.num_warps = getIntAttr(result, "num_warps"); - compile_result.shared_mem = getIntAttr(result, "shared_mem"); - compile_result.xblock = getIntAttr(result, "xblock"); - compile_result.yblock = getIntAttr(result, "yblock"); - compile_result.zblock = getIntAttr(result, "zblock"); - compile_result.r0block = getIntAttr(result, "r0block"); - compile_result.rsplit = getIntAttr(result, "rsplit"); - compile_result.rsplit_size = getIntAttr(result, "rsplit_size"); - compile_result.config_index = getOptionalIntAttr(result, "config_index"); - compile_result.global_scratch = getOptionalIntAttr(result, "global_scratch"); - compile_result.profile_scratch = getOptionalIntAttr(result, "profile_scratch"); - return compile_result; -} - -template -static inline PyObject* convertArgToPython(const T& arg) { - using DecayedT = std::decay_t; - if constexpr (std::is_same_v) { - at::Tensor* tensor_ptr = torch::aot_inductor::tensor_handle_to_tensor_pointer(arg); - return THPVariable_Wrap(*tensor_ptr); - } else if constexpr (std::is_same_v) { - at::Tensor* tensor_ptr = torch::aot_inductor::tensor_handle_to_tensor_pointer(arg.get()); - return THPVariable_Wrap(*tensor_ptr); - } else if constexpr (std::is_same_v) { - PyObject* py_arg = arg ? Py_True : Py_False; - Py_INCREF(py_arg); - return py_arg; - } else if constexpr (std::is_integral_v) { - return PyLong_FromLongLong(static_cast(arg)); - } else if constexpr (std::is_floating_point_v) { - return PyFloat_FromDouble(static_cast(arg)); - } else { - AOTI_TORCH_CHECK(false, "Invalid input type to convertArgToPython"); - } -} - -template -static inline LazyKernelCompileResult runTritonKernelWithAutotune( - const std::string& kernel_name, - cudaStream_t stream, - const Args&... kernel_args) { - py::gil_scoped_acquire_simple acquire; - - constexpr size_t num_args = sizeof...(Args); - RAIIPyObject py_args_list = PyList_New(num_args); - AOTI_TORCH_CHECK(py_args_list, "Failed to create args list"); - - size_t idx = 0; - auto add_arg = [&py_args_list, &idx](PyObject* py_arg) { - AOTI_TORCH_CHECK(py_arg, "Failed to convert argument"); - PyList_SetItem(py_args_list, idx++, py_arg); - }; - (add_arg(convertArgToPython(kernel_args)), ...); - - RAIIPyObject call_args = PyTuple_Pack(3, - PyUnicode_FromString(kernel_name.c_str()), - PyLong_FromVoidPtr(stream), - py_args_list.get() - ); - AOTI_TORCH_CHECK(call_args, "Failed to create call args"); - - RAIIPyObject result = PyObject_CallObject(run_triton_kernel_with_autotune, call_args); - AOTI_TORCH_CHECK(result, "Failed to run kernel with autotuning"); - - return extractCompileResult(result); -} - -static inline void startKernelCompile(const std::string& kernel_name, const std::string& kernel_source) { - py::gil_scoped_acquire_simple acquire; - - RAIIPyObject py_name = PyUnicode_FromString(kernel_name.c_str()); - RAIIPyObject py_source = PyUnicode_FromString(kernel_source.c_str()); - AOTI_TORCH_CHECK(py_name && py_source, "Failed to create Python strings"); - - RAIIPyObject call_args = PyTuple_Pack(2, py_name.get(), py_source.get()); - AOTI_TORCH_CHECK(call_args, "Failed to create call args"); - - RAIIPyObject result = PyObject_CallObject(start_kernel_compile, call_args); - AOTI_TORCH_CHECK(result, "Failed to start kernel compilation"); -} -""" - - def _unpack_tma_descriptor_args(var_name: str, sig_type: str) -> list[str]: """Unpack a StableTMADescriptor into kernel launch args. @@ -334,7 +189,7 @@ def _write_wrapper_signature( ) # Write function signature - prefix.writeline(f"static inline void {self.wrapper_name}(") + prefix.writeline(f"static __attribute__((noinline)) void {self.wrapper_name}(") with prefix.indent(): for i, param in enumerate(param_lines): comma = "," if i < len(param_lines) - 1 else "" @@ -539,12 +394,14 @@ def _generate_lazy_scratch( device_type, _ = wrapper.codegen_device(torch.device(get_gpu_type())).split( ", " ) + device_ptr_type = wrapper.device_codegen.cpp_device_ptr() for scratch_name in ("global_scratch", "profile_scratch"): size_expr = f"{kernel_name}_result.{scratch_name}" var = f"{scratch_name}_ptr" prefix.splice( - f"""\ - CUdeviceptr {var} = 0; + maybe_hipify_code_wrapper( + f"""\ + {device_ptr_type} {var} = 0; RAIIAtenTensorHandle {var}_tensor; if ({size_expr} > 0) {{ int64_t {var}_size[] = {{{size_expr}}}; @@ -554,9 +411,10 @@ def _generate_lazy_scratch( 1, {var}_size, {var}_stride, {dtype_str}, {device_type}, device_idx_, &{var}_handle)); {var}_tensor = RAIIAtenTensorHandle({var}_handle); - {var} = reinterpret_cast({var}_tensor.data_ptr()); + {var} = reinterpret_cast<{device_ptr_type}>({var}_tensor.data_ptr()); }} """ + ) ) call_args_str += f", &{var}" return call_args_str @@ -607,11 +465,17 @@ def _generate_lazy_launch( ) call_args_str = self._generate_lazy_scratch(prefix, wrapper, call_args_str) + launch_args = ( + f"{kernel_name}, grid_0, grid_1, grid_2," + f" {kernel_name}_result.num_warps," + f" {kernel_name}_result.shared_mem," + f" kernel_args_, stream_" + ) + prefix.splice( f"""\ void* kernel_args_[] = {{{call_args_str}}}; - launchKernel({kernel_name}, grid_0, grid_1, grid_2, - {kernel_name}_result.num_warps, {kernel_name}_result.shared_mem, kernel_args_, stream_); + launchKernel({launch_args}); """ ) @@ -620,10 +484,6 @@ def generate_lazy(self, wrapper: CppWrapperGpu): Generate C++ code that embeds Triton source and compiles it at runtime. """ prefix = wrapper.prefix - if not wrapper._lazy_compile_helper_emitted: - prefix.splice(LAZY_COMPILE_HELPER) - wrapper._lazy_compile_helper_emitted = True - kernel_name = self.kernel_name # Track kernel names for parallel initialization wrapper._lazy_kernel_names.append(kernel_name) @@ -638,9 +498,8 @@ def generate_lazy(self, wrapper: CppWrapperGpu): ) prefix.writeline(kernel_var_decl) # Use delimited raw string to handle )" in kernel source - kernel_body = ( - f'R"TRITON(\n{self.kernel_name_to_body.get(kernel_name, "")}\n)TRITON"' - ) + kernel_source_str = self.kernel_name_to_body.get(kernel_name, "") + kernel_body = f'R"TRITON(\n{kernel_source_str}\n)TRITON"' prefix.writeline(f"static const char* {kernel_name}_source = {kernel_body};") prefix.writeline(f"static LazyKernelCompileResult {kernel_name}_result;") @@ -656,19 +515,32 @@ def generate_lazy(self, wrapper: CppWrapperGpu): tma_tensor_args = self.tma_tensor_args num_autotune_args = len(wrapper_arg_names) - len(tma_tensor_args) autotune_arg_list = [] - for name in wrapper_arg_names[:num_autotune_args]: + # Track which args need scalar extraction for the autotune call. + # UnwrapUnspecArg args are 0-dim tensors in C++ that Triton expects + # as Python scalars; we use codegen_tensor_item to extract them. + scalar_extractions: list[tuple[str, str, torch_dtype]] = [] + for idx, name in enumerate(wrapper_arg_names[:num_autotune_args]): if name in tma_signature_types: autotune_arg_list.append(f"_tma_tensor_{name}") + elif isinstance(self.arg_types[idx], UnwrapUnspecArg): + scalar_var = f"_autotune_scalar_{name}" + scalar_extractions.append((name, scalar_var, self.arg_types[idx].dtype)) + autotune_arg_list.append(scalar_var) else: autotune_arg_list.append(name) autotune_args = ", ".join(autotune_arg_list) # Lazy compile with autotuning on first invocation with prefix.indent(): - prefix.splice( - f"""\ - if ({kernel_name} == nullptr) {{ + prefix.writeline(f"if ({kernel_name} == nullptr) {{") + with prefix.indent(): + for tensor_name, scalar_var, dtype in scalar_extractions: + wrapper.codegen_tensor_item( + dtype, tensor_name, scalar_var, indented_buffer=prefix + ) + prefix.splice( + f"""\ {kernel_name}_result = runTritonKernelWithAutotune( - "{kernel_name}", stream_, {autotune_args}); + _module_pending_kernels, "{kernel_name}", stream_, {autotune_args}); {kernel_name} = loadKernel( {kernel_name}_result.cubin_path, @@ -677,9 +549,9 @@ def generate_lazy(self, wrapper: CppWrapperGpu): // First invocation already ran the kernel, so return early return; - }} - """ - ) + """ + ) + prefix.writeline("}") self._generate_lazy_grid(prefix) self._generate_lazy_launch( @@ -762,8 +634,9 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): ] arg_types = [arg_type_lookup[name] for name in call_args] arg_signatures = [triton_meta["signature"][name] for name in call_args] + num_ctas = params.get("config", {}).get("num_ctas", 1) scratch_spaces = { - name: params[name] + name: params[name] * num_ctas for name in ["global_scratch", "profile_scratch"] if params.get(name, None) is not None } @@ -785,8 +658,6 @@ def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): "kernel_args_", "stream_", ] - if wrapper.device == "xpu": - launch_kernel_args.append(str(params["threads_per_warp"])) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", @@ -941,7 +812,6 @@ def __init__(self) -> None: self._kernel_name_to_body: dict[str, str] = {} self._triton_call_wrappers: dict[str, DeferredTritonCallWrapper] = {} self.autotune_input_prefix = "_REAL_AUTOTUNE_INPUT" - self._lazy_compile_helper_emitted = False self._lazy_kernel_names: list[str] = [] @staticmethod @@ -1062,7 +932,7 @@ def finalize_prefix(self): # Generate parallel kernel compilation initialization function if self._lazy_kernel_names: start_compile_calls = "\n ".join( - f'startKernelCompile("{name}", {name}_source);' + f'startKernelCompile(_module_pending_kernels, "{name}", {name}_source);' for name in self._lazy_kernel_names ) self.prefix.splice( @@ -1070,6 +940,8 @@ def finalize_prefix(self): // Start parallel compilation of all Triton kernels static inline void start_all_triton_kernel_compiles() {{ loadLazyCompileFuncs(); + _module_pending_kernels = PyDict_New(); + AOTI_TORCH_CHECK(_module_pending_kernels, "Failed to create pending kernels dict"); {start_compile_calls} }} @@ -1303,6 +1175,7 @@ def _generate_kernel_call_helper( inductor_meta=None, graph_name="", original_fxnode_name=None, + current_stream_idx=None, ): """ Override the default value of argument 'gpu' to True here. @@ -1410,7 +1283,7 @@ def _generate_kernel_call_helper( self.writeline(f"{wrapper_name}({', '.join(call_args)});") else: casted = [] - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] for arg_type, arg in zip(arg_types, call_args): new_arg = arg if arg_type.endswith("*") and arg != "nullptr": @@ -1420,9 +1293,8 @@ def _generate_kernel_call_helper( call_args_str = ", ".join(casted) self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") - @staticmethod def prepare_triton_wrapper_args( - call_args: list[Any], arg_types: list[Any] + self, call_args: list[Any], arg_types: list[Any] ) -> tuple[list[Any], list[Any]]: assert len(call_args) == len(arg_types), (call_args, arg_types) new_args = [] @@ -1436,7 +1308,10 @@ def prepare_triton_wrapper_args( elif isinstance(arg, bool): new_args.append(str(arg).lower()) elif isinstance(arg, (int, float, SymbolicCallArg)): - new_args.append(str(arg)) + if isinstance(arg, float): + new_args.append(self.generate_float_value(arg)) + else: + new_args.append(str(arg)) else: new_args.append(cexpr(V.graph.sizevars.simplify(arg))) new_args_types.append(arg_type) diff --git a/torch/_inductor/codegen/cpp_wrapper_mps.py b/torch/_inductor/codegen/cpp_wrapper_mps.py index 7fa8e171b84c9..86b6cf92870cf 100644 --- a/torch/_inductor/codegen/cpp_wrapper_mps.py +++ b/torch/_inductor/codegen/cpp_wrapper_mps.py @@ -45,6 +45,7 @@ def _generate_kernel_call_helper( inductor_meta: dict[str, Any] | None = None, graph_name: str = "", original_fxnode_name: str | None = None, + current_stream_idx: int | None = None, ) -> None: """ Generates MPS kernel call code. It should look something like: diff --git a/torch/_inductor/codegen/cuda/compile_utils.py b/torch/_inductor/codegen/cuda/compile_utils.py index 79852d5eeed66..1da1990bb454d 100644 --- a/torch/_inductor/codegen/cuda/compile_utils.py +++ b/torch/_inductor/codegen/cuda/compile_utils.py @@ -2,7 +2,6 @@ import logging import os import shutil -from pathlib import Path import torch from torch._inductor import config @@ -156,7 +155,7 @@ def _nvcc_compiler_options() -> list[str]: "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", config.cutlass.compile_opt_level, - "-std=c++17", + "-std=c++20", "--expt-relaxed-constexpr", "-DNDEBUG", ] @@ -228,44 +227,3 @@ def cuda_compile_command( else: autotuning_log.debug("CUDA command: %s", res) return res - - -class CUDACompileSourceCapturingContext: - # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. - # Can be used to capture the sourcecode passed to CUDACodeCache.compile - - def __init__(self): - self.sources = [] - self._compile_patch = None - - def __enter__(self, *args, **kwargs): - import unittest.mock as mock - - import torch._inductor.codecache - - _compile_method_orig = torch._inductor.codecache.CUDACodeCache.compile - - def my_compile(source_code, dst_file_ext, extra_args: list[str] | None = None): - self.sources.append(source_code) - return _compile_method_orig(source_code, dst_file_ext) - - # pyrefly: ignore [bad-assignment] - self._compile_patch = mock.patch( - "torch._inductor.codecache.CUDACodeCache.compile", my_compile - ) - self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] - return self - - def __exit__(self, *args, **kwargs): - self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] - - -def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path): - # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run - # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. - - extra_args = ["-DGENERATE_STANDALONE_RUNNER=1", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"] - compile_command = cuda_compile_command( - [str(srcpath)], str(exepath), "exe", extra_args=extra_args - ) - return compile_command diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index d271d9f57fee5..8d0f210fd6c25 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -336,7 +336,10 @@ def cpp_scratch( prefix = f"{prefix}_" if prefix else "" var_name = f"{prefix}scratch_{idx}" if workspace.size > 0: - size_array = f"int64_t {var_name}_size[] = {{{workspace.size}}};" + size_expr = ( + f"static_cast({workspace.size}) * grid_0 * grid_1 * grid_2" + ) + size_array = f"int64_t {var_name}_size[] = {{{size_expr}}};" stride_array = f"int64_t {var_name}_stride[] = {{1}};" device_type = "cached_torch_device_type_cuda" device_idx = "device_idx_" diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py index 709bbaf515887..175860efbaff0 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_kernel.py @@ -544,11 +544,19 @@ def _emit_scalar_fragment( self, expr_str: str, cute_dtype: str, torch_dtype: torch.dtype ) -> str: """ - Convert SSA expression to indexable scalar for tensor loads. + Convert expression to indexable scalar for tensor loads. Workaround for lack of gather support: SSA values cannot be used directly - as indices. This generates code to convert SSA → indexable scalar. + as indices in tensor loads. This generates code to convert SSA → indexable + scalar. Compile-time integer constants are already indexable and are + returned directly without the SSA round-trip. """ + # Constant integer expressions (e.g. sympy-folded offsets like "0") + # are already valid indices — skip the ssa_to_indexable round-trip + # which only accepts TensorSSA, not bare Python ints. + if expr_str.lstrip("-").isdigit(): + return expr_str + result = self.kernel.cse.newvar(dtype=torch_dtype) self.kernel.body.writeline( f"{result} = ssa_to_indexable({expr_str}, {cute_dtype})" diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py b/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py index 77d6e8585f934..ebfb2e0c08507 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_op_overrides.py @@ -254,6 +254,7 @@ def remainder(a, b): return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} % {b})") @staticmethod + # pyrefly: ignore [bad-override] def exp(x: CuteDSLArg) -> CuteDSLArg: """Exponential using CuteDSL cute.math.exp2 with log2(e) scaling.""" if CuteDSLOpOverrides._get_cse_var(x) is None: @@ -263,26 +264,31 @@ def exp(x: CuteDSLArg) -> CuteDSLArg: ) @staticmethod + # pyrefly: ignore [bad-override] def sqrt(x: CuteDSLArg) -> CuteDSLArg: """Square root using CuteDSL cute.math.sqrt function.""" return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sqrt({x})") @staticmethod + # pyrefly: ignore [bad-override] def log(x: CuteDSLArg) -> CuteDSLArg: """Natural logarithm using CuteDSL cute.math.log function.""" return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.log({x})") @staticmethod + # pyrefly: ignore [bad-override] def cos(x: CuteDSLArg) -> CuteDSLArg: """Cosine using CuteDSL cute.math.cos function.""" return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.cos({x})") @staticmethod + # pyrefly: ignore [bad-override] def sin(x: CuteDSLArg) -> CuteDSLArg: """Sine using CuteDSL cute.math.sin function.""" return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.sin({x})") @staticmethod + # pyrefly: ignore [bad-override] def erf(x: CuteDSLArg) -> CuteDSLArg: """Error function using CuteDSL cute.math.erf function.""" return CuteDSLOpOverrides._apply_unary_op(x, "cute.math.erf({x})") @@ -338,14 +344,17 @@ def _minmax(a: CuteDSLArg, b: CuteDSLArg, *, op: str) -> CuteDSLArg: return f"({lhs} if {lhs} {op} {rhs} else {rhs})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: return CuteDSLOpOverrides._minmax(a, b, op=">") @staticmethod + # pyrefly: ignore [bad-override] def minimum(a: CuteDSLArg, b: CuteDSLArg) -> CuteDSLArg: return CuteDSLOpOverrides._minmax(a, b, op="<") @staticmethod + # pyrefly: ignore [bad-override] def where( condition: CuteDSLArg, a: CuteDSLArg, @@ -384,6 +393,7 @@ def pow(a: CuteDSLArg, b: CuteDSLArg): return CuteDSLOpOverrides._apply_binary_op(a, b, "({a} ** {b})") @staticmethod + # pyrefly: ignore [bad-override] def abs(x: CuteDSLArg) -> CuteDSLArg: """Absolute value using CuteDSL cute.math.abs function.""" if isinstance(x, CSEVariable): diff --git a/torch/_inductor/codegen/cutedsl/cutedsl_template.py b/torch/_inductor/codegen/cutedsl/cutedsl_template.py index 17f681a302730..85ceb9d014558 100644 --- a/torch/_inductor/codegen/cutedsl/cutedsl_template.py +++ b/torch/_inductor/codegen/cutedsl/cutedsl_template.py @@ -57,10 +57,10 @@ def maybe_append_choice( choices.append(self.generate(**kwargs)) return None except NotImplementedError as e: - log.debug("CuteDSL template choice generation failed: %s", e) # noqa: G200 + log.debug("CuteDSL template choice generation failed: %s", e) return e except Exception as e: - log.debug("CuteDSL template choice generation error: %s", e) # noqa: G200 + log.debug("CuteDSL template choice generation error: %s", e) return NotImplementedError(f"CuteDSL template failed: {e}") def generate(self, **kwargs: Any) -> ChoiceCaller: diff --git a/torch/_inductor/codegen/cutlass/cache.py b/torch/_inductor/codegen/cutlass/cache.py index d1ee4a6006a9d..b990f052e1d16 100644 --- a/torch/_inductor/codegen/cutlass/cache.py +++ b/torch/_inductor/codegen/cutlass/cache.py @@ -10,7 +10,6 @@ import torch._inductor.config as config from torch._inductor.codecache import cutlass_key -from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version from torch._inductor.codegen.cutlass import serialization, utils from torch._inductor.codegen.cutlass.serialization import ( get_cutlass_operation_serializer, @@ -27,11 +26,11 @@ def get_config_request_key( arch: str, - cuda_version: str, + toolkit_version: str, instantiation_level: str, ) -> str: """ - Return a key for the full ops, based on cutlass key, arch, cuda version, instantiation level, and serialization.py file hash. + Return a key for the full ops, based on cutlass key, arch, toolkit version, instantiation level, and serialization.py file hash. """ # Get hash of serialization.py and cutlass_utils.py files using their module file paths @@ -47,7 +46,7 @@ def get_file_hash(file_module): [ cutlass_key().hex(), arch, - cuda_version, + toolkit_version, instantiation_level, serialization_hash, cutlass_utils_hash, @@ -65,7 +64,7 @@ def _generate_config_filename(request_key: str) -> str: @clear_on_fresh_cache @functools.cache -def maybe_fetch_ops() -> list[Any] | None: +def maybe_fetch_ops(device_type: str) -> list[Any] | None: """ Fetch ops from databases. """ @@ -73,10 +72,12 @@ def maybe_fetch_ops() -> list[Any] | None: return None # setup - arch: str = get_cuda_arch() - # get_cuda_version might return "12.4.0" or "12.4" - # but we want to use "12.4" - version: str = ".".join(get_cuda_version().split(".")[:2]) + arch: str = utils.cutlass_arch(device_type) + version: str = utils.toolkit_version(device_type) + if device_type == "cuda": + # get_cuda_version might return "12.4.0" or "12.4" + # but we want to use "12.4" + version = ".".join(version.split(".")[:2]) instantiation_level: str = config.cutlass.cutlass_instantiation_level # filename and filepath diff --git a/torch/_inductor/codegen/cutlass/gemm_template.py b/torch/_inductor/codegen/cutlass/gemm_template.py index a27af3d9225e0..3b9135cea2880 100644 --- a/torch/_inductor/codegen/cutlass/gemm_template.py +++ b/torch/_inductor/codegen/cutlass/gemm_template.py @@ -32,6 +32,7 @@ from ...utils import is_dynamic, Placeholder from ...virtualized import V from ..common import IndentedBuffer +from ..cuda import cuda_env from . import utils as cutlass_utils from .kernel import CUTLASSTemplateKernel from .python_evt import CutlassEVTCodegen, scaled_mm_evt @@ -394,12 +395,17 @@ for (int i=0; i IndentedBuffer: #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" - #include "cutlass/gemm/device/gemm_sparse.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" @@ -636,6 +651,13 @@ def header(self) -> IndentedBuffer: #include "cutlass/util/tensor_view_io.h" """ ) + if self.device_type != "xpu": + # XPU SYCL-TLA does not support sparse gemm yet + res.splice( + """ + #include "cutlass/gemm/device/gemm_sparse.h" + """ + ) if inductor_cutlass_config.generate_test_runner and not is_dynamic( *self.input_nodes, self.output_node ): @@ -713,12 +735,14 @@ def set_alignment(torch_layout, op_element) -> bool: bool: True if the alignment was successfully updated, False otherwise. """ alignment = cutlass_utils.get_max_alignment(torch_layout) - cuda_arch = cutlass_utils.get_cuda_arch() - if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: - return False - else: - op_element.alignment = alignment - return True + if torch.cuda.is_available(): + cuda_arch = cuda_env.get_cuda_arch() + cuda_arch = cutlass_utils._normalize_cuda_arch(cuda_arch) + if cuda_arch and int(cuda_arch) >= 90 and alignment < op_element.alignment: + return False + + op_element.alignment = alignment + return True @staticmethod def should_swap_XW( @@ -794,7 +818,7 @@ def fix_op_layout( if all_match: return op log.warning( - f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004, B950 + f"Cutlass GEMM Layout change: Input and/or output layouts have changed between autotuning/retuning and call to render on {self}. Applying workaround. This can lead to suboptimal performance. Match List: {match_list}" # noqa: G004 ) new_op = copy.deepcopy(op) @@ -961,7 +985,10 @@ def filter_op( # TODO: update epilogue functor according to epilogues. op.element_epilogue = op.accumulator_type() - if self.use_fast_accum is not None: + if ( + self.use_fast_accum is not None + and int(cutlass_utils._normalize_cuda_arch(cuda_env.get_cuda_arch())) == 90 + ): is_op_fast_accum = "fastaccum" in op.configuration_name() if self.use_fast_accum ^ is_op_fast_accum: return None @@ -989,6 +1016,12 @@ def filter_op( ): return None + # `_procedural_name` is decorated with @functools.cached_property in cutlass, and its value is + # cached based on the key `self`. After we modify some attributes of + # `self` (e.g., layout or alignment), the `self` itself doesn’t change, so the + # cached value remains stale. We therefore need to clear the cached value so that + # `_procedural_name` can be recomputed with the updated attributes. + del op._procedural_name return op def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 @@ -1010,10 +1043,10 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: return self.filtered_ops_cache[self.cache_key] with dynamo_timed("CUTLASSGemmTemplate.maybe_fetch_ops"): - maybe_ops = maybe_fetch_ops() + maybe_ops = maybe_fetch_ops(self.device_type) if maybe_ops is None: log.debug("Cannot fetch ops from cache, generating ops from scratch") - full_ops = cutlass_utils.gen_ops() + full_ops = cutlass_utils.gen_ops(self.device_type) ops = pytree.tree_flatten(full_ops)[0] else: log.debug("Using cached ops from cache") @@ -1101,7 +1134,7 @@ def render( # type: ignore[override] ) -> str: """ The primary entry point for the code rendering process used in this template. - Renders the Cutlass based CUDA C++ code for the GEMM Kernel that this template is designed to implement, + Renders the Cutlass based CUDA/XPU C++ code for the GEMM Kernel that this template is designed to implement, including potentially fused epilogues. Args: @@ -1111,7 +1144,7 @@ def render( # type: ignore[override] **kwargs: Additional keyword arguments. Currently unused. Returns: - str: Cutlass based CUDA C++ code fragment as a string, to be used by the current + str: Cutlass based CUDA/XPU C++ code fragment as a string, to be used by the current CUTLASSTemplateKernel or autotuning code. Note: @@ -1126,10 +1159,11 @@ def render( # type: ignore[override] "op argument is required and has to be an instance of GemmOperation" ) - if epilogue_nodes and not self._has_tma_epilogue(op): - raise NotImplementedError( - "Non-TMA epilogue visitor tree is not supported in Cutlass." - ) + if epilogue_nodes: + if self.device_type == "cuda" and not self._has_tma_epilogue(op): + raise NotImplementedError( + "Non-TMA epilogue visitor tree is not supported in NV-Cutlass." + ) assert len(self.input_nodes) >= 2 and self.output_node is not None X, W = self.input_nodes[0], self.input_nodes[1] @@ -1332,7 +1366,7 @@ def test_call_statement( names_str: str = "", ) -> str: """ - Helper method to render the Cutlass CUDA C++ code required for calling the GEMM operation in the standalone + Helper method to render the Cutlass CUDA/XPU C++ code required for calling the GEMM operation in the standalone test runner that might also be generated along with the rest of the code, if the corresponding config is enabled. @@ -1347,7 +1381,7 @@ def test_call_statement( f"(({arg_type}){arg_name}_data.get())" for arg_type, arg_name in zip(arg_types, arg_names) ] - return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950 + return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, B, lda, ldb, ldc, ldd, 0, 0, 0, swizzle, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" def _render_evt( self, @@ -1357,7 +1391,7 @@ def _render_evt( name_to_buffer: dict[str, Buffer], output_dtype: torch.dtype, accumulator_dtype: torch.dtype, - ) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] # noqa: F821 + ) -> tuple[str, str, str, EVTArgRenames]: # type: ignore[name-defined] raise NotImplementedError("_render_evt in CUTLASSGemmTemplate not implemented") @@ -1407,7 +1441,7 @@ def _get_template_args( return (GEMM_ARGS_CUTLASS_3X, GEMM_ARGS_CUTLASS_3X_EPILOGUE) @staticmethod - def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] + def _has_tma_epilogue( # type: ignore[arg-type,name-defined] op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined,arg-type] # noqa: F821 ) -> bool: # type: ignore[name-defined] """Helper method: Determine whether a given Cutlass GEMM op has a TMA Epilogue""" @@ -1421,7 +1455,9 @@ def _has_tma_epilogue( # noqa: F821 # type: ignore[arg-type,name-defined] return result @staticmethod - def supports_epilogue_fusion(op: GemmOperation) -> bool: + def supports_epilogue_fusion(op: GemmOperation, device_type: str) -> bool: + if device_type == "xpu": + return True return CUTLASS3xGemmTemplate._has_tma_epilogue(op) def _are_inputs_layout_compatible(self, layouts: list[Layout]) -> bool: @@ -1535,6 +1571,7 @@ def _render_evt( {k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc] V.graph.sizevars.guarding_hint_or_throw, kernel_schedule=op.kernel_schedule, + device_type=self.device_type, ) return ( @@ -1587,7 +1624,7 @@ def _define_gemm_instance( op: GemmOperation, evt_name: str | None = None, ) -> tuple[str, str]: - """Defines and renders the Cutlass / CUDA C++ code for a given GEMM operation instance. + """Defines and renders the Cutlass / CUDA/XPU C++ code for a given GEMM operation instance. This function uses the Cutlass library to generate key parts of the codegen process. General Matrix Multiply forms a core part of a number of scientific applications, so this efficient and adaptable implementation is @@ -1605,7 +1642,9 @@ def _define_gemm_instance( from .lib_extensions import gemm_operation_extensions as gemm_extensions - emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT(evt_name=evt_name) # type: ignore[call-arg] + emitter = gemm_extensions.EmitGemmUniversal3xInstanceWithEVT( + evt_name=evt_name, device_type=self.device_type + ) # type: ignore[call-arg] if not hasattr(op, "epilogue_functor") or not isinstance( op.epilogue_functor, enum.Enum @@ -1666,7 +1705,7 @@ def render_gemm_arguments( epilogue_args, ) -> str: """ - Render the Cutlass CUDA C++ code required for passing arguments to the GEMM operation. + Render the Cutlass CUDA/XPU C++ code required for passing arguments to the GEMM operation. Args: argument_template (str): Template for the GEMM operation arguments. @@ -1679,11 +1718,11 @@ def render_gemm_arguments( Y (IRNode): The output tensor. alpha (float): Scaling factor for the product of the inputs. beta (float): Scaling factor for the output tensor. - kernel (CUTLASSTemplateKernel): CUDA Template kernel for the operation. + kernel (CUTLASSTemplateKernel): CUDA/XPU Template kernel for the operation. epilogue_args (any): Additional arguments for the epilogue state. Returns: - str: A block of CUDA C++ code as a string, ready to be used as arguments for the GEMM operation. + str: A block of CUDA/XPU C++ code as a string, ready to be used as arguments for the GEMM operation. Note: If `should_swap_xw` is True, a transpose operation will be applied to the X, W, Bias, and Y tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped @@ -1845,7 +1884,7 @@ def _alignment_match( # SparseGemm in CUTLASS has specific alignment check that for # small k could make some of the choices throw kMisalignedOperand # CUTLASS error when run, see: - # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # noqa: B950 + # https://github.com/NVIDIA/cutlass/blob/e01b9b5029b7caca5a43c29f7d2714d7cf1dcae8/include/cutlass/gemm/kernel/sparse_gemm.h#L198-L200 # So, let's skip these choices if that would be the case. X = self.input_nodes[0] return (X.get_size()[1] * 2) % op.tile_description.tile_shape[2] == 0 diff --git a/torch/_inductor/codegen/cutlass/kernel.py b/torch/_inductor/codegen/cutlass/kernel.py index c9af268303e7d..f0bf04a107a06 100644 --- a/torch/_inductor/codegen/cutlass/kernel.py +++ b/torch/_inductor/codegen/cutlass/kernel.py @@ -11,6 +11,7 @@ import torch._inductor.config as config from torch import dtype as torch_dtype +from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.scheduler import BaseSchedulerNode from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder @@ -111,6 +112,25 @@ def find_layout_arg( ): raise AssertionError("All matching layout args should be identical") return first_match + attr_values = node.get_size() if attr == "size" else node.get_stride() + if dim >= len(attr_values): + return None + expr = attr_values[dim] + fallback_matches = [] + for arg in itertools.chain.from_iterable(self.layout_args.values()): + if arg.attr != attr: + continue + if arg.node.get_name() != node.get_name(): + continue + arg_values = ( + arg.node.get_size() if arg.attr == "size" else arg.node.get_stride() + ) + if arg.dim >= len(arg_values): + continue + if arg_values[arg.dim] == expr: + fallback_matches.append(arg) + if fallback_matches: + return fallback_matches[0] return None def add_layout_arg( @@ -194,13 +214,12 @@ class CUTLASSTemplateKernel(CUTLASSKernel): Template kernels defined by Cutlass in C++. """ - _EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream" - def __init__( self, kernel_name: str, runtime_arg_info: list["ArgInfo"], runtime_arg_values: list[Any], + device_type: str = "cuda", # type: ignore[assignment] ) -> None: """ Initializes a new instance of the CUTLASSTemplateKernel class. @@ -212,6 +231,9 @@ def __init__( self.kernel_name = kernel_name self.runtime_arg_info = runtime_arg_info self.runtime_arg_values = runtime_arg_values + self.device_type = device_type + self.device_codegen = get_device_op_overrides(self.device_type) + self._EXTRA_CPP_ARGS = f"size_t* workspace_size, uint8_t* workspace, {self.device_codegen.cpp_stream_type()} stream" def check_not_null(self, node: IRNode) -> str: """ @@ -244,6 +266,22 @@ def check_not_null(self, node: IRNode) -> str: def get_signature(self) -> str: return self.signature + def _collect_unbound_layout_free_symbols(self, node: IRNode) -> OrderedSet[Expr]: + free_symbols: OrderedSet[Expr] = OrderedSet() + for attr_name, values in ( + ("size", node.get_size()), + ("stride", node.get_stride()), + ): + attr = attr_name # help mypy narrow the Literal argument below + for dim, expr in enumerate(values): + if not isinstance(expr, Expr): + continue + if self.find_layout_arg(node, attr, dim) is not None: + continue + for symbol in expr.free_symbols: + free_symbols.add(symbol) # type: ignore[arg-type] + return free_symbols + def def_kernel( self, inputs: list[IRNode], @@ -284,27 +322,18 @@ def def_kernel( self.named_nodes[name] = node self.args.input_buffers[node.get_name()] = name - free_symbols: OrderedSet[Expr] = OrderedSet() for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs): if node is not None: # NB: named nodes must be populated in the order of names self.named_nodes[name] = node self.args.output_buffers[node.get_name()] = name - if name not in ( - "X", - "W", - "Bias", - "Y", - ): # we handle these symbolic shapes explicitly - for expr in itertools.chain(node.get_size(), node.get_stride()): - if isinstance(expr, Expr): - for s in expr.free_symbols: - free_symbols.add(s) # type: ignore[arg-type] - arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE) self.init_layout_args() + free_symbols: OrderedSet[Expr] = OrderedSet() + for node in self.named_nodes.values(): + free_symbols |= self._collect_unbound_layout_free_symbols(node) size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"] size_vars.extend(str(s) for s in free_symbols) self.size_args.extend(free_symbols) diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py b/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py index 5357dc0bd98a5..9dd810445668c 100644 --- a/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py +++ b/torch/_inductor/codegen/cutlass/lib_extensions/evt_extensions.py @@ -10,7 +10,7 @@ ) from torch.utils._ordered_set import OrderedSet -from ..utils import torch_dtype_to_cutlass_type, try_import_cutlass +from ..utils import cutlass_arch, torch_dtype_to_cutlass_type, try_import_cutlass EpilogueFunctor = Any # EpilogueFunctor local class defined in _trace @@ -34,6 +34,7 @@ dtype2ctype, ) from cutlass_cppgen.backend.evt import ( # type: ignore[import-not-found] + backend as evt_backend, EpilogueFunctorVisitor, ) from cutlass_cppgen.backend.evt.backend.emitter_base import ( # type: ignore[import-not-found] @@ -42,9 +43,6 @@ from cutlass_cppgen.backend.evt.backend.sm100_emitter import ( # type: ignore[import-not-found] Sm100CollectiveEpilogue, ) - from cutlass_cppgen.backend.evt.backend.sm90_emitter import ( # type: ignore[import-not-found] - CollectiveEpilogue, - ) from cutlass_cppgen.backend.evt.frontend import ( # type: ignore[import-not-found] PythonASTFrontend, ) @@ -58,7 +56,6 @@ TileDescription, ) - from torch._inductor.codegen.cuda import cuda_env from torch._inductor.utils import IndentedBuffer _CUTLASS_C_DTYPES = OrderedSet(dtype2ctype.values()) # type: ignore[var-annotated] @@ -125,14 +122,27 @@ def trace( name_to_buffer: dict[str, Buffer], size_hint_fn: Callable[[Expr | int], int], kernel_schedule: Any | None = None, + device_type: str = "cuda", **kwargs: dict[str, Any], ) -> tuple[str, str, str, EVTArgRenames]: - cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type] - assert cuda_arch >= 90, "Only SM90+ is supported for EVT" - epilogue_functor = _trace(fn_src, example_tensors, cuda_arch, **kwargs) - visitor = EpilogueFunctorVisitor(cuda_arch, epilogue_functor) - fusion_callbacks = FusionCallbacks(visitor.graph, cuda_arch, emit_CD=False) - if cuda_arch < 100: + arch = int(cutlass_arch(device_type)) + assert device_type != "cuda" or arch >= 90, ( + "For CUDA, only SM90+ is supported for EVT" + ) + epilogue_functor = _trace(fn_src, example_tensors, arch, **kwargs) + visitor = EpilogueFunctorVisitor(arch, epilogue_functor) + fusion_callbacks = FusionCallbacks(visitor.graph, arch, emit_CD=False) + arch_prefix = "xe" if device_type == "xpu" else "sm" + + if device_type == "xpu" or arch < 100: + try: + evt_emitter = getattr(evt_backend, f"{arch_prefix}{arch}_emitter") + CollectiveEpilogue = evt_emitter.CollectiveEpilogue + except AttributeError as e: + raise NotImplementedError( + f"EVT backend is not supported on Arch {arch_prefix}{arch}." + ) from e + collective_epilogue = CollectiveEpilogue( tile_description, epilogue_schedule, @@ -179,7 +189,6 @@ def parse( # pyrefly: ignore [missing-attribute] self.visit(self.ast) - cc = int(cuda_env.get_cuda_arch()) epilogue_functor = EpilogueFunctor(cc=cc, **kwargs) epilogue_functor.trace(example_tensors) return epilogue_functor diff --git a/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py b/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py index d10669d40bea0..0cc47cbea5c3c 100644 --- a/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py +++ b/torch/_inductor/codegen/cutlass/lib_extensions/gemm_operation_extensions.py @@ -8,15 +8,15 @@ if try_import_cutlass(): import enum - from cutlass_library.gemm_operation import * # noqa: F401, F403 - from cutlass_library.library import * # noqa: F401, F403 + from cutlass_library.gemm_operation import * # noqa: F403 + from cutlass_library.library import * # noqa: F403 _LOGGER = logging.getLogger(__name__) class EmitGemmUniversal3xInstanceWithEVT: """Responsible for emitting a CUTLASS 3.x template definition""" - def __init__(self, operation_suffix="", evt_name=None): + def __init__(self, operation_suffix="", evt_name=None, device_type="cuda"): self.operation_suffix = operation_suffix self.includes = [ "cutlass/cutlass.h", @@ -32,8 +32,8 @@ def __init__(self, operation_suffix="", evt_name=None): ${element_c}, ${element_epilogue} >""" - self.evt_name = evt_name + self.device_type = device_type self.gemm_template = """ using ${operation_name}_epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -175,6 +175,8 @@ def emit(self, operation): f"cutlass::gemm::collective::StageCountAutoCarveout(\ sizeof(typename {str(operation.procedural_name())}_epilogue::SharedStorage))>" ) + if self.device_type == "xpu": + stage_count_string = "cutlass::gemm::collective::StageCountAuto" epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" @@ -350,6 +352,11 @@ def emit(self, operation): if self.evt_name: epilogue_functor = self.evt_name + if self.device_type == "xpu": + arch = f"cutlass::arch::Xe{operation.arch}" + else: + arch = f"cutlass::arch::Sm{operation.arch}" + values = { "operation_name": operation_name_str, "operation_suffix": self.operation_suffix, @@ -369,7 +376,7 @@ def emit(self, operation): "element_accumulator": DataTypeTag[operation.accumulator_type()], "opcode_class_main": OpcodeClassTag[opcode_class_main], "opcode_class_epi": OpcodeClassTag[opcode_class_epi], - "arch": f"cutlass::arch::Sm{operation.arch}", + "arch": arch, "tile_shape_m": str(tile_shape_m), "tile_shape_n": str(tile_shape_n), "tile_shape_k": str(tile_shape_k), diff --git a/torch/_inductor/codegen/cutlass/python_evt.py b/torch/_inductor/codegen/cutlass/python_evt.py index 5ad3880f0faa9..26e4a9b9cc062 100644 --- a/torch/_inductor/codegen/cutlass/python_evt.py +++ b/torch/_inductor/codegen/cutlass/python_evt.py @@ -169,6 +169,7 @@ def __init__(self, accumulator_node_name: str, removed_buffers: OrderedSet[str]) self.cur_node: ComputedBuffer | None = None self.name_to_buffer = V.graph.name_to_buffer | V.graph.graph_inputs for name in V.graph.constants: + # pyrefly: ignore [unsupported-operation] self.name_to_buffer[name] = V.graph.add_tensor_constant( V.graph.constants[name], name ) diff --git a/torch/_inductor/codegen/cutlass/scheduling.py b/torch/_inductor/codegen/cutlass/scheduling.py index 74f3cd9504482..74fb187dfe4c0 100644 --- a/torch/_inductor/codegen/cutlass/scheduling.py +++ b/torch/_inductor/codegen/cutlass/scheduling.py @@ -2,7 +2,7 @@ import hashlib import logging from collections.abc import Sequence -from typing import cast +from typing import cast, TypeGuard from torch._inductor.codegen.cutlass.python_evt import ( CutlassEVTCodegen, @@ -40,7 +40,7 @@ class CUTLASSScheduling(BaseScheduling): """ Partial Scheduling implementation for cutlass C++ Kernels. This class is intended to be used in combination with TritonScheduling, - and delegated to by CUDACombinedScheduling. + and delegated to by CUDACombinedScheduling/XPUCombinedScheduling. It handles fusion decisions and cutlass C++ specific template code generation. """ @@ -53,7 +53,7 @@ def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @staticmethod - def is_cutlass_template(node: BaseSchedulerNode) -> bool: + def is_cutlass_template(node: BaseSchedulerNode) -> TypeGuard[SchedulerNode]: return isinstance(node, SchedulerNode) and isinstance( node.node, CUTLASSTemplateBuffer ) @@ -109,7 +109,7 @@ def define_kernel(self, src_code: str, node_schedule) -> str: _, _, kernel_path = get_path(code_hash(src_code), "py") compile_wrapper = IndentedBuffer() - compile_wrapper.writeline("async_compile.cuda(r'''") + compile_wrapper.writeline(f"async_compile.{V.graph.device_type}(r'''") compile_wrapper.splice(src_code, strip=True) compile_wrapper.writeline( f"''', 'so', aot_compile={str(V.graph.aot_mode)})" @@ -136,7 +136,6 @@ def codegen_template( assert self.is_cutlass_template(template_node), ( "Template node passed to CUTLASSScheduling.codegen_template must be a SchedulerNode that wraps a CUTLASSTemplateBuffer" ) - template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: CUTLASSTemplateBuffer = cast(CUTLASSTemplateBuffer, template_node.node) @@ -283,13 +282,13 @@ def _can_fuse_epilogue_impl( not_implemented_op = not_implemented_op[4:] why( f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}, \ -likely due to unsupported operation: {not_implemented_op}" # noqa: G004, B950 +likely due to unsupported operation: {not_implemented_op}" ) return False else: # Likely due to unsupported dtype. why( f"Cannot fuse epilogue node {node_to_fuse} into {cutlass_template_buffer.name}. \ -Reason: {not_implemented_op}" # noqa: G004, B950 +Reason: {not_implemented_op}" ) return False diff --git a/torch/_inductor/codegen/cutlass/template.py b/torch/_inductor/codegen/cutlass/template.py index 10b40c688cd0c..ba0ee80727999 100644 --- a/torch/_inductor/codegen/cutlass/template.py +++ b/torch/_inductor/codegen/cutlass/template.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: - from ...scheduler import BaseSchedulerNode # noqa: TC004 + from ...scheduler import BaseSchedulerNode else: BaseSchedulerNode = Any @@ -56,7 +56,6 @@ def __init__( input_nodes: list[Buffer], layout: Layout, input_reorder: list[int] | None = None, - device_type: str = "cuda", ) -> None: """ Baseclass for CUTLASS C++ Templates, derived from KernelTemplate. @@ -74,7 +73,7 @@ def __init__( self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout - self.device_type = device_type + self.device_type = layout.device.type @classmethod @functools.lru_cache(None) @@ -83,7 +82,7 @@ def _template_from_string(cls, source: str) -> Any: return KernelTemplate._template_from_string(source) @staticmethod - def supports_epilogue_fusion(op: GemmOperation) -> bool: + def supports_epilogue_fusion(op: GemmOperation, device_type: str) -> bool: return False def make_key(self, name: str, input_key: str, layout_repr: str) -> str: @@ -133,6 +132,7 @@ def generate_code_and_args( kernel_name=kernel_name, runtime_arg_info=self.get_runtime_arg_info(), runtime_arg_values=self.get_runtime_arg_values(**kwargs), + device_type=self.device_type, ) with patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)): code = self.render(kernel=kernel, **kwargs) @@ -224,7 +224,9 @@ def generate( # type: ignore[override] supports_epilogue_fusion = False else: # epilogue fusion is only supported for TMA kernels - supports_epilogue_fusion = self.supports_epilogue_fusion(op) + supports_epilogue_fusion = self.supports_epilogue_fusion( + op, self.device_type + ) def make_kernel_render( template_node: CUTLASSTemplateBuffer, @@ -237,6 +239,7 @@ def make_kernel_render( kernel_name=str(Placeholder.KERNEL_NAME), runtime_arg_info=self.get_runtime_arg_info(), runtime_arg_values=self.get_runtime_arg_values(**kwargs), + device_type=self.device_type, ) render = functools.partial( self.render, diff --git a/torch/_inductor/codegen/cutlass/utils.py b/torch/_inductor/codegen/cutlass/utils.py index 95592f4a53959..2cc6d74ff4965 100644 --- a/torch/_inductor/codegen/cutlass/utils.py +++ b/torch/_inductor/codegen/cutlass/utils.py @@ -7,6 +7,7 @@ import sys import time from dataclasses import dataclass +from pathlib import Path from typing import Any from typing_extensions import TypeIs @@ -22,6 +23,7 @@ from ...runtime.runtime_utils import cache_dir from ...virtualized import V from ..cuda.cuda_env import get_cuda_arch, get_cuda_version +from ..xpu.xpu_env import get_xpu_arch, get_xpu_version log = logging.getLogger(__name__) @@ -77,12 +79,12 @@ def try_import_cutlass() -> bool: """ if config.is_fbcode(): try: - import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401 + import cutlass_cppgen # type: ignore[import-not-found] import cutlass_library # type: ignore[import-not-found] except ImportError as e: - log.warning( # noqa: G200 + log.warning( "Failed to import CUTLASS packages in fbcode: %s, ignoring the CUTLASS backend.", - str(e), + e, ) return False @@ -99,7 +101,12 @@ def path_join(path0, path1): # contains both cutlass and cutlass_library # we need cutlass for eVT - cutlass_python_path = path_join(config.cutlass.cutlass_dir, "python") + cutlass_dir = ( + config.xpu.cutlass_dir + if torch.xpu._is_compiled() + else config.cutlass.cutlass_dir + ) + cutlass_python_path = path_join(cutlass_dir, "python") torch_root = os.path.abspath(os.path.dirname(torch.__file__)) mock_src_path = os.path.join( torch_root, @@ -157,17 +164,17 @@ def link_and_append(dst_link, src_path, parent_dir): ) try: - import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401, F811 - import cutlass_library.generator # noqa: F401 - import cutlass_library.library # noqa: F401 + import cutlass_cppgen # type: ignore[import-not-found] # noqa: F401 + import cutlass_library.generator + import cutlass_library.library import cutlass_library.manifest # noqa: F401 import pycute # type: ignore[import-not-found] # noqa: F401 return True except ImportError as e: - log.debug( # noqa: G200 + log.debug( "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", - str(e), + e, ) else: log.debug( @@ -177,7 +184,15 @@ def link_and_append(dst_link, src_path, parent_dir): return False -@functools.lru_cache(8) +def _normalize_xpu_arch(arch: str) -> str: + if arch.startswith("Xe"): + return arch[2:] + if 12 <= int(arch) and int(arch) <= 50: + return arch + else: + raise NotImplementedError(f"Unsupported xpu arch: {arch}") + + def _normalize_cuda_arch(arch: str) -> str: arch_num = arch if isinstance(arch, str): @@ -206,6 +221,24 @@ def _normalize_cuda_arch(arch: str) -> str: raise NotImplementedError(f"Unsupported cuda arch: {arch}") +@functools.lru_cache(8) +def cutlass_arch(device_type: str) -> str: + if device_type == "xpu": + arch = get_xpu_arch() + return _normalize_xpu_arch(arch) + else: + arch = get_cuda_arch() + return _normalize_cuda_arch(arch) + + +@functools.lru_cache(1) +def toolkit_version(device_type: str) -> str: + if device_type == "xpu": + return get_xpu_version() + else: + return get_cuda_version() + + @dataclass class CUTLASSArgs: """ @@ -213,7 +246,7 @@ class CUTLASSArgs: """ architectures: str | None = None - cuda_version: str | None = None + toolkit_version: str | None = None instantiation_level: str | None = None operations: str | None = None @@ -229,18 +262,18 @@ class CUTLASSArgs: interface_dir: None = None filter_by_cc = True disable_full_archs_compilation = False + device_type: str = "cuda" def __post_init__(self): - if self.architectures is None or self.cuda_version is None: + if self.architectures is None or self.toolkit_version is None: raise RuntimeError( - f"{self.architectures=} or {self.cuda_version=} is None!" + f"{self.architectures=} or {self.toolkit_version=} is None!" ) - self.architectures = _normalize_cuda_arch(self.architectures) @clear_on_fresh_cache @functools.cache -def _gen_ops_cached(arch, version) -> dict[Any, Any]: +def _gen_ops_cached(arch: str, version: str, device_type: str) -> dict[Any, Any]: # Note: Cache needs to be specific for cuda architecture and version # Import cutlass python scripts. @@ -250,35 +283,46 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]: if arch is None or version is None: log.error( - "Cannot detect cuda arch %s or cuda version %s. " + "Cannot detect cuda arch %s or version %s. " "Will discard all cutlass ops. " "Please consider setting _inductor.cuda.arch and _inductor.cuda.version configs.", arch, version, ) return {} - arch = _normalize_cuda_arch(arch) + gen_arch = ( "100" if arch == "103" else arch ) # CUTLASS SM103 generator only covers NVFB4; fallback to SM100 set instantiation_level: str = config.cutlass.cutlass_instantiation_level args = CUTLASSArgs( architectures=gen_arch, - cuda_version=version, + toolkit_version=version, instantiation_level=instantiation_level, operations=CUTLASS_OPERATION_KIND, + device_type=device_type, ) manifest = cutlass_manifest.Manifest(args) start_time = time.time() - if gen_arch == "100": + if device_type == "xpu": + if hasattr(cutlass_generator, "GenerateIntelXe"): + cutlass_generator.GenerateIntelXe( + manifest, args.toolkit_version, arch=int(arch) + ) + else: + raise NotImplementedError( + "Arch " + arch + " is not supported by current cutlass lib." + ) + + elif arch == "100": if hasattr(cutlass_generator, "GenerateSM100"): - cutlass_generator.GenerateSM100(manifest, args.cuda_version) - cutlass_generator.GenerateSM90(manifest, args.cuda_version) + cutlass_generator.GenerateSM100(manifest, args.toolkit_version) + cutlass_generator.GenerateSM90(manifest, args.toolkit_version) else: try: func = getattr(cutlass_generator, "GenerateSM" + gen_arch) - func(manifest, args.cuda_version) + func(manifest, args.toolkit_version) except AttributeError as e: raise NotImplementedError( "Arch " + gen_arch + " is not supported by current cutlass lib." @@ -292,26 +336,35 @@ def _gen_ops_cached(arch, version) -> dict[Any, Any]: return manifest.operations -def gen_ops() -> dict[Any, Any]: +def gen_ops(device_type: str) -> dict[Any, Any]: """ Generates all supported CUTLASS operations. """ with dynamo_timed("cutlass_utils.gen_ops"): - arch = get_cuda_arch() - version = get_cuda_version() - return _gen_ops_cached(arch, version) + arch = cutlass_arch(device_type) + version = toolkit_version(device_type) + return _gen_ops_cached(arch, version, device_type) from ..cpp_utils import DTYPE_TO_CPP -DTYPE_TO_CUTLASS_TYPE = { - **DTYPE_TO_CPP, - torch.float16: "__half", - torch.bfloat16: "__nv_bfloat16", - torch.float8_e4m3fn: "__nv_fp8_e4m3", - torch.float8_e5m2: "__nv_fp8_e5m2", -} +if torch.xpu._is_compiled(): + DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "uint16_t", + torch.bfloat16: "uint16_t", + torch.float8_e4m3fn: "uint8_t", + torch.float8_e5m2: "uint8_t", + } +else: + DTYPE_TO_CUTLASS_TYPE = { + **DTYPE_TO_CPP, + torch.float16: "__half", + torch.bfloat16: "__nv_bfloat16", + torch.float8_e4m3fn: "__nv_fp8_e4m3", + torch.float8_e5m2: "__nv_fp8_e5m2", + } @functools.lru_cache(32) @@ -328,6 +381,10 @@ def torch_dtype_to_cutlass_type( return cutlass_library.library.DataType.f16 elif torch_dtype == torch.bfloat16: return cutlass_library.library.DataType.bf16 + elif torch_dtype == torch.float8_e4m3fn: + return cutlass_library.library.DataType.e4m3 + elif torch_dtype == torch.float8_e5m2: + return cutlass_library.library.DataType.e5m2 else: raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") @@ -478,3 +535,59 @@ def a_factor_of(x, alignment): ): return alignment return 1 + + +class CUTLASSCompileSourceCapturingContext: + # Helper class for Benchmarking and Testing CUTLASS Kernels in isolation. + # Can be used to capture the sourcecode passed to CUDACodeCache.compile + + def __init__(self, device_type: str): + self.sources = [] + self._compile_patch = None + self.device_type = device_type + + def __enter__(self, *args, **kwargs): + import unittest.mock as mock + + import torch._inductor.codecache + + codecache_cls = ( + torch._inductor.codecache.XPUCodeCache + if self.device_type == "xpu" + else torch._inductor.codecache.CUDACodeCache + ) + _compile_method_orig = codecache_cls.compile + + def my_compile(source_code, dst_file_ext, extra_args: list[str] | None = None): + self.sources.append(source_code) + return _compile_method_orig(source_code, dst_file_ext) + + # pyrefly: ignore [bad-assignment] + self._compile_patch = mock.patch( + f"torch._inductor.codecache.{codecache_cls.__name__}.compile", my_compile + ) + self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr] + return self + + def __exit__(self, *args, **kwargs): + self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr] + + +def cutlass_standalone_runner_compile_command( + device_type: str, srcpath: Path, exepath: Path +): + # returns command string to compile a (captured) CUDA GEMM Kernel source to a standalone executable that's ready to run + # Passes the correct preprocessor define to nvcc to ensure the standalone runner is enabled. + + extra_args = ["-DGENERATE_STANDALONE_RUNNER=1"] + if device_type != "xpu": + extra_args.append("-DCUTLASS_DEBUG_TRACE_LEVEL=1") + cutlass_compile_command = ( + torch._inductor.codegen.xpu.compile_utils.xpu_compile_command + if device_type == "xpu" + else torch._inductor.codegen.cuda.compile_utils.cuda_compile_command + ) + compile_command = cutlass_compile_command( + [str(srcpath)], str(exepath), "exe", extra_args=extra_args + ) + return compile_command diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 45018909f972e..f782dea55fd4e 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -289,20 +289,24 @@ def constant(cls, value, dtype): return cls.to_dtype(halide_constant(value), dtype) @staticmethod + # pyrefly: ignore [bad-override] def abs(x): return f"hl.abs({x})" @staticmethod + # pyrefly: ignore [bad-override] def exp(x): if not hasattr(x, "name"): return f"hl.exp({x})" return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})" @staticmethod + # pyrefly: ignore [bad-override] def sqrt(x): return f"hl.sqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def minimum(a, b): # return f"hl.min({a}, {b})" <== handles nan wrong if not hasattr(a, "name"): @@ -311,6 +315,7 @@ def minimum(a, b): return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a, b): # return f"hl.max({a}, {b})" <== handles nan wrong if not hasattr(a, "name"): @@ -319,80 +324,99 @@ def maximum(a, b): return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def where(a, b, c): if hasattr(b, "name"): c = f"hl.cast({b.name}.type(), {c})" return f"hl.select({a}, {b}, {c})" @staticmethod + # pyrefly: ignore [bad-override] def cos(x): return f"hl.cos({x})" @staticmethod + # pyrefly: ignore [bad-override] def sin(x): return f"hl.sin({x})" @staticmethod + # pyrefly: ignore [bad-override] def lgamma(x): raise Unsupported("lgamma") @staticmethod + # pyrefly: ignore [bad-override] def erf(x): return f"hl.erf({x})" @staticmethod + # pyrefly: ignore [bad-override] def cosh(x): return f"hl.cosh({x})" @staticmethod + # pyrefly: ignore [bad-override] def sinh(x): return f"hl.sinh({x})" @staticmethod + # pyrefly: ignore [bad-override] def acos(x): return f"hl.acos({x})" @staticmethod + # pyrefly: ignore [bad-override] def acosh(x): return f"hl.acosh({x})" @staticmethod + # pyrefly: ignore [bad-override] def asin(x): return f"hl.asin({x})" @staticmethod + # pyrefly: ignore [bad-override] def asinh(x): return f"hl.asinh({x})" @staticmethod + # pyrefly: ignore [bad-override] def atan2(x, y): return f"hl.atan2({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def atan(x): return f"hl.atan({x})" @staticmethod + # pyrefly: ignore [bad-override] def atanh(x): return f"hl.atanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def copysign(x, y): raise Unsupported("copysign") @staticmethod + # pyrefly: ignore [bad-override] def erfinv(x): raise Unsupported("erfinv") @staticmethod + # pyrefly: ignore [bad-override] def hypot(x, y): return f"hl.hypot({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def nextafter(x, y): raise Unsupported("nextafter") @staticmethod + # pyrefly: ignore [bad-override] def logical_and(a, b): return f"{a} & {b}" @@ -401,10 +425,12 @@ def logical_not(a): return f"{a} == 0" @staticmethod + # pyrefly: ignore [bad-override] def logical_or(a, b): return f"{a} | {b}" @staticmethod + # pyrefly: ignore [bad-override] def logical_xor(a, b): return f"({a} ^ {b})" @@ -440,6 +466,10 @@ def rand(seed, offset): def randn(seed, offset): return f"halide_helpers.randn({seed}, {offset})" + @staticmethod + def rand_eager(seed, base_offset, threads_per_round, tid, vec): + return f"halide_helpers.rand_eager_kernel({seed}, {base_offset}, {threads_per_round}, {tid}, {vec})" + @staticmethod def randint64(seed, offset, low, high): return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" @@ -449,23 +479,28 @@ def load_seed(name, offset): return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" @staticmethod + # pyrefly: ignore [bad-override] def rsqrt(x): # return f"hl.fast_inverse_sqrt({x})" <== accuracy issues return f"1./hl.sqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def tan(x): return f"hl.tan({x})" @staticmethod + # pyrefly: ignore [bad-override] def tanh(x): return f"hl.tanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def signbit(x): return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0" @staticmethod + # pyrefly: ignore [bad-override] def fmod(a, b): # TODO(jansel): find a better way to do this, builtin % has wrong sign return f"{a} - hl.trunc({a}/{b})*{b}" @@ -475,10 +510,12 @@ def pow(a, b): return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy @staticmethod + # pyrefly: ignore [bad-override] def ldexp(x, n): raise Unsupported("ldexp") @staticmethod + # pyrefly: ignore [bad-override] def log(x): return f"hl.log({x})" # hl.fast_log fails accuracy @@ -487,20 +524,24 @@ def log2(x): raise NotImplementedError("log2") @staticmethod + # pyrefly: ignore [bad-override] def isinf(x): # workaround https://github.com/halide/Halide/issues/8309 return f"hl.is_inf(hl.cast(hl.Float(32), {x}))" @staticmethod + # pyrefly: ignore [bad-override] def isnan(x): # workaround https://github.com/halide/Halide/issues/8309 return f"hl.is_nan(hl.cast(hl.Float(32), {x}))" @staticmethod + # pyrefly: ignore [bad-override] def round(x): return f"hl.round({x})" @staticmethod + # pyrefly: ignore [bad-override] def floor(x): return f"hl.floor({x})" @@ -523,10 +564,12 @@ def sign(cls, x): return f"hl.cast({x.name}.type(), {sub})" @staticmethod + # pyrefly: ignore [bad-override] def trunc(x): return f"hl.trunc({x})" @staticmethod + # pyrefly: ignore [bad-override] def truncdiv(a, b): # this causes crashes with floating point exception, see test_div_zero_dim_cpu # return f"hl.div_round_to_zero({a}, {b})" @@ -535,6 +578,7 @@ def truncdiv(a, b): ) @staticmethod + # pyrefly: ignore [bad-override] def ceil(x): return f"hl.ceil({x})" @@ -590,6 +634,7 @@ def masked(mask, body, other): return ops.where(new_mask, result, other) @staticmethod + # pyrefly: ignore [bad-override] def frexp(x): raise NotImplementedError("frexp") diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index 9d121330b2519..a6661976387b4 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -229,7 +229,8 @@ def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: # generates identical variable names. Without this reset, repeated calls to # body() would keep incrementing the counter, resulting in different cache key. V.kernel.cse.iter_buffer_ids = itertools.count() - V.kernel.cse.name_prefix = "tmp_scoped_" + # Append "_scoped" to the current prefix so each nesting level gets unique vars + V.kernel.cse.name_prefix += "_scoped" rc = body() # Compute cache key manually as variable name is needed to actually generate the code @@ -252,6 +253,7 @@ def masked(mask: CSEVariable, body: sympy.Expr, other: CSEVariable) -> str: return var @staticmethod + # pyrefly: ignore [bad-override] def where(a: OpVarT, b: OpVarT, c: OpVarT) -> str: return f"{a} ? {b} : static_cast({value_to_metal(c)})" @@ -260,50 +262,61 @@ def remainder(a: OpVarT, b: OpVarT) -> str: return f"c10::metal::remainder({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a: CSEVariable, b: CSEVariable) -> str: typecast_a = f"static_cast({a})" typecast_b = f"static_cast({b})" return f"c10::metal::max({typecast_a}, {typecast_b})" @staticmethod + # pyrefly: ignore [bad-override] def minimum(a: CSEVariable, b: CSEVariable) -> str: typecast_a = f"static_cast({a})" typecast_b = f"static_cast({b})" return f"c10::metal::min({typecast_a}, {typecast_b})" @staticmethod + # pyrefly: ignore [bad-override] def logical_or(a: CSEVariable, b: CSEVariable) -> str: return f"{a} || {b}" @staticmethod + # pyrefly: ignore [bad-override] def logical_and(a: CSEVariable, b: CSEVariable) -> str: return f"{a} && {b}" @staticmethod + # pyrefly: ignore [bad-override] def isnan(x: CSEVariable) -> str: return f"metal::isnan({x})" @staticmethod + # pyrefly: ignore [bad-override] def isinf(x: CSEVariable) -> str: return f"metal::isinf({x})" @staticmethod + # pyrefly: ignore [bad-override] def log(x: CSEVariable) -> str: return f"metal::precise::log({x})" @staticmethod + # pyrefly: ignore [bad-override] def exp(x: CSEVariable) -> str: return f"metal::precise::exp({x})" @staticmethod + # pyrefly: ignore [bad-override] def abs(x: CSEVariable) -> str: return f"metal::abs({x})" @staticmethod + # pyrefly: ignore [bad-override] def signbit(x: CSEVariable) -> str: return f"metal::signbit({x})" @staticmethod + # pyrefly: ignore [bad-override] def sin(x: CSEVariable) -> str: return f"metal::precise::sin({x})" @@ -312,30 +325,37 @@ def sinc(x: CSEVariable) -> str: return f"c10::metal::sinc({x})" @staticmethod + # pyrefly: ignore [bad-override] def cos(x: CSEVariable) -> str: return f"metal::precise::cos({x})" @staticmethod + # pyrefly: ignore [bad-override] def tan(x: CSEVariable) -> str: return f"metal::precise::tan({x})" @staticmethod + # pyrefly: ignore [bad-override] def asin(x: CSEVariable) -> str: return f"metal::precise::asin({x})" @staticmethod + # pyrefly: ignore [bad-override] def acos(x: CSEVariable) -> str: return f"metal::precise::acos({x})" @staticmethod + # pyrefly: ignore [bad-override] def atan(x: CSEVariable) -> str: return f"metal::precise::atan({x})" @staticmethod + # pyrefly: ignore [bad-override] def atan2(x: CSEVariable, y: CSEVariable) -> str: return f"::metal::precise::atan2({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def sqrt(x: CSEVariable) -> str: return f"metal::precise::sqrt({x})" @@ -346,14 +366,17 @@ def neg(x: CSEVariable) -> str: return f"static_cast(-{x})" @staticmethod + # pyrefly: ignore [bad-override] def rsqrt(x: CSEVariable) -> str: return f"metal::precise::rsqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def tanh(x: CSEVariable) -> str: return f"metal::precise::tanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def atanh(x: CSEVariable) -> str: return f"metal::precise::atanh({x})" @@ -363,24 +386,29 @@ def floordiv(a: CSEVariable, b: CSEVariable) -> str: return f"c10::metal::floor_divide({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def floor(x: CSEVariable) -> str: return f"metal::floor({x})" @staticmethod + # pyrefly: ignore [bad-override] def sign(x: CSEVariable) -> str: return f"metal::sign({x})" @staticmethod + # pyrefly: ignore [bad-override] def fmod(a: CSEVariable, b: CSEVariable) -> str: typecast_a = f"static_cast({a})" typecast_b = f"static_cast({b})" return f"metal::fmod({typecast_a}, {typecast_b})" @staticmethod + # pyrefly: ignore [bad-override] def trunc(x: CSEVariable) -> str: return f"metal::trunc({x})" @staticmethod + # pyrefly: ignore [bad-override] def truncdiv(a: CSEVariable, b: CSEVariable) -> str: quot = f"{a} / {b}" if (a.dtype is not None and a.dtype.is_floating_point) or ( @@ -390,6 +418,7 @@ def truncdiv(a: CSEVariable, b: CSEVariable) -> str: return quot @staticmethod + # pyrefly: ignore [bad-override] def ceil(x: CSEVariable) -> str: return f"metal::ceil({x})" @@ -411,6 +440,7 @@ def randint64( return f"c10::metal::randint64({seed}, {offset}, {low}, {high})" @staticmethod + # pyrefly: ignore [bad-override] def round(x: CSEVariable) -> str: return f"metal::rint({x})" @@ -899,10 +929,11 @@ def codegen_body(self) -> None: self.compute.clear() self.stores.clear() - def codegen_kernel(self, name: str | None = None) -> str: + def codegen_kernel(self, name: str = "generated_kernel") -> str: """Called at the end to generate a final kernel string""" self.codegen_body() code = IndentedBuffer() + fn_name = name if V.graph.cpp_wrapper: code.writeline('(R"MTL(') @@ -939,7 +970,7 @@ def codegen_kernel(self, name: str | None = None) -> str: code.writeline( f"[[max_total_threads_per_threadgroup({threadgroup_size})]]" ) - code.writeline("kernel void generated_kernel(") + code.writeline(f"kernel void {fn_name}(") with code.indent(): for outer, inner in self.args.output_buffers.items(): if outer in self.removed_buffers: @@ -1097,6 +1128,20 @@ def format_threads(threads: list[str], kwarg: str) -> str: arg_types=arg_types, ) + def device_assert_async(self, cond: CSEVariable, msg: str) -> None: + if V.graph.cpp_wrapper: + self.cse.generate(self.compute, f"if (!{cond}) return", assignment=False) + else: + self.headers.add("error") + self.compute.writelines( + [ + f"if (!{cond}) {{", + f" TORCH_REPORT_ERROR(error_buf, {repr(msg)});", + " return;", + "}", + ] + ) + def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool ) -> None: @@ -1136,36 +1181,55 @@ def check_bounds( class MetalScheduling(SIMDScheduling): kernel_type = MetalKernel # type: ignore[assignment] + _kernel_fn_counter: int = 0 def __init__(self, scheduler: Scheduler | None) -> None: super().__init__(scheduler) - wrapper = V.graph.wrapper_code - if wrapper is not None: - if not V.graph.cpp_wrapper: - wrapper.header.splice( - "from torch._inductor.runtime.runtime_utils import compile_mps_shader" - ) def define_kernel( self, src_code: str, node_schedule: list[SchedulerNode], kernel: MetalKernel ) -> str: wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: - kernel_name = wrapper.src_to_kernel[src_code] - else: - # TODO: Merge multiple kernels into a single library - # Either using MultiKernel concept or overriding SIMDScheduling.codegen_node_scheduling - mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" + return wrapper.src_to_kernel[src_code] - kernel_name = f"{mps_lib_name}" + if V.graph.cpp_wrapper: + # C++ path: one library per kernel (each has a single "generated_kernel" function) + mps_lib_name = f"mps_lib_{wrapper.next_kernel_suffix()}" + kernel_name = mps_lib_name wrapper.src_to_kernel[src_code] = kernel_name - - if V.graph.cpp_wrapper: - # For shimified version, generate source constant instead of direct instantiation - src_code = f"const char* {mps_lib_name}_source = " + src_code - + src_code = f"const char* {mps_lib_name}_source = " + src_code origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) - metadata_comment = f"{origins}\n{detailed_origins}" - wrapper.define_kernel(mps_lib_name, src_code, metadata_comment, gpu=False) + wrapper.define_kernel( + mps_lib_name, src_code, f"{origins}\n{detailed_origins}", gpu=False + ) + return kernel_name + + # Python path: register kernel with async_compile; wait() will compile all + # accumulated Metal kernels into a single library and replace each placeholder. + fn_name = f"generated_kernel_{self._kernel_fn_counter}" + self._kernel_fn_counter += 1 + wrapper.src_to_kernel[src_code] = fn_name + + # Extract Metal source from compile_mps_shader('''...''') call + metal_src_start = "compile_mps_shader('''" + start = src_code.index(metal_src_start) + len(metal_src_start) + end = src_code.rindex("''')") + metal_src = src_code[start:end] + + # Strip #include lines and rename the kernel function + body_lines = [] + for line in metal_src.split("\n"): + if line.strip().startswith("#include"): + continue + body_lines.append( + line.replace("kernel void generated_kernel(", f"kernel void {fn_name}(") + ) + + headers_repr = repr(sorted(kernel.headers)) + wrapper.header.writeline(f"{fn_name} = async_compile.metal({fn_name!r}, '''") + for line in body_lines: + wrapper.header.writeline(line) + wrapper.header.writeline(f"''', {headers_repr})") - return kernel_name + return fn_name diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 627d4752ff26b..41f963de6fcf8 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -3,12 +3,13 @@ import dataclasses import hashlib import math +import re import typing_extensions from typing import Any, cast, TYPE_CHECKING -import sympy # noqa: TC002 +import sympy -import torch # noqa: TC001 +import torch from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import ModularIndexing @@ -24,7 +25,7 @@ OpOverrides, PythonPrinter, ) -from .simd import SIMDKernel, SIMDScheduling +from .simd import IterationRangesEntry, SIMDKernel, SIMDScheduling class PallasPrinter(PythonPrinter): @@ -71,17 +72,6 @@ def _print_Max(self, expr: sympy.Expr) -> str: # Main function suffix used in generated Pallas code MAIN_SUFFIX = "main" -# Mosaic GPU warpgroup size: 4 warps × 32 threads = 128 threads per warpgroup. -# This is a hardware constant for Hopper and Blackwell GPUs. -# See: jax/_src/pallas/mosaic_gpu/lowering.py -WARPGROUP_SIZE = 128 - - -def _align_to_warpgroup(size: int) -> int: - """Align size to WARPGROUP_SIZE (128) for Mosaic GPU compatibility.""" - return ((size + WARPGROUP_SIZE - 1) // WARPGROUP_SIZE) * WARPGROUP_SIZE - - # Logger for Pallas kernel code kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") @@ -122,42 +112,52 @@ class PallasKernelOverrides(OpOverrides): """ @staticmethod + # pyrefly: ignore [bad-override] def sin(x: str) -> str: return f"jnp.sin({x})" @staticmethod + # pyrefly: ignore [bad-override] def cos(x: str) -> str: return f"jnp.cos({x})" @staticmethod + # pyrefly: ignore [bad-override] def tan(x: str) -> str: return f"jnp.tan({x})" @staticmethod + # pyrefly: ignore [bad-override] def sinh(x: str) -> str: return f"jnp.sinh({x})" @staticmethod + # pyrefly: ignore [bad-override] def cosh(x: str) -> str: return f"jnp.cosh({x})" @staticmethod + # pyrefly: ignore [bad-override] def tanh(x: str) -> str: return f"jnp.tanh({x})" @staticmethod + # pyrefly: ignore [bad-override] def asin(x: str) -> str: return f"jnp.arcsin({x})" @staticmethod + # pyrefly: ignore [bad-override] def acos(x: str) -> str: return f"jnp.arccos({x})" @staticmethod + # pyrefly: ignore [bad-override] def atan(x: str) -> str: return f"jnp.arctan({x})" @staticmethod + # pyrefly: ignore [bad-override] def exp(x: str) -> str: return f"jnp.exp({x})" @@ -170,6 +170,7 @@ def expm1(x: str) -> str: return f"jnp.expm1({x})" @staticmethod + # pyrefly: ignore [bad-override] def log(x: str) -> str: return f"jnp.log({x})" @@ -186,14 +187,17 @@ def log1p(x: str) -> str: return f"jnp.log1p({x})" @staticmethod + # pyrefly: ignore [bad-override] def sqrt(x: str) -> str: return f"jnp.sqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def rsqrt(x: str) -> str: return f"jax.lax.rsqrt({x})" @staticmethod + # pyrefly: ignore [bad-override] def abs(x: str) -> str: return f"jnp.abs({x})" @@ -202,18 +206,22 @@ def neg(x: str) -> str: return f"(-{x})" @staticmethod + # pyrefly: ignore [bad-override] def floor(x: str) -> str: return f"jnp.floor({x})" @staticmethod + # pyrefly: ignore [bad-override] def ceil(x: str) -> str: return f"jnp.ceil({x})" @staticmethod + # pyrefly: ignore [bad-override] def trunc(x: str) -> str: return f"jnp.trunc({x})" @staticmethod + # pyrefly: ignore [bad-override] def round(x: str) -> str: return f"jnp.round({x})" @@ -230,14 +238,17 @@ def pow(a: str, b: str) -> str: return f"jnp.power({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a: str, b: str) -> str: return f"jnp.maximum({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def minimum(a: str, b: str) -> str: return f"jnp.minimum({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def where(cond: str, a: str, b: str) -> str: return f"jnp.where({cond}, {a}, {b})" @@ -268,6 +279,9 @@ def to_dtype( src_dtype: torch.dtype | None = None, use_compute_types: bool = True, ) -> str: + # TPU doesn't support 64-bit types + if dtype == torch.int64 and V.graph.get_current_device_or_throw().type == "tpu": + dtype = torch.int32 jax_dtype = torch_dtype_to_jax(dtype) # Wrap in jnp.asarray to handle scalars from integer indexing return f"jnp.asarray({x}).astype({jax_dtype})" @@ -359,10 +373,12 @@ def gt(a: str, b: str) -> str: return f"({a} > {b})" @staticmethod + # pyrefly: ignore [bad-override] def isnan(x: str) -> str: return f"jnp.isnan({x})" @staticmethod + # pyrefly: ignore [bad-override] def isinf(x: str) -> str: return f"jnp.isinf({x})" @@ -376,10 +392,12 @@ def ge(a: str, b: str) -> str: # Logical operations @staticmethod + # pyrefly: ignore [bad-override] def logical_and(a: str, b: str) -> str: return f"jnp.logical_and({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def logical_or(a: str, b: str) -> str: return f"jnp.logical_or({a}, {b})" @@ -388,19 +406,23 @@ def logical_not(x: str) -> str: return f"jnp.logical_not({x})" @staticmethod + # pyrefly: ignore [bad-override] def logical_xor(a: str, b: str) -> str: return f"jnp.logical_xor({a}, {b})" # Math operations @staticmethod + # pyrefly: ignore [bad-override] def atan2(a: str, b: str) -> str: return f"jnp.arctan2({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def hypot(a: str, b: str) -> str: return f"jnp.hypot({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def fmod(a: str, b: str) -> str: return f"jnp.fmod({a}, {b})" @@ -409,6 +431,7 @@ def remainder(a: str, b: str) -> str: return f"jnp.remainder({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def truncdiv(a: str, b: str) -> str: # Truncated division (rounds toward zero) # For integers: sign(a)*sign(b) * (abs(a) // abs(b)) @@ -426,16 +449,19 @@ def clamp(x: str, min_val: str, max_val: str) -> str: # Sign operations @staticmethod + # pyrefly: ignore [bad-override] def sign(x: str) -> str: # PyTorch returns 0 for NaN, JAX returns NaN return f"jnp.where(jnp.isnan({x}), 0.0, jnp.sign({x}))" @staticmethod + # pyrefly: ignore [bad-override] def signbit(x: str) -> str: return f"jnp.signbit({x})" # Special math functions @staticmethod + # pyrefly: ignore [bad-override] def erf(x: str) -> str: return f"jax.scipy.special.erf({x})" @@ -444,10 +470,12 @@ def erfc(x: str) -> str: return f"jax.scipy.special.erfc({x})" @staticmethod + # pyrefly: ignore [bad-override] def erfinv(x: str) -> str: return f"jax.scipy.special.erfinv({x})" @staticmethod + # pyrefly: ignore [bad-override] def lgamma(x: str) -> str: return f"jax.scipy.special.gammaln({x})" @@ -553,6 +581,7 @@ def xlog1py(x: str, y: str) -> str: return f"jax.scipy.special.xlog1py({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def chebyshev_polynomial_t(x: str, n: str) -> str: # Chebyshev polynomial of the first kind T_n(x) # For |x| <= 1: T_n(x) = cos(n * arccos(x)) @@ -567,6 +596,7 @@ def chebyshev_polynomial_t(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def chebyshev_polynomial_u(x: str, n: str) -> str: # Chebyshev polynomial of the second kind U_n(x) # For |x| < 1: U_n(x) = sin((n+1) * arccos(x)) / sqrt(1 - x^2) @@ -588,6 +618,7 @@ def chebyshev_polynomial_u(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def chebyshev_polynomial_v(x: str, n: str) -> str: # Chebyshev polynomial of the third kind V_n(x) # V_n(x) = (T_n(x) - T_{n+1}(x)) / (1 - x) for x != 1 @@ -604,6 +635,7 @@ def chebyshev_polynomial_v(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def chebyshev_polynomial_w(x: str, n: str) -> str: # Chebyshev polynomial of the fourth kind W_n(x) # W_n(x) = (T_n(x) + T_{n+1}(x)) / (1 + x) for x != -1 @@ -620,22 +652,27 @@ def chebyshev_polynomial_w(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def shifted_chebyshev_polynomial_t(x: str, n: str) -> str: return PallasKernelOverrides.chebyshev_polynomial_t(f"(2 * {x} - 1)", n) @staticmethod + # pyrefly: ignore [bad-override] def shifted_chebyshev_polynomial_u(x: str, n: str) -> str: return PallasKernelOverrides.chebyshev_polynomial_u(f"(2 * {x} - 1)", n) @staticmethod + # pyrefly: ignore [bad-override] def shifted_chebyshev_polynomial_v(x: str, n: str) -> str: return PallasKernelOverrides.chebyshev_polynomial_v(f"(2 * {x} - 1)", n) @staticmethod + # pyrefly: ignore [bad-override] def shifted_chebyshev_polynomial_w(x: str, n: str) -> str: return PallasKernelOverrides.chebyshev_polynomial_w(f"(2 * {x} - 1)", n) @staticmethod + # pyrefly: ignore [bad-override] def hermite_polynomial_h(x: str, n: str) -> str: # Physicist's Hermite polynomial H_n(x) # H_n(x) = 2^n * x^n - n*(n-1)/2 * 2^(n-2) * x^(n-2) + ... @@ -654,6 +691,7 @@ def hermite_polynomial_h(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def hermite_polynomial_he(x: str, n: str) -> str: # Probabilist's Hermite polynomial He_n(x) # He_0 = 1, He_1 = x, He_2 = x^2 - 1, He_3 = x^3 - 3x @@ -668,6 +706,7 @@ def hermite_polynomial_he(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def laguerre_polynomial_l(x: str, n: str) -> str: # Laguerre polynomial L_n(x) # L_0 = 1, L_1 = 1 - x, L_2 = (x^2 - 4x + 2)/2, L_3 = (-x^3 + 9x^2 - 18x + 6)/6 @@ -682,6 +721,7 @@ def laguerre_polynomial_l(x: str, n: str) -> str: ) @staticmethod + # pyrefly: ignore [bad-override] def legendre_polynomial_p(x: str, n: str) -> str: # Legendre polynomial P_n(x) # P_0 = 1, P_1 = x, P_2 = (3x^2 - 1)/2, P_3 = (5x^3 - 3x)/2 @@ -715,18 +755,22 @@ def fma(a: str, b: str, c: str) -> str: return f"(({a}) * ({b}) + ({c}))" @staticmethod + # pyrefly: ignore [bad-override] def copysign(a: str, b: str) -> str: return f"jnp.copysign({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def nextafter(a: str, b: str) -> str: return f"jnp.nextafter({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def ldexp(a: str, b: str) -> str: return f"jnp.ldexp({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def frexp(x: str) -> str: return f"jnp.frexp({x})" @@ -809,6 +853,37 @@ def randint64(seed: str, offset: str, low: str, high: str) -> str: ) +@dataclasses.dataclass +class _IndirectAccessInfo: + """Describes a detected indirect (data-dependent) buffer access.""" + + table_param: str + table_buf_name: str + table_shape: tuple + indirect_dim: int + indirect_var: str + indices_param: str + + +@dataclasses.dataclass +class _BufferIndexing: + """Encapsulates index string and flattening requirements for buffer access.""" + + index_str: str + needs_flatten: bool + + +@dataclasses.dataclass +class _BroadcastedIterVar: + """Encapsulates information needed to codegen a broadcasted iteration var""" + + # index of this var in `self.range_tree_nodes.items()`` + idx: int + var_sym: sympy.Symbol + entry: IterationRangesEntry + length_val: int | None + + @dataclasses.dataclass class _CodegenContext: """Bundles local state shared across codegen_kernel helper methods.""" @@ -834,19 +909,10 @@ class _CodegenContext: class PallasKernel(SIMDKernel): - """ - Pallas kernel for elementwise operations with support for strided/scatter access. - - Strategy: - - Convert index expressions to JAX-compatible array slicing - - Load/store using indexed access: "in_ptrX[slice]" or full-array "in_ptrX[...]" - - Compute expression with Python operators (compatible with jax.numpy broadcasting) - - Generate Python code that defines a Pallas kernel and a host entrypoint. - - Use async_compile.pallas path to compile and load Python code. - - For GPU (Mosaic backend): - - Use TMA (Tensor Memory Accelerator) for automatic OOB masking - - Falls back to legacy padding approach for reductions, broadcasting, non-contiguous tensors + """Pallas kernel codegen for TPU and GPU (Mosaic backend). + + Generates Python code that defines a Pallas kernel and a host entrypoint, + compiled and loaded via async_compile.pallas. """ overrides = PallasKernelOverrides # type: ignore[assignment] @@ -894,6 +960,11 @@ def __init__(self, *args, **kwargs): # Buffers that already use flatten+gather indexing; strided # decomposition must not reshape these (it would break flat offsets). self.flatten_indexed_buffers: OrderedSet[str] = OrderedSet() + # Indirect (data-dependent) access info for scalar prefetch + self.indirect_access: _IndirectAccessInfo | None = None + self._cse_to_param: dict[str, str] = {} + self._param_to_graph_name: dict[str, str] = {} + self.is_tpu = device.type == "tpu" def check_bounds( self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool @@ -1280,37 +1351,6 @@ def _codegen_strided_reshapes( slice_parts.append(":") code.writeline(f"{param} = {param}[{', '.join(slice_parts)}]") - @staticmethod - def _c_contiguous_strides(shape: list[int]) -> list[int]: - """Return C-contiguous strides for the given shape.""" - n = len(shape) - strides = [1] * n - for i in range(n - 2, -1, -1): - strides[i] = strides[i + 1] * shape[i + 1] - return strides - - @staticmethod - def _map_coeffs_to_dims(coeffs: list[int], strides: list[int]) -> list[int] | None: - """Map coefficient values to dimension indices via stride matching. - - Returns a list where entry k is the dimension whose stride equals - coeffs[k], or None if the mapping is ambiguous or incomplete. - """ - stride_to_dim: dict[int, int] = {} - for d, s in enumerate(strides): - if s in stride_to_dim: - return None # duplicate strides - stride_to_dim[s] = d - mapping: list[int] = [] - for c in coeffs: - d = stride_to_dim.get(c) - if d is None: - return None - mapping.append(d) - if len(OrderedSet(mapping)) != len(coeffs): - return None - return mapping - @staticmethod def _get_actual_out_strides(out_buf, n: int) -> list[int] | None: """Extract actual output buffer strides from its layout.""" @@ -1607,15 +1647,17 @@ def _group_dims_to_ranges(dims: list[int], ranges: list[int]) -> list[int] | Non groups.reverse() return groups - def _get_index_expr(self, index: sympy.Expr) -> tuple[str, bool]: + def _get_index_expr(self, index: sympy.Expr) -> _BufferIndexing: """Get the index expression string and whether it needs flattening.""" has_indirect = self._has_indirect_vars(index) has_iter_vars = self._has_iteration_vars(index) if has_indirect and has_iter_vars: - return self._handle_mixed_indexing(index), True + return _BufferIndexing( + index_str=self._handle_mixed_indexing(index), needs_flatten=True + ) elif has_indirect: - return self.kexpr(index), False + return _BufferIndexing(index_str=self.kexpr(index), needs_flatten=False) else: index_str = self._get_index_str(index) # Check if index contains ModularIndexing - this requires flattened access @@ -1627,7 +1669,7 @@ def _get_index_expr(self, index: sympy.Expr) -> tuple[str, bool]: # Check if it's a simple slice pattern (::N or M::N) if not ("::" in index_str or index_str.lstrip("-").isdigit()): needs_flatten = True - return index_str, needs_flatten + return _BufferIndexing(index_str=index_str, needs_flatten=needs_flatten) @staticmethod def _safe_int(val: Any) -> int | None: @@ -1637,6 +1679,37 @@ def _safe_int(val: Any) -> int | None: except (TypeError, ValueError): return None + @staticmethod + def _c_contiguous_strides(shape: list[int]) -> list[int]: + """Return C-contiguous strides for the given shape.""" + n = len(shape) + strides = [1] * n + for i in range(n - 2, -1, -1): + strides[i] = strides[i + 1] * shape[i + 1] + return strides + + @staticmethod + def _map_coeffs_to_dims(coeffs: list[int], strides: list[int]) -> list[int] | None: + """Map coefficient values to dimension indices via stride matching. + + Returns a list where entry k is the dimension whose stride equals + coeffs[k], or None if the mapping is ambiguous or incomplete. + """ + stride_to_dim: dict[int, int] = {} + for d, s in enumerate(strides): + if s in stride_to_dim: + return None # duplicate strides + stride_to_dim[s] = d + mapping: list[int] = [] + for c in coeffs: + d = stride_to_dim.get(c) + if d is None: + return None + mapping.append(d) + if len(OrderedSet(mapping)) != len(coeffs): + return None + return mapping + def _zero_dim_output_flags(self, ctx: _CodegenContext) -> tuple[bool, bool]: """Return whether an output has a zero or unknown dimension.""" has_unknown_dim = False @@ -1965,24 +2038,16 @@ def _needs_strided_indexing( self, name: str, index: sympy.Expr, - index_str: str, - needs_flatten: bool, - ) -> tuple[str, bool]: - """ - Check if buffer access needs strided indexing due to size mismatch or gather patterns. - - This handles cases like: - - Pooling operations where input/output have different sizes - - im2col-like gather patterns - - Transposed or strided buffer access - """ + indexing: _BufferIndexing, + ) -> _BufferIndexing: + """Check if buffer access needs strided indexing due to size mismatch or gather patterns.""" # Only applies when full array access is indicated - if index_str != "..." or needs_flatten: - return index_str, needs_flatten + if indexing.index_str != "..." or indexing.needs_flatten: + return indexing info = self._get_buffer_info(name) if info is None: - return index_str, needs_flatten + return indexing buf_obj, buf_size, buf_numel, actual_strides, is_contiguous = info output_numel, used_vars = self._compute_output_numel_from_index(index) @@ -2021,77 +2086,85 @@ def _needs_strided_indexing( and not skip_for_non_contiguous and not has_symbolic_coef ): - return self._generate_strided_index(index), True + return _BufferIndexing( + index_str=self._generate_strided_index(index), needs_flatten=True + ) - return index_str, needs_flatten + return indexing def _adjust_index_for_buffer_shape( self, name: str, index: sympy.Expr, - index_str: str, - needs_flatten: bool, - ) -> tuple[str, bool]: + indexing: _BufferIndexing, + ) -> _BufferIndexing: """ Adjust index expression based on buffer shape (0-dim scalar, multi-dim, etc.). """ - if needs_flatten or index_str == "...": - return index_str, needs_flatten + if indexing.needs_flatten or indexing.index_str == "...": + return indexing buf_obj = V.graph.get_buffer(name) if buf_obj is None: - return index_str, needs_flatten + return indexing buf_size = buf_obj.get_size() # 0-dimensional (scalar) buffer - use [...] to access it if len(buf_size) == 0: - return "...", needs_flatten + return _BufferIndexing( + index_str="...", needs_flatten=indexing.needs_flatten + ) # Multi-dimensional buffer with constant/scalar index if len(buf_size) > 1: has_iter_vars = self._has_iteration_vars(index) if not has_iter_vars: - return index_str, True # Use flattened access - elif "::" in index_str: + return _BufferIndexing( + index_str=indexing.index_str, needs_flatten=True + ) # Use flattened access + elif "::" in indexing.index_str: # Strided slice patterns need flattened indexing for multi-dim - return self._generate_strided_index(index), True + return _BufferIndexing( + index_str=self._generate_strided_index(index), needs_flatten=True + ) # GPU doesn't support gather from slice patterns on 1D buffers - if self.is_gpu and "::" in index_str: - return self._generate_strided_index(index), True + if self.is_gpu and "::" in indexing.index_str: + return _BufferIndexing( + index_str=self._generate_strided_index(index), needs_flatten=True + ) - return index_str, needs_flatten + return indexing def _try_multidim_slice( self, name: str, index: sympy.Expr, - index_str: str, - needs_flatten: bool, - ) -> tuple[str, bool]: + indexing: _BufferIndexing, + ) -> _BufferIndexing: """ Try to emit multi-dim slice notation instead of flatten + gather. For a buffer with shape (d0, ..., dk) and index `stride * var + offset`, emit `buf[:, ..., :, offset::stride]` when stride divides dk. """ - if not needs_flatten: - return index_str, needs_flatten + if not indexing.needs_flatten: + return indexing buf_obj = V.graph.get_buffer(name) if buf_obj is None: - return index_str, needs_flatten + return indexing buf_size = buf_obj.get_size() ndim = len(buf_size) if ndim < 2: - return index_str, needs_flatten + return indexing # Need a single iteration variable with an affine index used_vars = self._get_used_iter_vars(index) if len(used_vars) != 1: - return index_str, needs_flatten + return indexing var = next(iter(used_vars)) var_expr = BlockPatternMatcher.get_subexpr_involving_symbol(index, var) @@ -2099,20 +2172,20 @@ def _try_multidim_slice( BlockPatternMatcher.match_affine_block_expr(var_expr, var) ) if stride is None or stride <= 1: - return index_str, needs_flatten + return indexing offset = V.graph.sizevars.simplify(index - var_expr) try: offset_val = int(offset) except (TypeError, ValueError): - return index_str, needs_flatten + return indexing if offset_val < 0 or offset_val >= stride: - return index_str, needs_flatten + return indexing last_dim = self._safe_int(buf_size[-1]) if last_dim is None or last_dim % stride != 0: - return index_str, needs_flatten + return indexing # Verify the iteration variable covers all buffer elements at the # given stride: var_length * stride == buf_numel. This ensures @@ -2120,23 +2193,23 @@ def _try_multidim_slice( # to buf[:, ..., :, offset::stride]. entry = self.range_tree_nodes.get(var) if entry is None: - return index_str, needs_flatten + return indexing var_length = self._safe_int(entry.length) buf_numel = 1 for s in buf_size: d = self._safe_int(s) if d is None: - return index_str, needs_flatten + return indexing buf_numel *= d if var_length is None or var_length * stride != buf_numel: - return index_str, needs_flatten + return indexing prefix = ":, " * (ndim - 1) if offset_val == 0: slice_str = f"{prefix}::{stride}" else: slice_str = f"{prefix}{offset_val}::{stride}" - return slice_str, False + return _BufferIndexing(index_str=slice_str, needs_flatten=False) @staticmethod def _gather_permute_expr(load_expr: str, perm: tuple[int, ...]) -> str: @@ -2148,29 +2221,181 @@ def _gather_permute_expr(load_expr: str, perm: tuple[int, ...]) -> str: """ return f"pallas_permute({load_expr}, {perm})" + def _trace_to_load_source(self, var_name: str) -> str | None: + """Trace a tmp variable back to its source buffer's kernel param. + + Follows CSE assignments backward through bounds-checking (where/clamp) + until it finds a variable that was directly loaded from a buffer. + """ + if var_name in self._cse_to_param: + return self._cse_to_param[var_name] + for line in self.compute._lines: + line_str = str(line).lstrip() + if not line_str.startswith(f"{var_name} = "): + continue + for ref in re.findall(r"\btmp\d+\b", line_str.split(" = ", 1)[1]): + result = self._trace_to_load_source(ref) + if result is not None: + return result + return None + + def _detect_indirect_access( + self, buf: str, name: str, index: sympy.Expr + ) -> _IndirectAccessInfo | None: + """Detect a load with data-dependent indexing suitable for scalar prefetch. + + Matches exactly one indirect variable whose coefficient corresponds to + a C-contiguous stride dimension. Rejects 1-to-1 gather patterns where + the indices buffer covers the full iteration space. + """ + buf_info = self._get_buffer_info(name) + if buf_info is None: + return None + _, buf_size, _, _, _ = buf_info + buf_size_raw = [self._safe_int(s) for s in buf_size] + if len(buf_size_raw) < 2 or any(s is None for s in buf_size_raw): + return None + buf_size_ints: list[int] = cast(list[int], buf_size_raw) + + indirect_vars = self._get_indirect_vars(index) + if len(indirect_vars) != 1: + return None + indirect_var = indirect_vars[0] + + coeff = self._get_index_coefficient(index, indirect_var) + if coeff == 0 or not isinstance(coeff, int): + return None + + # Use existing stride mapping to find which dimension is indirected + strides = self._c_contiguous_strides(buf_size_ints) + mapping = self._map_coeffs_to_dims([coeff], strides) + if mapping is None: + return None + indirect_dim = mapping[0] + + ndim = len(buf_size_ints) + if indirect_dim >= max(1, ndim - 2): + return None + + indirect_var_name = str(indirect_var) + indices_param = self._trace_to_load_source(indirect_var_name) + if indices_param is None: + return None + + # Reject gather patterns: only 1-D static index tensors supported + indices_graph_name = self._param_to_graph_name.get(indices_param) + if indices_graph_name is not None: + indices_info = self._get_buffer_info(indices_graph_name) + if indices_info is not None: + _, indices_size, _, _, _ = indices_info + if len(indices_size) != 1: + return None + if self._safe_int(indices_size[0]) is None: + return None + indices_numel = math.prod( + v for s in indices_size if (v := self._safe_int(s)) is not None + ) + iter_product = math.prod( + length + for var in self._get_used_iter_vars(index) + if var in self.range_tree_nodes + if (length := self._safe_int(self.range_tree_nodes[var].length)) + is not None + ) + if indices_numel >= iter_product: + return None + + return _IndirectAccessInfo( + table_param=buf, + table_buf_name=name, + table_shape=tuple(buf_size_ints), + indirect_dim=indirect_dim, + indirect_var=indirect_var_name, + indices_param=indices_param, + ) + + def _eliminate_dead_indirect_code(self) -> None: + """Remove dead compute lines after scalar prefetch replaces indirect load. + + When the table load is simplified to buf[0] (scalar prefetch handles + indexing), the indices load and all derived bounds-checking code become + dead. This performs backward liveness analysis from the store variables + to identify and remove dead lines. + """ + # Collect variables used by stores (live roots) + live_vars: OrderedSet[str] = OrderedSet() + for _, store_line in self.store_with_output: + for m in re.finditer(r"\btmp\d+\b", store_line): + live_vars.add(m.group()) + + # Parse assignments from compute lines + assignments: list[tuple[str | None, str, Any]] = [] + for line in self.compute._lines: + line_str = str(line).lstrip() + m = re.match(r"^(tmp\d+)\s*=\s*(.*)", line_str, re.DOTALL) + if m: + assignments.append((m.group(1), m.group(2), line)) + else: + assignments.append((None, line_str, line)) + + # Propagate liveness backward + changed = True + while changed: + changed = False + for var_name, rhs, _ in reversed(assignments): + if var_name and var_name in live_vars: + for m in re.finditer(r"\btmp\d+\b", rhs): + if m.group() not in live_vars: + live_vars.add(m.group()) + changed = True + + # Keep only live assignments (and non-assignment lines) + self.compute._lines = [ + line + for var_name, _, line in assignments + if var_name is None or var_name in live_vars + ] + def _build_load_expr( self, buf: str, name: str, index: sympy.Expr, - index_str: str, - needs_flatten: bool, + indexing: _BufferIndexing, ) -> str: """ Build the load expression based on indexing mode. """ - if needs_flatten: + if indexing.needs_flatten: + # Detect indirect (data-dependent) access for scalar prefetch + indirect = self._detect_indirect_access(buf, name, index) + if indirect is not None: + if self.indirect_access is not None: + # Fused nodes may re-visit the same indirect load (e.g. + # a reduction + pointwise over the same embedding). + # Allow that, but reject truly different indirect accesses. + assert indirect == self.indirect_access, ( + "only one indirect access per kernel supported" + ) + self.indirect_access = indirect + return f"{buf}[0]" + self.has_flatten_indexing = True self.flatten_indexed_buffers.add(name) # Flatten then index for non-contiguous access (gather operation) has_minmax = index.has(sympy.Min) or index.has(sympy.Max) - idx = f"({index_str}).astype(jnp.int64)" if has_minmax else index_str + idx_dtype = "jnp.int32" if self.is_tpu else "jnp.int64" + idx = ( + f"({indexing.index_str}).astype({idx_dtype})" + if has_minmax + else indexing.index_str + ) return f"{buf}[...].flatten()[{idx}]" else: # Direct indexing for contiguous access - load_expr = f"{buf}[{index_str}]" + load_expr = f"{buf}[{indexing.index_str}]" - if index_str == "..." and not self.is_gpu: + if indexing.index_str == "..." and not self.is_gpu: perm = self._get_full_load_permutation(name, index) if perm is not None: load_expr = self._gather_permute_expr(load_expr, perm) @@ -2307,16 +2532,16 @@ def _maybe_broadcast_1d_buffer( return f"{load_expr}.reshape({', '.join(map(str, reshape_dims))})" def _check_im2col_pattern( - self, index: sympy.Expr, index_str: str, needs_flatten: bool - ) -> tuple[str, bool]: + self, index: sympy.Expr, indexing: _BufferIndexing + ) -> _BufferIndexing: """ Check for im2col-like patterns where store uses block variables but load doesn't. For cat/expand patterns, both load and store prepared indices share block vars. For im2col patterns, store compresses to block vars but load doesn't. """ - if index_str != "..." or needs_flatten: - return index_str, needs_flatten + if indexing.index_str != "..." or indexing.needs_flatten: + return indexing prepared_index = self.prepare_indexing(index) iter_vars = self._get_iter_vars() @@ -2330,7 +2555,7 @@ def _check_im2col_pattern( # Only trigger if store introduces new block vars if not new_vars or len(store_orig_vars) <= 1: - return index_str, needs_flatten + return indexing # Check if loads are compatible with broadcast or cat pattern has_im2col_pattern = False @@ -2363,9 +2588,12 @@ def _check_im2col_pattern( break if has_im2col_pattern: - return self._generate_strided_index(prepared_index), True + return _BufferIndexing( + index_str=self._generate_strided_index(prepared_index), + needs_flatten=True, + ) - return index_str, needs_flatten + return indexing def _check_load_is_strided_input( self, buf_name: str, load_index: sympy.Expr, load_orig_vars: OrderedSet @@ -2490,8 +2718,7 @@ def _build_store_expr( name: str, index: sympy.Expr, value: CSEVariable, - index_str: str, - needs_flatten: bool, + indexing: _BufferIndexing, mode: Any = None, ) -> list[str]: """ @@ -2499,17 +2726,17 @@ def _build_store_expr( mode can be None (set) or "atomic_add" (accumulate). Returns a list of lines to emit. """ - if index_str == "...": + if indexing.index_str == "...": # Full array store with shape matching needs_transpose = self._check_store_needs_transpose(name) return self._build_full_array_store_expr(out, value, needs_transpose) - if needs_flatten: + if indexing.needs_flatten: self.has_flatten_indexing = True # Block variable indexing (e.g., im2col) - use flattened scatter scatter_op = "add" if mode == "atomic_add" else "set" return [ - f"{out}[...] = {out}[...].flatten().at[({index_str}).flatten()].{scatter_op}(" + f"{out}[...] = {out}[...].flatten().at[({indexing.index_str}).flatten()].{scatter_op}(" f"jnp.asarray({value}).flatten()).reshape({out}.shape)" ] @@ -2527,22 +2754,20 @@ def _build_store_expr( # Indirect indexed store (scatter): use .add() for atomic_add, .set() otherwise scatter_op = "add" if mode == "atomic_add" else "set" lines = [f"_val = jnp.asarray({value})"] - value_expr = ( - f"(jnp.full({index_str}.shape, _val) if _val.ndim == 0 else {value})" - ) + value_expr = f"(jnp.full({indexing.index_str}.shape, _val) if _val.ndim == 0 else {value})" if mode == "atomic_add": # For atomic_add, mark output as needing to be readable (for aliasing) self.outputs_need_read.add(out) alias_param = f"{out}_alias" lines.append( - f"{out}[...] = {alias_param}[...].flatten().at[({index_str}).flatten()].{scatter_op}(" + f"{out}[...] = {alias_param}[...].flatten().at[({indexing.index_str}).flatten()].{scatter_op}(" f"{value_expr}.flatten()).reshape({out}.shape)" ) else: - lines.append(f"{out}[{index_str}] = {value_expr}") + lines.append(f"{out}[{indexing.index_str}] = {value_expr}") return lines - return [f"{out}[{index_str}] = {value}"] + return [f"{out}[{indexing.index_str}] = {value}"] def _build_scatter_store_expr( self, @@ -2633,12 +2858,10 @@ def load(self, name: str, index: sympy.Expr) -> CSEVariable: self.load_index_exprs[name] = index # Get base index expression - index_str, needs_flatten = self._get_index_expr(index) + indexing = self._get_index_expr(index) # Check for buffer size mismatch requiring strided indexing - index_str, needs_flatten = self._needs_strided_indexing( - name, index, index_str, needs_flatten - ) + indexing = self._needs_strided_indexing(name, index, indexing) # Try strided decomposition before multidim slice or flatten. # This generates reshape + static indexing which works on both @@ -2649,31 +2872,30 @@ def load(self, name: str, index: sympy.Expr) -> CSEVariable: load_expr = self._strided_load_expr(buf, decomp) else: # Adjust index for buffer shape (scalar, multi-dim, etc.) - index_str, needs_flatten = self._adjust_index_for_buffer_shape( - name, index, index_str, needs_flatten - ) + indexing = self._adjust_index_for_buffer_shape(name, index, indexing) # Try to emit multi-dim slice instead of flatten + gather - index_str, needs_flatten = self._try_multidim_slice( - name, index, index_str, needs_flatten - ) + indexing = self._try_multidim_slice(name, index, indexing) # Build the load expression - load_expr = self._build_load_expr( - buf, name, index, index_str, needs_flatten - ) + load_expr = self._build_load_expr(buf, name, index, indexing) # Handle intermediate buffer squeezing for correct broadcasting - if not needs_flatten and index_str == "...": + if not indexing.needs_flatten and indexing.index_str == "...": load_expr = self._maybe_squeeze_intermediate_buffer(name, load_expr) # Handle 1D buffer broadcasting for higher-dimensional kernels load_expr = self._maybe_broadcast_1d_buffer(name, index, load_expr) - return self.cse.generate( + cse_var = self.cse.generate( self.compute, load_expr, dtype=dtype, ) + # Track CSE var -> param -> graph name for indirect access detection + buf_param = self.args.input(name) + self._cse_to_param[str(cse_var)] = buf_param + self._param_to_graph_name[buf_param] = name + return cse_var def _handle_mixed_indexing(self, index: sympy.Expr) -> str: """ @@ -2910,16 +3132,14 @@ def store( ] else: # Get base index expression - index_str, needs_flatten = self._get_index_expr(index) + indexing = self._get_index_expr(index) # Check for im2col-like patterns - index_str, needs_flatten = self._check_im2col_pattern( - index, index_str, needs_flatten - ) + indexing = self._check_im2col_pattern(index, indexing) # Build the store expression store_lines = self._build_store_expr( - out, name, index, value, index_str, needs_flatten, mode + out, name, index, value, indexing, mode ) for line in store_lines: @@ -3479,6 +3699,21 @@ def codegen_kernel(self, name: str | None = None) -> str: # type: ignore[overri ctx.alias_pairs = self._compute_alias_pairs(ctx, aliasable_flags) + use_scalar_prefetch = bool(self.indirect_access) + + if use_scalar_prefetch: + self._eliminate_dead_indirect_code() + kernel_body_sp = IndentedBuffer() + with kernel_body_sp.indent(): + for line in self.compute._lines: + kernel_body_sp.writeline(str(line)) + self._codegen_scalar_prefetch_wrapper( + ctx, + kernel_name, + kernel_body_sp, + ) + return code.getvalue() + # Emit the kernel function with the correct signature kernel_signature = f"def {kernel_name}_kernel({', '.join(ctx.full_kernel_params)}{extra_kernel_params}):" code.writeline(kernel_signature) @@ -3562,15 +3797,14 @@ def codegen_kernel(self, name: str | None = None) -> str: # type: ignore[overri if cshape is not None: code.writeline(f"{param} = {param}.reshape({cshape})") - code.writeline("indexer = lambda n: lambda i: [jnp.int32(i)] * n") code.writeline("out_specs_pallas = tuple(") - code.writeline(" pl.BlockSpec(shape, indexer(len(shape)))") + code.writeline(" pallas_make_block_spec_non_tiled(shape)") code.writeline( " for shape, dtype in zip(_pallas_out_shapes, out_dtypes)" ) code.writeline(")") code.writeline("in_specs_pallas = tuple(") - code.writeline(" pl.BlockSpec(i.shape, indexer(len(i.shape)))") + code.writeline(" pallas_make_block_spec_non_tiled(i.shape)") code.writeline( " for i in [" + ", ".join(ctx.kernel_input_params) + "]" ) @@ -3647,6 +3881,180 @@ def _emit_kernel_body( if out_ptr in ctx.full_kernel_params: code.writeline(store_line) + def _codegen_scalar_prefetch_wrapper( + self, + ctx: _CodegenContext, + kernel_name: str, + kernel_body: IndentedBuffer, + ) -> None: + """Emit kernel, JIT wrapper, and main entry for scalar prefetch.""" + assert self.indirect_access is not None + indirect = self.indirect_access + code = ctx.code + + alias_set = OrderedSet(ctx.alias_params) + other_input_params = [ + p + for p in ctx.kernel_input_params + if p != indirect.indices_param + and p != indirect.table_param + and p not in alias_set + ] + + # Emit kernel function with params reordered for PrefetchScalarGridSpec: + # [scalar_prefetch] + [in_specs refs] + [out_specs refs] + prefetch_kernel_params = ( + [indirect.indices_param] + + [indirect.table_param] + + other_input_params + + list(ctx.alias_params) + + ctx.output_params + ) + code.writeline( + f"def {kernel_name}_kernel({', '.join(prefetch_kernel_params)}):" + ) + with code.indent(): + self._emit_kernel_body(code, kernel_body, ctx) + + # Emit JIT wrapper + code.writeline("") + jit_wrapper_name = f"{kernel_name}_jit_wrapper" + wrapper_params = ( + ["out_shapes", "out_dtypes"] + ctx.size_var_params + ctx.kernel_input_params + ) + static_argnums = list(range(2 + len(ctx.size_var_params))) + static_argnums_literal = "(" + ", ".join(str(x) for x in static_argnums) + ",)" + code.writeline( + f"@functools.partial(jax.jit, static_argnums={static_argnums_literal})" + ) + code.writeline(f"def {jit_wrapper_name}({', '.join(wrapper_params)}):") + + with code.indent(): + table = indirect.table_param + indices = indirect.indices_param + + ind_dim = indirect.indirect_dim + ndim = len(indirect.table_shape) + code.writeline("_D = 1") + for i in range(ndim): + if i != ind_dim: + code.writeline(f"_D = _D * {table}.shape[{i}]") + code.writeline(f"_seq = {indices}.shape[0]") + + if ind_dim == 0: + code.writeline(f"_table_3d = {table}.reshape({table}.shape[0], 1, _D)") + else: + perm = (ind_dim, *[d for d in range(ndim) if d != ind_dim]) + code.writeline( + f"_table_3d = {table}.transpose{perm}.reshape(" + f"{table}.shape[{ind_dim}], 1, _D)" + ) + + # Reshape other (non-table, non-indices) inputs to 3D to match the + # table's (seq, 1, D) layout. Currently handles: + # - 2D with leading dim == seq: row-aligned, reshape to (seq, 1, D) + # - 1D: broadcast scalar/vector, reshape to (1, 1, numel) + # - else: flatten to (1, 1, -1) — assumes broadcastable with + # (seq, 1, D). This may not work correctly for 3D+ inputs. + pallas_call_other_args = [] + for p in other_input_params: + p3d = f"_{p}_3d" + code.writeline(f"if {p}.ndim == 2 and {p}.shape[0] == _seq:") + code.writeline(f" {p3d} = {p}.reshape(_seq, 1, _D)") + code.writeline(f"elif {p}.ndim == 1:") + code.writeline(f" {p3d} = {p}.reshape(1, 1, {p}.shape[0])") + code.writeline("else:") + code.writeline(f" {p3d} = {p}.reshape(1, 1, -1)") + pallas_call_other_args.append(p3d) + + pallas_call_alias_args = [] + for p in ctx.alias_params: + p3d = f"_{p}_3d" + code.writeline(f"{p3d} = {p}.reshape(_seq, 1, _D)") + pallas_call_alias_args.append(p3d) + + partial_args = [f"{sv}={sv}" for sv in ctx.size_var_params] + if partial_args: + kernel_ref = ( + f"functools.partial({kernel_name}_kernel," + f" {', '.join(partial_args)})" + ) + else: + kernel_ref = f"{kernel_name}_kernel" + + # Reusable row-tiled BlockSpec (all i32 index_map for Mosaic compat) + code.writeline( + "_ROW_SPEC = pl.BlockSpec((1, 1, _D)," + " lambda i, _: (i, jnp.int32(0), jnp.int32(0)))" + ) + + num_non_alias_in_specs = 1 + len(pallas_call_other_args) + code.writeline("_in_specs = [") + with code.indent(): + code.writeline( + "pl.BlockSpec((1, 1, _D)," + " lambda gi, idx: (idx[gi], jnp.int32(0), jnp.int32(0)))," + ) + for p3d in pallas_call_other_args: + code.writeline( + f"_ROW_SPEC" + f" if {p3d}.shape[0] == _seq else" + f" pl.BlockSpec({p3d}.shape," + f" lambda i, _: (jnp.int32(0), jnp.int32(0), jnp.int32(0)))," + ) + for _ in ctx.alias_params: + code.writeline("_ROW_SPEC,") + code.writeline("]") + + num_outputs = len(ctx.output_params) + code.writeline( + "_out_specs = [" + ", ".join(["_ROW_SPEC"] * num_outputs) + "]" + ) + + # input_output_aliases: pallas_call arg index -> output index + # (offset by 1 for scalar prefetch arg) + alias_map_parts = [] + for out_idx, _ in enumerate(ctx.alias_params): + arg_idx = 1 + num_non_alias_in_specs + out_idx + alias_map_parts.append(f"{arg_idx}: {out_idx}") + alias_map_literal = ", ".join(alias_map_parts) + + out_shape_parts = [ + f"jax.ShapeDtypeStruct((_seq, 1, _D), out_dtypes[{i}])" + for i in range(num_outputs) + ] + out_shape_expr = "[" + ", ".join(out_shape_parts) + "]" + + code.writeline("_result = pl.pallas_call(") + with code.indent(): + code.writeline(f"{kernel_ref},") + code.writeline(f"out_shape={out_shape_expr},") + code.writeline("grid_spec=pltpu.PrefetchScalarGridSpec(") + with code.indent(): + code.writeline("num_scalar_prefetch=1,") + code.writeline("grid=(_seq,),") + code.writeline("in_specs=_in_specs,") + code.writeline("out_specs=_out_specs,") + code.writeline("),") + if alias_map_parts: + code.writeline(f"input_output_aliases={{ {alias_map_literal} }},") + if not self.is_tpu: + code.writeline(f"interpret={ctx.interpret_literal},") + + all_pallas_args = ( + [indices] + + ["_table_3d"] + + pallas_call_other_args + + pallas_call_alias_args + ) + code.writeline(f")({', '.join(all_pallas_args)})") + + code.writeline( + "return tuple(r.reshape(s) for r, s in zip(_result, out_shapes))" + ) + + self._codegen_main_entry(ctx, jit_wrapper_name) + def _codegen_imports(self, ctx: _CodegenContext) -> None: imports = """ import functools @@ -3660,14 +4068,22 @@ def _codegen_imports(self, ctx: _CodegenContext) -> None: pallas_compute_tiling, pallas_make_block_spec, pallas_permute, pallas_gpu_align_output_specs, pallas_gpu_pad_inputs, pallas_gpu_unpad_results, + pallas_ensure_nonzero_rank, + pallas_make_block_spec_non_tiled, torch_dtype_to_jax_runtime, ) """ if ctx.is_tpu: imports += "\nimport jax.export" + imports += "\nfrom jax.experimental.pallas import tpu as pltpu" imports += "\nfrom torch_tpu._internal.pallas import tpu_torch_pallas" elif not ctx.interpret_is_cpu: imports += "\nfrom jax.experimental.pallas import mosaic_gpu as plgpu" + if self.indirect_access and not ctx.is_tpu: + imports += ( + "\nimport os as _os; _os.environ.setdefault('JAX_PLATFORMS', 'cpu')" + ) + imports += "\nfrom jax.experimental.pallas import tpu as pltpu" ctx.code.splice(imports, strip=True) def _get_iter_var_axis(self, var_sym: sympy.Symbol) -> int | None: @@ -3689,16 +4105,9 @@ def _get_iter_var_axis(self, var_sym: sympy.Symbol) -> int | None: pw_idx += 1 return None - def _codegen_iteration_vars( - self, kernel_body: IndentedBuffer, ctx: _CodegenContext - ) -> None: - # Generate iteration variables as jnp.arange arrays - # Skip on GPU - jnp.arange is not supported by Pallas Mosaic backend - if not (self.range_tree_nodes and not self.is_gpu and self.used_iter_vars): - return - - kernel_body.writeline("# Define iteration variables as JAX arrays") - + def _get_reshape_target_shape_and_numel( + self, + ) -> tuple[tuple[int, ...] | None, int | None]: # Find reshape target: N-D shape whose numel matches an iteration # var. Try output first (repeat/upsample), then inputs (reductions). iter_lengths = OrderedSet( @@ -3720,11 +4129,7 @@ def _get_nd_shape_if_matches(buf_name): numel = math.prod(shape) return (shape, numel) if numel in iter_lengths else (None, None) - candidate_buf_names = [] - if ctx.output_params: - buf_name = ctx.output_buffer_lookup.get(ctx.output_params[0]) - if buf_name: - candidate_buf_names.append(buf_name) + candidate_buf_names = self._output_buffer_names.copy() candidate_buf_names.extend(self.args.input_buffers) reshape_target_shape, reshape_target_numel = None, None @@ -3734,6 +4139,41 @@ def _get_nd_shape_if_matches(buf_name): reshape_target_shape, reshape_target_numel = result break + return reshape_target_shape, reshape_target_numel + + def _make_broadcasted_iteration_var_expr( + self, broadcast_vars: list[_BroadcastedIterVar], broadcast_idx: int + ) -> str: + bv = broadcast_vars[broadcast_idx] + length = bv.entry.length + renamed_length = self.rename_indexing(length) + length_str = self.kexpr(renamed_length) + + num_broadcast_dims = len(broadcast_vars) + axis_idx = self._broadcast_axis_idx( + broadcast_vars, broadcast_idx, num_broadcast_dims + ) + shape_parts = ["1"] * num_broadcast_dims + shape_parts[axis_idx] = length_str + shape_str = ", ".join(shape_parts) + arange = f"jnp.arange({length_str})" + reshaped = f"{arange}.reshape({shape_str})" + return reshaped + + def _codegen_iteration_vars( + self, kernel_body: IndentedBuffer, ctx: _CodegenContext + ) -> None: + # Generate iteration variables as jnp.arange arrays + # Skip on GPU - jnp.arange is not supported by Pallas Mosaic backend + if not (self.range_tree_nodes and not self.is_gpu and self.used_iter_vars): + return + + kernel_body.writeline("# Define iteration variables as JAX arrays") + + reshape_target_shape, reshape_target_numel = ( + self._get_reshape_target_shape_and_numel() + ) + var_items = list(self.range_tree_nodes.items()) broadcast_vars = [] @@ -3743,7 +4183,9 @@ def _get_nd_shape_if_matches(buf_name): if length_val is not None and length_val == reshape_target_numel: total_var_idx = idx else: - broadcast_vars.append((idx, var_sym, entry, length_val)) + broadcast_vars.append( + _BroadcastedIterVar(idx, var_sym, entry, length_val) + ) num_broadcast_dims = len(broadcast_vars) @@ -3763,24 +4205,14 @@ def _get_nd_shape_if_matches(buf_name): and idx != total_var_idx ): broadcast_idx = next( - ( - i - for i, (vidx, _, _, _) in enumerate(broadcast_vars) - if vidx == idx - ), + (i for i, v in enumerate(broadcast_vars) if v.idx == idx), None, ) if broadcast_idx is not None: - axis_idx = self._broadcast_axis_idx( - broadcast_vars, broadcast_idx, num_broadcast_dims - ) - shape_parts = ["1"] * num_broadcast_dims - shape_parts[axis_idx] = length_str - shape_str = ", ".join(shape_parts) - arange = f"jnp.arange({length_str})" - kernel_body.writeline( - f"{var_name} = {arange}.reshape({shape_str})" + expr = self._make_broadcasted_iteration_var_expr( + broadcast_vars, broadcast_idx ) + kernel_body.writeline(f"{var_name} = {expr}") continue kernel_body.writeline(f"{var_name} = jnp.arange({length_str})") continue @@ -3795,16 +4227,12 @@ def _get_nd_shape_if_matches(buf_name): kernel_body.writeline(f"{var_name} = {arange}.reshape({shape_str})") elif num_broadcast_dims > 1 and idx != total_var_idx: broadcast_idx = next( - i for i, (vidx, _, _, _) in enumerate(broadcast_vars) if vidx == idx + i for i, v in enumerate(broadcast_vars) if v.idx == idx ) - axis_idx = self._broadcast_axis_idx( - broadcast_vars, broadcast_idx, num_broadcast_dims + expr = self._make_broadcasted_iteration_var_expr( + broadcast_vars, broadcast_idx ) - shape_parts = ["1"] * num_broadcast_dims - shape_parts[axis_idx] = length_str - shape_str = ", ".join(shape_parts) - arange = f"jnp.arange({length_str})" - kernel_body.writeline(f"{var_name} = {arange}.reshape({shape_str})") + kernel_body.writeline(f"{var_name} = {expr}") else: # Simple 1D arange — emit tile-relative form so tiling is safe. # When grid=(1,), _pallas_tile[ax] == full length and @@ -3837,7 +4265,7 @@ def _get_nd_shape_if_matches(buf_name): @staticmethod def _broadcast_axis_idx( - broadcast_vars: list[tuple[int, Any, Any, Any]], + broadcast_vars: list[_BroadcastedIterVar], broadcast_idx: int, num_broadcast_dims: int, ) -> int: @@ -3845,10 +4273,10 @@ def _broadcast_axis_idx( # - Mixed: pointwise first, reduction last for output reshape # - Same-type: reverse order, first var innermost has_reduction_vars = any( - str(v).startswith("r") for _, v, _, _ in broadcast_vars + str(bv.var_sym).startswith("r") for bv in broadcast_vars ) has_pointwise_vars = any( - not str(v).startswith("r") for _, v, _, _ in broadcast_vars + not str(bv.var_sym).startswith("r") for bv in broadcast_vars ) is_mixed = has_reduction_vars and has_pointwise_vars if is_mixed: @@ -4156,7 +4584,10 @@ def _codegen_jit_wrapper_cpu_tpu( ) code.writeline(")(") if ctx.kernel_input_params: - code.writeline(f" {', '.join(ctx.kernel_input_params)},") + kernel_input_params_nonzero_rank = [ + f"pallas_ensure_nonzero_rank({p})" for p in ctx.kernel_input_params + ] + code.writeline(f" {', '.join(kernel_input_params_nonzero_rank)},") code.writeline(")") # Reshape results back to original shapes (restores 0-d from promoted (1,)) code.writeline("if isinstance(_result, tuple):") @@ -4184,8 +4615,21 @@ def _codegen_main_entry_tpu( f"def {main_name}({', '.join(ctx.full_kernel_params)}, stream=None):" ) with code.indent(): + # `jax_enable_x64` is per-process. The CPU path sets it to True, + # so running both CPU and TPU tests in one process can cause + # x64-related TPU crashes if we do not explicitly set it to + # False here. + code.writeline("jax.config.update('jax_enable_x64', False)") code.writeline("jax.clear_caches()") + # Convert int64 inputs to int32 (TPU doesn't support int64) + all_input_params = list(ctx.alias_params) + list(ctx.pointer_tail) + for param_name in all_input_params: + code.writeline( + f"{param_name} = {param_name}.to(torch.int32) " + f"if {param_name}.dtype == torch.int64 else {param_name}" + ) + # Build JAX placeholders for all inputs code.writeline("# Build JAX placeholders for export tracing") all_jax_input_names = [] diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index 57e5053acd607..c1f76f73af7dc 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -509,7 +509,7 @@ def _check_num_k_loops(self, op, kBatch): torch.cuda.get_device_properties(X_meta.device).warp_size, ) except Exception as e: - log.debug( # noqa: G200 + log.debug( "Failed to prefetch_stages for %s with exception %s", op.name, e ) # be conservative here and disable the op @@ -549,7 +549,7 @@ def _prefetch_stages(self, op, a_dtype_size, b_dtype_size, warp_size: int = 64): stages = version_to_stages.get(version) if stages is None: # This means we're at stage 2, and this requires computation - # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 # noqa: B950 + # See github.com/ROCm/composable_kernel/blob/d6a4605/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp#L143 wgp_per_cu = max(4 * warp_size // op.block_size, 1) full_mem_band_prefetch_stages = math.ceil( 32768 diff --git a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py index b07f33654ca27..a113493a8259e 100644 --- a/torch/_inductor/codegen/rocm/rocm_benchmark_request.py +++ b/torch/_inductor/codegen/rocm/rocm_benchmark_request.py @@ -117,7 +117,7 @@ def update_workspace_size(self) -> None: torch.cuda.synchronize() # shake out any CUDA errors self.workspace_size = c_workspace_size.value log.debug( - "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", # noqa: B950 + "update_workspace_size called: new workspace size=%d, self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s", self.workspace_size, self.kernel_name, self.source_file, diff --git a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py index ec58e458df6b1..eb8769de8db79 100644 --- a/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py +++ b/torch/_inductor/codegen/rocm/rocm_cpp_scheduling.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from typing import cast +from typing import cast, TypeGuard from ... import config from ...codecache import code_hash, get_path @@ -28,7 +28,7 @@ def group_fn(self, sizes): return tuple(V.graph.sizevars.simplify(sympy_product(s)) for s in sizes) @staticmethod - def is_rocm_cpp_template(node: BaseSchedulerNode) -> bool: + def is_rocm_cpp_template(node: BaseSchedulerNode) -> TypeGuard[SchedulerNode]: return isinstance(node, SchedulerNode) and isinstance( node.node, ROCmTemplateBuffer ) @@ -82,7 +82,6 @@ def codegen_template( assert self.is_rocm_cpp_template(template_node), ( "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" ) - template_node = cast(SchedulerNode, template_node) _, (_numel, rnumel) = template_node.group assert rnumel == 1 ctb: ROCmTemplateBuffer = cast(ROCmTemplateBuffer, template_node.node) diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index 8a2b2fae44a58..406d59ae8ec2c 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -802,7 +802,10 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: if not sv.statically_known_multiple_of( size, remaining[current_group] * remaining[current_group + 1] ): - raise CantSplit + raise CantSplit( + size, + remaining[current_group] * remaining[current_group + 1], + ) size1 = remaining[current_group] size2 = remaining[current_group + 1] @@ -856,11 +859,12 @@ def getter(flat_vars: list[sympy.Expr]) -> sympy.Expr: ) ) else: - if current_group < len(remaining): - return_getters.append( - # pyrefly: ignore [bad-argument-type] - operator.itemgetter(add_range(current_group, size)) - ) + if current_group >= len(remaining): + raise CantSplit(size, 0) + return_getters.append( + # pyrefly: ignore [bad-argument-type] + operator.itemgetter(add_range(current_group, size)) + ) return_getters_groups.append(return_getters) assert all( @@ -1647,6 +1651,7 @@ def _generate_kernel_code_for_mix_order_reduction( src_code = src_code.replace(str(Placeholder.KERNEL_NAME), "triton_") return kernel, ws_name, src_code + # pyrefly: ignore [bad-override] def benchmark_codegened_module( self, mod, n_spills_threshold=8, node_names: OrderedSet[str] | None = None ) -> tuple[float, str]: @@ -1790,7 +1795,13 @@ def _bench(candidate_split_size): partial_accum.reduction_type, partial_accum.reduction_type ) - final_reduce = f"{buffer_name} = {ws_name}[{start} : {end}].view({nsplit}, {rnumel}).{opname}(dim=0)" + # Check if the original reduction used keepdim=True by comparing dimensions. + # Without keepdim, reduction produces [rnumel]; with keepdim, [1, rnumel]. + buffer = V.graph.get_buffer(buffer_name) + keepdim = buffer is not None and len(buffer.get_layout().size) > 1 + + final_reduce = f"{buffer_name} = {ws_name}[{start} : {end}].view({nsplit}, {rnumel}).{opname}(dim=0, keepdim={keepdim})" + # The workspace tensor is in torch.float, need a cast if the buffer is # not. if (buffer_dtype := V.graph.get_dtype(buffer_name)) != torch.float: @@ -2554,10 +2565,16 @@ def create_tiling( """ Create a tiling dict from pointwise and reduction splits. """ - pw_prefixes = ["z", "y", "x"][-len(pw_tiling) :] - reduction_prefixes = ["r0_", "r1_"][: len(reduction_tiling)] + pw_prefixes = ("z", "y", "x") + reduction_prefixes = ("r0_", "r1_") + assert len(pw_tiling) <= len(pw_prefixes) + assert len(reduction_tiling) <= len(reduction_prefixes) + return immutable_dict( - [*zip(pw_prefixes, pw_tiling), *zip(reduction_prefixes, reduction_tiling)] + [ + *zip(pw_prefixes[-len(pw_tiling) :], pw_tiling, strict=False), + *zip(reduction_prefixes, reduction_tiling, strict=False), + ] ) @classmethod @@ -2615,6 +2632,12 @@ def collapse_dims( if not dims: return (fallback_numel,) max_tiles = get_max_tiles(2) + if V.graph.sizevars.statically_known_equals( + pointwise_numel, 1 + ) and V.graph.sizevars.statically_known_gt(reduction_numel, 1): + # We only have at most two dimensions to tile over when emitting a + # reduction-only kernel. + max_tiles = min(max_tiles, 2) num_leading_dims = max(0, len(dims) - max_tiles) first_trailing_dim = num_leading_dims + 1 collapsed_leading_dim = sympy_product(dims[:first_trailing_dim]) diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 47c0434e56137..167891ba04ded 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -11,6 +11,11 @@ import torch._inductor.config as config from torch._dynamo.utils import counters from torch._inductor import ir +from torch._inductor.autotune_process import ( + SubgraphCPUBenchmarkRequest, + SubgraphGPUBenchmarkRequest, + TensorMeta, +) from torch._inductor.codegen.common import KernelTemplate from torch._inductor.ir import ( Buffer, @@ -112,6 +117,17 @@ def __init__( self.config_patches: dict[str, Any] = {} # Cache compiled module to avoid recompiling on every benchmark call self._compiled_module: Any = None + # Cache benchmark request for async autotuning + self._bmreq: ( + SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest | None + ) = None + + # Pre-compile only if using async pipelined autotuning + # Must happen in __init__ because compilation requires virtualized context (V.graph, V.debug) + if config.pipeline_max_autotune_gemm: + with V.fake_mode: + self._compiled_module = self._compile_for_benchmarking() + self._bmreq = self._create_benchmark_request() def _compute_sym_input_values(self) -> list[int]: """Extract concrete dimension values for sym_inputs from benchmark_inputs. @@ -141,8 +157,8 @@ def _compute_sym_input_values(self) -> list[int]: if isinstance(sym_var, sympy.Symbol) and sym_var.name in sym_name_to_value: result.append(sym_name_to_value[sym_var.name]) else: - hint = V.graph.sizevars.shape_env.size_hint(sym_var) - result.append(int(hint) if hint is not None else 1) + hint = V.graph.sizevars.shape_env.optimization_hint(sym_var, fallback=1) + result.append(int(hint)) return result def cache_decomposition( @@ -193,17 +209,57 @@ def _compile_for_benchmarking(self) -> Any: "max_autotune": False, "max_autotune_gemm": False, "max_autotune_gemm_backends": "ATEN", + "benchmark_fusion": False, + "pipeline_max_autotune_gemm": False, **self.config_patches, } with config.patch(benchmark_config): bm_graph_lowering.run(*compile_inputs) return bm_graph_lowering.compile_to_module() - def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: - """Regular benchmarking: compile and use benchmarker with warmup/rep.""" + def _create_benchmark_request( + self, + ) -> SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest: + """Create a benchmark request for async autotuning.""" + assert self._compiled_module is not None, ( + "Module must be compiled before creating benchmark request" + ) + input_tensor_meta = TensorMeta.from_irnodes(self.input_nodes) + output_tensor_meta = TensorMeta.from_irnodes(self.layout) + + if self.layout.device.type == "cpu": + bmreq_cls = SubgraphCPUBenchmarkRequest + else: + bmreq_cls = SubgraphGPUBenchmarkRequest + + return bmreq_cls( + kernel_name=self.name, + input_tensor_meta=input_tensor_meta, + output_tensor_meta=output_tensor_meta, + extra_args=tuple(), + module_path=self._compiled_module.__file__, + module_cache_key=self._compiled_module.key, + sym_input_values=self.sym_input_values, + ) + + @property + def bmreq( + self, + ) -> SubgraphGPUBenchmarkRequest | SubgraphCPUBenchmarkRequest: + """Benchmark request for async autotuning. Pre-compiled when pipeline_max_autotune_gemm is enabled.""" + assert self._bmreq is not None, ( + "bmreq accessed but pipeline_max_autotune_gemm was not enabled during __init__" + ) + return self._bmreq + + def _ensure_compiled(self) -> None: + """Ensure the module is compiled. Used for lazy compilation in non-async path.""" if self._compiled_module is None: self._compiled_module = self._compile_for_benchmarking() + def benchmark(self, *args: list[Any], out: torch.Tensor) -> float: + """Regular benchmarking: compile if needed, then use benchmarker.""" + self._ensure_compiled() bm_func = self._compiled_module.call sym_inputs = self.sym_input_values @@ -222,9 +278,7 @@ def fn() -> Any: def benchmark_collective(self, *args: list[Any], out: torch.Tensor) -> None: """Run once for collective benchmarking (barrier sync handled by caller).""" - if self._compiled_module is None: - self._compiled_module = self._compile_for_benchmarking() - + self._ensure_compiled() self._compiled_module.call([*self.sym_input_values, *args]) def hash_key(self) -> str: diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3a9f1394c2b03..34c0da3ea44ca 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -28,7 +28,13 @@ from torch._dynamo.utils import identity, preserve_rng_state from torch._prims_common import is_integer_dtype, type_to_dtype from torch.utils._ordered_set import OrderedSet -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import ( + CeilDiv, + FloorDiv, + ModularIndexing, + TruncToFloat, + TruncToInt, +) from torch.utils._triton import ( get_triton_version, has_triton_package, @@ -160,6 +166,47 @@ def is_sympy_integer_like(expr: object): ) +def _materialize_trunc_to_float_expr( + expr: sympy.Expr, dtype: torch.dtype +) -> sympy.Expr: + if not dtype.is_floating_point or not expr.has(TruncToInt): + return expr + + # Preserve float truncation semantics when materializing symbolic scalars + # into floating tensors. Casting to the kernel index dtype first can + # overflow before the requested floating-point conversion happens. Only + # rewrite truncations that are already participating in floating-point + # computation; integer subexpressions and predicates must keep exact + # integer semantics until the final materialization cast. + if expr.func is TruncToInt: + return TruncToFloat(*expr.args) + + def is_predicate_expr(node: sympy.Basic) -> bool: + return bool( + getattr(node, "is_Boolean", False) or getattr(node, "is_Relational", False) + ) + + def rewrite_float_subexpr(node: sympy.Expr) -> sympy.Expr: + if not node.has(TruncToInt): + return node + if node.func is TruncToInt: + return TruncToFloat(*node.args) + if is_predicate_expr(node) or node.is_integer: + return node + + new_args = tuple( + rewrite_float_subexpr(arg) + if isinstance(arg, sympy.Expr) and not is_predicate_expr(arg) + else arg + for arg in node.args + ) + if new_args == node.args: + return node + return node.func(*new_args) + + return rewrite_float_subexpr(expr) + + class OpDtypeSupport: """ Some Triton ops such as libdevice and tl.math only support float32 and float64. @@ -260,8 +307,15 @@ def get_block_shape(cls, expr: sympy.Expr) -> BlockShapeType: ] assert len(tree_match) == 1, "# of Match expected to 1" - shape[tree_match[0].tensor_dim] = str(cls.get_block_size(tree_match[0])) - var_shape = tuple(shape) + if tree_match[0].tensor_dim is None: + # tree has no tensor dimension (e.g. no_x_dim mode), + # treat as scalar + var_shape = () + else: + shape[tree_match[0].tensor_dim] = str( + cls.get_block_size(tree_match[0]) + ) + var_shape = tuple(shape) # Union current variable shape expr_shape = get_broadcasted_shape(expr_shape, var_shape) @@ -768,6 +822,15 @@ def _print_TruncToInt(self, expr: sympy.Expr) -> str: f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # pyrefly: ignore [missing-attribute] + value = self._print(expr.args[0]) + # Adding +0.0 preserves large floating results while canonicalizing + # libdevice.trunc(-0.0) back to Python's +0.0 materialization behavior. + # pyrefly: ignore [missing-attribute] + return f"(libdevice.trunc({value}) + tl.zeros_like({value}))" + def _print_Float(self, expr: sympy.Expr) -> str: if expr.is_integer: # sympy considers 0.0 to be integer, but triton doesn't. @@ -1231,6 +1294,7 @@ def _cast_libdevice_arg(cls, arg, dtype: torch.dtype) -> str: @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def abs(x): return f"tl_math.abs({x})" @@ -1293,6 +1357,7 @@ def mod(x, y): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def exp(x): """ When use_fast_math, use the ftz (flushing to zero) variant @@ -1318,6 +1383,7 @@ def expm1(x): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def sqrt(x): return f"tl.sqrt_rn({x})" @@ -1340,6 +1406,7 @@ def relu(x): ) @staticmethod + # pyrefly: ignore [bad-override] def minimum(a, b): if torch.version.hip: return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)" @@ -1347,6 +1414,7 @@ def minimum(a, b): return f"triton_helpers.minimum({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def maximum(a, b): if torch.version.hip: return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)" @@ -1354,10 +1422,12 @@ def maximum(a, b): return f"triton_helpers.maximum({a}, {b})" @staticmethod + # pyrefly: ignore [bad-override] def where(a, b, c): return f"tl.where({a}, {b}, {c})" @staticmethod + # pyrefly: ignore [bad-override] def dot(a, b): """ Triton code generation for lowering ops.dot to tl.dot. @@ -1546,21 +1616,83 @@ def reshape_transpose_broadcast_for_dot( @staticmethod def inline_asm_elementwise( - *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + *inputs, + asm, + constraints=None, + dtype=torch.float32, + is_pure=True, + pack=1, + input_dtypes=None, ): - triton_type = triton_compute_type(dtype) - input_refs = ", ".join([str(i) for i in inputs]) + # Use the actual dtype, not the compute type — the asm operates on + # specific register types and Triton needs to know the real output type. + asm_triton_type = triton_type(dtype) if constraints is None: constraints = ", ".join(["=r"] + ["r" for _ in inputs]) - return f"tl.inline_asm_elementwise('{asm}', '{constraints}', [{input_refs}], dtype={triton_type}, is_pure={is_pure}, pack={pack})" # noqa: B950 + + # Inductor computes bf16/fp16 in fp32. For "h" (16-bit register) + # constraints, cast back to the original dtype so the asm sees the + # right register type. + constraint_parts = [p.strip() for p in constraints.split(",")] + input_constraints = [p for p in constraint_parts if not p.startswith("=")] + cast_inputs = [] + for i, (inp, c) in enumerate(zip(inputs, input_constraints[: len(inputs)])): + if ( + c == "h" + and input_dtypes is not None + and isinstance(inp, CSEVariable) + and inp.dtype != input_dtypes[i] + ): + cast_inputs.append(f"{inp}.to({triton_type(input_dtypes[i])})") + else: + cast_inputs.append(str(inp)) + + if torch.version.hip: + # AMDGCN asm strings may contain real newlines (instructions are + # newline-separated, unlike PTX which uses semicolons). The + # generated code is nested inside two Python string layers: + # Layer 1 : the cached wrapper .py file + # Layer 2 : the Triton kernel source (a triple-quoted string + # inside that wrapper, exec'd / JIT-compiled) + # repr() escapes \n -> \\n, then we double the backslashes so + # they survive both layers: \\\\n -> (L1 parse) \\n -> (L2 parse) \n. + asm_literal = repr(asm).replace("\\", "\\\\") + constraints_literal = repr(constraints).replace("\\", "\\\\") + else: + asm_literal = f"'{asm}'" + constraints_literal = f"'{constraints}'" + + def asm_call(args): + return ( + f"tl.inline_asm_elementwise({asm_literal}, {constraints_literal}, " + f"[{args}], dtype={asm_triton_type}, is_pure={is_pure}, pack={pack})" + ) + + if pack <= 1: + return asm_call(", ".join(cast_inputs)) + + first_input = inputs[0] + compute = V.kernel.compute + cse = V.kernel.cse + result = cse.newvar(dtype=dtype, shape=first_input.shape) + packed_args = ", ".join( + f"triton_helpers.inline_asm_pack({inp}, {pack})" for inp in cast_inputs + ) + compute.writeline(f"{result} = {asm_call(packed_args)}") + compute.writeline( + f"{result} = triton_helpers.inline_asm_unpack({result}, {first_input}, {pack})" + ) + return result @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def cos(x): return f"tl_math.cos({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def sin(x): return f"tl_math.sin({x})" @@ -1574,61 +1706,73 @@ def masked(mask, body, other): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def lgamma(x): return f"libdevice.lgamma({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def erf(x): return f"libdevice.erf({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def cosh(x): return f"libdevice.cosh({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def sinh(x): return f"libdevice.sinh({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def acos(x): return f"libdevice.acos({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def acosh(x): return f"libdevice.acosh({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def asin(x): return f"libdevice.asin({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def asinh(x): return f"libdevice.asinh({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def atan2(x, y): return f"libdevice.atan2({x}, {y})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def atan(x): return f"libdevice.atan({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def atanh(x): return f"libdevice.atanh({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def copysign(x, y): return f"libdevice.copysign({x}, {y})" @@ -1639,11 +1783,13 @@ def erfc(x): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def erfinv(x): return f"libdevice.erfinv({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def hypot(x, y): return f"libdevice.hypot({x}, {y})" @@ -1658,15 +1804,18 @@ def log2(x): return f"libdevice.log2({x})" @staticmethod + # pyrefly: ignore [bad-override] def ldexp(x, n): return f"libdevice.ldexp({x}, {n}.to(tl.int32))" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def nextafter(x, y): return f"libdevice.nextafter({x}, {y})" @staticmethod + # pyrefly: ignore [bad-override] def logical_and(a, b): return f"{a} & {b}" @@ -1675,10 +1824,12 @@ def logical_not(a): return f"{a} == 0" @staticmethod + # pyrefly: ignore [bad-override] def logical_or(a, b): return f"{a} | {b}" @staticmethod + # pyrefly: ignore [bad-override] def logical_xor(a, b): return f"({a} ^ {b})" @@ -1711,6 +1862,16 @@ def rand(seed, offset): offset = f"({offset}).to(tl.uint32)" return f"tl.rand({seed}, {offset})" + @staticmethod + def rand_eager(seed, base_offset, threads_per_round, tid, vec): + # vec: 4 for fp32, 8 for fp16/bf16 + tid_u32 = f"({tid}).to(tl.uint32)" + denom = f"(({vec})*({threads_per_round}))" + r = f"(({tid_u32})//({denom})*({vec}//4))" + tid_trunc = f"(({tid_u32})%({denom}))" + + return f"triton_helpers.rand_eager_kernel({seed}, {base_offset}+{r}, {tid_trunc}, VEC={vec})" + @staticmethod def randn(seed, offset): offset = f"({offset}).to(tl.uint32)" @@ -1727,6 +1888,7 @@ def load_seed(name, offset): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def rsqrt(x): if torch.version.hip: return f"tl.rsqrt({x})" @@ -1740,11 +1902,13 @@ def log1p(x): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def tan(x): return f"libdevice.tan({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def tanh(x): cse_var = V.kernel.cse.varname_map.get(x) if cse_var and hasattr(cse_var, "dtype"): @@ -1770,6 +1934,7 @@ def sigmoid(x): return f"tl.sigmoid({x})" @staticmethod + # pyrefly: ignore [bad-override] def signbit(x): # XX: This is wrong for the value -0.0 in floating point return ( @@ -1778,6 +1943,7 @@ def signbit(x): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def fmod(a, b): return f"libdevice.fmod({a}, {b})" @@ -1818,26 +1984,31 @@ def pow(cls, a, b): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def log(x): return f"tl_math.log({x})" @staticmethod @maybe_upcast_float32(convert_output=False) + # pyrefly: ignore [bad-override] def isinf(x): return f"libdevice.isinf({x}).to(tl.int1)" @staticmethod @maybe_upcast_float32(convert_output=False) + # pyrefly: ignore [bad-override] def isnan(x): return f"libdevice.isnan({x}).to(tl.int1)" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def round(x): return f"libdevice.nearbyint({x})" @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def floor(x): return f"libdevice.floor({x})" @@ -1853,15 +2024,24 @@ def floordiv(a, b): # floor_div(a, b) = ~(~a // b) when a < 0, a // b when a >= 0 # For negative b we negate both operands first. zero = ops.constant(0, torch.int32) + one = ops.constant(1, torch.int32) + # Guard against integer division by zero before the division to + # avoid undefined behavior (LLVM may optimize away a post-division + # check assuming UB doesn't happen). Replace b with 1 when b is 0 + # so the division is safe, then select 0 as the final result. + b_zero = ops.eq(b, zero) + b = ops.where(b_zero, one, b) b_neg = ops.lt(b, zero) a = ops.where(b_neg, ops.sub(zero, a), a) b = ops.where(b_neg, ops.sub(zero, b), b) a_neg = ops.lt(a, zero) a = ops.where(a_neg, ops.bitwise_not(a), a) quot = ops.truncdiv(a, b) - return ops.where(a_neg, ops.bitwise_not(quot), quot) + quot = ops.where(a_neg, ops.bitwise_not(quot), quot) + return ops.where(b_zero, zero, quot) @staticmethod + # pyrefly: ignore [bad-override] def sign(x): z = ops.constant(0, torch.int32) left = ops.to_dtype((ops.lt(z, x)), torch.int8) @@ -1871,10 +2051,12 @@ def sign(x): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def trunc(x): return f"libdevice.trunc({x})" @staticmethod + # pyrefly: ignore [bad-override] def truncdiv(a, b): # See the comment in lowering.div_mode. a and b are integer type. # Notice that // in triton behaves as truncdiv instead of floordiv @@ -1882,6 +2064,7 @@ def truncdiv(a, b): @staticmethod @maybe_upcast_float32() + # pyrefly: ignore [bad-override] def ceil(x): return f"libdevice.ceil({x})" @@ -1956,6 +2139,7 @@ def constant(cls, value, dtype): @classmethod def index_expr(cls, expr, dtype): + expr = _materialize_trunc_to_float_expr(expr, dtype) indexing = V.kernel.indexing( expr, block_ptr=False, tma_compatibility_checker=None ) @@ -2071,6 +2255,7 @@ def load_seed(name, offset): ) @staticmethod + # pyrefly: ignore [bad-override] def frexp(x): cache_key = f"frexp({x})" if cse_val := V.kernel.cse.try_get(cache_key): @@ -2709,7 +2894,9 @@ def dtype_to_str(self, dtype: torch.dtype) -> str: def should_use_cooperative_reduction(self) -> bool: return self.inside_reduction and V.choices.should_use_cooperative_reduction( - self.features + V.graph.get_current_device_or_throw(), + self.features.numel, + self.features.reduction_numel, ) def init_cooperative_reduction(self): @@ -2830,6 +3017,7 @@ def indexing( override_mask=None, block_ptr=False, tma_compatibility_checker: TMACompatibilityChecker | None = None, + mask_constant_index=False, ): """ Compute the index and mask to pass to tl.load() or tl.store() @@ -3205,7 +3393,7 @@ def _get_expand_str(): expand_shape = tuple([1] * len(self.dense_size_list())) index_str = f"tl.full({expand_str}, {index_str}, tl.int32)" - if self.fixed_config or self.is_combo_kernel: + if self.fixed_config or self.is_combo_kernel or mask_constant_index: mask_vars = OrderedSet( f"{tree.prefix}mask" for tree in self.range_trees @@ -3433,7 +3621,15 @@ def codegen_block_ptr_store_line(self, name, indexing, block_ptr, value, other=" ) # workaround https://github.com/triton-lang/triton/issues/2814 - value = f"{value}.to({triton_store_type(V.graph.get_dtype(name))})" + # For inplace-mutated buffers, the block pointer element type comes from + # the actual tensor (input buffer), which may differ from the graph dtype + # (e.g., _to_copy produces fp32 values stored into a bf16 gradient buffer). + store_dtype = V.graph.get_dtype(name) + if name in self.args.inplace_buffers: + buf = self.args.inplace_buffers[name] + if not isinstance(buf, RemovedArg): + store_dtype = V.graph.get_dtype(buf.other_names[0]) + value = f"{value}.to({triton_store_type(store_dtype)})" if isinstance(indexing, BlockPtrOptions): return f"tl.store({block_ptr}, {value}{other})" return f"{block_ptr}.store({V.kernel.index_to_str(indexing.offsets)}, {value})" @@ -3780,6 +3976,7 @@ def store( dense_indexing=True, block_ptr=mode is None, tma_compatibility_checker=tma_compatibility_checker, + mask_constant_index=mode == "atomic_add", ) if isinstance(indexing, IndexingOptions) and self._has_stride1_on_rdim( @@ -5215,7 +5412,7 @@ def codegen_kernel_benchmark(self, num_gb: float | None) -> IndentedBuffer: hint_override=self.hint_override, ) result.writeline( - f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" ) elif arg_name in V.graph.constants: # note that random seed is put in V.graph.constants @@ -5229,7 +5426,7 @@ def codegen_kernel_benchmark(self, num_gb: float | None) -> IndentedBuffer: hint_override=self.hint_override, ) result.writeline( - f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] ) elif isinstance(arg_sig, SizeArg): symval_hint = V.graph.sizevars.optimization_hint_with_override( @@ -5295,7 +5492,7 @@ def codegen_kernel_benchmark(self, num_gb: float | None) -> IndentedBuffer: result.writeline("args = get_args()") result.writeline( - f"ms = benchmarker.benchmark(lambda: call(args), device='{V.graph.get_current_device_or_throw().type}', rep=40)" # noqa: B950 line too long + f"ms = benchmarker.benchmark(lambda: call(args), device='{V.graph.get_current_device_or_throw().type}', rep=40)" ) result.writeline(f"num_gb = {num_gb}") result.writeline("gb_per_s = num_gb / (ms / 1e3)") @@ -5390,7 +5587,8 @@ def inductor_meta_common(cls): "min_split_scan_rblock": config.triton.min_split_scan_rblock, "spill_threshold": config.triton.spill_threshold, "store_cubin": config.triton.store_cubin, - "deterministic": config.deterministic, + "deterministic": config.deterministic or config.batch_invariant, + "batch_invariant": config.batch_invariant, "force_filter_reduction_configs": config.test_configs.force_filter_reduction_configs, "mix_order_reduction_allow_multi_stages": config.triton.mix_order_reduction_allow_multi_stages, } @@ -5650,8 +5848,6 @@ def add_constexpr_arg(arg_name): if flops is not None: inductor_meta["kernel_flop"] = flops - triton_meta["configs"] = [config_of(signature)] - # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -5667,6 +5863,17 @@ def add_constexpr_arg(arg_name): self.codegen_body() self._filter_pdl(self.body) + # Compute configs after codegen_body() so we know if the kernel + # uses atomic ops. On HIP, buffer ops don't support atomics, so + # we must not tag any args with pointer_range_32 in that case. + # Also disable pointer_range_32 when the config flag is off. + if torch.version.hip is not None and ( + self.atomic_add_found or not config.triton.emit_pointer_range_32 + ): + triton_meta["configs"] = [config_of(signature, pointer_range_override=())] + else: + triton_meta["configs"] = [config_of(signature)] + for helper in self.helper_functions: code.writeline("") code.splice(helper) @@ -6176,7 +6383,11 @@ def load(self, name: str, index: sympy.Expr) -> TritonCSEVariable: ) return result_var else: - return super().load(name, index) + # The scheduler should prevent this. + raise AssertionError( + f"Epilogue attempted to load from '{name}'. " + "Inductor indexing variables are not defined in user kernel scope. " + ) def store( self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None @@ -6199,9 +6410,13 @@ def codegen(self) -> str: # Generate a new AST where the store value expr is replaced with the new value new_ast = copy.deepcopy(self.ir_node.kernel_ast) - from torch._higher_order_ops.triton_kernel_wrap import identify_triton_stores + from torch._higher_order_ops.triton_kernel_wrap import ( + identify_triton_stores, + identify_triton_stores_from_ast, + ) - kernel_stores = identify_triton_stores(new_ast) + # avoid redundant cache entry of new_ast + kernel_stores = identify_triton_stores_from_ast(new_ast) assert len(kernel_stores.stores) == 1 new_store_value_node = ast.Name(self.new_store_cse_var.name) @@ -6231,7 +6446,7 @@ def _replace_arg( src_lines = src_with_store_replaced.splitlines() # identify the store again, because the previous parse-modify-unparse could've change its location - kernel_stores = identify_triton_stores(ast.parse(src_with_store_replaced)) + kernel_stores = identify_triton_stores(src_with_store_replaced) # python ast lineno is 1-indexed store_line_index = kernel_stores.stores[0].store_node.lineno - 1 @@ -6370,6 +6585,8 @@ def define_kernel(self, src_code, node_schedule, kernel): if config.triton.descriptive_names else "" ) + if fused_name: + fused_name = V.choices.customize_fused_kernel_name(fused_name, src_code) kernel_category = get_kernel_category_by_source_code(src_code)[:3] kernel_name = "_".join( ["triton", kernel_category, fused_name, wrapper.next_kernel_suffix()] @@ -6472,7 +6689,7 @@ def load_cache(): except Exception as e: if config.triton.disallow_failing_autotune_kernels_TESTING_ONLY: raise - log.debug( # noqa: G200 + log.debug( "Exception (%s) in compiling fused nodes %s", e, node_names, @@ -6729,11 +6946,16 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]: from torch._inductor.codegen.cuda_combined_scheduling import ( CUDACombinedScheduling, ) + from torch._inductor.codegen.xpu.xpu_combined_scheduling import ( + XPUCombinedScheduling, + ) device = node.get_device() assert device is not None backend = node.scheduler.get_backend(device) - assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), ( + assert isinstance( + backend, (SIMDScheduling, CUDACombinedScheduling, XPUCombinedScheduling) + ), ( f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" ) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 1cb2ef3f48971..702e0e3a8bdb3 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -40,6 +40,11 @@ from .triton_utils import config_of, equal_1_arg_indices, signature_to_meta +# Default block sizes used when combo kernel autotuning is disabled. +DEFAULT_COMBO_BLOCK_SIZE_1D = 1024 +DEFAULT_COMBO_BLOCK_SIZE_2D = 32 + + log = logging.getLogger(__name__) pexpr = PythonPrinter().doprint LARGE_NUMELS = 512e5 @@ -454,9 +459,9 @@ def __init__( | None ) = None self.block_args: list[str] = [] - # there following are used when autotuning is disabled - self.block_size_1d = 1024 # Try tuning this value - self.block_size_2d = 32 + # the following are used when autotuning is disabled + self.block_size_1d = DEFAULT_COMBO_BLOCK_SIZE_1D + self.block_size_2d = DEFAULT_COMBO_BLOCK_SIZE_2D self.num_warps = 8 self.block_size_reduce = 256 self.dynamic_shape_args: list[str] = [] @@ -762,6 +767,10 @@ def jit_line( "combo_grid_meta": self.combo_grid_meta(size_hints_list), "kernel_name": str(Placeholder.DESCRIPTIVE_NAME), "mutated_arg_names": mutated_args, + # Matches triton.py:codegen_kernel(): inference/backward graphs skip + # CPU-copy of mutated args during autotune retries; training-forward + # graphs must keep it to preserve benchmark inputs across retries. + "optimize_mem": V.graph.is_inference or V.graph.is_backward, **self.triton_kernel_cls.inductor_meta_common(), } if max_persistent_rblock > 0: @@ -1007,7 +1016,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: size = V.graph.sizevars.optimization_hints(buf.get_size()) stride = V.graph.sizevars.optimization_hints(buf.get_stride()) result.writeline( - f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" # noqa: B950 line too long + f"{var_name} = rand_strided({size}, {stride}, device='{buf.get_device()}', dtype={buf.get_dtype()})" ) elif arg_name in V.graph.constants: # note that random seed is put in V.graph.constants @@ -1015,7 +1024,7 @@ def codegen_kernel_benchmark(self, num_gb: float) -> IndentedBuffer: size = V.graph.sizevars.optimization_hints(const_tensor.size()) stride = V.graph.sizevars.optimization_hints(const_tensor.stride()) result.writeline( - f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] # noqa: B950 line too long + f"{var_name} = rand_strided({size}, {stride}, device='{const_tensor.device}', dtype={const_tensor.dtype})" # type: ignore[arg-type] ) elif isinstance(arg_sig, SizeArg): symval_hint = V.graph.sizevars.optimization_hint(arg_sig.expr) @@ -1201,6 +1210,15 @@ def combo_grid_meta(self, size_hints_list: list[dict[str, int]]) -> dict[str, An meta[f"tile_hint_{num}"] = "TileHint.SQUARE" else: meta[f"tile_hint_{num}"] = "TileHint.DEFAULT" + if sub_kernel.tiling_scores: + meta[f"tiling_scores_{num}"] = { + dim: V.graph.sizevars.optimization_hint(score, fallback=1) + for dim, score in sub_kernel.tiling_scores.items() + } + else: + meta[f"reduction_hint_{num}"] = ( + sub_kernel.features.get_reduction_hint().name + ) for tree in sub_kernel.range_trees: if not tree.is_reduction: diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index d4adb2aaea473..f89c72ebfb1c4 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -34,18 +34,7 @@ def should_unwrap_unspec_arg(name: str): def signature_of(arg: KernelArgType, *, size_dtype: str | None) -> str: if isinstance(arg, TensorArg): - # TODO: Remove fp8 special handling when Triton supports PyTorch fp8 dtypes. - # Related PR: https://github.com/triton-lang/triton/pull/2279/ - if arg.dtype == torch.float8_e4m3fn: - typ = "*fp8e4nv" - elif arg.dtype == torch.float8_e5m2: - typ = "*fp8e5" - elif arg.dtype == torch.float8_e4m3fnuz: - typ = "*fp8e4b8" - elif arg.dtype == torch.float8_e5m2fnuz: - typ = "*fp8e5b16" - else: - typ = _type_of(arg.dtype) + typ = _type_of(arg.dtype) if should_unwrap_unspec_arg(arg.buffer): # had unwrapped 0d tensor as scalar new_typ = typ.lstrip("*") @@ -160,6 +149,21 @@ def _decide_tl_dtype(arg): } +def _get_buffer_layout(buf_name: str) -> "torch._inductor.ir.Layout": + """Get the layout for a buffer, handling both scheduler buffers and graph inputs.""" + if V.graph.scheduler: + layout = V.graph.scheduler.get_buffer_layout(buf_name) + else: + buffer = V.graph.try_get_buffer(buf_name) + # output arg + if not buffer: + assert buf_name == V.kernel.output_node.name + layout = V.kernel.output_node.layout + else: + layout = buffer.get_layout() + return layout + + def is_unaligned_buffer(arg: TensorArg): buf_name = arg.buffer if buf_name in V.graph.unaligned_buffers: @@ -175,17 +179,7 @@ def is_unaligned_buffer(arg: TensorArg): # all constants are assumed to be aligned return False - if V.graph.scheduler: - layout = V.graph.scheduler.get_buffer_layout(buf_name) - else: - buffer = V.graph.try_get_buffer(buf_name) - # output arg - if not buffer: - assert buf_name == V.kernel.output_node.name - layout = V.kernel.output_node.layout - else: - layout = buffer.get_layout() - + layout = _get_buffer_layout(buf_name) if isinstance(layout, torch._inductor.ir.NonOwningLayout): return not layout.maybe_guard_aligned() else: @@ -213,10 +207,36 @@ def equal_1_arg_indices( return equal_to_1 +def _is_tensor_within_2gb(arg: TensorArg) -> bool: + """Check if a tensor argument's storage is provably within 2GB. + + Mirrors HIPBackend.is_within_2gb() but uses compile-time symbolic analysis + instead of runtime tensor inspection. This enables canonicalize_pointers to + decompose pointer arithmetic into (splat(base), offset) form for buffer ops. + """ + MAX_BYTES = 2**31 - 1 + try: + # Graph inputs aren't tracked by the scheduler; get their layout + # from the graph_inputs dict to avoid KeyError in get_buffer_layout. + if arg.buffer in V.graph.graph_inputs: + inp = V.graph.graph_inputs[arg.buffer] + if hasattr(inp, "get_layout"): + layout = inp.get_layout() + else: + return False + else: + layout = _get_buffer_layout(arg.buffer) + storage_bytes = layout.storage_size() * arg.dtype.itemsize + return V.graph.sizevars.statically_known_true(storage_bytes <= MAX_BYTES) + except Exception: + return False + + def config_of( args: list[KernelArgType], *, indices: list[int] | None = None, + pointer_range_override: tuple[int, ...] | None = None, ) -> Any: if indices is None: indices = list(range(len(args))) @@ -263,5 +283,18 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: equal_to_1 = equal_1_arg_indices(args, indices=indices) - # pyrefly: ignore [bad-argument-type] - return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) + # On AMD/HIP, tag tensor args whose storage fits in 2GB so Triton + # can use 32-bit pointer offsets and emit buffer load/store ops. + if pointer_range_override is not None: + pointer_range_32 = pointer_range_override + elif torch.version.hip is not None: + pointer_range_32 = tuple( + i + for i, arg in zip(indices, args) + if isinstance(arg, TensorArg) and _is_tensor_within_2gb(arg) + ) + else: + pointer_range_32 = () + + # pyrefly: ignore [bad-argument-count, bad-argument-type] + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1, pointer_range_32) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 06a54f278bb51..70a1d55716e95 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -47,7 +47,7 @@ from ..ir import IRNode, ReinterpretView from ..runtime import triton_heuristics from ..runtime.hints import DeviceProperties -from ..stream_constants import DEFAULT_STREAM, STREAM_NAME_TEMPLATE +from ..stream_constants import DEFAULT_STREAM, DEFAULT_STREAM_IDX, STREAM_NAME_TEMPLATE from ..stream_utils import get_stream_name from ..utils import ( cache_on_self, @@ -78,7 +78,7 @@ if TYPE_CHECKING: - from collections.abc import Iterator, Sequence + from collections.abc import Iterable, Iterator, Sequence import triton @@ -93,7 +93,7 @@ pexpr = PythonPrinter().doprint -ReuseKey = tuple[torch.device, torch.dtype, str, bool] +ReuseKey = tuple[torch.device, torch.dtype, str, bool, int] CommBufferReuseKey = tuple[torch.device, torch.dtype, str, "ir.CommBufferType", str] BufferLike = ir.Buffer | WorkspaceArg FxConversionFunc = Callable[["WrapperLine"], None] @@ -102,6 +102,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: storage_size = V.graph.get_allocation_storage_size(node) alignment = node.get_name() not in V.graph.unaligned_buffers + stream = V.graph.scheduler.get_buf_stream(node.get_name()) return ( node.get_device_or_error(), node.get_dtype(), @@ -110,6 +111,7 @@ def buffer_reuse_key(node: BufferLike) -> ReuseKey: # size hint sympy_str(V.graph.sizevars.simplify(storage_size)), alignment, + stream, ) @@ -328,6 +330,7 @@ def user_defined_triton_kernel_transitive_closure_source_code( import triton from triton import JITFunction # type: ignore[name-defined, attr-defined] from triton.language import constexpr # type: ignore[name-defined] + from triton.language.core import dtype as triton_dtype # global constexpr vars handled above symbols_included = OrderedSet([kernel.__name__]) @@ -351,7 +354,6 @@ def traverse(cur_kernel): if isinstance(symbol, JITFunction): compile_wrapper.newline() compile_wrapper.writeline("@triton.jit") - compile_wrapper.splice(symbol.src, strip=True) symbols_included.add(symbol_name) traverse(symbol) @@ -359,9 +361,25 @@ def traverse(cur_kernel): symbol, triton.runtime.jit.ConstexprFunction, ): + # Import dtype class if used in type annotations + if "dtype" in symbol.src and "dtype" not in symbols_included: + dtype_symbol = symbol.fn.__globals__.get("dtype") + if ( + dtype_symbol + and hasattr(dtype_symbol, "__module__") + and dtype_symbol.__module__.startswith("triton") + ): + compile_wrapper.writeline( + f"from {dtype_symbol.__module__} import dtype as dtype" + ) + symbols_included.add("dtype") compile_wrapper.newline() compile_wrapper.writeline("@triton.constexpr_function") compile_wrapper.splice(symbol.src, strip=True) + if symbol_name != symbol.fn.__name__: + compile_wrapper.writeline( + f"{symbol_name} = {symbol.fn.__name__}" + ) symbols_included.add(symbol_name) traverse(symbol) elif isinstance(symbol, (int, str, bool, constexpr)): @@ -395,9 +413,14 @@ def traverse(cur_kernel): # a global symbol imported from triton is referenced # without module qualification (i.e., `store` instead # of `tl.store`): need to codegen an import - compile_wrapper.writeline( - f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" - ) + + # Triton dtype instances have .name instead of .__name__ + if isinstance(symbol, triton_dtype): + compile_wrapper.writeline(f"{symbol_name} = tl.{symbol.name}") + elif hasattr(symbol, "__name__"): + compile_wrapper.writeline( + f"from {symbol.__module__} import {symbol.__name__} as {symbol_name}" + ) symbols_included.add(symbol_name) traverse(kernel) @@ -675,6 +698,7 @@ class KernelCallLine(WrapperLine): device: torch.device graph_name: str original_fxnode_name: str + current_stream_idx: int | None = None def codegen(self, code: IndentedBuffer) -> None: self.wrapper._generate_kernel_call_helper( @@ -689,6 +713,7 @@ def codegen(self, code: IndentedBuffer) -> None: device=self.device, graph_name=self.graph_name, original_fxnode_name=self.original_fxnode_name, + current_stream_idx=self.current_stream_idx, ) def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: @@ -745,16 +770,19 @@ def __str__(self) -> str: @dataclasses.dataclass class EnterDeviceContextManagerWithStreamInfoLine(EnterDeviceContextManagerLine): - """Enter a CUDA device context and allocate required side streams. + """Enter a CUDA device context and retrieve user stream objects. Attributes: - num_streams: Number of streams to allocate (determined by user annotations on nodes). + num_streams: Number of streams (determined by user annotations on nodes). + stream_idx_to_user_obj_idx: Maps stream_idx → user_object_index for + retrieving user stream objects via get_external_object_by_index. """ num_streams: int = 1 + stream_idx_to_user_obj_idx: dict[int, int] = dataclasses.field(default_factory=dict) def codegen(self, code: IndentedBuffer) -> None: - """Generate context switching and stream allocation code.""" + """Generate context switching and stream retrieval code.""" if V.graph.cpp_wrapper: super().codegen(code) else: @@ -763,9 +791,10 @@ def codegen(self, code: IndentedBuffer) -> None: if self.num_streams > 1: for i in range(1, self.num_streams): + user_obj_idx = self.stream_idx_to_user_obj_idx[i] code.writeline( f"{STREAM_NAME_TEMPLATE.format(stream_idx=i)} " - f"= torch.cuda.Stream(device={self.device_idx})", + f"= get_external_object_by_index({user_obj_idx})", ) @@ -887,6 +916,7 @@ def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: return self # Regular buffer reuse + # Stream is part of the key, so cross-stream reuse is naturally prevented. key = buffer_reuse_key(self.node) if config.allow_buffer_reuse and key in state: free_line = state.pop(key) @@ -1172,6 +1202,20 @@ def codegen_fx(self, converter: FxConverter) -> FxConversionFunc: return converter._generate_unbacked_symbol_defs +@dataclasses.dataclass +class AssertSizeStrideLine(WrapperLine): + name: str + size: str + stride: str + + def codegen(self, code: IndentedBuffer) -> None: + code.writeline(f"assert_size_stride({self.name}, {self.size}, {self.stride})") + + @staticmethod + def codegen_fx(converter: FxConverter) -> FxConversionFunc: + return converter._generate_assert_size_stride + + BufferName = str Line = MemoryPlanningLine | LineContext @@ -1185,6 +1229,8 @@ class PythonWrapperCodegen(CodeGen): def __init__(self): super().__init__() + self._pending_input_asserts: dict[str, tuple[str, str]] = {} + self._pending_alignment_copies: OrderedSet[str] = OrderedSet() self._names_iter: Iterator[int] = count() self.args_to_buffers: dict[ str, None | ir.TensorBox | ir.Buffer | ir.TorchBindObject @@ -1422,6 +1468,7 @@ def write_kernel_autotune_defs_header(self) -> None: self.kernel_autotune_defs.splice( f""" import torch + from math import inf, nan from torch._dynamo.testing import rand_strided from torch._dynamo.utils import preserve_rng_state from torch._inductor.select_algorithm import AlgorithmSelectorCache @@ -1512,13 +1559,19 @@ def get_graph_outputs(self) -> list[IRNode]: def codegen_input_size_asserts(self) -> None: for name, buf in self.get_graph_inputs().items(): - if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + if isinstance( + buf, + ( + sympy.Basic, + ir.TorchBindObject, + ir.GeneratorState, + ir.OpaqueObjectState, + ), + ): continue # a graph partition may take an IRNode output from a previous partition - if name not in V.graph.graph_input_names or isinstance( - buf, ir.GeneratorState - ): + if name not in V.graph.graph_input_names: continue # comparing strides for 0 size tensor is tricky. Ignore them for now. @@ -1526,14 +1579,13 @@ def codegen_input_size_asserts(self) -> None: continue size = self.codegen_python_shape_tuple(buf.get_size()) stride = self.codegen_python_shape_tuple(buf.get_stride()) - self.prefix.writeline(f"assert_size_stride({name}, {size}, {stride})") + self._pending_input_asserts[name] = (size, stride) def codegen_input_nan_asserts(self) -> None: self.prefix.writeline("# make sure graph inputs are not nan/inf") for name, buf in self.get_graph_inputs().items(): - if isinstance(buf, (sympy.Expr, ir.TorchBindObject)): + if isinstance(buf, (sympy.Basic, ir.TorchBindObject)): continue - line = f"assert not {name}.isnan().any().item()" self.prefix.writeline(line) line = f"assert not {name}.isinf().any().item()" @@ -1619,6 +1671,47 @@ def codegen_input_size_and_nan_asserts(self) -> None: if config.nan_asserts: self.codegen_input_nan_asserts() + # Input size/stride assertions are deferred from the top of call() to just + # before the first kernel that uses each input. This avoids a block of N + # sequential assert calls (~1 us each) on the critical path before the first + # GPU kernel launch. Called from the scheduler codegen loop. + def codegen_deferred_input_asserts(self, input_names: Iterable[str]) -> None: + for name in input_names: + if name in self._pending_input_asserts: + size, stride = self._pending_input_asserts.pop(name) + self.writeline(AssertSizeStrideLine(name, size, stride)) + + def register_alignment_check_inputs(self) -> None: + """Populate pending alignment copies for non-mutated inputs. + Called from the scheduler after mutated_input_idxs is computed.""" + if V.graph.cpp_wrapper: + return + inputs_to_check = V.graph.inputs_to_check + if not inputs_to_check: + return + # Mutated inputs are handled separately by the runtime wrapper, + # which needs to copy back the mutation after the call. + mutated_idxs = OrderedSet(V.graph.mutated_input_idxs) + for idx in inputs_to_check: + if idx not in mutated_idxs: + name = V.graph.graph_input_names[idx] + self._pending_alignment_copies.add(name) + if self._pending_alignment_copies: + V.graph._defers_input_alignment = True + self.imports.writeline( + "from torch._C._dynamo.guards import copy_if_misaligned" + ) + + def codegen_deferred_alignment_copies(self, input_names: Iterable[str]) -> None: + """Emit alignment check + clone just before the first kernel + that reads each input, hiding the cost behind GPU execution.""" + if V.graph.cpp_wrapper: + return + for name in input_names: + if name in self._pending_alignment_copies: + self._pending_alignment_copies.discard(name) + self.writeline(f"{name} = copy_if_misaligned({name})") + # this function (and below) takes the graph name as input so # that stream caching happens per graph instance. this # is important for nested subgraph codegening. @@ -1655,13 +1748,26 @@ def pop_computed_sizes(self): def next_kernel_suffix(self) -> str: return f"{next(self._names_iter)}" - def codegen_device_guard_enter(self, device_idx: int, num_streams: int = 1) -> None: + def codegen_device_guard_enter( + self, + device_idx: int, + num_streams: int = 1, + stream_idx_to_user_obj_idx: dict[int, int] | None = None, + ) -> None: if num_streams > 1: + assert stream_idx_to_user_obj_idx is not None + import_line = ( + "from torch._dynamo.graph_bytecode_inputs import " + "get_external_object_by_index" + ) + if not self.imports.contains(import_line): + self.imports.writeline(import_line) self.writeline( EnterDeviceContextManagerWithStreamInfoLine( device_idx, self.last_seen_device_guard_index, num_streams, + stream_idx_to_user_obj_idx, ), ) else: @@ -2172,9 +2278,9 @@ def strideof(name): if isinstance(stride, sympy.Symbol) and stride not in bound_vars: code.writeline(f"{stride} = {strideof(name)}[{dim}]") bound_vars.add(stride) - elif isinstance(value, ir.TorchBindObject): - return - elif isinstance(value, ir.GeneratorState): + elif isinstance( + value, (ir.TorchBindObject, ir.GeneratorState, ir.OpaqueObjectState) + ): return else: if torch._inductor.config.graph_partition: @@ -2465,11 +2571,7 @@ def add_torchbind_input(name, value): # the subclass. continue if isinstance(value, ir.TorchBindObject): - if len(V.graph.torchbind_constants) == 0: - # otherwise we have already imported the pickle package - output.writeline("import pickle") - output.writeline(f"global {name}") - add_torchbind_input(name, value.get_real_obj()) + output.writeline(f"{name} = None") elif isinstance(value, sympy.Expr): # Don't need to add symbolic # TODO: this fallback and those below actually will generate possibly # invalid benchmark code, because it's not guaranteed 42 @@ -2478,11 +2580,18 @@ def add_torchbind_input(name, value): add_expr_input( name, V.graph.sizevars.optimization_hint(value, fallback=42) ) + elif isinstance(value, sympy.Basic): + # sympy.Boolean (e.g. StrictLessThan from torch.cond predicates) + # is not a sympy.Expr so optimization_hint cannot handle it. + # Use False as a fallback for benchmark harness purposes. + add_expr_input(name, False) elif isinstance(value, ir.GeneratorState): add_expr_input( name, f"torch.cuda.default_generators[{value.device.index}].graphsafe_get_state()", ) + elif isinstance(value, ir.OpaqueObjectState): + output.writeline(f"{name} = None") else: shape = V.graph.sizevars.optimization_hints( value.get_size(), fallback=42 @@ -2826,6 +2935,10 @@ def rename_sizes_for_launcher(expr: int | sympy.Expr) -> sympy.Expr: cache_key.append(arg) cache_key.append(str(triton_meta)) cache_key.extend(str(inductor_meta)) + + if epilogue_fusion is not None: + cache_key.append((epilogue_fusion[0].get_name(), epilogue_fusion[1])) + cache_key = tuple(cache_key) if cache_key in self.user_defined_kernel_cache: name, triton_meta, cached_inductor_meta = self.user_defined_kernel_cache[ @@ -3135,6 +3248,7 @@ def generate_kernel_call( ) device = device or V.graph.get_current_device_or_throw() + current_stream_idx = V.graph.scheduler.current_stream_idx self.writeline( KernelCallLine( self, @@ -3154,6 +3268,7 @@ def generate_kernel_call( graph_name=V.graph.name, # pyrefly: ignore [bad-argument-type] original_fxnode_name=original_fxnode_name, + current_stream_idx=current_stream_idx, ) ) @@ -3171,25 +3286,32 @@ def _generate_kernel_call_helper( inductor_meta=None, graph_name="", original_fxnode_name=None, + current_stream_idx=None, ): device = device or V.graph.get_current_device_or_throw() - if not triton and device.type != "cuda": + if not triton and device.type not in ("cuda", "xpu"): if device.type == "cpu": self.writeline(self.wrap_kernel_call(kernel_name, call_args)) elif device.type == "mps": # TODO: Fix me, MPS does not expose streams now - self.writeline( - self.wrap_kernel_call(f"{kernel_name}.generated_kernel", call_args) - ) + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) else: raise RuntimeError(f"device {device.type} nyi") return call_args_str = self.prepare_triton_kernel_call(call_args) call_args_str = ", ".join(call_args_str) - stream_name = PythonWrapperCodegen.write_get_raw_stream( - self, device.index, graph_name - ) + if current_stream_idx is not None and current_stream_idx != DEFAULT_STREAM_IDX: + # Inside a user stream context: emit a fresh get_raw_stream call so + # it picks up the active stream at runtime, rather than reusing the + # LRU-cached stream0 variable which captured the default stream. + self.write_get_raw_stream_header() + stream_name = "raw_stream" + self.writeline(f"{stream_name} = get_raw_stream({device.index})") + else: + stream_name = PythonWrapperCodegen.write_get_raw_stream( + self, device.index, graph_name + ) if not triton: stream_ptr = f"c_void_p({stream_name})" self.writeline( @@ -3274,7 +3396,7 @@ def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): reused_args = {} for i, (arg, arg_type, raw_key, raw_arg) in enumerate( - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] zip(call_args, arg_types, raw_keys, raw_args) ): key = None @@ -3315,6 +3437,7 @@ def infer_arg_by_inputs(raw_keys, raw_args, idx, reused_args): self.kernel_autotune_example_args[arg] = (arg_str, kernel_name) else: arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg) + if isinstance(arg, str) and should_unwrap_unspec_arg(arg): arg_str += ".item()" all_args.append(arg_str if key is None else f"{key}={arg_str}") @@ -3383,7 +3506,7 @@ def __repr__(self): return s.codegen_reference() elif has_triton_package() and isinstance(s, triton.language.dtype): # type: ignore[possibly-undefined] return repr(s) - elif isinstance(s, ir.GeneratorState): + elif isinstance(s, (ir.GeneratorState, ir.OpaqueObjectState)): return s.codegen_reference() elif is_opaque_value_type(type(s)): obj_repr, opaque_types = get_opaque_obj_repr(s) diff --git a/torch/_inductor/codegen/wrapper_fxir.py b/torch/_inductor/codegen/wrapper_fxir.py index b03347cac7286..4a1cce8bdb964 100644 --- a/torch/_inductor/codegen/wrapper_fxir.py +++ b/torch/_inductor/codegen/wrapper_fxir.py @@ -4,7 +4,7 @@ import operator import textwrap from collections import Counter -from collections.abc import Callable, Sequence +from collections.abc import Callable, Iterable, Sequence from typing import Any import sympy @@ -232,6 +232,13 @@ def write_header(self) -> None: """ PythonWrapperCodegen.write_header(self) + def register_alignment_check_inputs(self) -> None: + """FXIR does not emit deferred alignment copies. + Alignment is handled by the runtime wrapper.""" + + def codegen_deferred_alignment_copies(self, input_names: Iterable[str]) -> None: + """FXIR does not emit deferred alignment copies.""" + @classmethod def create( cls: type["WrapperFxCodegen"], @@ -714,6 +721,9 @@ def generate_buffer(node: ir.IRNode | None) -> torch.fx.Node | None: ) self._record_allocation(ir_node, fx_node) + def _generate_assert_size_stride(self, line: WrapperLine) -> None: + pass + def _generate_comment(self, line: WrapperLine) -> None: assert isinstance(line, CommentLine) # We ignore comments in FX IR. diff --git a/torch/_inductor/codegen/xpu/compile_utils.py b/torch/_inductor/codegen/xpu/compile_utils.py new file mode 100644 index 0000000000000..9904043c66d59 --- /dev/null +++ b/torch/_inductor/codegen/xpu/compile_utils.py @@ -0,0 +1,136 @@ +# mypy: allow-untyped-defs +import logging +import os +import shutil +import subprocess + +from torch._inductor import config +from torch._inductor.codegen.xpu.xpu_env import get_xpu_arch +from torch._inductor.utils import is_linux + +from ..cuda.compile_utils import _cutlass_include_paths + + +log = logging.getLogger(__name__) + + +def _sycl_compiler() -> str: + # Search order: + # 0) which icpx + # 1) config.xpu.oneapi_root + # 2) ONEAPI_ROOT environment variable + # 3) default system search PATH. + if shutil.which("icpx"): + return "icpx" + + if os.path.exists(config.xpu.oneapi_root or ""): + oneapi_root = config.xpu.oneapi_root + elif os.path.exists(os.getenv("ONEAPI_ROOT") or ""): + oneapi_root = os.getenv("ONEAPI_ROOT") + else: + oneapi_root = None + + if oneapi_root: + oneapi_inclue = os.path.join(oneapi_root, "include") + if "CPLUS_INCLUDE_PATH" in os.environ: + os.environ["CPLUS_INCLUDE_PATH"] += ":" + oneapi_inclue + else: + os.environ["CPLUS_INCLUDE_PATH"] = oneapi_inclue + return os.path.realpath(os.path.join(oneapi_root, "bin/icpx")) + else: + raise RuntimeError("Can not find Intel compiler.") + + +def _sycl_lib_options() -> list[str]: + """ + Util function for CUTLASS backend to find the correct XPU libraries. + """ + # _set_gpu_runtime_env() # cpp_extension consults the env + from torch.utils import cpp_extension + + lpaths = cpp_extension.library_paths(device_type="xpu") + extra_ldflags: list[str] = [] + if is_linux(): + for path in lpaths: + if "torch/lib" in path: + # don't want to depend on pytorch + continue + # -rpath ensures the DLL can find its dependencies when loaded, even + # if the library path is non-standard. + extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"]) + + extra_ldflags.append("-lsycl") + else: + raise NotImplementedError( + "Unsupported env, failed to find xpu libs! Currently only Linux is supported." + ) + return extra_ldflags + + +def _sycl_arch_as_compile_option() -> str: + arc_option_map = {"Xe12": "intel_gpu_pvc", "Xe20": "intel_gpu_bmg_g21"} + arch = get_xpu_arch() + return arc_option_map.get(arch, "intel_gpu_pvc") + + +def _sycl_compiler_options() -> list[str]: + options = [ + "-DCUTLASS_ENABLE_SYCL", + "-DSYCL_INTEL_TARGET", + "-DCUTLASS_VERSIONS_GENERATED", + "-O3", + "-DNDEBUG", + "-std=c++20", + "-fPIC", + "-fsycl", + f"-fsycl-targets={_sycl_arch_as_compile_option()}", + "-Xspirv-translator", + "-spirv-ext=+SPV_INTEL_split_barrier,+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate", + "-fno-sycl-instrument-device-code", + "-DMKL_ILP64", + "-MD", + "-Xs", + ( + "-options \"-igc_opts 'VISAOptions=-perfmodel,VectorAliasBBThreshold=100000000000," + "ExtraOCLOptions=-cl-intel-256-GRF-per-thread'\" " + "-options -ze-opt-large-register-file" + ), + ] + if config.cutlass.enable_debug_info: + options.extend(["-lineinfo", "-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"]) + return options + + +def xpu_compile_command( + src_files: list[str], + dst_file: str, + dst_file_ext: str, + extra_args: list[str] | None = None, +) -> str: + if extra_args is None: + extra_args = [] + include_paths = _cutlass_include_paths() + sycl_lib_options = _sycl_lib_options() + sycl_compiler_options = _sycl_compiler_options() + + # Build command as a list to preserve arguments with spaces + cmd_parts = ( + [_sycl_compiler()] + + extra_args + + ["-I" + path for path in include_paths] + + ["-isystem", "/include"] + + sycl_compiler_options + + sycl_lib_options + ) + if dst_file_ext == "o": + cmd_parts.extend(["-c", "-o", dst_file] + src_files) + elif dst_file_ext == "so": + cmd_parts.extend(["-shared", "-o", dst_file] + src_files) + elif dst_file_ext == "exe": + cmd_parts.extend(["-o", dst_file] + src_files) + else: + raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!") + + res = subprocess.list2cmdline(cmd_parts) + log.debug("XPU command: %s", res) + return res diff --git a/torch/_inductor/codegen/xpu/xpu_combined_scheduling.py b/torch/_inductor/codegen/xpu/xpu_combined_scheduling.py new file mode 100644 index 0000000000000..b7c1b5eb3933e --- /dev/null +++ b/torch/_inductor/codegen/xpu/xpu_combined_scheduling.py @@ -0,0 +1,135 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + +from torch._inductor.scheduler import ( + BaseSchedulerNode, + BaseScheduling, + FusedSchedulerNode, + Scheduler, + SchedulerNode, +) + +from ..cutlass.scheduling import CUTLASSScheduling +from ..triton import TritonScheduling + + +if TYPE_CHECKING: + from collections.abc import Sequence + from typing import TypeAlias + + from sympy import Expr + + import torch + from torch.utils._ordered_set import OrderedSet + + from ..common import BackendFeature + + _IntLike: TypeAlias = int | Expr + + +class XPUCombinedScheduling(BaseScheduling): + """ + Scheduler for XPU Kernels, which delegates calls as appropriate + to the SYCL-C++ and Triton Schedulers, which both work for XPU devices + and use a unified-wrapper for codegen. + + If Scheduling code needs to be specialized for the case of mixed Triton / SYCL C++ code, + this would also be the place to do it. + """ + + def __init__(self, scheduler: Scheduler | None) -> None: + super().__init__(scheduler) + self._triton_scheduling = TritonScheduling(scheduler) + self._cutlass_scheduling = CUTLASSScheduling(scheduler) + + def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]: + return self._triton_scheduling.get_backend_features(device) + + def choose_node_backend(self, node: BaseSchedulerNode) -> BaseScheduling: + if self._cutlass_scheduling.is_cutlass_template(node): + return self._cutlass_scheduling + return self._triton_scheduling + + def can_fuse_vertical( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + if self._cutlass_scheduling.can_fuse_vertical(node1, node2): + return True + elif self._cutlass_scheduling.is_cutlass_template( + node1 + ) or self._cutlass_scheduling.is_cutlass_template(node2): + return False + return self._triton_scheduling.can_fuse_vertical(node1, node2) + + def can_fuse_horizontal( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> bool: + for node in (node1, node2): + if self._cutlass_scheduling.is_cutlass_template(node): + return self._cutlass_scheduling.can_fuse_horizontal( + node1, node2 + ) # always False at the moment + return self._triton_scheduling.can_fuse_horizontal(node1, node2) + + def group_fn( + self, sizes: Sequence[Sequence[_IntLike]] + ) -> tuple[tuple[_IntLike, ...], ...]: + return self._triton_scheduling.group_fn(sizes) + + def codegen_template( + self, + template_node: BaseSchedulerNode, + epilogue_nodes: Sequence[BaseSchedulerNode], + prologue_nodes: Sequence[BaseSchedulerNode], + ) -> str | None: + if self._cutlass_scheduling.is_cutlass_template(template_node): + assert not prologue_nodes + return self._cutlass_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + else: + return self._triton_scheduling.codegen_template( + template_node, epilogue_nodes, prologue_nodes + ) + + def codegen_mix_order_reduction(self, node): + return self._triton_scheduling.codegen_mix_order_reduction(node) + + def codegen_node(self, node: FusedSchedulerNode | SchedulerNode) -> None: + return self._triton_scheduling.codegen_node(node) + + def codegen_sync(self) -> None: + return self._triton_scheduling.codegen_sync() + + def flush(self) -> None: + return self._triton_scheduling.flush() + + def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None: + return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs) + + def benchmark_fused_nodes( + self, nodes: Sequence[BaseSchedulerNode] + ) -> tuple[float, str]: + return self._triton_scheduling.benchmark_fused_nodes(nodes) + + def benchmark_codegened_module(self, module): + return self._triton_scheduling.benchmark_codegened_module(module) + + def generate_kernel_code_from_nodes( + self, + nodes: Sequence[Any], + benchmark_kernel: bool = False, + hint_override: int | None = None, + ) -> str: + return self._triton_scheduling.generate_kernel_code_from_nodes( + nodes, benchmark_kernel, hint_override=hint_override + ) + + def benchmark_combo_kernel( + self, node_list: Sequence[BaseSchedulerNode], node_benchmark_results + ) -> tuple[float, float, list[str | None]]: + return self._triton_scheduling.benchmark_combo_kernel( + node_list, node_benchmark_results + ) diff --git a/torch/_inductor/codegen/xpu/xpu_env.py b/torch/_inductor/codegen/xpu/xpu_env.py new file mode 100644 index 0000000000000..a74b4cf2a24c7 --- /dev/null +++ b/torch/_inductor/codegen/xpu/xpu_env.py @@ -0,0 +1,38 @@ +import functools +import logging + +import torch +from torch._inductor.utils import clear_on_fresh_cache + + +log = logging.getLogger(__name__) + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_xpu_arch() -> str | None: + from torch.testing._internal.common_xpu import get_xpu_codename, XPUCodename + + name2arch = { + XPUCodename.PVC: "Xe12", + XPUCodename.BMG: "Xe20", + } + + codename = get_xpu_codename() + if not codename or codename not in name2arch: + log.warning("Unknown XPU codename, cannot determine architecture") + return None + + return name2arch[codename] + + +@clear_on_fresh_cache +@functools.lru_cache(1) +def get_xpu_version() -> str | None: + # string of version, like 20250101 + try: + xpu_version = torch.version.xpu or "" + return xpu_version + except Exception: + log.exception("Error getting xpu version") + return None diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index cef14b8da37b2..6f277daf47d55 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -9,7 +9,7 @@ import torch import torch.utils._pytree as pytree -from torch.fx.experimental.symbolic_shapes import size_hint +from torch.fx.experimental.symbolic_shapes import optimization_hint from torch.fx.operator_schemas import normalize_function from . import ir @@ -86,7 +86,7 @@ def get_ir_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int def get_fx_node_size_numel(size: torch.Size, fallback: int = 4096 * 4096) -> int: numel = functools.reduce(operator.mul, size, 1) - result = size_hint(numel, fallback=fallback) + result = optimization_hint(numel, fallback=fallback) return result @@ -399,10 +399,18 @@ def add_inp_bytes(inp: torch.fx.Node): output_val = fx_node.meta.get("val", None) - if input_bytes is None or not isinstance(output_val, torch.Tensor): + if input_bytes is None or output_val is None: return 0 - output_bytes = tensor_bytes(output_val) + # Coalesced collectives return a list of tensors + if isinstance(output_val, (list, tuple)): + output_bytes = sum( + tensor_bytes(t) for t in output_val if isinstance(t, torch.Tensor) + ) + elif isinstance(output_val, torch.Tensor): + output_bytes = tensor_bytes(output_val) + else: + return 0 return input_bytes + output_bytes @@ -467,19 +475,18 @@ def estimate_nccl_collective_runtime_from_fx_node( def _nccl_estimate() -> float | None: # TODO: Refactor with estimate_nccl_collective_runtime_nccl_estimator - from torch.distributed.distributed_c10d import ( - _get_pg_default_device, - _resolve_process_group, - Backend, - ) + from torch.distributed.distributed_c10d import _resolve_process_group, Backend pg = _resolve_process_group(group_name) if torch.distributed.distributed_c10d.get_backend(pg) == Backend.FAKE: # nccl estimator requires real process group return None - device = _get_pg_default_device(pg) - backend = pg._get_backend(device) + device = torch.device("cuda") + try: + backend = pg._get_backend(device) + except RuntimeError: + return None if not backend.supports_time_estimate: return None @@ -492,9 +499,6 @@ def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def device=device, ) - def try_size_hint(s: sympy.Expr) -> int: - return V.graph.sizevars.optimization_hint(s, fallback=0) - def to_real_tensor(e: Any) -> Any: if isinstance(e, torch.fx.Node): return to_real_tensor(e.meta["val"]) @@ -507,9 +511,16 @@ def to_real_tensor(e: Any) -> Any: fn = fx_node.target assert isinstance(fn, torch._ops.OpOverload) - with torch.distributed._time_estimator(group=pg) as time_estimator: + with torch.distributed._time_estimator( + group=pg, device=device + ) as time_estimator: w = fn(*real_args, **real_kwargs) - torch.ops._c10d_functional.wait_tensor.default(w) + # Coalesced collectives return a list of tensors + if isinstance(w, (list, tuple)): + for t in w: + torch.ops._c10d_functional.wait_tensor.default(t) + else: + torch.ops._c10d_functional.wait_tensor.default(w) est_time_us = time_estimator.estimated_time # -1000 constant is NCCL return in case of error during estimations. # Observed it for all_to_all estimations. diff --git a/torch/_inductor/comm_lowering.py b/torch/_inductor/comm_lowering.py index 60c710ccb6884..be547aa09b561 100644 --- a/torch/_inductor/comm_lowering.py +++ b/torch/_inductor/comm_lowering.py @@ -183,6 +183,12 @@ def _one_shot_all_reduce(inp: ir.TensorBox, reduce_op, group_name): ) +def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode: + node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args) + assert isinstance(node, ir.IRNode) + return ir.TensorBox.create(node) + + def register_comm_lowerings(): """ Register lowerings for the comm subsystem. @@ -287,11 +293,6 @@ def _all_reduce_coalesced_(inputs, reduce_op, group_name): ) return inputs - def _create_out_of_place(kernel, inputs, *args) -> ir.IRNode: - node = ir._CollectiveKernel.create_out_of_place(kernel, inputs, *args) - assert isinstance(node, ir.IRNode) - return ir.TensorBox.create(node) - @register_comm_lowering(c10d.all_gather_into_tensor) def _all_gather_into_tensor(inp, group_size, group_name): return _create_out_of_place( @@ -483,21 +484,65 @@ def register_symm_mem_lowerings(): log.info("symm_mem ops not available, skipping symm_mem lowerings") return + from torch._library._out_variant import register_out_variant + + # Register manual out variant mappings for symm_mem ops. + register_out_variant( + symm_mem.one_shot_all_reduce.default, + symm_mem.one_shot_all_reduce_out.default, + ) + register_out_variant( + symm_mem.one_shot_all_reduce_copy.default, + symm_mem.one_shot_all_reduce_copy_out.default, + ) + from .lowering import register_lowering + def _copy_input_to_comm_buffer( + inp: ir.TensorBox, + comm_buffer_type: ir.CommBufferType, + group_name: "torch.distributed.distributed_c10d.GroupName", + ) -> ir.TensorBox: + """ + Fallback: insert a Pointwise identity copy allocated in P2P via + CommBufferLayout. Used when we don't control the input's allocation. + """ + inp.realize() + copy = ir.Pointwise.create( + device=inp.get_device(), + dtype=inp.get_dtype(), + inner_fn=inp.make_loader(), + ranges=inp.get_size(), + ) + realize_as_comm_buffer(copy, comm_buffer_type, group_name) + return copy + def _maybe_realize_symm_mem( inp: ir.TensorBox, group_name: str, # type: ignore[arg-type] - ) -> None: + ) -> ir.TensorBox: """ - Helper to realize an input as symmetric memory buffer if possible. + Ensure inp is in P2P memory for a symm_mem collective. + + If inductor controls the buffer's allocation (ComputedBuffer, + or any buffer with FlexibleLayout/FixedLayout), switch its + layout to CommBufferLayout in-place, zero-copy. + + If inductor does not control allocation (e.g. InputBuffer), + insert a Pointwise identity copy into a new CommBufferLayout buffer. + This adds an extra Triton kernel. Returns the possibly new TensorBox. + + TODO(tianrengao): eliminate the extra kernel for static-shape + InputBuffers by pre-allocating P2P memory in the wrapper and DMA .copy_() """ if can_realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM): realize_as_comm_buffer(inp, ir.CommBufferType.SYMM_MEM, group_name) # type: ignore[arg-type] + return inp else: - log.warning( - "Failed to realize the input as a symmetric memory buffer for symm_mem operation; " - "ensure the input is allocated as a symmetric memory buffer." + return _copy_input_to_comm_buffer( + inp, + ir.CommBufferType.SYMM_MEM, + group_name, # type: ignore[arg-type] ) @register_lowering(symm_mem.one_shot_all_reduce) @@ -506,7 +551,7 @@ def _symm_mem_one_shot_all_reduce( reduce_op: str, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -524,7 +569,7 @@ def _symm_mem_one_shot_all_reduce_out( group_name: str, out: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -543,7 +588,7 @@ def _symm_mem_one_shot_all_reduce_copy( reduce_op: str, group_name: str, ): - _maybe_realize_symm_mem(symm_buffer, group_name) + symm_buffer = _maybe_realize_symm_mem(symm_buffer, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -563,7 +608,7 @@ def _symm_mem_one_shot_all_reduce_copy_out( group_name: str, out: ir.TensorBox, ): - _maybe_realize_symm_mem(symm_buffer, group_name) + symm_buffer = _maybe_realize_symm_mem(symm_buffer, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -582,7 +627,7 @@ def _symm_mem_two_shot_all_reduce_( reduce_op: str, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) ir.FallbackKernel.create( symm_mem.two_shot_all_reduce_.default, inp, @@ -598,7 +643,7 @@ def _symm_mem_two_shot_all_reduce_out( group_name: str, output: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -616,7 +661,7 @@ def _symm_mem_multimem_all_reduce_( reduce_op: str, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) ir.FallbackKernel.create( symm_mem.multimem_all_reduce_.default, inp, @@ -631,7 +676,7 @@ def _symm_mem_multimem_one_shot_all_reduce( reduce_op: str, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -649,7 +694,7 @@ def _symm_mem_multimem_one_shot_all_reduce_out( group_name: str, out: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -669,7 +714,7 @@ def _symm_mem_multimem_one_shot_reduce_out( group_name: str, out: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -688,7 +733,7 @@ def _symm_mem_multimem_all_gather_out( group_name: str, out: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -706,7 +751,7 @@ def _symm_mem_reduce_scatter_out( split_last_dim: bool, output: ir.TensorBox, ): - _maybe_realize_symm_mem(inp, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) return pytree.tree_map( ir.TensorBox.create, ir.FallbackKernel.create( @@ -726,8 +771,8 @@ def _symm_mem_all_to_all_vdev( out_splits_offsets: ir.TensorBox, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) - _maybe_realize_symm_mem(out, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) + out = _maybe_realize_symm_mem(out, group_name) ir.FallbackKernel.create( symm_mem.all_to_all_vdev.default, inp, @@ -747,8 +792,8 @@ def _symm_mem_all_to_all_vdev_2d( group_name: str, major_align=None, ): - _maybe_realize_symm_mem(inp, group_name) - _maybe_realize_symm_mem(out, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) + out = _maybe_realize_symm_mem(out, group_name) ir.FallbackKernel.create( symm_mem.all_to_all_vdev_2d.default, inp, @@ -768,8 +813,8 @@ def _symm_mem_all_to_all_vdev_2d_offset( out_splits_offsets: ir.TensorBox, group_name: str, ): - _maybe_realize_symm_mem(inp, group_name) - _maybe_realize_symm_mem(out, group_name) + inp = _maybe_realize_symm_mem(inp, group_name) + out = _maybe_realize_symm_mem(out, group_name) ir.FallbackKernel.create( symm_mem.all_to_all_vdev_2d_offset.default, inp, @@ -788,8 +833,8 @@ def _symm_mem_tile_reduce( group_name: str, reduce_op: str = "sum", ): - _maybe_realize_symm_mem(in_tile, group_name) - _maybe_realize_symm_mem(out_tile, group_name) + in_tile = _maybe_realize_symm_mem(in_tile, group_name) + out_tile = _maybe_realize_symm_mem(out_tile, group_name) ir.FallbackKernel.create( symm_mem.tile_reduce.default, in_tile, @@ -808,9 +853,9 @@ def _symm_mem_multi_root_tile_reduce( group_name: str, reduce_op: str = "sum", ): - for in_tile in in_tiles: - _maybe_realize_symm_mem(in_tile, group_name) - _maybe_realize_symm_mem(out_tile, group_name) + for i, in_tile in enumerate(in_tiles): + in_tiles[i] = _maybe_realize_symm_mem(in_tile, group_name) + out_tile = _maybe_realize_symm_mem(out_tile, group_name) ir.FallbackKernel.create( symm_mem.multi_root_tile_reduce.default, in_tiles, @@ -820,3 +865,35 @@ def _symm_mem_multi_root_tile_reduce( reduce_op, ) return None + + @register_lowering(symm_mem._low_contention_all_gather) + def _symm_mem_low_contention_all_gather( + inp: ir.TensorBox, + group_name: str, + ): + # Use _CollectiveKernel so that _WaitKernel.get_volatile_reads() + # can track the input's lifetime through wait_tensor, preventing + # the memory planner from reusing the input buffer while the + # backend stream is still reading it. + return _create_out_of_place( + symm_mem._low_contention_all_gather.default, + inp, + group_name, + ) + + @register_lowering(symm_mem._low_contention_reduce_scatter) + def _symm_mem_low_contention_reduce_scatter( + inp: ir.TensorBox, + reduce_op: str, + group_name: str, + ): + # Use _CollectiveKernel so that _WaitKernel.get_volatile_reads() + # can track the input's lifetime through wait_tensor, preventing + # the memory planner from reusing the input buffer while the + # backend stream is still reading it. + return _create_out_of_place( + symm_mem._low_contention_reduce_scatter.default, + inp, + reduce_op, + group_name, + ) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 34a1c2a693d8a..6453c4b506125 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -1751,7 +1751,7 @@ def _find_buffers_with_changed_last_use_sink_waits( for buf in candidate_bufs: snode_last_use = buf_to_snode_last_use[buf] - if snode_last_use != candidate: # noqa: E711 + if snode_last_use != candidate: continue # candidate is last use of buf @@ -2186,7 +2186,7 @@ def visualize_overlap(order): cur_comm_node = None def step_log(step, msg): - overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004 + overlap_log.debug(f"{step:>6}: {msg}") for step, snode in enumerate(order): if cur_comm_node is None: @@ -2205,15 +2205,13 @@ def step_log(step, msg): if contains_collective(snode): total_est_runtime += estimate_op_runtime(snode) cur_comm_node = snode.node - step_log(step, f"{node_summary(snode)}") # noqa: G004 + step_log(step, f"{node_summary(snode)}") elif is_wait(snode.node): # end of this comm op step_log(step, f"{node_summary(snode)}") cur_comm_node = None else: # overlapped compute op step_log(step, f"| {node_summary(snode)}") - overlap_log.debug( - f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004 - ) + overlap_log.debug(f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}") def reorder_compute_and_comm_for_overlap( diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 6fbd56eec8912..ace206bc0e7ac 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -35,6 +35,7 @@ logging as dynamo_logging, utils as dynamo_utils, ) +from torch._dynamo.backends import common as dynamo_common from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.repro.after_aot import wrap_compiler_debug from torch._dynamo.utils import ( @@ -46,10 +47,11 @@ flatten_graph_inputs, get_inputs_devices, get_metrics_context, + GmWrapper, lazy_format_graph_code, set_feature_use, ) -from torch._functorch import config as functorch_config +from torch._functorch import aot_autograd, config as functorch_config from torch._functorch._aot_autograd.subclass_parametrization import ( unwrap_tensor_subclass_parameters, ) @@ -86,6 +88,7 @@ count_tangents, fresh_cache, get_all_devices, + get_static_bw_input_idxs, InputType, is_gpu, should_assume_input_aligned, @@ -102,7 +105,6 @@ from torch.monitor import _WaitCounter from torch.utils._ordered_set import OrderedSet -from .._dynamo.backends.common import aot_autograd from .._dynamo.exc import ShortenTraceback, SkipFrame from ..fx._lazy_graph_module import _use_lazy_graph_module from ..fx.graph import _PyTreeCodeGen @@ -158,8 +160,6 @@ def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log if TYPE_CHECKING: - import types - from torch._functorch._aot_autograd.schemas import ( FQN, GraphInputName, @@ -771,6 +771,7 @@ class _CompileFxKwargs(TypedDict, total=False): extern_node_serializer: Callable[[list[ExternKernelNode]], Any] | None boxed_forward_device_index: BoxedDeviceIndex | None fx_wrapper: bool + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] class _CompileFxCallable(Protocol): @@ -778,6 +779,7 @@ def __call__( self, gm: GraphModule, example_inputs: Sequence[InputType], + compile_region_name: str | None = None, **kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: ... @@ -785,6 +787,7 @@ def __call__( def compile_fx_inner( gm: GraphModule, example_inputs: Sequence[InputType], + compile_region_name: str | None = None, **kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: kwargs.setdefault("cudagraphs", None) @@ -803,6 +806,15 @@ def compile_fx_inner( # compile_fx return and we may want to use the _LazyGraphModule for compiling # the backward graph as well. with contextlib.ExitStack() as stack: + # When cpp_wrapper is enabled, ensure the required triton config + # (store_cubin, autotune_at_compile_time, etc.) is applied. This is + # needed because lazy backward compilation may run after the + # config.patch context from compile_fx has already exited. + # Suppress cudagraph skip logging here; compile_fx already logged it. + if kwargs["cpp_wrapper"]: + stack.enter_context( + config.patch(get_cpp_wrapper_config(log_cudagraph_skip=False)) + ) stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) stack.enter_context(_use_lazy_graph_module(dynamo_config.use_lazy_graph_module)) stack.enter_context( @@ -824,6 +836,7 @@ def compile_fx_inner( return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( gm, example_inputs, + compile_region_name=compile_region_name, **kwargs, ) @@ -832,6 +845,7 @@ def compile_fx_inner( def _compile_fx_inner( gm: GraphModule, example_inputs: Sequence[InputType], + compile_region_name: str | None = None, **graph_kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: """ @@ -890,6 +904,7 @@ def _compile_fx_inner( save_args_for_compile_fx_inner( gm, example_inputs, + compile_region_name=compile_region_name, **graph_kwargs, ) @@ -986,7 +1001,11 @@ def _compile_fx_inner( TritonBundler.begin_compile() try: mb_compiled_graph = fx_codegen_and_compile( - gm, example_inputs, inputs_to_check, **graph_kwargs + gm, + example_inputs, + inputs_to_check, + compile_region_name=compile_region_name, + **graph_kwargs, ) assert mb_compiled_graph is not None ( @@ -1018,7 +1037,11 @@ def _compile_fx_inner( ) try: mb_compiled_graph = fx_codegen_and_compile( - gm, example_inputs, inputs_to_check, **graph_kwargs + gm, + example_inputs, + inputs_to_check, + compile_region_name=compile_region_name, + **graph_kwargs, ) except Exception as e: raise InductorError(e, currentframe()).with_traceback( @@ -1033,7 +1056,11 @@ def _compile_fx_inner( TritonBundler.begin_compile() try: mb_compiled_graph = fx_codegen_and_compile( - gm, example_inputs, inputs_to_check, **graph_kwargs + gm, + example_inputs, + inputs_to_check, + compile_region_name=compile_region_name, + **graph_kwargs, ) assert mb_compiled_graph is not None mb_compiled_graph._time_taken_ns = time.time_ns() - start_time @@ -1078,6 +1105,8 @@ def _compile_fx_inner( assert mb_compiled_graph is not None compiled_graph = mb_compiled_graph + if isinstance(compiled_graph, CompiledFxGraph): + compiled_graph.compile_region_name = compile_region_name # Logging and observability: we log a single chromium event # and a tlparse log for every cache action. @@ -1125,6 +1154,10 @@ def _compile_fx_inner( ) compiled_graph.post_compile(example_inputs, constants, graph_kwargs) + policy = config.cudagraph_policy + if policy is not None: + compiled_graph = policy.wrap_output(compiled_graph) + log.debug("FX codegen and compilation took %.3fs", time.time() - start) # This message is for printing overview information of inductor mm counts, shapes,etc after lowering @@ -1192,6 +1225,7 @@ class FxCompile(ABC): # Some stats for logging/debugging _compile_stats: dict[type[FxCompile], _FxCompileStat] = defaultdict(_FxCompileStat) + compile_region_name: str | None = None # TODO: We should probably eventually add some kind of async version of this # so we can kick off a compile and then go do other things - but we'll need @@ -1239,7 +1273,9 @@ def codegen_and_compile( extern_node_serializer: Callable[[list[ExternKernelNode]], Any] | None = ( graph_kwargs.get("extern_node_serializer", None) ) - + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = graph_kwargs.get( + "get_decomp_fn", select_decomp_table + ) with ( _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(), dynamo_utils.preserve_rng_state(), @@ -1435,6 +1471,7 @@ def codegen_and_compile( is_backward=is_backward, is_const_graph=True, fx_wrapper=fx_wrapper, + get_decomp_fn=get_decomp_fn, ) with ( V.set_graph_handler(const_graph), @@ -1469,6 +1506,7 @@ def codegen_and_compile( const_module=const_graph, inputs_to_check=inputs_to_check, fx_wrapper=fx_wrapper, + get_decomp_fn=get_decomp_fn, ) metrics_helper = metrics.CachedMetricsHelper() @@ -1733,6 +1771,7 @@ def codegen_and_compile( cudagraphs, example_inputs, static_input_idxs, + self.compile_region_name, graph_kwargs, inputs_to_check, runnable_graph_str, @@ -1749,6 +1788,7 @@ def fx_codegen_and_compile( # This is derivable from the other inputs to this function, but we pass it # in explicitly because it's nontrivial to compute inputs_to_check: Sequence[int], + compile_region_name: str | None = None, **graph_kwargs: Unpack[_CompileFxKwargs], ) -> OutputCode: scheme: FxCompile @@ -1774,6 +1814,9 @@ def fx_codegen_and_compile( ) # pyrefly: ignore [unbound-name] scheme = _AsyncFxCompile(scheme) + scheme._compile.compile_region_name = ( + compile_region_name # pyrefly: ignore[attr-defined] + ) if fx_compile_progressive: from .compile_fx_async import _ProgressiveFxCompile @@ -1788,9 +1831,15 @@ def fx_codegen_and_compile( # Use in-process compile for the fast version fast_scheme = _InProcessFxCompile() + fast_scheme.compile_region_name = compile_region_name # pyrefly: ignore [unbound-name] scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) + scheme._optimized_compile.compile_region_name = ( + compile_region_name # pyrefly: ignore[attr-defined] + ) + + scheme.compile_region_name = compile_region_name # pyrefly: ignore[unbound-name] # pyrefly: ignore [unbound-name] return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -2067,6 +2116,7 @@ def fw_compiler_freezing( dynamo_model: GraphModule, num_example_inputs: int, inner_compile: Callable[..., Any], + # TODO: Take compiler_config_extra instead cudagraphs: BoxedBool, graph_id: int, forward_device: BoxedDeviceIndex, @@ -2076,7 +2126,8 @@ def fw_compiler_freezing( # partition_fn won't be called inputs_devices = get_inputs_devices(aot_example_inputs, aot_autograd_model) aot_autograd_model = _recursive_joint_graph_passes( - aot_autograd_model, input_device=next(iter(inputs_devices)) + aot_autograd_model, + input_device=next(iter(inputs_devices)), ) layout_opt = GraphLowering.decide_layout_opt(aot_autograd_model, is_inference=True) @@ -2166,21 +2217,28 @@ def wrapper(args: list[object]) -> Sequence[torch.Tensor]: return wrapper -def get_cpp_wrapper_config() -> dict[str, object]: - if config.triton.cudagraphs: +def get_cpp_wrapper_config(log_cudagraph_skip: bool = True) -> dict[str, object]: + if log_cudagraph_skip and config.triton.cudagraphs and config.graph_partition: log_cudagraph_skip_and_bump_counter( - format_default_skip_message("cpp wrapper enabled") + format_default_skip_message( + "cpp-wrapper does not support graph partition yet" + ) ) + autotune_at_compile_time = ( + config.triton.autotune_at_compile_time + if config.triton.autotune_at_compile_time is not None + # Default to True for AOTI. Subject to change in future. + else has_triton() and V.aot_compilation + ) return { - # Set autotune_at_compile_time to True as default if the option is not explicitly set - "triton.autotune_at_compile_time": ( - config.triton.autotune_at_compile_time - if config.triton.autotune_at_compile_time is not None - else has_triton() + "triton.autotune_at_compile_time": autotune_at_compile_time, + "triton.autotune_cublasLt": not autotune_at_compile_time, + "triton.cudagraphs": ( + config.triton.cudagraphs + and not V.aot_compilation + and not config.graph_partition ), - "triton.autotune_cublasLt": False, - "triton.cudagraphs": False, # TODO: to be removed "triton.store_cubin": True, } @@ -2215,7 +2273,9 @@ def partition_fn( # in partitioning. inputs_devices = get_inputs_devices(joint_inputs, gm) gm = _recursive_joint_graph_passes( - gm, skip_invoke_subgraph=True, input_device=next(iter(inputs_devices)) + gm, + skip_invoke_subgraph=True, + input_device=next(iter(inputs_devices)), ) static_lifetime_input_indices: list[int] | None = kwargs.pop( # type: ignore[assignment] @@ -2271,12 +2331,15 @@ class CompilerConfigExtra: cudagraphs: BoxedBool graph_id: int forward_device: BoxedDeviceIndex + forward_is_partitioned: BoxedBool cudagraphs_bwd_override: bool | None = None def create_compiler_config_extra( - config: types.ModuleType, gm_meta: dict[str, Any] | None = None + gm: GraphModule | GmWrapper, ) -> CompilerConfigExtra: + gm_meta = gm.meta if isinstance(gm, GraphModule) else None + # Although cudagraphs may have been enabled via config, various # conditions (which are tested within the bowels of Inductor) may # force cudagraphs to be disabled. This mutable box lets us retrieve @@ -2320,11 +2383,17 @@ def create_compiler_config_extra( # See [Backward Generation Handling] forward_device = BoxedDeviceIndex(None) + # Set by the forward compilation when it is partitioned for CUDA graphs. + # The backward reads this to decide whether saved tensors can be assumed + # to have fixed addresses. + forward_is_partitioned = BoxedBool(False) + return CompilerConfigExtra( cudagraphs=cudagraphs, graph_id=graph_id, forward_device=forward_device, cudagraphs_bwd_override=cudagraphs_bwd_override, + forward_is_partitioned=forward_is_partitioned, ) @@ -2363,6 +2432,18 @@ def compile_fx_forward( ), ) + # Snapshot stack traces on the output node before passes run, + # as later passes may strip stack_trace from individual nodes. + output = output_node(gm) + output.meta["output_stack_traces"] = [ + ( + arg.meta.get("stack_trace") + if isinstance(arg, torch.fx.node.Node) + else None + ) + for arg in output.args[0] # type: ignore[union-attr] + ] + inputs_devices = get_inputs_devices(example_inputs, gm) gm = _recursive_joint_graph_passes(gm, input_device=next(iter(inputs_devices))) @@ -2429,7 +2510,7 @@ def compile_fx_forward( _recursive_record_user_visible_output_idxs(gm) with cudagraph_annotation_context(compiler_config_extra.cudagraphs): - return inner_compile( + result = inner_compile( gm, example_inputs, static_input_idxs=get_static_input_idxs(fixed), @@ -2439,6 +2520,16 @@ def compile_fx_forward( boxed_forward_device_index=compiler_config_extra.forward_device, ) + if ( + not is_inference + and isinstance(result, CompiledFxGraph) + and result.partition_maps + and len(result.partition_maps) > 1 + ): + compiler_config_extra.forward_is_partitioned.value = True + + return result + def compile_fx_backward( gm: GraphModule, @@ -2476,6 +2567,13 @@ def compile_fx_backward( if compiler_config_extra.cudagraphs_bwd_override is not None: cudagraphs = BoxedBool(compiler_config_extra.cudagraphs_bwd_override) + # When the forward was partitioned, saved activations from inline + # code between partitions are NOT at fixed addresses. Only mark + # primals (params/buffers) as static. + if compiler_config_extra.forward_is_partitioned.value: + static_input_idxs: Sequence[int] = get_static_bw_input_idxs(gm) + else: + static_input_idxs = list(range(fixed)) with ( ( config.patch(get_cpp_wrapper_config()) @@ -2487,7 +2585,7 @@ def compile_fx_backward( return inner_compile( gm, example_inputs, - static_input_idxs=list(range(fixed)), + static_input_idxs=static_input_idxs, cudagraphs=cudagraphs, is_backward=True, graph_id=compiler_config_extra.graph_id, @@ -2552,6 +2650,7 @@ def compile_fx( config_patches: dict[str, Any] | None = None, decompositions: dict[OpOverload, Callable[..., Any]] | None = None, ignore_shape_env: bool = False, + compile_region_name: str | None = None, ) -> CompileFxOutput: """ Main entry point for compiling given FX graph. Despite the fact that this @@ -2564,6 +2663,13 @@ def compile_fx( NB: This function TAKES OWNERSHIP of the input ``model_`` and can potentially mutate it! Make a copy if you need to preserve the original GraphModule. """ + if decompositions is not None: + + def get_decomp_fn() -> dict[Any, Callable[..., Any]]: + return decompositions # pyrefly: ignore[bad-return] + else: + get_decomp_fn = select_decomp_table + # Some arguments trigger a recursive call to compile_fx. Handle these # short circuits first, before anything else @@ -2582,8 +2688,15 @@ def compile_fx( inner_compile=config.patch(config_patches)(inner_compile), decompositions=decompositions, ignore_shape_env=ignore_shape_env, + compile_region_name=compile_region_name, ) + # Keep region names out of graph_kwargs so they don't perturb FX cache keys. + inner_compile = functools.partial( + inner_compile, + compile_region_name=compile_region_name, + ) + # Wake up the AsyncCompile subproc pool as early as possible (if there's cuda). if any( isinstance(e, torch.Tensor) and e.device.type in ("cuda", "xpu") @@ -2622,16 +2735,18 @@ def compile_fx( cpp_wrapper=cpp_wrapper_config, fx_wrapper=fx_wrapper_config, ), - decompositions=decompositions, ignore_shape_env=ignore_shape_env, + get_decomp_fn=get_decomp_fn, + compile_region_name=compile_region_name, ) return _maybe_wrap_and_compile_fx_main( model_, example_inputs_, inner_compile, - decompositions, ignore_shape_env, + get_decomp_fn=get_decomp_fn, + compile_region_name=compile_region_name, ) @@ -2670,8 +2785,10 @@ def _maybe_wrap_and_compile_fx_main( model_: GraphModule, example_inputs_: Sequence[InputType], inner_compile: Callable[..., OutputCode], - decompositions: dict[OpOverload, Callable[..., Any]] | None, ignore_shape_env: bool, + *, + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = select_decomp_table, + compile_region_name: str | None = None, ) -> CompileFxOutput: """ Part of compile_fx, called after patching configs. @@ -2685,8 +2802,9 @@ def _maybe_wrap_and_compile_fx_main( compile_gm = functools.partial( _maybe_wrap_and_compile_fx_main, inner_compile=inner_compile, - decompositions=decompositions, ignore_shape_env=ignore_shape_env, + get_decomp_fn=get_decomp_fn, + compile_region_name=compile_region_name, ) if not graph_returns_tuple(model_): return make_graph_return_tuple(model_, example_inputs_, compile_gm) @@ -2707,8 +2825,9 @@ def _maybe_wrap_and_compile_fx_main( model_, example_inputs_, inner_compile, - decompositions, ignore_shape_env, + get_decomp_fn=get_decomp_fn, + compile_region_name=compile_region_name, ) @@ -2716,8 +2835,10 @@ def _compile_fx_main( model_: GraphModule, example_inputs_: Sequence[InputType], inner_compile: Callable[..., OutputCode], - decompositions: dict[OpOverload, Callable[..., Any]] | None, ignore_shape_env: bool, + *, + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = select_decomp_table, + compile_region_name: str | None = None, ) -> CompileFxOutput: """ Main part of compile_fx, called after wrapping is done. @@ -2746,12 +2867,10 @@ def _compile_fx_main( num_example_inputs = len(example_inputs_) - gm_meta = model_.meta if isinstance(model_, GraphModule) else None - compiler_config_extra = create_compiler_config_extra(config, gm_meta) + compiler_config_extra = create_compiler_config_extra(model_) - decompositions = ( - decompositions if decompositions is not None else select_decomp_table() - ) + decompositions = get_decomp_fn() + inner_compile = functools.partial(inner_compile, get_decomp_fn=get_decomp_fn) def fw_compiler_base( gm: GraphModule, @@ -2900,17 +3019,17 @@ def bw_compiler( ), ): try: - return aot_autograd( + return dynamo_common.aot_autograd( fw_compiler=fw_compiler, bw_compiler=bw_compiler, inference_compiler=inference_compiler, decompositions=decompositions, partition_fn=partition_fn, keep_inference_input_mutations=True, - cudagraphs=compiler_config_extra.cudagraphs, - boxed_forward_device_index=compiler_config_extra.forward_device, + compiler_config_extra=compiler_config_extra, ignore_shape_env=ignore_shape_env, pre_grad_passes=run_pre_grad_passes, + compile_region_name=compile_region_name, )(model_, example_inputs_) except ShortenTraceback as e: # We will also shorten the traceback inside dynamo. @@ -3079,7 +3198,7 @@ def _aoti_flatten_inputs( ] if in_spec is not None and received_spec != in_spec: - raise ValueError( # noqa: B904 + raise ValueError( "Trying to flatten user inputs with exported input tree spec: \n" f"{in_spec}\n" "but actually got inputs with tree spec of: \n" @@ -3099,3 +3218,73 @@ def _aoti_flatten_inputs( } ) return flat_example_inputs, options + + +def autograd_cache_key( + graph, + example_inputs, + ignore_shape_env: bool, + decompositions=None, +): + if config.cpp_wrapper or config.fx_wrapper: + raise RuntimeError( + "autograd_cache_key is not supported with cpp_wrapper or fx_wrapper" + ) + + decompositions = ( + decompositions if decompositions is not None else select_decomp_table() + ) + # compile_fx applies these graph transforms before reaching _compile_fx_main. + # Neither occurs on the torch.compile/Dynamo path (which always produces + # tuple-returning, pre-flattened graphs). Not supported by this API. + if isinstance(graph, GraphModule) and not graph_returns_tuple(graph): + raise NotImplementedError( + "autograd_cache_key does not support graphs that don't return a tuple" + ) + if any(isinstance(x, (list, tuple, dict)) for x in example_inputs): + raise NotImplementedError( + "autograd_cache_key does not support nested container inputs" + ) + + compiler_config_extra = create_compiler_config_extra(graph) + + # These context managers replicate the ones that _compile_fx_main sets up + # before calling aot_autograd, so that the config snapshot captured by + # autograd_cache_key is identical to a real compile_fx run: + # _compile_fx_main outer with-block: _use_lazy_graph_module, + # enable_python_dispatcher, preserve_node_meta, + # reset_provenance_globals + # _compile_fx_main aot_autograd with-block: V.set_fake_mode, + # torch._guards.tracing, compiled_autograd._disable, + # functorch_config.patch + + fake_mode = detect_fake_mode(example_inputs) or torch._subclasses.FakeTensorMode( + allow_non_fake_inputs=True + ) + tracing_context = ( + torch._guards.TracingContext.try_get() + or torch._guards.TracingContext(fake_mode) + ) + + with ( + functorch_config.patch( + unlift_effect_tokens=True, selective_decompose=config.selective_decompose + ), + _use_lazy_graph_module(dynamo_config.use_lazy_graph_module), + enable_python_dispatcher(), + torch.fx.traceback.preserve_node_meta( + config.trace.provenance_tracking_level == 1 + ), + torch._inductor.debug.reset_provenance_globals(), + V.set_fake_mode(fake_mode), + torch._guards.tracing(tracing_context), + compiled_autograd._disable(), + ): + return aot_autograd.autograd_cache_key( + graph, + example_inputs, + ignore_shape_env=ignore_shape_env, + decompositions=decompositions, + compiler_config_extra=compiler_config_extra, + keep_inference_input_mutations=True, + ) diff --git a/torch/_inductor/compile_fx_async.py b/torch/_inductor/compile_fx_async.py index b6968ebd8daed..732a2c0a52659 100644 --- a/torch/_inductor/compile_fx_async.py +++ b/torch/_inductor/compile_fx_async.py @@ -6,7 +6,11 @@ from typing_extensions import final, override import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools -from torch._inductor.output_code import CompiledFxGraphConstants, OutputCode +from torch._inductor.output_code import ( + CompiledFxGraph, + CompiledFxGraphConstants, + OutputCode, +) from .compile_fx import _CompileFxKwargs, _InProcessFxCompile, FxCompile from .output_code import complex_memory_overlap # noqa: F401 @@ -199,7 +203,9 @@ def codegen_and_compile( inputs_to_check: Sequence[int], graph_kwargs: _CompileFxKwargs, ) -> OutputCode: - eager_output_code = _InProcessFxCompile().codegen_and_compile( + eager_compile = _InProcessFxCompile() + eager_compile.compile_region_name = self.compile_region_name + eager_output_code = eager_compile.codegen_and_compile( gm, example_inputs, inputs_to_check, graph_kwargs ) @@ -222,6 +228,8 @@ def codegen_and_compile( def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: _AsyncFxCompile._stat_bg_finished += 1 output = pickled_output.deserialize(constants) + if isinstance(output.graph, CompiledFxGraph): + output.graph.compile_region_name = self.compile_region_name self._compile._postprocess(output) return output.graph @@ -392,6 +400,8 @@ def codegen_and_compile( def callback(pickled_output: _WireProtocolPickledOutput) -> OutputCode: _ProgressiveFxCompile._stat_bg_finished += 1 output = pickled_output.deserialize(constants) + if isinstance(output.graph, CompiledFxGraph): + output.graph.compile_region_name = self.compile_region_name self._optimized_compile._postprocess(output) return output.graph diff --git a/torch/_inductor/compile_fx_ext.py b/torch/_inductor/compile_fx_ext.py index 493f3b6b28960..bb569908fd98e 100644 --- a/torch/_inductor/compile_fx_ext.py +++ b/torch/_inductor/compile_fx_ext.py @@ -14,7 +14,7 @@ from typing import Any, TYPE_CHECKING, TypeGuard from typing_extensions import final, override, Self -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch.fx from torch._inductor.codecache import BypassFxGraphCache, FxGraphCache from torch._inductor.metrics import CachedMetricsDeltas, CachedMetricsHelper @@ -438,12 +438,16 @@ def codegen_and_compile( gm, example_inputs, inputs_to_check, graph_kwargs ) if not serialized: - return _InProcessFxCompile().codegen_and_compile( + eager_compile = _InProcessFxCompile() + eager_compile.compile_region_name = self.compile_region_name + return eager_compile.codegen_and_compile( gm, example_inputs, inputs_to_check, graph_kwargs ) inputs, constants = serialized output = self._send_to_child(inputs).deserialize(constants) + if isinstance(output.graph, CompiledFxGraph): + output.graph.compile_region_name = self.compile_region_name self._postprocess(output) self._compile_stats[type(self)].codegen_and_compile += 1 @@ -469,7 +473,7 @@ def serialize_compile( # we can't cache (or serialize) FxGraphCache._check_for_hop(gm) except BypassFxGraphCache as e: - log.debug("Skipping %s compile: %s", type(self), e) # noqa: G200 + log.debug("Skipping %s compile: %s", type(self), e) return None # Triton kernel wrapper nodes contain references to the kernel_side_table diff --git a/torch/_inductor/compile_fx_subproc.py b/torch/_inductor/compile_fx_subproc.py index 0dbf71cbb175c..f72bcfbfa4ba3 100644 --- a/torch/_inductor/compile_fx_subproc.py +++ b/torch/_inductor/compile_fx_subproc.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING from typing_extensions import final, override -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch.fx from torch._inductor.compile_worker.subproc_pool import ( AnyPool, diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 1e6753cb66d6d..34830133a61bb 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -16,6 +16,7 @@ if TYPE_CHECKING: from torch._inductor.choices import InductorChoices + from torch._inductor.cudagraph_utils import CUDAGraphPolicy inplace_padding = os.environ.get("TORCHINDUCTOR_INPLACE_PADDING", "1") == "1" can_inplace_pad_graph_input = False # ease testing @@ -208,7 +209,12 @@ def prologue_fusion_enabled() -> bool: # Controls automatic precompiling of common include files for codecache.CppCodeCache # (i.e. for cpp_wrapper mode and for cpp kernels on CPU). AOTI header precompiling is # controlled by a separate flag. -cpp_cache_precompile_headers: bool = not is_fbcode() +cpp_cache_precompile_headers: bool = ( + os.environ.get( + "TORCHINDUCTOR_CPP_CACHE_PRECOMPILE_HEADERS", "0" if is_fbcode() else "1" + ) + == "1" +) online_softmax = os.environ.get("TORCHINDUCTOR_ONLINE_SOFTMAX", "1") == "1" @@ -300,7 +306,7 @@ def prologue_fusion_enabled() -> bool: # Registers a custom pregrad pass. Note that the pre-grad IR is 1. # non-functional, 2. non-normalized, and 3. prone to change. Ideally we should # use post-grad passes. -pre_grad_custom_pass: Callable[[torch.fx.graph.Graph], None] | None = None +pre_grad_custom_pass: torch._inductor.custom_graph_pass.CustomGraphPassType = None # Registers a custom pass to be run right before fusion in Inductor scheduler. # WARNING: Inductor scheduler IR is at prototype stage and subject to change, @@ -373,6 +379,11 @@ def prologue_fusion_enabled() -> bool: # but the mul gets fused with other pointwise ops instead. force_fuse_int_mm_with_mul = False +# Prevent unfusing addmm into mm+add for bf16/fp16 to avoid precision loss +# from extra truncation at the mm output. Set to False to allow unfusing +# (may improve perf at the cost of accuracy for some models). +keep_addmm_fused_for_half_dtypes = True + # DEPRECATED. This setting is ignored. use_mixed_mm = True @@ -545,6 +556,15 @@ def _autotune_num_choices_displayed_default() -> int | None: == "1" ) +# Pluggable CUDAGraph wrapping policy. When set to a ``CUDAGraphPolicy`` +# instance, ``post_compile`` delegates cudagraph wrapping to the policy +# instead of the built-in ``cudagraphify`` pipeline. This allows custom +# cudagraph implementations, selective inner-vs-outer wrapping for +# regional compilation, and shared memory pool management. +# +# See ``torch._inductor.cudagraph_utils.CUDAGraphPolicy`` for the base class. +cudagraph_policy: "CUDAGraphPolicy | None" = None + # register ops upon which inductor should partition the graph. name format should be # "namespace::kernel_name" (e.g., aten::mm) for op overload packet, or # "namespace::kernel_name.overload" (e.g., aten::mm.default). @@ -655,6 +675,15 @@ def _nvgemm_max_profiling_configs_default() -> int | None: # Use fx graph passes use_pre_grad_passes: bool = True + +# "early": pre-grad passes run before cache lookup (every compile). +# "late": pre-grad passes run after cache lookup (only on cache miss); +# requires custom passes to implement uuid() for the cache key. +# "default": resolves to "late" when possible (no custom pass, or custom pass +# with uuid), falls back to "early" otherwise. +pre_grad_pass_timing: Literal["early", "late", "default"] = "default" + + use_joint_graph_passes: bool = True use_post_grad_passes: bool = True @@ -712,10 +741,35 @@ def _nvgemm_max_profiling_configs_default() -> int | None: # AutoHeuristic is a framework that allows one to collect data from autotuning, use the data to learn a heuristic, and # generate the learned heuristic to code which is shipped with the compiler -# Specify a list of comma separated optimizations to collect data for -autoheuristic_collect = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "") -# Specify a list of comma separated optimizations to use learned heuristics for -autoheuristic_use = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm") + + +def _parse_autoheuristic_collect_env(): + collect_env = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_COLLECT", "").split(",") + return collect_env + + +def _parse_autoheuristic_use_env(): + use_env = os.environ.get("TORCHINDUCTOR_AUTOHEURISTIC_USE", "mixed_mm").split(",") + return use_env + + +class autoheuristic_collect: + """ + Config for which autoheuristic optimizations should collect training data. + """ + + pad_mm = "pad_mm" in _parse_autoheuristic_collect_env() + mixed_mm = "mixed_mm" in _parse_autoheuristic_collect_env() + + +class autoheuristic_use: + """ + Config for which autoheuristic optimizations should use learned heuristics. + """ + + pad_mm = "pad_mm" in _parse_autoheuristic_use_env() + mixed_mm = "mixed_mm" in _parse_autoheuristic_collect_env() + # If set to 1, will run a JIT post compile hook if one is set. run_jit_post_compile_hook = ( @@ -728,11 +782,23 @@ def run_autoheuristic(name: str) -> bool: def collect_autoheuristic(name: str) -> bool: - return name in torch._inductor.config.autoheuristic_collect.split(",") + if name == "pad_mm": + return autoheuristic_collect.pad_mm + elif name == "mixed_mm": + return autoheuristic_collect.mixed_mm + else: + # For test compatibility with non-standard ops (e.g. "test", "foo" used in tests) + return name in _parse_autoheuristic_collect_env() def use_autoheuristic(name: str) -> bool: - return name in torch._inductor.config.autoheuristic_use.split(",") + if name == "pad_mm": + return autoheuristic_use.pad_mm + elif name == "mixed_mm": + return autoheuristic_use.mixed_mm + else: + # For test compatibility with non-standard ops (e.g. "test", "foo" used in tests) + return name in _parse_autoheuristic_use_env() # If set to "DEFAULT", this will use the default log path specified in autoheuristic.py. @@ -778,9 +844,22 @@ def use_autoheuristic(name: str) -> bool: None # TODO(xuanzh): harden this to make it non optional ) +# Defer early realization of cheap output nodes (0 buffer reads, small opcount) +# to prevent cascade materialization in fullgraph compilation. +# Shared constants/indices saved for backward get eagerly materialized because +# they are graph outputs with multiple users, which inflates downstream read +# counts and can trigger suboptimal Triton block size heuristics. +delay_realize_cheap_outputs: bool = Config( + env_name_force="TORCHINDUCTOR_DELAY_REALIZE_CHEAP_OUTPUTS", + default=True, +) + # fallback to eager for random/dropout, this is slow but useful for debugging fallback_random = False +# align random/dropout as eager mode(aten) behavior, maintaining fused possibility and faster gpu kernel +align_random_eager = False + # fallback embedding_bag_byte_unpack to eager fallback_embedding_bag_byte_unpack = False @@ -790,7 +869,17 @@ def use_autoheuristic(name: str) -> bool: os.environ.get("TORCHINDUCTOR_ASSUME_UNALIGNED_FALLBACK_OUTPUT") == "1" ) -# Custom InductorChoices callable to use (can be a class or functools.partial with kwargs) +# Factory callable that returns a custom InductorChoices instance. +# A callable (rather than a class) is used to defer imports and avoid circular +# dependencies between config and the choices module. Example: +# +# def _custom_choices_factory(): +# from my_package.choices import MyInductorChoices +# return MyInductorChoices() +# +# config.inductor_choices_class = _custom_choices_factory +# +# The returned instance must implement uuid() for cache key serialization. inductor_choices_class: Callable[[], "InductorChoices"] | None = None # fuse even in cases without common reads @@ -807,6 +896,9 @@ def use_autoheuristic(name: str) -> bool: ) == "1" ) +loop_reindexing_after_fusion: bool = ( + os.environ.get("TORCHINDUCTOR_LOOP_REINDEXING_AFTER_FUSION", "1") == "1" +) # When trying to fuse two nodes, one with: @@ -879,6 +971,9 @@ def use_autoheuristic(name: str) -> bool: # if we know they affect numerics. WARNING: Expect perf hit in this mode. deterministic = os.getenv("TORCHINDUCTOR_DETERMINISTIC") == "1" +# Batch-invariant mode: stable per-sample compiled kernel across batch sizes. Implies deterministic. +batch_invariant = os.getenv("TORCHINDUCTOR_BATCH_INVARIANT") == "1" + # When we do split reduction, this number control the minimum value for # num_split. Too small num_split make the split reduction less efficient. # It's a much bigger problem when we compile a dynamic shape kernel with @@ -897,6 +992,9 @@ def use_autoheuristic(name: str) -> bool: # assert that indirect indexing does not read / write out of bounds assert_indirect_indexing = True +# skip emitting runtime assertions for unbacked symbols in generated code +do_not_emit_runtime_assertions = False + # compute CSE bounds on variables that do not appear in the FX graph compute_all_bounds = False @@ -915,10 +1013,15 @@ def use_autoheuristic(name: str) -> bool: combo_kernel_foreach_dynamic_shapes = True # Maximum number of arguments (read/write buffers) allowed in a combo kernel combo_kernel_max_num_args = 250 +# Maximum number of sub-kernels allowed in a single combo kernel +combo_kernel_max_num_nodes = 8 # When True, each combo sub-kernel gets its own block sizes (XBLOCK_0, YBLOCK_0, etc.) # allowing different sub-kernels to use different tile sizes based on their heuristics. # When False, all sub-kernels share block sizes (XBLOCK, YBLOCK, etc.) combo_kernel_per_subkernel_blocks = False +# When True, combo-kernel autotuning groups sub-kernels that share the same +# candidate config set and kernel-analysis signature. Disabled by default. +combo_kernel_autotune_grouping = False # When True, only pointwise kernels are eligible for combo kernel fusion. combo_kernels_pointwise_only = False @@ -1072,6 +1175,9 @@ class aten_distributed_optimizations: # In deterministic mode, this setting is ignored and "analytical" is used. compute_estimator: Literal["analytical", "benchmark"] = "benchmark" + # Chrome Trace JSON path for profile-guided runtime estimation. + profile_guided_estimations_profile_path: str | None = None + # Maximum memory increase above baseline for prefetch operations # Uses minimum of absolute cap and ratio of baseline max_memory_increase_gb: float | None = None # Absolute cap in GB @@ -1096,9 +1202,44 @@ class aten_distributed_optimizations: # as atomic units with memory-bound runtime estimates. enable_fusion_regions: bool | None = None + # Default bucketing mode for auto and manual overlap scheduling + # "default": traced bucketing, fully lowered by inductor during compilation + # "custom_ops": temporary bucketing using custom ops to hide parts from inductor + # "custom_ops_multidtype": same as custom_ops but buckets multiple dtypes + # (e.g. bf16 and fp32) into one bucket + # "coalesced": zero-copy batching via reduce_scatter_tensor_coalesced + # (reduce_scatter only; all_gather falls back to default) + # None means "auto" — the compiler picks the best mode + bucket_mode: ( + Literal["default", "custom_ops", "custom_ops_multidtype", "coalesced"] | None + ) = None + # Prioritize bucketing during overlap scheduling by grouping candidates by bucket key prioritize_bucketing_during_scheduling: bool = True + # Verify FX graphs are identical across ranks before overlap scheduling. + # Detects non-SPMD graphs that would cause NCCL collective ordering + # mismatches and hangs. + spmd_check: bool = True + + # Action on SPMD graph mismatch: "warn" logs a warning, "error" raises + # RuntimeError. "error" fails fast instead of risking silent NCCL hang. + # TODO(ivankobzarev): change default to "error" after real-world testing. + spmd_mismatch: Literal["warn", "error"] = "warn" + + # When True, automatically remove extra deps that create cycles instead of + # raising an error. Set this to True as a workaround if overlap scheduling + # fails with a cycle error, and file a bug so the root cause can be fixed. + overlap_scheduling_autofix_cycles: bool = False + + # Replace NCCL collectives with low-contention variants that use + # copy engine instead of SMs, freeing SMs for overlapping compute. + enable_low_contention_collectives: bool = False + + # Minimum per-rank bytes for LC replacement. Below this, LC barrier + # overhead exceeds the benefit. Set to 0 to disable. + low_contention_min_bytes_per_rank: int = 16 * 1024 * 1024 + def parallel_compile_enabled_internally() -> bool: """ @@ -1138,11 +1279,7 @@ def decide_compile_threads() -> int: compile_threads = 1 log.info("compile_threads set to 1 in fbcode") else: - cpu_count = ( - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else os.cpu_count() - ) + cpu_count = torch._utils.cpu_count() assert cpu_count compile_threads = min(32, cpu_count) log.info("compile_threads set to %d", compile_threads) @@ -1717,6 +1854,15 @@ class triton: os.environ.get("TORCHINDUCTOR_PERSISTENT_REDUCTIONS", "1") == "1" ) + # Decompose sort-based ops (sort, mode, median) to generate Triton + # kernels instead of falling back to ATen eager. When enabled, sort + # removes the default 512-element dimension limit and uses int32 + # indices (up to 2^31-1 elements), and mode/median decompose into + # sort + reduction / pointwise ops that Inductor can lower to Triton. + decompose_sort_ops: bool = ( + os.environ.get("TORCHINDUCTOR_DECOMPOSE_SORT_OPS", "0") == "1" + ) + # For small output size reductions uses cross thread-block synchronization to gain more parallelism cooperative_reductions = ( os.environ.get("TORCHINDUCTOR_COOPERATIVE_REDUCTIONS", "0") == "1" @@ -1737,6 +1883,13 @@ class triton: # hint to Triton when arguments are divisible by 16 divisible_by_16 = os.environ.get("TORCHINDUCTOR_DIVISIBLE_BY_16", "1") == "1" + # On AMD/HIP, annotate pointer args with tt.pointer_range=32 when the + # tensor storage provably fits in 2 GB. This lets Triton emit efficient + # buffer load/store ops. Disable if a Triton compiler bug is triggered. + emit_pointer_range_32 = ( + os.environ.get("TORCHINDUCTOR_EMIT_POINTER_RANGE_32", "1") == "1" + ) + # Minimum R0_BLOCK to be used for a TritonSplitScanKernel # NOTE: This also indirectly controls the size of workspace buffer required min_split_scan_rblock = 256 @@ -1786,8 +1939,9 @@ class triton: # Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental) codegen_upcast_to_fp32 = True - # Whether persistent matmul kernels should be enabled this flag only has effect when on h100 - # with a version of triton new enough to support TMA + # Whether persistent matmul kernels should be enabled. On NVIDIA H100+ with TMA support, + # this enables TMA persistent kernels. On AMD GPUs without TMA, this enables + # non-TMA persistent kernels as a fallback. enable_persistent_tma_matmul = ( os.environ.get("ENABLE_PERSISTENT_TMA_MATMUL", "0") == "1" ) @@ -1837,6 +1991,10 @@ class triton: # this could be helpful to avoid recompilations in some cases mix_order_reduction_non_strict_mode = False + # Maximum external read buffers (loads) in a mix-order reduction + # kernel. Set to 0 to disable the check. + mix_order_reduction_max_reads = 10 + # Don't allow multi-stages by default to avoid out of shared memory mix_order_reduction_allow_multi_stages = ( os.environ.get("TORCHINDUCTOR_MIX_ORDER_REDUCTION_ALLOW_MULTI_STAGES") == "1" @@ -1851,14 +2009,6 @@ class triton: # This ensures the last N runs are saved, where N is this value max_kernel_dump_occurrences = 3 - # TLX template mode: "default", "allow", or "force" - tlx_mode: str = os.environ.get("TORCHINDUCTOR_TLX_MODE", "default") - - # TLX heuristic config: when True, use heuristic-based config selection for TLX templates - tlx_heuristic_config: bool = ( - os.environ.get("TORCHINDUCTOR_TLX_HEURISTIC_CONFIG", "1") == "1" - ) - proton_profiling: bool = ( os.environ.get("TORCHINDUCTOR_TRITON_PROTON_PROFILING", "0") == "1" ) @@ -1897,6 +2047,11 @@ class aot_inductor: debug_compile = os.environ.get("AOT_INDUCTOR_DEBUG_COMPILE", "0") == "1" debug_symbols = os.environ.get("AOT_INDUCTOR_DEBUG_SYMBOLS", "0") == "1" + # Enable frame pointers for profiling tools (e.g. strobelight) + enable_frame_pointer = ( + os.environ.get("AOT_INDUCTOR_ENABLE_FRAME_POINTER", "0") == "1" + ) + # Annotate generated main wrapper function, i.e. AOTInductorModel::run_impl, # to use which cpp compiler optimization level, default to O1 compile_wrapper_opt_level = os.environ.get( @@ -2241,6 +2396,11 @@ class xpu(cutlass): # e.g. "20250201". version: str | None = None + # Path to Intel OneAPI. + oneapi_root: str | None = None + + cutlass_dir = os.path.realpath(os.environ.get("TORCHINDUCTOR_CUTLASS_DIR", "")) + class rocm: # Offload arch list for device code compilation, e.g. ["gfx90a", "gfx942"]. @@ -2452,6 +2612,9 @@ class trace: "post_grad_custom_post_pass", "_fuse_ddp_communication_passes", "_pre_fusion_custom_pass", + # CUDAGraphPolicy objects are not picklable and only affect + # post_compile wrapping, not compiled code itself. + "cudagraph_policy", ] _cache_config_ignore_prefix: list[str] = [ @@ -2469,10 +2632,15 @@ class trace: "post_grad_custom_pre_pass", "joint_custom_pre_pass", "joint_custom_post_pass", + "pre_grad_custom_pass", "_fuse_ddp_communication_passes", "_pre_fusion_custom_pass", + # CUDAGraphPolicy only affects post_compile, not compiled output + "cudagraph_policy", # tests assume that changes here don't invalidate cache "always_complex_memory_overlap_TESTING_ONLY", + # timing affects cache structure, not cache content + "pre_grad_pass_timing", # cache related options are not relevant to cache results "fx_graph_cache", "fx_graph_remote_cache", @@ -2480,6 +2648,12 @@ class trace: "autotune_remote_cache", ] +# Config keys whose values are callable factories. save_config_portable will +# instantiate the factory and use .uuid() for serialization. +_cache_config_factory_keys: list[str] = [ + "inductor_choices_class", +] + # External callable for matmul tuning candidates external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] @@ -2559,7 +2733,7 @@ class test_configs: if TYPE_CHECKING: - from torch.utils._config_typing import * # noqa: F401, F403 + from torch.utils._config_typing import * # noqa: F403 class eager_numerics: diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index f2d7696b4b57e..4da65df7adf93 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -437,6 +437,56 @@ def convert_cubin_to_obj( return obj_file +def batch_convert_cubins_to_obj( + cubins: list[tuple[str, str]], + output_dir: str, + cpp_compiler: str = "gcc", +) -> str: + """Convert multiple cubin files to a single .o using batched .incbin assembly. + + Instead of spawning 3 subprocesses per cubin (ld + 2x objcopy), generates + a single .S file with .incbin directives for all cubins and compiles it + with one compiler invocation. Produces bit-identical rodata and symbols + as the per-cubin convert_cubin_to_obj approach. + + Args: + cubins: list of (cubin_file_path, kernel_name) tuples. + output_dir: directory for the generated .S and .o files. + cpp_compiler: C compiler to use for assembling (default: gcc). + + Returns: + Path to the combined .o file. + """ + asm_path = os.path.join(output_dir, "cubins_combined.S") + obj_path = os.path.join(output_dir, "cubins_combined.o") + + with open(asm_path, "w") as f: + f.write(".section .rodata\n") + for cubin_file, kernel_name in cubins: + # Use absolute path to avoid issues with working directory + abs_cubin = os.path.abspath(cubin_file) + escaped_path = abs_cubin.replace("\\", "\\\\").replace('"', '\\"') + f.write( + f".balign 16\n" + f".global __{kernel_name}_start\n" + f".global __{kernel_name}_end\n" + f"__{kernel_name}_start:\n" + f'.incbin "{escaped_path}"\n' + f"__{kernel_name}_end:\n" + f".global __{kernel_name}_size\n" + f".set __{kernel_name}_size, " + f"__{kernel_name}_end - __{kernel_name}_start\n" + ) + + subprocess.run( + [cpp_compiler, "-c", asm_path, "-o", obj_path], + capture_output=True, + text=True, + check=True, + ) + return obj_path + + @functools.cache def _is_apple_clang(cpp_compiler: str) -> bool: version_string = subprocess.check_output([cpp_compiler, "--version"]).decode("utf8") @@ -688,9 +738,10 @@ def __init__( self._preprocessing: bool = preprocessing def _process_compile_only_options(self) -> None: - if self._compile_only: + if self._compile_only or self._precompiling or self._preprocessing: self._libraries_dirs = [] self._libraries = [] + self._ldflags = [] def _remove_duplicate_options(self) -> None: self._definitions = _remove_duplication_in_list(self._definitions) @@ -802,7 +853,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]: # For Intel oneAPI, ref: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus?view=msvc-170 "Zc:__cplusplus", # Enable max compatible to msvc for oneAPI headers. - # ref: https://github.com/pytorch/pytorch/blob/db38c44ad639e7ada3e9df2ba026a2cb5e40feb0/cmake/public/utils.cmake#L352-L358 # noqa: B950 + # ref: https://github.com/pytorch/pytorch/blob/db38c44ad639e7ada3e9df2ba026a2cb5e40feb0/cmake/public/utils.cmake#L352-L358 "permissive-", ] else: @@ -912,6 +963,12 @@ def _get_optimization_cflags( cflags += debug_cflags ldflags += debug_ldflags + if config.aot_inductor.enable_frame_pointer: + if _IS_WINDOWS: + cflags.append("Oy-") + else: + cflags.append("fno-omit-frame-pointer") + cflags += _get_ffast_math_flags() if _IS_WINDOWS: @@ -1097,12 +1154,13 @@ def _setup_standard_sys_libs( cpp_compiler: str, aot_mode: bool, use_relative_path: bool, -) -> tuple[list[str], list[str], list[str]]: +) -> tuple[list[str], list[str], list[str], list[str]]: cflags: list[str] = [] include_dirs: list[str] = [] passthrough_args: list[str] = [] + ldflags: list[str] = [] if _IS_WINDOWS: - return cflags, include_dirs, passthrough_args + return cflags, include_dirs, passthrough_args, ldflags if config.is_fbcode(): # TODO(T203137008) Can we unify these flags with triton_cc_command? @@ -1127,12 +1185,12 @@ def _setup_standard_sys_libs( if _is_clang(cpp_compiler): passthrough_args.append(" --rtlib=compiler-rt") - passthrough_args.append(" -fuse-ld=lld") - passthrough_args.append(f" -Wl,--script={linker_script}") passthrough_args.append(" -B" + build_paths.glibc_lib) - passthrough_args.append(" -L" + build_paths.glibc_lib) + ldflags.append("fuse-ld=lld") + ldflags.append(f"Wl,--script={linker_script}") + ldflags.append("L" + build_paths.glibc_lib) - return cflags, include_dirs, passthrough_args + return cflags, include_dirs, passthrough_args, ldflags def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]: @@ -1167,6 +1225,8 @@ def _get_torch_related_args( libraries_dirs = [TORCH_LIB_PATH] if sys.platform != "darwin" and not config.is_fbcode(): libraries.extend(["torch", "torch_cpu"]) + if _IS_WINDOWS: + libraries.append("c10") if not aot_mode: libraries.append("torch_python") else: @@ -1401,9 +1461,8 @@ def _get_openmp_args( if config.is_fbcode(): include_dir_paths.append(build_paths.openmp_include) - openmp_lib = build_paths.openmp_lib_so - fb_openmp_extra_flags = f"-Wp,-fopenmp {openmp_lib}" - passthrough_args.append(fb_openmp_extra_flags) + passthrough_args.append("-Wp,-fopenmp") + lib_dir_paths.append(os.path.dirname(build_paths.openmp_lib_so)) libs.append("omp") else: @@ -1491,6 +1550,7 @@ def get_cpp_torch_options( sys_libs_cflags, sys_libs_include_dirs, sys_libs_passthrough_args, + sys_libs_ldflags, ) = _setup_standard_sys_libs(cpp_compiler, aot_mode, use_relative_path) isa_macros, isa_ps_args_build_flags = _get_build_args_of_chosen_isa(vec_isa) @@ -1532,7 +1592,7 @@ def get_cpp_torch_options( + omp_include_dir_paths ) cflags = sys_libs_cflags + omp_cflags - ldflags = omp_ldflags + ldflags = sys_libs_ldflags + omp_ldflags libraries_dirs = python_libraries_dirs + torch_libraries_dirs + omp_lib_dir_paths libraries = torch_libraries + omp_lib passthrough_args = ( @@ -1640,6 +1700,183 @@ def _find_libcudart_static(path: str) -> Path | None: return None +def _gen_mingw_import_lib(dll_path: str, def_path: str, import_lib_path: str) -> None: + """Generate a MinGW import library (.a) from a DLL using gendef and dlltool.""" + dll_name = os.path.basename(dll_path) + with open(def_path, "w") as def_file: + subprocess.run( + ["gendef", "-", dll_path], + stdout=def_file, + stderr=subprocess.PIPE, + check=True, + ) + + subprocess.run( + [ + "x86_64-w64-mingw32-dlltool", + "-d", + def_path, + "-l", + import_lib_path, + "-D", + dll_name, + ], + stderr=subprocess.PIPE, + check=True, + ) + log.info("Generated MinGW import library %s from %s", import_lib_path, dll_name) + + +# MSVC /GS buffer security check stubs for MinGW cross-compilation. +# CUDA 13.0+ cudart.lib contains MSVC-compiled static objects that reference +# these symbols (__security_cookie, __security_check_cookie, __GSHandlerCheck). +# When cross-compiling with MinGW, we provide no-op stubs so the linker can +# resolve them. At runtime on Windows, the CUDA runtime DLL handles its own +# security checks internally; the static loader code that references these +# symbols is a thin shim whose /GS instrumentation is safe to stub out. +_MSVC_GS_STUBS_SOURCE = """\ +#include +uint64_t __security_cookie = 0x00002B992DDFA232ULL; +void __security_check_cookie(uint64_t cookie) { (void)cookie; } +void __GSHandlerCheck(void) {} +""" + + +def _create_msvc_gs_stubs_lib(output_dir: str) -> str | None: + """ + Create a static library with MSVC GS security symbol stubs for MinGW. + + Returns the library name (without lib prefix / .a suffix) if successful, + or None on failure. + """ + stubs_lib = os.path.join(output_dir, "libmsvc_gs_stubs.a") + if os.path.exists(stubs_lib): + return "msvc_gs_stubs" + + src_path = "" + obj_path = "" + try: + src_path = os.path.join(output_dir, "_msvc_gs_stubs.c") + obj_path = os.path.join(output_dir, "_msvc_gs_stubs.o") + with open(src_path, "w") as f: + f.write(_MSVC_GS_STUBS_SOURCE) + + mingw_gcc = MINGW_GXX.replace("g++", "gcc") + subprocess.run( + [mingw_gcc, "-c", src_path, "-o", obj_path], + stderr=subprocess.PIPE, + check=True, + ) + mingw_ar = MINGW_GXX.replace("g++", "ar") + subprocess.run( + [mingw_ar, "rcs", stubs_lib, obj_path], + stderr=subprocess.PIPE, + check=True, + ) + log.info("Created MSVC GS stubs library: %s", stubs_lib) + return "msvc_gs_stubs" + except (FileNotFoundError, subprocess.CalledProcessError): + log.warning( + "Failed to create MSVC GS stubs library.", + exc_info=True, + ) + if os.path.exists(stubs_lib): + os.remove(stubs_lib) + return None + finally: + for f in [src_path, obj_path]: + if f and os.path.exists(f): + os.remove(f) + + +def _ensure_mingw_cudart_import_lib(libraries_dirs: list[str]) -> list[str]: + """ + Auto-generate a MinGW-compatible import library (libcudart.a) + from the CUDA runtime DLL. This avoids linking against the hybrid cudart.lib + which contains MSVC-compiled static objects with /GS security symbols that + MinGW cannot resolve. + + Falls back to creating MSVC GS security stubs if the DLL is unavailable, + and falls back gracefully to the original cudart.lib if that also fails. + + Returns a list of extra library names to link (e.g. ["msvc_gs_stubs"]). + """ + import glob + + windows_cuda_home = os.environ.get("WINDOWS_CUDA_HOME") + if not windows_cuda_home: + log.debug( + "WINDOWS_CUDA_HOME not set, skipping MinGW cudart import lib generation" + ) + return [] + + for lib_dir in libraries_dirs: + if os.path.exists(os.path.join(lib_dir, "libcudart.a")): + log.debug("libcudart.a already exists in %s, skipping generation", lib_dir) + return [] + + # Find the CUDA runtime DLL for import lib generation + bin_dir = os.path.join(windows_cuda_home, "bin", "x64") + if not os.path.isdir(bin_dir): + bin_dir = os.path.join(windows_cuda_home, "bin") + dll_candidates = glob.glob(os.path.join(bin_dir, "cudart64_*.dll")) + + # Find a writable directory containing cudart.lib for output + output_dir = None + for lib_dir in libraries_dirs: + if os.path.isdir(lib_dir) and os.access(lib_dir, os.W_OK): + if os.path.exists(os.path.join(lib_dir, "cudart.lib")): + output_dir = lib_dir + break + + if not dll_candidates: + log.warning( + "No cudart64_*.dll found in %s. Cannot generate MinGW import library. " + "Will create MSVC GS security stubs as fallback.", + bin_dir, + ) + # Fallback: create GS stubs so the hybrid cudart.lib can link + if output_dir is not None: + stub_lib = _create_msvc_gs_stubs_lib(output_dir) + if stub_lib: + return [stub_lib] + return [] + + if output_dir is None: + log.warning( + "No writable directory containing cudart.lib found. " + "Cannot generate MinGW import library. " + "If linking fails with undefined references to __security_cookie, " + "ensure cudart.lib is present in one of: %s", + libraries_dirs, + ) + return [] + + dll_path = dll_candidates[0] + dll_name = os.path.basename(dll_path) + + def_path = os.path.join(output_dir, dll_name.replace(".dll", ".def")) + import_lib_path = os.path.join(output_dir, "libcudart.a") + + try: + _gen_mingw_import_lib(dll_path, def_path, import_lib_path) + return [] + except (FileNotFoundError, subprocess.CalledProcessError): + log.warning( + "Failed to generate MinGW cudart import library. " + "Falling back to MSVC GS stubs.", + exc_info=True, + ) + for f in [def_path, import_lib_path]: + if os.path.exists(f): + os.remove(f) + # Fallback: create GS stubs + stub_lib = _create_msvc_gs_stubs_lib(output_dir) + if stub_lib: + return [stub_lib] + return [] + + def _transform_cuda_paths(lpaths: list[str]) -> None: # This handles two cases: # 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs @@ -1708,7 +1945,8 @@ def get_cpp_torch_device_options( else: libraries += ["cuda", "torch_cuda"] if config.aot_inductor.cross_target_platform == "windows": - libraries += ["cudart"] + extra_libs = _ensure_mingw_cudart_import_lib(libraries_dirs) + libraries += ["cudart"] + extra_libs _transform_cuda_paths(libraries_dirs) if device_type == "xpu": @@ -1789,6 +2027,7 @@ def __init__( min_optimize: bool = False, precompiling: bool = False, preprocessing: bool = False, + compiler: str = "", ) -> None: super().__init__( vec_isa=vec_isa, @@ -1802,6 +2041,7 @@ def __init__( min_optimize=min_optimize, precompiling=precompiling, preprocessing=preprocessing, + compiler=compiler, ) device_definitions: list[str] = [] @@ -2002,6 +2242,11 @@ def __init__( assert len(sources) == 1 # See above; we can currently assume this is not on MSVC. self._sources_args = f"-x c++-header {sources[0]}" + if self._use_relative_path and _is_clang(BuildOption.get_compiler()): + # Store PCH paths relative to -isysroot so the .pch can + # be used from a different build directory. The matching + # -isysroot is injected by build_fbcode_re(). + self._cflags_args += " -relocatable-pch -Xclang -fno-pch-timestamp " else: self._sources_args = " ".join(sources) @@ -2025,6 +2270,13 @@ def __init__( ) else: self._include_dirs_args = f"-include {precompiled_header} " + if self._use_relative_path and _is_clang(BuildOption.get_compiler()): + # Skip clang's own PCH validation during consumption. + # _precompile_header() already handles cache invalidation + # via content hashing, and -fno-validate-pch allows the + # PCH to be used even when the original source file is at + # a different path (e.g. across Remote Execution workers). + self._cflags_args += " -Xclang -fno-validate-pch " for inc_dir in BuildOption.get_include_dirs(): if _IS_WINDOWS: @@ -2106,7 +2358,7 @@ def build_fbcode_re( self, ) -> None: with dynamo_timed("compile_file"): - command = self.get_command_line().split() + command = shlex.split(self.get_command_line()) try: output_path = self._target_file # When we build remotely, we need to make sure to carefully copy any files @@ -2120,6 +2372,35 @@ def build_fbcode_re( shutil.copy(src, os.path.join(tmp_dir, os.path.basename(src))) dest_include_path = os.path.join(tmp_dir, "include") shutil.copytree(torch_includes_path, dest_include_path) + + # Copy precompiled header (.h and .gch/.pch) into the + # build directory and rewrite the -include flag so the + # compiler can find it. + pch_header = self._build_option.precompiled_header + if pch_header and os.path.isfile(pch_header): + pch_ext = ".pch" if _IS_WINDOWS or not is_gcc() else ".gch" + pch_compiled = pch_header + pch_ext + pch_basename = os.path.basename(pch_header) + shutil.copy(pch_header, os.path.join(tmp_dir, pch_basename)) + if os.path.isfile(pch_compiled): + shutil.copy( + pch_compiled, + os.path.join(tmp_dir, pch_basename + pch_ext), + ) + command = [ + pch_basename if arg == pch_header else arg + for arg in command + ] + + # Relocatable PCH stores include paths relative to + # -isysroot. Set sysroot to the tmp build dir so + # paths resolve correctly in both precompilation and + # later kernel compilations that consume the PCH. + if self._precompiling or ( + pch_header and os.path.isfile(pch_header) + ): + command[1:1] = ["-isysroot", "."] + # Run the build, raising RuntimeError on failure instead of # SkipFrame so compilation errors propagate rather than # silently falling back to eager execution. @@ -2177,7 +2458,7 @@ def save_compile_cmd_to_cmake( f""" cmake_minimum_required(VERSION 3.27 FATAL_ERROR) project({self._target_name} LANGUAGES CXX) - set(CMAKE_CXX_STANDARD 17) + set(CMAKE_CXX_STANDARD 20) # Set a library target add_library({self._target_name} {target_library_type}) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index e4aee2d6c0258..1cab625eabe2f 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -67,7 +67,7 @@ class VecISA: auto tmp1 = tmp0.exp(); tmp1.store(in_out_ptr0); } -""" # noqa: B950 +""" _avx_py_load = """ import torch diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 9602a3a9cee6a..df4d2b6030910 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -56,6 +56,10 @@ import torch.fx from torch import Tensor from torch._dynamo.callback import CallbackTrigger +from torch._dynamo.graph_bytecode_inputs import ( + CURRENT_STREAM_INDEX, + set_external_object_by_index, +) from torch._dynamo.mutation_guard import GenerationTracker from torch._dynamo.utils import counters, dynamo_timed, preserve_rng_state from torch._higher_order_ops.cudagraph_conditional_nodes import ( @@ -84,6 +88,8 @@ PlaceholderInfo, WrappedFunction, ) +from torch._library.opaque_object import is_opaque_value +from torch._opaque_base import OpaqueBase from torch.multiprocessing.reductions import StorageWeakRef from torch.storage import UntypedStorage from torch.utils import _pytree as pytree @@ -631,6 +637,19 @@ def _use_cuda_memory_pool_manager( torch.cuda.current_stream().wait_stream(stream) +@contextlib.contextmanager +def _update_current_stream_external_object() -> Generator[None, None, None]: + """Update the external object registry so custom ops see the capture stream. + + During cudagraph recording/warmup the current stream differs from the + trace-time default stream. The external object at CURRENT_STREAM_INDEX + must reflect the actual current stream so that custom ops (e.g. event + record/wait) executed during capture use the right stream. + """ + set_external_object_by_index(CURRENT_STREAM_INDEX, torch.cuda.current_stream()) + yield + + def map_to_ref(t: Tensor | None) -> StorageWeakRefWrapper | None: if not isinstance(t, torch.Tensor): assert t is None @@ -723,6 +742,8 @@ def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]: _use_cuda_memory_pool_manager( self.device_index, self.cuda_graphs_pool, self.stream ), + # NB: must go after _use_cuda_memory_pool_manager which switches the stream + _update_current_stream_external_object(), ControlFlowOpWarmupDispatchMode(), get_history_recording(), ): @@ -878,10 +899,7 @@ def __init__( # Enable re-record a cudagraph when static tensor address changed. # if not we should error when it changed. - self.rerecord_if_static_inputs_change = ( - torch._dynamo.config.inline_inbuilt_nn_modules - or torch._inductor.config.triton.cudagraph_support_input_mutation - ) + self.rerecord_if_static_inputs_change = True # if this is a root parent will be None. use weakref to prevent reference cycle self._parent = weakref.ref(parent) if parent is not None else None @@ -935,9 +953,26 @@ def __init__( # and also aliases an output of the current CUDAGraphNode self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) + # Opaque values (e.g. DeviceMesh, ProcessGroup) are non-tensor + # inputs that cannot be copied like tensors. We include them in + # static_input_idxs to keep them out of non_static_input_idx + # (which drives the tensor-copy path during replay). "Static" + # here just means "don't try to copy this as a tensor" — it + # does NOT mean the object is semantically immutable. + # + # Opaque indices must also be excluded from any list passed to + # _tensors_data_ptrs_at_indices_equal (the C++ data-pointer + # stability check), because opaque objects have no data_ptr. + # That is why tensor_static_input_idxs and + # non_managed_static_input_idxs filter them out below. + opaque_input_idxs = OrderedSet( + i for i, inp in enumerate(inputs) if is_opaque_value(inp) + ) + static_input_idxs = OrderedSet(wrapped_function.static_input_idxs) + cudagraph_managed_idxs = OrderedSet(self.cudagraph_managed_idxs) + self.static_input_idxs: list[int] = list( - OrderedSet(wrapped_function.static_input_idxs) - | OrderedSet(self.cudagraph_managed_idxs) + static_input_idxs | cudagraph_managed_idxs | opaque_input_idxs ) self.non_static_input_idx: LevelList[int] = [ @@ -948,11 +983,13 @@ def __init__( self.non_static_input_idx ) - self.non_managed_static_input_idxs: LevelList[int] = [ - i - for i in wrapped_function.static_input_idxs - if i not in self.cudagraph_managed_idxs - ] + self.non_managed_static_input_idxs: LevelList[int] = LevelList( + static_input_idxs - cudagraph_managed_idxs - opaque_input_idxs + ) + + self.tensor_static_input_idxs: list[int] = list( + static_input_idxs | cudagraph_managed_idxs + ) def maybe_get_static_data_ptr( idx: int, @@ -1316,6 +1353,8 @@ def static_input_iter() -> Generator[torch.Tensor, None, None]: pool=self.cuda_graphs_pool, capture_error_mode="thread_local", ), + # NB: must go after torch.cuda.graph which switches the stream + _update_current_stream_external_object(), CUDAGraphCaptureControlFlowOpDispatchMode(), get_history_recording(), ): @@ -1559,7 +1598,7 @@ def _check_liveness( return True def add_child(self, function_id: FunctionID, node: CUDAGraphNode) -> None: - "Adds node as a a child of self" + "Adds node as a child of self" self.children[function_id].append(node) @staticmethod @@ -1732,7 +1771,7 @@ def _allocate_and_copy_recording_inputs( ): for i, inp in enumerate(inputs): if not isinstance(inp, torch.Tensor): - assert isinstance(inp, (int, torch.Generator)) + assert isinstance(inp, (int, torch.Generator, OpaqueBase)) recording_inputs.append(inp) elif i not in self.static_input_idxs: @@ -1791,13 +1830,13 @@ def check_invariants( and not torch._C._tensors_data_ptrs_at_indices_equal( inputs, # type: ignore[arg-type] self.static_input_data_ptrs, - self.static_input_idxs, + self.tensor_static_input_idxs, ) ): status = CheckInvariantStatus.StaticInputIdxMismatch _logger = functools.partial( _logger, - self.static_input_idxs, + self.tensor_static_input_idxs, status, ) return status, _logger @@ -1862,7 +1901,7 @@ def check_memory_pool( live_storages_ptrs: list[StorageWeakRefWrapper], ) -> None: """Validate cudagraph pool allocations against tracked live storages and surface leaks.""" - assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419 + assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter # check if there is a divergence first, then do the expensive snapshot call after @@ -2169,13 +2208,7 @@ def _get_node_id(self) -> GraphID | None: def exceed_rerecord_limit( self, node_id: GraphID | None, function_id: FunctionID ) -> bool: - if torch._dynamo.config.inline_inbuilt_nn_modules: - return False - - return ( - self.num_rerecord[node_id][function_id] - > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit - ) + return False def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType: # we will try to end the current execution lazily, since @@ -2587,9 +2620,9 @@ def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> No self.warned_functions.add(function_id) warnings.warn( - "Unable to hit fast path of CUDAGraphs because of pending, uninvoked backwards. " - "Consider running with torch.no_grad() or using torch.compiler.cudagraph_mark_step_begin() " - "before each model invocation" + "Unable to hit fast path of CUDAGraphs because outputs from a previous step " + "still require backward. Ensure backward() is invoked or detach outputs. " + "You may also call torch.compiler.cudagraph_mark_step_begin() before each model invocation." ) @staticmethod diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 7daed59b9b7be..d2c443c50cec2 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -4,7 +4,7 @@ import dataclasses from collections.abc import Callable from enum import Enum -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeVar import torch from torch._dynamo.utils import counters, get_metrics_context @@ -18,6 +18,10 @@ if TYPE_CHECKING: from collections.abc import Sequence, Set as AbstractSet + from torch._inductor.output_code import OutputCode + +_OC = TypeVar("_OC", bound="OutputCode") + cudagraphs_log = torch._logging.getArtifactLogger(__name__, "cudagraphs") static_inputs_log = torch._logging.getArtifactLogger( @@ -29,6 +33,93 @@ ModelType = Callable[[list[InputType]], OutputType] +class CUDAGraphPolicy: + """Pluggable policy controlling CUDA graph wrapping in Inductor's post_compile. + + Override methods to customize: + - HOW compiled functions are cudagraph-wrapped (cudagraphify) + - WHETHER inner CompiledFxGraphs should be wrapped (should_wrap) + - OUTER wrapping of compound outputs like RegionalOutputCode (wrap_output) + + Set via ``torch._inductor.config.cudagraph_policy``. When ``None`` + (the default), the existing built-in behaviour is used unchanged. + + Example usage:: + + class MyCUDAGraphPolicy(CUDAGraphPolicy): + def cudagraphify(self, model, example_inputs, static_input_idxs, **kwargs): + return my_custom_wrapper(model, example_inputs, static_input_idxs) + + + with torch._inductor.config.patch("cudagraph_policy", MyCUDAGraphPolicy()): + compiled_fn = deserialize_artifacts(...) + """ + + def cudagraphify( + self, + model: Callable[..., Any], + example_inputs: Sequence[InputType], + static_input_idxs: Sequence[int], + *, + device_index: int, + is_backward: bool, + is_inference: bool, + **kwargs: Any, + ) -> Callable[..., Any]: + """Wrap a single compiled callable with CUDA graph capture/replay. + + Called by ``cudagraph_post_compile`` for each ``CompiledFxGraph``. + The default delegates to ``compile_fx.cudagraphify`` (cudagraph_trees). + + ``example_inputs`` are the example inputs at post_compile time. + The default implementation does not forward them because + ``compile_fx.cudagraphify`` defers graph recording to the first + real call via an inner closure. Subclasses that need the + example inputs for warmup or static-input detection may use them. + + When ``config.graph_partition=True``, setting a CUDAGraphPolicy + bypasses ``cudagraph_partition_post_compile`` (which wraps each + partition individually) and routes through ``cudagraph_post_compile`` + instead, so this method wraps the *entire* callable, not individual + partitions. Subclasses that need per-partition control should + handle partitioning internally. + """ + from torch._inductor.compile_fx import cudagraphify + + return cudagraphify( + model, + static_input_idxs, + device_index=device_index, + is_backward=is_backward, + is_inference=is_inference, + **kwargs, + ) + + def should_wrap(self, compiled_graph: OutputCode) -> bool: + """Whether to apply cudagraph wrapping to this CompiledFxGraph. + + Called for each inner ``CompiledFxGraph`` during ``post_compile``. + Return ``False`` to skip wrapping (e.g. when wrapping at the outer + level via ``wrap_output`` instead). + + Default: ``True`` (wrap everything, same as current behaviour). + """ + return True + + def wrap_output(self, output_code: _OC) -> _OC: + """Optional outer-level wrapping after inner post_compile completes. + + Called by ``_compile_fx_inner``, ``BundledOutputCodeLoadable.post_compile``, + and ``FxGraphCacheLoadable.post_compile`` on the ``OutputCode`` returned + from ``post_compile``. Subclasses that only want to wrap specific + output types should check ``isinstance`` and return the input + unchanged for types they don't handle. + + Default: identity (no outer wrapping). + """ + return output_code + + @dataclasses.dataclass(frozen=True, slots=True) class FunctionID: "Unique counter of a function wrapped in cudagraphify_impl" diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 3187cbe8c8023..f3ca59e2504bb 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -34,7 +34,7 @@ from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map -from . import config, ir # noqa: F811, this is needed +from . import config, ir from .ir import ExternKernel from .scheduler import ( BaseSchedulerNode, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 0c57b606f08d6..9c17b631b23d7 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -100,6 +100,7 @@ aten.triu_indices, aten.unbind_copy.int, aten.upsample_bilinear2d.vec, + aten.hann_window, quantized.linear_dynamic_fp16_unpacked_weight, _quantized.wrapped_quantized_linear, ] @@ -350,7 +351,7 @@ def bmm( ) -> torch.Tensor: # Outer-product specialization: [B, M, 1] x [B, 1, N] -> [B, M, N]. # This avoids introducing a reduction and maps directly to broadcasted mul. - if statically_known_true(self.shape[2] == 1) or statically_known_true( + if statically_known_true(self.shape[2] == 1) and statically_known_true( batch2.shape[1] == 1 ): return (self * batch2).contiguous() @@ -775,22 +776,22 @@ def _rand_like( return result.permute(permutation).clone() -@register_decomposition(aten.rand_like) +@decomp.register_decomposition([aten.rand_like], extra_random_decomps) def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: return _rand_like(torch.rand, self, **kwargs) -@register_decomposition(aten.randn_like) +@decomp.register_decomposition([aten.randn_like], extra_random_decomps) def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: return _rand_like(torch.randn, self, **kwargs) -@register_decomposition(aten.randint_like.default) +@decomp.register_decomposition([aten.randint_like.default], extra_random_decomps) def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) -@register_decomposition(aten.randint_like.low_dtype) +@decomp.register_decomposition([aten.randint_like.low_dtype], extra_random_decomps) def randint_like_low( self: torch.Tensor, low: int, high: int, **kwargs: Any ) -> torch.Tensor: diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index d7e296ff68c64..312ab76341e3c 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -313,6 +313,10 @@ def randn(seed: int, offset: int) -> torch.dtype: def rand(seed: int, offset: int) -> torch.dtype: return torch.float + @staticmethod + def rand_eager(seed, offset, threads_per_round, tid, vec) -> torch.dtype: + return torch.float + @staticmethod def store_reduction(name: str, index, value: DTypeArg) -> None: return None @@ -444,7 +448,13 @@ def dot(x: DTypeArg, y: DTypeArg) -> torch.dtype: @staticmethod def inline_asm_elementwise( - *inputs, asm, constraints=None, dtype=torch.float32, is_pure=True, pack=1 + *inputs, + asm, + constraints=None, + dtype=torch.float32, + is_pure=True, + pack=1, + input_dtypes=None, ): return dtype diff --git a/torch/_inductor/fuzzer.py b/torch/_inductor/fuzzer.py index e487a6f8789a4..e9ba9f575e7b1 100644 --- a/torch/_inductor/fuzzer.py +++ b/torch/_inductor/fuzzer.py @@ -179,8 +179,18 @@ def failing(self) -> bool: "group_linear": {"require_fbgemm": True}, }, ], - "autoheuristic_collect": ["pad_mm", "mixed_mm"], - "autoheuristic_use": ["pad_mm", "mixed_mm"], + "autoheuristic_collect": [ + {"pad_mm": True, "mixed_mm": True}, + {"pad_mm": True, "mixed_mm": False}, + {"pad_mm": False, "mixed_mm": True}, + {"pad_mm": False, "mixed_mm": False}, + ], + "autoheuristic_use": [ + {"pad_mm": True, "mixed_mm": True}, + {"pad_mm": True, "mixed_mm": False}, + {"pad_mm": False, "mixed_mm": True}, + {"pad_mm": False, "mixed_mm": False}, + ], "traceable_tensor_subclasses": [OrderedSet()], "nontraceable_tensor_subclasses": [OrderedSet()], } @@ -509,6 +519,7 @@ def keys(self) -> KeysView[ComboType]: "pre_grad_custom_pass": DEFAULT, # Typing "custom_partitioner_fn": DEFAULT, # Typing "inductor_choices_class": DEFAULT, # Typing + "cudagraph_policy": DEFAULT, # Typing }, "torch._dynamo.config": { "traceable_tensor_subclasses": DEFAULT, # Typing diff --git a/torch/_inductor/fx_passes/auto_chunker/applier.py b/torch/_inductor/fx_passes/auto_chunker/applier.py index 97c675ad702c2..37582297d6d76 100644 --- a/torch/_inductor/fx_passes/auto_chunker/applier.py +++ b/torch/_inductor/fx_passes/auto_chunker/applier.py @@ -251,6 +251,7 @@ def _create_placeholder_node(input_node: Node) -> Node: and meta.chunk_dim is not None ): shape = list(original_node.args[0]) # type: ignore[arg-type] + # pyrefly: ignore [unsupported-operation] shape[meta.chunk_dim] = chunk_size env[original_node] = new_graph.call_function( aten.full.default, @@ -258,15 +259,15 @@ def _create_placeholder_node(input_node: Node) -> Node: original_node.kwargs, ) continue - # Chunk aten.expand a scalar + # Chunk aten.expand: adjust the target shape at the chunk dimension if ( original_node.target == aten.expand.default and isinstance(original_node.args[0], torch.fx.Node) - and original_node.args[0].meta["val"].numel() == 1 and (meta := get_chunking_meta(original_node)) is not None and meta.chunk_dim is not None ): shape = list(original_node.args[1]) # type: ignore[arg-type] + # pyrefly: ignore [unsupported-operation] shape[meta.chunk_dim] = chunk_size env[original_node] = new_graph.call_function( aten.expand.default, @@ -275,6 +276,23 @@ def _create_placeholder_node(input_node: Node) -> Node: ) continue + # Chunk aten.view: adjust the target shape at the chunk dimension + if ( + original_node.target == aten.view.default + and isinstance(original_node.args[0], torch.fx.Node) + and (meta := get_chunking_meta(original_node)) is not None + and meta.chunk_dim is not None + ): + shape = list(original_node.args[1]) # type: ignore[arg-type] + # pyrefly: ignore [unsupported-operation] + shape[meta.chunk_dim] = chunk_size + env[original_node] = new_graph.call_function( + aten.view.default, + (env[original_node.args[0]], shape), # type: ignore[arg-type] + original_node.kwargs, + ) + continue + # create the node with chunked inputs env[original_node] = new_graph.node_copy(original_node, lambda x: env[x]) diff --git a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py index 5d8ba6a6b74a2..22848193e52a4 100644 --- a/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py +++ b/torch/_inductor/fx_passes/auto_chunker/propagate_scale_by.py @@ -136,14 +136,15 @@ def propagate_where(where_node: Node) -> bool: aten.exp.default, aten.log.default, aten.tanh.default, + aten.eq.Tensor, ] ) -def propagate_nonlinear_requires_no_scaling(out_node: Node) -> bool: +def propagate_requires_no_scaling(out_node: Node) -> bool: """ - For nonlinear ops like exp, log, tanh, scale_by cannot be propagated - through since f(S*x) != S*f(x). These ops typically appear in the chunking - subgraph when the final gradient is 1 (i.e. scale_by is None), - making scaling a no-op. + For nonlinear ops (exp, log, tanh) scale_by cannot be propagated + through since f(S*x) != S*f(x). For boolean-output ops (eq) scale_by + is meaningless. These ops only appear in the chunking subgraph when + scale_by is None (e.g. the final gradient is 1). """ args_node = get_args_of_node_type(out_node) args_meta = get_chunking_metas(args_node) @@ -165,9 +166,15 @@ def propagate_nonlinear_requires_no_scaling(out_node: Node) -> bool: aten.neg.default, aten.sum.dim_IntList, aten.sum.default, # sum to scalar + aten.amax.default, aten.mm.default, aten.permute.default, aten.expand.default, + aten.squeeze.dim, + aten.unsqueeze.default, + aten.gather.default, + aten.scatter_add.default, + aten.view.default, ] ) def propagate_general_copy(out_node: Node) -> bool: @@ -186,6 +193,16 @@ def propagate_general_copy(out_node: Node) -> bool: return True +@register_propagate_rule(aten.scatter.value) +def propagate_scatter_value(out_node: Node) -> bool: + # The backward of scatter.value always has value=0 (gradient of a constant), + # so S * scatter(x, idx, 0) = scatter(S*x, idx, 0) holds. + value = out_node.args[3] + if value != 0: + return False + return propagate_general_copy(out_node) + + @register_propagate_rule( [ aten.add.Tensor, diff --git a/torch/_inductor/fx_passes/auto_chunker/propagator.py b/torch/_inductor/fx_passes/auto_chunker/propagator.py index 57c89923cc9de..347539575ae15 100644 --- a/torch/_inductor/fx_passes/auto_chunker/propagator.py +++ b/torch/_inductor/fx_passes/auto_chunker/propagator.py @@ -1,5 +1,6 @@ import functools import logging +import math from collections.abc import Callable, Sequence from enum import Enum from queue import Queue @@ -396,6 +397,7 @@ def bwd() -> PropagateStatus: prims.fma.default, aten.where.self, aten.neg.default, + aten.eq.Tensor, ] ) def propagate_general_copy_metadata( @@ -476,7 +478,7 @@ def propagate_bwd() -> PropagateStatus: return PropagateStatus.FAIL # apply any to a list to avoid short-circuit - changed = any( # noqa: C419 + changed = any( [ # noqa: C419 copy_chunking_meta(node, meta) if not need_handle_broadcast or node_ndim[node] == out_ndim @@ -491,7 +493,7 @@ def propagate_bwd() -> PropagateStatus: # where we attach chunking metadata to tangents that need to be # included in the chunking subgraph. # This is different to having a None ChunkingMeta - changed |= any( # noqa: C419 + changed |= any( [ # noqa: C419 set_chunking_meta(node) for node in scalar_args @@ -509,6 +511,7 @@ def propagate_bwd() -> PropagateStatus: aten.squeeze.dim, aten.gather.default, aten.scatter.value, + aten.scatter_add.default, ] ) def propagate_general_copy_metadata_ignore_broadcast(out_node: Node) -> _HandlerRetType: @@ -636,6 +639,118 @@ def bwd() -> PropagateStatus: return fwd(), bwd() +@register_propagate_rule(aten.unsqueeze.default) +def propagate_unsqueeze(unsqueeze_node: Node) -> _HandlerRetType: + input_node, unsqueeze_dim = unsqueeze_node.args[:2] + assert isinstance(input_node, Node) + assert isinstance(unsqueeze_dim, int) + input_ndim = get_fake_tensor_from_node_arg(input_node).ndim # type: ignore[union-attr] + # Normalize negative dim: unsqueeze valid range is [-(ndim+1), ndim] + normalized_dim = ( + unsqueeze_dim + input_ndim + 1 if unsqueeze_dim < 0 else unsqueeze_dim + ) + + def fwd() -> PropagateStatus: + assert isinstance(input_node, Node) + input_meta = get_chunking_meta(input_node) + if input_meta is None: + return _bool_to_status(False) + if input_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(unsqueeze_node, input_meta)) + + # pyrefly: ignore[unsupported-operation] + new_dim = input_meta.chunk_dim + ( + 1 if input_meta.chunk_dim >= normalized_dim else 0 + ) + return _bool_to_status( + set_chunking_meta(unsqueeze_node, meta=input_meta, chunk_dim=new_dim) + ) + + def bwd() -> PropagateStatus: + assert isinstance(input_node, Node) + output_meta = get_chunking_meta(unsqueeze_node) + if output_meta is None: + return _bool_to_status(False) + if output_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(input_node, output_meta)) + # pyrefly: ignore[unsupported-operation] + new_dim = output_meta.chunk_dim - ( + 1 if output_meta.chunk_dim > normalized_dim else 0 + ) + return _bool_to_status( + set_chunking_meta(input_node, meta=output_meta, chunk_dim=new_dim) + ) + + return fwd(), bwd() + + +def _find_chunk_dim_after_reshape( + old_shape: Sequence[int], new_shape: Sequence[int], chunk_dim: int +) -> int | None: + """ + Find the equivalent chunk_dim position after a reshape by matching + the prefix product (number of elements before the dimension) and + the dimension size. Returns None if the chunk dimension is merged + or split by the reshape, making it unsafe to propagate. + + Examples: + [M, N] -> [M, N, 1], chunk_dim=0: returns 0 (trailing dim added) + [M] -> [M, 1], chunk_dim=0: returns 0 + [M, N] -> [M1, M2, N] where M1*M2=M, chunk_dim=0: returns None (split) + [M, N] -> [M*N], chunk_dim=0: returns None (merged) + """ + chunk_size = old_shape[chunk_dim] + old_offset = math.prod(old_shape[:chunk_dim]) + new_offset = 1 + for new_dim in range(len(new_shape)): + if new_offset == old_offset and new_shape[new_dim] == chunk_size: + return new_dim + new_offset *= new_shape[new_dim] + return None + + +@register_propagate_rule(aten.view.default) +def propagate_view(view_node: Node) -> _HandlerRetType: + input_node = view_node.args[0] + assert isinstance(input_node, Node) + input_shape = list(get_fake_tensor_from_node_arg(input_node).shape) # type: ignore[union-attr] + output_shape = list(get_fake_tensor_from_node_arg(view_node).shape) # type: ignore[union-attr] + + def fwd() -> PropagateStatus: + assert isinstance(input_node, Node) + input_meta = get_chunking_meta(input_node) + if input_meta is None: + return _bool_to_status(False) + if input_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(view_node, input_meta)) + new_dim = _find_chunk_dim_after_reshape( + input_shape, output_shape, input_meta.chunk_dim + ) + if new_dim is None: + return PropagateStatus.FAIL + return _bool_to_status( + set_chunking_meta(view_node, meta=input_meta, chunk_dim=new_dim) + ) + + def bwd() -> PropagateStatus: + assert isinstance(input_node, Node) + output_meta = get_chunking_meta(view_node) + if output_meta is None: + return _bool_to_status(False) + if output_meta.chunk_dim is None: + return _bool_to_status(copy_chunking_meta(input_node, output_meta)) + new_dim = _find_chunk_dim_after_reshape( + output_shape, input_shape, output_meta.chunk_dim + ) + if new_dim is None: + return PropagateStatus.FAIL + return _bool_to_status( + set_chunking_meta(input_node, meta=output_meta, chunk_dim=new_dim) + ) + + return fwd(), bwd() + + @register_propagate_rule( [ aten.expand.default, @@ -645,15 +760,50 @@ def propagate_expand(expand_node: Node) -> _HandlerRetType: input_node = expand_node.args[0] assert isinstance(input_node, Node) - if input_node.meta["val"].numel() != 1: - return PropagateStatus.FAIL + input_ft = get_fake_tensor_from_node_arg(input_node) + assert input_ft is not None + output_ft = get_fake_tensor_from_node_arg(expand_node) + assert output_ft is not None + input_shape = list(input_ft.shape) + output_shape = list(output_ft.shape) - # Combined fwd/bwd rule - output_meta = get_chunking_meta(expand_node) - if output_meta is None: - return _bool_to_status(False) + if input_ft.numel() == 1: + # Scalar input: combined fwd/bwd rule + output_meta = get_chunking_meta(expand_node) + if output_meta is None: + return _bool_to_status(False) + return _bool_to_status(set_chunking_meta(input_node)) + + # How many leading dims are added by expand + dim_offset = len(output_shape) - len(input_shape) + + def is_expand_dim(out_dim: int) -> bool: + """Check if out_dim is a broadcast dimension (newly added or size 1 in input).""" + return out_dim < dim_offset or input_shape[out_dim - dim_offset] == 1 + + def fwd() -> PropagateStatus: + assert isinstance(input_node, Node) + input_meta = get_chunking_meta(input_node) + if input_meta is None: + return _bool_to_status(False) + # Fail if chunk_dim is an expand dimension (input size 1 broadcast to larger size) + if input_meta.chunk_dim is not None and is_expand_dim( + input_meta.chunk_dim + dim_offset + ): + return PropagateStatus.FAIL + return _bool_to_status(copy_chunking_meta(expand_node, input_meta)) - return _bool_to_status(set_chunking_meta(input_node)) + def bwd() -> PropagateStatus: + assert isinstance(input_node, Node) + output_meta = get_chunking_meta(expand_node) + if output_meta is None: + return _bool_to_status(False) + # Fail if chunk_dim is an expand dimension (input size 1 broadcast to larger size) + if output_meta.chunk_dim is not None and is_expand_dim(output_meta.chunk_dim): + return PropagateStatus.FAIL + return _bool_to_status(copy_chunking_meta(input_node, output_meta)) + + return fwd(), bwd() @register_propagate_rule( diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index d79bef415c4f7..a98793040a194 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -26,7 +26,15 @@ overlap_log = torch._logging.getArtifactLogger(__name__, "overlap") -BucketMode: TypeAlias = Literal["default", "custom_ops", "custom_ops_multidtype"] +BucketMode: TypeAlias = Literal[ + "default", "custom_ops", "custom_ops_multidtype", "coalesced" +] + + +def _default_bucket_mode() -> BucketMode: + from torch._inductor import config + + return config.aten_distributed_optimizations.bucket_mode or "default" # Helper functions moved to top for better organization @@ -71,13 +79,13 @@ def _compute_foreach_groups( Returns a flat list with -1 as group delimiter, or None if only one group exists. For example, groups [[0, 2], [1]] would be encoded as [0, 2, -1, 1]. """ - from torch.fx.experimental.symbolic_shapes import size_hint + from torch.fx.experimental.symbolic_shapes import guarding_hint_or_throw groups: defaultdict[tuple[torch.dtype, torch.dtype, tuple[int, ...]], list[int]] = ( defaultdict(list) ) for i, (ag_in, out_dtype) in enumerate(zip(ag_ins, out_dtypes)): - shape = tuple(size_hint(s) for s in ag_in.shape) + shape = tuple(guarding_hint_or_throw(s) for s in ag_in.shape) key = (ag_in.dtype, out_dtype, shape) groups[key].append(i) @@ -94,23 +102,36 @@ def _compute_foreach_groups( return result -def _schedulable_wait_node(node: torch.fx.Node) -> bool: - """ - Add additional check on if the wait node is schedulable - We should not schedule a fx node that is: - 1. wait on a collective that is not callable - 2. wait on a non-NCCL communication node +def _get_collective_node_from_wait(node: torch.fx.Node) -> torch.fx.Node | None: + """Given a wait node, return the collective it waits on. + + Handles both standard (wait -> collective) and coalesced + (wait -> getitem -> coalesced_collective) patterns. + Returns None if the node is not a wait on a recognized NCCL collective. """ if not is_wait_tensor(node): - return False - assert isinstance(node.args[0], torch.fx.Node) - if not isinstance(node.args[0].target, Callable): - return False - is_callable: bool = node.args[0].op == "call_function" + return None + arg = node.args[0] + assert isinstance(arg, torch.fx.Node) + if arg.op != "call_function": + return None + if arg.target is operator.getitem: + assert isinstance(arg.args[0], torch.fx.Node) + arg = arg.args[0] + if arg.op != "call_function": + return None + if not isinstance(arg.target, Callable): + return None # pyrefly: ignore [missing-attribute] - coll: NCCL_COLL = get_collective_type_from_kernel_name(node.args[0].target.name()) - is_collective: bool = coll != NCCL_COLL.UNSUPPORTED - return is_callable and is_collective + coll: NCCL_COLL = get_collective_type_from_kernel_name(arg.target.name()) + if coll == NCCL_COLL.UNSUPPORTED: + return None + return arg + + +def _schedulable_wait_node(node: torch.fx.Node) -> bool: + """Check if this wait node is schedulable (waits on a recognized NCCL collective).""" + return _get_collective_node_from_wait(node) is not None def _populate_node_meta( @@ -195,8 +216,9 @@ def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float: def bucket_all_gather( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> None: + mode = mode or _default_bucket_mode() if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( bucket_cap_mb_by_bucket_idx_default, # pyrefly: ignore [missing-module-attribute] @@ -212,8 +234,9 @@ def bucket_all_gather( def bucket_reduce_scatter( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float] | None = None, - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> None: + mode = mode or _default_bucket_mode() if bucket_cap_mb_by_bucket_idx is None: from torch._inductor.fx_passes.bucketing import ( bucket_cap_mb_by_bucket_idx_default, # pyrefly: ignore [missing-module-attribute] @@ -275,7 +298,7 @@ def get_collective_type(node: torch.fx.Node) -> str: def get_full_bucket_key( - node: torch.fx.Node, bucket_mode: BucketMode + node: torch.fx.Node, bucket_mode: BucketMode | None ) -> tuple[str, Any]: """Get the full bucket key including collective type and bucket key.""" return (get_collective_type(node), bucket_key(node, mode=bucket_mode)) @@ -453,7 +476,7 @@ def bucket_all_gather_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> list[list[torch.fx.Node]]: """ Identifies all all_gather nodes and groups them into buckets, @@ -472,6 +495,7 @@ def bucket_all_gather_by_mb( Returns: list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes. """ + mode = mode or _default_bucket_mode() group_key_fn = ( _ag_group_key_multidtype if mode and "multidtype" in mode else _ag_group_key @@ -490,7 +514,7 @@ def bucket_reduce_scatter_by_mb( gm: torch.fx.GraphModule, bucket_cap_mb_by_bucket_idx: Callable[[int], float], filter_wait_node: Callable[[torch.fx.Node], bool] | None = None, - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> list[list[torch.fx.Node]]: """ Identifies all reduce_scatter nodes and groups them into buckets, @@ -507,8 +531,9 @@ def bucket_reduce_scatter_by_mb( Returns: list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes. """ + mode = mode or _default_bucket_mode() - assert "multidtype" not in mode, ( + assert mode is None or "multidtype" not in mode, ( "reduce scatter bucketing does not support multidtype" ) @@ -624,6 +649,29 @@ def reduce_scatter_merge_fn_to_trace( return new_outs +def reduce_scatter_merge_fn_coalesced( + rs_ins: list[torch.Tensor], + group_size: int, + group_name: str, + reduce_op: str, + reduce_dtype: torch.dtype, + device: torch.device, +) -> list[torch.Tensor]: + """Bucketed RS via NCCL's coalesced API (ncclGroupStart/End). + + Avoids cat-ing inputs into one buffer; instead passes the tensor list + directly to reduce_scatter_tensor_coalesced for zero-copy batching. + """ + rs_ins_flat = [x.view(-1) for x in rs_ins] + new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins] + + rs_outs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( + rs_ins_flat, reduce_op, group_size, group_name + ) + rs_outs = [torch.ops.c10d_functional.wait_tensor(o) for o in rs_outs] + return [o.view(s) for o, s in zip(rs_outs, new_out_sizes)] + + def all_reduce_merge_fn_to_trace( ar_ins: list[torch.Tensor], group_name: str, @@ -825,9 +873,11 @@ def all_gather_merge_fn_to_trace( device = ag_ins[0].device new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device) new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel) - foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes) ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins] - torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened) + # Inductor fuses copy_(cat(...)) into 1 Triton kernel with no allocation for cat. + # _foreach_copy_(..., ag_ins_flattened) emits separate kernel per item, + # resulting in large number of small triton kernels to launch. + new_ag_in.copy_(torch.cat(ag_ins_flattened)) wait_tensor = torch.ops.c10d_functional.wait_tensor( torch.ops._c10d_functional.all_gather_into_tensor_out.default( new_ag_in, group_size, group_name, out=new_ag_out @@ -1066,10 +1116,11 @@ def process_collective_bucket( def merge_reduce_scatter_bucket( g: torch.fx.Graph, rs_nodes: list[torch.fx.Node], - mode: BucketMode = "default", + mode: BucketMode | None = None, insert_before: torch.fx.Node | None = None, wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + mode = mode or _default_bucket_mode() # Validate bucket consistency rs0 = rs_nodes[0] rs0_val = rs0.meta["val"] @@ -1089,7 +1140,9 @@ def merge_reduce_scatter_bucket( # Choose merge function based on mode rs_merge_fn = reduce_scatter_merge_fn_to_trace - if mode and "custom_ops" in mode: + if mode == "coalesced": + rs_merge_fn = reduce_scatter_merge_fn_coalesced + elif mode and "custom_ops" in mode: rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops # Process bucket with lazy input collection @@ -1159,10 +1212,11 @@ def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: def merge_all_gather_bucket( g: torch.fx.Graph, ag_nodes: list[torch.fx.Node], - mode: BucketMode = "default", + mode: BucketMode | None = None, insert_before: torch.fx.Node | None = None, wait_insertion_point: torch.fx.Node | None = None, ) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]: + mode = mode or _default_bucket_mode() from torch.distributed.distributed_c10d import _resolve_process_group ag0 = ag_nodes[0] @@ -1178,7 +1232,9 @@ def merge_all_gather_bucket( # Choose merge function based on mode ag_merge_fn = all_gather_merge_fn_to_trace - if mode is not None and "custom_ops" in mode: + if mode == "coalesced": + logger.info("coalesced bucket_mode not supported for all_gather, using default") + elif mode and "custom_ops" in mode: ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops # type: ignore[assignment] # Process bucket with lazy input collection @@ -1207,11 +1263,12 @@ def create_trace_args(bucket_ins: list[torch.fx.Node]) -> tuple[Any, ...]: def merge_reduce_scatter( gm: torch.fx.GraphModule, rs_buckets: list[list[torch.fx.Node]], - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> None: """ Merges specified buckets of reduce_scatter to joint reduce_scatter. """ + mode = mode or _default_bucket_mode() with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True): trace_structured( "artifact", @@ -1231,11 +1288,12 @@ def merge_reduce_scatter( def merge_all_gather( gm: torch.fx.GraphModule, ag_buckets: list[list[torch.fx.Node]], - mode: BucketMode = "default", + mode: BucketMode | None = None, ) -> None: """ Merges specified buckets of all_gather to joint all_gather. """ + mode = mode or _default_bucket_mode() with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True): trace_structured( "artifact", diff --git a/torch/_inductor/fx_passes/control_dependencies.py b/torch/_inductor/fx_passes/control_dependencies.py index 681ded9438fc4..d50a034665b2b 100644 --- a/torch/_inductor/fx_passes/control_dependencies.py +++ b/torch/_inductor/fx_passes/control_dependencies.py @@ -57,6 +57,13 @@ def __call__(self, additional_deps, subgraph, *args, **kwargs): control_deps = ControlDeps() +# control_deps wraps side-effecting ops (e.g. record_event, wait_event) +# and must not be eliminated by DCE even when its outputs are unused. +from torch.fx.node import has_side_effect + + +has_side_effect(control_deps) + # Register fake implementation for tracing @register_fake(control_deps) diff --git a/torch/_inductor/fx_passes/freezing_patterns.py b/torch/_inductor/fx_passes/freezing_patterns.py index 3e93063e6e4a5..48a940ae625fe 100644 --- a/torch/_inductor/fx_passes/freezing_patterns.py +++ b/torch/_inductor/fx_passes/freezing_patterns.py @@ -187,7 +187,7 @@ def check_concat_weights(match): if not all( inp.op == "get_attr" - and inp.meta["val"].shape == inps[0].meta["val"].shape + and inp.meta["val"].shape[:-1] == inps[0].meta["val"].shape[:-1] for inp in inps ): return False @@ -221,9 +221,10 @@ def matmul_fuse_pattern(inp, w1, w2, w3): return (inp @ w1, inp @ w2, inp @ w3) def matmul_replacement(inp, w1, w2, w3): - cat_t = torch.cat((w1, w2, w3), dim=1) + weights = (w1, w2, w3) + cat_t = torch.cat(weights, dim=1) mm = inp @ cat_t - return mm.chunk(3, dim=1) + return mm.split([w.size(1) for w in weights], dim=1) register_replacement( # pyrefly: ignore [bad-argument-type] @@ -243,9 +244,10 @@ def matmul_fuse_pattern_two(inp, w1, w2): return (inp @ w1, inp @ w2) def matmul_replacement_two(inp, w1, w2): - cat_t = torch.cat((w1, w2), dim=1) + weights = (w1, w2) + cat_t = torch.cat(weights, dim=1) mm = inp @ cat_t - return mm.chunk(2, dim=1) + return mm.split([w.size(1) for w in weights], dim=1) register_replacement( # pyrefly: ignore [bad-argument-type] @@ -269,9 +271,10 @@ def addmm_fuse_pattern_second(inp, w1, w2, w3, b1, b2, b3): ) def addmm_fuse_replacement_second(inp, w1, w2, w3, b1, b2, b3): - cat_w = torch.cat((w1, w2, w3), dim=1) + weights = (w1, w2, w3) + cat_w = torch.cat(weights, dim=1) cat_b = torch.cat((b1, b2, b3)) - return aten.addmm(cat_b, inp, cat_w).chunk(3, dim=1) + return aten.addmm(cat_b, inp, cat_w).split([w.size(1) for w in weights], dim=1) register_replacement( # pyrefly: ignore [bad-argument-type] diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index cfa64e84afe0b..4bb2a99455572 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -2,7 +2,6 @@ import functools import inspect import logging -import math import warnings import torch @@ -108,15 +107,15 @@ def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): ) -def _sfdp_pattern_5(query, key, value, attn_mask): +def _sfdp_pattern_5(query, key, value, attn_mask, inv_scale): attn_weight = torch.softmax( - (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + (query @ key.transpose(-2, -1) / (inv_scale)) + attn_mask, dim=-1 ) # attn_weight = torch.dropout(attn_weight, dropout_p) return attn_weight @ value -def _sfdp_replacement_5(query, key, value, attn_mask): +def _sfdp_replacement_5(query, key, value, attn_mask, inv_scale): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( query, @@ -124,19 +123,20 @@ def _sfdp_replacement_5(query, key, value, attn_mask): value, attn_mask=attn_mask.to(dtype=query.dtype), dropout_p=0.0, + scale=1.0 / inv_scale, is_causal=False, ) -def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): +def _sfdp_pattern_6(query, key, value, attn_mask, inv_scale, dropout_p): attn_weight = torch.softmax( - (query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))) + attn_mask, dim=-1 + (query @ key.transpose(-2, -1) / inv_scale) + attn_mask, dim=-1 ) attn_weight = torch.dropout(attn_weight, dropout_p, True) return attn_weight @ value -def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): +def _sfdp_replacement_6(query, key, value, attn_mask, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( query, @@ -144,26 +144,28 @@ def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): value, attn_mask=attn_mask.to(dtype=query.dtype), dropout_p=dropout_p, + scale=1.0 / inv_scale, is_causal=False, ) -def _sfdp_pattern_7(query, key, value, dropout_p): +def _sfdp_pattern_7(query, key, value, inv_scale, dropout_p): # in real workloads inputs to matmul are permuted # causing matmul to expand to a series of expand and clone calls # we want the same to happen during pattern tracing q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) / inv_scale div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, True) attn_weight = attn_weight.to(torch.float16) + v = v.to(attn_weight.dtype) return attn_weight @ v -def _sfdp_replacement_7(query, key, value, dropout_p): +def _sfdp_replacement_7(query, key, value, inv_scale, dropout_p): # sdpa prefers inputs in permuted format # it makes a copy to put them in this format # if they aren't already @@ -179,23 +181,25 @@ def _sfdp_replacement_7(query, key, value, dropout_p): v, attn_mask=None, # attn_mask, dropout_p=dropout_p, + scale=1.0 / inv_scale, is_causal=False, ) -def _sfdp_pattern_8(query, key, value): +def _sfdp_pattern_8(query, key, value, inv_scale): # no dropout version of pattern 7 q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1)) + div = q @ k.transpose(-2, -1) / inv_scale div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = attn_weight.to(torch.float16) + v = v.to(attn_weight.dtype) return attn_weight @ v -def _sfdp_replacement_8(query, key, value): +def _sfdp_replacement_8(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) @@ -206,24 +210,26 @@ def _sfdp_replacement_8(query, key, value): v, attn_mask=None, # attn_mask, dropout_p=0.0, + scale=1.0 / inv_scale, is_causal=False, ) -def _sfdp_pattern_9(query, key, value, dropout_p): +def _sfdp_pattern_9(query, key, value, inv_scale, dropout_p): q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - q = q / math.sqrt(q.size(-1)) + q = q / inv_scale div = q @ k.transpose(-2, -1) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, True) attn_weight = attn_weight.to(torch.float16) + v = v.to(attn_weight.dtype) return attn_weight @ v -def _sfdp_replacement_9(query, key, value, dropout_p): +def _sfdp_replacement_9(query, key, value, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) @@ -234,24 +240,26 @@ def _sfdp_replacement_9(query, key, value, dropout_p): v, attn_mask=None, # attn_mask, dropout_p=dropout_p, + scale=1.0 / inv_scale, is_causal=False, ) -def _sfdp_pattern_10(query, key, value): +def _sfdp_pattern_10(query, key, value, inv_scale): # no dropout version of 9 q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) v = value.permute(0, 2, 1, 3) - q = q / math.sqrt(q.size(-1)) + q = q / inv_scale div = q @ k.transpose(-2, -1) div = div.to(torch.float32) attn_weight = torch.softmax(div, dim=-1) attn_weight = attn_weight.to(torch.float16) + v = v.to(attn_weight.dtype) return attn_weight @ v -def _sfdp_replacement_10(query, key, value): +def _sfdp_replacement_10(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 q = query.permute(0, 2, 1, 3) k = key.permute(0, 2, 1, 3) @@ -262,6 +270,7 @@ def _sfdp_replacement_10(query, key, value): v, attn_mask=None, # attn_mask, dropout_p=0.0, + scale=1.0 / inv_scale, is_causal=False, ) @@ -460,7 +469,7 @@ def _sfdp_replacement_17(query, key, value, attn_mask, inv_scale, dropout_p): ) -def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): +def _sfdp_pattern_18(query, key, value, causal_mask, inv_scale, dropout_p): # for hf_GPT2 with dropout (introduces clone node) for inference # it also returns permuted key & value query = query.permute([0, 2, 1, 3]) @@ -469,7 +478,7 @@ def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) inv_scale = torch.full( [], - value.size(-1) ** 0.5, + inv_scale, dtype=attn_weights.dtype, device=attn_weights.device, ) @@ -489,7 +498,7 @@ def _sfdp_pattern_18(query, key, value, causal_mask, dropout_p): ) -def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): +def _sfdp_replacement_18(query, key, value, causal_mask, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 permuted_key = key.transpose(1, 2) permuted_value = value.transpose(1, 2) @@ -501,19 +510,19 @@ def _sfdp_replacement_18(query, key, value, causal_mask, dropout_p): attn_mask=causal_mask, dropout_p=dropout_p, is_causal=False, - scale=1.0 / math.sqrt(value.size(-1)), + scale=1.0 / inv_scale, ), permuted_key, permuted_value, ) -def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): +def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, inv_scale, dropout_p): # for token-classification+gpt2 / text-generation+gpt2 attn_weights = torch.matmul(query, key.permute(0, 1, 3, 2)) inv_scale = torch.full( [], - value.size(-1) ** 0.5, + inv_scale, dtype=attn_weights.dtype, device=attn_weights.device, ) @@ -527,7 +536,9 @@ def _sfdp_pattern_19(query, key, value, causal_mask, attn_mask, dropout_p): return torch.nn.functional.dropout(attn_weights, dropout_p).matmul(value) -def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): +def _sfdp_replacement_19( + query, key, value, causal_mask, attn_mask, inv_scale, dropout_p +): counters["inductor"]["fuse_attention"] += 1 fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) attn_mask = torch.where(causal_mask, attn_mask, fill_value) @@ -538,18 +549,18 @@ def _sfdp_replacement_19(query, key, value, causal_mask, attn_mask, dropout_p): attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, - scale=1.0 / math.sqrt(value.size(-1)), + scale=1.0 / inv_scale, ) -def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p): +def _sfdp_pattern_20(query, key, value, attn_mask, inv_scale, dropout_p): # for DistilBert with dropout transformers==4.44.2 q = query.permute([0, 2, 1, 3]) k = key.permute([0, 2, 1, 3]) v = value.permute([0, 2, 1, 3]) bs = q.size(0) k_len = k.size(-2) - q = q.div(math.sqrt(q.size(-1))) + q = q.div(inv_scale) scores = q @ k.transpose(-2, -1) fill_value = torch.full((), -float("inf"), dtype=query.dtype, device=query.device) attn_mask = (attn_mask == 0).view((bs, 1, 1, k_len)).expand_as(scores) @@ -561,7 +572,7 @@ def _sfdp_pattern_20(query, key, value, attn_mask, dropout_p): ) -def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): +def _sfdp_replacement_20(query, key, value, attn_mask, inv_scale, dropout_p): counters["inductor"]["fuse_attention"] += 1 bs = query.size(0) n_head = query.size(2) @@ -578,7 +589,7 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p): attn_mask=attn_mask.to(dtype=torch.bool), dropout_p=dropout_p, is_causal=False, - scale=1.0 / math.sqrt(query.size(-1)), + scale=1.0 / inv_scale, ) @@ -811,6 +822,25 @@ def _sfdp_replacement_27(query, key, value, dropout_p): ) +def _sfdp_pattern_28(query, key, value, scale_factor, dropout_p): + # Visformer pattern + # same as pattern 4 but non-contiguous q/k/v + return _sfdp_pattern_4(query, key, value, scale_factor, dropout_p) + + +def _sfdp_replacement_28(query, key, value, scale_factor, dropout_p): + counters["inductor"]["fuse_attention"] += 1 + return _scaled_dot_product_attention( + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=None, + dropout_p=dropout_p, + is_causal=False, + scale=scale_factor, + ) + + @functools.lru_cache(None) def _warn_tf32_disabled() -> None: if ( @@ -935,6 +965,14 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): g_inp = functools.partial( torch.empty, (2, 4, 8, 16), device=device, requires_grad=True ) + # non-contiguous input to cover more patterns. + gn_inp = functools.partial( + torch.empty_strided, + (2, 6, 16, 8), + (2304, 128, 1, 16), + device=device, + requires_grad=True, + ) # attn_mask b_inp = functools.partial(torch.empty, (1, 1, 8, 8), device=device) m_inp = functools.partial(torch.empty, (2, 1, 1, 4), device=device) @@ -945,6 +983,8 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): # workaround https://github.com/pytorch/pytorch/issues/97894 # 0.113377 is a "magic" value that lets us recover the lost input arg relationship d = {"dropout_p": 0.113377} + s = {"inv_scale": 0.66666} + sd = {"inv_scale": 0.66666, "dropout_p": 0.113377} # we could also generate all these patterns in 3d.. TODO g_3d_inp = functools.partial( @@ -963,6 +1003,7 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): # but will not in float, so we generate a pattern for both for dtype in [torch.float, torch.half]: g = functools.partial(g_inp, dtype=dtype) + gn = functools.partial(gn_inp, dtype=dtype) b = functools.partial(b_inp, dtype=dtype) b_float = functools.partial(b_inp, dtype=torch.float) b_bool = functools.partial(b_inp, dtype=torch.bool) @@ -1010,42 +1051,42 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): _sfdp_pattern_5, _sfdp_replacement_5, [g(), g(), g(), b()], - {}, + s, _sfdp_params_check, ), ( _sfdp_pattern_6, _sfdp_replacement_6, [g(), g(), g(), b()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_7, _sfdp_replacement_7, [g(), g(), g()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_8, _sfdp_replacement_8, [g(), g(), g()], - {}, + s, _sfdp_params_check, ), ( _sfdp_pattern_9, _sfdp_replacement_9, [g(), g(), g()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_10, _sfdp_replacement_10, [g(), g(), g()], - {}, + s, _sfdp_params_check, ), ( @@ -1113,29 +1154,29 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): _sfdp_pattern_18, _sfdp_replacement_18, [g(), g(), g(), m_bool()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_18, _sfdp_replacement_18, [g_bs1(), g_bs1(), g_bs1(), m_bs1_bool()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_19, _sfdp_replacement_19, [g(), g(), g(), b_bool(), b_float()], - d, + sd, _sfdp_params_check, ), ( _sfdp_pattern_20, _sfdp_replacement_20, [g(), g(), g(), m_2d()], - d, - _sfdp_extra_check(aten.div.Tensor), + sd, + _sfdp_params_check, ), ( _sfdp_pattern_21, @@ -1228,6 +1269,13 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): d, _sfdp_extra_check(disable_cuda=True), ), + ( + _sfdp_pattern_28, + _sfdp_replacement_28, + [gn(), gn(), gn(), c()], + d, + _sfdp_extra_check(aten.mul.Tensor), + ), ] mask_fp32_patterns = ["pattern_16"] if dtype == torch.half: @@ -1287,15 +1335,17 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): "skip_duplicates": True, }, ) - + inference_workaround = {} if workaround: - assert len(workaround) == 1 and "dropout_p" in workaround - # functools.partial insufficient because we look at signature downstream - pattern = partialize_and_update_signature(pattern, dropout_p=0.0) - replacement = partialize_and_update_signature( - replacement, dropout_p=0.0 - ) - workaround = {} + assert len(workaround) <= 2 + if "inv_scale" in workaround: + inference_workaround["inv_scale"] = workaround["inv_scale"] + if "dropout_p" in workaround: + # functools.partial insufficient because we look at signature downstream + pattern = partialize_and_update_signature(pattern, dropout_p=0.0) + replacement = partialize_and_update_signature( + replacement, dropout_p=0.0 + ) inference_name = name + "_inference" yield ( @@ -1307,7 +1357,7 @@ def _get_sfdp_patterns(input_device: torch.device | None = None): "trace_fn": fwd_only, "pass_dicts": patterns, "extra_check": extra_check, - "scalar_workaround": workaround, + "scalar_workaround": inference_workaround, # with dropout turned into clone, we end up with a number of # semantically identical graphs "skip_duplicates": True, diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 8782a5402538e..bd3d786d49957 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -15,7 +15,9 @@ from .. import config from ..pattern_matcher import ( CallFunctionVarArgs, + CallMethodVarArgs, get_arg_value, + MatchResult, stable_topological_sort, ) from ..utils import OPTIMUS_EXCLUDE_POST_GRAD @@ -621,7 +623,7 @@ def _getitem_args(self, getitem_node: torch.fx.Node): return getitem_node.args[0] def match(self, node: torch.fx.Node): - if CallFunctionVarArgs(torch.nn.functional.linear).match( + if CallFunctionVarArgs([torch.nn.functional.linear, torch._C._nn.linear]).match( node ) and is_linear_node_can_be_fused(node): input = get_arg_value(node, 0, "input") @@ -1057,7 +1059,7 @@ def __init__(self, op, **kwargs): def match(self, node: torch.fx.Node): input = get_arg_value(node, 0, "input") - if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + if self._match_op(node) and is_node_meta_valid(node): # check the input has the same shape and its users have the same target # check all clamp operators have the same min and max values, and # nan_to_num operators use the same default value. @@ -1071,6 +1073,9 @@ def match(self, node: torch.fx.Node): group_key = None return group_key + def _match_op(self, node: torch.fx.Node) -> MatchResult: + return CallFunctionVarArgs(self.op).match(node) + def fuse(self, graph: torch.fx.GraphModule, subset: list[torch.fx.Node]): batch_nodes = [] batch_inputs = [] @@ -1134,6 +1139,11 @@ class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion): def __init__(self, **kwargs): super().__init__(torch.detach, **kwargs) + def _match_op(self, node: torch.fx.Node) -> MatchResult: + return CallFunctionVarArgs(torch.detach).match(node) or CallMethodVarArgs( + "detach" + ).match(node) + @register_fusion("batch_nan_to_num") class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion): diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 017db2d471b8f..4e1d384f66795 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py @@ -55,6 +55,27 @@ ] +def _is_lossless_fp_widening_cast( + src_dtype: torch.dtype, dst_dtype: torch.dtype +) -> bool: + if src_dtype == dst_dtype: + return True + + if not (src_dtype.is_floating_point and dst_dtype.is_floating_point): + return False + + src_info = torch.finfo(src_dtype) + dst_info = torch.finfo(dst_dtype) + + # A floating-point cast is only pointless if the first conversion cannot + # discard precision or range from the source values. + return ( + dst_info.eps <= src_info.eps + and dst_info.max >= src_info.max + and dst_info.tiny <= src_info.tiny + ) + + @init_once_fakemode def lazy_init(input_device: torch.device | None = None): from .fuse_attention import _sfdp_init @@ -362,12 +383,10 @@ def _deduce_value(self, node: torch.fx.Node): # handle before view ops because this changes value if node.target is aten.view.dtype: (input_tensor, output_dtype), kwargs = self.fetch_args_kwargs_from_env(node) - # view.dtype fails on 0-d tensors when element size changes - # (e.g., 0-d complex tensors can't be viewed as float) - if ( - input_tensor.ndim == 0 - and input_tensor.element_size() != output_dtype.itemsize - ): + # view.dtype with different element sizes changes element count + # (e.g., complex64 [1+0j] viewed as float32 becomes [1.0, 0.0]), + # making uniform values non-uniform. Also crashes on 0-d tensors. + if input_tensor.element_size() != output_dtype.itemsize: return self.unknown_value return super(ConstantFolder, self).run_node(node) @@ -524,7 +543,7 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule): "dtype": fake_tensor.dtype, "layout": torch.strided, "device": fake_tensor.device, - "pin_memory": False, + "pin_memory": node.kwargs.get("pin_memory", False), }, ) @@ -619,7 +638,8 @@ def canonicalize_aten_ir_passes(gm: torch.fx.GraphModule): def joint_graph_passes( - graph: torch.fx.GraphModule, input_device: torch.device | None = None + graph: torch.fx.GraphModule, + input_device: torch.device | None = None, ): """ Run FX transformations on the joint forwards+backwards graph. @@ -629,7 +649,7 @@ def joint_graph_passes( subsystem="joint_graph_passes", ) - lazy_init(input_device) # type: ignore[call-arg] + lazy_init(input_device) count = 0 # must occur before other passes @@ -763,7 +783,15 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp graph = match.graph node = match.output_node() allowed = torch.float16, torch.bfloat16, torch.float32, torch.float64 - if dtype1 in allowed and dtype2 in allowed: + arg_val = arg.meta.get("val", None) + if not isinstance(arg_val, torch.Tensor): + return + + if arg_val.dtype in allowed and dtype1 in allowed and dtype2 in allowed: + if config.emulate_precision_casts and not _is_lossless_fp_widening_cast( + arg_val.dtype, dtype1 + ): + return repl = graph.call_function( torch.ops.prims.convert_element_type.default, (arg, dtype2) ) diff --git a/torch/_inductor/fx_passes/low_contention_collectives.py b/torch/_inductor/fx_passes/low_contention_collectives.py new file mode 100644 index 0000000000000..663f47d5f91f1 --- /dev/null +++ b/torch/_inductor/fx_passes/low_contention_collectives.py @@ -0,0 +1,252 @@ +from __future__ import annotations + +import logging +import warnings + +import torch +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +def _get_collective_info(node): + """Return (is_ag, group_name) if node is an AG/RS collective, else None.""" + from torch._inductor.fx_passes.bucketing import ( + is_all_gather_into_tensor, + is_reduce_scatter_tensor, + ) + from torch._inductor.fx_passes.overlap_scheduling import get_group_name + + if is_all_gather_into_tensor(node): + return True, get_group_name(node) + if is_reduce_scatter_tensor(node): + return False, get_group_name(node) + return None + + +def replace_collectives_with_low_contention( + graph: torch.fx.Graph, +) -> None: + """Replace FSDP collectives with copy-engine symm_mem variants.""" + symm_mem = torch.ops.symm_mem + + collectives = [] + groups: OrderedSet[str] = OrderedSet() + for node in list(graph.nodes): + info = _get_collective_info(node) + if info is None: + continue + is_ag, group_name = info + collectives.append((node, is_ag, group_name)) + groups.add(group_name) + + if not collectives: + return + + # Some group names can't be resolved at compile time — skip them. + valid_groups: OrderedSet[str] = OrderedSet() + for group_name in groups: + if _enable_symm_mem(group_name): + valid_groups.add(group_name) + + # Filter to collectives whose groups we can actually resolve + collectives = [ + (node, is_ag, gn) for node, is_ag, gn in collectives if gn in valid_groups + ] + if not collectives: + return + + from torch._inductor import config + + min_bytes = config.aten_distributed_optimizations.low_contention_min_bytes_per_rank + + node_positions = {n: i for i, n in enumerate(graph.nodes)} + + replacements = 0 + skipped_small = 0 + skipped_no_overlap = 0 + skipped_nvlink_contention = 0 + for node, is_ag, group_name in collectives: + coll_type = "AG" if is_ag else "RS" + + # Size filter: LC barrier overhead dominates for small messages + if min_bytes > 0: + per_rank_bytes = _get_per_rank_bytes(node, is_ag) + if per_rank_bytes is not None and per_rank_bytes < min_bytes: + skipped_small += 1 + log.debug( + "LC skip %s %s: size %d < min_bytes %d", + coll_type, + node.name, + per_rank_bytes, + min_bytes, + ) + continue + + # Skip collectives with no compute to hide behind + if not _has_compute_bound_overlap(node, graph, node_positions): + skipped_no_overlap += 1 + log.debug("LC skip %s %s: no compute-bound overlap", coll_type, node.name) + continue + + # Skip if other groups' NCCL collectives overlap on NVLink + if _has_other_group_collectives(node, group_name, graph, node_positions): + skipped_nvlink_contention += 1 + log.debug( + "LC skip %s %s: overlaps other-group collectives (NVLink contention)", + coll_type, + node.name, + ) + continue + + _replace_collective(node, graph, symm_mem, is_ag, group_name) + replacements += 1 + + log.info( + "Replaced %d/%d FSDP collectives " + "(skipped_small=%d, skipped_no_overlap=%d, " + "skipped_nvlink_contention=%d, min_bytes=%d)", + replacements, + len(collectives), + skipped_small, + skipped_no_overlap, + skipped_nvlink_contention, + min_bytes, + ) + + +def _enable_symm_mem(group_name): + """Try to enable symmetric memory for a group. Returns True on success.""" + from torch.distributed._symmetric_memory import ( + enable_symm_mem_for_group, + is_symm_mem_enabled_for_group, + ) + + if is_symm_mem_enabled_for_group(group_name): + return True + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", FutureWarning) + enable_symm_mem_for_group(group_name) + return True + except (TypeError, RuntimeError, KeyError) as e: + log.debug("LC: cannot enable symm_mem for group %s: %s", group_name, e) + return False + + +def _replace_collective(node, graph, symm_mem, is_ag, group_name): + input_node = node.args[0] + if is_ag: + target = symm_mem._low_contention_all_gather.default + args = (input_node, group_name) + else: + reduce_op = node.args[1] + target = symm_mem._low_contention_reduce_scatter.default + args = (input_node, reduce_op, group_name) + + with graph.inserting_before(node): + new_node = graph.call_function(target, args=args) + new_node.meta.update(node.meta) + node.replace_all_uses_with(new_node) + graph.erase_node(node) + + +def _get_per_rank_bytes(node, is_ag): + """Return per-rank message bytes for a collective, or None if unknown.""" + input_val = node.args[0].meta.get("val") if node.args else None + if not isinstance(input_val, torch.Tensor): + return None + total_bytes = input_val.nelement() * input_val.element_size() + if is_ag: + return total_bytes + # For RS, input is the full tensor; per-rank = total / group_size + group_size = node.args[2] if len(node.args) > 2 else None + if not isinstance(group_size, int) or group_size <= 0: + return None + return total_bytes // group_size + + +def _has_compute_bound_overlap(start_node, graph, node_positions): + """Check if compute-bound ops exist between collective start and wait.""" + from torch._inductor.fx_passes.overlap_scheduling import is_compute_node + + wait_node = _find_wait_for_collective(start_node) + if wait_node is None: + return False + + start_pos = node_positions[start_node] + wait_pos = node_positions[wait_node] + + for node in graph.nodes: + pos = node_positions[node] + if pos <= start_pos or pos >= wait_pos: + continue + if is_compute_node(node): + return True + return False + + +def _has_other_group_collectives(start_node, group_name, graph, node_positions): + """Check if other groups' collectives overlap, competing for NVLink.""" + wait_node = _find_wait_for_collective(start_node) + if wait_node is None: + return False + + start_pos = node_positions[start_node] + wait_pos = node_positions[wait_node] + + for node in graph.nodes: + pos = node_positions[node] + if pos <= start_pos or pos >= wait_pos: + continue + info = _get_collective_info(node) + if info is not None: + _, other_group = info + if other_group != group_name: + log.debug( + "LC contention %s: found %s (group %s) between start/wait", + start_node.name, + node.name, + other_group, + ) + return True + return False + + +def _is_wait_tensor(node): + """Check if node is a wait_tensor op (direct or wrapped in ControlDeps).""" + if node.op != "call_function": + return False + if node.target is torch.ops._c10d_functional.wait_tensor.default: + return True + # Handles public namespace (c10d_functional.wait_tensor) and + # ControlDeps-wrapped wait_tensor (from TBB manual scheduling) + return "wait_tensor" in node.name + + +def _find_wait_for_collective(start_node): + """Find the wait_tensor node for a collective. + + Handles multiple graph patterns: + 1. Direct: start -> wait_tensor(start) + 2. _out variant: start(out=buf) -> wait_tensor(buf) + 3. ControlDeps-wrapped: start -> control_deps(wait_tensor_subgraph, start) + """ + for user in start_node.users: + if _is_wait_tensor(user): + return user + + # For _out variants, check users of the out-buffer keyword argument. + c10d = torch.ops._c10d_functional + if start_node.target in ( + c10d.all_gather_into_tensor_out.default, + c10d.reduce_scatter_tensor_out.default, + ): + out_buf = start_node.kwargs.get("out") + if isinstance(out_buf, torch.fx.Node): + for user in out_buf.users: + if _is_wait_tensor(user): + return user + + return None diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index ca723bb98bdd4..c939621679e09 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -6,7 +6,7 @@ import torch import torch.fx as fx -from torch.fx.experimental.symbolic_shapes import size_hint +from torch.fx.experimental.symbolic_shapes import optimization_hint from torch.utils._ordered_set import OrderedSet from torch.utils._pytree import tree_map_only @@ -144,9 +144,7 @@ def get_storages_last_used( def _size_of_default(num_bytes: int | torch.SymInt) -> int: - return size_hint( - num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback - ) + return optimization_hint(num_bytes) def device_filter(device: torch.device) -> bool: @@ -387,9 +385,7 @@ def get_current_memory_bytes(self) -> int: def _get_storage_size(self, storage_key: StorageKey) -> int: """Get the size of a storage in bytes, handling symbolic shapes.""" size_bytes = storage_key.storage.nbytes() - return size_hint( - size_bytes, fallback=torch._inductor.config.unbacked_symint_fallback - ) + return optimization_hint(size_bytes) def _get_storages_freed_by_node(self, node: fx.Node) -> OrderedSet[StorageKey]: """Get storages that would be freed if we schedule this node.""" diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index d61ef27840fa5..c1f0a7284484e 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -6,7 +6,7 @@ import torch from torch._dynamo.utils import counters -from torch.fx.experimental.symbolic_shapes import has_free_symbols +from torch.fx.experimental.symbolic_shapes import has_free_symbols, optimization_hint from torch.utils._ordered_set import OrderedSet from .. import ir, mkldnn_ir @@ -104,7 +104,7 @@ def pack_linear_weight( # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. packed_weight_inputs = ( transpose_weight_node, - batch_size.node.shape_env.size_hint(batch_size.node.expr) + optimization_hint(batch_size) if has_free_symbols(batch_size) else batch_size, ) diff --git a/torch/_inductor/fx_passes/numeric_utils.py b/torch/_inductor/fx_passes/numeric_utils.py index d1db82f21f7ec..486d047852b8f 100644 --- a/torch/_inductor/fx_passes/numeric_utils.py +++ b/torch/_inductor/fx_passes/numeric_utils.py @@ -207,7 +207,7 @@ def numeric_check_if_enabled( precision=precision, ) except Exception as e: - logger.warning( # noqa: G200 + logger.warning( "Runtime numeric check failed in pre grad fx passes with error: %s", e ) traceback.print_exc() diff --git a/torch/_inductor/fx_passes/overlap_manual_scheduling.py b/torch/_inductor/fx_passes/overlap_manual_scheduling.py index 1a5e3cd585de1..f0e2c855aaa9d 100644 --- a/torch/_inductor/fx_passes/overlap_manual_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_manual_scheduling.py @@ -4,11 +4,13 @@ from collections import Counter, defaultdict from typing import Any, TYPE_CHECKING -import torch -import torch.fx as fx +import torch # noqa: TC001 +import torch.fx as fx # noqa: TC001 from torch._dynamo.graph_deduplication import _stable_topological_sort from torch._inductor.fx_passes.bucketing import ( + _get_collective_node_from_wait, _schedulable_wait_node, + BucketMode, is_all_gather_into_tensor as is_all_gather, is_fsdp_all_gather, is_fsdp_reduce_scatter, @@ -53,6 +55,9 @@ def __init__( ): super().__init__(*args, **kwargs) self.node_to_wait_map: dict[fx.Node, fx.Node] = defaultdict() + # Maps bucketed nodes to their type string, scoped to this bucketer + # instance so metadata doesn't leak across separate invocations. + self.bucketed_node_types: dict[fx.Node, str] = {} def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: assert len(coll_nodes) > 0, "bucketed coll_nodes should have nonzero node" @@ -72,7 +77,7 @@ def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: coll_nodes, wait_insertion_point=first_wait, insert_before=next_node, - mode="custom_ops", + mode=self.bucket_mode, ) elif is_reduce_scatter(first): new_nodes, replacements = merge_reduce_scatter_bucket( @@ -80,7 +85,7 @@ def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: coll_nodes, wait_insertion_point=first_wait, insert_before=next_node, - mode="custom_ops", + mode=self.bucket_mode, ) else: raise ValueError( @@ -89,25 +94,33 @@ def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: logger.debug(f"bucketing nodes: {coll_nodes} into {new_nodes}") # noqa: G004 - # Identify the new wait and start - new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] - assert len(new_waits) == 1, f"Expected exactly one new wait, got {new_waits}" - new_wait = new_waits[0] - new_start = new_wait.args[0] - assert isinstance(new_start, fx.Node) - - # Set manual bucketing-specific metadata - # Note: Generic metadata (nn_module_stack, fwd_nn_module_stack, custom, stack_trace) - # is now preserved automatically by the bucketing functions in bucketing.py + # Identify the new wait(s) and their collective start in a single pass + wait_to_start = { + n: start + for n in new_nodes + if (start := _get_collective_node_from_wait(n)) is not None + } + assert len(wait_to_start) >= 1, ( + f"Expected at least one new wait, got none in {new_nodes}" + ) + new_waits = list(wait_to_start) + new_start: fx.Node = wait_to_start[new_waits[0]] + # Use last wait as the canonical wait for scheduling (same node when len == 1) + new_wait = new_waits[-1] + + # Track bucketed node types on this bucketer instance so it doesn't leak + # when the same graph is processed by multiple ManualOverlapScheduler + # invocations (e.g. separate forward and backward passes). node_type = ( "bucketed_all_gather" if is_all_gather(first) else "bucketed_reduce_scatter" ) + wait_set = OrderedSet(new_waits) for n in new_nodes: - if n == new_wait: - node_type = node_type + "_wait" - n.meta["manual_bucket_node_type"] = node_type - if "wait" in node_type: + if n in wait_set: + self.bucketed_node_types[n] = node_type + "_wait" self.node_to_wait_map[n] = new_wait + elif n is new_start: + self.bucketed_node_types[n] = node_type def manual_bucket_collectives(self, nodes: list[fx.Node]) -> None: """ @@ -143,7 +156,10 @@ def __init__( module_bucket_plans: list[list[str] | str], insert_overlap_deps: bool, module_stack_fn: Callable[[fx.Node], list[tuple[str, type[Any]]]] | None = None, + bucket_mode: BucketMode | None = None, ): + # Manual overlap historically used "custom_ops" mode for bucketing + bucket_mode = bucket_mode or "custom_ops" super().__init__( gm, max_in_flight_gb=0.0, @@ -156,6 +172,7 @@ def __init__( collective_estimator="analytical", max_memory_increase_gb=None, max_memory_increase_ratio=None, + bucket_mode=bucket_mode, ) self.module_bucket_plans = module_bucket_plans self.nodes_in_subgraph: list[list[fx.Node]] = [] @@ -164,6 +181,7 @@ def __init__( graph=self.graph, collective_info=self.collective_info, scheduled=OrderedSet(self.graph.nodes), + bucket_mode=bucket_mode, ) self.insert_overlap_deps = insert_overlap_deps @@ -246,7 +264,7 @@ def _manual_reorder_graph(self) -> None: # schedule reduce scatter normally in self._schedule while self.on_path_ready: _, node = heapq.heappop(self.on_path_ready) - node_type = node.meta.get("manual_bucket_node_type", "") + node_type = self.bucketer.bucketed_node_types.get(node, "") if node in self.scheduled: continue @@ -273,7 +291,7 @@ def _manual_reorder_graph(self) -> None: last_compute: fx.Node | None = None for node in self.scheduled: - node_type = node.meta.get("manual_bucket_node_type", "") + node_type = self.bucketer.bucketed_node_types.get(node, "") if node_type == "bucketed_all_gather": picked_ag.append(node) continue @@ -352,6 +370,7 @@ def manual_overlap_bucketing( module_bucket_plans: list[list[str] | str], insert_overlap_deps: bool = False, module_stack_fn: Callable[[fx.Node], list[tuple[str, type[Any]]]] | None = None, + bucket_mode: BucketMode | None = None, ) -> torch.fx.GraphModule: """Schedule nodes based on user specifications in module_bucket_plans The manual overlapping consists of two steps: @@ -370,10 +389,15 @@ def manual_overlap_bucketing( See the `module_stack_fn` parameter in `make_graph_view` (graph_view.py) for detailed documentation on signature, return format, and usage examples. + bucket_mode: Bucket mode for collective bucketing. None uses default. """ # decode abbreviated FQNs to actual FQNs overlapped_gm = ManualOverlapScheduler( - gm, module_bucket_plans, insert_overlap_deps, module_stack_fn + gm, + module_bucket_plans, + insert_overlap_deps, + module_stack_fn, + bucket_mode=bucket_mode, ).run() overlapped_gm.recompile() return overlapped_gm diff --git a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py index 3186ea8a4477a..0a58be6849159 100644 --- a/torch/_inductor/fx_passes/overlap_preserving_bucketer.py +++ b/torch/_inductor/fx_passes/overlap_preserving_bucketer.py @@ -9,6 +9,8 @@ from torch._dynamo.utils import counters from torch._inductor.augmented_graph_helper import AugmentedGraphHelper from torch._inductor.fx_passes.bucketing import ( + _default_bucket_mode, + _get_collective_node_from_wait, _schedulable_wait_node, BucketMode, get_full_bucket_key, @@ -47,7 +49,7 @@ def __init__(self, node1: fx.Node, node2: fx.Node) -> None: def __call__(self, reason: str, *args: Any) -> None: if bucket_log.isEnabledFor(logging.DEBUG): bucket_log.debug( - "cannot bucket %s with %s: " + reason, # noqa: G003 + "cannot bucket %s with %s: " + reason, self.name1, self.name2, *args, @@ -139,7 +141,7 @@ def __init__( max_coll_distance: int = 1000, insert_overlap_deps: bool = False, collective_bucketing: bool = True, - bucket_mode: BucketMode = "custom_ops_multidtype", + bucket_mode: BucketMode | None = None, bucket_exposed_first: bool | None = None, region_of: dict[fx.Node, Any] | None = None, bucket_only_internode_comms: bool = False, @@ -153,7 +155,7 @@ def __init__( self.insert_overlap_deps = insert_overlap_deps self.bucket_exposed_first = bucket_exposed_first self.bucket_only_internode_comms = bucket_only_internode_comms - self.bucket_mode = bucket_mode + self.bucket_mode = bucket_mode or _default_bucket_mode() self.collective_bucketing = collective_bucketing self.region_of: dict[fx.Node, Any] = region_of or {} self.node_to_event: dict[fx.Node, PGEvent] = {} @@ -227,8 +229,8 @@ def build_timeline(self, pg: str) -> PGEvent | None: node_type = "starts" hiding_nodes |= self.collective_info[node].hiding_nodes elif _schedulable_wait_node(node): - wait_input = node.args[0] - if isinstance(wait_input, fx.Node) and get_group_name(wait_input) == pg: + wait_coll = _get_collective_node_from_wait(node) + if isinstance(wait_coll, fx.Node) and get_group_name(wait_coll) == pg: node_type = "waits" # Wait for a different PG but hiding a collective on this PG elif node in hiding_nodes: @@ -336,6 +338,10 @@ def _apply_deps_and_effect_tokens(self) -> None: """Apply topological sort and effect tokens to preserve overlap.""" from torch._dynamo.graph_deduplication import _stable_topological_sort + # Clean up any remaining erased node references and cycles + self.aug_graph.remove_erased_extra_deps() + autofix = torch._inductor.config.aten_distributed_optimizations.overlap_scheduling_autofix_cycles + self.aug_graph.check_and_maybe_autofix_cyclic_extra_deps(autofix=autofix) additional_deps = self.aug_graph.get_all_extra_deps() for n, deps in additional_deps.items(): @@ -583,10 +589,9 @@ def _get_intervals( coll = event.node # For wait events, look up the start node from the event's args elif event.is_wait: - wait_input = event.node.args[0] - if not isinstance(wait_input, fx.Node): + coll = _get_collective_node_from_wait(event.node) + if coll is None: return None, [] - coll = wait_input else: return None, [] @@ -998,13 +1003,13 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: self.graph, bucket, insert_before=next_node, - mode="custom_ops", + mode=self.bucket_mode, ) elif is_all_reduce_tensor(bucket[0]): new_nodes, replacements = merge_all_reduce_bucket( self.graph, bucket, - mode="custom_ops", + mode=self.bucket_mode, insert_before=next_node, ) else: @@ -1013,39 +1018,51 @@ def _apply_bucket(self, bucket_info: CollBucket) -> None: self.graph, bucket, insert_before=next_node, - mode="custom_ops", + mode=self.bucket_mode, ) - # Get new nodes - new_waits = [n for n in new_nodes if _schedulable_wait_node(n)] - assert len(new_waits) == 1 - - new_wait = new_waits[0] - new_start = new_wait.args[0] - assert isinstance(new_start, fx.Node) + # Identify the new wait(s) and their collective start in a single pass + wait_to_start = { + n: start + for n in new_nodes + if (start := _get_collective_node_from_wait(n)) is not None + } + new_waits = list(wait_to_start) # Create mapping of all erased nodes to their replacements erased_to_new: dict[fx.Node, fx.Node | None] = {} - for old_start in old_starts: - erased_to_new[old_start] = new_start - for old_wait in old_waits: - erased_to_new[old_wait] = new_wait + new_start = wait_to_start[new_waits[0]] + if len(new_waits) == 1: + # Standard bucketing: single start + single wait + new_wait = new_waits[0] + for old_start in old_starts: + erased_to_new[old_start] = new_start + for old_wait in old_waits: + erased_to_new[old_wait] = new_wait + else: + # Coalesced bucketing: single start + N waits (one per original tensor) + assert len(new_waits) == len(old_waits) + for old_start in old_starts: + erased_to_new[old_start] = new_start + erased_to_new.update(dict(zip(old_waits, new_waits))) # Handle convert_element_type nodes that were fused and erased # The bucketed operation may have a _pre_bucket op that handles dtype conversion if fused_convert_dtypes: - # all gather bucketing may fuse in dtype conversion into the bucketing - # if so, we need to transfer hiding deps from the old dtype conversion - # to the new bucketing node - new_convert_dtypes_node = new_start.kwargs["out"] - assert isinstance(new_convert_dtypes_node, fx.Node) - assert ( - new_convert_dtypes_node.target + # In custom_ops mode, the _pre_bucket_all_gather node handles dtype conversion + # In default mode, convert nodes are just erased — map them to new_start + new_convert_dtypes_node = new_start.kwargs.get("out") + if ( + isinstance(new_convert_dtypes_node, fx.Node) + and new_convert_dtypes_node.target == torch.ops.bucketing._pre_bucket_all_gather.default - ) + ): + replacement = new_convert_dtypes_node + else: + replacement = new_start for n in fused_convert_dtypes: - erased_to_new[n] = new_convert_dtypes_node + erased_to_new[n] = replacement # Transfer all dependencies from old nodes to new nodes self.aug_graph.transfer_erased_node_deps(erased_to_new) @@ -1063,6 +1080,7 @@ def finalize_overlap_scheduling( region_of: dict[fx.Node, Any] | None = None, bucket_exposed_first: bool | None = None, bucket_only_internode_comms: bool = False, + bucket_mode: BucketMode | None = None, ) -> None: """ Finalize overlap scheduling by applying deps, inlining fusions, and optionally bucketing. @@ -1094,5 +1112,6 @@ def finalize_overlap_scheduling( bucket_exposed_first=bucket_exposed_first, bucket_only_internode_comms=bucket_only_internode_comms, region_of=region_of, + bucket_mode=bucket_mode, ) bucketer.bucket_collectives() diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index b19affd28cc03..9b52b8c18e2a5 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -14,6 +14,8 @@ from torch._inductor import config from torch._inductor.comm_analysis import estimate_fx_collective_memory_footprint from torch._inductor.fx_passes.bucketing import ( + _default_bucket_mode, + _get_collective_node_from_wait, _schedulable_wait_node, bucket_key, BucketMode, @@ -136,6 +138,7 @@ def estimate_collective_time( override_size: int | None = None, custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None = None, + collective_estimator: Literal["analytical", "benchmark"] = "analytical", ) -> float: """Estimate the runtime of a collective operation, optionally with an overridden size.""" if ( @@ -143,7 +146,16 @@ def estimate_collective_time( ) is not None: return est - # Use analytical model (benchmarking is handled separately in alignment) + if collective_estimator == "benchmark": + from torch._inductor.fx_passes.node_runtime_estimation import ( + benchmark_collective_with_cuda_events, + ) + + cuda_val, _ = benchmark_collective_with_cuda_events(n, nruns=5) + if cuda_val is not None: + return cuda_val + + # Analytical model (also fallback when benchmark returns None) return torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( n, override_size ) @@ -369,8 +381,9 @@ def __init__( insert_overlap_deps: bool, compute_overlap_multipler: float, max_coll_distance: int, - custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] | None, - collective_estimator: Literal["analytical", "benchmark"], + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + collective_estimator: Literal["analytical", "benchmark"] = "analytical", compute_estimator: Literal["analytical", "benchmark"] = "benchmark", max_memory_increase_gb: float | None = 1.0, max_memory_increase_ratio: float | None = 0.05, @@ -378,15 +391,27 @@ def __init__( bucket_exposed_first: bool | None = None, enable_fusion_regions: bool = False, bucket_only_internode_comms: bool = False, - bucket_mode: BucketMode = "custom_ops_multidtype", + bucket_mode: BucketMode | None = None, max_off_bucket_gb: float | None = 0.5, prioritize_bucketing_during_scheduling: bool = True, + pge_profile_path: str | None = None, ): self.gm = gm self.graph = gm.graph self.compute_overlap_multipler = compute_overlap_multipler self.max_node_distance = max_coll_distance self.max_in_flight_bytes: int = gb_to_bytes(max_in_flight_gb) + + # Profile-guided estimation: create estimator from profile path + if pge_profile_path and custom_runtime_estimation is None: + from torch._inductor.fx_passes.profile_guided_estimation import ( + ProfileGuidedEstimator, + ) + + custom_runtime_estimation = ProfileGuidedEstimator( + pge_profile_path, diagnostics_gm=gm + ) + self.custom_runtime_estimation = custom_runtime_estimation self.collective_bucketing = collective_bucketing self.insert_overlap_deps = insert_overlap_deps @@ -401,7 +426,7 @@ def __init__( self.log_final_collectives_estimations = log_final_collectives_estimations self.bucket_exposed_first = bucket_exposed_first self.bucket_only_internode_comms = bucket_only_internode_comms - self.bucket_mode = bucket_mode + self.bucket_mode = bucket_mode or _default_bucket_mode() self.max_off_bucket_bytes: int | None = ( gb_to_bytes(max_off_bucket_gb) if max_off_bucket_gb is not None else None ) @@ -421,20 +446,18 @@ def __init__( num_device_put_converted, ) - # Build and collapse fusion regions FIRST so all subsequent operations - # work on the collapsed graph where fused ops are atomic units - self.region_of: dict[fx.Node, Any] = {} - if enable_fusion_regions: - from torch._inductor.fx_passes.fusion_regions import ( - build_fusion_regions, - collapse_fusion_regions, - ) - - self.region_of = build_fusion_regions(self.gm) - if self.region_of: - self.region_of = collapse_fusion_regions(self.gm, self.region_of) - # fuse_by_partitions replaces gm.graph, so we need to update our reference - self.graph = gm.graph + # Build fusion regions (mutates gm.graph) and compute initial node runtime + # estimates. Compute nodes use roofline model here; the alignment step in + # run() replaces them with benchmarked + cross-rank-aligned values. + self.node_estimations, self.region_of = gather_node_runtime_estimations( + gm, + custom_runtime_estimation, + enable_fusion_regions=enable_fusion_regions, + log_estimations=True, + ) + if self.region_of: + # fuse_by_partitions replaces gm.graph, so we need to update our reference + self.graph = gm.graph # Build structures stable_topological_sort(self.graph) @@ -643,10 +666,18 @@ def _identify_collectives(self) -> None: for node in self.nodes: if _schedulable_wait_node(node): - start = node.args[0] - coll_time_ms = estimate_collective_time( - start, custom_runtime_estimation=self.custom_runtime_estimation + start = _get_collective_node_from_wait(node) + assert start is not None + assert start in self.node_estimations, ( + f"Missing estimation for collective {start.name}. " + f"Ensure custom_runtime_estimation returns a value for this node." ) + self.wait_to_start[node] = start + # For coalesced collectives, multiple waits share the same + # start node. Only register the first wait as the representative. + if start in self.collective_info: + continue + coll_time_ms = self.node_estimations[start] info = CollectiveInfo( start_node=start, @@ -656,7 +687,6 @@ def _identify_collectives(self) -> None: exposed_time_ms=coll_time_ms, # Initially fully exposed ) self.collective_info[start] = info - self.wait_to_start[node] = start self.unscheduled_collectives.add(start) self.all_pgs.add(get_group_name(start)) @@ -799,6 +829,9 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( continue if idx < compute_key_count: # Compute node + self.node_estimations[self.compute_nodes[idx]] = ( + median_runtime_estimation + ) set_cached_node_time(key, median_runtime_estimation) else: # Collective CUDA event benchmark @@ -814,6 +847,7 @@ def _align_compute_nodes_runtime_estimations_across_all_distributed_ranks( info = self.collective_info[coll_node] info.estimated_time_ms = median_runtime_estimation info.exposed_time_ms = median_runtime_estimation + self.node_estimations[coll_node] = median_runtime_estimation collective_keys.append(key) collective_medians.append(median_runtime_estimation) @@ -986,6 +1020,7 @@ def run(self) -> torch.fx.GraphModule: region_of=self.region_of, bucket_exposed_first=self.bucket_exposed_first, bucket_only_internode_comms=self.bucket_only_internode_comms, + bucket_mode=self.bucket_mode, ) if self.log_final_collectives_estimations: @@ -999,32 +1034,6 @@ def run(self) -> torch.fx.GraphModule: return self.gm - def get_non_collective_runtime_estimate(self, node: fx.Node) -> float | None: - """Get runtime estimation for a node in ms. Returns None if no estimation is available.""" - if is_compute_node(node): - if self.compute_estimator == "benchmark": - return benchmark_node(node, self.custom_runtime_estimation) - else: - return estimate_roofline_runtime_ms(node) - - # Use precomputed cost for fusion region call_module nodes - # This takes priority even over custom estimation since fusion regions - # have already computed their cost based on their contents - if node in self.region_of: - return self.region_of[node].cost_ms - - if self.custom_runtime_estimation is not None: - if (est := self.custom_runtime_estimation(node, None)) is not None: - return est - # Custom estimation provided but returned None - don't fall through to fusible estimation - return None - - # assume any node without flop counter is mem bound - if node.op == "call_function": - return estimate_roofline_runtime_ms(node) - - return None - def _reduce_exposed_time_of_in_flight_collectives( self, node: fx.Node, @@ -1063,7 +1072,7 @@ def _reduce_exposed_time_of_in_flight_collectives( def _handle_compute_or_other(self, node: fx.Node) -> None: """Handle scheduling compute or other nodes and attempt to overlap with collectives.""" - runtime_estimate = self.get_non_collective_runtime_estimate(node) + runtime_estimate = self.node_estimations.get(node) # TODO: we could consider skipping overlapping for overlapable, unary chains to collectives. # using these nodes for overlap prevents bucketing. potentially if chain time < latency @@ -1191,7 +1200,11 @@ def _handle_wait(self, node: fx.Node) -> None: """Handle scheduling a wait.""" assert node in self.wait_to_start coll_start = self.wait_to_start[node] - assert coll_start in self.in_flight + # For coalesced collectives, multiple waits share the same start node. + # The first wait completes the collective; subsequent waits just schedule. + if coll_start not in self.in_flight: + self._schedule(node) + return # Scheduling a wait of a collective also forces the wait # of every node enqueued prior to the collective on the @@ -1433,12 +1446,12 @@ def should_assume_bucketed(self, node: fx.Node) -> bool: if not torch._inductor.config.test_configs.assume_bucketing_reduces_latency: return False - key = bucket_key(node, mode="custom_ops_multidtype") + key = bucket_key(node, mode=self.bucket_mode) if key is None: return False for in_flight_coll in self.in_flight: - if bucket_key(in_flight_coll, mode="custom_ops_multidtype") == key: + if bucket_key(in_flight_coll, mode=self.bucket_mode) == key: return True return False @@ -1590,6 +1603,149 @@ def compute_potential_hidden_waits(self) -> dict[fx.Node, fx.Node]: return self.compute_potential_hidden_nodes(wait_nodes) +def gather_node_runtime_estimations( + gm: torch.fx.GraphModule, + custom_runtime_estimation: Callable[[fx.Node, int | None], float | None] + | None = None, + collective_estimator: Literal["analytical", "benchmark"] = "analytical", + enable_fusion_regions: bool = False, + log_estimations: bool = False, +) -> tuple[dict[fx.Node, float], dict[fx.Node, Any]]: + """Gather initial runtime estimations for all nodes without scheduling. + + Uses analytical models (roofline) for compute nodes — the alignment step + in OverlapScheduler.run() replaces these with benchmarked + cross-rank-aligned + values. Collectives use bandwidth formulas or CUDA events depending on + collective_estimator. + + When enable_fusion_regions is True, builds and collapses fusion regions + (mutating gm's graph), then includes their costs in the estimations. + + Args: + collective_estimator: "analytical" uses bandwidth formulas, + "benchmark" uses CUDA events for collectives. + log_estimations: When True, log compute and collective estimations + via trace_structured for tlparse. + + Returns (estimations, fusion_region_of) where estimations maps fx.Node to + runtime in ms, and fusion_region_of maps call_module nodes to FusionRegion + objects (empty dict if fusion regions are disabled). + """ + # Build and collapse fusion regions first (mutates gm) + fusion_region_of: dict[fx.Node, Any] = {} + if enable_fusion_regions: + from torch._inductor.fx_passes.fusion_regions import ( + build_fusion_regions, + collapse_fusion_regions, + ) + + fusion_region_of = build_fusion_regions(gm) + if fusion_region_of: + fusion_region_of = collapse_fusion_regions(gm, fusion_region_of) + + estimations: dict[fx.Node, float] = {} + nodes = list(gm.graph.nodes) + + # Collectives + collective_nodes: list[fx.Node] = [] + for node in nodes: + if _schedulable_wait_node(node): + start = _get_collective_node_from_wait(node) + assert start is not None + if start in estimations: + continue + estimations[start] = estimate_collective_time( + start, + custom_runtime_estimation=custom_runtime_estimation, + collective_estimator=collective_estimator, + ) + collective_nodes.append(start) + + # Compute nodes (matmul, bmm, etc.) — analytical estimates only. + # The alignment step in run() replaces these with benchmarked + aligned values. + compute_nodes: list[fx.Node] = [] + compute_analytical: list[float] = [] + + for node in nodes: + if is_compute_node(node): + est = estimate_roofline_runtime_ms(node) + if custom_runtime_estimation is not None: + custom_est = custom_runtime_estimation(node, None) + if custom_est is not None: + est = custom_est + estimations[node] = est + compute_nodes.append(node) + compute_analytical.append(est) + elif node.op == "call_function" and node not in estimations: + if custom_runtime_estimation is not None: + est = custom_runtime_estimation(node, None) + if est is not None: + estimations[node] = est + else: + est = estimate_roofline_runtime_ms(node) + if est > 0: + estimations[node] = est + + # Fusion region costs (call_module nodes from collapse_fusion_regions) + for node, region in fusion_region_of.items(): + estimations[node] = region.cost_ms # pyrefly: ignore[missing-attribute] + + # Logging + if log_estimations and compute_nodes: + from torch._inductor.fx_passes.node_runtime_estimation import ( + _log_compute_estimations, + ) + + _log_compute_estimations( + compute_nodes, + compute_analytical, + compute_analytical, + ) + + if log_estimations and collective_nodes: + from torch._inductor.fx_passes.node_runtime_estimation import ( + _log_collective_benchmarks, + ) + + _log_collective_benchmarks( + collective_nodes, + artifact_name="fx_collectives_analytical_estimation", + ) + + return estimations, fusion_region_of + + +def align_estimations_across_ranks( + estimations: dict[fx.Node, float], +) -> dict[fx.Node, float]: + """Align runtime estimations across distributed ranks using median. + + All ranks must make identical scheduling decisions, so we gather each + rank's values and take the median. All nodes in estimations are aligned. + + Returns a new estimations dict with aligned values. + """ + import torch.distributed as dist + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch.distributed.distributed_c10d import _get_default_group + + nodes = list(estimations.keys()) + if not nodes: + return {} + + local_values = [estimations[n] for n in nodes] + + world_size = dist.get_world_size() + pg = _get_default_group() + + with unset_fake_temporarily(): + gathered: list[list[float]] = [[] for _ in range(world_size)] + dist.all_gather_object(gathered, local_values, pg) + medians = torch.median(torch.tensor(gathered), dim=0).values.tolist() + + return dict(zip(nodes, medians)) + + def schedule_overlap_bucketing( gm: torch.fx.GraphModule, max_in_flight_gb: float = 5, @@ -1610,6 +1766,8 @@ def schedule_overlap_bucketing( bucket_only_internode_comms=False, prioritize_bucketing_during_scheduling: bool = True, max_off_bucket_gb: float | None = 0.5, + bucket_mode: BucketMode | None = None, + pge_profile_path: str | None = None, ) -> torch.fx.GraphModule: """Schedule nodes to maximize compute-collective overlap. @@ -1625,8 +1783,9 @@ def schedule_overlap_bucketing( compute_overlap_multipler: Scale factor for compute time used to hide collectives. This can be used to address over or under aggressive overlapping. max_coll_distance: Maximum pre fetch or bucketing candidates. Mainly intended for compile time - custom_runtime_estimation: Custom runtime estimation function that estimates runtime in ms for an fx node. - If None, uses default estimations. This is currently limited to collectives and compute nodes. + custom_runtime_estimation: Override runtime estimation for specific nodes. Called as + custom_runtime_estimation(node, override_size) -> float | None. To pass pre-computed + estimations, wrap a dict: lambda node, _: estimations.get(node). collective_estimator: Method for estimating collective runtime. "analytical" uses bandwidth formulas, "benchmark" uses CUDA events with power-of-2 rounding and interpolation. compute_estimator: Method for estimating compute (ATen op) runtime. "analytical" uses roofline model @@ -1635,6 +1794,7 @@ def schedule_overlap_bucketing( max_memory_increase_ratio: Maximum increase as ratio of baseline peak memory. If None, no ratio limit. Uses minimum of absolute and ratio limits when both are specified. enable_fusion_regions: Enable fusion region detection and cost estimation for fusible ops. + bucket_mode: Bucketing mode for grouping collectives. """ if not any(is_wait_tensor(n) for n in gm.graph.nodes): return gm @@ -1666,6 +1826,8 @@ def schedule_overlap_bucketing( bucket_only_internode_comms=bucket_only_internode_comms, prioritize_bucketing_during_scheduling=prioritize_bucketing_during_scheduling, max_off_bucket_gb=max_off_bucket_gb, + bucket_mode=bucket_mode, + pge_profile_path=pge_profile_path, ).run() trace_structured( "artifact", @@ -1675,6 +1837,7 @@ def schedule_overlap_bucketing( }, payload_fn=lambda: ret.print_readable(False), ) + return ret @@ -1712,9 +1875,15 @@ def schedule_overlap_bucketing_from_inductor_configs( "bucket_only_internode_comms", "enable_fusion_regions", "prioritize_bucketing_during_scheduling", + "bucket_mode", ) for key in config_keys: if (val := getattr(dist_opts, key, None)) is not None: kwargs[key] = val + # Profile-guided latency estimation + pge_path = dist_opts.profile_guided_estimations_profile_path + if pge_path and "custom_runtime_estimation" not in kwargs: + kwargs["pge_profile_path"] = pge_path + return schedule_overlap_bucketing(gm, **kwargs) # type: ignore[arg-type] diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index df270053d45f9..5a0deb68090f8 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -84,11 +84,13 @@ def check_dtype(a: Tensor, b: Tensor) -> bool: return a.is_floating_point() and b.is_floating_point() -def realize_symbols( - ds: torch.Size | tuple[torch.SymInt, ...], +def hint_symbols( + ds: Sequence[int | torch.SymInt], ) -> list[int]: """Helper to convert symbolic dimensions to their concrete hint values.""" - return [d if isinstance(d, int) else d.node.hint for d in ds] + from torch.fx.experimental.symbolic_shapes import optimization_hint + + return [optimization_hint(d) for d in ds] def can_pad( @@ -102,31 +104,15 @@ def can_pad( All logic related to whether it's safe to pad should be here. """ - # It's fine we have symbolic shapes or strides as long as they - # have hints. Later, we will make sure we only pad non-symbolic dimensions. - def valid_shape_and_stride(t: Tensor | None) -> bool: - if t is None: - return True - - symbolic_cnt = 0 + # Can't pad if there is no static dims, we pad static dims only. + def has_one_static_dim(t: Tensor) -> bool: + """Return False if all dimensions are symbolic — nothing concrete to pad.""" for x in t.size(): if isinstance(x, int): - continue - elif utils.is_symbolic(x): - # pyrefly: ignore [missing-attribute] - if not x.node.has_hint(): - return False - symbolic_cnt += 1 - else: - return False - # filter out cases where all dimensions are symbolic - if symbolic_cnt == len(t.size()): - return False - return all( - # pyrefly: ignore [missing-attribute] - isinstance(x, int) or (utils.is_symbolic(x) and x.node.has_hint()) - for x in t.stride() - ) + return True + elif not isinstance(x, torch.SymInt): + raise RuntimeError("not expected size") + return False # Basic safety checks if not torch._inductor.config.shape_padding: @@ -138,15 +124,16 @@ def valid_shape_and_stride(t: Tensor | None) -> bool: if not check_dtype(mat1, mat2): return False - if not all(valid_shape_and_stride(t) for t in (mat1, mat2, input)): + # For padding to be vaible each tensor should have at least one static dim. + tensors = [t for t in (mat1, mat2, input) if t is not None] + if not all(has_one_static_dim(t) for t in tensors): return False - # Check for zero dimensions - not safe to pad + # Skip zero-sized dimensions — padding would be wasteful (mm on empty tensors) + from torch.fx.experimental.symbolic_shapes import optimization_hint + if any( - dim == 0 - for dim in itertools.chain( - realize_symbols(mat1.shape), realize_symbols(mat2.shape) - ) + optimization_hint(dim) == 0 for dim in itertools.chain(mat1.shape, mat2.shape) ): return False @@ -172,10 +159,11 @@ def valid_shape_and_stride(t: Tensor | None) -> bool: return False # In deterministic mode, we can't safely benchmark - disallow padding - # Check this after other basic checks so force_shape_pad can override + # Check this after other basic checks so force_shape_pad/autoheuristic can override if ( torch._inductor.config.deterministic and not torch._inductor.config.force_shape_pad + and not torch._inductor.config.use_autoheuristic("pad_mm") ): return False @@ -388,11 +376,15 @@ def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node: def should_exclude_padding_time(match: Match, arg_name: str) -> bool: + from torch._prims_common import is_contiguous_or_false + node_def = get_non_view_def(match.kwargs[arg_name]) # constant padding converts tensors to contiguous so even if the input tensor # can be planned layout transform is not free. TODO - way to pad and preserve layout ? - if not fetch_fake_tensors(match, (arg_name,))[0].is_contiguous(): + # Use is_contiguous_or_false to avoid guarding on data-dependent expressions + # with unbacked symints - returns False instead of raising an error. + if not is_contiguous_or_false(fetch_fake_tensors(match, (arg_name,))[0]): return False # TODO - see issue https://github.com/pytorch/pytorch/issues/128889 @@ -519,15 +511,19 @@ def _should_pad( if torch._inductor.config.force_shape_pad: return True + # Resolve symbolic dims to concrete hints for heuristic checks below. + # These are performance decisions, not correctness — optimization_hint is safe. + m_concrete, k_concrete, n_concrete = hint_symbols((m, k, n)) + # Performance heuristic for bf16 large K scenarios if ( "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options - and should_pad_mm_bf16(mat1.dtype, m, n, k) + and should_pad_mm_bf16(mat1.dtype, m_concrete, n_concrete, k_concrete) ): return True # Check if operation is compute bound (performance check) - if not is_mm_compute_bound(m, k, n, mat1.dtype): + if not is_mm_compute_bound(m_concrete, k_concrete, n_concrete, mat1.dtype): return False # We don't want to look up the cache for cases that are trivially false @@ -540,9 +536,9 @@ def _should_pad( def realize_tensor(t): if isinstance(t, FakeTensor): - size_hints = realize_symbols(t.size()) + size_hints = hint_symbols(t.size()) # pyrefly: ignore [bad-argument-type] - stride_hint = realize_symbols(t.stride()) + stride_hint = hint_symbols(t.stride()) real_size = ( sum((d - 1) * s for d, s in zip(size_hints, stride_hint)) + 1 ) @@ -677,6 +673,10 @@ def pad_bench_fn(): if ah_should_pad is not None: return ah_should_pad + # AH didn't make a decision, so if we're in deterministic mode, we should return false + if torch._inductor.config.deterministic: + return False + if ori_time is None: ori_time = do_bench(orig_bench_fn) set_cached_base_mm_benchmark_time(ori_time_key, ori_time) diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 0b94a2504d30b..f6b7557e5a81d 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -245,6 +245,15 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): pass_name = "custom_backend_passes_" + device GraphTransformObserver(gm, pass_name).apply_gm_pass(custom_backend_pass) + # SPMD verification — before collective reordering passes. + if ( + config.aten_distributed_optimizations.spmd_check + and _needs_spmd_graph_preservation() + ): + from torch._inductor.fx_passes.spmd_check import spmd_check + + spmd_check(gm) + collectives_bucketing: bool = False if config.bucket_reduce_scatters_fx != "none": @@ -356,6 +365,15 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): ) ) + if config.aten_distributed_optimizations.enable_low_contention_collectives: + from torch._inductor.fx_passes.low_contention_collectives import ( + replace_collectives_with_low_contention, + ) + + GraphTransformObserver( + gm, "replace_collectives_with_low_contention" + ).apply_graph_pass(replace_collectives_with_low_contention) + # Keep these last, since they introduce mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass( @@ -777,6 +795,13 @@ def check(): def check(): return True + def consumes_rng_state(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and isinstance(node.target, torch._ops.OpOverload) + and torch.Tag.nondeterministic_seeded in node.target.tags + ) + def visit(other_node): if ( other_node.op == "call_function" @@ -786,6 +811,11 @@ def visit(other_node): == get_mutation_region_id(graph, other_node) and check() ): + # Ops that consume RNG state are order-sensitive and must not be + # reordered during locality optimization. + if consumes_rng_state(other_node): + return + # move node's producers right before it node.prepend(other_node) @@ -1037,8 +1067,20 @@ def register_fun(cond): return register_fun +def _needs_spmd_graph_preservation() -> bool: + """Check if SPMD graph preservation is needed for distributed overlap.""" + return ( + config.aten_distributed_optimizations.enable_overlap_scheduling + or config.reorder_for_compute_comm_overlap + ) + + @register_noop_decomp(aten.slice) def slice_noop(self, dim=0, start=None, end=None, step=1): + if _needs_spmd_graph_preservation(): + # Keep no-op slices so all ranks produce identical FX graphs (SPMD) + # with matching op counts and runtime estimations. + return False if start is None or end is None: return False @@ -1082,6 +1124,10 @@ def repeat_noop(self, repeats): @register_noop_decomp(aten.constant_pad_nd) def constant_pad_nd(x, padding, fill_value=0): + if _needs_spmd_graph_preservation(): + # Keep no-op pads so all ranks produce identical FX graphs (SPMD) + # with matching op counts and runtime estimations. + return False return all(p == 0 for p in padding) @@ -1530,9 +1576,10 @@ def should_prefer_unfused_addmm(match): extra_check=should_prefer_unfused_addmm, ) def unfuse_bias_add_to_pointwise(match: Match, mat1, mat2, *, inp, alpha, beta): - # Unfusing addmm introduces an extra bf16/fp16 truncation at the mm output - # that compounds through deep models and causes accuracy failures. - if inp.meta["val"].dtype in (torch.bfloat16, torch.float16): + if config.keep_addmm_fused_for_half_dtypes and inp.meta["val"].dtype in ( + torch.bfloat16, + torch.float16, + ): return def repl(inp, x1, x2, alpha, beta): @@ -1657,7 +1704,7 @@ def is_index_put_and_requires_h2d_sync_for_gpu_value(node): ]: return False # Inductor falls back to aten.index_put_. - # index_put_ will will call nonzero() and perform a H2D sync if + # index_put_ will call nonzero() and perform a H2D sync if # any of its indices are bool/byte tensors # However, it will short-circuit this H2D sync and run mask_fill_ # if the value we are putting is a cpu scalar. @@ -1876,7 +1923,8 @@ def __call__(self, graph: fx.Graph) -> None: lambda x: x not in [cpu_concat, gpu_concat, gpu_split, gpu_node] + unsqueezed_nodes - and x.target != torch.ops.aten.copy_.default, + and x.target != torch.ops.aten.copy_.default + and x.target != "output", ) last_node = gpu_node diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index ef754d7ff4f0f..efc953f97068f 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -431,16 +431,20 @@ def remove_identity(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: """ Removes all identity layers from the module. """ - - class IdentityRemover(torch.fx.Transformer): - def call_module(self, target, args, kwargs): - if isinstance(self.submodules[target], nn.Identity): - assert len(args) == 1 - return args[0] - else: - return super().call_module(target, args, kwargs) - - return IdentityRemover(gm).transform() + graph = gm.graph + work_done = False + for module_name, module in gm.named_modules(): + if type(module) is nn.Identity: + for node in list(graph.find_nodes(op="call_module", target=module_name)): + assert len(node.args) == 1 + input_node = node.args[0] + node.replace_all_uses_with(input_node) + graph.erase_node(node) + work_done = True + if work_done: + graph.lint() + gm.recompile() + return gm def fuse_conv_bn(gm: torch.fx.GraphModule, inplace=False) -> torch.fx.GraphModule: diff --git a/torch/_inductor/fx_passes/profile_guided_estimation.py b/torch/_inductor/fx_passes/profile_guided_estimation.py new file mode 100644 index 0000000000000..d70d649eaa26b --- /dev/null +++ b/torch/_inductor/fx_passes/profile_guided_estimation.py @@ -0,0 +1,828 @@ +""" +Profile-Guided Estimation (PGE) for overlap scheduling. + +Parses a Chrome Trace JSON (from torch.profiler) and builds lookup tables +for kernel runtimes (collectives, matmuls, attention, custom ops, etc.). +Used as a custom_runtime_estimation hook in the overlap scheduler. + +When the same profile is loaded on all ranks, estimates are deterministic +and no cross-rank synchronization is needed. +""" + +from __future__ import annotations + +import functools +import json +import logging +import math +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any + +import torch +import torch.fx as fx +from torch._inductor.analysis.profile_analysis import ( + _create_extern_mapping, + _dtype_map, + _get_size_from_string, + ParseException, +) +from torch._inductor.fx_passes.bucketing import ( + is_all_gather_into_tensor, + is_all_reduce_tensor, + is_all_to_all_tensor, + is_reduce_scatter_tensor, +) +from torch._logging import trace_structured +from torch.utils._ordered_set import OrderedSet + + +log = logging.getLogger(__name__) + + +def _rank_stride(ranks: tuple[int, ...]) -> int | None: + """Compute the stride of a sorted rank tuple, or None if non-uniform. + + Examples: + (0, 2, 4, 6) → stride 2 + (0, 1) → stride 1 + (1, 3, 5, 7) → stride 2 + (0, 1, 4, 5) → None (non-uniform) + """ + if len(ranks) <= 1: + return None + stride = ranks[1] - ranks[0] + if stride <= 0: + return None + for i in range(2, len(ranks)): + if ranks[i] - ranks[i - 1] != stride: + return None + return stride + + +@dataclass +class CollectiveRecord: + """A single collective kernel observation from the profile.""" + + collective_name: str # "all_gather_into_tensor", "reduce_scatter_tensor", etc. + pg_ranks: tuple[int, ...] # sorted rank tuple + group_size: int + in_nelems: int # "In msg nelems" from profile + out_nelems: int # "Out msg nelems" from profile + dtype: str # "Float", "BFloat16", etc. + duration_us: float + + +@dataclass +class OpRecord: + """A single op observation from the profile (any CPU op with GPU kernels).""" + + op_name: str # normalized name, e.g. "aten::mm", "mylib::my_custom_op" + input_shapes: tuple[tuple[int, ...], ...] + input_strides: tuple[tuple[int, ...], ...] + dtype: torch.dtype | None + duration_us: float # sum of all GPU kernels for this CPU op + + +def _to_nested_tuple(x: Any) -> Any: + """Recursively convert nested lists to tuples for hashability.""" + if isinstance(x, (list, tuple)): + return tuple(_to_nested_tuple(i) for i in x) + return x + + +@dataclass +class ProfileData: + """Parse Chrome Trace JSON and build lookup tables for kernel runtimes.""" + + collectives: list[CollectiveRecord] = field(default_factory=list) + ops: list[OpRecord] = field(default_factory=list) + pg_configs: dict[str, tuple[int, ...]] = field(default_factory=dict) + + # Lookup indices built after loading + _collective_index: dict[ + tuple[str, tuple[int, ...], str], list[tuple[int, float]] + ] = field(default_factory=dict) + # Fallback index by mesh dimension (name, stride, group_size, dtype). + # Matches PGs belonging to the same mesh dimension regardless of specific ranks. + # E.g. (0,2,4,6) and (1,3,5,7) both have stride=2, size=4 → same mesh dim. + _collective_index_by_mesh_dim: dict[ + tuple[str, int, int, str], list[tuple[int, float]] + ] = field(default_factory=dict) + # Count of distinct PGs per mesh dimension (stride, group_size) — used for + # ambiguity check (skip fallback if multiple PGs share the same mesh dim). + _pg_count_by_mesh_dim: dict[tuple[int, int], int] = field(default_factory=dict) + # Generic op index: (op_name, input_shapes, input_strides, dtype) -> avg_dur_us + _op_index: dict[ + tuple[ + str, + tuple[tuple[int, ...], ...], + tuple[tuple[int, ...], ...], + torch.dtype | None, + ], + float, + ] = field(default_factory=dict) + # Peak observed bandwidth per PG (GB/s), computed from largest messages + _pg_peak_bw: dict[tuple[int, ...], float] = field(default_factory=dict) + # Mesh-dimension fallback: (stride, group_size) -> peak BW (GB/s) + _mesh_dim_peak_bw: dict[tuple[int, int], float] = field(default_factory=dict) + + def load(self, trace_path: str) -> None: + """Load and parse a Chrome Trace JSON file.""" + import os + + if not os.path.isfile(trace_path): + raise FileNotFoundError( + f"PGE trace file not found: {trace_path}. " + f"Check config.aten_distributed_optimizations.profile_guided_estimations_profile_path" + ) + with open(trace_path) as f: + data = json.load(f) + + self._parse_pg_configs(data) + self._parse_events(data) + self._build_indices() + + log.info( + "PGE loaded: %d collectives, %d op records (%d distinct shapes), %d PGs", + len(self.collectives), + len(self.ops), + len(self._op_index), + len(self.pg_configs), + ) + + def _parse_pg_configs(self, data: dict[str, Any]) -> None: + dist_info = data.get("distributedInfo", {}) + pg_config = dist_info.get("pg_config", {}) + # pg_config can be a list of dicts or a dict of dicts + if isinstance(pg_config, list): + for pg_info in pg_config: + pg_name = str(pg_info.get("pg_name", "")) + ranks = pg_info.get("ranks", []) + if ranks: + self.pg_configs[pg_name] = tuple(sorted(ranks)) + elif isinstance(pg_config, dict): + for pg_name, pg_info in pg_config.items(): + ranks = pg_info.get("ranks", []) + if ranks: + self.pg_configs[pg_name] = tuple(sorted(ranks)) + + def _parse_events(self, data: dict[str, Any]) -> None: + events = data.get("traceEvents", []) + # Reuse profile_analysis's External id -> CPU op mapping + try: + extern_mapping = _create_extern_mapping(data) + except (ParseException, KeyError): + # Malformed trace (e.g. duplicate External ids, missing traceEvents) + extern_mapping = defaultdict(list) + for ev in events: + if ( + isinstance(ev, dict) + and ev.get("cat") == "cpu_op" + and "args" in ev + and "External id" in ev["args"] + ): + extern_mapping[ev["args"]["External id"]].append(ev) + + # Build External id -> total GPU kernel duration + gpu_dur: dict[int, float] = defaultdict(float) + for ev in events: + if not isinstance(ev, dict) or ev.get("cat") != "kernel": + continue + args = ev.get("args", {}) + eid = args.get("External id") + dur = ev.get("dur", 0.0) + if eid is not None and dur > 0: + gpu_dur[eid] += dur + + # Parse collectives from GPU kernel events directly + # (NCCL kernels carry collective metadata in args) + for ev in events: + if not isinstance(ev, dict) or ev.get("cat") != "kernel": + continue + args = ev.get("args", {}) + coll_name = args.get("Collective name") + if coll_name is None: + continue + pg_name = args.get("Process Group Name", "") + pg_ranks_str = args.get("Process Group Ranks", "") + group_size = args.get("Group size", 0) + in_nelems = args.get("In msg nelems", 0) + out_nelems = args.get("Out msg nelems", 0) + dtype = args.get("dtype", "") + dur = ev.get("dur", 0.0) + if dur <= 0: + continue + + pg_ranks = self._parse_ranks(pg_ranks_str, pg_name) + + self.collectives.append( + CollectiveRecord( + collective_name=coll_name, + pg_ranks=pg_ranks, + group_size=group_size, + in_nelems=in_nelems, + out_nelems=out_nelems, + dtype=dtype, + duration_us=dur, + ) + ) + + # Parse all CPU ops that have associated GPU kernels + for eid, cpu_evs in extern_mapping.items(): + if not cpu_evs: + continue + total_dur = gpu_dur.get(eid, 0.0) + if total_dur <= 0: + continue + cpu_ev = cpu_evs[0] + self._parse_op(cpu_ev.get("name", ""), cpu_ev.get("args", {}), total_dur) + + def _parse_ranks(self, ranks_str: str, pg_name: str) -> tuple[int, ...]: + """Parse rank list from profile string or fall back to pg_configs.""" + if isinstance(ranks_str, str) and ranks_str.startswith("["): + try: + ranks = json.loads(ranks_str) + return tuple(sorted(ranks)) + except (json.JSONDecodeError, TypeError): + pass + # Fall back to pg_configs + if pg_name in self.pg_configs: + return self.pg_configs[pg_name] + return () + + def _parse_op(self, name: str, args: dict[str, Any], total_dur: float) -> None: + """Parse any CPU op into a generic OpRecord.""" + input_dims = args.get("Input Dims", []) + input_strides = args.get("Input Strides", []) + input_types = args.get("Input type", []) + if not input_dims: + return + dtype_str = input_types[0] if input_types else "" + dtype = _dtype_map.get(dtype_str) + # Skip empty entries (non-tensor args like scalars/None) so the + # tuples match what _get_node_input_shapes/strides extract from FX nodes. + shapes = tuple( + _to_nested_tuple(d) + for d in input_dims + if isinstance(d, (list, tuple)) and d + ) + strides = tuple( + _to_nested_tuple(d) + for d in input_strides + if isinstance(d, (list, tuple)) and d + ) + if not shapes: + return + self.ops.append( + OpRecord( + op_name=name, + input_shapes=shapes, + input_strides=strides, + dtype=dtype, + duration_us=total_dur, + ) + ) + + def _build_indices(self) -> None: + """Build lookup indices from parsed records.""" + coll_idx: dict[tuple[str, tuple[int, ...], str], list[tuple[int, float]]] = ( + defaultdict(list) + ) + coll_idx_by_mesh_dim: dict[ + tuple[str, int, int, str], list[tuple[int, float]] + ] = defaultdict(list) + # Track distinct PG rank sets per mesh dimension for ambiguity check + pg_sets_by_mesh_dim: dict[tuple[int, int], OrderedSet[tuple[int, ...]]] = ( + defaultdict(OrderedSet) + ) + for rec in self.collectives: + norm_name = self._normalize_collective_name(rec.collective_name) + gs = len(rec.pg_ranks) if rec.pg_ranks else rec.group_size + coll_idx[(norm_name, rec.pg_ranks, rec.dtype)].append( + (rec.out_nelems, rec.duration_us) + ) + stride = _rank_stride(rec.pg_ranks) + if stride is not None: + coll_idx_by_mesh_dim[(norm_name, stride, gs, rec.dtype)].append( + (rec.out_nelems, rec.duration_us) + ) + pg_sets_by_mesh_dim[(stride, gs)].add(rec.pg_ranks) + # Sort by nelems for interpolation + self._collective_index = { + k: sorted(v, key=lambda x: x[0]) for k, v in coll_idx.items() + } + self._collective_index_by_mesh_dim = { + k: sorted(v, key=lambda x: x[0]) for k, v in coll_idx_by_mesh_dim.items() + } + self._pg_count_by_mesh_dim = { + k: len(pgs) for k, pgs in pg_sets_by_mesh_dim.items() + } + + op_groups: defaultdict[ + tuple[ + str, + tuple[tuple[int, ...], ...], + tuple[tuple[int, ...], ...], + torch.dtype | None, + ], + list[float], + ] = defaultdict(list) + for rec in self.ops: + key = (rec.op_name, rec.input_shapes, rec.input_strides, rec.dtype) + op_groups[key].append(rec.duration_us) + self._op_index = {k: sum(v) / len(v) for k, v in op_groups.items()} + + # Per-PG peak bandwidth: compute bytes/us for each collective observation, + # then take the max from the top-N largest messages per PG (where bandwidth + # is most representative of hardware speed, not dominated by startup latency). + # Uses output-convention bytes (matching _estimate_with_pg_bandwidth). + _TOP_N = 5 # consider top N largest messages for peak BW + pg_bw_samples: dict[tuple[int, ...], list[tuple[int, float]]] = defaultdict( + list + ) + mesh_dim_bw_samples: dict[tuple[int, int], list[tuple[int, float]]] = ( + defaultdict(list) + ) + for rec in self.collectives: + if rec.out_nelems <= 0 or rec.duration_us <= 0: + continue + gs = len(rec.pg_ranks) if rec.pg_ranks else rec.group_size + elem_bytes = self._dtype_elem_bytes(rec.dtype) + total_bytes = rec.out_nelems * elem_bytes + bw_gbps = total_bytes / (rec.duration_us * 1e-6) / 1e9 # GB/s + pg_bw_samples[rec.pg_ranks].append((total_bytes, bw_gbps)) + stride = _rank_stride(rec.pg_ranks) + if stride is not None: + mesh_dim_bw_samples[(stride, gs)].append((total_bytes, bw_gbps)) + + def _peak_bw_from_samples( + samples: list[tuple[int, float]], + ) -> float: + """Get peak BW from the top-N largest messages.""" + # Sort by message size descending, take top N, return max BW + sorted_samples = sorted(samples, key=lambda x: x[0], reverse=True) + top = sorted_samples[:_TOP_N] + return max(bw for _, bw in top) if top else 0.0 + + self._pg_peak_bw = { + pg: _peak_bw_from_samples(samples) + for pg, samples in pg_bw_samples.items() + if samples + } + self._mesh_dim_peak_bw = { + key: _peak_bw_from_samples(samples) + for key, samples in mesh_dim_bw_samples.items() + if samples + } + + def get_collective_keys(self) -> list[tuple[str, tuple[int, ...], str]]: + """Return the collective index keys: (name, pg_ranks, dtype).""" + return list(self._collective_index.keys()) + + @property + def op_count(self) -> int: + """Number of distinct op shapes in the index.""" + return len(self._op_index) + + def get_op_names(self) -> list[str]: + """Return distinct op names in the op index.""" + return list(OrderedSet(name for name, _, _, _ in self._op_index)) + + @staticmethod + def _dtype_elem_bytes(dtype: str) -> int: + """Return bytes per element for a dtype string (NCCL CamelCase or TypeMeta).""" + return _get_size_from_string(dtype.lower()) + + @staticmethod + def _normalize_collective_name(name: str) -> str: + """Normalize collective name between profile and FX conventions. + + Profile uses: _allgather_base, allreduce, reduce_scatter_tensor_coalesced + FX uses: all_gather_into_tensor, all_reduce, reduce_scatter_tensor + """ + n = name.lower() + if "allgather" in n or "all_gather" in n: + return "all_gather" + if "reduce_scatter" in n: + return "reduce_scatter" + if "allreduce" in n or "all_reduce" in n: + return "all_reduce" + if "all_to_all" in n or "alltoall" in n: + return "all_to_all" + return name + + # Maximum ratio of target_nelems / max_observed before switching from + # log-log extrapolation to bandwidth-based estimation. + EXTRAPOLATION_CAP = 2.0 + + def _estimate_with_pg_bandwidth( + self, + pg_ranks: tuple[int, ...], + nelems: int, + dtype: str, + ) -> float | None: + """Estimate collective duration using peak observed bandwidth for this PG. + + Used when the target size exceeds the extrapolation cap. Returns ms or None. + """ + bw_gbps = self._pg_peak_bw.get(pg_ranks) + if bw_gbps is None or bw_gbps <= 0: + # Try mesh-dimension fallback + stride = _rank_stride(pg_ranks) + gs = len(pg_ranks) + if stride is not None: + bw_gbps = self._mesh_dim_peak_bw.get((stride, gs)) + if bw_gbps is None or bw_gbps <= 0: + return None # fall through to analytical + elem_bytes = self._dtype_elem_bytes(dtype) + total_bytes = nelems * elem_bytes + dur_ms = total_bytes / (bw_gbps * 1e6) # GB/s → bytes/ms = 1e6 + return dur_ms + + def lookup_collective( + self, + collective_name: str, + pg_ranks: tuple[int, ...], + nelems: int, + dtype: str, + ) -> tuple[float, str] | None: + """Look up collective duration in ms. Returns (duration_ms, source) or None. + + ``source`` is ``"profile"`` for exact/interpolated matches, or + ``"pg_bandwidth"`` when bandwidth-based extrapolation was used. + + Tries exact rank match first, then falls back to mesh-dimension match + (e.g. (0,2,4,6) and (1,3,5,7) both have stride=2, size=4 → same mesh dim). + + When the target size exceeds EXTRAPOLATION_CAP * max_observed, uses + bandwidth-based estimation from peak observed bandwidth instead of + linear extrapolation (which overestimates for large messages). + """ + norm_name = self._normalize_collective_name(collective_name) + # Try exact rank match first + key = (norm_name, pg_ranks, dtype) + entries = self._collective_index.get(key) + if not entries: + # Fall back to mesh-dimension match + gs = len(pg_ranks) + stride = _rank_stride(pg_ranks) + if ( + stride is not None + and self._pg_count_by_mesh_dim.get((stride, gs), 0) == 1 + ): + mesh_dim_key = (norm_name, stride, gs, dtype) + entries = self._collective_index_by_mesh_dim.get(mesh_dim_key) + if not entries: + return None + + # Exact match + for n, dur in entries: + if n == nelems: + return (dur / 1e3, "profile") # us -> ms + + # Check extrapolation distance: if target is far beyond observed range, + # use bandwidth-based model instead of log-log extrapolation + max_observed = max((n for n, _ in entries if n > 0), default=0) + if max_observed > 0 and nelems > max_observed * self.EXTRAPOLATION_CAP: + est = self._estimate_with_pg_bandwidth(pg_ranks, nelems, dtype) + if est is not None: + return (est, "pg_bandwidth") + # Fall through to log-log if no BW data available + + # Interpolation in log-log space + result = self._interpolate_log_log(entries, nelems) + if result is not None: + return (result, "profile") + return None + + def _interpolate_log_log( + self, entries: list[tuple[int, float]], target_nelems: int + ) -> float | None: + """Interpolate duration in log-log space (log(nelems) vs log(dur)).""" + if not entries or target_nelems <= 0: + return None + + log_target = math.log(target_nelems) + + # Find bracketing entries + lower: tuple[int, float] | None = None + upper: tuple[int, float] | None = None + for n, dur in entries: + if n <= 0 or dur <= 0: + continue + if n <= target_nelems: + lower = (n, dur) + if n >= target_nelems and upper is None: + upper = (n, dur) + + if lower is not None and upper is not None: + log_n0, log_d0 = math.log(lower[0]), math.log(lower[1]) + log_n1, log_d1 = math.log(upper[0]), math.log(upper[1]) + if log_n1 == log_n0: + return lower[1] / 1e3 + t = (log_target - log_n0) / (log_n1 - log_n0) + log_dur = log_d0 + t * (log_d1 - log_d0) + return math.exp(log_dur) / 1e3 # us -> ms + elif lower is not None: + # Linear extrapolation (not log-log) from nearest lower; + # EXTRAPOLATION_CAP in lookup_collective limits how far this reaches. + return (lower[1] * target_nelems / lower[0]) / 1e3 + elif upper is not None: + # Linear extrapolation from nearest upper + return (upper[1] * target_nelems / upper[0]) / 1e3 + + return None + + def lookup_op( + self, + op_name: str, + input_shapes: tuple[tuple[int, ...], ...], + input_strides: tuple[tuple[int, ...], ...], + dtype: torch.dtype | None, + ) -> float | None: + """Look up op duration in ms by exact shape+stride match. Returns None on miss.""" + key = (op_name, input_shapes, input_strides, dtype) + dur_us = self._op_index.get(key) + if dur_us is not None: + return dur_us / 1e3 # us -> ms + return None + + +@functools.cache +def _dtype_to_nccl_str(dtype: torch.dtype) -> str: + """Convert torch.dtype to NCCL/ScalarType name (for collective matching). + + Derives the name from torch.Tensor.type() which returns e.g. + "torch.BFloat16Tensor" -> "BFloat16". + """ + return ( + torch.tensor([], dtype=dtype) + .type() + .removeprefix("torch.") + .removesuffix("Tensor") + ) + + +def _get_node_dtype(node: fx.Node) -> torch.dtype | None: + """Extract dtype from FX node metadata.""" + val = node.meta.get("val") + if isinstance(val, torch.Tensor): + return val.dtype + if isinstance(val, (list, tuple)) and val: + first = val[0] + if isinstance(first, torch.Tensor): + return first.dtype + return None + + +def _fx_target_to_profile_name(node: fx.Node) -> str | None: + """Convert FX node target to the profile op name format. + + FX: torch.ops.aten.mm.default → "aten::mm" + FX: torch.ops.deepep.dispatch.default → "deepep::dispatch" + """ + target = node.target + if isinstance(target, torch._ops.OpOverload): + # e.g. "aten::mm" from torch.ops.aten.mm.default + ns = target.namespace + op_name = target._schema.name.split("::")[-1] + return f"{ns}::{op_name}" + if hasattr(target, "__name__"): + return target.__name__ + return None + + +def _get_node_input_shapes_and_strides( + node: fx.Node, +) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]] | None: + """Extract input shapes and strides from FX node tensor args. + + Returns (shapes, strides) or None if no tensor args or symbolic dims. + """ + from torch._inductor.fx_passes.node_runtime_estimation import get_hint + + shapes: list[tuple[int, ...]] = [] + strides: list[tuple[int, ...]] = [] + for arg in node.args: + if not isinstance(arg, fx.Node): + continue + val = arg.meta.get("val") + if isinstance(val, torch.Tensor): + resolved_shape = [] + for s in val.shape: + h = get_hint(s) + if h is None: + return None + resolved_shape.append(h) + resolved_stride = [] + for s in val.stride(): + h = get_hint(s) + if h is None: + return None + resolved_stride.append(h) + shapes.append(tuple(resolved_shape)) + strides.append(tuple(resolved_stride)) + if not shapes: + return None + return tuple(shapes), tuple(strides) + + +def _is_collective_node(node: fx.Node) -> bool: + """Check if node is a collective communication op.""" + return ( + is_all_gather_into_tensor(node) + or is_reduce_scatter_tensor(node) + or is_all_reduce_tensor(node) + or is_all_to_all_tensor(node) + ) + + +def _get_collective_info( + node: fx.Node, +) -> tuple[str, tuple[int, ...], int, str] | None: + """Extract (collective_name, pg_ranks, nelems, dtype) from collective node.""" + import torch.distributed as c10d + from torch.fx.operator_schemas import normalize_function + + if not c10d.is_initialized(): + return None + + target = node.target + if not isinstance(target, torch._ops.OpOverload): + return None + collective_name = target.name().split("::")[-1].split(".")[0] + + opt = normalize_function( + target, + args=node.args, + kwargs=node.kwargs, + normalize_to_only_use_kwargs=True, + ) + if opt is None: + return None + _, kwargs = opt + group_name = kwargs.get("group_name", "") + + try: + from torch.distributed.distributed_c10d import ( + _resolve_process_group, + get_process_group_ranks, + ) + + pg = _resolve_process_group(group_name) + pg_ranks = tuple(sorted(get_process_group_ranks(pg))) + except (RuntimeError, KeyError, ValueError): + log.debug( + "PGE: failed to resolve process group for %s", node.name, exc_info=True + ) + return None + + # Get nelems from input tensor + val = node.meta.get("val") + if isinstance(val, torch.Tensor): + nelems = 1 + for s in val.shape: + nelems *= int(s) + dtype = _dtype_to_nccl_str(val.dtype) + else: + # Try first arg + if node.args and isinstance(node.args[0], fx.Node): + inp_val = node.args[0].meta.get("val") + if isinstance(inp_val, torch.Tensor): + nelems = 1 + for s in inp_val.shape: + nelems *= int(s) + dtype = _dtype_to_nccl_str(inp_val.dtype) + else: + return None + else: + return None + + return (collective_name, pg_ranks, nelems, dtype) + + +class ProfileGuidedEstimator: + """Profile-guided runtime estimator for FX nodes. + + Implements the ``custom_runtime_estimation`` interface: + ``(fx.Node, int | None) -> float | None`` (returns ms or None for fallback). + + Handles collectives via interpolation (latency + bandwidth model) and all + other ops (matmul, SDPA, custom ops, etc.) via exact shape match from the + profile trace. + """ + + def __init__( + self, + trace_path: str, + diagnostics_gm: torch.fx.GraphModule | None = None, + ) -> None: + self.profile = ProfileData() + self.profile.load(trace_path) + self._log_profile_vs_analytical_comparison(diagnostics_gm) + + def _log_profile_vs_analytical_comparison( + self, diagnostics_gm: torch.fx.GraphModule | None + ) -> None: + """Log profile data and PGE vs analytical comparison to trace_structured. + + Logs all profile entries (collectives, ops with durations). + If diagnostics_gm is provided, walks the graph and compares PGE + estimates with analytical (roofline / NCCL) for each matched node. + """ + profile = self.profile + op_entries = [ + { + "op": op_name, + "shapes": [list(s) for s in shapes], + "strides": [list(s) for s in strides], + "dtype": str(dtype) if dtype is not None else "", + "profile_ms": dur_us / 1e3, + } + for (op_name, shapes, strides, dtype), dur_us in profile._op_index.items() + ] + + diagnostics: list[dict[str, Any]] = [] + if diagnostics_gm is not None: + from torch._inductor.fx_passes.overlap_scheduling import ( + estimate_roofline_runtime_ms, + ) + + for node in diagnostics_gm.graph.nodes: + pge_est = self(node) + if pge_est is None: + continue + entry: dict[str, Any] = { + "node": node.name, + "op": str(node.target), + "pge_ms": pge_est, + } + if _is_collective_node(node): + try: + entry["analytical_ms"] = ( + torch._inductor.comm_analysis.estimate_nccl_collective_runtime_from_fx_node( + node + ) + ) + except (RuntimeError, ValueError, TypeError): + pass + else: + analytical = estimate_roofline_runtime_ms(node) + if analytical is not None and analytical > 0: + entry["analytical_ms"] = analytical + diagnostics.append(entry) + + payload: dict[str, Any] = { + "collective_count": len(profile.collectives), + "op_count": profile.op_count, + "op_entries": op_entries, + } + if diagnostics: + payload["diagnostics"] = diagnostics + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "pge_profile_vs_analytical", + "encoding": "json", + }, + payload_fn=lambda: json.dumps(payload), + ) + + def __call__(self, node: fx.Node, override_size: int | None = None) -> float | None: + if _is_collective_node(node): + return self._estimate_collective(node, override_size) + return self._estimate_op(node) + + def _estimate_collective( + self, node: fx.Node, override_size: int | None + ) -> float | None: + info = _get_collective_info(node) + if info is None: + return None + coll_name, pg_ranks, nelems, dtype = info + val = node.meta.get("val") + if override_size is not None: + if override_size == 0: + return None + if isinstance(val, torch.Tensor): + elem_size = val.element_size() + if elem_size > 0: + nelems = override_size // elem_size + result = self.profile.lookup_collective(coll_name, pg_ranks, nelems, dtype) + if result is not None: + return result[0] + return None + + def _estimate_op(self, node: fx.Node) -> float | None: + """Estimate any non-collective op via exact shape+stride match in profile.""" + profile_name = _fx_target_to_profile_name(node) + if profile_name is None: + return None + result = _get_node_input_shapes_and_strides(node) + if result is None: + return None + input_shapes, input_strides = result + dtype = _get_node_dtype(node) + return self.profile.lookup_op(profile_name, input_shapes, input_strides, dtype) diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 7b4870fff5e02..dc19a3ac24cc0 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -1520,7 +1520,7 @@ def _find_first_node_in_dequant_pattern(_node): return _node else: assert len(_node.args) >= 1, ( - "In in dequant pattern, each node should have more than 1 arg." + "In dequant pattern, each node should have more than 1 arg." ) return _find_first_node_in_dequant_pattern(_node.args[0]) @@ -1539,7 +1539,7 @@ def _find_first_node_in_dequant_pattern(_node): for user_node in user_node_list[1:]: _source_node = dequant_pattern_end_node _user_node = user_node - # pyrefly: ignore [bad-assignment] + # pyrefly: ignore [bad-assignment, non-convergent-recursion] while _source_node != dequant_pattern_start_node.args[0]: _user_node = clone_to_new_node(graph, _source_node, _user_node) _source_node = _source_node.args[0] # type: ignore[assignment] diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index ceea6fb10d46c..978c2b4c662d2 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -693,19 +693,25 @@ def tensor_with_same_storage_already_reinplaced(arg): _mutable_op = node.args[0] kwargs = node.kwargs - all_bases = kwargs["_all_bases"] - bases_to_clone = range(len(all_bases)) - base_tensors_dct = dict(enumerate(all_bases)) - new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone( - bases_to_clone, - base_tensors_dct, - node.target, - ReInplaceTrigger.AUTO_FUNC_V2, - ) - # Stash the metadata. There is a pass later on where we decompose - # auto_functionalized into clones + a mutable op; this metadata - # tells the decomp to only clone the following inputs - node.meta["only_clone_these_tensors"] = new_bases_to_clone + if isinstance( + _mutable_op, torch._ops.OpOverload + ) and torch._library.utils.is_out(_mutable_op): + # Out args are write-only, always safe to reinplace (no clones needed) + node.meta["only_clone_these_tensors"] = [] + else: + all_bases = kwargs["_all_bases"] + bases_to_clone = range(len(all_bases)) + base_tensors_dct = dict(enumerate(all_bases)) + new_bases_to_clone: list[int] = reinplace_and_refine_tensors_to_clone( + bases_to_clone, + base_tensors_dct, + node.target, + ReInplaceTrigger.AUTO_FUNC_V2, + ) + # Stash the metadata. There is a pass later on where we decompose + # auto_functionalized into clones + a mutable op; this metadata + # tells the decomp to only clone the following inputs + node.meta["only_clone_these_tensors"] = new_bases_to_clone elif node.target is torch.ops.higher_order.auto_functionalized: _mutable_op = node.args[0] from torch._higher_order_ops.auto_functionalize import get_mutable_args @@ -761,15 +767,15 @@ def tensor_with_same_storage_already_reinplaced(arg): mutated_tensors_flat.append(arg) # Check if all mutated args can be inplaced - all_can_inplace = all_can_inplace(node, mutated_tensors_flat) + can_inplace_all = all_can_inplace(node, mutated_tensors_flat) log.debug( - "reinplace with_effects: mutated_tensors=%s, all_can_inplace=%s", + "reinplace with_effects: mutated_tensors=%s, can_inplace_all=%s", [str(a) for a in mutated_tensors_flat], - all_can_inplace, + can_inplace_all, ) - if all_can_inplace and inplaceable_op.extra_check(node): + if can_inplace_all and inplaceable_op.extra_check(node): log.debug( "reinplace with_effects: converting %s -> %s", inner_op, diff --git a/torch/_inductor/fx_passes/replace_random.py b/torch/_inductor/fx_passes/replace_random.py index 150ba5cde4a7c..7689b64ed5334 100644 --- a/torch/_inductor/fx_passes/replace_random.py +++ b/torch/_inductor/fx_passes/replace_random.py @@ -3,6 +3,7 @@ import logging import torch +from torch.fx.experimental.symbolic_shapes import guard_or_false, statically_known_true from torch.fx.passes.graph_transform_observer import GraphTransformObserver from torch.fx.passes.shape_prop import _extract_tensor_metadata @@ -21,6 +22,40 @@ aten = torch.ops.aten +def _shape_to_offset(shape, device: torch.device): + # Modified from torch/_prims/rng_prims.py:philox_rand_offset + nelem = 1 + for s in shape: + nelem *= s + + # Empty tensor: no random numbers are generated/consumed. + is_empty = nelem == 0 + if statically_known_true(is_empty) or guard_or_false(is_empty): + return 0 + + if device is None: + device = torch.device("cpu") + elif isinstance(device, str): + device = torch.device(device) + + if device.type != "cuda": + return 0 + + block_size = 256 + unroll = 4 + curand4_engine_calls = 4 + + device_property = torch.cuda.get_device_properties(device) + + blocks_per_sm = device_property.max_threads_per_multi_processor // block_size + max_grid = device_property.multi_processor_count * blocks_per_sm + grid_size = (nelem + block_size - 1) // block_size + grid_size = -torch.sym_min(-grid_size, -1) + grid_size = torch.sym_min(grid_size, max_grid) + + return ((nelem - 1) // (block_size * grid_size * unroll) + 1) * curand4_engine_calls + + def replace_random_passes(gm: torch.fx.GraphModule): """Modify the given FX graph to use backend-native random ops""" if config.fallback_random: @@ -29,10 +64,59 @@ def replace_random_passes(gm: torch.fx.GraphModule): count = patterns.apply(gm) with GraphTransformObserver(gm, "fuse_seed_creation_pass", "joint_graph_passes"): count += fuse_seed_creation_pass(gm.graph) + if config.align_random_eager: + with GraphTransformObserver(gm, "fuse_offset_creation_pass"): + count += fuse_offset_creation_pass(gm.graph) return count +def fuse_offset_creation_pass(graph: torch.fx.Graph) -> int: + """ + Here offset node means seed << 32 + offset, will unpacked in lowering.py:inductor_random() + Horizontally fuse all the seed generation on each device + a = inductor_prims.rand_eager_offset(offset, dev) + b = inductor_prims.rand_eager_offset(offset, dev) + Becomes: + offsets = inductor_prims.rand_eager_offsets([offset1, offset2...], dev) + a = torch.ops.aten.select.int(offsets, 0, 0) + b = torch.ops.aten.select.int(offsets, 0, 1) + We do this because seed creation is entirely launch overhead bound. + """ + device_offsets = collections.defaultdict(list) + for node in graph.nodes: + if CallFunctionVarArgs(inductor_prims.rand_eager_offset).match(node): + device_offsets[node.args[1]].append(node) + + if not device_offsets: + return 0 + + for device, offsets in device_offsets.items(): + with graph.inserting_before(offsets[0]): + offs = [n.args[0] for n in offsets] + combined = graph.call_function( + inductor_prims.rand_eager_offsets, (offs, device) + ) + with V.fake_mode: + combined.meta["val"] = torch.empty( + [len(offsets), 2], device=device, dtype=torch.int64 + ) + combined.meta["tensor_meta"] = _extract_tensor_metadata( + combined.meta["val"] + ) + + for idx, offset in enumerate(offsets): + with graph.inserting_before(offset): + new_state = graph.call_function( + torch.ops.aten.select.int, (combined, 0, idx) + ) + offset.replace_all_uses_with(new_state) + new_state.meta.update(offset.meta) + graph.erase_node(offset) + + return len(device_offsets) + + def fuse_seed_creation_pass(graph: torch.fx.Graph): """ Horizontally fuse all the seed generation on each device @@ -124,8 +208,32 @@ def replacement(size): match.output_node().target.overloadpacket # type: ignore[union-attr] ] # type: ignore[union-attr] device = get_device(device) + replacement_fn = replacement + + if mode == "rand" and config.align_random_eager and device.type == "cuda": + # Only enable when align_random_eager is on. + def replacement_align(size): + offset = _shape_to_offset(size, device) + + align_dtype = dtype + if isinstance(align_dtype, (tuple, list)): + align_dtype = align_dtype[0] if len(align_dtype) else None + + result = inductor_prims.random( + size, + inductor_prims.rand_eager_offset(offset, device), + mode, + **default_kwargs(device), + align_dtype=align_dtype, + ) + if dtype is not None: + result = result.to(dtype) + return result + + replacement_fn = replacement_align + # pyrefly: ignore [bad-argument-type] - match.replace_by_example(replacement, [size]) + match.replace_by_example(replacement_fn, [size]) # pyrefly: ignore [bad-argument-type] diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py index 567390838ede7..9e6e4a9ec945c 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_10.py @@ -32,7 +32,7 @@ _TargetExprVarArgs, ) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -52,7 +52,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -61,8 +62,7 @@ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) -view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) @@ -71,7 +71,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) @@ -81,16 +81,18 @@ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, convert_element_type_default_3, Ignored()) _sfdp_pattern_10_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, - permute_default_11 + permute_default_11, + None ]) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored()) @@ -110,7 +112,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -118,7 +121,7 @@ permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -158,7 +161,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) @@ -172,12 +175,13 @@ _sfdp_pattern_10_half_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, - permute_default_11 + permute_default_11, + None ]) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py index 567d898ed2042..252cdb3d147ef 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_18.py @@ -44,7 +44,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) @@ -98,6 +98,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -113,7 +114,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) @@ -147,7 +148,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) @@ -200,6 +201,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -213,7 +215,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1, _users=2) @@ -248,7 +250,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -307,6 +309,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -322,7 +325,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -358,7 +361,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -416,6 +419,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -429,7 +433,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py index 5c6d316351b85..66baf24230d19 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_19.py @@ -40,7 +40,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -89,6 +89,7 @@ view_default_11, None, None, + None, None ]) @@ -100,7 +101,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -127,7 +128,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False, _users=2) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) @@ -179,6 +180,7 @@ view_default_11, None, None, + None, None ]) @@ -190,7 +192,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +full_default = CallFunction(aten.full.default, [], KeywordArg('inv_scale'), dtype=Ignored(), device=Ignored(), pin_memory=False) div_Tensor = CallFunction(aten.div.Tensor, view_default_2, full_default) full_default_1 = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) where_self = CallFunction(aten.where.self, KeywordArg('causal_mask'), div_Tensor, full_default_1) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py index 9185aa3b1e330..207cb76b49b92 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_20.py @@ -38,7 +38,7 @@ expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -82,7 +82,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) @@ -98,6 +98,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -107,7 +108,7 @@ expand_default = CallFunction(aten.expand.default, view_default, Ignored()) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) @@ -141,7 +142,7 @@ expand_default = CallFunction(aten.expand.default, view_default, Ignored(), _users=2) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) view_default_1 = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -190,7 +191,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_2, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_9, permute_default_5) view_default_10 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_10, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_9) @@ -206,6 +207,7 @@ permute_default_9, permute_default_11, None, + None, None ]) @@ -215,7 +217,7 @@ expand_default = CallFunction(aten.expand.default, view_default, Ignored()) full_default = CallFunction(aten.full.default, [], Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default_1 = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) view_default_1 = CallFunction(aten.view.default, clone_default, Ignored()) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_28.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_28.py new file mode 100644 index 0000000000000..f414823adcebf --- /dev/null +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_28.py @@ -0,0 +1,202 @@ +# mypy: ignore-errors + +# noqa: F401, E501 +# This is an auto-generated file. Please do not modify it by hand. +# To re-generate, run: +# cd ~/pytorch && python torchgen/fuse/gen_patterns.py + +import torch +import torch._inductor +import operator + +aten = torch.ops.aten +prims = torch.ops.prims + +from torch._inductor.pattern_matcher import ( + Arg, + CallFunction, + CallFunctionVarArgs, + CallMethod, + CallMethodVarArgs, + CallModule, + CallModuleVarArgs, + ExclusiveKeywordArg, + Ignored, + KeywordArg, + ListOf, + MultiOutputPattern, + PatternExpr, + RepeatedExpr, + _TargetArgsExpr, + _TargetExpr, + _TargetExprVarArgs, +) +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList, _users=3) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, div_Tensor) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +neg_default = CallFunction(aten.neg.default, div_Tensor) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, mul_Tensor_4, div_Tensor, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, fma_default, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_28_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor'), _users=2) +amax_default = CallFunction(aten.amax.default, mul_Tensor, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, mul_Tensor, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +expand_default_2 = CallFunction(aten.expand.default, div_Tensor, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_28_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) + + +rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) +gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) +mul_Tensor_1 = CallFunction(aten.mul.Tensor, gt_Scalar, convert_element_type_default_1) +mul_Tensor_2 = CallFunction(aten.mul.Tensor, mul_Tensor_1, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, mul_Tensor_2, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +view_default_5 = CallFunction(aten.view.default, bmm_default_1, Ignored()) +convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, convert_element_type_default_1, Ignored(), _users=2) +neg_default = CallFunction(aten.neg.default, convert_element_type_default_2) +view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) +permute_default_1 = CallFunction(aten.permute.default, view_default_4, Ignored()) +bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_1) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) +mul_Tensor_3 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) +mul_Tensor_4 = CallFunction(aten.mul.Tensor, view_default_7, mul_Tensor_3) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, mul_Tensor_4, Ignored()) +mul_Tensor_5 = CallFunction(aten.mul.Tensor, convert_element_type_default_4, convert_element_type_default_2, _users=2) +sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_5, Ignored(), True) +fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_5) +convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) +mul_Tensor_6 = CallFunction(aten.mul.Tensor, convert_element_type_default_5, KeywordArg('scale_factor')) +view_default_8 = CallFunction(aten.view.default, mul_Tensor_6, Ignored(), _users=2) +permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) +bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) +view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) +permute_default_3 = CallFunction(aten.permute.default, view_default, Ignored()) +bmm_default_4 = CallFunction(aten.bmm.default, permute_default_3, view_default_8) +view_default_10 = CallFunction(aten.view.default, bmm_default_4, Ignored()) +permute_default_4 = CallFunction(aten.permute.default, view_default_10, Ignored()) +permute_default_5 = CallFunction(aten.permute.default, view_default_3, Ignored()) +bmm_default_5 = CallFunction(aten.bmm.default, permute_default_5, view_default_6) +view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) +_sfdp_pattern_28_half_training = MultiOutputPattern([view_default_5, + view_default_9, + permute_default_4, + view_default_11, + None, + None +]) + + +expand_default = CallFunction(aten.expand.default, KeywordArg('query'), Ignored()) +clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) +view_default = CallFunction(aten.view.default, clone_default, Ignored()) +permute_default = CallFunction(aten.permute.default, KeywordArg('key'), Ignored()) +expand_default_1 = CallFunction(aten.expand.default, permute_default, Ignored()) +clone_default_1 = CallFunction(aten.clone.default, expand_default_1, memory_format=torch.contiguous_format) +view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) +bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) +view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) +mul_Tensor = CallFunction(aten.mul.Tensor, view_default_2, KeywordArg('scale_factor')) +convert_element_type_default = CallFunction(prims.convert_element_type.default, mul_Tensor, Ignored(), _users=2) +amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) +sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) +exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) +sum_dim_IntList = CallFunction(aten.sum.dim_IntList, exp_default, Ignored(), True) +div_Tensor = CallFunction(aten.div.Tensor, exp_default, sum_dim_IntList) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored()) +expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) +view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, KeywordArg('value'), Ignored()) +clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) +view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) +bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) +_sfdp_pattern_28_half_inference = CallFunction(aten.view.default, bmm_default_1, Ignored(), _users=0) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py index f211e56b17a0a..a36ff3ae712fb 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_5.py @@ -38,7 +38,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) @@ -59,7 +59,7 @@ mul_Tensor = CallFunction(aten.mul.Tensor, view_default_7, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) -div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) @@ -75,6 +75,7 @@ view_default_9, permute_default_4, view_default_11, + None, None ]) @@ -86,7 +87,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) @@ -108,7 +109,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) @@ -134,7 +135,7 @@ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) @@ -150,6 +151,7 @@ view_default_9, permute_default_4, view_default_11, + None, None ]) @@ -161,7 +163,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py index 01304bf415163..48e3b33dfecc8 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_6.py @@ -40,7 +40,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) @@ -66,7 +66,7 @@ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) -div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) @@ -83,6 +83,7 @@ permute_default_4, view_default_11, None, + None, None ]) @@ -94,7 +95,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask'), _users=2) amax_default = CallFunction(aten.amax.default, add_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, add_Tensor, amax_default) @@ -118,7 +119,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) @@ -149,7 +150,7 @@ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_5 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_5, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_2 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_2) @@ -166,6 +167,7 @@ permute_default_4, view_default_11, None, + None, None ]) @@ -177,7 +179,7 @@ view_default_1 = CallFunction(aten.view.default, expand_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) add_Tensor = CallFunction(aten.add.Tensor, div_Tensor, KeywordArg('attn_mask')) convert_element_type_default = CallFunction(prims.convert_element_type.default, add_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py index b463c7e64a613..4f42972a0d383 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_7.py @@ -44,7 +44,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) @@ -56,7 +56,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -65,8 +66,7 @@ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) -view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) @@ -74,7 +74,7 @@ mul_Tensor_4 = CallFunction(aten.mul.Tensor, mul_Tensor_3, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) -div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) @@ -88,11 +88,13 @@ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, convert_element_type_default_4, Ignored()) _sfdp_pattern_7_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, permute_default_11, + None, None ]) @@ -108,7 +110,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) @@ -118,7 +120,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -138,7 +141,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) @@ -169,7 +172,7 @@ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor_4, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor_4) convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_4, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) @@ -188,6 +191,7 @@ permute_default_6, permute_default_9, permute_default_11, + None, None ]) @@ -203,7 +207,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py index 3faff67089b17..711f68cc6189e 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_8.py @@ -42,7 +42,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) @@ -52,7 +52,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -61,13 +62,12 @@ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) -view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) mul_Tensor = CallFunction(aten.mul.Tensor, convert_element_type_default_2, div_Tensor_1, _users=2) sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) -div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, fma_default, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) @@ -81,11 +81,13 @@ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, convert_element_type_default_3, Ignored()) _sfdp_pattern_8_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, - permute_default_11 + permute_default_11, + None ]) @@ -100,7 +102,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored(), _users=2) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale'), _users=2) amax_default = CallFunction(aten.amax.default, div_Tensor, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, div_Tensor, amax_default) exp_default = CallFunction(aten.exp.default, sub_Tensor, _users=2) @@ -110,7 +112,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -128,7 +131,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored(), _users=2) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) @@ -154,7 +157,7 @@ sum_dim_IntList_1 = CallFunction(aten.sum.dim_IntList, mul_Tensor, Ignored(), True) fma_default = CallFunction(prims.fma.default, neg_default, sum_dim_IntList_1, mul_Tensor) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, fma_default, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, convert_element_type_default_3, KeywordArg('inv_scale')) view_default_8 = CallFunction(aten.view.default, div_Tensor_2, Ignored(), _users=2) permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) @@ -172,7 +175,8 @@ _sfdp_pattern_8_half_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, - permute_default_11 + permute_default_11, + None ]) @@ -187,7 +191,7 @@ view_default_1 = CallFunction(aten.view.default, clone_default_1, Ignored()) bmm_default = CallFunction(aten.bmm.default, view_default, view_default_1) view_default_2 = CallFunction(aten.view.default, bmm_default, Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, view_default_2, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, view_default_2, KeywordArg('inv_scale')) convert_element_type_default = CallFunction(prims.convert_element_type.default, div_Tensor, Ignored(), _users=2) amax_default = CallFunction(aten.amax.default, convert_element_type_default, Ignored(), True) sub_Tensor = CallFunction(aten.sub.Tensor, convert_element_type_default, amax_default) diff --git a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py index 3bf77120e836a..599851e7af012 100644 --- a/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py +++ b/torch/_inductor/fx_passes/serialized_patterns/_sfdp_pattern_9.py @@ -34,7 +34,7 @@ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -56,7 +56,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored(), _users=2) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored(), _users=2) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -65,8 +66,7 @@ view_default_6 = CallFunction(aten.view.default, KeywordArg('tangents_1'), Ignored(), _users=2) permute_default_4 = CallFunction(aten.permute.default, view_default_4, Ignored()) bmm_default_2 = CallFunction(aten.bmm.default, view_default_6, permute_default_4) -convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, bmm_default_2, Ignored()) -view_default_7 = CallFunction(aten.view.default, convert_element_type_default_1, Ignored()) +view_default_7 = CallFunction(aten.view.default, bmm_default_2, Ignored()) convert_element_type_default_2 = CallFunction(prims.convert_element_type.default, view_default_7, Ignored()) convert_element_type_default_3 = CallFunction(prims.convert_element_type.default, gt_Scalar, Ignored()) mul_Tensor_2 = CallFunction(aten.mul.Tensor, convert_element_type_default_3, Ignored()) @@ -78,7 +78,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) @@ -88,17 +88,19 @@ permute_default_10 = CallFunction(aten.permute.default, view_default_3, Ignored()) bmm_default_5 = CallFunction(aten.bmm.default, permute_default_10, view_default_6) view_default_11 = CallFunction(aten.view.default, bmm_default_5, Ignored()) -permute_default_11 = CallFunction(aten.permute.default, view_default_11, Ignored()) +convert_element_type_default_4 = CallFunction(prims.convert_element_type.default, view_default_11, Ignored()) +permute_default_11 = CallFunction(aten.permute.default, convert_element_type_default_4, Ignored()) _sfdp_pattern_9_training = MultiOutputPattern([view_default_5, permute_default_6, permute_default_9, permute_default_11, + None, None ]) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored()) @@ -118,7 +120,8 @@ expand_default_2 = CallFunction(aten.expand.default, convert_element_type_default, Ignored()) view_default_3 = CallFunction(aten.view.default, expand_default_2, Ignored()) permute_default_3 = CallFunction(aten.permute.default, KeywordArg('value'), Ignored()) -expand_default_3 = CallFunction(aten.expand.default, permute_default_3, Ignored()) +convert_element_type_default_1 = CallFunction(prims.convert_element_type.default, permute_default_3, Ignored()) +expand_default_3 = CallFunction(aten.expand.default, convert_element_type_default_1, Ignored()) clone_default_2 = CallFunction(aten.clone.default, expand_default_3, memory_format=torch.contiguous_format) view_default_4 = CallFunction(aten.view.default, clone_default_2, Ignored()) bmm_default_1 = CallFunction(aten.bmm.default, view_default_3, view_default_4) @@ -128,7 +131,7 @@ rand_default = CallFunction(aten.rand.default, Ignored(), dtype=Ignored(), device=Ignored(), pin_memory=False) gt_Scalar = CallFunction(aten.gt.Scalar, rand_default, KeywordArg('dropout_p'), _users=2) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored(), _users=2) @@ -173,7 +176,7 @@ permute_default_5 = CallFunction(aten.permute.default, view_default_1, Ignored()) bmm_default_3 = CallFunction(aten.bmm.default, view_default_8, permute_default_5) view_default_9 = CallFunction(aten.view.default, bmm_default_3, Ignored()) -div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, Ignored()) +div_Tensor_2 = CallFunction(aten.div.Tensor, view_default_9, KeywordArg('inv_scale')) permute_default_6 = CallFunction(aten.permute.default, div_Tensor_2, Ignored()) permute_default_7 = CallFunction(aten.permute.default, view_default, Ignored()) bmm_default_4 = CallFunction(aten.bmm.default, permute_default_7, view_default_8) @@ -188,12 +191,13 @@ permute_default_6, permute_default_9, permute_default_11, + None, None ]) permute_default = CallFunction(aten.permute.default, KeywordArg('query'), Ignored()) -div_Tensor = CallFunction(aten.div.Tensor, permute_default, Ignored()) +div_Tensor = CallFunction(aten.div.Tensor, permute_default, KeywordArg('inv_scale')) expand_default = CallFunction(aten.expand.default, div_Tensor, Ignored()) clone_default = CallFunction(aten.clone.default, expand_default, memory_format=torch.contiguous_format) view_default = CallFunction(aten.view.default, clone_default, Ignored()) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index b4906a2f2ca14..c9ad30cce941d 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -141,7 +141,6 @@ def _get_dim(node: Any): ) -# noqa: W605 # ############The pattern to be optimized is######### # unbind (dim=0) # / ... \ @@ -1430,7 +1429,6 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int): SplitCatSimplifier().simplify(match.graph, split_node, split_sections) -# noqa: W605 # ############pattern to be optimized is######### # split_node(dim=1) @@ -2003,12 +2001,15 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): cat_dim = get_arg_value(node, 1, "dim") # check the unsqueeze nodes come from the select nodes if not all( + # pyrefly: ignore [bad-argument-type] get_arg_value(unsqueeze_node, 0, "input").target is torch.ops.aten.select for unsqueeze_node in unsqueeze_nodes ): return select_nodes = [ - get_arg_value(unsqueeze_node, 0, "input") for unsqueeze_node in unsqueeze_nodes + # pyrefly: ignore [bad-argument-type] + get_arg_value(unsqueeze_node, 0, "input") + for unsqueeze_node in unsqueeze_nodes ] parent_of_select_node = get_arg_value(select_nodes[0], 0, "input") # check the target of select_nodes are the same @@ -2041,6 +2042,7 @@ def merge_unbind_stack_aten(match: Match, *args, **kwargs): node.replace_all_uses_with(parent_of_select_node) graph.erase_node(node) for unsqueeze_node in unsqueeze_nodes: + # pyrefly: ignore [bad-argument-type] graph.erase_node(unsqueeze_node) for select_node in select_nodes: if len(select_node.users) == 0: diff --git a/torch/_inductor/fx_passes/spmd_check.py b/torch/_inductor/fx_passes/spmd_check.py new file mode 100644 index 0000000000000..b51004107a107 --- /dev/null +++ b/torch/_inductor/fx_passes/spmd_check.py @@ -0,0 +1,228 @@ +"""SPMD graph verification for overlap scheduling. + +Verifies all ranks have identical FX graph structure before collective +reordering passes. Non-SPMD graphs cause NCCL collective ordering +mismatches and hangs. +""" + +import hashlib +import logging +from collections import Counter + +import torch +from torch._inductor import config +from torch._logging import trace_structured + + +log = logging.getLogger(__name__) + + +def _compute_hash(gm: torch.fx.GraphModule) -> int | None: + """Compute a structural hash of the graph including tensor metadata. + + Uses FxGraphCachePickler(device_id_agnostic=True) to serialize + (target, val) per call_function node, capturing op targets and + FakeTensor metadata (dtype, shape, stride, etc.) with device indices + normalized to 0. + + Returns None if the graph contains unpicklable objects. + """ + from torch._inductor.codecache import BypassFxGraphCache, FxGraphCachePickler + + try: + pickler = FxGraphCachePickler(gm, device_id_agnostic=True) + data = pickler.dumps( + tuple( + (str(n.target), n.meta.get("val")) + for n in gm.graph.nodes + if n.op == "call_function" + ) + ) + digest = hashlib.blake2b(data, digest_size=8).digest() + return int.from_bytes(digest, "big", signed=True) + except BypassFxGraphCache: + # FxGraphCachePickler can't serialize certain objects: + # mkldnn tensors, BackwardState, torchbind objects, or general + # pickle failures. Skip the SPMD check gracefully. + log.warning("SPMD check: skipping, unpicklable graph objects", exc_info=True) + return None + + +def _build_diag_fingerprint( + gm: torch.fx.GraphModule, +) -> tuple[tuple[str, str | None], ...]: + """Build human-readable fingerprint for mismatch diagnostics. + + Only called on the rare mismatch path. + """ + from torch._inductor.codecache import extract_tensor_metadata_for_cache_key + + entries: list[tuple[str, str | None]] = [] + for n in gm.graph.nodes: + if n.op != "call_function": + continue + target_str = str(n.target) + val = n.meta.get("val") + entries.append( + ( + target_str, + _format_val_metadata(val, extract_tensor_metadata_for_cache_key), + ) + ) + return tuple(entries) + + +def _format_val_metadata(val: object, extract_fn: object) -> str | None: + """Format node val metadata for human-readable diagnostics.""" + if val is None: + return None + if isinstance(val, torch.Tensor): + return str(extract_fn(val)) # type: ignore[operator] + if isinstance(val, (tuple, list)): + parts = [] + for v in val: + if isinstance(v, torch.Tensor): + parts.append(str(extract_fn(v))) # type: ignore[operator] + else: + parts.append(str(type(v).__name__)) + return f"({', '.join(parts)})" + return str(type(val).__name__) + + +def spmd_check(gm: torch.fx.GraphModule) -> bool: + """Verify all ranks have identical FX graph structure (SPMD). + + Computes a structural hash (op targets + tensor metadata including + shapes, dtypes, strides) and compares across ranks. + On mismatch, emits a diagnostic report to stdout, logging, and + trace_structured. + + Returns True if graphs match (SPMD), False on mismatch. + """ + import torch.distributed as dist + + if not dist.is_initialized() or dist.get_world_size() <= 1: + return True + + structure_hash = _compute_hash(gm) + if structure_hash is None: + return True + + from torch._subclasses.fake_tensor import unset_fake_temporarily + from torch.distributed.distributed_c10d import _get_default_group + + pg = _get_default_group() + world_size = dist.get_world_size() + rank = dist.get_rank() + + with unset_fake_temporarily(): + all_hashes: list[int] = [0] * world_size + dist.all_gather_object(all_hashes, structure_hash, pg) + + if all(h == all_hashes[0] for h in all_hashes): + return True + + # Mismatch detected — build and gather diagnostic fingerprints + fingerprint = _build_diag_fingerprint(gm) + with unset_fake_temporarily(): + all_fingerprints: list[tuple[object, ...]] = [() for _ in range(world_size)] + dist.all_gather_object(all_fingerprints, fingerprint, pg) + + report = _build_mismatch_report(all_fingerprints, rank, world_size) + + print(report, flush=True) + log.warning("\n%s", report) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "inductor_spmd_graph_mismatch", + "encoding": "string", + }, + payload_fn=lambda: report, + ) + + if config.aten_distributed_optimizations.spmd_mismatch == "error": + raise RuntimeError( + "SPMD graph verification failed. " + 'Set aten_distributed_optimizations.spmd_mismatch="warn" ' + "to warn instead of fail.\n" + report + ) + + return False + + +def _entry_target(entry: object) -> str: + """Extract the target string from a fingerprint entry.""" + if isinstance(entry, tuple): + return str(entry[0]) + return str(entry) + + +def _entry_metadata(entry: object) -> str: + """Format metadata from a fingerprint entry, if present.""" + if isinstance(entry, tuple) and len(entry) >= 2: + meta = entry[1] + if meta is not None: + return f" meta={meta}" + return "" + + +def _build_mismatch_report( + all_fingerprints: list[tuple[object, ...]], + rank: int, + world_size: int, +) -> str: + """Build diagnostic report for SPMD graph mismatch.""" + lines = [ + "=" * 80, + f"SPMD GRAPH MISMATCH — rank {rank}, world_size={world_size}", + "=" * 80, + ] + + # Node count per rank + counts = [len(t) for t in all_fingerprints] + lines.append("NODE COUNTS PER RANK:") + for r in range(world_size): + marker = " <--" if counts[r] != counts[0] else "" + lines.append(f" rank {r}: {counts[r]} call_function nodes{marker}") + lines.append("") + + # Find entries that differ + ref = all_fingerprints[0] + for r in range(1, world_size): + other = all_fingerprints[r] + if other == ref: + continue + lines.append(f"DIFFS rank 0 vs rank {r}:") + + # Show first few positional differences + max_diffs = 10 + shown = 0 + for i, (a, b) in enumerate(zip(ref, other)): + if a != b and shown < max_diffs: + lines.append(f" node {i}:") + lines.append(f" rank 0: {_entry_target(a)}{_entry_metadata(a)}") + lines.append(f" rank {r}: {_entry_target(b)}{_entry_metadata(b)}") + shown += 1 + + # Also show count-based diffs for op targets + ref_targets = [_entry_target(e) for e in ref] + other_targets = [_entry_target(e) for e in other] + + ref_counts = Counter(ref_targets) + other_counts = Counter(other_targets) + only_ref = ref_counts - other_counts + only_other = other_counts - ref_counts + if only_ref: + lines.append(" Only on rank 0:") + for op, cnt in only_ref.most_common(10): + lines.append(f" {op} (x{cnt})") + if only_other: + lines.append(f" Only on rank {r}:") + for op, cnt in only_other.most_common(10): + lines.append(f" {op} (x{cnt})") + lines.append("") + + lines.append("=" * 80) + return "\n".join(lines) diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 8c63b798576bd..29a26a6c29fcf 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -24,7 +24,11 @@ from torch._decomp import get_decompositions from torch._dynamo.utils import defake, dynamo_timed from torch._library.fake_class_registry import FakeScriptObject -from torch._library.opaque_object import is_opaque_type +from torch._library.opaque_object import ( + is_opaque_reference_type, + is_opaque_type, + is_opaque_value_type, +) from torch._library.utils import get_layout_constraint_tag from torch._logging import LazyString, trace_structured from torch._prims_common import ( @@ -145,12 +149,18 @@ _post_grad_graph_counter = itertools.count() if config.is_fbcode(): + from torch._inductor.fb.triton_kernel_metadata import ( + save_triton_kernel_perf_artifact, + ) from torch._inductor.fb.utils import log_module_code else: def log_module_code(*args: Any, **kwargs: Any) -> None: pass + def save_triton_kernel_perf_artifact(*args: Any, **kwargs: Any) -> None: + pass + def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> torch.dtype | None: assert isinstance( @@ -319,7 +329,15 @@ def _get_overload_packet( prior.meta["dislike_padding"] = True # We only want to mark output nodes. So, move it after the above prior nodes process. if not config.pad_outputs and cur in extended_user_visible_nodes: - cur.meta["dislike_padding"] = True + # Reductions (ops_like_padding) produce new output buffers with + # fresh strides, so their output stride constraint is already + # enforced by allow_padding=False in as_exact_strides. Setting + # dislike_padding here would suppress input padding during + # freeze, causing a stride mismatch when an earlier lowering + # step (e.g. is_contiguous_storage_and_layout) already mutated + # the input layout to padded strides. + if op not in ops_like_padding: + cur.meta["dislike_padding"] = True def is_mkldnn_conv(node: Node) -> bool: @@ -367,8 +385,10 @@ def __init__( name: str | None = None, inputs_to_check: Sequence[int] | None = None, fx_wrapper: bool = False, + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] | None = None, ) -> None: super().__init__(gm) + self.get_decomp_fn = get_decomp_fn self.example_inputs = example_inputs self.layout_opt = ( layout_opt @@ -383,6 +403,7 @@ def __init__( self.const_kernel_code = const_kernel_code self.const_module = const_module self.inputs_to_check = inputs_to_check + self._defers_input_alignment = False self.extra_traceback = False # we do our own error wrapping if shape_env is None: @@ -651,6 +672,12 @@ def get_dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int: """ if (dep, count_bytes) not in self.dep_size_hint_cache: res = 0 + # Non-tensor graph inputs (TorchBindObject, OpaqueObjectState) + # have no meaningful size — skip the size computation entirely. + inp = self.graph_inputs.get(dep.name) + if isinstance(inp, ir.NonTensorObj): + self.dep_size_hint_cache[(dep, count_bytes)] = 0 + return 0 try: if ( not dep.has_unbacked_symbols() @@ -1236,6 +1263,11 @@ def placeholder( self.graph_inputs[target] = gen # type: ignore[assignment] self.graph_input_names.append(target) return gen + elif is_opaque_reference_type(type(example)): + opaque_obj = ir.OpaqueObjectState(name=target, value=example) + self.graph_inputs[target] = opaque_obj # type: ignore[assignment] + self.graph_input_names.append(target) + return opaque_obj assert isinstance(example, torch.Tensor), example # todo(chilli): We can remove the last check once we turn buffers into @@ -1309,7 +1341,12 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> ) base_name = target.name().split(".")[0] if base_name in FALLBACK_ALLOW_LIST: - make_fallback(target, warn=False, override_decomp=True) + make_fallback( + target, + warn=False, + get_decomp_fn=self.get_decomp_fn, + override_decomp=True, + ) elif config.implicit_fallbacks: error = ( MissingOperatorWithDecomp @@ -1347,7 +1384,11 @@ def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> ) decided_constraint = tag_to_layout_constraint(default_tag) - make_fallback(target, layout_constraint=decided_constraint) + make_fallback( + target, + layout_constraint=decided_constraint, + get_decomp_fn=self.get_decomp_fn, + ) elif get_decompositions([target]): # There isn't a good way to dynamically patch this in @@ -1496,7 +1537,7 @@ def get_attr( value=value.item(), dtype=value.dtype, device=value.device ) if self.can_inline_constant(value): - log.debug("Inlining constant: %s ", str(target)) + log.debug("Inlining constant: %s ", target) # tensor lowering has constant inlining logic from .lowering import tensor @@ -1510,11 +1551,11 @@ def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn: def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn: raise AssertionError - # pyrefly: ignore [bad-override] + @typing_extensions.override def output( self, - target: str, # type: ignore[override] - args: tuple[object], # type: ignore[override] + target: torch.fx.node.Target, + args: tuple[torch.fx.node.Argument, ...], kwargs: dict[str, object], ) -> None: result = super().output(target, args, kwargs) # type: ignore[arg-type] @@ -1522,6 +1563,10 @@ def output( # nested subgraphs can have singleton outputs result = (result,) assert isinstance(result, (tuple, list)), type(result) + result = [ + ir.OpaqueValueTypeConstant(value=x) if is_opaque_value_type(type(x)) else x + for x in result + ] assert all( isinstance( x, @@ -1537,6 +1582,8 @@ def output( ir.ShapeAsConstantBuffer, TorchBindObject, ir.OpaqueMultiOutput, + ir.OpaqueValueTypeConstant, + ir.OpaqueObjectState, ), ) for x in result @@ -1575,13 +1622,19 @@ def output( self.graph_outputs = result_correct_strides value: ir.IRNode for name, value in self.graph_inputs.items(): - if isinstance(value, TorchBindObject): - continue - assert isinstance( - value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState) - ), f"Unsupported inductor graph input type: {type(value)}" - if not isinstance(value, TensorBox): + if isinstance( + value, + ( + TorchBindObject, + sympy.Basic, + torch._inductor.ir.GeneratorState, + torch._inductor.ir.OpaqueObjectState, + ), + ): continue + assert isinstance(value, TensorBox), ( + f"Unsupported inductor graph input type: {type(value)}" + ) value.realize() assert isinstance(value, TensorBox) value = value.data @@ -1773,6 +1826,7 @@ def maybe_apply_channels_last_stride_order( self._realize_inputs_at_stream_boundaries(n) with ( ir.IRNode.current_origins(origins), + ir.IRNode.current_stream_idx(self._get_node_stream(n)), self.set_current_node(n), V.set_current_node(n), ): @@ -2007,6 +2061,19 @@ def maybe_apply_channels_last_stride_order( if user.op == "output": # pyrefly: ignore [missing-attribute] if isinstance(result.data.data, (Pointwise, Reduction)): + # Cheap-to-recompute nodes (0 buffer reads, e.g. + # index arithmetic or constant fills) can be + # deferred to realize_input at output processing. + # This prevents cascade materialization where + # shared constants inflate downstream read counts. + if ( + config.delay_realize_cheap_outputs + # pyrefly: ignore [missing-attribute] + and result.data.num_reads() == 0 + # pyrefly: ignore [missing-attribute] + and not result.data.has_large_inner_fn() + ): + continue result.realize() _data = result.data # type: ignore[attr-defined] @@ -2130,6 +2197,8 @@ def format_new_defs() -> str: def create_deferred_runtime_asserts( self, n: torch.fx.Node, new_unbacked_defs: OrderedSet[sympy.Symbol] ) -> None: + if config.do_not_emit_runtime_assertions: + return # [NOTE] Codegen runtime asserts in Inductor # # We need to generate runtime asserts directly in Inductor instead @@ -2441,6 +2510,7 @@ def materialize( if user_defined_kernels: real_inputs = extract_real_inputs() self.extract_autotune_inputs(real_inputs) + save_triton_kernel_perf_artifact(self) return self.codegen() else: if not self.aot_mode: diff --git a/torch/_inductor/inductor_prims.py b/torch/_inductor/inductor_prims.py index 8b0b11ed3f9fa..623edd13fb42b 100644 --- a/torch/_inductor/inductor_prims.py +++ b/torch/_inductor/inductor_prims.py @@ -8,6 +8,7 @@ import torch from torch import _prims, Tensor +from torch._utils import _get_device_index if TYPE_CHECKING: @@ -85,8 +86,8 @@ def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: # the dtype, so it always faithfully produces a float32 tensor during tracing, # even if the default dtype is set to something else. random = make_prim( - "inductor_random(SymInt[] size, Tensor seed, str mode) -> Tensor", - lambda size, seed, mode: getattr(torch, mode)( + "inductor_random(SymInt[] size, Tensor seed, str mode, *, ScalarType? align_dtype=None) -> Tensor", + lambda size, seed, mode, *, align_dtype=None: getattr(torch, mode)( size, device=seed.device, dtype=torch.float32 ), doc="torch.rand()/torch.randn() using backend-specific RNG that can be fused", @@ -96,6 +97,101 @@ def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]: lambda low, high, size, seed: torch.randint(low, high, size, device=seed.device), doc="torch.randint() using backend-specific RNG that can be fused", ) + + +def _reserve_rng_state(device: torch.device, used_offset): + """ + Reserve `used_offset` 32-bit Philox samples on the given CUDA device and + return (seed, base), where base is in Philox-4x32 units. + + This mirrors how Inductor accounts for Philox consumption so compiled + dropout kernels can reconstruct eager RNG state. + """ + dev = device if isinstance(device, torch.device) else torch.device(device) + if dev.type != "cuda": + # Only CUDA devices have Philox-based CUDAGenerator. For non-CUDA + # devices this prim should be dead code and never actually run. + return 0, 0 + + dev_index = _get_device_index(dev, optional=True) + if dev_index is None: + dev_index = torch.cuda.current_device() + + gen = torch.cuda.default_generators[dev_index] + seed_t, off_t, intra_t = torch.ops.inductor_prims.inductor_reserve_rng_state( + gen, used_offset + ) + + # NOTE: for correctness in eager, intra_t should be 0. + # Keep everything as tensor math to avoid host sync. + if intra_t.device.type != off_t.device.type: + intra = int(intra_t.item()) + base = torch.div(off_t + intra, 4, rounding_mode="floor") + else: + base = torch.div(off_t + intra_t, 4, rounding_mode="floor") + return seed_t, base + + +def _rand_eager_offset_impl(offset, device: torch.device) -> Tensor: + """ + Reserve `offset` 32-bit Philox samples and return a 1-element int64 tensor + Place-holder: will be replaced by rand_eager_offsets + In fx_passes/replace_random.py + fuse_offset_creation_pass() + """ + return torch.empty(2, dtype=torch.int64, device=device) + + +def _rand_eager_offsets_impl(offsets, device: torch.device) -> Tensor: + """ + Batched version of _rand_eager_offset_impl. For each entry in `offsets`, + reserve that many 32-bit Philox samples and return a 1D int64 tensor + containing the packed (seed, base) values for each reservation. + """ + states = [_reserve_rng_state(device, int(off)) for off in offsets] + seeds = [s for s, _ in states] + bases = [b for _, b in states] + + def _to_i64(x): + if isinstance(x, torch.Tensor): + return x + return torch.as_tensor(x, device=device, dtype=torch.int64) + + seeds_tensor = torch.stack([_to_i64(x) for x in seeds]).view(-1) + bases_tensor = torch.stack([_to_i64(x) for x in bases]).view(-1) + packed = torch.stack([seeds_tensor, bases_tensor], dim=1) + return packed + + +def _rand_eager_offsets_meta(offsets, device: torch.device): + return torch.empty((len(offsets), 2), dtype=torch.int64, device=device) + + +rand_eager_offset = make_prim( + "inductor_rand_eager_offset(SymInt offset, Device device) -> Tensor", + _rand_eager_offset_impl, + doc=( + "Reserve `offset` 32-bit Philox samples on `device` and return a " + "1-element int64 tensor containing packed (seed, base)." + ), + tags=(torch.Tag.nondeterministic_seeded,), +) + + +rand_eager_offsets = _prims._make_prim( + schema="inductor_rand_eager_offsets(SymInt[] offsets, Device device) -> Tensor", + return_type=_prims.RETURN_TYPE.NEW, + meta=_rand_eager_offsets_meta, + impl_aten=_rand_eager_offsets_impl, + doc=( + "Batched version of inductor_rand_eager_offset. For each entry in " + "`offsets`, reserves that many 64-bit Philox samples and returns " + "packed (seed, base) values." + ), + tags=(torch.Tag.nondeterministic_seeded,), +) + + force_stride_order = make_prim( "inductor_force_stride_order(Tensor input, SymInt[] stride) -> Tensor", eager_force_stride, @@ -222,14 +318,14 @@ def _low_memory_max_pool_offsets_to_indices_aten( _low_memory_max_pool_with_offsets = make_prim( - "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", # noqa: B950 + "_low_memory_max_pool_with_offsets(Tensor self, SymInt[] kernel_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool ceil_mode) -> (Tensor, Tensor)", _low_memory_max_pool_with_offsets_aten, return_type=(_prims.RETURN_TYPE.NEW, _prims.RETURN_TYPE.NEW), doc="Instead of returning indices, returns indices offsets.", ) _low_memory_max_pool_offsets_to_indices = make_prim( - "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", # noqa: B950 + "_low_memory_max_pool_offsets_to_indices(Tensor self, SymInt[] kernel_size, SymInt[] input_size, SymInt[] stride, SymInt[] padding, SymInt[] dilation) -> Tensor", _low_memory_max_pool_offsets_to_indices_aten, doc="Convert small int offsets to regular indices.", ) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 12204fc70cf9e..456bba988bf3b 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -43,7 +43,7 @@ from torch._inductor import metrics from torch._inductor.utils import get_free_symbols from torch._library.fake_class_registry import FakeScriptObject -from torch._library.opaque_object import is_opaque_value +from torch._library.opaque_object import get_opaque_obj_repr, is_opaque_value from torch._prims_common import ( compute_required_storage_length, is_boolean_dtype, @@ -121,6 +121,7 @@ if TYPE_CHECKING: from torch.fx.experimental.symbolic_shapes import SympyBoolean from torch.fx.node import Argument + from torch.types import IntLikeType from .codegen.cutlass.template import CUTLASSTemplate from .codegen.wrapper import PythonWrapperCodegen @@ -419,7 +420,7 @@ def is_triton(x: IRNode | torch.device | None | str) -> bool: # Special case cpu and cuda as using the method below # to determine if the scheduler is a triton scheduler subclass # requires instantiating a scheduler for them - if device in ["cpu", "cuda"]: + if device in ["cpu", "cuda", "xpu"]: if getattr(config, f"{device}_backend") == "triton": return True return False @@ -554,6 +555,7 @@ class IRNode: """ _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + _current_stream_idx: ClassVar[int | None] = None # NB: These are kinda weird, origins: OrderedSet[Any] = dataclasses.field(init=False) @@ -562,6 +564,8 @@ class IRNode: origin_node: torch.fx.Node | None = dataclasses.field(init=False) # Annotations dict for storing metadata (e.g., KernelTemplateChoice) annotations: dict[str, Any] = dataclasses.field(init=False) + # User-annotated stream index from FX node metadata (set during lowering) + stream_idx: int | None = dataclasses.field(init=False) @staticmethod @contextlib.contextmanager @@ -573,6 +577,18 @@ def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]: finally: IRNode._current_origins = old + @staticmethod + @contextlib.contextmanager + def current_stream_idx( + stream_idx: int | None, + ) -> Generator[None, None, None]: + old = IRNode._current_stream_idx + IRNode._current_stream_idx = stream_idx + try: + yield + finally: + IRNode._current_stream_idx = old + @staticmethod def is_realized_node(node: IRNode) -> bool: return isinstance( @@ -604,6 +620,7 @@ def __post_init__(self) -> None: self._post_init_setattr("origin_node", None) # Annotations dict for storing metadata (e.g., KernelTemplateChoice) self._post_init_setattr("annotations", {}) + self._post_init_setattr("stream_idx", self._current_stream_idx) def get_read_names(self) -> OrderedSet[str]: return OrderedSet(dep.name for dep in self.get_reads()) @@ -888,6 +905,10 @@ def get_origins(self) -> OrderedSet[Any]: assert hasattr(self, "origins") return self.origins + def get_stream_idx(self) -> int | None: + assert hasattr(self, "stream_idx") + return self.stream_idx + def get_operation_name(self) -> str: assert self.operation_name is not None return self.operation_name @@ -1339,10 +1360,23 @@ def num_splits( if not V.graph.sizevars.all_unbacked_explicitly_hinted(exprs): return ReductionHint.DEFAULT, 1 reduction_numel_hint = V.graph.sizevars.optimization_hint(reduction_numel) - numel_hint = V.graph.sizevars.optimization_hint(sympy_product(ranges)) + numel = sympy_product(ranges) + numel_hint = V.graph.sizevars.optimization_hint(numel) + + # The Triton backend adds REDUCE_TO_SINGLE_ELEMENT unconditionally if the + # cooperative_reductions feature flag is enabled, but we should still use a + # split scan if we don't actually do a cooperative reduction. + should_reduce_to_single_element = V.graph.has_feature( + device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT + ) and ( + not is_triton(device) + or V.choices.should_use_cooperative_reduction( + device, numel, reduction_numel + ) + ) should_split = reduction_type == "scan" or ( - not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT) + not should_reduce_to_single_element and reduction_type not in ( "argmax", @@ -1855,7 +1889,7 @@ def _multilayer_wrap_loader( reduction_ranges, [reduction_numel], dense_index ) need_mask = not V.graph.sizevars.statically_known_true( - sympy.Eq(reduction_numel % split, 0) + sympy.Eq(Mod(reduction_numel, split), 0) ) def wrapper_fn( @@ -2326,7 +2360,7 @@ def create_multilayer( # type: ignore[override] """ reduction_numel = sympy_product(reduction_ranges) need_mask = not V.graph.sizevars.statically_known_true( - sympy.Eq(reduction_numel % split, 0) + sympy.Eq(Mod(reduction_numel, split), 0) ) if need_mask and reduction_type != "welford_combine": @@ -2708,13 +2742,18 @@ def create( # type: ignore[override] sizevars = V.graph.sizevars sort_numel = sizevars.simplify(sympy_product(sort_ranges)) - # Heuristic, smallest rblock where triton usually outperforms aten.sort + # Heuristic, smallest rblock where triton usually outperforms aten.sort. # It also isn't bandwidth bound so fusion is unlikely to help. - max_rblock = 512 - is_persistent_kernel = ( - config.triton.persistent_reductions - and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock)) - ) + # When decompose_sort_ops is enabled, skip the size limit to always + # attempt Triton sort (index dtype is widened to int32 in lowering). + if config.triton.decompose_sort_ops: + is_persistent_kernel = config.triton.persistent_reductions + else: + max_rblock = 512 + is_persistent_kernel = ( + config.triton.persistent_reductions + and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock)) + ) if not is_persistent_kernel: # We only support persistent triton kernels return [None] * len(dtypes) @@ -2780,12 +2819,32 @@ def is_contiguous_storage_and_layout(x: IRNode) -> bool: # pad the stride here so we will NOT claim an tensor as contiguous # if a padding is gonna happen. if layout.should_pad_strides(): - layout.pad_strides() + assert isinstance(layout, FlexibleLayout), type(layout) + layout = FixedLayout( + layout.device, + layout.dtype, + layout.size, + layout._pad_strides(layout.stride, layout.size, layout.dtype), + layout.offset, + layout.is_pinned, + ) return layout.is_contiguous() except NotImplementedError: return False +def is_dense_contiguous_storage_and_layout(x: IRNode) -> bool: + try: + _buffer, layout = as_storage_and_layout(x, freeze=False) + if not layout.is_contiguous(): + return False + return V.graph.sizevars.statically_known_equals( + layout.storage_size(), layout.offset + sympy_product(layout.size) + ) + except NotImplementedError: + return False + + def as_storage_and_layout( x: IRNode, freeze: bool = True, @@ -3252,12 +3311,15 @@ def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[ len(free_unbacked_symbols(old_size)) > 0 or len(free_unbacked_symbols(new_size)) > 0 ) - is_contiguous = is_contiguous_storage_and_layout(x) + is_contiguous = is_dense_contiguous_storage_and_layout(x) def create_reinterpret_view( inp: IRNode, new_size: Sequence[Expr], new_stride: Sequence[Expr] ) -> ReinterpretView: - storage, old_layout = as_storage_and_layout(inp, want_contiguous=True) + inp = ExternKernel.require_exact_strides( + inp, FlexibleLayout.contiguous_strides(inp.get_size()) + ) + storage, old_layout = as_storage_and_layout(inp) new_layout = FixedLayout( old_layout.device, old_layout.dtype, @@ -3639,6 +3701,11 @@ def loader(idx: Sequence[Expr]) -> OpsValue: class SliceView(View): + """View that represents a slice along a single dimension. + + Corresponds to tensor[..., start:end:step, ...]. + """ + @classmethod def normalize_start_end( cls, x: IRNode, dim: int, start: int, end: int @@ -3653,6 +3720,14 @@ def normalize_start_end( if any(free_unbacked_symbols(x) for x in (start, end, dim_size)): min_func = sympy.Min max_func = sympy.Max + elif any( + # Only needed when backed_size_oblivious is on. + x.has(sympy.Min, sympy.Max) + for x in (start, end, dim_size) + if isinstance(x, Expr) + ): + min_func = sympy.Min + max_func = sympy.Max else: min_func = sizevars.evaluate_min max_func = sizevars.evaluate_max @@ -5204,6 +5279,16 @@ def constant_to_device(self, device: torch.device) -> IRNode: return self.data.constant_to_device(device) +@dataclasses.dataclass(frozen=True) +class FinalizeCodegenResult: + """Structured result from TemplateBuffer._finalize_codegen for external backends.""" + + source: str + imports: list[str] + call_preamble: list[str] + call_args: list[str] + + class TemplateBuffer(OperationBuffer): """ Base class for template operators that support epilogue and prologue fusion. @@ -5218,6 +5303,7 @@ def __init__( make_kernel_render: Callable[..., Any] | None, mutated_inputs: Iterable[IRNode] | None = None, allowed_prologue_inps: OrderedSet[str] | None = None, + named_inputs: dict[str, IRNode] | None = None, ) -> None: super().__init__(name=None, layout=layout) self.inputs = InputsKernel.unwrap_storage(inputs) @@ -5227,6 +5313,19 @@ def __init__( # Annotations dict for storing metadata (e.g., KernelTemplateChoice) self.annotations: dict[str, Any] = {} + # Output buffer names eligible for epilogue fusion. + # Maps buffer name → kernel parameter name (e.g. "buf3" → "result"). + self.epilogue_fusable_outputs: dict[str, str] = {} + # For multi-output kernels: maps child buffer name → MultiOutput + # node. Used by call_kernel to emit tuple-unpacking lines. + self._multi_output_children: dict[str, MultiOutput] = {} + # Maps kernel parameter name → IRNode for each tensor input. + # Used by ExternalTritonTemplateKernel to set up prologue fusion and + # by HelionTemplateBuffer to resolve call arguments. + self._named_inputs: dict[str, IRNode] = ( + dict(named_inputs) if named_inputs else {} + ) + # Inputs that the kernel mutates in-place self.mutated_inputs = mutated_inputs self.mutation_outputs: list[MutationOutput] = [] @@ -5242,6 +5341,18 @@ def __init__( self.allowed_prologue_inps: OrderedSet[str] = ( allowed_prologue_inps or OrderedSet() ) + # Per-template fusion overrides. None means fall back to global + # config.epilogue_fusion / config.prologue_fusion. + self.allow_epilogue_fusion: bool | None = None + self.allow_prologue_fusion: bool | None = None + + @property + def dtype(self) -> torch.dtype: + if isinstance(self.layout, MultiOutputLayout): + raise NotImplementedError( + "Multi-output templates do not have a single dtype" + ) + return self.get_layout().dtype def get_read_writes(self) -> dependencies.ReadWrites: return self.extract_read_writes(normalize=True) @@ -5265,6 +5376,30 @@ def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any: return reads def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites: + """Extract read/write dependencies for this TemplateBuffer. + + When the layout is MultiOutputLayout (multi-output templates), the + buffer itself has no data layout, so we cannot build an indexer. + Instead, synthesize a trivial write dep and derive read deps from + the named tensor inputs (``_named_inputs``). For single-output + templates with a concrete layout, fall through to the standard path. + """ + if isinstance(self.layout, MultiOutputLayout): + writes: OrderedSet[dependencies.Dep] = OrderedSet( + [ + dependencies.MemoryDep( + self.get_name(), sympy.Integer(0), var_names=(), size=() + ), + ] + ) + return dependencies.ReadWrites( + reads=self._read_deps_from_inputs(normalize), + writes=writes, + index_exprs=OrderedSet(), + range_vars=None, + var_ranges=None, + ) + name = self.get_name() indexer = self.get_layout().make_indexer() @@ -5307,6 +5442,76 @@ def is_multi_outputs_template(self) -> bool: def get_allowed_prologue_inps(self) -> OrderedSet[str]: return self.allowed_prologue_inps + def _finalize_codegen( + self, hook_outputs: dict[str, str] + ) -> FinalizeCodegenResult | None: + """Called after epilogue/prologue subgraph codegen with rendered hook outputs. + + ``hook_outputs`` maps placeholder keys (e.g. ````, + ````) to the Triton code generated by Inductor for + each fused subgraph. + + Return a ``FinalizeCodegenResult`` to provide custom source code and + call metadata, or ``None`` to use the default codegen path. + """ + return None + + @classmethod + def realize_template_input(cls, tb: TensorBox) -> IRNode: + """Realize a TensorBox, preserving MultiOutput layout (unlike ExternKernel.realize_input).""" + if isinstance(tb, TensorBox) and isinstance(tb.data, MultiOutput): + return tb.data + result = ExternKernel.realize_input(tb) + if isinstance(result, StorageBox): + result = result.data + if isinstance(result.layout, FlexibleLayout): # type: ignore[union-attr] + result.freeze_layout() + return result + + @classmethod + def build_multi_outputs( + cls, + template_buf: TemplateBuffer, + structured: object, + *, + direct_alias_at_leaf: dict[int, IRNode] | None = None, + on_tensor_leaf: Callable[[str, MultiOutput, list[tuple[type, int]], int], None] + | None = None, + on_non_tensor_leaf: Callable[[int], None] | None = None, + ) -> tuple[TensorBox, ...]: + """Walk a structured output tree, creating MultiOutput nodes for tensor leaves.""" + seen_outputs: dict[int, TensorBox] = {} + leaf_counter = itertools.count() + + def walk(output: object, indices: list[tuple[type, int]]) -> list[TensorBox]: + if isinstance(output, (list, tuple)): + results: list[TensorBox] = [] + for i, item in enumerate(output): + results.extend(walk(item, [*indices, (type(output), i)])) + return results + leaf_idx = next(leaf_counter) + if isinstance(output, torch.Tensor): + if direct_alias_at_leaf and leaf_idx in direct_alias_at_leaf: + return [TensorBox.create(direct_alias_at_leaf[leaf_idx])] + tid = id(output) + if tid in seen_outputs: + return [seen_outputs[tid]] + mo = MultiOutput( + FallbackKernel.tensor_to_layout(output), template_buf, indices + ) + template_buf._multi_output_children[mo.get_name()] = mo + if on_tensor_leaf is not None: + on_tensor_leaf(mo.get_name(), mo, indices, leaf_idx) + tb = TensorBox(mo) + seen_outputs[tid] = tb + return [tb] + # Non-tensor leaf (int, SymInt, None, etc.) + if on_non_tensor_leaf is not None: + on_non_tensor_leaf(leaf_idx) + return [] + + return tuple(walk(structured, [])) + class TritonTemplateBuffer(TemplateBuffer): def __init__( @@ -5333,6 +5538,8 @@ def __init__( mutated_inputs=mutated_inputs, allowed_prologue_inps=allowed_prologue_inps, ) + assert self.name is not None + self.epilogue_fusable_outputs = {self.name: self.name} self.subgraph_inps: list[IRNode | sympy.Expr | None] | None = None self.subgraph_outs: list[IRNode | None] | None = None @@ -5878,18 +6085,24 @@ def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox: break any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs) fx_node_args = V.graph.current_node.args[0] - assert isinstance(fx_node_args, list), type(fx_node_args) # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output - if any_input_is_storage_and_layout is False and any( - # pyrefly: ignore [missing-attribute] - "val" in arg.meta - and ( - # pyrefly: ignore [missing-attribute] - arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + # Skip this check when fx_node_args is not a list (e.g., called from _pad_as_cat). + if ( + any_input_is_storage_and_layout is False + and isinstance(fx_node_args, list) + and any( # pyrefly: ignore [missing-attribute] - or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d) + "val" in arg.meta + and ( + # pyrefly: ignore [missing-attribute] + arg.meta["val"].is_contiguous(memory_format=torch.channels_last) + # pyrefly: ignore [missing-attribute] + or arg.meta["val"].is_contiguous( + memory_format=torch.channels_last_3d + ) + ) + for arg in fx_node_args ) - for arg in fx_node_args ): output_stride = make_channels_last_strides_for(new_size) @@ -6239,24 +6452,58 @@ def process_kernel( Callable[[Any, Any], Any], dict[sympy.Symbol, pytree.KeyPath] | None, ]: + """Partition kernel args into tensor and non-tensor, realize tensor inputs, + re-run fake tensor propagation with the realized strides, and return + (example_output, tensor_args, non_tensor_args, unflatten_args, unbacked_bindings). + + unflatten_args(new_tensor_args, new_non_tensor_args) reconstructs the + original (args, kwargs) tree from replacement lists. + """ binded_args = {"args": args, "kwargs": kwargs} args_flat, args_spec = pytree.tree_flatten(binded_args) - is_arg_tensor = [] + args_flat_is_tensor: list[bool] = [] # tensor_args can be either tensor or torchbind objects - tensor_args = [] - non_tensor_args: list[Any] = [] + tensor_args: list[IRNode] = [] + non_tensor_args: list[object] = [] + real_non_tensor_args: list[ + FakeScriptObject + | torch._C.Generator + | torch._C.ScriptObject + | torch.Tensor + | IntLikeType + ] = [] for arg in args_flat: - is_arg_tensor.append( - isinstance(arg, IRNode) and not isinstance(arg, GeneratorState) - ) - if is_arg_tensor[-1]: - tensor_args.append(arg) - else: - if isinstance(arg, Expr): - arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) - non_tensor_args.append(arg) + match arg: + case Expr(): + node = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) + args_flat_is_tensor.append(False) + non_tensor_args.append(node) + real_non_tensor_args.append(node) + + case GeneratorState(): + args_flat_is_tensor.append(False) + non_tensor_args.append(arg) + device_index = arg.device.index + assert arg.device.type == "cuda" and device_index is not None + real_non_tensor_args.append( + torch.cuda.default_generators[device_index].clone_state() + ) + + case OpaqueObjectState(): + args_flat_is_tensor.append(False) + non_tensor_args.append(arg) + real_non_tensor_args.append(arg.value) + + case IRNode(): + args_flat_is_tensor.append(True) + tensor_args.append(arg) + + case _: + args_flat_is_tensor.append(False) + non_tensor_args.append(arg) + real_non_tensor_args.append(arg) def unflatten_args( new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T] @@ -6264,7 +6511,7 @@ def unflatten_args( result = [] it_tensors = iter(new_tensor_args) it_non_tensors = iter(new_non_tensor_args) - for is_tensor in is_arg_tensor: + for is_tensor in args_flat_is_tensor: if is_tensor: result.append(next(it_tensors)) else: @@ -6313,7 +6560,7 @@ def unflatten_args( else: example_args.append(ir_node_to_tensor(x)) - new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) + new_args, new_kwargs = unflatten_args(example_args, real_non_tensor_args) example_output = kernel(*new_args, **new_kwargs) unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] | None = None @@ -6482,6 +6729,7 @@ def require_strides( exact_strides: Sequence[_IntLike] | None = None, allow_padding: bool = False, ) -> IRNode: + """Ensure x has the requested stride order or exact strides, inserting a copy if needed.""" assert order is not None or exact_strides is not None # Layout generally doesn't matter, but some consuming external ops might have requirements if x.get_numel() in (0, 1) and not exact_strides: @@ -6530,7 +6778,17 @@ def require_strides( exact_strides=exact_strides, ) return x - elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and ( + + # When padding is allowed, check if the buffer's existing strides + # match padded versions of the requested strides (e.g. concat graph + # outputs that were already padded by ConcatKernel). + padded_exact_strides = None + if allow_padding and exact_strides: + padded_exact_strides = list( + Layout._pad_strides(exact_strides, x.get_size(), x.get_dtype()) + ) + + if isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and ( (order and x.get_layout().is_stride_ordered(order)) or ( exact_strides @@ -6544,6 +6802,15 @@ def require_strides( if exact_strides is not None else x ) + # Accept already-padded buffers when padding is allowed + elif ( + padded_exact_strides is not None + and isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) + and significant_strides_equal( + padded_exact_strides, x.get_layout().stride, x.get_size() + ) + ): + return try_match_insignificant_strides(x, padded_exact_strides) elif isinstance( (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE ): @@ -7545,7 +7812,7 @@ def __init__( # pyrefly: ignore [missing-attribute] self.kernel_src = kernel.src self.kernel_ast = ast.parse(self.kernel_src) - self.kernel_stores = identify_triton_stores(self.kernel_ast) + self.kernel_stores = identify_triton_stores(self.kernel_src) self.kernel_args = kernel_args # names in `arg_accesses.read_writes` are names of formal arguments in the kernel's prototype self.arg_accesses = identify_accessed_tensors( @@ -7554,8 +7821,12 @@ def __init__( tma_descriptor_metadata, ) + # Filter to only tensor args: with Triton 3.7+, ordered_arg_names + # includes scalars, so writes may reference non-tensor args like SymInts. self.mutable_args = [ - kernel_args[key.name] for key in self.arg_accesses.read_writes.writes + kernel_args[key.name] + for key in self.arg_accesses.read_writes.writes + if isinstance(kernel_args.get(key.name), TensorBox) ] self.mutation_outputs = [ @@ -8622,32 +8893,39 @@ def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKern ) = cls.process_kernel(kernel, *args, **kwargs) # Try to lower single output functional custom ops to their out-variant. - if isinstance(kernel, torch._ops.OpOverload): + if ( + isinstance(kernel, torch._ops.OpOverload) + and not torch._library.utils.is_builtin(kernel) + and isinstance(example_output, torch.Tensor) + ): from torch._library._out_variant import ( _is_functional, get_out_arg_names, + lookup_manual_out_variant, to_out_variant, ) - if _is_functional(kernel._schema) and isinstance( - example_output, torch.Tensor - ): + out_op = None + if _is_functional(kernel._schema): out_op = to_out_variant(kernel) - if out_op is not None and len(get_out_arg_names(out_op)) == 1: - layout = FixedLayout( - device=example_output.device, - dtype=example_output.dtype, - size=[*example_output.shape], - stride=[*example_output.stride()], - ) - return ExternKernelOut( # type: ignore[return-value] - layout=layout, - inputs=list(tensor_args), - constant_args=list(non_tensor_args), - kwargs=kwargs, - python_kernel_name=_make_out_variant_kernel_name(out_op), - op_overload=out_op, - ) + if out_op is None: + out_op = lookup_manual_out_variant(kernel) + + if out_op is not None and len(get_out_arg_names(out_op)) == 1: + layout = FixedLayout( + device=example_output.device, + dtype=example_output.dtype, + size=[*example_output.shape], + stride=[*example_output.stride()], + ) + return ExternKernelOut( # type: ignore[return-value] + layout=layout, + inputs=list(tensor_args), + constant_args=list(non_tensor_args), + kwargs=kwargs, + python_kernel_name=_make_out_variant_kernel_name(out_op), + op_overload=out_op, + ) # We need this extra check for input alignment since the example # inputs we created are always aligned. @@ -8662,9 +8940,10 @@ def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKern ): device = torch.device("cpu") - # Try multi-output .out() lowering for ops with out_variant tag. + # Try multi-output .out() lowering for custom ops with the out tag. if ( isinstance(kernel, torch._ops.OpOverload) + and not torch._library.utils.is_builtin(kernel) and not V.graph.cpp_wrapper and device ): @@ -8914,6 +9193,10 @@ def __init__( super().__init__(layout, input, indices, skip_size_stride_alignment_checks=True) self.opaque_example_value = opaque_value + @property # type: ignore[override] + def dtype(self) -> Never: + raise AttributeError("OpaqueMultiOutput has no dtype") + def wrap_for_lowering(self) -> OpaqueMultiOutput: return self @@ -9273,6 +9556,7 @@ def realize(self) -> str | None: self.data.origins = self.origins self.data.origin_node = origin_node self.data.traceback = traceback + self.data.stream_idx = self.data.data.stream_idx return self.data.name def realize_hint(self) -> None: @@ -9423,7 +9707,9 @@ def create( new_operands: list[IRNode] = [] for idx, operand in enumerate(operands): - if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)): + if isinstance( + operand, (ShapeAsConstantBuffer, GeneratorState, OpaqueObjectState) + ): new_operands.append(operand) else: new_operands.append( @@ -10151,6 +10437,27 @@ def get_buf_bytes(self) -> int: return functools.reduce(operator.add, flat_sizes, 0) +@ir_dataclass +class OpaqueValueTypeConstant(NonTensorObj): + """IR node for opaque value type constants that appear directly in graph outputs. + + Unlike TorchBindObject (which references named constants loaded at runtime), + this inlines the value's repr into the generated code since value types are + reconstructed from their repr. + """ + + value: Any + + def get_name(self) -> str: + return repr(self.value) + + def codegen_reference(self, writer: IndentedBuffer | None = None) -> str: + obj_repr, opaque_types = get_opaque_obj_repr(self.value) + for n, t in opaque_types.items(): + V.graph.opaque_value_type_classes[n] = t + return obj_repr + + @ir_dataclass class GeneratorState(NonTensorObj): name: str @@ -10163,6 +10470,24 @@ def codegen_reference(self, writer: IndentedBuffer | None = None) -> str: return self.name +@ir_dataclass +class OpaqueObjectState(NonTensorObj): + """ + Represents an opaque object (e.g., ProcessGroup) that is passed through + as a graph input. Similar to GeneratorState, this wraps the object with + its placeholder name so codegen can reference it properly. + """ + + name: str + value: Any # The actual opaque object (for reference, not used in codegen) + + def get_name(self) -> str: + return self.name + + def codegen_reference(self, writer: IndentedBuffer | None = None) -> str: + return self.name + + class _CollectiveKernel(FallbackKernel): def should_allocate(self) -> bool: return False @@ -10190,7 +10515,7 @@ def set_cpp_kernel_name(self, cpp_kernel_name: str | None = None) -> None: # Between the initiation and completion of an in-place collective, the # input buffers are subject to both volatile reads and volatile writes. # They must not be read, written to or reused by another kernel. To ensure - # the constraints, we model collective -> wait_tensor as as two-step + # the constraints, we model collective -> wait_tensor as a two-step # mutation of the input buffers. @classmethod def create_inplace( diff --git a/torch/_inductor/kernel/bmm.py b/torch/_inductor/kernel/bmm.py index 902ffad2fc16c..789086627bf45 100644 --- a/torch/_inductor/kernel/bmm.py +++ b/torch/_inductor/kernel/bmm.py @@ -10,6 +10,7 @@ from .. import config as inductor_config, ir, lowering as L from ..kernel_inputs import MMKernelInputs from ..lowering import lowerings, make_pointwise, make_reduction, transform_args +from ..runtime.runtime_utils import get_max_y_grid from ..select_algorithm import ( autotune_select_algorithm, ExternKernelChoice, @@ -43,8 +44,14 @@ @SymbolicGridFn -def bmm_grid(b, m, n, meta, *, cdiv): - return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), b, 1) +def bmm_grid(b, m, n, meta, *, cdiv, max): + tiles = cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]) + # Split batch across grid_y and grid_z to avoid exceeding CUDA grid_y limit. + # When b <= max_y_grid, grid_z = 1 and behavior is identical to the original. + max_y_grid = get_max_y_grid() + grid_z = max(cdiv(b, max_y_grid), 1) + grid_y = cdiv(b, grid_z) + return (tiles, grid_y, grid_z) # We define each template kernel in a separate file which is the name of the input to load_kernel_template @@ -69,6 +76,17 @@ def bmm_grid(b, m, n, meta, *, cdiv): ) +def _has_broadcast_batch_dim(mat1, mat2): + """Check if either input has a broadcast batch dimension (stride=0). + + The Triton bmm template can trigger CUDA IMA during autotuning with + stride-0 inputs; the aten bmm fallback handles broadcast correctly. + """ + return V.graph.sizevars.statically_known_equals( + mat1.get_stride()[0], 0 + ) or V.graph.sizevars.statically_known_equals(mat2.get_stride()[0], 0) + + @L.register_lowering(aten.bmm) def tuned_bmm(mat1, mat2, out_dtype=None, *, layout=None): """ @@ -179,11 +197,9 @@ def _to_dtype(x): templates_to_use.append(aten_handler) kwarg_overrides[aten_handler.uid] = aten_extra_kwargs - if use_triton_template(layout, check_max_autotune=False) and ( - out_dtype is None or out_dtype == mat1.get_dtype() - ): - # TODO: add out_dtype support for Triton Template - templates_to_use.append(bmm_template) + if use_triton_template(layout, check_max_autotune=False): + if not _has_broadcast_batch_dim(mat1, mat2): + templates_to_use.append(bmm_template) # Single unified call for all templates choices.extend( @@ -279,7 +295,8 @@ def tuned_baddbmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): templates_to_use.append(aten_baddbmm) if use_triton_template(layout, check_max_autotune=False): - templates_to_use.append(bmm_template) + if not _has_broadcast_batch_dim(mat1, mat2): + templates_to_use.append(bmm_template) # Single unified call for all templates choices.extend( diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index c2c2f85064103..dbc10b26af44d 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -497,7 +497,7 @@ def autotune_custom_op( choices = template.generate_custom_op_choices( name=name, decompositions=decompositions, - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] input_nodes=list(inputs), non_tensor_args=non_tensor_args, input_gen_fns=input_gen_fns if input_gen_fns else None, diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index fa69b51c87577..8c1523591c1ef 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -15,6 +15,7 @@ import torch from torch._inductor.virtualized import V from torch.nn.attention.flex_attention import _Backend +from torch.utils._sympy.functions import FloorDiv, Mod from ...ir import ComputedBuffer, ExternKernel, FixedLayout, TensorBox from ...lowering import empty, empty_strided, lowerings, register_lowering, to_dtype @@ -318,10 +319,16 @@ def flex_attention( B = Bq - if seq_len_q % 128 != 0 or seq_len_kv % 128 != 0: - kernel_options.setdefault("IS_DIVISIBLE", False) - else: + seq_q_divisible = V.graph.sizevars.statically_known_true( + sympy.Eq(Mod(seq_len_q, 128), 0) + ) + seq_kv_divisible = V.graph.sizevars.statically_known_true( + sympy.Eq(Mod(seq_len_kv, 128), 0) + ) + if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) + else: + kernel_options.setdefault("IS_DIVISIBLE", False) # NB it is okay that the v_head_dim is different # We are using these to match fill order of the output. @@ -353,7 +360,7 @@ def flex_attention( kernel_options.setdefault("SM_SCALE", scale) # Determine GQA broadcast factor. - gqa_shared_heads = Hq // Hkv + gqa_shared_heads = FloorDiv(Hq, Hkv) kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) # Inside of Triton kernel, only apply partial masking if partial blocks are computed. @@ -704,8 +711,16 @@ def flex_attention_backward(*args, **kwargs): for k, v in kernel_options.items() } kernel_options.setdefault("FLOAT32_PRECISION", get_float32_precision()) - seq_q_divisible = V.graph.sizevars.statically_known_true(seq_len_q % 128 == 0) - seq_kv_divisible = V.graph.sizevars.statically_known_true(seq_len_kv % 128 == 0) + kernel_options.setdefault("PRESCALE_QK", False) + kernel_options.setdefault("ROWS_GUARANTEED_SAFE", False) + kernel_options.setdefault("BLOCKS_ARE_CONTIGUOUS", False) + kernel_options.setdefault("WRITE_DQ", True) + seq_q_divisible = V.graph.sizevars.statically_known_true( + sympy.Eq(Mod(seq_len_q, 128), 0) + ) + seq_kv_divisible = V.graph.sizevars.statically_known_true( + sympy.Eq(Mod(seq_len_kv, 128), 0) + ) if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) else: @@ -865,7 +880,7 @@ def flex_attention_backward(*args, **kwargs): kernel_options.setdefault("SM_SCALE", scale) # Determine GQA factor - gqa_shared_heads = Hq // Hkv + gqa_shared_heads = FloorDiv(Hq, Hkv) kernel_options.setdefault("GQA_SHARED_HEADS", gqa_shared_heads) # Inside of Triton kernel, only apply partial masking if partial blocks are computed. diff --git a/torch/_inductor/kernel/flex/flex_cpu.py b/torch/_inductor/kernel/flex/flex_cpu.py index b16975f7113eb..820ace5b034e5 100644 --- a/torch/_inductor/kernel/flex/flex_cpu.py +++ b/torch/_inductor/kernel/flex/flex_cpu.py @@ -84,7 +84,7 @@ def lower_cpu( "torch.compile on current platform is not supported for CPU." ) - fake_buffers: list[Buffer] = [] # noqa: F821 + fake_buffers: list[Buffer] = [] # [Note] Handle the case where the split sizes are not statically known. # The value of cur_qSplitSize and cur_kvSplitSize are decided during runtime. diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 8401f24fcbac2..c069f2d674c9e 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -8,6 +8,7 @@ import torch from torch._inductor.virtualized import V +from torch.utils._sympy.functions import FloorDiv, Mod from ... import ir from ...ir import FixedLayout, FlexibleLayout @@ -69,7 +70,7 @@ def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> Hq = query.get_size()[1] Hkv = value.get_size()[1] - ratio = Hq // Hkv + ratio = FloorDiv(Hq, Hkv) pw_of_two = V.graph.sizevars.guard_or_false( sympy.And(sympy.Gt(ratio, 0), sympy.Eq(ratio & (ratio - 1), 0)) @@ -86,7 +87,7 @@ def _use_flex_decoding(query, kv_indices, value, kernel_options, enable_gqa) -> and pw_of_two ) log.debug( - "Use flex decoding %s, force_flex_attention=%s, short_query_length=%s, static_batch=%s, static_num_heads=%s", # noqa: B950 + "Use flex decoding %s, force_flex_attention=%s, short_query_length=%s, static_batch=%s, static_num_heads=%s", out, force_flex, short_query_length, @@ -181,10 +182,10 @@ def create_flex_decoding_kernel(*args, **kwargs): } seq_q_divisible = V.graph.sizevars.statically_known_true( - sympy.Eq(seq_len_q % 128, 0) + sympy.Eq(Mod(seq_len_q, 128), 0) ) seq_kv_divisible = V.graph.sizevars.statically_known_true( - sympy.Eq(seq_len_kv % 128, 0) + sympy.Eq(Mod(seq_len_kv, 128), 0) ) if seq_q_divisible and seq_kv_divisible: kernel_options.setdefault("IS_DIVISIBLE", True) @@ -192,7 +193,7 @@ def create_flex_decoding_kernel(*args, **kwargs): kernel_options.setdefault("IS_DIVISIBLE", False) # Calculate GQA head sharing - gqa_shared_heads = Hq // Hkv + gqa_shared_heads = FloorDiv(Hq, Hkv) if not is_power_of_2(gqa_shared_heads): raise ValueError( "Number of shared query heads sharing the same KV head must be power of 2. " @@ -302,7 +303,7 @@ def create_flex_decoding_kernel(*args, **kwargs): kernel_options.setdefault( "SAFE_M_BOUNDARY", - ((seq_len_q * gqa_shared_heads) % kernel_options["BLOCK_M"]) == 0, + Mod(seq_len_q * gqa_shared_heads, kernel_options["BLOCK_M"]) == 0, ) # TODO: This feels sketchy kernel_options.setdefault("SAFE_N_BOUNDARY", True) diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 7f9a00547f5a5..13653d41e66d5 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -100,6 +100,14 @@ source=load_kernel_template("triton_persistent_tma_mm"), ) +# Non-TMA Triton template for persistent MM +# used on AMD +persistent_mm_template = TritonTemplate( + name="mm_persistent", + grid=persistent_mm_grid, + source=load_kernel_template("triton_persistent_mm"), +) + scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( name="scaled_mm_device_tma_epilogue_scaling", @@ -359,7 +367,6 @@ def _to_dtype(x): return ops.to_dtype(x, mat1.dtype, use_compute_types=False) args = [make_pointwise(_to_dtype)(x) for x in args] - mul_pointwise = make_pointwise(ops.dot)(*args) dot_reduction = make_reduction("dot")(mul_pointwise, 1) @@ -398,14 +405,7 @@ def _to_dtype(x): templates_to_use: list[ExternKernelChoice | KernelTemplate] = [] kwarg_overrides: dict[str, dict[str, Any]] = {} - - # Check if TLX force mode is enabled (fbcode only) - tlx_force_mode = ( - inductor_config.is_fbcode() and inductor_config.triton.tlx_mode == "force" - ) - - # Add ATEN kernels unless in TLX force mode (force mode uses only TLX) - if use_aten_gemm_kernels() and not tlx_force_mode: + if use_aten_gemm_kernels(): templates_to_use.append(aten_handler) if aten_extra_kwargs: kwarg_overrides[aten_handler.uid] = aten_extra_kwargs @@ -426,26 +426,20 @@ def _to_dtype(x): if is_exhaustive or not use_decompose_k_choice(m, n, k, threshold_multiple=2): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): + if use_triton_blackwell_tma_template( + mat1, mat2, output_layout=layout, add_guards=True + ): templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + elif use_triton_tma_template( + mat1, mat2, output_layout=layout, add_guards=True + ): + if torch.version.hip is None: + templates_to_use.append(persistent_tma_mm_template) + else: + templates_to_use.append(persistent_mm_template) templates_to_use.append(mm_contiguous_subgraph_template) - # TLX templates hook (fbcode only) - if inductor_config.is_fbcode(): - from torch._inductor.fb.tlx_templates.mm_templates import apply_tlx_templates - - templates_to_use = apply_tlx_templates( - templates_to_use, - m, - n, - k, - use_decompose_k_choice, - decompose_k_subgraph_template, - ) - choices.extend( V.choices.get_template_configs( kernel_inputs, @@ -639,6 +633,10 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): kernel_inputs = MMKernelInputs( [inp_expanded, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) ) + kernel_inputs_aten = MMKernelInputs( + [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) + ) + choices: list[ChoiceCaller] = [] # below is for getting an overview logging info of inductor mms @@ -655,15 +653,9 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if (not is_nonzero) or ( not (inductor_config.max_autotune or inductor_config.max_autotune_gemm) ): - # TODO(coconutruben): combine this with the main flow of addmm through - # a subgraph or something as inp vs inp_expanded causes some slight numeric - # differences - kernel_inputs = MMKernelInputs( - [inp, mat1, mat2], scalars=dict(alpha=alpha, beta=beta) - ) choices.extend( V.choices.get_template_configs( - kernel_inputs, + kernel_inputs_aten, [aten_addmm], name, ) @@ -673,23 +665,35 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): ) return node - # Collect all templates for unified call templates_to_use: list[ExternKernelChoice | KernelTemplate] = [] + if use_aten_gemm_kernels(): - templates_to_use.append(aten_addmm) + aten_templates: list[ExternKernelChoice | KernelTemplate] = [aten_addmm] if ( - inp_expanded.get_stride()[0] == 0 + inp.get_stride()[0] == 0 + and len(inp.get_size()) == 2 and inductor_config.triton.autotune_cublasLt + and not V.graph.cpp_wrapper # bias_addmm only has a Python implementation ): - templates_to_use.append(aten_bias_addmm) + aten_templates.append(aten_bias_addmm) + + # On ROCm, ATen choices use original bias input; non-ROCm keeps unified inputs. + choices.extend( + V.choices.get_template_configs(kernel_inputs_aten, aten_templates, name) + ) if is_nonzero and use_triton_template(layout, check_max_autotune=False): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): + if use_triton_blackwell_tma_template( + mat1, mat2, output_layout=layout, add_guards=True + ): templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + elif use_triton_tma_template(mat1, mat2, output_layout=layout, add_guards=True): + if torch.version.hip is None: + templates_to_use.append(persistent_tma_mm_template) + else: + templates_to_use.append(persistent_mm_template) templates_to_use.append(addmm_contiguous_subgraph_template) @@ -972,7 +976,10 @@ def tuned_scaled_mm( # TODO (paulzhan): There is no template that exists for bias and TMA # Don't run tma template currently if bias exist - if use_triton_tma_template(mat_a, mat_b, output_layout=layout) and not bias: + if ( + use_triton_tma_template(mat_a, mat_b, output_layout=layout, add_guards=True) + and not bias + ): overriders["SCALE_RECIPE_A"] = scale_option_a.value overriders["SCALE_RECIPE_B"] = scale_option_b.value @@ -1000,7 +1007,9 @@ def tuned_scaled_mm( ) if ( - use_triton_blackwell_tma_template(mat_a, mat_b, output_layout=layout) + use_triton_blackwell_tma_template( + mat_a, mat_b, output_layout=layout, add_guards=True + ) and not bias ): templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) diff --git a/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja index b5d915129b580..f0a363725ae13 100644 --- a/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_blackwell_ws_persistent_device_tma_mm.py.jinja @@ -43,17 +43,24 @@ ) offs_am = pid_m * BLOCK_M offs_bn = pid_n * BLOCK_N + offs_am_desc = offs_am.to(tl.int32) + offs_bn_desc = offs_bn.to(tl.int32) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_K + offs_k_desc = offs_k.to(tl.int32) a = tl.load_tensor_descriptor( a_desc, - [offs_am, offs_k] if A_ROW_MAJOR else [offs_k, offs_am], + [offs_am_desc, offs_k_desc] + if A_ROW_MAJOR + else [offs_k_desc, offs_am_desc], ) b = tl.load_tensor_descriptor( b_desc, - [offs_k, offs_bn] if B_ROW_MAJOR else [offs_bn, offs_k], + [offs_k_desc, offs_bn_desc] + if B_ROW_MAJOR + else [offs_bn_desc, offs_k_desc], ) accumulator += tl.dot( a if A_ROW_MAJOR else a.T, diff --git a/torch/_inductor/kernel/templates/triton_bmm.py.jinja b/torch/_inductor/kernel/templates/triton_bmm.py.jinja index c7e3e574e8de7..94cec7f806301 100644 --- a/torch/_inductor/kernel/templates/triton_bmm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_bmm.py.jinja @@ -2,6 +2,7 @@ M = {{size("A", -2)}} N = {{size("B", -1)}} K = {{size("A", -1)}} + BATCH = {{size("A", 0)}} stride_aq = {{stride("A", 0)}} stride_am = {{stride("A", 1)}} @@ -25,22 +26,26 @@ tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: ram = rm % M - if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1): + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: rbn = rn % N - rk = tl.arange(0, BLOCK_K) + rk = tl.arange(0, BLOCK_K).to(INDEX_DTYPE) - idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM - A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) - B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + # Reconstruct batch index from grid_y/grid_z split (handles batch > 65535) + idx_q = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)).to(INDEX_DTYPE) + # Clamp to valid range for safe pointer arithmetic; out-of-bounds CTAs are + # masked off at the store below. + idx_q_clamped = tl.minimum(idx_q, BATCH - 1) + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q_clamped*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q_clamped*stride_bq) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k in range(K, 0, -BLOCK_K): @@ -55,12 +60,14 @@ B += BLOCK_K * stride_bk # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) + idx_q = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)).to(INDEX_DTYPE) idx_m = rm[:, None] idx_n = rn[None, :] - mask = (idx_m < M) & (idx_n < N) + mask = (idx_m < M) & (idx_n < N) & (idx_q < BATCH) + # cast accumulator to output dtype + acc = acc.to(OUT_DTYPE) # inductor generates a suffix {{store_output(("idx_q", "idx_m", "idx_n"), "acc", "mask", val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja index 56ef18b7a91e3..cd3a6bdd4913a 100644 --- a/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_epilogue_scaled_mm.py.jinja @@ -98,17 +98,26 @@ offs_bn = pid_n * BLOCK_N offs_k = ki * BLOCK_K + offs_am_desc = offs_am.to(tl.int32) + offs_bn_desc = offs_bn.to(tl.int32) + offs_k_desc = offs_k.to(tl.int32) {%- if TMA_EXPERIMENTAL_API %} a = tl._experimental_descriptor_load( - a_desc_ptr, [offs_am, offs_k], [BLOCK_M, BLOCK_K], A.dtype.element_ty + a_desc_ptr, + [offs_am_desc, offs_k_desc], + [BLOCK_M, BLOCK_K], + A.dtype.element_ty, ) b = tl._experimental_descriptor_load( - b_desc_ptr, [offs_bn, offs_k], [BLOCK_N, BLOCK_K], B.dtype.element_ty + b_desc_ptr, + [offs_bn_desc, offs_k_desc], + [BLOCK_N, BLOCK_K], + B.dtype.element_ty, ) {%- else %} - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + a = tl.load_tensor_descriptor(a_desc, [offs_am_desc, offs_k_desc]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn_desc, offs_k_desc]) {%- endif %} if USE_FAST_ACCUM: accumulator = tl.dot(a, b.T, accumulator) diff --git a/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja index 91f73e8742ac6..0c2e51a1ec2f2 100644 --- a/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_main_loop_scaled_mm.py.jinja @@ -45,13 +45,16 @@ offs_am = pid_m * BLOCK_M offs_bn = pid_n * BLOCK_N + offs_am_desc = offs_am.to(tl.int32) + offs_bn_desc = offs_bn.to(tl.int32) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for ki in range(0, k_tiles): offs_k = ki * BLOCK_K + offs_k_desc = offs_k.to(tl.int32) - a = tl.load_tensor_descriptor(a_desc, [offs_am, offs_k]) - b = tl.load_tensor_descriptor(b_desc, [offs_bn, offs_k]) + a = tl.load_tensor_descriptor(a_desc, [offs_am_desc, offs_k_desc]) + b = tl.load_tensor_descriptor(b_desc, [offs_bn_desc, offs_k_desc]) {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 scale_a_block = blockwise128x128_scaling( @@ -80,7 +83,7 @@ ) {%- endif %} - {%- if SCALE_RECIPE_A == 5 %} # ScalingType.Blockwise128x128 + {%- if SCALE_RECIPE_B == 5 %} # ScalingType.Blockwise128x128 scale_b_block = blockwise128x128_scaling( pid_n, B_inverse_scale, diff --git a/torch/_inductor/kernel/templates/triton_mm.py.jinja b/torch/_inductor/kernel/templates/triton_mm.py.jinja index 2da348f3e767c..fcb22c94eada1 100644 --- a/torch/_inductor/kernel/templates/triton_mm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_mm.py.jinja @@ -24,8 +24,8 @@ tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: @@ -34,7 +34,7 @@ offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) + offs_k = tl.arange(0, BLOCK_K).to(INDEX_DTYPE) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k_idx in range(0, tl.cdiv(K, BLOCK_K)): @@ -62,8 +62,8 @@ {% endif %} # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) idx_m = rm[:, None] idx_n = rn[None, :] mask = (idx_m < M) & (idx_n < N) diff --git a/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja index 46c40d99aa502..52049582bf21d 100644 --- a/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja +++ b/torch/_inductor/kernel/templates/triton_mm_grouped.py.jinja @@ -294,10 +294,10 @@ def do_mma(a, b, accumulator): {%- endif %} accumulator = do_mma(a, b, accumulator) {%- else %} - offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) - offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) for k_block_offset in range(0, k_size, BLOCK_K): - block_offs_k = k_block_offset + tl.arange(0, BLOCK_K) + block_offs_k = k_block_offset + tl.arange(0, BLOCK_K).to(INDEX_DTYPE) offs_k = block_offs_k + k_start_offset a_ptrs = ( a_ptr @@ -328,8 +328,8 @@ def do_mma(a, b, accumulator): b_ptrs += BLOCK_K {%- endif %} - offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) - offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N) + offs_am = tile_m_idx * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + offs_bn = tile_n_idx * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) {%- if SCALED %} scale_a = tl.load( scale_a_ptr diff --git a/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja index 42b99c70d5cbd..6ad00438cf66d 100644 --- a/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_mm_rocm.py.jinja @@ -24,8 +24,8 @@ tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1): offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) else: @@ -34,7 +34,7 @@ offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) else: offs_b_n = rn % N - offs_k = tl.arange(0, BLOCK_K) + offs_k = tl.arange(0, BLOCK_K).to(INDEX_DTYPE) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for k_idx in range(0, tl.cdiv(K, BLOCK_K)): @@ -61,8 +61,8 @@ {% endif %} # rematerialize rm and rn to save registers - rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M).to(INDEX_DTYPE) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N).to(INDEX_DTYPE) idx_m = rm[:, None] idx_n = rn[None, :] mask = (idx_m < M) & (idx_n < N) diff --git a/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja new file mode 100644 index 0000000000000..7450533be5adf --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja @@ -0,0 +1,75 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # persistent kernel: each CTA processes multiple tiles + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + num_tiles = grid_m * grid_n + width = GROUP_M * grid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): + + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=12, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=12, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=8, val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja index b27ef29478d8b..cb6b4a1440812 100644 --- a/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja +++ b/torch/_inductor/kernel/templates/triton_persistent_tma_mm.py.jinja @@ -66,31 +66,34 @@ rm = pid_m * BLOCK_M rn = pid_n * BLOCK_N + rm_desc = rm.to(tl.int32) + rn_desc = rn.to(tl.int32) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) for rk in tl.range(0, K, BLOCK_K): + rk_desc = rk.to(tl.int32) {%- if TMA_EXPERIMENTAL_API %} a = tl._experimental_descriptor_load( a_desc_ptr, - [rm, rk] if A_ROW_MAJOR else [rk, rm], + [rm_desc, rk_desc] if A_ROW_MAJOR else [rk_desc, rm_desc], [BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M], A.dtype.element_ty, ) b = tl._experimental_descriptor_load( b_desc_ptr, - [rk, rn] if B_ROW_MAJOR else [rn, rk], + [rk_desc, rn_desc] if B_ROW_MAJOR else [rn_desc, rk_desc], [BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K], B.dtype.element_ty, ) {%- else %} a = tl.load_tensor_descriptor( a_desc, - [rm, rk] if A_ROW_MAJOR else [rk, rm], + [rm_desc, rk_desc] if A_ROW_MAJOR else [rk_desc, rm_desc], ) b = tl.load_tensor_descriptor( b_desc, - [rk, rn] if B_ROW_MAJOR else [rn, rk], + [rk_desc, rn_desc] if B_ROW_MAJOR else [rn_desc, rk_desc], ) {%- endif %} if USE_FAST_ACCUM: diff --git a/torch/_inductor/kernel/vendored_templates/cutedsl/wrappers/__init__.py b/torch/_inductor/kernel/vendored_templates/cutedsl/wrappers/__init__.py index 583f1f04b7c97..3d56aeae8e586 100644 --- a/torch/_inductor/kernel/vendored_templates/cutedsl/wrappers/__init__.py +++ b/torch/_inductor/kernel/vendored_templates/cutedsl/wrappers/__init__.py @@ -1,2 +1,2 @@ # CuTeDSL Cutlass API registrations for PyTorch Inductor. -from . import dense_blockscaled_gemm_kernel # noqa: F401 +from . import dense_blockscaled_gemm_kernel diff --git a/torch/_inductor/lookup_table/choices.py b/torch/_inductor/lookup_table/choices.py index ca2bbc8bd2bf9..03fdab47e11e7 100644 --- a/torch/_inductor/lookup_table/choices.py +++ b/torch/_inductor/lookup_table/choices.py @@ -404,11 +404,14 @@ def _create_lookup_choices( continue # For each lookup config, create a KTC with the override kwargs + # Start from base params (which include derived options like OUT_DTYPE) + # and let the lookup config override them + base_kwargs = base_ktc.params.to_kwargs() for c in configs: + merged = {**base_kwargs, **c} lookup_ktc = KernelTemplateChoice( template=base_ktc.template, - # use the ones from the lookup table - params=DictKernelTemplateParams(c), + params=DictKernelTemplateParams(merged), extra_kwargs=base_ktc.extra_kwargs, layout=base_ktc.layout, inputs=base_ktc.inputs, diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 41f936661db9b..f65ab5e1f8ccc 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -13,6 +13,7 @@ import torch.fx from torch._dynamo.utils import identity from torch.fx.proxy import Scope, TracerBase +from torch.utils._sympy.functions import Mod from torch.utils._sympy.symbol import SymT from . import config, dependencies @@ -22,6 +23,7 @@ cache_on_self, reduction_num_outputs, sympy_index_symbol_with_prefix, + sympy_product, sympy_subs, ) from .virtualized import ops, V @@ -161,8 +163,10 @@ def _init_with_tracing(self, fn, args): self.memory_usage = {t: [] for t in MemoryUsageType} self.op_counts = collections.Counter() self.root_block = LoopBodyBlock(self, fn, args) # traces - self.has_partial_accumulate = self.root_block.graph.find_nodes( - op="call_method", target="partial_accumulate" + self.has_partial_accumulate = bool( + self.root_block.graph.find_nodes( + op="call_method", target="partial_accumulate" + ) ) del self.indexing_exprs_name # not used after _init_with_tracing @@ -268,7 +272,7 @@ def new_body(*indices: Sequence[sympy.Expr]) -> Any: reduce_idx = index[len(iter_size) :] new_iter_idx = list(iter_idx) - new_iter_idx[dimension] = iter_idx[dimension] % original_range + new_iter_idx[dimension] = Mod(iter_idx[dimension], original_range) return old_body(new_iter_idx, reduce_idx) @@ -286,6 +290,58 @@ def new_body(*indices: Sequence[sympy.Expr]) -> Any: ) return new_body + def reindex_iter_loops(self, new_iter_sizes: Sequence[sympy.Expr]) -> LoopBody: + """ + Reindex iteration loops into a different factorization of the same + total numel. For example, [1024, 8192] -> [65536, 128]. + + The old iteration vars are expressed as functions of the new vars via + FloorDiv and ModularIndexing on the flat index. + """ + from torch.utils._sympy.functions import ModularIndexing + + old_body = self + old_iter_sizes = self.sizes[0] + reduce_sizes = self.sizes[1] + + new_sizes = (list(new_iter_sizes), list(reduce_sizes)) + + (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="t", # type: ignore[arg-type] + ) + + def new_body(*indices: Sequence[sympy.Expr]) -> Any: + index = [*itertools.chain.from_iterable(indices)] + new_iter_idx = index[: len(new_iter_sizes)] + reduce_idx = index[len(new_iter_sizes) :] + # Build flat index from new iter vars + flat = sympy.S.Zero + for v, s in zip(new_iter_idx, new_iter_sizes): + flat = flat * s + v + # Express old iter vars from flat index + old_iter_idx: list[sympy.Expr] = [] + for i, old_size in enumerate(old_iter_sizes): + tail = sympy_product(old_iter_sizes[i + 1 :]) + old_iter_idx.append(ModularIndexing(flat, tail, old_size)) + return old_body(old_iter_idx, list(reduce_idx)) + + loop_body = LoopBody( + new_body, (iter_vars, reduce_vars), var_ranges, iter_vars, reduce_vars + ) + + (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( + *new_sizes, + prefix="p", # type: ignore[arg-type] + ) + return LoopBody( + loop_body, + (iter_vars2, reduce_vars2), + var_ranges2, + iter_vars2, + reduce_vars2, + ) + def reorder_iter_loops(self, new_order) -> LoopBody: """ Reorder iteration loops and return a new LoopBody. @@ -532,7 +588,12 @@ def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): from .index_propagation import IndexPropagation handler: Any = CountOps( - CaptureIndexing(proxy_ops, body, tracer), + CaptureIndexing( + # pyrefly: ignore[bad-argument-type] + proxy_ops, + body, + tracer, + ), body.op_counts, ) if config.constant_and_index_propagation: diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 449fc2f65db9b..6877f3c58f9d4 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -9,6 +9,7 @@ import math import operator import os +import sys import warnings from collections import defaultdict from collections.abc import Callable, Collection, Iterable, Sequence @@ -79,6 +80,7 @@ ) from .utils import ( ceildiv, + convert_symint_to_expr, decode_device, is_dynamic, is_gpu, @@ -957,6 +959,9 @@ def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False): def _get_primitive_bitwidth(dtype): if dtype.is_floating_point: return torch.finfo(dtype).bits + elif dtype == torch.bool: + # torch.iinfo doesn't support bool; bools are stored as uint8 (8 bits) + return 8 else: return torch.iinfo(dtype).bits @@ -1383,8 +1388,10 @@ def permute(x, dims): return TensorBox(PermuteView.create(x.data, tuple(dims))) +# Note: logic in this function need to be always synchronized with +# slice_forward in fake implementation. @register_lowering(aten.slice, type_promotion_kind=None) -def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True): +def slice_(x, dim=0, start=0, end=sys.maxsize, step=1, clamp=True): """ Lowers a slice call, creating ExternKernels for the output size & storage offset symbols, if the indices are unbacked and appropriate semantics aren't known. @@ -1421,30 +1428,78 @@ def compute_slice_index(index, size, default=None): fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731 index = sympy.expand(index) size = sympy.expand(size) - if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)): + if fn(sympy.And(sympy.Ge(index, 0), sympy.Le(index, size))): return index - elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)): + elif fn(sympy.And(sympy.Lt(index, 0), sympy.Ge(index, -size))): return index + size elif fn(sympy.Gt(index, size)): return size elif fn(sympy.Lt(index, -size)): return 0 + elif fn(sympy.Ge(index, 0)): + # If index >= 0, the resolved index is at most min(index, size). + return sympy.Min(index, size) + elif fn(sympy.Lt(index, 0)): + # If index < 0, wrap and clamp: the resolved index is at least 0. + return sympy.Max(index + size, 0) return None start_index, end_index = None, None + # ambiguous_slice=False means we know what semantics this slice call follows, + # and don't need to generate an extern kernel to represent the output size. + # This is assumed True for clamp=False + # (meant to follow standard indexing semantics: 0 <= index < size) ambiguous_slice = clamp if ambiguous_slice: start_index = compute_slice_index(start, size, 0) - end_index = compute_slice_index(end, size, size) + # Special case: if end is maxsize (unbounded), use size directly + # This matches the logic in fake_impls.py + if end is not None and V.graph.sizevars.statically_known_equals( + end, sys.maxsize + ): + end_index = size + else: + end_index = compute_slice_index(end, size, size) if start_index is not None and end_index is not None: start, end = start_index, end_index ambiguous_slice = False - # ambiguous_slice=False means we know what semantics this slice call follows, - # and don't need to generate an extern kernel to represent the output size. - # This is assumed True for clamp=False - # (meant to follow standard indexing semantics: 0 <= index < size) if not ambiguous_slice: + # Even though the bounds are resolvable now, the FX node may have + # allocated unbacked symbols for the slice output size because dynamo + # couldn't prove the bounds at trace time (constraints may have been + # learned after tracing the slice). We still need to define those + # symbols so the assertion new_unbacked_defs >= renamed_unbacked_bindings + # passes. Register a DynamicSliceSize operation to define the size symbol. + # Note: storage_offset bindings should not appear here because + # a resolved start_index means the offset is computable directly + # (base_offset + start * stride), so dynamo wouldn't allocate an + # unbacked symbol for it. + # Note: current_node may be None when slice_ is called from template + # rendering (e.g. cpp_template_kernel.slice_nd) rather than FX graph + # lowering, so we handle that. + current_node = V.graph.current_node + node_unbacked_bindings = resolve_unbacked_bindings( + V.graph.sizevars.shape_env, + current_node.meta.get("unbacked_bindings", {}) + if current_node is not None + else {}, + ) + if node_unbacked_bindings: + for sym, keypath in node_unbacked_bindings.items(): + if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)): + b_size = ir.DynamicSliceSize(sym, start, end, step, size) + b_size.name = V.graph.register_buffer(b_size) + V.graph.register_operation(b_size) + elif keypath == (CallMethodKey("storage_offset"),): + # Not handled yet — would require materializing the + # tensor layout. Unlikely to be hit because a resolved + # start_index means the offset is computable directly. + raise AssertionError( + "Unexpected storage_offset unbacked binding when both " + "start and end indices are resolved" + ) + return TensorBox( ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp) ) # go to SliceView/ReinterpretView @@ -1928,6 +1983,85 @@ def inner_fn(idx): ) +def _cat_inputs_recombine_reduction(inputs: list[TensorBox], dim: int) -> str | None: + """If all cat inputs share a common upstream reduction buffer whose + only consumers feed into this cat, return its name so it can be + excluded from the can_fuse_reduction check. + + Checks common reads for an IR reduction whose numel matches the cat + output, then verifies via FX origins that all of the reduction's + consumers feed into the cat inputs.""" + if len(inputs) < 2: + return None + + common_reads = inputs[0].get_read_names() + for inp in inputs[1:]: + common_reads = common_reads & inp.get_read_names() + if not common_reads: + return None + + # Find a common read that is an IR reduction buffer whose input + # numel matches the cat output numel. + cat_out_numel = convert_symint_to_expr(V.graph.current_node.meta["val"].numel()) + reduction_name = None + reduction_buf = None + for name in common_reads: + buf = V.graph.try_get_buffer(name) + if ( + buf is not None + and isinstance(buf, ir.ComputedBuffer) + and isinstance(buf.data, ir.Reduction) + ): + reduction_numel = sympy_product(buf.data.get_size()) * sympy_product( + buf.data.get_reduction_size() + ) + if V.graph.sizevars.statically_known_equals(cat_out_numel, reduction_numel): + reduction_name = name + reduction_buf = buf + break + + if reduction_name is None: + return None + + # Verify the reduction doesn't have consumers outside this cat's + # computation. Each IR node tracks which FX nodes produced it + # (origins). Collect the FX origins of all cat inputs, then check + # that every FX user of the reduction's origins feeds into one of + # the cat inputs. + # + # We also tried checking IR-level users via V.graph.name_to_users, + # but at lowering time the cat inputs are unrealized TensorBox + # wrappers (not named buffers), so name_to_users entries can't be + # correlated back to the cat's input chain. + # + # TODO: origins is a set of FX nodes attached to IR nodes during + # lowering — using it for correctness is fragile. A proper + # buffer→FX node mapping would be better. + origins = getattr(reduction_buf, "origins", None) + if not origins: + return None + + cat_input_origins: OrderedSet[torch.fx.Node] = OrderedSet() + for inp in inputs: + inp_origins = getattr(inp, "origins", None) + if inp_origins: + cat_input_origins.update(inp_origins) + + # Check that the reduction FX node's users all feed into the cat. + # origins may include non-reduction nodes (e.g. pow that feeds into + # mean), so filter to only reduction ops via torch.Tag.reduction. + for origin in origins: + if ( + origin.op == "call_function" + and isinstance(origin.target, torch._ops.OpOverload) + and torch.Tag.reduction in origin.target.tags + and not all(u in cat_input_origins for u in origin.users) + ): + return None + + return reduction_name + + @register_lowering(aten.cat) def cat(inputs, dim=0): """Lower aten.cat, choosing between pointwise_cat and ConcatKernel.""" @@ -1967,20 +2101,26 @@ def unwrap_tensor(x: TensorBox | ir.StorageBox) -> ir.IRNode: def is_reduction(t): return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction) - def can_fuse_reduction(t): + def can_fuse_reduction(t, exclude: OrderedSet[str] = OrderedSet()): if isinstance(t, (TensorBox, ir.StorageBox)): - return can_fuse_reduction(unwrap_tensor(t)) + return can_fuse_reduction(unwrap_tensor(t), exclude) return ( is_reduction(t) or isinstance(t, ir.Pointwise) and any( - can_fuse_reduction(V.graph.get_buffer(read)) + read not in exclude + and can_fuse_reduction(V.graph.get_buffer(read), exclude) for read in t.get_read_names() ) ) - # fusing reducutions into computed concat buffer can cause regressions. - fusable_reduction = any(can_fuse_reduction(t) for t in inputs) + # Pointwise cat evaluates every input's computation for each + # output element (masked), so fusing reductions in is wasteful. + # Exception: when inputs just recombine a reduction's output + # (e.g. qknorm → RoPE → cat), we do not duplicate computation + recombined = _cat_inputs_recombine_reduction(inputs, dim) + exclude: OrderedSet[str] = OrderedSet([recombined]) if recombined else OrderedSet() + fusable_reduction = any(can_fuse_reduction(t, exclude) for t in inputs) def should_lower_cat_input(x) -> bool: # Unrealized inputs will not be storage and layouts, and we dont want to realize @@ -2048,31 +2188,46 @@ def additional_pointwise_ops(op: torch._ops.OpOverload): # Skip pointwise_cat when any cat input has a fusible (pointwise) # multi-consumer — ConcatKernel + NonOwningLayout avoids redundant - # reads. Non-fusible users (e.g. matmul) don't benefit, so they - # should not prevent pointwise_cat. + # reads. Also skip when input is an unrealized Pointwise with + # multiple consumers to avoid recomputation (e.g. pad-as-cat). def any_input_has_multi_consumers() -> bool: - cat_node = V.current_node - if cat_node is None: + current_node = V.current_node + if current_node is None: return False - fx_args = cat_node.args[0] # aten.cat format: cat(input_list, dim) - if not isinstance(fx_args, (list, tuple)): + fx_args = current_node.args[0] + if isinstance(fx_args, (list, tuple)): + input_nodes = fx_args + elif isinstance(fx_args, torch.fx.Node): + input_nodes = [fx_args] + else: return False - def has_fusible_multi_consumer(arg): - if not hasattr(arg, "users") or len(arg.users) <= 1: - return False - return any(is_pointwise_use(u) for u in arg.users if u is not cat_node) + def is_unrealized_pointwise(x): + if isinstance(x, (TensorBox, ir.StorageBox)): + return is_unrealized_pointwise(unwrap_tensor(x)) + return isinstance(x, ir.Pointwise) - return any(has_fusible_multi_consumer(arg) for arg in fx_args) + for arg, ir_input in zip(input_nodes, inputs): + if not hasattr(arg, "users") or len(arg.users) <= 1: + continue + # input will be computed multiple times because other consumers + # (eg. pointwise) will also inline it. So we should realize-in-place via ConcatKernel + if any(is_pointwise_use(u) for u in arg.users if u is not current_node): + return True + # If input is an unrealized Pointwise with multiple consumers, pointwise_cat + # will inline input without realizing it to memory, causing separate + # realization cost for input. So we should realize-in-place via ConcatKernel + if is_unrealized_pointwise(ir_input): + return True + return False has_multi_consumers = any_input_has_multi_consumers() - horizontal_fuse_cat = all( - should_lower_cat_input(inp) for inp in inputs - ) and not any(can_fuse_reduction(t) for t in inputs) - if not has_multi_consumers and ( - fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction) - ): + horizontal_fuse_cat = ( + all(should_lower_cat_input(inp) for inp in inputs) and not fusable_reduction + ) + + if not has_multi_consumers and (fuse_pointwise_use or horizontal_fuse_cat): return pointwise_cat(inputs, dim) return TensorBox(ir.ConcatKernel.create(inputs, dim)) @@ -2433,8 +2588,15 @@ def check_skip_condition(inp_out_node, is_output): return check_skip_condition(node, is_output=True) -def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): - assert op not in decompositions or override_decomp, ( +def make_fallback( + op, + layout_constraint=None, + warn=True, + override_decomp=False, + get_decomp_fn=None, +): + check_decomps = get_decomp_fn() if get_decomp_fn is not None else decompositions + assert op not in check_decomps or override_decomp, ( f"both a fallback and a decomp for same op: {op}" ) if ( @@ -2589,10 +2751,15 @@ def warn_triton_random(): fallback_randn_default = fallback_handler(aten.randn.default) fallback_randn_generator = fallback_handler(aten.randn.generator) make_fallback(aten.randint) +make_fallback(aten.rand_like, override_decomp=True) +make_fallback(aten.randn_like, override_decomp=True) +make_fallback(aten.randint_like, override_decomp=True) # TODO: mlazos reevaluate if we want to codegen something different make_fallback(torch.ops.streams.record_event.default) make_fallback(torch.ops.streams.wait_event.default) +make_fallback(torch.ops.streams.synchronize_event.default) +make_fallback(torch.ops.streams.synchronize_device.default) @register_lowering(aten.rand) @@ -2645,8 +2812,35 @@ def inner_fn(_): ) +def get_threads_per_round(device: torch.device): + if not isinstance(device, torch.device): + device = torch.device(device) + + if device.type == "cuda": + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + + prop = torch.cuda.get_device_properties(idx) + threads_per_round = ( + prop.multi_processor_count * prop.max_threads_per_multi_processor + ) + else: + _CPU_GRAIN_SIZE = 32768 + threads_per_round = _CPU_GRAIN_SIZE + + return threads_per_round + + @register_lowering(inductor_prims.random, type_promotion_kind=None) -def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0): +def inductor_random( + size: list[int], + seed: TensorBox, + mode: str, + *, + offset: int = 0, + align_dtype: torch.dtype = torch.float32, +): assert not config.fallback_random assert mode in ("rand", "randn") size = [*size] @@ -2657,11 +2851,33 @@ def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int ).make_indexer() seed_loader = seed.make_loader() - def inner_fn(index): - return getattr(ops, mode)( - seed_loader([]), - ops.index_expr(random_pos(index), torch.int32), - ) + if config.align_random_eager and device.type == "cuda": + threads_per_round = get_threads_per_round(device) + + def _vec_from_dtype(dt: torch.dtype) -> int: + if dt in (torch.float16, torch.bfloat16): + return 8 + return 4 + + vec = _vec_from_dtype(align_dtype) + + def inner_fn(index): + rng_seed = seed_loader([0]) + base_offset = seed_loader([1]) + return ops.rand_eager( + rng_seed, + base_offset, + threads_per_round, + ops.index_expr(random_pos(index), torch.int32), + vec=int(vec), + ) + else: + + def inner_fn(index): + return getattr(ops, mode)( + seed_loader([]), + ops.index_expr(random_pos(index), torch.int32), + ) result = Pointwise.create( device=device, @@ -2673,6 +2889,10 @@ def inner_fn(index): return result +make_fallback(inductor_prims.rand_eager_offset) +make_fallback(inductor_prims.rand_eager_offsets) + + @register_lowering(inductor_prims.randint, type_promotion_kind=None) def inductor_randint( low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0 @@ -2881,16 +3101,24 @@ def inner_fn(index): return result +def _is_tensor_irnode(x): + return isinstance(x, ir.IRNode) and not isinstance( + x, (ir.NonTensorObj, ir.OpaqueMultiOutput) + ) + + def require_dense(_, *args, **kwargs): args, kwargs = pytree.tree_map_only( - ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs) + _is_tensor_irnode, ir.ExternKernel.require_stride1, (args, kwargs) ) return args, kwargs def require_contiguous(_, *args, **kwargs): args, kwargs = pytree.tree_map_only( - ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs) + _is_tensor_irnode, + ir.ExternKernel.require_contiguous, + (args, kwargs), ) return args, kwargs @@ -2899,14 +3127,18 @@ def require_contiguous_strides(_, *args, **kwargs): # TODO: combine this with require_contiguous after # https://github.com/pytorch/pytorch/pull/148235 lands. args, kwargs = pytree.tree_map_only( - ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs) + _is_tensor_irnode, + ir.ExternKernel.require_contiguous_strides, + (args, kwargs), ) return args, kwargs def require_channels_last(_, *args, **kwargs): args, kwargs = pytree.tree_map_only( - ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs) + _is_tensor_irnode, + ir.ExternKernel.require_channels_last, + (args, kwargs), ) return args, kwargs @@ -2938,9 +3170,12 @@ def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs): def constrain_to_fx_strides(fx_node, *args, **kwargs): def apply_constraint(arg, fx_arg): - if isinstance(arg, ir.IRNode): + if _is_tensor_irnode(arg): + fake_val = fx_arg.meta.get("val") + if not isinstance(fake_val, torch.Tensor): + return arg stride_order = ir.get_stride_order( - fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env + fake_val.stride(), V.graph.sizevars.shape_env ) return ir.ExternKernel.require_stride_order(arg, stride_order) if isinstance(arg, dict): @@ -2964,7 +3199,7 @@ def sdpa_constraint(fx_node, *args, **kwargs): """Apply stride constraints to SDPA inputs, ensuring dense last dimension.""" def apply_constraint(idx, arg, fx_arg): - if not isinstance(arg, ir.IRNode): + if not _is_tensor_irnode(arg): return arg meta_val = fx_arg.meta["val"] @@ -3015,7 +3250,7 @@ def apply_constraint(idx, arg, fx_arg): return result def _apply_constraint_inner(idx, arg, meta_val, meta_stride_expr, stride_order): - if not meta_val.is_cuda: + if not (meta_val.is_cuda or meta_val.is_xpu): return ir.ExternKernel.require_stride_order(arg, stride_order) # This is the minimum alignment required by SDPA kernels for attention_bias. @@ -3074,12 +3309,14 @@ def _apply_constraint_inner(idx, arg, meta_val, meta_stride_expr, stride_order): # we can make them expanded by setting the stride equal to 0 if i in expanded_dims: if V.graph.sizevars.statically_known_equals( - out_strides[i + 1] % ALIGNMENT, 0 + Mod(out_strides[i + 1], ALIGNMENT), 0 ): out_strides[i] = 0 continue - if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0): + if not V.graph.sizevars.statically_known_equals( + Mod(stride, ALIGNMENT), 0 + ): stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT out_strides[i] = stride @@ -3208,12 +3445,6 @@ def is_aligned(x): # 5) Impossible (missing triton/CPU features) # Sorting / Sorting-like -make_fallback(aten.sort) -make_fallback(aten.sort.stable) -make_fallback(aten.kthvalue) -make_fallback(aten.topk) -make_fallback(aten.mode) -make_fallback(aten.median) make_fallback(aten.nanmedian) make_fallback(aten.randperm) # see: https://github.com/pytorch/pytorch/pull/121354 @@ -3422,6 +3653,30 @@ def fn(index): ) +@register_lowering(aten.arange.start_step, type_promotion_kind=None) +def arange_start_step( + start, + end, + step=1, + *, + dtype=None, + device=None, + layout=None, + pin_memory=None, + requires_grad=False, +): + assert dtype is not None + length = ceildiv(end - start, step) + return iota( + length, + start=start, + step=step, + dtype=dtype, + device=device if device is not None else "cpu", + requires_grad=requires_grad, + ) + + @register_lowering(aten.select_scatter, type_promotion_kind=None) def select_scatter(x, src, dim: int, index: int): src = to_dtype(src, x.get_dtype()) @@ -3533,7 +3788,7 @@ def _unwrap(x): return x -@register_lowering([torch.tensor, aten.scalar_tensor]) +@register_lowering([torch.tensor, aten.scalar_tensor, prims.scalar_tensor]) def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): assert_nyi(layout in (None, torch.strided), f"layout={layout}") assert_nyi(not pin_memory, "pin_memory") @@ -3715,7 +3970,7 @@ def inner( ): assert_nyi(names is None, "named tensors") assert_nyi(layout in (None, torch.strided), f"layout={layout}") - assert_nyi(not pin_memory, "pin_memory") + assert_nyi(not memory_format, "memory_format") device = decode_device(device) dtype = dtype or torch.get_default_dtype() if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)): @@ -3725,7 +3980,14 @@ def inner( for s in size: assert not isinstance(s, torch.SymInt) size = [sympy.expand(s) for s in size] - return _full(fill_value, device, dtype, size) + full_pointwise = _full(fill_value, decode_device(device), dtype, size) + + if pin_memory: + # Realize the buffer + full_pointwise.realize() + full_pointwise.data.data.get_layout().is_pinned = True + + return full_pointwise return inner @@ -4271,7 +4533,6 @@ def indice_slice_from_randperm(indice): None, check=check, ) - values = expand(values, expected_vals_size) # all guards are set above during broadcast_tensors and expand @@ -4761,6 +5022,59 @@ def _padding_can_be_fused(): return resized_x +def _pad_as_cat( + x: TensorBox, padding: Sequence[int], fill_value: float +) -> TensorBox | None: + """Decompose right-pad into cat([x, fill], dim) and delegate to cat lowering. + + The cat lowering already has heuristics for choosing between pointwise_cat + (fusion) and ConcatKernel (memory planning / zero-copy). By routing through + cat() we reuse those heuristics rather than duplicating them here. + """ + # Bail out for symbolic padding, dynamic shapes + if not all(isinstance(p, int) for p in padding): + return None + + sizes = x.get_size() + ndim = len(sizes) + pad_pairs = list(zip(padding[::2], padding[1::2])) + + # Only support single-dimension right-pad + pad_dim = None + pad_amount = None + for i, (left, right) in enumerate(pad_pairs): + if left != 0: + return None + if right > 0: + if pad_dim is not None: + return None # multi-dim pad + pad_dim = ndim - 1 - i # padding format is reversed dim order + pad_amount = right + elif right < 0: + return None # trim, not pad + + if pad_dim is None: + return None + + # CPU cat always uses ConcatKernel (no pointwise_cat), which adds + # extra kernel launches for the fill. Skip pad-as-cat on CPU. + device = x.get_device() + if device is not None and device.type == "cpu": + return None + + # Build the fill tensor for the padding region + pad_shape = list(sizes) + pad_shape[pad_dim] = pad_amount + dtype = x.get_dtype() + fill_value_typed = dtype_to_type(dtype)(fill_value) + pad_tensor = tensor_constructor(fill_value_typed)( + pad_shape, dtype=dtype, device=device + ) + + counters["inductor"]["pad_rewritten_as_cat"] += 1 + return cat([x, pad_tensor], pad_dim) + + @register_lowering(aten.constant_pad_nd, type_promotion_kind=None) def constant_pad_nd(x, padding, fill_value=0): assert (len(padding) % 2) == 0 @@ -4773,6 +5087,10 @@ def constant_pad_nd(x, padding, fill_value=0): return out # fall through if can not inplace the padding + out = _pad_as_cat(x, padding, fill_value) + if out is not None: + return out + sizes = x.get_size() bounds = list(reversed(list(zip(padding[::2], padding[1::2])))) @@ -4852,7 +5170,7 @@ def load(index): mask = functools.reduce( ops.and_, - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)], ) return ( @@ -5767,7 +6085,7 @@ def fn(idx): device=x.get_device(), dtype=x.get_dtype(), inner_fn=fn, - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] ranges=list(input_size), ) @@ -6765,6 +7083,11 @@ def truncdiv(a, b): return ops.truncdiv(a, b) +@make_pointwise +def _div_rn(a, b): + return ops.div_rn(a, b) + + @register_lowering(aten.div, broadcast=True) def div_mode(a, b, rounding_mode=None): both_integer = is_integer_type(a) and is_integer_type(b) @@ -6774,7 +7097,11 @@ def div_mode(a, b, rounding_mode=None): # see the discussion at https://github.com/triton-lang/triton/issues/605 if rounding_mode == "floor": assert not both_boolean, "floordiv operands can not be boolean at the same time" - return floordiv(a, b) if both_integer else floor(div(a, b)) + # Use div_rn (IEEE round-to-nearest) instead of truediv here because + # Triton's default division uses an approximate reciprocal, which can + # produce a result slightly below the true quotient and cause floor() + # to round down by one. + return floordiv(a, b) if both_integer else floor(_div_rn(a, b)) if rounding_mode == "trunc": assert not both_boolean, "truncdiv operands can not be boolean at the same time" return truncdiv(a, b) if both_integer else trunc(div(a, b)) @@ -7081,11 +7408,18 @@ def sort_stable(x, *, stable=None, dim=-1, descending=False): return clone(x), _full(0, device, torch.int64, shape) dim_size = shape[dim] if len(shape) else 1 - if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max): + # Use int32 indices when decompose_sort_ops is enabled, allowing sort + # dimensions up to 2^31-1. Default int16 keeps register pressure low + # on GPU where the bitonic network holds all indices in-block. + if config.triton.decompose_sort_ops: + idx_dtype = torch.int32 + else: + idx_dtype = torch.int16 + if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(idx_dtype).max): return sort_fallback(x, stable=stable, dim=dim, descending=descending) indices = iota( - dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False + dim_size, start=0, step=1, dtype=idx_dtype, device=device, requires_grad=False ) view_shape = [1] * len(shape) if len(shape): @@ -7114,6 +7448,199 @@ def sort(x, dim=-1, descending=False): return sort_stable(x, stable=False, dim=dim, descending=descending) +# Sort-based op lowerings +# When config.triton.decompose_sort_ops is enabled, decompose into sort-based +# ops so Inductor generates Triton kernels via ir.Sort. +# Otherwise, fall back to ATen eager. +topk_fallback = fallback_handler(aten.topk.default, add_to_fallback_set=False) +kthvalue_fallback = fallback_handler(aten.kthvalue.default, add_to_fallback_set=False) +median_fallback = fallback_handler(aten.median.default, add_to_fallback_set=False) +median_dim_fallback = fallback_handler(aten.median.dim, add_to_fallback_set=False) +mode_fallback = fallback_handler(aten.mode.default, add_to_fallback_set=False) + +# sort/sort.stable already have register_lowering above (sort_stable, sort). +# They use ir.Sort directly and fall back when the dimension is too large. +# When decompose_sort_ops is enabled, the size limit is lifted (int32 indices). + + +@register_lowering(aten.median.default, type_promotion_kind=None) +def median_default(self): + if not config.triton.decompose_sort_ops: + return median_fallback(self) + size = self.get_size() + numel = functools.reduce(operator.mul, size, sympy.Integer(1)) + flat = view(self, [numel]) + sorted_vals, _ = sort_stable(flat, dim=0) + k = (numel - 1) // 2 + return select(sorted_vals, 0, k) + + +@register_lowering(aten.median.dim, type_promotion_kind=None) +def median_dim(self, dim, keepdim=False): + if not config.triton.decompose_sort_ops: + return median_dim_fallback(self, dim, keepdim) + shape = self.get_size() + ndim = len(shape) + if ndim == 0: + return clone(self), _full(0, self.get_device(), torch.int64, shape) + dim = canonicalize_dim(ndim, dim) + sorted_vals, sorted_idxs = sort_stable(self, stable=True, dim=dim) + n = shape[dim] + k = (n - 1) // 2 + values = select(sorted_vals, dim, k) + indices = select(sorted_idxs, dim, k) + if keepdim: + values = unsqueeze(values, dim) + indices = unsqueeze(indices, dim) + return values, indices + + +@register_lowering(aten.mode.default, type_promotion_kind=None) +def mode_default(self, dim=-1, keepdim=False): + """Lower aten.mode via sort-based decomposition or fallback.""" + if not config.triton.decompose_sort_ops: + return mode_fallback(self, dim, keepdim) + shape = self.get_size() + ndim = len(shape) + device = self.get_device() + if ndim == 0: + return clone(self), _full(0, device, torch.int64, shape) + dim = canonicalize_dim(ndim, dim) + sorted_vals, sorted_idxs = sort_stable(self, stable=True, dim=dim) + n = shape[dim] + + # Position indices along dim: [0, 1, ..., n-1] + positions = iota( + n, start=0, step=1, dtype=torch.int64, device=device, requires_grad=False + ) + pos_view_shape = [sympy.Integer(1)] * ndim + pos_view_shape[dim] = n + positions = view(positions, pos_view_shape) + positions = expand(positions, shape) + + # Shift positions by -1, clamp to 0 for position 0 + positions_loader0 = positions.make_loader() + + def prev_pos_fn(idx): + return ops.maximum( + ops.sub(positions_loader0(idx), ops.constant(1, torch.int64)), + ops.constant(0, torch.int64), + ) + + prev_positions = Pointwise.create( + device=decode_device(device), + dtype=torch.int64, + inner_fn=prev_pos_fn, + ranges=shape, + ) + + # Gather shifted values and compare for run boundaries + shifted_vals = gather(sorted_vals, dim, prev_positions) + + sorted_loader = sorted_vals.make_loader() + shifted_loader = shifted_vals.make_loader() + positions_loader = positions.make_loader() + + # is_boundary = (sorted != shifted) | (position == 0) + def is_boundary_fn(idx): + return ops.or_( + ops.ne(sorted_loader(idx), shifted_loader(idx)), + ops.eq(positions_loader(idx), ops.constant(0, torch.int64)), + ) + + is_boundary = Pointwise.create( + device=decode_device(device), + dtype=torch.bool, + inner_fn=is_boundary_fn, + ranges=shape, + ) + + # boundary_pos = where(is_boundary, position, -1) + is_boundary_loader = is_boundary.make_loader() + positions_loader2 = positions.make_loader() + + def boundary_pos_fn(idx): + return ops.where( + is_boundary_loader(idx), + positions_loader2(idx), + ops.constant(-1, torch.int64), + ) + + boundary_pos = Pointwise.create( + device=decode_device(device), + dtype=torch.int64, + inner_fn=boundary_pos_fn, + ranges=shape, + ) + + # Propagate boundary positions forward with cummax + last_boundary, _ = cummax(boundary_pos, dim) + + # run_len = position - last_boundary + 1 + positions_loader3 = positions.make_loader() + last_boundary_loader = last_boundary.make_loader() + + def run_len_fn(idx): + return ops.add( + ops.sub(positions_loader3(idx), last_boundary_loader(idx)), + ops.constant(1, torch.int64), + ) + + run_len = Pointwise.create( + device=decode_device(device), + dtype=torch.int64, + inner_fn=run_len_fn, + ranges=shape, + ) + + # argmax returns first maximum -> end of leftmost longest run + max_pos = reduce_argmax(run_len, axis=dim, keepdims=True) + mode_vals = gather(sorted_vals, dim, max_pos) + mode_idxs = gather(sorted_idxs, dim, max_pos) + + if not keepdim: + mode_vals = squeeze(mode_vals, dim) + mode_idxs = squeeze(mode_idxs, dim) + + return mode_vals, mode_idxs + + +@register_lowering(aten.topk.default, type_promotion_kind=None) +def topk(self, k, dim=-1, largest=True, sorted=True): + if not config.triton.decompose_sort_ops: + return topk_fallback(self, k, dim, largest, sorted) + shape = self.get_size() + ndim = len(shape) + if ndim == 0: + return clone(self), _full(0, self.get_device(), torch.int64, shape) + dim = canonicalize_dim(ndim, dim) + sorted_vals, sorted_idxs = sort_stable( + self, stable=True, dim=dim, descending=largest + ) + values = slice_(sorted_vals, dim, 0, k) + indices = slice_(sorted_idxs, dim, 0, k) + return values, indices + + +@register_lowering(aten.kthvalue.default, type_promotion_kind=None) +def kthvalue(self, k, dim=-1, keepdim=False): + if not config.triton.decompose_sort_ops: + return kthvalue_fallback(self, k, dim, keepdim) + shape = self.get_size() + ndim = len(shape) + if ndim == 0: + return clone(self), _full(0, self.get_device(), torch.int64, shape) + dim = canonicalize_dim(ndim, dim) + sorted_vals, sorted_idxs = sort_stable(self, stable=True, dim=dim) + # k is 1-based + values = select(sorted_vals, dim, k - 1) + indices = select(sorted_idxs, dim, k - 1) + if keepdim: + values = unsqueeze(values, dim) + indices = unsqueeze(indices, dim) + return values, indices + + def register_pointwise_numeric(op, name=None, triton_fallback=None): return register_pointwise( op, @@ -7456,6 +7983,7 @@ def make_triton_fallback(op): register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum) register_foreach_pointwise(aten._foreach_reciprocal, reciprocal) register_foreach_pointwise(aten._foreach_sign, sign) +register_foreach_pointwise(aten._foreach_clone, clone) foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy) @@ -7585,7 +8113,11 @@ def sym_numel(a): @register_lowering(torch.sym_sum) -def sym_sum(args): +def sym_sum(*args): + # sym_sum can be called as sym_sum([a, b]) or sym_sum(a, b). + # Normalize to a flat list before summing. + if len(args) == 1 and isinstance(args[0], (list, tuple)): + args = args[0] return sympy.Add(*args) @@ -7646,9 +8178,6 @@ def resize(x, size, *, memory_format=None): dtype = x.get_dtype() device = x.get_device_or_error() - if isinstance(x.data, ir.BaseView): - x.data = x.data.unwrap_view() - if ( torch.are_deterministic_algorithms_enabled() and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined] @@ -7666,15 +8195,18 @@ def resize(x, size, *, memory_format=None): if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type] return full(size, uninitialized_val, dtype=dtype, device=device) - x_flat = as_strided( - x, - [ - old_numel, - ], - [ - 1, - ], + strides = x.maybe_get_stride() + has_overlapping = strides is not None and any( + V.graph.sizevars.statically_known_equals(s, 0) for s in strides ) + if has_overlapping: + # overlapping: provide a contiguous logical view + x_flat = view(x, [old_numel]) + else: + # non-overlapping: keep storage order + if isinstance(x.data, ir.BaseView): + x.data = x.data.unwrap_view() + x_flat = as_strided(x, [old_numel], [1]) flat_loader = x_flat.make_loader() out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format) out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer() @@ -7766,6 +8298,9 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): return list(map(TensorBox.create, result)) # type: ignore[call-overload] +_MISSING = object() + + def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): """Process nodes from a FX graph by executing them through V.graph. @@ -7775,7 +8310,7 @@ def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): - Other nodes are executed via V.graph.run_node """ - output = None + output = _MISSING for i, node in enumerate(graph_module.graph.nodes): if node.op == "placeholder": @@ -7795,7 +8330,7 @@ def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): finally: V.graph.current_node = saved_current_node - if output is None: + if output is _MISSING: raise RuntimeError("No output node found in graph") return output @@ -7834,7 +8369,7 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args): # Process subgraph nodes using the shared helper output = process_subgraph_nodes(subgraph_fn.graph_module, list(args)) - assert output is not None and additional_deps + assert additional_deps # some operators, like wait_tensor, just return their input, # so its more robust to add dep to the operation itself, @@ -8126,6 +8661,43 @@ def cvt_e8m0_rceil_lowering(inp): return to_dtype(result, torch.uint8) +@register_lowering( + torch._higher_order_ops.inline_asm_elementwise, type_promotion_kind=None +) +def lower_inline_asm_elementwise( + *inputs, asm_str, constraints, dtype, is_pure=True, pack=1 +): + inputs = broadcast_tensors(*inputs) + + input_dtypes = tuple(inp.get_dtype() for inp in inputs) + loaders = [inp.make_loader() for inp in inputs] + + def inner_fn(idx): + vals = tuple(loader(idx) for loader in loaders) + result = ops.inline_asm_elementwise( + *vals, + asm=asm_str, + constraints=constraints, + dtype=dtype, + is_pure=is_pure, + pack=pack, + input_dtypes=input_dtypes, + ) + # Inductor computes in fp32 for bf16/fp16. Upcast so fused downstream + # ops (reductions, etc.) see fp32 values. The Pointwise's storage dtype + # handles the final downcast on store. + if dtype in (torch.float16, torch.bfloat16): + result = ops.to_dtype(result, torch.float32) + return result + + return ir.Pointwise.create( + device=inputs[0].get_device(), + dtype=dtype, + inner_fn=inner_fn, + ranges=list(inputs[0].get_size()), + ) + + # populate lowerings defined in kernel/* from . import kernel diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 47c3eac2b7e58..52b3116321d1b 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -59,6 +59,7 @@ class CppOuterLoopFusedCount: parallel_reduction_count = 0 codegen_mix_order_reduction = 0 +rejected_mix_order_reduction_fusion = 0 # reset all counters @@ -74,6 +75,7 @@ def reset() -> None: global num_loop_reordering global parallel_reduction_count global codegen_mix_order_reduction + global rejected_mix_order_reduction_fusion global num_auto_chunking generated_kernel_count = 0 @@ -89,6 +91,7 @@ def reset() -> None: num_loop_reordering = 0 parallel_reduction_count = 0 codegen_mix_order_reduction = 0 + rejected_mix_order_reduction_fusion = 0 num_auto_chunking = 0 diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index e193fe425dfe6..564e659a479be 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -742,7 +742,7 @@ def qlinear_unary( # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer # Refer to - # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 if w_zp is None: # If w_zp is None, then it's a dummy tensor created to denote the # absence of a zero point, and thus w is int8 symmetrically quantized. @@ -1055,7 +1055,7 @@ def qlinear_binary( # When channels less than 8, w_scale/w_zp is Pointwise instead of ConstantBuffer # Refer to - # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 # noqa: B950 + # https://github.com/pytorch/pytorch/blob/f353d17755ed23b02924c962a86ff99a3405fe10/torch/_inductor/graph.py#L570-L577 w_scale.realize() w_zp.realize() if w_zp.get_dtype() != torch.int32 and isinstance( diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 2999c32eb916c..510924f214c37 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -98,6 +98,12 @@ def rand(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="rand". offset has dtype int32.""" raise NotImplementedError + def rand_eager( + self, seed: T, base_offset: T, threads_per_round: T, tid: T, vec: T + ) -> T: + """Computes inductor_prims.random with mode="rand_eager". offset has dtype int32.""" + raise NotImplementedError + def randn(self, seed: T, offset: T) -> T: """Computes inductor_prims.random with mode="randn". offset has dtype int32.""" raise NotImplementedError @@ -735,6 +741,7 @@ def inline_asm_elementwise( dtype: torch.dtype = torch.float32, is_pure: bool = True, pack: int = 1, + input_dtypes: tuple[torch.dtype, ...] | None = None, ) -> T: raise NotImplementedError @@ -846,6 +853,7 @@ def masked(mask, body, other) -> None: return None @staticmethod + # pyrefly: ignore [bad-override] def frexp(x) -> tuple[None, None]: return (None, None) @@ -858,6 +866,7 @@ def sort(dtypes, values, stable, descending) -> tuple[None, ...]: return (None,) * len(values) @staticmethod + # pyrefly: ignore [bad-override] def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: return sympy.S.Zero @@ -955,6 +964,7 @@ def masked(mask, body, other) -> str: return f"ops.masked({mask}, {body()}, {other})" @staticmethod + # pyrefly: ignore [bad-override] def frexp(x): return (f"ops.frexp({x})[0]", f"ops.frexp({x})[1]") @@ -973,6 +983,7 @@ def sort(dtypes, values, stable, descending): ) @staticmethod + # pyrefly: ignore [bad-override] def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: return sympy_index_symbol(str(index_var)) diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index c8331d98bbe89..6d1279d22883a 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -50,6 +50,7 @@ output_node, set_tracing_context_output_strides, ) +from torch._opaque_base import OpaqueBase from torch.fx._graph_pickler import _node_metadata_key_filter_safe, _ops_filter_safe from torch.utils._ordered_set import OrderedSet from torch.utils._python_dispatch import is_in_torch_dispatch_mode @@ -226,23 +227,24 @@ def cudagraph_post_compile( placeholders = cached_info.placeholders stack_traces = cached_info.stack_traces + assert stack_traces is not None, ( + "stack_traces should not be None in cudagraph_post_compile" + ) prepare_cudagraph_post_compile( compiled_graph, example_inputs, boxed_forward_device_index ) - from .compile_fx import cudagraphify - current_callable = compiled_graph.current_callable assert current_callable is not None # Filter to only tensor constants (exclude opaque value type classes) tensor_constants = { k: v for k, v in constants.items() if isinstance(v, torch.Tensor) } - compiled_graph.current_callable = cudagraphify( - current_callable, - static_input_idxs=static_input_idxs or (), - device_index=next(iter(compiled_graph.device_idxs)), + + device_index = next(iter(compiled_graph.device_idxs)) + cudagraphify_kwargs = dict( + device_index=device_index, stack_traces=stack_traces, is_backward=is_backward, is_inference=is_inference, @@ -251,6 +253,23 @@ def cudagraph_post_compile( mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs), ) + policy = config.cudagraph_policy + if policy is not None: + compiled_graph.current_callable = policy.cudagraphify( + current_callable, + example_inputs, + static_input_idxs or (), + **cudagraphify_kwargs, + ) + else: + from .compile_fx import cudagraphify + + compiled_graph.current_callable = cudagraphify( + current_callable, + static_input_idxs=static_input_idxs or (), + **cudagraphify_kwargs, + ) + else: BoxedBool.disable(cudagraphs) maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) @@ -303,8 +322,6 @@ def cudagraph_partition_post_compile( maybe_handle_backward_generation(compiled_graph, boxed_forward_device_index) return - from .compile_fx import cudagraphify - assert compiled_graph.current_callable is not None assert compiled_graph.recursively_apply_fns is not None is_inference = compiled_graph.fx_kwargs["is_inference"] @@ -318,6 +335,9 @@ def cudagraph_partition_post_compile( k: v for k, v in constants.items() if isinstance(v, torch.Tensor) } + assert compiled_graph.cudagraph_info.stack_traces is not None, ( + "stack_traces should not be None in cudagraph_partition_post_compile" + ) graph_metadata = CudagraphMetadata( compiled_graph.cudagraph_info.placeholders, static_input_idxs, @@ -330,6 +350,8 @@ def cudagraph_partition_post_compile( compiled_graph, example_inputs, boxed_forward_device_index ) + from .compile_fx import cudagraphify + # cudagraphify each partition function, assuming every graph partition function # is cudagraphable. Non-cudagraphable ops (e.g., cpu ops) are inlined into # `call` function and not included in partition functions. @@ -367,14 +389,28 @@ def maybe_realign_inputs( we didn't end up running cudagraphs. Mutates `compiled_graph.current_callable` if cudagraphs was run. Otherwise, does nothing. + + Non-mutated inputs are handled by deferred alignment copies + in the generated code. Only mutated inputs need the wrapper + for writeback. """ if not ran_cudagraphs: - assert compiled_graph.current_callable is not None - new_callable = align_inputs_from_check_idxs( - compiled_graph.current_callable, inputs_to_check, mutated_inputs_idxs - ) - if new_callable is not compiled_graph.current_callable: - compiled_graph.current_callable = new_callable + check_idxs = inputs_to_check + if compiled_graph._defers_input_alignment: + # Non-mutated inputs are handled by deferred alignment copies + # in the generated Python code. Only mutated inputs need the wrapper + # for writeback. Backends that don't emit deferred copies (cpp_wrapper, + # FXIR) need the full wrapper. + check_idxs = [i for i in inputs_to_check if i in mutated_inputs_idxs] + if check_idxs: + assert compiled_graph.current_callable is not None + new_callable = align_inputs_from_check_idxs( + compiled_graph.current_callable, + check_idxs, + mutated_inputs_idxs, + ) + if new_callable is not compiled_graph.current_callable: + compiled_graph.current_callable = new_callable class CompiledFxGraphConstants: @@ -462,12 +498,14 @@ class CompiledFxGraph(OutputCode): cudagraph_info: CudagraphCachedInfo | None partition_maps: list[GraphPartitionMap] | None + compile_region_name: str | None fx_kwargs: _CompileFxKwargs inputs_to_check: Sequence[int] _boxed_call: bool | None = None _triton_bundle: TritonBundle | None = None _wrap_compiled_regions: bool = False + _defers_input_alignment: bool = False # Metadata-stripped copy of the FX graph for fake tensor propagation. # Running this graph under FakeTensorMode re-derives output shapes # (including aliasing) from the input shapes. @@ -485,6 +523,7 @@ def __init__( cudagraphs: BoxedBool, example_inputs: Sequence[InputType], static_input_idxs: Sequence[int], + compile_region_name: str | None, fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], runnable_graph_str: str, @@ -544,6 +583,7 @@ def __init__( self.extern_libs_key = None self.cudagraph_info = None self.partition_maps = graph.partition_maps + self._defers_input_alignment = getattr(graph, "_defers_input_alignment", False) self.fx_kwargs = {} self.inputs_to_check = () @@ -592,7 +632,15 @@ def __init__( (not complex_memory_overlap_inputs, "complex memory overlap"), ( all( - isinstance(t, (torch.Tensor, torch.SymInt, torch.Generator)) + isinstance( + t, + ( + torch.Tensor, + torch.SymInt, + torch.Generator, + OpaqueBase, + ), + ) for t in example_inputs ), "non-Tensor inputs", @@ -601,7 +649,10 @@ def __init__( output = output_node(gm) # output args are tuple of first argument assert len(output.args) == 1 - stack_traces = [ + # Use stack traces captured on the output node before + # post-grad passes, which may strip stack_trace from + # individual arg nodes. + stack_traces = output.meta.get("output_stack_traces") or [ (arg.stack_trace if isinstance(arg, torch.fx.node.Node) else None) for arg in output.args[0] # type: ignore[union-attr] ] @@ -612,6 +663,7 @@ def __init__( ) self.cudagraph_info = cudagraph_info + self.compile_region_name = compile_region_name self.inputs_to_check = inputs_to_check self.fx_kwargs = fx_kwargs @@ -712,6 +764,17 @@ def post_compile( assert graph_kwargs["is_backward"] is not None is_backward = graph_kwargs["is_backward"] cudagraphs: BoxedBool = graph_kwargs["cudagraphs"] + + # When a CUDAGraphPolicy is set and it says not to wrap this + # inner CompiledFxGraph (e.g. because wrapping happens at the + # outer level via policy.wrap_output), disable cudagraphs for + # this graph so the rest of post_compile (input realignment, + # _wrap_compiled_regions) still runs normally. + policy = config.cudagraph_policy + if policy is not None and not policy.should_wrap(self): + counters["inductor"]["cudagraph_skips"] += 1 + BoxedBool.disable(cudagraphs) + if cudagraphs: # It's possible that cudagraphs is enabled, but was disabled # during a previous compilation we're loading from the cache. @@ -737,9 +800,12 @@ def post_compile( "boxed_forward_device_index", None ) - if config.graph_partition: - # with graph_partition=True, we skip some cudagraph checks if it's supported - # with partition. So we have to use cudagraph_partition_post_compile. + if config.graph_partition and policy is None: + # With graph_partition=True, we skip some cudagraph checks + # if it's supported with partition, so we use + # cudagraph_partition_post_compile. When a CUDAGraphPolicy + # is active, we use cudagraph_post_compile instead so the + # policy controls wrapping via policy.cudagraphify(). cudagraph_partition_post_compile( example_inputs, self, @@ -773,12 +839,19 @@ def post_compile( original_callable = self.current_callable inductor_callable = InductorCompiledCallable( - original_callable, self._original_gm + original_callable, + self._original_gm, + compile_region_name=self.compile_region_name, ) def wrapped_callable(inputs): if is_in_torch_dispatch_mode(): - return inductor_compiled_code(inductor_callable, inputs) + kwargs = ( + {"name": self.compile_region_name} + if self.compile_region_name is not None + else {} + ) + return inductor_compiled_code(inductor_callable, inputs, **kwargs) else: return original_callable(inputs) @@ -893,6 +966,15 @@ def __post_init__(self): True, ).run # type: ignore[attr-defined] ) # type: ignore[attr-defined] + elif self.device_type.startswith("xpu"): + current_callable = ( + torch._C._aoti.AOTIModelContainerRunnerXpu( # type: ignore[call-arg] + current_callable, + 1, + self.device_type, + "", + ).run # type: ignore[attr-defined] + ) # type: ignore[attr-defined] elif self.device_type == "cpu": current_callable = ( torch._C._aoti.AOTIModelContainerRunnerCpu( # type: ignore[call-arg] diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 26af8b8bafe12..5b363cc1e50a2 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -90,6 +90,36 @@ backend = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_BACKEND", "inductor") +_debug_nodes_cache: bool | OrderedSet[str] | None = None +_debug_nodes_env_value_cache: str | None = None + + +def _should_debug_node(node_name: str) -> bool: + def _get_debug_nodes() -> bool | OrderedSet[str]: + global _debug_nodes_cache, _debug_nodes_env_value_cache + + def parse_debug_env(env_value: str | None) -> bool | OrderedSet[str]: + if not env_value: + return False + if env_value == "all": + return True + return OrderedSet(env_value.split(",")) + + current_env = os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") + + # Recompute only if env changed + if current_env != _debug_nodes_env_value_cache or _debug_nodes_cache is None: + _debug_nodes_cache = parse_debug_env(current_env) + _debug_nodes_env_value_cache = current_env + + return _debug_nodes_cache + + debug_nodes = _get_debug_nodes() + if isinstance(debug_nodes, bool): + return debug_nodes + return node_name in debug_nodes + + class SearchFn(Protocol): __name__: str @@ -127,11 +157,9 @@ def _transfer_meta( ) -> None: from torch.fx.traceback import NodeSource, NodeSourceAction - # transfer metadata after pattern matching occurs. - # skip "val" and "tensor_meta" because this info is too specific; it's unlikely - # to remain accurate after pattern matching has occurred. + # Transfer metadata after pattern matching occurs. + # Copies _COPY_META_FIELDS, stack_trace, and (if missing) val/tensor_meta. if config.trace.provenance_tracking_level == 1: - # We handle "from_node" field of the node meta specially to record that the new node comes from the old_node. new_from_node = new_meta.get("from_node", []).copy() new_from_node.append(NodeSource(old_node, pass_name, NodeSourceAction.REPLACE)) new_meta.update( @@ -148,6 +176,13 @@ def _transfer_meta( ) if "stack_trace" in old_node.meta: new_meta["stack_trace"] = old_node.meta["stack_trace"] + # Copy val/tensor_meta only when the new node doesn't already have them + # (e.g. from tracing the replacement graph). Don't overwrite if present + # since the replacement's own val is more accurate. + if "val" not in new_meta and "val" in old_node.meta: + new_meta["val"] = old_node.meta["val"] + if "tensor_meta" not in new_meta and "tensor_meta" in old_node.meta: + new_meta["tensor_meta"] = old_node.meta["tensor_meta"] class Match: @@ -1124,7 +1159,7 @@ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> Non handler = functools.wraps(self.handler)(functools.partial(self.handler, match)) with graph.inserting_before(node): replacement = graph.call_function(handler, tuple(match.args), match.kwargs) - replacement.meta.update(node.meta) + _transfer_meta(replacement.meta, node) node.replace_all_uses_with(replacement) assert match.nodes[-1] is node match.erase_nodes() @@ -1150,6 +1185,7 @@ class ReplacementPatternEntry(PatternEntry): """ normalize_args: Callable[..., list[Any]] + pattern_name: str | None = None # Unique identifier for per-pattern telemetry @staticmethod def replace_with_graph( @@ -1157,6 +1193,7 @@ def replace_with_graph( graph: torch.fx.Graph, replacement_graph: torch.fx.Graph | torch.fx.GraphModule, args: Sequence[torch.fx.Node], + pass_name: str | None = None, ) -> None: """ Inserts the replacement graph into the toplevel graph at the match @@ -1181,18 +1218,14 @@ def run_node(self, node: torch.fx.Node) -> Any: _transfer_meta( new_meta=result.meta, old_node=node, - pass_name="Interpreter_Replacer", + pass_name=pass_name or "", ) # This function copy-pastes the replacement graph into # the graph. If the replacement graph had any eager_input_vals, - # or val/tensor_meta, we propagate those over. + # we propagate those over (val/tensor_meta are handled by + # _transfer_meta above). if "eager_input_vals" in node.meta: result.meta["eager_input_vals"] = node.meta["eager_input_vals"] - if "val" in node.meta and "val" not in result.meta: - result.meta["val"] = node.meta["val"] - if isinstance(node.meta["val"], torch.Tensor): - assert "tensor_meta" in node.meta - result.meta["tensor_meta"] = node.meta["tensor_meta"] return result if node.op == "get_attr": # If the replacement graph contains a HOP, the subgraphs of the HOP are "get_attr" nodes. @@ -1298,8 +1331,7 @@ def filter_nodes_in_newly_added_nodes(node: torch.fx.Node) -> bool: graph.erase_node(old) return if isinstance(new, torch.fx.Node): - if "val" not in new.meta: - new.meta.update(old.meta) + _transfer_meta(new.meta, old, pass_name=pass_name or "") # Preserve the recompute tags in the replacement graph. We # look at the recompute tags of the original output node to @@ -1369,6 +1401,16 @@ def filter_nodes_in_newly_added_nodes(node: torch.fx.Node) -> bool: match.erase_nodes() + # Remove dead replacement nodes so they don't inflate user counts + # in later lowering heuristics (e.g. should_realize_on_reuse). + for node in reversed(added_replacement_nodes): + if ( + not node.users + and not node.is_impure() + and not isinstance(node.target, torch._ops.HigherOrderOperator) + ): + graph.erase_node(node) + def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> None: assert match.replacement_graph is not None self.replace_with_graph( @@ -1376,6 +1418,7 @@ def apply(self, match: Match, graph: torch.fx.Graph, node: torch.fx.Node) -> Non graph, match.replacement_graph, self.normalize_args(*match.args, **match.kwargs), + pass_name=self.pattern_name or "replace_with_graph", ) @@ -1436,7 +1479,7 @@ def check_and_add_duplicate_pattern( def register_replacement( search_fn: SearchFn, replace_fn: ReplaceFn, - example_inputs: Iterable[Any], + example_inputs: list[Any] | tuple[Any, ...], trace_fn: TraceFn, pass_dicts: _PassDictsType | Sequence[_PassDictsType], extra_check: Callable[[Match], bool] = _return_true, @@ -1444,6 +1487,8 @@ def register_replacement( exclusive_arg_names: Sequence[str] = (), search_fn_pattern: PatternExpr | None = None, skip_duplicates: bool = False, + pattern_name: str | None = None, + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = select_decomp_table, ) -> bool: """ Create a replacement rule based on example functions that get traced @@ -1467,6 +1512,11 @@ def register_replacement( replace_argnames = [*inspect.signature(replace_fn).parameters.keys()] replace_fn = _wrap_bound_method(replace_fn, replace_argnames) + if not isinstance(example_inputs, (list, tuple)): + raise TypeError( + f"example_inputs must be a list or tuple, got {type(example_inputs)}" + ) + def check_fn(match: Match) -> bool: """ Often shapes get burned into the pattern, so our initial match ran with @@ -1530,12 +1580,15 @@ def check_fn(match: Match) -> bool: # Later, when we actually do the replacement, the symbolic shape # sizes will get re-traced and added to the graph. - def search_fn_new(*args_new: Any) -> Any: + def search_fn_new(*args_new: Any, **_: Any) -> Any: return search_fn(*args_new[len(args_new) - len(args) :]) try: - # pyrefly: ignore [bad-argument-type] - specific_graph = trace_fn(search_fn_new, sym_args + args) + specific_graph = trace_fn( + search_fn_new, + sym_args + args, + get_decomp_fn=get_decomp_fn, + ) except RuntimeError as e: log_trace_failure(search_fn, e) return False @@ -1561,7 +1614,9 @@ def search_fn_new(*args_new: Any) -> Any: argnames = sym_arg_names + argnames else: try: - specific_graph = trace_fn(search_fn, args) + specific_graph = trace_fn( + search_fn, args, get_decomp_fn=get_decomp_fn + ) except RuntimeError as e: log_trace_failure(search_fn, e) return False @@ -1577,7 +1632,7 @@ def search_fn_new(*args_new: Any) -> Any: assert node is not None specific_pattern_match = specific_pattern.match(node) - if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + if _should_debug_node(node.name): log.warning( "Specific pattern match: %s%s %s %s", node, @@ -1588,7 +1643,9 @@ def search_fn_new(*args_new: Any) -> Any: if is_match(specific_pattern_match) and extra_check(specific_pattern_match): # trace the pattern using the shapes from the user program - match.replacement_graph = trace_fn(replace_fn, args) + match.replacement_graph = trace_fn( + replace_fn, args, get_decomp_fn=get_decomp_fn + ) if len(match.nodes) == 1: for n in match.replacement_graph.graph.nodes: _transfer_meta( @@ -1626,6 +1683,7 @@ def normalize_args(**kwargs: Any) -> list[Any]: trace_fn, scalar_workaround, exclusive_arg_names, + get_decomp_fn=get_decomp_fn, ) else: pattern = search_fn_pattern @@ -1647,6 +1705,7 @@ def normalize_args(**kwargs: Any) -> list[Any]: pattern=pattern, extra_check=check_fn, normalize_args=normalize_args, + pattern_name=pattern_name, ) pattern.register(pass_dicts) return pattern.pattern # type: ignore[return-value] @@ -1762,6 +1821,7 @@ def gen_register_replacement( scalar_workaround: dict[str, float | int] | None = None, exclusive_arg_names: Sequence[str] = (), skip_duplicates: bool = False, + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = select_decomp_table, ) -> None: # Make sure the example_inputs is materialized. example_inputs = tuple(example_inputs) @@ -1804,6 +1864,8 @@ def gen_register_replacement( exclusive_arg_names, search_fn_pattern=pat, skip_duplicates=skip_duplicates, + pattern_name=unique_name, + get_decomp_fn=get_decomp_fn, ) @@ -1814,6 +1876,7 @@ def gen_pattern_and_search_gm( trace_fn: TraceFn, scalar_workaround: dict[str, float | int] | None = None, exclusive_arg_names: Sequence[str] = (), + get_decomp_fn: Callable[..., dict[Any, Callable[..., Any]]] = select_decomp_table, ) -> tuple[PatternExpr, torch.fx.GraphModule]: argnames = [*inspect.signature(search_fn).parameters.keys()] @@ -1829,7 +1892,7 @@ def gen_pattern_and_search_gm( flat_inputs.append(example_inputs[input_idx]) input_idx += 1 - search_gm = trace_fn(search_fn, flat_inputs) + search_gm = trace_fn(search_fn, flat_inputs, get_decomp_fn=get_decomp_fn) return ( fx_to_pattern( search_gm, @@ -1990,6 +2053,10 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: class PatternMatcherPass: + """ + Registry of patterns to match and replace in FX graphs. + """ + def __init__( self, pass_name: str | None = None, @@ -2012,6 +2079,7 @@ def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEnt return self.patterns[item] def apply(self, gm: torch.fx.GraphModule | torch.fx.Graph) -> int: + """Apply all registered patterns to the graph, returning the number of matches.""" if not self.patterns: return 0 if isinstance(gm, torch.fx.GraphModule): @@ -2066,7 +2134,19 @@ def apply(self, gm: torch.fx.GraphModule | torch.fx.Graph) -> int: != 1 ): continue - if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG") == node.name: + # pattern match crosses stream boundary - discard + if ( + is_match(m) + and len( + OrderedSet( + n.meta.get("custom", {}).get("stream", 0) + for n in m.nodes + ) + ) + != 1 + ): + continue + if _should_debug_node(node.name): log.warning("%s%s %s %s", node, node.args, m, entry.pattern) if is_match(m) and guard_or_false(entry.extra_check(m)): @@ -2074,6 +2154,19 @@ def apply(self, gm: torch.fx.GraphModule | torch.fx.Graph) -> int: entry.apply(m, graph, node) counters[backend]["pattern_matcher_count"] += 1 counters[backend]["pattern_matcher_nodes"] += len(m.nodes) + + # Track per-pattern counts when debug mode is active + if os.environ.get("TORCHINDUCTOR_PATTERN_MATCH_DEBUG"): + if getattr(entry, "pattern_name", None): + pattern_name = entry.pattern_name + else: + # Fallback: use pattern class name + operation target + pattern_class = entry.pattern.__class__.__name__ + target = str(node.target) + pattern_name = f"{pattern_class}_{target}" + + pattern_key = f"{backend}_pattern_matcher_per_pattern" + counters[pattern_key][pattern_name] += 1 return count def clear(self) -> None: @@ -2200,15 +2293,12 @@ def fwd_only( args: Sequence[Any], *, run_functional_passes: bool = True, - get_decomp_fn: Callable[..., Any] | None = None, + get_decomp_fn: Callable[..., Any] = select_decomp_table, ) -> torch.fx.GraphModule: """Build a normalized inference graph, for use with fx_to_pattern""" # TODO - look into using aot autograd, asserting no mutating ops here with enable_python_dispatcher(), preserve_node_meta(): - decompositions = ( - get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() - ) - gm = make_fx(fn, decompositions, tracing_mode="real")(*args) + gm = make_fx(fn, get_decomp_fn(), tracing_mode="real")(*args) from .fx_passes.post_grad import remove_noop_ops @@ -2229,7 +2319,12 @@ def fwd_only( @torch.enable_grad() -def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.GraphModule: +def joint_fwd_bwd( + fn: Callable[..., Any], + args: Sequence[Any], + *, + get_decomp_fn: Callable[..., Any] = select_decomp_table, +) -> torch.fx.GraphModule: """Build a normalized training graph, for use with fx_to_pattern""" gm: torch.fx.GraphModule | None = None @@ -2247,7 +2342,7 @@ def record_joint_graph( # pyrefly: ignore[bad-argument-type] lambda gm, example_inputs: make_boxed_func(gm), partition_fn=record_joint_graph, - decompositions=select_decomp_table(), + decompositions=get_decomp_fn(), keep_inference_input_mutations=True, enable_log=False, )(*args) @@ -2310,16 +2405,29 @@ def stable_topological_sort(graph: torch.fx.Graph) -> None: assert not waiting and len(ready) == len(graph.nodes) -def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]: +def init_once_fakemode(fn: Callable[..., Any]) -> Callable[..., Any]: """Wrapper around lazy init functions in fx_passes/""" + _fn_params = inspect.signature(fn).parameters + @functools.cache @functools.wraps(fn) - def lazy_init(input_device: torch.device | None = None) -> Any: + def lazy_init( + input_device: Any | None = None, + get_decomp_fn: Callable[ + ..., dict[Any, Callable[..., Any]] + ] = select_decomp_table, + ) -> Any: counters_ref = counters[backend].copy() + kwargs: dict[str, Any] = {} + if "input_device" in _fn_params: + kwargs["input_device"] = input_device + if "get_decomp_fn" in _fn_params: + kwargs["get_decomp_fn"] = get_decomp_fn + with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode(): - result = fn(input_device) + result = fn(**kwargs) # clear view matches encountered during tracing counters[backend] = counters_ref diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 5a162342faa52..a181d15b0baaa 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -26,7 +26,6 @@ from __future__ import annotations import dataclasses -import hashlib import logging import os import os.path @@ -43,6 +42,7 @@ ) from torch.utils._triton import has_triton +from ..cache_key import AUTOTUNE_CACHE_KEY_STRATEGY from ..remote_cache import ( create_cache, JsonDataTy, @@ -139,7 +139,7 @@ def _prepare_key(filename: str) -> str: # base of filename is already sha256 hash the source contents key = f"{os.path.basename(filename)}:{cconfig.cache_key_tag}" - return hashlib.sha256(key.encode("utf-8")).hexdigest() + return AUTOTUNE_CACHE_KEY_STRATEGY.key(key) # Read the best config options from the most local cache and return it. def _read(self) -> dict[str, JsonDataTy] | None: @@ -184,10 +184,7 @@ def _setup_local_cache( of changes to the best_config format or other code changes that are not backward compatible w.r.t. the cache. """ - hasher = hashlib.sha256() - hasher.update(cache_key.encode("utf-8")) - hasher.update(torch_key()) - updated_cache_key = hasher.hexdigest() + updated_cache_key = AUTOTUNE_CACHE_KEY_STRATEGY.key(cache_key, torch_key()) cache_filename = f"{dirname}/{updated_cache_key}.best_config" local_cache = LocalAutotuneCache() @@ -213,8 +210,9 @@ def _setup_remote_autotune_cache( salt = "autotune-best-config-v2" # re: torch_key - see [Note: torch_key in autotune cache key] - key = torch_key().hex() + backend_hash + self.configs_hash + salt - key = hashlib.sha256(key.encode("utf-8")).hexdigest() + key = AUTOTUNE_CACHE_KEY_STRATEGY.key( + torch_key().hex(), backend_hash, self.configs_hash, salt + ) remote_cache = create_cache( key, @@ -474,8 +472,7 @@ def begin_compile( # that info is basically present in the `code_hash` (since it's a # parameter to the pointwise decorator) - but is there other info we # need to include from inductor_meta? - key = code_hash + backend_hash + salt - key = hashlib.sha256(key.encode("utf-8")).hexdigest() + key = AUTOTUNE_CACHE_KEY_STRATEGY.key(code_hash, backend_hash, salt) bundler = _AutotuneCacheBundlerImpl(key, cache) if not bundler._load_cache(): diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 27322454b5b19..2fa20649d32be 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -310,6 +310,7 @@ def __call__( ) -> Callable[_P, Callable[[_R], _EncodedR | DeferredRecording[_R, _EncodedR]]]: ... +# pyrefly: ignore [variance-mismatch] class ResultDecoderFactory(Protocol[_P, _R, _EncodedR]): """Protocol for custom result decoder factories. diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 91736febd29f6..77dc4d4425461 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -75,6 +75,7 @@ def __init__( self.frozen_fields: OrderedSet[str] = ( OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet() ) + self._combo_tunable_fields: list[str] = [] def get_config_max(self, prefix: str) -> int: max_block = TRITON_MAX_BLOCK[prefix.upper()] @@ -136,9 +137,14 @@ def tunable_fields(self): # control the stage of pipelining of tl.range. out.append("NUM_STAGES") + out = self._combo_tunable_fields + out return [f for f in out if f not in self.frozen_fields] def value_too_large(self, name: str, val: int) -> bool: + field_limits = self.inductor_meta.get("combo_coordesc_field_limits") + if isinstance(field_limits, dict) and name in field_limits: + return val > field_limits[name] + block_suffix = "BLOCK" if name.endswith(block_suffix): prefix = name.strip(block_suffix).lower() @@ -279,7 +285,7 @@ def compare_config(self, func, candidate_config, best_config, best_timing): try: candidate_timing = self.call_func(func, candidate_config) except Exception as e: - log.debug("Got exception %s", e) # noqa: G200 + log.debug("Got exception %s", e) return False, float("inf") if self.has_improvement(best_timing, candidate_timing): @@ -302,6 +308,9 @@ def autotune( baseline_config: "triton.Config", baseline_timing: float | None = None, ) -> "triton.Config": # pyrefly: ignore # missing-attribute + """ + Perform coordinate descent autotuning starting from a baseline configuration. + """ if baseline_timing is None: baseline_timing = self.call_func(func, baseline_config) @@ -315,6 +324,11 @@ def autotune( improved = True best_config = baseline_config best_timing = baseline_timing + + self._combo_tunable_fields = self.inductor_meta.get( + "combo_coordesc_field_order", [] + ) + tunable_fields = self.tunable_fields while improved: diff --git a/torch/_inductor/runtime/halide_helpers.py b/torch/_inductor/runtime/halide_helpers.py index 3b02892eb2eb7..2b0161ead97ae 100644 --- a/torch/_inductor/runtime/halide_helpers.py +++ b/torch/_inductor/runtime/halide_helpers.py @@ -101,6 +101,40 @@ def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): return _uint_to_uniform_float(source) +def rand_eager_kernel(seed, offset_blocks, tid, VEC, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + inv = hl.cast(hl.Float(32), 1.0 / 4294967296.0) # 2^-32 + half = hl.cast(hl.Float(32), 0.5) * inv + + tid_u64 = hl.cast(hl.UInt(64), tid) + VEC_u64 = hl.cast(hl.UInt(64), VEC) + subseq = tid_u64 // VEC_u64 + which4 = (tid_u64 % VEC_u64) // hl.cast(hl.UInt(64), 4) + lane = tid_u64 % hl.cast(hl.UInt(64), 4) + + offblk = hl.cast(hl.UInt(64), offset_blocks) + which4 + + c0 = hl.cast(hl.UInt(32), offblk & hl.cast(hl.UInt(64), 0xFFFFFFFF)) + c1 = hl.cast( + hl.UInt(32), + (offblk >> hl.cast(hl.UInt(64), 32)) & hl.cast(hl.UInt(64), 0xFFFFFFFF), + ) + c2 = hl.cast(hl.UInt(32), subseq & hl.cast(hl.UInt(64), 0xFFFFFFFF)) + c3 = hl.cast( + hl.UInt(32), + (subseq >> hl.cast(hl.UInt(64), 32)) & hl.cast(hl.UInt(64), 0xFFFFFFFF), + ) + + u0, u1, u2, u3 = halide_philox(seed, c0, c1, c2, c3, n_rounds) + + v01 = hl.select(lane == hl.cast(hl.UInt(64), 0), u0, u1) + v23 = hl.select(lane == hl.cast(hl.UInt(64), 2), u2, u3) + rand_int = hl.select( + (lane == hl.cast(hl.UInt(64), 0)) | (lane == hl.cast(hl.UInt(64), 1)), v01, v23 + ) + + return hl.cast(hl.Float(32), 1.0) - (hl.cast(hl.Float(32), rand_int) * inv + half) + + def randn(seed, offset): i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) u1 = _uint_to_uniform_float(i1) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index a9ddf91e9a59c..af91cdece3bd3 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -47,6 +47,7 @@ class TileHint(Enum): def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # Prepare the arguments for AttrsDescriptor kwargs = { @@ -69,6 +70,7 @@ def AttrsDescriptorWrapper( def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # Prepare the arguments for AttrsDescriptor kwargs = { @@ -88,17 +90,27 @@ def AttrsDescriptorWrapper( def AttrsDescriptorWrapper( divisible_by_16=None, equal_to_1=None, + pointer_range_32=None, ): # pyrefly: ignore [not-iterable] - return {(x,): [["tt.divisibility", 16]] for x in divisible_by_16} + # Build attr dict merging divisibility and pointer_range per arg index, + # since a single arg can carry both attributes. + result = {(x,): [["tt.divisibility", 16]] for x in (divisible_by_16 or ())} + for x in pointer_range_32 or (): + key = (x,) + if key in result: + result[key].append(["tt.pointer_range", 32]) + else: + result[key] = [["tt.pointer_range", 32]] + return result else: # Define a namedtuple as a fallback when AttrsDescriptor is not available AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] # pyrefly: ignore [invalid-argument] "AttrsDescriptor", - ["divisible_by_16", "equal_to_1"], - defaults=[(), ()], + ["divisible_by_16", "equal_to_1", "pointer_range_32"], + defaults=[(), (), ()], ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 61897910a59cd..632c39889f8e1 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -193,6 +193,27 @@ def compile_mps_shader(source: str) -> Any: raise SyntaxError(f"failed to compile {source} with {err.msg}") from err +def compile_mps_shaders( + kernels: list[tuple[str, str, list[str]]], +) -> dict[str, Any]: + """Compile a batch of Metal kernels into one library. + + Args: + kernels: list of (kernel_name, metal_source, headers) tuples. + headers are bare names resolved as . + + Returns: + dict mapping each kernel_name to its compiled function handle. + """ + from torch.utils._ordered_set import OrderedSet + + all_headers = sorted(OrderedSet(h for _, _, hs in kernels for h in hs)) + header_src = "\n".join(f"#include " for h in all_headers) + body_src = "\n".join(src for _, src, _ in kernels) + lib = compile_mps_shader(header_src + "\n" + body_src) + return {name: getattr(lib, name) for name, _, _ in kernels} + + def torch_dtype_to_jax_runtime(dtype: torch.dtype) -> Any: """ Map PyTorch dtype to actual JAX dtype object at runtime. @@ -385,6 +406,7 @@ def pallas_permute(x, perm): # Unrolled loop: extract slices with static indices, apply sub-perm, # then stack along the loop output dimension. slices = [] + # pyrefly: ignore [bad-argument-type] for i in range(loop_size): idx: list[Any] = [slice(None)] * ndim idx[loop_in_dim] = i @@ -697,7 +719,7 @@ def pallas_make_block_spec( if buf_nd == 0: # Scalar — untouched regardless of grid shape. - return pl.BlockSpec((), _make_index_map([], buf_nd, n_grid)) + return pl.BlockSpec((1,), _make_index_map([], 1, n_grid)) bs = list(buf_shape) tiled_pairs: list[tuple[int, int]] = [] @@ -786,3 +808,22 @@ def index_map(*grid_args): ) return index_map + + +def pallas_ensure_nonzero_rank(x: torch.Tensor) -> torch.Tensor: + if len(x.shape) == 0: + return x.reshape((1,)) + return x + + +def pallas_make_block_spec_non_tiled(shape: tuple[int, ...]) -> Any: + import jax.numpy as jnp # pyrefly: ignore [import-error, missing-import] + from jax.experimental import ( # pyrefly: ignore [import-error, missing-import] + pallas as pl, + ) + + nonzero_rank_shape = shape if len(shape) > 0 else (1,) + return pl.BlockSpec( + nonzero_rank_shape, + lambda i: [jnp.int32(i)] * len(nonzero_rank_shape), + ) diff --git a/torch/_inductor/runtime/triton_helpers.py b/torch/_inductor/runtime/triton_helpers.py index 7f14b089ab0b8..c38633b2842b9 100644 --- a/torch/_inductor/runtime/triton_helpers.py +++ b/torch/_inductor/runtime/triton_helpers.py @@ -5,7 +5,7 @@ from collections.abc import Callable from typing import Any, TypeVar -from .triton_compat import ( # noqa: F401 +from .triton_compat import ( _log2, builtins_use_semantic_kwarg, JITFunction, @@ -34,10 +34,27 @@ def set_driver_to_cpu(): ) +def _is_backend_active(name, backend): + if backend.driver.is_active(): + return True + # Triton may fail to detect the GPU in subprocess workers when using + # ctypes-based driver detection (triton-lang/triton#9578). Fall back + # to torch's own device checks which are more reliable in these environments. + if name == "nvidia": + import torch + + return torch.cuda.is_available() and torch.version.hip is None + if name == "amd": + import torch + + return torch.cuda.is_available() and torch.version.hip is not None + return False + + def set_driver_to_gpu(): driver = triton.runtime.driver for name, backend in triton.backends.backends.items(): - if backend.driver.is_active() and name != "cpu": + if _is_backend_active(name, backend) and name != "cpu": # After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4, # `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute. if ( @@ -266,6 +283,34 @@ def device_assert_then(cond, msg, r): return r +@triton.jit +def rand_eager_kernel(seed, offset_blocks, tid: tl.tensor, VEC: tl.constexpr): + inv = 1.0 / 4294967296.0 + half = inv * 0.5 + + tid_u64 = tid.to(tl.uint64) + + subseq = tid_u64 // VEC + which4 = (tid_u64 % VEC) // 4 + lane = tid_u64 % 4 + + offblk = offset_blocks.to(tl.uint64) + which4 + + u0, u1, u2, u3 = tl.philox( + seed, + (offblk & 0xFFFFFFFF).to(tl.uint32), + ((offblk >> 32) & 0xFFFFFFFF).to(tl.uint32), + (subseq & 0xFFFFFFFF).to(tl.uint32), + ((subseq >> 32) & 0xFFFFFFFF).to(tl.uint32), + ) + + v01 = tl.where(lane == 0, u0, u1) + v23 = tl.where(lane == 2, u2, u3) + rand_int = tl.where((lane == 0) | (lane == 1), v01, v23) + + return 1.0 - (rand_int.to(tl.float32) * inv + half) + + @triton.jit def randint64(seed, offset, low, high): r0, r1, _r2, _r3 = tl.randint4x(seed, offset) @@ -779,3 +824,27 @@ def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr: if isinstance(mask, tl.constexpr) and mask.value is None: return tl.constexpr(None) return val + + +@triton.jit +def inline_asm_pack(x, pack: tl.constexpr): + """Ravel to 1D and pad (via join with zeros) so numel is divisible by pack.""" + result = x.ravel() + # Only pad when the block size is smaller than pack. When block >= pack + # the numel is already divisible by pack (both are powers of 2). + n_pad: tl.constexpr = _log2(pack) - _log2(result.numel) + for _ in tl.static_range(n_pad): + result = tl.reshape( + tl.join(result, tl.zeros_like(result)), (result.shape[0] * 2,) + ) + return result + + +@triton.jit +def inline_asm_unpack(x, orig, pack: tl.constexpr): + """Unpad and reshape back to orig's shape.""" + result = x + n_pad: tl.constexpr = _log2(pack) - _log2(orig.numel) + for _ in tl.static_range(n_pad): + result, _ = tl.split(tl.reshape(result, (result.shape[0] // 2, 2))) + return tl.reshape(result, orig.shape) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index bb601f233f5d6..5b82589c565a2 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -112,7 +112,9 @@ class NoTritonConfigsError(RuntimeError): if TYPE_CHECKING: from collections.abc import Callable, Container, Hashable + from torch._C._profiler import _RecordFunctionFast from torch._guards import CompileId + from torch.utils._debug_mode import _TritonKernelCall LauncherType = Any @@ -434,6 +436,9 @@ def __init__( self.triton_interpret = os.environ.get("TRITON_INTERPRET", "0") == "1" + self._debug_call: _TritonKernelCall | None = None + self._profiler_ctx: _RecordFunctionFast | None = None + # Compile-time info included in runtime logginging self.compile_id: CompileId | None = None self.is_backward = False @@ -759,12 +764,28 @@ def _create_compile_meta(self, cfg: Config) -> dict[str, Any]: compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages - cfg_kwargs = cfg.kwargs + cfg_kwargs = {**cfg.kwargs} if self.device_props.type == "hip": - cfg_kwargs = {**cfg_kwargs} - for k in ("matrix_instr_nonkdim", "waves_per_eu", "kpack"): - if k in cfg_kwargs: - compile_meta[k] = cfg_kwargs.pop(k) + # `compile_meta["signature"]` contains the actual Triton kernel argument + # names, including constexprs such as XBLOCK_0/XBLOCK_1 for combo kernels. + # Any HIP config kwarg that is *not* in that signature is not a kernel + # argument at all; it is a backend compile option that should be forwarded + # to triton.compile via `options`, not materialized as a constexpr. + signature_arg_names = OrderedSet(compile_meta["signature"]) + backend_options = { + key: value + for key, value in cfg_kwargs.items() + if key not in signature_arg_names + } + cfg_kwargs = { + key: value + for key, value in cfg_kwargs.items() + if key in signature_arg_names + } + if backend_options: + # Stash backend-only options separately so they do not get mixed into + # `constants`, which are interpreted as signature-bound constexpr args. + compile_meta["backend_options"] = backend_options compile_meta["constants"].update(cfg_kwargs) for i in get_constexprs(self.fn): @@ -832,10 +853,9 @@ def _create_compile_options( if v := getattr(cfg, k, None): options[k] = v if self.device_props.type == "hip": - if "waves_per_eu" in compile_meta: - options["waves_per_eu"] = compile_meta["waves_per_eu"] - if "matrix_instr_nonkdim" in compile_meta: - options["matrix_instr_nonkdim"] = compile_meta["matrix_instr_nonkdim"] + # HIP backend options are consumed by Triton out-of-band from the kernel + # signature. They are intentionally *not* present in `constants`. + options.update(compile_meta.get("backend_options", {})) if self.device_props.type == "xpu" and XPU_KERNEL_FORMAT == "zebin": options["generate_native_code"] = True @@ -1057,13 +1077,22 @@ def copy_args_to_cpu_if_needed(self, *args, **kwargs): copies = {} try: - budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated() + if torch.accelerator.current_accelerator() is None: + # No initialized accelerator; skip memory-optimized path + return {} + budget = ( + torch.accelerator.max_memory_allocated() + - torch.accelerator.memory_allocated() + ) except RuntimeError: # Possibly a custom CUDA allocator, see https://github.com/pytorch/pytorch/issues/163257 return {} def maybe_copy(name, arg): - if name in self.mutated_arg_names and arg.is_cuda: + if name in self.mutated_arg_names and arg.device.type in ( + "cuda", + "xpu", + ): nonlocal budget assert isinstance(arg, torch.Tensor) required_storage_length = compute_required_storage_length( @@ -1275,6 +1304,8 @@ def autotune_to_one_config(self, *args, **kwargs): launcher.shared, ) + TritonBundler.put_winner(launcher.cache_hash) + if self.save_cache_hook: self.save_cache_hook( launcher.config, @@ -1285,6 +1316,157 @@ def autotune_to_one_config(self, *args, **kwargs): triton_cache_hash=launcher.cache_hash, ) + def _combo_sequential_autotune(self, launcher, *args, **kwargs): + """ + Chain block-size decisions for combo kernels: tune one group at a time, + each step building on the previous winner. + + Phase 1: Tune block sizes with warps/stages fixed from the base config. + Phase 2: Re-tune warps/stages with finalized block sizes. + """ + combo_tuning_groups = self.inductor_meta.get("combo_tuning_groups") + if not combo_tuning_groups: + return launcher + + if self.fn.fn is None: + assert hasattr(self, "_reload_kernel") + self.fn = self._reload_kernel().fn + + signature_keys = OrderedSet(self.triton_meta["signature"]) + best_config = launcher.config + current_kwargs = dict(best_config.kwargs) + base_num_warps = best_config.num_warps + base_num_stages = best_config.num_stages + + start_time = time.time_ns() + best_time = self.bench(launcher, *args, **kwargs) + counters["inductor"]["combo_autotune_bench"] += 1 + self.coordesc_tuner.cache_benchmark_result(launcher.config, best_time) + log.debug( + " Phase 1 baseline: %s warps=%d time=%f", + dict(current_kwargs), + base_num_warps, + best_time, + ) + + # Phase 1: Tune block sizes per sub-kernel (largest first). + # warps/stages stay fixed at base config values. + for gi, group in enumerate(combo_tuning_groups): + member_indices = group["member_indices"] + cfgs = group["configs"] + skip_rblock = group["skip_rblock"] + + if len(cfgs) <= 1: + log.debug(" Phase 1 group %d SK%s: 1 config, skip", gi, member_indices) + continue + + log.debug( + " Phase 1 group %d SK%s: trying %d configs, current_kwargs=%s", + gi, + member_indices, + len(cfgs), + dict(current_kwargs), + ) + for ci, cfg in enumerate(cfgs): + trial_kwargs = dict(current_kwargs) + for idx in member_indices: + _update_combo_kernel_kwargs( + trial_kwargs, cfg.kwargs, idx, skip_rblock, signature_keys + ) + + if trial_kwargs == current_kwargs: + log.debug(" cfg[%d] skip (same as current)", ci) + continue + + trial_config = triton.Config( + trial_kwargs, + num_warps=base_num_warps, + num_stages=base_num_stages, + ) + + with self.lock: + trial_launcher = self._precompile_config( + trial_config + ).make_launcher() + trial_time = self.bench(trial_launcher, *args, **kwargs) + counters["inductor"]["combo_autotune_bench"] += 1 + self.coordesc_tuner.cache_benchmark_result(trial_config, trial_time) + + improved = trial_time < best_time + log.debug( + " cfg[%d] trial=%s time=%f%s", + ci, + dict(trial_kwargs), + trial_time, + " (BETTER)" if improved else "", + ) + if improved: + best_time = trial_time + launcher = trial_launcher + current_kwargs = trial_kwargs + + log.debug( + " Phase 1 group %d winner: current_kwargs=%s", + gi, + dict(current_kwargs), + ) + + # Phase 2: Re-tune num_warps/num_stages with finalized block sizes. + # Block sizes are now optimal — find the best warp/stage pair for them. + warp_stage_candidates = self.inductor_meta.get("combo_warp_stage_candidates") + log.debug( + " Phase 2: blocks=%s, trying %d warp/stage pairs", + dict(current_kwargs), + len(warp_stage_candidates), + ) + best_warps = launcher.config.num_warps + best_stages = launcher.config.num_stages + for num_warps, num_stages in warp_stage_candidates: + if num_warps == best_warps and num_stages == best_stages: + log.debug( + " warps=%d stages=%d skip (same as current)", + num_warps, + num_stages, + ) + continue + + trial_config = triton.Config( + dict(current_kwargs), + num_warps=num_warps, + num_stages=num_stages, + ) + with self.lock: + trial_launcher = self._precompile_config(trial_config).make_launcher() + trial_time = self.bench(trial_launcher, *args, **kwargs) + counters["inductor"]["combo_autotune_bench"] += 1 + self.coordesc_tuner.cache_benchmark_result(trial_config, trial_time) + + improved = trial_time < best_time + log.debug( + " warps=%d stages=%d time=%f%s", + num_warps, + num_stages, + trial_time, + " (BETTER)" if improved else "", + ) + if improved: + best_time = trial_time + launcher = trial_launcher + best_warps = num_warps + best_stages = num_stages + + log.debug( + "Combo sequential autotune for %s: best config %s, time %f", + self.fn.__name__, + launcher.config, + best_time, + ) + launcher.config.found_by_combo_autotune = True + self.autotune_time_taken_ns += time.time_ns() - start_time + if self.save_cache_hook: + self.save_cache_hook(launcher.config, self.autotune_time_taken_ns) + return launcher + def save_gpu_kernel(self, stream, launcher): key = self.inductor_meta.get("kernel_name", None) # unique kernel name assert key is not None, "kernel_name can not be None" @@ -1314,14 +1496,6 @@ def save_gpu_kernel(self, stream, launcher): "global_scratch": launcher.global_scratch, "profile_scratch": launcher.profile_scratch, } - if self.device_props.type == "xpu": - # On the XPU backend, threads_per_warp is not always 32. - # For Intel GEMM Triton kernels, it can be 16. - # This information must be preserved so that the Cpp wrapper - # can launch the kernel with the correct configuration. - params["threads_per_warp"] = getattr( - launcher.bin.metadata, "threads_per_warp", 32 - ) from torch._inductor import config from torch._inductor.codecache import CudaKernelParamCache @@ -1462,11 +1636,14 @@ def benchmark_one_config(config): best_config ).make_launcher() + winner = config2launcher[best_config] + TritonBundler.put_winner(winner.cache_hash) + fn_hash = generate_lookup_hash_from_source_code( str(self.size_hints), self.fn.src ) log.debug("Function hash %s has best config %s", fn_hash, best_config) - return config2launcher[best_config] + return winner def get_profiler_kwargs(self, stream, launcher): kernel_kwargs_str = ",".join( @@ -1490,6 +1667,42 @@ def get_profiler_kwargs(self, stream, launcher): ret["kernel_num_gb"] = self.inductor_meta["kernel_num_gb"] return ret + def _pre_launch(self, launcher, *args, stream, **kwargs): + """Pre-launch instrumentation: param/tensor dumping and profiler context entry.""" + if self.dump_launch_params: + new_args, grid = self._interpret_args_grid(args, launcher.config) + _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) + + if self.dump_launch_tensors: + if not self.kernels_to_dump or any( + kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump + ): + _dump_launch_tensors( + args, self.filename, self.kernel_hash, self.fn.__name__ + ) + + if autograd_profiler._is_profiler_enabled: + profiler_kwargs = self.get_profiler_kwargs(stream, launcher) + profiler_ctx = torch._C._profiler._RecordFunctionFast( + self.inductor_meta.get("kernel_name", "triton kernel"), + tuple(args), + profiler_kwargs, + ) + profiler_ctx.__enter__() + # set ctx after enter succeeds + self._profiler_ctx = profiler_ctx + else: + self._profiler_ctx = None + + def _post_launch(self) -> None: + """Post-launch instrumentation: profiler context exit and debug mode finalization.""" + if (profiler_ctx := self._profiler_ctx) is not None: + self._profiler_ctx = None + profiler_ctx.__exit__(None, None, None) + if (debug_call := self._debug_call) is not None: + self._debug_call = None + debug_call.finalize(self.get_device_interface()) + def run( self, *args, @@ -1499,12 +1712,11 @@ def run( ): # type:ignore[override] """Launch triton kernel call and return result.""" debug_mode = get_active_debug_mode() - debug_call = None if debug_mode: arg_names = list(self.triton_meta.get("signature", {}).keys()) kernel_kwargs = dict(zip(arg_names, args)) kernel_kwargs.update(kwargs) - debug_call = debug_mode.record_triton_kernel( + self._debug_call = debug_mode.record_triton_kernel( kernel_name=self.fn.__name__, kwargs=kernel_kwargs ) @@ -1533,6 +1745,23 @@ def alloc_fn(size: int, align: int, stream: int | None): if len(self.launchers) > 1: self.autotune_to_one_config(*args, **kwargs) + if self.inductor_meta.get("combo_tuning_groups") and not getattr( + self.launchers[0].config, "found_by_combo_autotune", False + ): + with dynamo_timed( + "CachingAutotuner.combo_sequential_autotune", + log_pt2_compile_event=False, + metadata={"kernel_name": self.inductor_meta.get("kernel_name")}, + dynamo_compile_column_us="runtime_triton_autotune_time_us", + compile_id=self.compile_id, + is_backward=self.is_backward, + log_waitcounter=True, + waitcounter_name_override="triton_autotuner", + ): + self.launchers = [ + self._combo_sequential_autotune(self.launchers[0], *args, **kwargs) + ] + if not getattr( self.launchers[0].config, "found_by_coordesc", False ) and self.inductor_meta.get("coordinate_descent_tuning", False): @@ -1541,51 +1770,19 @@ def alloc_fn(size: int, align: int, stream: int | None): ] (launcher,) = self.launchers + # Ensure the final launcher is marked as a winner for bundle filtering. + # For multi-config autotuning and coordesc, put_winner was already called + # (this is an idempotent set-add). For single-config kernels that skip + # autotuning entirely, this is the only call site that records the winner. + TritonBundler.put_winner(launcher.cache_hash) if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): self.save_gpu_kernel(stream, launcher) - # PyTorch execution trace replay calls CachingAutotuner::run() instead of calls launcher - # so _RecordFunctionFast need to capture the args into CachingAutotuner::run() - # make a copy here to avoid mutating the original args - args_without_constexprs = tuple(args) - - if self.dump_launch_params: - new_args, grid = self._interpret_args_grid(args, launcher.config) - _dump_launch_params(new_args, kwargs, launcher, self.fn.__name__, grid) - - if self.dump_launch_tensors: - # Check the kernel name if the list was provided - if not self.kernels_to_dump or any( - kernel_name in self.fn.__name__ for kernel_name in self.kernels_to_dump - ): - _dump_launch_tensors( - args, self.filename, self.kernel_hash, self.fn.__name__ - ) - - # it is faster than entering and exiting a context manager, even if the context - # manager is a nullcontext. - if autograd_profiler._is_profiler_enabled: - profiler_kwargs = self.get_profiler_kwargs(stream, launcher) - - with torch._C._profiler._RecordFunctionFast( - self.inductor_meta.get("kernel_name", "triton kernel"), - args_without_constexprs, - profiler_kwargs, - ): - result = launcher( - *args, - **kwargs, - stream=stream, - ) - else: - result = launcher( - *args, - **kwargs, - stream=stream, - ) - - if debug_call: - debug_call.finalize(self.get_device_interface()) + try: + self._pre_launch(launcher, *args, stream=stream, **kwargs) + result = launcher(*args, **kwargs, stream=stream) + finally: + self._post_launch() return result def _interpret_args_grid( @@ -1839,7 +2036,7 @@ def check_can_launch() -> _KernelType: result = check_can_launch() return result except CannotStaticallyLaunchKernel as e: - log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", str(e)) # noqa: G200 + log.info("Bypassing StaticallyLaunchedCudaKernel due to %s", e) if torch._inductor.config.strict_static_triton_launcher: raise e return None @@ -2489,6 +2686,7 @@ def triton_config( num_warps=None, matrix_instr=None, waves_per_eu=None, + kpack=None, ) -> Config: """ Construct a pointwise triton config with some adjustment heuristics @@ -2586,6 +2784,8 @@ def triton_config( config.kwargs["matrix_instr_nonkdim"] = matrix_instr if waves_per_eu is not None: config.kwargs["waves_per_eu"] = waves_per_eu + if kpack is not None: + config.kwargs["kpack"] = kpack return config @@ -2716,6 +2916,48 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]: return {prefix.upper() + "BLOCK": numel for prefix, numel in numels.items()} +def _combo_tiling_signature( + tiling_scores: dict[str, Any] | None, +) -> tuple[tuple[str, float], ...] | None: + """ + Build a grouping signature from tiling scores. + + Normalize scores so proportional patterns (e.g. {x: 8, y: 1} vs {x: 16, y: 2}) + end up in the same group, while kernels with different coalescing preference do not. + """ + if not tiling_scores: + return None + + total = sum(float(score) for score in tiling_scores.values()) + if total == 0: + return tuple(sorted((dim, 0.0) for dim in tiling_scores)) + + return tuple( + sorted( + (dim, round(float(score) / total, 2)) + for dim, score in tiling_scores.items() + ) + ) + + +def _update_combo_kernel_kwargs( + kwargs: dict[str, Any], + cfg_kwargs: dict[str, Any], + subkernel_idx: int, + skip_rblock: bool, + signature_keys: OrderedSet[str], +) -> None: + for key, value in cfg_kwargs.items(): + if skip_rblock and key.startswith("R") and "BLOCK" in key: + continue + suffixed_key = f"{key}_{subkernel_idx}" + # Only suffix keys that actually exist in the combo kernel signature. + # Signature keys are real per-subkernel constexpr args such as XBLOCK_0. + # Everything else must stay unsuffixed so HIP-specific compile options like + # waves_per_eu continue to flow through the backend-options path above. + kwargs[suffixed_key if suffixed_key in signature_keys else key] = value + + def _handle_combo_kernel_per_subkernel_blocks( size_hints: dict[str, int], inductor_meta: dict[str, Any], @@ -2729,8 +2971,12 @@ def _handle_combo_kernel_per_subkernel_blocks( Handle per-subkernel config generation for combo kernels. Each sub-kernel gets its own block sizes (XBLOCK_0, XBLOCK_1, etc.) generated - using the same heuristics as standalone Triton kernels. The final config uses - the maximum num_warps and num_stages across all sub-kernels. + using the same heuristics as standalone Triton kernels. + + Returns base configs that vary (num_warps, num_stages) with all blocks at + heuristic defaults. Stores per-subkernel candidate configs in + inductor_meta["combo_tuning_groups"] for sequential chained autotuning + in CachingAutotuner._combo_sequential_autotune(). Returns: List of configs if combo kernel with combo_grid_meta and per-subkernel @@ -2749,13 +2995,22 @@ def _handle_combo_kernel_per_subkernel_blocks( all_num_warps: list[int] = [] all_num_stages: list[int] = [] unique_warp_stage_pairs: OrderedSet[tuple[int, int]] = OrderedSet() + combo_coordesc_field_limits: dict[str, int] = {} + signature_keys = OrderedSet(triton_meta.get("signature", ())) + + # Group sub-kernels with identical config kwargs to skip redundant tuning. + group_map: dict[tuple[Any, ...], dict[str, Any]] = {} for i in range(num_kernels): subkernel_heuristic = combo_meta[f"heuristic_{i}"] size_hints_i = combo_meta[f"size_hints_{i}"] + tiling_scores_i = combo_meta.get(f"tiling_scores_{i}") + inductor_meta_i = dict(inductor_meta_clean) + if tiling_scores_i is not None: + inductor_meta_i["tiling_scores"] = tiling_scores_i if subkernel_heuristic == "pointwise": - cfg = pointwise( + cfgs = pointwise( size_hints_i, triton_meta=triton_meta, tile_hint=TileHint.SQUARE @@ -2763,51 +3018,105 @@ def _handle_combo_kernel_per_subkernel_blocks( else TileHint.DEFAULT, filename=filename, min_elem_per_thread=min_elem_per_thread, - inductor_meta=inductor_meta_clean, + inductor_meta=inductor_meta_i, return_configs=True, - )[0] + ) skip_rblock = False elif subkernel_heuristic == "reduction": - cfg = reduction( + cfgs = reduction( size_hints_i, - reduction_hint=reduction_hint, + reduction_hint=ReductionHint[combo_meta[f"reduction_hint_{i}"]], triton_meta=triton_meta, filename=filename, - inductor_meta=inductor_meta_clean, + inductor_meta=inductor_meta_i, return_configs=True, - )[0] + ) skip_rblock = False elif subkernel_heuristic == "persistent_reduction": - cfg = persistent_reduction( + cfgs = persistent_reduction( size_hints_i, - reduction_hint=reduction_hint, + reduction_hint=ReductionHint[combo_meta[f"reduction_hint_{i}"]], triton_meta=triton_meta, filename=filename, - inductor_meta=inductor_meta_clean, + inductor_meta=inductor_meta_i, return_configs=True, - )[0] + ) skip_rblock = True # persistent reduction embeds RBLOCK in kernel body else: raise ValueError(f"Unknown heuristic: {subkernel_heuristic}") - for key, value in cfg.kwargs.items(): + group_coordesc_fields: OrderedSet[str] = OrderedSet() + cfg = cfgs[0] + _update_combo_kernel_kwargs( + combined_kwargs, cfg.kwargs, i, skip_rblock, signature_keys + ) + for key in cfg.kwargs: if skip_rblock and key.startswith("R") and "BLOCK" in key: continue - combined_kwargs[f"{key}_{i}"] = value + if not key.endswith("BLOCK"): + continue + combined_key = f"{key}_{i}" + group_coordesc_fields.add(combined_key) + prefix = key.removesuffix("BLOCK").lower() + if prefix in size_hints_i: + combo_coordesc_field_limits[combined_key] = min( + TRITON_MAX_BLOCK[prefix.upper()], + size_hints_i[prefix], + ) all_num_warps.append(cfg.num_warps) all_num_stages.append(cfg.num_stages) - unique_warp_stage_pairs.add((cfg.num_warps, cfg.num_stages)) + for c in cfgs: + unique_warp_stage_pairs.add((c.num_warps, c.num_stages)) + + cfg_key = tuple(item for c in cfgs for item in sorted(c.kwargs.items())) + group_key = ( + ( + subkernel_heuristic, + skip_rblock, + cfg_key, + _combo_tiling_signature(tiling_scores_i), + ) + if torch._inductor.config.combo_kernel_autotune_grouping + else (i,) + ) + if group_key in group_map: + group_map[group_key]["member_indices"].append(i) + else: + group_map[group_key] = { + "member_indices": [i], + "configs": cfgs, + "skip_rblock": skip_rblock, + "size_hints": size_hints_i, + "coordesc_fields": list(group_coordesc_fields), + } unique_warp_stage_pairs.add((max(all_num_warps), max(all_num_stages))) + combo_tuning_groups = list(group_map.values()) + # Largest sub-kernels tuned first — they dominate runtime and get most freedom + combo_tuning_groups.sort( + key=lambda g: -functools.reduce(operator.mul, g["size_hints"].values()) + ) + inductor_meta["combo_tuning_groups"] = combo_tuning_groups + inductor_meta["combo_coordesc_field_order"] = [ + field for group in combo_tuning_groups for field in group["coordesc_fields"] + ] + inductor_meta["combo_coordesc_field_limits"] = combo_coordesc_field_limits + # Candidates for num_warps/num_stages re-tuning after block sizes are finalized + inductor_meta["combo_warp_stage_candidates"] = list(unique_warp_stage_pairs) + + # Single base config: max warps/stages, all blocks at heuristic defaults. + # Block sizes are tuned in _combo_sequential_autotune, then num_warps/num_stages + # are re-tuned at the end with finalized block sizes. + base_num_warps = max(all_num_warps) + base_num_stages = max(all_num_stages) return [ triton.Config( combined_kwargs, - num_warps=num_warps, - num_stages=num_stages, + num_warps=base_num_warps, + num_stages=base_num_stages, ) - for num_warps, num_stages in unique_warp_stage_pairs ] @@ -2966,25 +3275,24 @@ def pointwise( ] # Additional configs appended for ROCm builds if torch.version.hip: - if inductor_meta.get("max_autotune_pointwise"): - configs.extend( - [ - triton_config_with_settings( - size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 - ), - triton_config_with_settings( - size_hints, - 4096, # wrt: better than the max_block for some kernel - ), - triton_config_with_settings( - size_hints, - 2048, - num_warps=8, - num_stages=2, - waves_per_eu=1, # 20% improvement - ), - ] - ) + configs.extend( + [ + triton_config_with_settings( + size_hints, TRITON_MAX_BLOCK["X"], waves_per_eu=2 + ), + triton_config_with_settings( + size_hints, + 4096, # wrt: better than the max_block for some kernel + ), + triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1, # 20% improvement + ), + ] + ) if inductor_meta.get("atomic_add_found"): configs.extend( [ @@ -3056,7 +3364,9 @@ def pointwise( ) if len(size_hints) == 3: if not ( - inductor_meta.get("max_autotune_pointwise") or torch.xpu.is_available() + inductor_meta.get("max_autotune") + or inductor_meta.get("max_autotune_pointwise") + or torch.xpu.is_available() ): configs = [triton_config_with_settings(size_hints, 16, 16, 16)] else: @@ -3172,10 +3482,13 @@ def _reduction_configs( ) register_intensive = False - MAX_R0_BLOCK = 2048 loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( "num_reduction", 0 ) + + device_major = triton_meta["device"].major + # Prefer smaller MAX_R0_BLOCK for Blackwell + MAX_R0_BLOCK = 1024 if device_major is not None and device_major >= 10 else 2048 if size_hints["x"] >= 1024 and loads_and_red >= 10: # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. # Consider load and reduction since load need move data into registers and @@ -3349,23 +3662,25 @@ def outer_config_opt(): ] if torch.version.hip: - # Skip large-XBLOCK HIP configs when a combo kernel has a persistent - # sub-kernel with a large hardcoded R0_BLOCK. The persistent tile size - # (XBLOCK * max_persistent_rblock) would otherwise cause pathological - # ROCm compilation times (e.g. 1024 * 1024 = 1M elements → 20+ min). - # Use the same 4096-element threshold as _persistent_reduction_configs. - max_persistent_rblock = inductor_meta.get("max_persistent_rblock", 0) hip_configs = [ make_config(1024, 8, num_warps=4, num_stages=1, waves_per_eu=2), make_config(512, 8, num_warps=4, num_stages=1, waves_per_eu=1), ] + result_configs.extend(hip_configs) + + # Filter ALL configs (not just HIP-specific ones) when a combo kernel + # has a persistent sub-kernel with a large hardcoded R0_BLOCK. The + # persistent tile size (XBLOCK * max_persistent_rblock) causes + # pathological ROCm compilation times (e.g. 64 * 1024 = 64K elements + # → 60+ min triton.compile). Use the same 4096-element threshold as + # _persistent_reduction_configs. + max_persistent_rblock = inductor_meta.get("max_persistent_rblock", 0) if max_persistent_rblock > 0: - hip_configs = [ + result_configs = [ c - for c in hip_configs + for c in result_configs if c.kwargs.get("XBLOCK", 0) * max_persistent_rblock <= 4096 ] - result_configs.extend(hip_configs) return result_configs @@ -3649,9 +3964,7 @@ def cooperative_reduction( # the GPU, we want to create as many CTAs as possible, while keeping things # in powers of 2. target = last_power_of_2(triton_meta["device"].multi_processor_count) - split = max(1, min(target // xnumel, TRITON_MAX_RSPLIT)) - assert rnumel >= split - assert split <= TRITON_MAX_RSPLIT + split = max(1, min((rnumel, target // xnumel, TRITON_MAX_RSPLIT))) if inductor_meta["persistent_reduction"]: configs = _persistent_reduction_configs( {"x": xnumel, "r0_": rnumel // split}, @@ -3687,6 +4000,16 @@ def _persistent_reduction_configs( inductor_meta=None, triton_meta=None, ): + # Under deterministic mode, canonicalize the batch-dim hint so the + # candidate-config branching below (e.g. xnumel // 8 < 128) doesn't pick + # a different (XBLOCK, num_warps) for bs=N vs bs=N/2. Different picks + # change the bf16 reduction order and break batch invariance in + # persistent reductions like LayerNorm. + if inductor_meta and inductor_meta.get("batch_invariant"): + size_hints = dict(size_hints) + if "x" in size_hints: + size_hints["x"] = max(size_hints["x"], 4096) + xnumel = size_hints["x"] rnumel = get_total_reduction_numel(size_hints) @@ -3906,7 +4229,8 @@ def persistent_reduction( # more warps for larger rows new_configs.append(c) - if max_autotune_enabled and c.num_warps < 32: + max_warps_limit = 16 if torch.version.hip else 32 + if max_autotune_enabled and c.num_warps < max_warps_limit: newc = copy.deepcopy(c) newc.num_warps *= 2 new_configs.append(newc) @@ -4144,7 +4468,10 @@ def maximum(self, seq: list[int | str]) -> int | str: return items[0] if self.mode == "python": return f"max({', '.join(map(str, items))})" - return functools.reduce(lambda x, y: f"std::max({x}, {y})", items) + # Cast int constants to (long) to avoid type deduction errors with std::max + # when mixing long variables with int literals + cpp_items = [f"(long){x}" if isinstance(x, int) else str(x) for x in items] + return functools.reduce(lambda x, y: f"std::max({x}, {y})", cpp_items) def summation(self, seq: list[int | str]) -> int | str: """Codegen for sum function with constant folding, constants are represented as int""" @@ -4273,7 +4600,11 @@ def generate(self, meta: dict[str, int], is_lazy: bool = False) -> None: ), ] ) - self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") + ceildiv_expr = self.ceildiv("y_grid_raw_", "y_grid_div_") + if self.mode == "python": + self.y_grid = f"(0 if y_grid_div_ == 0 else {ceildiv_expr})" + else: + self.y_grid = f"(y_grid_div_ == 0 ? 0 : {ceildiv_expr})" self.z_grid = "y_grid_div_" @@ -4360,7 +4691,11 @@ def generate(self, meta: dict[str, int], is_lazy: bool = False) -> None: ), ] ) - self.y_grid = self.ceildiv("y_grid_raw_", "y_grid_div_") + ceildiv_expr = self.ceildiv("y_grid_raw_", "y_grid_div_") + if self.mode == "python": + self.y_grid = f"(0 if y_grid_div_ == 0 else {ceildiv_expr})" + else: + self.y_grid = f"(y_grid_div_ == 0 ? 0 : {ceildiv_expr})" self.z_grid = "y_grid_div_" def combo_x_grid( diff --git a/torch/_inductor/runtime/triton_lazy_compile.py b/torch/_inductor/runtime/triton_lazy_compile.py index 3d6710e659b59..75c7baa808111 100644 --- a/torch/_inductor/runtime/triton_lazy_compile.py +++ b/torch/_inductor/runtime/triton_lazy_compile.py @@ -41,8 +41,6 @@ class TritonKernelCompileResult: profile_scratch: int | None -_pending_kernels: dict[str, Any] = {} - _async_compile: Any = None @@ -98,25 +96,31 @@ def _wrap_tma_args(args: list[Any], kernel_fn: CachingAutotuner) -> list[Any]: return wrapped -def start_kernel_compile(kernel_name: str, kernel_source: str) -> None: +def start_kernel_compile( + pending_kernels: dict[str, Any], kernel_name: str, kernel_source: str +) -> None: """ This function is called from C++ at model initialization time for each kernel. It starts the compilation in a background process but does NOT wait for it. The actual kernel execution happens later in run_triton_kernel_with_autotune(). + + The pending_kernels dict is per-module, created in C++ and passed through + to avoid global state collisions across compiled modules. """ - if kernel_name in _pending_kernels: + if kernel_name in pending_kernels: return async_compile = _get_async_compile() # noqa: F841 (used by eval below) # Evaluate the kernel source to get the Future or CachingAutotuner # The kernel_source is like: async_compile.triton('name', '''...''', ...) - kernel_obj = eval(kernel_source.strip()) # noqa: S307 + kernel_obj = eval(kernel_source.strip()) - _pending_kernels[kernel_name] = kernel_obj + pending_kernels[kernel_name] = kernel_obj def run_triton_kernel_with_autotune( + pending_kernels: dict[str, Any], kernel_name: str, stream: Any, args: list[Any], @@ -127,9 +131,9 @@ def run_triton_kernel_with_autotune( from torch._inductor.codecache import CodeCacheFuture, CudaKernelParamCache from torch._inductor.runtime.triton_heuristics import config_to_dict - if kernel_name not in _pending_kernels: - raise RuntimeError(f"Kernel {kernel_name} not found in pending kernels. ") - kernel_obj = _pending_kernels.pop(kernel_name) + if kernel_name not in pending_kernels: + raise RuntimeError(f"Kernel {kernel_name} not found in pending kernels.") + kernel_obj = pending_kernels[kernel_name] if isinstance(kernel_obj, CodeCacheFuture): kernel_fn = kernel_obj.result() @@ -160,15 +164,26 @@ def run_triton_kernel_with_autotune( from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name cubin_path_name = get_cpp_wrapper_cubin_path_name() - for key in (cubin_path_name, "mangled_name", "num_warps", "shared_mem"): - if key not in cached_params: - raise RuntimeError(f"{key} not found in cached params for {kernel_name}") + for key_name in (cubin_path_name, "mangled_name", "num_warps", "shared_mem"): + if key_name not in cached_params: + raise RuntimeError( + f"{key_name} not found in cached params for {kernel_name}" + ) cubin_path = cached_params[cubin_path_name] mangled_name = cached_params["mangled_name"] num_warps = cached_params["num_warps"] shared_mem = cached_params["shared_mem"] config = config_to_dict(launcher.config) if launcher.config else {} + + # For combo/foreach kernels, the autotuned config may have empty kwargs + # (e.g., the foreach heuristic only tunes num_warps, not XBLOCK). + # In that case, use the default_config from combo_grid_meta + combo_grid_meta = inductor_meta.get("combo_grid_meta") if inductor_meta else None + default_config = combo_grid_meta.get("default_config") if combo_grid_meta else None + if default_config: + config = {**default_config, **config} + xblock = config.get("XBLOCK", 128) yblock = config.get("YBLOCK", 1) zblock = config.get("ZBLOCK", 1) @@ -210,7 +225,7 @@ def run_triton_kernel_with_autotune( profile_scratch, ) - return TritonKernelCompileResult( + result = TritonKernelCompileResult( cubin_path=cubin_path, mangled_name=mangled_name, num_warps=num_warps, @@ -225,3 +240,4 @@ def run_triton_kernel_with_autotune( global_scratch=global_scratch, profile_scratch=profile_scratch, ) + return result diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 8b6e41ca3202d..76e5c483ddab9 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -16,7 +16,7 @@ import typing from collections import Counter, defaultdict from concurrent.futures import as_completed, Future -from typing import Any, Generic, TYPE_CHECKING, TypeAlias, TypeVar +from typing import Any, Generic, Literal, overload, TYPE_CHECKING, TypeAlias, TypeVar from typing_extensions import ParamSpec from torch.utils._ordered_set import OrderedSet @@ -35,7 +35,7 @@ import sympy import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch.utils._pytree as pytree from torch._dynamo.utils import counters, dynamo_timed from torch._inductor.autotune_process import use_pipelined_autotuning @@ -361,6 +361,9 @@ def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16): return False + if MixOrderReduction.is_split_reduction(contiguous_node): + return False + # Other reduction types like max/min is not supported yet. # There are no real use case as well. out = all( @@ -889,6 +892,9 @@ def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: or buf_node.get_inputs_that_alias_output() or buf_node.get_mutation_names() or buf.get_name() in V.graph.removed_buffers + # CommBufferLayout buffer must keep its P2P allocation. + # Do not allow in-place reuse into or from a P2P buffer. + or isinstance(buf_node.get_output_spec(), ir.CommBufferLayout) ): continue @@ -910,8 +916,13 @@ def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: for x in input_buf.users if x.node.get_name() not in inconsequential_nodes ] + has_cross_stream_hazard = self.scheduler.has_cross_stream_hazard( + read.name, self + ) + if ( - len(remaining_uses) == 1 + not has_cross_stream_hazard + and len(remaining_uses) == 1 and remaining_uses[0].can_inplace and remaining_uses[0].node is self and input_buf.node is not None @@ -921,6 +932,7 @@ def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool: ir.NoneLayout, ir.MultiOutputLayout, ir.MutationLayoutSHOULDREMOVE, + ir.CommBufferLayout, ), ) and not ( @@ -1226,11 +1238,11 @@ def _get_estimated_runtime(self) -> float: except ValueError as e: # We don't know how to estimate runtime for this collective, # falling back to 0 - log.info(e) # noqa: G200 + log.info(e) return 0 except TypeError as e: # this happens when the collective is not of type ir._CollectiveKernel - log.info(e) # noqa: G200 + log.info(e) return 0 elif is_wait(self.node): @@ -1319,7 +1331,9 @@ def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str: flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs)) def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] - return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState) + return isinstance(x, ir.IRNode) and not isinstance( + x, (ir.GeneratorState, ir.OpaqueObjectState) + ) cache_key = str( (python_kernel_name,) @@ -1575,10 +1589,17 @@ def recompute_size_and_body( extra_indexing_constraints: tuple[dict[Any, Any], list[Any]] | None = None, recompute_sizes_body_func: Callable[..., Any] | None = None, ) -> None: + fake_deps: OrderedSet[Dep] = OrderedSet( + dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep)) + ) self._compute_attrs( extra_indexing_constraints=extra_indexing_constraints, recompute_sizes_body_func=recompute_sizes_body_func, ) + if fake_deps: + self.set_read_writes( + self.read_writes.with_read(fake_deps).rename(self.mutation_renames) + ) def refresh_dependencies( self, normalize: bool, need_clear_tiling_cache: bool @@ -1610,6 +1631,32 @@ def refresh_dependencies( # lru_cache. SIMDScheduling.candidate_tilings.cache_clear() + def snapshot_loop_state(self) -> tuple[Any, ...]: + """Snapshot mutable state modified by loop transformations + (apply_new_loop_order, apply_loop_reindexing). Must be kept + in sync with those methods and restore_loop_state.""" + return ( + self._body, + self._sizes, + self.group, + self.read_writes, + self.unmet_dependencies, + ) + + def restore_loop_state(self, state: tuple[Any, ...]) -> None: + """Restore state from snapshot_loop_state.""" + from .codegen.simd import SIMDScheduling + + ( + self._body, + self._sizes, + self.group, + self.read_writes, + self.unmet_dependencies, + ) = state + self.pointwise_read_writes.clear_cache(self) + SIMDScheduling.candidate_tilings.cache_clear() + def apply_new_loop_order(self, new_order: Sequence[int]) -> None: self._body = self._body.reorder_iter_loops( new_order, @@ -1618,6 +1665,18 @@ def apply_new_loop_order(self, new_order: Sequence[int]) -> None: self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) + def apply_loop_reindexing(self, new_iter_sizes: Sequence[sympy.Expr]) -> None: + assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) + + self._body = self._body.reindex_iter_loops(new_iter_sizes) + self._sizes = self._body.sizes + + device = self.node.get_device_or_error() + group_fn = self.scheduler.get_backend(device).group_fn + self.group = (device, group_fn(self._sizes)) + + self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True) + def swap_pw_red_dimension(self) -> None: num_rdims = self._body.get_original_num_rdims() num_pwdims = len(self._body.iter_vars) - num_rdims @@ -2139,6 +2198,8 @@ def has_side_effects(self) -> bool: class FusedMixOrderReductions(FusedSchedulerNode): + """Fused node for two reductions with different iteration orders (inner + outer).""" + def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None: if not MixOrderReduction.is_contiguous_node(node1): assert MixOrderReduction.is_contiguous_node(node2) @@ -2197,13 +2258,24 @@ def _get_operation_names( return ( not node2.is_reduction() - or typing.cast( - int, self.scheduler.score_fusion_memory(node1, node2, count_bytes=False) - ) + or self.scheduler.score_fusion_memory(node1, node2, count_bytes=False) >= self.numel ) def can_fuse_with(self, other: BaseSchedulerNode): + # Limit tl.load() count in the fused RSPLIT loop to avoid register + # spills. See https://github.com/pytorch/pytorch/issues/179423 + max_reads = config.triton.mix_order_reduction_max_reads + if max_reads > 0: + all_reads: OrderedSet[str] = OrderedSet() + for sn in itertools.chain(self.get_nodes(), other.get_nodes()): + for dep in sn.read_writes.reads: + if isinstance(dep, MemoryDep): + all_reads.add(dep.name) + if len(all_reads) > max_reads: + # pyrefly: ignore [bad-assignment] + metrics.rejected_mix_order_reduction_fusion += 1 + return False if not isinstance(other, FusedMixOrderReductions): return self.sub_node_can_fuse( self.node1, other, (self.node2,) @@ -2253,16 +2325,18 @@ def epilogue_fuse( node1: ExternKernelSchedulerNode, node2: SchedulerNode, ) -> FusedSchedulerNode: + assert isinstance(node1.node, ir.UserDefinedTritonKernel) scheduler = node1.scheduler - # this unmet dependency is the buffer which is mutated - # after fusion, we don't need this buffer anymore, - # because the kernel directly writes to the output buffer of the epilogue - assert len(node1.unmet_dependencies) == 1 - original_mutated_buffer = scheduler.name_to_buf[ - next(iter(node1.unmet_dependencies)).name - ] - original_mutated_buffer.users.remove(NodeUser(node1)) - return FusedExternTritonKernelSchedulerNode(scheduler, node1, node2) + + assert len(node1.node.mutation_outputs) == 1 + # pyrefly: ignore[bad-assignment] + mutated_name: str = node1.node.mutation_outputs[0].name + # Node1's mutated tensor becomes an intermediary tensor. + # Thus, remove node1 from the respective allocated buffer's users + # for `Scheduler.dead_node_elimination` to remove. + real_name = scheduler.mutation_real_name.get(mutated_name, mutated_name) + scheduler.name_to_buf[real_name].users.remove(NodeUser(node1)) + return cls(scheduler, node1, node2) def codegen(self, wrapper: PythonWrapperCodegen) -> None: assert isinstance(self.fused_epilogue.node, ir.ComputedBuffer) @@ -2592,7 +2666,7 @@ def _default_group_nodes_for_combo_kernels( """ sorted_nodes = scheduler._topological_sort_nodes() grouped_nodes = [] - max_num_nodes = 8 + max_num_nodes = config.combo_kernel_max_num_nodes excluded_buffer_names: OrderedSet[str] = OrderedSet( [ @@ -2618,14 +2692,20 @@ def _default_group_nodes_for_combo_kernels( continue device_groups[device].append(node) - # Chunk each device group separately + # Sub-group by stream to avoid mixing nodes across stream + # boundaries. When multi-stream scheduling is inactive every + # node maps to DEFAULT_STREAM_IDX so this is a no-op. for device_nodes in device_groups.values(): - grouped_nodes.extend( - [ - device_nodes[i : i + max_num_nodes] - for i in range(0, len(device_nodes), max_num_nodes) - ] - ) + stream_groups: dict[int, list[BaseSchedulerNode]] = defaultdict(list) + for node in device_nodes: + stream_groups[scheduler.node_to_stream.get(node, 0)].append(node) + for stream_nodes in stream_groups.values(): + grouped_nodes.extend( + [ + stream_nodes[i : i + max_num_nodes] + for i in range(0, len(stream_nodes), max_num_nodes) + ] + ) return grouped_nodes group_algorithm_for_combo_kernels: Callable[ @@ -3000,12 +3080,36 @@ def get_scheduler_node_symbol_uses( return free_symbol_uses +def _is_epilogue_fusion_enabled(template_node: BaseSchedulerNode) -> bool: + """Check per-template flag, fall back to global config.""" + tb = template_node.get_template_node() + if tb is not None and tb.allow_epilogue_fusion is not None: + return tb.allow_epilogue_fusion + return config.epilogue_fusion + + +def _is_prologue_fusion_enabled(template_node: BaseSchedulerNode) -> bool: + """Check per-template flag, fall back to global config.""" + tb = template_node.get_template_node() + if tb is not None and tb.allow_prologue_fusion is not None: + return tb.allow_prologue_fusion + return config.prologue_fusion + + def is_epilogue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode): - return node1.is_template() and config.epilogue_fusion and not node2.is_template() + return ( + node1.is_template() + and not node2.is_template() + and _is_epilogue_fusion_enabled(node1) + ) def is_prologue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode): - return node2.is_template() and config.prologue_fusion and not node1.is_template() + return ( + node2.is_template() + and not node1.is_template() + and _is_prologue_fusion_enabled(node2) + ) def is_template_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode): @@ -3123,6 +3227,8 @@ def _init(self, nodes: list[ir.Operation]) -> None: self.node_to_stream: dict[BaseSchedulerNode, int] = {} self.buff_to_stream: dict[str, int] = {} self._multi_stream_nodes: bool = False + # Maps stream_idx → user_object_index for retrieving user stream objects + self.stream_idx_to_user_obj_idx: dict[int, int] = {} self._populate_stream_assignments() self.nodes = self.fuse_nodes(self.nodes) @@ -3153,6 +3259,10 @@ def _init(self, nodes: list[ir.Operation]) -> None: ): self.create_combo_kernel_nodes(num_ck_nodes=None) + # torch.cond can contain arbitrary subgraphs, which can contain collectives + # reordering these can cause a nccl hang + self._enforce_conditional_ordering() + # Peak memory pass and overlap pass must run last, otherwise # other reordering passes could undo their effects. if config.reorder_for_peak_memory: @@ -3273,44 +3383,55 @@ def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]: return name_to_donated_buf def _populate_stream_assignments(self) -> None: - """Populate node_to_stream and buff_to_stream from FX node metadata. + """Populate node_to_stream and buff_to_stream from IR node stream_idx. - Reads the 'custom.stream' metadata from FX nodes to determine which - stream each scheduler node should run on. This metadata is set by - dynamo when tracing torch.cuda.stream() context managers. + Reads the stream_idx field set on IR nodes during lowering to determine + which stream each scheduler node should run on. This field is propagated + from 'custom.stream' FX node metadata via IRNode.current_stream_idx(). """ from .stream_constants import DEFAULT_STREAM_IDX # Map user_object_index to stream index (1-indexed for side streams) user_obj_to_stream_idx: dict[int, int] = {} - next_stream_idx = 1 # 0 is reserved for default stream + stream_idx_counter = itertools.count(1) # 0 is reserved for default stream for node in self.nodes: stream_idx = DEFAULT_STREAM_IDX - # Get the origin FX nodes to read metadata. - # Each scheduler node may have multiple origin FX nodes (via origins). if node.node is not None: - origins = node.node.get_origins() - for fx_node in origins: - if not hasattr(fx_node, "meta"): - continue - custom_meta = fx_node.meta.get("custom", {}) - if "stream" in custom_meta: - user_obj_idx = custom_meta["stream"] - if user_obj_idx not in user_obj_to_stream_idx: - user_obj_to_stream_idx[user_obj_idx] = next_stream_idx - next_stream_idx += 1 - stream_idx = user_obj_to_stream_idx[user_obj_idx] - # Use the first stream found - break + user_obj_idx = node.node.get_stream_idx() + if user_obj_idx is not None: + if user_obj_idx not in user_obj_to_stream_idx: + new_stream_idx = next(stream_idx_counter) + user_obj_to_stream_idx[user_obj_idx] = new_stream_idx + self.stream_idx_to_user_obj_idx[new_stream_idx] = user_obj_idx + stream_idx = user_obj_to_stream_idx[user_obj_idx] self.node_to_stream[node] = stream_idx - # Also populate buff_to_stream for all buffers produced by this node + # Also populate buff_to_stream for all buffers produced by this node. + # Mutation renames are resolved at lookup time via get_buf_stream. for buf in node.get_buffer_names(): self.buff_to_stream[buf] = stream_idx + # Propagate a device to device-less nodes (e.g. record_event, + # wait_event) so they naturally enter the device guard in the + # main codegen loop instead of requiring special-case handling. + if any(s != DEFAULT_STREAM_IDX for s in self.node_to_stream.values()): + device = next( + (n.get_device() for n in self.nodes if n.get_device() is not None), None + ) + if device is not None: + for node in self.nodes: + ir_node = node.node + if ( + node.get_device() is None + and isinstance(ir_node, ir.Buffer) + and isinstance(ir_node.layout, ir.NoneLayout) + ): + # pyrefly: ignore [bad-assignment] + ir_node.layout = ir.NoneLayout(device=device) + # Check if we have any nodes on non-default streams self._multi_stream_nodes = any( stream_idx != DEFAULT_STREAM_IDX @@ -3321,6 +3442,21 @@ def _has_multi_stream_nodes(self) -> bool: """Check if any nodes are assigned to non-default streams.""" return self._multi_stream_nodes + def get_buf_stream(self, buf_name: str) -> int: + """Return the stream index for a buffer, resolving mutation renames.""" + real = self.mutation_renames.get(buf_name, buf_name) + return self.buff_to_stream.get(real, self.buff_to_stream.get(buf_name, 0)) + + def has_cross_stream_hazard(self, buf_name: str, node: BaseSchedulerNode) -> bool: + """True if buf_name was produced on a different stream than node. + + Resolves mutation renames so that mutated buffers inherit the + stream of their original definition. + """ + if not self._has_multi_stream_nodes(): + return False + return self.get_buf_stream(buf_name) != self.node_to_stream.get(node, 0) + @property def current_device(self) -> torch.device | None: return V.graph.current_device @@ -3815,6 +3951,17 @@ def visit(n: BaseSchedulerNode) -> None: visit(node) return result + def _enforce_conditional_ordering(self) -> None: + conditional_nodes = [ + n for n in self.nodes if isinstance(n.node, ir.Conditional) + ] + for i in range(1, len(conditional_nodes)): + mutating_buf = next(iter(conditional_nodes[i].get_buffer_names())) + prev_buf = next(iter(conditional_nodes[i - 1].get_buffer_names())) + conditional_nodes[i].add_fake_dep( + WeakDep(prev_buf, mutating_buf=mutating_buf, is_fake=True) + ) + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]: unmet_deps: OrderedSet[str] = OrderedSet() if isinstance( @@ -4340,10 +4487,10 @@ def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None: future.result() except Exception as e: if fusion_log.isEnabledFor(logging.DEBUG): - fusion_log.debug( # noqa: G200 + fusion_log.debug( "Exception in compiling %s: %s", "prologue" if not epilogue_fusion else "epilogue", - str(e), + e, ) continue with multi_node.swap_as_triton_caller(choice): @@ -4475,10 +4622,10 @@ def benchmark_when_ready() -> bool: # triton will unpredictably error with valid prologue fusions except Exception as e: if fusion_log.isEnabledFor(logging.DEBUG): - fusion_log.debug( # noqa: G200 + fusion_log.debug( "Exception in compiling %s: %s", "prologue" if not epilogue_fusion else "epilogue", - str(e), + e, ) continue @@ -4910,11 +5057,7 @@ def fuse_nodes_once( is_reorder_round, ) - if ( - (config.max_autotune_gemm or config.max_autotune) - and config.prologue_fusion - and config.epilogue_fusion - ): + if config.max_autotune_gemm or config.max_autotune: possible_fusions = self._handle_template_overlap( possible_fusions, deferred_prologue_fusions ) @@ -4983,6 +5126,11 @@ def create_combo_kernel_nodes(self, num_ck_nodes: int | None = None) -> None: self.name_to_fused_node.update( {n.get_name(): group_snode for n in group_snode.get_nodes()} ) + # Propagate stream assignment so codegen can place the combo + # kernel in the correct stream context. + stream = self.node_to_stream.get(node_list[0]) + if stream is not None: + self.node_to_stream[group_snode] = stream self.nodes = sorted(fused_nodes, key=lambda x: x.min_order) self.nodes = self.topological_sort_schedule(self.nodes) log.info( @@ -5450,11 +5598,9 @@ def shared_data_after_reordering_loop( """ - # TODO Don't do loop reordering for CPU for now. + # TODO Don't do loop reordering/reindexing for CPU for now. # Should debug more why it does not work for CPU codegen - if not config.loop_ordering_after_fusion or any( - n.is_cpu() for n in [node1, node2] - ): + if any(n.is_cpu() for n in [node1, node2]): return -1 # in some rare case, a template can be passed in. @@ -5463,27 +5609,71 @@ def shared_data_after_reordering_loop( if node1.is_template() or node2.is_template(): return -1 - node1_buffer_names = node1.read_writes.buffer_names() - node2_buffer_names = node2.read_writes.buffer_names() - # Fast path: no common buffers. - common_buffer_names = node1_buffer_names & node2_buffer_names + common_buffer_names = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) if not common_buffer_names: return -1 - node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} - node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + if config.loop_ordering_after_fusion: + score = self._try_reorder_loops_for_candidates(node1, node2) + if score >= 0: + return score + + # No reordering candidates found (or loop ordering disabled). + # Try reindexing the pointwise to match the reduction's iteration + # domain (e.g., [1024, 8192] -> [65536, 128] for RMS norm with + # reshape), then retry loop reordering if enabled. The retry is + # needed because FusedSchedulerNodes may have more loop vars than + # the reindexed pointwise (e.g., 3 vs 2), and only the normalize() + # comparison in _try_reorder_loops_for_candidates handles that + # num_vars mismatch. + if ( + not config.loop_reindexing_after_fusion + or not self._try_reindex_pointwise_for_reduction(node1, node2) + ): + return -1 + + if config.loop_ordering_after_fusion: + score = self._try_reorder_loops_for_candidates(node1, node2) + if score >= 0: + return score + + return self.score_fusion_memory(node1, node2) + + def _try_reorder_loops_for_candidates( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + ) -> int: + """ + Find common buffers with matching normalized stride order but different + loop orders, and try to reorder loops to align them. + """ + common_buffer_names = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() + ) + node1_reads = {dep.name: dep for dep in node1.read_writes.reads} + node1_writes = {dep.name: dep for dep in node1.read_writes.writes} + node2_reads = {dep.name: dep for dep in node2.read_writes.reads} + node2_writes = {dep.name: dep for dep in node2.read_writes.writes} - # Find the commons buffers that has different loop orders candidates = [] for buffer_name in common_buffer_names: - lhs_dep = node1_name2dep[buffer_name] - rhs_dep = node2_name2dep[buffer_name] + lhs_dep = node1_writes.get(buffer_name) or node1_reads[buffer_name] + rhs_dep = node2_writes.get(buffer_name) or node2_reads[buffer_name] + + is_write_read = ( + buffer_name in node1_writes and buffer_name in node2_reads + ) or (buffer_name in node2_writes and buffer_name in node1_reads) + if ( lhs_dep.normalize_with_stride_order() == rhs_dep.normalize_with_stride_order() ): candidates.append( ( + is_write_read, V.graph.sizevars.optimization_hint( lhs_dep.get_numel(), fallback=0 ), @@ -5491,12 +5681,33 @@ def shared_data_after_reordering_loop( rhs_dep, ) ) + elif is_write_read: + # A write→read dep failed normalize_with_stride_order. + # This could be a dimension order issue (reordering can + # fix it) or a factorization issue (only reindexing can). + # Distinguish by checking if the write dep's sizes are + # a subset of the read dep's — if so, reordering the + # read's loops could align them. + w = node1_writes.get(buffer_name) or node2_writes.get(buffer_name) + r = node2_reads.get(buffer_name) or node1_reads.get(buffer_name) + if isinstance(w, MemoryDep) and isinstance(r, MemoryDep): + sv = V.graph.sizevars + w_sizes = w.normalize().size + r_sizes = r.normalize().size + if not all( + any(sv.statically_known_equals(ws, rs) for rs in r_sizes) + for ws in w_sizes + ): + return -1 if len(candidates) == 0: return -1 - # Pick the largest buffer to guide the loop reordering - _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0)) + # Prefer write→read deps over shared reads. Among same + # priority, pick the largest buffer. + _is_wr, _numel, lhs_dep, rhs_dep = max( + candidates, key=operator.itemgetter(0, 1) + ) if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): return -1 @@ -5523,11 +5734,99 @@ def shared_data_after_reordering_loop( node2.get_name(), ) - return ( - typing.cast(int, self.score_fusion_memory(node1, node2)) - if reordered - else -1 + return self.score_fusion_memory(node1, node2) if reordered else -1 + + def _try_reindex_pointwise_for_reduction( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + ) -> bool: + """ + Reindex a pointwise's iteration loops to match a reduction's + groups. After reindexing, the shared reads have identical index + expressions, enabling the codegen to CSE loads. + + Returns True if reindexing was applied. + """ + from .codegen.simd import SIMDKernel + + if node1.is_reduction() and not node2.is_reduction(): + reduction_node, pw_node = node1, node2 + elif node2.is_reduction() and not node1.is_reduction(): + reduction_node, pw_node = node2, node1 + else: + return False + + _, groups = reduction_node.group + red_numel = typing.cast(sympy.Expr, groups[0]) + red_rnumel = typing.cast(sympy.Expr, groups[1]) + target_numel = red_numel * red_rnumel + + if not all(isinstance(sn, SchedulerNode) for sn in pw_node.get_nodes()): + return False + snodes = typing.cast(list[SchedulerNode], pw_node.get_nodes()) + + # All snodes must have the same total iteration numel matching + # the reduction's numel * rnumel so they can be reindexed identically. + if not all( + V.graph.sizevars.statically_known_equals( + sympy_product(sn._sizes[0]), target_numel + ) + for sn in snodes + ): + return False + + if not all( + SIMDKernel.is_compatible((red_numel, red_rnumel), sn.get_ranges()) + for sn in snodes + ): + return False + + # Snapshot state before mutation so we can rollback if the + # reindexed deps don't actually improve the fusion score. + snapshots = [(sn, sn.snapshot_loop_state()) for sn in snodes] + old_pw_group = ( + pw_node.group if isinstance(pw_node, FusedSchedulerNode) else None + ) + + for sn in snodes: + sn.apply_loop_reindexing([red_numel, red_rnumel]) + + if isinstance(pw_node, FusedSchedulerNode): + pw_node.group = snodes[0].group + refresh_group_node_dependencies(pw_node) + + # Verify reindexing actually increases shared deps. + common_names = ( + node1.read_writes.buffer_names() & node2.read_writes.buffer_names() ) + n1_deps = {dep.name: dep for dep in node1.read_writes.reads_and_writes()} + n2_deps = {dep.name: dep for dep in node2.read_writes.reads_and_writes()} + has_benefit = any( + self.deps_match_normalized(n1_deps[name], n2_deps[name]) + for name in common_names + ) + if not has_benefit: + for sn, state in snapshots: + sn.restore_loop_state(state) + if isinstance(pw_node, FusedSchedulerNode): + assert old_pw_group is not None + pw_node.group = old_pw_group + refresh_group_node_dependencies(pw_node) + return False + + # When loop ordering is disabled, re-extract deps with + # normalize=True so variable names are canonical. This is + # safe because no further loop reordering will occur. + # Without this, reindexed deps use different var names + # (e.g. c0 vs d0) causing exact dep comparisons to fail. + if not config.loop_ordering_after_fusion: + for sn in snodes: + sn.refresh_dependencies(normalize=True, need_clear_tiling_cache=False) + if isinstance(pw_node, FusedSchedulerNode): + refresh_group_node_dependencies(pw_node) + + return True def unfusable_node(self, node: BaseSchedulerNode) -> bool: """ @@ -5593,8 +5892,10 @@ def check_prologue_fusion_heuristics_fusable( def low_prec_fp(dtype: torch.dtype) -> bool: return dtype.itemsize <= 2 and dtype.is_floating_point + template_buf = template_node.get_template_node_or_throw() if ( - low_prec_fp(template_node.get_template_node_or_throw().dtype) + not template_buf.is_multi_outputs_template() + and low_prec_fp(template_buf.dtype) and not prologue_node.can_codegen_in_low_precision() ): why( @@ -5765,6 +6066,15 @@ def can_fuse( why("node1 is extern but node2.node.data is not Pointwise") return False + assert len(node1.node.mutation_outputs) == 1 + written_buffer_name = node1.node.mutation_outputs[0].name + + # The epilogue can only read from the output buffer. + # Any other tensor/s would require additional load expressions. + if any(dep.name != written_buffer_name for dep in node2.read_writes.reads): + why("epilogue reads from buffers other than the mutated output") + return False + # the epilogue depends on expressions which may not available in the user triton kernel # (e.g. indexing exprs used not in a load) node2_inner_fn_free_symbols = node2.node.data.inner_fn_free_symbols() @@ -5779,9 +6089,6 @@ def can_fuse( why("node1 and node2 uses different buf layouts") return False - assert len(node1.node.mutation_outputs) == 1 - written_buffer_name = node1.node.mutation_outputs[0].name - def _is_other_node_that_references_mutation_buffer( other_node: BaseSchedulerNode, ): @@ -5809,7 +6116,7 @@ def _is_other_node_that_references_mutation_buffer( return False if node2.is_template(): - if not config.prologue_fusion: + if not _is_prologue_fusion_enabled(node2): why("prologue fusion turned off") return False @@ -5818,11 +6125,10 @@ def _is_other_node_that_references_mutation_buffer( return False template = node2.get_template_node_or_throw() - if not isinstance(template, ir.TritonTemplateBuffer): - why("prologue fusion only supported for TritonTemplates") - return False - allowed_prologue_inps = template.get_allowed_prologue_inps() + if not allowed_prologue_inps: + why("template has no allowed prologue inputs") + return False unsupported_prologue_args = ( OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr] @@ -5866,13 +6172,21 @@ def _is_other_node_that_references_mutation_buffer( if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why): return False - if node1.is_template() and ( - node2.has_aliasing_or_mutation() - or node2.is_reduction() - or not config.epilogue_fusion - ): - why("template epilogue not satisfied") - return False + if node1.is_template(): + if ( + node2.has_aliasing_or_mutation() + or node2.is_reduction() + or not _is_epilogue_fusion_enabled(node1) + ): + why("template epilogue not satisfied") + return False + template_buf = node1.get_template_node() + assert template_buf is not None + if template_buf.is_multi_outputs_template() and not isinstance( + node2.node, ir.ComputedBuffer + ): + why("multi-output template epilogue requires ComputedBuffer") + return False if (node1.get_buffer_names() & V.graph.no_fuse_buffer_names) or ( node2.get_buffer_names() & V.graph.no_fuse_buffer_names @@ -5894,7 +6208,9 @@ def _is_other_node_that_references_mutation_buffer( if ( can_reorder and shared_data_score < config.score_fusion_memory_threshold - and config.loop_ordering_after_fusion + and ( + config.loop_ordering_after_fusion or config.loop_reindexing_after_fusion + ) ): new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2) if new_shared_data_score >= 0: @@ -6114,9 +6430,47 @@ def fusable_stardep_write_and_read_on_empty_tensor( return True return False + @staticmethod + def deps_match_normalized(dep1: Dep, dep2: Dep) -> bool: + """Check if two deps refer to the same access pattern after normalization. + + Handles the case where FusedSchedulerNodes have more loop vars + than a single SchedulerNode (e.g., 3 vars vs 2) by falling back + to normalize() which merges loops before comparing. + """ + if not isinstance(dep1, MemoryDep) or not isinstance(dep2, MemoryDep): + return False + if dep1 == dep2: + return True + if dep1.num_vars == dep2.num_vars: + return ( + dep1.normalize_with_stride_order() == dep2.normalize_with_stride_order() + ) + return dep1.normalize() == dep2.normalize() + def dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int: return V.graph.get_dep_size_hint(dep, count_bytes) + @overload + def score_fusion_memory( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + count_bytes: bool = ..., + return_is_mix_order_reduction: Literal[False] = ..., + allow_mix_order_reduction: bool = ..., + ) -> int: ... + + @overload + def score_fusion_memory( + self, + node1: BaseSchedulerNode, + node2: BaseSchedulerNode, + count_bytes: bool = ..., + return_is_mix_order_reduction: Literal[True] = ..., + allow_mix_order_reduction: bool = ..., + ) -> tuple[int, int, bool]: ... + def score_fusion_memory( self, node1: BaseSchedulerNode, @@ -6259,9 +6613,7 @@ def _can_use_buffer_overlap_scoring( if node1.is_template() or node2.is_template(): return False - if (config.max_autotune or config.max_autotune_gemm) and ( - config.prologue_fusion or config.epilogue_fusion - ): + if config.max_autotune or config.max_autotune_gemm: node1_outputs = node1.get_outputs() node2_outputs = node2.get_outputs() @@ -6282,6 +6634,7 @@ def _can_use_buffer_overlap_scoring( if ( isinstance(user.node, BaseSchedulerNode) and user.node.is_template() + and _is_prologue_fusion_enabled(user.node) ): # Check if this output is actually in the template's # allowed_prologue_inps. If not, fusing horizontally @@ -6319,12 +6672,15 @@ def _can_use_buffer_overlap_scoring( # Conservative: block fusion return False - if config.epilogue_fusion: - for node in (node1, node2): - for dep in node.read_writes.reads: - producer = self.name_to_fused_node.get(dep.name) - if producer is not None and producer.is_template(): - return False + for node in (node1, node2): + for dep in node.read_writes.reads: + producer = self.name_to_fused_node.get(dep.name) + if ( + producer is not None + and producer.is_template() + and _is_epilogue_fusion_enabled(producer) + ): + return False return True @@ -6478,7 +6834,7 @@ def free_buffers(self) -> None: inp = V.graph.graph_inputs[name] if isinstance(inp, ir.TorchBindObject): V.graph.wrapper_code.codegen_free(inp) - elif isinstance(inp, ir.GeneratorState): + elif isinstance(inp, (ir.GeneratorState, ir.OpaqueObjectState)): continue else: storage = inp.data @@ -7417,6 +7773,10 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: if self.default_device_context and config.triton.autotune_at_compile_time: V.graph.wrapper_code.write_get_raw_stream_header() + # Register non-mutated inputs that need alignment checks. + # Deferred to just before the first kernel that reads each input. + V.graph.wrapper_code.register_alignment_check_inputs() + for node in nodes: if log.isEnabledFor(logging.DEBUG): try: @@ -7433,6 +7793,12 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: self.enter_context(node) + # pyrefly: ignore [unbound-name] + if config.size_asserts: + V.graph.wrapper_code.codegen_deferred_input_asserts( + dep.name for dep in node.read_writes.reads + ) + if device := node.get_device(): if ( device != self.current_device @@ -7460,13 +7826,27 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: max(unique_streams) + 1 if unique_streams else 1 ) V.graph.wrapper_code.codegen_device_guard_enter( - device.index, num_streams + device.index, + num_streams, + self.stream_idx_to_user_obj_idx, ) - # Handle stream context switching for multi-stream scheduling - # Only do this for nodes with a device, inside the device guard - if self._has_multi_stream_nodes(): - self.generate_stream_ctx_switching(node) + # Handle stream context switching for multi-stream scheduling. + # This runs for all nodes (including device-less sync ops like + # record_event/wait_event) so they are placed inside the correct + # stream context. Only switch when inside a device guard (i.e. + # current_device is set), since stream variables are declared there. + if self._has_multi_stream_nodes() and self.current_device is not None: + self.generate_stream_ctx_switching(node) + + # Emit deferred alignment copies for inputs first used by this + # node. This runs *after* stream context switching so the copy + # executes on the same stream as the consuming kernel. + # TODO: inputs read on multiple streams should be copied in the + # prologue instead, to avoid cross-stream races. + V.graph.wrapper_code.codegen_deferred_alignment_copies( + dep.name for dep in node.read_writes.reads + ) self.current_node = node self.buffer_names_to_free.update(node.last_usage) @@ -7487,8 +7867,12 @@ def _codegen(self, nodes: list[BaseSchedulerNode]) -> None: backend_ = self.get_backend(device) from .codegen.cuda_combined_scheduling import CUDACombinedScheduling from .codegen.simd import SIMDScheduling + from .codegen.xpu.xpu_combined_scheduling import XPUCombinedScheduling - if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)): + if isinstance( + backend_, + (SIMDScheduling, CUDACombinedScheduling, XPUCombinedScheduling), + ): backend = backend_ else: raise AssertionError(f"{type(self)=}") @@ -7757,6 +8141,14 @@ def can_fuse_multi_outputs_template( return False if not template_buf.is_multi_outputs_template(): return False + + if isinstance(node2.node, ir.MultiOutput): + return ( + len(node2.node.inputs) == 1 + and isinstance(node2.node.inputs[0], ir.IRNode) + and node2.node.inputs[0].get_name() == template_buf.get_name() + ) + return False def fuse( diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index a59bfeee94aa3..3bb5e4b9f1163 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -27,7 +27,7 @@ import sympy import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.testing import rand_strided from torch._dynamo.utils import ( @@ -365,6 +365,11 @@ class SubgraphInfo: range_tree_nodes: dict[sympy.Symbol, "IterationRangesEntry"] | None = None numels: dict[str, sympy.Expr] | None = None + # Mapping from original range-tree root variable names (e.g. "xindex") + # to renamed prologue variables (e.g. "_prologue_x_xindex"). Used by + # prologue hooks to apply text-level renames in a structured way. + root_var_renames: dict[str, str] = dataclasses.field(default_factory=dict) + def __post_init__(self): self.only_copy_if_non_none_fields = ( "range_trees", @@ -507,6 +512,7 @@ def __init__( hint_override: int | None = None, triton_meta: dict[str, object] | None = None, always_freeze_layout: bool = False, + index_dtype_override: str | None = None, ) -> None: if tma_store: pass @@ -577,6 +583,7 @@ def __init__( self.epilogue_fn = epilogue_fn self.render_hooks = {} # type: ignore[var-annotated] self.triton_meta: dict[str, object] | None = triton_meta + self._index_dtype_override = index_dtype_override # For Templated Attention this can be a list of ir.Subgraph self.subgraphs: list[ir.ComputedBuffer] | None = subgraphs @@ -611,6 +618,7 @@ def __init__( self.template_mask: str | None = None self.template_out_shape: str | tuple[str] | None = None self.ops_handler: V.WrapperHandler | None = None # type: ignore[name-defined] + self.root_var_renames: dict[str, str] = {} # When caching is enabled, the generated code is not dependent on the input nodes names, or # symbolic sizes names. @@ -640,6 +648,12 @@ def __init__( # Tracking for intermediate variables self.tmp_var_ctr = itertools.count() + @property + def index_dtype(self) -> str: + if self._index_dtype_override is not None: + return self._index_dtype_override + return super().index_dtype + def _gen_tmp_var(self) -> str: return f"_tmp_var{next(self.tmp_var_ctr)}" @@ -719,6 +733,28 @@ def create_subgraph_body(self, body_name: str, clear_cse: bool = False): with self.set_subgraph_body(body_name): yield + def _make_independent_subgraph(self, subgraph_name, numel, **extra_fields): + """Create a subgraph with fresh independent range trees. + + Used by external template backends for epilogue/prologue hooks + that need their own range tree state. + """ + groups = {"x": V.graph.sizevars.simplify(numel), "r0_": sympy.S.One} + self.subgraph_bodies[subgraph_name] = SubgraphInfo( + body=IndentedBuffer(), + cse=self.cse.clone(), + range_trees=self.construct_range_trees( + pid_cache=None, + inside_reduction=False, + is_reduction=False, + numels=groups, + no_x_dim=False, + ), + range_tree_nodes={}, + numels=groups, + **extra_fields, + ) + def _setup_contiguous_index_state( self, indices: list[str], @@ -963,6 +999,8 @@ def size(self, name: str | None, index: int): """ Hook called from template code to get the size of an arg. Will add needed args to pass it in if it is dynamic. + Automatically wraps with tl.full([], ..., dtype=INDEX_DTYPE) when + int64 indexing is needed to prevent overflow in size arithmetic. """ assert isinstance(index, int) if name is None: @@ -970,7 +1008,10 @@ def size(self, name: str | None, index: int): else: assert isinstance(name, str) val = self.named_input_nodes[name].get_size()[index] - return texpr(self.rename_indexing(val)) + result = texpr(self.rename_indexing(val)) + if self.index_dtype == "tl.int64": + return f"tl.full([], {result}, dtype=INDEX_DTYPE)" + return result def stride(self, name, index=None): """ @@ -1590,6 +1631,7 @@ def indexing( override_mask=None, block_ptr=False, tma_compatibility_checker: TMACompatibilityChecker | None = None, + mask_constant_index=False, ): """ Override the default indexing to use our custom mask and force @@ -1604,6 +1646,7 @@ def indexing( override_mask=self.template_mask, block_ptr=block_ptr, tma_compatibility_checker=tma_compatibility_checker, + mask_constant_index=mask_constant_index, ) def codegen_range_tree(self): @@ -1656,9 +1699,14 @@ def call_kernel( inductor_meta=inductor_meta, triton=True, ) + self._emit_post_kernel_code(wrapper, name) if self.workspace_arg is not None: wrapper.generate_workspace_deallocation(self.workspace_arg) + def _emit_post_kernel_code(self, wrapper, kernel_name: str) -> None: + """Hook for subclasses to emit code after kernel call, before workspace dealloc.""" + pass # noqa: PIE790 + def kernel_benchmark_extra_args(self) -> list[str]: # Grid args are only used for benchmarking, not correctness return [ @@ -1679,12 +1727,19 @@ def get_stride_and_maybe_freeze_layout(self, node) -> list[int]: Scheduler falls back to aten if layout constraint violated. If no aten, freeze right away. """ + # realizing for safety ir.ExternKernel.realize_input(node) layout = node.data.layout node_name = node.get_name() - if isinstance(layout, ir.FlexibleLayout): + # For ReinterpretView, the view's strides are already determined by its layout. + # We skip constraint tracking because node.get_name() returns the underlying + # buffer name, not the view's identity, so constraints would be incorrectly + # associated with the underlying buffer rather than the view. + if isinstance(layout, ir.FlexibleLayout) and not isinstance( + node, ir.ReinterpretView + ): if not use_aten_gemm_kernels() or self.always_freeze_layout: # No ExternKernel fallback available, or always_freeze_layout is set # (e.g., for FlexAttention templates), freeze immediately @@ -1753,6 +1808,8 @@ def codegen_template_body( buf_name_to_prologue_group, prologue_preserves_zero_mask_fn ) + partial_code = self._finalize_partial_render(partial_code) + # Template hooks must be finalised after kernel.remove_kernel_local_buffers # is called (this is called when the kernel context is exited above), and when # the kernel handler is set (as below). This is because the hooks may add @@ -1761,28 +1818,25 @@ def codegen_template_body( # finalize must be called after adding epilogue above with V.set_kernel_handler(self): - if not isinstance(partial_code, str): + if isinstance(partial_code, str): + src_code = partial_code + else: # This is used to calculate flops in TritonTemplateKernels with ir.IRNode.current_origins(template_node.node.origins): partial_code.finalize_hook("") partial_code.finalize_hook("", strict=False) - # TODO: Maybe unify CUTLASSTemplateKernel to also use PartialRender for flexible epilogue fusion. - - for input_name in self.named_input_nodes: - subgraph_name = f"" + # TODO: Maybe unify CUTLASSTemplateKernel to also use PartialRender for flexible epilogue fusion. - partial_code.finalize_hook(subgraph_name, strict=False) - - num_store_subgraphs = self.get_store_output_count() - for i in range(num_store_subgraphs): - subgraph_name = self._get_store_output_subgraph_name(i) + for input_name in self.named_input_nodes: + subgraph_name = f"" + partial_code.finalize_hook(subgraph_name, strict=False) - partial_code.finalize_hook(subgraph_name) + num_store_subgraphs = self.get_store_output_count() + for i in range(num_store_subgraphs): + subgraph_name = self._get_store_output_subgraph_name(i) + partial_code.finalize_hook(subgraph_name) - if isinstance(partial_code, str): - src_code = partial_code - else: # Ensure all hooks are finalized before the kernel is defined. # Note: some of these hooks may have been registered by a kernel subclass src_code = partial_code.finalize_remaining() @@ -1819,6 +1873,447 @@ def codegen_prologues_in_subgraphs( ) self.cse.invalidate(OrderedSet()) + def _finalize_partial_render( + self, partial_code: str | PartialRender + ) -> str | PartialRender: + """Hook to intercept or replace the PartialRender before hook finalization. + + Called after Inductor has populated subgraph buffers (epilogue stores, + prologue loads) but before ``finalize_hook`` resolves placeholders. + Subclasses may return a replacement ``PartialRender`` — e.g. to capture + the rendered hook outputs and supply entirely new source code (as + ``ExternalTritonTemplateKernel`` does for fusion-aware autotuning). + """ + return partial_code + + +class ExternalTritonTemplateKernel(TritonTemplateKernel): + """TritonTemplateKernel variant for external template backends (e.g. Helion). + + Orchestrates prologue/epilogue fusion by running Inductor's codegen ops + (store, load, indexing) on fused scheduler nodes, capturing the generated + code into subgraph buffers. The backend's ``render()`` callable handles + hook setup and AST generation, while the standard ``codegen_template_body`` + handles hook finalization and subgraph codegen. + + Also implements ``call_kernel`` and ``emit_kernel_override`` for + post-codegen emission of the fused kernel. + + Subclasses TritonTemplateKernel for subgraph body management and + template indexing support. + """ + + def __init__(self, template_buffer: "ir.TemplateBuffer") -> None: + class _RealOutputNode: + def get_size(self) -> list: + return list(template_buffer.get_size()) + + def get_layout(self): + return template_buffer.get_layout() + + def get_name(self) -> str: + return template_buffer.get_name() + + # Pass dummy values for TritonTemplateKernel params that are only + # relevant for standalone Triton kernel codegen (grid, warps, etc.). + super().__init__( + kernel_name="", + input_nodes=(), + output_node=_RealOutputNode(), + defines={}, + num_stages=0, + num_warps=1, + grid_fn=None, + meta={}, + call_sizes=[], + hint_override=None, + ) + self._template_buffer = template_buffer + # Extra inputs needed by fused ops beyond the template's own I/O + self._extra_inputs: dict[str, str] = {} + # Prologue primary source buffers, populated by load_input + self._prologue_source_buffers: dict[str, str | None] = {} + # Simplified epilogue interface: {output_param: epilogue_idx} + self._epilogue_idx_by_param: dict[str, int] = {} + # Output params that must keep their original tl.store + self._epilogue_keep_store: OrderedSet[str] = OrderedSet() + # Store target buffers: {buf_name: param_name} + self._extra_store_targets: dict[str, str] = {} + # Prologue variable names per input param + self._prologue_vars: dict[str, dict[str, str]] = {} + # Import lines for emit_kernel_override, populated by external render + self._kernel_imports: list[str] = [] + # Call emission state, populated by _setup_fusion_hooks / external render + self._call_preamble: list[str] = [] + self._call_args: list[str] = [] + # Epilogues that could not be fused into the kernel + self._unfused_epilogues: list[Any] = [] + # Reference to the scheduler, set by _compute_fusion_metadata; + # used in call_kernel() to codegen unfused epilogue nodes + self._scheduling_ref: Any = None + + def _finalize_partial_render( + self, partial_code: str | PartialRender + ) -> str | PartialRender: + # Capture hook outputs. + hook_outputs: dict[str, str] = {} + with V.set_kernel_handler(self): + for key, hook in self.render_hooks.items(): + hook_outputs[key] = hook() + + # Delegate to template buffer. + result = self._template_buffer._finalize_codegen(hook_outputs) + if result is None: + return partial_code + + # Apply structured result to kernel state. + self._kernel_imports = result.imports + self._call_preamble = result.call_preamble + self._call_args = result.call_args + + # Freeze hooks and clear subgraph buffers. + if hook_outputs: + self.render_hooks = { + k: (lambda captured=v: captured) for k, v in hook_outputs.items() + } + for info in self.subgraph_bodies.values(): + info.loads = IndentedBuffer() + info.compute = IndentedBuffer() + info.stores = IndentedBuffer() + info.indexing_code = IndentedBuffer() + + return PartialRender(result.source, self.render_hooks) + + def get_unfused_epilogues(self) -> list[Any]: + return self._unfused_epilogues + + def _compute_fusion_metadata( + self, scheduling, epilogue_nodes, prologue_nodes, buf_name_to_prologue_group + ): + """Compute fusion metadata for external backends. + + Determines eligible epilogues/prologues, builds epilogue specs, + and computes prologue sources — all before render(). + + Hook setup (_setup_epilogue_hook / _setup_prologue_hook) cannot + happen here because it requires V.kernel context, which is only + active during codegen_template_body → render(). + """ + self._scheduling_ref = scheduling + from torch._inductor.dependencies import MemoryDep + + tb = self._template_buffer + self._eligible_epilogues = self._find_eligible_epilogues( + epilogue_nodes, tb.epilogue_fusable_outputs + ) + self._epilogue_nodes_by_subgraph = defaultdict( + list, + {i: [sn] for i, (sn, _, _, _) in enumerate(self._eligible_epilogues)}, + ) + fused_ids = OrderedSet(id(sn) for sn, _, _, _ in self._eligible_epilogues) + self._unfused_epilogues = [ + n + for n in epilogue_nodes + if id(n) not in fused_ids and not isinstance(n.node, ir.MultiOutput) + ] + self._prologue_sources = { + buf_name: frozenset( + d.name for d in pro_node.read_writes.reads if isinstance(d, MemoryDep) + ) + for buf_name, pro_nodes in buf_name_to_prologue_group.items() + for pro_node in pro_nodes + } + + # Build simplified epilogue interface: _epilogue_idx_by_param, + # _epilogue_keep_store, and _extra_store_targets. + from torch._inductor.codegen.common import RemovedArg + + scheduler = V.graph.scheduler + epilogues = self._eligible_epilogues + + # Compute fused node names for buffer removability + fused_node_names = None + if scheduler is not None: + all_store_names = OrderedSet([tb.get_name()]) + all_store_names.update(tb._multi_output_children) + all_store_names.update(st for _, _, _, st in epilogues if st) + fused_node_names = OrderedSet( + scheduler.name_to_buf[n].defining_op_name() + for n in all_store_names + if n in scheduler.name_to_buf + ) + + # Pre-register store_target buffers so we know their param names + for _, _, _, store_target in epilogues: + if ( + store_target is not None + and store_target not in self.args.output_buffers + ): + self.args.output(store_target) + + # Build per-epilogue metadata + for i, (_, output_buf, output_param, store_target) in enumerate(epilogues): + self._epilogue_idx_by_param[output_param] = i + + can_remove = ( + store_target is not None + and fused_node_names is not None + and scheduler.can_buffer_be_removed_through_fusion( + output_buf, fused_node_names + ) + ) + + store_target_param_raw = ( + self.args.output_buffers.get(store_target) + if store_target is not None + else None + ) + store_target_param = ( + None + if isinstance(store_target_param_raw, RemovedArg) + else store_target_param_raw + ) + + if store_target_param is not None: + self._extra_store_targets[store_target] = store_target_param + if can_remove: + self.removed_buffers.add(output_buf) + else: + self._epilogue_keep_store.add(output_param) + + def _setup_fusion_hooks(self): + """Set up epilogue/prologue render hooks and mark prologue buffers. + + Must be called during render() (inside V.kernel context), after + _compute_fusion_metadata has run. + """ + tb = self._template_buffer + + # Mark prologue buffers on the kernel. + for pro_buf, source_bufs in self._prologue_sources.items(): + self.store_buffer_names.add(pro_buf) + if not source_bufs: + self.removed_buffers.add(pro_buf) + + # Set up epilogue hooks. + epilogues = self._eligible_epilogues + for epilogue_idx in range(len(tb.epilogue_fusable_outputs)): + epi = epilogues[epilogue_idx] if epilogue_idx < len(epilogues) else None + self._setup_epilogue_hook( + output_buf=epi[1] if epi else None, + output_param=epi[2] if epi else None, + ) + + # Set up prologue hooks. + for param_name in tb._named_inputs: + self._setup_prologue_hook( + param_name, prologue_sources=self._prologue_sources + ) + + # Register no-op hook (standard path requires it). + self.render_hooks[""] = lambda: "" + + def _find_eligible_epilogues(self, epilogue_nodes, output_param_mapping): + """Compute fusion eligibility and register extra inputs. + + Returns list of eligible epilogue tuples: + [(snode, output_buf, output_param, store_target), ...] + """ + from torch._inductor.dependencies import MemoryDep + + # Filter eligible epilogues + epilogues = [] + for epilogue_node in epilogue_nodes: + if isinstance(epilogue_node.node, ir.MultiOutput): + continue + dep_names = OrderedSet( + d.name + for d in epilogue_node.read_writes.reads + if isinstance(d, MemoryDep) and d.name in output_param_mapping + ) + if len(dep_names) != 1: + continue + output_buf = next(iter(dep_names)) + epilogue_writes = epilogue_node.read_writes.writes + raw_st = next(iter(epilogue_writes)).name if epilogue_writes else None + epilogues.append( + ( + epilogue_node, + output_buf, + output_param_mapping[output_buf], + raw_st if raw_st != output_buf else None, + ) + ) + + # Register extra inputs needed by fused epilogues + for snode, _, _, _ in epilogues: + for dep in snode.read_writes.reads: + if isinstance(dep, MemoryDep) and dep.name not in output_param_mapping: + if dep.name not in self._extra_inputs: + param = f"_extra_input_{len(self._extra_inputs)}" + self._extra_inputs[dep.name] = param + self.args.input_buffers[dep.name] = param + + return epilogues + + def _setup_epilogue_hook(self, output_buf=None, output_param=None): + store_idx = next(self.store_output_ctr) + subgraph_name = self._get_store_output_subgraph_name(store_idx) + if output_buf is None: + self.subgraph_bodies[subgraph_name] = SubgraphInfo(body=IndentedBuffer()) + self.render_hooks[subgraph_name] = lambda: "" + return + + n_dims = len(self._template_buffer.get_size()) + indices = [f"x_epilogue{store_idx}_{d}" for d in range(n_dims)] + val = f"_kernel_val_{store_idx}" + mask = f"_tile_mask_{store_idx}" + + buf = output_buf + node = V.graph.get_buffer(buf) if buf else None + output_size = ( + list(node.get_size()) + if node is not None + else list(self.output_node.get_size()) + ) + self._make_independent_subgraph(subgraph_name, sympy_product(output_size)) + with self.set_subgraph_body(subgraph_name): + indices = list(map(OpOverrides.paren, indices)) + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] + lengths = [V.graph.sizevars.simplify(s) for s in output_size] + assert len(indices) == len(lengths) + self.template_out = val + self._setup_contiguous_index_state( + indices, + index_symbols, + lengths, + mask, + xindex_name=f"x_epilogue{store_idx}_index", + ) + self.template_out_shape = val + + # Set up CSE state for epilogue codegen + block_shape = tuple( + f"{rt.prefix.upper()}BLOCK" + for rt in self.range_trees + if not rt.is_reduction + ) + if not block_shape: + block_shape = ("XBLOCK",) + self.cse.store_cache[buf] = self.cse.namedvar( + val, dtype=torch.float32, shape=block_shape + ) + assert output_param is not None + self.args.output_buffers[buf] = output_param + + self.render_hooks[subgraph_name] = self._make_codegen_hook(subgraph_name) + + def _setup_prologue_hook(self, input_name, prologue_sources=None): + ir_node = self._template_buffer._named_inputs.get(input_name) + if ir_node is None: + return + self.named_input_nodes[input_name] = ir_node + input_buf = ir_node.get_name() + source_bufs = ( + prologue_sources.get(input_buf) if prologue_sources is not None else None + ) + if source_bufs is None: + return + source_buffer = next(iter(source_bufs)) if source_bufs else None + self._prologue_source_buffers[input_name] = source_buffer + if source_buffer is not None: + self.args.input_buffers[source_buffer] = input_name + for src in source_bufs: + if src != source_buffer: + if src not in self._extra_inputs: + self._extra_inputs[src] = f"_extra_input_{len(self._extra_inputs)}" + self.args.input_buffers[src] = self._extra_inputs[src] + + subgraph_name = f"" + result_var = f"_prologue_{input_name}_result" + + # Compute prologue variable names once and store them. + renames = { + "xindex": f"_prologue_{input_name}_xindex", + "xmask": f"_prologue_{input_name}_xmask", + } + self._prologue_vars[input_name] = { + "xindex": renames["xindex"], + "xmask": renames["xmask"], + "result": result_var, + } + + class _CaptureStoreHandler(V.WrapperHandler): # type: ignore[name-defined] + def store(self, name, index, value, mode=None): + V.kernel.store_buffer_names.add(name) + V.kernel.cse.store_cache[name] = value + V.kernel.compute.writeline(f"{result_var} = {value}") + + self._make_independent_subgraph( + subgraph_name, + sympy_product(ir_node.get_size()), + ops_handler=_CaptureStoreHandler, + root_var_renames=renames, + ) + + def hook(_name=subgraph_name, _input=input_name, _self=self): + with _self.set_subgraph_body(_name): + _self.codegen_body() + _self.cse.invalidate(OrderedSet()) + body = _self.body.getvalue() + # Rename range-tree root variables (xindex/xmask) to avoid collisions + # across prologue subgraphs. tmp/x0 names are already unique (shared + # kernel-level counters). Read renames from the subgraph info. + subgraph = _self.subgraph_bodies[_name] + for orig, renamed in subgraph.root_var_renames.items(): + body = re.sub(rf"\b{orig}\b", renamed, body) + return body.rstrip() + + self.render_hooks[f""] = hook + + def call_kernel(self, name, node=None, deallocate_ws=True): + """Emit the kernel call, multi-output unpacking, and unfused epilogues.""" + tb = self._template_buffer + wrapper = V.graph.wrapper_code + for line in self._call_preamble: + wrapper.writeline(line) + output_name = tb.get_name() + wrapper.writeline(f"{output_name} = {name}({', '.join(self._call_args)})") + # Unpack multi-output children from the kernel result + for mo_name, mo in sorted(tb._multi_output_children.items()): + if mo_name not in self.removed_buffers: + idx_str = output_name + for _, idx in mo.indices: + idx_str = f"{idx_str}[{idx}]" + wrapper.writeline(f"{mo_name} = {idx_str}") + # Unfused epilogues are codegen'd separately after the kernel call + for epi_node in self._unfused_epilogues: + self._scheduling_ref.codegen_node(epi_node) + + def emit_kernel_override( + self, + wrapper, + src_code, + kernel_name, + node_schedule, + kernel_path, + get_kernel_metadata, + ): + origins, detailed = get_kernel_metadata(node_schedule, wrapper) + wrapper.header.writeline(f"# kernel path: {kernel_path}\n{origins}\n{detailed}") + for imp in self._kernel_imports: + wrapper.add_import_once(imp) + # Hooks may reference standard Inductor runtime modules + wrapper.add_import_once( + "from torch._inductor.runtime import triton_helpers, triton_heuristics" + ) + wrapper.add_import_once( + "from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math" + ) + wrapper.header.splice(src_code, strip=True) + wrapper.header.writeline("") + return True + @functools.cache def _jinja2_env(): @@ -2031,7 +2526,7 @@ def maybe_append_choice( choices.append(choice) return None except NotImplementedError as e: - log.info( # noqa: G200 + log.info( "Cannot Append Choice: %s. KernelTemplate type is %s", e, type(self), @@ -2126,6 +2621,7 @@ def generate_and_load( "subgraphs": subgraphs, "prologue_loads_all_inputs": self.prologue_loads_all_inputs, "always_freeze_layout": self.always_freeze_layout, + "index_dtype_override": index_dtype, } if HAS_WARP_SPEC: @@ -2360,7 +2856,8 @@ def generate( # type: ignore[override] workspace_zero_fill = False workspace_args = [] if workspace_arg is not None: - workspace_size_bytes = workspace_arg.count + ws_count = V.graph.sizevars.optimization_hint(workspace_arg.count) + workspace_size_bytes = ws_count * get_dtype_size(workspace_arg.dtype) workspace_zero_fill = ( workspace_arg.zero_mode != WorkspaceZeroMode.UNINITIALIZED ) @@ -2472,6 +2969,7 @@ def make_kernel_render(out_node, hint_override: int | None = None): "matrix_instr_nonkdim": kwargs.get("matrix_instr_nonkdim", 0), "waves_per_eu": kwargs.get("waves_per_eu", 0), "kpack": kwargs.get("kpack", 2), + "epilogue_subtile": kwargs.get("EPILOGUE_SUBTILE", 0), **{ k: kwargs[k] for k in AlgorithmSelectorCache.FLEX_ATTENTION_TUNABLE_KEYS @@ -2992,11 +3490,7 @@ def get_num_workers() -> int: if "TORCHINDUCTOR_COMPILE_THREADS" in os.environ: return int(os.environ["TORCHINDUCTOR_COMPILE_THREADS"]) - cpu_count = ( - len(os.sched_getaffinity(0)) - if hasattr(os, "sched_getaffinity") - else os.cpu_count() - ) + cpu_count = torch._utils.cpu_count() assert cpu_count # Divide the number of CPUs by the number of GPUs for distributed workloads @@ -3267,7 +3761,6 @@ def __call__( benchmark_with_cudagraphs: bool = False, # Use CUDA graphs for ExternKernelCaller benchmarking ): from .codegen.cutlass.kernel import CUTLASSTemplateCaller - from .codegen.subgraph import SubgraphChoiceCaller # Run preprocessing functions on choices for preprocessing_fn in self.preprocessing_fns: @@ -3290,7 +3783,7 @@ def __call__( if len(choices) == 0: raise self.create_no_valid_choices(name, "No choices exist for backend.") - log.debug("Max autotune selects from %s choices.", str(len(choices))) + log.debug("Max autotune selects from %s choices.", len(choices)) if len(choices) == 1: if not isinstance(choices[0], CUTLASSTemplateCaller): @@ -3322,9 +3815,6 @@ def __call__( assert not config.benchmark_epilogue_fusion, ( "Benchmarking epilogues will cause gpu contention with pipelined autotuning" ) - assert all(not isinstance(c, SubgraphChoiceCaller) for c in choices), ( - "Pipelined autotuning not compatible yet with subgraph choices" - ) extern_kernels = [ c for c in choices if AlgorithmSelectorCache._is_extern(c) ] @@ -3515,6 +4005,8 @@ def is_fallback(c: ChoiceCaller) -> bool: ) best_choice = min(fallback_choices, key=lambda c: timings[c]) + best_choice = V.choices.override_best_choice(best_choice, timings) + # Test-only: force choosing decomposition (non-fallback) if available if config.test_configs.force_custom_op_decomposition: @@ -3979,7 +4471,7 @@ def wait_on_futures() -> dict[ChoiceCaller, float]: "select_algorithm_num_precompilation_exceptions" ] += 1 exceptions.append((futures[future], e)) - log.exception( # noqa: G202 + log.exception( "Exception %s for benchmark choice %s", e, futures[future], @@ -4409,7 +4901,7 @@ def choice_priority(c: ChoiceCaller) -> int: from triton.runtime.autotuner import OutOfResources if isinstance(e, OutOfResources): - log.warning(e) # noqa: G200 + log.warning(e) timing = float("inf") else: raise e @@ -5020,6 +5512,9 @@ def benchmark_example_value(node, hint_override: int | None = None): if isinstance(node, ir.Layout): node = ir.Buffer(name="fake", layout=node) # triton templates want the base tensor. + # Preserve the original dtype before unwrapping, since dtype views + # (e.g. uint8 -> float4_e2m1fn_x2) would be lost by unwrap_view. + original_dtype = node.get_dtype() if isinstance(node, ir.BaseView): node = node.unwrap_view() @@ -5027,7 +5522,7 @@ def benchmark_example_value(node, hint_override: int | None = None): # stride is large enough. The V.graph.get_allocation_size takes this into account. # So we need call as_strided in the end to 'view' the tensor with the correct # sizes/strides - return AlgorithmSelectorCache.generate_example_value( + result = AlgorithmSelectorCache.generate_example_value( V.graph.sizevars.optimization_hints_with_override( node.get_size(), hint_override=hint_override, @@ -5047,6 +5542,10 @@ def benchmark_example_value(node, hint_override: int | None = None): hint_override=hint_override, ), ) + # Restore dtype if it was lost by unwrapping a dtype view + if result.dtype != original_dtype: + result = result.view(original_dtype) + return result @staticmethod def generate_example_value( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index edd6b0c14d65e..df4f3632d6b5e 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -2,8 +2,6 @@ import functools import itertools import logging -import sys -from collections import defaultdict from collections.abc import Callable, Iterable, Sequence from typing import Any, cast @@ -11,17 +9,20 @@ from sympy import Expr from torch import SymInt +from torch.fx.experimental._size_hinting import ( + _guarding_hint_or_throw_base, + _maybe_realize_expr, + _optimization_hint_base, +) from torch.fx.experimental.symbolic_shapes import ( free_symbols, free_unbacked_symbols, - GuardOnDataDependentSymNode, - has_free_unbacked_symbols, IterateExprs, ShapeEnv, + SymNode, ) from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import FloorDiv, Mod, ModularIndexing -from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.value_ranges import IntInfinity, ValueRanges @@ -220,19 +221,29 @@ def remove_zero_terms(base, divisor): for v in base.free_symbols: if v in var_ranges: - # var smaller than divisor can be removed - # if the rest is guaranteed to be multiple of divisor rest = sympy.Wild("_rest", exclude=[v]) m = base.match(v + rest) if m and v not in m[rest].free_symbols: + # v can be removed if it doesn't affect the FloorDiv. + # rest is always a multiple of gcd(rest, divisor), so + # rest % divisor is also a multiple of that gcd. The + # worst case is rest % divisor == divisor - gcd, so + # adding v is safe when v < gcd. gcd = sympy.gcd(m[rest], divisor) - if gcd == divisor: - if statically_known(v < divisor): - base = m[rest] + if statically_known(v < gcd): + base = m[rest] return base def visit_indexing_div(base, divisor): - return FloorDiv(remove_zero_terms(base, divisor), divisor) + base = remove_zero_terms(base, divisor) + if statically_known(base >= 0) and statically_known(base < divisor): + return sympy.S.Zero + # FloorDiv(ModularIndexing(b, d1, m), d2) = ModularIndexing(b, d1*d2, m//d2) + if isinstance(base, ModularIndexing) and isinstance(divisor, sympy.Integer): + b, d1, m = base.args + if m % divisor == 0: + return ModularIndexing(b, d1 * divisor, FloorDiv(m, divisor)) + return FloorDiv(base, divisor) def visit_modular_indexing(base, divisor, modulus): base = remove_zero_terms(base, divisor) @@ -403,6 +414,54 @@ def statically_known_gt(self, left: Expr, right: Expr | int) -> bool: expr = left > right return self.statically_known_true(expr) + def _is_multiple_of(self, numerator: Expr, denominator: int) -> bool: + """ + Structural divisibility check: returns True only if numerator is + provably a multiple of denominator. Recurses over sympy expression + structure before falling back to statically_known_true. + """ + # Rule 1 — concrete value + if isinstance(numerator, (int, sympy.Integer)): + return int(numerator) % denominator == 0 + + # Rule 2 — product: any factor divisible → product divisible + if isinstance(numerator, sympy.Mul): + for factor in numerator.args: + if self._is_multiple_of(factor, denominator): + return True + # Also check if combined constant factors are divisible + const = 1 + for factor in numerator.args: + if isinstance(factor, (int, sympy.Integer)): + const *= int(factor) + if const != 1 and const % denominator == 0: + return True + + # Rule 3 — sum: all terms divisible → sum divisible + if isinstance(numerator, sympy.Add): + if all(self._is_multiple_of(term, denominator) for term in numerator.args): + return True + + # Rule 4 — FloorDiv(a, b): if a is multiple of b*n + if isinstance(numerator, FloorDiv): + a, b = numerator.args + if isinstance(b, (int, sympy.Integer)): + if self._is_multiple_of(a, int(b) * denominator): + return True + + # Rule 5 — Mod(a, b): Mod(a,b) = a - b*floor(a/b), so if both a and b + # are multiples of n, then Mod(a,b) is too. + if isinstance(numerator, (Mod, sympy.Mod)): + a, b = numerator.args + if self._is_multiple_of(a, denominator) and self._is_multiple_of( + b, denominator + ): + return True + + # Rule 6 — axiom fallback: ask ShapeEnv + expr = sympy.Eq(Mod(numerator, denominator), 0) + return self.statically_known_true(expr) + def statically_known_multiple_of( self, numerator: Expr, denominator: Expr | int ) -> bool: @@ -415,6 +474,10 @@ def statically_known_multiple_of( if len(free_symbols(numerator)) > 20: return False + if isinstance(denominator, (int, sympy.Integer)): + return self._is_multiple_of(numerator, int(denominator)) + + # For symbolic denominators, fall back to direct sympy check expr = sympy.Eq(Mod(numerator, denominator), 0) return self.statically_known_true(expr) # type: ignore[arg-type] @@ -548,6 +611,20 @@ def evaluate_min(self, left: Expr, right: Expr) -> Expr: if right == gcd: return right + # Min/Max fallback: we can prove Min(a, b) <= c when any arg <= c, but + # sympy doesn't simplify this yet. So, evaluate it here. Same for Max. + for lhs, rhs in [(left, right), (right, left)]: + + def le_rhs(a: Expr) -> bool: + return self.guard_or_false(sympy.Le(a, rhs)) + + # Min(Min(a, b), c) ==> Min(a, b) if (a <= c) or (b <= c). + if isinstance(lhs, sympy.Min) and any(le_rhs(a) for a in lhs.args): + return lhs + # Min(Max(a, b), c) ==> Max(a, b) if (a <= c) and (b <= c). + if isinstance(lhs, sympy.Max) and all(le_rhs(a) for a in lhs.args): + return lhs + raise TypeError( f"evaluate_min({left}, {right}) with unbacked symints" ) from None @@ -605,7 +682,8 @@ def replace_backed_symbols_with_hints( expr = self.remove_precomputed_replacements(expr) expr = sympy_subs(expr, self.backed_var_to_val) - expr = expr.expand(identity=True) + if isinstance(expr, Expr): + expr = expr.expand(identity=True) free_symbols = expr.free_symbols if not free_symbols: @@ -666,50 +744,14 @@ def guarding_hint_or_throw(self, expr: Expr | int) -> int: optimization_hint: For cases where fallback/heuristic values are acceptable for unbacked symbols. """ - expr = self.simplify(expr) - if isinstance(expr, sympy.Expr): - expr = expr.expand(identity=True) - - # Replace backed symbols with their hints, leaving unbacked symbols alone. - expr = self.replace_backed_symbols_with_hints(expr) - - if has_free_unbacked_symbols(expr): - raise GuardOnDataDependentSymNode(expr) - - result = self._maybe_realize_expr(expr, None) - assert result is not None, expr - return result - - def _maybe_realize_expr(self, expr: Expr, nan_fallback: int | None) -> int | None: - """ - Handle special sympy values in optimization hints. - - Returns: - - Raises ValueError for complex numbers - - sys.maxsize for positive infinity - - -sys.maxsize for negative infinity - - fallback for NaN - - None if no special handling needed - """ - try: - return int(expr) - except (TypeError, ValueError): - pass - - if isinstance(expr, Expr): - if expr.has(sympy.I): - raise ValueError( - f"_maybe_realize_expr received a complex expression: {expr}. " - "Tensor dimensions cannot be complex numbers." - ) - if expr in (int_oo, sympy.oo): - return sys.maxsize - if expr in (-int_oo, -sympy.oo): - return -sys.maxsize - if nan_fallback is not None and expr is sympy.nan or expr.has(sympy.nan): - return nan_fallback - - return None + if isinstance(expr, SymNode): + raise TypeError( + f"guarding_hint_or_throw expects a sympy Expr or int, not {type(expr)}. " + "Use expr.expr to extract the sympy expression from a SymNode." + ) + return _guarding_hint_or_throw_base( + self.shape_env, expr, self.inv_precomputed_replacements + ) def optimization_hint(self, expr: Expr | int, fallback: int | None = None) -> int: """ @@ -728,74 +770,9 @@ def optimization_hint(self, expr: Expr | int, fallback: int | None = None) -> in - Infinity (int_oo, sympy.oo): returns sys.maxsize. - NaN (sympy.nan): returns the fallback value. """ - # Read config at call time to respect runtime patches (e.g., in tests) - if fallback is None: - fallback = config.unbacked_symint_fallback - assert fallback is not None - - expr = self.simplify(expr) - result = self._maybe_realize_expr(expr, fallback) - if result is not None: - return result - - if isinstance(expr, sympy.Expr): - expr = expr.expand(identity=True) - - original = expr - - expr = self.replace_backed_symbols_with_hints(expr) - - result = self._maybe_realize_expr(expr, fallback) - if result is not None: - return result - - # replace unbacked with optimizations hints if exists. - expr = sympy_subs(expr, self.var_to_hint_override) - - result = self._maybe_realize_expr(expr, fallback) - if result is not None: - return result - - # If unbacked symbols remain, try to substitute them using heuristics - # that maximize consistency with the shape environment. - if has_free_unbacked_symbols(expr): - # Make sure to substitute with the factored version - # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 - - # Limit sympy.factor() to expressions with <= 200 free symbols, - # as factoring polynomials with many variables is expensive. - if isinstance(original, sympy.Expr) and len(original.free_symbols) <= 200: - expr = self._sub_unbacked_exprs(sympy.factor(original)) - else: - # TODO optimize _sub_unbacked_exprs - expr = self._sub_unbacked_exprs(original) - - # For multiple expressions that depend on an unbacked symint, - # we want to compute them consistently for a size hint we have chosen. - # So, recursively compute expressions via size hints of contained symbols. - # For example: u1 * u2 - 10 ==> fallback * fallback - 10 - - assert isinstance(expr, Expr), type(expr) - free_symbols = expr.free_symbols - - # Constrain fallback per-symbol based on var_to_range bounds - size_dict = {} - for s in free_symbols: - sym_fallback = fallback - vr = self.shape_env.var_to_range.get(s, None) - if vr is not None: - if isinstance(vr.lower, (int, sympy.Integer)): - sym_fallback = max(sym_fallback, int(vr.lower)) - if isinstance(vr.upper, (int, sympy.Integer)): - sym_fallback = min(sym_fallback, int(vr.upper)) - size_dict[s] = sym_fallback - - final_result = expr.subs(size_dict) - - final_result = self._maybe_realize_expr(final_result, fallback) - assert final_result is not None, final_result - - return final_result + return _optimization_hint_base( + self.shape_env, expr, self.inv_precomputed_replacements, fallback + ) def optimization_hints( self, @@ -843,7 +820,7 @@ def optimization_hint_with_override( Returns: int: A concrete integer hint for the expression. """ - simplified = self._maybe_realize_expr(self.simplify(expr), None) + simplified = _maybe_realize_expr(self.simplify(expr), None) if simplified is not None: return simplified @@ -946,192 +923,6 @@ def _stride_vars( ) return strides - def _get_unbacked_replacements(self) -> dict[Expr, Expr]: - if self.unbacked_replacements is not None: - return self.unbacked_replacements - - class CanonicalExprFinder: - """ - Purpose: - A disjoint-set/union-find data structure that can return the - "canonical" expression for a group of equivalent expressions. - - The canonical expression must come from the input eq_graph. - - The heuristics used to choose a leader determines which - expression becomes the canonical expression. - - Problem: - Given any unbacked expression, we should be able to find a size_hint - for the unbacked expression, that adheres to the ShapeEnv's deferred - runtime assertions. Otherwise, we may generate conflicting size hints. - In other words, even though we know u0 + s0 == u2, we may generate - size hints, such that, size_hint(u0 + s0) != size_hint(u2). - NOTE: At this time, only deferred runtime asserts that are equalities - (i.e. Eq(lhs, rhs)) are considered in this data structure. - - Examples: - - u0 + u1 == 9000, then find_expr(u0 + u1) == find_expr(9000) - - u0 + u1 == s9, then find_expr(u0 + u1) == find_expr(s9) - - u0 + s0 == u10, then find_expr(u0 + s0) == find_expr(u10) - - Inputs: - - equality_graph: An adjacency set of expressions where the edge - connects two expressions that are found equal to each other. The - edges are sourced from ShapeEnv's deferred_runtime_asserts. - - Usage: - - Call union_expr(a, b) to merge a & b into a single set which - shares the same canonical expression. - - Call find_expr(x) to find the canonical expression for x. - """ - - def __init__(self, eq_graph: dict[Expr, OrderedSet[Expr]]): - self.eq_graph = eq_graph - self.expressions = list(eq_graph.keys()) - self.reverse_expressions = { - expr: i for i, expr in enumerate(self.expressions) - } - # Each node is its own leader/parent initially - self.leader = list(range(len(self.expressions))) - # Track size for union-by-size - self.size = [1] * len(self.expressions) - - # Takes each edge from the undirected graph and starts merging them. - self._build_canonical_expr_mapping() - - def _build_canonical_expr_mapping(self): - for expr, edges in self.eq_graph.items(): - for adj in edges: - self.union_expr(expr, adj) - - def union_expr(self, a: Expr, b: Expr): - return self.union( - self.reverse_expressions[a], self.reverse_expressions[b] - ) - - def union(self, a: int, b: int): - rootA = self.find(a) - rootB = self.find(b) - if rootA == rootB: - return False # already connected - leader, other = self.choose_leader(rootA, rootB) - self.leader[other] = leader - self.size[leader] += self.size[other] - return True - - def find_expr(self, expr: Expr): - parent = self.find(self.reverse_expressions[expr]) - return self.expressions[parent] - - def find(self, x: int): - # Path compression - if self.leader[x] != x: - self.leader[x] = self.find(self.leader[x]) - return self.leader[x] - - def choose_leader(self, a: int, b: int): - """ - The leader will become the canonical expression. - Returns a (leader, follower) tuple. - - Here are the heuristics used for choosing a leader: - 1. Backed expression or constants preferred over unbacked expr - 2. Simpler sub-expr when one contains the other - 3. Higher frequency across equalities from deferred runtime assertions - 4. Size of the set - 5. Fallback to sympy.Basic.compare - """ - - def _choose(x: int, y: int) -> bool: - lhs, rhs = self.expressions[x], self.expressions[y] - - # Prefer replacing unbacked exprs with backed expressions/constants. - # Examples: - # u0 + s3 ==> s0 + s1, then leader is s0 + s1 - # u2 ==> 300, then leader is 300 - any_unbacked_lhs = has_free_unbacked_symbols(lhs) - any_unbacked_rhs = has_free_unbacked_symbols(rhs) - if any_unbacked_lhs != any_unbacked_rhs: - return bool(any_unbacked_rhs) - - # Handles cases where LHS contains the RHS. In other words, - # RHS is a sub-expression of LHS. For example: - # s1 * Max(2, u0) ==> Max(2, u0), then leader is Max(2, u0) - if lhs.has(rhs): - return False - elif rhs.has(lhs): - return True - - # Prefer expressions that come up more often. - degrees_lhs = len(self.eq_graph[lhs]) - degrees_rhs = len(self.eq_graph[rhs]) - if degrees_lhs != degrees_rhs: - return degrees_lhs > degrees_rhs - - # Try to apply union-by-size optimization to flatten the - # leader trees. - if self.size[x] != self.size[y]: - return self.size[x] > self.size[y] - - # Fallback to sympy.Basic.compare for a deterministic ordering. - return lhs.compare(rhs) == -1 - - if _choose(a, b): - return a, b - return b, a - - # Build an undirected graph using ShapeEnv's deferred runtime assertions. - self.equality_graph: dict[Expr, OrderedSet[Expr]] = defaultdict(OrderedSet) - for assertions in self.shape_env.deferred_runtime_asserts.values(): - for assertion in assertions: - if not isinstance(assertion.expr, sympy.Equality): - # We're ignoring other relationals for now. If you need to - # account for relationals, then you may need a solver solution. - continue - lhs = sympy.sympify(assertion.expr.lhs) # sympify helps with ints - rhs = sympy.sympify(assertion.expr.rhs) - self.equality_graph[lhs].add(rhs) - self.equality_graph[rhs].add(lhs) - - # Use the undirected graph to create a DSU data structure, so we can - # query for a "canonical" expression. - uf = CanonicalExprFinder(self.equality_graph) - - # Start building the unbacked replacements mapping using CanonicalExprFinder - # The mapping is from Expr to its "canonical" Expr. - self.unbacked_replacements = {} - for expr in self.equality_graph: - canonical_expr = uf.find_expr(expr) - if expr != canonical_expr: - self.unbacked_replacements[expr] = canonical_expr - - return self.unbacked_replacements - - @functools.lru_cache # noqa: B019 - def _sub_unbacked_exprs(self, expr: Expr) -> Expr: - # it's fine to cache this fn since self is a singleton - replacements = self._get_unbacked_replacements() - - # consider making this threshold configurable - sub_cnt_limit = 30 - sub_cnt = 0 - while sub_cnt < sub_cnt_limit: - new_expr = expr.subs(replacements) - if new_expr == expr: - break - # Skip sympy.factor() for expressions with many free symbols, - # as factoring polynomials with many variables is expensive. - if len(new_expr.free_symbols) <= 200: - expr = sympy.factor(new_expr) - else: - expr = new_expr - sub_cnt += 1 - else: - log.warning("Substitution limit (%d) reached w/ %s", sub_cnt_limit, expr) - - expr = sympy_subs(expr, self.backed_var_to_val) - expr = sympy_subs(expr, self.var_to_hint_override) - return expr - def offset_var(self, index: Expr, vars: Sequence[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) @@ -1223,7 +1014,7 @@ def _check_args(x, div, mod, is_first): if not _check_args(x2, div2, mod2, False): return index - if mod2 % mod != 0: + if Mod(mod2, mod) != 0: return index return ModularIndexing(x2, 1, mod) diff --git a/torch/_inductor/standalone_compile.py b/torch/_inductor/standalone_compile.py index 1a195a85ef977..6324e68690d5f 100644 --- a/torch/_inductor/standalone_compile.py +++ b/torch/_inductor/standalone_compile.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import copy import logging import os @@ -9,6 +10,9 @@ from contextlib import AbstractContextManager, nullcontext from typing import Any, Literal, TYPE_CHECKING + +DynamicShapesType = Literal["from_example_inputs", "from_tracing_context", "from_graph"] + import torch.fx from torch._dynamo.aot_compile_types import BundledAOTAutogradSerializableCallable from torch._dynamo.utils import dynamo_timed @@ -371,35 +375,24 @@ def load( return AOTCompiledArtifact.deserialize(artifact) -def standalone_compile( - gm: GraphModule, - example_inputs: Sequence[InputType], - *, - dynamic_shapes: Any, - options: Any, - aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache -) -> CompiledArtifact: - """ - Implementation of torch.inductor.standalone_compile - """ - from torch.compiler._cache import CacheArtifactManager +def _resolve_ignore_shape_env(dynamic_shapes: DynamicShapesType): + # tells compile_fx to ignore the shape_envs on the ambient context + # and the graph_module. + return dynamic_shapes == "from_example_inputs" - from .compile_fx import compile_fx - ignore_shape_env = False +def _resolve_fake_mode( + gm: GraphModule, dynamic_shapes: DynamicShapesType +) -> FakeTensorMode: if dynamic_shapes == "from_example_inputs": - fake_mode = FakeTensorMode(shape_env=ShapeEnv()) - # tells compile_fx to ignore the shape_envs on the ambient context - # and the graph_module. - ignore_shape_env = True + return FakeTensorMode(shape_env=ShapeEnv()) elif dynamic_shapes == "from_tracing_context": # Reuse fake_mode from the TracingContext. # NB: The TracingContext only exists if we're currently in a torch.compile backend. context = torch._guards.TracingContext.get() assert context.fake_mode is not None - fake_mode = context.fake_mode + return context.fake_mode elif dynamic_shapes == "from_graph": - fake_mode = FakeTensorMode(shape_env=ShapeEnv()) # Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode. # The graph passed to standalone_compile must be an Inductor-approved graph, # which means that there is at least one Tensor output and the output node @@ -408,39 +401,64 @@ def standalone_compile( assert last_node.op == "output" assert len(last_node.args) == 1 - def handle_node(node: torch.fx.Node) -> None: - nonlocal fake_mode - if "example_value" in node.meta: - maybe_tensor = node.meta["example_value"] - if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor): - fake_mode = maybe_tensor.fake_mode - # If gm came from Dynamo, then last_node.args[0] is always a list, # even in single-Tensor returns. # # It's possible to get into a situation where last_node.args[0] # is a Node (and not a list!). This happens if you call split_module # on the graph. We allow for this case since it is common. - if isinstance(last_node.args[0], torch.fx.Node): - handle_node(last_node.args[0]) - else: - for node in last_node.args[0]: - handle_node(node) + nodes = ( + [last_node.args[0]] + if isinstance(last_node.args[0], torch.fx.Node) + else last_node.args[0] + ) + for node in nodes: + if "example_value" in node.meta: + maybe_tensor = node.meta["example_value"] + if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor): + return maybe_tensor.fake_mode + return FakeTensorMode(shape_env=ShapeEnv()) else: raise ValueError( f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}." ) - context = torch._guards.TracingContext(fake_mode) + +@contextlib.contextmanager +def _standalone_context(gm: GraphModule, dynamic_shapes: DynamicShapesType, aot: bool): + from torch.compiler._cache import CacheArtifactManager + + fake_mode = _resolve_fake_mode(gm, dynamic_shapes) + tracing_context = torch._guards.TracingContext(fake_mode) with ( - torch._guards.tracing(context), + torch._guards.tracing(tracing_context), CacheArtifactManager.with_fresh_cache(), config.patch("triton.autotune_at_compile_time", True), torch._functorch.config.patch("bundled_autograd_cache", aot), ): - # compile_fx can mutate gm - gm = copy.deepcopy(gm) + yield + + +def standalone_compile( + gm: GraphModule, + example_inputs: Sequence[InputType], + *, + dynamic_shapes: DynamicShapesType, + options: Any, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache + donate_graph_module: bool = False, +) -> CompiledArtifact: + """ + Implementation of torch.inductor.standalone_compile + """ + from .compile_fx import compile_fx + + ignore_shape_env = _resolve_ignore_shape_env(dynamic_shapes) + with _standalone_context(gm, dynamic_shapes, aot): + # compile_fx takes ownership of gm and may mutate it on cache miss. + if not donate_graph_module: + gm = copy.deepcopy(gm) compiled_fn = compile_fx( gm, example_inputs, ignore_shape_env=ignore_shape_env, **options ) @@ -459,3 +477,20 @@ def handle_node(node: torch.fx.Node) -> None: ) return CacheCompiledArtifact(compiled_fn, artifacts) + + +def autograd_cache_key( + graph, + example_inputs, + dynamic_shapes: DynamicShapesType, + aot: bool = False, # AOT mode, which uses BundledAOTAutogradCache +): + from . import compile_fx + + ignore_shape_env = _resolve_ignore_shape_env(dynamic_shapes) + with _standalone_context(graph, dynamic_shapes, aot): + return compile_fx.autograd_cache_key( + graph, + example_inputs, + ignore_shape_env=ignore_shape_env, + ) diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index c46954752ba82..b8938e8c778af 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, TypeVar +from typing import Any, cast, TypeVar from typing_extensions import ParamSpec import torch @@ -14,7 +14,7 @@ from . import ir from .exc import SubgraphLoweringException from .graph import GraphLowering -from .ops_handler import SimpleCSEHandler +from .ops_handler import OpsHandler, SimpleCSEHandler from .virtualized import ops, V, WrapperHandler @@ -137,7 +137,7 @@ class InputDescriptor: class TracingOpsHandler(WrapperHandler): def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: parent = tracer.create_proxy("placeholder", "ops", (), {}) - super().__init__(parent) + super().__init__(cast(OpsHandler[Any], parent)) self.tracer = tracer self.placeholders = [ diff --git a/torch/_inductor/template_heuristics/aten.py b/torch/_inductor/template_heuristics/aten.py index 1781796c44cac..ad496c36a40e4 100644 --- a/torch/_inductor/template_heuristics/aten.py +++ b/torch/_inductor/template_heuristics/aten.py @@ -86,6 +86,9 @@ def _get_template_configs_impl( nodes = kernel_inputs.nodes() # for addmm, bias is the first input bias = nodes[0] - # Conditions should be checked in tuned_addmm before adding this template - assert bias.get_stride()[0] == 0 and inductor_config.triton.autotune_cublasLt + assert ( + len(bias.get_size()) == 2 + and bias.get_stride()[0] == 0 + and inductor_config.triton.autotune_cublasLt + ) yield from super()._get_template_configs_impl(kernel_inputs, op_name) diff --git a/torch/_inductor/template_heuristics/nv_universal_gemm.py b/torch/_inductor/template_heuristics/nv_universal_gemm.py index 8ae0415bccc01..adbdfae241df7 100644 --- a/torch/_inductor/template_heuristics/nv_universal_gemm.py +++ b/torch/_inductor/template_heuristics/nv_universal_gemm.py @@ -23,10 +23,10 @@ autotuning_log = getArtifactLogger(__name__, "autotuning") # Type alias for kernel config key tuple. -# Currently matches on (tile_m, tile_n, tile_k, cluster_m, cluster_n). -# #TODO(nikhilap) When cutlass_api adds support for stages/split_k, extend this tuple and -# update the _make_config_key_* helper functions below. -ConfigKey = tuple[int, int, int, int, int] +# Currently matches on (tile_m, tile_n, cluster_m, cluster_n). +# tile_k excluded because nvMatmulHeuristics and cutlass_api use it to mean different things. +# TODO(nikhilap): Extend config key for stages/split_k https://github.com/pytorch/pytorch/issues/177578 +ConfigKey = tuple[int, int, int, int] @dataclass @@ -48,21 +48,20 @@ class HeuristicConfig: def _make_config_key_from_heuristic(cfg: HeuristicConfig) -> ConfigKey: """Build config key from HeuristicConfig returned by nvMatmulHeuristics.""" - return (cfg.tile_m, cfg.tile_n, cfg.tile_k, cfg.cluster_m, cfg.cluster_n) + return (cfg.tile_m, cfg.tile_n, cfg.cluster_m, cfg.cluster_n) def _make_config_key_from_kernel_design(design) -> ConfigKey | None: """Build config key from cutlass_api kernel metadata.design.""" if ( hasattr(design, "tile_shape") - and len(design.tile_shape) >= 3 + and len(design.tile_shape) >= 2 and hasattr(design, "cluster_shape") and len(design.cluster_shape) >= 2 ): return ( design.tile_shape[0], design.tile_shape[1], - design.tile_shape[2], design.cluster_shape[0], design.cluster_shape[1], ) @@ -74,7 +73,6 @@ def _make_config_key_from_heuristics_kernel(kernel) -> ConfigKey: return ( kernel.cta[0], kernel.cta[1], - kernel.cta[2], kernel.cluster[0], kernel.cluster[1], ) diff --git a/torch/_inductor/template_heuristics/tlx.py b/torch/_inductor/template_heuristics/tlx.py index 83381687cd5de..e75457aceb1cf 100644 --- a/torch/_inductor/template_heuristics/tlx.py +++ b/torch/_inductor/template_heuristics/tlx.py @@ -3,5 +3,3 @@ if config.is_fbcode(): import torch._inductor.fb.tlx_templates.registry # noqa: F401 # type: ignore[import-not-used] - -# TODO. Move the registry to this file once the TLX template is more complete. diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 89999fc2f3636..31d4a7403efe1 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -13,6 +13,7 @@ import torch from torch._inductor.template_heuristics.triton_addmm import AddMMConfigMixin from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.functions import Mod from torch.utils._triton import has_triton_stable_tma_api from .. import config, config as inductor_config @@ -22,6 +23,7 @@ get_scaling_options, get_tile_size, mm_template, + persistent_mm_template, persistent_tma_mm_template, scaled_mm_device_tma_epilogue_scaling_template, scaled_mm_device_tma_main_loop_scaling_template, @@ -30,6 +32,7 @@ from ..kernel_inputs import KernelInputs, MMKernelInputs from ..utils import ( get_backend_num_stages, + get_default_kpack, get_num_sms, get_tma_workspace_arg, TMA_DESCRIPTOR_SIZE, @@ -176,7 +179,7 @@ class ROCmGemmConfig(GemmConfig): matrix_instr_nonkdim: int = 16 waves_per_eu: int = 0 - kpack: int = 2 + kpack: int = 1 @dataclasses.dataclass @@ -187,7 +190,7 @@ class ROCmConvConfig(ConvConfig): matrix_instr_nonkdim: int = 16 waves_per_eu: int = 0 - kpack: int = 2 + kpack: int = 1 @dataclasses.dataclass @@ -198,7 +201,7 @@ class ROCmFlexConfig(FlexConfig): matrix_instr_nonkdim: int = 0 waves_per_eu: int = 0 - kpack: int = 2 + kpack: int = 1 @dataclasses.dataclass @@ -209,7 +212,7 @@ class ROCmFlexBwDConfig(FlexBwDConfig): matrix_instr_nonkdim: int = 0 waves_per_eu: int = 0 - kpack: int = 2 + kpack: int = 1 @dataclasses.dataclass @@ -220,7 +223,7 @@ class ROCmFlexDecodeConfig(FlexDecodeConfig): matrix_instr_nonkdim: int = 0 waves_per_eu: int = 0 - kpack: int = 2 + kpack: int = 1 class BaseHeuristicSingleton(type): @@ -779,6 +782,12 @@ def __init__(self) -> None: for num_warps in [2, 4, 8] ] + def _get_extra_config_key_and_kwargs( + self, conf: BaseConfig + ) -> tuple[tuple[int | None, ...], dict[str, Any]]: + """Hook for subclasses to extend config dedup key and kwargs.""" + return (), {} + def _finalize_mm_configs( self, configs: list[BaseConfig], @@ -813,16 +822,8 @@ def _finalize_mm_configs( if isinstance(conf, BlackwellGPUGemmConfig): key += (conf.epilogue_subtile, conf.warp_specialize, conf.flatten) - # Add TlxGemmConfig specific fields to key if present - if config.is_fbcode() and config.triton.tlx_mode in ("allow", "force"): - from torch._inductor.fb.tlx_templates.registry import ( - get_tlx_config_key_and_kwargs, - ) - - tlx_key_fields, tlx_kwargs = get_tlx_config_key_and_kwargs(conf) - key += tlx_key_fields - else: - tlx_kwargs = {} + extra_key, extra_kwargs = self._get_extra_config_key_and_kwargs(conf) + key += extra_key if key not in used and ( max_mm_configs is None or len(used) < max_mm_configs @@ -843,8 +844,7 @@ def _finalize_mm_configs( kwargs["WARP_SPECIALIZE"] = conf.warp_specialize kwargs["FLATTEN"] = conf.flatten - # Add TlxGemmConfig specific fields if present - kwargs.update(tlx_kwargs) + kwargs.update(extra_kwargs) yield self.triton_config(conf.num_stages, num_warps, **kwargs) @@ -1282,6 +1282,7 @@ def __init__(self) -> None: self.h100_default_flex_config = { (torch.float32, 64): FlexConfig(128, 32, 3, 4), (torch.float32, 128): FlexConfig(32, 64, 3, 4), + (torch.float32, 192): FlexConfig(32, 64, 1, 8), (torch.float32, 256): FlexConfig(32, 32, 3, 4), (torch.bfloat16, 64): FlexConfig(128, 128, 3, 4), (torch.bfloat16, 128): FlexConfig(128, 64, 3, 8), @@ -1547,23 +1548,26 @@ def __init__(self) -> None: for group_m in [4, 8, 16] for matrix_instr_nonkdim in [0, 16] for waves_per_eu in [0, 2] - for kpack in [2] + for kpack in [1, 2] ] + # Architecture-aware default kpack for flex configs + default_kpack = get_default_kpack() + self.default_flex_config = { - (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4), - (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4), - (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4), - (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 2, 4), - (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 2, 4), - (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 2, 4), - (torch.float16, 64): ROCmFlexConfig(128, 64, 2, 8), - (torch.float16, 128): ROCmFlexConfig(128, 64, 2, 8), - (torch.float16, 256): ROCmFlexConfig(32, 64, 2, 4), + (torch.float32, 64): ROCmFlexConfig(128, 32, 1, 4, kpack=default_kpack), + (torch.float32, 128): ROCmFlexConfig(128, 32, 1, 4, kpack=default_kpack), + (torch.float32, 256): ROCmFlexConfig(64, 16, 1, 4, kpack=default_kpack), + (torch.bfloat16, 64): ROCmFlexConfig(128, 64, 2, 4, kpack=default_kpack), + (torch.bfloat16, 128): ROCmFlexConfig(128, 64, 2, 4, kpack=default_kpack), + (torch.bfloat16, 256): ROCmFlexConfig(32, 64, 2, 4, kpack=default_kpack), + (torch.float16, 64): ROCmFlexConfig(128, 64, 2, 4, kpack=default_kpack), + (torch.float16, 128): ROCmFlexConfig(128, 64, 2, 4, kpack=default_kpack), + (torch.float16, 256): ROCmFlexConfig(32, 64, 2, 4, kpack=default_kpack), } self.flex_attn_fwd_autotune_configs: list[FlexConfig] = [ - ROCmFlexConfig(BLOCK1, BLOCK2, 1, w) + ROCmFlexConfig(BLOCK1, BLOCK2, 1, w, kpack=default_kpack) for BLOCK1 in [16, 64, 128] for BLOCK2 in [16, 32, 64, 128] for w in [4, 8] @@ -1571,7 +1575,9 @@ def __init__(self) -> None: self.flex_attn_bwd_autotune_configs: list[FlexBwDConfig] = [ # See Note: flex bwd configs - ROCmFlexBwDConfig(BLOCK1, BLOCK2, BLOCK2, BLOCK1, 1, w, mfma) + ROCmFlexBwDConfig( + BLOCK1, BLOCK2, BLOCK2, BLOCK1, 1, w, mfma, kpack=default_kpack + ) for BLOCK1 in [16, 32, 64] for BLOCK2 in [32, 64, 128] for w in ([4, 8] if BLOCK1 >= 128 or BLOCK2 >= 128 else [4]) @@ -1580,22 +1586,23 @@ def __init__(self) -> None: ] self.flex_decode_autotune_configs: list[FlexDecodeConfig] = [ - ROCmFlexDecodeConfig(32, 1, 4), - ROCmFlexDecodeConfig(64, 1, 4), - ROCmFlexDecodeConfig(128, 1, 4), - ROCmFlexDecodeConfig(32, 1, 8), - ROCmFlexDecodeConfig(64, 1, 8), - ROCmFlexDecodeConfig(128, 1, 8), + ROCmFlexDecodeConfig(32, 1, 4, kpack=default_kpack), + ROCmFlexDecodeConfig(64, 1, 4, kpack=default_kpack), + ROCmFlexDecodeConfig(128, 1, 4, kpack=default_kpack), + ROCmFlexDecodeConfig(32, 1, 8, kpack=default_kpack), + ROCmFlexDecodeConfig(64, 1, 8, kpack=default_kpack), + ROCmFlexDecodeConfig(128, 1, 8, kpack=default_kpack), ] self.exhaustive_flex_attn_fwd_configs: list[FlexConfig] = [ - ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu) + ROCmFlexConfig(BLOCK_M, BLOCK_N, num_stages, num_warps, mfma, wpeu, kpack) for BLOCK_M in [16, 32, 64, 128] for BLOCK_N in [32, 64, 128] for num_stages in [1, 2] for num_warps in [2, 4, 8] for mfma in [0, 16] for wpeu in [0, int(8 // num_warps)] + for kpack in [1, 2] ] self.exhaustive_flex_attn_bwd_configs: list[FlexBwDConfig] = [ @@ -1609,6 +1616,7 @@ def __init__(self) -> None: num_warps, mfma, wpeu, + kpack, ) for BLOCK_M1 in [16, 32, 64, 128] for BLOCK_N1 in [16, 32, 64, 128] @@ -1618,17 +1626,21 @@ def __init__(self) -> None: for num_warps in [2, 4, 8] for mfma in [0, 16] for wpeu in [0, int(8 // num_warps)] + for kpack in [1, 2] if BLOCK_N1 % BLOCK_M1 == 0 and BLOCK_M2 % BLOCK_N2 == 0 # kernel static assertions ] self.exhaustive_flex_decode_configs: list[FlexDecodeConfig] = [ - ROCmFlexDecodeConfig(block_n, num_stages, num_warps, mfma, wpeu, kpack=2) + ROCmFlexDecodeConfig( + block_n, num_stages, num_warps, mfma, wpeu, kpack=kpack + ) for block_n in [16, 32, 64, 128] for num_stages in [1, 2] for num_warps in [2, 4, 8] for mfma in [0, 16] for wpeu in [0, int(8 // num_warps)] + for kpack in [1, 2] ] def _prune_exhaustive_configs( @@ -1641,8 +1653,11 @@ def _prune_exhaustive_configs( c for c in configs if not ( - getattr(c, "matrix_instr_nonkdim", 0) == 2 - and getattr(c, "kpack", 0) == 2 + ( + getattr(c, "matrix_instr_nonkdim", 0) == 2 + and getattr(c, "kpack", 0) == 2 + ) + or (c.block_k <= 16 and getattr(c, "kpack", 0) == 2) ) ] return pruned_configs @@ -1671,9 +1686,11 @@ def _finalize_mm_configs( conf.num_warps = min(conf.num_warps, conf.block_m * conf.block_n // 256) # Defaults for AMD triton backend kern args if not set - matrix_instr_nonkdim = getattr(conf, "matrix_instr_nonkdim", 16) - waves_per_eu = getattr(conf, "waves_per_eu", 0) - kpack = getattr(conf, "kpack", 2) + matrix_instr_nonkdim: int = getattr(conf, "matrix_instr_nonkdim", 16) + waves_per_eu: int = getattr(conf, "waves_per_eu", 0) + # Use explicit kpack if set, otherwise determine optimal value based on + # architecture and BLOCK_K + kpack: int = getattr(conf, "kpack", get_default_kpack(conf.block_k)) if matrix_instr_nonkdim != 0 and ( conf.block_m % matrix_instr_nonkdim != 0 @@ -1731,19 +1748,20 @@ def get_flex_attn_fwd_configs(self, head_dim: int, dtype: Any) -> list[FlexConfi return self.exhaustive_flex_attn_fwd_configs flex_attn_fwd_configs += self.flex_attn_fwd_autotune_configs + default_kpack = get_default_kpack() if head_dim <= 256: if dtype == torch.float32: - default_config = ROCmFlexConfig(64, 64, 1, 4) + default_config = ROCmFlexConfig(64, 64, 1, 4, kpack=default_kpack) else: - default_config = ROCmFlexConfig(128, 64, 2, 4) + default_config = ROCmFlexConfig(128, 64, 2, 4, kpack=default_kpack) default_config = self.default_flex_config.get( (dtype, head_dim), default_config ) else: if dtype == torch.float32: - default_config = ROCmFlexConfig(32, 16, 1, 4) + default_config = ROCmFlexConfig(32, 16, 1, 4, kpack=default_kpack) else: - default_config = ROCmFlexConfig(64, 32, 2, 4) + default_config = ROCmFlexConfig(64, 32, 2, 4, kpack=default_kpack) if default_config not in flex_attn_fwd_configs: flex_attn_fwd_configs.append(default_config) @@ -1760,17 +1778,28 @@ def get_flex_attn_bwd_configs( return self.exhaustive_flex_attn_bwd_configs flex_attn_bwd_configs += self.flex_attn_bwd_autotune_configs + default_kpack = get_default_kpack() if dtype == torch.float32: - default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + default_config = ROCmFlexBwDConfig( + 16, 16, 16, 16, 1, 4, kpack=default_kpack + ) elif head_dim <= 256: if head_dim == 64: - default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + default_config = ROCmFlexBwDConfig( + 64, 64, 64, 64, 1, 4, kpack=default_kpack + ) elif head_dim == 128: - default_config = ROCmFlexBwDConfig(64, 128, 128, 64, 1, 4) + default_config = ROCmFlexBwDConfig( + 64, 128, 128, 64, 1, 4, kpack=default_kpack + ) else: - default_config = ROCmFlexBwDConfig(64, 64, 64, 64, 1, 4) + default_config = ROCmFlexBwDConfig( + 64, 64, 64, 64, 1, 4, kpack=default_kpack + ) else: - default_config = ROCmFlexBwDConfig(16, 16, 16, 16, 1, 4) + default_config = ROCmFlexBwDConfig( + 16, 16, 16, 16, 1, 4, kpack=default_kpack + ) if default_config not in flex_attn_bwd_configs: flex_attn_bwd_configs.append(default_config) @@ -1787,7 +1816,8 @@ def get_flex_decode_configs( return self.exhaustive_flex_decode_configs flex_decode_configs += self.flex_decode_autotune_configs - default_config = ROCmFlexDecodeConfig(64, 1, 4) + default_kpack = get_default_kpack() + default_config = ROCmFlexDecodeConfig(64, 1, 4, kpack=default_kpack) if default_config not in flex_decode_configs: flex_decode_configs.append(default_config) @@ -1802,6 +1832,10 @@ class XPUConfigHeuristic(BaseConfigHeuristic): def __init__(self) -> None: super().__init__() + self.mm_configs = self.mm_configs + [ + GemmConfig(32, 64, 128, 2, 2), + GemmConfig(64, 64, 32, 2, 8), + ] self.xpu_default_flex_config = { (torch.float32, 64): FlexConfig(128, 32, 1, 16), (torch.float32, 128): FlexConfig(128, 32, 1, 16), @@ -2040,7 +2074,7 @@ def _convert_config_to_template_kwargs( Moved from mm_common.mm_options. """ # Calculate EVEN_K symbolic. (It isn't worth guarding on this) - even_k_symbolic = (k % triton_config.kwargs["BLOCK_K"]) == 0 + even_k_symbolic = sympy.Eq(Mod(k, triton_config.kwargs["BLOCK_K"]), 0) even_k_symbolic = V.graph.sizevars.statically_known_true(even_k_symbolic) # Build options dict @@ -2049,6 +2083,7 @@ def _convert_config_to_template_kwargs( EVEN_K=even_k_symbolic, USE_FAST_ACCUM=False, # Option for _scaled_mm ACC_TYPE=self._get_acc_type(out_dtype), + OUT_DTYPE=self._get_out_dtype(out_dtype), num_stages=triton_config.num_stages, num_warps=triton_config.num_warps, **triton_config.kwargs, @@ -2061,6 +2096,11 @@ def _convert_config_to_template_kwargs( return options_dict + @staticmethod + def _dtype_to_triton(dtype: torch.dtype) -> str: + """Convert a torch dtype to a triton type string.""" + return f"tl.{dtype}".replace("torch.", "") + def _get_acc_type(self, dtype: torch.dtype) -> str: """ Get accumulator type for the given dtype. @@ -2068,7 +2108,11 @@ def _get_acc_type(self, dtype: torch.dtype) -> str: """ if dtype in (torch.float16, torch.bfloat16): return "tl.float32" - return f"tl.{dtype}".replace("torch.", "") + return self._dtype_to_triton(dtype) + + def _get_out_dtype(self, dtype: torch.dtype) -> str: + """Get output dtype as a triton type string.""" + return self._dtype_to_triton(dtype) # INT8 specific mixin to filter correctly @@ -2291,7 +2335,10 @@ def adjust_kernel_inputs( if bias: nodes.append(bias) return MMKernelInputs( - nodes, mat1_idx=kernel_inputs._mat1_idx, mat2_idx=kernel_inputs._mat2_idx + nodes, + mat1_idx=kernel_inputs._mat1_idx, + mat2_idx=kernel_inputs._mat2_idx, + out_dtype=kernel_inputs._out_dtype, ) def _get_template_configs_impl( @@ -2533,6 +2580,33 @@ def __init__(self) -> None: self.mm_configs = self.persistent_mm_configs +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class PersistentMMTemplateConfigHeuristic( + MMTemplateConfigMixin, + ROCmConfigHeuristic, # type: ignore[misc] +): + """Persistent MM template heuristic (no TMA, standard pointer loads)""" + + def __init__(self) -> None: + super().__init__() + self.mm_configs = self.persistent_mm_configs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + **kwargs, + ) -> Generator[dict[str, Any], None, None]: + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, op_name, **kwargs + ): + yield {**template_kwargs, "NUM_SMS": get_num_sms()} + + @register_template_heuristic( blackwell_ws_persistent_device_tma_mm_template.uid, "cuda", @@ -2684,7 +2758,8 @@ def _get_template_configs_impl( @register_template_heuristic( - blackwell_ws_persistent_device_tma_mm_template.uid, # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin + # regular Blackwell MM template + scaling epilogue from ScaledMMConfigMixin + blackwell_ws_persistent_device_tma_mm_template.uid, "cuda", register=torch.version.hip is None, op_name="scaled_mm", @@ -2770,6 +2845,18 @@ class ROCmAddMMTemplateConfigHeuristic(AddMMConfigMixin, ROCmMMTemplateConfigHeu """Addmm specific mixin for ROCm""" +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="addmm", +) +class ROCmAddMMPersistentTemplateConfigHeuristic( + AddMMConfigMixin, PersistentMMTemplateConfigHeuristic +): + """Addmm specific mixin for persistent MM on ROCm""" + + # TODO(coconutruben): deprecate once autoheuristic is deprecated @register_template_heuristic("mm-ah", "cuda", register=torch.version.hip is not None) class ROCmMMAHTemplateConfigHeuristic(MMTemplateConfigMixin, ROCmConfigHeuristic): @@ -2795,11 +2882,10 @@ def __init__(self) -> None: super().__init__() # Override mm_configs to use scaled_mm_configs self.mm_configs = self.scaled_mm_configs - # NOTE: overriding exhaustive configs here to be the same as mm_configs - # as we haven't validated exhaustive support here yet - # TODO(coconutruben): remove this once we have validated exhaustive support - # for scaled_mm - self.exhaustive_configs = self.scaled_mm_configs + + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + configs = [c for c in configs if c.block_k >= 32] + return super()._filter_configs(configs) @register_template_heuristic( diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 5137797e84852..509c5694176cc 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -786,7 +786,7 @@ def analyze_memory_coalescing( continue # TODO - if a var is in the middle, such as [n0, n1, n2] - # n1 can can be split beyond range + # n1 can be split beyond range MIN_TILING_BLOCK = 8 if not all( diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index c2351d1506b18..ed095541d9a2e 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -9,6 +9,7 @@ from torch._dynamo.utils import counters, dynamo_timed, set_feature_use from torch._utils_internal import justknobs_check from torch.utils._filelock import FileLock +from torch.utils._ordered_set import OrderedSet from .runtime.runtime_utils import triton_cache_dir from .utils import _IS_WINDOWS, GPU_KERNEL_BIN_EXTS @@ -106,6 +107,7 @@ class TritonBundler: _entries: list[TritonBundleEntry] | None = None _static_autotuners: list[StaticallyLaunchedAutotuner] | None = None + _winners: OrderedSet[str] | None = None # __grp__kernel_name.json contains metadata with source code paths # we use this as sentinel value for search and replace @@ -140,6 +142,7 @@ def begin_compile(cls) -> None: assert cls._entries is None cls._entries = [] cls._static_autotuners = [] + cls._winners = OrderedSet() @classmethod def end_compile(cls) -> None: @@ -150,6 +153,7 @@ def end_compile(cls) -> None: log.debug("TritonBundler.end_compile is called") cls._entries = None cls._static_autotuners = None + cls._winners = None @classmethod def put(cls, kernel_hash: str, device: int) -> None: @@ -162,6 +166,17 @@ def put(cls, kernel_hash: str, device: int) -> None: TritonBundleEntry(kernel_hash, device, triton_cache_dir(device)) ) + @classmethod + def put_winner(cls, kernel_hash: str) -> None: + """ + Marks a kernel hash as a winning autotuning config. Only winning + kernels are included in the bundle by collect(). If no winners are + recorded (e.g. single-config kernels that skip autotuning), all + entries are bundled. + """ + if cls._winners is not None: + cls._winners.add(kernel_hash) + @classmethod def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: # type: ignore[name-defined] # noqa: F821 from torch._inductor import config @@ -262,9 +277,17 @@ def collect( with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True): entries = cls._entries if entries is not None: + # Only bundle winning autotuning configs. If _winners is + # non-empty, skip entries whose kernel_hash is not a winner. + # When _winners is empty (single-config kernels, or no + # autotuning ran), bundle everything. + winners = cls._winners result: list[TritonKernelArtifacts] = [] kernel_names: list[str] = [] for entry in entries: + if winners and entry.kernel_hash not in winners: + log.debug("Skipping non-winning kernel %s", entry.kernel_hash) + continue artifacts: list[TritonKernelArtifact] = [] path = os.path.join(entry.directory, entry.kernel_hash) if not os.path.exists(path): diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 1a11a1572249a..ccf9134cbd686 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -23,6 +23,7 @@ import textwrap import time import unittest +import warnings from collections.abc import ( Callable, Collection, @@ -68,6 +69,7 @@ "inductor_autotune_lookup_table", ] +from torch.fx.experimental._size_hinting import _sympy_subs from torch.fx.experimental.symbolic_shapes import ( free_symbols, free_unbacked_symbols, @@ -420,7 +422,7 @@ def _do_bench_using_profiling( @functools.cache def has_torchvision_roi_align() -> bool: try: - from torchvision.ops import roi_align # noqa: F401 + from torchvision.ops import roi_align torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") return roi_align is not None and hasattr( @@ -485,6 +487,8 @@ def _type_of(key: torch.dtype | None) -> str: "float8e4b15x4": "fp8e4b15x4", "float8_e4m3fn": "fp8e4nv", "float8_e5m2": "fp8e5", + "float8_e4m3fnuz": "fp8e4b8", + "float8_e5m2fnuz": "fp8e5b16", # TODO: remove when support is added in triton # https://github.com/triton-lang/triton/issues/6054 "float8_e8m0fnu": "u8", @@ -1015,9 +1019,17 @@ def stringfy_layout(layout: ir.Layout | None) -> str: all_writes.append("%" + output_name) for node in inductor_nodes: - detailed_metadata.append( - f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}" - ) + formatted_node = node.format_node(include_tensor_metadata=True) + if formatted_node is not None and torch.version.hip: + # AMDGCN asm strings can contain newlines, which propagate + # into format_node() output. Split so every line gets the + # comment prefix; otherwise bare newlines break the wrapper. + detailed_metadata.extend( + f"{wrapper.comment} {line}" + for line in formatted_node.splitlines() + ) + else: + detailed_metadata.append(f"{wrapper.comment} {formatted_node}") detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}") @@ -1171,22 +1183,7 @@ def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.E When the passed replacement symbol v is a string, it is converted to a symbol with name v that have the same replaced expression integer and nonnegative properties. """ - - def to_symbol(replaced: sympy.Expr, replacement: sympy.Expr | str) -> sympy.Symbol: - assert isinstance(replaced, sympy.Expr) - if isinstance(replacement, str): - return sympy.Symbol( - replacement, - integer=replaced.is_integer, # type: ignore[attr-defined] - nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] - ) - else: - return replacement - - # xreplace is faster than subs, but is way more picky - return sympy.sympify(expr).xreplace( - {k: to_symbol(k, v) for k, v in replacements.items()} - ) + return _sympy_subs(expr, replacements) def is_symbolic(a: Any) -> TypeGuard[torch.SymInt | torch.Tensor]: @@ -1771,6 +1768,14 @@ def get_tma_workspace_arg( ) +def get_default_kpack(block_k: int = 16) -> int: + if not torch.version.hip: + return 0 + if "gfx942" in torch.cuda.get_device_properties(0).gcnArchName and block_k <= 16: + return 1 + return 2 + + def _use_template_for_gpu( layout: Layout, allowed_layout_dtypes: list[torch.dtype] ) -> bool: @@ -1961,15 +1966,53 @@ def _is_tma_compatible_xpu( ) +def _descriptor_shape_fits_in_int32( + sizes: Sequence[sympy.Expr], add_guards: bool = False +) -> bool: + int32_max = torch.iinfo(torch.int32).max + conditions = [] + for size in sizes: + if isinstance(size, (int, sympy.Integer)): + if size > int32_max: + return False + else: + conditions.append(sympy.Le(size, int32_max)) + + if not conditions: + return True + + from .virtualized import V + + condition = conditions[0] if len(conditions) == 1 else sympy.And(*conditions) + return ( + V.graph.sizevars.guard_or_false(condition) + if add_guards + else V.graph.sizevars.statically_known_true(condition) + ) + + def use_triton_tma_template( *matrices: IRNode, output_layout: Layout, add_guards: bool = False ) -> bool: + if not config.triton.enable_persistent_tma_matmul: + return False + if not all(len(m.get_size()) == 2 for m in matrices): + return False + if not all( + _descriptor_shape_fits_in_int32(m.get_size(), add_guards=add_guards) + for m in matrices + ): + return False + if config.triton.enable_template_tma_store and not _descriptor_shape_fits_in_int32( + output_layout.size, add_guards=add_guards + ): + return False + # On AMD (HIP), TMA is not available but we still use non-TMA persistent + # kernels, so skip the TMA compatibility checks. + if torch.version.hip is not None: + return True layout = output_layout if config.triton.enable_template_tma_store else None - return ( - all(len(m.get_size()) == 2 for m in matrices) - and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards) - and config.triton.enable_persistent_tma_matmul - ) + return can_use_tma(*matrices, output_layout=layout, add_guards=add_guards) def use_triton_blackwell_tma_template( @@ -2142,9 +2185,9 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: # for the compiled CUTLASS .so, similar to how the triton branch uses # static CUfunction + loadKernel for non-AOT mode. if V.graph.cpp_wrapper and not V.graph.aot_mode: - log.warning( + warnings.warn( "CUTLASS backend is not supported with non-AOT cpp_wrapper mode. " - "Skipping CUTLASS backend." + "Skipping CUTLASS backend.", ) return False @@ -2603,7 +2646,7 @@ def run_and_get_kernels( result, source_codes = run_and_get_code(fn, *args, **kwargs) kernels = [] for code in source_codes: - if config.cpp_wrapper and config.triton.autotune_at_compile_time is False: + if config.cpp_wrapper and config.triton.autotune_at_compile_time is not True: # With lazy Triton kernel compilation, kernel sources are embedded # inside C++ R"TRITON(...)TRITON" raw strings. kernels.extend(re.findall(r'R"TRITON\((.*?)\)TRITON"', code, re.DOTALL)) @@ -3200,6 +3243,22 @@ def is_saved_tensor(x: Node) -> bool: return len(static_arg_idxs) +def get_static_bw_input_idxs(fx_g: torch.fx.GraphModule) -> list[int]: + """ + Returns indices of backward graph inputs that are always at fixed + addresses: primals (parameters/buffers/user inputs saved for backward). + Excludes saved activations which may not be at fixed addresses when + the forward is partitioned for CUDA graphs. + """ + static_idxs = [] + for idx, n in enumerate(fx_g.graph.nodes): + if n.op != "placeholder": + break + if n.name.startswith("primals_"): + static_idxs.append(idx) + return static_idxs + + @dataclasses.dataclass class BoxedBool: value: bool @@ -3524,7 +3583,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: int_max = torch.iinfo(torch.int32).max guarding_hint_or_throw = V.graph.sizevars.guarding_hint_or_throw - has_hint = V.graph.sizevars.shape_env.has_hint + has_guarding_hint = V.graph.sizevars.shape_env.has_guarding_hint if config.assume_32bit_indexing: V.graph.sizevars.check_leq(e, int_max) # type: ignore[arg-type] @@ -3555,7 +3614,7 @@ def expr_fits_within_32bit(e: sympy.Expr) -> bool: return False # Otherwise, the hint MUST exist and be in range - return has_hint(e) and guarding_hint_or_throw(e) <= int_max + return has_guarding_hint(e) and guarding_hint_or_throw(e) <= int_max def set_tracing_context_output_strides( @@ -4317,7 +4376,8 @@ def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, An def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def] return isinstance(x, torch._inductor.ir.IRNode) and not isinstance( - x, torch._inductor.ir.GeneratorState + x, + (torch._inductor.ir.GeneratorState, torch._inductor.ir.OpaqueObjectState), ) flat_args = [ diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index e277bcb8b6c8a..c083821dba77f 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -351,6 +351,7 @@ def _wrap(x): return OpsValue(x) @staticmethod + # pyrefly: ignore [bad-override] def indirect_indexing(index, size, check=True, wrap_neg=True): # Returns a sympy value, not IR value index = OpsWrapper._unwrap(index) diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 449c46de09c1e..506adb7686853 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -10,6 +10,7 @@ import collections import contextlib import enum +import functools import inspect import io import pickle @@ -376,8 +377,7 @@ def get_closure(fn): # annotations on `eg``, but starting in Python 4.0, they will represented as # strings and no longer present. Furthermore, since the body of `eg` does # not reference those names, they do not appear in the list of closed over -# variables. In Python 2.x, type annotations are in comments, leading to a -# similar situation where their definitions are not available. We anticipate +# variables. We anticipate # that most users will not run into this issue because their modules and # functions will be defined at a global scope like MyGlobalClass. In cases # where they are not, it is possible to work around issues by declaring the @@ -409,6 +409,7 @@ def __getattr__(self, key: str) -> Any: return createResolutionCallbackFromEnv(closure_lookup()) +@functools.cache def can_compile_class(cls) -> bool: # If any of the functions on a type don't have a code object, this type can't # be compiled and is probably a builtin / bound from C @@ -1097,7 +1098,24 @@ def get_class_name_lineno(method) -> tuple[str, int]: def _overload_method(func): - _check_overload_body(func) + try: + _check_overload_body(func) + except IndentationError: + # CPython 3.13.8 has a bug (https://github.com/python/cpython/issues/139783) + # where inspect.getsourcelines() returns truncated source when a decorator + # is followed by a comment, causing ast.parse() to fail with IndentationError. + # Fixed in 3.13.9. Swallow the error on affected versions; re-raise otherwise. + if sys.version_info[:3] == (3, 13, 8): + import warnings + + warnings.warn( + "Skipping overload body check due to a known CPython 3.13.8 bug " + "(https://github.com/python/cpython/issues/139783). " + "Consider upgrading to Python 3.13.9+.", + stacklevel=2, + ) + else: + raise qual_name = _qualified_name(func) global _overloaded_methods class_name_map = _overloaded_methods.get(qual_name) @@ -1361,8 +1379,7 @@ def _disable_emit_hooks(): torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) -def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 - # noqa: F841 +def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: def __enter__(self) -> None: self.hooks = torch._C._jit_get_emit_hooks() torch._C._jit_set_emit_hooks(None, None) diff --git a/torch/_library/_out_variant.py b/torch/_library/_out_variant.py index 693861daaa41a..df620fdda4f44 100644 --- a/torch/_library/_out_variant.py +++ b/torch/_library/_out_variant.py @@ -8,6 +8,26 @@ log = logging.getLogger(__name__) +# Manual registry for ops whose out variant is not discoverable via +# to_out_variant() (e.g. flat _out naming instead of .out overload). +_manual_out_variant_registry: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} + + +def register_out_variant( + functional_op: torch._ops.OpOverload, + out_op: torch._ops.OpOverload, +) -> None: + """Register a functional op -> out variant mapping.""" + _manual_out_variant_registry[functional_op] = out_op + + +def lookup_manual_out_variant( + op: torch._ops.OpOverload, +) -> torch._ops.OpOverload | None: + """Return the manually registered out variant for op, or None.""" + return _manual_out_variant_registry.get(op) + + def _is_functional(schema: torch._C.FunctionSchema) -> bool: """ A schema is functional if no argument is written to and the name doesn't @@ -90,7 +110,7 @@ def to_out_variant(op: torch._ops.OpOverload) -> torch._ops.OpOverload | None: candidate = getattr(torch_packet, overload_name) # pyrefly: ignore [missing-attribute] - if torch.Tag.out_variant not in candidate.tags: + if torch.Tag.out not in candidate.tags: continue candidate_schema = candidate._schema @@ -129,11 +149,11 @@ def check_out_variant( tagged_info = _get_out_variants_info(functional_op) raise AssertionError( f"We did not find an out variant for {functional_op}. Some common mistakes include:\n" - " 1. The out variant is missing the torch.Tag.out_variant tag.\n" + " 1. The out variant is missing the torch.Tag.out tag.\n" " 2. The out variant is not an overload of the original op (e.g., 'op.out' or 'op.overload_out') \n" " 3. The out variant's input arguments does not match the functional op's signature (excluding the mutable args).\n" " 4. The original operator is not functional.\n" - f"Overloads tagged with out_variant:\n" + f"Overloads tagged with out:\n" f"{tagged_info or ' (none)'}" ) if out_op != expected_out_op: @@ -145,7 +165,7 @@ def check_out_variant( def _get_out_variants_info(functional_op) -> str: - """Collect information about overloads tagged with out_variant for debugging.""" + """Collect information about overloads tagged with out for debugging.""" namespace = functional_op.namespace op_name = functional_op._schema.name.split("::")[1] torch_packet = getattr(getattr(torch.ops, namespace), op_name) @@ -154,7 +174,7 @@ def _get_out_variants_info(functional_op) -> str: for overload_name in torch_packet.overloads(): candidate = getattr(torch_packet, overload_name) # pyrefly: ignore [missing-attribute] - if torch.Tag.out_variant in candidate.tags: + if torch.Tag.out in candidate.tags: overloads_info.append(f" - {overload_name}: {candidate._schema}") return "\n".join(overloads_info) diff --git a/torch/_library/custom_ops.py b/torch/_library/custom_ops.py index dcbfc6ea18034..f09975b1872e7 100644 --- a/torch/_library/custom_ops.py +++ b/torch/_library/custom_ops.py @@ -654,7 +654,9 @@ def _register_to_dispatcher(self, tags: Sequence[_C.Tag]) -> None: def fake_impl(*args, **kwargs): if self._abstract_fn is None: if utils.can_generate_trivial_fake_impl(self._opoverload): - return None + return utils.generate_trivial_fake_impl( + self._opoverload, *args, **kwargs + ) raise RuntimeError( f"There was no fake impl registered for {self}. " f"This is necessary for torch.compile/export/fx tracing to work. " @@ -681,11 +683,12 @@ def fake_impl(*args, **kwargs): def adinplaceorview_impl(keyset, *args, **kwargs): # Handle the mutated idx the user gave us explicitly + all_args, all_kwargs = utils.fill_defaults(schema, args, kwargs) for idx in mutated_idxs: - increment_version(args[idx]) + increment_version(all_args[idx]) for key in mutated_keys: - increment_version(kwargs[key]) + increment_version(all_kwargs[key]) # Handle view + mutation that are in the schema return original_kernel.call_boxed(keyset, *args, **kwargs) @@ -714,6 +717,7 @@ def backend_select(keyset, *args, **kwargs): f"{self._name} does not have a kernel registered for {device}. " "Please use register_kernel to do so." ) + # pyrefly: ignore [bad-argument-type] dispatch_key = _C._dispatch_key_for_device(device) dispatch_key = getattr(_C.DispatchKey, dispatch_key) return self._opoverload.redispatch( diff --git a/torch/_library/fake_class_registry.py b/torch/_library/fake_class_registry.py index eb4c4ecc72736..5d4a6580af13c 100644 --- a/torch/_library/fake_class_registry.py +++ b/torch/_library/fake_class_registry.py @@ -30,13 +30,13 @@ def __init__( with _disable_current_modes(): real_obj = copy.deepcopy(x) except (RuntimeError, TypeError) as e: - log.warning( # noqa: G200 + log.warning( "Unable to deepcopy the custom object %s due to %s. " "Defaulting to the user given object. This might be " "dangerous as side effects may be directly applied " "to the object.", script_class_name, - str(e), + e, ) object.__setattr__(self, "real_obj", real_obj) @@ -69,12 +69,31 @@ def __getitem__(self, key): return self.real_obj[key] def __eq__(self, other): + if self is other: + return True + # Get real_obj without triggering custom __getattribute__ + self_real = object.__getattribute__(self, "real_obj") if isinstance(other, FakeScriptObject): - return self.real_obj == other.real_obj - return self.real_obj == other - - def __hash__(self) -> int: - return hash(self.real_obj) + other_real = object.__getattribute__(other, "real_obj") + # For reference types, identity check first + if self_real is other_real: + return True + # Fall back to equality check + return self_real == other_real + # Compare with the real object directly + return self_real == other + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + # Use real_obj's hash if available, otherwise use object id + real_obj = object.__getattribute__(self, "real_obj") + try: + return hash(real_obj) + except TypeError: + # Object is not hashable, use identity-based hash + return id(real_obj) def __deepcopy__(self, memo: dict[int, Any]) -> "FakeScriptObject": if id(self) in memo: @@ -85,21 +104,28 @@ def __deepcopy__(self, memo: dict[int, Any]) -> "FakeScriptObject": new_obj, "wrapped_obj", copy.deepcopy(self.wrapped_obj, memo) ) object.__setattr__(new_obj, "script_class_name", self.script_class_name) - new_real_obj = copy.deepcopy(self.real_obj, memo) - object.__setattr__(new_obj, "real_obj", new_real_obj) - for name, value in self.__dict__.items(): - if name not in ("wrapped_obj", "script_class_name", "real_obj"): - if isinstance(value, FakeScriptMethod): - object.__setattr__( - new_obj, - name, - FakeScriptMethod(new_obj, value.method_name, value.schema), - ) - else: - if hasattr(new_real_obj, name): - object.__setattr__(new_obj, name, getattr(new_real_obj, name)) + # Disable dispatch modes during deepcopy of real_obj and attribute + # access to prevent tensor operations (e.g. storage cloning, property + # access on DeviceMesh) from going through proxy tracing or + # functionalization. + with _disable_current_modes(): + new_real_obj = copy.deepcopy(self.real_obj, memo) + object.__setattr__(new_obj, "real_obj", new_real_obj) + for name, value in self.__dict__.items(): + if name not in ("wrapped_obj", "script_class_name", "real_obj"): + if isinstance(value, FakeScriptMethod): + object.__setattr__( + new_obj, + name, + FakeScriptMethod(new_obj, value.method_name, value.schema), + ) else: - object.__setattr__(new_obj, name, value) + if hasattr(new_real_obj, name): + object.__setattr__( + new_obj, name, getattr(new_real_obj, name) + ) + else: + object.__setattr__(new_obj, name, value) return new_obj @@ -220,7 +246,6 @@ def maybe_to_fake_obj( is_opaque_type, OpaqueTypeStr, ) - from torch._subclasses.fake_tensor import unset_fake_temporarily x_type = type(x) if is_opaque_type(x_type): @@ -232,7 +257,7 @@ def maybe_to_fake_obj( if opaque_info is None: raise AssertionError(f"opaque_info for type {x_type} must not be None") for attr_name in opaque_info.members: - with unset_fake_temporarily(): + with _disable_current_modes(): if not hasattr(x, attr_name): raise TypeError( f"Opaque object of type '{type_name}' was specified to have member " diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 90a1d162a10a8..234a6f211ca82 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -8,7 +8,11 @@ from torch import device, dtype, Tensor, types from torch.utils._exposed_in import exposed_in -from .opaque_object import _OPAQUE_TYPES, is_opaque_reference_type, is_opaque_type +from .opaque_object import ( + _resolve_opaque_type_info, + is_opaque_reference_type, + is_opaque_type, +) # This is used as a negative test for @@ -127,7 +131,7 @@ def unstringify_type(ty: type[object] | str) -> tuple[typing.Any, bool]: schema_type = None if annotation_type not in SUPPORTED_PARAM_TYPES: if is_opaque_type(annotation_type): - schema_type = _OPAQUE_TYPES[annotation_type].class_name + schema_type = _resolve_opaque_type_info(annotation_type).class_name # type: ignore[union-attr] elif annotation_type == torch._C.ScriptObject: error_fn( f"Parameter {name}'s type cannot be inferred from the schema " @@ -230,7 +234,7 @@ def derived_types( def derived_seq_types(typ: type | typing._SpecialForm): return ( - typing.Sequence[typ], # type: ignore[valid-type] # noqa: UP006 + typing.Sequence[typ], # type: ignore[valid-type] typing.List[typ], # type: ignore[valid-type] # noqa: UP006 GenericAlias(collections.abc.Sequence, (typ,)), GenericAlias(list, (typ,)), @@ -300,7 +304,7 @@ def parse_return(annotation, error_fn): if origin is not tuple: if annotation not in SUPPORTED_RETURN_TYPES: if is_opaque_reference_type(annotation): - return _OPAQUE_TYPES[annotation].class_name + return _resolve_opaque_type_info(annotation).class_name # type: ignore[union-attr] error_fn( f"Return has unsupported type {annotation}. " f"The valid types are: {SUPPORTED_RETURN_TYPES}." @@ -319,7 +323,7 @@ def parse_return(annotation, error_fn): def _return_type_str(arg): if ty := SUPPORTED_RETURN_TYPES.get(arg): return ty - return _OPAQUE_TYPES[arg].class_name + return _resolve_opaque_type_info(arg).class_name # type: ignore[union-attr] output_ty = ", ".join(_return_type_str(arg) for arg in args) diff --git a/torch/_library/opaque_object.py b/torch/_library/opaque_object.py index cfbe0ae814a33..2513c31a336ad 100644 --- a/torch/_library/opaque_object.py +++ b/torch/_library/opaque_object.py @@ -39,12 +39,17 @@ from collections.abc import Callable from dataclasses import dataclass from enum import Enum -from typing import Any, Literal, NewType +from typing import Any, Literal, NewType, TYPE_CHECKING, TypeAlias from typing_extensions import TypeIs from weakref import WeakKeyDictionary import torch -from torch._opaque_base import OpaqueBase, OpaqueBaseMeta # noqa: F401 +from torch._opaque_base import OpaqueBase, OpaqueBaseMeta + + +if TYPE_CHECKING: + from torch.fx import Proxy + from torch.fx.experimental.proxy_tensor import PythonKeyTracer from .fake_class_registry import register_fake_class @@ -81,6 +86,15 @@ def __obj_unflatten__(cls, flattened_ctx: dict[str, Any]) -> None: OpaqueType = NewType("OpaqueType", torch._C.ScriptObject) +# Type for reconstruct_fn: called by PythonKeyTracer.create_arg when make_fx +# encounters an untracked opaque reference (e.g. a backward closure capture). +# Should derive the object from existing graph inputs or return None to fall +# back to get_attr. Args: (obj, get_tracked_proxy, tracer). +ReconstructFn: TypeAlias = Callable[ + [OpaqueBase, Callable[[OpaqueBase], "Proxy | None"], "PythonKeyTracer"], + "Proxy | None", +] + @dataclass class _OpaqueTypeInfo: @@ -91,6 +105,7 @@ class _OpaqueTypeInfo: ] # Callable that takes the object and returns list of values to guard on members: dict[str, MemberType] # Maps member name to how it should be handled hoist: bool + reconstruct_fn: ReconstructFn | None # Mapping of type -> (string name, reference/value type) @@ -141,6 +156,7 @@ def register_opaque_type( hoist=False, guard_fn: Any = None, members: dict[str, MemberType] | None = None, + reconstruct_fn: ReconstructFn | None = None, ) -> None: """ Registers the given type as an opaque type which allows this to be consumed @@ -187,7 +203,9 @@ def register_opaque_type( "registered as a pytree. Opaque objects must be pytree leaves." ) - if not isinstance(cls, OpaqueBaseMeta): + # Value types store the real object directly during tracing (no + # FakeScriptObject wrapper), so they don't need OpaqueBaseMeta. + if typ != "value" and not isinstance(cls, OpaqueBaseMeta): raise TypeError( f"Opaque type {cls} must subclass torch._opaque_base.OpaqueBase " "or 'metaclass=torch._opaque_base.OpaqueBaseMeta'. " @@ -202,7 +220,8 @@ def register_opaque_type( ) if typ == "value": - if cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap] + # Enums use identity-based equality (singletons), which is fine for guarding. + if not issubclass(cls, Enum) and cls.__eq__ is object.__eq__: # type: ignore[comparison-overlap] raise TypeError( f"Value-type opaque object of type {cls} is " "expected to have a non-default `__eq__` " @@ -220,13 +239,14 @@ def register_opaque_type( "for FakeTensor caching." ) - if not hasattr(cls, "__fx_repr__"): + # Enums are special-cased in get_opaque_obj_repr. + if not issubclass(cls, Enum) and not hasattr(cls, "__fx_repr__"): raise TypeError( f"Value-type opaque object of type {cls} is " "expected to have a `__fx_repr__` method " "implementation as we will use this to reconstruct " "the object in the FX codegen. __fx_repr__ should return " - "a tuple of (repr_string, set_of_types)." + "a tuple of (repr_string, dict[str, type])." ) if guard_fn is not None: @@ -239,15 +259,27 @@ def register_opaque_type( # Generate a fully qualified name by combining module and qualname name = f"{cls.__module__}.{cls.__qualname__}" - type_info = _OpaqueTypeInfo(name, typ, guard_fn, members or {}, hoist) + type_info = _OpaqueTypeInfo( + name, typ, guard_fn, members or {}, hoist, reconstruct_fn + ) _OPAQUE_TYPES[cls] = type_info _OPAQUE_TYPES_BY_NAME[name] = type_info torch._C._register_opaque_type(name) +# Enums are always opaque value types. +register_opaque_type(Enum, typ="value") + + def is_opaque_value(value: object) -> TypeIs[OpaqueType]: - return is_opaque_type(type(value)) + if is_opaque_type(type(value)): + return True + from torch._library.fake_class_registry import FakeScriptObject + + if isinstance(value, FakeScriptObject): + return is_opaque_type(type(value.real_obj)) + return False def should_hoist(cls: Any) -> bool: @@ -257,6 +289,13 @@ def should_hoist(cls: Any) -> bool: return info.hoist +def get_reconstruct_fn(cls: type[OpaqueBase]) -> ReconstructFn | None: + info = _resolve_opaque_type_info(cls) + if info is None: + return None + return info.reconstruct_fn + + def has_members(cls: Any) -> bool: info = _resolve_opaque_type_info(cls) if info is None: @@ -330,6 +369,12 @@ def get_opaque_obj_repr(obj: Any) -> tuple[str, dict[str, type]]: For example, if repr_string is "Foo(bar=Bar(1))", the dict should be: {"Foo": Foo, "Bar": Bar} """ + + # Enums are special cased + if isinstance(obj, Enum): + cls = type(obj) + return f"{cls.__name__}.{obj.name}", {cls.__name__: cls} + if not hasattr(obj, "__fx_repr__"): raise TypeError( f"Value-type opaque object of type {obj} is " diff --git a/torch/_library/utils.py b/torch/_library/utils.py index 4620e5a88c4ab..b81d3afc4776f 100644 --- a/torch/_library/utils.py +++ b/torch/_library/utils.py @@ -76,6 +76,12 @@ def is_builtin(op: OpOverload) -> bool: return op.namespace in {"aten", "prim", "prims"} +def is_out(op: OpOverload) -> bool: + """Returns True if the operator has "out" semantics: its mutable arguments + are write-only output buffers that are not read from.""" + return torch.Tag.out in op.tags + + def is_functional_schema(schema: Any, *, allow_valid_view: bool = False) -> bool: """Check if the schema is functional. @@ -286,6 +292,9 @@ def can_generate_trivial_fake_impl(op: OpOverload) -> bool: # do input metadata mutation (which we have banned on custom ops) return False schema = op._schema + if is_out(op): + # Tag.out ops have a trivial fake impl: return the out= args in order. + return True # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution if not schema.is_mutable: return False @@ -295,6 +304,22 @@ def can_generate_trivial_fake_impl(op: OpOverload) -> bool: return True +def generate_trivial_fake_impl(op: OpOverload, *args, **kwargs): + """Generate the result of a trivial fake impl for the given op. + + For ops with no returns: returns None. + For Tag.out ops: returns the out= kwargs in declaration order. + """ + if is_out(op): + schema = op._schema + _, out_kwarg_names = mutated_args_kwargs(schema) + out_args = tuple(kwargs[name] for name in out_kwarg_names) + if len(out_args) == 1: + return out_args[0] + return out_args + return None + + def requires_set_python_module() -> bool: """If an op was defined in C++ and extended from Python using the torch.library APIs, returns if we require that there have been a @@ -634,6 +659,9 @@ def is_impure( if _get_effect(op) is not None: return True + if op in _side_effectful_functions: + return True + return False # Impure since it mutates RNG state diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index eff99aa1cfb1d..df9d54f11c12c 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -1024,11 +1024,7 @@ def format(self, record): if self._is_trace: if s != "": raise AssertionError(f"expected empty string for trace, got {s!r}") - try: - r = f"{prefix} {json.dumps(record.metadata)}" - except TypeError: - log.warning("failing metadata: %r", record.metadata) - raise + r = f"{prefix} {json.dumps(record.metadata, default=repr)}" if record.payload is not None: r += "".join(f"\n\t{l}" for l in record.payload.split("\n")) return r @@ -1200,6 +1196,7 @@ def __init__(self, root_dir: str | None) -> None: logging.Handler.__init__(self) self.stream = None self._builtin_open = open + self._pending_log_version = False # cloned from FileHandler in cpython def close(self) -> None: @@ -1271,12 +1268,36 @@ def emit(self, record) -> None: # TORCH_LOGS="inductor" is enabled inductor_log = logging.getLogger("torch._inductor") inductor_log.info("tlparse raw data: %s", self.stream.name) + self._pending_log_version = True else: # We go poof, remove and no-op trace_log.removeHandler(self) return if self.stream: super().emit(record) + if self._pending_log_version: + self._pending_log_version = False + _log_torch_version() + + +def _log_torch_version() -> None: + import torch + from torch._environment import is_fbcode + from torch._utils_internal import get_torch_source_version + + version_info: dict[str, object] = { + "pytorch_version": torch.__version__, + "commit": get_torch_source_version(), + "oss": not is_fbcode(), + } + + trace_structured( + "artifact", + metadata_fn=lambda: {"name": "torch_version", "encoding": "json"}, + payload_fn=lambda: version_info, + suppress_context=True, + expect_trace_id=False, + ) @functools.cache diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py index aa5671d67c359..9bd4aaba825b0 100644 --- a/torch/_logging/_registrations.py +++ b/torch/_logging/_registrations.py @@ -1,4 +1,3 @@ -# flake8: noqa: B950 from ._internal import register_artifact, register_log diff --git a/torch/_logging/scribe.py b/torch/_logging/scribe.py index a457b549e4f1e..b928b0f922b6b 100644 --- a/torch/_logging/scribe.py +++ b/torch/_logging/scribe.py @@ -59,5 +59,5 @@ def inner(**kwargs: TLazyField) -> None: # The weight of the record according to current sampling rate 18: optional i64 weight; } -""", # noqa: B950 +""", ) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index aed7b48c6cb7a..52649ee215974 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -314,9 +314,33 @@ def meta_fft_c2c(self, dim, normalization, forward): if device_hint(self) == "cpu" and not torch.backends.mkl.is_available(): return self.new_empty(self.size()) - sorted_dims = _sort_dims(self, dim) - out = self.new_empty(self.size()) - return _exec_fft(out, self, self.size(), sorted_dims, forward=forward) + out_sizes = self.size() + output = self.new_empty(out_sizes) + if device_hint(self) != "cuda": + sorted_dims = _sort_dims(self, dim) + return _exec_fft(output, self, out_sizes, sorted_dims, forward=forward) + + # Match _fft_c2c_cufft, which re-sorts the remaining dimensions after each + # staged transform because _exec_fft restrides the output in place. + sorted_dims = list(dim) + working_tensor = self + while True: + strides = working_tensor.stride() + sorted_dims.sort(key=lambda i: strides[i], reverse=True) + max_dims = min(cufft_max_ndim, len(sorted_dims)) + last_dims = sorted_dims[len(sorted_dims) - max_dims :] + + _exec_fft(output, working_tensor, out_sizes, last_dims, forward=forward) + sorted_dims = sorted_dims[: len(sorted_dims) - max_dims] + + if not sorted_dims: + return output + + if working_tensor is self: + working_tensor = output + output = self.new_empty(out_sizes) + else: + output, working_tensor = working_tensor, output cufft_max_ndim = 3 @@ -467,6 +491,89 @@ def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory= ) +@register_meta(aten._philox_key_split.default) +def meta_philox_key_split(key, num_splits): + torch._check( + key.dim() >= 1 and key.shape[-1] == 2, + lambda: f"_philox_key_split: key must have shape (*batch, 2), got shape {key.shape}", + ) + torch._check( + key.dtype == torch.uint64, + lambda: f"_philox_key_split: key must have dtype uint64, got {key.dtype}", + ) + torch._check( + num_splits > 0, + lambda: f"_philox_key_split: num_splits must be positive, got {num_splits}", + ) + batch_sizes = key.shape[:-1] + return key.new_empty((num_splits, *batch_sizes, 2)) + + +@register_meta(aten._philox_key_fold_in.default) +def meta_philox_key_fold_in(key, data): + torch._check( + key.dim() >= 1 and key.shape[-1] == 2, + lambda: f"_philox_key_fold_in: key must have shape (*batch, 2), got shape {key.shape}", + ) + torch._check( + key.dtype == torch.uint64, + lambda: f"_philox_key_fold_in: key must have dtype uint64, got {key.dtype}", + ) + return torch.empty_like(key) + + +def _check_philox_distribution_args(op_name, self, key): + torch._check( + self.dtype.is_floating_point, + lambda: f"{op_name}: self must be a floating point tensor, got {self.dtype}", + ) + torch._check( + key.dtype == torch.uint64, + lambda: f"{op_name}: key must have dtype uint64, got {key.dtype}", + ) + torch._check( + self.device == key.device, + lambda: ( + f"{op_name}: self and key must be on the same device, " + f"got {self.device} and {key.device}" + ), + ) + torch._check( + key.dim() >= 1 and key.shape[-1] == 2, + lambda: ( + f"{op_name}: key must have shape (2,) or (*batch, 2), got shape {key.shape}" + ), + ) + if key.dim() > 1: + torch._check( + key.dim() == self.dim() + 1, + lambda: ( + f"{op_name}: batched key must have ndim == output ndim + 1, " + f"got key shape {key.shape} with output shape {self.shape}" + ), + ) + key_batch = key.shape[: self.dim()] + torch._check( + all(ks == 1 or ks == ss for ks, ss in zip(key_batch, self.shape)), + lambda: ( + f"{op_name}: key batch shape {list(key_batch)} " + f"is not broadcastable with output shape {self.shape}" + ), + ) + + +@register_meta(aten._philox_normal_.default) +def meta_philox_normal_(self, key, mean=0.0, std=1.0): + _check_philox_distribution_args("_philox_normal_", self, key) + return self + + +@register_meta(aten._philox_uniform_.default) +def meta_philox_uniform_(self, key, low=0.0, high=1.0): + _check_philox_distribution_args("_philox_uniform_", self, key) + return self + + @register_meta([aten._fft_c2r.default, aten._fft_c2r.out]) @out_wrapper() def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int): @@ -490,7 +597,7 @@ def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int) else: # First complete any C2C transforms if len(dim) > 1: - temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none + temp = meta_fft_c2c(self, dim[:-1], 0, forward=False) else: temp = self.clone(memory_format=torch.contiguous_format) return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False) @@ -600,10 +707,11 @@ def meta_sparse_structured_linear( raise AssertionError( f"out_dtype is only supported for i8i8->i32 linear operator, got input.dtype={input.dtype}, out_dtype={out_dtype}" ) - output = input.new_empty( + output = input.new_empty_strided( output_sizes, + transposed_strides, dtype=input.dtype if out_dtype is None else out_dtype, - ).as_strided(output_sizes, transposed_strides) + ) return output @@ -894,7 +1002,7 @@ def make_dep_token( pin_memory=None, memory_format=None, ): - return torch.empty(0, device="meta") + return torch.empty((), device="meta") @register_meta(aten.sym_constrain_range.default) @@ -2593,6 +2701,11 @@ def meta_conv( if guard_or_false(input_tensor.size(input_channels_dim) == 0): shape_out[output_channels_dim] = 0 + # Memory format is left as contiguous: meta tensors have no device info, + # so _select_conv_backend returns Overrideable and the correct format + # cannot be determined here. The FakeTensor path (torch.compile, export) + # intercepts via a register_op_impl in fake_impls.py before reaching this + # kernel and uses FakeTensor.fake_device for an accurate answer. out = input_tensor.new_empty(shape_out) return out @@ -2885,6 +2998,13 @@ def check_dim_size(tensor, dim, dim_size, size): ) +@register_meta(aten.quantize_per_tensor) +def meta_quantize_per_tensor( + input: torch.Tensor, scale: float, zero_point: int, dtype: torch.dtype +) -> torch.Tensor: + return torch.empty_like(input) + + @register_meta(aten.avg_pool2d.default) def meta_avg_pool2d( input, @@ -3467,8 +3587,8 @@ def meta_complex(real, imag): @register_meta([aten.nonzero_static.default, aten.nonzero_static.out]) @out_wrapper() def nonzero_static(self, *, size, fill_value: int = -1): - # The impl of xpu nonzero_static is different with cuda but aligned with cpu - if device_hint(self) in ("cpu", "xpu"): + # The impl of nonzero_static on xpu and mps differs from cuda but aligned with cpu + if device_hint(self) in ("cpu", "mps", "xpu"): return self.new_empty((size, self.dim()), dtype=torch.long) else: return torch.empty_strided( @@ -3651,18 +3771,13 @@ def meta_convolution_backward( backend_grad_weight = None backend_grad_bias = None - # Backend layout expectation: GPU backends (CUDA via cudnn_conv_suggest_memory_format, - # MPS via mps_conv_use_channels_last) return channels_last outputs when either input - # tensor is channels_last. This must be matched here to avoid stride assertion failures - # in inductor when the predicted strides don't match actual backend output strides. + # All GPU backends compute output memory format via + # determine_backend_memory_format(input, weight, backend) — which calls + # cudnn_conv_suggest_memory_format(input, weight), mps_conv_use_channels_last(input, weight), + # etc. The format depends only on input and weight, NOT on grad_output. + # Both grad_input and grad_weight use this same backend_memory_format. # See: https://github.com/pytorch/pytorch/issues/171622 - # - # Memory format inference rules (matching backend behavior): - # - grad_input format: derived from grad_output and weight - # - grad_weight format: derived from input and grad_output def _conv_memory_format(t1, t2): - # Match the logic in cudnn_conv_suggest_memory_format and mps_conv_use_channels_last: - # Use channels_last if either tensor suggests it fmt1 = suggest_memory_format(t1) fmt2 = suggest_memory_format(t2) if fmt1 == torch.channels_last or fmt2 == torch.channels_last: @@ -3671,13 +3786,12 @@ def _conv_memory_format(t1, t2): return torch.channels_last_3d return torch.contiguous_format + memory_format = _conv_memory_format(input_, weight_) if output_mask[0]: - memory_format = _conv_memory_format(grad_output_, weight_) backend_grad_input = grad_output_.new_empty(input_.size()).to( memory_format=memory_format ) if output_mask[1]: - memory_format = _conv_memory_format(input_, grad_output_) backend_grad_weight = grad_output_.new_empty(weight_.size()).to( memory_format=memory_format ) @@ -4562,6 +4676,10 @@ def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor") torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor") + torch._check( + batch1.dtype == batch2.dtype, + lambda: f"expected scalar type {batch1.dtype} but found {batch2.dtype}", + ) batch1_sizes = batch1.size() batch2_sizes = batch2.size() @@ -5010,7 +5128,8 @@ def unpack(name, val): return nInputPlane, outputHeight, outputWidth -@register_meta(aten.max_pool2d_with_indices_backward.default) +@register_meta(aten.max_pool2d_with_indices_backward) +@out_wrapper("grad_input") def meta_max_pool2d_with_indices_backward( grad_output, self, @@ -5603,11 +5722,22 @@ def meta_zeros( @register_meta(aten.select_scatter.default) def meta_select_scatter(self, src, dim, index): - return utils.clone_preserve_strides(self) + return _scatter_meta_output(self) @register_meta(aten.slice_scatter.default) def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1): + return _scatter_meta_output(self) + + +def _scatter_meta_output(self): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + # Match clone_preserve_strides() in aten/native/TensorShape.cpp: overlapping + # bases cannot preserve their logical strides because the scatter writes would + # alias, so eager falls back to clone(). + if not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1: + return self.clone() return utils.clone_preserve_strides(self) @@ -5969,14 +6099,23 @@ def meta__scaled_dot_product_fused_attention_overrideable( return_debug_mask: bool = False, scale: float | None = None, ): - B = query.size(0) - H_Q = query.size(1) - S_Q = query.size(2) - S_KV = key.size(2) + # Explicitly handle 3D (H, S, D) and 4D (B, H, S, D) inputs, + # matching the C++ runtime in aten_mtia_ops.cpp. + B, H_Q, S_Q = 0, 0, 0 + if query.dim() == 4: + B, H_Q, S_Q, _ = query.size() + elif query.dim() == 3: + H_Q, S_Q, _ = query.size() + B = 1 + else: + raise RuntimeError("query must be 3D or 4D") + S_KV = key.size(-2) D_V = value.size(-1) - res_shape = (B, H_Q, S_Q, D_V) - res = alloc_with_matching_layout(query, res_shape) + # Preserve input dimensionality for the output shape + out_shape = list(query.shape) + out_shape[-1] = D_V + res = alloc_with_matching_layout(query, tuple(out_shape)) logsum_exp = torch.empty( (B, H_Q, S_Q), @@ -6001,6 +6140,34 @@ def meta__scaled_dot_product_fused_attention_overrideable( ) +@register_meta(aten._scaled_dot_product_fused_attention_overrideable_backward) +def meta__scaled_dot_product_fused_attention_overrideable_backward( + grad_out: Tensor, + query: Tensor, + key: Tensor, + value: Tensor, + attn_bias: Tensor, + grad_input_mask: list[bool], + out: Tensor, + logsumexp: Tensor, + cum_seq_q: Tensor, + cum_seq_k: Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: Tensor, + philox_offset: Tensor, + *, + scale: float | None = None, +): + grad_q = torch.empty_like(query) + grad_k = torch.empty_like(key) + grad_v = torch.empty_like(value) + grad_attn_bias = torch.empty_like(attn_bias) if attn_bias is not None else None + return grad_q, grad_k, grad_v, grad_attn_bias + + @register_meta( [ aten._scaled_dot_product_flash_attention_backward, @@ -6129,37 +6296,9 @@ def ensure_4d(x): q_, unsqueezed = ensure_4d(query) k_, _ = ensure_4d(key) v_, _ = ensure_4d(value) - mask_ = None - if attn_mask is not None: - mask_expanded_dims = list(query.shape) - mask_expanded_dims[-1] = k_.size(2) - mask_ = attn_mask.expand(mask_expanded_dims) - mask_, _ = ensure_4d(mask_) - - batch_size, num_head, q_size, query_head_size = q_.shape - _, k_size, max_seq_length, value_head_size = v_.shape - - def sdpa_vector_fast_mps(): - out = q_.new_empty(q_.shape) - if unsqueezed: - out = out.view_as(query) - attn = q_.new_empty((batch_size, num_head, q_size, max_seq_length)) - if unsqueezed: - if query.dim() == 3: - attn = attn.squeeze(0) - else: - shape = list(query.shape[:-3]) + attn.shape[1:4] - attn = attn.view(shape) - return out, attn - - def sdpa_vector_2pass_mps(): - blocks = 32 - out = q_.new_empty(q_.shape) - intermediate = q_.new_empty( - (batch_size, num_head, q_size, blocks, query_head_size) - ) - return out, intermediate + batch_size, num_head, q_size, _ = q_.shape + _, _, max_seq_length, value_head_size = v_.shape def sdpa_general_mps(): out = q_.new_empty((batch_size, num_head, q_size, value_head_size)) @@ -6175,26 +6314,9 @@ def sdpa_general_mps(): attn = attn.view(attn_shape) return out, attn - query_head_dim = q_.size(3) - value_head_dim = v_.size(3) - sdpa_vector_supported_head_dim = (query_head_dim == value_head_dim) and ( - query_head_dim == 64 or query_head_dim == 96 or query_head_dim == 128 - ) - query_seq_len = q_.size(2) - supports_sdpa_vector = ( - (query_seq_len <= 8) - and (query_seq_len <= k_.size(2)) - and ((mask_ is None) or (mask_.dtype == torch.bool)) - and sdpa_vector_supported_head_dim - ) - supports_fast_sdpa = (not is_causal) and supports_sdpa_vector - - if not supports_fast_sdpa: - return sdpa_general_mps() - elif (max_seq_length >= 1024) or (k_size < q_size and max_seq_length >= 4096): - return sdpa_vector_2pass_mps() - else: - return sdpa_vector_fast_mps() + # sdpa_vector_2pass_mps and sdpa_vector_fast_mps are intentionally left out. + # See https://github.com/pytorch/pytorch/issues/177603 for additional context. + return sdpa_general_mps() @register_meta([aten._scaled_dot_product_efficient_attention]) @@ -6672,7 +6794,11 @@ def is_fp8_or_fp4_type(dtype): lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}", ) - if device_hint(self) == "cuda" or device_hint(self) == "xpu": + if ( + device_hint(self) == "cuda" + or device_hint(self) == "xpu" + or device_hint(self) == "cpu" + ): def is_row_major(stride): return stride[0] > stride[1] and stride[1] == 1 @@ -6683,22 +6809,23 @@ def is_col_major(stride): def has_zero_dim(tensor_2d): return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0 - torch._check( - is_row_major(self.stride()) or has_zero_dim(self), - lambda: f"self must be row_major, got stride {self.stride()}", - ) - torch._check( - is_col_major(mat2.stride()) or has_zero_dim(mat2), - lambda: f"mat2 must be col_major, got stride {mat2.stride()}", - ) - torch._check( - self.size(1) % 16 == 0, - lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", - ) - torch._check( - mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, - lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}", - ) + if device_hint(self) != "cpu": + torch._check( + is_row_major(self.stride()) or has_zero_dim(self), + lambda: f"self must be row_major, got stride {self.stride()}", + ) + torch._check( + is_col_major(mat2.stride()) or has_zero_dim(mat2), + lambda: f"mat2 must be col_major, got stride {mat2.stride()}", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) + torch._check( + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}", + ) # determine scaling type and check input dimensions (refer to Blas.cpp op) @@ -6913,7 +7040,11 @@ def is_fp4_type(dtype): SwizzleType.NO_SWIZZLE, ] - if device_hint(self) == "cuda" or device_hint(self) == "xpu": + if ( + device_hint(self) == "cuda" + or device_hint(self) == "xpu" + or device_hint(self) == "cpu" + ): def is_row_major(stride): return stride[0] > stride[1] and stride[1] == 1 @@ -6924,22 +7055,23 @@ def is_col_major(stride): def has_zero_dim(tensor_2d): return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0 - torch._check( - is_row_major(self.stride()) or has_zero_dim(self), - lambda: f"self must be row_major, got stride {self.stride()}", - ) - torch._check( - is_col_major(mat2.stride()) or has_zero_dim(mat2), - lambda: f"mat2 must be col_major, got stride {mat2.stride()}", - ) - torch._check( - self.size(1) % 16 == 0, - lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", - ) - torch._check( - mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, - lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}", - ) + if device_hint(self) != "cpu": + torch._check( + is_row_major(self.stride()) or has_zero_dim(self), + lambda: f"self must be row_major, got stride {self.stride()}", + ) + torch._check( + is_col_major(mat2.stride()) or has_zero_dim(mat2), + lambda: f"mat2 must be col_major, got stride {mat2.stride()}", + ) + torch._check( + self.size(1) % 16 == 0, + lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}", + ) + torch._check( + mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0, + lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}", + ) def is_tensorwise(recipe_a: list[ScalingType], recipe_b: list[ScalingType]): return ( @@ -7875,7 +8007,11 @@ def meta_histc(input, bins=100, min=0, max=0): @register_meta( - [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default] + [ + aten._upsample_bilinear2d_aa.default, + aten._upsample_bicubic2d_aa.default, + aten._upsample_lanczos2d_aa.default, + ] ) def meta_upsample_bimode2d_aa( input, @@ -7896,7 +8032,12 @@ def meta_upsample_bimode2d_aa( ) -@register_meta([aten._upsample_bilinear2d_aa_backward.default]) +@register_meta( + [ + aten._upsample_bilinear2d_aa_backward.default, + aten._upsample_lanczos2d_aa_backward.default, + ] +) def meta_upsample_bimode2d_aa_backward( grad_output, output_size, @@ -8283,17 +8424,17 @@ def _meta_grouped_mm_common( fp8_dtype = torch.float8_e4m3fnuz torch._check( mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype, - lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 + lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", ) else: torch._check( mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16, - lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950 + lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", ) torch._check( mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3], - lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950 + lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", ) mat_a_is_2d = mat_a.dim() == 2 @@ -8317,11 +8458,11 @@ def is_col_major(mat): torch._check( is_row_major(mat_a), - lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950 + lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", ) torch._check( is_col_major(mat_b), - lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950 + lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", ) def check_valid_strides(mat_name, mat): @@ -8333,19 +8474,19 @@ def check_valid_strides(mat_name, mat): ): torch._check( mat_stride[end_dim] % alignment == 0, - lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950 + lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", ) elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max( 1, mat.shape[end_dim] ): torch._check( mat_stride[end_dim - 1] % alignment == 0, - lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950 + lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", ) else: torch._check( False, - lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950 + lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", ) check_valid_strides("mat_a", mat_a) @@ -8358,7 +8499,7 @@ def check_valid_strides(mat_name, mat): scale_a.dtype == torch.float8_e8m0fnu and scale_b.dtype == torch.float8_e8m0fnu ), - lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950 + lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", ) is_mxfp8 = ( scale_a.dtype == torch.float8_e8m0fnu @@ -8378,7 +8519,7 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): if is_mxfp8: torch._check( scale.dim() == mat.dim(), - lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950 + lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", ) else: torch._check( @@ -8387,7 +8528,7 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): ) torch._check( scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier, - lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950 + lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", ) else: torch._check( @@ -8403,7 +8544,7 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): if is_mxfp8: torch._check( scale.ndim == mat.ndim - 1, - lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950 + lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", ) # TODO: This logic only holds for RHS tensor in 2d-3d case. # We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases. @@ -8413,7 +8554,7 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): blocked_N = round_up(N, 128) torch._check( scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N, - lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950 + lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", ) else: torch._check( @@ -8422,7 +8563,7 @@ def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1): ) torch._check( scale.shape[1] == mat.shape[1 + scaled_dim], - lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950 + lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", ) scale_multiplier = ( @@ -8620,6 +8761,13 @@ def embedding( ) -> Tensor: if weight.dim() != 2: raise AssertionError(f"'weight' must be 2-D, got {weight.dim()}-D") + torch._check( + indices.dtype in (torch.long, torch.int32), + lambda: ( + "Expected tensor for argument #1 'indices' to have one of the following " + f"scalar types: Long, Int; but got {indices.dtype} instead" + ), + ) weight_shape = weight.shape indices_shape = indices.shape @@ -8881,11 +9029,11 @@ def activate_meta(): in { "aten::empty_strided", # causing infinite recursion, test_meta.py "aten::clone", # causing infinite recursion - "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950 - "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950 - "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950 - "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950 - "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950 + "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite + "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 + "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 + "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 + "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 } ): pass diff --git a/torch/_native/README.md b/torch/_native/README.md index e38ee2e13acf3..f746c94848265 100644 --- a/torch/_native/README.md +++ b/torch/_native/README.md @@ -173,9 +173,64 @@ register_op_override( implementation_fn: _OpOverrideFn, *, allow_multiple_override: bool = False, - unconditional_override: bool = False, -) -> None + unconditional_override: bool = False,) -> None ``` -Register a given implementation to a library - `lib_symbol = "aten"` for most cases, `op_symbol` refers to the library method you wish to override (ex. `"_scaled_grouped_mm_v2"` from above), and dispatch key will generally be one of `("CPU", "CUDA")` depending on what backend you're overriding. For all arguments, please see the comments for `_register_op_override` in [registry.py](registry.py). +Register a given implementation to a library - `lib_symbol = "aten"` for most cases, `op_symbol` refers to the library method you wish to override (ex. `"_scaled_grouped_mm_v2"` from above), and dispatch key will generally be one of `("CPU", "CUDA")` depending on what backend you're overriding. For all arguments, please see the comments for `register_op_override` in [registry.py](registry.py). + +`deregister_op_overrides() -> None` : De-register all operators that are currently registered by this DSL. Note that `torch._native.registry` has a `deregister_op_overrides` method to enable this in a centralized fashion. An example of an implementation of this spec can be found in [cutedsl_utils.py](cutedsl_utils.py), but please talk to us if you're planning on adding a new DSL. + +## Registration Orders and You + +Currently the registration order (both in general and per-op) is set by the order of imports in `torch/_native/ops/__init__.py`, noting that registration acts as a stack, in that **the last registered override for an op is the first that will be called**. If you wish to exercise control of the override ordering, please utilize one of the methods below. + +### User-Ordering Functions + +We allow for user-defined ordering functions of the form: + +``` +from torch._native.registry import _OverrideNode + +def ordering_fn( + op_symbol: str, + dispatch_key: str, + graph: list[_OverrideNode], +) -> list[_OverrideNode] +``` + +In other words, a function that takes some context and a graph describing the override order, and returning a modified graph. + +**NOTE**: Graphs are described as lists of the private class `_OverrideNode` -- while this graph re-ordering functionality is public, it is both experimental and intended for advanced users only. The `_OverrideNode` class is to be used very carefully, and may change in the future. + +This functionality can used by either setting the environment variable `TORCH_PYTHON_NATIVE_USER_GRAPH_ORDER_FN` to an importable python function with the above signature, or by adding the following to your top-level script, post `import torch`: + +``` +torch._native.reorder_graphs_from_user_function( + my_ordering_fn, + reregister_overrides=True, +) +``` + +Both methods are equivalent in functionality, but the environment-variable version is a little more efficient in that torch doesn't have to register **all** ops, before disabling/re-registering again based on the user-passed function. + +**NOTE**: The passed ordering function can be destructive in nature - one can disable an op completely by returning `[]` for a given graph, indicating that no overrides exist / are allowed. **There is currently no supported way to retrieve the original graphs - they are considered gone for the lifetime of the process**. + +An example user-ordering function is demonstrated below: + +``` +def example_ordering_fn(op_symbol, dispatch_key, nodes): + out_nodes = [] + + # disable overrides for these symbols completely + if op_symbol in ["_scaled_mm_v2", "add"]: + return [] + + # Only keep triton overrides otherwise + for node in nodes: + if node.dsl_name != 'triton': + continue + out_nodes.append(node) + + return out_nodes +``` diff --git a/torch/_native/__init__.py b/torch/_native/__init__.py index 819adfa2f4704..c8d41271989c9 100644 --- a/torch/_native/__init__.py +++ b/torch/_native/__init__.py @@ -1,2 +1,59 @@ +import os +import warnings +from functools import cache +from typing import cast + # This handles collecting registration of all native ops -from . import ops +# Also need to import DSL utils to make sure DSL registration is ok +from . import cutedsl_utils, dsl_registry, ops, registry, triton_utils + + +@cache +def get_user_ordering_fn() -> registry.UserOrderingFn | None: + """ + Get a user-supplied graph-ordering function if specified. + + Pass in a `package.submodule.fn` string to the env variable + `TORCH_PYTHON_NATIVE_USER_GRAPH_ORDER_FN` that implements the + calling API described in `torch/_native/README.md`. This function + must be part of an importable path. + + Return either the imported function or `None` + """ + env_var = os.getenv("TORCH_PYTHON_NATIVE_USER_GRAPH_ORDER_FN") + + if not env_var: + return None + + try: + import importlib + + # Split into "package.submodule.fn_name + module_name, fn_name = env_var.rsplit(".", 1) + + module = importlib.import_module(module_name) + fn = getattr(module, fn_name) + + if not callable(fn): + raise TypeError(f"{env_var} does not describe a callable function") + + # Cast needed: getattr returns object, but we've verified fn is callable with correct signature + return cast(registry.UserOrderingFn, fn) + except Exception as e: + raise ValueError( + f"Could not resolve {env_var} into an importable & callable function" + ) from e + + +user_order_fn = get_user_ordering_fn() +if user_order_fn: + registry.reorder_graphs_from_user_function(user_order_fn) + + +with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Warning only once for all operators, other operators may also be overridden\\.", + category=UserWarning, + ) + registry._register_all_overrides() diff --git a/torch/_native/common_utils.py b/torch/_native/common_utils.py index d2dc49c9eb622..40a07b253cb8d 100644 --- a/torch/_native/common_utils.py +++ b/torch/_native/common_utils.py @@ -3,7 +3,7 @@ import os from functools import cache -import packaging.version +from torch._vendor.packaging import version as _packaging_version @cache @@ -33,7 +33,7 @@ def _unavailable_reason(deps: list[tuple[str, str]]) -> None | str: return None -def _available_version(package: str) -> packaging.version.Version | None: +def _available_version(package: str) -> _packaging_version.Version | None: """ Get the installed version of a package as (major, minor, patch). @@ -47,8 +47,8 @@ def _available_version(package: str) -> packaging.version.Version | None: return None try: - v = packaging.version.parse(version) - except packaging.version.InvalidVersion: + v = _packaging_version.parse(version) + except _packaging_version.InvalidVersion: return None return v diff --git a/torch/_native/cutedsl_utils.py b/torch/_native/cutedsl_utils.py index d654b3558bddf..203c8812bc379 100644 --- a/torch/_native/cutedsl_utils.py +++ b/torch/_native/cutedsl_utils.py @@ -1,24 +1,34 @@ import functools import logging +import sys +from typing import cast -from packaging.version import Version +from torch._vendor.packaging.version import Version +from ..backends import cuda as _cuda from .common_utils import ( _available_version, _unavailable_reason, check_native_jit_disabled, check_native_version_skip, ) -from .registry import _OpFn, _register_op_override +from .dsl_registry import dsl_registry, DSLModuleProtocol +from .registry import ( + _OpFn, + deregister_op_overrides as _deregister_op_overrides_impl, + register_op_override as _register_op_override_impl, +) log = logging.getLogger(__name__) +_CUTEDSL_DSL_NAME = "cutedsl" _CUTEDSL_REQUIRED_VERSIONS: set[Version] = { # Current version - Note Version.from_part(release=(4.4.1)) is better # but > v26 of packaging. Version(f"{4}.{4}.{1}"), + Version(f"{4}.{4}.{2}"), } @@ -29,6 +39,10 @@ def _check_runtime_available() -> tuple[bool, Version | None]: NOTE: Doesn't import at this point """ + # Skip all checks if running on CPU-only binary + if not _cuda.is_built(): + return (False, None) + deps = [ ("nvidia_cutlass_dsl", "cutlass"), ("apache_tvm_ffi", "tvm_ffi"), @@ -38,7 +52,7 @@ def _check_runtime_available() -> tuple[bool, Version | None]: available = True version = _available_version("nvidia_cutlass_dsl") else: - log.info( + log.warning( "CuTeDSL operators require optional Python packages " "`nvidia-cutlass-dsl` and `apache-tvm-ffi`; " "%s", @@ -65,7 +79,7 @@ def _version_is_ok() -> bool: if check_native_version_skip() or (version in _CUTEDSL_REQUIRED_VERSIONS): return True - log.info( + log.warning( "cutedsl version %s is not known-good (ok: %s); " "set TORCH_NATIVE_SKIP_VERSION_CHECK=1 to override", version, @@ -74,6 +88,13 @@ def _version_is_ok() -> bool: return False +def deregister_op_overrides() -> None: + """ + Deregister all ops through cuteDSL + """ + _deregister_op_overrides_impl(disable_dsl_names=_CUTEDSL_DSL_NAME) + + def register_op_override( lib_symbol: str, op_symbol: str, @@ -86,7 +107,7 @@ def register_op_override( """ See torch/_native/registry.py for the underlying implementation and arguments. This is a thin, DSL-checking wrapper over - _register_op_override + _register_op_override_impl """ available, version = _check_runtime_available() if (not available) or check_native_jit_disabled(): @@ -95,7 +116,8 @@ def register_op_override( if not _version_is_ok(): return - _register_op_override( + _register_op_override_impl( + _CUTEDSL_DSL_NAME, lib_symbol, op_symbol, dispatch_key, @@ -103,3 +125,8 @@ def register_op_override( allow_multiple_override=allow_multiple_override, unconditional_override=unconditional_override, ) + + +# Register this DSL module with the registry +# Note: Import-time registration ensures DSL is available when module is loaded +dsl_registry.register_dsl("cutedsl", cast(DSLModuleProtocol, sys.modules[__name__])) diff --git a/torch/_native/dsl_registry.py b/torch/_native/dsl_registry.py new file mode 100644 index 0000000000000..b2fa082ee0ae6 --- /dev/null +++ b/torch/_native/dsl_registry.py @@ -0,0 +1,153 @@ +# Owner(s): ["module: dsl-native-ops"] + +import functools +import logging +from typing import Protocol + +from torch._vendor.packaging.version import Version + +from .registry import _OpFn + + +log = logging.getLogger(__name__) + + +class DSLModuleProtocol(Protocol): + """Complete interface for DSL utility modules""" + + def runtime_available(self) -> bool: ... + def runtime_version(self) -> Version | None: ... + + def deregister_op_overrides(self) -> None: ... + + def register_op_override( + self, + lib_symbol: str, + op_symbol: str, + dispatch_key: str, + impl: _OpFn, + *, + allow_multiple_override: bool = False, + unconditional_override: bool = False, + ) -> None: ... + + +class DSLRegistry: + """Registry for DSL modules - calls their existing API functions dynamically""" + + def __init__(self): + self._dsl_modules: dict[str, DSLModuleProtocol] = {} + + def _validate_dsl_name(self, name: str) -> None: + """Validate DSL name at runtime""" + if not isinstance(name, str): + raise TypeError(f"DSL name must be string, got {type(name).__name__}") + + if not name.strip(): + raise ValueError("DSL name cannot be empty or whitespace") + + def register_dsl(self, name: str, dsl_module: DSLModuleProtocol) -> None: + """Register a DSL module with required interface""" + # Runtime validation for name and module interface + self._validate_dsl_name(name) + + # Validate that module implements the protocol + required_methods = [ + "runtime_available", + "runtime_version", + "register_op_override", + "deregister_op_overrides", + ] + missing_methods = [ + method for method in required_methods if not hasattr(dsl_module, method) + ] + if missing_methods: + raise TypeError( + f"DSL module '{name}' missing required methods: {missing_methods}" + ) + + # Handle duplicate registration case + if name in self._dsl_modules: + existing_module = self._dsl_modules[name] + if existing_module is dsl_module: + # Same module re-registering - this is OK (import-time registration) + log.debug( + "DSL '%s' re-registered with same module", + name, + ) + return + else: + # Different module object but same name - warn and allow (for testing) + # This can happen when tests import modules directly + log.warning( + "DSL '%s' re-registered with different module object (possibly from test imports)", + name, + ) + # Continue to allow the registration + + # No cast needed - already properly typed + self._dsl_modules[name] = dsl_module + + # Clear caches to prevent stale results after registration + self.is_dsl_available.cache_clear() + self.get_dsl_version.cache_clear() + self.list_available_dsls.cache_clear() + self.list_all_dsls.cache_clear() + + log.info("Successfully registered DSL: %s", name) + + @functools.cache # noqa: B019 + def is_dsl_available(self, dsl_name: str) -> bool: + """Check if DSL is available by calling its runtime_available()""" + dsl_module = self._dsl_modules.get(dsl_name) + if dsl_module is None: + return False + try: + return dsl_module.runtime_available() + except ImportError: + log.debug("DSL %s import error", dsl_name, exc_info=True) + return False + except Exception: + log.exception("Error checking availability for DSL %s", dsl_name) + return False + + @functools.cache # noqa: B019 + def get_dsl_version(self, dsl_name: str) -> Version | None: + """Get DSL version by calling its runtime_version()""" + dsl_module = self._dsl_modules.get(dsl_name) + if dsl_module is None: + return None + try: + return dsl_module.runtime_version() + except Exception: + log.debug("Error getting version for DSL %s", dsl_name, exc_info=True) + return None + + @functools.cache # noqa: B019 + def list_available_dsls(self) -> tuple[str, ...]: + """Get names of currently available DSLs""" + available = [] + for name in self._dsl_modules: + if self.is_dsl_available(name): # Use cached method + available.append(name) + return tuple(available) + + @functools.cache # noqa: B019 + def list_all_dsls(self) -> tuple[str, ...]: + """Get all registered DSL names (available or not)""" + return tuple(self._dsl_modules.keys()) + + def get_dsl_module(self, name: str) -> DSLModuleProtocol | None: + """Get a registered DSL module by name. + + Args: + name: Name of the DSL to retrieve. + + Returns: + The DSL module if registered, None otherwise. + """ + return self._dsl_modules.get(name) + + +# Global registry instance +dsl_registry = DSLRegistry() diff --git a/torch/_native/ops/__init__.py b/torch/_native/ops/__init__.py index e69de29bb2d1d..d5fa810b57376 100644 --- a/torch/_native/ops/__init__.py +++ b/torch/_native/ops/__init__.py @@ -0,0 +1 @@ +from . import bmm_outer_product diff --git a/torch/_native/ops/bmm_outer_product/__init__.py b/torch/_native/ops/bmm_outer_product/__init__.py new file mode 100644 index 0000000000000..0b84de3406155 --- /dev/null +++ b/torch/_native/ops/bmm_outer_product/__init__.py @@ -0,0 +1,4 @@ +from .triton_impl import register_to_dispatch + + +register_to_dispatch() diff --git a/torch/_native/ops/bmm_outer_product/triton_impl.py b/torch/_native/ops/bmm_outer_product/triton_impl.py new file mode 100644 index 0000000000000..177fde2195d3f --- /dev/null +++ b/torch/_native/ops/bmm_outer_product/triton_impl.py @@ -0,0 +1,62 @@ +import functools +import importlib.util + +import torch + +from ... import triton_utils as tu + + +@functools.cache +def _has_triton() -> bool: + try: + return importlib.util.find_spec("triton") is not None + except ModuleNotFoundError: + return False + + +def _is_outer_product(a: torch.Tensor, b: torch.Tensor) -> bool: + return ( + a.ndim == 3 + and b.ndim == 3 + and a.shape[2] == 1 + and b.shape[1] == 1 + and a.numel() > 0 + and b.numel() > 0 + and not a.is_complex() + ) + + +def _bmm_outer_product_impl( + dispatch_keys: torch.DispatchKeySet, + a: torch.Tensor, + b: torch.Tensor, + *, + fallback_kernel, +) -> torch.Tensor: + a_is_cow = torch._C._is_cow_tensor(a) # pyrefly: ignore[missing-attribute] + b_is_cow = torch._C._is_cow_tensor(b) # pyrefly: ignore[missing-attribute] + if _has_triton() and _is_outer_product(a, b) and not (a_is_cow or b_is_cow): + from .triton_kernels import bmm_outer_product + + return bmm_outer_product(a, b) + return fallback_kernel.call_boxed(dispatch_keys, a, b) + + +def _register_for_dispatch_key(dispatch_key: str) -> None: + fallback_kernel = torch.library.get_kernel("aten::bmm", dispatch_key) + tu.register_op_override( + "aten", + "bmm", + dispatch_key, + functools.partial(_bmm_outer_product_impl, fallback_kernel=fallback_kernel), + allow_multiple_override=True, + ) + + +def register_to_dispatch() -> None: + if not _has_triton(): + return + + _register_for_dispatch_key("CUDA") + if torch.xpu._is_compiled(): + _register_for_dispatch_key("XPU") diff --git a/torch/_native/ops/bmm_outer_product/triton_kernels.py b/torch/_native/ops/bmm_outer_product/triton_kernels.py new file mode 100644 index 0000000000000..3477e04d4a2ee --- /dev/null +++ b/torch/_native/ops/bmm_outer_product/triton_kernels.py @@ -0,0 +1,93 @@ +import triton +import triton.language as tl + +import torch + + +@triton.jit +def _bmm_outer_product_kernel( + A_ptr, + B_ptr, + OUT_ptr, + B_dim, + M, + N, + stride_ab, + stride_am, + stride_bb, + stride_bn, + stride_ob, + stride_om, + stride_on, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + pid = tl.program_id(0) + + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + tiles_per_batch = grid_m * grid_n + + pid_b = pid // tiles_per_batch + pid_mn = pid % tiles_per_batch + pid_m = pid_mn // grid_n + pid_n = pid_mn % grid_n + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + mask_m = rm < M + mask_n = rn < N + + a = tl.load(A_ptr + pid_b * stride_ab + rm * stride_am, mask=mask_m, other=0.0) + b = tl.load(B_ptr + pid_b * stride_bb + rn * stride_bn, mask=mask_n, other=0.0) + + out = a[:, None] * b[None, :] + + mask = mask_m[:, None] & mask_n[None, :] # pyrefly: ignore[bad-index] + tl.store( + OUT_ptr + pid_b * stride_ob + rm[:, None] * stride_om + rn[None, :] * stride_on, + out, + mask=mask, + ) + + +def _pick_block_sizes(m: int, n: int) -> tuple[int, int]: + """I swept over some shapes and in the future we should figure out @autotune story""" + if m <= 32: + block_m = triton.next_power_of_2(m) + elif m <= 96: + block_m = 32 + elif m <= 192: + block_m = 64 + else: + block_m = 128 + return block_m, min(triton.next_power_of_2(n), 128) + + +def bmm_outer_product(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + B, M, _ = a.shape + N = b.shape[2] + + out = torch.empty(B, M, N, dtype=a.dtype, device=a.device) + + BLOCK_M, BLOCK_N = _pick_block_sizes(M, N) + + _bmm_outer_product_kernel[(B * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),)]( + a, + b, + out, + B, + M, + N, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + return out diff --git a/torch/_native/registry.py b/torch/_native/registry.py index b1d32cb46ec26..0924e51f6d4a3 100644 --- a/torch/_native/registry.py +++ b/torch/_native/registry.py @@ -1,9 +1,22 @@ -from collections.abc import Callable +import logging +from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from typing import Concatenate, ParamSpec, TypeVar import torch.library +__all__ = [ + "UserOrderingFn", + "register_op_override", + "reorder_graphs_from_user_function", + "reenable_op_overrides", + "deregister_op_overrides", + "get_dsl_operations", +] + +log = logging.getLogger(__name__) + P = ParamSpec("P") R = TypeVar("R") @@ -13,51 +26,782 @@ _OpFn = _OpOverrideFn | _OpReplaceFn -libs = {} +@dataclass +class _OverrideNode: + """Track function override data.""" + + dsl_name: str + op_symbol: str + dispatch_key: str + override_fn: _OpFn + unconditional_override: bool = False + active: bool = True + + +UserOrderingFn = Callable[[str, str, list[_OverrideNode]], list[_OverrideNode]] + + +@dataclass +class _FilterState: + """Manages filtering state for override nodes.""" + + _dsl_names: set[str] = field(default_factory=set) + _op_symbols: set[str] = field(default_factory=set) + _dispatch_keys: set[str] = field(default_factory=set) + + def check_enabled(self, node: _OverrideNode) -> bool: + """ + Check if a node is enabled based on current filter state. + + Args: + node: The override node to check + + Returns: + bool: True if the node should be enabled, False if filtered out + """ + if node.dsl_name in self._dsl_names: + return False + + if node.op_symbol in self._op_symbols: + return False + + if node.dispatch_key in self._dispatch_keys: + return False + + return True + + def update( + self, + dsl_names: str | Iterable[str] | None, + op_symbols: str | Iterable[str] | None, + dispatch_keys: str | Iterable[str] | None, + remove_keys: bool = False, + ) -> None: + """ + Update filter sets as (current | new) or (current ~ new). + + Args: + dsl_names: DSL names to add/remove from filter + op_symbols: Operation symbols to add/remove from filter + dispatch_keys: Dispatch keys to add/remove from filter + remove_keys: If True, remove keys from filter; if False, add them + + Note: + Uses set.discard as it doesn't raise an exception if the element + wasn't in the set to begin with. + """ + if remove_keys: + self._dsl_names -= set(_resolve_iterable(dsl_names)) + self._op_symbols -= set(_resolve_iterable(op_symbols)) + self._dispatch_keys -= set(_resolve_iterable(dispatch_keys)) + else: + self._dsl_names |= set(_resolve_iterable(dsl_names)) + self._op_symbols |= set(_resolve_iterable(op_symbols)) + self._dispatch_keys |= set(_resolve_iterable(dispatch_keys)) + + def build_disable_key_set(self) -> set[tuple[str, str]]: + """ + Build a set of dictionary keys based on the current filter state. + + Returns: + set[tuple[str, str]]: Set of (op_symbol, dispatch_key) tuples + """ + return _build_key_set( + self._dsl_names, + self._op_symbols, + self._dispatch_keys, + ) + + def __str__(self) -> str: + """Return string representation of filter state.""" + s = "" + s += "Filter State:\n" + s += " === DSL: ===\n" + for i, dsl in enumerate(self._dsl_names): + s += f" {i}: {dsl}\n" + s += " === OP SYMBOL: ===\n" + for i, op in enumerate(self._op_symbols): + s += f" {i}: {op}\n" + s += " === DISPATCH KEYS: ===\n" + for i, key in enumerate(self._dispatch_keys): + s += f" {i}: {key}\n" + + return s + + +# Store the global override filtering state +_filter_state: _FilterState = _FilterState() + +# Store torch.library.Library instances +_libs: dict[tuple[str, str], torch.library.Library] = {} + +# store graph structures +_GraphsType = dict[tuple[str, str], list[_OverrideNode]] +_graphs: _GraphsType = {} + +_MappingType = dict[str, list[tuple[str, str]]] +# map a {dsl, op, dispatch_key} to keys to all graphs that contain it +_dsl_name_to_lib_graph: _MappingType = {} +_dispatch_key_to_lib_graph: _MappingType = {} +_op_symbol_to_lib_graph: _MappingType = {} -def _get_library(lib_symbol: str, dispatch_key: str) -> torch.library.Library: + +def _build_key_set( + dsl_names: str | Iterable[str] | None, + op_symbols: str | Iterable[str] | None, + dispatch_keys: str | Iterable[str] | None, +) -> set[tuple[str, str]]: """ - Return a `torch.library.Library` instance unique to the passed - (lib_symbol, dispatch_key) pair. Create a new instance if necessary. + Build a set of dictionary keys based on filter criteria. + + Args: + dsl_names: DSL names to include in key set + op_symbols: Operation symbols to include in key set + dispatch_keys: Dispatch keys to include in key set + + Returns: + set[tuple[str, str]]: Set of (op_symbol, dispatch_key) tuples """ - global libs + key_set: set[tuple[str, str]] = set() - if (lib_symbol, dispatch_key) not in libs: - libs[(lib_symbol, dispatch_key)] = torch.library.Library( - lib_symbol, "IMPL", dispatch_key - ) + def _append_to_set( + entries: str | Iterable[str] | None, graph_lib_dict: _MappingType + ) -> None: + """Helper to add matching keys from graph_lib_dict to key_set.""" + resolved_entries = _resolve_iterable(entries) + + for entry in resolved_entries: + if entry in graph_lib_dict: + for key in graph_lib_dict[entry]: + key_set.add(key) + + _append_to_set(dsl_names, _dsl_name_to_lib_graph) + _append_to_set(op_symbols, _op_symbol_to_lib_graph) + _append_to_set(dispatch_keys, _dispatch_key_to_lib_graph) + + return key_set + + +def _print_override_graphs(*, print_inactive: bool = False) -> None: + """ + Print all override graphs for debugging purposes. + + Args: + print_inactive: Whether to print inactive nodes + """ + for (op, key), node_list in _graphs.items(): + print(f"{op=}, {key=}") + + for i, node in enumerate(node_list): + if node.active or print_inactive: + s: str = f" {i}: {node.dsl_name=}, {node.unconditional_override=}" + if print_inactive: + s += f" {node.active=}" + + print(s) + + +def _get_or_create_library(op_symbol: str, dispatch_key: str) -> torch.library.Library: + """ + Get or create a torch.library.Library instance for the given key. + + Args: + op_symbol: The operation symbol + dispatch_key: The dispatch key + + Returns: + torch.library.Library: The library instance + """ + global _libs + + key = (op_symbol, dispatch_key) + if key not in _libs: + _libs[key] = torch.library.Library("aten", "IMPL", dispatch_key) + + return _libs[key] + + +def _register_node_impl( + lib: torch.library.Library, node: _OverrideNode, dispatch_key: str +) -> None: + """ + Register a single node implementation with the library. + + Args: + lib: The torch.library.Library instance + node: The override node to register + dispatch_key: The dispatch key for registration + """ + lib.impl( + node.op_symbol, + node.override_fn, + dispatch_key, + with_keyset=not node.unconditional_override, + allow_override=True, + ) + + +def _resolve_iterable(iterable: str | Iterable[str] | None) -> Iterable[str]: + """ + Resolve various input types to a consistent iterable of strings. + + Args: + iterable: String, iterable of strings, or None + + Returns: + Iterable[str]: Consistent iterable output + """ + if iterable is None: + return [] + + if not isinstance(iterable, Iterable) or isinstance(iterable, str): + return (iterable,) + + return iterable + + +def reenable_op_overrides( + *, + enable_dsl_names: str | list[str] | None = None, + enable_op_symbols: str | list[str] | None = None, + enable_dispatch_keys: str | list[str] | None = None, +) -> None: + """ + Re-enable overrides by removing them from filter state and reregistering. + + Args: + enable_dsl_names: DSL names to re-enable + enable_op_symbols: Operation symbols to re-enable + enable_dispatch_keys: Dispatch keys to re-enable + + Note: + This function uses reverse filter state management (removing from + filters to enable). + """ + log.info( + "Re-registering ops by dsl: %s, op_symbol: %s, dispatch_key: %s", + enable_dsl_names, + enable_op_symbols, + enable_dispatch_keys, + ) + + # Update the filters - note `remove_keys=True` because + # we are removing keys from the filters (vs. adding them) + _filter_state.update( + enable_dsl_names, + enable_op_symbols, + enable_dispatch_keys, + remove_keys=True, + ) + + # Get the set of keys that need to be reprocessed + key_set: set[tuple[str, str]] = _build_key_set( + enable_dsl_names, + enable_op_symbols, + enable_dispatch_keys, + ) + + # Process each affected graph with updated filter state + for key in key_set: + op_symbol, dispatch_key = key + + if key in _graphs: + # Note: We don't need to cleanup and recreate the library here + # since we're just updating the registration with new filter state + _register_overrides_from_graph( + op_symbol, dispatch_key, _graphs[key], filter_state=_filter_state + ) + + +def deregister_op_overrides( + *, + disable_dsl_names: str | list[str] | None = None, + disable_op_symbols: str | list[str] | None = None, + disable_dispatch_keys: str | list[str] | None = None, +) -> None: + """ + De-register overrides by updating filter state and reregistering graphs. + + Args: + disable_dsl_names: DSL names to disable + disable_op_symbols: Operation symbols to disable + disable_dispatch_keys: Dispatch keys to disable + + Note: + This function uses filter state management to selectively disable + operations. + """ + log.info( + "De-registering ops by dsl: %s, op_symbol: %s, dispatch_key: %s", + disable_dsl_names, + disable_op_symbols, + disable_dispatch_keys, + ) + + # Update filter state to disable specified entries + _filter_state.update(disable_dsl_names, disable_op_symbols, disable_dispatch_keys) + + # Get the set of keys that need to be reprocessed + key_set: set[tuple[str, str]] = _filter_state.build_disable_key_set() + + # Process each affected graph with filter state + for key in key_set: + op_symbol, dispatch_key = key + + if key in _graphs: + _cleanup_and_reregister_graph( + op_symbol, + dispatch_key, + _graphs[key], + filter_state=_filter_state, + ) + + +def get_dsl_operations(dsl_name: str) -> list[str]: + """Get list of operations registered by a specific DSL. + + Args: + dsl_name: Name of the DSL to query. + + Returns: + Sorted list of operation names registered by the DSL. + """ + operations = set() + for (op_symbol, _), nodes in _graphs.items(): + for node in nodes: + if node.dsl_name == dsl_name: + operations.add(op_symbol) + break + return sorted(operations) + + +def _update_registration_maps( + dsl_name: str, + op_symbol: str, + dispatch_key: str, + key: tuple[str, str], +) -> None: + """ + Update the registration mapping dictionaries. - return libs[(lib_symbol, dispatch_key)] + Args: + dsl_name: The DSL name + op_symbol: The operation symbol + dispatch_key: The dispatch key + key: The dictionary key tuple + """ + global _dsl_name_to_lib_graph + global _op_symbol_to_lib_graph + global _dispatch_key_to_lib_graph + + def _get_new_entry_or_append( + registration: dict[str, list[tuple[str, str]]], + symbol: str, + key: tuple[str, str], + ) -> None: + """Helper to add key to registration list or create new entry.""" + entry_list = registration.get(symbol) + + if entry_list is None: + entry_list = [key] + registration[symbol] = entry_list + else: + entry_list.append(key) + + _get_new_entry_or_append(_dsl_name_to_lib_graph, dsl_name, key) + _get_new_entry_or_append(_op_symbol_to_lib_graph, op_symbol, key) + _get_new_entry_or_append(_dispatch_key_to_lib_graph, dispatch_key, key) -def _register_op_override( +def register_op_override( + backend: str, lib_symbol: str, op_symbol: str, dispatch_key: str, impl: _OpOverrideFn | _OpReplaceFn, *, - allow_multiple_override=False, - unconditional_override=False, + allow_multiple_override: bool = False, + unconditional_override: bool = False, ) -> None: """ - Register a passed override function to the dispatcher, based on the - passed lib and op symbols, and the dispatch key. + Register a passed override function to the dispatcher. - lib_symbol: str - library yourve overriding symbols in (generally "aten") - op_symbol: str - name of the op you're overriding - dispatch_key: str - dispatch key to override - impl: Fn - implementation for the override - allow_multiple_override: bool - allow overriding an existing override - unconditional_override: bool - Impl doesn't have a fallback, and doesn't require - torch.DispatchKeySet as the first argument. + Actually a graph-building operation; real registration happens later. + + Args: + backend: The backend name (DSL name) + lib_symbol: Library you're overriding symbols in (must be "aten") + op_symbol: Name of the operation you're overriding + dispatch_key: Dispatch key to override + impl: Implementation function for the override + allow_multiple_override: Allow overriding an existing override + unconditional_override: Implementation doesn't have a fallback and + doesn't require torch.DispatchKeySet as the first argument + + Raises: + ValueError: If lib_symbol is not "aten" """ - lib = _get_library(lib_symbol, dispatch_key) + if lib_symbol != "aten": + raise ValueError(f'Unsupported lib_symbol (must be "aten", got: "{lib_symbol}"') - lib.impl( - op_symbol, - impl, - dispatch_key, - with_keyset=(not unconditional_override), - allow_override=allow_multiple_override, + key = (op_symbol, dispatch_key) + + global _graphs + op_graph = _graphs.get(key, []) + + op_graph.append( + _OverrideNode( + dsl_name=backend, + op_symbol=op_symbol, + dispatch_key=dispatch_key, + override_fn=impl, + unconditional_override=unconditional_override, + ) + ) + _graphs[key] = op_graph + # Build additional maps helpful for de-registration + _update_registration_maps(backend, op_symbol, dispatch_key, key=key) + + +def _should_reregister_graph( + original_graph: list[_OverrideNode], + new_graph: list[_OverrideNode], + *, + force_reregister: bool = False, +) -> bool: + """ + Determine if a graph needs reregistration based on changes. + + Args: + original_graph: The original graph before modification + new_graph: The graph after modification + force_reregister: If True, always reregister regardless of changes + + Returns: + bool: True if reregistration is needed + """ + if force_reregister: + return True + + # Check if the graph structure has changed + return original_graph != new_graph + + +def _cleanup_and_reregister_graph( + op_symbol: str, + dispatch_key: str, + graph: list[_OverrideNode], + *, + filter_state: _FilterState | None = None, +) -> None: + """ + Clean up existing library and reregister a graph. + + This is the common pattern used across reorder, deregister, and reenable operations. + + Args: + op_symbol: The operation symbol + dispatch_key: The dispatch key + graph: The graph to register + filter_state: Optional filter state for conditional registration + """ + key = (op_symbol, dispatch_key) + + # Remove existing library if it exists + if key in _libs: + del _libs[key] + + # Only create a library if the graph has nodes + # Empty graphs (disabled operations) shouldn't get libraries + if graph: + _register_overrides_from_graph( + op_symbol, + dispatch_key, + graph, + filter_state=filter_state, + ) + + +def _apply_graph_transformation( + transformation_fn: UserOrderingFn, + *, + keys_to_process: set[tuple[str, str]] | None = None, + reregister_overrides: bool = False, + filter_state: _FilterState | None = None, +) -> None: + """ + Apply a transformation function to graphs and optionally reregister. + + This is the core pattern used by reorder_graphs_from_user_function and + can be reused for other graph transformation operations. + + Args: + transformation_fn: Function to transform each graph + keys_to_process: Keys to process, or None for all graphs + reregister_overrides: Whether to reregister changed graphs + filter_state: Optional filter state for conditional registration + + Note: + If transformation_fn raises an exception for a specific graph, that graph + will be skipped and processing will continue with remaining graphs. + """ + global _graphs + + # Determine which graphs to process + target_keys = ( + keys_to_process if keys_to_process is not None else set(_graphs.keys()) + ) + + # Process each graph + for op_symbol, dispatch_key in list(target_keys): + if (op_symbol, dispatch_key) not in _graphs: + continue # Skip if graph doesn't exist + + original_graph = list(_graphs[(op_symbol, dispatch_key)]) + + # Apply the transformation with error handling + try: + new_graph = transformation_fn(op_symbol, dispatch_key, original_graph) + except (TypeError, ValueError, AttributeError, RuntimeError): + log.warning( + "Graph transformation failed for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + exc_info=True, + ) + continue + except Exception: + log.exception( + "Unexpected error in graph transformation for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + ) + continue + + # Validate that the transformation returned a valid result + if not isinstance(new_graph, list): + log.warning( + "Graph transformation returned invalid type %s for %s/%s. Expected list. Preserving original graph.", + type(new_graph).__name__, + op_symbol, + dispatch_key, + ) + continue + + # Update the graph + _graphs[(op_symbol, dispatch_key)] = new_graph + + # Reregister if needed + if reregister_overrides and _should_reregister_graph( + original_graph, new_graph, force_reregister=False + ): + _cleanup_and_reregister_graph( + op_symbol, + dispatch_key, + new_graph, + filter_state=filter_state, + ) + + +def _register_overrides_from_graph( + op_symbol: str, + dispatch_key: str, + graph: list[_OverrideNode], + *, + filter_state: _FilterState | None = None, +) -> None: + """ + Register all overrides in a single graph. + + Args: + op_symbol: The operation symbol + dispatch_key: The dispatch key + graph: List of override nodes to register + filter_state: Optional filter state for conditional registration + """ + key = (op_symbol, dispatch_key) + lib = _get_or_create_library(*key) + + for node in graph: + enable = True + if filter_state: + enable = filter_state.check_enabled(node) + + if enable: + _register_node_impl(lib, node, dispatch_key) + node.active = True + else: + node.active = False + + +def _register_all_overrides() -> None: + """ + Perform all registration calls from previously-built override graphs. + """ + for key, graph in _graphs.items(): + op_symbol, dispatch_key = key + + _register_overrides_from_graph( + op_symbol, + dispatch_key, + graph, + ) + + +def reorder_graphs_from_user_function( + fn: UserOrderingFn, + *, + reregister_overrides: bool = False, +) -> None: + """ + Reorder override graphs using a user-provided ordering function. + + Args: + fn: User-provided function that takes (op_symbol, dispatch_key, graph) + and returns a reordered graph + reregister_overrides: Whether to reregister graphs that have changed + + Note: + This function uses the common graph transformation pattern and can serve + as an example for other graph manipulation operations. + """ + _apply_graph_transformation( + transformation_fn=fn, + reregister_overrides=reregister_overrides, + ) + + +def _apply_graph_filter( + filter_fn: Callable[[str, str, _OverrideNode], bool], + *, + reregister_overrides: bool = False, +) -> None: + """ + Apply a filter function to remove nodes from graphs. + + This is a convenience function that uses the graph transformation pattern + to filter out unwanted nodes. + + Args: + filter_fn: Function that takes (op_symbol, dispatch_key, node) and + returns True to keep the node, False to remove it + reregister_overrides: Whether to reregister modified graphs + + Example: + # Remove all nodes with "deprecated" in the DSL name + _apply_graph_filter( + lambda op, dk, node: "deprecated" not in node.dsl_name, + reregister_overrides=True + ) + + Note: + If filter_fn raises an exception for a specific graph, the original + graph will be preserved and processing will continue. + """ + + def filtering_transformation( + op_symbol: str, dispatch_key: str, graph: list[_OverrideNode] + ) -> list[_OverrideNode]: + """Apply filter_fn to graph with error handling.""" + try: + return [node for node in graph if filter_fn(op_symbol, dispatch_key, node)] + except (TypeError, ValueError, AttributeError, RuntimeError): + log.warning( + "Graph transformation failed for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + exc_info=True, + ) + return graph + except Exception: + log.exception( + "Unexpected error in graph transformation for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + ) + return graph + + _apply_graph_transformation( + transformation_fn=filtering_transformation, + reregister_overrides=reregister_overrides, + ) + + +def _apply_selective_reordering( + condition_fn: Callable[[str, str], bool], + ordering_fn: UserOrderingFn, + *, + reregister_overrides: bool = False, +) -> None: + """ + Apply reordering only to graphs that match a condition. + + This allows for more targeted reordering operations. + + Args: + condition_fn: Function that takes (op_symbol, dispatch_key) and + returns True if the graph should be reordered + ordering_fn: Ordering function to apply to matching graphs + reregister_overrides: Whether to reregister modified graphs + + Example: + # Only reorder CUDA operations + _apply_selective_reordering( + condition_fn=lambda op, dk: dk == "CUDA", + ordering_fn=lambda op, dk, g: sorted(g, key=lambda n: n.dsl_name), + reregister_overrides=True + ) + + Note: + If condition_fn or ordering_fn raises an exception for a specific graph, + the original graph will be preserved and processing will continue. + """ + + def conditional_transformation( + op_symbol: str, dispatch_key: str, graph: list[_OverrideNode] + ) -> list[_OverrideNode]: + """Apply ordering_fn conditionally based on condition_fn result.""" + try: + should_reorder = condition_fn(op_symbol, dispatch_key) + except (TypeError, ValueError, AttributeError, RuntimeError): + log.warning( + "Graph transformation failed for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + exc_info=True, + ) + return graph + except Exception: + log.exception( + "Unexpected error in graph transformation for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + ) + return graph + + if should_reorder: + try: + return ordering_fn(op_symbol, dispatch_key, graph) + except (TypeError, ValueError, AttributeError, RuntimeError): + log.warning( + "Graph transformation failed for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + exc_info=True, + ) + return graph + except Exception: + log.exception( + "Unexpected error in graph transformation for %s/%s. Preserving original graph.", + op_symbol, + dispatch_key, + ) + return graph + + return graph # Return unchanged if condition doesn't match + + _apply_graph_transformation( + transformation_fn=conditional_transformation, + reregister_overrides=reregister_overrides, ) diff --git a/torch/_native/triton_utils.py b/torch/_native/triton_utils.py index d8bc431152189..f875e9aacc303 100644 --- a/torch/_native/triton_utils.py +++ b/torch/_native/triton_utils.py @@ -1,20 +1,29 @@ import functools import logging +import sys +from typing import cast -from packaging.version import Version +from torch._vendor.packaging.version import Version +from ..backends import cuda as _cuda from .common_utils import ( _available_version, _unavailable_reason, check_native_jit_disabled, check_native_version_skip, ) -from .registry import _OpFn, _register_op_override +from .dsl_registry import dsl_registry, DSLModuleProtocol +from .registry import ( + _OpFn, + deregister_op_overrides as _deregister_op_overrides_impl, + register_op_override as _register_op_override_impl, +) log = logging.getLogger(__name__) +_TRITON_DSL_NAME = "triton" _TRITON_REQUIRED_VERSION_MAJOR = 3 _TRITON_MINIMUM_VERSION_MINOR = 6 @@ -26,6 +35,9 @@ def _check_runtime_available() -> tuple[bool, Version | None]: NOTE: must not import at this point """ + # Skip all checks if running on CPU-only binary + if not _cuda.is_built(): + return (False, None) deps = [ ("triton", "triton"), @@ -35,7 +47,7 @@ def _check_runtime_available() -> tuple[bool, Version | None]: available = True version = _available_version("triton") else: - log.info("triton native DSL ops require: `triton` %s", reason) + log.warning("triton native DSL ops require: `triton` %s", reason) available = False version = None return available, version @@ -65,7 +77,7 @@ def _version_is_sufficient() -> bool: if (major_ok and minor_ok) or check_native_version_skip(): return True - log.info( + log.warning( "triton version %s is not sufficient (>= (%s.%s.*)); " "set TORCH_NATIVE_SKIP_VERSION_CHECK=1 to override", version, @@ -75,6 +87,13 @@ def _version_is_sufficient() -> bool: return False +def deregister_op_overrides() -> None: + """ + Deregister all ops through triton + """ + _deregister_op_overrides_impl(disable_dsl_names=_TRITON_DSL_NAME) + + def register_op_override( lib_symbol: str, op_symbol: str, @@ -87,7 +106,7 @@ def register_op_override( """ See torch/_native/registry.py for the underlying implementation and arguments. This is a thin, DSL-checking wrapper over - _register_op_override + _register_op_override_impl """ available, version = _check_runtime_available() if (not available) or check_native_jit_disabled(): @@ -96,7 +115,8 @@ def register_op_override( if not _version_is_sufficient(): return - _register_op_override( + _register_op_override_impl( + _TRITON_DSL_NAME, lib_symbol, op_symbol, dispatch_key, @@ -104,3 +124,8 @@ def register_op_override( allow_multiple_override=allow_multiple_override, unconditional_override=unconditional_override, ) + + +# Register this DSL module with the registry +# Note: Import-time registration ensures DSL is available when module is loaded +dsl_registry.register_dsl("triton", cast(DSLModuleProtocol, sys.modules[__name__])) diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index 134f7617b758a..19f5f991d7e6c 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -447,6 +447,6 @@ def str_to_abstract(t): __all__ = ["dtype", "DType", "typecodes", "issubdtype", "set_default_dtype", "sctypes"] -__all__ += list(_names.keys()) # noqa: PLE0605 -__all__ += list(_name_aliases.keys()) # noqa: PLE0605 +__all__ += list(_names.keys()) +__all__ += list(_name_aliases.keys()) __all__ += _abstract_dtypes # noqa: PLE0605 diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index d024f98588d2d..e55c119373c86 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -415,6 +415,10 @@ def imag(self): def imag(self, value): self.tensor.imag = asarray(value).tensor + @property + def flat(self): + return self.ravel() + # ctors def astype(self, dtype, order="K", casting="unsafe", subok=True, copy=True): if order != "K": diff --git a/torch/_opaque_base.py b/torch/_opaque_base.py index d08048713e3d0..dfb7949c0bc30 100644 --- a/torch/_opaque_base.py +++ b/torch/_opaque_base.py @@ -1,13 +1,35 @@ +# Cached lazily on first __instancecheck__ miss to avoid an import cycle at +# module load (FakeScriptObject's module imports torch, which imports us). +_FakeScriptObject_cls: type | None = None + + class OpaqueBaseMeta(type): def __instancecheck__(cls, instance): + # When checking against OpaqueBase itself (not a concrete subclass), + # delegate to the registration system which correctly covers all + # opaque types (value types, metaclass-only reference types, and + # FakeScriptObject wrappers). + if cls is OpaqueBase: + from torch._library.opaque_object import is_opaque_value + + return is_opaque_value(instance) + if super().__instancecheck__(instance): return True - if hasattr(instance, "real_obj"): + # Check FakeScriptObject before hasattr to avoid triggering custom + # __getattr__ on arbitrary user objects (e.g. dict-like objects that + # raise KeyError on unknown attributes). + # e.g. test/dynamo/test_dynamic_shapes.py -k test_user_getattr1_dynamic_shapes + global _FakeScriptObject_cls + if _FakeScriptObject_cls is None: from torch._library.fake_class_registry import FakeScriptObject - if isinstance(instance, FakeScriptObject): - return super().__instancecheck__(instance.real_obj) + _FakeScriptObject_cls = FakeScriptObject + if isinstance(instance, _FakeScriptObject_cls) and hasattr( + instance, "real_obj" + ): + return super().__instancecheck__(instance.real_obj) return False diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py index f2831a56ecf73..6db0499074565 100644 --- a/torch/_prims/__init__.py +++ b/torch/_prims/__init__.py @@ -2467,7 +2467,7 @@ def _iota_aten( iota = _make_prim( - schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", return_type=RETURN_TYPE.NEW, meta=_iota_meta, impl_aten=_iota_aten, @@ -2575,7 +2575,7 @@ def _empty_permuted_meta( # TODO: add layout, pin_memory empty_permuted = _make_prim( - schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950 + schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor", return_type=RETURN_TYPE.NEW, meta=_empty_permuted_meta, impl_aten=torch.empty_permuted, @@ -2832,7 +2832,7 @@ def _normal_aten( normal = _make_prim( schema=( - "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950 + "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" ), return_type=RETURN_TYPE.NEW, meta=_normal_meta, diff --git a/torch/_prims/rng_prims.py b/torch/_prims/rng_prims.py index 60039060769fb..525607c842c1c 100644 --- a/torch/_prims/rng_prims.py +++ b/torch/_prims/rng_prims.py @@ -78,7 +78,7 @@ def philox_rand_offset( def register_philox_rand(): name = "philox_rand" - schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950 + schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" def _philox_rand_meta( shape: torch.Size, @@ -394,6 +394,18 @@ def impl_functional(ctx, op, *args, rng_state=None, **kwargs): graphsafe_run_with_rng_state = register_graphsafe_run_with_rng_state_op() +torch._library.opaque_object.register_opaque_type( + torch._C.Generator, + typ="reference", + guard_fn=lambda gen: [gen.device], + members={ + "device": torch._library.opaque_object.MemberType.USE_REAL, + "__eq__": torch._library.opaque_object.MemberType.USE_REAL, + "__ne__": torch._library.opaque_object.MemberType.USE_REAL, + }, +) + + def register_run_dtensor_rng_op(): """ Register a higher-order operator for DTensor distributed random operations. diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 8caa7fb32056b..17083fb30a9f4 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -20,6 +20,7 @@ TensorLikeType, ) from torch.utils import _pytree as pytree +from torch.utils._inspect import _fast_bind from torch.utils._pytree import tree_flatten, tree_unflatten @@ -129,7 +130,7 @@ def __call__(self, fn: Callable) -> Callable: @torch._disable_dynamo @wraps(fn) def _fn(*args, **kwargs): - bound = sig.bind(*args, **kwargs) + bound = _fast_bind(sig, *args, **kwargs) type_promoting_args = tuple( bound.arguments[x] for x in self.type_promoting_arg_names # type: ignore[union-attr] diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index d3a18043b0d49..7c089a68a57ec 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -50,6 +50,7 @@ elementwise_unary_scalar_wrapper, out_wrapper, ) +from torch.testing._internal.common_dtype import highest_precision_float # Experimental module containing prototype Python references for existing @@ -387,9 +388,9 @@ def handle_noncontiguous_outputs(input_tlist, output): def _broadcast_shapes(*_shapes): from torch.fx.experimental.symbolic_shapes import ( guard_or_false, - has_hint, + guarding_hint_or_throw, + has_guarding_hint, is_nested_int, - size_hint, ) backed_so = torch.fx.experimental._config.backed_size_oblivious @@ -433,9 +434,13 @@ def _broadcast_shapes(*_shapes): # specialize(s0) to be 1. # s0:4, s1:1 ==> # specialize(s1) to be 1. - if backed_so and has_hint(shape[idx]) and has_hint(common_shape[idx]): - a = size_hint(shape[idx]) - b = size_hint(common_shape[idx]) + if ( + backed_so + and has_guarding_hint(shape[idx]) + and has_guarding_hint(common_shape[idx]) + ): + a = guarding_hint_or_throw(shape[idx]) + b = guarding_hint_or_throw(common_shape[idx]) if a == 1 and b != 1: torch._check(shape[idx] == 1) if b == 1 and a != 1: @@ -1184,6 +1189,16 @@ def _ref( return inner +def _binary_op_dtype( + a: TensorLikeType | NumberType, b: TensorLikeType | NumberType +) -> torch.dtype: + if isinstance(a, TensorLike): + return a.dtype + if isinstance(b, TensorLike): + return b.dtype + return utils.type_to_dtype(type(a)) + + # Add has its own implementation because it has an alpha argument @register_decomposition(aten.add) @out_wrapper() @@ -1204,7 +1219,7 @@ def add( a, b = _maybe_broadcast(a, b) if alpha is not None: - dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + dtype = _binary_op_dtype(a, b) python_type = utils.dtype_to_type(dtype) if python_type is not bool and not utils.is_weakly_lesser_type( type(alpha), python_type @@ -1854,7 +1869,7 @@ def sub( ) if alpha != 1: - dtype = a.dtype if isinstance(a, TensorLike) else b.dtype # type: ignore[union-attr] + dtype = _binary_op_dtype(a, b) python_type = utils.dtype_to_type(dtype) if not utils.is_weakly_lesser_type(type(alpha), python_type): msg = f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!" @@ -2039,7 +2054,11 @@ def clamp_max( # https://pytorch.org/docs/stable/generated/torch.where.html -# TODO: implement where.default +@register_decomposition(aten.where.default) +def _where_default(pred: Tensor) -> tuple[Tensor, ...]: + return torch.nonzero(pred, as_tuple=True) + + @register_decomposition(aten.where.self) @register_decomposition(aten.where.ScalarSelf) @register_decomposition(aten.where.ScalarOther) @@ -3116,8 +3135,8 @@ def dstack(tensors: TensorSequenceType) -> TensorLikeType: def expand(a: Tensor, *shape, implicit: bool = False) -> Tensor: from torch.fx.experimental.symbolic_shapes import ( guard_or_false, - has_hint, - size_hint, + guarding_hint_or_throw, + has_guarding_hint, sym_or, ) @@ -3161,9 +3180,13 @@ def expand(a: Tensor, *shape, implicit: bool = False) -> Tensor: # The non-broadcast path is picked # x:1, requested_length:4 ==> # specialize(x) to be 1. - if backed_so and has_hint(x) and has_hint(requested_length): - x_hint = size_hint(x) - requested_hint = size_hint(requested_length) + if ( + backed_so + and has_guarding_hint(x) + and has_guarding_hint(requested_length) + ): + x_hint = guarding_hint_or_throw(x) + requested_hint = guarding_hint_or_throw(requested_length) if x_hint == 1 and requested_hint != 1: torch._check(x == 1) @@ -3453,6 +3476,10 @@ def native_layer_norm( + ", but got input of size " + str(input.shape), ) + torch._check( + not input.is_complex(), + lambda: "native_layer_norm does not support complex inputs", + ) input = contiguous(input) if weight is not None: @@ -5443,13 +5470,13 @@ def linspace( start.dim() == 0, lambda: "linspace only supports 0-dimensional start and end tensors", ) - start = _maybe_convert_to_dtype(start, torch.float64) + start = _maybe_convert_to_dtype(start, highest_precision_float(device)) if isinstance(end, TensorLikeType): torch._check( end.dim() == 0, lambda: "linspace only supports 0-dimensional start and end tensors", ) - end = _maybe_convert_to_dtype(end, torch.float64) + end = _maybe_convert_to_dtype(end, highest_precision_float(device)) if builtins.any(isinstance(arg, complex) for arg in (start, end, steps)): default_complex_dtype = utils.corresponding_complex_dtype( diff --git a/torch/_refs/fft.py b/torch/_refs/fft.py index d4dc72d3d47d2..8de041aa56b22 100644 --- a/torch/_refs/fft.py +++ b/torch/_refs/fft.py @@ -67,7 +67,7 @@ def _promote_type_fft( dtype = torch.get_default_dtype() allowed_types = [torch.float32, torch.float64] - maybe_support_half = device.type in ["cuda", "meta"] + maybe_support_half = device.type in ["cuda", "meta", "xpu"] if maybe_support_half: allowed_types.append(torch.float16) diff --git a/torch/_subclasses/_fake_tensor_utils.py b/torch/_subclasses/_fake_tensor_utils.py index a20daa1a311a8..61a71a5a1d3c7 100644 --- a/torch/_subclasses/_fake_tensor_utils.py +++ b/torch/_subclasses/_fake_tensor_utils.py @@ -35,9 +35,10 @@ def from_node(node: SymNode) -> _DeconstructedSymNode: return _DeconstructedSymNode( node._expr, node.pytype, + # pyrefly: ignore[bad-argument-type] node._hint, node.constant, - # pyrefly: ignore [bad-argument-type] + # pyrefly: ignore[bad-argument-type] node.fx_node, ) diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 64445cf2e9f57..e7e66bdd57acd 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -426,9 +426,9 @@ def _unique( if dim is None: if unique_consecutive: - arg.unique_consecutive_memo = nnz + arg.unique_consecutive_memo = nnz # pyrefly: ignore[bad-assignment] else: - arg.unique_memo = nnz + arg.unique_memo = nnz # pyrefly: ignore[bad-assignment] if dim is None: # pyrefly: ignore[no-matching-overload] @@ -439,15 +439,19 @@ def _unique( return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu") if return_inverse or return_if_dim_and_cpu: - inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],)) + inverse = arg.new_empty( + arg.shape if dim is None else (arg.shape[dim],), dtype=torch.int64 + ) else: - inverse = arg.new_empty(0) + inverse = arg.new_empty(0, dtype=torch.int64) ret.append(inverse) if return_counts or return_if_dim_and_cpu: - counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],)) + counts = arg.new_empty( + ret[0].shape if dim is None else (ret[0].shape[dim],), dtype=torch.int64 + ) else: - counts = arg.new_empty(0) + counts = arg.new_empty(0, dtype=torch.int64) ret.append(counts) return tuple(ret) @@ -641,14 +645,14 @@ def maybe_guard_or_true(x: Any) -> Any: def _view_has_unbacked_input( a: torch.Tensor, shape: ShapeType | tuple[ShapeType] ) -> bool: - from torch.fx.experimental.symbolic_shapes import has_hint + from torch.fx.experimental.symbolic_shapes import has_guarding_hint shape = utils.extract_shape_from_varargs(shape, validate=False) return ( - any(not has_hint(s) for s in a.size()) - or any(not has_hint(s) for s in a.stride()) - or any(not has_hint(s) for s in shape) + any(not has_guarding_hint(s) for s in a.size()) + or any(not has_guarding_hint(s) for s in a.stride()) + or any(not has_guarding_hint(s) for s in shape) ) @@ -905,7 +909,7 @@ def nonzero(fake_mode: FakeTensorMode, func: OpOverload, arg: FakeTensor) -> Fak _constrain_range_for_size(nnz, max=maxval) - arg.nonzero_memo = nnz + arg.nonzero_memo = nnz # pyrefly: ignore[bad-assignment] return arg.new_empty_strided((nnz, arg.dim()), (1, nnz), dtype=torch.int64) # type: ignore[return] @@ -961,6 +965,11 @@ def _compute_slice_index(size: IntLikeType, index: IntLikeType) -> IntLikeType | return 0 elif guard_or_false(index > size): return size + elif guard_or_false(index >= 0): + return torch.sym_min(index, size) + elif guard_or_false(index < 0): + return torch.sym_max(index + size, 0) + return None @@ -1005,6 +1014,12 @@ def slice_forward( new_size = (end_index - start_index + step - 1) // step elif guard_or_false(start_index >= end_index): new_size = 0 + else: + # Both indices are resolved but we can't statically determine their + # ordering (e.g., when they involve Min/Max). Compute the size via + # max(end - start, 0) to avoid creating an unbacked symint. + diff = torch.sym_max(end_index - start_index, 0) + new_size = (diff + step - 1) // step # create unbacked if case unknown if new_size is None: @@ -1329,17 +1344,19 @@ def conv( _, new_kwargs = _normalize_function_or_error( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True ) - device = new_kwargs["input"].fake_device + input_ = new_kwargs["input"] + weight = new_kwargs["weight"] + device = input_.fake_device # need to re-enable mode so the tensors report fake device with fake_mode: - # if the input is unsqueezed is done in Convolution.cpp we get segfault - k = new_kwargs["weight"].ndim + # if the input is unsqueezed in Convolution.cpp we get segfault + k = weight.ndim # Avoid importing sympy at a module level - from torch.fx.experimental.symbolic_shapes import has_hint + from torch.fx.experimental.symbolic_shapes import has_guarding_hint - all_hinted = all(has_hint(s) for s in new_kwargs["input"].shape) and all( - has_hint(s) for s in new_kwargs["weight"].shape + all_hinted = all(has_guarding_hint(s) for s in input_.shape) and all( + has_guarding_hint(s) for s in weight.shape ) if not all_hinted: @@ -1347,53 +1364,33 @@ def conv( # channels last detection (but only if it's statically obvious!) mem_fmt = None else: - if func is aten.convolution.default: - conv_backend = torch._C._select_conv_backend(**new_kwargs) - else: - conv_backend = torch._C._select_conv_backend( - new_kwargs["input"], - new_kwargs["weight"], - bias=None, - stride=new_kwargs["stride"], - padding=new_kwargs["padding"], - dilation=new_kwargs["dilation"], - transposed=new_kwargs["transposed"], - output_padding=new_kwargs["output_padding"], - groups=new_kwargs["groups"], - bias_sizes=new_kwargs["bias_sizes"], - ) + # convolution has "bias" but not "bias_sizes"; convolution_backward + # has "bias_sizes" but not "bias". .get() handles both with one call. + bias = new_kwargs.get("bias") + select_kwargs: dict[str, object] = dict( + stride=new_kwargs["stride"], + padding=new_kwargs["padding"], + dilation=new_kwargs["dilation"], + transposed=new_kwargs["transposed"], + output_padding=new_kwargs["output_padding"], + groups=new_kwargs["groups"], + bias=bias, + ) + if bias is None: + select_kwargs["bias_sizes"] = new_kwargs.get("bias_sizes") + conv_backend = torch._C._select_conv_backend( + input_, weight, **select_kwargs + ) # Expand 1d -> 2d. # Note: Avoid expanding before calling _select_conv_backend, # as the function handles 2D expansion internally. - if ( - k == 3 - and not new_kwargs["input"].is_mkldnn - and not new_kwargs["input"].is_xpu - ): + if k == 3 and not input_.is_mkldnn and not input_.is_xpu: # Note: Using input.to(memory_format=contiguous) does not work. - new_kwargs["input"] = new_kwargs["input"].contiguous().unsqueeze(2) - new_kwargs["weight"] = new_kwargs["weight"].unsqueeze(2) - if len(new_kwargs["stride"]) == 1: - new_kwargs["stride"].insert(0, 1) - new_kwargs["padding"].insert(0, 0) - new_kwargs["dilation"].insert(0, 1) - new_kwargs["output_padding"].insert(0, 0) + input_ = input_.contiguous().unsqueeze(2) + weight = weight.unsqueeze(2) mem_fmt = torch._C._conv_determine_backend_memory_format( - new_kwargs["input"], new_kwargs["weight"], conv_backend + input_, weight, conv_backend ) - # revert 2d -> 1d - if ( - k == 3 - and not new_kwargs["input"].is_mkldnn - and not new_kwargs["input"].is_xpu - ): - new_kwargs["input"] = new_kwargs["input"].squeeze(2) - new_kwargs["weight"] = new_kwargs["weight"].squeeze(2) - if len(new_kwargs["stride"]) == 2: - new_kwargs["stride"].pop(0) - new_kwargs["padding"].pop(0) - new_kwargs["dilation"].pop(0) - new_kwargs["output_padding"].pop(0) def convert( t: torch.Tensor | None, mem_fmt: torch.memory_format | None diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 48306f93eae26..a630a72aa6d8c 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -34,7 +34,7 @@ is_sparse_compressed, MetaConverter, ) -from torch._utils import render_call +from torch._utils import _is_privateuse1_backend_available, render_call from torch.fx.immutable_collections import immutable_dict from torch.fx.operator_schemas import normalize_function from torch.multiprocessing.reductions import StorageWeakRef @@ -44,6 +44,7 @@ from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, TorchDispatchMode, + TraceableWrapperSubclass, ) from torch.utils._pytree import KeyPath, keystr, PyTree, tree_map, tree_map_, TreeSpec from torch.utils._stats import count @@ -57,7 +58,6 @@ from types import TracebackType from torch._guards import Source - from torch._library.opaque_object import OpaqueType from torch._ops import OpOverload from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext @@ -183,8 +183,10 @@ def disable_fake_tensor_cache(fake_mode: FakeTensorMode) -> Generator[None, None def get_plain_tensors( - subclass: Tensor, *, out: list[Tensor | int | SymInt | OpaqueType] -) -> list[Tensor | int | SymInt | OpaqueType]: + subclass: Tensor | TraceableWrapperSubclass, + *, + out: list[Tensor | int | SymInt | OpaqueBase], +) -> list[Tensor | int | SymInt | OpaqueBase]: # This function is used in Runtime, do not add redundant asserts todo = [subclass] while todo: @@ -457,6 +459,25 @@ def mk_fake_tensor( if out is NotImplemented: raise UnsupportedFakeTensorException("meta converter nyi") + # Propagate grad_dtype here rather than in meta_converter because + # meta tensors don't carry autograd metadata. + # Unwrap FunctionalTensor because accessing is_leaf/grad_fn on a + # FunctionalTensor view whose base was mutated (e.g. via set_()) + # triggers lazy view replay through __torch_dispatch__, which + # errors without an active FunctionalTensorMode. + inner_t = ( + torch._from_functional_tensor(t.elem) + if isinstance(t, torch._subclasses.functional_tensor.FunctionalTensor) + else t + ) + if ( + inner_t.requires_grad + and inner_t.is_leaf + and inner_t.grad_dtype != inner_t.dtype + and out.is_leaf + ): + out.grad_dtype = inner_t.grad_dtype + from torch._dynamo.source import RandomValueSource value = None @@ -665,7 +686,9 @@ def __get__( return r def __set__( - self, obj: FakeTensor, value: torch.SymInt | torch.SymFloat | None + self, + obj: FakeTensor, + value: torch.SymInt | torch.SymFloat | torch.SymBool | int | float | None, ) -> None: if value is None: setattr(obj, self._memo(obj), None) @@ -965,16 +988,25 @@ def _find_common_device( aten._foreach_copy.default, ) + # These in-place ops keep the destination tensor's device even if the + # rhs was explicitly constructed on meta. + meta_rhs_mixed_device_fns = ordered_set( + aten.add_.Tensor, + ) + # list of ops not using zero dim cpu tensor logic to align with the eager mode. bypass_zero_dim_cpu_tensor_check_ops = ordered_set( aten.nextafter.default, ) - def check_cpu_device(device: torch.device) -> bool: + def is_device_cpu(device: torch.device) -> bool: return device.type == "cpu" + def is_device_meta(device: torch.device) -> bool: + return device.type == "meta" + def cpu_zero_dim(t: Tensor) -> bool: - return check_cpu_device(t.device) and t.dim() == 0 + return is_device_cpu(t.device) and t.dim() == 0 def merge_devices(t: object) -> None: nonlocal common_device @@ -1013,7 +1045,11 @@ def merge_devices(t: object) -> None: # device must be cpu in this case we will return from here without # throwing an error if func in mixed_device_fns: - if any(map(check_cpu_device, (common_device, t.device))): + if any(map(is_device_cpu, (common_device, t.device))): + return + + if func in meta_rhs_mixed_device_fns: + if any(map(is_device_meta, (common_device, t.device))): return # if prefer_device_type is set, prefer that device type over others @@ -1424,6 +1460,7 @@ def avoid_device_init(self) -> bool: return not ( torch.cuda.is_available() or (hasattr(torch, "hpu") and torch.hpu.is_available()) + or _is_privateuse1_backend_available() ) @property @@ -2321,7 +2358,7 @@ def _check_fake_real_vals(fake: Any, real: Any) -> None: "mismatched_fake_kernel", metadata_fn=lambda: { "op": str(func), - "reason": f"mismatch between fake value {fake} and real value {real}", # noqa: F821 + "reason": f"mismatch between fake value {fake} and real value {real}", }, ) return _infer_fake_from_real_tensor(self, func, real), True # type: ignore[arg-type] @@ -2665,7 +2702,7 @@ def maybe_to_real_tensor( # we shouldn't broadly catch all errors here; # some come from real-kernel mutation/aliasing checks we want to run. # add more exception types as needed. - log.debug( # noqa: G200 + log.debug( "real-tensor fallback failed for %s: %s; silently ignoring", func, exc, @@ -2728,9 +2765,7 @@ def go(t: object, real_t: Tensor) -> None: "self.shape_env must not be None for symbolic Eq" ) - self.shape_env.set_real_tensor_prop_unbacked_vals( - s, int(real_t) - ) + self.shape_env.set_real_tensor_prop_unbacked_vals(s, real_t) if real_out is not nil: # cross check fake/real outputs, and optionally override fake kernel mismatches @@ -2891,18 +2926,23 @@ def go(t: object, real_t: Tensor) -> None: # and then afterwards wrapping them to a FakeTensor for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): + # pyrefly: ignore [bad-argument-count] op_impl_out = op_impl(self, func, *args, **kwargs) if op_impl_out is not NotImplemented: + # pyrefly: ignore [bad-return] return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback( error: RuntimeError | None = None, ) -> FakeTensor | None: - # We infer the meta of a custom ops that return None to just - # return None. custom ops are not allowed to mutate metadata - # of their inputs, so this is safe. + # We infer the meta of custom ops that return None to just + # return None, and Tag.out ops to return their out= args. + # Custom ops are not allowed to mutate metadata of their + # inputs, so this is safe. if torch._library.utils.can_generate_trivial_fake_impl(func): - return None + return torch._library.utils.generate_trivial_fake_impl( + func, *args, **kwargs + ) # no meta kernel registered, fallback to kernel for the device if has_symbolic_sizes or not self.can_run_unsafe_fallback(func): raise UnsupportedOperatorException(func) @@ -3063,7 +3103,7 @@ def wrap(e: T) -> T | FakeTensor: def create_symbolic_nested_int( self, *, nt_tensor_id: int | None = None - ) -> torch.SymInt: + ) -> IntLikeType: # See Note: [Creating symbolic nested int] # Returned nested int always has coeff=1; multiply the result by coeff if needed import torch.nested._internal.nested_tensor diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index ffa2b4194caa5..da0d815b65c2e 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +import functools import warnings import weakref from abc import ABC, abstractmethod @@ -35,6 +36,66 @@ not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") +def _has_unrecognized_tensor_types(types: Sequence[type]) -> bool: + unrecognized_types = [ + t + for t in types + if t not in (torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor) + ] + if unrecognized_types: + not_implemented_log.debug( + "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types + ) + return bool(unrecognized_types) + + +@functools.lru_cache(maxsize=512) +def _can_decompose_fast( + func: OpOverload, export: bool, pre_dispatch: bool +) -> bool | None: + """Fast path for _can_decompose that depends only on (func, export, pre_dispatch). + + Returns True/False for a definitive answer, or None to fall through + to the slow path (autograd_would_have_decomposed). + """ + if export and func is torch.ops.aten.dropout.default: + return False + + from torch._decomp import _should_decompose_because_unsafe_op + + if _should_decompose_because_unsafe_op(func): + return True + + alias_info_present = any(arg.alias_info for arg in func._schema.arguments) + if alias_info_present or func._schema.is_mutable: + return True + + if export: + if pre_dispatch: + if func.namespace not in ("aten", "prim") and func._can_decompose(): + warnings.warn( + f"At pre-dispatch tracing, we assume that any custom op marked with " + f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " + f"Found {func} to be one such op.", + stacklevel=3, + ) + return False + return True + + return None + + +def _assert_functionalize_not_active(msg: str) -> None: + is_included = torch._C._dispatch_tls_is_dispatch_key_included( + torch._C.DispatchKey.Functionalize + ) + is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( + torch._C.DispatchKey.Functionalize + ) + if not is_excluded and is_included: + raise AssertionError(msg) + + # NOTE Some special handling for tensor conversion during export is needed. # Normally, when tracing through the model with tensor.to(), the maybe-aliasing # relationship between input and output tensors will be baked into the graph. @@ -66,7 +127,7 @@ class FunctionalTensor(torch.Tensor): This class is a lightweight python shim around the C++ functionalization logic. FunctionalTensor is required to be used with a corresponding - FunctionalTensormode active, because it relies + FunctionalTensorMode active, because it relies on using the mode for dispatch (which can properly handle factory functions). """ @@ -88,22 +149,24 @@ class FunctionalTensor(torch.Tensor): # These are all aten ops that correspond to metadata queries. # We want FunctionalTensor to be able to handle them directly. - metadata_fns = [ - torch.ops.aten.is_contiguous.default, - torch.ops.aten.is_contiguous.memory_format, - torch.ops.aten.is_strides_like_format.default, - torch.ops.aten.is_non_overlapping_and_dense.default, - torch.ops.aten.size.default, - torch.ops.aten.sym_size.default, - torch.ops.aten.stride.default, - torch.ops.aten.sym_stride.default, - torch.ops.aten.storage_offset.default, - torch.ops.aten.sym_storage_offset.default, - torch.ops.aten.numel.default, - torch.ops.aten.sym_numel.default, - torch.ops.aten.dim.default, - torch.ops.prim.device.default, - ] + metadata_fns = frozenset( + { + torch.ops.aten.is_contiguous.default, + torch.ops.aten.is_contiguous.memory_format, + torch.ops.aten.is_strides_like_format.default, + torch.ops.aten.is_non_overlapping_and_dense.default, + torch.ops.aten.size.default, + torch.ops.aten.sym_size.default, + torch.ops.aten.stride.default, + torch.ops.aten.sym_stride.default, + torch.ops.aten.storage_offset.default, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.numel.default, + torch.ops.aten.sym_numel.default, + torch.ops.aten.dim.default, + torch.ops.prim.device.default, + } + ) # Used by auto_functionalize to determine base of tensors during inference mode. _inference_mode_base: FunctionalTensor | None = None @@ -182,15 +245,7 @@ def __torch_dispatch__( # type: ignore[override] args: tuple[Any, ...] = (), kwargs: dict[str, Any] | None = None, ) -> Any: - unrecognized_types = [ - t - for t in types - if t not in [torch.Tensor, torch._subclasses.FakeTensor, FunctionalTensor] - ] - if unrecognized_types: - not_implemented_log.debug( - "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types - ) + if _has_unrecognized_tensor_types(types): return NotImplemented if kwargs is None: @@ -407,69 +462,12 @@ def __torch_dispatch__( if kwargs is None: kwargs = {} - unrecognized_types = [ - t - for t in types - if not issubclass(t, torch._subclasses.FakeTensor) - and t not in [torch.Tensor, FunctionalTensor] - ] - - if unrecognized_types: - not_implemented_log.debug( - "FunctionalTensor unrecognized subclass(es): %s", unrecognized_types - ) + if _has_unrecognized_tensor_types(types): return NotImplemented - def _can_decompose(func: OpOverload) -> bool: - # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832 - # Never decompose dropout in export - if self.export and func is torch.ops.aten.dropout.default: - return False - - # We unconditionally decompose ops that are maybe aliasing or mutating ops - from torch._decomp import _should_decompose_because_unsafe_op - - if _should_decompose_because_unsafe_op(func): - return True - - # (1) we unconditionally decompose maybe-aliasing or maybe-mutating ops, - # because we must know statically of an op mutates or aliasing in order to functionalize it properly - # (2) for mutating ops that have CompositeImplicit decomps, we choose to decompose them today. - # In theory, we could walk this back and avoid decomposing them later if we need to. - alias_info_present = any(arg.alias_info for arg in func._schema.arguments) - if alias_info_present or func._schema.is_mutable: - return True - - # If we are here, it means we are seeing functional composite op. - # For pre-dispatch IR, we don't want to decompose this op - # For post-dispatch IR, we do want to decompose this op. it is fine - # to decompose here even if you want to preserve a CIA in post-dispatch export - # because we already override decompose behaviour so it will do the - # right thing. - if self.export: - if self.pre_dispatch: - # If it is CIA custom op, we warn that we are assuming this op is indeed functional. - if func.namespace not in ["aten", "prim"] and func._can_decompose(): - warnings.warn( - f"At pre-dispatch tracing, we assume that any custom op marked with " - f"CompositeImplicitAutograd and have functional schema are safe to not decompose. " - f"Found {func} to be one such op.", - stacklevel=2, - ) - return False - return True - - # in normal torch.compile IR, we only decompose an op if autograd - # would have decomposed it (NB: autograd may have been skipped if - # we are in inference mode) - # TODO: the flatten here can potentially be deduped with the - # unwrapping pytree_map later - flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) - return autograd_would_have_decomposed(func, flat_args_kwargs) - if ( func not in FunctionalTensor.metadata_fns - and _can_decompose(func) + and self._can_decompose(func, args, kwargs) # Not all funcs from __torch_dispatch__ are actual dispatcher ops, # e.g. prim.device and torch._C._dispatch_has_kernel(func.name()) @@ -507,12 +505,16 @@ def unwrap(x: FunctionalTensor) -> torch.Tensor: import torch._inductor.config as inductor_config if torch.compiler.is_exporting(): + # NB: out= ops are not yet handled here; they only go through v2 below. if export_config.enable_auto_functionalized_v2_for_export: return do_auto_functionalize_v2(self, func, args, kwargs) return do_auto_functionalize(self, func, args, kwargs) - if inductor_config.enable_auto_functionalized_v2: + if inductor_config.enable_auto_functionalized_v2 or ( + isinstance(func, torch._ops.OpOverload) + and torch._library.utils.is_out(func) + ): return do_auto_functionalize_v2(self, func, args, kwargs) return do_auto_functionalize(self, func, args, kwargs) @@ -536,16 +538,9 @@ def unwrap(x: FunctionalTensor) -> torch.Tensor: # Expectation: functionalization should not **already** be enabled above our mode. # Why would that be bad? when we return a FunctionalTensor here, we don't want functionalization # to run above this mode and further wrap that output in **another** C++ FunctionalTensorWrapper. - is_included = torch._C._dispatch_tls_is_dispatch_key_included( - torch._C.DispatchKey.Functionalize - ) - is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( - torch._C.DispatchKey.Functionalize + _assert_functionalize_not_active( + "Functionalization should not already be enabled above this mode" ) - if not is_excluded and is_included: - raise AssertionError( - "Functionalization should not already be enabled above this mode" - ) include_to_set = ( torch._C._dispatch_tls_local_include_set() | torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) @@ -588,34 +583,7 @@ def unwrap(x: FunctionalTensor) -> torch.Tensor: torch.Tensor, wrap, outs_unwrapped ) else: - # Note: [Functionalization View Replay Annotation] - # When functionalization encounters a mutation, it handles aliases by lazily regenerating the aliases - # at the first time they are next used. - # This is a problem when plumbing user annotations during tracing. We want the view ops from view replay - # to have the same annotation that the user specified on the original views. But view replay in - # functionalization happens the next time the alias is used (e.g. second_op(alias_with_pending_mutation)), - # so when we regenerate views before calling into second_op, those views will end up getting the metadata - # for second_op! - # - # Instead, we need to remember the node metadata from the original views, and ensure that this node metadata - # is globally set when we lazily perform view replay. - # The globally set metadata will be used to populate the fx node created for the replayed operation. - if m := torch._C._get_dispatch_mode( - torch._C._TorchDispatchModeKey.PROXY - ): - for a in pytree.tree_leaves([args, kwargs]): - if not isinstance(a, FunctionalTensor): - continue - unwrapped = torch._from_functional_tensor(a.elem) - try: - tracker_entry = m.tracer.tensor_tracker[unwrapped] - except KeyError: - raise RuntimeError( - f"cannot find {unwrapped} in tensor_tracker" - ) from None - curr_node = tracker_entry.proxy.node - with fx_traceback.set_current_replay_node(curr_node): - torch._sync(a) + self._sync_view_replay_annotations(args, kwargs) # When we dispatch to the C++ functionalization kernel, we might need to jump back to the # PreDispatch mode stack afterwards, to handle any other PreDispatch modes underneath @@ -638,16 +606,9 @@ def unwrap(x: FunctionalTensor) -> torch.Tensor: torch._disable_functionalization() torch._functionalize_enable_reapply_views(old_apply_views) # type: ignore[attr-defined] - is_included = torch._C._dispatch_tls_is_dispatch_key_included( - torch._C.DispatchKey.Functionalize - ) - is_excluded = torch._C._dispatch_tls_is_dispatch_key_excluded( - torch._C.DispatchKey.Functionalize + _assert_functionalize_not_active( + "Functionalization should not already be enabled above this mode after dispatch" ) - if not is_excluded and is_included: - raise AssertionError( - "Functionalization should not already be enabled above this mode after dispatch" - ) if ( # If no outputs are our functional subclass, then don't try to fix up aliasing @@ -675,6 +636,59 @@ def unwrap(x: FunctionalTensor) -> torch.Tensor: # If none of our inputs were wrapped, then we have no FunctionalTensor outputs that we need to fix up storages for. return return_and_correct_aliasing(func, args, kwargs, outs_wrapped) + def _sync_view_replay_annotations( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> None: + """Sync FunctionalTensor args so view replay uses correct fx node metadata. + + When functionalization encounters a mutation, it handles aliases by lazily + regenerating them at the first time they are next used. This is a problem when + plumbing user annotations during tracing: we want view ops from view replay to + have the same annotation the user specified on the original views. But view + replay happens the next time the alias is used (e.g. + second_op(alias_with_pending_mutation)), so the regenerated views would get the + metadata for second_op instead. + + To fix this, we remember the node metadata from the original views and globally + set it when we lazily perform view replay. The globally set metadata will be + used to populate the fx node created for the replayed operation. + """ + m = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + if m is not None: + for a in pytree.tree_leaves([args, kwargs]): + if not isinstance(a, FunctionalTensor): + continue + unwrapped = torch._from_functional_tensor(a.elem) + try: + tracker_entry = m.tracer.tensor_tracker[unwrapped] + except KeyError: + raise RuntimeError( + f"cannot find {unwrapped} in tensor_tracker" + ) from None + curr_node = tracker_entry.proxy.node + with fx_traceback.set_current_replay_node(curr_node): + torch._sync(a) + + def _can_decompose( + self, + func: OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> bool: + result = _can_decompose_fast(func, self.export, self.pre_dispatch) + if result is not None: + return result + + # in normal torch.compile IR, we only decompose an op if autograd + # would have decomposed it (NB: autograd may have been skipped if + # we are in inference mode) + # TODO: the flatten here can potentially be deduped with the + # unwrapping pytree_map later + flat_args_kwargs, _ = pytree.tree_flatten((args, kwargs)) + return autograd_would_have_decomposed(func, flat_args_kwargs) + @classmethod def is_infra_mode(cls) -> bool: return True @@ -746,21 +760,17 @@ def functionalize(self, inner_f: Callable[..., Any]) -> Callable[..., Any]: def redispatch_to_next(self) -> AbstractContextManager[None]: pass - @abstractmethod def replace(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> None: - pass + torch._functionalize_replace(input_tensor, output_tensor) - @abstractmethod def commit_update(self, tensor: torch.Tensor) -> None: - pass + torch._functionalize_commit_update(tensor) - @abstractmethod def sync(self, tensor: torch.Tensor) -> None: - pass + torch._functionalize_sync(tensor) - @abstractmethod def mark_mutation_hidden_from_autograd(self, tensor: torch.Tensor) -> None: - pass + torch._functionalize_mark_mutation_hidden_from_autograd(tensor) class PythonFunctionalizeAPI(BaseFunctionalizeAPI): @@ -796,35 +806,30 @@ def redispatch_to_next(self) -> AbstractContextManager[None]: # directly instead of globally setting it. return contextlib.nullcontext() - def replace(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> None: - if not isinstance(input_tensor, FunctionalTensor): + @staticmethod + def _check_cast_functional(tensor: torch.Tensor, name: str) -> FunctionalTensor: + if not isinstance(tensor, FunctionalTensor): raise AssertionError( - f"input_tensor must be a FunctionalTensor, got {type(input_tensor)}" + f"{name} must be a FunctionalTensor, got {type(tensor)}" ) + return tensor + + def replace(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> None: + ft = self._check_cast_functional(input_tensor, "input_tensor") if isinstance(output_tensor, FunctionalTensor): raise AssertionError("output_tensor must not be a FunctionalTensor") - input_tensor.replace_(output_tensor) + ft.replace_(output_tensor) def commit_update(self, tensor: torch.Tensor) -> None: - if not isinstance(tensor, FunctionalTensor): - raise AssertionError( - f"tensor must be a FunctionalTensor, got {type(tensor)}" - ) - tensor.commit_update() + self._check_cast_functional(tensor, "tensor").commit_update() def sync(self, tensor: torch.Tensor) -> None: - if not isinstance(tensor, FunctionalTensor): - raise AssertionError( - f"tensor must be a FunctionalTensor, got {type(tensor)}" - ) - tensor.sync() + self._check_cast_functional(tensor, "tensor").sync() def mark_mutation_hidden_from_autograd(self, tensor: torch.Tensor) -> None: - if not isinstance(tensor, FunctionalTensor): - raise AssertionError( - f"tensor must be a FunctionalTensor, got {type(tensor)}" - ) - tensor.mark_mutation_hidden_from_autograd() + self._check_cast_functional( + tensor, "tensor" + ).mark_mutation_hidden_from_autograd() class CppFunctionalizeAPI(BaseFunctionalizeAPI): @@ -851,18 +856,6 @@ def redispatch_to_next(self) -> AbstractContextManager[None]: torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) ) - def replace(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> None: - torch._functionalize_replace(input_tensor, output_tensor) - - def commit_update(self, tensor: torch.Tensor) -> None: - torch._functionalize_commit_update(tensor) - - def sync(self, tensor: torch.Tensor) -> None: - torch._functionalize_sync(tensor) - - def mark_mutation_hidden_from_autograd(self, tensor: torch.Tensor) -> None: - torch._functionalize_mark_mutation_hidden_from_autograd(tensor) - class FunctorchFunctionalizeAPI(BaseFunctionalizeAPI): def __init__(self, interpreter: FunctionalizeInterpreter) -> None: @@ -898,18 +891,6 @@ def functionalize(self, inner_f: Callable) -> Callable: def redispatch_to_next(self) -> AbstractContextManager[None]: return self.interpreter.lower() - def replace(self, input_tensor: torch.Tensor, output_tensor: torch.Tensor) -> None: - torch._functionalize_replace(input_tensor, output_tensor) - - def commit_update(self, tensor: torch.Tensor) -> None: - torch._functionalize_commit_update(tensor) - - def sync(self, tensor: torch.Tensor) -> None: - torch._functionalize_sync(tensor) - - def mark_mutation_hidden_from_autograd(self, tensor: torch.Tensor) -> None: - torch._functionalize_mark_mutation_hidden_from_autograd(tensor) - def mb_unwrap_functional_tensor(tensor: torch.Tensor) -> torch.Tensor: if isinstance(tensor, FunctionalTensor): diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py index c57b0a7412a7d..3d2490d5e7600 100644 --- a/torch/_subclasses/meta_utils.py +++ b/torch/_subclasses/meta_utils.py @@ -55,6 +55,7 @@ # Import the following modules during type checking to enable code intelligence features, # Do not import unconditionally, as they import sympy and importing sympy is very slow from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext + from torch.types import IntLikeType def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]: @@ -803,6 +804,65 @@ def _safe_clone(src: torch.Tensor) -> torch.Tensor | None: return src.clone() +def _grad_context_compatible( + symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext, + grad_desc: MetaTensorDesc[torch.Tensor], +) -> bool: + """Check if a symbolic_context is compatible with a grad tensor. + + Returns False when the view base structure in symbolic_context doesn't + match the grad, which means we need a fresh symbolic context. This + happens in FSDP2 where param._local_tensor is a view of an N-D padded + base while grad._local_tensor is a view of a 1-D flat gradient buffer. + + We check at both the outer level and the inner (subclass attr) level. + """ + from torch.fx.experimental.symbolic_shapes import ( + StatelessSymbolicContext, + SubclassSymbolicContext, + ) + + def _view_base_compatible( + ctx: StatelessSymbolicContext[Any, Any], + grad_t: MetaTensorDesc[torch.Tensor], + ) -> bool: + vbc = ctx.view_base_context + if grad_t.is_view and vbc is None: + return False + if not grad_t.is_view and vbc is not None: + return False + if ( + grad_t.is_view + and vbc is not None + and isinstance(vbc, StatelessSymbolicContext) + and grad_t.base is not None + and len(vbc.dynamic_sizes) != grad_t.base.ndim + ): + return False + return True + + if not isinstance(symbolic_context, StatelessSymbolicContext): + return True + + # Check outer level + if not _view_base_compatible(symbolic_context, grad_desc): + return False + + # Check inner (subclass) level + if isinstance(symbolic_context, SubclassSymbolicContext): + if grad_desc.attrs is None: + return False + for attr, inner_ctx in symbolic_context.inner_contexts.items(): + if attr not in grad_desc.attrs: + return False + if isinstance( + inner_ctx, StatelessSymbolicContext + ) and not _view_base_compatible(inner_ctx, grad_desc.attrs[attr]): + return False + + return True + + # This is a class for converting multiple tensors into meta tensors which # share the same view/storage structure. The operation model is you allocate # one of these, and then call it repeatedly on all the tensors you want to @@ -1035,7 +1095,7 @@ def sym_sizes_strides_storage_offset( src: torch._guards.Source, symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext | None = symbolic_context, - ) -> tuple[tuple[int, ...], tuple[int, ...], int]: + ) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType]: # local import to prevent circular import from torch.fx.experimental.symbolic_shapes import is_symbolic @@ -1106,8 +1166,8 @@ def empty_create( # symbolic context. def empty_create_subclass( t: MetaTensorDesc[Any], - outer_size: tuple[int, ...], - outer_stride: tuple[int, ...], + outer_size: tuple[IntLikeType, ...], + outer_stride: tuple[IntLikeType, ...], symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext | None = symbolic_context, source: torch._guards.Source | None = source, @@ -1145,7 +1205,9 @@ def empty_create_subclass( raise AssertionError("source must not be None") sub = self._empty_create_subclass( t, + # pyrefly: ignore[bad-argument-type] outer_size, + # pyrefly: ignore[bad-argument-type] outer_stride, shape_env, symbolic_context, @@ -1286,7 +1348,7 @@ def view_from_base( sym_eq, ) - def symint_visitor_fn(s: int) -> int: + def symint_visitor_fn(s: int) -> IntLikeType: nonlocal symbolic_context from torch.fx.experimental.symbolic_shapes import DimDynamic @@ -1393,7 +1455,11 @@ def tensor_visitor_fn( # NB: we do NOT suppress guards here, we need to remove ephemeral # sources fake_t = t.view_func.apply( - t, base, symint_visitor_fn, tensor_visitor_fn + t, + base, + # pyrefly: ignore[bad-argument-type] + symint_visitor_fn, + tensor_visitor_fn, ) # Ensure the output has symbolic shapes according to the outer symbolic context. @@ -2024,15 +2090,28 @@ def is_c_of_r( if t.grad is not None: from torch._dynamo.source import AttrSource - # TODO: Use a valid grad-specific symbolic context instead of recycling - # the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view(). + grad_source = AttrSource(source, "grad") + grad_symbolic_context = symbolic_context + + # The param's symbolic_context may be incompatible with the + # grad when they have different view base dimensionalities. + # This happens in FSDP2 where param._local_tensor is a view + # of an N-D padded base but grad._local_tensor is a view of + # a 1-D flat gradient buffer. Build a fresh context only in + # that case to preserve FX graph cache consistency. + if shape_env is not None and symbolic_context is not None: + if not _grad_context_compatible(symbolic_context, t.grad): + grad_symbolic_context = all_dynamic_symbolic_context( + t.grad, grad_source, shape_env, callback + ) + # pyrefly: ignore [unbound-name] r.grad = self.meta_tensor( t.grad, shape_env, callback, - AttrSource(source, "grad"), - symbolic_context, + grad_source, + grad_symbolic_context, ) # pyrefly: ignore [unbound-name] torch._C._set_conj(r, t.is_conj) diff --git a/torch/_tensor.py b/torch/_tensor.py index 9619263850659..77436cc17d887 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -900,7 +900,7 @@ def norm( keepdim=False, dtype=None, ): - r"""See :func:`torch.norm`""" + r"""See :func:`torch.linalg.norm`""" if has_torch_function_unary(self): return handle_torch_function( Tensor.norm, (self,), self, p=p, dim=dim, keepdim=keepdim, dtype=dtype @@ -1201,9 +1201,9 @@ def __len__(self): return self.shape[0] def __iter__(self): - # NB: we use 'imap' and not 'map' here, so that in Python 2 we get a - # generator and don't eagerly perform all the indexes. This could - # save us work, and also helps keep trace ordering deterministic + # NB: we use 'imap' and not 'map' here, so that we get a generator + # and don't eagerly perform all the indexes. This could save us + # work, and also helps keep trace ordering deterministic # (e.g., if you zip(*hiddens), the eager map will force all the # indexes of hiddens[0] before hiddens[1], while the generator # map will interleave them.) @@ -1364,8 +1364,6 @@ def refine_names(self, *names): # pyrefly: ignore # bad-override :attr:`names` to the same length as ``self.dim()`` using names from the corresponding indices of ``self.names``. - Python 2 does not support Ellipsis but one may use a string literal - instead (``'...'``). Args: names (iterable of str): The desired names of the output tensor. May @@ -1408,8 +1406,6 @@ def align_to(self, *names): # pyrefly: ignore # bad-override that are not mentioned in :attr:`names`, in the order that they appear in :attr:`self`. - Python 2 does not support Ellipsis but one may use a string literal - instead (``'...'``). Args: names (iterable of str): The desired dimension ordering of the diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 5f1268f8481c2..f2870993566f4 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1515,12 +1515,37 @@ def add_docstr_all(method: str, docstr: str) -> None: """, ) +add_docstr_all( + "const_data_ptr", + r""" +const_data_ptr() -> int + +Returns the address of the first element of :attr:`self` tensor. + +Unlike :meth:`data_ptr`, this is guaranteed to be a read-only access +that will not trigger copy-on-write materialization. For regular +(non-COW) tensors, the return value is identical to :meth:`data_ptr`. + +.. warning:: + + The returned pointer must not be used to mutate the tensor data. + Use :meth:`data_ptr` when write access is needed. +""", +) + add_docstr_all( "data_ptr", r""" data_ptr() -> int Returns the address of the first element of :attr:`self` tensor. + +.. note:: + + If the tensor is a copy-on-write tensor (e.g. created via + :meth:`_lazy_clone`), calling this method will materialize the + copy. Use :meth:`const_data_ptr` if you only need read-only access + to the data pointer. """, ) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index c3302c577151a..9d7932d3295fa 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -1235,7 +1235,7 @@ def merge_dicts(*dicts): add_docstr( torch.asarray, r""" -asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: bool = False) -> Tensor # noqa: B950 +asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: Optional[bool] = None) -> Tensor # noqa: B950 Converts :attr:`obj` to a tensor. @@ -1249,13 +1249,15 @@ def merge_dicts(*dicts): 6. a sequence of scalars When :attr:`obj` is a tensor, NumPy array, or DLPack capsule the returned tensor will, -by default, not require a gradient, have the same datatype as :attr:`obj`, be on the -same device, and share memory with it. These properties can be controlled with the -:attr:`dtype`, :attr:`device`, :attr:`copy`, and :attr:`requires_grad` keyword arguments. -If the returned tensor is of a different datatype, on a different device, or a copy is -requested then it will not share its memory with :attr:`obj`. If :attr:`requires_grad` -is ``True`` then the returned tensor will require a gradient, and if :attr:`obj` is -also a tensor with an autograd history then the returned tensor will have the same history. +by default, have the same requires_grad as :attr:`obj` (defaulting to False), have the +same datatype, be on the same device, and share memory with it. These properties can be +controlled with the :attr:`dtype`, :attr:`device`, :attr:`copy`, and +:attr:`requires_grad` keyword arguments. If the returned tensor is of a different +datatype, on a different device, or a copy is requested then it will not share its +memory with :attr:`obj`. If :attr:`requires_grad` is ``True`` (or ``None``, and +:attr:`obj` was a tensor with requires_grad set), then the returned tensor will require +a gradient, and if :attr:`obj` is also a tensor with an autograd history then the +returned tensor will have the same history. When :attr:`obj` is not a tensor, NumPy array, or DLPack capsule but implements Python's buffer protocol then the buffer is interpreted as an array of bytes grouped according to @@ -1298,10 +1300,10 @@ def merge_dicts(*dicts): Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if :attr:`obj` is a Python sequence, the current default device will be used. requires_grad (bool, optional): whether the returned tensor requires grad. - Default: ``False``, which causes the returned tensor not to require a gradient. - If ``True``, then the returned tensor will require a gradient, and if :attr:`obj` - is also a tensor with an autograd history then the returned tensor will have - the same history. + Default: ``None``, which causes requires_grad for the returned tensor to be + inferred from :attr:`obj`. If ``True``, then the returned tensor will require + a gradient, and if :attr:`obj` is also a tensor with an autograd history then + the returned tensor will have the same history. Example:: @@ -1320,13 +1322,17 @@ def merge_dicts(*dicts): >>> b tensor([3., 4., 5.], grad_fn=) >>> # Shares memory with tensor 'b', with no grad - >>> c = torch.asarray(b) + >>> c = torch.asarray(b, requires_grad=False) >>> c tensor([3., 4., 5.]) >>> # Shares memory with tensor 'b', retaining autograd history >>> d = torch.asarray(b, requires_grad=True) >>> d tensor([3., 4., 5.], grad_fn=) + >>> # Shares memory with tensor 'b', retaining autograd history + >>> e = torch.asarray(b) + >>> e + tensor([3., 4., 5.], grad_fn=) >>> array = numpy.array([1, 2, 3]) >>> # Shares memory with array 'array' @@ -11344,7 +11350,7 @@ def merge_dicts(*dicts): Args: {input} indices (LongTensor): the indices into :attr:`input`. Must have long dtype. - dim (int, optional): dimension to select along. Default: 0 + dim (int, optional): dimension to select along. Default: `None`. Keyword args: {out} @@ -13123,7 +13129,7 @@ def merge_dicts(*dicts): add_docstr( torch.trapz, r""" -trapz(y, x, *, dim=-1) -> Tensor +trapz(y, x=None, *, dim=-1) -> Tensor Alias for :func:`torch.trapezoid`. """, @@ -13634,6 +13640,24 @@ def merge_dicts(*dicts): """, ) +add_docstr( + torch.Stream.is_capturing, + r""" +Stream.is_capturing() -> bool + +Return true if this stream is currently recording work for graph capture. + +Returns: + bool: A boolean indicating if the stream is capturing. + +Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> s_cuda = torch.Stream(device='cuda') + >>> s_cuda.is_capturing() +""", +) + add_docstr( torch.Stream.wait_event, diff --git a/torch/_utils.py b/torch/_utils.py index a883c24b598f6..0c823d52f2f1f 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -703,17 +703,6 @@ def _take_tensors(tensors, size_limit): yield buf -# annotation decorator to get annotations in a way that is compatible -# with both Python 2 and 3 -def annotate(ret, **kwargs): - def dec(fun): - fun.__annotations__ = dict(kwargs) - fun.__annotations__["return"] = ret - return fun - - return dec - - def render_call(fn, args, kwargs): str_fn = torch.overrides.resolve_name(fn) if str_fn is None: @@ -785,6 +774,19 @@ def reraise(self): raise exception +def cpu_count() -> int | None: + """Return the number of CPUs available to the current process. + + Prefers ``os.sched_getaffinity`` (respects cgroups / taskset) and + falls back to ``os.cpu_count``. + """ + # os.process_cpu_count was added in CPython 3.13, see + # https://docs.python.org/3/library/os.html#os.process_cpu_count + if hasattr(os, "sched_getaffinity"): + return len(os.sched_getaffinity(0)) + return os.cpu_count() + + def _get_available_device_type(): if torch.cuda.is_available(): return "cuda" @@ -1379,3 +1381,17 @@ def _augment_memory_snapshot_stack_traces( _augment_frames(trace_entry["frames"]) return snapshot_dict + + +def _is_privateuse1_backend_available(): + """ + Determines whether the privateuse1 backend is registered and available. + + Returns: + Return True if the privateuse1 backend is registered and available. + """ + privateuse1_backend_name = torch._C._get_privateuse1_backend_name() + privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None) + return ( + is_available := getattr(privateuse1_backend_module, "is_available", None) + ) and is_available() diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index ec7c0a32f52b5..75ed0a754caed 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -375,3 +375,10 @@ def find_compile_subproc_binary() -> str | None: Allows overriding the binary used for subprocesses """ return None + + +def get_torch_source_version() -> str: + """Return the source commit hash for the current PyTorch build.""" + import torch.version as torch_version + + return getattr(torch_version, "git_version", "") diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py index 4b21450006e2e..0e7528e71a527 100644 --- a/torch/accelerator/__init__.py +++ b/torch/accelerator/__init__.py @@ -9,8 +9,10 @@ import torch from ._utils import _device_t, _get_device_index +from .graphs import Graph from .memory import ( empty_cache, + empty_host_cache, get_memory_info, max_memory_allocated, max_memory_reserved, @@ -23,6 +25,7 @@ __all__ = [ + "Graph", "current_accelerator", "current_device_idx", # deprecated "current_device_index", @@ -31,6 +34,7 @@ "device_count", "device_index", "empty_cache", + "empty_host_cache", "get_memory_info", "is_available", "max_memory_allocated", @@ -165,7 +169,10 @@ def get_device_capability(device: _device_t = None, /) -> dict[str, Any]: Returns: dict[str, Any]: A dictionary containing device capability information. The dictionary includes: - - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types supported by the device + - ``supported_dtypes`` (set(torch.dtype)): Set of PyTorch data types for which + tensors can be allocated on the accelerator and type conversion across + supported dtypes are supported. Any operator support outside of that + is not guaranteed Examples: >>> # xdoctest: +SKIP("requires cuda") diff --git a/torch/accelerator/graphs.py b/torch/accelerator/graphs.py new file mode 100644 index 0000000000000..f02e91f253a45 --- /dev/null +++ b/torch/accelerator/graphs.py @@ -0,0 +1,177 @@ +import gc +from typing import Literal +from typing_extensions import Self + +import torch +from torch._C import _acceleratorGraph + + +class Graph(_acceleratorGraph): + r""" + Wrapper around an :ref:`accelerator` graph that supports capture and replay. + + A graph captures a sequence of operations and their dependencies, allowing them to be + replayed efficiently with reduced overhead. This class can be used as a context manager + to automatically capture operations on the current stream. + + Arguments: + keep_graph (bool, optional): If ``False``, the underlying graph is destroyed and the + executable graph is instantiated on the GPU at the end of ``capture_end``. + If ``True``, the underlying graph is preserved after ``capture_end``. In this case, + the executable graph is not instantiated automatically; it must be explicitly created + by calling ``instantiate``, or it will be instantiated on the first call to ``replay``. + Defaults to ``False``. + pool (tuple[int, int], optional): Memory pool identifier for this graph. Multiple graphs + can share the same pool by passing the same identifier, which can reduce memory overhead. + Defaults to ``None``. + capture_error_mode (Literal["default", "global", "thread_local", "relaxed"], optional): + Specifies the behavior of graph capture. The exact semantics are backend-specific. + ``"default"``: backend-defined default capture behavior. + ``"global"``: potentially unsafe API calls are prohibited. Errors may occur if capture + in the current thread affects other threads. + ``"thread_local"``: potentially unsafe API calls are prohibited. Errors occur only if + capture in the current thread affects itself. + ``"relaxed"``: the current thread is allowed to make potentially unsafe API calls, except + for calls that inherently conflict with stream capture. + Default: ``"default"``. + + Example:: + + >>> # xdoctest: +SKIP + >>> x = torch.zeros([2000], device=0) + + >>> stream = torch.Stream() + >>> graph = torch.accelerator.Graph() + >>> with stream, graph: + ... x += 1 + + >>> graph.replay() + """ + + def __new__( + cls, + keep_graph: bool = False, + *, + pool: tuple[int, int] | None = None, + capture_error_mode: Literal[ + "default", "global", "thread_local", "relaxed" + ] = "default", + ) -> Self: + return super().__new__(cls, keep_graph) + + def __init__( + self, + keep_graph: bool = False, + *, + pool: tuple[int, int] | None = None, + capture_error_mode: Literal[ + "default", "global", "thread_local", "relaxed" + ] = "default", + ) -> None: + super().__init__(keep_graph) + self.graph_pool = pool + self.capture_error_mode = capture_error_mode + + # pyrefly: ignore [bad-override] + def capture_begin(self) -> None: + r""" + Begin graph capture on the current stream. + + All operations on the current stream after this call will be recorded into the graph until + ``capture_end`` is called, using the memory pool and capture error mode provided at construction time. + """ + super().capture_begin( + pool=self.graph_pool, capture_error_mode=self.capture_error_mode + ) + + def capture_end(self) -> None: + r""" + End graph capture on the current stream of the current device. + + After this call, the graph can be replayed via ``replay``. + """ + super().capture_end() + + def instantiate(self) -> None: + r""" + Instantiate the underlying graph. Will be called by ``capture_end`` + if ``keep_graph=False``, or by ``replay`` if ``keep_graph=True`` and + ``instantiate`` has not already been explicitly called. + """ + super().instantiate() + + def replay(self) -> None: + r"""Replay the work captured by this graph.""" + super().replay() + + def reset(self) -> None: + r""" + Delete the graph currently held by this instance. + + After this call, the graph can be recaptured. Set :attr:`graph_pool` or + :attr:`capture_error_mode` beforehand to use different settings on the next capture. + """ + super().reset() + + def pool(self) -> tuple[int, int]: + r""" + Return an opaque token representing the id of this graph's memory pool. + + This id can optionally be passed to another graph's ``capture_begin``, + which hints the other graph may share the same memory pool. + + Example:: + >>> # xdoctest: +SKIP + >>> g1 = torch.accelerator.Graph() + >>> g1.capture_begin() + >>> # ... operations ... + >>> g1.capture_end() + + >>> # Share g1's memory pool with a new graph + >>> pool_id = g1.pool() + >>> g2 = torch.accelerator.Graph(pool=pool_id) + """ + return super().pool() + + def enable_debug_mode(self) -> None: + r"""Enable debugging mode for ``debug_dump``.""" + return super().enable_debug_mode() + + def debug_dump(self, path: str) -> None: + r""" + Dump the captured graph to a file for debugging purposes if the debugging is + enabled via ``enable_debug_mode``. + + Arguments: + path (str): Path to dump the graph to. + + Example:: + >>> # xdoctest: +SKIP + >>> s = torch.Stream() + >>> g = torch.accelerator.Graph() + >>> g.enable_debug_mode() + + >>> with s, g: + >>> # ... operations ... + + >>> # Dump captured graph to a file "graph_dump.dot" + >>> g.debug_dump("graph_dump.dot") + """ + return super().debug_dump(path) + + def __enter__(self) -> None: + torch.accelerator.synchronize() + if torch.compiler.config.force_cudagraph_gc: + # We previously always ran garbage collection here. While this can help + # reclaim accelerator device memory held by dead Python cycles, it is + # very expensive, especially when performing multiple graph captures in sequence. + gc.collect() + torch.accelerator.empty_cache() + torch.accelerator.empty_host_cache() + self.capture_begin() + + def __exit__(self, *exc_info: object) -> None: + self.capture_end() + + +__all__ = ["Graph"] diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py index a326af6dbcc1b..08998d221a58b 100644 --- a/torch/accelerator/memory.py +++ b/torch/accelerator/memory.py @@ -8,6 +8,7 @@ __all__ = [ "empty_cache", + "empty_host_cache", "get_memory_info", "max_memory_allocated", "max_memory_reserved", @@ -31,6 +32,16 @@ def empty_cache() -> None: torch._C._accelerator_emptyCache() +def empty_host_cache() -> None: + r"""Release all unoccupied cached host (pinned) memory currently held by the host caching + allocator so that it can be used by other applications. + + .. note:: This function is a no-op if the memory allocator for the current + :ref:`accelerator ` has not been initialized. + """ + torch._C._accelerator_emptyHostCache() + + def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: r"""Return a dictionary of accelerator device memory allocator statistics for a given device index. diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 529e47a382b7b..892b9eeef646c 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -392,6 +392,11 @@ def __call__(self, func): return autocast_decorator(self, func) +# Subclass to distinguish autocast variables created by _enter_autocast (and not managed by a with statement) +class _UnmanagedAutocast(autocast): + pass + + # These functions aren't meant for public usage. # They are what we trace into a graph during pre_dispatch tracing # when we encounter an autocast context manager. @@ -401,7 +406,7 @@ def _enter_autocast(*vals): return torch.overrides.handle_torch_function( torch.amp._enter_autocast, [], *vals ) - mode = torch.amp.autocast(*vals) + mode = _UnmanagedAutocast(*vals) mode.__enter__() return mode diff --git a/torch/amp/grad_scaler.py b/torch/amp/grad_scaler.py index b17d84e72ae87..c1592efe9f737 100644 --- a/torch/amp/grad_scaler.py +++ b/torch/amp/grad_scaler.py @@ -449,7 +449,7 @@ def step( found_inf = cast( torch.Tensor, sum( - [ # noqa: C419 + [ t.to(scaler.device, non_blocking=True) for t in optimizer_state["found_inf_per_device"].values() ] diff --git a/torch/ao/__init__.py b/torch/ao/__init__.py index ac866b5073deb..0919208c4201f 100644 --- a/torch/ao/__init__.py +++ b/torch/ao/__init__.py @@ -7,7 +7,7 @@ if _TYPE_CHECKING: from types import ModuleType - from torch.ao import ( # noqa: TC004 + from torch.ao import ( nn as nn, ns as ns, pruning as pruning, diff --git a/torch/ao/nn/__init__.py b/torch/ao/nn/__init__.py index 7439c22d66882..c3f48fa2bb342 100644 --- a/torch/ao/nn/__init__.py +++ b/torch/ao/nn/__init__.py @@ -9,7 +9,7 @@ if _TYPE_CHECKING: from types import ModuleType - from torch.ao.nn import ( # noqa: TC004 + from torch.ao.nn import ( intrinsic as intrinsic, qat as qat, quantizable as quantizable, diff --git a/torch/ao/nn/intrinsic/__init__.py b/torch/ao/nn/intrinsic/__init__.py index 80ba84a84251d..5236cb7ba7e1d 100644 --- a/torch/ao/nn/intrinsic/__init__.py +++ b/torch/ao/nn/intrinsic/__init__.py @@ -1,12 +1,12 @@ import types from .modules import * # noqa: F403 -from .modules.fused import _FusedModule # noqa: F403 +from .modules.fused import _FusedModule # # Subpackages -# from . import qat # noqa: F403 -# from . import quantized # noqa: F403 +# from . import qat +# from . import quantized __all__ = [ "ConvBn1d", diff --git a/torch/ao/nn/intrinsic/modules/__init__.py b/torch/ao/nn/intrinsic/modules/__init__.py index 132137b735737..038d4b567f902 100644 --- a/torch/ao/nn/intrinsic/modules/__init__.py +++ b/torch/ao/nn/intrinsic/modules/__init__.py @@ -1,4 +1,4 @@ -from .fused import ( # noqa: F401 +from .fused import ( _FusedModule, BNReLU2d, BNReLU3d, diff --git a/torch/ao/nn/quantized/dynamic/modules/conv.py b/torch/ao/nn/quantized/dynamic/modules/conv.py index 532964ad5e99f..3f63756d37b31 100644 --- a/torch/ao/nn/quantized/dynamic/modules/conv.py +++ b/torch/ao/nn/quantized/dynamic/modules/conv.py @@ -68,7 +68,7 @@ def __init__( reduce_range=True, ): warnings.warn( - f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", stacklevel=2, ) factory_kwargs = {"device": device, "dtype": dtype} @@ -241,7 +241,7 @@ def __init__( dtype=None, ): warnings.warn( - f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", stacklevel=2, ) if padding_mode == "reflect": @@ -334,7 +334,7 @@ def __init__( dtype=None, ): warnings.warn( - f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", stacklevel=2, ) factory_kwargs = {"device": device, "dtype": dtype} @@ -417,7 +417,7 @@ def __init__( dtype=None, ): warnings.warn( - f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", stacklevel=2, ) factory_kwargs = {"device": device, "dtype": dtype} @@ -500,7 +500,7 @@ def __init__( dtype=None, ): warnings.warn( - f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", # noqa: B950 + f"The current implementation of the {self._get_name()} module has poor numerical accuracy and its use is not recommended", stacklevel=2, ) factory_kwargs = {"device": device, "dtype": dtype} diff --git a/torch/ao/nn/quantized/dynamic/modules/rnn.py b/torch/ao/nn/quantized/dynamic/modules/rnn.py index b1e186ef130fb..29cca89663a6b 100644 --- a/torch/ao/nn/quantized/dynamic/modules/rnn.py +++ b/torch/ao/nn/quantized/dynamic/modules/rnn.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn -from torch import Tensor # noqa: F401 +from torch import Tensor from torch._jit_internal import Dict, List, Optional, Tuple, Union # noqa: F401 from torch.ao.nn.quantized.modules.utils import _quantize_weight from torch.nn.utils.rnn import PackedSequence diff --git a/torch/ao/nn/quantized/functional.py b/torch/ao/nn/quantized/functional.py index 5d0aed8600265..80e124b9a21f0 100644 --- a/torch/ao/nn/quantized/functional.py +++ b/torch/ao/nn/quantized/functional.py @@ -221,7 +221,7 @@ def conv1d( >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) >>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) - """ # noqa: E501 + """ if padding_mode != "zeros": raise NotImplementedError("Only zero-padding is supported!") if input.dtype != torch.quint8: @@ -293,7 +293,7 @@ def conv2d( >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) >>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) - """ # noqa: E501 + """ if padding_mode != "zeros": raise NotImplementedError("Only zero-padding is supported!") if input.dtype != torch.quint8: @@ -369,7 +369,7 @@ def conv3d( >>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters) >>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs) >>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point) - """ # noqa: E501 + """ if padding_mode != "zeros": raise NotImplementedError("Only zero-padding is supported!") if input.dtype != torch.quint8: diff --git a/torch/ao/nn/quantized/modules/embedding_ops.py b/torch/ao/nn/quantized/modules/embedding_ops.py index 66826fd6fe143..401e44e9a78de 100644 --- a/torch/ao/nn/quantized/modules/embedding_ops.py +++ b/torch/ao/nn/quantized/modules/embedding_ops.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import torch import torch.nn as nn -from torch import Tensor # noqa: F401 +from torch import Tensor from torch._jit_internal import List, Optional # noqa: F401 from .utils import _hide_packed_params_repr, _quantize_weight diff --git a/torch/ao/nn/quantized/reference/modules/rnn.py b/torch/ao/nn/quantized/reference/modules/rnn.py index fd5179e7f3ad0..10b2b0e334cb9 100644 --- a/torch/ao/nn/quantized/reference/modules/rnn.py +++ b/torch/ao/nn/quantized/reference/modules/rnn.py @@ -573,7 +573,7 @@ def get_flat_weights(self): flat_weights.append(weight) return flat_weights - def forward(self, input, hx=None): # noqa: F811 + def forward(self, input, hx=None): orig_input = input # xxx: isinstance check needs to be in conditional for TorchScript to compile batch_sizes = None @@ -742,7 +742,7 @@ def get_flat_weights(self): flat_weights.append(weight) return flat_weights - def forward(self, input, hx=None): # noqa: F811 + def forward(self, input, hx=None): # Note: this is copied from the forward of GRU in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # only changed self._flat_weights to self.get_flat_weights() # TODO: maybe we can try inheriting from that class and define get_flat_weights diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 0e85cf9a91ba2..e8dbd98e31375 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -83,7 +83,7 @@ def __next__(self) -> NSSubgraph: ) if is_match: # navigate to the base node - # pyrefly: ignore [bad-assignment] + # pyrefly: ignore [bad-assignment, non-convergent-recursion] for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): # pyrefly: ignore [bad-argument-type] self.seen_nodes.add(cur_start_node) diff --git a/torch/ao/ns/fx/n_shadows_utils.py b/torch/ao/ns/fx/n_shadows_utils.py index 64e37673ec1de..41df14e962674 100644 --- a/torch/ao/ns/fx/n_shadows_utils.py +++ b/torch/ao/ns/fx/n_shadows_utils.py @@ -405,6 +405,7 @@ def _add_placeholder( new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] cur_args_copy.append(new_arg_placeholder) elif isinstance(arg, (float, int, torch.dtype)): + # pyrefly: ignore [bad-argument-type] cur_args_copy.append(arg) else: raise AssertionError(f"arg of type {type(arg)} not handled yet") diff --git a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py index b727635d08151..097337aa8b42c 100644 --- a/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py +++ b/torch/ao/pruning/_experimental/data_sparsifier/quantization_utils.py @@ -92,6 +92,7 @@ def post_training_sparse_quantize( # quantize for _, emb_module in embedding_modules: + # pyrefly: ignore [bad-argument-type] emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.ao.quantization.prepare(model, inplace=True) @@ -100,6 +101,7 @@ def post_training_sparse_quantize( else: # quantize for _, emb_module in embedding_modules: + # pyrefly: ignore [bad-argument-type] emb_module.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig torch.ao.quantization.prepare(model, inplace=True) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index 357a0bfece028..b95f026dcf2f7 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -8,7 +8,7 @@ from torch import Tensor from .fake_quantize import * # noqa: F403 -from .fuse_modules import fuse_modules, fuse_modules_qat # noqa: F403 +from .fuse_modules import fuse_modules, fuse_modules_qat from .fuser_method_mappings import * # noqa: F403 from .observer import * # noqa: F403 from .qconfig import * # noqa: F403 diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index 9cb322fd85d2c..db6852e915d75 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -157,7 +157,7 @@ def _get_binary_op_configs( ] binary_op_configs.extend( BackendPatternConfig(bop_pattern) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) ._set_num_tensor_args_to_observation_type( num_tensor_args_to_observation_type_mapping ) @@ -165,7 +165,7 @@ def _get_binary_op_configs( ) # matmul binary_op_configs.append( - BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) # noqa: E131 + BackendPatternConfig(torch.matmul).set_dtype_configs(dtype_configs) ) return binary_op_configs @@ -182,7 +182,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear module linear_configs.append( BackendPatternConfig(torch.nn.Linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -191,7 +191,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear qat module linear_configs.append( BackendPatternConfig(nnqat.Linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -199,7 +199,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # functional linear linear_configs.append( BackendPatternConfig(torch.nn.functional.linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -210,14 +210,14 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear relu, linear module + relu module linear_configs.append( BackendPatternConfig((torch.nn.Linear, torch.nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU) ) # linear relu, linear module + functional relu linear_configs.append( BackendPatternConfig((torch.nn.Linear, torch.nn.functional.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(nni.LinearReLU)) .set_fused_module(nni.LinearReLU) ) @@ -226,7 +226,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear relu, fused module linear_configs.append( BackendPatternConfig(nni.LinearReLU) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -235,7 +235,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear relu, qat fused module linear_configs.append( BackendPatternConfig(nniqat.LinearReLU) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -244,13 +244,13 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear relu, functional linear + relu module linear_configs.append( BackendPatternConfig((F.linear, torch.nn.ReLU)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) # linear relu, functional linear + functional relu linear_configs.append( BackendPatternConfig((F.linear, F.relu)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) @@ -259,7 +259,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # 3.1 linear bn fusion linear_configs.append( BackendPatternConfig((nn.Linear, nn.BatchNorm1d)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_linear_bn) .set_fused_module(nni.LinearBn1d) ) @@ -268,7 +268,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear bn, fused module linear_configs.append( BackendPatternConfig(nni.LinearBn1d) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -277,7 +277,7 @@ def _get_linear_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern # linear bn, qat fused module linear_configs.append( BackendPatternConfig(nniqat.LinearBn1d) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -297,7 +297,7 @@ def _get_conv_configs(dtype_configs): # conv module conv_configs.append( BackendPatternConfig(convs.root) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -306,7 +306,7 @@ def _get_conv_configs(dtype_configs): # conv qat module conv_configs.append( BackendPatternConfig(convs.qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -314,7 +314,7 @@ def _get_conv_configs(dtype_configs): # functional conv conv_configs.append( BackendPatternConfig(convs.func) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -325,14 +325,14 @@ def _get_conv_configs(dtype_configs): # conv relu fusion, conv module + relu module conv_configs.append( BackendPatternConfig((convs.root, torch.nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu) ) # conv relu fusion, conv module + functional relu conv_configs.append( BackendPatternConfig((convs.root, F.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu) ) @@ -340,7 +340,7 @@ def _get_conv_configs(dtype_configs): # conv relu, fused module conv_configs.append( BackendPatternConfig(convs.fused_conv_relu) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -349,7 +349,7 @@ def _get_conv_configs(dtype_configs): # conv relu, qat fused module conv_configs.append( BackendPatternConfig(convs.relu_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -358,26 +358,26 @@ def _get_conv_configs(dtype_configs): # conv relu, functional conv + relu module conv_configs.append( BackendPatternConfig((convs.func, torch.nn.ReLU)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) # conv relu, functional conv + functional relu conv_configs.append( BackendPatternConfig((convs.func, F.relu)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) # fused conv relu conv_configs.append( BackendPatternConfig(convs.fused_conv_relu) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.relu_qat) ) conv_configs.append( BackendPatternConfig(convs.relu_qat) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) ) @@ -388,21 +388,21 @@ def _get_conv_configs(dtype_configs): # conv + bn fusion conv_configs.append( BackendPatternConfig((convs.root, convs.bn)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_conv_bn) .set_fused_module(convs.fused_conv_bn) ) # conv + bn + relu module fusion conv_configs.append( BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_conv_bn_relu) .set_fused_module(convs.fused_conv_bn_relu) ) # conv + bn + relu functional fusion conv_configs.append( BackendPatternConfig((convs.root, convs.bn, F.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_fuser_method(fuse_conv_bn_relu) .set_fused_module(convs.fused_conv_bn_relu) @@ -413,21 +413,21 @@ def _get_conv_configs(dtype_configs): # fused conv bn conv_configs.append( BackendPatternConfig(convs.fused_conv_bn) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.bn_qat) ) # fused conv bn relu conv_configs.append( BackendPatternConfig(convs.fused_conv_bn_relu) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.bn_relu_qat) ) # conv bn, qat fused module conv_configs.append( BackendPatternConfig(convs.bn_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -435,7 +435,7 @@ def _get_conv_configs(dtype_configs): # conv bn relu, qat fused module conv_configs.append( BackendPatternConfig(convs.bn_relu_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -445,7 +445,7 @@ def _get_conv_configs(dtype_configs): # 4.1 conv transpose config conv_configs.append( BackendPatternConfig(convs.transpose) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_root_module(convs.transpose) .set_reference_quantized_module(convs.transpose_reference) ) @@ -453,7 +453,7 @@ def _get_conv_configs(dtype_configs): # 4.2 conv transpose + bn fusion conv_configs.append( BackendPatternConfig((convs.transpose, convs.bn)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_convtranspose_bn) .set_root_module(convs.transpose) .set_reference_quantized_module(convs.transpose_reference) @@ -462,7 +462,7 @@ def _get_conv_configs(dtype_configs): # 4.3 functional conv transpose conv_configs.append( BackendPatternConfig(convs.func_transpose) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -481,12 +481,12 @@ def _get_ln_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf ln_configs = [] ln_configs.append( BackendPatternConfig(torch.nn.LayerNorm) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs(dtype_configs) ) ln_configs.append( BackendPatternConfig(torch.nn.functional.layer_norm) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 2, "bias": 3}) ) @@ -512,21 +512,21 @@ def _get_default_op_configs( ] configs = [ BackendPatternConfig(op) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs(dtype_configs) for op in default_ops ] configs.append( BackendPatternConfig(torch.nn.functional.group_norm) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 2, "bias": 3}) ) configs.append( BackendPatternConfig(torch.nn.functional.instance_norm) - .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) # noqa: E131 + .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 3, "bias": 4}) ) @@ -584,7 +584,7 @@ def _get_fixed_qparams_op_configs( BackendPatternConfig(fixed_qparam_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(new_dtype_configs) ) return fixed_qparams_op_configs @@ -682,14 +682,14 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf # bn module + relu module fusion config bn_configs.append( BackendPatternConfig((bn, nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(fused_bn)) .set_fused_module(fused_bn) ) # bn module + F.relu fusion config bn_configs.append( BackendPatternConfig((bn, F.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(fused_bn)) .set_fused_module(fused_bn) ) @@ -697,7 +697,7 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf BackendPatternConfig(bn) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) ) @@ -707,7 +707,7 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf BackendPatternConfig(fused_bn) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) ) return bn_configs @@ -726,7 +726,7 @@ def _get_rnn_op_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPattern BackendPatternConfig(rnn_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) .set_root_module(rnn_op) .set_reference_quantized_module(ref_rnn_op) @@ -746,7 +746,7 @@ def _get_embedding_op_configs( BackendPatternConfig(embedding_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) .set_qat_module(qat_embedding_op) .set_root_module(embedding_op) @@ -758,7 +758,7 @@ def _get_embedding_op_configs( BackendPatternConfig(qat_embedding_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) .set_root_module(embedding_op) .set_reference_quantized_module(ref_embedding_op) diff --git a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py index d4e67b79c3702..5dc79a853ad2c 100644 --- a/torch/ao/quantization/backend_config/_qnnpack_pt2e.py +++ b/torch/ao/quantization/backend_config/_qnnpack_pt2e.py @@ -53,20 +53,20 @@ def get_linear_configs(): # linear_configs.append( # BackendPatternConfig((torch.ops.aten.addmm.default, MatchAllNode, MatchAllNode, torch.ops.aten.t.default)) - # .set_observation_type(observation_type) # noqa: E131 + # .set_observation_type(observation_type) # .set_dtype_configs(dtype_configs) # ._set_root_node_getter(root_node_getter)) linear_configs.append( BackendPatternConfig(torch.ops.aten.addmm.default) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 2, "bias": 0}) ) # linear is decomposed to `t - mm` if bias is not present linear_configs.append( BackendPatternConfig(torch.ops.aten.mm.default) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1}) ) @@ -79,7 +79,7 @@ def get_conv_configs(): dtype_configs = [weighted_op_quint8_dtype_config] conv_configs.append( BackendPatternConfig(torch.ops.aten.convolution.default) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -87,7 +87,7 @@ def get_conv_configs(): BackendPatternConfig( (torch.ops.aten.convolution.default, torch.ops.aten.relu.default) ) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -96,7 +96,7 @@ def get_conv_configs(): BackendPatternConfig( (torch.ops.aten.convolution.default, torch.ops.aten.relu_.default) ) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -117,7 +117,7 @@ def root_node_getter(node_pattern): ._set_pattern_complex_format( (operator.getitem, torch.ops.aten.max_pool2d_with_indices.default, 0) ) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_root_node_getter(root_node_getter) ) @@ -131,7 +131,7 @@ def get_relu_configs(): dtype_configs = [weighted_op_quint8_dtype_config] backend_pattern_configs.append( BackendPatternConfig(torch.ops.aten.relu.default) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) return backend_pattern_configs @@ -160,7 +160,7 @@ def get_binary_op_configs(): ] binary_op_configs.extend( BackendPatternConfig(bop_pattern) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) ._set_num_tensor_args_to_observation_type( num_tensor_args_to_observation_type_mapping ) diff --git a/torch/ao/quantization/backend_config/executorch.py b/torch/ao/quantization/backend_config/executorch.py index 2b9b16492821b..a7d8f585072ae 100644 --- a/torch/ao/quantization/backend_config/executorch.py +++ b/torch/ao/quantization/backend_config/executorch.py @@ -113,7 +113,7 @@ def _get_linear_configs() -> list[BackendPatternConfig]: # linear module linear_configs.append( BackendPatternConfig(torch.nn.Linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -122,7 +122,7 @@ def _get_linear_configs() -> list[BackendPatternConfig]: # linear qat module linear_configs.append( BackendPatternConfig(nnqat.Linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(torch.nn.Linear) .set_reference_quantized_module(nnqr.Linear) @@ -130,7 +130,7 @@ def _get_linear_configs() -> list[BackendPatternConfig]: # functional linear linear_configs.append( BackendPatternConfig(torch.nn.functional.linear) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -153,7 +153,7 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv module conv_configs.append( BackendPatternConfig(convs.root) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -162,7 +162,7 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv qat module conv_configs.append( BackendPatternConfig(convs.qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -170,7 +170,7 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # functional conv conv_configs.append( BackendPatternConfig(convs.func) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1, "bias": 2}) ) @@ -180,21 +180,21 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv module + relu module conv_configs.append( BackendPatternConfig((convs.root, nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu) ) # conv module + functional relu conv_configs.append( BackendPatternConfig((convs.root, F.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(_sequential_wrapper2(convs.fused_conv_relu)) .set_fused_module(convs.fused_conv_relu) ) # fused conv relu module conv_configs.append( BackendPatternConfig(convs.fused_conv_relu) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -203,7 +203,7 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv relu, qat fused module conv_configs.append( BackendPatternConfig(convs.relu_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -211,25 +211,25 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # functional conv + relu module conv_configs.append( BackendPatternConfig((convs.func, nn.ReLU)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) # functional conv + functional relu conv_configs.append( BackendPatternConfig((convs.func, F.relu)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) # fused conv relu conv_configs.append( BackendPatternConfig(convs.fused_conv_relu) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.relu_qat) ) conv_configs.append( BackendPatternConfig(convs.relu_qat) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) ) @@ -239,21 +239,21 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv + batchnorm (+ relu) conv_configs.append( BackendPatternConfig((convs.root, convs.bn)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_conv_bn) .set_fused_module(convs.fused_conv_bn) ) # conv + bn + relu module fusion conv_configs.append( BackendPatternConfig((convs.root, convs.bn, nn.ReLU)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuse_conv_bn_relu) .set_fused_module(convs.fused_conv_bn_relu) ) # conv + bn + relu functional fusion conv_configs.append( BackendPatternConfig((convs.root, convs.bn, F.relu)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_fuser_method(fuse_conv_bn_relu) .set_fused_module(convs.fused_conv_bn_relu) @@ -263,21 +263,21 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # fused conv bn conv_configs.append( BackendPatternConfig(convs.fused_conv_bn) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.bn_qat) ) # fused conv bn relu conv_configs.append( BackendPatternConfig(convs.fused_conv_bn_relu) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_qat_module(convs.bn_relu_qat) ) # conv bn, qat fused module conv_configs.append( BackendPatternConfig(convs.bn_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -285,7 +285,7 @@ def _get_conv_configs() -> list[BackendPatternConfig]: # conv bn relu, qat fused module conv_configs.append( BackendPatternConfig(convs.bn_relu_qat) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(convs.root) .set_reference_quantized_module(convs.reference) @@ -326,7 +326,7 @@ def _get_binary_ops_configs() -> list[BackendPatternConfig]: ] binary_op_configs.extend( BackendPatternConfig(bop_pattern) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) ._set_num_tensor_args_to_observation_type( num_tensor_args_to_observation_type_mapping ) @@ -386,7 +386,7 @@ def _get_share_qparams_ops_configs() -> list[BackendPatternConfig]: ] share_qparams_op_configs: list[BackendPatternConfig] = [ BackendPatternConfig(op) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) for op in share_qparams_ops ] @@ -405,7 +405,7 @@ def _get_bn_configs() -> list[BackendPatternConfig]: bn_configs = [] bn_configs.append( BackendPatternConfig(nn.BatchNorm2d) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) return bn_configs @@ -448,7 +448,7 @@ def _get_embedding_op_configs() -> list[BackendPatternConfig]: BackendPatternConfig(embedding_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) .set_qat_module(qat_embedding_op) .set_root_module(embedding_op) @@ -459,7 +459,7 @@ def _get_embedding_op_configs() -> list[BackendPatternConfig]: BackendPatternConfig(qat_embedding_op) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) .set_root_module(embedding_op) .set_reference_quantized_module(ref_embedding_op) @@ -470,7 +470,7 @@ def _get_embedding_op_configs() -> list[BackendPatternConfig]: BackendPatternConfig(torch.nn.functional.embedding) .set_observation_type( ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT - ) # noqa: E131 + ) .set_dtype_configs(dtype_configs) ._set_input_type_to_index({"weight": 1}) ) diff --git a/torch/ao/quantization/backend_config/onednn.py b/torch/ao/quantization/backend_config/onednn.py index 11a6b06581a75..e25932f4d5b37 100644 --- a/torch/ao/quantization/backend_config/onednn.py +++ b/torch/ao/quantization/backend_config/onednn.py @@ -190,7 +190,7 @@ def _conv_bn_add_extra_inputs_getter_left(add_pattern): BackendPatternConfig() ._set_pattern_complex_format( (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode) - ) # noqa: E131 + ) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_bn_add_left) @@ -201,7 +201,7 @@ def _conv_bn_add_extra_inputs_getter_left(add_pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) # noqa: E131 + ._set_pattern_complex_format((add_op, nn.Conv2d, MatchAllNode)) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_left) @@ -273,7 +273,7 @@ def _conv_bn_add_extra_inputs_getter_right(pattern): BackendPatternConfig() ._set_pattern_complex_format( (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d)) - ) # noqa: E131 + ) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_bn_add_right) @@ -284,7 +284,7 @@ def _conv_bn_add_extra_inputs_getter_right(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) # noqa: E131 + ._set_pattern_complex_format((add_op, MatchAllNode, nn.Conv2d)) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_right) @@ -295,7 +295,7 @@ def _conv_bn_add_extra_inputs_getter_right(pattern): conv_configs.append( BackendPatternConfig(nni.ConvAdd2d) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_root_module(nn.Conv2d) .set_reference_quantized_module(nnqr.Conv2d) @@ -376,7 +376,7 @@ def _conv_bn_add_relu_extra_inputs_getter_left(pattern): BackendPatternConfig() ._set_pattern_complex_format( (nn.ReLU, (add_op, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode)) - ) # noqa: E131 + ) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_bn_add_relu_left) @@ -387,7 +387,7 @@ def _conv_bn_add_relu_extra_inputs_getter_left(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) # noqa: E131 + ._set_pattern_complex_format((nn.ReLU, (add_op, nn.Conv2d, MatchAllNode))) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_relu_left) @@ -469,7 +469,7 @@ def _conv_bn_add_relu_extra_inputs_getter_right(pattern): BackendPatternConfig() ._set_pattern_complex_format( (nn.ReLU, (add_op, MatchAllNode, (nn.BatchNorm2d, nn.Conv2d))) - ) # noqa: E131 + ) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_bn_add_relu_right) @@ -480,7 +480,7 @@ def _conv_bn_add_relu_extra_inputs_getter_right(pattern): else: conv_configs.append( BackendPatternConfig() - ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) # noqa: E131 + ._set_pattern_complex_format((nn.ReLU, (add_op, MatchAllNode, nn.Conv2d))) .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_fuser_method(_fuse_conv_add_relu_right) @@ -491,7 +491,7 @@ def _conv_bn_add_relu_extra_inputs_getter_right(pattern): conv_configs.append( BackendPatternConfig(nni.ConvAddReLU2d) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(conv_dtype_configs) .set_root_module(nn.Conv2d) .set_reference_quantized_module(nnqr.Conv2d) @@ -523,14 +523,14 @@ def _add_eltwise_fusion_configs( # 1 base module + op module fusion config configs.append( BackendPatternConfig((root_module, post_module)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuser_method) .set_fused_module(fused_module) ) # base module + functional post op configs.append( BackendPatternConfig((root_module, post_op)) - .set_dtype_configs(dtype_configs) # noqa: E131 + .set_dtype_configs(dtype_configs) .set_fuser_method(fuser_method) .set_fused_module(fused_module) ) @@ -538,7 +538,7 @@ def _add_eltwise_fusion_configs( # 2 fused module configs configs.append( BackendPatternConfig(fused_module) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) .set_root_module(root_module) .set_reference_quantized_module(ref_quant_module) @@ -547,12 +547,12 @@ def _add_eltwise_fusion_configs( # 3 functional base op + post op configs configs.append( BackendPatternConfig((root_op, post_module)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) configs.append( BackendPatternConfig((root_op, post_op)) - .set_observation_type(observation_type) # noqa: E131 + .set_observation_type(observation_type) .set_dtype_configs(dtype_configs) ) @@ -574,7 +574,7 @@ def _add_eltwise_fusion_configs( # Configs for linear module + batchnorm + leaky_relu linear_configs.append( BackendPatternConfig((nn.Linear, nn.BatchNorm1d, nn.LeakyReLU)) - .set_dtype_configs(linear_dtype_configs) # noqa: E131 + .set_dtype_configs(linear_dtype_configs) .set_fuser_method(_fuse_linear_bn_leaky_relu) .set_fused_module(nni.LinearLeakyReLU) ) diff --git a/torch/ao/quantization/experimental/adaround_optimization.py b/torch/ao/quantization/experimental/adaround_optimization.py index 94aa8aa31009d..4fc0aa289b93c 100644 --- a/torch/ao/quantization/experimental/adaround_optimization.py +++ b/torch/ao/quantization/experimental/adaround_optimization.py @@ -76,12 +76,10 @@ def run_adaround(self) -> torch.nn.Module: # Knowing activation ahead-of-time would be helpful for asymmetric formulation # But this is challenging in eager mode, but graph module. layer_list.append((name, module, q_module)) - print(f"Total number of layers : {len(layer_list)}") # noqa: G004 + print(f"Total number of layers : {len(layer_list)}") for name, module, q_module in layer_list: - print( - f"Kick start adaptive rounding on {name} module {module}" # noqa: G004 - ) + print(f"Kick start adaptive rounding on {name} module {module}") self.optimize_adaptive_rounding( module, q_module, @@ -162,7 +160,7 @@ def _compute_and_display_local_losses( soft_quant_loss = F.mse_loss(out_soft_quant, fp_out) hard_quant_loss = F.mse_loss(out_hard_quant, fp_out) print( - f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" # noqa: G004 + f"soft quant loss: {soft_quant_loss.item()} hard quant loss: {hard_quant_loss.item()}" ) def optimize_adaptive_rounding( @@ -240,8 +238,8 @@ def optimize_adaptive_rounding( break if iteration % 30 == 0: print( - f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " # noqa: G004 - f"reconstruction_loss {reconstruction_loss.item()}" # noqa: G004 + f"glob iter {global_idx} regularization_loss {regularization_loss.item()} " + f"reconstruction_loss {reconstruction_loss.item()}" ) print("==================== After adaround ====================") self._compute_and_display_local_losses(ada_quantizer, q_module, inp[0], out[0]) diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index a7346daa283a5..95fc679d79ae4 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -3,6 +3,7 @@ import warnings from collections import namedtuple from typing import Any +from typing_extensions import TypeIs import torch import torch.ao.nn.intrinsic as nni @@ -327,7 +328,9 @@ def node_supports_equalization(node: Node, modules) -> bool: return False -def is_equalization_observer(observer: nn.Module) -> bool: +def is_equalization_observer( + observer: nn.Module, +) -> TypeIs[_InputEqualizationObserver | _WeightEqualizationObserver]: return isinstance( observer, (_InputEqualizationObserver, _WeightEqualizationObserver) ) diff --git a/torch/ao/quantization/fx/_lower_to_native_backend.py b/torch/ao/quantization/fx/_lower_to_native_backend.py index 03df30af556f9..8dce0d924b418 100644 --- a/torch/ao/quantization/fx/_lower_to_native_backend.py +++ b/torch/ao/quantization/fx/_lower_to_native_backend.py @@ -438,7 +438,7 @@ def _load_packed_weight( for attr_name in state_dict: if attr_name.startswith("_packed_weight") and isinstance( state_dict[attr_name], torch._C.ScriptObject - ): # type: ignore[attr-defined] # noqa: B950 + ): # type: ignore[attr-defined] setattr(self, attr_name, state_dict[attr_name]) attrs_to_pop.append(attr_name) @@ -534,9 +534,7 @@ def load_arg(a): quantized_model.register_load_state_dict_pre_hook(_load_packed_weight) if keep_original_weights: - setattr( # noqa: B010 - quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup - ) + setattr(quantized_model, ORIGINAL_WEIGHTS_LOOKUP, original_weights_lookup) return quantized_model diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index e4f42fc43700f..dde8bbaab07b6 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -210,7 +210,7 @@ def add_dequantize_op_kwargs(dequantize_op, input_node): # TODO: we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ["_scale_", "_zero_point_"] and ( - not isinstance(value_or_node, (float, int)) # noqa: UP038 + not isinstance(value_or_node, (float, int)) ): # For scale and zero_point values we register them as buffers in the root module. # However, note that when the values are not tensors, as in the case of @@ -596,7 +596,7 @@ def _maybe_recursive_remove_dequantize(arg: Any, node: Node, graph: Graph) -> No # we only replace the specific use since dequantize could be used by other nodes # as well node.replace_input_with(arg, quantize_node) - elif isinstance(arg, (list, tuple)): # noqa: UP038 + elif isinstance(arg, (list, tuple)): for arg_element in arg: _maybe_recursive_remove_dequantize(arg_element, node, graph) elif isinstance(arg, dict): @@ -838,7 +838,7 @@ def convert_weighted_module( "weight_hh": weight_qparams_hh, } ) - elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): # noqa: UP038 + elif isinstance(float_module, (torch.nn.LSTM, torch.nn.GRU)): # format for wq_or_wq_dict (flattened attributes): # {"weight_ih_l0_scale": ..., "weight_ih_l0_qscheme": ..., ...} for wn in float_module._flat_weights_names: @@ -1217,12 +1217,12 @@ def convert( return_node = node output = node.args[0] # outputs can be Node, list, tuple, dict, other cases are not supported yet - if isinstance(output, (list, tuple)): # noqa: UP038 + if isinstance(output, (list, tuple)): for idx in output_quantized_idxs: _maybe_recursive_remove_dequantize( output[idx], return_node, model.graph ) - elif isinstance(output, (Node, dict)): # noqa: UP038 + elif isinstance(output, (Node, dict)): # we treat dict as a single argument currently, but it can be extended # to support {"key": dtype} after we change output_quantized_idxs to # dict diff --git a/torch/ao/quantization/fx/graph_module.py b/torch/ao/quantization/fx/graph_module.py index 87ec3179a68ee..eaeebd64f3002 100644 --- a/torch/ao/quantization/fx/graph_module.py +++ b/torch/ao/quantization/fx/graph_module.py @@ -177,7 +177,7 @@ def _load_from_state_dict( for attr_name in state_dict: if attr_name.startswith("_packed_weight") and isinstance( state_dict[attr_name], torch._C.ScriptObject - ): # type: ignore[attr-defined] # noqa: B950 + ): # type: ignore[attr-defined] setattr(self, attr_name, state_dict[attr_name]) attrs_to_pop.append(attr_name) diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 71793a026fa9b..e3756af513737 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -182,7 +182,7 @@ def _is_input_arg_dtype_supported_by_backend( """Check if the configured qconfig for the argument is supported by the backend or not """ - if isinstance(arg, (list, tuple)): # noqa: UP038 + if isinstance(arg, (list, tuple)): return all( _is_input_arg_dtype_supported_by_backend( a, @@ -396,7 +396,7 @@ def _qat_swap_modules( def _add_matched_node_name_to_set(matched_node_pattern: NodePattern, s: set[str]): if isinstance(matched_node_pattern, Node): s.add(matched_node_pattern.name) - elif isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038 + elif isinstance(matched_node_pattern, (list, tuple)): for maybe_node in matched_node_pattern: _add_matched_node_name_to_set(maybe_node, s) @@ -446,7 +446,7 @@ def _set_target_dtype_info_for_matched_node_pattern( """Sets the target_dtype_info for each node in matched_node_pattern Note: processed_nodes is used to ensure we only process each node once """ - if isinstance(matched_node_pattern, (list, tuple)): # noqa: UP038 + if isinstance(matched_node_pattern, (list, tuple)): for node_pattern in matched_node_pattern: _set_target_dtype_info_for_matched_node_pattern( node_pattern, @@ -734,7 +734,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( """ # for ops such as torch.cat([x0, x1]), # traverse through the list - if isinstance(arg, (list, tuple)): # noqa: UP038 + if isinstance(arg, (list, tuple)): new_arg_to_return = [] for inner_arg in arg: new_inner_arg = _maybe_insert_input_observer_for_arg_or_kwarg( @@ -1155,7 +1155,7 @@ def _recursive_maybe_replace_node_with_obs( return observer_node else: return maybe_node - elif isinstance(maybe_node, (list, tuple)): # noqa: UP038 + elif isinstance(maybe_node, (list, tuple)): results = [ _recursive_maybe_replace_node_with_obs( inner_node, @@ -1247,7 +1247,7 @@ def propagate_dtypes_for_known_nodes( # when an argument is a tuple, it does not show up as another node so we need to go through # all elements of the tuple manually - if isinstance(arg, (tuple, list)): # noqa: UP038 + if isinstance(arg, (tuple, list)): arg_list = list(arg) else: arg_list = [arg] @@ -1282,7 +1282,7 @@ def _maybe_make_input_output_share_observers( first_arg = None # find the first non-Tensor arg for i in range(len(node.args)): - if isinstance(node.args[i], (Node, list, tuple)): # noqa: UP038 + if isinstance(node.args[i], (Node, list, tuple)): first_arg = node.args[i] break @@ -1290,7 +1290,7 @@ def _maybe_make_input_output_share_observers( if first_arg is None: return False - if isinstance(first_arg, (list, tuple)): # noqa: UP038 + if isinstance(first_arg, (list, tuple)): first_arg_arg = first_arg[0] elif isinstance(first_arg, Node): first_arg_arg = first_arg @@ -1329,7 +1329,7 @@ def _maybe_make_input_output_share_observers( raise AssertionError("target_to_use must be a string") obs_mod_to_use = named_modules[target_to_use] - if isinstance(first_arg, (list, tuple)): # noqa: UP038 + if isinstance(first_arg, (list, tuple)): # set all other input observer nodes to use that module for input_idx, input_arg in enumerate(first_arg): if input_idx == 0: diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 0521af9d1123c..858f44d8708cd 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -144,7 +144,7 @@ def _with_callable_args(cls_or_self, **kwargs): return r.with_callable_args(**kwargs) -ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: +ABC: Any = ABCMeta("ABC", (object,), {}) class ObserverBase(ABC, nn.Module): diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index ba6ab86aaa048..30d29e31ee188 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -10,9 +10,9 @@ from .backend_config import BackendConfig, get_tensorrt_backend_config # noqa: F401 from .fx.convert import convert from .fx.custom_config import ConvertCustomConfig, FuseCustomConfig, PrepareCustomConfig -from .fx.fuse import fuse # noqa: F401 +from .fx.fuse import fuse from .fx.graph_module import ObservedGraphModule # noqa: F401 -from .fx.prepare import prepare # noqa: F401 +from .fx.prepare import prepare from .fx.tracer import QuantizationTracer, Scope, ScopeContextManager # noqa: F401 from .fx.utils import ( # noqa: F401 get_custom_module_class_keys, diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index d6453402e8d59..75ee857c4af0f 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -547,6 +547,7 @@ def vjp(gO): result = tuple( output if output is not None + # pyrefly: ignore [bad-argument-type] else torch.zeros_like(input, requires_grad=create_graph) for (output, input) in zip(result, inputs) ) @@ -570,7 +571,7 @@ def _is_checkpoint_valid(): return Variable._execution_engine.is_checkpoint_valid() -def variable(*args, **kwargs): # noqa: D103 +def variable(*args, **kwargs): raise RuntimeError( "torch.autograd.variable(...) is deprecated, use torch.tensor(...) instead" ) diff --git a/torch/autograd/anomaly_mode.py b/torch/autograd/anomaly_mode.py index 0277f1b75541f..337ab539a92e6 100644 --- a/torch/autograd/anomaly_mode.py +++ b/torch/autograd/anomaly_mode.py @@ -76,7 +76,7 @@ class detect_anomaly: """ - def __init__(self, check_nan=True) -> None: # noqa: D107 + def __init__(self, check_nan=True) -> None: self.prev = torch.is_anomaly_enabled() self.check_nan = check_nan self.prev_check_nan = torch.is_anomaly_check_nan_enabled() @@ -87,10 +87,10 @@ def __init__(self, check_nan=True) -> None: # noqa: D107 stacklevel=2, ) - def __enter__(self) -> None: # noqa: D105 + def __enter__(self) -> None: torch.set_anomaly_enabled(True, self.check_nan) - def __exit__(self, *args: object) -> None: # noqa: D105 + def __exit__(self, *args: object) -> None: torch.set_anomaly_enabled(self.prev, self.prev_check_nan) @@ -111,13 +111,13 @@ class set_detect_anomaly: """ - def __init__(self, mode: bool, check_nan: bool = True) -> None: # noqa: D107 + def __init__(self, mode: bool, check_nan: bool = True) -> None: self.prev = torch.is_anomaly_enabled() self.prev_check_nan = torch.is_anomaly_check_nan_enabled() torch.set_anomaly_enabled(mode, check_nan) - def __enter__(self) -> None: # noqa: D105 + def __enter__(self) -> None: pass - def __exit__(self, *args: object) -> None: # noqa: D105 + def __exit__(self, *args: object) -> None: torch.set_anomaly_enabled(self.prev, self.prev_check_nan) diff --git a/torch/autograd/forward_ad.py b/torch/autograd/forward_ad.py index ce3eb94266271..355538c54d311 100644 --- a/torch/autograd/forward_ad.py +++ b/torch/autograd/forward_ad.py @@ -29,8 +29,12 @@ def enter_dual_level(): This function also updates the current level that is used by default by the other functions in this API. """ + from torch._functorch.predispatch import ( + _enter_dual_level as _predispatch_enter_dual_level, + ) + global _current_level - new_level = torch._C._enter_dual_level() + new_level = _predispatch_enter_dual_level() if new_level != _current_level + 1: raise RuntimeError( "Entering a new forward AD level but the current level " @@ -57,7 +61,11 @@ def exit_dual_level(*, level=None): "Trying to exit a forward AD level that was not the last one " "that was created. This is not supported." ) - torch._C._exit_dual_level(level=level) + from torch._functorch.predispatch import ( + _exit_dual_level as _predispatch_exit_dual_level, + ) + + _predispatch_exit_dual_level(level=level) _current_level = level - 1 @@ -125,7 +133,9 @@ def make_dual(tensor, tangent, *, level=None): f"Expected tangent to be floating point or complex, but got: {tangent.dtype}" ) - return torch._VF._make_dual(tensor, tangent, level=level) + from torch._functorch.predispatch import _make_dual as _predispatch_make_dual + + return _predispatch_make_dual(tensor, tangent, level=level) class UnpackedDualTensor(NamedTuple): @@ -165,7 +175,9 @@ def unpack_dual(tensor, *, level=None): if level < 0: return UnpackedDualTensor(tensor, None) - primal, dual = torch._VF._unpack_dual(tensor, level=level) + from torch._functorch.predispatch import _unpack_dual as _predispatch_unpack_dual + + primal, dual = _predispatch_unpack_dual(tensor, level=level) return UnpackedDualTensor(primal, dual) diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 4c639fd79b280..75c649335928d 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -299,12 +299,7 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): This class is used for internal autograd work. Do not use. """ - def apply(self, *args): - r""" - Apply method used when executing this Node during the backward - """ - # _forward_cls is defined by derived class - # The user should define either backward or vjp but never both. + def _get_user_fn(self): backward_fn = self._forward_cls.backward # type: ignore[attr-defined] vjp_fn = self._forward_cls.vjp # type: ignore[attr-defined] if backward_fn is not Function.backward and vjp_fn is not Function.vjp: @@ -313,9 +308,30 @@ def apply(self, *args): "Function is not allowed. You should only implement one " "of them." ) - user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn + return vjp_fn if vjp_fn is not Function.vjp else backward_fn + + def apply(self, *args): + r""" + Apply method used when executing this Node during the backward. + + Called by the autograd engine (non-boxed path) and by direct + grad_fn.apply() calls. When boxed_grads_call is True, boxes + grads into a mutable list before calling user's backward. + """ + user_fn = self._get_user_fn() + fwd_cls = self._forward_cls # type: ignore[attr-defined] # pyrefly: ignore[missing-attribute] + if getattr(fwd_cls, "boxed_grads_call", False): + args = (list(args),) return user_fn(self, *args) + def apply_boxed(self, *args): + r""" + Apply method called by the autograd engine when boxed_grads_call + is True. Grads arrive as a single mutable list argument, allowing + backward to free individual grads mid-execution. + """ + return self._get_user_fn()(self, *args) + def apply_jvp(self, *args): r""" Apply method used when executing forward mode AD during the forward @@ -458,6 +474,19 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any: """ clear_saved_tensors_on_access = False + """ + Bool that specifies if backward should receive grads as a single mutable + list argument instead of individual args in an immutable tuple. This allows + backward to free individual grads mid-execution by removing them from the + list, reducing peak memory. + + When True, ``backward(ctx, grads)`` receives a single list instead of + ``backward(ctx, *grads)``. + + Default is False. + """ + boxed_grads_call = False + @staticmethod def jvp(ctx: Any, *grad_inputs: Any) -> Any: r"""Define a formula for differentiating the operation with forward mode automatic differentiation. diff --git a/torch/autograd/functional.py b/torch/autograd/functional.py index 18e7716c6de10..b39e69cd1ca7f 100644 --- a/torch/autograd/functional.py +++ b/torch/autograd/functional.py @@ -568,7 +568,7 @@ def jvp(tangents): # batch dimension represents that of the inputs jacobian_input_i_output_j = jac.permute(*range(1, jac.ndim), 0).reshape( (*output_i.shape, *input_j.shape) - ) # noqa: C409 + ) jacobian_output_i_output.append(jacobian_input_i_output_j) jacobian_input_output.append(jacobian_output_i_output) diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 3d1c2a7269338..0775d43994283 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -376,11 +376,15 @@ class _force_original_view_tracking(_DecoratorContextManager): def __init__(self, mode: bool) -> None: self.prev = torch._C._is_view_replay_enabled() - torch._C._set_view_replay_enabled(mode) self.mode = mode + torch._C._set_view_replay_enabled(mode) + + def __call__(self, orig_func: F) -> F: + torch._C._set_view_replay_enabled(self.prev) + return super().__call__(orig_func) def __enter__(self) -> None: - pass + torch._C._set_view_replay_enabled(self.mode) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: torch._C._set_view_replay_enabled(self.prev) diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 4d5a4585b03bd..15a00af6f1b4a 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -2013,7 +2013,7 @@ def gradcheck( check_backward_ad: bool = True, fast_mode: bool = False, masked: bool | None = None, -) -> bool: # noqa: D400,D205 +) -> bool: r"""Check gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` that are of floating point or complex type and with ``requires_grad=True``. @@ -2182,7 +2182,7 @@ def gradgradcheck( check_rev_over_rev: bool = True, fast_mode: bool = False, masked: bool = False, -) -> bool: # noqa: D400,D205 +) -> bool: r"""Check gradients of gradients computed via small finite differences against analytical gradients wrt tensors in :attr:`inputs` and :attr:`grad_outputs` that are of floating point or complex type and with diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py index 55e9a9974c3fa..0c285fbe7b7c4 100644 --- a/torch/autograd/graph.py +++ b/torch/autograd/graph.py @@ -1,5 +1,6 @@ import abc import contextlib +import contextvars import functools import logging import threading @@ -873,6 +874,10 @@ def _engine_run_backward( attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG if attach_logging_hooks: unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs) + + # Need to save the context so compiler config will be visible in device threads + torch._C._stash_obj_in_tls("context", contextvars.copy_context()) + try: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass t_outputs, *args, **kwargs @@ -880,3 +885,4 @@ def _engine_run_backward( finally: if attach_logging_hooks: unregister_hooks() # type: ignore[possibly-undefined] + torch._C._stash_obj_in_tls("context", None) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 37a4aaa1d8ead..d067c8fc4d0a4 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -485,9 +485,6 @@ def _ensure_function_events(self): self._function_events.append(evt) self._old_function_events = None - if self._function_events is None: - raise RuntimeError("Profiler didn't finish running") - @property def function_events(self): if self._function_events is None or self._needs_processing: @@ -588,6 +585,7 @@ def _parse_kineto_results( # result.events() has most of the events - PyTorch op-level and device-level events timeout_ns = int(timeout_s * 1e9) if timeout_s is not None else None + result_events = result.events() if timeout_ns is not None and timeout_ns < 0: raise ValueError("timeout_s must be non-negative") start_time_ns = perf_counter_ns() @@ -604,10 +602,10 @@ def _check_timeout() -> bool: trace_start_ns = result.trace_start_ns() mem_records = [ - [evt, False] for evt in result.events() if evt.name() == MEMORY_EVENT_NAME + [evt, False] for evt in result_events if evt.name() == MEMORY_EVENT_NAME ] oom_records = [ - evt for evt in result.events() if evt.name() == OUT_OF_MEMORY_EVENT_NAME + evt for evt in result_events if evt.name() == OUT_OF_MEMORY_EVENT_NAME ] mem_records_acc = MemRecordsAcc(mem_records) @@ -641,7 +639,7 @@ def _device_memory_usage(mem_record): frontend_function_events = [] device_corr_map: dict[int, list[FunctionEvent]] = {} max_evt_id = 0 - for kineto_event in result.events(): + for kineto_event in result_events: if _check_timeout(): break @@ -652,14 +650,13 @@ def _device_memory_usage(mem_record): continue rel_start_ns = kineto_event.start_ns() - trace_start_ns rel_end_ns = kineto_event.end_ns() - trace_start_ns - abs_end_ns = kineto_event.end_ns() cpu_memory_usage = 0 device_memory_usage = 0 if kineto_event.device_type() == DeviceType.CPU: # find the corresponding memory allocation events for mem_record in mem_records_acc.in_interval( - kineto_event.start_ns(), abs_end_ns + kineto_event.start_ns(), kineto_event.end_ns() ): cpu_memory_usage += _cpu_memory_usage(mem_record[0]) device_memory_usage += _device_memory_usage(mem_record[0]) @@ -697,7 +694,21 @@ def _device_memory_usage(mem_record): device_resource_id=kineto_event.device_resource_id(), flops=kineto_event.flops(), is_user_annotation=kineto_event.is_user_annotation(), + is_python_function=kineto_event.is_python_function(), + activity_type=kineto_event.activity_type(), metadata_json=kineto_event.metadata_json(), + extra_meta=kineto_event.extra_meta() or None, + flow_id=kineto_event.flow_id(), + flow_type=kineto_event.flow_type(), + flow_start=kineto_event.flow_start(), + external_id=kineto_event.external_id(), + linked_correlation_id=kineto_event.linked_correlation_id(), + structured_input_shapes=kineto_event.structured_input_shapes(), + structured_input_strides=kineto_event.structured_input_strides(), + input_dtypes=kineto_event.dtypes(), + python_id=kineto_event.python_id(), + python_parent_id=kineto_event.python_parent_id(), + python_module_id=kineto_event.python_module_id(), ) max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: @@ -749,7 +760,7 @@ def _device_memory_usage(mem_record): # parents and children f_evt.thread = fe.thread - def createFunctionEventForMemoryEvents(evt): + def _create_function_event_for_memory_events(evt): rel_start_ns = evt.start_ns() - trace_start_ns fe = FunctionEvent( id=max_evt_id, @@ -780,7 +791,7 @@ def createFunctionEventForMemoryEvents(evt): if not mem_record[1]: max_evt_id += 1 - fe = createFunctionEventForMemoryEvents(mem_record[0]) + fe = _create_function_event_for_memory_events(mem_record[0]) all_function_events.append(fe) for oom_record in oom_records: @@ -788,7 +799,7 @@ def createFunctionEventForMemoryEvents(evt): break max_evt_id += 1 - fe = createFunctionEventForMemoryEvents(oom_record) + fe = _create_function_event_for_memory_events(oom_record) all_function_events.append(fe) if timed_out: diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index c51ac870aa37b..15815e8b19285 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1,10 +1,12 @@ # mypy: allow-untyped-defs import bisect import itertools +import json import math from collections import defaultdict, namedtuple +from collections.abc import Callable from operator import attrgetter -from typing import Any +from typing import Any, NamedTuple from typing_extensions import deprecated import torch @@ -13,6 +15,7 @@ __all__ = [ "EventList", + "EventMetadata", "FormattedTimesMixin", "Interval", "Kernel", @@ -549,6 +552,104 @@ def elapsed_us(self): Kernel = namedtuple("Kernel", ["name", "device", "duration"]) +class EventMetadata(NamedTuple): + # Kernel fields + registers_per_thread: int | None + shared_memory: int | None + grid: list[int] | None + block: list[int] | None + priority: int | None + blocks_per_sm: float | None + warps_per_sm: float | None + occupancy: dict[str, Any] | None + est_occupancy_pct: float | None + queued: int | None + graph_id: int | None + graph_node_id: int | None + stream: int | None + context: int | None + # Memory fields + bytes: int | None + bandwidth_gb_s: float | None + # NCCL fields + collective_name: str | None + dtype: str | None + in_msg_nelems: int | None + out_msg_nelems: int | None + in_split_size: str | None + out_split_size: str | None + global_rank_start: int | None + global_rank_stride: int | None + group_size: int | None + process_group_name: str | None + process_group_desc: str | None + group_ranks: str | None + rank: int | None + src_rank: int | None + dst_rank: int | None + seq: int | None + is_async: bool | None + + +def _to_str(v: str) -> str: + return v.strip('"') + + +def _to_bool(v: str) -> bool: + return v in ("1", "true") + + +# Kineto key → (EventMetadata field name, converter from string) +_EVENT_METADATA_KEYS: dict[str, tuple[str, Callable[[str], Any]]] = { + "registers per thread": ("registers_per_thread", int), + "shared memory": ("shared_memory", int), + "grid": ("grid", json.loads), # list[int] + "block": ("block", json.loads), # list[int] + "priority": ("priority", int), + "blocks per SM": ("blocks_per_sm", float), + "warps per SM": ("warps_per_sm", float), + "est. achieved occupancy %": ("est_occupancy_pct", float), + "occupancy": ("occupancy", json.loads), # dict[str, Any] + "queued": ("queued", int), + "graph id": ("graph_id", int), + "graph node id": ("graph_node_id", int), + "stream": ("stream", int), + "context": ("context", int), + "bytes": ("bytes", int), + "memory bandwidth (GB/s)": ("bandwidth_gb_s", float), + "Collective name": ("collective_name", _to_str), + "dtype": ("dtype", _to_str), + "In msg nelems": ("in_msg_nelems", int), + "Out msg nelems": ("out_msg_nelems", int), + "In split size": ("in_split_size", _to_str), + "Out split size": ("out_split_size", _to_str), + "Global rank start": ("global_rank_start", int), + "Global rank stride": ("global_rank_stride", int), + "Group size": ("group_size", int), + "Process Group Name": ("process_group_name", _to_str), + "Process Group Description": ("process_group_desc", _to_str), + "Process Group Ranks": ("group_ranks", _to_str), + "Rank": ("rank", int), + "Src Rank": ("src_rank", int), + "Dst Rank": ("dst_rank", int), + "Seq": ("seq", int), + "Is asynchronized op": ("is_async", _to_bool), +} + + +def _build_metadata(extra_meta): + fields: dict[str, Any] = {} + any_populated = False + for kineto_key, (field_name, convert) in _EVENT_METADATA_KEYS.items(): + v = extra_meta.get(kineto_key) + if v is not None: + fields[field_name] = convert(v) + any_populated = True + else: + fields[field_name] = None + return EventMetadata(**fields) if any_populated else None + + class FunctionEvent(FormattedTimesMixin): """Profiling information about a single function. @@ -573,7 +674,10 @@ class FunctionEvent(FormattedTimesMixin): count (int): Number of times this event was called (usually 1). cpu_children (List[FunctionEvent]): Direct CPU child operations. cpu_parent (FunctionEvent): Direct CPU parent operation. - input_shapes (Tuple[int, ...]): Shapes of input tensors (requires record_shapes=true). + input_shapes (List[List[int]]): Shapes of input tensors (requires record_shapes=True). + For plain tensor inputs, each entry is a list of dimensions (e.g. ``[16, 16]``). + TensorList inputs are represented as an empty list ``[]``; use + ``structured_input_shapes`` to get per-element shapes for TensorList inputs. concrete_inputs (List[Any]): Concrete input values (requires record_shapes=true). kwinputs (Dict[str, Any]): Keyword arguments (requires record_shapes=true). stack (List[str]): Python stack trace where the operation was called (requires with_stack=true). @@ -590,7 +694,15 @@ class FunctionEvent(FormattedTimesMixin): is_legacy (bool): Whether this is from the legacy profiler. flops (int): Estimated floating point operations. is_user_annotation (bool): Whether this is a user-annotated region. - metadata_json (str): Additional metadata in JSON format. + metadata_json (str): Deprecated. Use event_metadata instead. + event_metadata (EventMetadata): Additional metadata in structured format. + structured_input_shapes (List[List[int] | List[List[int]]]): Like ``input_shapes`` + but distinguishes TensorList inputs. Plain tensor inputs are ``List[int]``; + TensorList inputs are ``List[List[int]]`` containing one shape per tensor in the list. + Matches the ``"Input Dims"`` field in the Chrome trace JSON. + structured_input_strides (List[List[int] | List[List[int]]]): Strides of input + tensors in the same format as ``structured_input_shapes`` (requires + record_shapes=True). Properties: cpu_time_total (float): Total CPU time in microseconds. @@ -637,7 +749,21 @@ def __init__( concrete_inputs=None, kwinputs=None, is_user_annotation=False, + is_python_function=False, + activity_type=None, metadata_json=None, + flow_id=None, + flow_type=None, + flow_start=None, + external_id=0, + linked_correlation_id=0, + extra_meta=None, + structured_input_shapes=None, + structured_input_strides=None, + input_dtypes=None, + python_id=-1, + python_parent_id=-1, + python_module_id=-1, ): self.id: int = id self.node_id: int = node_id @@ -676,10 +802,29 @@ def __init__( self.is_legacy: bool = is_legacy self.flops: int | None = flops self.is_user_annotation: bool | None = is_user_annotation + self.is_python_function: bool = is_python_function + self.activity_type: str | None = activity_type self.self_cpu_percent = -1 self.total_cpu_percent = -1 self.total_device_percent = -1 - self.metadata_json = metadata_json + self._metadata_json = metadata_json + self.flow_id: int | None = flow_id + self.flow_type: int | None = flow_type + self.flow_start: bool | None = flow_start + self.external_id: int = external_id + self.linked_correlation_id: int = linked_correlation_id + self.event_metadata: EventMetadata | None = ( + _build_metadata(extra_meta) if extra_meta else None + ) + # pyrefly: ignore [bad-assignment] + self.structured_input_shapes: list = structured_input_shapes + # pyrefly: ignore [bad-assignment] + self.structured_input_strides: list = structured_input_strides + # pyrefly: ignore [bad-assignment] + self.input_dtypes: list[str] = input_dtypes + self.python_id: int = python_id + self.python_parent_id: int = python_parent_id + self.python_module_id: int = python_module_id def append_kernel(self, name, device, duration): if self.device_type != DeviceType.CPU: @@ -733,6 +878,14 @@ def self_device_memory_usage(self): child.device_memory_usage for child in self.cpu_children ) + @property + @deprecated( + "`metadata_json` is deprecated. Use `event_metadata` instead.", + category=FutureWarning, + ) + def metadata_json(self): + return self._metadata_json + @property @deprecated( "`self_cuda_memory_usage` is deprecated. Use `self_device_memory_usage` instead.", @@ -1160,7 +1313,6 @@ def _build_table( if append_node_id: headers.append("Node ID") - # Have to use a list because nonlocal is Py3 only... SPACING_SIZE = 2 row_format_lst = [""] header_sep_lst = [""] @@ -1221,7 +1373,6 @@ def auto_scale_flops(flops): line_length = line_length_lst[0] add_column = None # type: ignore[assignment] - # Have to use a list because nonlocal is Py3 only... result = [] def append(s): diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index f54a3fd6820c7..f07a4797c64da 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -135,5 +135,6 @@ class GenericModule(PropModule): nnpack as nnpack, openmp as openmp, opt_einsum as opt_einsum, + python_native as python_native, quantized as quantized, ) diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index 0be813c7610d1..386cb664cbb4d 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -246,12 +246,12 @@ def broadcast_shapes(shape1, shape2): # s2 = [1] * (len(s1) - len(s2)) + s2 raise Exception( # noqa: TRY002 "Non-equal-rank broadcast is not supported yet." - ) # noqa: TRY002 + ) if len(s2) > len(s1): # s3 = [1] * (len(s2) - len(s1)) + s1 raise Exception( # noqa: TRY002 "Non-equal-rank broadcast is not supported yet." - ) # noqa: TRY002 + ) ret = [] for d1, d2 in zip(s1, s2): if d1 == 1: @@ -263,7 +263,7 @@ def broadcast_shapes(shape1, shape2): else: raise Exception( # noqa: TRY002 f"Cannot broadcast shapes: {shape1} and {shape2}" - ) # noqa: TRY002 + ) return tuple(ret) @@ -420,7 +420,7 @@ def torch_tensor_to_operand(self, tensor, dim_order): else: raise Exception( # noqa: TRY002 f"Can't handle input with dtype '{tensor.dtype}'" - ) # noqa: TRY002 + ) return Operand( shape=tuple(tensor.shape), # pyrefly: ignore [bad-argument-type] @@ -512,7 +512,7 @@ def get_tensor_operand_by_jitval_fixed_size(self, jitval): # many callsites to support flexible size. raise Exception( # noqa: TRY002 "Flexible size is not supported for this operand." - ) # noqa: TRY002 + ) if s < 0: # runtime flex LOG.warning("Operand %s has runtime flex shape", oper) @@ -551,7 +551,7 @@ def get_constant_value(self, jitval, typekind=None): if record is None: raise Exception( # noqa: TRY002 f"Could not find constant value for '{jitval!r}'." - ) # noqa: TRY002 + ) ctype, _ = record if typekind is not None and ctype.kind() != typekind: raise Exception( # noqa: TRY002 @@ -583,7 +583,7 @@ def operand_to_template_torchscript(self, op_id, oper, shape=None): else: raise Exception( # noqa: TRY002 "Unknown dim value, dimensions should be >= -1" - ) # noqa: TRY002 + ) shape_parts.append(",") shape_parts.append(")") shape_code = "".join(shape_parts) @@ -611,7 +611,7 @@ def operand_to_template_torchscript(self, op_id, oper, shape=None): raise Exception( # noqa: TRY002 f"Unsupported output operand type: {oper.op_type}" - ) # noqa: TRY002 + ) def forward_operand_shape(self, out_op_id, out_dim, in_op_id, in_dim): self.compute_operand_shape(out_op_id, out_dim, flex_name(in_op_id, in_dim)) @@ -625,7 +625,7 @@ def transpose_to_nhwc(self, in_id, oper): if oper.shape[2:] != (1, 1): raise Exception( # noqa: TRY002 "Automatic transpose only supported for H,W == 1,1" - ) # noqa: TRY002 + ) out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) @@ -666,7 +666,7 @@ def get_size_arg(self, jitval): return value raise Exception( # noqa: TRY002 f"Can't handle size arg of type '{ctype!r}' for '{jitval!r}'" - ) # noqa: TRY002 + ) def get_conv_pool_args_2d_from_pack(self, kernel_size, packed_config): pc = [i.item() for i in packed_config] @@ -770,7 +770,7 @@ def serialize_model(self, model, inputs, return_shapes=None): else: raise Exception( # noqa: TRY002 f"Unsupported return type: {retn_input.type()}" - ) # noqa: TRY002 + ) if return_shapes is not None: if len(return_shapes) != len(return_values): @@ -955,7 +955,7 @@ def add_node(self, node): if not adder: raise Exception( # noqa: TRY002 f"Unsupported node kind ({node.kind()!r}) in node {node!r}" - ) # noqa: TRY002 + ) adder(self, node) def _identity(self, node): @@ -1155,7 +1155,7 @@ def add_flatten(self, node): if any(dim == 0 for dim in in_oper.shape[start_dim : end_dim + 1]): raise Exception( # noqa: TRY002 "Flattening flexible dims is not supported yet" - ) # noqa: TRY002 + ) non_flattened_dims = in_oper.shape[:start_dim] + in_oper.shape[end_dim + 1 :] if non_flattened_dims.count(0) > 1: raise Exception("Only 1 dim can be flexible") # noqa: TRY002 @@ -1220,7 +1220,7 @@ def add_slice(self, node): if start_value >= stop_value: raise Exception( # noqa: TRY002 "Slice start value should be less than stop value" - ) # noqa: TRY002 + ) out_len = (stop_value - start_value) // step_value out_shape = tuple( @@ -1513,7 +1513,7 @@ def add_pointwise_simple_unary_op(self, node, opcode): self.add_operation(opcode, inputs, outputs) - def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D401 + def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): """Helper for pointwise binary broadcast ops with superfluous extra args.""" if node.outputsSize() != 1: raise AssertionError( @@ -1542,7 +1542,7 @@ def _do_add_binary(self, node, opcode, fuse_code, *, qparams=None): # noqa: D40 else: raise Exception( # noqa: TRY002 f"Can't do a NNAPI binary op: {opcode} on two constants" - ) # noqa: TRY002 + ) if in0_oper.op_type != in1_oper.op_type: raise AssertionError( @@ -1597,7 +1597,7 @@ def add_add_sub_op(self, node, opcode, fuse_code): if alpha != 1: raise Exception( # noqa: TRY002 "NNAPI does not support add/sub with alpha." - ) # noqa: TRY002 + ) self._do_add_binary(node, opcode, fuse_code) @@ -1661,7 +1661,7 @@ def add_hardtanh(self, node): if opcode is None: raise Exception( # noqa: TRY002 "NNAPI only supports hardtanh with args (-1, 1) or (0, 6)." - ) # noqa: TRY002 + ) inputs = [None] * 1 inputs[0] = in_id @@ -1712,7 +1712,7 @@ def add_prelu_op(self, node): elif dim <= 1: raise Exception( # noqa: TRY002 "PReLU requires fixed size for dim 0 and dim 1." - ) # noqa: TRY002 + ) else: self.forward_operand_shape(out_id, dim, in_id, dim) @@ -2033,7 +2033,7 @@ def add_upsample_nearest2d(self, node): else: raise Exception( # noqa: TRY002 "Size and scale cannot both be None." - ) # noqa: TRY002 + ) inputs = [None] * 4 inputs[0] = image_id @@ -2507,7 +2507,7 @@ def add_conv2d_common( else: raise Exception( # noqa: TRY002 f"Unsupported input type for conv2d: {image_oper.op_type}" - ) # noqa: TRY002 + ) if len(image_oper.shape) != 4: raise AssertionError( diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 536cd4fa68391..8094607542e49 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -14,7 +14,11 @@ "cuBLASModule", "preferred_linalg_library", "preferred_blas_library", + "cublas_workspace_size", + "cublaslt_workspace_size", + "blas_workspace_size", "preferred_rocm_fa_library", + "is_ck_sdpa_available", "cufft_plan_cache", "matmul", "SDPAParams", @@ -329,6 +333,109 @@ def preferred_blas_library( return torch._C._get_blas_preferred_backend() +def cublas_workspace_size(size: None | int = None) -> int: + r"""Query or set the cuBLAS workspace size in bytes. + + When called with no arguments, returns the current workspace size. + When called with a size argument, sets the workspace size and returns the new value. + Setting the workspace size will take precedence over the CUBLAS_WORKSPACE_CONFIG environment variable. + Changes take effect lazily: only handles used after the change get new workspaces. + + Args: + size (int, optional): workspace size in bytes. Must be non-negative. + + Returns: + int: the current (or newly set) workspace size in bytes. + """ + if size is not None: + torch._C._cuda_setCublasWorkspaceSize(size) + return torch._C._cuda_getCublasWorkspaceSize() + + +def cublaslt_workspace_size(size: None | int = None) -> int: + r"""Query or set the cuBLASLt workspace size in bytes. + + When called with no arguments, returns the current workspace size. + When called with a size argument, sets the workspace size and returns the new value. + Setting the workspace size will take precedence over the CUBLASLT_WORKSPACE_SIZE environment variable. + Changes take effect lazily: only handles used after the change get new workspaces. + + Args: + size (int, optional): workspace size in bytes. Must be non-negative. + + Returns: + int: the current (or newly set) workspace size in bytes. + """ + if size is not None: + torch._C._cuda_setCublasLtWorkspaceSize(size) + return torch._C._cuda_getCublasLtWorkspaceSize() + + +def blas_workspace_size( + size: None | int = None, + backend: None | str | torch._C._BlasBackend = None, +) -> int: + r"""Query or set the BLAS workspace size for a given backend. + + Convenience wrapper that dispatches to :func:`cublas_workspace_size` or + :func:`cublaslt_workspace_size` depending on the backend. + + When *backend* is ``None`` the current :func:`preferred_blas_library` is + used. ``Default`` is resolved to the platform's default backend (cuBLAS + on NVIDIA, potentially hipBLASLt on supported ROCm architectures). + + .. note:: + + When ``TORCH_CUBLASLT_UNIFIED_WORKSPACE`` is enabled (the default on + open-source CUDA builds), the cuBLASLt workspace is capped at the + cuBLAS workspace size and physically reuses the same allocation. + Setting a large cuBLASLt workspace via this function will therefore + *not* increase memory beyond the cuBLAS workspace size. + + .. note:: + + Setting the workspace size for the cublas backend will take precedence + over the CUBLAS_WORKSPACE_CONFIG environment variable, and setting the + workspace size for the cublaslt backend will take precedence over the + CUBLASLT_WORKSPACE_SIZE environment variable. + + Args: + size (int, optional): workspace size in bytes. Must be non-negative. + When omitted the current size is returned without modification. + backend (str | torch._C._BlasBackend, optional): which backend's + workspace to query/set. Accepts the same strings as + :func:`preferred_blas_library` (e.g. ``"cublas"``, ``"cublaslt"``). + + Returns: + int: the current (or newly set) workspace size in bytes. + + Raises: + RuntimeError: if the resolved backend is CK (no workspace concept). + """ + if backend is None: + resolved = preferred_blas_library() + elif isinstance(backend, str): + if backend not in _BlasBackends: + raise RuntimeError( + f"Unknown backend string. Choose from: {_BlasBackends_str}." + ) + resolved = _BlasBackends[backend] + elif isinstance(backend, torch._C._BlasBackend): + resolved = backend + else: + raise RuntimeError("Unknown backend type.") + + if resolved == torch._C._BlasBackend.Default: + resolved = torch._C._get_blas_default_backend() + + if resolved == torch._C._BlasBackend.Ck: + raise RuntimeError("CK backend does not use a workspace.") + + if resolved == torch._C._BlasBackend.Cublaslt: + return cublaslt_workspace_size(size) + return cublas_workspace_size(size) + + _ROCmFABackends = { "default": torch._C._ROCmFABackend.Default, "aotriton": torch._C._ROCmFABackend.AOTriton, @@ -385,6 +492,17 @@ def preferred_rocm_fa_library( SDPAParams.__name__ = "SDPAParams" +def is_ck_sdpa_available() -> bool: + r""" + .. warning:: This flag is beta and subject to change. + + Returns whether composable_kernel may be used as the backend for + scaled-dot-product-attention. + """ + # pyrefly: ignore [missing-attribute] + return torch._C._is_ck_sdpa_available() + + def flash_sdp_enabled(): r""" .. warning:: This flag is beta and subject to change. diff --git a/torch/backends/cudnn/__init__.py b/torch/backends/cudnn/__init__.py index e2494aa5ff8ef..c98153dcfa697 100644 --- a/torch/backends/cudnn/__init__.py +++ b/torch/backends/cudnn/__init__.py @@ -158,6 +158,7 @@ def set_flags( _deterministic=None, _allow_tf32=None, _fp32_precision="none", + _depthwise_kernel=None, ): orig_flags = ( torch._C._get_cudnn_enabled(), @@ -166,6 +167,7 @@ def set_flags( torch._C._get_cudnn_deterministic(), torch._C._get_cudnn_allow_tf32(), torch._C._get_fp32_precision_getter("cuda", "all"), + torch._C._get_cudnn_depthwise_kernel(), ) if _enabled is not None: torch._C._set_cudnn_enabled(_enabled) @@ -179,6 +181,8 @@ def set_flags( torch._C._set_cudnn_allow_tf32(_allow_tf32) if _fp32_precision is not None: torch._C._set_fp32_precision_setter("cuda", "all", _fp32_precision) + if _depthwise_kernel is not None: + torch._C._set_cudnn_depthwise_kernel(_depthwise_kernel) return orig_flags @@ -190,6 +194,7 @@ def flags( deterministic=False, allow_tf32=True, fp32_precision="none", + depthwise_kernel="auto", ): with __allow_nonbracketed_mutation(): orig_flags = set_flags( @@ -199,6 +204,7 @@ def flags( deterministic, allow_tf32, fp32_precision, + depthwise_kernel, ) try: yield @@ -235,6 +241,10 @@ class CudnnModule(PropModule): _get_fp32_precision_getter("cuda", "all"), _set_fp32_precision_setter("cuda", "all"), ) + depthwise_kernel = ContextProp( + torch._C._get_cudnn_depthwise_kernel, + torch._C._set_cudnn_depthwise_kernel, + ) # This is the sys.modules replacement trick, see @@ -248,3 +258,4 @@ class CudnnModule(PropModule): allow_tf32: bool fp32_precision: str benchmark_limit: int +depthwise_kernel: str diff --git a/torch/backends/cudnn/rnn.py b/torch/backends/cudnn/rnn.py index a17dfbdef59e8..177b664aa3b0a 100644 --- a/torch/backends/cudnn/rnn.py +++ b/torch/backends/cudnn/rnn.py @@ -32,7 +32,7 @@ def get_cudnn_mode(mode): # pyrefly: ignore [missing-attribute] return int(_cudnn.RNNMode.gru) else: - raise ValueError(f"Unknown mode: {mode}") # noqa: TRY002 + raise ValueError(f"Unknown mode: {mode}") # NB: We don't actually need this class anymore (in fact, we could serialize the @@ -46,8 +46,6 @@ def get(self): return self.inner def __getstate__(self): - # Note: can't return {}, because python2 won't call __setstate__ - # if the value evaluates to False return "" def __setstate__(self, state): diff --git a/torch/backends/python_native/__init__.py b/torch/backends/python_native/__init__.py new file mode 100644 index 0000000000000..9e67571e148a1 --- /dev/null +++ b/torch/backends/python_native/__init__.py @@ -0,0 +1,387 @@ +""" +User-facing API for controlling DSL operation overrides. + +The torch.backends.python_native module provides control over DSL (Domain Specific Language) +operation overrides defined in torch._native. This allows users to selectively enable or disable +high-performance implementations from various DSLs like Triton and CuteDSL. + +The module supports both coarse-grained control (entire DSLs) and fine-grained control +(individual operations or dispatch keys). All control operations support context managers +for temporary state changes. + +Example usage:: + + import torch.backends.python_native as pn + + # DSL-level control + pn.triton.enabled = False # Disable all triton ops + pn.cutedsl.enabled = True # Enable all cutedsl ops + + # Individual operation control + pn.disable_operations("scaled_mm") # Disable specific op across all DSLs + pn.enable_operations("scaled_mm") # Re-enable specific op + + # Context manager support + with pn.triton.disabled(): + result = some_computation() # Triton ops disabled here + + # Query capabilities + print(pn.available_dsls) # ['triton', 'cutedsl'] + print(pn.get_dsl_operations("triton")) # Operations for triton +""" + +import functools +import sys +import types +from contextlib import contextmanager + +from torch.backends import ContextProp, flags_frozen, PropModule + + +@contextmanager +def _preserve_filter_state(): + """Context manager to save and restore registry filter state.""" + filter_state = _get_filter_state() + + # Save original state + original_state = ( + set(filter_state._dsl_names), + set(filter_state._op_symbols), + set(filter_state._dispatch_keys), + ) + + try: + yield filter_state + finally: + # Restore original state + filter_state._dsl_names.clear() + filter_state._op_symbols.clear() + filter_state._dispatch_keys.clear() + + filter_state._dsl_names.update(original_state[0]) + filter_state._op_symbols.update(original_state[1]) + filter_state._dispatch_keys.update(original_state[2]) + + +def _get_dsl_registry(): + """Lazy import to avoid circular imports.""" + from torch._native.dsl_registry import dsl_registry + + return dsl_registry + + +def _get_registry_functions(): + """Lazy import of registry functions.""" + from torch._native.registry import ( + _filter_state, + _graphs, + deregister_op_overrides, + reenable_op_overrides, + ) + + return deregister_op_overrides, reenable_op_overrides, _graphs, _filter_state + + +def _get_filter_state(): + """Direct access to filter state.""" + return _get_registry_functions()[3] + + +def _get_dsl_module(dsl_name: str): + """Get the registered DSL module for direct control. + + Uses the DSL registry to dynamically look up DSL modules instead of + hard-coding the mapping. This makes the function automatically extensible + for new DSLs without code changes. + + Args: + dsl_name (str): Name of the DSL to retrieve. + + Returns: + DSLModuleProtocol: The registered DSL module. + + Raises: + ValueError: If the DSL is not registered. + """ + registry = _get_dsl_registry() + + # Use the public API to get the DSL module + dsl_module = registry.get_dsl_module(dsl_name) + if dsl_module is not None: + return dsl_module + else: + raise ValueError( + f"Unknown DSL: {dsl_name}. Available DSLs: {registry.list_all_dsls()}" + ) + + +class DSLController: + """Controller for a specific DSL.""" + + def __init__(self, dsl_name: str): + self._dsl_name = dsl_name + + @property + def name(self) -> str: + return self._dsl_name + + @property + def available(self) -> bool: + """Check if DSL runtime is available.""" + registry = _get_dsl_registry() + return registry.is_dsl_available(self._dsl_name) + + @property + def version(self): + """Get DSL version.""" + registry = _get_dsl_registry() + return registry.get_dsl_version(self._dsl_name) + + @property + def enabled(self) -> bool: + """Check if DSL is currently enabled.""" + filter_state = _get_filter_state() + return self._dsl_name not in filter_state._dsl_names + + @enabled.setter + def enabled(self, value: bool): + """Enable or disable the DSL.""" + if flags_frozen(): + raise RuntimeError( + f"not allowed to set {self._dsl_name} DSL flags " + "after disable_global_flags; please use flags() context manager instead" + ) + if value: + self.enable() + else: + self.disable() + + def disable(self): + """Disable all operations for this DSL.""" + dsl_module = _get_dsl_module(self._dsl_name) + dsl_module.deregister_op_overrides() + + def enable(self): + """Re-enable all operations for this DSL.""" + reenable_op_overrides = _get_registry_functions()[1] + reenable_op_overrides(enable_dsl_names=self._dsl_name) + + @contextmanager + def disabled(self): + """Context manager to temporarily disable DSL.""" + original_state = self.enabled + try: + self.disable() + yield + finally: + if original_state: + self.enable() + + def __repr__(self): + status = "available" if self.available else "unavailable" + enabled_status = "enabled" if self.enabled else "disabled" + return f"DSLController({self._dsl_name}, {status}, {enabled_status})" + + +class PythonNativeModule(PropModule): + """Main module for python_native DSL control.""" + + def __init__(self, original_module): + super().__init__(original_module, original_module.__name__) + + @property + def available_dsls(self) -> list[str]: + """Get list of available DSLs.""" + registry = _get_dsl_registry() + result = registry.list_available_dsls() + return list(result) if not isinstance(result, list) else result + + @property + def all_dsls(self) -> list[str]: + """Get list of all registered DSLs.""" + registry = _get_dsl_registry() + result = registry.list_all_dsls() + return list(result) if not isinstance(result, list) else result + + def get_dsl_operations(self, dsl_name: str) -> list[str]: + """Get list of operations registered by a specific DSL. + + Args: + dsl_name (str): Name of the DSL to query (e.g., 'triton', 'cutedsl'). + + Returns: + list[str]: Sorted list of operation names registered by the DSL. + + Example:: + + ops = torch.backends.python_native.get_dsl_operations("triton") + print(ops) # ['triton_to_mxfp8_dim0', ...] + """ + from torch._native.registry import get_dsl_operations + + return get_dsl_operations(dsl_name) + + def disable_operations(self, *op_symbols: str): + """Disable specific operations across all DSLs. + + Args: + *op_symbols (str): Names of operations to disable. + + Example:: + + # Disable scaled matrix multiply across all DSLs + torch.backends.python_native.disable_operations("scaled_mm") + + # Disable multiple operations + torch.backends.python_native.disable_operations( + "scaled_mm", "flash_attention" + ) + """ + deregister_op_overrides = _get_registry_functions()[0] + deregister_op_overrides(disable_op_symbols=list(op_symbols)) + + def enable_operations(self, *op_symbols: str): + """Re-enable specific operations across all DSLs. + + Args: + *op_symbols (str): Names of operations to re-enable. + + Example:: + + # Re-enable previously disabled operations + torch.backends.python_native.enable_operations( + "scaled_mm", "flash_attention" + ) + """ + reenable_op_overrides = _get_registry_functions()[1] + reenable_op_overrides(enable_op_symbols=list(op_symbols)) + + def disable_dispatch_keys(self, *dispatch_keys: str): + """Disable operations at specific dispatch keys. + + Args: + *dispatch_keys (str): Dispatch keys to disable (e.g., 'CUDA', 'CPU'). + + Example:: + + # Disable all native operations on CUDA + torch.backends.python_native.disable_dispatch_keys("CUDA") + """ + deregister_op_overrides = _get_registry_functions()[0] + deregister_op_overrides(disable_dispatch_keys=list(dispatch_keys)) + + def enable_dispatch_keys(self, *dispatch_keys: str): + """Re-enable operations at specific dispatch keys. + + Args: + *dispatch_keys (str): Dispatch keys to re-enable (e.g., 'CUDA', 'CPU'). + + Example:: + + # Re-enable native operations on CUDA + torch.backends.python_native.enable_dispatch_keys("CUDA") + """ + reenable_op_overrides = _get_registry_functions()[1] + reenable_op_overrides(enable_dispatch_keys=list(dispatch_keys)) + + @contextmanager + def operations_disabled(self, *op_symbols: str): + """Context manager to temporarily disable operations. + + Args: + *op_symbols (str): Names of operations to temporarily disable. + + Example:: + + with torch.backends.python_native.operations_disabled("scaled_mm"): + # scaled_mm is disabled across all DSLs + result = model(input) + # scaled_mm is automatically re-enabled here + """ + filter_state = _get_filter_state() + previously_disabled_ops = { + op for op in op_symbols if op in filter_state._op_symbols + } + + self.disable_operations(*op_symbols) + try: + yield + finally: + # Only re-enable operations that weren't already disabled + ops_to_reenable = [ + op for op in op_symbols if op not in previously_disabled_ops + ] + if ops_to_reenable: + self.enable_operations(*ops_to_reenable) + + @functools.lru_cache(maxsize=16) # noqa: B019 + def _get_dsl_controller(self, name: str) -> "DSLController": + """Get or create a DSL controller (cached).""" + return DSLController(name) + + def _get_registry_functions(self): + """Expose registry functions for testing.""" + return _get_registry_functions() + + def is_operation_disabled(self, op_symbol: str) -> bool: + """Check if an operation is currently disabled.""" + filter_state = _get_filter_state() + return op_symbol in filter_state._op_symbols + + def is_dsl_disabled(self, dsl_name: str) -> bool: + """Check if a DSL is currently disabled.""" + filter_state = _get_filter_state() + return dsl_name in filter_state._dsl_names + + def __getattr__(self, name: str): + """Dynamic attribute access for DSL controllers.""" + # Skip dunder attributes to avoid triggering DSL registry lookups + # during torch initialization. inspect.getmodule() calls + # hasattr(module, '__file__') which would otherwise cause a circular + # import through _get_dsl_registry() while torch is still loading. + if name.startswith("__") and name.endswith("__"): + raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'") + + if name in self.all_dsls: + return self._get_dsl_controller(name) + + # Expose private functions for testing + if name == "_get_dsl_module": + return _get_dsl_module + if name == "_get_registry_functions": + return self._get_registry_functions + if name == "_get_filter_state": + return _get_filter_state + + raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'") + + def __dir__(self): + """Return available attributes including DSL names.""" + attrs = set(super().__dir__()) + attrs.update( + { + "available_dsls", + "all_dsls", + "get_dsl_operations", + "disable_operations", + "enable_operations", + "disable_dispatch_keys", + "enable_dispatch_keys", + "operations_disabled", + "is_operation_disabled", + "is_dsl_disabled", + } + ) + + # Add DSL names + try: + attrs.update(self.all_dsls) + except Exception: + # If registry not available yet, skip DSL names + pass + + return sorted(attrs) + + +# Replace the current module with our enhanced version +sys.modules[__name__] = PythonNativeModule(sys.modules[__name__]) diff --git a/torch/backends/xeon/run_cpu.py b/torch/backends/xeon/run_cpu.py index 0f9aec7152425..e731e74bebb8b 100644 --- a/torch/backends/xeon/run_cpu.py +++ b/torch/backends/xeon/run_cpu.py @@ -249,8 +249,8 @@ def numa_aware_check(self, core_list): "Numa Aware: cores:%s on different NUMA nodes:%s. To avoid \ this behavior, please use --ncores-per-instance knob to make sure number of cores is divisible by --ncores-per-\ instance. Alternatively, please use --skip-cross-node-cores knob.", - str(core_list), - str(numa_ids), + core_list, + numa_ids, ) if len(numa_ids) == 0: raise RuntimeError( @@ -345,7 +345,7 @@ def set_memory_allocator( find_tc = self.add_lib_preload(lib_type="tcmalloc") if not find_tc: msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge gperftools" to install {{0}}' - logger.warning(msg.format("TCmalloc", "tcmalloc")) # noqa: G001 + logger.warning(msg.format("TCmalloc", "tcmalloc")) else: logger.info("Use TCMalloc memory allocator") @@ -353,7 +353,7 @@ def set_memory_allocator( find_je = self.add_lib_preload(lib_type="jemalloc") if not find_je: msg = f'{self.msg_lib_notfound} you can use "conda install -c conda-forge jemalloc" to install {{0}}' - logger.warning(msg.format("Jemalloc", "jemalloc")) # noqa: G001 + logger.warning(msg.format("Jemalloc", "jemalloc")) else: logger.info("Use JeMalloc memory allocator") self.set_env( @@ -426,7 +426,7 @@ def set_multi_thread_and_allocator( find_iomp = self.add_lib_preload(lib_type="iomp5") if not find_iomp: msg = f'{self.msg_lib_notfound} you can use "conda install mkl" to install {{0}}' - logger.warning(msg.format("iomp", "iomp5")) # noqa: G001 + logger.warning(msg.format("iomp", "iomp5")) else: logger.info("Using Intel OpenMP") if set_kmp_affinity: diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index b7963e6a1d85c..0dc3edddabda6 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import contextlib import io from collections.abc import Callable from typing import Any, TypeVar @@ -20,6 +21,8 @@ "substitute_in_graph", "list_backends", "disable", + "set_default_backend", + "get_default_backend", "set_stance", "set_enable_guard_collectives", "cudagraph_mark_step_begin", @@ -257,6 +260,39 @@ def disable(fn=None, recursive=True, *, reason=None): return torch._dynamo.disable(fn, recursive, reason=reason) +def set_default_backend(backend: str | Callable[..., Any] | None) -> None: + """Set the default backend for ``torch.compile`` when no ``backend`` argument is specified. + + Passing ``None`` resets the default back to ``"inductor"``. + + Args: + backend: A backend name (string), a callable backend, or ``None``. + + Example:: + + >>> torch.compiler.set_default_backend("eager") + >>> torch.compiler.get_default_backend() + 'eager' + >>> torch.compiler.set_default_backend(None) # reset + >>> torch.compiler.get_default_backend() + 'inductor' + """ + from torch._dynamo.backends.registry import set_default_backend + + set_default_backend(backend) + + +def get_default_backend() -> str | Callable[..., Any]: + """Return the current default backend for ``torch.compile``. + + Returns: + The current default backend (string or callable). Initially ``"inductor"``. + """ + from torch._dynamo.backends.registry import get_default_backend + + return get_default_backend() + + def set_stance( stance: str = "default", *, @@ -360,7 +396,7 @@ def set_enable_guard_collectives(enabled: bool): Returns the previous setting of enabled. """ - from torch._C._dynamo.eval_frame import set_guard_complete_hook # noqa: F401 + from torch._C._dynamo.eval_frame import set_guard_complete_hook from torch._dynamo.eval_frame import guard_collectives_hook if enabled: @@ -433,6 +469,7 @@ def wrap_numpy(fn): _is_compiling_flag: bool = False _is_exporting_flag: bool = False +_is_non_strict_tracing_flag: bool = False def is_compiling() -> bool: @@ -457,6 +494,105 @@ def is_compiling() -> bool: return _is_compiling_flag +def _is_non_strict_tracing() -> bool: + """ + Indicates whether we are inside a non-strict make_fx-based tracing session. + """ + return _is_non_strict_tracing_flag + + +@contextlib.contextmanager +def _non_strict_tracing_context(): + """Context manager that sets the non-strict tracing flag.""" + global _is_non_strict_tracing_flag + old = _is_non_strict_tracing_flag + try: + _is_non_strict_tracing_flag = True + yield + finally: + _is_non_strict_tracing_flag = old + + +@contextlib.contextmanager +def _patch_autograd_grad(): + """Patch autograd.grad for non-strict make_fx tracing. + + This patch installs autograd hooks so traced backward nodes preserve + stack trace, seq_nr, and autograd_backward metadata before delegating to + the real torch.autograd.grad. + """ + import functools + + import torch.autograd + from torch._dynamo.utils import warn_once + from torch._functorch._aot_autograd.logging_utils import ( + setup_stacktrace_preservation_hooks_from_tensors, + ) + + warn_once( + "torch.compiler._patch_autograd_grad() is deprecated; " + "use torch.compiler._patch_engine_backward() instead." + ) + # TODO: Remove this helper once Titan no longer depends on it. + + _orig_grad = torch.autograd.grad + + @functools.wraps(_orig_grad) + def _patched_grad(outputs, inputs, *args, **kwargs): + if not _is_non_strict_tracing(): + raise AssertionError( + "_patch_autograd_grad() must be used under " + "_non_strict_tracing_context()" + ) + + setup_stacktrace_preservation_hooks_from_tensors(outputs) + return _orig_grad(outputs, inputs, *args, **kwargs) + + torch.autograd.grad = _patched_grad + try: + yield + finally: + torch.autograd.grad = _orig_grad + + +@contextlib.contextmanager +def _patch_engine_backward(): + """Patch _engine_run_backward for non-strict make_fx tracing. + + This patch installs autograd hooks so traced backward nodes preserve + stack trace, seq_nr, and autograd_backward metadata before delegating to + the real autograd engine entrypoint used by backward(). + """ + import functools + + import torch.autograd + import torch.autograd.graph + from torch._functorch._aot_autograd.logging_utils import ( + setup_stacktrace_preservation_hooks_from_tensors, + ) + + _orig_engine_run_backward = torch.autograd.graph._engine_run_backward + + @functools.wraps(_orig_engine_run_backward) + def _patched_engine_backward(outputs, *args, **kwargs): + if not _is_non_strict_tracing(): + raise AssertionError( + "_patch_engine_backward() must be used under " + "_non_strict_tracing_context()" + ) + + setup_stacktrace_preservation_hooks_from_tensors(outputs) + return _orig_engine_run_backward(outputs, *args, **kwargs) + + torch.autograd.graph._engine_run_backward = _patched_engine_backward + torch.autograd._engine_run_backward = _patched_engine_backward + try: + yield + finally: + torch.autograd.graph._engine_run_backward = _orig_engine_run_backward + torch.autograd._engine_run_backward = _orig_engine_run_backward + + def is_dynamo_compiling() -> bool: """ Indicates whether a graph is traced via TorchDynamo. @@ -658,7 +794,13 @@ def skip_all_guards_unsafe(guard_entries): return [False for entry in guard_entries] -def nested_compile_region(fn=None, options: NestedCompileRegionOptions | None = None): +def nested_compile_region( + fn=None, + *, + options: NestedCompileRegionOptions | None = None, + max_reuse_entries: int = 8, + reuse_hash_fn=None, +): """ Tells **``torch.compile``** that the marked set of operations forms a nested compile region (which is often repeated in the full model) whose code can be @@ -688,6 +830,17 @@ def nested_compile_region(fn=None, options: NestedCompileRegionOptions | None = options: Optional backend to use for compiling the subgraph. Warning: this is an experimental feature under development and not ready for use yet. + max_reuse_entries: Maximum number of reuse cache entries per function + before raising an error. If this limit is hit, guards keep failing + across invocations and hierarchical compilation is not effective. + reuse_hash_fn: Optional callable that takes the same ``*args, **kwargs`` + as the wrapped function and returns an integer hash key. When + provided, Dynamo traces this function to obtain a constant integer + and uses it as the cache key for subgraph reuse, bypassing the + automatic fingerprint/guard machinery. Two calls that produce the + same hash key reuse the same cached subgraph. The hash function + must be fully traceable (no graph breaks) and must return a + constant integer. """ if options is not None: @@ -702,7 +855,12 @@ def nested_compile_region(fn=None, options: NestedCompileRegionOptions | None = mark_compile_region as _mark_compile_region, ) - return _mark_compile_region(fn, options=options) + return _mark_compile_region( + fn, + options=options, + max_reuse_entries=max_reuse_entries, + reuse_hash_fn=reuse_hash_fn, + ) def load_compiled_function( diff --git a/torch/compiler/_cache.py b/torch/compiler/_cache.py index 4d759ce21c992..6252148dae33e 100644 --- a/torch/compiler/_cache.py +++ b/torch/compiler/_cache.py @@ -242,7 +242,7 @@ def record_artifact( artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) if artifact in cls._seen_artifacts: return - log.debug("Recording %s", str(artifact)) + log.debug("Recording %s", artifact) cls._new_cache_artifacts[artifact_type].append(artifact) cls._seen_artifacts.add(artifact) diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index c75643e2fa129..c3caf43e805b4 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -1,5 +1,6 @@ #include #include +#include #include namespace torch::accelerator { @@ -100,6 +101,15 @@ void initModule(PyObject* module) { m.def("_accelerator_emptyCache", []() { at::accelerator::emptyCache(); }); + m.def("_accelerator_emptyHostCache", []() { + const auto device_type = at::accelerator::getAccelerator(true).value(); + if (torch::utils::is_device_lazy_init_supported(device_type) && + !torch::utils::is_device_initialized(device_type)) { + return; + } + at::accelerator::emptyHostCache(); + }); + m.def("_accelerator_getDeviceStats", [](c10::DeviceIndex device_index) { using c10::CachingAllocator::Stat; using c10::CachingAllocator::StatArray; @@ -171,6 +181,65 @@ void initModule(PyObject* module) { m.def("_accelerator_setAllocatorSettings", [](std::string env) { c10::CachingAllocator::setAllocatorSettings(env); }); + + // Accelerator Graph class binding + py::class_>( + m, "_acceleratorGraph") + .def(py::init(), py::arg("keep_graph") = false) + .def( + "capture_begin", + [](at::accelerator::Graph& self, + std::optional pool_opt, + const std::string& capture_error_mode) { + c10::MempoolId_t pool = pool_opt.has_value() + ? pool_opt.value() + : c10::MempoolId_t{0, 0}; + at::GraphCaptureMode capture_mode = at::GraphCaptureMode::Default; + if (capture_error_mode == "default") { + capture_mode = at::GraphCaptureMode::Default; + } else if (capture_error_mode == "global") { + capture_mode = at::GraphCaptureMode::Global; + } else if (capture_error_mode == "thread_local") { + capture_mode = at::GraphCaptureMode::ThreadLocal; + } else if (capture_error_mode == "relaxed") { + capture_mode = at::GraphCaptureMode::Relaxed; + } else { + TORCH_CHECK( + false, + "Unknown capture error mode. Expected `default`, `global`, `thread_local`, or `relaxed`, got ", + capture_error_mode); + } + return self.capture_begin(pool, capture_mode); + }, + py::arg("pool") = std::nullopt, + py::arg("capture_error_mode") = "default", + py::call_guard()) + .def( + "capture_end", + torch::wrap_pybind_function_no_gil( + &at::accelerator::Graph::capture_end)) + .def( + "instantiate", + torch::wrap_pybind_function_no_gil( + &at::accelerator::Graph::instantiate)) + .def( + "replay", + torch::wrap_pybind_function_no_gil(&at::accelerator::Graph::replay)) + .def( + "reset", + torch::wrap_pybind_function_no_gil(&at::accelerator::Graph::reset)) + .def( + "pool", + torch::wrap_pybind_function_no_gil(&at::accelerator::Graph::pool)) + .def( + "enable_debug_mode", + torch::wrap_pybind_function_no_gil( + &::at::accelerator::Graph::enable_debug_mode)) + .def( + "debug_dump", + torch::wrap_pybind_function_no_gil( + &::at::accelerator::Graph::debug_dump), + py::arg("path")); } } // namespace torch::accelerator diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp index bff17ca0cbc79..5f1c064586e33 100644 --- a/torch/csrc/Dtype.cpp +++ b/torch/csrc/Dtype.cpp @@ -9,7 +9,6 @@ #include PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) { - HANDLE_TH_ERRORS AT_ASSERT(name.length() < DTYPE_NAME_LEN); auto type = &THPDtypeType; auto self = THPObjectPtr{type->tp_alloc(type, 0)}; @@ -19,7 +18,6 @@ PyObject* THPDtype_New(at::ScalarType scalar_type, const std::string& name) { self_->scalar_type = scalar_type; std::strncpy(self_->name, name.c_str(), DTYPE_NAME_LEN); return self.release(); - END_HANDLE_TH_ERRORS } static PyObject* THPDtype_is_floating_point(THPDtype* self, PyObject* noargs) { diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index 9db1903eec33a..d5d211d5b8fc5 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -113,7 +113,7 @@ std::tuple createStorageGetType( "'"); auto storage = THPStorage_Unpack(untyped_storage_obj); - return std::make_tuple(storage, scalar_type, is_typed_storage); + return std::make_tuple(std::move(storage), scalar_type, is_typed_storage); } at::Storage createStorage(PyObject* obj) { diff --git a/torch/csrc/Event.cpp b/torch/csrc/Event.cpp index ff3e87ef1506d..e3bdc5ee32062 100644 --- a/torch/csrc/Event.cpp +++ b/torch/csrc/Event.cpp @@ -109,7 +109,7 @@ static PyObject* THPEvent_record( TORCH_WARN("Parsing THPEvent_record arg fails"); return nullptr; } - if (_stream != Py_None) { + if (!Py_IsNone(_stream)) { auto stream = reinterpret_cast(_stream); self->event.record(c10::Stream::unpack3( stream->stream_id, @@ -187,7 +187,7 @@ static PyObject* THPEvent_wait( TORCH_WARN("Parsing THPEvent_wait arg fails"); return nullptr; } - if (_stream != Py_None) { + if (!Py_IsNone(_stream)) { auto stream = reinterpret_cast(_stream); self->event.block(c10::Stream::unpack3( stream->stream_id, diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 058335921209e..160e95c9f5799 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -24,14 +24,16 @@ PyObject* THPGenerator_initDefaultGenerator(const at::Generator& cdata) { auto type = reinterpret_cast(THPGeneratorClass); auto self = THPObjectPtr{type->tp_alloc(type, 0)}; if (!self) - throw python_error(); + throw python_error(); // @allow-raw-throw auto self_ = reinterpret_cast(self.get()); self_->cdata = cdata; + self_->weakreflist = nullptr; return self.release(); } static void THPGenerator_dealloc(PyObject* _self) { auto self = reinterpret_cast(_self); + PyObject_ClearWeakRefs(_self); if (self->cdata.defined()) { self->cdata.set_pyobj(nullptr); self->cdata.~Generator(); @@ -116,7 +118,7 @@ static uint64_t unpack_uint64(PyObject* pyobj) { unsigned_obj = *(reinterpret_cast(&obj)); } else { // If any other type of exception happened, rethrow it - throw; + throw; // @allow-raw-throw } } return unsigned_obj; @@ -235,7 +237,7 @@ static PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) { auto ret = THPObjectPtr{PyTuple_New(3)}; if (!ret) - throw python_error(); + throw python_error(); // @allow-raw-throw py::object torch_module = py::module::import("torch"); py::object torch_generator = torch_module.attr("Generator"); @@ -243,14 +245,14 @@ static PyObject* THPGenerator_reduce(PyObject* _self, PyObject* noargs) { auto args = THPObjectPtr{PyTuple_New(1)}; if (!args) - throw python_error(); + throw python_error(); // @allow-raw-throw PyTuple_SET_ITEM(args.get(), 0, THPGenerator_get_device(self, nullptr)); PyTuple_SET_ITEM(ret.get(), 1, args.release()); auto state = THPObjectPtr{PyTuple_New(3)}; if (!state) - throw python_error(); + throw python_error(); // @allow-raw-throw c10::DeviceType device_type = gen.device().type(); PyTuple_SET_ITEM(state.get(), 0, THPGenerator_initialSeed(_self, nullptr)); @@ -270,7 +272,7 @@ static PyObject* THPGenerator_pickleSetState(PyObject* _self, PyObject* state) { HANDLE_TH_ERRORS THPGenerator_manualSeed(_self, PyTuple_GET_ITEM(state, 0)); auto& offset = PyTuple_GET_ITEM(state, 1); - if (offset != Py_None) { + if (!Py_IsNone(offset)) { THPGenerator_setOffset(_self, offset); } THPGenerator_setState(_self, PyTuple_GET_ITEM(state, 2)); @@ -337,7 +339,7 @@ static PyTypeObject THPGeneratorType = { nullptr, /* tp_traverse */ nullptr, /* tp_clear */ nullptr, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ + offsetof(THPGenerator, weakreflist), /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ THPGenerator_methods, /* tp_methods */ @@ -354,9 +356,28 @@ static PyTypeObject THPGeneratorType = { }; bool THPGenerator_init(PyObject* module) { + // Set OpaqueBaseMeta as the metaclass so Generator can be registered as an + // opaque type for FX tracing (same pattern as ProcessGroup). + auto opaque_module = + THPObjectPtr(PyImport_ImportModule("torch._opaque_base")); + TORCH_CHECK(opaque_module, "Failed to import torch._opaque_base"); + auto opaque_base_meta = + THPObjectPtr(PyObject_GetAttrString(opaque_module, "OpaqueBaseMeta")); + TORCH_CHECK(opaque_base_meta, "Failed to get OpaqueBaseMeta"); + Py_SET_TYPE(&THPGeneratorType, (PyTypeObject*)opaque_base_meta.release()); + THPGeneratorClass = reinterpret_cast(&THPGeneratorType); if (PyType_Ready(&THPGeneratorType) < 0) return false; + // PyType_Ready inherits __module__ from the metaclass (OpaqueBaseMeta lives + // in torch._opaque_base). Override it so pickle can find the class at + // torch._C.Generator. + auto module_name = THPObjectPtr(PyUnicode_FromString("torch._C")); + if (!module_name) + return false; + if (PyDict_SetItemString( + THPGeneratorType.tp_dict, "__module__", module_name) < 0) + return false; Py_INCREF(&THPGeneratorType); PyModule_AddObject( module, "Generator", reinterpret_cast(&THPGeneratorType)); diff --git a/torch/csrc/Generator.h b/torch/csrc/Generator.h index b5f72cb47b762..f0f2e5c2cbb04 100644 --- a/torch/csrc/Generator.h +++ b/torch/csrc/Generator.h @@ -8,6 +8,7 @@ struct THPGenerator { PyObject_HEAD at::Generator cdata; + PyObject* weakreflist; }; // Creates a new Python object wrapping the default at::Generator. The reference diff --git a/torch/csrc/Layout.cpp b/torch/csrc/Layout.cpp index af7dfc74379de..8e1d2935c6c82 100644 --- a/torch/csrc/Layout.cpp +++ b/torch/csrc/Layout.cpp @@ -65,12 +65,7 @@ PyTypeObject THPLayoutType = { }; void THPLayout_init(PyObject* module) { - if (PyType_Ready(&THPLayoutType) < 0) { - throw python_error(); - } - Py_INCREF(&THPLayoutType); - if (PyModule_AddObject( - module, "layout", reinterpret_cast(&THPLayoutType)) != 0) { + if (PyModule_AddType(module, &THPLayoutType) < 0) { throw python_error(); } } diff --git a/torch/csrc/MemoryFormat.cpp b/torch/csrc/MemoryFormat.cpp index 0a8e212500cf1..3aadc4d5ac9e7 100644 --- a/torch/csrc/MemoryFormat.cpp +++ b/torch/csrc/MemoryFormat.cpp @@ -80,14 +80,7 @@ PyTypeObject THPMemoryFormatType = { }; void THPMemoryFormat_init(PyObject* module) { - if (PyType_Ready(&THPMemoryFormatType) < 0) { - throw python_error(); - } - Py_INCREF(&THPMemoryFormatType); - if (PyModule_AddObject( - module, - "memory_format", - reinterpret_cast(&THPMemoryFormatType)) != 0) { + if (PyModule_AddType(module, &THPMemoryFormatType) < 0) { throw python_error(); } } diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 56bce2ba966ed..28f61755eecca 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -19,7 +20,6 @@ #include #include -#include #include #include #include @@ -241,7 +241,7 @@ static PyObject* THPModule_initExtension( auto module = THPObjectPtr(PyImport_ImportModule("torch")); if (!module) - throw python_error(); + throw python_error(); // @allow-raw-throw THPStorage_postInit(module); THPAutograd_initFunctions(); @@ -473,7 +473,7 @@ static PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) { m->d_getset->name); } m->d_getset->doc = doc_str; - } else if (Py_TYPE(obj) == &PyType_Type) { + } else if (PyType_Check(obj)) { PyTypeObject* t = reinterpret_cast(obj); if (t->tp_doc) { return PyErr_Format( @@ -516,7 +516,7 @@ static PyObject* THPModule_setBackcompatBroadcastWarn( "set_backcompat_broadcast_warn expects a bool, " "but got ", THPUtils_typename(arg)); - setBackCompatBroadcastWarn(arg == Py_True); + setBackCompatBroadcastWarn(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -539,7 +539,7 @@ static PyObject* THPModule_setBackcompatKeepdimWarn( "set_backcompat_keepdim_warn expects a bool, " "but got ", THPUtils_typename(arg)); - setBackCompatKeepdimWarn(arg == Py_True); + setBackCompatKeepdimWarn(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -682,7 +682,7 @@ static PyObject* THPModule_torchDeviceToDLDevice( auto tuple = PyTuple_New(2); if (!tuple) { - throw python_error(); + throw python_error(); // @allow-raw-throw } PyTuple_SET_ITEM(tuple, 0, THPUtils_packInt64(dl_device.device_type)); @@ -852,7 +852,7 @@ static PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) { "set_allow_tf32_cublas expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowTF32CuDNN(arg == Py_True); + at::globalContext().setAllowTF32CuDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -925,7 +925,7 @@ static PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) { "set_sdp_use_math expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setSDPUseFlash(arg == Py_True); + at::globalContext().setSDPUseFlash(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -944,7 +944,7 @@ static PyObject* THPModule_setSDPUseFA3(PyObject* _unused, PyObject* arg) { "set_sdp_use_fa3 expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setSDPUseFA3(arg == Py_True); + at::globalContext().setSDPUseFA3(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -965,7 +965,7 @@ static PyObject* THPModule_setSDPUseMemEfficient( "set_sdp_use_math expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setSDPUseMemEfficient(arg == Py_True); + at::globalContext().setSDPUseMemEfficient(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -984,7 +984,7 @@ static PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) { "set_sdp_use_math expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setSDPUseMath(arg == Py_True); + at::globalContext().setSDPUseMath(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1005,7 +1005,7 @@ static PyObject* THPModule_setAllowFP16BF16ReductionMathSDP( "set_sdp_use_math expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowFP16BF16ReductionMathSDP(arg == Py_True); + at::globalContext().setAllowFP16BF16ReductionMathSDP(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1026,7 +1026,7 @@ static PyObject* THPModule_setSDPUseOverrideable( "set_sdp_use_overrideable expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setSDPUseOverrideable(arg == Py_True); + at::globalContext().setSDPUseOverrideable(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1045,7 +1045,7 @@ static PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) { "set_sdp_use_cudnn expects a bool, " "but got %s", THPUtils_typename(arg)); - at::globalContext().setSDPUseCuDNN(arg == Py_True); + at::globalContext().setSDPUseCuDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1067,7 +1067,7 @@ static PyObject* THPModule_setUserEnabledCuDNN( "set_enabled_cudnn expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setUserEnabledCuDNN(arg == Py_True); + at::globalContext().setUserEnabledCuDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1090,7 +1090,7 @@ static PyObject* THPModule_setUserEnabledMkldnn( "set_enabled_mkldnn expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setUserEnabledMkldnn(arg == Py_True); + at::globalContext().setUserEnabledMkldnn(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1113,7 +1113,7 @@ static PyObject* THPModule_setDeterministicCuDNN( "set_deterministic_cudnn expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setDeterministicCuDNN(arg == Py_True); + at::globalContext().setDeterministicCuDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1136,7 +1136,7 @@ static PyObject* THPModule_setDeterministicMkldnn( "set_deterministic_mkldnn expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setDeterministicMkldnn(arg == Py_True); + at::globalContext().setDeterministicMkldnn(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1175,7 +1175,7 @@ static PyObject* THPModule_setAllowTF32OneDNN( "_set_onednn_allow_tf32 expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowTF32OneDNN(arg == Py_True); + at::globalContext().setAllowTF32OneDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1217,7 +1217,7 @@ static PyObject* THPModule_setDeterministicFillUninitializedMemory( HANDLE_TH_ERRORS TORCH_CHECK( PyBool_Check(arg), "expected a bool, but got ", THPUtils_typename(arg)); - at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True); + at::globalContext().setDeterministicFillUninitializedMemory(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1240,7 +1240,7 @@ static PyObject* THPModule_setUserEnabledNNPACK( "set_enabled_NNPACK expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setUserEnabledNNPACK(arg == Py_True); + at::globalContext().setUserEnabledNNPACK(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1261,7 +1261,7 @@ static PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) { "setWarnOnlyOnce expects a bool, " "but got ", THPUtils_typename(arg)); - c10::WarningUtils::set_warnAlways(arg == Py_True); + c10::WarningUtils::set_warnAlways(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1298,7 +1298,7 @@ static PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) { "set_benchmark_cudnn expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setBenchmarkCuDNN(arg == Py_True); + at::globalContext().setBenchmarkCuDNN(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1310,6 +1310,31 @@ static PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) { Py_RETURN_FALSE; } +static PyObject* THPModule_setCuDNNDepthwiseKernel( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkString(arg), + "set_cudnn_depthwise_kernel expects a string, " + "but got ", + THPUtils_typename(arg)); + std::string mode = THPUtils_unpackString(arg); + at::globalContext().setCuDNNDepthwiseKernel(at::str2cudnn_depthwise(mode)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THPModule_getCuDNNDepthwiseKernel( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + auto mode = + at::cudnn_depthwise2str(at::globalContext().cudnnDepthwiseKernel()); + return THPUtils_packString(mode); + END_HANDLE_TH_ERRORS +} + static PyObject* THPModule_setImmediateMiopen( PyObject* _unused, PyObject* arg) { @@ -1319,7 +1344,7 @@ static PyObject* THPModule_setImmediateMiopen( "set_immediate_miopen expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setImmediateMiopen(arg == Py_True); + at::globalContext().setImmediateMiopen(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1342,7 +1367,7 @@ static PyObject* THPModule_setAllowTF32CuBLAS( "set_allow_tf32_cublas expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowTF32CuBLAS(arg == Py_True); + at::globalContext().setAllowTF32CuBLAS(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1372,15 +1397,15 @@ static PyObject* THPModule_setAllowFP16ReductionCuBLAS( "set_allow_fp16_reduction_cublas expects a bool for allow_reduced_precision, " "but got ", THPUtils_typename(allow_reduction_obj)); - bool allow_reduction = allow_reduction_obj == Py_True; + bool allow_reduction = Py_IsTrue(allow_reduction_obj); bool allow_splitk = true; - if (allow_splitk_obj != Py_None) { + if (!Py_IsNone(allow_splitk_obj)) { TORCH_CHECK( PyBool_Check(allow_splitk_obj), "set_allow_fp16_reduction_cublas expects a bool for allow_splitk, " "but got ", THPUtils_typename(allow_splitk_obj)); - allow_splitk = allow_splitk_obj == Py_True; + allow_splitk = Py_IsTrue(allow_splitk_obj); } at::globalContext().setAllowFP16ReductionCuBLAS( allow_reduction, allow_splitk); @@ -1416,15 +1441,15 @@ static PyObject* THPModule_setAllowBF16ReductionCuBLAS( "set_allow_bf16_reduction_cublas expects a bool for allow_reduced_precision, " "but got ", THPUtils_typename(allow_reduction_obj)); - bool allow_reduction = allow_reduction_obj == Py_True; + bool allow_reduction = Py_IsTrue(allow_reduction_obj); bool allow_splitk = true; - if (allow_splitk_obj != Py_None) { + if (!Py_IsNone(allow_splitk_obj)) { TORCH_CHECK( PyBool_Check(allow_splitk_obj), "set_allow_bf16_reduction_cublas expects a bool for allow_splitk, " "but got ", THPUtils_typename(allow_splitk_obj)); - allow_splitk = allow_splitk_obj == Py_True; + allow_splitk = Py_IsTrue(allow_splitk_obj); } at::globalContext().setAllowBF16ReductionCuBLAS( allow_reduction, allow_splitk); @@ -1455,7 +1480,7 @@ static PyObject* THPModule_setAllowFP16AccumulationCuBLAS( "set_allow_fp16_accumulation_cublas expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowFP16AccumulationCuBLAS(arg == Py_True); + at::globalContext().setAllowFP16AccumulationCuBLAS(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1478,7 +1503,7 @@ static PyObject* THPModule_setAllowFP16ReductionCPU( "set_allow_fp16_reduction_cpu expects a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setAllowFP16ReductionCPU(arg == Py_True); + at::globalContext().setAllowFP16ReductionCPU(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1499,7 +1524,7 @@ static PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) { "flush_denormal expects a bool, " "but got ", THPUtils_typename(arg)); - if (!at::globalContext().setFlushDenormal(arg == Py_True)) { + if (!at::globalContext().setFlushDenormal(Py_IsTrue(arg))) { Py_RETURN_FALSE; }; Py_RETURN_TRUE; @@ -1570,7 +1595,7 @@ static PyObject* THPModule_setCheckSparseTensorInvariants( PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - if (arg == Py_None) { + if (Py_IsNone(arg)) { at::globalContext().setCheckSparseTensorInvariants(std::nullopt); } else { TORCH_CHECK( @@ -1578,7 +1603,7 @@ static PyObject* THPModule_setCheckSparseTensorInvariants( "set_check_sparse_tensor_invariants expects a bool or None, " "but got ", THPUtils_typename(arg)); - at::globalContext().setCheckSparseTensorInvariants(arg == Py_True); + at::globalContext().setCheckSparseTensorInvariants(Py_IsTrue(arg)); } Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1727,7 +1752,7 @@ static PyObject* THPModule_set_display_vmap_fallback_warnings_mode( "enabled must be a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True); + at::globalContext().setDisplayVmapFallbackWarnings(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1753,7 +1778,7 @@ static PyObject* THPModule_set_warn_on_accumulate_grad_stream_mismatch( "enabled must be a bool, " "but got ", THPUtils_typename(arg)); - at::globalContext().setWarnOnAccumulateGradStreamMismatch(arg == Py_True); + at::globalContext().setWarnOnAccumulateGradStreamMismatch(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1779,7 +1804,56 @@ static PyObject* THCPModule_ensureCUDADeviceGuardSet( END_HANDLE_TH_ERRORS } +struct TorchModuleState { + PyObject* log_api_usage_seen; // dict used by _log_api_usage_once +}; + +static int torchmodule_traverse(PyObject* mod, visitproc visit, void* arg) { + auto* state = static_cast(PyModule_GetState(mod)); + Py_VISIT(state->log_api_usage_seen); + return 0; +} + +static int torchmodule_clear(PyObject* mod) { + auto* state = static_cast(PyModule_GetState(mod)); + Py_CLEAR(state->log_api_usage_seen); + return 0; +} + +static void torchmodule_free(void* mod) { + torchmodule_clear((PyObject*)mod); +} + +// Thread-safe in free-threaded Python. +static PyObject* LogAPIUsageOnceFromPython(PyObject* self, PyObject* event) { + auto* state = static_cast(PyModule_GetState(self)); + PyObject* api_usage_seen = state->log_api_usage_seen; + + int found = PyDict_Contains(api_usage_seen, event); + if (found < 0) { + return nullptr; + } else if (found != 0) { + Py_RETURN_NONE; + } + + const char* event_str = PyUnicode_AsUTF8(event); + if (!event_str) { + return nullptr; + } + + // Returns 0 if we inserted, 1 if already present, -1 on error. + int rc = PyDict_SetDefaultRef(api_usage_seen, event, Py_None, nullptr); + if (rc < 0) { + return nullptr; + } + if (rc == 0) { + c10::LogAPIUsage(event_str); + } + Py_RETURN_NONE; +} + static std::initializer_list TorchMethods = { + {"_log_api_usage_once", LogAPIUsageOnceFromPython, METH_O, nullptr}, {"_initExtension", THPModule_initExtension, METH_O, nullptr}, {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr}, @@ -1887,6 +1961,14 @@ static std::initializer_list TorchMethods = { {"_set_onednn_allow_tf32", THPModule_setAllowTF32OneDNN, METH_O, nullptr}, {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr}, {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr}, + {"_get_cudnn_depthwise_kernel", + THPModule_getCuDNNDepthwiseKernel, + METH_NOARGS, + nullptr}, + {"_set_cudnn_depthwise_kernel", + THPModule_setCuDNNDepthwiseKernel, + METH_O, + nullptr}, {"_get_miopen_immediate", THPModule_immediateMiopen, METH_NOARGS, nullptr}, {"_set_miopen_immediate", THPModule_setImmediateMiopen, METH_O, nullptr}, {"_get_cudnn_deterministic", @@ -2121,16 +2203,6 @@ void initModule(PyObject* module); static std::vector methods; -// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE -// Guaranteed to be invoked from Python under GIL, no locking on map needed -static void LogAPIUsageOnceFromPython(const std::string& event) { - static std::unordered_set seen; - if (!seen.count(event)) { - seen.insert(event); - c10::LogAPIUsage(event); - } -} - static void LogAPIUsageMetadataFromPython( const std::string& event, const std::map& metadata_map) { @@ -2249,9 +2321,22 @@ PyObject* initModule() { #endif static struct PyModuleDef torchmodule = { - PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()}; + PyModuleDef_HEAD_INIT, + "torch._C", + nullptr, + sizeof(TorchModuleState), + methods.data(), + nullptr, // m_slots + torchmodule_traverse, + torchmodule_clear, + torchmodule_free}; module = PyModule_Create(&torchmodule); ASSERT_TRUE(module); + + auto* mod_state = static_cast(PyModule_GetState(module)); + mod_state->log_api_usage_seen = PyDict_New(); + ASSERT_TRUE(mod_state->log_api_usage_seen); + #ifdef Py_GIL_DISABLED PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); #endif @@ -2367,7 +2452,7 @@ PyObject* initModule() { #endif ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn)); -#if defined(USE_CUSPARSELT) +#if defined(USE_CUSPARSELT) || defined(USE_ROCM) PyObject* has_cusparselt = Py_True; #else PyObject* has_cusparselt = Py_False; @@ -2399,20 +2484,8 @@ PyObject* initModule() { auto py_module = py::reinterpret_borrow(module); py_module.def("_initCrashHandler", &_initCrashHandler); py_module.def("_demangle", &c10::demangle); - py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython); py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython); - py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled); - py_module.def( - "set_vital", - [](const std::string& vital, - const std::string& attr, - const std::string& value) { - return at::vitals::VitalsAPI.setVital(vital, attr, value); - }); - py_module.def( - "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); }); - py_module.def( "init_num_threads", torch::wrap_pybind_function(at::init_num_threads), @@ -2693,6 +2766,9 @@ Call this whenever a new thread is created in order to propagate values from py_module.def("_get_blas_preferred_backend", []() { return at::globalContext().blasPreferredBackend(); }); + py_module.def("_get_blas_default_backend", []() { + return at::globalContext().blasDefaultBackend(); + }); py::enum_( py_module, "_ScalingType", "Supported Tensor scaling types") @@ -2739,6 +2815,14 @@ Call this whenever a new thread is created in order to propagate values from return at::globalContext().getROCmFAPreferredBackend(); }); + py_module.def("_is_ck_sdpa_available", []() { +#ifdef USE_ROCM + return at::globalContext().ckSupported() && at::globalContext().hasCKSDPA(); +#else + return false; +#endif + }); + py_module.def( "_set_sm_carveout_experimental", [](std::optional val) { at::globalContext()._setSMCarveout_EXPERIMENTAL(val); @@ -2778,6 +2862,7 @@ Call this whenever a new thread is created in order to propagate values from py_module.def( "_stash_obj_in_tls", [](const std::string& key, py::handle arg) { + Py_INCREF(arg.ptr()); at::impl::ThreadLocalPythonObjects::get_state().set( key, std::make_shared(arg.ptr(), getPyInterpreter())); diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp index 7f36d88bdaa32..ee0b310efcd31 100644 --- a/torch/csrc/PyInterpreter.cpp +++ b/torch/csrc/PyInterpreter.cpp @@ -967,7 +967,7 @@ PyInterpreterHolder self_interpreter; } // anonymous namespace py::handle getTorchApiFunction(const c10::OperatorHandle& op) { - return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { + return op.getPythonOp([&]() -> PyObject* { // Parse the name into namespace and name (no overload_name) // TODO: put this into the library const auto& schema = op.schema(); diff --git a/torch/csrc/QScheme.cpp b/torch/csrc/QScheme.cpp index e178ec9247ea5..d9d0fbcb41787 100644 --- a/torch/csrc/QScheme.cpp +++ b/torch/csrc/QScheme.cpp @@ -77,13 +77,7 @@ PyTypeObject THPQSchemeType = { }; void THPQScheme_init(PyObject* module) { - if (PyType_Ready(&THPQSchemeType) < 0) { - throw python_error(); - } - Py_INCREF(&THPQSchemeType); - if (PyModule_AddObject( - module, "qscheme", reinterpret_cast(&THPQSchemeType)) != - 0) { + if (PyModule_AddType(module, &THPQSchemeType) < 0) { throw python_error(); } } diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index 58e37d2c236b0..887eb32ba96cc 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -39,11 +39,6 @@ PyTypeObject* THPStorageClass = nullptr; static PyObject* THPStorage_New(PyTypeObject* type, c10::Storage _storage) { PyObject* obj = type->tp_alloc(type, 0); TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); - - // Ensure that PyUnstable_TryIncref calls don't fail spuriously in - // free-threaded Python. - PyUnstable_EnableTryIncRef(obj); - auto s = (THPStorage*)obj; new (&s->cdata) c10::Storage(std::move(_storage)); return obj; @@ -60,8 +55,7 @@ PyObject* THPStorage_NewWithStorage(PyTypeObject* type, c10::Storage _storage) { c10::StorageImpl* storage_impl = _storage.unsafeGetStorageImpl(); PyObject* obj = THPStorage_New(type, std::move(_storage)); - PyObjectPreservation::init_fresh_nonatomic( - storage_impl, storage_impl->pyobj_slot(), obj); + PyObjectPreservation::init_fresh_nonatomic(*storage_impl, obj); return obj; } @@ -73,22 +67,9 @@ PyObject* THPStorage_Wrap(c10::Storage storage) { } c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); - c10::impl::PyObjectSlot* pyobj_slot = storage_impl->pyobj_slot(); - - PyObject* obj = pyobj_slot->load_pyobj(); - if (obj) { - return Py_NewRef(obj); - } - - obj = THPStorage_New(THPStorageClass, std::move(storage)); - PyObject* wrapper = - PyObjectPreservation::init_once(storage_impl, pyobj_slot, obj); - if (wrapper != obj) { - // Another thread beat us to it - Py_DECREF(obj); - return Py_NewRef(wrapper); - } - return obj; + return PyObjectPreservation::get_or_init(*storage_impl, [&]() { + return THPStorage_New(THPStorageClass, std::move(storage)); + }); } static void THPStorage_dealloc(PyObject* self) { @@ -254,7 +235,6 @@ static PyObject* THPStorage_pynew( "but one of the items was of type ", THPUtils_typename(item.get()), " instead of int"); - return nullptr; } } return self; @@ -302,7 +282,6 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { step, ", but only a step of " "1 is supported"); - return nullptr; } const auto& storage = THPStorage_Unpack(self); @@ -351,7 +330,6 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { "can only set storage content with a int types, but got ", THPUtils_typename(value), " instead"); - return -1; } uint8_t rvalue = THPByteUtils_unpackReal(value); @@ -374,7 +352,6 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { step, ", but only a step of " "1 is supported"); - return 0; } // TODO: check the bounds only once // TODO: fill? @@ -384,7 +361,6 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { } TORCH_CHECK( false, "can't index a " THPStorageStr " with ", THPUtils_typename(index)); - return -1; END_HANDLE_TH_ERRORS_RET(-1) } diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp index edabf0efabf2b..813e3703da35f 100644 --- a/torch/csrc/StorageMethods.cpp +++ b/torch/csrc/StorageMethods.cpp @@ -405,12 +405,13 @@ static PyObject* THPStorage_writeFile(PyObject* self, PyObject* args) { TORCH_CHECK( !invalid, "Attempted to call _write_file() on an invalid python storage.") PyObject* file = PyTuple_GetItem(args, 0); - bool is_real_file = PyTuple_GetItem(args, 1) == Py_True; - bool save_size = PyTuple_GetItem(args, 2) == Py_True; + bool is_real_file = Py_IsTrue(PyTuple_GetItem(args, 1)); + bool save_size = Py_IsTrue(PyTuple_GetItem(args, 2)); PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3); TORCH_CHECK( - element_size_obj != Py_None, "_write_file: need to specify element size"); + !Py_IsNone(element_size_obj), + "_write_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); if (!is_real_file) { @@ -441,7 +442,7 @@ static PyObject* THPStorage_newWithFile(PyObject* _unused, PyObject* args) { "descriptor from given object"); PyObject* element_size_obj = PyTuple_GET_ITEM(args, 1); TORCH_CHECK( - element_size_obj != Py_None, + !Py_IsNone(element_size_obj), "_new_with_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); @@ -458,12 +459,12 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) { const auto& storage = THPStorage_Unpack(self); PyObject* file = PyTuple_GET_ITEM(args, 0); PyObject* offset = PyTuple_GET_ITEM(args, 1); - bool is_real_file = PyTuple_GET_ITEM(args, 2) == Py_True; + bool is_real_file = Py_IsTrue(PyTuple_GET_ITEM(args, 2)); PyObject* element_size_obj = PyTuple_GET_ITEM(args, 3); TORCH_CHECK( - element_size_obj != Py_None, + !Py_IsNone(element_size_obj), "_set_from_file: need to specify element size"); uint64_t element_size = THPUtils_unpackUInt64(element_size_obj); @@ -471,7 +472,7 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) { // offset can be implemented with a call to the Python object's seek() // but it is currently unnecessary to support this. TORCH_CHECK( - offset == Py_None, + Py_IsNone(offset), "_set_from_file: offset is NYI for filelike objects"); auto self_storage_impl = c10::intrusive_ptr::reclaim_copy( @@ -488,7 +489,7 @@ static PyObject* THPStorage_setFromFile(PyObject* self, PyObject* args) { // file is backed by a fd const int fd = PyObject_AsFileDescriptor(file); const auto fd_original_pos = LSEEK(fd, 0, SEEK_CUR); - if (offset != Py_None) { + if (!Py_IsNone(offset)) { LSEEK(fd, THPUtils_unpackLong(offset), SEEK_SET); } TORCH_CHECK( diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index a97b6ba9f0199..6853e386d8e86 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -588,7 +588,7 @@ static PyObject* THPStorage_newWithWeakPtr(PyObject* _unused, PyObject* arg) { static PyObject* THPStorage_freeWeakRef(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - if (arg == Py_None) { + if (Py_IsNone(arg)) { Py_RETURN_NONE; } TORCH_CHECK( diff --git a/torch/csrc/Stream.cpp b/torch/csrc/Stream.cpp index c5a8e343e6b27..c6d6ee5a2723d 100644 --- a/torch/csrc/Stream.cpp +++ b/torch/csrc/Stream.cpp @@ -169,6 +169,19 @@ static PyObject* THPStream_synchronize(PyObject* _self, PyObject* noargs) { END_HANDLE_TH_ERRORS } +static PyObject* THPStream_is_capturing(PyObject* _self, PyObject* noargs) { + HANDLE_TH_ERRORS + auto self = reinterpret_cast(_self); + + return PyBool_FromLong(c10::Stream::unpack3( + self->stream_id, + static_cast(self->device_index), + static_cast(self->device_type)) + .is_capturing()); + + END_HANDLE_TH_ERRORS +} + static PyObject* THPStream_wait_event(PyObject* _self, PyObject* _event) { HANDLE_TH_ERRORS { auto self = reinterpret_cast(_self); @@ -224,7 +237,7 @@ static PyObject* THPStream_record_event( &_event)) { TORCH_CHECK(false, "parse record_event arg fails"); } - if (_event != Py_None) { + if (!Py_IsNone(_event)) { // We expect it to be an explicit torch.Event instance. TORCH_CHECK( Py_TYPE(_event) == THPEventClass, @@ -366,7 +379,7 @@ static PyObject* THPStream_exit(PyObject* _self, PyObject* unused) { PyObject* top = PyList_GET_ITEM(self->context, stack_size - 1); // Sentinel: this __enter__ was a no-op, nothing to restore. - if (top == Py_None) { + if (Py_IsNone(top)) { if (PyList_SetSlice(self->context, stack_size - 1, stack_size, nullptr) < 0) { throw python_error(); @@ -421,7 +434,7 @@ static PyObject* THPStream_richcompare( PyObject* other, int op) { PyObject* result = nullptr; - if (other == Py_None) { + if (Py_IsNone(other)) { result = Py_False; } else { switch (op) { @@ -478,6 +491,7 @@ static const std::initializer_list THPStream_properties = { static const std::initializer_list THPStream_methods = { {"query", THPStream_query, METH_NOARGS, nullptr}, {"synchronize", THPStream_synchronize, METH_NOARGS, nullptr}, + {"is_capturing", THPStream_is_capturing, METH_NOARGS, nullptr}, {"wait_event", THPStream_wait_event, METH_O, nullptr}, {"wait_stream", THPStream_wait_stream, METH_O, nullptr}, {"record_event", diff --git a/torch/csrc/TypeInfo.cpp b/torch/csrc/TypeInfo.cpp index 355202c7e40f9..8df8702a851fb 100644 --- a/torch/csrc/TypeInfo.cpp +++ b/torch/csrc/TypeInfo.cpp @@ -399,20 +399,10 @@ PyTypeObject THPIInfoType = { }; void THPDTypeInfo_init(PyObject* module) { - if (PyType_Ready(&THPFInfoType) < 0) { + if (PyModule_AddType(module, &THPFInfoType) < 0) { throw python_error(); } - Py_INCREF(&THPFInfoType); - if (PyModule_AddObject( - module, "finfo", reinterpret_cast(&THPFInfoType)) != 0) { - throw python_error(); - } - if (PyType_Ready(&THPIInfoType) < 0) { - throw python_error(); - } - Py_INCREF(&THPIInfoType); - if (PyModule_AddObject( - module, "iinfo", reinterpret_cast(&THPIInfoType)) != 0) { + if (PyModule_AddType(module, &THPIInfoType) < 0) { throw python_error(); } } diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 49de1c8af63f3..350ec4f20d2b1 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -918,9 +918,10 @@ inline std::tuple multi_head_attention_forward( // average attention weights over heads attn_output_weights = attn_output_weights.sum(/*dim=*/1) / num_heads; } - return std::make_tuple(attn_output, attn_output_weights); + return std::make_tuple( + std::move(attn_output), std::move(attn_output_weights)); } else { - return std::make_tuple(attn_output, Tensor()); + return std::make_tuple(std::move(attn_output), Tensor()); } } } // namespace detail diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index 4eff40199ff43..66ec084f1f429 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -122,10 +122,12 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](nn::Module& module) { - /// std::cout << module.name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// /// \endrst void apply(const ModuleApplyFunction& function); @@ -134,10 +136,12 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const nn::Module& module) { - /// std::cout << module.name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](const nn::Module& module) { + /// std::cout << module.name() << std::endl; + /// }); + /// /// \endrst void apply(const ConstModuleApplyFunction& function) const; @@ -149,10 +153,12 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, nn::Module& module) { - /// std::cout << key << ": " << module.name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](const std::string& key, nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// /// \endrst void apply( const NamedModuleApplyFunction& function, @@ -166,10 +172,12 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, const nn::Module& module) { - /// std::cout << key << ": " << module.name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](const std::string& key, const nn::Module& module) { + /// std::cout << key << ": " << module.name() << std::endl; + /// }); + /// /// \endrst void apply( const ConstNamedModuleApplyFunction& function, @@ -180,10 +188,12 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::shared_ptr& module) { - /// std::cout << module->name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](const std::shared_ptr& module) { + /// std::cout << module->name() << std::endl; + /// }); + /// /// \endrst void apply(const ModulePointerApplyFunction& function) const; @@ -196,11 +206,13 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// \rst /// .. code-block:: cpp - /// MyModule module; - /// module->apply([](const std::string& key, - /// const std::shared_ptr& module) { - /// std::cout << key << ": " << module->name() << std::endl; - /// }); + /// + /// MyModule module; + /// module->apply([](const std::string& key, + /// const std::shared_ptr& module) { + /// std::cout << key << ": " << module->name() << std::endl; + /// }); + /// /// \endrst void apply( const NamedModulePointerApplyFunction& function, @@ -310,18 +322,20 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Attempts to cast this `Module` to the given `ModuleType`. /// /// This method is useful when calling `apply()`. + /// /// \rst /// .. code-block:: cpp /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); /// - /// MyModule module; - /// module->apply(initialize_weights); /// \endrst template typename ModuleType::ContainedType* as() noexcept; @@ -329,17 +343,20 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Attempts to cast this `Module` to the given `ModuleType`. /// /// This method is useful when calling `apply()`. + /// /// \rst /// .. code-block:: cpp - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } /// - /// MyModule module; - /// module->apply(initialize_weights); + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module->apply(initialize_weights); + /// /// \endrst template const typename ModuleType::ContainedType* as() const noexcept; @@ -347,18 +364,20 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Attempts to cast this `Module` to the given `ModuleType`. /// /// This method is useful when calling `apply()`. + /// /// \rst /// .. code-block:: cpp /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); /// - /// MyModule module; - /// module.apply(initialize_weights); /// \endrst template < typename ModuleType, @@ -368,18 +387,20 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Attempts to cast this `Module` to the given `ModuleType`. /// /// This method is useful when calling `apply()`. + /// /// \rst /// .. code-block:: cpp /// - /// void initialize_weights(nn::Module& module) { - /// torch::NoGradGuard no_grad; - /// if (auto* linear = module.as()) { - /// linear->weight.normal_(0.0, 0.02); - /// } - /// } + /// void initialize_weights(nn::Module& module) { + /// torch::NoGradGuard no_grad; + /// if (auto* linear = module.as()) { + /// linear->weight.normal_(0.0, 0.02); + /// } + /// } + /// + /// MyModule module; + /// module.apply(initialize_weights); /// - /// MyModule module; - /// module.apply(initialize_weights); /// \endrst template < typename ModuleType, @@ -437,7 +458,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// /// A buffer is intended to be state in your module that does not record /// gradients, such as running statistics. Registering it makes it available - /// to methods such as `buffers()`, `clone()` or `to(). + /// to methods such as `buffers()`, `clone()` or `to()`. /// /// \rst /// .. code-block:: cpp @@ -487,8 +508,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// Replaces a registered submodule with this `Module`. /// /// This takes care of the registration, if you used submodule members, you - /// should - // assign the submodule as well, i.e. use as + /// should assign the submodule as well, i.e. use as /// module->submodule_ = module->replace_module("linear", /// torch::nn::Linear(3, 4)); /// It only works when a module of the name is already registered. @@ -504,8 +524,7 @@ class TORCH_API Module : public std::enable_shared_from_this { /// This method deals with `ModuleHolder`s. /// /// This takes care of the registration, if you used submodule members, you - /// should - // assign the submodule as well, i.e. use as + /// should assign the submodule as well, i.e. use as /// module->submodule_ = module->replace_module("linear", linear_holder); /// It only works when a module of the name is already registered. /// @@ -683,15 +702,15 @@ std::shared_ptr Module::replace_module( template void Module::to_impl(Ts&&... ts) { - // First call `to()` on every child module. + /// First call `to()` on every child module. for (auto& child : children_) { child.value()->to(ts...); } - // Then move every parameter to the new dtype/device. + /// Then move every parameter to the new dtype/device. for (auto& parameter : named_parameters(/*recurse=*/false)) { parameter->set_data(parameter->to(ts...)); } - // Then move every buffer to the new dtype/device. + /// Then move every buffer to the new dtype/device. for (auto& buffer : named_buffers(/*recurse=*/false)) { buffer->set_data(buffer->to(ts...)); } diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index e2af5b74ee599..32f3e36f46d57 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -12,6 +12,7 @@ #include #include #include +#include #include // Forward declarations confuse Doxygen @@ -66,12 +67,209 @@ class TORCH_API OptimizerOptions { virtual void set_lr(const double lr); }; +// Forward declarations for optimizer option types +struct SGDOptions; +struct AdamOptions; +struct AdamWOptions; +struct AdagradOptions; +struct RMSpropOptions; +struct LBFGSOptions; + +/** + * OptimizerCloneableOptions provides parameter group inheritance functionality + * for PyTorch C++ optimizer options. When creating parameter groups with + * partial options (e.g., AdamOptions().weight_decay(0.1)), fields not + * explicitly set by the user inherit from the optimizer's default values, + * while explicitly set fields are preserved. + * + * This enables Python-like behavior in C++: + * ```cpp + * // Python equivalent: + * // optimizer = Adam([{'params': params1, 'weight_decay': 0.1}], lr=0.01) + * // Result: weight_decay=0.1 preserved, lr=0.01 inherited + * + * AdamOptions defaults; + * defaults.lr(0.01).weight_decay(0.05); + * + * std::vector groups; + * groups.emplace_back(params1, std::make_unique( + * AdamOptions().weight_decay(0.1))); // Only weight_decay specified + * + * Adam optimizer(groups, defaults); + * // Result: group inherits lr=0.01, preserves weight_decay=0.1 + * ``` + * + * **Implementation**: Uses SFINAE-based field detection and constructor-default + * comparison to distinguish explicitly set fields from default values. + * Fields that match constructor defaults are inherited; others are preserved. + */ template class OptimizerCloneableOptions : public OptimizerOptions { private: std::unique_ptr clone() const override { return std::make_unique(static_cast(*this)); } + + // SFINAE field detection - detects optimizer fields using public accessor + // methods + template + struct _has_lr : std::false_type {}; + template + struct _has_lr().lr())>> + : std::true_type {}; + + template + struct _has_momentum : std::false_type {}; + template + struct _has_momentum< + T, + std::void_t().momentum())>> + : std::true_type {}; + + template + struct _has_weight_decay : std::false_type {}; + template + struct _has_weight_decay< + T, + std::void_t().weight_decay())>> + : std::true_type {}; + + template + struct _has_dampening : std::false_type {}; + template + struct _has_dampening< + T, + std::void_t().dampening())>> + : std::true_type {}; + + template + struct _has_nesterov : std::false_type {}; + template + struct _has_nesterov< + T, + std::void_t().nesterov())>> + : std::true_type {}; + + template + struct _has_betas : std::false_type {}; + template + struct _has_betas().betas())>> + : std::true_type {}; + + template + struct _has_eps : std::false_type {}; + template + struct _has_eps().eps())>> + : std::true_type {}; + + template + struct _has_amsgrad : std::false_type {}; + template + struct _has_amsgrad< + T, + std::void_t().amsgrad())>> + : std::true_type {}; + + // Optimizer-specific field detection + template + struct _has_lr_decay : std::false_type {}; + template + struct _has_lr_decay< + T, + std::void_t().lr_decay())>> + : std::true_type {}; + + template + struct _has_alpha : std::false_type {}; + template + struct _has_alpha().alpha())>> + : std::true_type {}; + + template + struct _has_centered : std::false_type {}; + template + struct _has_centered< + T, + std::void_t().centered())>> + : std::true_type {}; + + template + struct _has_initial_accumulator_value : std::false_type {}; + template + struct _has_initial_accumulator_value< + T, + std::void_t< + decltype(std::declval().initial_accumulator_value())>> + : std::true_type {}; + + // LBFGS-specific fields with appropriate types + template + struct _has_max_iter : std::false_type {}; + template + struct _has_max_iter< + T, + std::void_t().max_iter())>> + : std::true_type {}; + + template + struct _has_max_eval : std::false_type {}; + template + struct _has_max_eval< + T, + std::void_t().max_eval())>> + : std::true_type {}; + + template + struct _has_tolerance_grad : std::false_type {}; + template + struct _has_tolerance_grad< + T, + std::void_t().tolerance_grad())>> + : std::true_type {}; + + template + struct _has_tolerance_change : std::false_type {}; + template + struct _has_tolerance_change< + T, + std::void_t().tolerance_change())>> + : std::true_type {}; + + template + struct _has_history_size : std::false_type {}; + template + struct _has_history_size< + T, + std::void_t().history_size())>> + : std::true_type {}; + + template + struct _has_line_search_fn : std::false_type {}; + template + struct _has_line_search_fn< + T, + std::void_t().line_search_fn())>> + : std::true_type {}; + + /** + * Merges user-specified options with optimizer defaults using + * constructor-default comparison to detect explicitly set fields. + * + * Algorithm: + * 1. Start with optimizer defaults as base + * 2. Create fresh constructor instance for comparison + * 3. If user_value != constructor_default → user explicitly set it → preserve + * 4. If user_value == constructor_default → user didn't set it → inherit from + * defaults + * + * Implementation is in optimizer.cpp to anchor vtable/typeinfo. + */ + void _merge_by_comparison( + const Derived& defaults, + const Derived& user_options); + + // Friend class for controlled access to private _merge_by_comparison method + friend class Optimizer; }; /// Stores parameters in the param_group and stores a pointer to the @@ -186,6 +384,43 @@ class TORCH_API Optimizer { /// Deserializes the optimizer state from the given `archive`. virtual void load(serialize::InputArchive& archive); + private: + /// Helper function to try merging for a specific optimizer type + template + static bool _try_merge_optimizer_type( + std::unique_ptr& final_options, + const OptimizerOptions& user_options, + const OptimizerOptions& defaults) { + auto* typed_final = dynamic_cast(final_options.get()); + auto* typed_user = dynamic_cast(&user_options); + auto* typed_defaults = dynamic_cast(&defaults); + + if (typed_final && typed_user && typed_defaults) { + typed_final->_merge_by_comparison(*typed_defaults, *typed_user); + return true; + } + return false; + } + + /// Simple variadic dispatch helper - try all optimizer types in one call + template + static void _try_merge_all_optimizer_types( + std::unique_ptr& final_options, + const OptimizerOptions& user_options, + const OptimizerOptions& defaults) { + // Try each optimizer type until one succeeds - much cleaner than manual + // chain + (void)(_try_merge_optimizer_type( + final_options, user_options, defaults) || + ...); + } + + /// Convenience function with all known PyTorch optimizers + static void _try_merge_all_optimizers( + std::unique_ptr& final_options, + const OptimizerOptions& user_options, + const OptimizerOptions& defaults); + protected: std::vector param_groups_; ska::flat_hash_map> state_; diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp index 7ee864bc8ea94..c03d46e5b99e2 100644 --- a/torch/csrc/api/src/nn/modules/rnn.cpp +++ b/torch/csrc/api/src/nn/modules/rnn.cpp @@ -619,7 +619,7 @@ std::tuple> LSTMImpl::forward_helper( max_batch_size, options.hidden_size()}, torch::dtype(input.dtype()).device(input.device())); - hx = std::make_tuple(h_zeros, c_zeros); + hx = std::make_tuple(std::move(h_zeros), std::move(c_zeros)); } else { hx = hx_opt.value(); // Each batch of the hidden state should match the input sequence that @@ -632,7 +632,7 @@ std::tuple> LSTMImpl::forward_helper( if (!batch_sizes.defined()) { result = torch::lstm( input, - {std::get<0>(hx), std::get<1>(hx)}, + {std::move(std::get<0>(hx)), std::move(std::get<1>(hx))}, flat_weights_, options.bias(), options.num_layers(), @@ -644,7 +644,7 @@ std::tuple> LSTMImpl::forward_helper( result = torch::lstm( input, batch_sizes, - {std::get<0>(hx), std::get<1>(hx)}, + {std::move(std::get<0>(hx)), std::move(std::get<1>(hx))}, flat_weights_, options.bias(), options.num_layers(), @@ -652,10 +652,11 @@ std::tuple> LSTMImpl::forward_helper( this->is_training(), options.bidirectional()); } - auto output = std::get<0>(result); - auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result)); + auto output = std::move(std::get<0>(result)); + auto hidden = std::make_tuple( + std::move(std::get<1>(result)), std::move(std::get<2>(result))); - return std::make_tuple(output, hidden); + return std::make_tuple(std::move(output), std::move(hidden)); } std::tuple> LSTMImpl::forward( diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index 2241d8e0964b6..4da31ee903377 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -179,7 +179,7 @@ std::tuple LBFGS::_directional_evaluate( } auto flat_grad = _gather_flat_grad(); _set_param(x); - return std::make_tuple(loss, flat_grad); + return std::make_tuple(loss, std::move(flat_grad)); } static double _cubic_interpolate( diff --git a/torch/csrc/api/src/optim/optimizer.cpp b/torch/csrc/api/src/optim/optimizer.cpp index 24c40089bb5ea..674e07307d958 100644 --- a/torch/csrc/api/src/optim/optimizer.cpp +++ b/torch/csrc/api/src/optim/optimizer.cpp @@ -1,10 +1,195 @@ #include +#include +#include + +// Include complete type definitions for all optimizers to enable dynamic_cast +#include +#include +#include +#include +#include +#include + +#include #include #include namespace torch::optim { +// Implementation of OptimizerCloneableOptions::_merge_by_comparison +// Moved here to anchor vtable/typeinfo for template instantiations +template +void OptimizerCloneableOptions::_merge_by_comparison( + const Derived& defaults, + const Derived& user_options) { + auto* result = static_cast(this); + *result = defaults; // Start with optimizer defaults + + // Create constructor defaults instance for comparison + Derived constructor_defaults = []() { + if constexpr (std::is_default_constructible_v) { + return Derived{}; + } else { + // Handle optimizers requiring constructor parameters + if constexpr (std::is_same_v) { + return Derived(1e-3); + } else if constexpr (std::is_same_v) { + return Derived(1e-2); + } else if constexpr (std::is_same_v) { + return Derived(1e-2); + } else if constexpr (std::is_same_v) { + return Derived(1); + } else { + return Derived{}; + } + } + }(); + + // Merge fields: preserve user-set values, inherit defaults for unset values + + if constexpr (OptimizerCloneableOptions::_has_lr::value) { + if (user_options.lr() != constructor_defaults.lr()) { + result->lr(user_options.lr()); + } + } + if constexpr (OptimizerCloneableOptions::_has_momentum< + Derived>::value) { + if (user_options.momentum() != constructor_defaults.momentum()) { + result->momentum(user_options.momentum()); + } + } + if constexpr (OptimizerCloneableOptions::_has_weight_decay< + Derived>::value) { + if (user_options.weight_decay() != constructor_defaults.weight_decay()) { + result->weight_decay(user_options.weight_decay()); + } + } + if constexpr (OptimizerCloneableOptions::_has_dampening< + Derived>::value) { + if (user_options.dampening() != constructor_defaults.dampening()) { + result->dampening(user_options.dampening()); + } + } + if constexpr (OptimizerCloneableOptions::_has_nesterov< + Derived>::value) { + if (user_options.nesterov() != constructor_defaults.nesterov()) { + result->nesterov(user_options.nesterov()); + } + } + if constexpr (OptimizerCloneableOptions::_has_betas< + Derived>::value) { + if (user_options.betas() != constructor_defaults.betas()) { + result->betas(user_options.betas()); + } + } + if constexpr (OptimizerCloneableOptions::_has_eps::value) { + if (user_options.eps() != constructor_defaults.eps()) { + result->eps(user_options.eps()); + } + } + if constexpr (OptimizerCloneableOptions::_has_amsgrad< + Derived>::value) { + if (user_options.amsgrad() != constructor_defaults.amsgrad()) { + result->amsgrad(user_options.amsgrad()); + } + } + + // Optimizer-specific fields - automatically detected and handled + if constexpr (OptimizerCloneableOptions::_has_lr_decay< + Derived>::value) { + if (user_options.lr_decay() != constructor_defaults.lr_decay()) { + result->lr_decay(user_options.lr_decay()); + } + } + if constexpr (OptimizerCloneableOptions::_has_alpha< + Derived>::value) { + if (user_options.alpha() != constructor_defaults.alpha()) { + result->alpha(user_options.alpha()); + } + } + if constexpr (OptimizerCloneableOptions::_has_centered< + Derived>::value) { + if (user_options.centered() != constructor_defaults.centered()) { + result->centered(user_options.centered()); + } + } + if constexpr (OptimizerCloneableOptions< + Derived>::_has_initial_accumulator_value::value) { + if (user_options.initial_accumulator_value() != + constructor_defaults.initial_accumulator_value()) { + result->initial_accumulator_value( + user_options.initial_accumulator_value()); + } + } + + // LBFGS-specific fields with appropriate types + if constexpr (OptimizerCloneableOptions::_has_max_iter< + Derived>::value) { + if (user_options.max_iter() != constructor_defaults.max_iter()) { + result->max_iter(user_options.max_iter()); + } + } + if constexpr (OptimizerCloneableOptions::_has_max_eval< + Derived>::value) { + if (user_options.max_eval() != constructor_defaults.max_eval()) { + result->max_eval(user_options.max_eval()); + } + } + if constexpr (OptimizerCloneableOptions::_has_tolerance_grad< + Derived>::value) { + if (user_options.tolerance_grad() != + constructor_defaults.tolerance_grad()) { + result->tolerance_grad(user_options.tolerance_grad()); + } + } + if constexpr (OptimizerCloneableOptions::_has_tolerance_change< + Derived>::value) { + if (user_options.tolerance_change() != + constructor_defaults.tolerance_change()) { + result->tolerance_change(user_options.tolerance_change()); + } + } + if constexpr (OptimizerCloneableOptions::_has_history_size< + Derived>::value) { + if (user_options.history_size() != constructor_defaults.history_size()) { + result->history_size(user_options.history_size()); + } + } + if constexpr (OptimizerCloneableOptions::_has_line_search_fn< + Derived>::value) { + if (user_options.line_search_fn() != + constructor_defaults.line_search_fn()) { + result->line_search_fn(user_options.line_search_fn()); + } + } +} + +// Explicit template instantiations to anchor vtable/typeinfo +// These instantiations ensure the compiler generates the full class definition +// and vtable for each OptimizerCloneableOptions specialization +template class OptimizerCloneableOptions; +template class OptimizerCloneableOptions; +template class OptimizerCloneableOptions; +template class OptimizerCloneableOptions; +template class OptimizerCloneableOptions; +template class OptimizerCloneableOptions; + +// Simple implementation using variadic template helper +void Optimizer::_try_merge_all_optimizers( + std::unique_ptr& final_options, + const OptimizerOptions& user_options, + const OptimizerOptions& defaults) { + // Clean one-liner replaces the entire repetitive dispatch chain + _try_merge_all_optimizer_types< + SGDOptions, + AdamOptions, + AdamWOptions, + AdagradOptions, + RMSpropOptions, + LBFGSOptions>(final_options, user_options, defaults); +} + bool OptimizerParamGroup::has_options() const { return options_ != nullptr; } @@ -97,9 +282,20 @@ void Optimizer::add_param_group(const OptimizerParamGroup& param_group) { TORCH_INTERNAL_ASSERT(defaults_ != nullptr); OptimizerParamGroup param_group_(param_group.params()); if (!param_group.has_options()) { + // No options provided - use defaults directly param_group_.set_options(defaults_->clone()); } else { - param_group_.set_options(param_group.options().clone()); + // Options provided - merge user's explicit settings with defaults for + // parameter group inheritance This enables Python-C++ API parity by + // honoring user intent while inheriting missing parameters + auto final_options = defaults_->clone(); + + // Simple variadic dispatch - try all known optimizer types + _try_merge_all_optimizers(final_options, param_group.options(), *defaults_); + + // If no merging was done (custom optimizer), final_options already contains + // defaults + param_group_.set_options(std::move(final_options)); } for (const auto& p : param_group_.params()) { TORCH_CHECK( diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index f9468e929a329..e1761b000d8c2 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,8 @@ using at::OptionalIntArrayRef; using at::Scalar; using at::Tensor; using at::TensorList; +using at::native::detail::GridSamplerInterpolation; +using at::native::detail::GridSamplerPadding; const char* kCudnnDoubleBackwardMsg = "Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: \nwith torch.backends.cudnn.flags(enabled=False):\n output = model(inputs)"; @@ -223,6 +226,67 @@ Tensor amaxamin_jvp( return at::where(mask, dx, 0.).sum(dim, keepdim) / mask.sum(dim, keepdim); } +// Builds the dims vector for aminmax: all dims when dim is nullopt, +// or the single wrapped dim otherwise. +static std::vector _aminmax_dims( + const Tensor& self, + std::optional dim_opt) { + if (dim_opt.has_value()) { + return {at::maybe_wrap_dim(*dim_opt, self.dim())}; + } + std::vector dims(self.dim()); + std::iota(dims.begin(), dims.end(), 0); + return dims; +} + +Tensor aminmax_backward( + const Tensor& self, + std::optional dim, + bool keepdim, + const Tensor& grad_min, + const Tensor& grad_max, + const Tensor& min, + const Tensor& max) { + auto dims_vec = _aminmax_dims(self, dim); + IntArrayRef dims(dims_vec); + Tensor result; + + if (grad_min.defined()) { + auto grad_min_expanded = restore_reduced_dims(grad_min, dims, keepdim); + auto min_mask = (self == restore_reduced_dims(min, dims, keepdim)); + result = scale_grad_by_count(grad_min_expanded, min_mask, dims); + } + + if (grad_max.defined()) { + auto grad_max_expanded = restore_reduced_dims(grad_max, dims, keepdim); + auto max_mask = (self == restore_reduced_dims(max, dims, keepdim)); + auto grad_max_result = + scale_grad_by_count(grad_max_expanded, max_mask, dims); + + if (result.defined()) { + if (!areAnyTensorSubclassLike({result, grad_max_result})) { + result.add_(grad_max_result); + } else { + result = result + grad_max_result; + } + } else { + result = grad_max_result; + } + } + + return result; +} + +Tensor aminmax_jvp( + const Tensor& self_p, + const Tensor& self_t, + const Tensor& result, + std::optional dim, + bool keepdim) { + auto dims_vec = _aminmax_dims(self_p, dim); + return amaxamin_jvp(self_p, self_t, result, IntArrayRef(dims_vec), keepdim); +} + std::tuple _euclidean_dist_backward( const Tensor& grad, const Tensor& x1, @@ -989,6 +1053,7 @@ Tensor unbind_backward_nested( int64_t dim, const at::TensorOptions& options) { std::vector grads_tensors; + grads_tensors.reserve(grads.size()); for (int64_t i : c10::irange(static_cast(grads.size()))) { if (grads[i].defined()) { grads_tensors.push_back(static_cast(grads[i])); @@ -2185,6 +2250,7 @@ Tensor _nested_split_with_sizes_backward( // it's possible some of the grads are not defined (represents tensors of all // 0s). Since at::cat can't handle those, let's define them std::vector grads_all_defined; + grads_all_defined.reserve(grads.size()); for (int64_t i : c10::irange(static_cast(grads.size()))) { if (grads[i].defined()) { grads_all_defined.push_back(static_cast(grads[i])); @@ -2251,7 +2317,6 @@ Tensor error_for_max_pool2d_double_backward() { // This is mps-only. "max_pool2d with `return_indices=False` is not infinitely differentiable.", " If you want to calculate higher order derivatives, e.g. second order,", " set `return_indices=True`."); - return Tensor(); } Tensor glu_double_backward( @@ -4312,10 +4377,12 @@ Tensor linalg_det_backward( const Tensor& LU, const Tensor& pivots) { at::NoTF32Guard disable_tf32; - // A.numel() == 0 necessary for the singular case - if (!grad.defined() || A.sym_numel() == 0) { + if (!grad.defined()) { return {}; } + if (A.sym_numel() == 0) { + return at::zeros_like(A); + } // Special case handling for 1 x 1 matrix, to ensure mathematically correct. // d(det)/dA = 1, so gradient = grad * ones_like(A) @@ -7208,7 +7275,7 @@ std::tuple scatter_reduce_backward( grad_self = grad_self.scatter(dim, index, 0); } - return std::make_tuple(grad_self, grad_src); + return std::make_tuple(std::move(grad_self), std::move(grad_src)); } Tensor _to_copy_backward( @@ -7304,7 +7371,7 @@ std::tuple index_reduce_backward( grad_self = grad_self.index_fill(dim, index, 0); } - return std::make_tuple(grad_self, grad_src); + return std::make_tuple(std::move(grad_self), std::move(grad_src)); } Tensor take_backward( @@ -7501,4 +7568,697 @@ Tensor values_backward(const Tensor& grad, const Tensor& self) { return grad_self; } +// ==================== grid_sampler double backward ==================== + +namespace { + +// Bound integer tap coordinate per padding mode (for bicubic taps). +// Returns a kLong tensor in [0, size-1]. +static Tensor gs_bound_coord( + const Tensor& idx, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + if (padding_mode != GridSamplerPadding::Reflection) { + return idx.clamp(0, size - 1); + } + // Reflection: mirrors the PyTorch reflect_coordinates convention. + if (size <= 1) { + return at::zeros_like(idx); + } + auto x = idx.to(at::kDouble); + double span, min_v; + if (align_corners) { + span = size - 1; + min_v = 0.0; + } else { + span = size; + min_v = -0.5; + } + // Native formula: in = idx - min_v (= idx + 0.5 for no-align_corners). + // For integer idx this is always a half-integer, so fold result + min_v is + // an exact integer — no rounding artifacts. + auto in = x - min_v; + auto in_abs = in.abs(); + auto even = at::fmod(at::floor(in_abs / span), 2.0) < 0.5; + auto rem = at::fmod(in_abs, span); + auto r = at::where(even, rem, span - rem); + return (r + min_v).round().clamp(0, size - 1).to(at::kLong); +} + +// Compute unnormalized source coordinate and per-position chain-rule multiplier +// d(source_coord)/d(grid_coord), accounting for unnormalization and padding. +static std::pair gs_compute_coords( + const Tensor& coord, + int64_t size, + GridSamplerPadding padding_mode, + bool align_corners) { + double unnorm_scale = align_corners ? (size - 1) / 2.0 : size / 2.0; + Tensor ix = align_corners ? (coord + 1) * unnorm_scale + : (coord + 1) * unnorm_scale - 0.5; + Tensor padding_grad; + if (padding_mode == GridSamplerPadding::Zeros) { + padding_grad = at::ones_like(ix); + } else if (padding_mode == GridSamplerPadding::Border) { + // clip_coordinates_set_grad: grad = 1 iff strictly inside (0, size-1) + padding_grad = + ((ix > 0) & (ix < static_cast(size - 1))).to(ix.dtype()); + ix = ix.clamp(0, size - 1); + } else { + TORCH_CHECK_NOT_IMPLEMENTED( + padding_mode == GridSamplerPadding::Reflection, + "Unknown padding mode: ", + static_cast(padding_mode)); + double twice_low = align_corners ? 0.0 : -1.0; + double twice_high = align_corners ? 2.0 * (size - 1) : 2.0 * size - 1.0; + if (twice_high <= twice_low) { + auto z = at::zeros_like(ix); + return {z, z}; + } + double min_v = twice_low / 2.0; + double span = (twice_high - twice_low) / 2.0; + auto in = ix - min_v; + auto in_abs = in.abs(); + auto reflect_sign = at::where(in >= 0, 1.0, -1.0); + auto even = at::fmod(at::floor(in_abs / span), 2.0) < 0.5; + auto flip_sign = at::where(even, 1.0, -1.0); + auto rem = in_abs % span; + auto ix_refl = at::where(even, rem + min_v, span - rem + min_v); + auto clip_grad = ((ix_refl > 0) & (ix_refl < static_cast(size - 1))) + .to(ix.dtype()); + padding_grad = reflect_sign * flip_sign * clip_grad; + ix = ix_refl.clamp(0, size - 1); + } + return {std::move(ix), padding_grad * unnorm_scale}; +} + +static Tensor gs_accum_sumprod_k(const Tensor& values, const Tensor& basis) { + return (values * basis.unsqueeze(1)).sum(-1); +} + +static Tensor gs_outer_last(const Tensor& a, const Tensor& b) { + auto shape = a.sizes().vec(); + shape.back() *= b.size(-1); + return (b.unsqueeze(-1) * a.unsqueeze(-2)).reshape(shape); +} + +// Multi-tap gather for 2D: h_idx/w_idx [N, Ho, Wo, K] → [N, C, Ho, Wo, K] +static Tensor gs_gather2d_multi( + const Tensor& input, + const Tensor& h_idx, + const Tensor& w_idx, + bool zeros_oob) { + auto N = input.size(0), C = input.size(1); + auto H = input.size(2), W = input.size(3); + auto out_H = h_idx.size(1), out_W = h_idx.size(2), K = h_idx.size(3); + auto flat = (h_idx.clamp(0, H - 1) * W + w_idx.clamp(0, W - 1)) + .reshape({N, 1, out_H * out_W * K}) + .expand({N, C, out_H * out_W * K}); + auto result = input.reshape({N, C, H * W}) + .gather(2, flat) + .reshape({N, C, out_H, out_W, K}); + if (zeros_oob) { + auto mask = (h_idx >= 0) & (h_idx < H) & (w_idx >= 0) & (w_idx < W); + result = result * mask.unsqueeze(1).to(result.dtype()); + } + return result; +} + +// Multi-tap scatter for 2D: values [N,C,Ho,Wo], weights [N,Ho,Wo,K] → [N,C,H,W] +static Tensor gs_scatter2d_multi( + const Tensor& values, + const Tensor& weights, + const Tensor& h_idx, + const Tensor& w_idx, + int64_t H, + int64_t W, + bool zeros_oob) { + auto N = values.size(0), C = values.size(1); + auto out_H = values.size(2), out_W = values.size(3); + auto K = h_idx.size(3); + auto flat = (h_idx.clamp(0, H - 1) * W + w_idx.clamp(0, W - 1)) + .reshape({N, 1, out_H * out_W * K}) + .expand({N, C, out_H * out_W * K}); + auto weighted = values.unsqueeze(-1) * weights.unsqueeze(1); + if (zeros_oob) { + auto mask = (h_idx >= 0) & (h_idx < H) & (w_idx >= 0) & (w_idx < W); + weighted = weighted * mask.unsqueeze(1).to(weighted.dtype()); + } + return at::zeros({N, C, H * W}, values.options()) + .scatter_add(2, flat, weighted.reshape({N, C, out_H * out_W * K})) + .reshape({N, C, H, W}); +} + +// Multi-tap bounded gather for bicubic 2D: h_idx/w_idx [N, Ho, Wo, K] → [N, C, +// Ho, Wo, K] +static Tensor gs_gather2d_bc_multi( + const Tensor& input, + const Tensor& h_idx, + const Tensor& w_idx, + GridSamplerPadding padding_mode, + bool align_corners) { + auto N = input.size(0), C = input.size(1); + auto H = input.size(2), W = input.size(3); + auto out_H = h_idx.size(1), out_W = h_idx.size(2), K = h_idx.size(3); + auto flat = (gs_bound_coord(h_idx, H, padding_mode, align_corners) * W + + gs_bound_coord(w_idx, W, padding_mode, align_corners)) + .reshape({N, 1, out_H * out_W * K}) + .expand({N, C, out_H * out_W * K}); + auto result = input.reshape({N, C, H * W}) + .gather(2, flat) + .reshape({N, C, out_H, out_W, K}); + if (padding_mode == GridSamplerPadding::Zeros) { + auto mask = (h_idx >= 0) & (h_idx < H) & (w_idx >= 0) & (w_idx < W); + result = result * mask.unsqueeze(1).to(result.dtype()); + } + return result; +} + +// Multi-tap bounded scatter for bicubic 2D: values [N,C,Ho,Wo], weights +// [N,Ho,Wo,K] → [N,C,H,W] +static Tensor gs_scatter2d_bc_multi( + const Tensor& values, + const Tensor& weights, + const Tensor& h_idx, + const Tensor& w_idx, + int64_t H, + int64_t W, + GridSamplerPadding padding_mode, + bool align_corners) { + auto N = values.size(0), C = values.size(1); + auto out_H = values.size(2), out_W = values.size(3); + auto K = h_idx.size(3); + auto flat = (gs_bound_coord(h_idx, H, padding_mode, align_corners) * W + + gs_bound_coord(w_idx, W, padding_mode, align_corners)) + .reshape({N, 1, out_H * out_W * K}) + .expand({N, C, out_H * out_W * K}); + auto weighted = values.unsqueeze(-1) * weights.unsqueeze(1); + if (padding_mode == GridSamplerPadding::Zeros) { + auto mask = (h_idx >= 0) & (h_idx < H) & (w_idx >= 0) & (w_idx < W); + weighted = weighted * mask.unsqueeze(1).to(weighted.dtype()); + } + return at::zeros({N, C, H * W}, values.options()) + .scatter_add(2, flat, weighted.reshape({N, C, out_H * out_W * K})) + .reshape({N, C, H, W}); +} + +// Multi-tap gather for 3D: d_idx/h_idx/w_idx [N, Do, Ho, Wo, K] → [N, C, Do, +// Ho, Wo, K] +static Tensor gs_gather3d_multi( + const Tensor& input, + const Tensor& d_idx, + const Tensor& h_idx, + const Tensor& w_idx, + bool zeros_oob) { + auto N = input.size(0), C = input.size(1); + auto D = input.size(2), H = input.size(3), W = input.size(4); + auto out_D = d_idx.size(1), out_H = d_idx.size(2), out_W = d_idx.size(3), + K = d_idx.size(4); + auto flat = ((d_idx.clamp(0, D - 1) * H + h_idx.clamp(0, H - 1)) * W + + w_idx.clamp(0, W - 1)) + .reshape({N, 1, out_D * out_H * out_W * K}) + .expand({N, C, out_D * out_H * out_W * K}); + auto result = input.reshape({N, C, D * H * W}) + .gather(2, flat) + .reshape({N, C, out_D, out_H, out_W, K}); + if (zeros_oob) { + auto mask = (d_idx >= 0) & (d_idx < D) & (h_idx >= 0) & (h_idx < H) & + (w_idx >= 0) & (w_idx < W); + result = result * mask.unsqueeze(1).to(result.dtype()); + } + return result; +} + +// Multi-tap scatter for 3D: values [N,C,Do,Ho,Wo], weights [N,Do,Ho,Wo,K] → +// [N,C,D,H,W] +static Tensor gs_scatter3d_multi( + const Tensor& values, + const Tensor& weights, + const Tensor& d_idx, + const Tensor& h_idx, + const Tensor& w_idx, + int64_t D, + int64_t H, + int64_t W, + bool zeros_oob) { + auto N = values.size(0), C = values.size(1); + auto out_D = values.size(2), out_H = values.size(3), out_W = values.size(4); + auto K = d_idx.size(4); + auto flat = ((d_idx.clamp(0, D - 1) * H + h_idx.clamp(0, H - 1)) * W + + w_idx.clamp(0, W - 1)) + .reshape({N, 1, out_D * out_H * out_W * K}) + .expand({N, C, out_D * out_H * out_W * K}); + auto weighted = values.unsqueeze(-1) * weights.unsqueeze(1); + if (zeros_oob) { + auto mask = (d_idx >= 0) & (d_idx < D) & (h_idx >= 0) & (h_idx < H) & + (w_idx >= 0) & (w_idx < W); + weighted = weighted * mask.unsqueeze(1).to(weighted.dtype()); + } + return at::zeros({N, C, D * H * W}, values.options()) + .scatter_add(2, flat, weighted.reshape({N, C, out_D * out_H * out_W * K})) + .reshape({N, C, D, H, W}); +} + +} // anonymous namespace + +std::tuple grid_sampler_2d_double_backward( + const Tensor& ggI, + const Tensor& ggGrid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask) { + Tensor d_grad_output, d_input, d_grid; + const auto interpolation = + static_cast(interpolation_mode); + const auto padding_mode_enum = static_cast(padding_mode); + + // ggI -> d_grad_output: gather ggI at grid positions = grid_sampler_2d(ggI, + // grid) + if (output_mask[0] && ggI.defined()) { + d_grad_output = at::grid_sampler_2d( + ggI, grid, interpolation_mode, padding_mode, align_corners); + } + // ggI -> d_grid: same structure as grad_grid but with ggI as "input" + if (output_mask[2] && ggI.defined()) { + d_grid = std::get<1>(at::grid_sampler_2d_backward( + grad_output, + ggI, + grid, + interpolation_mode, + padding_mode, + align_corners, + {false, true})); + } + // d_input from ggI is 0: grad_input has no dependence on input. + // For nearest, grad_grid = 0, so all ggGrid contributions vanish. + if (!ggGrid.defined() || interpolation == GridSamplerInterpolation::Nearest) { + return {std::move(d_grad_output), std::move(d_input), std::move(d_grid)}; + } + auto H = input.size(2), W = input.size(3); + + auto [ix, gix_mult] = gs_compute_coords( + grid.select(-1, 0), W, padding_mode_enum, align_corners); + auto [iy, giy_mult] = gs_compute_coords( + grid.select(-1, 1), H, padding_mode_enum, align_corners); + + auto x0 = at::floor(ix).to(at::kLong); + auto y0 = at::floor(iy).to(at::kLong); + auto fx = ix - x0.to(ix.dtype()); + auto fy = iy - y0.to(iy.dtype()); + + bool zeros_oob = (padding_mode_enum == GridSamplerPadding::Zeros); + auto ggG_x = ggGrid.select(-1, 0) * gix_mult; // (N, out_H, out_W) + auto ggG_y = ggGrid.select(-1, 1) * giy_mult; + + if (interpolation == GridSamplerInterpolation::Bilinear) { + auto x1 = x0 + 1; + auto y1 = y0 + 1; + // Stack all 4 tap indices: [N, Ho, Wo, 4] order: nw, ne, sw, se + auto h_stk = at::stack({y0, y0, y1, y1}, -1); + auto w_stk = at::stack({x0, x1, x0, x1}, -1); + + Tensor I; // [N, C, Ho, Wo, 4] + if (output_mask[0] || output_mask[1] || output_mask[2]) { + I = gs_gather2d_multi(input, h_stk, w_stk, zeros_oob); + } + + Tensor dw_dx, dw_dy; + if (output_mask[0] || output_mask[1]) { + // Derivative weights [N, Ho, Wo, 4]: d(w_k)/d(ix) and d(w_k)/d(iy) + auto omx = 1.0 - fx, omy = 1.0 - fy; + dw_dx = at::stack({-omy, omy, -fy, fy}, -1); + dw_dy = at::stack({-omx, -fx, omx, fx}, -1); + } + + if (output_mask[0]) { + auto sum_dx = gs_accum_sumprod_k(I, dw_dx); + auto sum_dy = gs_accum_sumprod_k(I, dw_dy); + auto contrib = sum_dx * ggG_x.unsqueeze(1); + contrib.addcmul_(sum_dy, ggG_y.unsqueeze(1)); + d_grad_output = d_grad_output.defined() ? d_grad_output + contrib + : std::move(contrib); + } + + if (output_mask[1]) { + // Combined per-tap weight [N, Ho, Wo, 4] + auto w_in = ggG_x.unsqueeze(-1) * dw_dx; + w_in.addcmul_(ggG_y.unsqueeze(-1), dw_dy); + d_input = + gs_scatter2d_multi(grad_output, w_in, h_stk, w_stk, H, W, zeros_oob); + } + + if (output_mask[2]) { + // cross second derivative: coeffs [nw,ne,sw,se] = [+1,-1,-1,+1] + auto I_nw = I.select(-1, 0); + auto I_ne = I.select(-1, 1); + auto I_sw = I.select(-1, 2); + auto I_se = I.select(-1, 3); + auto cross = (grad_output * (I_nw - I_ne - I_sw + I_se)).sum(1); + auto d_grid_x = gix_mult * ggG_y * cross; + auto d_grid_y = giy_mult * ggG_x * cross; + auto contrib = at::stack({std::move(d_grid_x), std::move(d_grid_y)}, -1); + d_grid = d_grid.defined() ? d_grid + contrib : std::move(contrib); + } + } else if (interpolation == GridSamplerInterpolation::Bicubic) { + // Native bicubic backward uses raw unnormalized coordinates with a constant + // gx_mult/gy_mult (the unnormalize scale only), applying padding + // exclusively when sampling each tap value. Using gs_compute_coords here + // would be wrong: for border/reflection modes it zeroes gix_mult at + // boundaries, making ggG_x/y = 0 where native still has nonzero + // sensitivity. + double x_scale = align_corners ? (W - 1) / 2.0 : W / 2.0; + double y_scale = align_corners ? (H - 1) / 2.0 : H / 2.0; + auto x_raw = align_corners ? (grid.select(-1, 0) + 1) * x_scale + : (grid.select(-1, 0) + 1) * x_scale - 0.5; + auto y_raw = align_corners ? (grid.select(-1, 1) + 1) * y_scale + : (grid.select(-1, 1) + 1) * y_scale - 0.5; + auto x0_bc = at::floor(x_raw).to(at::kLong); + auto y0_bc = at::floor(y_raw).to(at::kLong); + auto ggG_x_bc = ggGrid.select(-1, 0) * x_scale; + auto ggG_y_bc = ggGrid.select(-1, 1) * y_scale; + auto fx = x_raw - x0_bc.to(x_raw.dtype()); + auto fy = y_raw - y0_bc.to(y_raw.dtype()); + + // Cubic interpolation coefficients and derivatives w.r.t. fractional + // offset. For t in [0,1], the four corners are at offsets {-1, 0, 1, 2} + // from base. + constexpr double A = -0.75; + auto fx1 = fx + 1.0, fx2 = 1.0 - fx, fx3 = 2.0 - fx; + auto fy1 = fy + 1.0, fy2 = 1.0 - fy, fy3 = 2.0 - fy; + + auto cx_t = at::stack( + {(((A * fx1) - (5 * A)) * fx1 + (8 * A)) * fx1 - (4 * A), + (((A + 2) * fx) - (A + 3)) * fx.square() + 1, + (((A + 2) * fx2) - (A + 3)) * fx2.square() + 1, + (((A * fx3) - (5 * A)) * fx3 + (8 * A)) * fx3 - (4 * A)}, + -1); + auto dcx_t = at::stack( + {(((3 * A) * fx1) - (10 * A)) * fx1 + (8 * A), + (((3 * (A + 2)) * fx) - (2 * (A + 3))) * fx, + -((((3 * (A + 2)) * fx2) - (2 * (A + 3))) * fx2), + (((-3 * A) * fx3) + (10 * A)) * fx3 - (8 * A)}, + -1); + auto cy_t = at::stack( + {(((A * fy1) - (5 * A)) * fy1 + (8 * A)) * fy1 - (4 * A), + (((A + 2) * fy) - (A + 3)) * fy.square() + 1, + (((A + 2) * fy2) - (A + 3)) * fy2.square() + 1, + (((A * fy3) - (5 * A)) * fy3 + (8 * A)) * fy3 - (4 * A)}, + -1); + auto dcy_t = at::stack( + {(((3 * A) * fy1) - (10 * A)) * fy1 + (8 * A), + (((3 * (A + 2)) * fy) - (2 * (A + 3))) * fy, + -((((3 * (A + 2)) * fy2) - (2 * (A + 3))) * fy2), + (((-3 * A) * fy3) + (10 * A)) * fy3 - (8 * A)}, + -1); + + // Stack coefficients into [N, Ho, Wo, 4] tensors for vectorized outer + // products. Tap indices [N, Ho, Wo, 16]: flat index k=j*4+i → h=y0+(j-1), + // w=x0+(i-1) + auto offs = at::arange(-1, 3, x0_bc.options()); + auto xt = x0_bc.unsqueeze(-1) + offs; // [N, Ho, Wo, 4] + auto yt = y0_bc.unsqueeze(-1) + offs; + // Outer product: h_idx[...,j,i]=yt[j], w_idx[...,j,i]=xt[i] + auto w_idx_bc = xt.unsqueeze(-2) + .expand({-1, -1, -1, 4, 4}) + .reshape({-1, xt.size(1), xt.size(2), 16}); + auto h_idx_bc = yt.unsqueeze(-1) + .expand({-1, -1, -1, 4, 4}) + .reshape({-1, yt.size(1), yt.size(2), 16}); + + // Gather all 16 taps at once: [N, C, Ho, Wo, 16] + auto I_all = gs_gather2d_bc_multi( + input, h_idx_bc, w_idx_bc, padding_mode_enum, align_corners); + + Tensor B_dx, B_dy; + if (output_mask[0] || output_mask[1]) { + // 1D weight outer products → [N, Ho, Wo, 16] where k=j*4+i + // B_dx[k] = dcx[i]*cy[j], B_dy[k] = cx[i]*dcy[j] + B_dx = gs_outer_last(dcx_t, cy_t); + B_dy = gs_outer_last(cx_t, dcy_t); + } + + if (output_mask[0]) { + auto sum_dx = gs_accum_sumprod_k(I_all, B_dx); + auto sum_dy = gs_accum_sumprod_k(I_all, B_dy); + auto contrib = sum_dx * ggG_x_bc.unsqueeze(1); + contrib.addcmul_(sum_dy, ggG_y_bc.unsqueeze(1)); + d_grad_output = d_grad_output.defined() ? d_grad_output + contrib + : std::move(contrib); + } + + if (output_mask[1]) { + auto w_in = ggG_x_bc.unsqueeze(-1) * B_dx; + w_in.addcmul_(ggG_y_bc.unsqueeze(-1), B_dy); + d_input = gs_scatter2d_bc_multi( + grad_output, + w_in, + h_idx_bc, + w_idx_bc, + H, + W, + padding_mode_enum, + align_corners); + } + + // d_grid from ggGrid: non-zero for bicubic (second derivative of cubic + // weights). + if (output_mask[2]) { + auto d2cx_t = at::stack( + {((6 * A) * fx1) - (10 * A), + ((6 * (A + 2)) * fx) - (2 * (A + 3)), + ((6 * (A + 2)) * fx2) - (2 * (A + 3)), + ((6 * A) * fx3) - (10 * A)}, + -1); + auto d2cy_t = at::stack( + {((6 * A) * fy1) - (10 * A), + ((6 * (A + 2)) * fy) - (2 * (A + 3)), + ((6 * (A + 2)) * fy2) - (2 * (A + 3)), + ((6 * A) * fy3) - (10 * A)}, + -1); + Tensor dot_dx2, dot_dy2, dot_dxdy; + for (const auto j : c10::irange(4)) { + auto cy_j = cy_t.select(-1, j); + auto dcy_j = dcy_t.select(-1, j); + auto d2cy_j = d2cy_t.select(-1, j); + for (const auto i : c10::irange(4)) { + auto tap_dot = (grad_output * I_all.select(-1, j * 4 + i)).sum(1); + auto dx2_term = tap_dot * d2cx_t.select(-1, i) * cy_j; + auto dy2_term = tap_dot * cx_t.select(-1, i) * d2cy_j; + auto dxdy_term = tap_dot * dcx_t.select(-1, i) * dcy_j; + dot_dx2 = + dot_dx2.defined() ? dot_dx2.add_(dx2_term) : std::move(dx2_term); + dot_dy2 = + dot_dy2.defined() ? dot_dy2.add_(dy2_term) : std::move(dy2_term); + dot_dxdy = dot_dxdy.defined() ? dot_dxdy.add_(dxdy_term) + : std::move(dxdy_term); + } + } + auto ggrid_d_grid_x = ggG_x_bc * dot_dx2; + ggrid_d_grid_x.addcmul_(ggG_y_bc, dot_dxdy); + ggrid_d_grid_x.mul_(x_scale); + auto ggrid_d_grid_y = ggG_x_bc * dot_dxdy; + ggrid_d_grid_y.addcmul_(ggG_y_bc, dot_dy2); + ggrid_d_grid_y.mul_(y_scale); + auto ggrid_d_grid = + at::stack({std::move(ggrid_d_grid_x), std::move(ggrid_d_grid_y)}, -1); + d_grid = + d_grid.defined() ? d_grid + ggrid_d_grid : std::move(ggrid_d_grid); + } + } else { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "grid_sampler_2d_double_backward does not support interpolation mode ", + interpolation_mode); + } + return {std::move(d_grad_output), std::move(d_input), std::move(d_grid)}; +} + +std::tuple grid_sampler_3d_double_backward( + const Tensor& ggI, + const Tensor& ggGrid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask) { + Tensor d_grad_output, d_input, d_grid; + const auto interpolation = + static_cast(interpolation_mode); + const auto padding_mode_enum = static_cast(padding_mode); + + if (output_mask[0] && ggI.defined()) { + d_grad_output = at::grid_sampler_3d( + ggI, grid, interpolation_mode, padding_mode, align_corners); + } + if (output_mask[2] && ggI.defined()) { + d_grid = std::get<1>(at::grid_sampler_3d_backward( + grad_output, + ggI, + grid, + interpolation_mode, + padding_mode, + align_corners, + {false, true})); + } + if (!ggGrid.defined() || interpolation == GridSamplerInterpolation::Nearest) { + return {std::move(d_grad_output), std::move(d_input), std::move(d_grid)}; + } + TORCH_CHECK_NOT_IMPLEMENTED( + interpolation == GridSamplerInterpolation::Bilinear, + "grid_sampler_3d double backward not implemented for interpolation_mode=", + interpolation_mode); + + // Bilinear 3D: ggGrid contributions. + auto D = input.size(2), H = input.size(3), W = input.size(4); + + auto [ix, gix_mult] = gs_compute_coords( + grid.select(-1, 0), W, padding_mode_enum, align_corners); + auto [iy, giy_mult] = gs_compute_coords( + grid.select(-1, 1), H, padding_mode_enum, align_corners); + auto [iz, giz_mult] = gs_compute_coords( + grid.select(-1, 2), D, padding_mode_enum, align_corners); + + auto x0 = at::floor(ix).to(at::kLong); + auto y0 = at::floor(iy).to(at::kLong); + auto z0 = at::floor(iz).to(at::kLong); + auto x1 = x0 + 1, y1 = y0 + 1, z1 = z0 + 1; + auto fx = ix - x0.to(ix.dtype()); + auto fy = iy - y0.to(iy.dtype()); + auto fz = iz - z0.to(iz.dtype()); + + bool zeros_oob = (padding_mode_enum == GridSamplerPadding::Zeros); + auto ggG_x = ggGrid.select(-1, 0) * gix_mult; + auto ggG_y = ggGrid.select(-1, 1) * giy_mult; + auto ggG_z = ggGrid.select(-1, 2) * giz_mult; + + // Stack all 8 tap indices [N, Do, Ho, Wo, 8]: tnw,tne,tsw,tse,bnw,bne,bsw,bse + auto d_stk = at::stack({z0, z0, z0, z0, z1, z1, z1, z1}, -1); + auto h_stk = at::stack({y0, y0, y1, y1, y0, y0, y1, y1}, -1); + auto w_stk = at::stack({x0, x1, x0, x1, x0, x1, x0, x1}, -1); + + Tensor I; // [N, C, Do, Ho, Wo, 8] + if (output_mask[0] || output_mask[1] || output_mask[2]) { + I = gs_gather3d_multi(input, d_stk, h_stk, w_stk, zeros_oob); + } + + Tensor dw_dx, dw_dy, dw_dz; + if (output_mask[0] || output_mask[1]) { + // Trilinear derivative weights [N, Do, Ho, Wo, 8] + auto omx = 1.0 - fx, omy = 1.0 - fy, omz = 1.0 - fz; + dw_dx = at::stack( + {-omy * omz, + omy * omz, + -fy * omz, + fy * omz, + -omy * fz, + omy * fz, + -fy * fz, + fy * fz}, + -1); + dw_dy = at::stack( + {-omx * omz, + -fx * omz, + omx * omz, + fx * omz, + -omx * fz, + -fx * fz, + omx * fz, + fx * fz}, + -1); + dw_dz = at::stack( + {-omx * omy, + -fx * omy, + -omx * fy, + -fx * fy, + omx * omy, + fx * omy, + omx * fy, + fx * fy}, + -1); + } + + if (output_mask[0]) { + auto sum_dx = gs_accum_sumprod_k(I, dw_dx); + auto sum_dy = gs_accum_sumprod_k(I, dw_dy); + auto sum_dz = gs_accum_sumprod_k(I, dw_dz); + auto d_gO = sum_dx * ggG_x.unsqueeze(1); + d_gO.addcmul_(sum_dy, ggG_y.unsqueeze(1)); + d_gO.addcmul_(sum_dz, ggG_z.unsqueeze(1)); + d_grad_output = + d_grad_output.defined() ? d_grad_output + d_gO : std::move(d_gO); + } + + if (output_mask[1]) { + auto w_in = ggG_x.unsqueeze(-1) * dw_dx; + w_in.addcmul_(ggG_y.unsqueeze(-1), dw_dy); + w_in.addcmul_(ggG_z.unsqueeze(-1), dw_dz); + d_input = gs_scatter3d_multi( + grad_output, w_in, d_stk, h_stk, w_stk, D, H, W, zeros_oob); + } + + // ggGrid -> d_grid: cross second derivatives for 3D bilinear. + // d²output/(dix diy) = (1-fz)*(I_tnw-I_tne-I_tsw+I_tse) + fz*(I_bnw-...) + // d²output/(dix diz) = (1-fy)*(I_tnw-I_tne-I_bnw+I_bne) + fy*(I_tsw-...) + // d²output/(diy diz) = (1-fx)*(I_tnw-I_tsw-I_bnw+I_bsw) + fx*(I_tne-...) + if (output_mask[2]) { + auto fx_ = fx.unsqueeze(1), fy_ = fy.unsqueeze(1), fz_ = fz.unsqueeze(1); + auto I_tnw = I.select(-1, 0); + auto I_tne = I.select(-1, 1); + auto I_tsw = I.select(-1, 2); + auto I_tse = I.select(-1, 3); + auto I_bnw = I.select(-1, 4); + auto I_bne = I.select(-1, 5); + auto I_bsw = I.select(-1, 6); + auto I_bse = I.select(-1, 7); + auto top_xy = I_tnw - I_tne; + top_xy.sub_(I_tsw); + top_xy.add_(I_tse); + auto bottom_xy = I_bnw - I_bne; + bottom_xy.sub_(I_bsw); + bottom_xy.add_(I_bse); + auto d2_xy = (1 - fz_) * top_xy; + d2_xy.add_(fz_ * bottom_xy); + + auto top_xz = I_tnw - I_tne; + top_xz.sub_(I_bnw); + top_xz.add_(I_bne); + auto bottom_xz = I_tsw - I_tse; + bottom_xz.sub_(I_bsw); + bottom_xz.add_(I_bse); + auto d2_xz = (1 - fy_) * top_xz; + d2_xz.addcmul_(fy_, bottom_xz); + + auto top_yz = I_tnw - I_tsw; + top_yz.sub_(I_bnw); + top_yz.add_(I_bsw); + auto bottom_yz = I_tne - I_tse; + bottom_yz.sub_(I_bne); + bottom_yz.add_(I_bse); + auto d2_yz = (1 - fx_) * top_yz; + d2_yz.addcmul_(fx_, bottom_yz); + auto dot_xy = (grad_output * d2_xy).sum(1); + auto dot_xz = (grad_output * d2_xz).sum(1); + auto dot_yz = (grad_output * d2_yz).sum(1); + auto d_grid_x = ggG_y * dot_xy; + d_grid_x.addcmul_(ggG_z, dot_xz); + d_grid_x.mul_(gix_mult); + auto d_grid_y = ggG_x * dot_xy; + d_grid_y.addcmul_(ggG_z, dot_yz); + d_grid_y.mul_(giy_mult); + auto d_grid_z = ggG_x * dot_xz; + d_grid_z.addcmul_(ggG_y, dot_yz); + d_grid_z.mul_(giz_mult); + auto contrib = at::stack( + {std::move(d_grid_x), std::move(d_grid_y), std::move(d_grid_z)}, -1); + d_grid = d_grid.defined() ? d_grid + contrib : std::move(contrib); + } + return {std::move(d_grad_output), std::move(d_input), std::move(d_grid)}; +} + } // namespace torch::autograd::generated::details diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 2df0a9af28e50..ecc6bc8bdb5dd 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -736,6 +736,26 @@ std::tuple batchnorm_double_backward( const std::optional& save_mean, const std::optional& save_invstd, std::array output_mask); +std::tuple grid_sampler_2d_double_backward( + const Tensor& ggI, + const Tensor& ggGrid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); +std::tuple grid_sampler_3d_double_backward( + const Tensor& ggI, + const Tensor& ggGrid, + const Tensor& grad_output, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + std::array output_mask); std::tuple _euclidean_dist_backward( const Tensor& grad, const Tensor& x1, @@ -820,6 +840,20 @@ Tensor amaxamin_jvp( const Tensor& result, IntArrayRef dim, bool keepdim); +Tensor aminmax_backward( + const at::Tensor& self, + std::optional dim, + bool keepdim, + const at::Tensor& grad_min, + const at::Tensor& grad_max, + const at::Tensor& min, + const at::Tensor& max); +Tensor aminmax_jvp( + const Tensor& self_p, + const Tensor& self_t, + const Tensor& result, + std::optional dim, + bool keepdim); std::tuple layer_norm_double_backward( const Tensor& input, const std::optional& gamma, diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index c2c4dffee66eb..76dda79d4185f 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -479,11 +479,10 @@ static Tensor _fw_primal( std::function rev_func = nullptr; if (!self.unsafeGetTensorImpl()->support_as_strided()) { func = std::make_unique(self.sym_sizes()); - rev_func = [=](const at::Tensor& input_view) { + rev_func = [=](const at::Tensor& input_view) -> at::Tensor { TORCH_INTERNAL_ASSERT( false, "Reverse view_func for _fw_primal() is not currently supported"); - return Tensor(); }; } auto result = as_view( @@ -512,11 +511,10 @@ static Tensor _make_dual( std::function rev_func = nullptr; if (!primal.unsafeGetTensorImpl()->support_as_strided()) { func = std::make_unique(primal.sym_sizes()); - rev_func = [=](const at::Tensor& input_view) { + rev_func = [=](const at::Tensor& input_view) -> at::Tensor { TORCH_INTERNAL_ASSERT( false, "Reverse view_func for _make_dual() is not currently supported"); - return Tensor(); }; } auto result = as_view( diff --git a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp index 386a8a9df534d..efe91535a5a36 100644 --- a/torch/csrc/autograd/autograd_not_implemented_fallback.cpp +++ b/torch/csrc/autograd/autograd_not_implemented_fallback.cpp @@ -618,7 +618,7 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( "input and the first output (the output can be a vector of tensors). Please change the " "order of your operator's parameters so that this is the case."); const bool is_view = aliased_input_idx.has_value(); - size_t aliased_input_idx_val; + size_t aliased_input_idx_val = 0; // Save inputs before we redispatch down torch::jit::Stack non_tensor_stack; @@ -657,13 +657,13 @@ static void autogradNotImplementedInplaceOrViewFallbackImpl( "which does not have a derivative implemented is forbidden."); auto erroring_view_func = std::make_unique(error_msg); - const auto erroring_rev_view_func = [op_name = op_name](const at::Tensor&) { + const auto erroring_rev_view_func = + [op_name = op_name](const at::Tensor&) -> at::Tensor { TORCH_CHECK( false, "Accessing the reverse view for ", op_name, " which does not have a derivative implemented is forbidden."); - return at::Tensor(); }; if (aliased_output_iv.isTensorList()) { diff --git a/torch/csrc/autograd/function.cpp b/torch/csrc/autograd/function.cpp index 5e294327a0ac4..5743cf48faefd 100644 --- a/torch/csrc/autograd/function.cpp +++ b/torch/csrc/autograd/function.cpp @@ -46,6 +46,44 @@ auto Node::name() const -> std::string { return c10::demangle(typeid(*this).name()); } +auto Node::forward_op_name() const -> std::string { + auto n = name(); + // Strip "Backward" suffix to get the forward op name. + auto pos = n.rfind("Backward"); + if (pos == std::string::npos) { + return n; + } + // Verify everything after "Backward" is digits (e.g., "Backward0"). + auto suffix_start = pos + 8; + for (size_t i = suffix_start; i < n.size(); ++i) { + if (!std::isdigit(static_cast(n[i]))) { + return n; + } + } + // Keep the numeric suffix if it is not "0" (e.g., "AddBackward1" → "Add1"). + auto suffix = n.substr(suffix_start); + if (suffix == "0" || suffix.empty()) { + return n.substr(0, pos); + } + return n.substr(0, pos) + suffix; +} + +bool Node::task_should_compute_output(size_t output_edge_index) const { + TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); + const auto& next = next_edges_[output_edge_index]; + if (next.is_valid()) { + const auto exec_info = get_current_graph_task_exec_info(); + if (exec_info && !exec_info->empty()) { + auto it = exec_info->find(next.function.get()); + if (it == exec_info->end() || !it->second.should_execute()) { + return false; + } + } + return true; + } + return false; +} + AnomalyMetadata* Node::metadata() noexcept { if (!anomaly_metadata_) { anomaly_metadata_ = Engine::get_default_engine().make_anomaly_metadata(); diff --git a/torch/csrc/autograd/function.h b/torch/csrc/autograd/function.h index ca97c43ca726e..2b09e67311019 100644 --- a/torch/csrc/autograd/function.h +++ b/torch/csrc/autograd/function.h @@ -1,719 +1,15 @@ #pragma once -#include -#include -#include #include -#include +#include #include #include -#include #include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - namespace torch::autograd { -struct Edge; -struct FunctionPostHook; -struct FunctionPreHook; - -using tensor_list = std::vector; -using variable_list = std::vector; -using edge_list = std::vector; -using saved_variable_list = std::vector; -using ivalue_list = std::vector; -using functional_apply_t = std::function< - variable_list(const variable_list&, const std::vector&)>; -using IndexRange = std::pair; -using torch::dynamo::autograd::CompiledNodeArgs; -using torch::dynamo::autograd::PackedArgs; -using torch::dynamo::autograd::SwapSavedVariables; - -// Custom deleter to prevent stack overflows. -TORCH_API void deleteNode(Node* function); - -// Guard that sets and restores the evaluating node -class NodeGuard { - public: - explicit NodeGuard(std::shared_ptr node); - ~NodeGuard(); - - private: - std::shared_ptr last_evaluating_node_; -}; - -// Return the Node currently being evaluated (if any) -// This is only set during the backward pass while a Node is being -// executed. -TORCH_API std::shared_ptr get_current_node(); - -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Node -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// A `Node` is an abstract class that represents an operation taking zero -// or more input `Variable`s and producing zero or more output `Variable`s. All -// functions in PyTorch's autograd machinery derive from this class and -// override its `apply` method. Instances of such subclasses will then be -// invocable via the call operator. -// -// Nodes in the Autograd Graph -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// When viewing the autograd system as a graph, `Node`s are the vertices or -// nodes, connected to each other via (directed) `Edge`s, which themselves are -// represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to -// and inputs of `Node`s, and travel between these edges during execution -// of the graph. When two or more `Edge`s (from different sources) point at the -// same input to a `Node`, the values produced along all of these edges are -// implicitly summed prior to being forwarded to the target `Node`. -// -// Hierarchy -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Subclasses usually represent differentiable functions as well as their -// gradient operators. Note, however, that due to the very general definition -// of a `Node` taking *zero* or more inputs and producing *zero* or more -// outputs, uses of `Node`s are flexible and extend beyond purely -// mathematical operations. For example, the `AccumulateGrad` function is a -// *sink*: it takes one input, but produces no outputs, instead accumulating -// the input as a side effect. At the other extreme, the `GraphRoot` function -// receives no inputs from other functions, but produces multiple outputs. -// -// Interface -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// The most important method on `Node` is the call operator, which takes in -// a list of variables and produces a list of variables. The precise size of -// these lists can be determined with `num_inputs()` and `num_outputs()`. -// `Node`s are stitched together via their `next_edge` interface, which let -// you manipulate the set of outgoing edges of a `Node`. You can add an -// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and -// iterate over them via the `next_edges()` method. Other methods exist for -// integration with the JIT and other parts of PyTorch. Every `Node` has a -// *sequence number* that increases monotonically in the order of `Node` -// construction. It can be retrieved via the `sequence_nr()` method. Note that -// this sequence number is *thread local*. This means that when `Node`s -// `A`, `B` and `C` are created consecutively in the same thread, their -// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` -// are created in one thread and `C` is created in a new thread, there are *no -// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. -// See NOTE [ Sequence Number] for more details on the usages of sequence -// number. -//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -struct TORCH_API Node : std::enable_shared_from_this { - public: - /// Construct a new `Node` with the given `next_edges` - explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) - : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { - for (const Edge& edge : next_edges_) { - update_topological_nr(edge); - } - - if (AnomalyMode::is_enabled()) { - metadata()->store_stack(); - - // If anomaly mode is enabled and graph is constructed, then assign the - // currently evaluating node as the parent of this node. - // A parent is a Node where this Node is created. - // We are tracking the parents to track multiple backward operations. - assign_parent(); - } - - // Store the thread_id of the forward operator. - // See NOTE [ Sequence Numbers ] - thread_id_ = at::RecordFunction::currentThreadId(); - } - - explicit Node(edge_list&& next_edges = edge_list()) - : Node( - /*sequence_nr=*/at::sequence_number::get_and_increment(), - std::move(next_edges)) {} - - /// Nodes are neither copyable nor moveable. - Node(const Node& other) = delete; - Node(Node&& other) = delete; - Node& operator=(const Node& other) = delete; - Node& operator=(Node&& other) = delete; - virtual ~Node() = default; - - std::shared_ptr getptr() { - return shared_from_this(); - } - /// Evaluates the function on the given inputs and returns the result of the - /// function call. - variable_list operator()(variable_list&& inputs) { - // In the first iteration of named tensors, autograd ignores names and - // operates on unnamed tensors. In the long term, autograd should - // probably operate with names. - at::NoNamesGuard no_names_guard; - -#ifdef USE_ROCM - // Keep track of backward pass for rocblas. - at::ROCmBackwardPassGuard in_backward; -#endif - - auto step_callbacks = - at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); - if (C10_UNLIKELY(step_callbacks.has_value())) { - at::RecordFunction guard(std::move(*step_callbacks)); - // Using sequence number and thread id to correlate with - // the forward pass function - guard.setForwardThreadId(thread_id_); - if (guard.needsInputs()) { - std::vector inputs_vec(inputs.begin(), inputs.end()); - guard.before( - name(), - c10::ArrayRef( - inputs_vec.data(), inputs_vec.size()), - static_cast(sequence_nr())); - } else { - guard.before(name(), static_cast(sequence_nr())); - } - return apply(std::move(inputs)); - } else { - return apply(std::move(inputs)); - } - } - - // Graph Connectivity API - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the - // forward function. - - // Marker for expected undefined input - struct undefined_input {}; - - /// Adds the type and shape metadata for a new input. Returns the index of - /// of the new input. - uint32_t add_input_metadata( - const at::TensorOptions& options, - c10::SymIntArrayRef shape, - bool is_tensor_subclass, - bool is_nested, - std::optional grad_dtype) noexcept { - uint32_t input_nr = input_metadata_.size(); - auto meta_shape = MetadataShape{std::in_place_type, shape}; - input_metadata_.emplace_back( - options, meta_shape, is_tensor_subclass, is_nested, grad_dtype); - return input_nr; - } - - uint32_t add_input_metadata(const at::Tensor& t) noexcept { - uint32_t input_nr = input_metadata_.size(); - input_metadata_.emplace_back(t); - return input_nr; - } - - /// Adds a placeholder for an input that will not be used. - uint32_t add_input_metadata(undefined_input u) noexcept { - uint32_t input_nr = input_metadata_.size(); - input_metadata_.emplace_back(); - return input_nr; - } - - uint32_t num_inputs() const noexcept { - return input_metadata_.size(); - } - - const InputMetadata& input_metadata(size_t index) const { - return input_metadata_[index]; - } - - // Danger: not thread safe, caller must protect with lock - InputMetadata& mutable_input_metadata(size_t index) { - return input_metadata_[index]; - } - - /** - * Note: Function Streams - * A function's stream (for a given device type) is the stream of the first - * element of its input buffer on a device of that type. - * - * If all elements are on the same device they MUST share a stream. If - * elements are on different devices (across multiple GPUs, for example) - * they may have different streams. - */ - std::optional stream() { - auto opt_device_type = at::getAccelerator(); - if (!opt_device_type.has_value()) { - return std::nullopt; - } - for (const auto& metadata : input_metadata_) { - if (metadata.device().type() == opt_device_type.value()) - return metadata.stream(); - } - - return std::nullopt; - } - - // Used by the engine to determine what device thread to run on - at::Device device() { - // Since we pick the first non-CPU tensor, this won't work with - // mixed device-type operations (e.g., an op that is both CUDA - // and XLA). This is *incredibly* unlikely, so we don't worry - // about it. - for (const auto& metadata : input_metadata_) { - auto device = metadata.device(); - if (device.type() != at::kCPU) { - return device; - } - } - // Only report to the CPU thread if there really were no tensors - // from other devices. - return at::kCPU; - } - - void clear_input_metadata() { - input_metadata_.clear(); - } - - // Outputs ("Next Edges") - - void update_topological_nr(const Edge& edge) { - TORCH_INTERNAL_ASSERT( - !has_parent_, - "Cannot update a node's topological_nr after it already has a parent." - " If we allow this, we can no longer guarantee that a parent's" - " topo_nr is always greater than those of all its children") - Node* node = edge.function.get(); - if (node) { - auto topo_nr = node->topological_nr(); - if (topological_nr_ <= topo_nr) { - topological_nr_ = topo_nr + 1; - } - } - } - - void set_next_edge(size_t index, Edge edge) { - update_topological_nr(edge); - next_edges_[index] = std::move(edge); - } - - void add_next_edge(Edge edge) { - update_topological_nr(edge); - next_edges_.emplace_back(std::move(edge)); - } - - void set_next_edges(edge_list&& next_edges) { - next_edges_ = std::move(next_edges); - for (const auto& next_edge : next_edges_) { - update_topological_nr(next_edge); - } - } - - const Edge& next_edge(size_t index) const noexcept { - return next_edges_[index]; - } - - const edge_list& next_edges() const noexcept { - return next_edges_; - } - - edge_list& next_edges() noexcept { - return next_edges_; - } - - uint32_t num_outputs() const noexcept { - return next_edges_.size(); - } - - // Miscellaneous Methods - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - /// NOTE [ Sequence Number] - /// - /// The sequence_nr has two main usages in autograd: - /// - /// 1) Helps determine the node's execution priority in the engine. - /// All else being equal, nodes with higher priority numbers are executed - /// first. Thus, nodes corresponding to ops executed later are the first to - /// be executed in the backward pass. One caveat is that we prioritize - /// AccumulateGrad nodes by explicitly setting its sequence_nr to be - /// UINT64_MAX. - /// 2) The sequence number of this `Node` is paired with with thread_id it was - /// created in - /// as a unique identifier by the profiler to annotate recorded events. - /// The purpose of this is to help users (and possibly programs) - /// interpreting the profiler's output to correlate backward nodes with its - /// forward ops. We need both sequence_nr and thread_id to identify a node - /// because sequence_nr is thread_local, i.e., starts counting up from zero - /// in a new thread - uint64_t sequence_nr() const noexcept { - return sequence_nr_; - } - - void set_sequence_nr(uint64_t sequence_nr) { - sequence_nr_ = sequence_nr; - } - - // NOTE [ Topological Number ] - // - // topological_nr is used to prune branches in the DAG during autograd - // discovery as maintaining topological_nr helps us check in O(1) if there - // does NOT exist a directed path between two nodes. - // - // The topological order number of this `Node` representing the length of the - // longest possible path from this Node to any leaf node. If you are leaf - // node, aka AccumulateGrad, this will be zero. This value has the property - // that For every pair of nodes X, Y in G, existence of a directed path from X - // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so - // we cannot prove existence of a path from X to Y, only non-existence. - // - // One assumption we make when using topo_nr is that once a node - // has been used, i.e., has a parent node, its own topo_nr does not change - // we have added some checks with the `has_parent_` field to enforce this. - // - // What NOT to do: - // - // 1) 2 -> 1 -> 0 In this diagram we label nodes with their - // topo_nr. - // 2 -> 1 -> 0 We have two simple graphs that can each - // arise from - // `t.exp().exp()`, for example. - // 2) 2 -> 1 -> 0 - // / - // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 - // already - // has a parent. - // 3) 2 -> 1 -> 0 - // / - // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3! - // - uint64_t topological_nr() const noexcept { - has_parent_ = true; - return topological_nr_; - } - - // assigning a node as a parent to this node - void assign_parent(); - - /// Id of the thread that created Node - uint64_t thread_id() const noexcept { - return thread_id_; - } - - /// Returns the name of the dynamic type of the function, for debugging. - virtual std::string name() const; - - /// The difference between functions `should_compute_output` and - /// `task_should_compute_output`: - /// - `should_compute_output` should only be used during graph construction - /// and takes into account only requires_grad information - /// - `task_should_compute_output` should only be called during the backward - /// pass (unless called directly through grad_fn) and takes into account the - /// current graph task. Specifically, the autograd engine trims unnecessary - /// edges when `inputs` are specified, and during backward untrimmed nodes - /// left on the graph can/should check `task_should_compute_output` to see if - /// any outgoing edges have been trimmed by the engine. If that is the case, - /// gradient computation wrt those edges can be omitted. - /// - /// Returns true if the particular output edge is active, and that particular - /// output of this function should be computed. - bool should_compute_output(size_t output_edge_index) const { - TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); - return next_edges_[output_edge_index].is_valid(); - } - - /// Returns true if any of the output edges in any of the ranges are active. - bool should_compute_output(std::initializer_list idxs) const { - return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { - for (const auto i : c10::irange(range.first, range.second)) { - if (should_compute_output(i)) - return true; - } - return false; - }); - } - - /// Same as the above `should_compute_output` function but will also - /// check whether this edge is needed within the current graph task. - bool task_should_compute_output(size_t output_edge_index) const { - TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); - const auto& next = next_edges_[output_edge_index]; - if (next.is_valid()) { - const auto exec_info = get_current_graph_task_exec_info(); - if (exec_info && !exec_info->empty()) { - auto it = exec_info->find(next.function.get()); - if (it == exec_info->end() || !it->second.should_execute()) { - return false; // this edge is not needed for the current graph_task - } - } - return true; - } - return false; - } - - /// Returns true if any of the output edges in any of the ranges are active - /// and should be computed in the current graph task. - bool task_should_compute_output( - std::initializer_list idxs) const { - return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { - for (const auto i : c10::irange(range.first, range.second)) { - if (task_should_compute_output(i)) - return true; - } - return false; - }); - } - - /// Returns the `PyObject` stored for this `Node` (for Python - /// interaction). - PyObject* pyobj() const noexcept { - return pyobj_; - } - - /// Sets the `PyObject` stored for this `Node` (for Python interaction). - void set_pyobj(PyObject* pyobj) noexcept { - pyobj_ = pyobj; - } - - /// Returns the anomaly metadata stored for this `Node`. - /// If none exist, creates a new empty one. - AnomalyMetadata* metadata() noexcept; - - // Hook API - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - uintptr_t add_post_hook(std::unique_ptr&& post_hook) { - post_hooks_.emplace_back(std::move(post_hook)); - // Use the raw pointer as the unique key to identify this hook. This key - // can then be used in del_post_hook(key) to remove this hook. - return reinterpret_cast(post_hooks_.back().get()); - } - - const std::vector>& post_hooks() - const noexcept { - return post_hooks_; - } - - // delete a post hook matching the key - bool del_post_hook(const uintptr_t& key) { - for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) { - if (key == reinterpret_cast(it->get())) { - post_hooks_.erase(it); - return true; - } - } - return false; - } - - std::vector>& post_hooks() noexcept { - return post_hooks_; - } - - void add_pre_hook(std::unique_ptr&& pre_hook) { - pre_hooks_.emplace_back(std::move(pre_hook)); - } - - void add_tensor_pre_hook(std::unique_ptr&& pre_hook) { - tensor_pre_hooks_.emplace_back(std::move(pre_hook)); - } - - void add_retains_grad_hook( - std::unique_ptr&& pre_hook, - size_t output_idx) { - retains_grad_hooks_[output_idx] = std::move(pre_hook); - } - - std::unique_ptr pop_retains_grad_hook(size_t output_idx) { - auto ret = std::move(retains_grad_hooks_[output_idx]); - retains_grad_hooks_.erase(output_idx); - return ret; - } - - const std::vector>& pre_hooks() - const noexcept { - return pre_hooks_; - } - - std::vector>& pre_hooks() noexcept { - return pre_hooks_; - } - - virtual std::vector>& - tensor_pre_hooks() noexcept { - return tensor_pre_hooks_; - } - - virtual std::unique_ptr& tensor_post_acc_grad_hooks() - const noexcept { - static std::unique_ptr empty = nullptr; - return empty; - } - - std::unordered_map>& - retains_grad_hooks() noexcept { - return retains_grad_hooks_; - } - - // Customization Points for Subclasses - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - /// Releases saved variables if the operation won't be reused. - virtual void release_variables() {} - - /// Called before an apply if `release_variables()` is going to be called. - /// Allows larger ops like `InterpreterAutogradFunction` to incrementally - /// release variables as they run. - virtual void will_release_variables() {} - - /// Returns true if this function is traceable. An op is traceable if all - /// operations happening within `apply()` are performed on autograd - /// `Variables` (i.e. apply mostly instantiates and applies other functions). - virtual bool is_traceable() { - return false; - } - - /// A `Node` is said to pass state transparently to backward, if the - /// state consists only of (Saved)Variables and only non-variable objects - /// that parameterize the operation in some way that defines the graph - /// structure AND the backward function is traceable. In particular, - /// parametrization MUST NOT depend on the data of any `Variable`. - /// TODO: it might be possible to handle cases where backward is - /// non-traceable but state passing could be considered transparent. This - /// will probably depend on saved_variable_list being mutable. - /// NOTE: this value matters only if is_traceable() returns false. - virtual bool passes_state_transparently() { - return false; - } - - // see [Note: Compiled Autograd] - // Used by compiled autograd to - // 1) Extract tensors/symint args - // 2) Collect node information for specialization and caching - // Implementations in subclasses should call args.collect() with all node - // attrs. These functions are only called during backward. - virtual void compiled_args(CompiledNodeArgs& args) const { - TORCH_CHECK_NOT_IMPLEMENTED( - false, std::string("compiled_args not implemented: ") + name()); - } - - // Used by compiled autograd to call apply() with different saved tensors - // Implementations should call saved.before() on all attrs, then apply(), then - // saved.after() on all attrs in the same order. - virtual variable_list apply_with_saved( - const variable_list& inputs, - SwapSavedVariables& saved) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, std::string("apply_with_saved not implemented: ") + name()); - } - - // If this node is the AOTBackward node produced by torch.compile. - // Compiled Autograd special-cases on this information. - virtual bool is_aot_backward() const { - return false; - } - - protected: - /// Performs the `Node`'s actual operation. - virtual variable_list apply(variable_list&& inputs) = 0; - - /// Calls `apply()`, but instruments it with tracing machinery. - variable_list traced_apply(variable_list inputs); - - // Sequence number used to correlate backward nodes with forward ops in the - // profiler and provide determinism in the engine. - // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) - uint64_t sequence_nr_; - - // See NOTE [ Topological Number ] - uint64_t topological_nr_ = 0; - - // Tracks whether this node has been added as the next_edge of another node - // via set_next_edge(s), which always calls topological_nr() of all its - // children See NOTE [ Topological Number ] for why we need this. - mutable bool has_parent_ = false; - - // Id of the thread that created the instance - uint64_t thread_id_ = 0; - - // Note [Thread Safety on Autograd Node] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Autograd Engine let the owning thread which calls Engine::execute to drive - // the GraphTask execution, there might be cases that part of the GraphTask is - // shared across different `backward()` or `grad()` calls, i.e. fork new - // threads in the middle of the forward and call `backward()` separately from - // different threads. We need to protect the thread safety on NodeTask to - // prevent data racing on shared variables read/write. - // - // NB: This is only needed for Autograd Nodes that runs on CPU, technically - // "CUDA", "XLA" nodes don't need locking because device threads are always - // single threaded. - // - // Here we add a thread mutex to help protect the Node's thread safety, so - // that different threads cannot race the shared data when executing the same - // NodeTask from multiple CPU threads. It IS the user/developer responsibility - // to take advantage of this mutex to protect the thread safety of their - // autograd Node. The general strategy of thread safety on autograd Node: - // - // 1. User should lock the mutex during Node::release_variables() if the Node - // needs - // to release the variables on the fly, this serve the purpose that when we - // release saved_variables from one thread, no other threads can release - // the saved variables concurrently. call the Node::apply(), - // 2. User should lock the mutex during Node::apply(), this is to ensure Node - // that - // writing to the shared variable are not racing across threads (i.e. - // AccumulateGrad and custom C++ Autograd Node if writing to shared - // variables ) - // 3. item 2 and item 3 should work together so that when we release saved - // variables - // from one thread, no other threads can call Node::apply(), this ensures - // the variable references from other threads aren't dangling. - // 4. if the Node don't release any variables and no shared data read/write in - // the Node - // i.e. purely functional, user don't need to lock the mutex - // - // This way we could protect the thread safety on Autograd Node, but we could - // still not protect the thread safety on Node pre/post C++ hooks (python - // hooks are automatically thread safe), we rely on the user to write thread - // safe C++ hooks if they want the hook to be correctly applied in - // multithreading environment. - std::mutex mutex_; - - edge_list next_edges_; - PyObject* pyobj_ = nullptr; // weak reference - std::unique_ptr anomaly_metadata_ = nullptr; - - // NOTE [Hooks ordering] - // We have 3 separate fields for pre hooks registered to the autograd nodes - // because the conditions under which they execute are different, and we - // want more fine-grained control over the order in which different types - // of hooks are executed. - // - pre_hooks are only executed when the node itself is executed - // - tensor_pre_hook is executed as long as the engine traverses over it - // even if that node won't be executed. - // - retains_grad_hook are like tensor_pre_hooks except they are always - // ordered after all other tensor pre hooks - std::vector> pre_hooks_; - std::vector> tensor_pre_hooks_; - std::unordered_map> - retains_grad_hooks_; - std::vector> post_hooks_; - at::SmallVector input_metadata_; -}; - -/// See Node::is_traceable() for definition. -struct TraceableFunction : public Node { - using Node::Node; - bool is_traceable() final { - return true; - } -}; - //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Associated Free Nodes +// Associated Free Functions //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ namespace detail { diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp index 3e26190dd9ec6..fa41af91fa196 100644 --- a/torch/csrc/autograd/init.cpp +++ b/torch/csrc/autograd/init.cpp @@ -293,6 +293,10 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { .def( "linked_correlation_id", [](const KinetoEvent& e) { return e.linkedCorrelationId(); }) + .def("flow_id", [](const KinetoEvent& e) { return e.flowId(); }) + .def("flow_type", [](const KinetoEvent& e) { return e.flowType(); }) + .def("flow_start", [](const KinetoEvent& e) { return e.flowStart(); }) + .def("external_id", [](const KinetoEvent& e) { return e.externalId(); }) // compute flops .def("flops", [](const KinetoEvent& e) { return e.flops(); }) // Whether this is async event or not @@ -307,14 +311,58 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { e.activityType() == (uint8_t)libkineto::ActivityType::GPU_USER_ANNOTATION; }) + .def( + "is_python_function", + [](const KinetoEvent& e) { return e.isPythonFunction(); }) .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }) // whether the event is hidden .def( "is_hidden_event", [](const KinetoEvent& e) { return e.isHiddenEvent(); }) // KinetoEvent metadata - .def("metadata_json", [](const KinetoEvent& e) { - return e.metadataJson(); + .def( + "metadata_json", + [](const KinetoEvent& e) { return e.metadataJson(); }) + .def( + "activity_type", + [](const KinetoEvent& e) { + return libkineto::toString( + static_cast(e.activityType())); + }) + .def("extra_meta", [](const KinetoEvent& e) { return e.extraMeta(); }) + // Like shapes/strides, but also contains TensorList input shapes. + .def( + "structured_input_shapes", + [](const KinetoEvent& e) { + py::list result; + for (const auto& s : e.structuredInputShapes()) { + if (std::holds_alternative>(s)) { + result.append(std::get>(s)); + } else { + result.append(std::get>>(s)); + } + } + return result; + }) + .def( + "structured_input_strides", + [](const KinetoEvent& e) { + py::list result; + for (const auto& s : e.structuredInputStrides()) { + if (std::holds_alternative>(s)) { + result.append(std::get>(s)); + } else { + result.append(std::get>>(s)); + } + } + return result; + }) + .def("python_id", [](const KinetoEvent& e) { return e.pythonId(); }) + .def( + "python_parent_id", + [](const KinetoEvent& e) { return e.pythonParentId(); }) + .def("python_module_id", [](const KinetoEvent& e) { + return e.pythonModuleId(); }); m.def("_soft_assert_raises", &setSoftAssertRaises); @@ -767,7 +815,7 @@ static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { ")"); TORCH_WARN_DEPRECATION( "torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead.") - at::autocast::set_autocast_enabled(at::kCPU, arg == Py_True); + at::autocast::set_autocast_enabled(at::kCPU, Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -793,7 +841,7 @@ static PyObject* set_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) { ")"); TORCH_WARN_DEPRECATION( "torch.set_autocast_ipu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('ipu', enabled) instead.") - at::autocast::set_autocast_enabled(at::kIPU, arg == Py_True); + at::autocast::set_autocast_enabled(at::kIPU, Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -819,7 +867,7 @@ static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) { ")"); TORCH_WARN_DEPRECATION( "torch.set_autocast_xla_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('xla', enabled) instead.") - at::autocast::set_autocast_enabled(at::kXLA, arg == Py_True); + at::autocast::set_autocast_enabled(at::kXLA, Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -970,7 +1018,7 @@ static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { "enabled must be a bool (got ", Py_TYPE(arg)->tp_name, ")"); - at::autocast::set_autocast_cache_enabled(arg == Py_True); + at::autocast::set_autocast_cache_enabled(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1014,7 +1062,7 @@ static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { "enabled must be a bool (got ", Py_TYPE(arg)->tp_name, ")"); - c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True); + c10::AutogradState::get_tls_state().set_fw_grad_mode(Py_IsTrue(arg)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS } @@ -1234,7 +1282,7 @@ static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) { static PyObject* set_graph_exec_group(PyObject* self, PyObject* obj) { HANDLE_TH_ERRORS - if (obj == Py_None) { + if (Py_IsNone(obj)) { c10::AutogradState::get_tls_state().set_graph_exec_group(std::nullopt); } else { Py_INCREF(obj); @@ -1348,7 +1396,7 @@ static PyObject* push_on_torch_function_stack( PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - if (arg != Py_None) { + if (!Py_IsNone(arg)) { Py_INCREF(arg); at::impl::PythonTorchFunctionTLS::push_onto_stack( std::make_shared(arg, getPyInterpreter())); @@ -1399,7 +1447,7 @@ static PyObject* push_on_torch_dispatch_stack( PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - if (arg != Py_None) { + if (!Py_IsNone(arg)) { using c10::impl::TorchDispatchModeKey; // When we push a mode onto the mode stack, we need to // check if it's an "infra" mode, by checking its _mode_key attribute. @@ -1433,7 +1481,7 @@ static PyObject* pop_torch_dispatch_stack( // When the shared_ptr is destroyed, ~SafePyObject will Py_DECREF, so we must // Py_INCREF first to give the caller a valid reference. std::shared_ptr mode; - if (maybe_mode_key != Py_None) { + if (!Py_IsNone(maybe_mode_key)) { mode_key = py::cast(maybe_mode_key); auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key.value()); @@ -1477,7 +1525,7 @@ static PyObject* get_dispatch_stack_at( static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) { HANDLE_TH_ERRORS - TORCH_CHECK(mode != Py_None); + TORCH_CHECK(!Py_IsNone(mode)); py::object maybe_mode_key_obj = PyObject_FastGetAttrString(mode, "_mode_key"); TORCH_CHECK( @@ -1497,7 +1545,7 @@ static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) { static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - TORCH_CHECK(arg != Py_None); + TORCH_CHECK(!Py_IsNone(arg)); auto mode_key = py::cast(arg); auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key); @@ -1512,7 +1560,7 @@ static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) { static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS - TORCH_CHECK(arg != Py_None); + TORCH_CHECK(!Py_IsNone(arg)); auto mode_key = py::cast(arg); const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key); diff --git a/torch/csrc/autograd/node.h b/torch/csrc/autograd/node.h new file mode 100644 index 0000000000000..cee1f6dc63254 --- /dev/null +++ b/torch/csrc/autograd/node.h @@ -0,0 +1,706 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::autograd { + +struct Edge; +struct FunctionPostHook; +struct FunctionPreHook; +struct Node; +class SavedVariable; + +using Variable = at::Tensor; +using tensor_list = std::vector; +using variable_list = std::vector; +using edge_list = std::vector; +using saved_variable_list = std::vector; +using ivalue_list = std::vector; +using functional_apply_t = std::function< + variable_list(const variable_list&, const std::vector&)>; +using IndexRange = std::pair; +using torch::dynamo::autograd::CompiledNodeArgs; +using torch::dynamo::autograd::PackedArgs; +using torch::dynamo::autograd::SwapSavedVariables; + +// Custom deleter to prevent stack overflows. +TORCH_API void deleteNode(Node* function); + +// Guard that sets and restores the evaluating node +class NodeGuard { + public: + explicit NodeGuard(std::shared_ptr node); + ~NodeGuard(); + + private: + std::shared_ptr last_evaluating_node_; +}; + +// Return the Node currently being evaluated (if any) +// This is only set during the backward pass while a Node is being +// executed. +TORCH_API std::shared_ptr get_current_node(); + +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Node +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// A `Node` is an abstract class that represents an operation taking zero +// or more input `Variable`s and producing zero or more output `Variable`s. All +// functions in PyTorch's autograd machinery derive from this class and +// override its `apply` method. Instances of such subclasses will then be +// invocable via the call operator. +// +// Nodes in the Autograd Graph +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// When viewing the autograd system as a graph, `Node`s are the vertices or +// nodes, connected to each other via (directed) `Edge`s, which themselves are +// represented via (`Node`, input_nr) pairs. `Variable`s are the outputs to +// and inputs of `Node`s, and travel between these edges during execution +// of the graph. When two or more `Edge`s (from different sources) point at the +// same input to a `Node`, the values produced along all of these edges are +// implicitly summed prior to being forwarded to the target `Node`. +// +// Hierarchy +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Subclasses usually represent differentiable functions as well as their +// gradient operators. Note, however, that due to the very general definition +// of a `Node` taking *zero* or more inputs and producing *zero* or more +// outputs, uses of `Node`s are flexible and extend beyond purely +// mathematical operations. For example, the `AccumulateGrad` function is a +// *sink*: it takes one input, but produces no outputs, instead accumulating +// the input as a side effect. At the other extreme, the `GraphRoot` function +// receives no inputs from other functions, but produces multiple outputs. +// +// Interface +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// The most important method on `Node` is the call operator, which takes in +// a list of variables and produces a list of variables. The precise size of +// these lists can be determined with `num_inputs()` and `num_outputs()`. +// `Node`s are stitched together via their `next_edge` interface, which let +// you manipulate the set of outgoing edges of a `Node`. You can add an +// edge with `add_next_edge()`, retrieve an edge with `next_edge(index)` and +// iterate over them via the `next_edges()` method. Other methods exist for +// integration with the JIT and other parts of PyTorch. Every `Node` has a +// *sequence number* that increases monotonically in the order of `Node` +// construction. It can be retrieved via the `sequence_nr()` method. Note that +// this sequence number is *thread local*. This means that when `Node`s +// `A`, `B` and `C` are created consecutively in the same thread, their +// sequence numbers will be ordered `A` < `B` < `C`. If, however, `A` and `B` +// are created in one thread and `C` is created in a new thread, there are *no +// guarantees* w.r.t. the ordering of `C` relative to `A` or `B`. +// See NOTE [ Sequence Number] for more details on the usages of sequence +// number. +//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +struct TORCH_API Node : std::enable_shared_from_this { + public: + /// Construct a new `Node` with the given `next_edges` + explicit Node(uint64_t sequence_nr, edge_list&& next_edges = edge_list()) + : sequence_nr_(sequence_nr), next_edges_(std::move(next_edges)) { + for (const Edge& edge : next_edges_) { + update_topological_nr(edge); + } + + if (AnomalyMode::is_enabled()) { + metadata()->store_stack(); + + // If anomaly mode is enabled and graph is constructed, then assign the + // currently evaluating node as the parent of this node. + // A parent is a Node where this Node is created. + // We are tracking the parents to track multiple backward operations. + assign_parent(); + } + + // Store the thread_id of the forward operator. + // See NOTE [ Sequence Numbers ] + thread_id_ = at::RecordFunction::currentThreadId(); + } + + explicit Node(edge_list&& next_edges = edge_list()) + : Node( + /*sequence_nr=*/at::sequence_number::get_and_increment(), + std::move(next_edges)) {} + + /// Nodes are neither copyable nor moveable. + Node(const Node& other) = delete; + Node(Node&& other) = delete; + Node& operator=(const Node& other) = delete; + Node& operator=(Node&& other) = delete; + virtual ~Node() = default; + + std::shared_ptr getptr() { + return shared_from_this(); + } + /// Evaluates the function on the given inputs and returns the result of the + /// function call. + variable_list operator()(variable_list&& inputs) { + // In the first iteration of named tensors, autograd ignores names and + // operates on unnamed tensors. In the long term, autograd should + // probably operate with names. + at::NoNamesGuard no_names_guard; + +#ifdef USE_ROCM + // Keep track of backward pass for rocblas. + at::ROCmBackwardPassGuard in_backward; +#endif + + auto step_callbacks = + at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION); + if (C10_UNLIKELY(step_callbacks.has_value())) { + at::RecordFunction guard(std::move(*step_callbacks)); + // Using sequence number and thread id to correlate with + // the forward pass function + guard.setForwardThreadId(thread_id_); + if (guard.needsInputs()) { + std::vector inputs_vec(inputs.begin(), inputs.end()); + guard.before( + name(), + c10::ArrayRef( + inputs_vec.data(), inputs_vec.size()), + static_cast(sequence_nr())); + } else { + guard.before(name(), static_cast(sequence_nr())); + } + return apply(std::move(inputs)); + } else { + return apply(std::move(inputs)); + } + } + + // Graph Connectivity API + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + // Inputs. NOTE: inputs of the grad_fn correspond to Tensor outputs of the + // forward function. + + // Marker for expected undefined input + struct undefined_input {}; + + /// Adds the type and shape metadata for a new input. Returns the index of + /// of the new input. + uint32_t add_input_metadata( + const at::TensorOptions& options, + c10::SymIntArrayRef shape, + bool is_tensor_subclass, + bool is_nested, + std::optional grad_dtype) noexcept { + uint32_t input_nr = input_metadata_.size(); + auto meta_shape = MetadataShape{std::in_place_type, shape}; + input_metadata_.emplace_back( + options, meta_shape, is_tensor_subclass, is_nested, grad_dtype); + return input_nr; + } + + uint32_t add_input_metadata(const at::Tensor& t) noexcept { + uint32_t input_nr = input_metadata_.size(); + input_metadata_.emplace_back(t); + return input_nr; + } + + /// Adds a placeholder for an input that will not be used. + uint32_t add_input_metadata(undefined_input u) noexcept { + uint32_t input_nr = input_metadata_.size(); + input_metadata_.emplace_back(); + return input_nr; + } + + uint32_t num_inputs() const noexcept { + return input_metadata_.size(); + } + + const InputMetadata& input_metadata(size_t index) const { + return input_metadata_[index]; + } + + // Danger: not thread safe, caller must protect with lock + InputMetadata& mutable_input_metadata(size_t index) { + return input_metadata_[index]; + } + + /** + * Note: Function Streams + * A function's stream (for a given device type) is the stream of the first + * element of its input buffer on a device of that type. + * + * If all elements are on the same device they MUST share a stream. If + * elements are on different devices (across multiple GPUs, for example) + * they may have different streams. + */ + std::optional stream() { + auto opt_device_type = at::getAccelerator(); + if (!opt_device_type.has_value()) { + return std::nullopt; + } + for (const auto& metadata : input_metadata_) { + if (metadata.device().type() == opt_device_type.value()) + return metadata.stream(); + } + + return std::nullopt; + } + + // Used by the engine to determine what device thread to run on + at::Device device() { + // Since we pick the first non-CPU tensor, this won't work with + // mixed device-type operations (e.g., an op that is both CUDA + // and XLA). This is *incredibly* unlikely, so we don't worry + // about it. + for (const auto& metadata : input_metadata_) { + auto device = metadata.device(); + if (device.type() != at::kCPU) { + return device; + } + } + // Only report to the CPU thread if there really were no tensors + // from other devices. + return at::kCPU; + } + + void clear_input_metadata() { + input_metadata_.clear(); + } + + // Outputs ("Next Edges") + + void update_topological_nr(const Edge& edge) { + TORCH_INTERNAL_ASSERT( + !has_parent_, + "Cannot update a node's topological_nr after it already has a parent." + " If we allow this, we can no longer guarantee that a parent's" + " topo_nr is always greater than those of all its children") + Node* node = edge.function.get(); + if (node) { + auto topo_nr = node->topological_nr(); + if (topological_nr_ <= topo_nr) { + topological_nr_ = topo_nr + 1; + } + } + } + + void set_next_edge(size_t index, Edge edge) { + update_topological_nr(edge); + next_edges_[index] = std::move(edge); + } + + void add_next_edge(Edge edge) { + update_topological_nr(edge); + next_edges_.emplace_back(std::move(edge)); + } + + void set_next_edges(edge_list&& next_edges) { + next_edges_ = std::move(next_edges); + for (const auto& next_edge : next_edges_) { + update_topological_nr(next_edge); + } + } + + const Edge& next_edge(size_t index) const noexcept { + return next_edges_[index]; + } + + const edge_list& next_edges() const noexcept { + return next_edges_; + } + + edge_list& next_edges() noexcept { + return next_edges_; + } + + uint32_t num_outputs() const noexcept { + return next_edges_.size(); + } + + // Miscellaneous Methods + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// NOTE [ Sequence Number] + /// + /// The sequence_nr has two main usages in autograd: + /// + /// 1) Helps determine the node's execution priority in the engine. + /// All else being equal, nodes with higher priority numbers are executed + /// first. Thus, nodes corresponding to ops executed later are the first to + /// be executed in the backward pass. One caveat is that we prioritize + /// AccumulateGrad nodes by explicitly setting its sequence_nr to be + /// UINT64_MAX. + /// 2) The sequence number of this `Node` is paired with with thread_id it was + /// created in + /// as a unique identifier by the profiler to annotate recorded events. + /// The purpose of this is to help users (and possibly programs) + /// interpreting the profiler's output to correlate backward nodes with its + /// forward ops. We need both sequence_nr and thread_id to identify a node + /// because sequence_nr is thread_local, i.e., starts counting up from zero + /// in a new thread + uint64_t sequence_nr() const noexcept { + return sequence_nr_; + } + + void set_sequence_nr(uint64_t sequence_nr) { + sequence_nr_ = sequence_nr; + } + + // NOTE [ Topological Number ] + // + // topological_nr is used to prune branches in the DAG during autograd + // discovery as maintaining topological_nr helps us check in O(1) if there + // does NOT exist a directed path between two nodes. + // + // The topological order number of this `Node` representing the length of the + // longest possible path from this Node to any leaf node. If you are leaf + // node, aka AccumulateGrad, this will be zero. This value has the property + // that For every pair of nodes X, Y in G, existence of a directed path from X + // to Y implies topo_nr(X) > topo_nr(Y). The converse is not true, however, so + // we cannot prove existence of a path from X to Y, only non-existence. + // + // One assumption we make when using topo_nr is that once a node + // has been used, i.e., has a parent node, its own topo_nr does not change + // we have added some checks with the `has_parent_` field to enforce this. + // + // What NOT to do: + // + // 1) 2 -> 1 -> 0 In this diagram we label nodes with their + // topo_nr. + // 2 -> 1 -> 0 We have two simple graphs that can each + // arise from + // `t.exp().exp()`, for example. + // 2) 2 -> 1 -> 0 + // / + // 2 -> 1 -> 0 We add 2 as a next edge to 1 even though 1 + // already + // has a parent. + // 3) 2 -> 1 -> 0 + // / + // 2 -> 3 -> 0 2 < 3, yet there exists a path from 2 to 3! + // + uint64_t topological_nr() const noexcept { + has_parent_ = true; + return topological_nr_; + } + + // assigning a node as a parent to this node + void assign_parent(); + + /// Id of the thread that created Node + uint64_t thread_id() const noexcept { + return thread_id_; + } + + /// Returns the name of the dynamic type of the function, for debugging. + virtual std::string name() const; + + /// Returns the name of the corresponding forward op by stripping the + /// "Backward" suffix from name(), if present. + std::string forward_op_name() const; + + /// The difference between functions `should_compute_output` and + /// `task_should_compute_output`: + /// - `should_compute_output` should only be used during graph construction + /// and takes into account only requires_grad information + /// - `task_should_compute_output` should only be called during the backward + /// pass (unless called directly through grad_fn) and takes into account the + /// current graph task. Specifically, the autograd engine trims unnecessary + /// edges when `inputs` are specified, and during backward untrimmed nodes + /// left on the graph can/should check `task_should_compute_output` to see if + /// any outgoing edges have been trimmed by the engine. If that is the case, + /// gradient computation wrt those edges can be omitted. + /// + /// Returns true if the particular output edge is active, and that particular + /// output of this function should be computed. + bool should_compute_output(size_t output_edge_index) const { + TORCH_CHECK(output_edge_index < num_outputs(), "Index out of range"); + return next_edges_[output_edge_index].is_valid(); + } + + /// Returns true if any of the output edges in any of the ranges are active. + bool should_compute_output(std::initializer_list idxs) const { + return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { + for (const auto i : c10::irange(range.first, range.second)) { + if (should_compute_output(i)) + return true; + } + return false; + }); + } + + /// Same as the above `should_compute_output` function but will also + /// check whether this edge is needed within the current graph task. + /// Implemented out of line to avoid including graph_task.h. + bool task_should_compute_output(size_t output_edge_index) const; + + /// Returns true if any of the output edges in any of the ranges are active + /// and should be computed in the current graph task. + bool task_should_compute_output( + std::initializer_list idxs) const { + return std::any_of(idxs.begin(), idxs.end(), [this](IndexRange range) { + for (const auto i : c10::irange(range.first, range.second)) { + if (task_should_compute_output(i)) + return true; + } + return false; + }); + } + + /// Returns the `PyObject` stored for this `Node` (for Python + /// interaction). + PyObject* pyobj() const noexcept { + return pyobj_; + } + + /// Sets the `PyObject` stored for this `Node` (for Python interaction). + void set_pyobj(PyObject* pyobj) noexcept { + pyobj_ = pyobj; + } + + /// Returns the anomaly metadata stored for this `Node`. + /// If none exist, creates a new empty one. + AnomalyMetadata* metadata() noexcept; + + // Hook API + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + uintptr_t add_post_hook(std::unique_ptr&& post_hook) { + post_hooks_.emplace_back(std::move(post_hook)); + // Use the raw pointer as the unique key to identify this hook. This key + // can then be used in del_post_hook(key) to remove this hook. + return reinterpret_cast(post_hooks_.back().get()); + } + + const std::vector>& post_hooks() + const noexcept { + return post_hooks_; + } + + // delete a post hook matching the key + bool del_post_hook(const uintptr_t& key) { + for (auto it = post_hooks_.begin(); it != post_hooks_.end(); ++it) { + if (key == reinterpret_cast(it->get())) { + post_hooks_.erase(it); + return true; + } + } + return false; + } + + std::vector>& post_hooks() noexcept { + return post_hooks_; + } + + void add_pre_hook(std::unique_ptr&& pre_hook) { + pre_hooks_.emplace_back(std::move(pre_hook)); + } + + void add_tensor_pre_hook(std::unique_ptr&& pre_hook) { + tensor_pre_hooks_.emplace_back(std::move(pre_hook)); + } + + void add_retains_grad_hook( + std::unique_ptr&& pre_hook, + size_t output_idx) { + retains_grad_hooks_[output_idx] = std::move(pre_hook); + } + + std::unique_ptr pop_retains_grad_hook(size_t output_idx) { + auto ret = std::move(retains_grad_hooks_[output_idx]); + retains_grad_hooks_.erase(output_idx); + return ret; + } + + const std::vector>& pre_hooks() + const noexcept { + return pre_hooks_; + } + + std::vector>& pre_hooks() noexcept { + return pre_hooks_; + } + + virtual std::vector>& + tensor_pre_hooks() noexcept { + return tensor_pre_hooks_; + } + + virtual std::unique_ptr& tensor_post_acc_grad_hooks() + const noexcept { + static std::unique_ptr empty = nullptr; + return empty; + } + + std::unordered_map>& + retains_grad_hooks() noexcept { + return retains_grad_hooks_; + } + + // Customization Points for Subclasses + //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + /// Releases saved variables if the operation won't be reused. + virtual void release_variables() {} + + /// Called before an apply if `release_variables()` is going to be called. + /// Allows larger ops like `InterpreterAutogradFunction` to incrementally + /// release variables as they run. + virtual void will_release_variables() {} + + /// Returns true if this function is traceable. An op is traceable if all + /// operations happening within `apply()` are performed on autograd + /// `Variables` (i.e. apply mostly instantiates and applies other functions). + virtual bool is_traceable() { + return false; + } + + /// A `Node` is said to pass state transparently to backward, if the + /// state consists only of (Saved)Variables and only non-variable objects + /// that parameterize the operation in some way that defines the graph + /// structure AND the backward function is traceable. In particular, + /// parametrization MUST NOT depend on the data of any `Variable`. + /// TODO: it might be possible to handle cases where backward is + /// non-traceable but state passing could be considered transparent. This + /// will probably depend on saved_variable_list being mutable. + /// NOTE: this value matters only if is_traceable() returns false. + virtual bool passes_state_transparently() { + return false; + } + + // see [Note: Compiled Autograd] + // Used by compiled autograd to + // 1) Extract tensors/symint args + // 2) Collect node information for specialization and caching + // Implementations in subclasses should call args.collect() with all node + // attrs. These functions are only called during backward. + virtual void compiled_args(CompiledNodeArgs& args) const { + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("compiled_args not implemented: ") + name()); + } + + // Used by compiled autograd to call apply() with different saved tensors + // Implementations should call saved.before() on all attrs, then apply(), then + // saved.after() on all attrs in the same order. + virtual variable_list apply_with_saved( + const variable_list& inputs, + SwapSavedVariables& saved) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, std::string("apply_with_saved not implemented: ") + name()); + } + + // If this node is the AOTBackward node produced by torch.compile. + // Compiled Autograd special-cases on this information. + virtual bool is_aot_backward() const { + return false; + } + + protected: + /// Performs the `Node`'s actual operation. + virtual variable_list apply(variable_list&& inputs) = 0; + + /// Calls `apply()`, but instruments it with tracing machinery. + variable_list traced_apply(variable_list inputs); + + // Sequence number used to correlate backward nodes with forward ops in the + // profiler and provide determinism in the engine. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + uint64_t sequence_nr_; + + // See NOTE [ Topological Number ] + uint64_t topological_nr_ = 0; + + // Tracks whether this node has been added as the next_edge of another node + // via set_next_edge(s), which always calls topological_nr() of all its + // children See NOTE [ Topological Number ] for why we need this. + mutable bool has_parent_ = false; + + // Id of the thread that created the instance + uint64_t thread_id_ = 0; + + // Note [Thread Safety on Autograd Node] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Autograd Engine let the owning thread which calls Engine::execute to drive + // the GraphTask execution, there might be cases that part of the GraphTask is + // shared across different `backward()` or `grad()` calls, i.e. fork new + // threads in the middle of the forward and call `backward()` separately from + // different threads. We need to protect the thread safety on NodeTask to + // prevent data racing on shared variables read/write. + // + // NB: This is only needed for Autograd Nodes that runs on CPU, technically + // "CUDA", "XLA" nodes don't need locking because device threads are always + // single threaded. + // + // Here we add a thread mutex to help protect the Node's thread safety, so + // that different threads cannot race the shared data when executing the same + // NodeTask from multiple CPU threads. It IS the user/developer responsibility + // to take advantage of this mutex to protect the thread safety of their + // autograd Node. The general strategy of thread safety on autograd Node: + // + // 1. User should lock the mutex during Node::release_variables() if the Node + // needs + // to release the variables on the fly, this serve the purpose that when we + // release saved_variables from one thread, no other threads can release + // the saved variables concurrently. call the Node::apply(), + // 2. User should lock the mutex during Node::apply(), this is to ensure Node + // that + // writing to the shared variable are not racing across threads (i.e. + // AccumulateGrad and custom C++ Autograd Node if writing to shared + // variables ) + // 3. item 2 and item 3 should work together so that when we release saved + // variables + // from one thread, no other threads can call Node::apply(), this ensures + // the variable references from other threads aren't dangling. + // 4. if the Node don't release any variables and no shared data read/write in + // the Node + // i.e. purely functional, user don't need to lock the mutex + // + // This way we could protect the thread safety on Autograd Node, but we could + // still not protect the thread safety on Node pre/post C++ hooks (python + // hooks are automatically thread safe), we rely on the user to write thread + // safe C++ hooks if they want the hook to be correctly applied in + // multithreading environment. + std::mutex mutex_; + + edge_list next_edges_; + PyObject* pyobj_ = nullptr; // weak reference + std::unique_ptr anomaly_metadata_ = nullptr; + + // NOTE [Hooks ordering] + // We have 3 separate fields for pre hooks registered to the autograd nodes + // because the conditions under which they execute are different, and we + // want more fine-grained control over the order in which different types + // of hooks are executed. + // - pre_hooks are only executed when the node itself is executed + // - tensor_pre_hook is executed as long as the engine traverses over it + // even if that node won't be executed. + // - retains_grad_hook are like tensor_pre_hooks except they are always + // ordered after all other tensor pre hooks + std::vector> pre_hooks_; + std::vector> tensor_pre_hooks_; + std::unordered_map> + retains_grad_hooks_; + std::vector> post_hooks_; + at::SmallVector input_metadata_; +}; + +/// See Node::is_traceable() for definition. +struct TraceableFunction : public Node { + using Node::Node; + bool is_traceable() final { + return true; + } +}; + +} // namespace torch::autograd diff --git a/torch/csrc/autograd/profiler_kineto.cpp b/torch/csrc/autograd/profiler_kineto.cpp index 76a07b5d85ad8..2c952a9005c21 100644 --- a/torch/csrc/autograd/profiler_kineto.cpp +++ b/torch/csrc/autograd/profiler_kineto.cpp @@ -269,10 +269,10 @@ struct AddGenericMetadata : public MetadataBase { if (arg_data.hasData) { if (get_record_concrete_inputs_enabled()) { addMetadata("Input Dims", variantShapesToStr(arg_data.shapes)); - addMetadata("Input Strides", variantShapesToStr(arg_data.strides)); } else { addMetadata("Input Dims", shapesToStr(arg_data.shapesForKinetoEvent)); } + addMetadata("Input Strides", variantShapesToStr(arg_data.strides)); addMetadata("Input type", strListToStr(arg_data.dtypes)); if (!arg_data.concreteInputs.empty()) { addMetadata( @@ -463,16 +463,7 @@ struct KinetoThreadLocalState : public ProfilerStateBase { auto records_and_trace = recordQueue.getRecords(std::move(converter), startTime, end_time); - materializeOpEvents(records_and_trace.first); - - // `kinetoEvents` does not include Python events. Instead it exposes them - // via the `stacks` property. - kinetoEvents.erase( - std::remove_if( - kinetoEvents.begin(), - kinetoEvents.end(), - [](const auto& i) { return i.isPythonFunction(); }), - kinetoEvents.end()); + materializeOpEvents(records_and_trace.first, end_time); return std::move(records_and_trace.second); } @@ -484,25 +475,34 @@ struct KinetoThreadLocalState : public ProfilerStateBase { } } - void materializeOpEvents(std::vector>& events) { + void materializeOpEvents( + std::vector>& events, + int64_t trace_end_ns) { for (auto& e : events) { if (e->parent_.expired() && e->deviceType() == c10::DeviceType::CPU) { eventTree.push_back(e); } - if (e->finished_) { + // Unfinished events automatically have end time set to trace end time + if (!e->finished_) { e->visit(c10::overloaded( - [this](ExtraFields& i) { invokeCallback(i); }, - [this](ExtraFields& i) { invokeCallback(i); }, + [trace_end_ns](ExtraFields& i) { + i.end_time_ns_ = trace_end_ns; + }, [](auto&) {})); + } - kinetoEvents.emplace_back(e, config_.experimental_config.verbose); - AddTensorboardFields add_tb(e, kinetoEvents.back()); - AddGenericMetadata add_generic(e, &config_); + e->visit(c10::overloaded( + [this](ExtraFields& i) { invokeCallback(i); }, + [this](ExtraFields& i) { invokeCallback(i); }, + [](auto&) {})); - // It is not safe to use the activity after post processing. - e->kineto_activity_ = nullptr; - } + kinetoEvents.emplace_back(e, config_.experimental_config.verbose); + AddTensorboardFields add_tb(e, kinetoEvents.back()); + AddGenericMetadata add_generic(e, &config_); + + // It is not safe to use the activity after post processing. + e->kineto_activity_ = nullptr; } } @@ -876,10 +876,16 @@ std::unique_ptr disableProfiler() { profiler_state_info_ptr = nullptr; auto state_ptr = ProfilerStateBase::pop(); + if (!state_ptr) { + LOG(WARNING) + << "disableProfiler called but no active profiling session found. " + << "This can happen if profiling was cancelled during warmup."; + return std::make_unique(); + } const auto& config = state_ptr->config(); TORCH_CHECK( - state_ptr && isValidDisableState(config.state), - "Can't disable Kineto profiler when it's not running"); + isValidDisableState(config.state), + "Can't disable Kineto profiler: config is not in a valid disable state"); state_ptr->removeCallback(); @@ -947,6 +953,8 @@ KinetoEvent::KinetoEvent( result->visit_if_base>([&](const auto& op) { auto arg_data = parseArgData(op.inputs_, op.concrete_inputs_); shapes_ = std::move(arg_data.shapesForKinetoEvent); + structured_input_shapes_ = std::move(arg_data.shapes); + structured_input_strides_ = std::move(arg_data.strides); dtypes_ = std::move(arg_data.dtypes); concrete_inputs_ = std::move(arg_data.concreteInputs); kwinputs_ = std::move(op.kwinputs_); @@ -959,6 +967,45 @@ bool KinetoEvent::isPythonFunction() const { return out; } +int64_t KinetoEvent::pythonId() const { + int64_t out{-1}; + result_->visit_if_base( + [&](const auto& i) { out = static_cast(i.id_); }); + return out; +} + +int64_t KinetoEvent::pythonParentId() const { + int64_t out{-1}; + // Walk the python parent pointers up to find the next event of type + // PyExtraFieldsBase + result_->visit_if_base([&](const auto&) { + auto parent = result_->parent_.lock(); + while (parent) { + parent->visit_if_base( + [&](const auto& j) { out = static_cast(j.id_); }); + if (out >= 0) { + break; + } + parent = parent->parent_.lock(); + } + }); + return out; +} + +int64_t KinetoEvent::pythonModuleId() const { + int64_t out{-1}; + // Returns the module id for PyCall events (python function calls to + // nn.Module) + result_->visit(c10::overloaded( + [&](const ExtraFields& py_call) { + if (py_call.module_.has_value()) { + out = static_cast(py_call.module_->id_); + } + }, + [](const auto&) {})); + return out; +} + bool KinetoEvent::hasShapes() const { return !shapes_.empty(); } @@ -967,6 +1014,16 @@ const c10::ArrayRef> KinetoEvent::shapes() const { return shapes_; } +const c10::ArrayRef KinetoEvent:: + structuredInputShapes() const { + return structured_input_shapes_; +} + +const c10::ArrayRef KinetoEvent:: + structuredInputStrides() const { + return structured_input_strides_; +} + bool KinetoEvent::hasTypes() const { return !dtypes_.empty(); } @@ -1078,7 +1135,6 @@ int64_t KinetoEvent::privateuse1ElapsedUs() const { } return (int64_t)torch::profiler::impl::privateuse1Stubs()->elapsed( &privateuse1_event_start, &privateuse1_event_end); - return -1; } void KinetoEvent::getPerfEventCounters(std::vector& in) const { @@ -1107,6 +1163,40 @@ std::string KinetoEvent::metadataJson() const { [](const auto&) -> std::string { return std::string(""); })); } +int64_t KinetoEvent::externalId() const { + // Mirrors libkineto::ChromeTraceLogger::handleActivity() "External id" logic. + // libkineto::ChromeTraceLogger checks op.linkedActivity() != nullptr; here we + // check linkedCorrelationId() > 0, which is equivalent because PyTorch + // correlation IDs are monotonically increasing from 1 (a valid linked + // activity always has a non-zero correlation ID). + uint64_t linked = linkedCorrelationId(); + if (linked > 0) { + return static_cast(linked); + } + + // Orphaned GPU activities (no linked CPU op) in these types should not get + // an External id, to avoid incorrect cross-linking in trace viewers. + auto type = static_cast(activityType()); + if (type != libkineto::ActivityType::GPU_MEMCPY && + type != libkineto::ActivityType::GPU_MEMSET && + type != libkineto::ActivityType::CONCURRENT_KERNEL && + type != libkineto::ActivityType::CUDA_RUNTIME && + type != libkineto::ActivityType::CUDA_DRIVER && + type != libkineto::ActivityType::PRIVATEUSE1_RUNTIME && + type != libkineto::ActivityType::PRIVATEUSE1_DRIVER) { + return static_cast(result_->visit(c10::overloaded( + [](const ExtraFields& e) -> uint64_t { + return e.correlation_id_; + }, + [](const ExtraFields& e) -> uint64_t { + return e.correlation_id_; + }, + [](const auto&) -> uint64_t { return 0; }))); + } + + return 0; +} + #define FORWARD_FROM_RESULT(method_name, result_expr) \ decltype(std::declval().method_name()) \ KinetoEvent::method_name() const { \ @@ -1148,7 +1238,16 @@ TYPED_ATTR(TorchOp, fwdThreadId, e.sequence_number_ >= 0 ? e.forward_tid_ : 0) TYPED_ATTR(TorchOp, scope, static_cast(e.scope_)) TYPED_ATTR(TorchOp, hasModuleHierarchy, !e.jit_modules_.empty()) TYPED_ATTR(TorchOp, isAsync, e.is_async_) -TYPED_ATTR(TorchOp, extraMeta, e.extra_meta_) + +extra_meta_t KinetoEvent::extraMeta() const { + extra_meta_t out; + result_->visit(c10::overloaded( + [&](const ExtraFields& e) { out = e.extra_meta_; }, + [&](const ExtraFields& e) { out = e.extra_meta_; }, + [](const auto&) {})); + return out; +} + TYPED_ATTR(TorchOp, fallbackStart, e.device_fallback_.device_event_start_) TYPED_ATTR(TorchOp, fallbackEnd, e.device_fallback_.device_event_end_) TYPED_ATTR( @@ -1166,6 +1265,30 @@ TYPED_ATTR(Kineto, linkedCorrelationId, [&]() { #undef TYPED_ATTR #undef TYPED_ATTR_WITH_DEFAULT +// Flow fields exist on both TorchOp and Kineto event types. +uint32_t KinetoEvent::flowId() const { + return result_->visit(c10::overloaded( + [](const ExtraFields& e) { return e.flow.id; }, + [](const ExtraFields& e) { return e.flow.id; }, + [](const auto&) -> uint32_t { return 0; })); +} +uint32_t KinetoEvent::flowType() const { + return result_->visit(c10::overloaded( + [](const ExtraFields& e) { return e.flow.type; }, + [](const ExtraFields& e) { return e.flow.type; }, + [](const auto&) -> uint32_t { return 0; })); +} +bool KinetoEvent::flowStart() const { + return result_->visit(c10::overloaded( + [](const ExtraFields& e) { + return static_cast(e.flow.start); + }, + [](const ExtraFields& e) { + return static_cast(e.flow.start); + }, + [](const auto&) { return false; })); +} + ProfilerResult::ProfilerResult( uint64_t start_time, std::vector events, diff --git a/torch/csrc/autograd/profiler_kineto.h b/torch/csrc/autograd/profiler_kineto.h index ab0b792716eeb..4c7db2cae8aa7 100644 --- a/torch/csrc/autograd/profiler_kineto.h +++ b/torch/csrc/autograd/profiler_kineto.h @@ -58,6 +58,10 @@ struct TORCH_API KinetoEvent { bool isAsync() const; uint64_t correlationId() const; uint64_t linkedCorrelationId() const; + uint32_t flowId() const; + uint32_t flowType() const; + bool flowStart() const; + int64_t externalId() const; int64_t deviceResourceId() const; std::string backend() const; bool isPythonFunction() const; @@ -67,6 +71,14 @@ struct TORCH_API KinetoEvent { extra_meta_t extraMeta() const; std::string metadataJson() const; + const c10::ArrayRef structuredInputShapes() + const; + const c10::ArrayRef structuredInputStrides() + const; + int64_t pythonId() const; + int64_t pythonParentId() const; + int64_t pythonModuleId() const; + private: torch::profiler::impl::ProfilerVoidEventStub fallbackStart() const; torch::profiler::impl::ProfilerVoidEventStub fallbackEnd() const; @@ -79,6 +91,8 @@ struct TORCH_API KinetoEvent { std::vector dtypes_; std::vector concrete_inputs_; std::unordered_map kwinputs_; + std::vector structured_input_shapes_; + std::vector structured_input_strides_; }; // Consolidating events returned directly from Kineto diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index a45935ecb2995..9e1c2e8e18118 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -18,6 +19,8 @@ #include #include #include +#include +#include #include #include #include @@ -551,13 +554,13 @@ struct TraceKeyCacheState { // ============================================================================ // == Core CPython data types ================================================= // ============================================================================ -// PyObject that allows different threads to record events without colliding. -// It is passed as the second argument when enabling tracing via -// `PyEval_SetProfile`. -struct ThreadLocalResults; +// PyObject passed as the second argument when enabling tracing via +// `PyEval_SetProfile`. A single shared instance is used for all threads; +// the callback resolves per-thread state via PyThreadState_Get(). +class PythonTracer; struct TraceContext { PyObject_HEAD - ThreadLocalResults* thread_local_results_; + PythonTracer* tracer_; }; // CPython boilerplate to define `TraceContext` as a proper python object. @@ -604,43 +607,39 @@ static PyTypeObject TraceContextType = { nullptr /* tp_free */ }; -class gil_and_restore_thread { +#if IS_PYTHON_3_14_PLUS +extern "C" void _PyEval_StopTheWorld(PyInterpreterState*); +extern "C" void _PyEval_StartTheWorld(PyInterpreterState*); + +class StopTheWorldGuard { public: - gil_and_restore_thread() : initial_thread_state_{PyThreadState_Get()} {} - ~gil_and_restore_thread() { - PyThreadState_Swap(initial_thread_state_); - - // `gil_scoped_acquire` is a bit fragile in on-demand mode: - // https://github.com/pytorch/pytorch/pull/91684#issuecomment-1413154458 - if (!Py_IsInitialized()) { - gil_.disarm(); - } + explicit StopTheWorldGuard(PyInterpreterState* interp) : interp_(interp) { + _PyEval_StopTheWorld(interp_); } - - PyThreadState* initial_thread_state() const { - return initial_thread_state_; + ~StopTheWorldGuard() { + _PyEval_StartTheWorld(interp_); } + StopTheWorldGuard(const StopTheWorldGuard&) = delete; + StopTheWorldGuard& operator=(const StopTheWorldGuard&) = delete; private: - pybind11::gil_scoped_acquire gil_; - PyThreadState* initial_thread_state_; + PyInterpreterState* interp_; }; +#else +class StopTheWorldGuard { + public: + explicit StopTheWorldGuard(PyInterpreterState*) {} + StopTheWorldGuard(const StopTheWorldGuard&) = delete; + StopTheWorldGuard& operator=(const StopTheWorldGuard&) = delete; +}; +#endif // ============================================================================ // == Thread local cache ====================================================== // ============================================================================ -class PythonTracer; struct ThreadLocalResults { - ThreadLocalResults( - PyThreadState* thread_state, - ValueCache* value_cache, - PythonTracer* active_tracer) - : thread_state_{thread_state}, - ctx_{(TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0)}, - value_cache_{value_cache}, - active_tracer_{active_tracer} { - ctx_->thread_local_results_ = this; - } + ThreadLocalResults(PythonTracer* active_tracer) + : active_tracer_{active_tracer} {} ThreadLocalResults() = delete; ThreadLocalResults(const ThreadLocalResults&) = delete; @@ -648,30 +647,18 @@ struct ThreadLocalResults { ThreadLocalResults& operator=(const ThreadLocalResults&) = delete; ThreadLocalResults& operator=(const ThreadLocalResults&&) = delete; - ~ThreadLocalResults() { - // Currently, there is a bug in Profiler when using Python 3.12 that causes - // a segfault when decrementing the refcount of a TraceContext during - // on-demand. We are purposefully allowing for a small leak in this - // situation to avoid the segfault. This should be fixed in the future. -#if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 12) - Py_DECREF((PyObject*)ctx_); -#endif - } - template TraceKey intern(Ephemeral ephemeral, Args... args) { static_assert( Config::event_type == E, "ThreadLocalResults.intern called from the wrong typed context."); auto callsite = Callsite(std::forward(args)...); - return std::get(trace_keys_).intern(callsite, ephemeral, *value_cache_); + return std::get(trace_keys_).intern(callsite, ephemeral, value_cache_); } static constexpr size_t BLOCK_SIZE = 1024; - PyThreadState* thread_state_; - TraceContext* ctx_; - ValueCache* value_cache_; + ValueCache value_cache_; PythonTracer* active_tracer_; CallTypeHelper::tuple_type trace_keys_; AppendOnlyList exit_times_; @@ -679,6 +666,12 @@ struct ThreadLocalResults { int active_frames_{0}; int remaining_start_frames_{0}; + + // Guards against teardown racing with in-flight callbacks. + // pyProfileFn acquires this on entry and releases on exit. + // PythonTracer::stop() acquires each thread's semaphore after + // clearing the profiling callback to ensure all callbacks have finished. + c10::Semaphore profile_sem{1}; }; // ============================================================================ @@ -733,7 +726,10 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PyObject* arg, bool start_frame = false); + ThreadLocalResults* findThreadLocalResults(PyThreadState* tstate) const; + const std::vector interpreterThreads() const; + void setprofileAllThreads(Py_tracefunc func, PyObject* arg) const; std::atomic active_lock_{false}; bool active_{false}; @@ -743,10 +739,12 @@ class PythonTracer final : public python_tracer::PythonTracerBase { PyInterpreterState* interpreter_{nullptr}; PyCodeObject* module_call_code_; PyCodeObject* optimizer_hook_; + TraceContext* shared_ctx_{nullptr}; std::vector start_frames_; std::deque thread_local_results_; - ValueCache value_cache_; + std::unordered_map + thread_local_results_map_; #if IS_PYTHON_3_12 friend PyObject* c_call_callback( @@ -833,12 +831,14 @@ static PyObject* c_call_callback( PyExc_SystemError, "Missing frame when calling profile function."); return NULL; } - Py_INCREF(frame); - auto& local_results = - *reinterpret_cast(tstate->c_profileobj) - ->thread_local_results_; - local_results.active_tracer_->recordCCall(local_results, frame, func); - Py_DECREF(frame); + auto* tracer = + reinterpret_cast(tstate->c_profileobj)->tracer_; + auto* local_results = tracer->findThreadLocalResults(tstate); + if (local_results) { + Py_INCREF(frame); + local_results->active_tracer_->recordCCall(*local_results, frame, func); + Py_DECREF(frame); + } } } Py_RETURN_NONE; @@ -965,6 +965,12 @@ static void unregisterMonitoringCallback() { } #endif +ThreadLocalResults* PythonTracer::findThreadLocalResults( + PyThreadState* tstate) const { + auto it = thread_local_results_map_.find(tstate); + return it != thread_local_results_map_.end() ? it->second : nullptr; +} + const std::vector PythonTracer::interpreterThreads() const { pybind11::gil_scoped_acquire gil; std::vector out; @@ -978,6 +984,19 @@ const std::vector PythonTracer::interpreterThreads() const { return out; } +void PythonTracer::setprofileAllThreads(Py_tracefunc func, PyObject* arg) + const { +#if IS_PYTHON_3_13_PLUS + PyEval_SetProfileAllThreads(func, arg); +#else + for (const auto thread_state : interpreterThreads()) { + if (_PyEval_SetProfile(thread_state, func, arg) < 0) { + PyErr_WriteUnraisable(nullptr); + } + } +#endif +} + // we are only registering on main thread while holding GIL so this should be // safe static PyObject* py_gc_callback = nullptr; @@ -1015,56 +1034,65 @@ PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) return; } - gil_and_restore_thread gil; - interpreter_ = PyInterpreterState_Get(); - - if (!gil.initial_thread_state()) { - TORCH_WARN("PyThreadState_Get returned NULL"); - return; - } - - // Register the tracer in each thread. - for (const auto thread_state : interpreterThreads()) { - PyThreadState_Swap(thread_state); - - thread_local_results_.emplace_back(thread_state, &value_cache_, this); - auto& tls = thread_local_results_.back(); - auto* ctx = tls.ctx_; +#if defined(Py_GIL_DISABLED) && !defined(IS_PYTHON_3_14_PLUS) + TORCH_WARN( + "The PyTorch profiler is not thread-safe on Python 3.13t. " + "Please use Python 3.14t or later."); +#endif - // When we begin profiling there are already frames on the Python - // interpreter stack. To ensure a complete trace, we must push calls - // to all the prior frames onto our event stack. (We stop at depth=128) + pybind11::gil_scoped_acquire gil; + interpreter_ = PyInterpreterState_Get(); - std::vector current_stack; - auto frame = PyEval_GetFrame(); - Py_XINCREF(frame); + // Shared context passed as the profile arg to all threads. + shared_ctx_ = (TraceContext*)TraceContextType.tp_alloc(&TraceContextType, 0); + shared_ctx_->tracer_ = this; + + // Enable profiling on all threads. setprofileAllThreads handles its own + // synchronization (stop-the-world on free-threaded builds). The callback + // returns early because findThreadLocalResults returns nullptr for threads + // we haven't set up yet. + // Note: This profile will not compose with other CPython profilers, and + // cannot be round tripped via `sys.settrace(sys.gettrace())` + setprofileAllThreads(PythonTracer::pyProfileFn, (PyObject*)shared_ctx_); + + // Capture existing frames on each thread's stack. + { + StopTheWorldGuard stw(interpreter_); + for (const auto thread_state : interpreterThreads()) { + thread_local_results_.emplace_back(this); + auto& tls = thread_local_results_.back(); + thread_local_results_map_[thread_state] = &tls; + + // When we begin profiling there are already frames on the Python + // interpreter stack. To ensure a complete trace, we must push calls + // to all the prior frames onto our event stack. (We stop at depth=128) + + // NB: `PyThreadState_GetFrame` returns a strong reference. + std::vector current_stack; + auto frame = PyThreadState_GetFrame(thread_state); + + size_t depth = 0; // Make sure we can't infinite loop. + while (frame != nullptr) { + current_stack.emplace_back(frame); + if (++depth == 128) { + break; + } - size_t depth = 0; // Make sure we can't infinite loop. - while (frame != nullptr) { - current_stack.emplace_back(frame); - if (++depth == 128) { - break; + // NB: `PyFrame_GetBack` returns a strong reference. + frame = PyFrame_GetBack(frame); } - // NB: `PyFrame_GetBack` returns a strong reference. - frame = PyFrame_GetBack(frame); - } + for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { + recordPyCall(tls, it->get(), true); + auto frame_refcount = Py_REFCNT(it->get()); - for (auto it = current_stack.rbegin(); it != current_stack.rend(); it++) { - recordPyCall(tls, it->get(), true); - auto frame_refcount = Py_REFCNT(it->get()); + // We hold one reference in `current_stack`, and the interpreter holds + // another. + TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); + } - // We hold one reference in `current_stack`, and the interpreter holds - // another. - TORCH_INTERNAL_ASSERT(frame_refcount >= 2, frame_refcount); + tls.remaining_start_frames_ = tls.active_frames_; } - - tls.remaining_start_frames_ = tls.active_frames_; - - // Note: - // This profile will not compose with other CPython profilers, and - // cannot be round tripped via `sys.settrace(sys.gettrace())` - PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); } #if IS_PYTHON_3_12 registerMonitoringCallback(); @@ -1092,6 +1120,7 @@ void unregister_gc_callback() { PySequence_DelItem(callbacks, idx); } else { // Not found, maybe already removed + PyErr_Clear(); } Py_DECREF(callbacks); Py_DECREF(gc_module); @@ -1134,16 +1163,23 @@ void PythonTracer::register_gc_callback() { } void PythonTracer::stop() { - gil_and_restore_thread gil; + pybind11::gil_scoped_acquire gil; if (gc_callback_registered_) { unregister_gc_callback(); gc_callback_registered_ = false; } if (active_) { - for (const auto thread_state : interpreterThreads()) { - if (thread_state->c_profilefunc == &PythonTracer::pyProfileFn) { - PyThreadState_Swap(thread_state); - PyEval_SetProfile(nullptr, nullptr); + setprofileAllThreads(nullptr, nullptr); + + // Wait for any in-flight pyProfileFn callbacks to finish. Threads inside + // pyProfileFn hold their thread's profile_sem. They may have temporarily + // released the GIL or parked mid-callback due to a stop-the-world event. + // Acquiring each semaphore here blocks until those callbacks complete. + { + pybind11::gil_scoped_release release; + for (auto& tls : thread_local_results_) { + tls.profile_sem.acquire(); + tls.profile_sem.release(); } } @@ -1158,7 +1194,7 @@ void PythonTracer::stop() { } void PythonTracer::restart() { - gil_and_restore_thread gil; + pybind11::gil_scoped_acquire gil; active_ = active_lock_.compare_exchange_strong(active_, true); if (!active_) { TORCH_WARN( @@ -1166,14 +1202,7 @@ void PythonTracer::restart() { "Refusing to register profile functions."); return; } - int cur_thread = 0; - for (const auto thread_state : interpreterThreads()) { - if (thread_state->c_profilefunc == nullptr) { - auto* ctx = thread_local_results_[cur_thread].ctx_; - PyThreadState_Swap(thread_state); - PyEval_SetProfile(PythonTracer::pyProfileFn, (PyObject*)ctx); - } - } + setprofileAllThreads(PythonTracer::pyProfileFn, (PyObject*)shared_ctx_); #if IS_PYTHON_3_12 registerMonitoringCallback(); #endif @@ -1185,6 +1214,10 @@ PythonTracer::~PythonTracer() { TORCH_WARN("`PythonTracer::stop()` was not called."); stop(); } + if (Py_IsInitialized() && !Py_IsFinalizing()) { + pybind11::gil_scoped_acquire gil; + Py_XDECREF((PyObject*)shared_ctx_); + } } void PythonTracer::recordPyCall( @@ -1207,13 +1240,13 @@ void PythonTracer::recordPyCall( auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); #if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13) - auto self = THPObjectPtr(PyDict_GetItemString(locals, "self")); + auto self = + THPObjectPtr(Py_XNewRef(PyDict_GetItemString(locals, "self"))); #else // In Python-3.13+ `PyFrame_GetLocals()` returns instance of // PyFrameLocalsProxy_Type See PEP 667 for more info auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self")); #endif - Py_INCREF(self.get()); auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); TORCH_INTERNAL_ASSERT(back != nullptr); return tls.intern( @@ -1221,11 +1254,11 @@ void PythonTracer::recordPyCall( } else if (code.get() == optimizer_hook_) { auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); #if PY_MAJOR_VERSION < 3 || (PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 13) - auto self = THPObjectPtr(PyDict_GetItemString(locals, "self")); + auto self = + THPObjectPtr(Py_XNewRef(PyDict_GetItemString(locals, "self"))); #else auto self = THPObjectPtr(PyMapping_GetItemString(locals, "self")); #endif - Py_INCREF(self.get()); auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); TORCH_INTERNAL_ASSERT(back != nullptr); return tls.intern( @@ -1279,12 +1312,14 @@ class PostProcess { PostProcess( std::function time_converter, std::deque& tls, - const ValueCache& value_cache, c10::time_t end_time_ns) : end_time_{end_time_ns}, time_converter_{std::move(time_converter)} { for (size_t python_tid : c10::irange(tls.size())) { CallTypeHelper::map( - tls[python_tid].trace_keys_, *this, value_cache, python_tid); + tls[python_tid].trace_keys_, + *this, + tls[python_tid].value_cache_, + python_tid); addExits(tls[python_tid].exit_times_, python_tid); addExits(tls[python_tid].c_exit_times_, python_tid); @@ -1445,12 +1480,11 @@ std::vector> PythonTracer::getEvents( std::function time_converter, std::vector& enters, c10::time_t end_time_ns) { - value_cache_.trimPrefixes(); + for (auto& tls : thread_local_results_) { + tls.value_cache_.trimPrefixes(); + } PostProcess post_process( - std::move(time_converter), - thread_local_results_, - value_cache_, - end_time_ns); + std::move(time_converter), thread_local_results_, end_time_ns); post_process.set_start_frames(start_frames_, enters); auto out = post_process.run(enters); @@ -1493,17 +1527,15 @@ static void toggle_memory_tracing(bool enable) { } // Call the function with arguments PyObject* args = PyTuple_New(6); - PyTuple_SetItem(args, 0, enable ? PyUnicode_FromString("all") : Py_None); + PyTuple_SetItem( + args, 0, enable ? PyUnicode_FromString("all") : Py_NewRef(Py_None)); PyTuple_SetItem(args, 1, PyUnicode_FromString("all")); // context PyTuple_SetItem(args, 2, PyUnicode_FromString("all")); // stacks PyTuple_SetItem(args, 3, THPUtils_packInt64(100000)); // max_entries - PyTuple_SetItem(args, 4, Py_None); // device (None) + PyTuple_SetItem(args, 4, Py_NewRef(Py_None)); // device (None) PyTuple_SetItem(args, 5, PyBool_FromLong(0)); // clear_history (False) - PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); + THPObjectPtr result(PyObject_Call(snapshot_func.get(), args, nullptr)); Py_DECREF(args); - if (result == nullptr) { - return; - } } void PythonMemoryTracer::start() { @@ -1522,14 +1554,9 @@ void PythonMemoryTracer::export_memory_history(const std::string& path) { if (!snapshot_func) { return; } - PyObject* py_filename = PyUnicode_FromString(path.c_str()); - // Call the function with arguments (e.g., a file path) - PyObject* args = PyTuple_Pack(1, py_filename); - PyObject* result = PyObject_Call(snapshot_func.get(), args, nullptr); - Py_DECREF(args); - if (result == nullptr) { - return; - } + THPObjectPtr py_filename(PyUnicode_FromString(path.c_str())); + THPObjectPtr result( + PyObject_CallOneArg(snapshot_func.get(), py_filename.get())); } void PythonMemoryTracer::stop() { @@ -1544,36 +1571,47 @@ int PythonTracer::pyProfileFn( PyFrameObject* frame, int what, PyObject* arg) { - auto& local_results = - *reinterpret_cast(obj)->thread_local_results_; + HANDLE_TH_ERRORS + auto* tracer = reinterpret_cast(obj)->tracer_; + auto* local_results = tracer->findThreadLocalResults(PyThreadState_Get()); + if (C10_UNLIKELY(!local_results)) { + return 0; + } + bool acquired = local_results->profile_sem.tryAcquire(); + TORCH_INTERNAL_ASSERT(acquired, "pyProfileFn: profile_sem unexpectedly held"); + // RAII release: ensures the semaphore is released on both normal + // return and C++ exception paths (e.g. from pybind11 in ValueCache::store). + auto release_sem = + c10::make_scope_exit([&]() { local_results->profile_sem.release(); }); switch (what) { case PyTrace_CALL: - local_results.active_tracer_->recordPyCall(local_results, frame, false); + local_results->active_tracer_->recordPyCall(*local_results, frame, false); break; case PyTrace_C_CALL: - local_results.active_tracer_->recordCCall(local_results, frame, arg); + local_results->active_tracer_->recordCCall(*local_results, frame, arg); break; case PyTrace_RETURN: - local_results.exit_times_.emplace_back(c10::getApproximateTime()); - local_results.active_frames_--; - if (local_results.active_frames_ < - local_results.remaining_start_frames_) { - local_results.remaining_start_frames_ = local_results.active_frames_; + local_results->exit_times_.emplace_back(c10::getApproximateTime()); + local_results->active_frames_--; + if (local_results->active_frames_ < + local_results->remaining_start_frames_) { + local_results->remaining_start_frames_ = local_results->active_frames_; } break; case PyTrace_C_EXCEPTION: case PyTrace_C_RETURN: - if (local_results.active_frames_ > - local_results.remaining_start_frames_) { - local_results.c_exit_times_.emplace_back(c10::getApproximateTime()); - local_results.active_frames_--; + if (local_results->active_frames_ > + local_results->remaining_start_frames_) { + local_results->c_exit_times_.emplace_back(c10::getApproximateTime()); + local_results->active_frames_--; } break; } return 0; + END_HANDLE_TH_ERRORS_RET(-1) } std::unique_ptr getTracer( diff --git a/torch/csrc/autograd/python_cpp_function.cpp b/torch/csrc/autograd/python_cpp_function.cpp index 6787df7080ad8..3dc9647251395 100644 --- a/torch/csrc/autograd/python_cpp_function.cpp +++ b/torch/csrc/autograd/python_cpp_function.cpp @@ -44,7 +44,7 @@ PyObject* THPCppFunction_call( variable_list vars(num_inputs); for (int i = 0; i != num_inputs; ++i) { PyObject* arg = PyTuple_GET_ITEM(args, i); - if (arg == Py_None) { + if (Py_IsNone(arg)) { continue; } if (!THPVariable_Check(arg)) { @@ -326,16 +326,11 @@ void registerCppFunction(const std::type_info& type, PyTypeObject* pytype) { } bool THPCppFunction_Check(PyObject* obj) { - THPObjectPtr type = THPObjectPtr(PyObject_Type(obj)); - if ((PyTypeObject*)type.get() == get_default_type()) { - return true; - } - if (cpp_function_types_set.find((PyTypeObject*)type.get()) == - cpp_function_types_set.end()) { - return false; - } else { + PyTypeObject* type = Py_TYPE(obj); + if (type == get_default_type()) { return true; } + return cpp_function_types_set.contains(type); } static PyObject* callRegisterFn(PyObject* dict, PyObject* hook) { @@ -364,7 +359,7 @@ PyObject* registerFunctionHook(Node& fn, PyObject* hook) { if (!res) { return nullptr; } - if (dict == Py_None) { + if (Py_IsNone(dict)) { dict = PyTuple_GET_ITEM(res.get(), 0); fn.add_post_hook(std::make_unique(dict)); } @@ -387,7 +382,7 @@ PyObject* registerFunctionPreHook(Node& fn, PyObject* hook) { if (!res) { return nullptr; } - if (dict == Py_None) { + if (Py_IsNone(dict)) { dict = PyTuple_GET_ITEM(res.get(), 0); fn.add_pre_hook(std::make_unique(dict)); } diff --git a/torch/csrc/autograd/python_engine.cpp b/torch/csrc/autograd/python_engine.cpp index 8a52306e91830..13e29c534d15f 100644 --- a/torch/csrc/autograd/python_engine.cpp +++ b/torch/csrc/autograd/python_engine.cpp @@ -285,7 +285,7 @@ static PyObject* THPEngine_run_backward( grads.push_back(grad_var); } else { TORCH_CHECK( - grad == Py_None, + Py_IsNone(grad), "element ", i, " of gradients tuple is not a Tensor or None"); diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 2b97fd593cfe4..828abc3d8dcf6 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -160,11 +162,42 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { // Massage a C++ variable_list into a Python arguments tuple THPObjectPtr pyInputs(to_py_args(inputs, &_device_guard)); + inputs.clear(); + + THPObjectPtr r; + if (py_fn->boxed_grads_call) { + // Move grad tensors from the immutable args tuple into a plain list + // and call apply_boxed instead of apply. This lets backward pop/clear + // individual grads to free memory mid-execution, because the mutable + // list (not the C++ tuple) is the only container holding grad refs. + auto num_inputs = PyTuple_GET_SIZE(pyInputs.get()); + THPObjectPtr gradsList(PyList_New(num_inputs)); + if (!gradsList) + throw_python_error(); + for (Py_ssize_t i = 0; i < num_inputs; i++) { + PyObject* item = PyTuple_GET_ITEM(pyInputs.get(), i); + Py_INCREF(item); + PyList_SET_ITEM(gradsList.get(), i, item); + } + // Release the tuple so its refs to individual grads are dropped + pyInputs = nullptr; - THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); - if (!apply_fn) - throw_python_error(); - THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); + THPObjectPtr boxedArgs(PyTuple_New(1)); + if (!boxedArgs) + throw_python_error(); + PyTuple_SET_ITEM(boxedArgs.get(), 0, gradsList.release()); + + THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply_boxed")); + if (!apply_fn) + throw_python_error(); + r = THPObjectPtr(PyObject_CallObject(apply_fn, boxedArgs.get())); + } else { + THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply")); + if (!apply_fn) + throw_python_error(); + r = THPObjectPtr(PyObject_CallObject(apply_fn, pyInputs.get())); + } + pyInputs = nullptr; if (!r) throw_python_error(); ensure_tuple(r); @@ -177,7 +210,7 @@ auto PyNode::apply(variable_list&& inputs) -> variable_list { if (num_outputs > num_forward_inputs) { bool all_none = true; for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { - all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None; + all_none &= Py_IsNone(PyTuple_GET_ITEM(r.get(), i)); } if (all_none) { num_outputs = num_forward_inputs; @@ -304,7 +337,7 @@ auto PyNode::is_traceable() -> bool { PyObject_GetAttrString(forward_class, "is_traceable")}; if (!traceable_py_bool) throw_python_error(); - return traceable_py_bool == Py_True; + return Py_IsTrue(traceable_py_bool); } auto PyNode::release_variables() -> void { @@ -474,7 +507,7 @@ variable_list PyNode::to_variable_list( bool was_variable = is_variable_input[i]; if (!was_variable) { TORCH_CHECK( - output == Py_None, + Py_IsNone(output), "function ", name(), " returned a gradient different than None at position ", @@ -482,7 +515,7 @@ variable_list PyNode::to_variable_list( ", but the corresponding forward input was not a Variable"); continue; } - if (output == Py_None) { + if (Py_IsNone(output)) { results.emplace_back(); } else { TORCH_CHECK( @@ -714,7 +747,7 @@ static void _wrap_outputs( results.reserve(num_outputs); for (const auto i : c10::irange(num_outputs)) { PyObject* output = PyTuple_GET_ITEM(r.get(), i); - if (output == Py_None) { + if (Py_IsNone(output)) { results.emplace_back(); } else { TORCH_CHECK( @@ -820,7 +853,7 @@ static void _get_tensors_to_save( Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save); for (const auto i : c10::irange(num_saved)) { PyObject* obj = PyTuple_GET_ITEM(self->to_save, i); - if (obj == Py_None) { + if (Py_IsNone(obj)) { tensors_to_save.emplace_back(std::nullopt); continue; } else if (THPVariable_Check(obj)) { @@ -1097,6 +1130,7 @@ void _trace_post_record( } std::vector trace_outputs; + trace_outputs.reserve(static_cast(std::max(0, num_outputs))); for (const auto i : c10::irange(num_outputs)) { PyObject* obj = PyTuple_GET_ITEM(output_objects, i); if (THPVariable_Check(obj)) { @@ -1303,11 +1337,14 @@ THPObjectPtr make_ctx_input_output_tuple( return result; } -static PyObject* THPFunction_setup_context = nullptr; - static PyObject* get_base_setup_context() { - if (THPFunction_setup_context != nullptr) { - return THPFunction_setup_context; + // NOTE: THPFunction_setup_context is intentionally leaked and never freed. + static std::atomic THPFunction_setup_context = nullptr; + + PyObject* setup_context = + THPFunction_setup_context.load(std::memory_order_acquire); + if (setup_context != nullptr) { + return setup_context; } auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function")); @@ -1321,11 +1358,17 @@ static PyObject* get_base_setup_context() { // setup_context gets "leaked" - we return a new reference and hold onto it // forever. - auto setup_context = PyObject_GetAttrString(function, "setup_context"); + setup_context = PyObject_GetAttrString(function, "setup_context"); if (!setup_context) return nullptr; - THPFunction_setup_context = setup_context; - return THPFunction_setup_context; + + PyObject* expected = nullptr; + if (!THPFunction_setup_context.compare_exchange_strong( + expected, setup_context, std::memory_order_acq_rel)) { + Py_DECREF(setup_context); + return expected; + } + return setup_context; } PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { @@ -1386,7 +1429,17 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { PyBool_Check(clear_attr.get()), "clear_saved_tensors_on_access must be a bool, got ", Py_TYPE(clear_attr.get())->tp_name); - ctx->clear_saved_tensors_on_access = clear_attr.get() == Py_True; + ctx->clear_saved_tensors_on_access = Py_IsTrue(clear_attr.get()); + + // Get boxed_grads_call from the Function class + THPObjectPtr boxed_attr(PyObject_GetAttrString(cls, "boxed_grads_call")); + TORCH_CHECK( + boxed_attr, "autograd.Function is missing boxed_grads_call attribute"); + TORCH_CHECK( + PyBool_Check(boxed_attr.get()), + "boxed_grads_call must be a bool, got ", + Py_TYPE(boxed_attr.get())->tp_name); + ctx->boxed_grads_call = Py_IsTrue(boxed_attr.get()); // autograd.Function may optionally override a setup_context staticmethod. // In this case, autograd.Function.forward does NOT accept a ctx object. @@ -1504,7 +1557,7 @@ int THPFunction_set_materialize_grads( value, nullptr, "set_materialize_grads", 1, "(bool)"); return -1; } - self->materialize_grads = (value == Py_True); + self->materialize_grads = (Py_IsTrue(value)); return 0; END_HANDLE_TH_ERRORS_RET(-1) } @@ -1518,7 +1571,7 @@ int THPFunction_set_pure_view( THPUtils_invalidArguments(value, nullptr, "set_pure_view", 1, "(bool)"); return -1; } - self->pure_view = (value == Py_True); + self->pure_view = (Py_IsTrue(value)); return 0; END_HANDLE_TH_ERRORS_RET(-1) } @@ -1545,7 +1598,7 @@ int THPFunction_set_materialize_non_diff_grads( value, nullptr, "set_materialize_non_diff_grads", 1, "(bool)"); return -1; } - self->materialize_non_diff_grads = (value == Py_True); + self->materialize_non_diff_grads = (Py_IsTrue(value)); return 0; END_HANDLE_TH_ERRORS_RET(-1) } @@ -1743,7 +1796,7 @@ PyObject* getObject(PyObject* obj, void* _unused) { template int setObject(PyObject* obj, PyObject* value, void* _unused) { auto self = (THPFunction*)obj; - if (value == Py_None) { + if (Py_IsNone(value)) { value = nullptr; } Py_XDECREF((self->*ptr)); @@ -1912,9 +1965,7 @@ PyTypeObject THPFunctionType = { }; bool THPFunction_initModule(PyObject* module) { - if (PyType_Ready(&THPFunctionType) < 0) + if (PyModule_AddType(module, &THPFunctionType) < 0) return false; - Py_INCREF(&THPFunctionType); - PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType); return true; } diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index 69d6bfd4fff3c..28c1c58fdf127 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -120,6 +120,14 @@ struct THPFunction { // https://github.com/pytorch/pytorch/pull/98659#pullrequestreview-1376822560 bool materialize_non_diff_grads; + // When true, PyNode::apply passes grads as a single mutable list argument + // instead of individual args in an immutable tuple, allowing backward to + // free individual grads mid-execution and reduce peak memory. + // Used by pt2 compiled AutogradFunctions: the standard calling convention + // keeps a reference to all grads (via the immutable args tuple) for the + // entire backward, preventing deallocation after last use. + bool boxed_grads_call = false; + PyObject* compiled_autograd_backward_state; std::vector compiled_autograd_symints; diff --git a/torch/csrc/autograd/python_hook.cpp b/torch/csrc/autograd/python_hook.cpp index 8d2fd0b996708..0b9af04bc0a05 100644 --- a/torch/csrc/autograd/python_hook.cpp +++ b/torch/csrc/autograd/python_hook.cpp @@ -74,7 +74,7 @@ bool _call_hooks(PyObject* dict, PyObject* args) { THPObjectPtr res(PyObject_CallObject(hook, args)); if (!res) throw python_error(); - if (res == Py_None) + if (Py_IsNone(res)) continue; PyObject* args0 = PyTuple_GetItem(args, 0); @@ -283,7 +283,7 @@ static variable_list unwrap_variables(PyObject* py_variables) { variable_list results(PyTuple_GET_SIZE(py_variables)); for (const auto i : c10::irange(results.size())) { PyObject* item = PyTuple_GET_ITEM(py_variables, i); - if (item == Py_None) { + if (Py_IsNone(item)) { continue; } else if (THPVariable_Check(item)) { results[i] = THPVariable_Unpack(item); @@ -327,11 +327,11 @@ static void check_single_result( PyObject* _original, PyObject* _result, PyObject* hook) { - if (_result == Py_None) + if (Py_IsNone(_result)) return; TORCH_CHECK( - _original != Py_None, + !Py_IsNone(_original), "can't replace a None gradient with a non-None value"); if (!PyObject_IsInstance(_result, THPVariableClass)) { diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index ee00008c94bb9..0a9dbd0726be1 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -41,7 +41,7 @@ static PyObject* THPVariable_pynew( &name)) return nullptr; - if (grad_fn == Py_None) + if (Py_IsNone(grad_fn)) grad_fn = nullptr; if (is_volatile) { @@ -64,7 +64,7 @@ static PyObject* THPVariable_pynew( Py_TYPE(grad_fn)->tp_name); } Variable var; - if (!data || data == Py_None) { + if (!data || Py_IsNone(data)) { // For legacy serialization code, create an empty tensor. This is also used // by nn.Parameter() with no arguments. auto dispatch_key = torch::tensors::get_default_dispatch_key(); @@ -101,7 +101,7 @@ static PyObject* THPVariable_pynew( impl::set_name(var, name); } - if (jit::tracer::isTracing() && data && data != Py_None && + if (jit::tracer::isTracing() && data && !Py_IsNone(data) && THPVariable_Check(data)) { if (auto* v = jit::tracer::getValueTrace(THPVariable_Unpack(data))) { jit::tracer::setValueTrace(var, v); @@ -155,12 +155,7 @@ static PyTypeObject THPLegacyVariableType = { }; void init_legacy_variable(PyObject* module) { - if (PyType_Ready(&THPLegacyVariableType) < 0) { - throw python_error(); - } - auto obj = (PyObject*)&THPLegacyVariableType; - Py_INCREF(obj); - if (PyModule_AddObject(module, "_LegacyVariableBase", obj) < 0) { + if (PyModule_AddType(module, &THPLegacyVariableType) < 0) { throw python_error(); } } diff --git a/torch/csrc/autograd/python_torch_functions_manual.cpp b/torch/csrc/autograd/python_torch_functions_manual.cpp index 79739b6e459d2..51a9fa41de5f2 100644 --- a/torch/csrc/autograd/python_torch_functions_manual.cpp +++ b/torch/csrc/autograd/python_torch_functions_manual.cpp @@ -330,7 +330,7 @@ static PyObject* THPVariable_asarray( HANDLE_TH_ERRORS static PythonArgParser parser( { - "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool requires_grad=False)", + "asarray(PyObject* obj, *, ScalarType? dtype=None, Device? device=None, bool? copy=None, bool? requires_grad=None)", }, /*traceable=*/false); @@ -347,7 +347,7 @@ static PyObject* THPVariable_asarray( auto dtype = r.scalartypeOptional(1); auto device = r.deviceOptional(2); auto copy = r.toBoolOptional(3); - auto requires_grad = r.toBool(4); + auto requires_grad = r.toBoolOptional(4); return wrap(torch::utils::asarray(obj, dtype, device, copy, requires_grad)); } diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index d220c6dec130e..77b235de60b78 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -227,6 +227,14 @@ std::pair parseIValuesToPyArgsKwargs( } return false; }; + auto matchList = [&](c10::TypeKind kind) { + const auto& t = arg.real_type(); + if (auto list_t = t->cast()) { + if (list_t->getElementType()->kind() == kind) + return true; + } + return false; + }; if (argument.isNone()) { return py::none(); } else if (match(c10::ScalarTypeType::Kind)) { @@ -239,6 +247,31 @@ std::pair parseIValuesToPyArgsKwargs( reinterpret_cast(obj)); } else if (match(c10::MemoryFormatType::Kind)) { return py::cast(static_cast(argument.toInt())); + } else if (matchList(c10::ScalarTypeType::Kind)) { + const auto& list = argument.toListRef(); + py::list result(list.size()); + for (const auto i : c10::irange(list.size())) { + auto* obj = getTHPDtype(static_cast(list[i].toInt())); + result[i] = py::reinterpret_borrow( + reinterpret_cast(obj)); + } + return result; + } else if (matchList(c10::LayoutType::Kind)) { + const auto& list = argument.toListRef(); + py::list result(list.size()); + for (const auto i : c10::irange(list.size())) { + auto* obj = getTHPLayout(static_cast(list[i].toInt())); + result[i] = py::reinterpret_borrow( + reinterpret_cast(obj)); + } + return result; + } else if (matchList(c10::MemoryFormatType::Kind)) { + const auto& list = argument.toListRef(); + py::list result(list.size()); + for (const auto i : c10::irange(list.size())) { + result[i] = py::cast(static_cast(list[i].toInt())); + } + return result; } else { return torch::jit::toPyObject(argument); } @@ -388,53 +421,32 @@ static PyObject* THPVariable_WrapWithType( } c10::TensorImpl* tensor_impl = var.unsafeGetTensorImpl(); - c10::impl::PyObjectSlot* pyobj_slot = tensor_impl->pyobj_slot(); - - PyObject* obj = pyobj_slot->load_pyobj(); - if (obj) { + THPObjectPtr obj(PyObjectPreservation::get_or_init(*tensor_impl, [&]() { + PyTypeObject* type = reinterpret_cast(THPVariableClass); if (desired_type) { - check_tensor_subclass(obj, *desired_type); - } - return Py_NewRef(obj); - } - - PyTypeObject* type = reinterpret_cast(THPVariableClass); - if (desired_type) { - type = *desired_type; - } else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) { - if (auto clazz = getPythonTensorClass(var.device())) { - type = reinterpret_cast(clazz); + type = *desired_type; + } else if (C10_UNLIKELY(var.device().type() == c10::kXLA)) { + if (auto clazz = getPythonTensorClass(var.device())) { + type = reinterpret_cast(clazz); + } } - } - - obj = type->tp_alloc(type, 0); - TORCH_CHECK(obj, "Failed to allocate a ", type->tp_name, " object"); - // Ensure that PyUnstable_TryIncref calls don't fail spuriously in - // free-threaded Python. - PyUnstable_EnableTryIncRef(obj); + PyObject* wrapper = type->tp_alloc(type, 0); + TORCH_CHECK_WITH( + OutOfMemoryError, + wrapper, + "Failed to allocate a ", + type->tp_name, + " object"); + auto v = reinterpret_cast(wrapper); + new (&v->cdata) Tensor(std::forward(var)); + return wrapper; + })); - auto v = reinterpret_cast(obj); - new (&v->cdata) Tensor(std::forward(var)); - - if (THPVariable_Unpack(obj).is_uniquely_owned()) { - // We can use a faster non-atomic code path if we have the only reference to - // a fresh Tensor. - PyObjectPreservation::init_fresh_nonatomic(tensor_impl, pyobj_slot, obj); - return obj; - } - - PyObject* wrapper = - PyObjectPreservation::init_once(tensor_impl, pyobj_slot, obj); - if (wrapper != obj) { - // Another thread beat us to it - Py_DECREF(obj); - if (desired_type) { - check_tensor_subclass(wrapper, *desired_type); - } - return Py_NewRef(wrapper); + if (desired_type) { + check_tensor_subclass(obj.get(), *desired_type); } - return obj; + return obj.release(); } PyObject* THPVariable_Wrap(at::TensorBase&& var) { @@ -524,14 +536,14 @@ static PyObject* view_func_impl( // Determine new SymInt / tensor state as needed. std::optional> new_symints = std::nullopt; - if (symint_visitor_fn != Py_None) { + if (!Py_IsNone(symint_visitor_fn)) { new_symints = map_py_func( py::cast(symint_visitor_fn), view_func.get_symints()); } std::optional> new_tensors = std::nullopt; - if (tensor_visitor_fn != Py_None) { + if (!Py_IsNone(tensor_visitor_fn)) { new_tensors = map_py_func( py::cast(tensor_visitor_fn), view_func.get_tensors()); @@ -1087,9 +1099,10 @@ struct IValueOrDTensorSpec { py::object dtensor_spec; bool operator==(const IValueOrDTensorSpec& rhs) const { - return dtensor_spec - ? (rhs.dtensor_spec && dtensor_spec.equal(rhs.dtensor_spec)) - : (iv == rhs.iv); + if (dtensor_spec) { + return rhs.dtensor_spec && dtensor_spec.equal(rhs.dtensor_spec); + } + return !rhs.dtensor_spec && iv == rhs.iv; } }; @@ -1162,18 +1175,6 @@ class NativeOpSchema { // have no guarantees about its lifetime. This class is cheap anyway. c10::OperatorHandle op_; std::size_t hash_; - // Subtle point: consider clamp.Tensor(Tensor self, Tensor? - // min=None, Tensor? max=None). The invocations clamp(t1, None, t2) - // and clamp(t1, t2, None) have the same comparison key (t1, t2) - // because we drop non-static non-tensor args from comparison. The - // only way we happen to be able to tell them apart is that we omit - // trailing defaulted arguments from the args tuple passed to - // __torch_dispatch__ (and hence to DTensor dispatch as well), so - // they have different args_schema_len_. - // - // I am preserving this existing behavior, but I suspect we should - // make an algorithm change to be less brittle, such as including - // None defaults for Tensor arguments in the comparison. std::size_t args_schema_len_; // There is no particular justification for the choice of 8 // here. Feel free to change it. @@ -1842,7 +1843,7 @@ static bool DTensor_OpSchema_recompute_comparison_key_impl( size_t idx = 0; for (const auto& e : args_schema) { if (idx >= native_info.static_argnum || - arg_type_tensor_or_tensor_list_like(e)) { + arg_type_tensor_or_tensor_list_like(e) || e.is_none()) { if (PyList_Check(e.ptr())) { args_to_hash.push_back( py::reinterpret_steal(PyList_AsTuple(e.ptr()))); @@ -2304,7 +2305,9 @@ create_native_op_schema( const auto handle_non_dtensor_arg = [&comparison_key, &comparison_key_hash, &native_info]( size_t idx, c10::IValue arg) { - if (idx >= native_info.static_argnum) { + bool is_none_or_undefined = + arg.isNone() || (arg.isTensor() && !arg.toTensor().defined()); + if (idx >= native_info.static_argnum || is_none_or_undefined) { if (arg.isList()) { const auto& list = arg.toList(); if (list.empty()) { @@ -2420,7 +2423,13 @@ create_native_op_schema( item_flavor == TensorFlavor::NON_DTENSOR_TENSOR_SUBCLASS) { handle_exactly_tensor(item_py_tensor); } else { // non-tensor - handle_non_tensor_or_undefined(item); + // Use handle_non_dtensor_arg to respect static_argnum. + // Non-tensor items in lists (e.g., ScalarList args to + // foreach ops) should only be included in the cache key + // if the list's argument index is >= static_argnum. + // Otherwise, step-varying scalars (like AdamW bias + // corrections) cause unbounded cache growth. + handle_non_dtensor_arg(idx, item); } } } else { @@ -2454,6 +2463,20 @@ create_native_op_schema( } if (native_info.static_kwargkey && !native_info.static_kwargkey.is_none()) { + // Only kwargs named in static_kwargkey affect sharding propagation and + // belong in the cache key. The Python comparison key + // (DTensor_OpSchema_recompute_comparison_key_impl) already filters this + // way; the C++ fast path must match. Without this filter, step-varying + // scalar kwargs (e.g. the `value` arg of addcdiv_ used by AdamW bias + // corrections) cause unbounded cache growth. + py::list static_kwargkey = + py::reinterpret_borrow(native_info.static_kwargkey); + c10::SmallVector static_kwarg_names; + static_kwarg_names.reserve(static_kwargkey.size()); + for (const auto& key : static_kwargkey) { + static_kwarg_names.push_back(py::cast(key)); + } + // Separator to disambiguate kwargs from args in comparison and hashing. static constexpr int64_t kwargs_separator = 0x0011223344556677LL; comparison_key.emplace_back(static_cast(kwargs_separator)); @@ -2462,14 +2485,26 @@ create_native_op_schema( for (auto argument_it = args_kwargs.kwargs_begin(); argument_it != args_kwargs.kwargs_end(); ++argument_it) { + const auto underlying_index = argument_it.underlying_index(); + const auto [tensor_flavor, py_tensor] = + check_for_dtensor_or_tensor(*argument_it); + // Skip non-tensor kwargs not listed in static_kwargkey. + if (tensor_flavor == TensorFlavor::NON_TENSOR) { + const auto& kwarg_name = + op.schema().arguments()[underlying_index].name(); + if (std::find( + static_kwarg_names.begin(), + static_kwarg_names.end(), + kwarg_name) == static_kwarg_names.end()) { + continue; + } + } + // Rather than hash/compare the string key, we can just use the // index of the kwarg in the schema! - const auto underlying_index = argument_it.underlying_index(); comparison_key.emplace_back(c10::IValue(underlying_index)); comparison_key_hash = hash_combine( comparison_key_hash, c10::IValue::hash(comparison_key.back().iv)); - const auto [tensor_flavor, py_tensor] = - check_for_dtensor_or_tensor(*argument_it); switch (tensor_flavor) { case TensorFlavor::EXACTLY_DTENSOR: case TensorFlavor::DTENSOR_SUBCLASS: { @@ -2583,7 +2618,7 @@ template struct GetterBase { static PyObject* getter(THPVariable* self, void* /*unused*/) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, T::name); } return THPVariable_Wrap(T::fn(THPVariable_Unpack(self))); @@ -2649,7 +2684,7 @@ struct PropertyImag : GetterBase { static PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_cdata"); } const auto& var = THPVariable_Unpack(self); @@ -2659,7 +2694,7 @@ static PyObject* THPVariable_get_cdata(THPVariable* self, void* unused) { static PyObject* THPVariable_get_version(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_version"); } const auto& var = THPVariable_Unpack(self); @@ -2669,7 +2704,7 @@ static PyObject* THPVariable_get_version(THPVariable* self, void* unused) { static PyObject* THPVariable_get_grad_fn(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "grad_fn"); } const auto& var = THPVariable_Unpack(self); @@ -2685,11 +2720,11 @@ static int THPVariable_set_grad_fn( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "_grad_fn", obj); } TORCH_CHECK(obj, "Deletion of _grad_fn not allowed. Detach tensor instead!"); - TORCH_CHECK(obj == Py_None, "_grad_fn can be only set to None"); + TORCH_CHECK(Py_IsNone(obj), "_grad_fn can be only set to None"); THPVariable_Unpack(self).detach_(); return 0; END_HANDLE_TH_ERRORS_RET(-1) @@ -2697,7 +2732,7 @@ static int THPVariable_set_grad_fn( static PyObject* THPVariable_is_leaf(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_leaf"); } return PyBool_FromLong(!THPVariable_Unpack(self).grad_fn()); @@ -2709,7 +2744,7 @@ static int THPVariable_set_data( PyObject* data, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "data", data); } TORCH_CHECK( @@ -2729,11 +2764,11 @@ static int THPVariable_set_grad( PyObject* py_grad, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "grad", py_grad); } const auto& var = THPVariable_Unpack(self); - if (!py_grad || py_grad == Py_None) { + if (!py_grad || Py_IsNone(py_grad)) { var.mutable_grad().reset(); return 0; } @@ -2795,7 +2830,7 @@ static int THPVariable_set_grad( static PyObject* THPVariable_get_volatile(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "volatile"); } const char* msg = "volatile was removed (Variable.volatile is always False)"; @@ -2811,7 +2846,7 @@ static int THPVariable_set_volatile( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "volatile", obj); } auto r = PyErr_WarnEx(PyExc_UserWarning, VOLATILE_WARNING, 1); @@ -2823,7 +2858,7 @@ static int THPVariable_set_volatile( static PyObject* THPVariable_get_output_nr(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "output_nr"); } const auto output_nr = THPVariable_Unpack(self).output_nr(); @@ -2835,7 +2870,7 @@ static PyObject* THPVariable_get_requires_grad( THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "requires_grad"); } if (THPVariable_Unpack(self).requires_grad()) { @@ -2848,7 +2883,7 @@ static PyObject* THPVariable_get_requires_grad( static PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "retains_grad"); } if (THPVariable_Unpack(self).retains_grad()) { @@ -2861,7 +2896,7 @@ static PyObject* THPVariable_retains_grad(THPVariable* self, void* unused) { static PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "ndim"); } return THPUtils_packInt64(THPVariable_Unpack(self).dim()); @@ -2870,7 +2905,7 @@ static PyObject* THPVariable_get_ndim(THPVariable* self, void* unused) { static PyObject* THPVariable_get_names(PyObject* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function_getter((THPVariable*)self, "names"); } // The long-term plan is to return a list of (python) torch.Dimname. @@ -2911,11 +2946,11 @@ static int THPVariable_set_names( PyObject* names, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { return handle_torch_function_setter((THPVariable*)self, "names", names); } const auto& var = THPVariable_Unpack(self); - if (names == Py_None) { + if (Py_IsNone(names)) { at::internal_set_names_inplace(var, std::nullopt); } else { TORCH_CHECK( @@ -2932,15 +2967,15 @@ static int THPVariable_set_requires_grad( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "requires_grad", obj); } TORCH_CHECK(obj && PyBool_Check(obj), "requires_grad must be a bool"); const auto& var = THPVariable_Unpack(self); - auto requires_grad = (obj == Py_True); + auto requires_grad = (Py_IsTrue(obj)); if (!var.is_leaf()) { THPUtils_setError( - autograd::utils::requires_grad_leaf_error(obj == Py_True).c_str()); + autograd::utils::requires_grad_leaf_error(Py_IsTrue(obj)).c_str()); return -1; } if (requires_grad && @@ -2955,7 +2990,7 @@ static int THPVariable_set_requires_grad( } static PyObject* THPVariable_get_name(THPVariable* self, void* unused) { - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { HANDLE_TH_ERRORS return handle_torch_function_getter(self, "name"); END_HANDLE_TH_ERRORS @@ -2970,7 +3005,7 @@ static PyObject* THPVariable_get_backwards_hooks( THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_backward_hooks"); } if (self->backward_hooks) { @@ -2986,11 +3021,11 @@ static int THPVariable_set_backwards_hooks( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "_backward_hooks", obj); } TORCH_CHECK(obj, "Deletion of _backwards_hooks not allowed!"); - if (obj == Py_None) { + if (Py_IsNone(obj)) { obj = nullptr; } Py_XINCREF(obj); @@ -3010,7 +3045,7 @@ static PyObject* THPVariable_get_post_accumulate_grad_hooks( THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_post_accumulate_grad_hooks"); } if (self->post_accumulate_grad_hooks) { @@ -3026,12 +3061,12 @@ static int THPVariable_set_post_accumulate_grad_hooks( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter( self, "_post_accumulate_grad_hooks", obj); } TORCH_CHECK(obj, "Deletion of _post_accumulate_grad_hooks not allowed!"); - if (obj == Py_None) { + if (Py_IsNone(obj)) { obj = nullptr; } Py_XINCREF(obj); @@ -3048,7 +3083,7 @@ static int THPVariable_set_post_accumulate_grad_hooks( static PyObject* THPVariable_get_base(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "_base"); } const auto& tensor = THPVariable_Unpack(self); @@ -3061,7 +3096,7 @@ static PyObject* THPVariable_get_base(THPVariable* self, void* unused) { static PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "shape"); } return THPSize_NewFromSymSizes(THPVariable_Unpack(self)); @@ -3070,7 +3105,7 @@ static PyObject* THPVariable_get_shape(THPVariable* self, void* unused) { static PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_cpu"); } auto& self_ = THPVariable_Unpack(self); @@ -3080,7 +3115,7 @@ static PyObject* THPVariable_is_cpu(THPVariable* self, void* unused) { static PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_cuda"); } auto& self_ = THPVariable_Unpack(self); @@ -3090,7 +3125,7 @@ static PyObject* THPVariable_is_cuda(THPVariable* self, void* unused) { static PyObject* THPVariable_is_mtia(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_mtia"); } auto& self_ = THPVariable_Unpack(self); @@ -3100,7 +3135,7 @@ static PyObject* THPVariable_is_mtia(THPVariable* self, void* unused) { static PyObject* THPVariable_is_xla(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_xla"); } auto& self_ = THPVariable_Unpack(self); @@ -3110,7 +3145,7 @@ static PyObject* THPVariable_is_xla(THPVariable* self, void* unused) { static PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_ipu"); } auto& self_ = THPVariable_Unpack(self); @@ -3120,7 +3155,7 @@ static PyObject* THPVariable_is_ipu(THPVariable* self, void* unused) { static PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_xpu"); } auto& self_ = THPVariable_Unpack(self); @@ -3130,7 +3165,7 @@ static PyObject* THPVariable_is_xpu(THPVariable* self, void* unused) { static PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_sparse"); } auto& self_ = THPVariable_Unpack(self); @@ -3140,7 +3175,7 @@ static PyObject* THPVariable_is_sparse(THPVariable* self, void* unused) { static PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_sparse_csr"); } auto& self_ = THPVariable_Unpack(self); @@ -3150,7 +3185,7 @@ static PyObject* THPVariable_is_sparse_csr(THPVariable* self, void* unused) { static PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_mkldnn"); } auto& self_ = THPVariable_Unpack(self); @@ -3160,7 +3195,7 @@ static PyObject* THPVariable_is_mkldnn(THPVariable* self, void* unused) { static PyObject* THPVariable_is_mps(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_mps"); } auto& self_ = THPVariable_Unpack(self); @@ -3170,7 +3205,7 @@ static PyObject* THPVariable_is_mps(THPVariable* self, void* unused) { static PyObject* THPVariable_is_maia(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_maia"); } auto& self_ = THPVariable_Unpack(self); @@ -3180,7 +3215,7 @@ static PyObject* THPVariable_is_maia(THPVariable* self, void* unused) { static PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_vulkan"); } auto& self_ = THPVariable_Unpack(self); @@ -3190,7 +3225,7 @@ static PyObject* THPVariable_is_vulkan(THPVariable* self, void* unused) { static PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_quantized"); } auto& self_ = THPVariable_Unpack(self); @@ -3200,7 +3235,7 @@ static PyObject* THPVariable_is_quantized(THPVariable* self, void* unused) { static PyObject* THPVariable_is_meta(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_meta"); } auto& self_ = THPVariable_Unpack(self); @@ -3210,7 +3245,7 @@ static PyObject* THPVariable_is_meta(THPVariable* self, void* unused) { static PyObject* THPVariable_is_complex(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_complex"); } auto& self_ = THPVariable_Unpack(self); @@ -3220,7 +3255,7 @@ static PyObject* THPVariable_is_complex(THPVariable* self, void* unused) { static PyObject* THPVariable_is_nested(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "is_nested"); } auto& self_ = THPVariable_Unpack(self); @@ -3240,7 +3275,7 @@ static PyObject* THPVariable_has_symbolic_sizes_strides( static PyObject* THPVariable_dtype(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "dtype"); } auto& self_ = THPVariable_Unpack(self); @@ -3250,7 +3285,7 @@ static PyObject* THPVariable_dtype(THPVariable* self, void* unused) { static PyObject* THPVariable_layout(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "layout"); } auto& self_ = THPVariable_Unpack(self); @@ -3260,7 +3295,7 @@ static PyObject* THPVariable_layout(THPVariable* self, void* unused) { static PyObject* THPVariable_device(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "device"); } return THPDevice_New(THPVariable_Unpack(self).device()); @@ -3269,7 +3304,7 @@ static PyObject* THPVariable_device(THPVariable* self, void* unused) { static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "nbytes"); } return PyLong_FromSize_t(THPVariable_Unpack(self).nbytes()); @@ -3278,7 +3313,7 @@ static PyObject* THPVariable_get_nbytes(THPVariable* self, void* unused) { static PyObject* THPVariable_get_grad_dtype(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "grad_dtype"); } const auto& var = THPVariable_Unpack(self); @@ -3297,15 +3332,15 @@ static int THPVariable_set_grad_dtype( PyObject* obj, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_setter(self, "grad_dtype", obj); } const auto& var = THPVariable_Unpack(self); TORCH_CHECK( - THPDtype_Check(obj) || obj == Py_None, + THPDtype_Check(obj) || Py_IsNone(obj), "grad_dtype must be a torch.dtype or None, but got ", Py_TYPE(obj)->tp_name); - if (var.grad().defined() && obj != Py_None) { + if (var.grad().defined() && !Py_IsNone(obj)) { auto new_dtype = reinterpret_cast(obj); TORCH_CHECK( var.grad().dtype() == new_dtype->scalar_type, @@ -3317,7 +3352,7 @@ static int THPVariable_set_grad_dtype( "or ensure the new grad_dtype matches the existing gradient's dtype."); } std::optional new_dtype; - if (obj != Py_None) { + if (!Py_IsNone(obj)) { auto* dtype = reinterpret_cast(obj); new_dtype = dtype->scalar_type; } @@ -3328,7 +3363,7 @@ static int THPVariable_set_grad_dtype( static PyObject* THPVariable_get_itemsize(THPVariable* self, void* unused) { HANDLE_TH_ERRORS - if (check_has_torch_function((PyObject*)self)) { + if (has_torch_function((PyObject*)self)) { return handle_torch_function_getter(self, "itemsize"); } return PyLong_FromSize_t(THPVariable_Unpack(self).itemsize()); diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 1a1a12ec20a72..7cfe7fd74bd30 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -38,7 +38,7 @@ namespace torch::autograd { Py_ssize_t THPVariable_length(PyObject* self) { HANDLE_TH_ERRORS - if (check_has_torch_function(self)) { + if (has_torch_function(self)) { py::object ret = py::reinterpret_steal( handle_torch_function(self, "__len__")); Py_ssize_t length = PyLong_AsSsize_t(ret.ptr()); @@ -116,8 +116,8 @@ static int64_t count_specified_dimensions(PyObject* index) { return -1; // Signal torch function handling needed } } - if (obj != Py_None && obj != Py_Ellipsis && obj != Py_True && - obj != Py_False) { + if (!Py_IsNone(obj) && obj != Py_Ellipsis && !Py_IsTrue(obj) && + !Py_IsFalse(obj)) { count++; } } @@ -257,10 +257,10 @@ static Variable applySlicing( at::indexing::Slice(val.start, val.stop, val.step)); } else if (obj == Py_Ellipsis) { return at::indexing::TensorIndex(at::indexing::Ellipsis); - } else if (obj == Py_None) { + } else if (Py_IsNone(obj)) { return at::indexing::TensorIndex(at::indexing::None); } else if (PyBool_Check(obj)) { - return at::indexing::TensorIndex(obj == Py_True); + return at::indexing::TensorIndex(Py_IsTrue(obj)); } else if (THPVariable_Check(obj)) { Tensor tensor = THPVariable_Unpack(obj); if (is_tracing) { @@ -353,7 +353,7 @@ static bool treatSequenceAsTuple(PyObject* index) { "different result"); return true; } - if (obj.get() == Py_Ellipsis || obj.get() == Py_None) { + if (obj.get() == Py_Ellipsis || Py_IsNone(obj.get())) { TORCH_WARN( "Using a non-tuple sequence for " "multidimensional indexing is deprecated and will be changed in " @@ -396,7 +396,7 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { OptionalDeviceGuard device_guard(device_of(self_)); // handle simple types: none, ellipsis - if (index == Py_None) { + if (Py_IsNone(index)) { return THPVariable_Wrap(at::indexing::get_item( self_, {at::indexing::TensorIndex(at::indexing::None)})); } else if (index == Py_Ellipsis) { @@ -422,11 +422,11 @@ PyObject* THPVariable_getitem(PyObject* self, PyObject* index) { self_, {at::indexing::TensorIndex( at::indexing::Slice(val.start, val.stop, val.step))})); - } else if (index == Py_False || index == Py_True) { + } else if (Py_IsFalse(index) || Py_IsTrue(index)) { return THPVariable_Wrap(([&]() { pybind11::gil_scoped_release no_gil; return at::indexing::get_item( - self_, {at::indexing::TensorIndex(index == Py_True)}); + self_, {at::indexing::TensorIndex(Py_IsTrue(index))}); })()); } @@ -513,7 +513,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { } // handle simple types: ellipsis, none, bool - if (index == Py_False) { + if (Py_IsFalse(index)) { // do nothing for false (technically we should check the size, but we don't // have real 0-sized shapes. return 0; @@ -521,11 +521,11 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) { dispatch_set_item( self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value); return 0; - } else if (index == Py_None) { + } else if (Py_IsNone(index)) { dispatch_set_item( self_, {at::indexing::TensorIndex(at::indexing::None)}, value); return 0; - } else if (index == Py_True) { + } else if (Py_IsTrue(index)) { dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value); return 0; } diff --git a/torch/csrc/autograd/python_variable_indexing.h b/torch/csrc/autograd/python_variable_indexing.h index 7efab1dcf2229..c6c5b5d3ebd18 100644 --- a/torch/csrc/autograd/python_variable_indexing.h +++ b/torch/csrc/autograd/python_variable_indexing.h @@ -37,7 +37,7 @@ inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { return val; }; - if (r->step == Py_None) { + if (Py_IsNone(r->step)) { step_sym = c10::SymInt(1); } else { if (torch::is_symint(r->step)) { @@ -58,7 +58,7 @@ inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { if (torch::is_symint(r->start)) { start_sym = py::handle(r->start).cast(); - } else if (r->start == Py_None) { + } else if (Py_IsNone(r->start)) { start_sym = c10::SymInt(step_sym < 0 ? PY_SSIZE_T_MAX : 0); } else { Py_ssize_t start = 0; @@ -71,7 +71,7 @@ inline UnpackedSlice __PySlice_Unpack(PyObject* _r) { if (torch::is_symint(r->stop)) { stop_sym = py::handle(r->stop).cast(); - } else if (r->stop == Py_None) { + } else if (Py_IsNone(r->stop)) { stop_sym = c10::SymInt( step_sym < 0 ? c10::SymInt::min_representable_int() : PY_SSIZE_T_MAX); } else { diff --git a/torch/csrc/autograd/saved_variable.cpp b/torch/csrc/autograd/saved_variable.cpp index 9624baa8b6b7a..fe8449537dec1 100644 --- a/torch/csrc/autograd/saved_variable.cpp +++ b/torch/csrc/autograd/saved_variable.cpp @@ -180,7 +180,7 @@ Variable SavedVariable::unpack(std::shared_ptr saved_for) const { } if (grad_fn) { message << ", which is output " << output_nr_ << " of " - << grad_fn->name() << ','; + << grad_fn->forward_op_name() << ','; } message << " is at version " << current_version << "; expected version " << saved_version_ << " instead."; diff --git a/torch/csrc/autograd/variable.cpp b/torch/csrc/autograd/variable.cpp index 89135f9aa9a22..188e6b5e67d2f 100644 --- a/torch/csrc/autograd/variable.cpp +++ b/torch/csrc/autograd/variable.cpp @@ -118,9 +118,8 @@ ViewInfo ViewInfo::chain( "Attempted to chain views when the parent view has no view_func() and " "does not support as_strided(). This is not supported."; view_func = std::make_unique(error_msg); - rev_view_func = [=](const at::Tensor& root_view) { + rev_view_func = [=](const at::Tensor& root_view) -> at::Tensor { TORCH_CHECK(false, error_msg); - return root_view; }; } } @@ -788,7 +787,7 @@ void handle_view_on_rebase( "Output ", diff_view_meta->output_nr_, " of ", - grad_fn->name(), + grad_fn->forward_op_name(), " is a view of a view which was created in"); } else { prefix = "A view was created in"; @@ -815,7 +814,7 @@ void handle_view_on_rebase( "Output ", diff_view_meta->output_nr_, " of ", - grad_fn->name(), + grad_fn->forward_op_name(), " is a view and ", modified_obj, " modified inplace."); diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index bb08bcaf9bedd..1fb6795eae6fd 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -297,6 +297,14 @@ void CUDAPluggableAllocator::attachOutOfMemoryObserver( "If you need it, please file an issue describing your use case."); } +void CUDAPluggableAllocator::attachOomRejectionObserver( + c10::cuda::CUDACachingAllocator::OomRejectionObserver observer) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "CUDAPluggableAllocator does not yet support attachOomRejectionObserver. " + "If you need it, please file an issue describing your use case."); +} + void CUDAPluggableAllocator::attachAllocatorTraceTracker( c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) { TORCH_CHECK( diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.h b/torch/csrc/cuda/CUDAPluggableAllocator.h index 5434dad237608..91b2179038a19 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.h +++ b/torch/csrc/cuda/CUDAPluggableAllocator.h @@ -123,6 +123,8 @@ struct TORCH_CUDA_CPP_API CUDAPluggableAllocator const std::vector& skip_actions) override; void attachOutOfMemoryObserver( c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) override; + void attachOomRejectionObserver( + c10::cuda::CUDACachingAllocator::OomRejectionObserver observer) override; void attachAllocatorTraceTracker( c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) override; std::shared_ptr diff --git a/torch/csrc/cuda/GreenContext.cpp b/torch/csrc/cuda/GreenContext.cpp index 26dc19b8b291e..d420600b0968f 100644 --- a/torch/csrc/cuda/GreenContext.cpp +++ b/torch/csrc/cuda/GreenContext.cpp @@ -6,8 +6,45 @@ void THCPGreenContext_init(PyObject* module) { auto m = py::handle(module).cast(); + + py::enum_(m, "_WorkqueueScope") + .value("device_ctx", at::cuda::WorkqueueScope::DeviceCtx) + .value("balanced", at::cuda::WorkqueueScope::Balanced); + py::class_(m, "_CUDAGreenContext") - .def_static("create", &::at::cuda::GreenContext::create) + .def_static( + "create", + [](std::optional device_id, + std::optional num_sms, + std::optional workqueue_scope, + std::optional workqueue_concurrency_limit) { + std::optional scope; + if (workqueue_scope.has_value()) { + const auto& s = *workqueue_scope; + if (s == "device_ctx") { + scope = + static_cast(at::cuda::WorkqueueScope::DeviceCtx); + } else if (s == "balanced") { + scope = + static_cast(at::cuda::WorkqueueScope::Balanced); + } else { + throw std::invalid_argument( + "workqueue_scope must be 'device_ctx' or 'balanced', got '" + + s + "'"); + } + } + return at::cuda::GreenContext::create( + device_id, num_sms, scope, workqueue_concurrency_limit); + }, + py::kw_only(), + py::arg("device_id") = py::none(), + py::arg("num_sms") = py::none(), + py::arg("workqueue_scope") = py::none(), + py::arg("workqueue_concurrency_limit") = py::none()) + .def_static( + "max_workqueue_concurrency", + &at::cuda::GreenContext::max_workqueue_concurrency, + py::arg("device_id") = py::none()) .def("set_context", &::at::cuda::GreenContext::setContext) .def("pop_context", &::at::cuda::GreenContext::popContext) .def("Stream", [](at::cuda::GreenContext& self) { diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 43b8c7cefe292..9cc8a2f65a56d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -297,6 +297,11 @@ PyObject* THCPModule_cudaCachingAllocator_raw_alloc( return nullptr; } auto size = PyLong_AsSsize_t(size_o); + TORCH_CHECK_VALUE( + size >= 0, + "Invalid memory size: ", + size, + ". caching_allocator_alloc requires a non-negative size."); cudaStream_t stream = static_cast(PyLong_AsVoidPtr(stream_o)); void* mem = nullptr; { @@ -367,6 +372,9 @@ PyObject* THCPModule_cudaJiteratorCompileAndLaunchKernel( Py_ssize_t num_tensors = PyTuple_GET_SIZE(tensors_o); c10::SmallVector tensors; + if (num_tensors > 0) { + tensors.reserve(static_cast(num_tensors)); + } for (const auto i : c10::irange(num_tensors)) { PyObject* _tensor = PyTuple_GET_ITEM(tensors_o, i); TORCH_CHECK( @@ -378,6 +386,10 @@ PyObject* THCPModule_cudaJiteratorCompileAndLaunchKernel( } c10::SmallVector extra_args; + const Py_ssize_t num_extra_args = kwargs_o ? PyDict_Size(kwargs_o) : 0; + if (num_extra_args > 0) { + extra_args.reserve(static_cast(num_extra_args)); + } PyObject* key = nullptr; PyObject* value = nullptr; Py_ssize_t pos = 0; @@ -451,6 +463,14 @@ PyObject* THCPModule_cudaCachingAllocator_enable( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cudaCachingAllocator_is_enabled( + PyObject* _unused, + PyObject* noargs) { + HANDLE_TH_ERRORS + return PyBool_FromLong(c10::cuda::CUDACachingAllocator::isEnabled()); + END_HANDLE_TH_ERRORS +} + PyObject* THCPModule_getAllocatorBackend(PyObject* _unused, PyObject* noargs) { HANDLE_TH_ERRORS return THPUtils_packString(c10::cuda::CUDACachingAllocator::name()); @@ -486,41 +506,6 @@ PyObject* THCPModule_cudaSleep(PyObject* _unused, PyObject* cycles) { END_HANDLE_TH_ERRORS } -// We need to ensure that as long as a thread will NEVER loose the GIL as long -// as it holds the CUDA mutex. Otherwise another thread might be scheduled and -// try to e.g. allocate a new tensor which will cause a deadlock. It's enough to -// have a single global, because it can be only set once (cudaMutex is not -// recursive) by the thread that owns the mutex (obviously there can be only one -// such thread). -static PyGILState_STATE cudaMutexGILState; - -PyObject* THCPModule_cudaLockMutex(PyObject* module, PyObject* noargs) { - auto mutex = c10::cuda::getFreeMutex(); - // This has to be a busy loop because we **absolutely need to** hold the GIL - // or it's a recipe for a deadlock otherwise (if we let other Python threads - // run while we have the cudaMutex, but not the GIL, they might try to e.g. - // free a CUDA tensor and acquire the cudaMutex without giving up the GIL, - // because it happens deep within THC). - while (true) { - if (mutex->try_lock()) - break; - { - pybind11::gil_scoped_release no_gil; - std::this_thread::sleep_for(std::chrono::microseconds(10)); - } - } - - cudaMutexGILState = PyGILState_Ensure(); - Py_RETURN_NONE; -} - -PyObject* THCPModule_cudaUnlockMutex(PyObject* module, PyObject* noargs) { - auto mutex = c10::cuda::getFreeMutex(); - PyGILState_Release(cudaMutexGILState); - mutex->unlock(); - Py_RETURN_NONE; -} - PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) { HANDLE_TH_ERRORS TORCH_CHECK( @@ -627,6 +612,7 @@ PyObject* THCPModule_memoryStats(PyObject* _unused, PyObject* arg) { result["num_sync_all_streams"] = stats.num_sync_all_streams; result["num_device_alloc"] = stats.num_device_alloc; result["num_device_free"] = stats.num_device_free; + result["num_oom_rejections"] = stats.num_oom_rejections; result["allocation"] = statArrayToDict(stats.allocation); result["segment"] = statArrayToDict(stats.segment); result["active"] = statArrayToDict(stats.active); @@ -735,7 +721,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { c10::cuda::MempoolId_t mempool_id = {0, 0}; bool include_traces = true; - if (arg && arg != Py_None) { + if (arg && !Py_IsNone(arg)) { TORCH_CHECK(PyTuple_Check(arg), "Expected tuple or None"); Py_ssize_t size = PyTuple_Size(arg); @@ -760,7 +746,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { PyBool_Check(traces.get()), "include_traces must be a boolean"); mempool_id = c10::cuda::MempoolId_t( THPUtils_unpackLong(id1), THPUtils_unpackLong(id2)); - include_traces = (traces.get() == Py_True); + include_traces = (Py_IsTrue(traces.get())); } else { TORCH_CHECK(false, "Expected tuple of size 2 or 3"); } @@ -793,6 +779,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { py::str time_us_s = "time_us"; py::str compile_context_s = "compile_context"; py::str user_metadata_s = "user_metadata"; + py::str pool_id_s = "pool_id"; py::list empty_frames; std::vector to_gather_frames; @@ -912,6 +899,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* arg) { trace_entry[time_us_s] = te.time_.t_; trace_entry[compile_context_s] = te.compile_context_; trace_entry[user_metadata_s] = te.user_metadata_; + trace_entry[pool_id_s] = te.mempool_; trace.append(trace_entry); } traces.append(trace); @@ -1404,6 +1392,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s); }); + m.def( + "_clear_storage_data_ptr_access_error_msg", [](size_t storage_impl_ptr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + storage_impl->clear_data_ptr_access_error_msg_(); + }); + m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) { // NOLINTNEXTLINE(performance-no-int-to-ptr) c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; @@ -1624,6 +1619,72 @@ static PyObject* THCPModule_clearBlasWorkspaces_wrap( END_HANDLE_TH_ERRORS } +static PyObject* THCPModule_getCublasWorkspaceSize( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(at::cuda::getChosenWorkspaceSize()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_setCublasWorkspaceSize( + PyObject* self, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "set cublas workspace size expects an int, but got ", + THPUtils_typename(arg)); + int64_t size = THPUtils_unpackLong(arg); + TORCH_CHECK( + size >= 0, "cublas workspace size must be non-negative, got ", size); + at::cuda::setChosenWorkspaceSize(static_cast(size)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_getCublasLtWorkspaceSize( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(at::cuda::getCUDABlasLtWorkspaceSize()); + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_setCublasLtWorkspaceSize( + PyObject* self, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "set cublaslt workspace size expects an int, but got ", + THPUtils_typename(arg)); + int64_t size = THPUtils_unpackLong(arg); + TORCH_CHECK( + size >= 0, "cublaslt workspace size must be non-negative, got ", size); + at::cuda::setCUDABlasLtWorkspaceSize(static_cast(size)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_resetCublasWorkspaceSize( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + at::cuda::resetChosenWorkspaceSize(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* THCPModule_resetCublasLtWorkspaceSize( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + at::cuda::resetCUDABlasLtWorkspaceSize(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + PyObject* THCPModule_rocm_is_backward_pass( PyObject* _unused, PyObject* noargs) { @@ -2067,6 +2128,30 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_clearBlasWorkspaces_wrap, METH_NOARGS, nullptr}, + {"_cuda_getCublasWorkspaceSize", + THCPModule_getCublasWorkspaceSize, + METH_NOARGS, + nullptr}, + {"_cuda_setCublasWorkspaceSize", + THCPModule_setCublasWorkspaceSize, + METH_O, + nullptr}, + {"_cuda_getCublasLtWorkspaceSize", + THCPModule_getCublasLtWorkspaceSize, + METH_NOARGS, + nullptr}, + {"_cuda_setCublasLtWorkspaceSize", + THCPModule_setCublasLtWorkspaceSize, + METH_O, + nullptr}, + {"_cuda_resetCublasWorkspaceSize", + THCPModule_resetCublasWorkspaceSize, + METH_NOARGS, + nullptr}, + {"_cuda_resetCublasLtWorkspaceSize", + THCPModule_resetCublasLtWorkspaceSize, + METH_NOARGS, + nullptr}, {"_cuda_isCurrentStreamCapturing", THCPModule_isCurrentStreamCapturing_wrap, METH_NOARGS, @@ -2137,6 +2222,10 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cudaCachingAllocator_enable, METH_O, nullptr}, + {"_cuda_cudaCachingAllocator_is_enabled", + THCPModule_cudaCachingAllocator_is_enabled, + METH_NOARGS, + nullptr}, {"_cuda_getAllocatorBackend", THCPModule_getAllocatorBackend, METH_NOARGS, @@ -2144,8 +2233,6 @@ static struct PyMethodDef _THCPModule_methods[] = { {"_cuda_synchronize", THCPModule_cudaSynchronize, METH_NOARGS, nullptr}, {"_cuda_ipc_collect", THCPModule_cudaIPCCollect, METH_NOARGS, nullptr}, {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr}, - {"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr}, - {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr}, {"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index 685e2fa38a9a7..9ce2b9bf383c6 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -315,6 +315,7 @@ std::string _memory_snapshot_pickled() { IValue time_us_s = "time_us"; IValue compile_contexts_s = "compile_context"; IValue user_metadata_s = "user_metadata"; + IValue pool_id_s = "pool_id"; auto empty_frames = new_list(); @@ -439,6 +440,7 @@ std::string _memory_snapshot_pickled() { frame_dict.push_back(trace_entry); } trace_entry.insert(time_us_s, te.time_.t_); + trace_entry.insert(pool_id_s, std::tuple(te.mempool_)); trace.push_back(trace_entry); } traces.push_back(trace); diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index 55af32792018a..2bb6977f5f68d 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -60,7 +60,7 @@ static void destroy_nccl_comm(PyObject* capsule) { static std::vector> unpack_streams( PyObject* obj, size_t size) { - if (obj == Py_None) { + if (Py_IsNone(obj)) { return std::vector>(size, std::nullopt); } auto streams = THPUtils_PySequence_to_CUDAStreamList(obj); @@ -74,7 +74,7 @@ static at::Tensor extract_tensor(PyObject* obj); static std::vector extract_tensors(PyObject* obj); static std::vector unpack_comms(PyObject* obj, size_t size) { - if (obj == Py_None) { + if (Py_IsNone(obj)) { return std::vector(); } std::vector comms; diff --git a/torch/csrc/cuda/utils.cpp b/torch/csrc/cuda/utils.cpp index 23112a8a06b8e..0024f563aab45 100644 --- a/torch/csrc/cuda/utils.cpp +++ b/torch/csrc/cuda/utils.cpp @@ -31,7 +31,7 @@ THPUtils_PySequence_to_CUDAStreamList(PyObject* obj) { reinterpret_cast(stream)->device_index), static_cast( (reinterpret_cast(stream))->device_type))); - } else if (stream == Py_None) { + } else if (Py_IsNone(stream)) { streams.emplace_back(); } else { TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 8ce9dc1d207dd..bb23e13a2d8e7 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include #include @@ -509,6 +511,26 @@ class TORCH_API Backend : public torch::CustomClassHolder { // normal shutdown. virtual void shutdown() {} + // APIs related to memory offload + virtual void suspend() { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support suspend")); + } + + virtual void resume() { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support resume")); + } + + virtual std::unordered_map getMemoryStats() { + TORCH_CHECK( + false, + c10::str( + "Backend ", getBackendName(), " does not support getMemoryStats")); + } + protected: // Implementations of this interface need to call this to setup // appropriate logging etc. diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp index 28647b8c50f5a..50aa63ed748d1 100644 --- a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -89,13 +89,13 @@ typename FlightRecorder::TraceIdentifier FlightRecorder:: if (!enabled_) { return TraceIdentifier{std::nullopt, std::nullopt}; } + auto traceback = + torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); + std::lock_guard guard(mutex_); if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { // Current pg_status is not in FR. all_pg_status_[pg_id] = std::move(pg_status); } - auto traceback = - torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); - std::lock_guard guard(mutex_); TORCH_CHECK( reset_epoch_start_idx_.find(reset_epoch_) != diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index b4e11b643837c..fd8112d2b9dba 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -4,6 +4,8 @@ #include #include #include +#include +#include #include namespace { @@ -31,7 +33,6 @@ c10d::ReduceOp to_reduce_op(const std::string& reduce_op) { at::Tensor allocate_all_gather_output( const at::Tensor& input, int64_t group_size) { - TORCH_CHECK(input.is_contiguous()); auto output_size = input.sizes().vec(); if (output_size.empty()) { output_size.push_back(group_size); @@ -46,7 +47,6 @@ at::Tensor allocate_all_gather_output( at::Tensor allocate_reduce_scatter_output( const at::Tensor& input, const int64_t group_size) { - TORCH_CHECK(input.is_contiguous()); auto output_size = input.sizes().vec(); if (output_size[0] % group_size != 0) { LOG(WARNING) << "The first dimension of the reduce_scatter input (" @@ -67,13 +67,11 @@ at::Tensor& all_reduce_( at::Tensor& input, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string reduce_op, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::string group_name) { + c10::intrusive_ptr group) { c10d::AllreduceOptions opts; opts.reduceOp = to_reduce_op(reduce_op); std::vector inputs{input}; - auto group = c10d::resolve_process_group(group_name); auto work = group->allreduce(inputs, opts); c10d::register_work(input, work); return input; @@ -82,7 +80,7 @@ at::Tensor& all_reduce_( at::Tensor all_reduce( const at::Tensor& input, std::string reduce_op, - std::string group_name) { + c10::intrusive_ptr group) { if (input.is_complex()) { TORCH_CHECK( // TODO - ideally use 'to_reduce_op' helper but it currently errors on @@ -95,21 +93,44 @@ at::Tensor all_reduce( } auto input_real = input.is_complex() ? at::view_as_real(input) : input; auto output = input_real.clone(at::MemoryFormat::Contiguous); - auto output_ret = - all_reduce_(output, std::move(reduce_op), std::move(group_name)); + auto output_ret = all_reduce_(output, std::move(reduce_op), std::move(group)); return input.is_complex() ? at::view_as_complex(output_ret) : output_ret; } +at::Tensor& all_reduce_( + at::Tensor& input, + std::string reduce_op, + std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_reduce_(input, std::move(reduce_op), std::move(group)); +} + +at::Tensor all_reduce( + const at::Tensor& input, + std::string reduce_op, + std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_reduce(input, std::move(reduce_op), std::move(group)); +} + std::vector all_reduce_coalesced_( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string reduce_op, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_reduce_coalesced_(inputs, std::move(reduce_op), std::move(group)); +} + +std::vector all_reduce_coalesced_( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + c10::intrusive_ptr group) { c10d::AllreduceCoalescedOptions opts; opts.reduceOp = to_reduce_op(reduce_op); - auto group = c10d::resolve_process_group(group_name); auto work = group->allreduce_coalesced(inputs, opts); for (const auto& tensor : inputs) { c10d::register_work(tensor, work); @@ -122,28 +143,33 @@ std::vector all_reduce_coalesced( std::vector inputs, std::string reduce_op, std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_reduce_coalesced(inputs, std::move(reduce_op), std::move(group)); +} + +std::vector all_reduce_coalesced( + std::vector inputs, + std::string reduce_op, + c10::intrusive_ptr group) { std::vector outputs; outputs.reserve(inputs.size()); for (const auto& tensor : inputs) { outputs.push_back(tensor.clone(at::MemoryFormat::Contiguous)); } - return all_reduce_coalesced_( - outputs, std::move(reduce_op), std::move(group_name)); + return all_reduce_coalesced_(outputs, std::move(reduce_op), std::move(group)); } std::vector all_gather_into_tensor_coalesced( std::vector inputs, int64_t group_size, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::string group_name) { + c10::intrusive_ptr group) { std::vector outputs; outputs.reserve(inputs.size()); - for (const auto& tensor : inputs) { - TORCH_CHECK(tensor.is_contiguous()); + for (auto& tensor : inputs) { + tensor = tensor.contiguous(); outputs.push_back(allocate_all_gather_output(tensor, group_size)); } - auto group = c10d::resolve_process_group(group_name); auto work = group->allgather_into_tensor_coalesced(outputs, inputs); for (const auto& tensor : outputs) { c10d::register_work(tensor, work); @@ -154,29 +180,53 @@ std::vector all_gather_into_tensor_coalesced( at::Tensor all_gather_into_tensor( const at::Tensor& input, int64_t group_size, - std::string group_name) { - TORCH_CHECK(input.is_contiguous()); + c10::intrusive_ptr group) { auto real_input = input.is_complex() ? at::view_as_real(input) : input; std::vector inputs{real_input}; - auto output = all_gather_into_tensor_coalesced( - inputs, group_size, std::move(group_name))[0]; + auto output = + all_gather_into_tensor_coalesced(inputs, group_size, std::move(group))[0]; return input.is_complex() ? at::view_as_complex(output) : output; } at::Tensor& all_gather_into_tensor_out( at::Tensor& input, int64_t group_size, - const std::string& group_name, + c10::intrusive_ptr group, at::Tensor& output) { - TORCH_CHECK(input.is_contiguous()); + auto contig_input = input.contiguous(); c10d::AllgatherOptions opts; - auto group = c10d::resolve_process_group(group_name); - auto work = group->_allgather_base(output, input, opts); + auto work = group->_allgather_base(output, contig_input, opts); c10d::register_work(output, work); return output; } +std::vector all_gather_into_tensor_coalesced( + std::vector inputs, + int64_t group_size, + std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_gather_into_tensor_coalesced(inputs, group_size, std::move(group)); +} + +at::Tensor all_gather_into_tensor( + const at::Tensor& input, + int64_t group_size, + std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_gather_into_tensor(input, group_size, std::move(group)); +} + +at::Tensor& all_gather_into_tensor_out( + at::Tensor& input, + int64_t group_size, + const std::string& group_name, + at::Tensor& output) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_gather_into_tensor_out( + input, group_size, std::move(group), output); +} + std::vector reduce_scatter_tensor_coalesced( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -184,16 +234,26 @@ std::vector reduce_scatter_tensor_coalesced( int64_t group_size, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return reduce_scatter_tensor_coalesced( + inputs, std::move(reduce_op), group_size, std::move(group)); +} + +std::vector reduce_scatter_tensor_coalesced( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group) { c10d::ReduceScatterOptions opts; opts.reduceOp = to_reduce_op(reduce_op); std::vector outputs; outputs.reserve(inputs.size()); - for (const auto& tensor : inputs) { - TORCH_CHECK(tensor.is_contiguous()); + for (auto& tensor : inputs) { + tensor = tensor.contiguous(); outputs.push_back(allocate_reduce_scatter_output(tensor, group_size)); } - auto group = c10d::resolve_process_group(group_name); auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts); for (const auto& tensor : outputs) { c10d::register_work(tensor, work); @@ -206,13 +266,11 @@ static std::vector reduce_scatter_tensor_coalesced_out( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string reduce_op, int64_t group_size, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::string group_name, + c10::intrusive_ptr group, std::vector& outputs) { c10d::ReduceScatterOptions opts; opts.reduceOp = to_reduce_op(reduce_op); - auto group = c10d::resolve_process_group(std::move(group_name)); auto work = group->reduce_scatter_tensor_coalesced(outputs, inputs, opts); for (const auto& tensor : outputs) { c10d::register_work(tensor, work); @@ -225,16 +283,25 @@ at::Tensor reduce_scatter_tensor( std::string reduce_op, int64_t group_size, std::string group_name) { - TORCH_CHECK(input.is_contiguous()); + auto group = c10d::resolve_process_group(std::move(group_name)); + return reduce_scatter_tensor( + input, std::move(reduce_op), group_size, std::move(group)); +} + +at::Tensor reduce_scatter_tensor( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group) { if (input.is_complex()) { auto real_input = at::view_as_real(input); - std::vector inputs{real_input}; + std::vector inputs{std::move(real_input)}; return at::view_as_complex(reduce_scatter_tensor_coalesced( - inputs, std::move(reduce_op), group_size, std::move(group_name))[0]); + inputs, std::move(reduce_op), group_size, std::move(group))[0]); } std::vector inputs{input}; return reduce_scatter_tensor_coalesced( - inputs, std::move(reduce_op), group_size, std::move(group_name))[0]; + inputs, std::move(reduce_op), group_size, std::move(group))[0]; } at::Tensor reduce_scatter_tensor_out( @@ -243,10 +310,21 @@ at::Tensor reduce_scatter_tensor_out( int64_t group_size, std::string group_name, at::Tensor& output) { - TORCH_CHECK(input.is_contiguous()); - if (input.is_complex()) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return reduce_scatter_tensor_out( + input, std::move(reduce_op), group_size, std::move(group), output); +} + +at::Tensor reduce_scatter_tensor_out( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group, + at::Tensor& output) { + auto contig_input = input.contiguous(); + if (contig_input.is_complex()) { TORCH_CHECK(output.is_complex()) - auto real_input = at::view_as_real(input); + auto real_input = at::view_as_real(contig_input); std::vector inputs{std::move(real_input)}; auto real_output = at::view_as_real(output); std::vector outputs{std::move(real_output)}; @@ -254,17 +332,13 @@ at::Tensor reduce_scatter_tensor_out( inputs, std::move(reduce_op), group_size, - std::move(group_name), + std::move(group), outputs)[0]); } - std::vector inputs{std::move(input)}; + std::vector inputs{std::move(contig_input)}; std::vector outputs{std::move(output)}; return reduce_scatter_tensor_coalesced_out( - inputs, - std::move(reduce_op), - group_size, - std::move(group_name), - outputs)[0]; + inputs, std::move(reduce_op), group_size, std::move(group), outputs)[0]; } at::Tensor all_to_all_single( @@ -273,6 +347,16 @@ at::Tensor all_to_all_single( c10::SymIntArrayRef _input_split_sizes, // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return all_to_all_single( + input, _output_split_sizes, _input_split_sizes, std::move(group)); +} + +at::Tensor all_to_all_single( + const at::Tensor& input, + c10::SymIntArrayRef _output_split_sizes, + c10::SymIntArrayRef _input_split_sizes, + c10::intrusive_ptr group) { std::vector output_split_sizes; std::vector input_split_sizes; output_split_sizes.reserve(_output_split_sizes.size()); @@ -284,17 +368,16 @@ at::Tensor all_to_all_single( input_split_sizes.emplace_back(size.expect_int()); } - TORCH_CHECK(input.is_contiguous()); - std::vector output_sizes = input.sizes().vec(); + auto contig_input = input.contiguous(); + std::vector output_sizes = contig_input.sizes().vec(); output_sizes[0] = std::accumulate( output_split_sizes.begin(), output_split_sizes.end(), int64_t(0)); - auto output = input.new_empty(output_sizes); + auto output = contig_input.new_empty(output_sizes); - auto group = c10d::resolve_process_group(std::move(group_name)); auto work = group->alltoall_base( output, // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) - const_cast(input), + const_cast(contig_input), output_split_sizes, input_split_sizes); c10d::register_work(output, work); @@ -303,12 +386,19 @@ at::Tensor all_to_all_single( // NOLINTNEXTLINE(performance-unnecessary-value-param) at::Tensor& broadcast_(at::Tensor& input, int64_t src, std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return broadcast_(input, src, std::move(group)); +} + +at::Tensor& broadcast_( + at::Tensor& input, + int64_t src, + c10::intrusive_ptr group) { c10d::BroadcastOptions opts; opts.rootRank = src; auto input_real = input.is_complex() ? at::view_as_real(input) : input; std::vector inputs{input_real}; - auto group = c10d::resolve_process_group(group_name); auto work = group->broadcast(inputs, opts); c10d::register_work(input, work); return input; @@ -318,8 +408,16 @@ at::Tensor broadcast( const at::Tensor& input, int64_t src, std::string group_name) { + auto group = c10d::resolve_process_group(std::move(group_name)); + return broadcast(input, src, std::move(group)); +} + +at::Tensor broadcast( + const at::Tensor& input, + int64_t src, + c10::intrusive_ptr group) { auto output = input.clone(at::MemoryFormat::Contiguous); - return broadcast_(output, src, std::move(group_name)); + return broadcast_(output, src, std::move(group)); } at::Tensor isend( @@ -416,73 +514,190 @@ std::vector batch_p2p_ops( } // namespace c10d +namespace { + +c10::intrusive_ptr get_process_group( + const c10::IValue& group_name, + const char* func_name) { + if (group_name.isString()) { + return c10d::resolve_process_group(group_name.toStringRef()); + } else if (group_name.isCapsule()) { + return c10::static_intrusive_pointer_cast( + group_name.toCapsule()); + } else { + TORCH_CHECK( + false, + func_name, + "(): argument 'group_name' must be either a string (group name) " + "or a ProcessGroup object, but got ", + group_name.type()->str()); + } +} + +// all_to_all_single_dispatch is kept as a named function because it is +// referenced via decltype inside the AllToAllSingle autograd class. +at::Tensor all_to_all_single_dispatch( + const at::Tensor& input, + c10::SymIntArrayRef output_split_sizes, + c10::SymIntArrayRef input_split_sizes, + const c10::IValue& group_name) { + return c10d::all_to_all_single( + input, + output_split_sizes, + input_split_sizes, + get_process_group(group_name, "all_to_all_single")); +} + +} // namespace + TORCH_LIBRARY(_c10d_functional, m) { m.def( - "all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", + "all_reduce(Tensor input, str reduce_op, Any group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce), + c10::DispatchKey::CompositeExplicitAutograd, + [](const at::Tensor& input, + std::string reduce_op, + const c10::IValue& group) { + return c10d::all_reduce( + input, + std::move(reduce_op), + get_process_group(group, "all_reduce")); + }), {at::Tag::pt2_compliant_tag}); m.def( - "all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)", + "all_reduce_(Tensor(a!) input, str reduce_op, Any group_name) -> Tensor(a!)", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::all_reduce_), + c10::DispatchKey::CompositeExplicitAutograd, + [](at::Tensor& input, std::string reduce_op, const c10::IValue& group) + -> at::Tensor& { + return c10d::all_reduce_( + input, + std::move(reduce_op), + get_process_group(group, "all_reduce_")); + }), {at::Tag::pt2_compliant_tag}); m.def( - "all_reduce_coalesced(Tensor[] inputs, str reduce_op, str group_name) -> Tensor[]", + "all_reduce_coalesced(Tensor[] inputs, str reduce_op, Any group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::all_reduce_coalesced), + [](std::vector inputs, + std::string reduce_op, + const c10::IValue& group) { + return c10d::all_reduce_coalesced( + inputs, + std::move(reduce_op), + get_process_group(group, "all_reduce_coalesced")); + }), {at::Tag::pt2_compliant_tag}); m.def( - "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, str group_name) -> Tensor[](a!)", + "all_reduce_coalesced_(Tensor[](a!) inputs, str reduce_op, Any group_name) -> Tensor[](a!)", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::all_reduce_coalesced_), + [](std::vector inputs, + std::string reduce_op, + const c10::IValue& group) { + return c10d::all_reduce_coalesced_( + inputs, + std::move(reduce_op), + get_process_group(group, "all_reduce_coalesced_")); + }), {at::Tag::pt2_compliant_tag}); m.def( - "all_gather_into_tensor_out(Tensor input, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", + "all_gather_into_tensor_out(Tensor input, int group_size, Any group_name, *, Tensor(a!) out) -> Tensor(a!)", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::all_gather_into_tensor_out), + [](at::Tensor& input, + int64_t group_size, + const c10::IValue& group, + at::Tensor& output) -> at::Tensor& { + return c10d::all_gather_into_tensor_out( + input, + group_size, + get_process_group(group, "all_gather_into_tensor_out"), + output); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "all_gather_into_tensor(Tensor input, int group_size, str group_name) -> Tensor", + "all_gather_into_tensor(Tensor input, int group_size, Any group_name) -> Tensor", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::all_gather_into_tensor), + [](const at::Tensor& input, + int64_t group_size, + const c10::IValue& group) { + return c10d::all_gather_into_tensor( + input, + group_size, + get_process_group(group, "all_gather_into_tensor")); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, str group_name) -> Tensor[]", + "all_gather_into_tensor_coalesced(Tensor[] inputs, int group_size, Any group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::all_gather_into_tensor_coalesced), + [](std::vector inputs, + int64_t group_size, + const c10::IValue& group) { + return c10d::all_gather_into_tensor_coalesced( + inputs, + group_size, + get_process_group(group, "all_gather_into_tensor_coalesced")); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, str group_name) -> Tensor", + "reduce_scatter_tensor(Tensor input, str reduce_op, int group_size, Any group_name) -> Tensor", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::reduce_scatter_tensor), + [](const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + const c10::IValue& group) { + return c10d::reduce_scatter_tensor( + input, + std::move(reduce_op), + group_size, + get_process_group(group, "reduce_scatter_tensor")); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "reduce_scatter_tensor_out(Tensor input, str reduce_op, int group_size, str group_name, *, Tensor(a!) out) -> Tensor(a!)", + "reduce_scatter_tensor_out(Tensor input, str reduce_op, int group_size, Any group_name, *, Tensor(a!) out) -> Tensor(a!)", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::reduce_scatter_tensor_out), + [](const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + const c10::IValue& group, + at::Tensor& output) { + return c10d::reduce_scatter_tensor_out( + input, + std::move(reduce_op), + group_size, + get_process_group(group, "reduce_scatter_tensor_out"), + output); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, str group_name) -> Tensor[]", + "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduce_op, int group_size, Any group_name) -> Tensor[]", torch::dispatch( c10::DispatchKey::CompositeExplicitAutograd, - c10d::reduce_scatter_tensor_coalesced), + [](std::vector inputs, + std::string reduce_op, + int64_t group_size, + const c10::IValue& group) { + return c10d::reduce_scatter_tensor_coalesced( + inputs, + std::move(reduce_op), + group_size, + get_process_group(group, "reduce_scatter_tensor_coalesced")); + }), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( @@ -490,21 +705,31 @@ TORCH_LIBRARY(_c10d_functional, m) { "Tensor input, " "SymInt[] output_split_sizes, " "SymInt[] input_split_sizes, " - "str group_name) -> Tensor", + "Any group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::all_to_all_single), + c10::DispatchKey::CompositeExplicitAutograd, + all_to_all_single_dispatch), {at::Tag::pt2_compliant_tag, at::Tag::needs_contiguous_strides}); m.def( - "broadcast(Tensor input, int src, str group_name) -> Tensor", + "broadcast(Tensor input, int src, Any group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast), + c10::DispatchKey::CompositeExplicitAutograd, + [](const at::Tensor& input, int64_t src, const c10::IValue& group) { + return c10d::broadcast( + input, src, get_process_group(group, "broadcast")); + }), {at::Tag::pt2_compliant_tag}); m.def( - "broadcast_(Tensor(a!) input, int src, str group_name) -> Tensor(a!)", + "broadcast_(Tensor(a!) input, int src, Any group_name) -> Tensor(a!)", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, c10d::broadcast_), + c10::DispatchKey::CompositeExplicitAutograd, + [](at::Tensor& input, int64_t src, const c10::IValue& group) + -> at::Tensor& { + return c10d::broadcast_( + input, src, get_process_group(group, "broadcast_")); + }), {at::Tag::pt2_compliant_tag}); m.def( @@ -646,8 +871,7 @@ class AllToAllSingle : public torch::autograd::Function { at::SymIntArrayRef output_split_sizes, // NOLINTNEXTLINE(performance-unnecessary-value-param) at::SymIntArrayRef input_split_sizes, - // NOLINTNEXTLINE(performance-unnecessary-value-param) - std::string group_name) { + const c10::IValue& group_name) { // swap sizes for backwards pass ctx->saved_data["output_split_sizes"] = input_split_sizes.vec(); ctx->saved_data["input_split_sizes"] = output_split_sizes.vec(); @@ -655,7 +879,7 @@ class AllToAllSingle : public torch::autograd::Function { return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") - .typed() + .typed() .call(input, output_split_sizes, input_split_sizes, group_name); } @@ -666,7 +890,7 @@ class AllToAllSingle : public torch::autograd::Function { ctx->saved_data["output_split_sizes"].toSymIntVector(); std::vector input_split_sizes = ctx->saved_data["input_split_sizes"].toSymIntVector(); - const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); + auto group_name = ctx->saved_data["group_name"]; DCHECK(grad_out_list.size() == 1); auto grad_out = grad_out_list[0].contiguous(); @@ -674,7 +898,7 @@ class AllToAllSingle : public torch::autograd::Function { auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_to_all_single", "") - .typed() + .typed() .call(grad_out, output_split_sizes, input_split_sizes, group_name); // do an explicit wait to avoid cuda stream issues @@ -692,7 +916,7 @@ at::Tensor all_to_all_single_autograd( const at::Tensor& input, at::SymIntArrayRef output_split_sizes, at::SymIntArrayRef input_split_sizes, - const std::string& group_name) { + const c10::IValue& group_name) { return AllToAllSingle::apply( input, output_split_sizes, input_split_sizes, group_name); } @@ -705,7 +929,7 @@ class ReduceScatterTensor const at::Tensor& input, const std::string& reduce_op, int64_t group_size, - const std::string& group_name) { + const c10::IValue& group_name) { TORCH_CHECK(reduce_op == "sum", "Only sum reduce op is supported"); ctx->saved_data["group_size"] = group_size; @@ -713,7 +937,8 @@ class ReduceScatterTensor return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") - .typed() + .typed() .call(input, reduce_op, group_size, group_name); } @@ -721,7 +946,7 @@ class ReduceScatterTensor torch::autograd::AutogradContext* ctx, const torch::autograd::variable_list& grad_out_list) { const int64_t group_size = ctx->saved_data["group_size"].toInt(); - const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); + auto group_name = ctx->saved_data["group_name"]; DCHECK(grad_out_list.size() == 1); const auto& grad_out = grad_out_list[0]; @@ -729,7 +954,7 @@ class ReduceScatterTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") - .typed() + .typed() .call(grad_out, group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -752,7 +977,7 @@ at::Tensor reduce_scatter_tensor_autograd( const at::Tensor& input, const std::string& reduce_op, int64_t group_size, - const std::string& group_name) { + const c10::IValue& group_name) { return ReduceScatterTensor::apply(input, reduce_op, group_size, group_name); } @@ -763,13 +988,13 @@ class AllGatherIntoTensor torch::autograd::AutogradContext* ctx, const at::Tensor& input, int64_t group_size, - const std::string& group_name) { + const c10::IValue& group_name) { ctx->saved_data["group_size"] = group_size; ctx->saved_data["group_name"] = group_name; return c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::all_gather_into_tensor", "") - .typed() + .typed() .call(input, group_size, group_name); } @@ -777,7 +1002,7 @@ class AllGatherIntoTensor torch::autograd::AutogradContext* ctx, const torch::autograd::variable_list& grad_out_list) { const int64_t group_size = ctx->saved_data["group_size"].toInt(); - const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); + auto group_name = ctx->saved_data["group_name"]; DCHECK(grad_out_list.size() == 1); const auto& grad_out = grad_out_list[0]; @@ -785,7 +1010,8 @@ class AllGatherIntoTensor auto out = c10::Dispatcher::singleton() .findSchemaOrThrow("_c10d_functional::reduce_scatter_tensor", "") - .typed() + .typed() .call(grad_out, "sum", group_size, group_name); // do an explicit wait to avoid cuda stream issues @@ -806,7 +1032,7 @@ class AllGatherIntoTensor at::Tensor all_gather_into_tensor_autograd( const at::Tensor& input, int64_t group_size, - const std::string& group_name) { + const c10::IValue& group_name) { return AllGatherIntoTensor::apply(input, group_size, group_name); } @@ -818,15 +1044,23 @@ TORCH_LIBRARY(_c10d_functional_autograd, m) { "Tensor input, " "SymInt[] output_split_sizes, " "SymInt[] input_split_sizes, " - "str group_name) -> Tensor", - torch::dispatch(c10::DispatchKey::Autograd, ::all_to_all_single_autograd), + "Any group_name) -> Tensor", + torch::dispatch( + c10::DispatchKey::Autograd, + [](const at::Tensor& input, + c10::SymIntArrayRef output_split_sizes, + c10::SymIntArrayRef input_split_sizes, + const c10::IValue& group) { + return all_to_all_single_autograd( + input, output_split_sizes, input_split_sizes, group); + }), {at::Tag::pt2_compliant_tag}); m.def( "reduce_scatter_tensor(" "Tensor input, " "str reduce_op, " "int group_size, " - "str group_name) -> Tensor", + "Any group_name) -> Tensor", torch::dispatch( c10::DispatchKey::Autograd, ::reduce_scatter_tensor_autograd), {at::Tag::pt2_compliant_tag}); @@ -834,7 +1068,7 @@ TORCH_LIBRARY(_c10d_functional_autograd, m) { "all_gather_into_tensor(" "Tensor input, " "int group_size, " - "str group_name) -> Tensor", + "Any group_name) -> Tensor", torch::dispatch( c10::DispatchKey::Autograd, ::all_gather_into_tensor_autograd), {at::Tag::pt2_compliant_tag}); @@ -847,8 +1081,7 @@ at::Tensor shard_dim_alltoall( const at::Tensor& input, int64_t gather_dim, int64_t shard_dim, - const std::string& group_name) { - auto group = c10d::resolve_process_group(group_name); + c10::intrusive_ptr group) { auto group_size = group->getSize(); std::vector input_sizes = input.sizes().vec(); std::vector output_sizes = input.sizes().vec(); @@ -890,13 +1123,24 @@ at::Tensor shard_dim_alltoall( return input.is_complex() ? at::view_as_complex(output).view(output_sizes) : output.view(output_sizes); } + } // namespace // DTensor comm op registry TORCH_LIBRARY(_dtensor, m) { m.def( - "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, str group_name) -> Tensor", + "shard_dim_alltoall(Tensor input, int gather_dim, int shard_dim, Any group_name) -> Tensor", torch::dispatch( - c10::DispatchKey::CompositeExplicitAutograd, ::shard_dim_alltoall), + c10::DispatchKey::CompositeExplicitAutograd, + [](const at::Tensor& input, + int64_t gather_dim, + int64_t shard_dim, + const c10::IValue& group) { + return shard_dim_alltoall( + input, + gather_dim, + shard_dim, + get_process_group(group, "shard_dim_alltoall")); + }), {at::Tag::pt2_compliant_tag}); } diff --git a/torch/csrc/distributed/c10d/Functional.hpp b/torch/csrc/distributed/c10d/Functional.hpp index fe4245f7f6885..72bac44406b1b 100644 --- a/torch/csrc/distributed/c10d/Functional.hpp +++ b/torch/csrc/distributed/c10d/Functional.hpp @@ -9,11 +9,27 @@ C10_EXPORT at::Tensor& all_reduce_( std::string reduce_op, std::string group_name); +C10_EXPORT at::Tensor& all_reduce_( + at::Tensor& input, + std::string reduce_op, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor all_reduce( const at::Tensor& input, std::string reduce_op, std::string group_name); +C10_EXPORT at::Tensor all_reduce( + const at::Tensor& input, + std::string reduce_op, + c10::intrusive_ptr group); + +C10_EXPORT std::vector all_reduce_coalesced_( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + c10::intrusive_ptr group); + C10_EXPORT std::vector all_reduce_coalesced_( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -21,6 +37,11 @@ C10_EXPORT std::vector all_reduce_coalesced_( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name); +C10_EXPORT std::vector all_reduce_coalesced( + std::vector inputs, + std::string reduce_op, + c10::intrusive_ptr group); + C10_EXPORT std::vector all_reduce_coalesced( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::vector inputs, @@ -33,17 +54,40 @@ C10_EXPORT std::vector all_gather_into_tensor_coalesced( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name); +C10_EXPORT std::vector all_gather_into_tensor_coalesced( + std::vector inputs, + int64_t group_size, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor all_gather_into_tensor( const at::Tensor& input, int64_t group_size, std::string group_name); +C10_EXPORT at::Tensor all_gather_into_tensor( + const at::Tensor& input, + int64_t group_size, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor& all_gather_into_tensor_out( at::Tensor& input, int64_t group_size, const std::string& group_name, at::Tensor& output); +C10_EXPORT at::Tensor& all_gather_into_tensor_out( + at::Tensor& input, + int64_t group_size, + c10::intrusive_ptr group, + at::Tensor& output); + +C10_EXPORT std::vector reduce_scatter_tensor_coalesced( + std::vector inputs, + // NOLINTNEXTLINE(performance-unnecessary-value-param) + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group); + C10_EXPORT std::vector reduce_scatter_tensor_coalesced( std::vector inputs, // NOLINTNEXTLINE(performance-unnecessary-value-param) @@ -52,12 +96,25 @@ C10_EXPORT std::vector reduce_scatter_tensor_coalesced( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name); +C10_EXPORT at::Tensor reduce_scatter_tensor( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor reduce_scatter_tensor( const at::Tensor& input, std::string reduce_op, int64_t group_size, std::string group_name); +C10_EXPORT at::Tensor reduce_scatter_tensor_out( + const at::Tensor& input, + std::string reduce_op, + int64_t group_size, + c10::intrusive_ptr group, + at::Tensor& output); + C10_EXPORT at::Tensor reduce_scatter_tensor_out( const at::Tensor& input, std::string reduce_op, @@ -72,11 +129,27 @@ C10_EXPORT at::Tensor all_to_all_single( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::string group_name); +C10_EXPORT at::Tensor all_to_all_single( + const at::Tensor& input, + at::SymIntArrayRef output_split_sizes, + at::SymIntArrayRef input_split_sizes, + c10::intrusive_ptr group); + +C10_EXPORT at::Tensor& broadcast_( + at::Tensor& input, + int64_t src, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor& broadcast_( at::Tensor& input, int64_t src, std::string group_name); +C10_EXPORT at::Tensor broadcast( + const at::Tensor& input, + int64_t src, + c10::intrusive_ptr group); + C10_EXPORT at::Tensor broadcast( const at::Tensor& input, int64_t src, diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index 25448dbc9f690..aefb79cfa3932 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -195,22 +195,25 @@ std::shared_ptr<::gloo::transport::Device> makeGlooDevice( transportName.value(), interfaceName, hostName, lazyInit); } -#ifdef __linux__ +#if defined(__linux__) + return GlooDeviceRegistry()->Create( "LINUX", interfaceName, hostName, lazyInit); -#endif -#ifdef __APPLE__ +#elif defined(__APPLE__) + return GlooDeviceRegistry()->Create( "APPLE", interfaceName, hostName, lazyInit); -#endif -#ifdef _WIN32 +#elif defined(_WIN32) + return GlooDeviceRegistry()->Create( "WIN32", interfaceName, hostName, lazyInit); -#endif +#else return nullptr; + +#endif } } // anonymous namespace diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index e1996ad376b16..182aa1975e891 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -595,6 +595,54 @@ std::string NCCLComm::repr() const { return c10::str((void*)ncclComm_); } +void NCCLComm::suspend() { +#ifdef NCCL_HAS_COMM_OFFLOAD + LockType lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); + auto comm = getNcclComm(); + C10D_NCCL_CHECK(ncclCommSuspend(comm, NCCL_SUSPEND_MEM), std::nullopt); +#else + TORCH_CHECK(false, "suspend() requires NCCL 2.29.7 or later"); +#endif +} + +void NCCLComm::resume() { +#ifdef NCCL_HAS_COMM_OFFLOAD + LockType lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); + auto comm = getNcclComm(); + C10D_NCCL_CHECK(ncclCommResume(comm), std::nullopt); +#else + TORCH_CHECK(false, "resume() requires NCCL 2.29.7 or later"); +#endif +} + +std::unordered_map NCCLComm::getMemoryStats() { +#ifdef NCCL_HAS_COMM_OFFLOAD + LockType lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_); + auto comm = getNcclComm(); + uint64_t suspend, suspended, persist, total; + C10D_NCCL_CHECK( + ncclCommMemStats(comm, ncclStatGpuMemSuspend, &suspend), std::nullopt); + C10D_NCCL_CHECK( + ncclCommMemStats(comm, ncclStatGpuMemSuspended, &suspended), + std::nullopt); + C10D_NCCL_CHECK( + ncclCommMemStats(comm, ncclStatGpuMemPersist, &persist), std::nullopt); + C10D_NCCL_CHECK( + ncclCommMemStats(comm, ncclStatGpuMemTotal, &total), std::nullopt); + return { + {"suspend", suspend}, + {"suspended", suspended}, + {"persist", persist}, + {"total", total}, + }; +#else + TORCH_CHECK(false, "getMemoryStats() requires NCCL 2.29.7 or later"); +#endif +} + #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map NCCLComm::ncclCommDump() { std::unordered_map dump; diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 4b0f38ce70ff1..f6bd900dc6dc2 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -94,6 +94,10 @@ static_assert( #define NCCL_HAS_COMM_SHRINK #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 29, 7) +#define NCCL_HAS_COMM_OFFLOAD +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -376,6 +380,13 @@ class NCCLComm { std::string repr() const; + // APIs related to memory offload (require NCCL 2.29.7+ at runtime) + void suspend(); + + void resume(); + + std::unordered_map getMemoryStats(); + friend class ProcessGroupNCCL; protected: diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index 80e3bd83e2569..bdd7a0d442706 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -56,7 +56,6 @@ std::string opTypeToString(OpType opType) { default: TORCH_INTERNAL_ASSERT(false, "Unknown op type!"); } - return "UNKNOWN"; } bool isP2POp(OpType opType, bool batchP2P /*= false*/) { diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 8c4a657fd7eed..95204e583caed 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -793,7 +793,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { virtual c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) { - static at::Tensor tensor; + at::Tensor tensor; // TODO: if nccl was specified then use it auto device = opts.device; if (device.has_value()) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 9f60a1f0cbf26..c81516b10bad7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -2468,6 +2468,242 @@ c10::intrusive_ptr ProcessGroupGloo::alltoall_base( return work; } +namespace { + +class AsyncAlltoallListWork : public ProcessGroupGloo::AsyncWork { + public: + AsyncAlltoallListWork( + std::shared_ptr context, + std::vector& outputTensors, + std::vector& inputTensors, + uint32_t tag, + uint64_t seq, + std::chrono::milliseconds timeout) + : ProcessGroupGloo::AsyncWork( + std::move(context), + {outputTensors}, + OpType::ALLTOALL, + seq, + timeout, + "gloo:all_to_all", + inputTensors), + outputTensors(outputTensors), + inputTensors(inputTensors), + tag(tag) {} + + std::vector outputTensors; + std::vector inputTensors; + const uint32_t tag; + + void alltoall( + std::vector& outputTensors, + std::vector& inputTensors) { + const auto scalarType = inputTensors[0].scalar_type(); + gloo::AlltoallOptions opts(context_); + opts.setTag(tag); + opts.setTimeout(getTimeout()); + + // Flatten input tensors into a single buffer + at::Tensor flatInputTensor = flattenDenseTensors(inputTensors); + GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor); + + // Allocate flat output tensor with same total size + at::Tensor flatOutputTensor = newLikeFlat(outputTensors); + GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor); + + // Perform the all-to-all operation + gloo::alltoall(opts); + + // Unflatten output into individual tensors + for (const auto i : c10::irange(outputTensors.size())) { + outputTensors[i].copy_(flatOutputTensor[static_cast(i)]); + } + } + + const std::vector getInputTensors() override { + return inputTensors; + } + + const std::vector getOutputTensors() override { + return outputTensors; + } + + void run() override { + alltoall(outputTensors, inputTensors); + } +}; + +class AsyncAlltoallListCUDAWork : public AsyncAlltoallListWork { + public: + AsyncAlltoallListCUDAWork( + const std::shared_ptr& context, + std::vector& outputTensors, + std::vector& inputTensors, + uint32_t tag, + uint64_t seq, + std::chrono::milliseconds timeout) + : AsyncAlltoallListWork( + context, + outputTensors, + inputTensors, + tag, + seq, + timeout) { + initializeStreamsEvents(inputTensors, inputStreams, inputEvents); + initializeStreamsEvents(outputTensors, outputStreams, outputEvents); + + // Kick off copy from CUDA tensors to pinned CPU tensors. + tmpInputs.reserve(inputTensors.size()); + c10::OptionalStreamGuard guard; + for (const auto i : c10::irange(inputTensors.size())) { + guard.reset_stream(inputStreams[i]); + tmpInputs.push_back( + pinnedLike(inputTensors[i]).copy_(inputTensors[i], true)); + } + + tmpOutputs.reserve(outputTensors.size()); + for (const auto i : c10::irange(outputTensors.size())) { + guard.reset_stream(outputStreams[i]); + tmpOutputs.push_back(pinnedLike(outputTensors[i])); + } + } + + void run() override { + // Synchronize with copy operations. + for (const auto i : c10::irange(inputTensors.size())) { + inputStreams[i].synchronize(); + } + for (const auto i : c10::irange(outputTensors.size())) { + outputStreams[i].synchronize(); + } + + // Run alltoall on host side tensors. + alltoall(tmpOutputs, tmpInputs); + + // Kick off copy back to the CUDA tensors. + c10::OptionalStreamGuard guard; + for (const auto i : c10::irange(outputTensors.size())) { + guard.reset_stream(outputStreams[i]); + outputTensors[i].copy_(tmpOutputs[i], /* non_blocking */ true); + outputEvents[i].record(outputStreams[i]); + } + } + + void synchronize() override { + // Synchronize with the copy back to CUDA tensors. + for (const auto i : c10::irange(outputTensors.size())) { + c10::Device device = outputTensors[i].device(); + outputEvents[i].block( + c10::impl::VirtualGuardImpl(device.type()).getStream(device)); + } + } + + std::vector tmpInputs; + std::vector inputStreams; + std::vector inputEvents; + + std::vector tmpOutputs; + std::vector outputStreams; + std::vector outputEvents; +}; + +} // namespace + +c10::intrusive_ptr ProcessGroupGloo::alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts) { + static auto invalidArgument = [](const std::string& msg) { + TORCH_CHECK(false, "ProcessGroupGloo::alltoall: " + msg); + }; + + // Validate input and output tensor lists + if (inputTensors.size() != static_cast(getSize())) { + invalidArgument( + "input tensor list size " + std::to_string(inputTensors.size()) + + " does not match world size " + std::to_string(getSize())); + } + + if (outputTensors.size() != static_cast(getSize())) { + invalidArgument( + "output tensor list size " + std::to_string(outputTensors.size()) + + " does not match world size " + std::to_string(getSize())); + } + + assertDense(invalidArgument, inputTensors); + assertDense(invalidArgument, outputTensors); + + // Check that all tensors are on the same device + assertSameDevice(invalidArgument, inputTensors); + assertSameDevice(invalidArgument, outputTensors); + + // Check that all input tensors have the same type and size + const auto& options = inputTensors[0].options(); + const auto& sizes = inputTensors[0].sizes(); + assertTypeAndSizesMatch(invalidArgument, inputTensors, options, sizes); + + // Check that all output tensors have the same type and size + const auto& outputOptions = outputTensors[0].options(); + const auto& outputSizes = outputTensors[0].sizes(); + assertTypeAndSizesMatch( + invalidArgument, outputTensors, outputOptions, outputSizes); + + // Check input and output tensors have compatible types + if (!options.type_equal(outputOptions)) { + invalidArgument("input and output tensors must have the same type"); + } + + // Check input and output tensors have compatible sizes + if (sizes != outputSizes) { + invalidArgument("input and output tensors must have the same size"); + } + + // Check device type + const auto& device = inputTensors[0].device(); + TORCH_CHECK( + outputTensors[0].device() == device, + "input and output tensors must be on the same device"); + + switch (device.type()) { + case at::kCPU: + break; + case at::kCUDA: + // If the user gave us a CUDA tensor then CUDA must be loaded. + TORCH_INTERNAL_ASSERT(at::hasCUDA()); + break; + default: + invalidArgument(c10::str("unsupported device type ", device.type())); + } + + c10::intrusive_ptr work; + auto tag = nextTag(); + auto context = getContext(tag); + ++seq_; + + if (device.type() == at::kCPU) { + work = c10::make_intrusive( + std::move(context), + outputTensors, + inputTensors, + tag, + seq_, + opts.timeout); + } else if (device.type() == at::kCUDA) { + work = c10::make_intrusive( + std::move(context), + outputTensors, + inputTensors, + tag, + seq_, + opts.timeout); + } else { + TORCH_CHECK(false, "Invalid backend"); + } + + enqueue(work); + return work; +} + static at::Tensor& checkSingleTensor(std::vector& tensors) { if (tensors.size() != 1) { TORCH_CHECK(false, "ProcessGroupGloo::send takes a single tensor"); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index f0b37ffa5ca40..7dba096df1bbd 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -401,6 +401,11 @@ class TORCH_API ProcessGroupGloo : public Backend { std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; + c10::intrusive_ptr alltoall( + std::vector& outputTensors, + std::vector& inputTensors, + const AllToAllOptions& opts = AllToAllOptions()) override; + c10::intrusive_ptr send( std::vector& tensors, int dstRank, diff --git a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp index f7a5f181f99f5..c51066f2a0b0a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupMPI.cpp @@ -67,7 +67,15 @@ bool cudaAwareMpiCheck() { } else { return false; } -#else // !defined(MPIX_CUDA_AWARE_SUPPORT) +// Recognize that Cray MPICH is CUDA-aware (used on Cray/HPE supercomputers) +#elif defined(MPIX_GPU_SUPPORT_CUDA) + const char* cray_gpu_support = std::getenv("MPICH_GPU_SUPPORT_ENABLED"); + if (cray_gpu_support != nullptr && std::string(cray_gpu_support) == "1") { + return true; + } else { + return false; + } +#else // !defined(MPIX_CUDA_AWARE_SUPPORT) && !defined(MPIX_GPU_SUPPORT_CUDA) return false; #endif // MPIX_CUDA_AWARE_SUPPORT } diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 25901dfe10661..b906166adcf31 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -11,8 +11,10 @@ #include #include +#include #include #include +#include #include #include #include @@ -75,7 +77,7 @@ ncclRedOpRAII unpackPreMulSum( const ReduceOp& reduceOp, const ncclComm_t& comm) { const auto* preMulSupplement = - reinterpret_cast(reduceOp.supplement_.get()); + reinterpret_cast(reduceOp.supplement_.get()); ncclRedOp_t preMulSum{}; bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; @@ -131,7 +133,6 @@ ncclRedOpRAII getNcclReduceOp( C10_THROW_ERROR( TypeError, "PreMulSum Data type must be half, float, bfloat16 or double"); - return ncclRedOp_t{}; } #else C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); @@ -221,6 +222,65 @@ std::string getExceptionMsgFromExceptionPtr( } } +#ifdef USE_ROCM +// Indicates that we're in the watchdog's event-query phase. This allows ROCm +// workaround behavior to be applied only to watchdog-side queries, while +// preserving existing behavior for user/main-thread `WorkNCCL::isCompleted()` +// and `wait()` calls. +thread_local bool g_in_rocm_watchdog_event_query_context = false; + +struct RocmWatchdogEventQueryContextGuard { + RocmWatchdogEventQueryContextGuard() + : previous_(g_in_rocm_watchdog_event_query_context) { + g_in_rocm_watchdog_event_query_context = true; + } + ~RocmWatchdogEventQueryContextGuard() { + g_in_rocm_watchdog_event_query_context = previous_; + } + + private: + bool previous_; +}; +#endif // USE_ROCM + +#ifdef USE_ROCM +// Watchdog-side cudaEventQuery workaround for HIP runtimes without the +// capture-mode fix. +// TODO: Remove once all supported runtimes include +// https://github.com/ROCm/rocm-systems/pull/3176 +bool queryEventWithRocmWatchdogCaptureWorkaround( + const std::shared_ptr& event) { + if (!event->isCreated()) { + return true; + } + + // Must unconditionally return false here during watchdog + active capture: + // on affected HIP runtimes, even calling cudaEventQuery from the watchdog + // thread while another thread has GLOBAL capture active can invalidate that + // capture and cause downstream failures. Skip the query entirely and report + // "not complete yet"; the watchdog will re-poll once capture ends. Timeout + // enforcement is also deferred during this window (see the + // is_graph_capture_active() gate in the watchdog loop). + if (g_in_rocm_watchdog_event_query_context && + at::cuda::is_graph_capture_active()) { + return false; + } + + const cudaError_t err = + C10_CUDA_ERROR_HANDLED(cudaEventQuery(event->event())); + if (err == cudaSuccess) { + return true; + } else if (err != cudaErrorNotReady) { + C10_CUDA_CHECK(err); + } else { + // ignore and clear the error if not ready + (void)cudaGetLastError(); + } + + return false; +} +#endif // USE_ROCM + inline void errorIfCapturingNonCapturableNCCL(c10::cuda::CaptureStatus status) { // parentheses avoid some compiler warnings static const uint64_t min_version = @@ -644,7 +704,11 @@ bool ProcessGroupNCCL::WorkNCCL::startedGPUExecutionInternal() const { return false; } // Checking the work's corresponding CUDA event's status +#ifdef USE_ROCM + if (!queryEventWithRocmWatchdogCaptureWorkaround(ncclStartEvent_)) { +#else if (!ncclStartEvent_->query()) { +#endif return false; } return true; @@ -657,7 +721,11 @@ bool ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const { // hang if another thread is holding the CUDA global context lock. For // example, when doing a `cudaDeviceSynchronize` or even // `cudaStreamSynchronize`. +#ifdef USE_ROCM + if (!queryEventWithRocmWatchdogCaptureWorkaround(ncclEndEvent_)) { +#else if (!ncclEndEvent_->query()) { +#endif return false; } return true; @@ -892,7 +960,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( options_(std::move(options)), terminateProcessGroup_(false), local_id_(process_group_id++), - intraNodeComm_(initIntraNodeComm()) { + intraNodeComm_(nullptr) { TORCH_CHECK_WITH( ValueError, at::cuda::getNumGPUs() != 0, @@ -1198,7 +1266,10 @@ c10::intrusive_ptr ProcessGroupNCCL:: return nullptr; } auto prefixStore = c10::make_intrusive("IntraNodeComm", store_); - auto comm = c10::make_intrusive(prefixStore, rank_, size_); + const std::string groupName = + options_->group_name.empty() ? "0" : options_->group_name; + auto comm = c10::make_intrusive( + prefixStore, rank_, size_, std::nullopt, groupName); if (comm->rendezvous()) { return comm; } else { @@ -2291,9 +2362,23 @@ void ProcessGroupNCCL::Watchdog::runLoop() { } } - // Then check if work has timed out - // Skip if work has encountered an error - bool timedout = !work.exception() && work.checkTimeout(); + // Then check if work has timed out. + // Skip if work has encountered an error. + + bool timedout = false; +#ifdef USE_ROCM + // On ROCm, watchdog event queries may be intentionally skipped during + // active graph capture to avoid HIP runtime capture invalidation. + // In that window, timeout checks can report false positives for + // otherwise-complete work, so we defer timeout enforcement. + // TODO: Remove once all supported HIP runtimes include: + // https://github.com/ROCm/clr/pull/3176 + if (!at::cuda::is_graph_capture_active()) { + timedout = !work.exception() && work.checkTimeout(); + } +#else + timedout = !work.exception() && work.checkTimeout(); +#endif // Report desync state in case of timeout (if TORCH_NCCL_DESYNC_DEBUG is // turned on; otherwise, run() is no-op) @@ -2358,7 +2443,11 @@ void ProcessGroupNCCL::Watchdog::runLoop() { // allow watchdog to do an event query on a side thread at::cuda::CUDAGuard device_guard(work.ncclEndEvent_->device_index()); at::cuda::CUDAStreamCaptureModeGuard g{cudaStreamCaptureModeThreadLocal}; - +#ifdef USE_ROCM + // Mark this thread/scope as watchdog event-query context so the ROCm + // workaround applies only here (not to main-thread wait()/isCompleted()). + RocmWatchdogEventQueryContextGuard watchdog_event_query_context_guard; +#endif // a work could be started but not completed, so we should not update // lastStartedSeq and lastStartedOpName if the work state is checked // multiple times after the start @@ -3277,6 +3366,30 @@ uint64_t ProcessGroupNCCL::getCommSplitCounter() const { return ret; } +void ProcessGroupNCCL::suspend() { + auto device = at::Device(at::kCUDA, guessDeviceId()); + std::string deviceKey = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(deviceKey); + TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized."); + ncclComm->suspend(); +} + +void ProcessGroupNCCL::resume() { + auto device = at::Device(at::kCUDA, guessDeviceId()); + std::string deviceKey = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(deviceKey); + TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized."); + ncclComm->resume(); +} + +std::unordered_map ProcessGroupNCCL::getMemoryStats() { + auto device = at::Device(at::kCUDA, guessDeviceId()); + std::string deviceKey = getKeyFromDevice(device); + auto ncclComm = getNCCLComm(deviceKey); + TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized."); + return ncclComm->getMemoryStats(); +} + namespace { // Check validity of tensor @@ -4501,12 +4614,17 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( } check_gpu_single_tensor(tensor); - if (intraNodeComm_ != nullptr && opts.reduceOp == ReduceOp::SUM) { + if (opts.reduceOp == ReduceOp::SUM) { using namespace intra_node_comm; - auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); - if (algo != intra_node_comm::AllReduceAlgo::NONE) { - intraNodeComm_->allReduce(tensor, algo); - return c10::make_intrusive(); + if (intraNodeComm_ == nullptr && IntraNodeComm::isEnabled()) { + intraNodeComm_ = initIntraNodeComm(); + } + if (intraNodeComm_ != nullptr) { + auto algo = intraNodeComm_->selectAllReduceAlgo(tensor); + if (algo != intra_node_comm::AllReduceAlgo::NONE) { + intraNodeComm_->allReduce(tensor, algo); + return c10::make_intrusive(); + } } } TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index c5d3eec1a03db..0069c556ca038 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -1051,6 +1051,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { void setEnableNanCheck(bool enableNanCheck); + // APIs related to memory offload (require NCCL 2.29.7+ at runtime) + void suspend() override; + + void resume() override; + + std::unordered_map getMemoryStats() override; + protected: uint64_t getWatchdogHeartbt() const; diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp index 77974b49d6003..b19750fb146ba 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp @@ -416,6 +416,13 @@ c10::intrusive_ptr ProcessGroupWrapper::allreduce_coalesced( return backend_->allreduce_coalesced(tensors, opts); } +c10::intrusive_ptr ProcessGroupWrapper::allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts) { + runCollectiveChecks(OpType::_ALLREDUCE_SPARSE, tensors); + return backend_->allreduce_sparse(tensors, opts); +} + c10::intrusive_ptr ProcessGroupWrapper::reduce( std::vector& tensors, const ReduceOptions& opts) { @@ -456,6 +463,13 @@ c10::intrusive_ptr ProcessGroupWrapper::allgather_coalesced( return backend_->allgather_coalesced(outputTensorLists, inputTensors, opts); } +c10::intrusive_ptr ProcessGroupWrapper::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return backend_->allgather_into_tensor_coalesced(outputs, inputs, opts); +} + c10::intrusive_ptr ProcessGroupWrapper::gather( std::vector>& outputTensors, std::vector& inputTensors, @@ -590,6 +604,21 @@ bool ProcessGroupWrapper::supportsTimeEstimation() const { return backend_->supportsTimeEstimation(); } +bool ProcessGroupWrapper::supportsShrinking() const { + return backend_->supportsShrinking(); +} + +c10::intrusive_ptr ProcessGroupWrapper::shrink( + const std::vector& ranks_to_exclude, + int shrink_flags, + const c10::intrusive_ptr& opts_override) { + return backend_->shrink(ranks_to_exclude, shrink_flags, opts_override); +} + +void ProcessGroupWrapper::setTimeout(std::chrono::milliseconds timeout) { + backend_->setTimeout(timeout); +} + c10::intrusive_ptr ProcessGroupWrapper::getBackendOptions() { return backend_->getBackendOptions(); } @@ -608,6 +637,27 @@ bool ProcessGroupWrapper::supportsTensorAlloc(c10::DeviceIndex deviceIdx) { return backend_->supportsTensorAlloc(deviceIdx); } +void ProcessGroupWrapper::abort() { + backend_->abort(); +} + +void ProcessGroupWrapper::shutdown() { + backend_->shutdown(); +} + +void ProcessGroupWrapper::suspend() { + backend_->suspend(); +} + +void ProcessGroupWrapper::resume() { + backend_->resume(); +} + +std::unordered_map ProcessGroupWrapper:: + getMemoryStats() { + return backend_->getMemoryStats(); +} + ErrorType ProcessGroupWrapper::getError() { return backend_->getError(); } @@ -616,6 +666,34 @@ void ProcessGroupWrapper::eagerConnectSingleDevice(at::Device device) { backend_->eagerConnectSingleDevice(device); } +void ProcessGroupWrapper::registerOnCompletionHook( + std::function)>&& hook) { + backend_->registerOnCompletionHook(std::move(hook)); +} + +void ProcessGroupWrapper::waitForPendingWorks() { + backend_->waitForPendingWorks(); +} + +void ProcessGroupWrapper::enableCollectivesTiming() { + backend_->enableCollectivesTiming(); +} + +c10::intrusive_ptr ProcessGroupWrapper::split( + const c10::intrusive_ptr& store, + const std::vector& ranks, + const c10::intrusive_ptr& opts) { + return backend_->split(store, ranks, opts); +} + +c10::intrusive_ptr ProcessGroupWrapper::merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) { + return backend_->merge(store, opts, rank, size); +} + c10::intrusive_ptr ProcessGroupWrapper::getWrappedPg() const { return backend_; } diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp index cc5cb9f695b7d..b41166eed5535 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.hpp @@ -23,8 +23,6 @@ class TORCH_API ProcessGroupWrapper : public Backend { const c10::intrusive_ptr& backend, c10::intrusive_ptr glooBackend); - const std::string getBackendName() const override; - c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; @@ -33,6 +31,10 @@ class TORCH_API ProcessGroupWrapper : public Backend { std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) override; + c10::intrusive_ptr allreduce_sparse( + std::vector& tensors, + const AllreduceOptions& opts = AllreduceOptions()) override; + c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = @@ -61,6 +63,11 @@ class TORCH_API ProcessGroupWrapper : public Backend { std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; + c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, @@ -76,6 +83,16 @@ class TORCH_API ProcessGroupWrapper : public Backend { std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& inputBuffer, + at::Tensor& outputBuffer, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, @@ -118,29 +135,46 @@ class TORCH_API ProcessGroupWrapper : public Backend { c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; + void registerOnCompletionHook( + std::function)>&& hook) override; - c10::intrusive_ptr _reduce_scatter_base( - at::Tensor& outputBuffer, - at::Tensor& inputBuffer, - const ReduceScatterOptions& opts) override; - - c10::intrusive_ptr reduce_scatter_tensor_coalesced( - std::vector& outputs, - std::vector& inputs, - const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + void waitForPendingWorks() override; + void enableCollectivesTiming() override; - void startCoalescing() override; + c10::intrusive_ptr split( + const c10::intrusive_ptr& store, + const std::vector& ranks, + const c10::intrusive_ptr& opts) override; - c10::intrusive_ptr endCoalescing() override; + c10::intrusive_ptr merge( + const c10::intrusive_ptr& store, + const c10::intrusive_ptr& opts, + const int& rank, + const int& size) override; // Forward methods to wrapped backend bool supportsSplitting() const override; bool supportsCoalescing() const override; bool supportsTimeEstimation() const override; + bool supportsShrinking() const override; + c10::intrusive_ptr shrink( + const std::vector& ranks_to_exclude, + int shrink_flags = 0, + const c10::intrusive_ptr& opts_override = nullptr) override; + void setTimeout(std::chrono::milliseconds timeout) override; + void startCoalescing() override; + c10::intrusive_ptr endCoalescing() override; + const std::string getBackendName() const override; c10::intrusive_ptr getBackendOptions() override; std::shared_ptr getMemAllocator() override; at::Tensor allocateTensor(long size, at::TensorOptions options = {}) override; bool supportsTensorAlloc(c10::DeviceIndex deviceIdx) override; + void abort() override; + void shutdown() override; + void suspend() override; + void resume() override; + std::unordered_map getMemoryStats() override; + ErrorType getError() override; void eagerConnectSingleDevice(at::Device device) override; diff --git a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp index e15e5e06df9b6..f718886445f07 100644 --- a/torch/csrc/distributed/c10d/TCPStoreBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreBackend.cpp @@ -221,27 +221,11 @@ void TCPStoreMasterDaemon::queryFds(std::vector& fds) { void TCPStoreMasterDaemon::clearSocketWaitState(int socket) { // Remove all the tracking state of the close FD - for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { - for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { - if (*vecIt == socket) { - vecIt = it->second.erase(vecIt); - } else { - ++vecIt; - } - } - if (it->second.empty()) { - it = waitingSockets_.erase(it); - } else { - ++it; - } - } - for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) { - if (it->first == socket) { - it = keysAwaited_.erase(it); - } else { - ++it; - } - } + std::erase_if(waitingSockets_, [&](auto& entry) { + std::erase(entry.second, socket); + return entry.second.empty(); + }); + keysAwaited_.erase(socket); } // query communicates with the worker. The format diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index 9cab99cf18556..f39f829953fd6 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -1405,20 +1405,10 @@ void LibUVStoreDaemon::clearClientWaitState( return; } keysAwaited_.erase(client); - for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) { - for (auto vecIt = it->second.begin(); vecIt != it->second.end();) { - if (*vecIt == client) { - vecIt = it->second.erase(vecIt); - } else { - ++vecIt; - } - } - if (it->second.empty()) { - it = waitingSockets_.erase(it); - } else { - ++it; - } - } + std::erase_if(waitingSockets_, [&](auto& entry) { + std::erase(entry.second, client); + return entry.second.empty(); + }); } void LibUVStoreDaemon::set( diff --git a/torch/csrc/distributed/c10d/Types.cpp b/torch/csrc/distributed/c10d/Types.cpp index 300d21780bdb0..bf718a2fb9d61 100644 --- a/torch/csrc/distributed/c10d/Types.cpp +++ b/torch/csrc/distributed/c10d/Types.cpp @@ -16,7 +16,6 @@ bool isComplexViewAsRealAllowed(const ReduceOp& reduceOp) { default: return false; } - return false; } } // namespace c10d diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 18db14f5cef04..125528e52af27 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -20,14 +20,16 @@ struct TORCH_API _SupplementBase : torch::CustomClassHolder { // Supplementary data specific to NCCL PREMUL_SUM // The point of use in ProcessGroupNCCL knows how to unpack it. -struct NCCLPreMulSumSupplement : _SupplementBase { +struct PreMulSumSupplement : _SupplementBase { double double_factor{0.0}; at::Tensor tensor_factor; - NCCLPreMulSumSupplement(double f) : double_factor{f} {} - NCCLPreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { + PreMulSumSupplement(double f) : double_factor{f} {} + PreMulSumSupplement(at::Tensor t) : tensor_factor{std::move(t)} { TORCH_CHECK_EQ(tensor_factor.numel(), 1); } }; +// Keep for BC only +using NCCLPreMulSumSupplement = PreMulSumSupplement; // Other ReduceOps that need different supplementary data can also // derive from _SupplementBase. @@ -103,10 +105,10 @@ struct TORCH_API ReduceOp : torch::CustomClassHolder { }; template -ReduceOp makeNCCLPreMulSum(const T& factor) { +ReduceOp makePreMulSum(const T& factor) { ReduceOp rop; rop.op_ = ReduceOp::PREMUL_SUM; - rop.supplement_ = c10::make_intrusive(factor); + rop.supplement_ = c10::make_intrusive(factor); return rop; } diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index cdec9185ce537..86adb9770faed 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -7,13 +7,44 @@ namespace c10d { +namespace { +// Raw pointer avoids thread_local destructor issues in forked +// processes and dynamically-loaded libraries. +thread_local std::string* comm_profiling_name = nullptr; +static_assert( + std::is_trivially_destructible_v, + "comm_profiling_name must be trivially destructible — a non-trivial " + "destructor (e.g. std::string) causes deadlocks after fork() in " + "dlopen'd libraries via __cxa_thread_atexit."); +} // namespace + +void set_comm_profiling_name(const std::string& name) { + if (!comm_profiling_name) { + comm_profiling_name = new std::string(name); + } else { + *comm_profiling_name = name; + } +} + +const std::string& get_comm_profiling_name() { + if (comm_profiling_name) { + return *comm_profiling_name; + } + static const std::string empty; + return empty; +} + Work::Work( int rank, OpType opType, const char* profilingTitle, const std::optional>& inputTensors) : rank_(rank), opType_(opType) { - if (profilingTitle != nullptr) { + // comm_profiling_name is thread-local; take a local copy so the + // RecordFunction owns the string (the TLS can be mutated after we return). + const bool use_tls_name = + comm_profiling_name != nullptr && !comm_profiling_name->empty(); + if (use_tls_name || profilingTitle != nullptr) { auto recordingFunction = std::make_shared(at::RecordScope::USER_SCOPE); if (recordingFunction->isActive()) { @@ -29,9 +60,17 @@ Work::Work( inputs.emplace_back(tensor); } } - recordingFunction->before( - profilingTitle, - c10::ArrayRef(inputs.data(), inputs.size())); + if (use_tls_name) { + recordingFunction->before( + std::string(*comm_profiling_name), + c10::ArrayRef(inputs.data(), inputs.size())); + } else { + // const char* overload — pointer is a string literal with static + // lifetime + recordingFunction->before( + profilingTitle, + c10::ArrayRef(inputs.data(), inputs.size())); + } std::function end_handler = [recordingFunction]() { recordingFunction->end(); }; diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index 8ab9ebf1c08ee..f7e4317c2ca29 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -183,4 +183,7 @@ struct TORCH_API WorkInfo { std::chrono::duration activeDuration; }; +TORCH_API void set_comm_profiling_name(const std::string& name); +TORCH_API const std::string& get_comm_profiling_name(); + } // namespace c10d diff --git a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu index a7e9ddaa17b1e..2bec14bad98dd 100644 --- a/torch/csrc/distributed/c10d/cuda/AsyncMM.cu +++ b/torch/csrc/distributed/c10d/cuda/AsyncMM.cu @@ -79,7 +79,7 @@ at::Tensor async_input_mm_impl( cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, - ElementC, + void, // Indicate there is no beta scaling to save register LayoutC, AlignmentC, ElementC, @@ -151,7 +151,7 @@ at::Tensor async_input_mm_impl( stride_B, }, {{}, - reinterpret_cast(out.data_ptr()), + nullptr, stride_C, reinterpret_cast(out.data_ptr()), stride_C}, diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 5df75a7e81b88..27fd17f4e1d11 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -209,7 +209,7 @@ class PythonStore : public ::c10d::Store { using ::c10d::Store::Store; // Note: this function manually calls the Python-side overload - // for this function instead of using the PYBIND11_OVERLOAD_XYZ + // for this function instead of using the PYBIND11_OVERRIDE_XYZ // macros. This is done so that we can call the Python-side // function with a std::string instead of a std::vector. void set(const std::string& key, const std::vector& value) override { @@ -222,7 +222,7 @@ class PythonStore : public ::c10d::Store { } // Note: this function manually calls the Python-side overload - // for this function instead of using the PYBIND11_OVERLOAD_XYZ + // for this function instead of using the PYBIND11_OVERRIDE_XYZ // macros. This is done so that the Python-side function can // return a py::bytes instead of a std::vector. std::vector get(const std::string& key) override { @@ -239,7 +239,7 @@ class PythonStore : public ::c10d::Store { } // Note: this function manually calls the Python-side overload - // for this function instead of using the PYBIND11_OVERLOAD_XYZ + // for this function instead of using the PYBIND11_OVERRIDE_XYZ // macros. This is done so that the Python-side function can // return a py::bytes instead of a std::vector. std::vector compareSet( @@ -260,37 +260,37 @@ class PythonStore : public ::c10d::Store { } int64_t add(const std::string& key, int64_t value) override { - PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, add, key, value); + PYBIND11_OVERRIDE_PURE(int64_t, ::c10d::Store, add, key, value); } int64_t getNumKeys() override { - PYBIND11_OVERLOAD_PURE(int64_t, ::c10d::Store, getNumKeys); + PYBIND11_OVERRIDE_PURE(int64_t, ::c10d::Store, getNumKeys); } bool deleteKey(const std::string& key) override { - PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, deleteKey, key); + PYBIND11_OVERRIDE_PURE(bool, ::c10d::Store, deleteKey, key); } bool check(const std::vector& keys) override { - PYBIND11_OVERLOAD_PURE(bool, ::c10d::Store, check, keys); + PYBIND11_OVERRIDE_PURE(bool, ::c10d::Store, check, keys); } void wait(const std::vector& keys) override { - PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys); + PYBIND11_OVERRIDE_PURE(void, ::c10d::Store, wait, keys); } void wait( const std::vector& keys, const std::chrono::milliseconds& timeout) override { - PYBIND11_OVERLOAD_PURE(void, ::c10d::Store, wait, keys, timeout); + PYBIND11_OVERRIDE_PURE(void, ::c10d::Store, wait, keys, timeout); } c10::intrusive_ptr clone() override { - PYBIND11_OVERLOAD_PURE(c10::intrusive_ptr, ::c10d::Store, clone); + PYBIND11_OVERRIDE_PURE(c10::intrusive_ptr, ::c10d::Store, clone); } // Note: this function manually calls the Python-side overload - // for this function instead of using the PYBIND11_OVERLOAD_XYZ + // for this function instead of using the PYBIND11_OVERRIDE_XYZ // macros. This is done so that we can call the Python-side // function with a std::string instead of a std::vector. void append(const std::string& key, const std::vector& value) @@ -339,7 +339,7 @@ class PythonStore : public ::c10d::Store { } bool hasExtendedApi() const override { - PYBIND11_OVERLOAD_NAME( + PYBIND11_OVERRIDE_NAME( bool, ::c10d::Store, "has_extended_api", hasExtendedApi); } }; @@ -574,7 +574,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO int64_t first_bucket_bytes_cap, bool skip_all_reduce_unused_params, bool use_python_reducer, - std::vector bucket_bytes_cap_list) { + std::vector bucket_bytes_cap_list, + bool batched_grad_copy) { // gil_scoped_release is not safe as a call_guard in init. // https://github.com/pybind/pybind11/issues/5473 py::gil_scoped_release nogil{}; @@ -591,7 +592,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO first_bucket_bytes_cap, skip_all_reduce_unused_params, use_python_reducer, - std::move(bucket_bytes_cap_list)); + std::move(bucket_bytes_cap_list), + batched_grad_copy); }), py::arg("params"), py::arg("bucket_indices"), @@ -606,7 +608,8 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO py::arg("first_bucket_bytes_cap") = ::c10d::kDefaultFirstBucketBytes, py::arg("skip_all_reduce_unused_params") = false, py::arg("use_python_reducer") = false, - py::arg("bucket_bytes_cap_list") = std::vector()) + py::arg("bucket_bytes_cap_list") = std::vector(), + py::arg("batched_grad_copy") = false) .def( "prepare_for_forward", &::c10d::Reducer::prepare_for_forward, @@ -812,11 +815,6 @@ An enum-like class for built-in communication hooks: ``ALLREDUCE`` and ``FP16_CO R"(Sets the debug level of the torch.distributed package from the ``TORCH_DISTRIBUTED_DEBUG`` environment variable.)"); - // TODO(crcrpar): Hardening `ReduceOp`. - // While keeping most op types as enum value, - // making `PREMUL_SUM` callable, i.e., allowing for - // `ReduceOp.PREMUL_SUM(scale)` might be better as per @wanchaol. - // https://pybind11.readthedocs.io/en/stable/classes.html#enumerations-and-internal-types py::class_<::c10d::ReduceOp> reduce_op( module, "ReduceOp", @@ -833,9 +831,9 @@ using the ``NCCL`` backend. and only for NCCL versions 2.10 or later. ``PREMUL_SUM`` multiplies inputs by a given scalar locally before reduction. -``PREMUL_SUM`` is only available with the ``NCCL`` backend, -and only available for NCCL versions 2.11 or later. Users are supposed to -use ``torch.distributed._make_nccl_premul_sum``. +``PREMUL_SUM`` is available with the ``NCCL`` backend (NCCL versions 2.11 or later) +and the ``XCCL`` backend. It can be used by calling ``ReduceOp.PREMUL_SUM(factor)`` +where factor is a float or a single-element Tensor. Additionally, ``MAX``, ``MIN`` and ``PRODUCT`` are not supported for complex tensors. @@ -887,14 +885,14 @@ This class does not support ``__members__`` property.)"); return ::c10d::ReduceOp(self); }) .def(py::pickle( - [](const ::c10d::ReduceOp& r) { + [](const ::c10d::ReduceOp& r) -> py::tuple { // __getstate__ if (r.op_ != ::c10d::ReduceOp::RedOpType::PREMUL_SUM) { return py::make_tuple(r.op_, py::none()); } TORCH_CHECK(r.supplement_.defined(), "Invalid PREMUL_SUM ReduceOp"); const auto* preMulSupplement = - reinterpret_cast<::c10d::NCCLPreMulSumSupplement*>( + reinterpret_cast<::c10d::PreMulSumSupplement*>( r.supplement_.get()); if (!preMulSupplement->tensor_factor.defined()) { return py::make_tuple(r.op_, preMulSupplement->double_factor); @@ -912,9 +910,9 @@ This class does not support ``__members__`` property.)"); } const auto preMulSupplement_factor = t[1]; if (py::isinstance(preMulSupplement_factor)) { - return ::c10d::makeNCCLPreMulSum(t[1].cast()); + return ::c10d::makePreMulSum(t[1].cast()); } else { - return ::c10d::makeNCCLPreMulSum(t[1].cast()); + return ::c10d::makePreMulSum(t[1].cast()); } })); @@ -928,7 +926,39 @@ This class does not support ``__members__`` property.)"); .value("BOR", ::c10d::ReduceOp::RedOpType::BOR) .value("BXOR", ::c10d::ReduceOp::RedOpType::BXOR) .value("PREMUL_SUM", ::c10d::ReduceOp::RedOpType::PREMUL_SUM) - .export_values(); + .export_values() + .def( + "__call__", + [](const ::c10d::ReduceOp::RedOpType& self, + const py::object& factor) -> ::c10d::ReduceOp { + TORCH_CHECK( + self == ::c10d::ReduceOp::RedOpType::PREMUL_SUM, + "Only PREMUL_SUM supports calling with a factor, got ", + py::str(py::cast(self)).cast()); + if (py::isinstance(factor) || + py::isinstance(factor)) { + return ::c10d::makePreMulSum(factor.cast()); + } else { + return ::c10d::makePreMulSum(factor.cast()); + } + }, + py::arg("factor"), + R"(Create a PREMUL_SUM ReduceOp with the given factor. + +Only ``PREMUL_SUM`` supports this callable interface. Other reduction +operations will raise an error if called. + +Args: + factor: A scalar (float, int) or a single-element Tensor to multiply + inputs by before reduction. + +Returns: + A ReduceOp configured for PREMUL_SUM with the specified factor. + +Example: + >>> op = ReduceOp.PREMUL_SUM(2.0) + >>> dist.all_reduce(tensor, op) +)"); // note(crcrpar): This could be removed because users will not pass // `RedOpType` to reduce collective ops Ref: [Implicit @@ -940,13 +970,13 @@ This class does not support ``__members__`` property.)"); module .def( "_make_nccl_premul_sum", - &::c10d::makeNCCLPreMulSum, + &::c10d::makePreMulSum, py::arg("factor").noconvert(), py::return_value_policy::copy, // seems safest py::call_guard()) .def( "_make_nccl_premul_sum", - &::c10d::makeNCCLPreMulSum, + &::c10d::makePreMulSum, py::arg("factor").noconvert(), py::return_value_policy::copy, // seems safest py::call_guard()); @@ -1140,6 +1170,10 @@ This class does not support ``__members__`` property.)"); &::c10d::symmetric_memory::has_multicast_support) .def_static("set_backend", &::c10d::symmetric_memory::set_backend) .def_static("get_backend", &::c10d::symmetric_memory::get_backend) + .def_static( + "is_symm_mem_tensor", + &::c10d::symmetric_memory::is_symm_mem_tensor, + py::arg("tensor")) .def_property_static( "signal_pad_size", [](py::object /* self */) { @@ -2758,6 +2792,13 @@ The hook must have the following signature: module.def("_set_process_group", &::c10d::setProcessGroup); module.def("_current_process_group", &::c10d::currentProcessGroup); + // Thread local comm profiling name + module.def( + "_set_comm_profiling_name", + &::c10d::set_comm_profiling_name, + py::arg("name")); + module.def("_get_comm_profiling_name", &::c10d::get_comm_profiling_name); + py::enum_<::c10d::ProcessGroup::BackendType>( processGroup, "BackendType", @@ -3156,7 +3197,15 @@ The hook must have the following signature: py::arg("device"), py::call_guard()) .def_property_readonly( - "mem_allocator", &::c10d::Backend::getMemAllocator); + "mem_allocator", &::c10d::Backend::getMemAllocator) + .def("suspend", &::c10d::Backend::suspend) + .def("resume", &::c10d::Backend::resume) + .def("memory_stats", &::c10d::Backend::getMemoryStats, R"( + Get the memory statistics of the backend. + + Returns: + A dictionary containing the memory statistics. + )"); // base Backend::Options binding // TODO: Maybe we can consider how to merge this with diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 4007e08227e96..6b56e09639214 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -95,7 +95,8 @@ Reducer::Reducer( int64_t first_bucket_bytes_cap, bool skip_all_reduce_unused_params, bool use_python_reducer, - std::vector bucket_bytes_cap_list) + std::vector bucket_bytes_cap_list, + bool batched_grad_copy) : params_(std::move(params)), process_group_(std::move(process_group)), expect_sparse_gradients_(std::move(expect_sparse_gradients)), @@ -105,6 +106,7 @@ Reducer::Reducer( has_marked_unused_parameters_(false), find_unused_parameters_(find_unused_parameters), gradient_as_bucket_view_(gradient_as_bucket_view), + batched_grad_copy_(batched_grad_copy), local_used_map_reduced_(false), num_iterations_(0), num_bwd_calls_(0), @@ -372,51 +374,66 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) { // to bucket_view. If grad has already been set as views of buckets in // previous iterations, no copy is needed. if (!grad.is_alias_of(bucket_view)) { - if (comm_hook_ == nullptr) { - auto wrapped = at::native::wrapped_scalar_tensor(1. / div_factor_); - if (!grad.requires_grad()) { - // Divides while copying into the bucket view to save one scan over - // all the input parameters. - RECORD_FUNCTION( - "torch::distributed::reducer::mul_out", - std::vector({bucket_view})) - at::mul_out(bucket_view, grad, wrapped); + if (batched_grad_copy_ && !grad.requires_grad()) { + // Defer the copy — will be batched with _foreach_copy_ + flat div_ + // when bucket.pending == 0. + bucket.deferred_copy_indices.push_back( + bucket_index.intra_bucket_index); + } else { + if (comm_hook_ == nullptr) { + auto wrapped = at::native::wrapped_scalar_tensor(1. / div_factor_); + if (!grad.requires_grad()) { + // Divides while copying into the bucket view to save one scan + // over all the input parameters. + RECORD_FUNCTION( + "torch::distributed::reducer::mul_out", + std::vector({bucket_view})) + at::mul_out(bucket_view, grad, wrapped); + } else { + // If DDP is running with create_graph=True, gradients + // require_grad themselves in order to compute higher order + // derivatives. However, DDP will not sync up these gradients + // currently (see + // https://github.com/pytorch/pytorch/issues/63812). + C10_LOG_EVERY_N(WARNING, 1000) + << "Using DistributedDataParallel with create_graph=True " + << " is not well-supported. The higher-order gradient will " + << " not be synchronized across ranks, and backpropagation " + << " through all_reduce operations will not occur. If you require " + << " DDP to work with higher-order gradients for your use case, " + << " please ping https://github.com/pytorch/pytorch/issues/63929"; + if (batched_grad_copy_) { + C10_LOG_EVERY_N(WARNING, 1000) + << "batched_grad_copy is incompatible with " + << "create_graph=True and has been bypassed."; + } + auto div_result = at::mul(grad, wrapped); + RECORD_FUNCTION( + "torch::distributed::reducer::copy_", + std::vector({bucket_view})) + bucket_view.copy_(div_result); + } } else { - // If DDP is running with create_graph=True, gradients require_grad - // themselves in order to compute higher order derivatives. However, - // DDP will not sync up these gradients currently (see - // https://github.com/pytorch/pytorch/issues/63812). - C10_LOG_EVERY_N(WARNING, 1000) - << "Using DistributedDataParallel with create_graph=True " - << " is not well-supported. The higher-order gradient will " - << " not be synchronized across ranks, and backpropagation " - << " through all_reduce operations will not occur. If you require " - << " DDP to work with higher-order gradients for your use case, " - << " please ping https://github.com/pytorch/pytorch/issues/63929"; - auto div_result = at::mul(grad, wrapped); RECORD_FUNCTION( "torch::distributed::reducer::copy_", std::vector({bucket_view})) - bucket_view.copy_(div_result); + bucket_view.copy_(grad); } - } else { - RECORD_FUNCTION( - "torch::distributed::reducer::copy_", - std::vector({bucket_view})) - bucket_view.copy_(grad); - } - if (gradient_as_bucket_view_) { - // Let grad point to bucket_view buffer. - grad = bucket_view; - // The grad is modified and need to be written back. - return true; + if (gradient_as_bucket_view_) { + grad = bucket_view; + return true; + } } } else { // If grad and bucket view point to the same storage, no need to copy. - if (comm_hook_ == nullptr) { - bucket_view.div_(div_factor_); + if (!batched_grad_copy_) { + if (comm_hook_ == nullptr) { + bucket_view.div_(div_factor_); + } } + // When batched_grad_copy_ is enabled, div_ is deferred to a single + // flat bucket div_ in flush_deferred_copies. } } else { // Gradient is undefined. When find_unused_parameters=True, ensure it is @@ -433,7 +450,6 @@ void Reducer::mark_variable_ready_dense(size_t variable_index) { } bucket_view.zero_(); } - // The grad is not modified and doesn't need to be written back. return false; }); } @@ -630,7 +646,11 @@ void Reducer::delay_all_reduce() { } // launch all reduces for all buckets - for (auto& bucket : buckets_) { + for (const auto bucket_index : c10::irange(buckets_.size())) { + auto& bucket = buckets_[bucket_index]; + if (batched_grad_copy_) { + flush_deferred_copies(bucket, bucket_index); + } all_reduce_bucket(bucket); } @@ -907,6 +927,11 @@ void Reducer::mark_variable_ready(size_t variable_index) { // Check if this was the final gradient for this bucket. if (--bucket.pending == 0) { + // When batched_grad_copy_ is enabled, flush deferred copies and perform + // a single div on the flat bucket tensor instead of per-variable ops. + if (batched_grad_copy_) { + flush_deferred_copies(bucket, bucket_index.bucket_index); + } mark_bucket_ready(bucket_index.bucket_index); } @@ -1431,6 +1456,7 @@ void Reducer::reset_bucket_counting() { for (auto& bucket : buckets_) { bucket.pending = bucket.variables.size(); + bucket.deferred_copy_indices.clear(); } if (static_graph_) { @@ -1812,6 +1838,53 @@ void Reducer::runGradCallbackForVariable( #endif } +void Reducer::flush_deferred_copies(Bucket& bucket, size_t bucket_index) { + // Sparse gradients are already divided in mark_variable_ready_sparse and + // communicated independently — skip to avoid double division. + if (bucket.expect_sparse_gradient) { + return; + } + if (!bucket.deferred_copy_indices.empty()) { + std::vector dsts; + std::vector srcs; + dsts.reserve(bucket.deferred_copy_indices.size()); + srcs.reserve(bucket.deferred_copy_indices.size()); + for (auto idx : bucket.deferred_copy_indices) { + auto grad = bucket.variables[idx].grad(); + TORCH_INTERNAL_ASSERT( + grad.defined(), + "Gradient became undefined between defer and flush for variable ", + idx, + " in bucket ", + bucket_index, + ". This indicates a bug — gradients should not be modified during backward."); + dsts.push_back(bucket.bucket_views_in[idx]); + srcs.push_back(grad); + } + at::_foreach_copy_(dsts, srcs); + + // Re-alias grads to bucket views if gradient_as_bucket_view + if (gradient_as_bucket_view_) { + for (auto idx : bucket.deferred_copy_indices) { + auto& variable = bucket.variables[idx]; + auto& bucket_view = bucket.bucket_views_in[idx]; + runGradCallbackForVariable(variable, [&](auto& grad) { + grad = bucket_view; + return true; + }); + } + } + bucket.deferred_copy_indices.clear(); + } + // Single div on the entire flat bucket tensor. + // This also divides regions zeroed for undefined gradients, which is a no-op + // (0 / div_factor_ == 0) but avoids the complexity of tracking whether any + // variable in the bucket had a defined grad. + if (comm_hook_ == nullptr) { + bucket.gradients.div_(div_factor_); + } +} + #ifndef _WIN32 void Reducer::RpcContext::set(ContextPtr&& new_context_ptr) { // We should set 'new_context_ptr' even if it's nullptr. That means the diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index 3badec59594fa..73730ead6de03 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -60,7 +60,8 @@ class TORCH_API Reducer { int64_t first_bucket_bytes_cap, bool skip_all_reduce_unused_params, bool use_python_reducer, - std::vector bucket_bytes_cap_list); + std::vector bucket_bytes_cap_list, + bool batched_grad_copy = false); ~Reducer() noexcept(false); @@ -236,6 +237,8 @@ class TORCH_API Reducer { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const bool gradient_as_bucket_view_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + const bool batched_grad_copy_; + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::vector unused_parameters_; // Previous iteration's unused params, used for checking if unused parameters // change between iterations. Only filled during the first backwards call. @@ -312,6 +315,12 @@ class TORCH_API Reducer { #endif void runGradCallbackForVariable(at::Tensor& variable, const GradCallback& cb); + // Flushes deferred grad-to-bucket copies for a single bucket when + // batched_grad_copy_ is enabled. Called from mark_variable_ready (when + // bucket.pending == 0) and from delay_all_reduce (after all variables + // are marked ready). + void flush_deferred_copies(Bucket& bucket, size_t bucket_index); + // This function is called inside `initialize_buckets()`. It initializes both // `bucket_views_in` and `bucket_views_out` with views for each variable's // gradient into the bucket's flattened `gradients` tensor. Views serve as @@ -402,6 +411,11 @@ class TORCH_API Reducer { // done on different CUDA streams. We record an event for every copy // so that we can synchronize with them prior to kicking off the reduction. // std::vector events; + + // Intra-bucket indices of variables whose grad-to-bucket copies are + // deferred for batching. Flushed as _foreach_copy_ + flat div_ when + // pending == 0. Only used when batched_grad_copy is enabled. + std::vector deferred_copy_indices; }; std::vector buckets_; diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu index ccdd00310dafe..576e59dad2f85 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -9,6 +10,7 @@ #include #include #include +#include #include #include @@ -79,7 +81,8 @@ CUDAPeerAllocInfo::CUDAPeerAllocInfo( size_t buffer_size, int local_device_idx, int rank, - int world_size) + int world_size, + std::string group_name) : alloc_refs_(std::move(alloc_refs)), buffers_(std::move(buffers)), signal_pads_(std::move(signal_pads)), @@ -88,7 +91,8 @@ CUDAPeerAllocInfo::CUDAPeerAllocInfo( buffer_size_(buffer_size), local_device_idx_(local_device_idx), rank_(rank), - world_size_(world_size) { + world_size_(world_size), + group_name_(std::move(group_name)) { const size_t arr_size = sizeof(void*) * world_size_; buffers_dev_ = reinterpret_cast( c10::cuda::CUDACachingAllocator::raw_alloc(arr_size)); @@ -210,7 +214,21 @@ static __global__ void barrier_kernel( void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) { check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); + auto pg = c10d::resolve_process_group(pai_->group_name_); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + rank_, + "symm_mem::barrier", + 0, + 0, + at::kByte, + std::vector(), + std::vector(), + -1, + -1, + world_size_); + c10::cuda::CUDAGuard device_guard(local_device_idx_); barrier_kernel<<< 1, max(at::cuda::warp_size(), world_size_), @@ -252,7 +270,21 @@ void CUDASymmetricMemory::put_signal( int channel, size_t timeout_ms) { check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); + auto pg = c10d::resolve_process_group(pai_->group_name_); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + rank_, + "symm_mem::put_signal", + 0, + 0, + at::kByte, + std::vector(), + std::vector(), + -1, + -1, + world_size_); + c10::cuda::CUDAGuard device_guard(local_device_idx_); put_signal_kernel<<< 1, at::cuda::warp_size(), @@ -300,7 +332,21 @@ void CUDASymmetricMemory::wait_signal( int channel, size_t timeout_ms) { check_channel(channel, world_size_); - c10::cuda::CUDAGuard guard(local_device_idx_); + auto pg = c10d::resolve_process_group(pai_->group_name_); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + rank_, + "symm_mem::wait_signal", + 0, + 0, + at::kByte, + std::vector(), + std::vector(), + -1, + -1, + world_size_); + c10::cuda::CUDAGuard device_guard(local_device_idx_); wait_signal_kernel<<< 1, at::cuda::warp_size(), @@ -453,9 +499,25 @@ struct RendezvousRequest { size_t buffer_size; size_t signal_pad_offset; bool has_multicast_support; + int clique_id; char hostname[HOST_NAME_MAX + 1]; }; +static std::string import_err_msg( + int rank, + int peer, + const std::vector& reqs) { + std::ostringstream oss; + oss << ". Rank " << rank << " (host: " << reqs[rank].hostname + << ", device: " << reqs[rank].device_idx << ", fabric_info: {" + << at::cuda::get_nvml_fabric_info(reqs[rank].device_idx) + << "}) failed to import memory from rank " << peer + << " (host: " << reqs[peer].hostname + << ", device: " << reqs[peer].device_idx << ", NCCL_MNNVL_CLIQUE_ID: " + << c10::utils::get_env("NCCL_MNNVL_CLIQUE_ID").value_or("unset") << ")."; + return oss.str(); +} + void validate_rendezvous_requests( const std::vector& reqs, int world_size) { @@ -466,7 +528,8 @@ void validate_rendezvous_requests( // Use (hostname, device_idx) pair to uniquely identify each allocation. std::set> device_host_pairs; for (auto req : reqs) { - device_host_pairs.insert(std::make_pair(std::string(req.hostname), req.device_idx)); + device_host_pairs.insert( + std::make_pair(std::string(req.hostname), req.device_idx)); } if (!allow_overlapping_devices() && device_host_pairs.size() < (size_t)world_size) { @@ -484,6 +547,35 @@ void validate_rendezvous_requests( } } +// All ranks must be in the same NVLink domain (same clique_id). Detect +// mismatches early before the import fails with an opaque CUDA error. +static void validate_nvlink_fabric_support( + const std::vector& reqs, + int world_size) { + std::unordered_set clique_ids; + for (const auto& req : reqs) { + if (req.clique_id >= 0) { + clique_ids.insert(req.clique_id); + } + } + if (clique_ids.size() > 1) { + std::ostringstream oss; + oss << "CUDASymmetricMemory::rendezvous: " + << "ranks have mismatched NVLink clique_ids. " + << "All ranks using fabric handles must be in the same NVLink domain. " + << "Per-rank info: "; + for (int r = 0; r < world_size; ++r) { + if (r > 0) { + oss << ", "; + } + oss << "rank " << r << " (host: " << reqs[r].hostname + << ", device: " << reqs[r].device_idx + << ", clique_id: " << reqs[r].clique_id << ")"; + } + TORCH_CHECK(false, oss.str()); + } +} + static bool check_group_multicast_support( const std::vector& reqs) { std::vector ranks_with_multicast_support; @@ -678,11 +770,13 @@ c10::intrusive_ptr make_peer_alloc_info( .block_size = block->block_size, .buffer_size = block->buffer_size, .signal_pad_offset = block->signal_pad_offset, - .has_multicast_support = device_has_multicast_support(block->device_idx)}; + .has_multicast_support = device_has_multicast_support(block->device_idx), + .clique_id = at::cuda::get_fabric_clique_id(block->device_idx)}; // Populate hostname field for host identification gethostname(local_req.hostname, sizeof(local_req.hostname)); auto reqs = storeExchange.all_gather(store, rank, world_size, local_req); + validate_nvlink_fabric_support(reqs, world_size); validate_rendezvous_requests(reqs, world_size); std::vector pids(world_size); @@ -715,15 +809,19 @@ c10::intrusive_ptr make_peer_alloc_info( // note how in one case it's directly imported_handles[r] and in another // &(imported_handles[r]) so can't do with just type definitions if constexpr (!use_fabric_handle) { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)(uintptr_t)imported_handles[r], - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + C10_CUDA_DRIVER_CHECK_MSG( + driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)(uintptr_t)imported_handles[r], + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + import_err_msg(rank, r, reqs)); } else { - C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_( - &handles[r], - (void*)&(imported_handles[r]), - CU_MEM_HANDLE_TYPE_FABRIC)); + C10_CUDA_DRIVER_CHECK_MSG( + driver_api->cuMemImportFromShareableHandle_( + &handles[r], + (void*)&(imported_handles[r]), + CU_MEM_HANDLE_TYPE_FABRIC), + import_err_msg(rank, r, reqs)); } #elif defined(USE_ROCM) C10_CUDA_CHECK(hipMemImportFromShareableHandle( @@ -785,7 +883,8 @@ c10::intrusive_ptr make_peer_alloc_info( block->buffer_size, block->device_idx, rank, - world_size); + world_size, + group_name); return pai; } @@ -888,6 +987,10 @@ c10::intrusive_ptr CUDASymmetricMemoryAllocator::find_block_covering(void return alloc_it->second; } +bool CUDASymmetricMemoryAllocator::has_allocation(void* ptr) { + return find_block(ptr) != nullptr; +} + struct RegisterCUDASymmetricMemoryAllocator { RegisterCUDASymmetricMemoryAllocator() { auto allocator = c10::make_intrusive(); diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp index 0ed09c6cc5dbf..5636bca810f6a 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.hpp @@ -84,7 +84,8 @@ class CUDAPeerAllocInfo : public c10::intrusive_ptr_target { size_t buffer_size, int local_device_idx, int rank, - int world_size); + int world_size, + std::string group_name); private: std::vector> alloc_refs_; @@ -98,6 +99,7 @@ class CUDAPeerAllocInfo : public c10::intrusive_ptr_target { int world_size_; void** buffers_dev_; void** signal_pads_dev_; + std::string group_name_; friend class CUDASymmetricMemory; }; @@ -135,6 +137,7 @@ class CUDASymmetricMemoryAllocator : public SymmetricMemoryAllocator { void* ptr, const std::optional& group_name) override; bool has_multicast_support(int device_idx) override; + bool has_allocation(void* ptr) override; c10::DeviceType supported_device_type() override; std::string name() override; diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu index 0755f75df00df..da8c1611f13a0 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu @@ -17,6 +17,7 @@ #include #include +#include #include #include @@ -163,6 +164,20 @@ at::Tensor multimem_all_reduce_( const at::Tensor& input, std::string reduce_op, std::string group_name) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::multimem_all_reduce", + input.numel(), + input.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); TORCH_CHECK( input.is_contiguous(), "multimem_all_reduce_: input must be contiguous."); TORCH_CHECK( @@ -247,6 +262,20 @@ at::Tensor multimem_one_shot_reduce_out( int64_t root, std::string group_name, at::Tensor out) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::multimem_one_shot_reduce", + input.numel(), + out.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); TORCH_CHECK( input.is_contiguous(), "multimem_one_shot_reduce: input must be contiguous."); @@ -361,6 +390,20 @@ at::Tensor multimem_all_gather_out( const at::Tensor& input, std::string group_name, at::Tensor out) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::multimem_all_gather", + input.numel(), + out.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); auto symm_mem = c10d::symmetric_memory::rendezvous(out, group_name); TORCH_CHECK( symm_mem != nullptr, @@ -475,6 +518,20 @@ at::Tensor one_shot_all_reduce_out_impl( std::string reduce_op, std::string group_name, at::Tensor out) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::one_shot_all_reduce", + input.numel(), + out.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); TORCH_CHECK( input.is_contiguous(), "one_shot_all_reduce: input must be contiguous."); TORCH_CHECK( @@ -732,6 +789,20 @@ at::Tensor two_shot_all_reduce_impl( std::optional output, std::string reduce_op, std::string group_name) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::two_shot_all_reduce", + input.numel(), + input.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); TORCH_CHECK( input.is_contiguous(), "two_shot_all_reduce: input must be contiguous."); TORCH_CHECK( @@ -856,6 +927,20 @@ at::Tensor reduce_scatter_out( std::string group_name, bool split_last_dim, at::Tensor output) { + auto pg = c10d::resolve_process_group(group_name); + RECORD_PARAM_COMMS( + static_cast(0), + std::make_tuple(pg->getGroupName(), pg->getGroupDesc()), + pg->getRank(), + "symm_mem::reduce_scatter", + input.numel(), + output.numel(), + input.scalar_type(), + std::vector(), + std::vector(), + -1, + -1, + pg->getSize()); TORCH_CHECK( input.is_contiguous(), "reduce_scatter: input must be contiguous."); TORCH_CHECK( diff --git a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp index 7c255fa283ec9..cec3bfc8d05d6 100644 --- a/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryTypes.hpp @@ -2,6 +2,10 @@ #include #include +#include +#include + +#include #if defined(USE_ROCM) #include @@ -9,6 +13,14 @@ namespace c10d::symmetric_memory { +// Key type for the symmetric memory map. `void*` for tensor storage ptr, +// `std::string` for group name. +using SymmMemKey = std::pair; +// Hash function for the symmetric memory map. c10::hash has a std::pair +// specialization (line 323-329 of hash.h) that delegates to the tuple hasher +// which combines hashes of each element. +using SymmMemKeyHash = c10::hash; + // Covers NVL72 constexpr int max_cuda_p2p_domain_size = 72; // Maximum number of channels diff --git a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu index 675b0d096ee71..02d3e35d19924 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -61,32 +62,16 @@ struct NCCLAllocation { namespace { -struct NCCLSymmMemKey { - void* ptr; - std::string group_name; - - bool operator==(const NCCLSymmMemKey& other) const noexcept { - return ptr == other.ptr && group_name == other.group_name; - } -}; - -struct NCCLSymmMemKeyHash { - size_t operator()(const NCCLSymmMemKey& key) const { - auto seed = c10::hash_combine(0, std::hash{}(key.ptr)); - return c10::hash_combine(seed, std::hash{}(key.group_name)); - } -}; - // Base allocation ptr -> owning NCCL allocation metadata. using NCCLAllocMap = ska::flat_hash_map>; // (Tensor storage/data ptr, group name) -> cached SymmetricMemory handle. using NCCLSymmMemMap = ska::flat_hash_map< - NCCLSymmMemKey, + SymmMemKey, c10::intrusive_ptr, - NCCLSymmMemKeyHash>; + SymmMemKeyHash>; // Base allocation ptr -> cached `(tensor ptr, group)` keys derived from it. using NCCLSymmMemKeysByAlloc = - ska::flat_hash_map>; + ska::flat_hash_map>; bool pointer_in_allocation(void* ptr, const NCCLAllocation& allocation) { auto ptr_int = reinterpret_cast(ptr); @@ -507,7 +492,7 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { const std::optional& group_name) override { TORCH_CHECK(group_name.has_value(), "group_name must be provided"); NCCLAllocation* allocation; - NCCLSymmMemKey key{ptr, *group_name}; + SymmMemKey key{ptr, *group_name}; { std::lock_guard lock(mutex_); auto it = symm_mems_.find(key); @@ -563,6 +548,11 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return device_has_multicast_support(device_idx); } + bool has_allocation(void* ptr) override { + std::lock_guard lock(mutex_); + return find_allocation_covering(ptr, allocations_) != allocations_.end(); + } + c10::DeviceType supported_device_type() override { return c10::DeviceType::CUDA; } diff --git a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cpp index fa9b509baa754..7a3e853197030 100644 --- a/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -10,6 +11,7 @@ #include #include #include +#include #include @@ -17,8 +19,10 @@ // include only the nvshmem host library headers: // #include // It translates into the following two lines: +#if !defined(USE_ROCM) #include #include +#endif // For maximum compatibility, we use the "host/" style for now. namespace c10d { @@ -146,12 +150,14 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target { arr_size, cudaMemcpyHostToDevice)); +#if !defined(USE_ROCM) // Multi-cast is not supported on ROCm yet // Initialize multicast address // On unsupported platforms, this API returns a nullptr. auto device = c10::Device(c10::DeviceType::CUDA, allocation->device_idx); auto& team_manager = c10d::nvshmem_extension::TeamManager::get(device); auto team = team_manager.get_team(group_name, rank_to_global_rank); mc_addr_ = nvshmemx_mc_ptr(team, base_ptr_); +#endif } private: @@ -347,9 +353,11 @@ static void initialize_nvshmem_with_store( is_initialized = true; // Print version +#if !defined(USE_ROCM) int major, minor; ::nvshmem_info_get_version(&major, &minor); LOG(INFO) << "NVSHMEM is available, version: " << major << '.' << minor; +#endif } class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { @@ -401,7 +409,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { TORCH_CHECK(group_name.has_value()); std::lock_guard lock(mutex_); { - auto it = symm_mems_.find(std::make_tuple(ptr, *group_name)); + auto it = symm_mems_.find(SymmMemKey{ptr, *group_name}); if (it != symm_mems_.end()) { return it->second; } @@ -426,7 +434,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { // Search again using allocation base ptr (which is the key we use for // caching, see below) - auto it = symm_mems_.find(std::make_tuple(allocation->ptr, *group_name)); + auto it = symm_mems_.find(SymmMemKey{allocation->ptr, *group_name}); c10::intrusive_ptr symm_mem; if (it != symm_mems_.end()) { // Base allocation has been rendezvoused @@ -438,7 +446,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { } // Cache rendezvous using allocation's base address as key - symm_mems_[std::make_tuple(allocation->ptr, *group_name)] = symm_mem; + symm_mems_[SymmMemKey{allocation->ptr, *group_name}] = symm_mem; // TODO: change the `ptr` below to `tensor.data_ptr()` when adding support // for user slice/view operations. For MemPool support, @@ -459,6 +467,18 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { return device_has_multicast_support(device_idx); } + bool has_allocation(void* ptr) override { + std::lock_guard lock(mutex_); + auto alloc_it = std::find_if( + allocations_.begin(), allocations_.end(), [&](const auto& pair) { + auto ptr_int = reinterpret_cast(ptr); + auto base_ptr = reinterpret_cast(pair.second->ptr); + return ptr_int >= base_ptr && + ptr_int < base_ptr + pair.second->buffer_size; + }); + return alloc_it != allocations_.end(); + } + c10::DeviceType supported_device_type() override { return c10::DeviceType::CUDA; } @@ -470,9 +490,10 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator { private: std::mutex mutex_; std::unordered_map> allocations_; - std::map< - std::tuple, - c10::intrusive_ptr> + ska::flat_hash_map< + SymmMemKey, + c10::intrusive_ptr, + SymmMemKeyHash> symm_mems_; }; diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp index 860b230cb04d4..279c578547e15 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.cpp @@ -295,6 +295,14 @@ TORCH_API c10::intrusive_ptr rendezvous( return allocator->rendezvous(tensor.storage().data_ptr().get(), group_name); } +TORCH_API bool is_symm_mem_tensor(const at::Tensor& tensor) { + if (!has_allocator(tensor.device().type())) { + return false; + } + auto allocator = get_allocator(tensor.device().type()); + return allocator->has_allocation(tensor.storage().data_ptr().get()); +} + TORCH_API bool has_multicast_support( c10::DeviceType device_type, int device_idx) { @@ -551,6 +559,8 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def("nccl_get(Tensor(a!) tensor, int peer) -> ()"); m.def("nccl_wait_for_signal(Tensor sigpad, int signal) -> ()"); m.def("nccl_put_with_signal(Tensor(a) tensor, int signal, int peer) -> ()"); + m.def( + "nccl_reduce_scatter_offset(Tensor input, Tensor(a!)[] out, str group_name, int dim, int[]? offsets=None, int[]? dst_ranks=None, str red_op='sum') -> ()"); m.def( "nvshmem_all_to_all(Tensor input, Tensor(a!) out, str group_name) -> Tensor(a!)"); m.def( diff --git a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp index 4156e24be1f26..a509568381e33 100644 --- a/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp @@ -113,6 +113,9 @@ class SymmetricMemoryAllocator : public c10::intrusive_ptr_target { virtual bool has_multicast_support(int device_idx) = 0; virtual c10::DeviceType supported_device_type() = 0; virtual std::string name() = 0; + virtual bool has_allocation(void* ptr) { + return false; + } }; C10_EXPORT bool is_finalizing(); @@ -191,6 +194,8 @@ TORCH_API bool has_multicast_support( c10::DeviceType device_type, int device_idx); +TORCH_API bool is_symm_mem_tensor(const at::Tensor& tensor); + TORCH_API void set_backend(const std::string& name); TORCH_API std::optional get_backend(c10::Device device); diff --git a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp index 1acc7b831043a..a84c0dcebebfe 100644 --- a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cpp @@ -3,7 +3,10 @@ #include #if defined(USE_ROCM) -#include +#include +#include +#include +#include #endif namespace c10d::intra_node_comm { @@ -35,23 +38,126 @@ static NvlMesh getNvlMesh(const std::vector& rankToDeviceIdx) { } return nvlMesh; #else + // Load libamd_smi at runtime to avoid linking it into torch_hip (double-load + // with Python amdsmi causes bus errors). Types/constants from amdsmi.h only. + struct AmdsmiApi { + amdsmi_status_t (*init)(uint64_t); + amdsmi_status_t (*get_socket_handles)(uint32_t*, amdsmi_socket_handle*); + amdsmi_status_t (*get_processor_handles)( + amdsmi_socket_handle, + uint32_t*, + amdsmi_processor_handle*); + amdsmi_status_t (*is_P2P_accessible)( + amdsmi_processor_handle, + amdsmi_processor_handle, + bool*); + }; + static void* amdsmi_handle = nullptr; + static AmdsmiApi amdsmi = {}; + static bool amdsmi_resolved = false; + + if (!amdsmi_resolved) { + amdsmi_resolved = true; + const char* rocm = std::getenv("ROCM_PATH"); + std::string path = + rocm ? std::string(rocm) + "/lib/libamd_smi.so" : "libamd_smi.so"; + amdsmi_handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + if (!amdsmi_handle) { + amdsmi_handle = dlopen("libamd_smi.so", RTLD_NOW | RTLD_LOCAL); + } + if (!amdsmi_handle) { + LOG(ERROR) << "IntraNodeComm:: getNvlMesh: dlopen libamd_smi.so failed: " + << dlerror(); + return {}; + } + amdsmi.init = reinterpret_cast( + dlsym(amdsmi_handle, "amdsmi_init")); + amdsmi.get_socket_handles = + reinterpret_cast( + dlsym(amdsmi_handle, "amdsmi_get_socket_handles")); + amdsmi.get_processor_handles = + reinterpret_cast( + dlsym(amdsmi_handle, "amdsmi_get_processor_handles")); + amdsmi.is_P2P_accessible = + reinterpret_cast( + dlsym(amdsmi_handle, "amdsmi_is_P2P_accessible")); + if (!amdsmi.init || !amdsmi.get_socket_handles || + !amdsmi.get_processor_handles || !amdsmi.is_P2P_accessible) { + LOG(ERROR) << "IntraNodeComm:: getNvlMesh: dlsym amdsmi failed"; + return {}; + } + } + NvlMesh nvlMesh = {}; const auto worldSize = rankToDeviceIdx.size(); - // For each device, loop over devices connected to it + + uint32_t socket_count = 0; + amdsmi_status_t ret = amdsmi.get_socket_handles(&socket_count, nullptr); + if (ret == AMDSMI_STATUS_NOT_INIT) { + ret = amdsmi.init(AMDSMI_INIT_AMD_GPUS); + if (ret != AMDSMI_STATUS_SUCCESS) { + LOG(ERROR) << "IntraNodeComm:: getNvlMesh: amdsmi_init failed, ret=" + << static_cast(ret); + return {}; + } + socket_count = 0; + ret = amdsmi.get_socket_handles(&socket_count, nullptr); + } + if (ret != AMDSMI_STATUS_SUCCESS) { + LOG(ERROR) + << "IntraNodeComm:: getNvlMesh: amdsmi_get_socket_handles failed, ret=" + << static_cast(ret); + return {}; + } + + std::vector socket_handles(socket_count); + ret = amdsmi.get_socket_handles(&socket_count, &socket_handles[0]); + if (ret != AMDSMI_STATUS_SUCCESS) { + LOG(ERROR) + << "IntraNodeComm:: getNvlMesh: amdsmi_get_socket_handles (buffer) failed, ret=" + << static_cast(ret); + return {}; + } + + std::vector processor_handles; + for (size_t i = 0; i < socket_count; ++i) { + uint32_t device_count = 0; + ret = + amdsmi.get_processor_handles(socket_handles[i], &device_count, nullptr); + if (ret != AMDSMI_STATUS_SUCCESS) { + LOG(ERROR) + << "IntraNodeComm:: getNvlMesh: amdsmi_get_processor_handles (count) failed, ret=" + << static_cast(ret); + return {}; + } + std::vector _processor_handles(device_count); + ret = amdsmi.get_processor_handles( + socket_handles[i], &device_count, &_processor_handles[0]); + if (ret != AMDSMI_STATUS_SUCCESS) { + LOG(ERROR) + << "IntraNodeComm:: getNvlMesh: amdsmi_get_processor_handles (buffer) failed, ret=" + << static_cast(ret); + return {}; + } + processor_handles.insert( + processor_handles.end(), + _processor_handles.begin(), + _processor_handles.end()); + } + for (size_t idx = 0; idx < worldSize; ++idx) { for (size_t link = 0; link < kMaxDevices; ++link) { if (idx == link) continue; - bool conn = false; - auto ret = rsmi_is_P2P_accessible(idx, link, &conn); - if (ret != RSMI_STATUS_SUCCESS) { + ret = amdsmi.is_P2P_accessible( + processor_handles[idx], processor_handles[link], &conn); + if (ret != AMDSMI_STATUS_SUCCESS) { LOG(ERROR) - << "IntraNodeComm: getNvlMesh: rsmi_is_P2P_accessible returned error ret=" - << ret; + << "IntraNodeComm: getNvlMesh: amdsmi_is_P2P_accessible failed, ret=" + << static_cast(ret); return {}; } - if (conn) { nvlMesh[idx][link] += 1; } @@ -88,11 +194,13 @@ IntraNodeComm::IntraNodeComm( c10::intrusive_ptr store, size_t rank, size_t worldSize, - std::optional bufferSize) + std::optional bufferSize, + std::string groupName) : store_(std::move(store)), rank_(rank), worldSize_(worldSize), - bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {} + bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize), + groupName_(std::move(groupName)) {} IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { @@ -171,14 +279,6 @@ bool IntraNodeComm::rendezvous() { gethostname(devInfo.hostname, sizeof(devInfo.hostname)); devInfo.deviceIdx = deviceIdx_; -#if defined(USE_ROCM) - auto ret = rsmi_init(0); - if (ret != RSMI_STATUS_SUCCESS) { - LOG(ERROR) << "IntraNodeComm:: rendezvous failed in rsmi_init, ret=" << ret; - return false; - } -#endif - auto peerDevInfos = storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo); @@ -214,11 +314,13 @@ bool IntraNodeComm::rendezvous() { return false; } - auto groupName = "IntraNodeComm" + std::to_string(intraNodeCommIdx++); + const std::string name = groupName_.empty() + ? "IntraNodeComm" + std::to_string(intraNodeCommIdx++) + : groupName_; set_group_info( - groupName, static_cast(rank_), static_cast(worldSize_), store_); + name, static_cast(rank_), static_cast(worldSize_), store_); auto allocator = get_allocator(c10::DeviceType::CUDA); - symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName); + symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, name); symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_, std::nullopt); isInitialized_ = true; return true; diff --git a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu index 7e5afa425335c..1a3fb6c5c0a25 100644 --- a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu @@ -23,7 +23,7 @@ static void checkInput(const at::Tensor& input, int deviceIdx) { } bool isIntraNodeCommSupported() { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) return false; #else return true; @@ -46,7 +46,7 @@ at::Tensor IntraNodeComm::oneShotAllReduce( at::TensorOptions().dtype(input.dtype()).device(input.device())); symmMemTensor.copy_(input); - op.call(symmMemTensor, "sum", "", input); + op.call(symmMemTensor, "sum", groupName_, input); return input; } @@ -65,7 +65,7 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::TensorOptions().dtype(input.dtype()).device(input.device())); symmMemTensor.copy_(input); - op.call(symmMemTensor, "sum", ""); + op.call(symmMemTensor, "sum", groupName_); input.copy_(symmMemTensor); return input; } diff --git a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp index 7b5e8ff999c5d..8b2a425c47ce2 100644 --- a/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/intra_node_comm.hpp @@ -33,7 +33,8 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { c10::intrusive_ptr store, size_t rank, size_t worldSize, - std::optional bufferSize = std::nullopt); + std::optional bufferSize = std::nullopt, + std::string groupName = ""); ~IntraNodeComm() override; @@ -67,6 +68,7 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { size_t rank_; size_t worldSize_; size_t bufferSize_; + std::string groupName_; /** * Members initialized after rendezvous diff --git a/torch/csrc/distributed/c10d/symm_mem/macros.hpp b/torch/csrc/distributed/c10d/symm_mem/macros.hpp new file mode 100644 index 0000000000000..18bcc0559b8c4 --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/macros.hpp @@ -0,0 +1,24 @@ +// Macros for type dispatch and common utilities for symmetric memory +#pragma once + +#include + +// Convert ATen floating point types to NV floating point types +// at::kBFloat16 -> __nv_bfloat16 +// at::kHalf -> __half +// Float is the same. + +#define AT_DISPATCH_CASE_CONVERT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + return __VA_ARGS__(); \ + } + +#define AT_DISPATCH_NV_FLOATS(scalar_type, name, ...) \ + AT_DISPATCH_SWITCH( \ + scalar_type, \ + name, \ + AT_DISPATCH_CASE_CONVERT(at::kBFloat16, __nv_bfloat16, __VA_ARGS__); \ + AT_DISPATCH_CASE_CONVERT(at::kHalf, __half, __VA_ARGS__); \ + AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__)); diff --git a/torch/csrc/distributed/c10d/symm_mem/nccl_dev_cap.hpp b/torch/csrc/distributed/c10d/symm_mem/nccl_dev_cap.hpp index fbf0cf6b50c3e..5c2c4d235c5d7 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nccl_dev_cap.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/nccl_dev_cap.hpp @@ -19,4 +19,8 @@ #if NCCL_VERSION_CODE >= NCCL_VERSION(2, 29, 0) #define NCCL_HAS_ONE_SIDED_API #endif + +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 29, 7) +#define NCCL_DEVICE_HAS_REDUCE_COPY +#endif #endif // USE_NCCL diff --git a/torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu index 3ad399a919990..2f37255bc24fb 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nccl_extension.cu @@ -408,6 +408,7 @@ TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { // API that uses internal signal mechanism and accepts handle m.impl("nccl_put_signal", torch::CppFunction::makeFromBoxedFunction<&nccl_put_signal_boxed>()); + m.impl("nccl_reduce_scatter_offset", c10d::nccl_extension::nccl_reduce_scatter_offset); } // Use CompositeExplicitAutograd as key since ops do not accept tensor as input diff --git a/torch/csrc/distributed/c10d/symm_mem/nccl_extension.hpp b/torch/csrc/distributed/c10d/symm_mem/nccl_extension.hpp index d0152e34964ac..5d779f1564726 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nccl_extension.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/nccl_extension.hpp @@ -18,4 +18,16 @@ TORCH_API void nccl_put_with_signal( int64_t signal, int64_t peer); +// Simultaneously reduce N blocks of a 2-D input tensor from a shared symmetric +// memory buffer, routing each to a specific destination rank. Blocks are +// described by inclusive-prefix-sum offsets along `dim` (0 or 1); all blocks +// must have equal size. +TORCH_API void nccl_reduce_scatter_offset( + const at::Tensor& input, + at::TensorList out, + const std::string& group_name, + int64_t dim, + std::optional offsets, + std::optional dst_ranks, + const std::string& red_op); } // namespace c10d::nccl_extension diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu index ead533f0da3d4..b766443612578 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -912,20 +913,6 @@ __global__ void tile_reduce_kernel( #endif } -#define AT_DISPATCH_CASE_CONVERT(enum_type, scalar_type, ...) \ - case enum_type: { \ - AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ - using scalar_t = scalar_type; \ - return __VA_ARGS__(); \ - } - -#define AT_DISPATCH_NVSHMEM_FLOATS(scalar_type, name, ...) \ - AT_DISPATCH_SWITCH( \ - scalar_type, name, \ - AT_DISPATCH_CASE_CONVERT(at::kBFloat16, __nv_bfloat16, __VA_ARGS__); \ - AT_DISPATCH_CASE_CONVERT(at::kHalf, __half, __VA_ARGS__); \ - AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__)); - void tile_reduce( at::Tensor& in_tile, at::Tensor& out_tile, @@ -974,7 +961,7 @@ void tile_reduce( &root, &teams_dev}; - AT_DISPATCH_NVSHMEM_FLOATS(in_tile.scalar_type(), "tile_reduce", [&]() { + AT_DISPATCH_NV_FLOATS(in_tile.scalar_type(), "tile_reduce", [&]() { nvshmemx_collective_launch( (const void*)tile_reduce_kernel, dim3(nblocks), @@ -1057,7 +1044,7 @@ void multi_root_tile_reduce( &root, &teams_dev}; - AT_DISPATCH_NVSHMEM_FLOATS(out_tile.scalar_type(), "multi_root_tile_reduce", [&]() { + AT_DISPATCH_NV_FLOATS(out_tile.scalar_type(), "multi_root_tile_reduce", [&]() { nvshmemx_collective_launch( (const void*)tile_reduce_kernel, dim3(nblocks), diff --git a/torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp b/torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp index d31358a4002ea..947795c9350f3 100644 --- a/torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp +++ b/torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp @@ -12,8 +12,12 @@ // include only the nvshmem host library headers: // #include // It translates into the following two lines: +#if !defined(USE_ROCM) #include #include +#else +#include +#endif // For maximum compatibility, we use the "host/" style for now. namespace c10d::nvshmem_extension { diff --git a/torch/csrc/distributed/c10d/symm_mem/ops/nccl_reduce_scatter_offset.cu b/torch/csrc/distributed/c10d/symm_mem/ops/nccl_reduce_scatter_offset.cu new file mode 100644 index 0000000000000..6cd13189c36bb --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/ops/nccl_reduce_scatter_offset.cu @@ -0,0 +1,380 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Simultaneously reduce N blocks of a 2-D input tensor from a symmetric memory +// buffer, routing each block to a specific destination rank (dst_ranks[i]). +// Only the destination rank writes the reduced value to a contiguous output +// tensor, with the same shape as the owned block. +// +// The `dim` argument controls which dimension is sharded (0 or 1): +// dim=1 (column sharding): each block spans input[:, offsets[i-1]:offsets[i]] +// dim=0 (row sharding): each block spans input[offsets[i-1]:offsets[i], :] +// +// Blocks are described by inclusive-prefix-sum offsets along `dim`. +// For each j, out[j] must have the same shape across all ranks (i.e. the j-th +// owned block on every rank must have equal size); different j's may differ. +// +// If offsets is nullopt, input.size(dim) is divided equally into group_size blocks. +// If dst_ranks is nullopt, blocks are distributed round-robin across ranks. +// +// Ownership must be balanced: every rank must own the same number of blocks +// (N % group_size == 0 and dst_ranks distributes evenly). + +namespace c10d::nccl_extension { + +using namespace c10d::symmetric_memory; + +// Kernel requires device-side API: ncclLsaReduceSum. +#ifdef NCCL_DEVICE_HAS_REDUCE_COPY + +// Naming conventions in this file: +// "BLOCK" means tensor block (as opposed to CUDA block); +// "CTA" means CUDA block; +// "RS" means Reduce Scatter; +// "slot" means which tensor block a CTA is assigned to. + +constexpr int RS_MAX_BLOCKS = 64; // max total blocks being scattered (N) +constexpr int RS_MAX_BLOCKS_PER_RANK = 16; // max blocks owned by a single rank +constexpr int RS_MAX_CTAS_PER_BLOCK = 16; // max CTAs assigned to one block +// Threads per CTA; defaults to a medium value to fit medium-width blocks. +constexpr int RS_THREADS_PER_CTA = 128; +// Total LSA barrier slots needed: one per CTA across all owned blocks. +constexpr int RS_MAX_CTA_COUNT = (RS_MAX_BLOCKS_PER_RANK * RS_MAX_CTAS_PER_BLOCK); + +// Per-slot data passed to the kernel in a single struct to avoid multiple +// kernel arguments. Indexed by owned slot (0..n_owned-1). +struct ReduceScatterOffsetsInfo { + size_t byte_offsets[RS_MAX_BLOCKS_PER_RANK]; // byte offset into the NCCL window + void* dst_ptrs[RS_MAX_BLOCKS_PER_RANK]; // output pointer (contiguous) + uint16_t dst_block_size[RS_MAX_BLOCKS_PER_RANK]; // per-slot size along the sharding dim + uint16_t ctas_offset[RS_MAX_BLOCKS_PER_RANK]; // inclusive prefix sum of per-slot CTA counts + uint8_t cta_slot[RS_MAX_CTA_COUNT]; // slot index for each flat CTA + int n_owned; +}; + +// Grid: 1D, total_ctas = sum of per-slot CTA counts (info.ctas_offset[n_owned]). +// Each CTA belongs to one slot; blockIdx.x is the flat CTA index used as the +// LSA barrier index, ensuring all ranks assign the same index to each logical +// (slot, local_block) pair (because owned_sizes[j] is consistent across ranks). +// +// UseMultimem=true: uses ncclMultimemReduceSum for hardware reduction via +// NVLink multicast; requires devcomm created with lsaMultimem=true. +// UseMultimem=false: uses ncclLsaReduceSum (software reduce via LSA reads). +template +__global__ void reduce_scatter_offset_kernel( + ncclWindow_t window, + ReduceScatterOffsetsInfo info, + int fixed_dim_size, // input.size(1-dim): constant across all slots + bool col_sharded, // true when dim==1 + int64_t outer_stride, // row stride of the input buffer (in elements) + ncclDevComm devComm) { + // cta_slot maps the flat CTA index to its owned slot. + const int slot = info.cta_slot[blockIdx.x]; + // ctas_offset is an inclusive prefix sum, so slot_start is the flat index + // of the first CTA assigned to this slot. + const int slot_start = slot > 0 ? info.ctas_offset[slot - 1] : 0; + // local_block is this CTA's position within its slot (0-based row tile index). + const int local_block = static_cast(blockIdx.x) - slot_start; + // Number of CTAs sharing this slot; used as the row-loop stride. + const int ctas_for_slot = info.ctas_offset[slot] - slot_start; + const ncclCoopCta coop{}; + + // One LSA barrier per CTA; all ranks must call both syncs unconditionally. + ncclLsaBarrierSession bar{ + coop, + devComm, + ncclTeamLsa(devComm), + devComm.lsaBarrier, + blockIdx.x}; + // Acquire: wait until all peers have written their data into the window. + bar.sync(coop, cuda::memory_order_acquire); + + const size_t base_byte_offset = info.byte_offsets[slot]; // start of this block in the window + T* dst_base = reinterpret_cast(info.dst_ptrs[slot]); // start of out[slot] + const int block_size = info.dst_block_size[slot]; // size along the sharding dim + const int rows = col_sharded ? fixed_dim_size : block_size; + const int cols = col_sharded ? block_size : fixed_dim_size; + + // Each CTA handles a strided subset of rows; the reduce reads from all peers + // and writes cols elements starting at dst_row. + for (int row = local_block; row < rows; row += ctas_for_slot) { + const size_t row_offset = + base_byte_offset + + static_cast(row * outer_stride) * sizeof(T); + T* dst_row = dst_base + row * cols; + if constexpr (UseMultimem) { + ncclMultimemReduceSum( + coop, window, row_offset, dst_row, cols, devComm.lsaMultimem); + } else { + ncclLsaReduceSum(coop, window, row_offset, dst_row, cols, devComm); + } + } + + // Release: signal peers that we are done reading window memory. + bar.sync(coop, cuda::memory_order_release); +} + +#endif // NCCL_DEVICE_HAS_REDUCE_COPY + +// Host entry point. Validates arguments, resolves defaults, builds the +// per-slot ReduceScatterOffsetsInfo, and launches the kernel. +// See file-level comment for semantics. +void nccl_reduce_scatter_offset( + const at::Tensor& input, + at::TensorList out, + const std::string& group_name, + int64_t dim, + std::optional offsets, + std::optional dst_ranks, + const std::string& red_op) { +#ifdef NCCL_DEVICE_HAS_REDUCE_COPY + TORCH_CHECK( + red_op == "sum", + "nccl_reduce_scatter_offset: only red_op='sum' is supported, got '", red_op, "'"); + + TORCH_CHECK( + input.dim() == 2, + "nccl_reduce_scatter_offset: input must be 2-D"); + TORCH_CHECK( + dim == 0 || dim == 1, + "nccl_reduce_scatter_offset: dim must be 0 or 1, got ", dim); + TORCH_CHECK( + input.stride(-1) == 1, + "nccl_reduce_scatter_offset: innermost dimension must be contiguous " + "(stride[-1] == 1)"); + + // rendezvous retrieves the symmetric memory handle; the tensor must have + // been allocated via empty_strided_p2p with the NCCL backend. + auto symm_mem = c10d::symmetric_memory::rendezvous(input, group_name); + TORCH_CHECK( + symm_mem != nullptr, + "nccl_reduce_scatter_offset: input must be allocated via NCCL symmetric " + "memory (use empty_strided_p2p with NCCL backend)"); + + auto* nccl_hdl = dynamic_cast(symm_mem.get()); + TORCH_CHECK( + nccl_hdl != nullptr, + "nccl_reduce_scatter_offset: requires NCCL symmetric memory backend"); + + c10::cuda::CUDAGuard guard(input.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + auto device = input.device(); + + auto& manager = c10d::symmetric_memory::NCCLDevCommManager::get(device); + // Get the host-side communicator. + ncclComm_t comm = manager.get_comm(group_name); + + const bool use_multimem = nccl_hdl->has_multicast_support(); + + // The devcomm is cached per (group, key); create it on first use. + // lsaBarrierCount must cover the maximum number of concurrent CTAs. + // lsaMultimem is set when the allocation has multicast support, so that + // devComm.lsaMultimem is valid for ncclMultimemReduceSum in the kernel. + static constexpr char const kDevcommKey[] = "nccl_reduce_scatter_offset"; + auto devcomm_opt = manager.get_devcomm(group_name, kDevcommKey); + if (!devcomm_opt) { + ncclDevCommRequirements reqs = NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER; + reqs.lsaBarrierCount = RS_MAX_CTA_COUNT; + reqs.lsaMultimem = use_multimem; + ncclDevComm devcomm; + C10D_NCCL_CHECK( + ncclDevCommCreate(comm, &reqs, &devcomm), + "ncclDevCommCreate failed in nccl_reduce_scatter_offset"); + // Cache the device communicator. + devcomm_opt = manager.register_devcomm(group_name, devcomm, kDevcommKey); + } + ncclDevComm& devcomm = devcomm_opt->get(); + + const int my_rank = devcomm.rank; + const int group_size = devcomm.nRanks; + + // Determine n_blocks: from offsets if given, else group_size (equal-size default). + const int n_blocks = offsets.has_value() + ? static_cast(offsets->size()) + : group_size; + TORCH_CHECK( + n_blocks > 0, + "nccl_reduce_scatter_offset: must have at least one block"); + + // Fill dst_ranks default: round-robin across ranks. + std::vector dst_ranks_vec; + at::IntArrayRef effective_dst_ranks; + if (dst_ranks.has_value()) { + effective_dst_ranks = *dst_ranks; + } else { + dst_ranks_vec.resize(n_blocks); + for (int i = 0; i < n_blocks; i++) { + dst_ranks_vec[i] = i % group_size; + } + effective_dst_ranks = at::IntArrayRef(dst_ranks_vec); + } + + // Fill offsets default: divide input.size(dim) equally among group_size blocks. + std::vector offsets_vec; + at::IntArrayRef effective_offsets; + if (offsets.has_value()) { + effective_offsets = *offsets; + TORCH_CHECK( + effective_offsets[n_blocks - 1] <= input.size(dim), + "nccl_reduce_scatter_offset: offsets exceed input size along dim ", dim); + } else { + const int64_t total = input.size(dim); + TORCH_CHECK( + total % group_size == 0, + "nccl_reduce_scatter_offset: input.size(", dim, ")=", total, + " must be divisible by group size (", group_size, ")"); + const int64_t block_size = total / group_size; + offsets_vec.resize(n_blocks); + for (int i = 0; i < n_blocks; i++) { + offsets_vec[i] = (i + 1) * block_size; + } + effective_offsets = at::IntArrayRef(offsets_vec); + } + + TORCH_CHECK( + n_blocks <= RS_MAX_BLOCKS, + "nccl_reduce_scatter_offset: too many blocks: ", n_blocks, + " (max ", RS_MAX_BLOCKS, ")"); + TORCH_CHECK( + static_cast(effective_dst_ranks.size()) == n_blocks, + "nccl_reduce_scatter_offset: dst_ranks.size() must match offsets.size()"); + + const int64_t outer_stride = input.stride(0); + + // Collect owned blocks (in order). + std::vector owned_indices; + for (int i = 0; i < n_blocks; i++) { + if (static_cast(effective_dst_ranks[i]) == my_rank) { + owned_indices.push_back(i); + } + } + const int n_owned = static_cast(owned_indices.size()); + TORCH_CHECK( + n_owned * group_size == n_blocks, + "nccl_reduce_scatter_offset: dst_ranks must distribute blocks evenly " + "(rank owns ", n_owned, "/", n_blocks, ", group_size=", group_size, ")"); + TORCH_CHECK( + n_owned <= RS_MAX_BLOCKS_PER_RANK, + "nccl_reduce_scatter_offset: too many owned blocks: ", n_owned, + " (max ", RS_MAX_BLOCKS_PER_RANK, ")"); + // Balance is guaranteed above (n_owned * group_size == n_blocks), so + // rank_counter[r] never exceeds n_owned during the owned_sizes loop. + + // For each j, out[j] must have the same shape across all ranks. That means + // all blocks that are the j-th owned block on their respective rank must have + // equal size. Different j's may differ in size. + // + // Compute the size for each j by iterating all blocks in order, tracking + // how many blocks each rank has seen so far (= the j-index for that block). + std::vector owned_sizes(n_owned, -1); + { + std::vector rank_counter(group_size, 0); + for (int i = 0; i < n_blocks; i++) { + const int r = static_cast(effective_dst_ranks[i]); + const int j = rank_counter[r]++; + const int64_t sz = + effective_offsets[i] - (i > 0 ? effective_offsets[i - 1] : 0); + if (owned_sizes[j] < 0) { + owned_sizes[j] = sz; + } else { + TORCH_CHECK( + sz == owned_sizes[j], + "nccl_reduce_scatter_offset: all output at position j=", j, + " must have equal size across all ranks"); + } + } + } + + TORCH_CHECK( + static_cast(out.size()) == n_owned, + "nccl_reduce_scatter_offset: out.size() must be ", n_owned); + for (int j = 0; j < n_owned; j++) { + // dim=1: out[j] shape is (input.size(0), owned_sizes[j]) + // dim=0: out[j] shape is (owned_sizes[j], input.size(1)) + const int64_t exp0 = dim == 1 ? input.size(0) : owned_sizes[j]; + const int64_t exp1 = dim == 1 ? owned_sizes[j] : input.size(1); + TORCH_CHECK( + out[j].size(0) == exp0 && out[j].size(1) == exp1, + "nccl_reduce_scatter_offset: out[", j, "] must have shape (", + exp0, ", ", exp1, ")"); + TORCH_CHECK( + out[j].is_contiguous(), + "nccl_reduce_scatter_offset: out[", j, "] must be contiguous"); + TORCH_CHECK( + out[j].scalar_type() == input.scalar_type(), + "nccl_reduce_scatter_offset: out[", j, "] must have the same dtype as input"); + } + + // Per-slot CTA count: sized for each slot independently. owned_sizes[j] is + // consistent across ranks, so ctas_offset is identical on every rank, which + // guarantees all ranks launch the same total CTA count and agree on the + // flat barrier index for each (slot, local_block) pair. + const bool col_sharded = (dim == 1); + const int fixed_dim_size = static_cast(col_sharded ? input.size(0) : input.size(1)); + const int unroll = 4 * 16 / static_cast(input.element_size()); + const int elems_per_cta = RS_THREADS_PER_CTA * unroll; + const size_t window_base_offset = nccl_hdl->get_offset(); + + // Build the per-slot info struct. + // For dim=1: byte_offsets encodes the column-block start within the window. + // For dim=0: byte_offsets encodes the row-block start within the window. + ReduceScatterOffsetsInfo info; + info.n_owned = n_owned; + for (int j = 0; j < n_owned; j++) { + const int i = owned_indices[j]; + const int64_t block_start = (i > 0 ? effective_offsets[i - 1] : 0); + const size_t elem_offset = col_sharded + ? static_cast(input.storage_offset() + block_start) + : static_cast(input.storage_offset()) + + static_cast(block_start) * outer_stride; + info.byte_offsets[j] = window_base_offset + elem_offset * input.element_size(); + info.dst_ptrs[j] = out[j].data_ptr(); + info.dst_block_size[j] = static_cast(owned_sizes[j]); + const int numel_j = static_cast(owned_sizes[j]) * fixed_dim_size; + const int ctas_j = std::max(1, std::min( + (numel_j + elems_per_cta - 1) / elems_per_cta, RS_MAX_CTAS_PER_BLOCK)); + info.ctas_offset[j] = static_cast((j > 0 ? info.ctas_offset[j - 1] : 0) + ctas_j); + const int slot_start = j > 0 ? info.ctas_offset[j - 1] : 0; + for (int k = slot_start; k < info.ctas_offset[j]; ++k) { + info.cta_slot[k] = static_cast(j); + } + } + const int total_ctas = info.ctas_offset[n_owned - 1]; + + auto window = nccl_hdl->get_window(); + TORCH_CHECK(window != nullptr, "nccl_reduce_scatter_offset: NCCL window is null"); + + // Each owned (slot, local_block) pair gets one CTA; the flat CTA index is + // the LSA barrier index. All ranks launch the same total_ctas because + // owned_sizes[j] is consistent, so every rank's ctas_offset is identical. + AT_DISPATCH_NV_FLOATS( + input.scalar_type(), + "nccl_reduce_scatter_offset", + [&]() { + if (use_multimem) { + reduce_scatter_offset_kernel + <<>>( + window, info, fixed_dim_size, col_sharded, outer_stride, devcomm); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + reduce_scatter_offset_kernel + <<>>( + window, info, fixed_dim_size, col_sharded, outer_stride, devcomm); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); +#else + TORCH_CHECK( + false, + "nccl_reduce_scatter_offset requires NCCL >= 2.29.7 with reduce copy support"); +#endif // NCCL_DEVICE_HAS_REDUCE_COPY +} + +} // namespace c10d::nccl_extension diff --git a/torch/csrc/distributed/c10d/symm_mem/rocshmem_extension.cu b/torch/csrc/distributed/c10d/symm_mem/rocshmem_extension.cu new file mode 100644 index 0000000000000..4fc6b70863fbc --- /dev/null +++ b/torch/csrc/distributed/c10d/symm_mem/rocshmem_extension.cu @@ -0,0 +1,904 @@ +// ROCm implementation of the NVSHMEM symmetric memory extension ops. +// +// This is a separate file from nvshmem_extension.cu (rather than a hipified +// copy) for the following reasons: +// +// 1. API differences: NVSHMEM and rocSHMEM device APIs diverge enough that +// #ifdef'ing would be more noise than signal. Key differences include: +// - nvshmemx_collective_launch (grid-wide sync) has no rocSHMEM equivalent; +// ROCm uses regular hip kernel launches with host-side barriers instead. +// - nvshmemx_getmem_nbi_block → rocshmem_getmem_nbi_wg (workgroup scope). +// +// 2. Missing features: rocSHMEM does not yet support tiled communication +// (nvshmemx::Tensor, nvshmemx::tile_sum_reduce_block, etc.), so the +// tile_reduce and multi_root_tile_reduce ops are not included here. +// +// 3. Offset writeback: without grid-wide sync, multi-block kernels cannot +// safely write output offsets in-kernel (race with blocks still reading +// source_offsets). A separate writeOutputOffsets kernel runs after the data +// exchange completes. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +using namespace rocshmem; +namespace c10d::nvshmem_extension { + +#define THREADS_PER_BLOCK 512 +#define WARP_SIZE 64 + +namespace { + +bool parse_rocshmem_version_ge( + const char* version, + unsigned min_major, + unsigned min_minor, + unsigned min_patch) { + if (version == nullptr) { + return false; + } + char* end = nullptr; + unsigned long major = std::strtoul(version, &end, 10); + if (end == version || *end != '.') { + return false; + } + version = end + 1; + unsigned long minor = std::strtoul(version, &end, 10); + if (end == version || *end != '.') { + return false; + } + version = end + 1; + unsigned long patch = std::strtoul(version, &end, 10); + if (end == version) { + return false; + } + if (major > min_major) { + return true; + } + if (major < min_major) { + return false; + } + if (minor > min_minor) { + return true; + } + if (minor < min_minor) { + return false; + } + return patch >= min_patch; +} + +} // namespace + +extern "C" void rocshmem_init() __attribute__((weak)); + +bool is_nvshmem_available() { + static const bool ok = + parse_rocshmem_version_ge(rocshmem::VERSION, 3, 3, 0); + return ok; +} + +void nvshmemx_cumodule_init(uintptr_t module) { + auto hipmodule = reinterpret_cast(module); + NVSHMEM_CHECK( + rocshmem_hipmodule_init(hipmodule), + "rocshmem_hipmodule_init failed"); +} + +at::Tensor nvshmem_broadcast(at::Tensor& input, const int64_t root, const std::string& group_name) { + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + int rank = input_hdl->get_rank(); + void* buffer_ptr = input.mutable_data_ptr(); + auto buffer_size = input.numel() * input.element_size(); + auto& team_manager = TeamManager::get(input.device()); + auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); + int team_size = rocshmem_team_n_pes(team); + TORCH_CHECK(root < team_size, "root must be smaller than group size"); + + auto stream = at::cuda::getCurrentCUDAStream(); + rocshmem_broadcastmem_on_stream(team, buffer_ptr, buffer_ptr, buffer_size, root, stream); + return input; +} + +void nvshmem_put(at::Tensor& tensor, const int64_t peer) { + // TODO: support non-contiguous tensors + TORCH_CHECK(tensor.is_contiguous(), + "put op currently supports contiguous tensors only"); + // TODO: rendezvous should remember the group name + auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); + auto rank = hdl->get_rank(); + void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; + auto buffer_size = tensor.numel() * tensor.element_size(); + TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size"); + + c10::cuda::CUDAGuard guard(tensor.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + rocshmem_putmem_on_stream(buffer_ptr, tensor.data_ptr(), buffer_size, peer, stream); +} + +void nvshmem_wait_for_signal(at::Tensor& sigpad, int64_t signal, int64_t peer) { + c10::cuda::CUDAGuard guard(sigpad.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + rocshmem_signal_wait_until_on_stream(static_cast(sigpad.data_ptr()), ROCSHMEM_CMP_EQ, signal, stream); +} + +void nvshmem_put_with_signal(at::Tensor& tensor, at::Tensor& sigpad, int64_t signal, int64_t peer) { + auto buffer_size = tensor.numel() * tensor.element_size(); + + c10::cuda::CUDAGuard guard(tensor.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + rocshmem_putmem_signal_on_stream( + tensor.mutable_data_ptr(), + tensor.mutable_data_ptr(), + buffer_size, + static_cast(sigpad.mutable_data_ptr()), + signal, + ROCSHMEM_SIGNAL_SET, + peer, + stream); +} + +void nvshmem_get(at::Tensor& tensor, const int64_t peer) { + // TODO: support non-contiguous tensors + TORCH_CHECK(tensor.is_contiguous(), + "get op currently supports contiguous tensors only"); + // TODO: rendezvous should remember the group name + auto hdl = c10d::symmetric_memory::rendezvous(tensor, "0"); + auto rank = hdl->get_rank(); + void* buffer_ptr = hdl->get_buffer_ptrs()[rank]; + auto buffer_size = tensor.numel() * tensor.element_size(); + TORCH_CHECK(peer < hdl->get_world_size(), "peer must be smaller than world size"); + + c10::cuda::CUDAGuard guard(tensor.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + rocshmem_getmem_on_stream(tensor.mutable_data_ptr(), buffer_ptr, buffer_size, peer, stream); +} + +at::Tensor nvshmem_all_to_all( + at::Tensor& input, + at::Tensor& out, + std::string group_name) { + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + auto& team_manager = TeamManager::get(input.device()); + auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); + + void* input_ptr = input.data_ptr(); + void* output_ptr = out.mutable_data_ptr(); + TORCH_CHECK(input.is_contiguous() && out.is_contiguous()); + TORCH_CHECK_EQ(input.numel(), out.numel()); + TORCH_CHECK_EQ(input.dtype(), out.dtype()); + TORCH_CHECK_EQ(input.numel() % world_size, 0); + auto buffer_size = input.numel() * input.element_size(); + size_t bytes_per_rank = buffer_size / world_size; + + auto stream = at::cuda::getCurrentCUDAStream(input.device().index()); + rocshmem_alltoallmem_on_stream(team, output_ptr, input_ptr, bytes_per_rank, stream); + return out; +} + +// This is an exclusive prefix sum function that calculates read (or write) offsets for each peer. +__device__ int64_t prefixSum(int64_t *odata, int64_t *idata, int n) { + // Specialize BlockScan for a 1D block of threads, of type int64_t. + // - `BLOCK_SCAN_WARP_SCANS` is a low-latency scan algorithm (instead of high + // throughput which we don't need here). + // - `at_cuda_detail::cub` is torch's cub wrapper, see #55292. + using BlockScanT = ROCM_HIPCUB(at_cuda_detail::cub)::BlockScan; + // Allocate shared memory for BlockScan + __shared__ typename BlockScanT::TempStorage temp_storage; + + // TODO: currently it is assumed that the number of PE's is smaller than + // `THREADS_PER_BLOCK` + CUDA_KERNEL_ASSERT(n <= THREADS_PER_BLOCK); + + // Obtain input item for each thread + int tid = threadIdx.x; + int64_t thread_data = (tid < n) ? idata[tid] : 0; + + // Collectively compute the block-wide exclusive prefix sum + int64_t block_aggregate; + BlockScanT(temp_storage).ExclusiveSum(thread_data, thread_data, block_aggregate); + + // Store the result + odata[tid] = thread_data; + return block_aggregate; +} + +static int get_a2a_nblocks(size_t size, int world_size, bool intra_node) { + // Check user setting first + int num_blocks = c10d::symmetric_memory::getenv_nblocks(); + if (num_blocks > 0) { // set by user + return num_blocks; + } + // 16B per thread, 8 loops + constexpr size_t chunk_size = 16 * THREADS_PER_BLOCK * 8; + num_blocks = at::ceil_div(size, chunk_size); + // Allow kernel to target even number of blocks per peer + num_blocks = at::round_up(num_blocks, world_size); + const int max_blocks = intra_node ? 64 : 16; + return ::min(num_blocks, max_blocks); +} + +// ROCm-only offset writeback kernel. +// +// On ROCm, allToAllV is a regular multi-block kernel with no grid-wide barrier. +// Writing source_offsets in-kernel can race with other blocks still reading it for +// remote gets. We therefore compute output offsets in a separate kernel after +// allToAllV has completed on the stream. +__global__ void writeOutputOffsets1d(int64_t* out_splits_offsets, int world_size) { + auto output_splits = out_splits_offsets; + auto output_offsets = out_splits_offsets + world_size; + int tid = threadIdx.x; + + CUDA_KERNEL_ASSERT(world_size <= THREADS_PER_BLOCK); + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + prefixSum(peer_offsets, output_splits, world_size); + __syncthreads(); + + if (tid < world_size) { + output_offsets[tid] = peer_offsets[tid]; + } +} +// This kernel is used to exchange output splits and source offsets between peers. +// `in_out_splits` is of size (3, npes) and contains: +// - input splits (IN) +// - output splits (OUT) and +// - source offsets (OUT). +__global__ void exchangeSplitAndOffset(int64_t* input_splits, int64_t* out_splits_offsets, rocshmem_team_t team) { + CUDA_KERNEL_ASSERT(team != ROCSHMEM_TEAM_INVALID); + int mype = rocshmem_team_my_pe(team); + int npes = rocshmem_team_n_pes(team); + auto output_splits = out_splits_offsets; + auto source_offsets = out_splits_offsets + npes; + int tid = threadIdx.x; + + CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK); + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + + // Scan input splits to get the source offsets + prefixSum(peer_offsets, input_splits, npes); + __syncthreads();; + + // Use 1 block to do the exchange + if (tid < npes) { + // tid is peer index within team, but put calls require global rank + int peer_global = rocshmem_team_translate_pe(team, tid, ROCSHMEM_TEAM_WORLD); + rocshmem_int64_p(source_offsets + mype, peer_offsets[tid], peer_global); + rocshmem_int64_p(output_splits + mype, input_splits[tid], peer_global); + } + rocshmem_barrier_wg(); +} + +// This kernel is used to do the actual data exchange. +// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. +// `stride` is the stride at dim 0, unit in byte. +__global__ void allToAllV(void *send_data, void *recv_data, int64_t* out_splits_offsets, size_t stride, rocshmem_team_t team) { + CUDA_KERNEL_ASSERT(team != ROCSHMEM_TEAM_INVALID); + int mype = rocshmem_team_my_pe(team); + int npes = rocshmem_team_n_pes(team); + auto output_splits = out_splits_offsets; + auto source_offsets = out_splits_offsets + npes; + int bid = blockIdx.x; + int tid = threadIdx.x; + int blocks_per_peer = max(gridDim.x / npes, 1); + + // Calculate the output offsets + CUDA_KERNEL_ASSERT(npes <= THREADS_PER_BLOCK); + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + prefixSum(peer_offsets, output_splits, npes); + __syncthreads(); + + // Target a different peer based on bid + for (int i = bid / blocks_per_peer; i < npes; i += gridDim.x / blocks_per_peer) { + int peer = (mype + i) % npes; + auto peer_global = rocshmem_team_translate_pe(team, peer, ROCSHMEM_TEAM_WORLD); + // Total amount from `peer` + auto peer_size = output_splits[peer] * stride; + // Amount to get from `peer` in this block + auto block_size = peer_size / blocks_per_peer; + // Being lazy here, we should handle the residual if the division is not exact + CUDA_KERNEL_ASSERT(block_size * blocks_per_peer == peer_size); + // This block's offset in the data from `peer` + auto block_offset = block_size * (bid % blocks_per_peer); + auto source_offset = source_offsets[peer] * stride + block_offset; + auto write_offset = peer_offsets[peer] * stride + block_offset; + rocshmem_getmem_nbi_wg( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + block_size, + peer_global); + } + rocshmem_quiet(); +} + +void all_to_all_vdev( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_splits, + at::Tensor& out_splits_offsets, + std::string group_name) { + /* Perform AllToAllv operation using NVSHMEM, with split information provided on device. + * Step 1: Rendezvous tensors so all ranks have symmetric (device) pointers. + * Step 2: Launch exchangeSplitAndOffset kernel to exchange per-rank split counts + * and compute source offsets (prefix sum); uses team barrier. + * Step 3: Launch allToAllV kernel to copy data between peers according to + * the exchanged splits/offsets. + * Arguments: + * - `input` is the send buffer + * - `out` is the receive buffer + * - `in_splits`: 1D[npes] num of elements this rank sends to each peer + * - `out_splits_offsets`:2D (2, npes). row0 = output splits, row1 = output offsets + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name); + auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name); + int world_size = input_hdl->get_world_size(); + + void* input_ptr = input.data_ptr(); + void* output_ptr = out.mutable_data_ptr(); + int64_t* in_splits_ptr = (int64_t*)(in_splits.const_data_ptr()); + int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr()); + + TORCH_CHECK_EQ(input.device(), out.device()); + auto device = input.device(); + c10::cuda::CUDAGuard guard(device); + auto& team_manager = TeamManager::get(device); + auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); + auto stream = at::cuda::getCurrentCUDAStream(device.index()); + + exchangeSplitAndOffset<<>>( + in_splits_ptr, out_splits_offsets_ptr, team); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); + rocshmem::rocshmem_barrier_all(); + // CTA Tuning + auto input_size = input.numel() * input.element_size(); + int num_blocks = get_a2a_nblocks( + input_size, + input_hdl->get_world_size(), + input_hdl->world_within_direct_access()); + + // Stride at dim 0 (assuming input is contiguous, TODO) + size_t stride_bytes = input.stride(0) * input.element_size(); + + allToAllV<<>>( + input_ptr, output_ptr, out_splits_offsets_ptr, + stride_bytes, team); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); + // `allToAllV` reads source_offsets while fetching remote shards. Since ROCm has + // no grid-wide sync here, writing output offsets in the same kernel can race + // with those reads. Write output offsets in a follow-up kernel instead. + writeOutputOffsets1d<<>>( + out_splits_offsets_ptr, world_size); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// Start of `all_to_all_vdev_2d` + +// This is an warp-scope, exclusive prefix sum. When called by a block of +// threads, each warp will perform an independent prefix sum, concurrently. +// Returns the sum of all elements in the warp. +// `NUM_WARPS` is the number of warps participating the concurrent prefix sum. +template +__device__ int64_t prefixSum_warp(int64_t *odata, int64_t *idata, int n) { + CUDA_KERNEL_ASSERT(n <= WARP_SIZE); + + // Specialize WarpScan for type int + using WarpScan = ROCM_HIPCUB(at_cuda_detail::cub)::WarpScan; + // Allocate WarpScan shared memory for N warps + __shared__ typename WarpScan::TempStorage temp_storage[NUM_WARPS]; + + int warp_id = threadIdx.x / WARP_SIZE; + if (warp_id >= NUM_WARPS) { + return 0; + } + + // Obtain input item for each thread + int tid = threadIdx.x % WARP_SIZE; + int64_t thread_data = (tid < n) ? idata[tid] : 0; + + // Total sum of all elements in the warp + int64_t warp_aggregate; + // Compute the warp-wide exclusive prefix sum + WarpScan(temp_storage[warp_id]).ExclusiveSum(thread_data, thread_data, warp_aggregate); + + // Store only valid lanes to avoid out-of-bounds writes when n < WARP_SIZE. + if (tid < n) { + odata[tid] = thread_data; + } + return warp_aggregate; +} + +// This is for abstracting a thread-group-scope, exclusive prefix sum. +// Since we use warp-scope prefix sum, the thread group size is limited to warp size. +#define A2AV_TILE_SIZE WARP_SIZE + + +__global__ void writeOutputOffsets_2d( + int64_t* out_splits_offsets, + int minor_size, + int major_size, + int64_t major_align) { + int nsplits = minor_size * major_size; + auto output_splits = out_splits_offsets; + auto source_offsets = out_splits_offsets + nsplits; + int tid = threadIdx.x; + + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + int tileId = tid / A2AV_TILE_SIZE; + int laneId = tid % A2AV_TILE_SIZE; + __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; + int nsplits_per_tile = min(minor_size, nsplits - tileId * minor_size); + + __shared__ int64_t len_per_tile[NUM_TILES]; + if (nsplits_per_tile > 0) { + int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * minor_size, nsplits_per_tile); + if (laneId == A2AV_TILE_SIZE - 1) { + if (major_align != 0) { + auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; + len_per_tile[tileId] = max(aligned_len, major_align); + } else { + len_per_tile[tileId] = my_tile_len; + } + } + } + __syncthreads(); + + __shared__ int64_t start_offset_per_tile[WARP_SIZE]; + prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES); + __syncthreads(); + + tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId]; + __syncthreads(); + + if (tid < nsplits) { + source_offsets[tid] = tile_prefix_sums[tid / minor_size][tid % minor_size]; + } +} + +// `exchangeSplitAndOffset_2d` is used to exchange output splits and source +// offsets between peers. + +/* Arguments: + * `in_splits_offsets`: input splits and offsets (optional), of size (2, nsplits), or (1, nsplits) if no offsets are provided. + * `out_splits_offsets`: output splits and offsets, of size (2, nsplits). + * `mype`: the rank of the current PE. + * `npes`: the number of PEs. + * `ne`: the number of experts. + * `input_dim0`: the size of dim 0 of the input tensor. + * `rank_is_row_in` is a boolean flag indicating whether the input has ranks as row or experts as row. +*/ + +/* Template parameters: + * `HAS_IN_OFFSETS` is a boolean flag indicating whether `in_splits_offsets` has offsets (2nd row) or not. +*/ + +template +__global__ void exchangeSplitAndOffset_2d(int64_t* in_splits_offsets, int64_t* out_splits_offsets, rocshmem_team_t team, int ne, size_t input_dim0, bool rank_is_row_in) { + CUDA_KERNEL_ASSERT(team != ROCSHMEM_TEAM_INVALID); + int mype = rocshmem_team_my_pe(team); + int npes = rocshmem_team_n_pes(team); + int nsplits = npes * ne; + auto input_splits = in_splits_offsets; + auto output_splits = out_splits_offsets; + // Borrowing the space below as a temporary exchange pad. + auto source_offsets = out_splits_offsets + nsplits; + int tid = threadIdx.x; + + int64_t* input_offsets = nullptr; + if (HAS_IN_OFFSETS) { + // input offset are provided, so we can use them directly + input_offsets = in_splits_offsets + nsplits; + } else { + // input offset are not provided, so we need to calculate them. + // Scan input splits to get the source offsets + __shared__ int64_t peer_offsets[THREADS_PER_BLOCK]; + auto sum_of_splits = prefixSum(peer_offsets, input_splits, nsplits); + __syncthreads();; + CUDA_KERNEL_ASSERT(sum_of_splits <= input_dim0 && "sum of splits is larger than input dim\n"); + // Redirect the input splits to the calculated result + input_offsets = peer_offsets; + } + + // Use 1 block to do the exchange + if (tid < nsplits) { + int peer, e, dst_offset; + if (rank_is_row_in) { + peer = tid / ne; + e = tid % ne; + dst_offset = e * npes + mype; + } else { // expert is row in input + peer = tid % npes; + e = tid / npes; + dst_offset = mype * ne + e; + } + // This does a transpose from rank-major order to expert-major order + // (or vice versa). + auto split_val = input_splits[tid]; + CUDA_KERNEL_ASSERT(split_val >= 0 && "split value is negative\n"); + auto peer_global = rocshmem_team_translate_pe(team, peer, ROCSHMEM_TEAM_WORLD); + rocshmem_int64_p(source_offsets + dst_offset, input_offsets[tid], peer_global); + rocshmem_int64_p(output_splits + dst_offset, split_val, peer_global); + } + rocshmem_barrier_wg(); +} + + +// This kernel is used to do the actual data exchange. +// `in_out_splits` has the same definition as in `exchangeSplitAndOffset`. +// `stride` is the stride at dim 0, unit in byte. +// For meaning of `mype` and `npes`, see the docstring of `all_to_all_vdev_2d`. +// `major_align` is the alignment at dim 0, unit in element. If 0, no alignment is needed. + +// `rank_is_row_out` is a boolean flag indicating whether the output has ranks as rows or experts as rows. +// In dispatch case, rank_is_row_out = false, major_size = ne, minor_size = npes. +// In combine case, rank_is_row_out = true, major_size = npes, minor_size = ne. + +__global__ void allToAllV_2d(void *send_data, void *recv_data, int64_t* in_splits, int64_t* out_splits_offsets, size_t stride, int minor_size, int major_size, int64_t major_align, bool rank_is_row_out, rocshmem_team_t team) { + int nsplits = minor_size * major_size; + auto output_splits = out_splits_offsets; + auto source_offsets = out_splits_offsets + nsplits; + int bid = blockIdx.x; + int tid = threadIdx.x; + + // Split the thread block into tiles + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + int tileId = tid / A2AV_TILE_SIZE; + int laneId = tid % A2AV_TILE_SIZE; + // Each tile calculates its own prefix sum + __shared__ int64_t tile_prefix_sums[NUM_TILES][A2AV_TILE_SIZE]; + // A tile takes care of minor_size worth of splits + int nsplits_per_tile = min(minor_size, nsplits - tileId * minor_size); + // TODO: currently it is assumed that the number of PE's is smaller than + // `A2AV_TILE_SIZE` bc the warp-scope prefix sum can only handle up to + // WARP_SIZE elements + CUDA_KERNEL_ASSERT(minor_size <= A2AV_TILE_SIZE && "minor_size is too large\n"); + // Similarly, the number of experts per rank is also assumed to be smaller + // than `NUM_TILES` + CUDA_KERNEL_ASSERT(major_size <= NUM_TILES && "major_size is too large\n"); + + // Total length of each tile + __shared__ int64_t len_per_tile[NUM_TILES]; + // When `nsplits` is small, not every tile gets data to sum. They can skip + // this local prefix sum. + if (nsplits_per_tile > 0) { + // Each tile calculates its own prefix sum, return value is the sum of all elements in the tile. + int64_t my_tile_len = prefixSum_warp(tile_prefix_sums[tileId], output_splits + tileId * minor_size, nsplits_per_tile); + // Last thread in each tile does the up aligning. + if (laneId == A2AV_TILE_SIZE - 1) { + if (major_align != 0) { // Needs alignment + auto aligned_len = (my_tile_len + major_align - 1) / major_align * major_align; + // In case `aligned_len` is 0, we set it to `major_align` to avoid an + // empty bin, bc cutlass currently does not support it. See + // https://github.com/pytorch/pytorch/issues/152668. + len_per_tile[tileId] = max(aligned_len, major_align); + } else { // 0 means alignment not needed + len_per_tile[tileId] = my_tile_len; + } + } + } + __syncthreads(); + + // Starting offset of each tile + __shared__ int64_t start_offset_per_tile[NUM_TILES]; + // Prefix sum again to get the tiles' start offsets. + // `NUM_TILES` is typically not greater than 32, because 32 tiles * 32 threads + // = 1024 threads, and this kernel is launched within 1024 threads. Thus, we + // can use warp-scope prefix sum. + static_assert(NUM_TILES <= WARP_SIZE); + // Only 1 warp is needed + prefixSum_warp<1>(start_offset_per_tile, len_per_tile, NUM_TILES); + __syncthreads(); + + // Add tile offset to every element in the tile + tile_prefix_sums[tileId][laneId] += start_offset_per_tile[tileId]; + __syncthreads(); + + // Target a different e based on bid + for (int eid = bid; eid < nsplits; eid += gridDim.x) { + int row = eid / minor_size; + int col = eid % minor_size; + // Amount from `peer` for `e` + auto peer_size = output_splits[eid] * stride; + auto source_offset = source_offsets[eid] * stride; + auto e_offset = tile_prefix_sums[row][col]; + auto write_offset = e_offset * stride; + auto peer_global = rocshmem_team_translate_pe(team, rank_is_row_out ? row : col, ROCSHMEM_TEAM_WORLD); + rocshmem_getmem_nbi_wg( + (char*)recv_data + write_offset, + (char*)send_data + source_offset, + peer_size, + peer_global); // peer's global index + } + rocshmem_quiet(); +} + +void all_to_all_vdev_2d( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_splits, + at::Tensor& out_splits_offsets, + std::string group_name, + std::optional major_align) { + /* Perform a 2D AllToAllv shuffle operation using NVSHMEM, with split information provided on device. + * Arguments: + * - `input` is the input tensor + * - `out` is the output tensor + * - `in_out_splits` is a 2D tensor of size (3, `world_size` * `ne`). In the + scenario of Mixture-of-Experts models, `ne` is the number of experts per + rank. The rows of `in_out_splits` are (in order): + input splits (IN) + output splits (OUT) and + output offsets (OUT). + * - `group_name` is the name of the group to use for the collective operation. + * - `major_align` is the alignment of the "major dimension" of the output + sequence. See below for details. + + * A 2D AllToAllv shuffle is illustrated below: + (world_size = 2, ne = 2, total number of experts = 4) + Source: | Rank 0 | Rank 1 | + | c0 | c1 | c2 | c3 | d0 | d1 | d2 | d3 | + + Dest : | Rank 0 | Rank 1 | + | c0 | d0 | c1 | d1 | c2 | d2 | c3 | d3 | + where each `c_i` / `d_i` are slices of the `input` tensor, targeting + expert `i`, with length indicated by input splits (in + `in_out_splits[0]`). That is, the 2D AllToAllv shuffle achieves a + transpose from rank-major order at input to expert-major order at + output. + + * If `major_align` is not 1, the output offsets of c1, c2, c3 will be + up-aligned to this value. For example, if c0 has length 5 and d0 has + length 7 (making a total of 12), and if the `major_align` is set to 16, + the output offset of c1 will be 16. Similar for c2 and c3. This value has + no effect on the offset of the minor dimension, i.e. d0, d1, d2 and d3. + Note: since cutlass does not support empty bins, we set the aligned length + to `major_align` if it is 0. See + https://github.com/pytorch/pytorch/issues/152668. + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto in_splits_hdl = c10d::symmetric_memory::rendezvous(in_splits, group_name); + auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name); + int world_size = input_hdl->get_world_size(); + // TODO: world_size is currently limited by the number of elements in a WarpScan. + TORCH_CHECK(world_size <= A2AV_TILE_SIZE, "world_size must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE); + + // If `major_align` is not provided, use 1 as the default value. + int64_t major_align_val = major_align.value_or(1); + TORCH_CHECK(major_align_val > 0, "major_align must be positive"); + + void* input_ptr = input.data_ptr(); + void* output_ptr = out.mutable_data_ptr(); + int64_t* in_splits_ptr = (int64_t*)(in_splits.data_ptr()); + int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr()); + + // Shape checks + TORCH_CHECK(in_splits.is_contiguous() + && out_splits_offsets.is_contiguous() + && input.is_contiguous() + && out.is_contiguous(), + "input, out, in_splits and out_splits_offsets must be contiguous"); + auto in_split_shape = in_splits.sizes(); + auto out_split_shape = out_splits_offsets.sizes(); + TORCH_CHECK(out_split_shape.size() == 2 + && out_split_shape[0] == 2 + && out_split_shape[1] == in_split_shape[0] + && in_split_shape[0] % world_size == 0, + "out_splits_offsets must be 2D with 2 rows, " + "each row must be a multiple of world_size"); + + // Consistency checks + TORCH_CHECK(input.dtype() == out.dtype() + && input.stride(0) == out.stride(0), + "input and out must have the same dtype and same stride at dim 0"); + TORCH_CHECK(in_splits.scalar_type() == at::kLong + && out_splits_offsets.scalar_type() == at::kLong, + "splits and offsets must be int64"); + + // Number of experts per rank + int ne = in_split_shape[0] / world_size; + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + TORCH_CHECK(ne <= NUM_TILES, "Number of experts must be smaller than NUM_TILES", NUM_TILES); + + // Set device context for getting the stream and launching kernels below + auto device = input.device(); + TORCH_CHECK(device.type() == at::DeviceType::CUDA && + out.device() == device && + in_splits.device() == device && + out_splits_offsets.device() == device, + "all tensor arguments must be on the same CUDA device"); + c10::cuda::CUDAGuard guard(device); + auto stream = at::cuda::getCurrentCUDAStream(); + auto& team_manager = TeamManager::get(device); + auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); + + // Exchange output splits and source offsets + auto input_dim0 = input.size(0); + bool rank_is_row_in = true; + exchangeSplitAndOffset_2d<<>>( + in_splits_ptr, out_splits_offsets_ptr, team, + ne, input_dim0, rank_is_row_in); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); + rocshmem::rocshmem_barrier_all(); + // CTA Tuning + // Naive for now, use 1 block per expert. + // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node). + int num_blocks = ::min(world_size * ne, world_size > 8 ? 8 : 64); + + // Stride at dim 0 + size_t stride_bytes = input.stride(0) * input.element_size(); + bool rank_is_row_out = !rank_is_row_in; + + allToAllV_2d<<>>( + input_ptr, output_ptr, + in_splits_ptr, out_splits_offsets_ptr, + stride_bytes, world_size, + ne, major_align_val, rank_is_row_out, team); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // allToAllV_2d uses a regular multi-block launch with no grid-wide sync, so + // in-kernel writeback can race with other blocks still reading source_offsets. + writeOutputOffsets_2d<<>>( + out_splits_offsets_ptr, world_size, ne, major_align_val); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); +} + +void all_to_all_vdev_2d_offset( + at::Tensor& input, + at::Tensor& out, + at::Tensor& in_splits_offsets, + at::Tensor& out_splits_offsets, + std::string group_name) { + /* Perform a 2D AllToAllv shuffle operation, with input split and offset + * information provided on device. The input offsets are not required to be + * exact prefix sum of the input splits, i.e. paddings are allowed between the + * split chunks. The paddings, however, will not be transferred to peer + * ranks. + + * In Mixture of Experts models, this operation can be used to combine tokens + * processed by experts on parallel ranks. This operation can be viewed as an + * "reverse" operation to the `all_to_all_vdev_2d` operation (which shuffles + * tokens to experts). + + * Arguments: + * - `input` is the input tensor + * - `out` is the output tensor + * - `in_splits_offsets` is a 2D tensor of size (2, `ne` * `world_size`). In the + scenario of Mixture-of-Experts models, `ne` is the number of experts per + rank. The rows of `in_splits_offsets` are (in order): + input splits (IN) and + input offsets (IN) + * - `out_splits_offsets` is a 2D tensor of size (2, `world_size` * `ne`). The + rows are (in order): + output splits (OUT) and + output offsets (OUT). + * - `group_name` is the name of the group to use for the collective operation. + */ + auto input_hdl = c10d::symmetric_memory::rendezvous(input, group_name); + auto out_hdl = c10d::symmetric_memory::rendezvous(out, group_name); + auto out_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(out_splits_offsets, group_name); + auto in_splits_offsets_hdl = c10d::symmetric_memory::rendezvous(in_splits_offsets, group_name); + int rank = input_hdl->get_rank(); + int world_size = input_hdl->get_world_size(); + constexpr int NUM_TILES = THREADS_PER_BLOCK / A2AV_TILE_SIZE; + TORCH_CHECK(world_size <= NUM_TILES, "world_size must be smaller than NUM_TILES", NUM_TILES); + + int64_t major_align_val = 0; + + void* input_ptr = input.data_ptr(); + void* output_ptr = out.mutable_data_ptr(); + int64_t* out_splits_offsets_ptr = (int64_t*)(out_splits_offsets.mutable_data_ptr()); + int64_t* in_splits_offsets_ptr = (int64_t*)(in_splits_offsets.data_ptr()); + + // Shape checks + TORCH_CHECK(out_splits_offsets.is_contiguous() + && in_splits_offsets.is_contiguous() + && input.is_contiguous() + && out.is_contiguous(), + "input, out, in_splits_offsets and out_splits_offsets must be contiguous"); + auto out_split_shape = out_splits_offsets.sizes(); + auto in_split_shape = in_splits_offsets.sizes(); + TORCH_CHECK(in_split_shape.size() == 2 + && in_split_shape[0] == 2 + && in_split_shape[1] % world_size == 0, + "in_splits_offsets must be 2D with 2 rows, " + "each row must be a multiple of world_size"); + + // Consistency checks + TORCH_CHECK(input.dtype() == out.dtype() + && input.stride(0) == out.stride(0), + "input and out must have the same dtype and same stride at dim 0"); + TORCH_CHECK(out_splits_offsets.scalar_type() == at::kLong + && in_splits_offsets.scalar_type() == at::kLong, + "splits and offsets must be int64"); + + // Number of experts per rank + int ne = in_split_shape[1] / world_size; + // TODO: number of experts is currently limited by the number of elements in a WarpScan. + TORCH_CHECK(ne <= A2AV_TILE_SIZE, "Number of experts must be smaller than A2AV_TILE_SIZE", A2AV_TILE_SIZE); + + // Set device context for getting the stream and launching kernels below + auto device = input.device(); + TORCH_CHECK(device.type() == at::DeviceType::CUDA && + out.device() == device && + in_splits_offsets.device() == device && + out_splits_offsets.device() == device, + "all tensor arguments must be on the same CUDA device"); + c10::cuda::CUDAGuard guard(device); + auto stream = at::cuda::getCurrentCUDAStream(); + auto& team_manager = TeamManager::get(device); + auto team = team_manager.get_team(group_name, input_hdl->get_rank_to_global_rank()); + + // Exchange output splits and source offsets + auto input_dim0 = input.size(0); + bool rank_is_row_in = false; + exchangeSplitAndOffset_2d<<>>( + in_splits_offsets_ptr, + out_splits_offsets_ptr, + team, + ne, input_dim0, rank_is_row_in); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); + rocshmem::rocshmem_barrier_all(); + // CTA Tuning + // Naive for now, use 1 block per expert. + // Total number of blocks is limited to 64 (intra-node) or 8 (inter-node). + int num_blocks = ::min(world_size * ne, world_size > 8 ? 8 : 64); + + // Stride at dim 0 + size_t stride_bytes = input.stride(0) * input.element_size(); + bool rank_is_row_out = !rank_is_row_in; + + allToAllV_2d<<>>( + input_ptr, + output_ptr, + in_splits_offsets_ptr, + out_splits_offsets_ptr, + stride_bytes, + ne, + world_size, + major_align_val, + rank_is_row_out, + team); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + writeOutputOffsets_2d<<>>( + out_splits_offsets_ptr, ne, world_size, major_align_val); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + C10_CUDA_CHECK(hipStreamSynchronize(stream)); +} + +} // namespace c10d::nvshmem_extension + +TORCH_LIBRARY_IMPL(symm_mem, CUDA, m) { + m.impl("nvshmem_broadcast", c10d::nvshmem_extension::nvshmem_broadcast); + m.impl("nvshmem_put", c10d::nvshmem_extension::nvshmem_put); + m.impl("nvshmem_get", c10d::nvshmem_extension::nvshmem_get); + m.impl("nvshmem_wait_for_signal", c10d::nvshmem_extension::nvshmem_wait_for_signal); + m.impl("nvshmem_put_with_signal", c10d::nvshmem_extension::nvshmem_put_with_signal); + m.impl("nvshmem_all_to_all", c10d::nvshmem_extension::nvshmem_all_to_all); + m.impl("all_to_all_vdev", c10d::nvshmem_extension::all_to_all_vdev); + m.impl("all_to_all_vdev_2d", c10d::nvshmem_extension::all_to_all_vdev_2d); + m.impl("all_to_all_vdev_2d_offset", c10d::nvshmem_extension::all_to_all_vdev_2d_offset); +} diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp index ce85ee4f5c4ba..09a8ca0555b55 100644 --- a/torch/csrc/distributed/rpc/init.cpp +++ b/torch/csrc/distributed/rpc/init.cpp @@ -393,28 +393,18 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) { .def( py::pickle( /* __getstate__ */ - [](const PyRRef& /* unused */) { + [](const PyRRef& /* unused */) -> py::tuple { TORCH_CHECK( false, "Can not pickle rref in python pickler, rref can only be " "pickled when using RPC"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy Pybind API's - // requirement. - return py::make_tuple(); }, /* __setstate__ */ - [](py::tuple /* unused */) { // NOLINT + [](py::tuple /* unused */) -> std::nullptr_t { // NOLINT TORCH_CHECK( false, "Can not unpickle rref in python pickler, rref can only be " "unpickled when using RPC"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy PyBind's API - // requirement. - return PyRRef( - py::cast(Py_None), - py::cast(Py_None)); }), py::call_guard()) .def( diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index e20f8730b6ec0..84358a8d2cb13 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -1265,23 +1265,14 @@ void TensorPipeAgent::updateGroupMembership( workerNameToURL_.erase(name); // remove reverse device maps that are no longer used - for (auto it = reverseDeviceMaps_.begin(); - it != reverseDeviceMaps_.end();) { - if (reverseDeviceMaps.find(it->first) == reverseDeviceMaps.end()) { - it = reverseDeviceMaps_.erase(it); - } else { - it++; - } - } + std::erase_if(reverseDeviceMaps_, [&reverseDeviceMaps](const auto& kv) { + return !reverseDeviceMaps.contains(kv.first); + }); // remove devices that are no longer used - for (auto it = devices_.begin(); it != devices_.end();) { - if (std::find(devices.begin(), devices.end(), *it) == devices.end()) { - it = devices_.erase(it); - } else { - it++; - } - } + std::erase_if(devices_, [&](const auto& d) { + return std::find(devices.begin(), devices.end(), d) == devices.end(); + }); } } std::unordered_map TensorPipeAgent::getMetrics() { diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index 58cb48de664d5..ee11e9b46519b 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -488,7 +488,7 @@ static PyObject* dynamo__custom_eval_frame_shim( // - Python callable(): enables TorchDynamo PyObject* callback = eval_frame_callback_get(); - if (callback == Py_None) { + if (Py_IsNone(callback)) { return dynamo_eval_frame_default(tstate, frame, throw_flag); } @@ -587,7 +587,7 @@ static PyObject* set_eval_frame(PyObject* new_callback, PyObject* module) { // None. Skip messing with threading, thread-local storage, and // reference counts. if (old_callback != new_callback) { - if (new_callback == Py_None) { + if (Py_IsNone(new_callback)) { decrement_working_threads(PyThreadState_GET(), module); } else { increment_working_threads(PyThreadState_GET(), module); @@ -612,7 +612,7 @@ static PyObject* set_eval_frame(PyObject* new_callback, PyObject* module) { } static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) { - if (callback != Py_None && callback != Py_False && + if (!Py_IsNone(callback) && !Py_IsFalse(callback) && !PyCallable_Check(callback)) { DEBUG_TRACE0("arg error"); PyErr_SetString(PyExc_TypeError, "expected a callable"); @@ -620,21 +620,21 @@ static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) { } DEBUG_TRACE( "python enabled=%d and is run_only=%d", - callback != Py_None, - callback == Py_False); + !Py_IsNone(callback), + Py_IsFalse(callback)); return set_eval_frame(callback, module); } static PyObject* set_skip_guard_eval_unsafe( PyObject* dummy, PyObject* skip_guard_unsafe_flag) { - if (skip_guard_unsafe_flag != Py_False && skip_guard_unsafe_flag != Py_True) { + if (!Py_IsFalse(skip_guard_unsafe_flag) && !Py_IsTrue(skip_guard_unsafe_flag)) { DEBUG_TRACE0("arg error"); PyErr_SetString(PyExc_TypeError, "expected True/False"); return NULL; } bool old_skip_guard_eval_unsafe = is_skip_guard_eval_unsafe; - is_skip_guard_eval_unsafe = skip_guard_unsafe_flag == Py_True; + is_skip_guard_eval_unsafe = Py_IsTrue(skip_guard_unsafe_flag); if (old_skip_guard_eval_unsafe) { Py_RETURN_TRUE; } @@ -672,7 +672,7 @@ static PyObject* unsupported(PyObject* dummy, PyObject* args) { } static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { - if (obj == Py_None) { + if (Py_IsNone(obj)) { obj = NULL; } Py_XSETREF(guard_error_hook, Py_XNewRef(obj)); @@ -682,7 +682,7 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) { PyObject* old_hook = guard_complete_hook; - if (obj == Py_None) { + if (Py_IsNone(obj)) { obj = NULL; } @@ -727,6 +727,48 @@ static int clear_state(PyObject* module) { bool is_skip_guard_eval_unsafe = false; +// -1 means inactive, >= 0 means active with that many compiled frames. +int fullgraph_compiled_frame_count = -1; + +// When true and fullgraph_compiled_frame_count > 0, sub-frames under fullgraph +// compilation will error (via get_fail_callback) instead of being silently +// skipped. +bool fullgraph_error_on_nested_compile = false; + +// Set the fullgraph compiled frame counter and return the old value. +// If setting to >= 0 (activating) and already active, no-op. +static PyObject* set_fullgraph_compiled_frame_count_py( + PyObject* dummy, + PyObject* arg) { + long val = PyLong_AsLong(arg); + if (val == -1 && PyErr_Occurred()) { + return NULL; + } + int old = fullgraph_compiled_frame_count; + if (val >= 0 && old >= 0) { + // Already active, no-op. + } else { + fullgraph_compiled_frame_count = (int)val; + } + return PyLong_FromLong(old); +} + +// Set fullgraph_error_on_nested_compile and return the old value. +static PyObject* set_fullgraph_error_on_nested_compile_py( + PyObject* dummy, + PyObject* arg) { + if (!Py_IsFalse(arg) && !Py_IsTrue(arg)) { + PyErr_SetString(PyExc_TypeError, "expected True/False"); + return NULL; + } + bool old = fullgraph_error_on_nested_compile; + fullgraph_error_on_nested_compile = Py_IsTrue(arg); + if (old) { + Py_RETURN_TRUE; + } + Py_RETURN_FALSE; +} + static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame_py, METH_O, NULL}, {"set_skip_guard_eval_unsafe", set_skip_guard_eval_unsafe, METH_O, NULL}, @@ -740,6 +782,14 @@ static PyMethodDef _methods[] = { {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, {"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL}, {"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL}, + {"set_fullgraph_compiled_frame_count", + set_fullgraph_compiled_frame_count_py, + METH_O, + NULL}, + {"set_fullgraph_error_on_nested_compile", + set_fullgraph_error_on_nested_compile_py, + METH_O, + NULL}, {NULL, NULL, 0, NULL}}; static struct PyModuleDef _module = { diff --git a/torch/csrc/dynamo/eval_frame.h b/torch/csrc/dynamo/eval_frame.h index 870603262ddb6..703b9fa306e1f 100644 --- a/torch/csrc/dynamo/eval_frame.h +++ b/torch/csrc/dynamo/eval_frame.h @@ -32,6 +32,8 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New( THP_EVAL_API_FRAME_OBJECT* frame); extern bool is_skip_guard_eval_unsafe; +extern int fullgraph_compiled_frame_count; +extern bool fullgraph_error_on_nested_compile; void clear_old_frame_if_python_312_plus( PyThreadState* tstate, diff --git a/torch/csrc/dynamo/eval_frame_cpp.cpp b/torch/csrc/dynamo/eval_frame_cpp.cpp index 5eed2371af625..36f072eef7138 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.cpp +++ b/torch/csrc/dynamo/eval_frame_cpp.cpp @@ -96,9 +96,7 @@ py::object get_null_stack_value() { py::list _get_frame_value_stack_with_depth( const py::handle& frame_obj, int depth) { - if (!PyFrame_Check(frame_obj.ptr())) { - throw py::type_error("expected a frame object!"); - } + TORCH_CHECK_TYPE(PyFrame_Check(frame_obj.ptr()), "expected a frame object!"); py::list result; if (depth <= 0) { @@ -294,9 +292,9 @@ static py::handle _callback_from_action( static int32_t c_recursion_limit = -1; void dynamo_set_c_recursion_limit(int32_t limit) { - if (limit < 1 && limit != -1) { - throw std::range_error("recursion limit must be >= 1, or -1 to reset"); - } + TORCH_CHECK_VALUE( + limit >= 1 || limit == -1, + "recursion limit must be >= 1, or -1 to reset"); c_recursion_limit = limit; } @@ -339,18 +337,6 @@ struct CRecursionLimitRAII { #endif -EvalFrameOverride eval_frame_override = EvalFrameOverride::NONE; - -EvalFrameOverride get_eval_frame_override() { - return eval_frame_override; -} - -EvalFrameOverride set_eval_frame_override(EvalFrameOverride override) { - EvalFrameOverride prev = eval_frame_override; - eval_frame_override = override; - return prev; -} - // frame and callback are borrowed references. // Returns new reference. PyObject* dynamo__custom_eval_frame( @@ -396,7 +382,7 @@ PyObject* dynamo__custom_eval_frame( // immediately skip the frame, and (2) even if it did, this would only // be profitable if there was tensor code in the unwinding code. Seems // unlikely. - DEBUG_TRACE("throw %s", get_frame_name(frame)); + DEBUG_TRACE("throw %s", get_frame_name(frame)); // @allow-raw-throw return dynamo_eval_frame_default(tstate, frame, throw_flag); } @@ -430,28 +416,28 @@ PyObject* dynamo__custom_eval_frame( // original frame, we are responsible for clearing it - via // clear_old_frame_if_python_312_plus. auto eval_custom = [&]() { - // If we're attempting to run dynamo-generated code and eval frame override - // is set to SKIP, then we should set the callback to None to skip. - // If the override is set to ERROR, then we call - // torch._dynamo.convert_frame.get_fail_callback, which patches - // convert_frame.compile_frame with a function that errors unconditionally. - // This means Dynamo will error if it attempts to trace into the frame - // (Python-level skips pre-trace are permissible). - if (!recursive_callback.is_none() && - !recursive_callback.is(py::bool_(false))) { - if (eval_frame_override == EvalFrameOverride::SKIP) { - recursive_callback = py::none(); - } else if (eval_frame_override == EvalFrameOverride::ERROR) { - if (!convert_frame_get_fail_callback) { - convert_frame_get_fail_callback = - py::module_::import("torch._dynamo.convert_frame") - .attr("get_fail_callback"); - auto atexit = py::module_::import("atexit"); - atexit.attr("register")(py::cpp_function( - []() { convert_frame_get_fail_callback = std::nullopt; })); + if (fullgraph_compiled_frame_count >= 0) { + fullgraph_compiled_frame_count++; + // Under fullgraph, disable or error Dynamo for sub-frames of compiled + // code. If fullgraph_error_on_nested_compile is set, wrap the callback + // with get_fail_callback so compilation attempts error. Otherwise, set + // callback to None to skip sub-frames entirely. + if (!recursive_callback.is_none() && + !recursive_callback.is(py::bool_(false))) { + if (fullgraph_error_on_nested_compile) { + if (!convert_frame_get_fail_callback) { + convert_frame_get_fail_callback = + py::module_::import("torch._dynamo.convert_frame") + .attr("get_fail_callback"); + auto atexit = py::module_::import("atexit"); + atexit.attr("register")(py::cpp_function( + []() { convert_frame_get_fail_callback = std::nullopt; })); + } + recursive_callback = + convert_frame_get_fail_callback.value()(recursive_callback); + } else { + recursive_callback = py::none(); } - recursive_callback = - convert_frame_get_fail_callback.value()(recursive_callback); } } eval_frame_callback_set(recursive_callback.ptr()); @@ -564,13 +550,13 @@ PyObject* dynamo__custom_eval_frame( if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) { py::handle guard_complete_hook_handle(guard_complete_hook); // False means force compilation (someone cache missed) - py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None); + py::object res = guard_complete_hook_handle(!Py_IsNone(maybe_cached_code)); if (!py::cast(res)) { maybe_cached_code = Py_None; // NB: non-owning } } - if (maybe_cached_code != Py_None) { + if (!Py_IsNone(maybe_cached_code)) { cached_code = (PyCodeObject*)maybe_cached_code; // used cached version DEBUG_TRACE("cache hit %s", get_frame_name(frame)); @@ -649,7 +635,7 @@ PyObject* dynamo__custom_eval_frame( extra_state_set_exec_strategy(extra, new_strategy); } - if (guarded_code != Py_None) { + if (!Py_IsNone(guarded_code)) { DEBUG_TRACE("create cache %s", get_frame_name(frame)); // NB: We could use extract_cache_entry to get the cache_entry, but diff --git a/torch/csrc/dynamo/eval_frame_cpp.h b/torch/csrc/dynamo/eval_frame_cpp.h index f0d04ed85c388..34ab9676185e1 100644 --- a/torch/csrc/dynamo/eval_frame_cpp.h +++ b/torch/csrc/dynamo/eval_frame_cpp.h @@ -29,15 +29,6 @@ int32_t dynamo_get_c_recursion_limit(); } // extern "C" -// Used to override the Dynamo callback for fullgraph=True'd compiled objects -enum class EvalFrameOverride { - NONE, // Run regular set callback - SKIP, // skip frames recursively - ERROR, // error if Dynamo attempts to trace code -}; - -EvalFrameOverride set_eval_frame_override(EvalFrameOverride override); - // Bytecode debugger callback functions void set_bytecode_debugger_callback(py::object callback); py::object get_bytecode_debugger_callback(); diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index b890c2848011b..42c17fed1393a 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -157,7 +157,7 @@ void lookup( for (CacheEntry& cache_entry : extra_state->cache_entry_list) { // Check backend. Py_False means run only mode. - bool valid = backend == Py_False || + bool valid = Py_IsFalse(backend) || backend_match(cache_entry.backend.ptr(), backend); if (valid) { @@ -229,9 +229,9 @@ CacheEntry* create_cache_entry( } py::list _debug_get_cache_entry_list(const py::handle& code_obj) { - if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { - throw py::type_error("expected a code object!"); - } + TORCH_CHECK_TYPE( + py::isinstance(code_obj, py::module::import("types").attr("CodeType")), + "expected a code object!"); PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; @@ -252,9 +252,9 @@ PrecompileEntry::PrecompileEntry(py::object gm, py::object c) } void _reset_precompile_entries(const py::handle& code_obj) { - if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { - throw py::type_error("expected a code object!"); - } + TORCH_CHECK_TYPE( + py::isinstance(code_obj, py::module::import("types").attr("CodeType")), + "expected a code object!"); PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; @@ -267,9 +267,9 @@ void _load_precompile_entry( const py::handle& code_obj, py::object guard_manager, py::object dynamo_code) { - if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { - throw py::type_error("expected a code object!"); - } + TORCH_CHECK_TYPE( + py::isinstance(code_obj, py::module::import("types").attr("CodeType")), + "expected a code object!"); PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; @@ -290,9 +290,9 @@ void _set_lru_cache(py::object boolean) { } py::list _debug_get_precompile_entries(const py::handle& code_obj) { - if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { - throw py::type_error("expected a code object!"); - } + TORCH_CHECK_TYPE( + py::isinstance(code_obj, py::module::import("types").attr("CodeType")), + "expected a code object!"); PyCodeObject* code = (PyCodeObject*)code_obj.ptr(); ExtraState* extra = get_extra_state(code); py::list result; diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 737c230388c1d..f9899719e5c1e 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -6,6 +6,7 @@ #define PY_SSIZE_T_CLEAN #include #include +#include #include #include #include @@ -345,7 +346,7 @@ static std::vector> pyListToVecOptInt( for (Py_ssize_t i = 0; i < size; i++) { PyObject* item = PyList_GetItem(pyList, i); auto handle = py::handle(item); - if (item == Py_None) { + if (Py_IsNone(item)) { vec.emplace_back(std::nullopt); } else if (torch::is_symint(handle)) { vec.emplace_back(py::cast(handle)); @@ -366,7 +367,7 @@ static std::vector> pyListToVecOptInt( static std::vector>> get_dynamic_dims( PyObject* dynamic_dims_py) { std::vector>> per_tensor_dynamic_dims; - if (dynamic_dims_py != Py_None) { + if (!Py_IsNone(dynamic_dims_py)) { Py_ssize_t size = PyList_Size(dynamic_dims_py); for (Py_ssize_t i = 0; i < size; i++) { PyObject* py_list = PyList_GetItem(dynamic_dims_py, i); @@ -537,13 +538,15 @@ PyObject* TensorGuards_check_verbose( PyObject* item = PyTuple_GET_ITEM(args, i); if (Py_TYPE(item) != checks[i].pytype) { std::stringstream fail_reason; - PyObject* type_str = PyObject_Str(PyObject_Type(item)); + PyObject* type_str = + PyObject_Str(reinterpret_cast(Py_TYPE(item))); fail_reason << "expected type of '" << tensor_check_names[i] << "' to be a tensor type, "; if (!type_str) { fail_reason << "but found a different type"; } else { fail_reason << "' but found " << PyUnicode_AsUTF8(type_str); + Py_DECREF(type_str); } return Py_BuildValue("s", fail_reason.str().c_str()); } @@ -874,20 +877,26 @@ static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { #if IS_PYTHON_3_12_PLUS -static std::unordered_map dict_version_map; +struct DictVersionState { + std::unordered_map map; + uint64_t next_id = 1; +}; + +static c10::Synchronized dict_version_state; static int dict_version_watcher_id; static int dict_recursive_tag_watcher_id; -static uint64_t global_dict_version_id = 1; static int dict_version_watch_callback( PyDict_WatchEvent event, PyObject* dict, PyObject* key, PyObject* new_value) noexcept { - if (event == PyDict_EVENT_DEALLOCATED) { - dict_version_map.erase(dict); - } else if (event != PyDict_EVENT_CLONED) { - dict_version_map[dict] = global_dict_version_id++; - } + dict_version_state.withLock([&](DictVersionState& state) { + if (event == PyDict_EVENT_DEALLOCATED) { + state.map.erase(dict); + } else if (event != PyDict_EVENT_CLONED) { + state.map[dict] = state.next_id++; + } + }); return 0; } @@ -899,10 +908,13 @@ static uint64_t get_dict_version_unchecked(PyObject* dict) { TORCH_CHECK( !PyDict_Watch(dict_version_watcher_id, dict), "failed to add version watcher to dict!"); - if (!dict_version_map.count(dict)) { - dict_version_map[dict] = global_dict_version_id++; - } - return dict_version_map[dict]; + return dict_version_state.withLock([&](DictVersionState& state) -> uint64_t { + auto [it, inserted] = state.map.try_emplace(dict, state.next_id); + if (inserted) { + state.next_id++; + } + return it->second; + }); #else @@ -911,16 +923,11 @@ static uint64_t get_dict_version_unchecked(PyObject* dict) { #endif } -static PyObject* dict_version(PyObject* dummy, PyObject* args) { - // Retrieves the version of a dictionary. - PyObject* obj = nullptr; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return nullptr; - } - if (!PyDict_Check(obj)) { - return nullptr; - } +static PyObject* dict_version(PyObject* dummy, PyObject* obj) { + HANDLE_TH_ERRORS + TORCH_CHECK(PyDict_Check(obj), "dict_version expects a dict"); return THPUtils_packUInt64(get_dict_version_unchecked(obj)); + END_HANDLE_TH_ERRORS } static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { @@ -1054,6 +1061,48 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) { Py_RETURN_TRUE; } +static PyObject* copy_if_misaligned(PyObject* dummy, PyObject* item) { + /* + * If the tensor's data pointer is not 16-byte aligned, return a + * clone that preserves strides. Otherwise return the original + * tensor (new reference). Implemented in C++ so the aligned + * fast-path is just a pointer check with minimal Python overhead. + * + * NOTE: kAlignment is hardcoded to match torch._inductor.utils.ALIGNMENT. + * If alignment requirements ever change or become per-platform, this + * constant must be updated (or turned into a parameter). + */ + constexpr size_t kAlignment = 16; + + if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { + PyErr_SetString(PyExc_TypeError, "expected Tensor()"); + return nullptr; + } + + at::Tensor tensor = THPVariable_Unpack(item); + + if (reinterpret_cast(tensor.data_ptr()) % kAlignment == 0) { + // Already aligned – return the original tensor. + Py_INCREF(item); + return item; + } + + // Misaligned – clone while preserving strides. + // Same logic as torch._inductor.utils.clone_preserve_strides. + int64_t needed_size = 0; + if (tensor.numel() > 0) { + auto sizes = tensor.sizes(); + auto strides = tensor.strides(); + for (int64_t i = 0; i < tensor.dim(); ++i) { + needed_size += (sizes[i] - 1) * strides[i]; + } + needed_size += 1; + } + at::Tensor flat = at::as_strided(tensor, {needed_size}, {1}).clone(); + at::Tensor result = at::as_strided(flat, tensor.sizes(), tensor.strides()); + return THPVariable_Wrap(std::move(result)); +} + template static void unwrap_size_tuple(PyObject* obj, T& output) { TORCH_CHECK(PyTuple_CheckExact(obj)); @@ -1176,7 +1225,8 @@ static PyMethodDef _methods[] = { {"check_obj_id", check_obj_id, METH_VARARGS, nullptr}, {"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr}, {"assert_alignment", assert_alignment, METH_VARARGS, nullptr}, - {"dict_version", dict_version, METH_VARARGS, nullptr}, + {"copy_if_misaligned", copy_if_misaligned, METH_O, nullptr}, + {"dict_version", dict_version, METH_O, nullptr}, {"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr}, {"_empty_strided_cpu_pinned", _empty_strided_cpu_pinned, @@ -1227,7 +1277,7 @@ bool is_immutable_object(py::handle example_value) { return true; } - return (example_value.ptr() == Py_None) || + return (Py_IsNone(example_value.ptr())) || PyLong_Check(example_value.ptr()) || PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) || PyUnicode_Check(example_value.ptr()) || @@ -1637,7 +1687,9 @@ class DictGuardManager; // the container can grow large over the lifetime of the process. That’s // acceptable: lookup is by pointer (hash/equals = identity) and each entry // stores only lightweight pointers. -std::unordered_map> dict_to_guard_managers; +using DictToGuardManagersMap = + std::unordered_map>; +c10::Synchronized dict_to_guard_managers; /** * Base class for the leaf guard in the GuardManager hierarchy. @@ -1683,7 +1735,6 @@ class LeafGuard { // is not exposed to Python and can only be called from C++. virtual bool check_nopybind(PyObject* value) = 0; virtual bool check_nopybind(FrameLocalsMapping* map) { - // throw std::runtime_error("fallback to python"); // Could fallback to running check on the Python dict (lazily constructed) return check_nopybind((PyObject*)map->to_dict()); } @@ -1726,7 +1777,8 @@ class LAMBDA_GUARD : public LeafGuard { if (py::isinstance(guard_check_fn)) { _guard_check_fn = py::cast(std::move(guard_check_fn)); } else { - throw py::type_error("LAMBDA_GUARD expects (callable, str)"); + throw py::type_error( + "LAMBDA_GUARD expects (callable, str)"); // @allow-raw-throw } } @@ -1824,7 +1876,7 @@ class NONE_MATCH : public LeafGuard { std::move(user_stack)) {} bool check_nopybind(PyObject* value) override { // borrowed ref - return value == Py_None; + return Py_IsNone(value); } }; @@ -1840,7 +1892,7 @@ class TRUE_MATCH : public LeafGuard { std::move(user_stack)) {} bool check_nopybind(PyObject* value) override { // borrowed ref - return value == Py_True; + return Py_IsTrue(value); } }; @@ -1856,7 +1908,7 @@ class FALSE_MATCH : public LeafGuard { std::move(user_stack)) {} bool check_nopybind(PyObject* value) override { // borrowed ref - return value == Py_False; + return Py_IsFalse(value); } }; @@ -2047,7 +2099,7 @@ class NOT_NONE : public LeafGuard { std::move(user_stack)) {} bool check_nopybind(PyObject* value) override { // borrowed ref - return value != Py_None; + return !Py_IsNone(value); } }; @@ -2154,7 +2206,8 @@ class GLOBAL_STATE : public LeafGuard { owner_(std::move(initial_state)), _guard((GlobalStateGuard*)owner_.ptr()) { if (!PyObject_TypeCheck(owner_.ptr(), &GlobalStateGuardType)) { - throw py::type_error("GLOBAL_STATE expects a GlobalStateGuard"); + throw py::type_error( + "GLOBAL_STATE expects a GlobalStateGuard"); // @allow-raw-throw } } @@ -2536,11 +2589,13 @@ class SYMBOLIC_SHAPE_GUARD : public RelationalGuard { _nargs_float = PyLong_AsSize_t(nargs_float.ptr()); _nargs = _nargs_int + _nargs_float; if (PyErr_Occurred()) { + // @allow-raw-throw throw py::value_error( "SYMBOLIC_SHAPE_GUARD expected a non-negative number of arguments."); } uintptr_t addr = PyLong_AsUnsignedLongLong(py_addr.ptr()); if (PyErr_Occurred()) { + // @allow-raw-throw throw py::value_error( "SYMBOLIC_SHAPE_GUARD expected an address to a C function."); } @@ -2687,7 +2742,7 @@ class DICT_VERSION : public LeafGuard { std::move(verbose_code_parts), std::move(user_stack)) { if (!PyDict_Check(value.ptr())) { - throw py::type_error("DICT_VERSION expects a dict"); + throw py::type_error("DICT_VERSION expects a dict"); // @allow-raw-throw } _tag = get_dict_version_unchecked(value.ptr()); } @@ -2724,9 +2779,7 @@ void stop_recording_dict_pointers( bool is_recording_dict_pointers(RootGuardManager* root); void record_dict_pointer(RootGuardManager* root, PyObject* dict_pointer); void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer); -void record_tensor_requires_grad( - RootGuardManager* root, - PyObject* tensor_pointer); +void record_tensor_metadata(RootGuardManager* root, PyObject* tensor_pointer); GuardManager* clone_guard_manager( GuardManager* from, @@ -2737,12 +2790,53 @@ void add_relational_guard_resetter_to_cloned_root( std::shared_ptr guard); std::shared_ptr get_no_tensor_aliasing_guard( RootGuardManager* _root); +const LocalState& get_local_state(RootGuardManager* root); // std::string get_compile_id(RootGuardManager* root); struct WeakEntry { PyObject* wr; // weakref PyObject* cap; // capsule whose m_self is used by the callback }; + +// Convert concrete sizes/strides to the optional vectors that +// TensorCheck expects. All dimensions are treated as static (no nullopt). +inline std::vector> to_opt_symint( + c10::IntArrayRef vals) { + std::vector> out; + out.reserve(vals.size()); + for (auto v : vals) { + out.emplace_back(c10::SymInt(v)); + } + return out; +} + +// Build a TensorCheck that validates all concrete metadata (dispatch key, +// dtype, device, requires_grad, sizes, strides) for the dict-tag fast path. +inline TensorCheck make_tensor_check( + const LocalState& state, + const at::Tensor& tensor) { + auto layout = tensor.layout(); + bool sparse = layout == c10::kSparseCsr || layout == c10::kSparseCsc || + layout == c10::kSparseBsc || layout == c10::kSparseBsr; + // Sparse layouts don't support strides; use nullopt per dim so + // TensorCheck skips stride comparison for each dimension. + auto strides = sparse + ? std::vector>(tensor.dim(), std::nullopt) + : to_opt_symint(tensor.strides()); + return TensorCheck( + state, + /*pt=*/nullptr, + tensor, + tensor.key_set(), + to_opt_symint(tensor.sizes()), + std::move(strides)); +} + +struct RecordedTensorMetadata { + PyObject* tensor_ptr; + TensorCheck check; +}; + /** * Base class representing a pair of accessor and the associated guard * manager. The accessor defines how to access the child value from the @@ -2785,7 +2879,6 @@ class GuardAccessor { // subtree on immutable dict getitems. virtual bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) = 0; virtual bool check_nopybind(FrameLocalsMapping* map, bool matches_dict_tag) { - // throw std::runtime_error("fallback to python"); // Could fallback to running check on the Python dict (lazily constructed) return check_nopybind((PyObject*)map->to_dict(), matches_dict_tag); } @@ -2922,6 +3015,7 @@ class GuardManager { if (PyCapsule_IsValid(e.cap, "GuardManager*")) { PyCapsule_SetName(e.cap, "DeadGuardManager"); } + Py_DECREF(e.cap); Py_CLEAR(e.wr); // kills weakref (may remove callback) } _tag_safe_entries.clear(); @@ -3012,14 +3106,21 @@ class GuardManager { _tensor_pointers[value] = tensor_pointers; } - void stash_tensor_requires_grad( + void stash_tensor_metadata( PyObject* value, - std::vector>&& tensor_requires_grad) { - _tensor_requires_grad_pointers[value] = std::move(tensor_requires_grad); + std::vector&& tensor_metadata) { + _tensor_metadata_pointers[value] = std::move(tensor_metadata); } void disable_recursive_dict_tag_optimization() { - unwatch_all_saved_dict_pointers(); + dict_to_guard_managers.withLock([&](DictToGuardManagersMap& map) { + disable_recursive_dict_tag_optimization(map); + }); + } + + // Caller must hold dict_to_guard_managers lock. + void disable_recursive_dict_tag_optimization(DictToGuardManagersMap& map) { + unwatch_all_saved_dict_pointers(map); _disable_dict_tag_matching = true; } @@ -3144,15 +3245,16 @@ class GuardManager { return true; } - bool check_tensor_requires_grad_fast(PyObject* value) const { - auto it = _tensor_requires_grad_pointers.find(value); - if (it == _tensor_requires_grad_pointers.end()) { + bool check_tensor_metadata_fast(PyObject* value) { + auto it = _tensor_metadata_pointers.find(value); + if (it == _tensor_metadata_pointers.end()) { return true; } - for (const auto& [tensor_ptr, expected_requires_grad] : it->second) { - if (THPVariable_Check(tensor_ptr) && - THPVariable_Unpack(tensor_ptr).requires_grad() != - expected_requires_grad) { + for (auto& recorded_tensor : it->second) { + if (!THPVariable_Check(recorded_tensor.tensor_ptr) || + !recorded_tensor.check.check( + get_local_state(_root), + THPVariable_Unpack(recorded_tensor.tensor_ptr))) { return false; } } @@ -3185,15 +3287,17 @@ class GuardManager { // For a `tag_safe_root`, the input pointer called `value`, the object the // guard is inspecting, serves as a proxy for the entire nested dictionary // structure beneath that node. If this `value` pointer is one we have - // already recorded, then verifying each dictionary’s tag is sufficient to - // prove that nothing inside the subtree has changed. + // already recorded, then verifying each dictionary’s tag plus the cached + // tensor metadata is sufficient to prove that nothing inside the subtree + // has changed. // // Runtime flow // ------------- // 1) Previously‑seen `value` pointer // • Look up the current `value` pointer in our cache. - // • If found, perform a recursive tag comparison on the cached subtree. - // All tags match means guard passes with no further traversal. + // • If found, perform a recursive tag comparison on the cached subtree + // and revalidate recorded tensor metadata. + // All checks passing means guard passes with no further traversal. // // 2) First‑time `value` pointer // • Enter recording mode; walk the subtree, each tag safe root collects @@ -3236,7 +3340,7 @@ class GuardManager { // Check for fast path // if (is_weakref_valid(value) && check_dict_pointer_tags(value)) { if (check_dict_pointer_tags(value) && - check_tensor_requires_grad_fast(value)) { + check_tensor_metadata_fast(value)) { if (check_no_tensor_aliasing_guards_fast(value)) { return true; } else { @@ -3266,9 +3370,9 @@ class GuardManager { } else if (_has_no_tensor_aliasing_guard) { record_tensor_pointer(_root, value); } - // Record tensor requires_grad for all tensors in the subtree. + // Tensor metadata can mutate in-place without changing dict tags. if (_is_immutable && THPVariable_Check(value)) { - record_tensor_requires_grad(_root, value); + record_tensor_metadata(_root, value); } } } @@ -3352,6 +3456,7 @@ class GuardManager { return false; } // These will be decrefed in destructor + Py_INCREF(capsule); _tag_safe_entries.push_back({wr, capsule}); return true; } @@ -3382,26 +3487,16 @@ class GuardManager { PyErr_Clear(); return false; } - dict_to_guard_managers[dict_pointer].push_back(this); + dict_to_guard_managers.withLock([&](DictToGuardManagersMap& map) { + map[dict_pointer].push_back(this); + }); } #endif return true; } - void unwatch_all_saved_dict_pointers() { - /* - We may have recorded hundreds/thousands of dict pointers for the recursive - dict-tag optimisation. If any of those dicts mutates, we want to disable the - optimisation and then unwatch as many dict pointers as we can. - - Be careful: the same dict pointer can be recorded by multiple GuardManagers. - So the flow is: - - 1) Remove *this* GuardManager from dict_to_guard_managers[dict_pointer]. - 2) If the list for that dict becomes empty, then: - - PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer) - - erase the dict_pointer entry from dict_to_guard_managers. - */ + // Caller must hold dict_to_guard_managers lock. + void unwatch_all_saved_dict_pointers(DictToGuardManagersMap& map) { #if IS_PYTHON_3_12_PLUS if (!_disable_dict_tag_matching) { for (auto& value_stashed_pointers : _dict_pointers) { @@ -3410,20 +3505,15 @@ class GuardManager { for (auto& stashed_pointer : stashed_pointers) { PyObject* dict_pointer = stashed_pointer.first; - // Delete the guard manager from the dict_to_guard_managers auto it = std::find( - dict_to_guard_managers[dict_pointer].begin(), - dict_to_guard_managers[dict_pointer].end(), - this); - if (it != dict_to_guard_managers[dict_pointer].end()) { - dict_to_guard_managers[dict_pointer].erase(it); + map[dict_pointer].begin(), map[dict_pointer].end(), this); + if (it != map[dict_pointer].end()) { + map[dict_pointer].erase(it); } - // Unwatch the dict pointer if this was the last guard manager - // watching it. - if (dict_to_guard_managers[dict_pointer].empty()) { + if (map[dict_pointer].empty()) { PyDict_Unwatch(dict_recursive_tag_watcher_id, dict_pointer); - dict_to_guard_managers.erase(dict_pointer); + map.erase(dict_pointer); } } } @@ -3618,7 +3708,7 @@ class GuardManager { void add_permitted_leaf_guard(std::shared_ptr leaf_guard) { // Selectively called for permitted guards. This is used by DictGuardManager - // which overrides the add_leaf_guard manager to throw runtime error. + // which overrides the add_leaf_guard manager to raise a runtime error. GuardManager::add_leaf_guard(std::move(leaf_guard)); } @@ -3674,8 +3764,8 @@ class GuardManager { std::unordered_map>> _dict_pointers; std::unordered_map> _tensor_pointers; - std::unordered_map>> - _tensor_requires_grad_pointers; + std::unordered_map> + _tensor_metadata_pointers; std::vector _tag_safe_entries; // 3.12+ related helper @@ -3944,7 +4034,7 @@ class RootGuardManager : public GuardManager { _current_tag_safe_root = nullptr; _recorded_dict_pointers.clear(); _recorded_tensor_pointers.clear(); - _recorded_tensor_requires_grad.clear(); + _recorded_tensor_metadata.clear(); } void stop_recording_dict_pointers(PyObject* value, bool result) { @@ -3954,8 +4044,8 @@ class RootGuardManager : public GuardManager { value, _recorded_dict_pointers); _current_tag_safe_root->stash_tensor_pointers( value, _recorded_tensor_pointers); - _current_tag_safe_root->stash_tensor_requires_grad( - value, std::move(_recorded_tensor_requires_grad)); + _current_tag_safe_root->stash_tensor_metadata( + value, std::move(_recorded_tensor_metadata)); } reset_dict_tag_recording_variables(); } @@ -3973,9 +4063,11 @@ class RootGuardManager : public GuardManager { _recorded_tensor_pointers.push_back(tensor_pointer); } - void record_tensor_requires_grad(PyObject* tensor_pointer) { - bool rg = THPVariable_Unpack(tensor_pointer).requires_grad(); - _recorded_tensor_requires_grad.emplace_back(tensor_pointer, rg); + void record_tensor_metadata(PyObject* tensor_pointer) { + _recorded_tensor_metadata.push_back(RecordedTensorMetadata{ + tensor_pointer, + make_tensor_check(_local_state, THPVariable_Unpack(tensor_pointer)), + }); } public: @@ -4034,7 +4126,7 @@ class RootGuardManager : public GuardManager { GuardManager* _current_tag_safe_root{nullptr}; std::vector> _recorded_dict_pointers; std::vector _recorded_tensor_pointers; - std::vector> _recorded_tensor_requires_grad; + std::vector _recorded_tensor_metadata; }; /* @@ -4396,15 +4488,18 @@ static int dict_recursive_tag_watch_callback( PyObject* key, PyObject* new_value) noexcept { if (event != PyDict_EVENT_CLONED) { - auto it = dict_to_guard_managers.find(dict); - if (it != dict_to_guard_managers.end()) { - auto guard_managers = it->second; - for (auto& guard_manager : guard_managers) { - if (guard_manager) { - guard_manager->disable_recursive_dict_tag_optimization(); + dict_to_guard_managers.withLock([&](DictToGuardManagersMap& map) { + auto it = map.find(dict); + if (it != map.end()) { + // Copy the list — unwatch_all_saved_dict_pointers may mutate it. + auto guard_managers = it->second; + for (auto& guard_manager : guard_managers) { + if (guard_manager) { + guard_manager->disable_recursive_dict_tag_optimization(map); + } } } - } + }); } return 0; // keep watching } @@ -4450,7 +4545,7 @@ std::unique_ptr make_guard_manager( return std::make_unique( root, std::move(source), example_value); } else { - throw py::type_error("Invalid guard manager enum"); + throw py::type_error("Invalid guard manager enum"); // @allow-raw-throw } } return std::make_unique(root, std::move(source), example_value); @@ -4481,10 +4576,8 @@ void record_tensor_pointer(RootGuardManager* root, PyObject* tensor_pointer) { root->record_tensor_pointer(tensor_pointer); } -void record_tensor_requires_grad( - RootGuardManager* root, - PyObject* tensor_pointer) { - root->record_tensor_requires_grad(tensor_pointer); +void record_tensor_metadata(RootGuardManager* root, PyObject* tensor_pointer) { + root->record_tensor_metadata(tensor_pointer); } std::shared_ptr get_no_tensor_aliasing_guard( @@ -4492,6 +4585,10 @@ std::shared_ptr get_no_tensor_aliasing_guard( return _root->get_no_tensor_aliasing_guard(); } +const LocalState& get_local_state(RootGuardManager* root) { + return root->_local_state; +} + // std::string get_compile_id(RootGuardManager* root) { // return root->get_compile_id(); // } @@ -4640,13 +4737,15 @@ class TENSOR_MATCH : public LeafGuard { if (Py_TYPE(value) != _tensor_check->pytype) { std::stringstream fail_reason; - PyObject* type_str = PyObject_Str(PyObject_Type(value)); + PyObject* type_str = + PyObject_Str(reinterpret_cast(Py_TYPE(value))); fail_reason << "expected type of '" << _tensor_name << "' to be a tensor type, "; if (!type_str) { fail_reason << "but found a different type"; } else { fail_reason << "' but found " << PyUnicode_AsUTF8(type_str); + Py_DECREF(type_str); } return GuardDebugInfo(false, fail_reason.str(), 0); } @@ -5003,10 +5102,7 @@ class FrameLocalsGuardAccessor : public GuardAccessor { _key(key[0].ptr()), _framelocals_idx(key[1].cast()), _is_immutable_object(is_immutable_object(example_value)), - _is_tensor(THPVariable_Check(example_value.ptr())), - _tensor_requires_grad( - _is_tensor ? THPVariable_Unpack(example_value.ptr()).requires_grad() - : false) {} + _is_tensor(THPVariable_Check(example_value.ptr())) {} // Run as a result of calling run_root_guard_manager/check_nopybind // NB: Intentional duplication between check_nopybind and @@ -5014,18 +5110,8 @@ class FrameLocalsGuardAccessor : public GuardAccessor { bool check_nopybind( FrameLocalsMapping* obj, bool matches_dict_tag = false) override { // borrowed ref - if (matches_dict_tag && _is_immutable_object) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - if (!tensor_requires_grad_changed(obj->get(_framelocals_idx))) { - return true; - } + if (matches_dict_tag && _is_immutable_object && !_is_tensor) { + return true; } PyObject* x = obj->get(_framelocals_idx); @@ -5047,19 +5133,8 @@ class FrameLocalsGuardAccessor : public GuardAccessor { PyDict_Check(obj), "FrameLocalsGuardAccessor check expected dict() input"); - if (matches_dict_tag && _is_immutable_object) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref - if (!tensor_requires_grad_changed(x)) { - return true; - } + if (matches_dict_tag && _is_immutable_object && !_is_tensor) { + return true; } PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref @@ -5116,15 +5191,9 @@ class FrameLocalsGuardAccessor : public GuardAccessor { to->_framelocals_idx = _framelocals_idx; to->_is_immutable_object = _is_immutable_object; to->_is_tensor = _is_tensor; - to->_tensor_requires_grad = _tensor_requires_grad; } private: - bool tensor_requires_grad_changed(PyObject* x) const { - return x != nullptr && THPVariable_Check(x) && - THPVariable_Unpack(x).requires_grad() != _tensor_requires_grad; - } - PyObject* _key{nullptr}; int _framelocals_idx{-1}; @@ -5132,7 +5201,6 @@ class FrameLocalsGuardAccessor : public GuardAccessor { // return true. bool _is_immutable_object{false}; bool _is_tensor{false}; - bool _tensor_requires_grad{false}; }; /** @@ -5156,30 +5224,15 @@ class DictGetItemGuardAccessor : public GuardAccessor { guard_manager_enum), _key(key.ptr()), _is_immutable_object(is_immutable_object(example_value)), - _is_tensor(THPVariable_Check(example_value.ptr())), - _tensor_requires_grad( - _is_tensor ? THPVariable_Unpack(example_value.ptr()).requires_grad() - : false) {} + _is_tensor(THPVariable_Check(example_value.ptr())) {} // NB: Intentional duplication between check_nopybind and // check_verbose_nopybind. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { - if (matches_dict_tag && _is_immutable_object && + if (matches_dict_tag && _is_immutable_object && !_is_tensor && !is_recording_dict_pointers(get_guard_manager()->get_root()) && _guard_manager->has_no_accessors()) { - // Tensors are treated as immutable for the dict-tag optimization, but - // their metadata (e.g. requires_grad) can be mutated in-place without - // changing the parent dict's version tag. For now we only check - // requires_grad since it is the most common mutation; other metadata - // changes (dtype, device, etc.) are possible but rare in practice. - if (!_is_tensor) { - return true; - } - PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref - if (!tensor_requires_grad_changed(x)) { - return true; - } - // Fall through to full check - requires_grad changed. + return true; } PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref @@ -5226,22 +5279,15 @@ class DictGetItemGuardAccessor : public GuardAccessor { to->_key = _key; to->_is_immutable_object = _is_immutable_object; to->_is_tensor = _is_tensor; - to->_tensor_requires_grad = _tensor_requires_grad; } private: - bool tensor_requires_grad_changed(PyObject* x) const { - return x != nullptr && THPVariable_Check(x) && - THPVariable_Unpack(x).requires_grad() != _tensor_requires_grad; - } - PyObject* _key{nullptr}; // If immutable object and dict tag matches, we can skip the guard subtree and // return true. bool _is_immutable_object{false}; bool _is_tensor{false}; - bool _tensor_requires_grad{false}; }; /** @@ -6032,22 +6078,47 @@ class TypeDictGuardAccessor : public GuardAccessor { // NB: Intentional duplication between check_nopybind and // check_verbose_nopybind. + // + // In CPython 3.12+, types with Py_TPFLAGS_MANAGED_DICT (e.g. enum.Enum) + // can have tp_dict=NULL because the dict is lazily materialized. Use + // PyType_GetDict() which handles this transparently. It returns a new + // reference, so we must decref after use. bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) override { // borrowed ref +#if PY_VERSION_HEX >= 0x030C0000 + PyObject* x = PyType_GetDict((PyTypeObject*)obj); // new ref + if (x == nullptr) { + return false; + } + bool result = _guard_manager->check_nopybind(x); + Py_DECREF(x); + return result; +#else PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref if (x == nullptr) { return false; } return _guard_manager->check_nopybind(x); +#endif } GuardDebugInfo check_verbose_nopybind( PyObject* obj) override { // borrowed ref +#if PY_VERSION_HEX >= 0x030C0000 + PyObject* x = PyType_GetDict((PyTypeObject*)obj); // new ref + if (x == nullptr) { + return GuardDebugInfo(false, "null type dict on " + repr(), 0); + } + auto result = _guard_manager->check_verbose_nopybind(x); + Py_DECREF(x); + return result; +#else PyObject* x = ((PyTypeObject*)obj)->tp_dict; // borrowed ref if (x == nullptr) { return GuardDebugInfo(false, "null type dict on " + repr(), 0); } return _guard_manager->check_verbose_nopybind(x); +#endif } std::string repr() const override { @@ -6976,6 +7047,30 @@ PyObject* torch_c_dynamo_guards_init() { return nullptr; } + // Expose THPVariable_Wrap for cpp_wrapper inductor, for fbcode + { + using WrapFn = PyObject* (*)(const at::TensorBase&); + WrapFn fn = &THPVariable_Wrap; + if (PyModule_AddObject( + m, + "_torchinductor_thp_variable_wrap", + PyLong_FromVoidPtr(reinterpret_cast(fn))) < 0) { + return nullptr; + } + } + + // Expose THPUtils_unpackInt for cpp_wrapper inductor, for fbcode + { + using UnpackFn = int32_t (*)(PyObject*); + UnpackFn fn = &THPUtils_unpackInt; + if (PyModule_AddObject( + m, + "_torchinductor_thputils_unpack_int", + PyLong_FromVoidPtr(reinterpret_cast(fn))) < 0) { + return nullptr; + } + } + auto py_m = py::handle(m).cast(); py::class_>( py_m, "GuardDebugInfo") @@ -8206,7 +8301,7 @@ PyObject* torch_c_dynamo_guards_init() { py_m.def("install_symbolic_shape_guard", install_symbolic_shape_guard); py_m.def("profile_guard_manager", profile_guard_manager); -// initialize dict_version_map watcher for 3.12 +// initialize dict version watcher for 3.12 #if IS_PYTHON_3_12_PLUS dict_version_watcher_id = PyDict_AddWatcher(dict_version_watch_callback); diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index ec096b1145e90..92ad88c771f01 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -158,10 +158,238 @@ PyObject* _is_valid_var_name( return result.release(); } +// Slot bit position definitions (each int64_t has independent bit positions) + +enum class PySequenceSlotBit : int64_t { + SQ_LENGTH = 0, + SQ_CONCAT = 1, + SQ_REPEAT = 2, + SQ_ITEM = 3, + SQ_CONTAINS = 4, + SQ_ASS_ITEM = 5, + SQ_INPLACE_CONCAT = 6, + SQ_INPLACE_REPEAT = 7, +}; + +enum class PyMappingSlotBit : int64_t { + MP_LENGTH = 0, + MP_SUBSCRIPT = 1, + MP_ASS_SUBSCRIPT = 2, +}; + +enum class PyNumberSlotBit : int64_t { + NB_ADD = 0, + NB_SUBTRACT = 1, + NB_MULTIPLY = 2, + NB_REMAINDER = 3, + NB_POWER = 4, + NB_NEGATIVE = 5, + NB_POSITIVE = 6, + NB_ABSOLUTE = 7, + NB_BOOL = 8, + NB_INVERT = 9, + NB_LSHIFT = 10, + NB_RSHIFT = 11, + NB_AND = 12, + NB_XOR = 13, + NB_OR = 14, + NB_INT = 15, + NB_FLOAT = 16, + NB_INPLACE_ADD = 17, + NB_INPLACE_SUBTRACT = 18, + NB_INPLACE_MULTIPLY = 19, + NB_INPLACE_REMAINDER = 20, + NB_INPLACE_POWER = 21, + NB_INPLACE_LSHIFT = 22, + NB_INPLACE_RSHIFT = 23, + NB_INPLACE_AND = 24, + NB_INPLACE_XOR = 25, + NB_INPLACE_OR = 26, + NB_FLOOR_DIVIDE = 27, + NB_TRUE_DIVIDE = 28, + NB_INPLACE_FLOOR_DIVIDE = 29, + NB_INPLACE_TRUE_DIVIDE = 30, + NB_INDEX = 31, + NB_MATRIX_MULTIPLY = 32, + NB_INPLACE_MATRIX_MULTIPLY = 33, +}; + +enum class PyTypeSlotBit : int64_t { + TP_HASH = 0, + TP_ITER = 1, + TP_ITERNEXT = 2, + TP_CALL = 3, + TP_REPR = 4, + TP_RICHCOMPARE = 5, + TP_GETATTRO = 6, + TP_SETATTRO = 7, + TP_DESCR_GET = 8, + TP_DESCR_SET = 9, +}; + +int64_t get_pysequence_slots(PyTypeObject* type) { + int64_t slots = 0; + if (PyType_GetSlot(type, Py_sq_length) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_LENGTH)); + if (PyType_GetSlot(type, Py_sq_concat) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_CONCAT)); + if (PyType_GetSlot(type, Py_sq_repeat) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_REPEAT)); + if (PyType_GetSlot(type, Py_sq_item) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_ITEM)); + if (PyType_GetSlot(type, Py_sq_contains) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_CONTAINS)); + if (PyType_GetSlot(type, Py_sq_ass_item) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_ASS_ITEM)); + if (PyType_GetSlot(type, Py_sq_inplace_concat) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_INPLACE_CONCAT)); + if (PyType_GetSlot(type, Py_sq_inplace_repeat) != nullptr) + slots |= (1LL << static_cast(PySequenceSlotBit::SQ_INPLACE_REPEAT)); + return slots; +} + +int64_t get_pymapping_slots(PyTypeObject* type) { + int64_t slots = 0; + if (PyType_GetSlot(type, Py_mp_length) != nullptr) + slots |= (1LL << static_cast(PyMappingSlotBit::MP_LENGTH)); + if (PyType_GetSlot(type, Py_mp_subscript) != nullptr) + slots |= (1LL << static_cast(PyMappingSlotBit::MP_SUBSCRIPT)); + if (PyType_GetSlot(type, Py_mp_ass_subscript) != nullptr) + slots |= (1LL << static_cast(PyMappingSlotBit::MP_ASS_SUBSCRIPT)); + return slots; +} + +int64_t get_pynumber_slots(PyTypeObject* type) { + int64_t slots = 0; + if (PyType_GetSlot(type, Py_nb_add) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_ADD)); + if (PyType_GetSlot(type, Py_nb_subtract) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_SUBTRACT)); + if (PyType_GetSlot(type, Py_nb_multiply) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_MULTIPLY)); + if (PyType_GetSlot(type, Py_nb_remainder) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_REMAINDER)); + if (PyType_GetSlot(type, Py_nb_power) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_POWER)); + if (PyType_GetSlot(type, Py_nb_negative) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_NEGATIVE)); + if (PyType_GetSlot(type, Py_nb_positive) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_POSITIVE)); + if (PyType_GetSlot(type, Py_nb_absolute) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_ABSOLUTE)); + if (PyType_GetSlot(type, Py_nb_bool) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_BOOL)); + if (PyType_GetSlot(type, Py_nb_invert) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INVERT)); + if (PyType_GetSlot(type, Py_nb_lshift) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_LSHIFT)); + if (PyType_GetSlot(type, Py_nb_rshift) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_RSHIFT)); + if (PyType_GetSlot(type, Py_nb_and) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_AND)); + if (PyType_GetSlot(type, Py_nb_xor) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_XOR)); + if (PyType_GetSlot(type, Py_nb_or) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_OR)); + if (PyType_GetSlot(type, Py_nb_int) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INT)); + if (PyType_GetSlot(type, Py_nb_float) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_FLOAT)); + if (PyType_GetSlot(type, Py_nb_inplace_add) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_ADD)); + if (PyType_GetSlot(type, Py_nb_inplace_subtract) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_SUBTRACT)); + if (PyType_GetSlot(type, Py_nb_inplace_multiply) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_MULTIPLY)); + if (PyType_GetSlot(type, Py_nb_inplace_remainder) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_REMAINDER)); + if (PyType_GetSlot(type, Py_nb_inplace_power) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_POWER)); + if (PyType_GetSlot(type, Py_nb_inplace_lshift) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_LSHIFT)); + if (PyType_GetSlot(type, Py_nb_inplace_rshift) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_RSHIFT)); + if (PyType_GetSlot(type, Py_nb_inplace_and) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_AND)); + if (PyType_GetSlot(type, Py_nb_inplace_xor) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_XOR)); + if (PyType_GetSlot(type, Py_nb_inplace_or) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_OR)); + if (PyType_GetSlot(type, Py_nb_floor_divide) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_FLOOR_DIVIDE)); + if (PyType_GetSlot(type, Py_nb_true_divide) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_TRUE_DIVIDE)); + if (PyType_GetSlot(type, Py_nb_inplace_floor_divide) != nullptr) + slots |= + (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_FLOOR_DIVIDE)); + if (PyType_GetSlot(type, Py_nb_inplace_true_divide) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_TRUE_DIVIDE)); + if (PyType_GetSlot(type, Py_nb_index) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_INDEX)); + if (PyType_GetSlot(type, Py_nb_matrix_multiply) != nullptr) + slots |= (1LL << static_cast(PyNumberSlotBit::NB_MATRIX_MULTIPLY)); + if (PyType_GetSlot(type, Py_nb_inplace_matrix_multiply) != nullptr) + slots |= + (1LL << static_cast(PyNumberSlotBit::NB_INPLACE_MATRIX_MULTIPLY)); + return slots; +} + +int64_t get_pytype_slots(PyTypeObject* type) { + int64_t slots = 0; + if (PyType_GetSlot(type, Py_tp_hash) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_HASH)); + if (PyType_GetSlot(type, Py_tp_iter) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_ITER)); + if (PyType_GetSlot(type, Py_tp_iternext) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_ITERNEXT)); + if (PyType_GetSlot(type, Py_tp_call) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_CALL)); + if (PyType_GetSlot(type, Py_tp_repr) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_REPR)); + if (PyType_GetSlot(type, Py_tp_richcompare) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_RICHCOMPARE)); + if (PyType_GetSlot(type, Py_tp_getattro) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_GETATTRO)); + if (PyType_GetSlot(type, Py_tp_setattro) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_SETATTRO)); + if (PyType_GetSlot(type, Py_tp_descr_get) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_DESCR_GET)); + if (PyType_GetSlot(type, Py_tp_descr_set) != nullptr) + slots |= (1LL << static_cast(PyTypeSlotBit::TP_DESCR_SET)); + return slots; +} + +PyObject* _get_type_slots( + PyObject* self, + PyObject* const* args, + Py_ssize_t nargs) { + if (!_checkParamCount(nargs, 1)) { + return nullptr; + } + + PyObject* arg = args[0]; + PyTypeObject* type = PyType_Check(arg) ? (PyTypeObject*)arg : Py_TYPE(arg); + + int64_t seq_slots = get_pysequence_slots(type); + int64_t map_slots = get_pymapping_slots(type); + int64_t num_slots = get_pynumber_slots(type); + int64_t type_slots = get_pytype_slots(type); + + PyObject* tuple = PyTuple_New(4); + if (tuple == nullptr) { + return nullptr; + } + PyTuple_SetItem(tuple, 0, PyLong_FromLongLong(seq_slots)); + PyTuple_SetItem(tuple, 1, PyLong_FromLongLong(map_slots)); + PyTuple_SetItem(tuple, 2, PyLong_FromLongLong(num_slots)); + PyTuple_SetItem(tuple, 3, PyLong_FromLongLong(type_slots)); + return tuple; +} + #define PYC_FN(x) ((PyCFunction)(void (*)()) & x) void _register_functions(PyObject* mod) { - static std::array fns = { + static std::array fns = { PyMethodDef{ "strip_function_call", PYC_FN(_strip_function_call), @@ -172,6 +400,8 @@ void _register_functions(PyObject* mod) { PYC_FN(_is_valid_var_name), METH_FASTCALL, nullptr}, + PyMethodDef{ + "get_type_slots", PYC_FN(_get_type_slots), METH_FASTCALL, nullptr}, PyMethodDef{nullptr, nullptr, 0, nullptr}, }; PyModule_AddFunctions(mod, fns.data()); @@ -182,7 +412,7 @@ void _register_functions(PyObject* mod) { void initDynamoBindings(PyObject* torch) { PyObject* dynamo = PyModule_Create(&_module); if (dynamo == nullptr || PyModule_AddObject(torch, "_dynamo", dynamo) != 0) { - throw python_error(); + throw python_error(); // @allow-raw-throw } #ifdef Py_GIL_DISABLED PyUnstable_Module_SetGIL(dynamo, Py_MOD_GIL_NOT_USED); @@ -191,23 +421,23 @@ void initDynamoBindings(PyObject* torch) { PyObject* eval_frame = torch_c_dynamo_eval_frame_init(); if (eval_frame == nullptr || PyModule_AddObject(dynamo, "eval_frame", eval_frame) != 0) { - throw python_error(); + throw python_error(); // @allow-raw-throw } PyObject* utils = torch_c_dynamo_utils_init(); if (utils == nullptr || PyModule_AddObject(dynamo, "utils", utils) != 0) { - throw python_error(); + throw python_error(); // @allow-raw-throw } PyObject* guards = torch_c_dynamo_guards_init(); if (guards == nullptr || PyModule_AddObject(dynamo, "guards", guards) != 0) { - throw python_error(); + throw python_error(); // @allow-raw-throw } PyObject* compiled_autograd = torch_c_dynamo_compiled_autograd_init(); if (compiled_autograd == nullptr || PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) { - throw python_error(); + throw python_error(); // @allow-raw-throw } auto m = py::handle(eval_frame).cast(); @@ -275,12 +505,77 @@ void initDynamoBindings(PyObject* torch) { m.def("code_framelocals_names", &code_framelocals_names); _register_functions(dynamo); - py::enum_(m, "_EvalFrameOverride") - .value("NONE", EvalFrameOverride::NONE) - .value("SKIP", EvalFrameOverride::SKIP) - .value("ERROR", EvalFrameOverride::ERROR); - - m.def("set_eval_frame_override", &set_eval_frame_override); + auto dynamo_module = py::handle(dynamo).cast(); + dynamo_module.def("has_slot", [](int64_t slots, py::object slot_bit_obj) { + // Convert slot_bit to int - handle both int and pybind11 enums + int64_t slot_bit = py::cast(slot_bit_obj.attr("__index__")()); + return (slots & (1LL << slot_bit)) != 0; + }); + py::enum_(dynamo_module, "PySequenceSlots") + .value("SQ_LENGTH", PySequenceSlotBit::SQ_LENGTH) + .value("SQ_CONCAT", PySequenceSlotBit::SQ_CONCAT) + .value("SQ_REPEAT", PySequenceSlotBit::SQ_REPEAT) + .value("SQ_ITEM", PySequenceSlotBit::SQ_ITEM) + .value("SQ_CONTAINS", PySequenceSlotBit::SQ_CONTAINS) + .value("SQ_ASS_ITEM", PySequenceSlotBit::SQ_ASS_ITEM) + .value("SQ_INPLACE_CONCAT", PySequenceSlotBit::SQ_INPLACE_CONCAT) + .value("SQ_INPLACE_REPEAT", PySequenceSlotBit::SQ_INPLACE_REPEAT); + + py::enum_(dynamo_module, "PyMappingSlots") + .value("MP_LENGTH", PyMappingSlotBit::MP_LENGTH) + .value("MP_SUBSCRIPT", PyMappingSlotBit::MP_SUBSCRIPT) + .value("MP_ASS_SUBSCRIPT", PyMappingSlotBit::MP_ASS_SUBSCRIPT); + + py::enum_(dynamo_module, "PyNumberSlots") + .value("NB_ADD", PyNumberSlotBit::NB_ADD) + .value("NB_SUBTRACT", PyNumberSlotBit::NB_SUBTRACT) + .value("NB_MULTIPLY", PyNumberSlotBit::NB_MULTIPLY) + .value("NB_REMAINDER", PyNumberSlotBit::NB_REMAINDER) + .value("NB_POWER", PyNumberSlotBit::NB_POWER) + .value("NB_NEGATIVE", PyNumberSlotBit::NB_NEGATIVE) + .value("NB_POSITIVE", PyNumberSlotBit::NB_POSITIVE) + .value("NB_ABSOLUTE", PyNumberSlotBit::NB_ABSOLUTE) + .value("NB_BOOL", PyNumberSlotBit::NB_BOOL) + .value("NB_INVERT", PyNumberSlotBit::NB_INVERT) + .value("NB_LSHIFT", PyNumberSlotBit::NB_LSHIFT) + .value("NB_RSHIFT", PyNumberSlotBit::NB_RSHIFT) + .value("NB_AND", PyNumberSlotBit::NB_AND) + .value("NB_XOR", PyNumberSlotBit::NB_XOR) + .value("NB_OR", PyNumberSlotBit::NB_OR) + .value("NB_INT", PyNumberSlotBit::NB_INT) + .value("NB_FLOAT", PyNumberSlotBit::NB_FLOAT) + .value("NB_INPLACE_ADD", PyNumberSlotBit::NB_INPLACE_ADD) + .value("NB_INPLACE_SUBTRACT", PyNumberSlotBit::NB_INPLACE_SUBTRACT) + .value("NB_INPLACE_MULTIPLY", PyNumberSlotBit::NB_INPLACE_MULTIPLY) + .value("NB_INPLACE_REMAINDER", PyNumberSlotBit::NB_INPLACE_REMAINDER) + .value("NB_INPLACE_POWER", PyNumberSlotBit::NB_INPLACE_POWER) + .value("NB_INPLACE_LSHIFT", PyNumberSlotBit::NB_INPLACE_LSHIFT) + .value("NB_INPLACE_RSHIFT", PyNumberSlotBit::NB_INPLACE_RSHIFT) + .value("NB_INPLACE_AND", PyNumberSlotBit::NB_INPLACE_AND) + .value("NB_INPLACE_XOR", PyNumberSlotBit::NB_INPLACE_XOR) + .value("NB_INPLACE_OR", PyNumberSlotBit::NB_INPLACE_OR) + .value("NB_FLOOR_DIVIDE", PyNumberSlotBit::NB_FLOOR_DIVIDE) + .value("NB_TRUE_DIVIDE", PyNumberSlotBit::NB_TRUE_DIVIDE) + .value( + "NB_INPLACE_FLOOR_DIVIDE", PyNumberSlotBit::NB_INPLACE_FLOOR_DIVIDE) + .value("NB_INPLACE_TRUE_DIVIDE", PyNumberSlotBit::NB_INPLACE_TRUE_DIVIDE) + .value("NB_INDEX", PyNumberSlotBit::NB_INDEX) + .value("NB_MATRIX_MULTIPLY", PyNumberSlotBit::NB_MATRIX_MULTIPLY) + .value( + "NB_INPLACE_MATRIX_MULTIPLY", + PyNumberSlotBit::NB_INPLACE_MATRIX_MULTIPLY); + + py::enum_(dynamo_module, "PyTypeSlots") + .value("TP_HASH", PyTypeSlotBit::TP_HASH) + .value("TP_ITER", PyTypeSlotBit::TP_ITER) + .value("TP_ITERNEXT", PyTypeSlotBit::TP_ITERNEXT) + .value("TP_CALL", PyTypeSlotBit::TP_CALL) + .value("TP_REPR", PyTypeSlotBit::TP_REPR) + .value("TP_RICHCOMPARE", PyTypeSlotBit::TP_RICHCOMPARE) + .value("TP_GETATTRO", PyTypeSlotBit::TP_GETATTRO) + .value("TP_SETATTRO", PyTypeSlotBit::TP_SETATTRO) + .value("TP_DESCR_GET", PyTypeSlotBit::TP_DESCR_GET) + .value("TP_DESCR_SET", PyTypeSlotBit::TP_DESCR_SET); } } // namespace torch::dynamo diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index 463eb7de0c222..fc14ec8ca2888 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -656,7 +656,7 @@ static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { throw_python_error(); } - if (logger == Py_None) { + if (Py_IsNone(logger)) { python_verbose_logger = nullptr; } else { python_verbose_logger = logger; @@ -1261,7 +1261,7 @@ static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) { PyObject* prior_compiler = the_autograd_compiler; PyObject* prior_dynamic = default_dyn_type_int == 0 ? Py_False : Py_True; default_dyn_type_int = b; - if (obj == Py_None) { // disable + if (Py_IsNone(obj)) { // disable the_autograd_compiler = nullptr; // decref not needed due to `prior` Engine::set_compiled_autograd(nullptr); } else { // enable diff --git a/torch/csrc/functorch/init.cpp b/torch/csrc/functorch/init.cpp index 32e781ce43056..57669b84bcc3f 100644 --- a/torch/csrc/functorch/init.cpp +++ b/torch/csrc/functorch/init.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include @@ -246,13 +247,36 @@ static RandomnessType get_randomness_enum(const std::string& randomness) { static int64_t _grad_increment_nesting() { // See NOTE [grad and vjp interaction with no_grad] bool prev_grad_mode = c10::GradMode::is_enabled(); + // When inference_mode is on, new tensors lack autograd dispatch keys + // (TensorImpl strips them in its constructor). Toggle the flag off so + // tensors created inside the transform can participate in autograd. + // Uses AutogradState::set_inference_mode — not the InferenceMode RAII + // guard, which would clobber grad_mode and fw_grad_mode. + bool prev_inference_mode = c10::InferenceMode::is_enabled(); + if (prev_inference_mode) { + auto state = c10::AutogradState::get_tls_state(); + state.set_inference_mode(false); + c10::AutogradState::set_tls_state(state); + } return initAndPushDynamicLayer( - TransformType::Grad, std::nullopt, std::nullopt, prev_grad_mode); + TransformType::Grad, + std::nullopt, + std::nullopt, + prev_grad_mode, + std::nullopt, + std::nullopt, + prev_inference_mode); } static int64_t _grad_decrement_nesting() { auto layer = popDynamicLayerAndDeleteMetadata(); TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Grad); + auto& meta = std::get(layer.interpreter().meta()); + if (meta.prevInferenceMode_) { + auto state = c10::AutogradState::get_tls_state(); + state.set_inference_mode(true); + c10::AutogradState::set_tls_state(state); + } return layer.layerId(); } @@ -260,17 +284,31 @@ static int64_t _jvp_increment_nesting() { // See NOTE [grad and vjp interaction with no_grad] bool prev_fwd_grad_mode = c10::AutogradState::get_tls_state().get_fw_grad_mode(); + bool prev_inference_mode = c10::InferenceMode::is_enabled(); + if (prev_inference_mode) { + auto state = c10::AutogradState::get_tls_state(); + state.set_inference_mode(false); + c10::AutogradState::set_tls_state(state); + } return initAndPushDynamicLayer( TransformType::Jvp, std::nullopt, std::nullopt, std::nullopt, - prev_fwd_grad_mode); + prev_fwd_grad_mode, + std::nullopt, + prev_inference_mode); } static int64_t _jvp_decrement_nesting() { auto layer = popDynamicLayerAndDeleteMetadata(); TORCH_INTERNAL_ASSERT(layer.key() == TransformType::Jvp); + auto& meta = std::get(layer.interpreter().meta()); + if (meta.prevInferenceMode_) { + auto state = c10::AutogradState::get_tls_state(); + state.set_inference_mode(true); + c10::AutogradState::set_tls_state(state); + } return layer.layerId(); } @@ -416,9 +454,15 @@ static void dump_local_tls() { namespace { // Pop the DynamicLayer stack until it's at the given depth. +// Used by Dynamo for error-recovery cleanup of the transform stack. +// +// NB: we peek at .back() to determine the type, then call the +// type-specific decrement helper which does the actual pop. +// Do NOT pop before the switch — the helpers call +// popDynamicLayerAndDeleteMetadata() internally. void popDynamicLayerStackToDepth(size_t depth) { while (at::functorch::getDynamicLayerStack().size() > depth) { - const auto top = popDynamicLayer(); + const auto& top = at::functorch::getDynamicLayerStack().back(); switch (top.key()) { case at::functorch::TransformType::Vmap: _vmap_decrement_nesting(); diff --git a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp index 6e8d40f31e7df..4ad6f83d1a07b 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_holder.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_holder.cpp @@ -255,7 +255,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { py::str(c10::DeviceTypeName(device_.type(), true)).ptr(), nullptr)); TORCH_INTERNAL_ASSERT( - result.ptr() != nullptr && result.ptr() != Py_None, + result.ptr() != nullptr && !Py_IsNone(result.ptr()), "Failed to load AOTI kernel. Operator Name is ", op_name_with_overload_); @@ -327,14 +327,15 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { py::isinstance(metadata["tensor_list"])); auto tensor_list = metadata["tensor_list"].cast(); std::vector test_list_metadata; - for (auto item_tensor : tensor_list) { + test_list_metadata.reserve(tensor_list.size()); + for (const auto& item_tensor : tensor_list) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( py::isinstance(item_tensor)); auto metadata = item_tensor.cast(); - auto tensor_metadata = build_tensor_metadata(metadata); - test_list_metadata.push_back(tensor_metadata); + test_list_metadata.push_back(build_tensor_metadata(metadata)); } - parameter_metadata_list.emplace_back(test_list_metadata, arg_idx); + parameter_metadata_list.emplace_back( + std::move(test_list_metadata), arg_idx); } else if (is_scalar) { // Scalar auto metadata = item_metadata.cast(); @@ -362,7 +363,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { // String auto metadata = item_metadata.cast(); auto str_value = metadata["string_value"].cast(); - parameter_metadata_list.emplace_back(str_value, arg_idx); + parameter_metadata_list.emplace_back(std::move(str_value), arg_idx); } else if (is_dtype) { // Dtype auto metadata = item_metadata.cast(); @@ -378,7 +379,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { auto device_type_value = metadata["device_type_value"].cast(); auto device = c10::Device(device_type_value); - if (metadata["device_index_value"].ptr() != Py_None) { + if (!Py_IsNone(metadata["device_index_value"].ptr())) { auto device_index_value = metadata["device_index_value"].cast(); device.set_index(device_index_value); @@ -404,7 +405,7 @@ void AOTIPythonKernelHolder::init_aoti_kernel_cache() { aoti_kernel_metadata.parameter_metadata_list_ = std::move(parameter_metadata_list); aoti_kernel_metadata.kernel_runner_ = load_aoti_model_runner(kernel_path); - aoti_kernel_cache_.push_back(aoti_kernel_metadata); + aoti_kernel_cache_.push_back(std::move(aoti_kernel_metadata)); } } @@ -501,7 +502,7 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( qualified_name.end()); py::gil_scoped_acquire gil; - py::handle op_py_func = op.getPythonOp(pyinterpreter_, [&]() -> PyObject* { + py::handle op_py_func = op.getPythonOp([&]() -> PyObject* { py::handle torch_api_function = py::module::import("torch") .attr("ops") .attr(ns_str.c_str()) @@ -514,7 +515,7 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( }); TORCH_INTERNAL_ASSERT( - op_py_func.ptr() != nullptr && op_py_func.ptr() != Py_None, + op_py_func.ptr() != nullptr && !Py_IsNone(op_py_func.ptr()), "Failed to get python operation. Operator Name is ", op.operator_name().name, ", Overload Name is ", @@ -525,7 +526,7 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( .attr("aoti_compile_with_persistent_cache"); TORCH_INTERNAL_ASSERT( aot_compile_function.ptr() != nullptr && - aot_compile_function.ptr() != Py_None, + !Py_IsNone(aot_compile_function.ptr()), "Failed to import - torch._inductor.aoti_eager.aoti_compile_with_persistent_cache"); // Pass the python operation to the AOT Inductor to generate the kernel @@ -541,7 +542,7 @@ std::string AOTIPythonKernelHolder::produce_aoti_kernel_lib( args_kwargs.first.ptr(), args_kwargs.second.ptr(), nullptr)); - TORCH_INTERNAL_ASSERT(result.ptr() != nullptr && result.ptr() != Py_None); + TORCH_INTERNAL_ASSERT(result.ptr() != nullptr && !Py_IsNone(result.ptr())); auto kernel_lib_path = py::cast(result); TORCH_CHECK( diff --git a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp index 911665f89f676..c16aac57cef5b 100644 --- a/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp +++ b/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp @@ -151,7 +151,7 @@ ParameterMetadata::ParameterMetadata( for (const auto& tensor : tensor_list) { tensor_metadata_list.emplace_back(tensor); } - value_ = tensor_metadata_list; + value_ = std::move(tensor_metadata_list); } ParameterMetadata::ParameterMetadata( diff --git a/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h b/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h index 6ca660093c67d..9fadafe66d2e5 100644 --- a/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h +++ b/torch/csrc/inductor/aoti_runtime/arrayref_tensor.h @@ -39,6 +39,8 @@ inline bool is_contiguous_strides_for_shape( template class ArrayRefTensor { public: + using value_type = T; + ArrayRefTensor() = default; explicit ArrayRefTensor( @@ -147,6 +149,16 @@ static_assert( (alignof(ArrayRefTensor) > 4 ? sizeof(int32_t) : 0), "changing the size of ArrayRefTensor breaks ABI compatibility!"); +// Type trait to detect ArrayRefTensor at compile time. +// Used by codegen_subgraph_prefix to conditionally borrow arrayref inputs. +template +struct is_arrayref_tensor_type : std::false_type {}; +template +struct is_arrayref_tensor_type> : std::true_type {}; +template +inline constexpr bool is_arrayref_tensor_type_v = + is_arrayref_tensor_type::value; + template inline ArrayRefTensor reinterpret_tensor_wrapper( const ArrayRefTensor& self, diff --git a/torch/csrc/inductor/aoti_runtime/arrayref_tensor_conversion.h b/torch/csrc/inductor/aoti_runtime/arrayref_tensor_conversion.h new file mode 100644 index 0000000000000..9a45c4171575b --- /dev/null +++ b/torch/csrc/inductor/aoti_runtime/arrayref_tensor_conversion.h @@ -0,0 +1,85 @@ +#pragma once + +// Zero-copy conversion utilities between ArrayRefTensor (C++ template) and +// AOTInductorArrayRefTensor (plain C struct). +// +// These helpers allow the host process to marshal ArrayRefTensor objects into +// the C-compatible AOTInductorArrayRefTensor descriptors before calling into a +// DSO, and to unmarshal the descriptors back after the call. Because only +// C types cross the DSO boundary, the host and DSO can be linked against +// different C++ standard libraries (e.g. libc++ vs libstdc++) without ABI +// conflicts. +// +// IMPORTANT: Both sides share the same underlying data buffers -- no copies +// are made. The caller must ensure the data remains valid for the lifetime +// of the descriptor. + +#include +#include + +#include +#include +#include +#include + +namespace torch::aot_inductor { + +inline void validate_arrayref_tensor_ndim(int32_t ndim) { + if (ndim < 0 || ndim > AOTI_ARRAYREF_TENSOR_MAX_DIMS) { + throw std::runtime_error( + "AOTInductorArrayRefTensor ndim exceeds AOTI_ARRAYREF_TENSOR_MAX_DIMS"); + } +} + +// ------------------------------------------------------------------------- +// ArrayRefTensor --> AOTInductorArrayRefTensor (zero-copy) +// ------------------------------------------------------------------------- +template +inline void arrayref_tensor_to_c( + const ArrayRefTensor& src, + AOTInductorArrayRefTensor& dst) { + const auto sizes = src.sizes(); + const auto strides = src.strides(); + dst.data = const_cast(static_cast(src.data())); + dst.numel = static_cast(src.numel()); + dst.ndim = static_cast(sizes.size()); + dst.dtype = aoti_torch_dtype>(); + dst.device_type = src.device_type(); + dst.device_idx = src.device_idx(); + + validate_arrayref_tensor_ndim(dst.ndim); + assert(dst.ndim <= AOTI_ARRAYREF_TENSOR_MAX_DIMS); + std::memcpy(dst.sizes, sizes.data(), dst.ndim * sizeof(int64_t)); + std::memcpy(dst.strides, strides.data(), dst.ndim * sizeof(int64_t)); + const int32_t remaining = AOTI_ARRAYREF_TENSOR_MAX_DIMS - dst.ndim; + std::memset(dst.sizes + dst.ndim, 0, remaining * sizeof(int64_t)); + std::memset(dst.strides + dst.ndim, 0, remaining * sizeof(int64_t)); + std::memset(dst.reserved, 0, sizeof(dst.reserved)); +} + +template +inline AOTInductorArrayRefTensor arrayref_tensor_to_c( + const ArrayRefTensor& src) { + AOTInductorArrayRefTensor dst; + arrayref_tensor_to_c(src, dst); + return dst; +} + +// ------------------------------------------------------------------------- +// AOTInductorArrayRefTensor --> ArrayRefTensor (zero-copy) +// ------------------------------------------------------------------------- +template +inline ArrayRefTensor c_to_arrayref_tensor( + const AOTInductorArrayRefTensor& src) { + validate_arrayref_tensor_ndim(src.ndim); + return ArrayRefTensor( + MiniArrayRef( + static_cast(const_cast(src.data)), + static_cast(src.numel)), + MiniArrayRef(src.sizes, static_cast(src.ndim)), + MiniArrayRef(src.strides, static_cast(src.ndim)), + src.device_type, + src.device_idx); +} + +} // namespace torch::aot_inductor diff --git a/torch/csrc/inductor/aoti_runtime/interface.h b/torch/csrc/inductor/aoti_runtime/interface.h index ffccdd94e5be2..a47bb3e4c67ff 100644 --- a/torch/csrc/inductor/aoti_runtime/interface.h +++ b/torch/csrc/inductor/aoti_runtime/interface.h @@ -35,6 +35,52 @@ struct AOTInductorConstantMapEntry { AtenTensorHandle handle; }; +// --------------------------------------------------------------------------- +// C-compatible tensor descriptor for crossing the DSO boundary. +// +// This struct carries the same information as ArrayRefTensor but uses only +// C-compatible types so the host process and DSO can be built with different +// C++ standard libraries (e.g. libc++ vs libstdc++). All pointer fields +// reference memory owned by the caller; no copies are made. +// +// Maximum supported number of dimensions. 8 covers all practical AOTI +// models; tensors with more dims should fall back to the AtenTensorHandle +// interface. +// --------------------------------------------------------------------------- +#define AOTI_ARRAYREF_TENSOR_MAX_DIMS 8 + +struct AOTInductorArrayRefTensor { + // Pointer to the raw data buffer. Not owned. + void* data; + + // Number of elements in the data buffer (product of sizes for contiguous + // tensors). + int64_t numel; + + // Static-size arrays for shape metadata. Only the first `ndim` entries + // are meaningful. + int64_t sizes[AOTI_ARRAYREF_TENSOR_MAX_DIMS]; + int64_t strides[AOTI_ARRAYREF_TENSOR_MAX_DIMS]; + + // Number of dimensions (0 <= ndim <= AOTI_ARRAYREF_TENSOR_MAX_DIMS). + int32_t ndim; + + // Torch dtype encoded as int32_t (same encoding as aoti_torch_dtype_*()). + int32_t dtype; + + // Device information. + int32_t device_type; + int32_t device_idx; + + // Reserved for future extension. Zero-initialize and do not read — a + // newer reader must tolerate zeros, and an older reader must ignore them. + int64_t reserved[4]; +}; + +static_assert( + sizeof(AOTInductorArrayRefTensor) == 192, + "changing the size of AOTInductorArrayRefTensor breaks ABI compatibility!"); + // TODO: Deprecate this API. This was kept for BC compatibility. // Please use AOTInductorModelContainerCreateWithDevice instead. AOTI_API AOTIRuntimeError AOTInductorModelContainerCreate( @@ -265,4 +311,25 @@ AOTI_API AOTIRuntimeError AOTInductorModelContainerGetCallSpec( const char** in_spec, const char** out_spec); +// --------------------------------------------------------------------------- +// C-ABI-safe variant of AOTInductorModelRunMinimalArrayrefInterface. +// +// Instead of passing std::tuple...>& (which encodes C++ +// standard library types into the ABI), this function accepts flat C arrays +// of AOTInductorArrayRefTensor descriptors. The descriptors reference the +// same underlying data buffers -- no copies are made. +// +// The host process marshals its ArrayRefTensor objects into +// AOTInductorArrayRefTensor descriptors, calls into the DSO through this +// pure-C interface, and then unmarshals the output descriptors back. +// Because only C types cross the DSO boundary, the host and DSO can be +// built with different C++ standard libraries (e.g. libc++ vs libstdc++). +// --------------------------------------------------------------------------- +AOTI_API AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterfaceV2( + AOTInductorModelHandle model_handle, + int32_t num_inputs, + const AOTInductorArrayRefTensor* inputs, + int32_t num_outputs, + AOTInductorArrayRefTensor* outputs); + } // extern "C" diff --git a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h index fc6e61abb185d..2099e129f24b9 100644 --- a/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h +++ b/torch/csrc/inductor/aoti_runtime/kernel_context_tls.h @@ -97,7 +97,7 @@ struct KernelContext { } res_stack += std::string{function} + '[' + std::string{p} + ']'; res_stack += '\n'; - res_stack += fs::path{filename}.filename(); + res_stack += fs::path{filename}.filename().string(); res_stack += '\n'; res_stack += std::to_string(fileline); res_stack += '\n'; diff --git a/torch/csrc/inductor/aoti_runtime/model.h b/torch/csrc/inductor/aoti_runtime/model.h index 253c5e917e76b..41edb7c1e8e42 100644 --- a/torch/csrc/inductor/aoti_runtime/model.h +++ b/torch/csrc/inductor/aoti_runtime/model.h @@ -6,6 +6,8 @@ // applies to other files under torch/csrc/inductor/aoti_runtime/. #include +struct AOTInductorArrayRefTensor; + namespace torch::aot_inductor { class AOTInductorModel : public AOTInductorModelBase { @@ -43,6 +45,12 @@ class AOTInductorModel : public AOTInductorModelBase { DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor); + void run_impl_minimal_arrayref_interface_v2_raw( + const AOTInductorArrayRefTensor* c_inputs, + AOTInductorArrayRefTensor* c_outputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + static std::unique_ptr Create( std::shared_ptr constants_map, std::shared_ptr> constants_array, diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index b594a812237bf..fda77cf43c572 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -481,13 +481,19 @@ class AOTInductorModelContainer { auto constants_map_to_update = get_constants_map(use_inactive); auto num_constants = models_[0]->num_constants(); + size_t non_folded_idx = 0; for (size_t idx = 0; idx < num_constants; idx++) { + bool from_folded = models_[0]->constant_from_folded(idx); + if (from_folded) { + continue; + } auto constant_name = std::string(models_[0]->constant_name(static_cast(idx))); auto it = constants_map.find(constant_name); if (it == constants_map.end() && !(use_inactive && _is_tensor_constant_or_buffer_type_or_empty_parameter(idx))) { + non_folded_idx++; continue; } @@ -504,6 +510,7 @@ class AOTInductorModelContainer { constants_map_to_update->insert_or_assign( constant_name, MaybeOwningAtenTensorHandle(tensor, /* user_managed = */ true)); + non_folded_idx++; continue; } @@ -512,7 +519,7 @@ class AOTInductorModelContainer { // Move the data to container handled blob. uint8_t* internal_constants_ptr = - constants_blob_ptr + constants_internal_offset_[idx]; + constants_blob_ptr + constants_internal_offset_[non_folded_idx]; void* user_constant_ptr; int64_t constant_size; int64_t* stride; @@ -537,11 +544,12 @@ class AOTInductorModelContainer { constants_blob_ptr, constant_size, offset, - constants_internal_offset_[idx]); + constants_internal_offset_[non_folded_idx]); // For mps tensors, all constants are stored in one buffer, with the // offset being where the constant starts. So we want to change the - // constant tensor's offset to point to constants_internal_offset_[idx] - offset = constants_internal_offset_[idx] / + // constant tensor's offset to point to + // constants_internal_offset_[non_folded_idx] + offset = constants_internal_offset_[non_folded_idx] / aoti_torch_dtype_element_size(dtype); #elif USE_CUDA AOTI_RUNTIME_CUDA_CHECK(cudaMemcpy( @@ -573,6 +581,7 @@ class AOTInductorModelContainer { // ownership of the tensor_handle will be taken over. constants_map_to_update->insert_or_assign( constant_name, RAIIAtenTensorHandle(tensor_handle)); + non_folded_idx++; } // Update the inactive constant array. update_array_from_map( diff --git a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h index 56bb478396004..b85b69449edc2 100644 --- a/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h +++ b/torch/csrc/inductor/aoti_runtime/sycl_runtime_wrappers.h @@ -134,8 +134,13 @@ static std::unique_ptr _createKernel( uint32_t numWarps, uint32_t sharedMemory, void** params, - sycl::queue* queuePtr, - uint32_t threadsPerWarp) { + sycl::queue* queuePtr) { + uint32_t threadsPerWarp = kernelPtr->get_info< + sycl::info::kernel_device_specific::compile_sub_group_size>( + queuePtr->get_device()); + if (threadsPerWarp == 0) { + threadsPerWarp = 32; // default to 32 if not set + } std::string kernelName = kernelPtr->get_info(); uint32_t numParams = kernelPtr->get_info(); diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index a53e5e609d9f7..802791b26a266 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -3,6 +3,7 @@ #include #include +#include // This header defines a stable C API for certain ATen functionality in // libtorch. The AOTInductor compiled model.so will only refer to this header @@ -39,6 +40,10 @@ // The following files are implemented in a header-only way and are guarded by // test/cpp/aoti_abi_check #include +#include +#include +#include +#include #include #include @@ -633,24 +638,24 @@ AOTI_TORCH_EXPORT void aoti_torch_check( const char* msg); #ifdef STRIP_ERROR_MESSAGES -#define AOTI_TORCH_CHECK(cond, ...) \ - if (!(cond)) { \ - aoti_torch_check( \ - false, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + STD_TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ } #else -#define AOTI_TORCH_CHECK(cond, ...) \ - if (!(cond)) { \ - aoti_torch_check( \ - false, \ - __func__, \ - __FILE__, \ - static_cast(__LINE__), \ - TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check( \ + false, \ + __func__, \ + __FILE__, \ + static_cast(__LINE__), \ + STD_TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ } #endif @@ -682,6 +687,10 @@ int32_t aoti_torch_dtype() = delete; DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2, float8_e5m2) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fn, float8_e4m3fn) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e5m2fnuz, float8_e5m2fnuz) +DEFINE_DTYPE_SPECIALIZATION(c10::Float8_e4m3fnuz, float8_e4m3fnuz) DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) DEFINE_DTYPE_SPECIALIZATION(float, float32) DEFINE_DTYPE_SPECIALIZATION(double, float64) diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 77072a4e68920..21ee7438165dc 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -121,12 +121,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Scalar(AtenTensorHand AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_rand_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_like(AtenTensorHandle self, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randint_like_low_dtype(AtenTensorHandle self, int64_t low, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randn_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index f047184bb8ca3..413a393f6fa7d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -135,12 +135,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Scalar(AtenTensorHan AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_rand_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_like(AtenTensorHandle self, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randint_like_low_dtype(AtenTensorHandle self, int64_t low, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randn_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h index 43b2570d1c69b..e71cf771a441d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_mps.h @@ -25,7 +25,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_moving_avg_obs_fq_helper_ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__fused_rms_norm(AtenTensorHandle input, const int64_t* normalized_shape, int64_t normalized_shape_len_, AtenTensorHandle* weight, double* eps, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__grouped_mm(AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle* offs, AtenTensorHandle* bias, int32_t* out_dtype, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__histogramdd_from_bin_cts(AtenTensorHandle self, const int64_t* bins, int64_t bins_len_, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_attention_math_for_mps(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_mask, double dropout_p, int32_t is_causal, AtenTensorHandle* dropout_mask, double* scale, int32_t enable_gqa, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* attn_bias, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, int64_t* ret4, int64_t* ret5, AtenTensorHandle* ret6, AtenTensorHandle* ret7, AtenTensorHandle* ret8); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__scaled_dot_product_fused_attention_overrideable_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle attn_bias, const int32_t* grad_input_mask, int64_t grad_input_mask_len_, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps__to_sparse(AtenTensorHandle self, int32_t* layout, const int64_t** blocksize, int64_t blocksize_len_, int64_t* dense_dim, AtenTensorHandle* ret0); @@ -62,6 +62,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_cumsum(AtenTensorHandle self, in AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_exponential(AtenTensorHandle self, double lambd, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_gcd(AtenTensorHandle self, AtenTensorHandle other, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_grid_sampler_2d_backward(AtenTensorHandle grad_output, AtenTensorHandle input, AtenTensorHandle grid, int64_t interpolation_mode, int64_t padding_mode, int32_t align_corners, const int32_t* output_mask, int64_t output_mask_len_, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_hann_window(int64_t window_length, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histc(AtenTensorHandle self, int64_t bins, double min, double max, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_histogram_bin_ct(AtenTensorHandle self, int64_t bins, const double** range, int64_t range_len_, AtenTensorHandle* weight, int32_t density, AtenTensorHandle* ret0, AtenTensorHandle* ret1); @@ -88,6 +89,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nanmedian(AtenTensorHandle self, AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_native_dropout(AtenTensorHandle input, double p, int32_t* train, AtenTensorHandle* ret0, AtenTensorHandle* ret1); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero(AtenTensorHandle self, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_nonzero_static(AtenTensorHandle self, int64_t size, int64_t fill_value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_normal_functional(AtenTensorHandle self, double mean, double std, AtenGeneratorHandle* generator, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_permute(AtenTensorHandle self, const int64_t* dims, int64_t dims_len_, AtenTensorHandle* ret0); @@ -97,12 +99,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Scalar(AtenTensorHand AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_pow_Tensor_Tensor(AtenTensorHandle self, AtenTensorHandle exponent, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_rand_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_like(AtenTensorHandle self, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randint_like_low_dtype(AtenTensorHandle self, int64_t low, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randn_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mps_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h index 307ada8eb8203..ea0cbfdc5659d 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_xpu.h @@ -54,12 +54,18 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_permute(AtenTensorHandle self, c AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_polar(AtenTensorHandle abs, AtenTensorHandle angle, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_rand_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint(int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_generator(int64_t high, const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low(int64_t low, int64_t high, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_low_out(AtenTensorHandle out, int64_t low, int64_t high, const int64_t* size, int64_t size_len_); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_like(AtenTensorHandle self, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randint_like_low_dtype(AtenTensorHandle self, int64_t low, int64_t high, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_generator(const int64_t* size, int64_t size_len_, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_like(AtenTensorHandle self, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randn_like_generator(AtenTensorHandle self, AtenGeneratorHandle* generator, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, int32_t* memory_format, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_randperm(int64_t n, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_reshape(AtenTensorHandle self, const int64_t* shape, int64_t shape_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_xpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); diff --git a/torch/csrc/inductor/aoti_torch/generated_enum_converters.h b/torch/csrc/inductor/aoti_torch/generated_enum_converters.h index 820023c1031e6..14b97a7f4bc26 100644 --- a/torch/csrc/inductor/aoti_torch/generated_enum_converters.h +++ b/torch/csrc/inductor/aoti_torch/generated_enum_converters.h @@ -50,6 +50,8 @@ inline c10::ScalarType convertSerializedScalarType(int serialized_value) { static_cast(c10::ScalarType::Float8_e4m3fnuz), // 31 static_cast(c10::ScalarType::Float8_e5m2fnuz), // 32 static_cast(c10::ScalarType::Float8_e8m0fnu), // 33 + static_cast(c10::ScalarType::UInt32), // 34 + static_cast(c10::ScalarType::UInt64), // 35 }; constexpr int kMapSize = sizeof(kScalarTypeMap) / sizeof(kScalarTypeMap[0]); diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp index 8a2a6056ba58a..a4ce975d68438 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -273,11 +273,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( index, " but got ", serialized_arg_type); - std::vector ret; - for (const auto& arg : serialized_arg_val) { - ret.push_back(arg.get()); - } - stack.at(index) = std::move(ret); + stack.at(index) = serialized_arg_val.get>(); } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) { TORCH_CHECK( serialized_arg_type == "as_bools", @@ -287,27 +283,15 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( index, " but got ", serialized_arg_type); - std::vector ret; - for (const auto& arg : serialized_arg_val) { - ret.push_back(arg.get()); - } - stack.at(index) = std::move(ret); + stack.at(index) = serialized_arg_val.get>(); } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) { if (serialized_arg_type == "as_ints") { dynamic_args.emplace_back( index, DynamicArgType::ListIntType, serialized_arg_val.size()); } else if (serialized_arg_type == "as_floats") { - std::vector ret; - for (const auto& arg : serialized_arg_val) { - ret.push_back(arg); - } - stack.at(index) = std::move(ret); + stack.at(index) = serialized_arg_val.get>(); } else if (serialized_arg_type == "as_bools") { - std::vector ret; - for (const auto& arg : serialized_arg_val) { - ret.push_back(arg); - } - stack.at(index) = std::move(ret); + stack.at(index) = serialized_arg_val.get>(); } else { TORCH_CHECK( false, @@ -322,6 +306,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( at::ListType::ofOptionalTensors())) { if (serialized_arg_type == "as_optional_tensors") { std::vector list_item_types; + list_item_types.reserve(serialized_arg_val.size()); for (const auto& arg : serialized_arg_val) { list_item_types.push_back(arg.begin().key()); } @@ -352,11 +337,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( index, " but got ", serialized_arg_type); - std::vector ret; - for (const auto& arg : serialized_arg_val) { - ret.push_back(arg.get()); - } - stack.at(index) = std::move(ret); + stack.at(index) = serialized_arg_val.get>(); } else { TORCH_CHECK( false, @@ -400,6 +381,32 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( } break; } + case c10::TypeKind::AnyType: { + // For Any type, dispatch based on the serialized type + if (serialized_arg_type == "as_string") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_int") { + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + } else if (serialized_arg_type == "as_float") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_bool") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_tensor") { + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1); + } else if (serialized_arg_type == "as_none") { + stack.at(index) = c10::IValue{}; + } else { + TORCH_CHECK( + false, + "Unsupported serialized type ", + serialized_arg_type, + " for Any type argument ", + index, + " in extern kernel ", + op_kernel->target_); + } + break; + } // TODO: handle the other input types default: TORCH_CHECK( @@ -756,21 +763,23 @@ void OSSProxyExecutor::call_function( } case DynamicArgType::ListTensorType: { std::vector tensor_list; + tensor_list.reserve(length); for (int j = 0; j < length; j++) { at::Tensor* tensor = tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); tensor_list.push_back(*tensor); } - stack[arg_index] = tensor_list; + stack[arg_index] = std::move(tensor_list); break; } case DynamicArgType::ListOptionalTensorType: { - std::vector> optional_tensor_list; auto& list_item_types = dynamic_arg.list_item_types; TORCH_CHECK( list_item_types.has_value(), "Could not find list of item types for optional tensor list input"); + std::vector> optional_tensor_list; + optional_tensor_list.reserve(list_item_types.value().size()); for (const std::string& item_type : list_item_types.value()) { if (item_type == "as_tensor") { at::Tensor* tensor = tensor_handle_to_tensor_pointer( @@ -780,7 +789,7 @@ void OSSProxyExecutor::call_function( optional_tensor_list.emplace_back(std::nullopt); } } - stack[arg_index] = optional_tensor_list; + stack[arg_index] = std::move(optional_tensor_list); break; } case DynamicArgType::ListIntType: { @@ -789,7 +798,7 @@ void OSSProxyExecutor::call_function( for (int j = 0; j < length; j++) { vals.push_back(flatten_int_args[int_id++]); } - stack[arg_index] = vals; + stack[arg_index] = std::move(vals); break; } default: diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index d6c22b096effe..3e3d82f349087 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -252,10 +252,11 @@ AOTITorchError aoti_torch_strlist_to_ivalue( C10IValueHandle* ivalue) { AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ c10::List vec; + vec.reserve(len); for (int64_t i = 0; i < len; i++) { - vec.push_back(std::string(val[i])); + vec.emplace_back(val[i]); } - c10::IValue* t = new c10::IValue(vec); + c10::IValue* t = new c10::IValue(std::move(vec)); *ivalue = reinterpret_cast(t); }); } diff --git a/torch/csrc/inductor/array_ref_impl.h b/torch/csrc/inductor/array_ref_impl.h index 8cfbc12fb2c3d..e36a3653009ad 100644 --- a/torch/csrc/inductor/array_ref_impl.h +++ b/torch/csrc/inductor/array_ref_impl.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include diff --git a/torch/csrc/inductor/cpp_prefix.h b/torch/csrc/inductor/cpp_prefix.h index a51bd74496fe8..18765cea13f66 100644 --- a/torch/csrc/inductor/cpp_prefix.h +++ b/torch/csrc/inductor/cpp_prefix.h @@ -9,6 +9,7 @@ #include #include #include +#include // WARNING: be extra careful when including more ATen/c10 header files here! // Because AOTInductor generated code will copy-paste this cpp_prefix.h for @@ -485,7 +486,7 @@ struct IndexValueVec { index = at::vec::VectorizedN(0); }; - IndexValueVec() {}; + IndexValueVec() = default; }; template < @@ -776,9 +777,39 @@ Welford welford_vec_reduce_all( } #endif +inline std::atomic* inductor_cpu_integer_div_error_flag = nullptr; + +inline void inductor_cpu_note_integer_div_by_zero() { + if (inductor_cpu_integer_div_error_flag != nullptr) { + inductor_cpu_integer_div_error_flag->store(1, std::memory_order_relaxed); + } else { + TORCH_CHECK(false, "ZeroDivisionError"); + } +} + +inline void inductor_cpu_throw_if_integer_div_error(std::atomic& err) { + if (err.load(std::memory_order_acquire)) { + TORCH_CHECK(false, "ZeroDivisionError"); + } +} + template -inline typename std::common_type_t mod(T a, U b) { - return a % b; +inline std::common_type_t mod(T a, U b) { + using C = std::common_type_t; + static_assert( + std::is_integral_v, + "inductor template mod(T a, U b) is only for integral types; use the float/double specializations " + "for floating-point operands."); + if (C10_UNLIKELY_OR_CONST(b == 0)) { + inductor_cpu_note_integer_div_by_zero(); + return C(0); + } + const C a_c = static_cast(a); + const C b_c = static_cast(b); + if (a_c == std::numeric_limits::min() && b_c == C(-1)) { + return C(0); + } + return a_c % b_c; } template <> inline float mod(float a, float b) { @@ -789,6 +820,61 @@ inline double mod(double a, double b) { return std::fmod(a, b); } +template +inline T remainder_integral(T a, T b) { + static_assert( + std::is_integral_v, "remainder_integral expects integral scalar T"); + if (C10_UNLIKELY_OR_CONST(b == 0)) { + inductor_cpu_note_integer_div_by_zero(); + return T(0); + } + if (a == std::numeric_limits::min() && b == T(-1)) { + return T(0); + } + T r = a % b; + if ((r != 0) && (c10::is_negative(r) != c10::is_negative(b))) { + r += b; + } + return r; +} + +#if INDUCTOR_USE_VECTOR_TYPES() +template +inline at::vec::Vectorized remainder_integral( + const at::vec::Vectorized& a, + const at::vec::Vectorized& b) { + static_assert( + std::is_integral_v, + "remainder_integral expects integral underlying type"); + // Some Vectorized (e.g. Vectorized8) deletes operator[]; + // use store/load like + using Vec = at::vec::Vectorized; + constexpr int kLen = Vec::size(); + alignas(alignof(Vec)) T out_buf[kLen]; + alignas(alignof(Vec)) T b_buf[kLen]; + a.store(out_buf); + b.store(b_buf); + for (int i = 0; i < kLen; ++i) { + out_buf[i] = remainder_integral(out_buf[i], b_buf[i]); + } + return Vec::loadu(out_buf); +} + +template +inline at::vec::VectorizedN remainder_integral( + const at::vec::VectorizedN& a, + const at::vec::VectorizedN& b) { + static_assert( + std::is_integral_v, + "remainder_integral expects integral underlying type"); + at::vec::VectorizedN out; + for (int i = 0; i < N; ++i) { + out[i] = remainder_integral(a[i], b[i]); + } + return out; +} +#endif + template inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) { if (at::_isnan(a)) { diff --git a/torch/csrc/inductor/cpp_wrapper/common.h b/torch/csrc/inductor/cpp_wrapper/common.h index 55f76761946bb..39a6fbbd2cf24 100644 --- a/torch/csrc/inductor/cpp_wrapper/common.h +++ b/torch/csrc/inductor/cpp_wrapper/common.h @@ -10,9 +10,6 @@ #include #else // pybind11 < 3.0: gil_simple.h does not exist yet. -// Use simple GIL management to avoid pulling in -// pybind11::detail::get_internals() which requires linking against pybind11 -// symbols unavailable in JIT builds. #define PYBIND11_SIMPLE_GIL_MANAGEMENT #include // Provide the _simple aliases so generated code works with either version. diff --git a/torch/csrc/inductor/cpp_wrapper/cuda.h b/torch/csrc/inductor/cpp_wrapper/cuda.h index 782a2b677276a..eb9888356c6ff 100644 --- a/torch/csrc/inductor/cpp_wrapper/cuda.h +++ b/torch/csrc/inductor/cpp_wrapper/cuda.h @@ -1,4 +1,10 @@ #pragma once - #include #include +#include + +#ifdef TORCH_INDUCTOR_PRECOMPILE_HEADERS +#include +#include +#include +#endif diff --git a/torch/csrc/inductor/cpp_wrapper/lazy_triton_compile.h b/torch/csrc/inductor/cpp_wrapper/lazy_triton_compile.h new file mode 100644 index 0000000000000..6fa59145adfd9 --- /dev/null +++ b/torch/csrc/inductor/cpp_wrapper/lazy_triton_compile.h @@ -0,0 +1,201 @@ +#pragma once + +#include + +#include +#include +#if defined(USE_XPU) +#include +#else +#include +#endif + +struct LazyKernelCompileResult { + std::string cubin_path; + std::string mangled_name; + int num_warps; + int shared_mem; + int xblock; + int yblock; + int zblock; + int r0block; + int rsplit; + int rsplit_size; + int config_index; + int global_scratch; + int profile_scratch; +}; + +static PyObject* (*_THPVariable_Wrap)(const at::TensorBase&) = nullptr; +static int32_t (*_THPUtils_unpackInt)(PyObject*) = nullptr; + +// Cached module and function references +static PyObject* triton_lazy_compile_module = nullptr; +static PyObject* start_kernel_compile = nullptr; +static PyObject* run_triton_kernel_with_autotune = nullptr; + +// Per-module dict for pending kernel compile results (avoids global state +// collisions when multiple compiled modules produce kernels with the same +// name). +static PyObject* _module_pending_kernels = nullptr; + +static inline void loadLazyCompileFuncs() { + if (triton_lazy_compile_module == nullptr) { + triton_lazy_compile_module = + PyImport_ImportModule("torch._inductor.runtime.triton_lazy_compile"); + AOTI_TORCH_CHECK( + triton_lazy_compile_module, "Failed to import triton_lazy_compile"); + + start_kernel_compile = PyObject_GetAttrString( + triton_lazy_compile_module, "start_kernel_compile"); + AOTI_TORCH_CHECK( + start_kernel_compile, "Failed to get start_kernel_compile"); + + run_triton_kernel_with_autotune = PyObject_GetAttrString( + triton_lazy_compile_module, "run_triton_kernel_with_autotune"); + AOTI_TORCH_CHECK( + run_triton_kernel_with_autotune, + "Failed to get run_triton_kernel_with_autotune"); + + RAIIPyObject guards_mod = PyImport_ImportModule("torch._C._dynamo.guards"); + AOTI_TORCH_CHECK(guards_mod, "Failed to import torch._C._dynamo.guards"); + + RAIIPyObject wrap_addr = + PyObject_GetAttrString(guards_mod, "_torchinductor_thp_variable_wrap"); + AOTI_TORCH_CHECK( + wrap_addr, "Failed to get _torchinductor_thp_variable_wrap"); + _THPVariable_Wrap = reinterpret_cast( + PyLong_AsVoidPtr(wrap_addr)); + AOTI_TORCH_CHECK(_THPVariable_Wrap, "THPVariable_Wrap not resolved"); + + RAIIPyObject unpack_addr = PyObject_GetAttrString( + guards_mod, "_torchinductor_thputils_unpack_int"); + AOTI_TORCH_CHECK( + unpack_addr, "Failed to get _torchinductor_thputils_unpack_int"); + _THPUtils_unpackInt = reinterpret_cast( + PyLong_AsVoidPtr(unpack_addr)); + AOTI_TORCH_CHECK(_THPUtils_unpackInt, "THPUtils_unpackInt not resolved"); + } +} + +static inline std::string getStringAttr(PyObject* obj, const char* attr) { + RAIIPyObject val = PyObject_GetAttrString(obj, attr); + AOTI_TORCH_CHECK(val, "Failed to get attribute"); + return PyUnicode_AsUTF8(val); +} + +static inline int getIntAttr(PyObject* obj, const char* attr) { + RAIIPyObject val = PyObject_GetAttrString(obj, attr); + AOTI_TORCH_CHECK(val, "Failed to get attribute"); + return _THPUtils_unpackInt(val); +} + +static inline int getOptionalIntAttr( + PyObject* obj, + const char* attr, + int sentinel = -1) { + RAIIPyObject val = PyObject_GetAttrString(obj, attr); + AOTI_TORCH_CHECK(val, "Failed to get attribute"); + return (!Py_IsNone(val.get())) ? _THPUtils_unpackInt(val) : sentinel; +} + +static inline LazyKernelCompileResult extractCompileResult(PyObject* result) { + LazyKernelCompileResult compile_result; + compile_result.cubin_path = getStringAttr(result, "cubin_path"); + compile_result.mangled_name = getStringAttr(result, "mangled_name"); + compile_result.num_warps = getIntAttr(result, "num_warps"); + compile_result.shared_mem = getIntAttr(result, "shared_mem"); + compile_result.xblock = getIntAttr(result, "xblock"); + compile_result.yblock = getIntAttr(result, "yblock"); + compile_result.zblock = getIntAttr(result, "zblock"); + compile_result.r0block = getIntAttr(result, "r0block"); + compile_result.rsplit = getIntAttr(result, "rsplit"); + compile_result.rsplit_size = getIntAttr(result, "rsplit_size"); + compile_result.config_index = getOptionalIntAttr(result, "config_index"); + compile_result.global_scratch = getOptionalIntAttr(result, "global_scratch"); + compile_result.profile_scratch = + getOptionalIntAttr(result, "profile_scratch"); + return compile_result; +} + +template +static inline PyObject* convertArgToPython(const T& arg) { + using DecayedT = std::decay_t; + if constexpr (std::is_same_v) { + at::Tensor* tensor_ptr = + torch::aot_inductor::tensor_handle_to_tensor_pointer(arg); + return _THPVariable_Wrap(*tensor_ptr); + } else if constexpr (std::is_same_v< + DecayedT, + torch::aot_inductor::RAIIAtenTensorHandle>) { + at::Tensor* tensor_ptr = + torch::aot_inductor::tensor_handle_to_tensor_pointer(arg.get()); + return _THPVariable_Wrap(*tensor_ptr); + } else if constexpr (std::is_same_v) { + PyObject* py_arg = arg ? Py_True : Py_False; + Py_INCREF(py_arg); + return py_arg; + } else if constexpr (std::is_integral_v) { + return PyLong_FromLongLong(static_cast(arg)); + } else if constexpr (std::is_floating_point_v) { + return PyFloat_FromDouble(static_cast(arg)); + } else { + AOTI_TORCH_CHECK(false, "Invalid input type to convertArgToPython"); + } +} + +template +static inline LazyKernelCompileResult runTritonKernelWithAutotune( + PyObject* pending_kernels, + const std::string& kernel_name, + void* stream, + const Args&... kernel_args) { + py::gil_scoped_acquire_simple acquire; + + constexpr size_t num_args = sizeof...(Args); + RAIIPyObject py_args_list = PyList_New(num_args); + AOTI_TORCH_CHECK(py_args_list, "Failed to create args list"); + + size_t idx = 0; + auto add_arg = [&py_args_list, &idx](PyObject* py_arg) { + AOTI_TORCH_CHECK(py_arg, "Failed to convert argument"); + PyList_SetItem(py_args_list, idx++, py_arg); + }; + // Use array pack-expansion instead of a fold expression to avoid + // hitting the compiler's expression-nesting limit when there are + // hundreds of kernel arguments (e.g. combo kernels). + int dummy[] = {0, (add_arg(convertArgToPython(kernel_args)), 0)...}; + (void)dummy; + + RAIIPyObject call_args = PyTuple_Pack( + 4, + pending_kernels, + PyUnicode_FromString(kernel_name.c_str()), + PyLong_FromVoidPtr(stream), + py_args_list.get()); + AOTI_TORCH_CHECK(call_args, "Failed to create call args"); + + RAIIPyObject result = + PyObject_CallObject(run_triton_kernel_with_autotune, call_args); + AOTI_TORCH_CHECK(result, "Failed to run kernel with autotuning"); + + return extractCompileResult(result); +} + +static inline void startKernelCompile( + PyObject* pending_kernels, + const std::string& kernel_name, + const std::string& kernel_source) { + py::gil_scoped_acquire_simple acquire; + + RAIIPyObject py_name = PyUnicode_FromString(kernel_name.c_str()); + RAIIPyObject py_source = PyUnicode_FromString(kernel_source.c_str()); + AOTI_TORCH_CHECK(py_name && py_source, "Failed to create Python args"); + + RAIIPyObject call_args = + PyTuple_Pack(3, pending_kernels, py_name.get(), py_source.get()); + AOTI_TORCH_CHECK(call_args, "Failed to create call args"); + + RAIIPyObject result = PyObject_CallObject(start_kernel_compile, call_args); + AOTI_TORCH_CHECK(result, "Failed to start kernel compilation"); +} diff --git a/torch/csrc/inductor/cpp_wrapper/xpu.h b/torch/csrc/inductor/cpp_wrapper/xpu.h index e26dea0f3b6e2..160788dfa1bab 100644 --- a/torch/csrc/inductor/cpp_wrapper/xpu.h +++ b/torch/csrc/inductor/cpp_wrapper/xpu.h @@ -2,3 +2,4 @@ #include #include +#include diff --git a/torch/csrc/inductor/inductor_ops.cpp b/torch/csrc/inductor/inductor_ops.cpp index 9723e27e6ba8a..96a41a78c00bc 100644 --- a/torch/csrc/inductor/inductor_ops.cpp +++ b/torch/csrc/inductor/inductor_ops.cpp @@ -108,4 +108,11 @@ TORCH_LIBRARY_FRAGMENT(inductor, m) { {at::Tag::pt2_compliant_tag}); } +TORCH_LIBRARY_FRAGMENT(inductor_prims, m) { + m.def( + "inductor_reserve_rng_state(Generator? generator, SymInt increment) " + "-> (Tensor, Tensor, Tensor)", + {at::Tag::pt2_compliant_tag}); +} + } // namespace torch::inductor diff --git a/torch/csrc/inductor/inductor_ops_gpu.cpp b/torch/csrc/inductor/inductor_ops_gpu.cpp new file mode 100644 index 0000000000000..2133dbcb4f796 --- /dev/null +++ b/torch/csrc/inductor/inductor_ops_gpu.cpp @@ -0,0 +1,89 @@ +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#include +#endif + +#if defined(USE_CUDA) || defined(USE_ROCM) +#include +#endif + +namespace torch::inductor { +using namespace at; + +#if defined(USE_CUDA) || defined(USE_ROCM) + +// Reserves RNG state for Inductor with CUDA Graph support. +// +// This function allows Inductor to reserve a specific amount of RNG offset +// (increment) for a kernel. It is designed to be safe for CUDA Graph capture +// by explicitly handling the internal generator state via public APIs. +// +// Behavior: +// - Graph Mode: Advances the generator state and returns pointers (wrapped as +// tensors) to the extragraph state. These tensors effectively point to the +// GPU memory that will be updated by `replay_prologue`. +// - Eager Mode: Advances the generator state and returns concrete values +// wrapped in 1D tensors to maintain shape consistency. +// +// -param gen The CUDA generator to use. +// -param increment The number of RNG values to reserve. +// -return A tuple of (Seed Tensor, Offset Tensor, Intragraph Offset CPU +// Tensor). +static std::tuple inductor_reserve_rng_state_impl( + const Generator& generator, + c10::SymInt increment) { + auto* gen_impl = at::check_generator(generator); + + const auto dev_opts = + at::TensorOptions().dtype(at::kLong).device(generator.device()); + const auto cpu_opts = at::TensorOptions().dtype(at::kLong).device(at::kCPU); + + int64_t inc = increment.expect_int(); + const at::PhiloxCudaState st = + gen_impl->philox_cuda_state(static_cast(inc)); + + if (st.captured_) { + auto seed_t = at::from_blob( + static_cast(st.seed_.ptr), {1}, [](void*) {}, dev_opts); + auto off_t = at::from_blob( + static_cast(st.offset_.ptr), {1}, [](void*) {}, dev_opts); + auto intra_t = + at::scalar_tensor(static_cast(st.offset_intragraph_), cpu_opts) + .unsqueeze(0); + return {seed_t, off_t, intra_t}; + } + + auto seed_t = at::scalar_tensor(static_cast(st.seed_.val), dev_opts) + .unsqueeze(0); + auto off_t = at::scalar_tensor(static_cast(st.offset_.val), dev_opts) + .unsqueeze(0); + auto intra_t = at::zeros({1}, cpu_opts); + return {seed_t, off_t, intra_t}; +} + +TORCH_LIBRARY_IMPL(inductor_prims, BackendSelect, m) { + m.impl( + "inductor_reserve_rng_state", TORCH_FN(inductor_reserve_rng_state_impl)); +} + +TORCH_LIBRARY_IMPL(inductor_prims, CUDA, m) { + m.impl( + "inductor_reserve_rng_state", TORCH_FN(inductor_reserve_rng_state_impl)); +} + +TORCH_LIBRARY_IMPL(inductor_prims, HIP, m) { + m.impl( + "inductor_reserve_rng_state", TORCH_FN(inductor_reserve_rng_state_impl)); +} + +#endif + +} // namespace torch::inductor diff --git a/torch/csrc/inductor/static_launcher/cuda.cpp b/torch/csrc/inductor/static_launcher/cuda.cpp index a2378b7c1a248..6b1d09863aa4c 100644 --- a/torch/csrc/inductor/static_launcher/cuda.cpp +++ b/torch/csrc/inductor/static_launcher/cuda.cpp @@ -64,7 +64,7 @@ CUdeviceptr getPointer(PyObject* obj) { return data_ptr; } - if (obj == Py_None) { + if (Py_IsNone(obj)) { // valid nullptr return data_ptr; } diff --git a/torch/csrc/inductor/static_launcher/xpu.cpp b/torch/csrc/inductor/static_launcher/xpu.cpp index d3a5b11b00eea..a8c90b36549ea 100644 --- a/torch/csrc/inductor/static_launcher/xpu.cpp +++ b/torch/csrc/inductor/static_launcher/xpu.cpp @@ -61,7 +61,7 @@ syclDevicePtr_t getPointer( return data_ptr; } - if (obj == Py_None) { + if (Py_IsNone(obj)) { // valid nullptr return data_ptr; } diff --git a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm index 7823d066bafc2..96828ab39e5d4 100644 --- a/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm +++ b/torch/csrc/jit/backends/coreml/objc/PTMCoreMLBackend.mm @@ -213,8 +213,9 @@ bool is_available() override { #elif TARGET_OS_MAC NSOperatingSystemVersion supportedVer = {10, 13, 0}; return [[NSProcessInfo processInfo] isOperatingSystemAtLeastVersion:supportedVer]; -#endif +#else return false; +#endif } }; diff --git a/torch/csrc/jit/frontend/schema_type_parser.cpp b/torch/csrc/jit/frontend/schema_type_parser.cpp index 1884ac61f8c96..960fc9c260666 100644 --- a/torch/csrc/jit/frontend/schema_type_parser.cpp +++ b/torch/csrc/jit/frontend/schema_type_parser.cpp @@ -108,6 +108,9 @@ TypePtr SchemaTypeParser::parseBaseType() { {"Any", c10::TypeFactory::get()}, {"AnyClassType", c10::TypeFactory::get()}, {"AnyEnumType", c10::TypeFactory::get()}, + // PyObjectType::get() used directly because PyObjectType is excluded + // from FORALL_DYNAMIC_TYPES (not supported on xplat/mobile) + {"PyObject", c10::PyObjectType::get()}, }; auto tok = L.cur(); if (!L.nextIf(TK_NONE) && !L.nextIf(TK_NONE_TYPE)) { diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 9b85b0909826f..902d8ffad9278 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -929,12 +929,7 @@ void Value::replaceAllUsesAfterNodeWith(const Node* node, Value* newValue) { } }); - uses_.erase( - std::remove_if( - uses_.begin(), - uses_.end(), - [&node](const Use& u) { return u.user->isAfter(node); }), - uses_.end()); + std::erase_if(uses_, [&node](const Use& u) { return u.user->isAfter(node); }); } void Value::replaceAllUsesDominatedByNodeWith( @@ -947,12 +942,8 @@ void Value::replaceAllUsesDominatedByNodeWith( } }); - uses_.erase( - std::remove_if( - uses_.begin(), - uses_.end(), - [&node](const Use& u) { return u.user->isDominatedBy(node); }), - uses_.end()); + std::erase_if( + uses_, [&node](const Use& u) { return u.user->isDominatedBy(node); }); } static size_t findArgument( @@ -1218,7 +1209,6 @@ bool Node::hasSideEffects() const { return true; } TORCH_INTERNAL_ASSERT(false, "Unhandled AliasAnalysisKind case"); - return false; // silence compiler warning } // Assign this node a topological position, to facilitate fast isBefore() and diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index a0e0959d6033d..9a0d976df8042 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -387,7 +387,6 @@ bool InterpreterState::run(Stack& stack) { // } // } } - return false; } IValue& InterpreterState::reg(size_t reg) { diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index f9287a5eb7040..a4d476070c12d 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -158,7 +158,6 @@ TypePtr TypeParser::parse() { " is not supported in the parser, ", "or the token is in wrong format."); } - return nullptr; } // NamedTuple custom type will be following structure: @@ -243,7 +242,6 @@ TypePtr TypeParser::parseCustomType() { TORCH_CHECK( false, "Can't find definition for the type: ", qualified_name); } - return nullptr; } } diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 79ead2c5ee6c3..7b638b85f9da1 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -133,8 +133,6 @@ std::optional parseAutocast( // TORCH_CHECK(false, "Unsupported autocast syntax"); } - - return std::nullopt; } void castTensorInputs( diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp index 3df14c2c75cdd..98de08357a4f2 100644 --- a/torch/csrc/jit/passes/onnx.cpp +++ b/torch/csrc/jit/passes/onnx.cpp @@ -431,7 +431,7 @@ void NodeToONNX( auto processSymbolicOutput = [&](const std::string& op_name, Node* n, const py::object& raw_output) { - if (raw_output.ptr() == Py_None) { + if (Py_IsNone(raw_output.ptr())) { cloneNode(n); return; } diff --git a/torch/csrc/jit/passes/onnx/constant_map.cpp b/torch/csrc/jit/passes/onnx/constant_map.cpp index 60699a1e75ef4..c332ea4204eb9 100644 --- a/torch/csrc/jit/passes/onnx/constant_map.cpp +++ b/torch/csrc/jit/passes/onnx/constant_map.cpp @@ -120,28 +120,25 @@ std::optional> ConstantValueMap:: GetShapeInto1DInt64VectorWithOneUnknown(const std::string& value_name) { if (ConstantValueMap::HasShape(value_name)) { auto shape_size = ConstantValueMap::GetShape(value_name).value(); - std::vector shape_value; if (shape_size.isComplete()) { - shape_value = - ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size); - return shape_value; - } else { - size_t count_unknown = 0; - auto shape_size_sizes = shape_size.sizes(); - if (shape_size_sizes.has_value()) { - auto shape_symbol_list = shape_size_sizes.value(); - for (const auto& v : shape_symbol_list) { - if (v.is_static()) { - shape_value.emplace_back(v.static_size()); - } else { - shape_value.emplace_back(-1); - count_unknown += 1; - } - } - if (count_unknown == 1) { - return shape_value; + return ConstantValueMap::GetCompleteShapeInto1DInt64Vector(shape_size); + } + size_t count_unknown = 0; + auto shape_size_sizes = shape_size.sizes(); + if (shape_size_sizes.has_value()) { + std::vector shape_value; + auto shape_symbol_list = shape_size_sizes.value(); + for (const auto& v : shape_symbol_list) { + if (v.is_static()) { + shape_value.emplace_back(v.static_size()); + } else { + shape_value.emplace_back(-1); + count_unknown += 1; } } + if (count_unknown == 1) { + return shape_value; + } } } return std::nullopt; diff --git a/torch/csrc/jit/passes/onnx/function_extraction.cpp b/torch/csrc/jit/passes/onnx/function_extraction.cpp index c82fd3d63c9af..eee0251f4153b 100644 --- a/torch/csrc/jit/passes/onnx/function_extraction.cpp +++ b/torch/csrc/jit/passes/onnx/function_extraction.cpp @@ -710,14 +710,8 @@ void FunctionExtractor::ConvertScopeToFunction( ctx_nlist.insert(last_n_it, func_n); // remove replaced nodes from list - ctx_nlist.erase( - std::remove_if( - ctx_nlist.begin(), - ctx_nlist.end(), - [&old_nodes](Node* n) { - return old_nodes.find(n) != old_nodes.end(); - }), - ctx_nlist.end()); + std::erase_if( + ctx_nlist, [&old_nodes](Node* n) { return old_nodes.contains(n); }); GRAPH_DEBUG("Parent total nodes after remove: ", ctx_nlist.size()); diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 3897a8d5cae5e..c2ee1d7dc285f 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -39,14 +39,8 @@ void eraseUnusedBlockInputs(Block* b) { } void eraseUnusedValuesFromMap(ValueToParamPairMap& valsToParamsMap) { - auto it = valsToParamsMap.begin(); - while (it != valsToParamsMap.end()) { - if (!it->first->hasUses()) { - it = valsToParamsMap.erase(it); - } else { - ++it; - } - } + std::erase_if( + valsToParamsMap, [](const auto& pr) { return !pr.first->hasUses(); }); } void buildParamsMapFromValueToParamsMap( @@ -101,7 +95,6 @@ std::optional ONNXTypeToATenType(int32_t onnx_type) { onnx_type, " is an unexpected tensor scalar type"); } - return std::optional{}; } Node* addNodeToBlock(Block* block, Symbol kind, ArrayRef inputs) { diff --git a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp index 686f6e660dba7..79791e0219a4b 100644 --- a/torch/csrc/jit/passes/onnx/shape_type_inference.cpp +++ b/torch/csrc/jit/passes/onnx/shape_type_inference.cpp @@ -23,7 +23,7 @@ namespace torch::jit { static inline bool PyNone_Check(PyObject* o) { - return o == Py_None; + return Py_IsNone(o); } std::pair MergeInferredType( diff --git a/torch/csrc/jit/passes/symbolic_shape_cache.cpp b/torch/csrc/jit/passes/symbolic_shape_cache.cpp index 0cca03d6f74d0..11c79dedcd0e4 100644 --- a/torch/csrc/jit/passes/symbolic_shape_cache.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_cache.cpp @@ -92,7 +92,7 @@ ShapeCacheKey get_cache_key( std::unordered_map& ss_map, bool deep_copy = true) { CanonicalArgVec canonical_args = cannonicalizeVec(arg_vec, ss_map, deep_copy); - return std::make_tuple(schema->operator_name(), canonical_args); + return std::make_tuple(schema->operator_name(), std::move(canonical_args)); } } // namespace diff --git a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp index c8c6953f3447f..3d858bb39d9dc 100644 --- a/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp +++ b/torch/csrc/jit/passes/symbolic_shape_runtime_fusion.cpp @@ -215,12 +215,11 @@ static void moveConstantTensorsOutOfSubgraph( const std::shared_ptr& tensorexpr_graph) { auto parent = tensorexpr_graph_node->owningGraph(); - auto env = [&](Value* v) { + auto env = [&](Value* v) -> Value* { TORCH_INTERNAL_ASSERT( false, "this should never happen since constant nodes do not have any inputs", v->debugName()); - return v; }; WithInsertPoint wip(tensorexpr_graph_node); diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp index 6f92e821e5b44..f1b6a1f875134 100644 --- a/torch/csrc/jit/passes/utils/subgraph_utils.cpp +++ b/torch/csrc/jit/passes/utils/subgraph_utils.cpp @@ -516,7 +516,6 @@ void unmergeNode(Node* n, Node* subgraphNode) { false, "all inputs should've been mapped. Couldn't map %", v->debugName()); - return v; }; for (auto i : c10::irange(subgraph->outputs().size())) { diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 931d80ffd6a68..67c6349f64efd 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1796,7 +1796,7 @@ void initJITBindings(PyObject* module) { m.def( "_jit_get_operation", - [](const std::string& op_name) { + [](const std::string& op_name) -> py::tuple { try { auto symbol = Symbol::fromQualString(op_name); const auto sortedOps = getAllSortedOperatorsFor(symbol); @@ -1843,7 +1843,7 @@ void initJITBindings(PyObject* module) { "_maybe_call_torch_function_for_op_packet", [](py::handle op_overload_packet, const py::args& args, - const py::kwargs& kwargs) { + const py::kwargs& kwargs) -> py::tuple { py::list ns_method = op_overload_packet.attr("_qualified_op_name").attr("split")("::"); auto res = _maybe_handle_torch_function( @@ -2183,20 +2183,12 @@ void initJITBindings(PyObject* module) { .def( py::pickle( /* __getstate__ */ - [](const PythonFutureWrapper& /* unused */) { + [](const PythonFutureWrapper& /* unused */) -> py::tuple { TORCH_CHECK(false, "Can not pickle torch.futures.Future"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy Pybind API's - // requirement. - return py::make_tuple(); }, /* __setstate__ */ - [](const py::tuple& /* unused */) { + [](const py::tuple& /* unused */) -> std::nullptr_t { TORCH_CHECK(false, "Can not unpickle torch.futures.Future"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy PyBind's API - // requirement. - return nullptr; }), py::call_guard()); @@ -2220,20 +2212,12 @@ void initJITBindings(PyObject* module) { .def( py::pickle( /* __getstate__ */ - [](const PythonAwaitWrapper& /* unused */) { + [](const PythonAwaitWrapper& /* unused */) -> py::tuple { TORCH_CHECK(false, "Can not pickle torch.jit._Await"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy Pybind API's - // requirement. - return py::make_tuple(); }, /* __setstate__ */ - [](const py::tuple& /* unused */) { + [](const py::tuple& /* unused */) -> std::nullptr_t { TORCH_CHECK(false, "Can not unpickle torch.jit._Await"); - // Note that this return has no meaning since we always - // throw, it's only here to satisfy PyBind's API - // requirement. - return nullptr; }), py::call_guard()); diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index cf422102f7dec..83d56ffdc03a0 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -74,7 +74,7 @@ static IValue listToIValue(py::handle obj) { IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { switch (type->kind()) { case TypeKind::TensorType: { - if (obj.ptr() == Py_None) { + if (Py_IsNone(obj.ptr())) { // None gets converted to undefined Tensors return autograd::Variable(); } @@ -543,9 +543,15 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { } case TypeKind::CapsuleType: { #ifdef USE_DISTRIBUTED - // Handle ProcessGroup custom class as a capsule + // Handle ProcessGroup custom class as a capsule. FakeScriptObject + // (used during Dynamo tracing with CooR) passes py::isinstance via + // OpaqueBaseMeta but cannot be cast directly; unwrap real_obj first. if (py::isinstance(obj)) { - auto cpp_obj = obj.cast>(); + py::handle target = obj; + if (py::hasattr(obj, "real_obj")) { + target = obj.attr("real_obj"); + } + auto cpp_obj = target.cast>(); return IValue::make_capsule(cpp_obj); } #endif @@ -1006,6 +1012,18 @@ std::optional detail::_tryToInferTypeImpl(py::handle input) { if (py::isinstance(input)) { return InferredType(CapsuleType::get()); } + // During Dynamo tracing with compile-on-one-rank (CooR), opaque reference + // types like ProcessGroup are wrapped in FakeScriptObject. Python-level + // isinstance() sees through the wrapper (via OpaqueBaseMeta), but the C++ + // py::isinstance above does too — yet the subsequent pybind11 cast would + // fail because FakeScriptObject is not a C++ bound object. Detect this + // case by checking for the wrapped real_obj attribute. + if (py::hasattr(input, "real_obj")) { + py::object real = input.attr("real_obj"); + if (py::isinstance(real)) { + return InferredType(CapsuleType::get()); + } + } #endif return std::nullopt; diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index 378bf4636fd2e..37e49c2952620 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -23,7 +23,6 @@ #include #include #include -#include #ifdef USE_DISTRIBUTED #include #include @@ -429,6 +428,14 @@ inline InferredType tryToInferType(py::handle input) { return InferredType(IntType::get()); } + // Check for types registered in _tryToInferTypeImpl (e.g. ProcessGroup) + // before falling through to the expensive inspect.isclass / JIT compilation + // path below. + auto ty = detail::_tryToInferTypeImpl(input); + if (ty.has_value()) { + return ty.value(); + } + auto enum_type = py::module::import("enum").attr("Enum"); py::bool_ isEnumValue = py::isinstance(input, enum_type); if (py::cast(isEnumValue)) { @@ -518,11 +525,6 @@ inline InferredType tryToInferType(py::handle input) { return InferredType("Cannot infer concrete type of torch.nn.Module"); } - auto ty = detail::_tryToInferTypeImpl(input); - if (ty.has_value()) { - return ty.value(); - } - // Try container types return tryToInferContainerType(input, false); } @@ -554,7 +556,7 @@ inline InferredType tryToInferPrimitiveType(py::handle input) { inline InferredType tryToInferContainerType( py::handle input, bool primitiveTypeOnly = false) { - if (six::isTuple(input)) { + if (PyTuple_Check(input.ptr())) { py::tuple tuple = py::cast(input); std::vector element_types; element_types.reserve(tuple.size()); diff --git a/torch/csrc/jit/python/python_arg_flatten.cpp b/torch/csrc/jit/python/python_arg_flatten.cpp index b71f21d043a31..6615f475a1f83 100644 --- a/torch/csrc/jit/python/python_arg_flatten.cpp +++ b/torch/csrc/jit/python/python_arg_flatten.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include @@ -29,7 +29,7 @@ static constexpr char NoneType = 'n'; namespace { inline bool PyNone_Check(PyObject* o) { - return o == Py_None; + return Py_IsNone(o); } template @@ -44,7 +44,7 @@ py::object cast_handle_sequence(std::vector objs) { void flatten_rec(PyObject* obj, ParsedArgs& args) { auto& structure = args.desc.structure; - if (six::isTuple(obj)) { + if (PyTuple_Check(obj)) { structure.push_back(D::TupleOpen); for (auto item : py::reinterpret_borrow(obj)) flatten_rec(item.ptr(), args); diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 5cf3bd900f351..4856404750af5 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -113,7 +113,7 @@ std::pair, Stack> createGraphByTracingWithDict( // We just leave the inputs_dict as it was and pass it to forward // method. auto out = func(**inputs_dict); - if (out.ptr() == Py_None) { + if (Py_IsNone(out.ptr())) { TORCH_CHECK( false, "The traced function didn't return any values! Side-effects are not " @@ -154,7 +154,7 @@ std::pair, Stack> createGraphByTracing( py_inputs[i] = py::cast(inputs[i]); } auto out = func(*py_inputs); - if (out.ptr() == Py_None) { + if (Py_IsNone(out.ptr())) { TORCH_CHECK( false, "The traced function didn't return any values! Side-effects are not " diff --git a/torch/csrc/jit/runtime/graph_executor.cpp b/torch/csrc/jit/runtime/graph_executor.cpp index 4bdab8c5dcb22..fa37ccab96d06 100644 --- a/torch/csrc/jit/runtime/graph_executor.cpp +++ b/torch/csrc/jit/runtime/graph_executor.cpp @@ -635,6 +635,9 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { const ExecutionPlan& getPlanFor( Stack& stack, std::optional remaining_bailout_depth) override { + if (FLAGS_torch_jit_input_independent_optimization) { + return getInputIndependentPlan(); + } return getGraphExecutorOptimize() ? getOrCompile(stack) : getOrCompileFallback(); } @@ -778,6 +781,38 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { return ExecutionPlan(opt_graph, function_name_); } + const ExecutionPlan& getInputIndependentPlan() override { + std::lock_guard lock(compile_mutex); + if (!input_independent_plan_) { + auto opt_graph = graph->copy(); + Inline(*opt_graph); + LowerGradOf(*opt_graph); + specializeAutogradZero(opt_graph); + LowerSimpleTuples(opt_graph); + ConstantPooling(opt_graph); + runRequiredPasses(opt_graph); + ConstantPropagation(opt_graph); + // Skip PropagateInputShapes and PropagateRequiresGrad since they need + // actual input data. + runOptimization(opt_graph); + + // Input-independent passes from runNondiffOptimization. Skipped: + // FuseTensorExprs/FuseGraph (need specialized tensor types). + for (const auto& passPair : getCustomPrePasses()) { + passPair.first(opt_graph); + } + DecomposeOps(opt_graph); + BatchMM(opt_graph); + for (const auto& passPair : getCustomPostPasses()) { + passPair.first(opt_graph); + } + + EliminateDeadCode(opt_graph); + input_independent_plan_ = ExecutionPlan(opt_graph, function_name_); + } + return *input_independent_plan_; + } + ~GraphExecutorImpl() override = default; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) @@ -787,6 +822,10 @@ struct GraphExecutorImpl : public GraphExecutorImplBase { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) ExecutionPlan fallback; + // Cached plan from getOptimizedPlan() -- uses only input-independent passes. + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + std::optional input_independent_plan_; + // Mapping from argument configurations to optimized versions of the graph // that are specialized to the spec. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) @@ -842,6 +881,10 @@ const ExecutionPlan& GraphExecutor::getPlanFor( return pImpl->getPlanFor(inputs, remaining_bailout_depth); } +const ExecutionPlan& GraphExecutor::getInputIndependentPlan() { + return pImpl->getInputIndependentPlan(); +} + GraphExecutorState GraphExecutor::getDebugState() { return pImpl->getDebugState(); } diff --git a/torch/csrc/jit/runtime/graph_executor.h b/torch/csrc/jit/runtime/graph_executor.h index d1039216de3ea..94ed97f1784e8 100644 --- a/torch/csrc/jit/runtime/graph_executor.h +++ b/torch/csrc/jit/runtime/graph_executor.h @@ -87,6 +87,11 @@ struct TORCH_API GraphExecutor { const ExecutionPlan& getPlanFor( Stack& inputs, std::optional remaining_bailout_depth = std::nullopt); + // Returns an optimized execution plan without requiring input arguments. + // Runs input-independent optimization passes (e.g. inlining, constant + // propagation, peephole, CSE) but skips profiling-based specializations + // that require runtime type/shape information. + const ExecutionPlan& getInputIndependentPlan(); GraphExecutorState getDebugState(); void debugFlushCompilationCache(); diff --git a/torch/csrc/jit/runtime/graph_executor_impl.h b/torch/csrc/jit/runtime/graph_executor_impl.h index 70069ac1907b0..78c7a7b79d982 100644 --- a/torch/csrc/jit/runtime/graph_executor_impl.h +++ b/torch/csrc/jit/runtime/graph_executor_impl.h @@ -79,6 +79,11 @@ struct GraphExecutorImplBase { virtual const ExecutionPlan& getPlanFor( Stack& stack, std::optional remaining_bailout_depth = std::nullopt) = 0; + // Returns an optimized execution plan without requiring input arguments. + // Runs input-independent optimization passes (e.g. inlining, constant + // propagation, peephole, CSE) but skips profiling-based specializations + // that require runtime type/shape information. + virtual const ExecutionPlan& getInputIndependentPlan() = 0; virtual GraphExecutorState getDebugState() = 0; virtual ~GraphExecutorImplBase() = default; diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 680244b363c36..5ae1ec45050a0 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -56,6 +56,13 @@ C10_DEFINE_bool( false, "fuse on 12 dynamic compilations") +C10_DEFINE_bool( + torch_jit_input_independent_optimization, + false, + "If set, getPlanFor will use input-independent optimization passes only, " + "skipping profiling-based specializations that require runtime type/shape " + "information. Useful for predictor nets and AOT compilation scenarios.") + C10_DEFINE_bool( torch_jit_release_profiling_graph_after_optimization, false, @@ -111,8 +118,9 @@ static FusionStrategy getInitialStrategy() { // TODO remove ifdef #ifdef FBCODE_CAFFE2 return {{FusionBehavior::STATIC, 20}}; -#endif +#else return mixed; +#endif } // defer initial value so that we can load in gflags @@ -714,10 +722,59 @@ const ExecutionPlan& ProfilingGraphExecutorImpl::getPlanFor( } return *optimized_plan_; } + if (FLAGS_torch_jit_input_independent_optimization) { + return getInputIndependentPlanImpl(); + } // if depth is not set, use return getOptimizedPlanFor(stack, remaining_bailout_depth); } +const ExecutionPlan& ProfilingGraphExecutorImpl::getInputIndependentPlan() { + std::lock_guard lock(compile_mutex); + return getInputIndependentPlanImpl(); +} + +const ExecutionPlan& +ProfilingGraphExecutorImpl::getInputIndependentPlanImpl() { + if (optimized_plan_) { + return *optimized_plan_; + } + + auto copy = graph->copy(); + if (!getGraphExecutorOptimize() || !getProfilingMode()) { + LowerGradOf(*copy); + RemoveExpands(copy); + } else { + // Run all input-independent optimizations. This includes + // runProfilingInsensitiveOptimizations (inlining, constant propagation, + // CSE, peephole, etc.) followed by runPreAutodiffPassPipeline which adds + // loop unrolling, list mutation removal, and additional rounds of + // peephole + constant propagation. Profiling-dependent passes (type + // specialization, fusion, autodiff guards) are skipped since they + // require runtime type/shape information from actual inputs. + runProfilingInsensitiveOptimizations(copy); + runPreAutodiffPassPipeline(copy); + + // Run additional input-independent passes from runNoGradOptimizations + // and runFinalOptimizations. These operate on graph structure (node + // patterns, alias analysis) and do not require profiled type/shape info. + // Skipped: RemoveProfileNodesAndSpecializeTypes, FuseTensorExprs, + // FuseGraph (these need profiled tensor types for specialization/fusion). + for (const auto& passPair : getCustomPrePasses()) { + passPair.first(copy); + } + BatchMM(copy); + for (const auto& passPair : getCustomPostPasses()) { + passPair.first(copy); + } + AddIfThenElseOp(copy); + EliminateDeadCode(copy); + } + optimized_plan_ = ExecutionPlan(copy, function_name_); + time_optimized_plan_created_ = getNowInSecs(); + return *optimized_plan_; +} + GraphExecutorState ProfilingGraphExecutorImpl::getDebugState() { GraphExecutorState state; TORCH_INTERNAL_ASSERT(optimized_plan_); diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h index da437bb456e92..422a7646cef8a 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.h @@ -7,6 +7,7 @@ TORCH_DECLARE_bool(torch_jit_static_then_dynamic); TORCH_DECLARE_bool(torch_jit_always_dynamic); +C10_DECLARE_bool(torch_jit_input_independent_optimization); C10_DECLARE_bool(torch_jit_release_profiling_graph_after_optimization); C10_DECLARE_int32(torch_jit_release_profiling_graph_delay_in_seconds); C10_DECLARE_int64(torch_jit_num_profiled_runs); @@ -24,6 +25,7 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { const ExecutionPlan& getPlanFor( Stack& stack, std::optional remaining_bailout_depth) override; + const ExecutionPlan& getInputIndependentPlan() override; GraphExecutorState getDebugState() override; ~ProfilingGraphExecutorImpl() override = default; @@ -37,6 +39,8 @@ struct TORCH_API ProfilingGraphExecutorImpl : public GraphExecutorImplBase { const ExecutionPlan& getOptimizedPlanFor( Stack& stack, std::optional remaining_bailout_depth); + // Input-independent optimization, assumes compile_mutex is held. + const ExecutionPlan& getInputIndependentPlanImpl(); void runProfilingInsensitiveOptimizations(std::shared_ptr& graph); void runProfilingOptimizations( std::shared_ptr& graph, diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index adbfdb46f0932..dfd6fba02de4f 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -132,7 +132,6 @@ void checkDoubleInRange(double a) { a < double(std::numeric_limits::min())) { throw c10::Error( "Cannot convert float " + std::to_string(a) + " to integer"); - return; } } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index 1f168d24e8adf..e512853a57567 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -583,56 +583,6 @@ static const std::vector opGenArgs{ push(stack, arg.element_size()); }, aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("aten::numel(Tensor self) -> int"), - [](Stack& stack) { - at::Tensor arg = pop(stack).toTensor(); - push(stack, arg.numel()); - }, - aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("aten::dim(Tensor self) -> int"), - dim, - aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("aten::get_device(Tensor self) -> int"), - [](Stack& stack) { - RECORD_FUNCTION("get_device", c10::ArrayRef{}); - auto result = - at::get_device((std::move(peek(stack, 0, 1))).toTensor()); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("aten::storage_offset(Tensor self) -> int"), - [](Stack& stack) { - RECORD_FUNCTION("storage_offset", c10::ArrayRef{}); - auto result = - ((std::move(peek(stack, 0, 1))).toTensor()).storage_offset(); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA("aten::is_contiguous(Tensor self) -> bool"), - [](Stack& stack) { - RECORD_FUNCTION("is_contiguous", c10::ArrayRef{}); - auto result = - ((std::move(peek(stack, 0, 1))).toTensor()).is_contiguous(); - drop(stack, 1); - pack(stack, result); - }, - aliasAnalysisFromSchema()), - OperatorGeneratorArgs( - TORCH_SELECTIVE_SCHEMA( - "aten::is_contiguous.memory_format(Tensor self, MemoryFormat memory_format) -> bool"), - [](Stack& stack) { - auto memory_format = pop(stack).toMemoryFormat(); - auto t = pop(stack).toTensor(); - push(stack, t.is_contiguous(memory_format)); - }, - aliasAnalysisFromSchema()), OperatorGeneratorArgs( // NB: intentionally suffixed with extra _format to prevent tests for // "_like" suffix from triggering on this diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index c7343914cb639..a9fd6dc1e33c2 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -324,6 +324,16 @@ RegisterOperators reg({ push(stack, at::infer_size(a.toDimVector(), b.toDimVector())); }, aliasAnalysisFromSchema()), + // Not gated by TORCH_SELECTIVE_SCHEMA because selective build + // allowlists won't include this op until models are re-traced. + Operator( + "aten::broadcast_shapes(int[] a, int[] b) -> int[]", + [](Stack& stack) { + auto a = pop(stack); + auto b = pop(stack); + push(stack, at::infer_size(a.toDimVector(), b.toDimVector())); + }, + aliasAnalysisFromSchema()), OperatorGenerator( TORCH_SELECTIVE_SCHEMA( "aten::_no_grad_embedding_renorm_(Tensor weight, Tensor input, float max_norm, float norm_type) -> Tensor"), diff --git a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp index fd908b48ee043..4671ab0fdcaf7 100644 --- a/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/simple_graph_executor_impl.cpp @@ -28,6 +28,22 @@ const ExecutionPlan& SimpleGraphExecutorImpl::getPlanFor( return *execution_plan_; } +const ExecutionPlan& SimpleGraphExecutorImpl::getInputIndependentPlan() { + std::lock_guard lock(compile_mutex); + return getInputIndependentPlanImpl(); +} + +const ExecutionPlan& SimpleGraphExecutorImpl::getInputIndependentPlanImpl() { + if (execution_plan_) { + return *execution_plan_; + } + auto copy = graph->copy(); + runNooptPassPipeline(copy); + execution_plan_ = ExecutionPlan(copy, function_name_); + + return *execution_plan_; +} + GraphExecutorState SimpleGraphExecutorImpl::getDebugState() { GraphExecutorState state; TORCH_INTERNAL_ASSERT(execution_plan_); diff --git a/torch/csrc/jit/runtime/simple_graph_executor_impl.h b/torch/csrc/jit/runtime/simple_graph_executor_impl.h index e1ebed46ede80..d972432b6883e 100644 --- a/torch/csrc/jit/runtime/simple_graph_executor_impl.h +++ b/torch/csrc/jit/runtime/simple_graph_executor_impl.h @@ -13,10 +13,13 @@ struct TORCH_API SimpleGraphExecutorImpl : public GraphExecutorImplBase { const ExecutionPlan& getPlanFor( Stack& stack, std::optional remaining_bailout_depth) override; + const ExecutionPlan& getInputIndependentPlan() override; GraphExecutorState getDebugState() override; ~SimpleGraphExecutorImpl() override = default; private: + const ExecutionPlan& getInputIndependentPlanImpl(); + std::optional execution_plan_; }; diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 59ed5281db6bc..c1d35d6069d9f 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -1446,7 +1446,6 @@ void check_onnx_proto(const std::string& proto_string) { onnx::ModelProto model; if (!ParseProtoFromBytes(&model, proto_string.c_str(), proto_string.size())) { throw std::runtime_error("Invalid ONNX proto string."); - return; } // 1. baseline check // These two checks prevent broken graph being generated diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp index b2a0d5e6d73dd..967cfea3a55cc 100644 --- a/torch/csrc/jit/tensorexpr/codegen.cpp +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -103,7 +103,6 @@ void* CodeGen::argToPtr(const BufferArg& bufferArg, const CallArg& callArg) { default: throw unsupported_dtype(); } - return nullptr; } void CodeGen::call_with_numel(void** args, int64_t numel) { diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp index 1709ef2bbff5a..9c8bd175a50e6 100644 --- a/torch/csrc/jit/tensorexpr/eval.cpp +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -20,7 +20,6 @@ int64_t InterpValue::intValue() const { AT_FORALL_INT_TYPES(TYPE_CASE); #undef TYPE_CASE throw unsupported_dtype(); - return 0; } template @@ -662,7 +661,6 @@ class SimpleIREvaluatorImpl : public IRVisitor { default: throw unsupported_dtype(); } - return {}; } void check_bounds_throw(int64_t idx, int64_t bound, const BufPtr& buf) { diff --git a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp index 53eca369cc5ac..7d918009e2421 100644 --- a/torch/csrc/jit/tensorexpr/ir_simplifier.cpp +++ b/torch/csrc/jit/tensorexpr/ir_simplifier.cpp @@ -37,8 +37,8 @@ static inline ExprPtr newBinaryOpOfType( case IRNodeType::kRshift: return alloc(lhs, rhs); default: - LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); - return nullptr; + TORCH_INTERNAL_ASSERT( + false, "unsupported expr_type: ", static_cast(expr_type)); } } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index d696d29bf733e..84d78a665e62d 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -330,8 +330,9 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) { _pair_int(*pad), _pair_int(*dilation), groups->toInt()); -#endif +#else return false; +#endif } bool isConv2d(const Node* node) { diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index 0107a3a8ed077..4e374b9186a2c 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -186,7 +186,7 @@ using FunctionCallee = llvm::FunctionCallee; #elif LLVM_VERSION_MAJOR == 8 && LLVM_VERSION_PATCH == 20181009 struct FunctionCallee { - FunctionCallee() {} + FunctionCallee() = default; FunctionCallee(llvm::Constant* fn) : v_(fn), ft_(cast(v_)->getFunctionType()) {} diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp index 80d919a5674e6..9690e23124785 100644 --- a/torch/csrc/jit/tensorexpr/llvm_jit.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -180,11 +180,19 @@ class TORCH_API PytorchLLVMJITImpl { #if LLVM_VERSION_MAJOR < 21 , const Triple& TT +#elif LLVM_VERSION_MAJOR >= 23 + , + jitlink::JITLinkMemoryManager& + JLMM #endif ) { +#if LLVM_VERSION_MAJOR >= 23 + return std::make_unique(ES, JLMM); +#else return std::make_unique( ES, assertSuccess(jitlink::InProcessMemoryManager::Create())); +#endif }) #endif .create())) { diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp index 1bdae4ca7ae90..e0a7a64396c04 100644 --- a/torch/csrc/jit/tensorexpr/loopnest.cpp +++ b/torch/csrc/jit/tensorexpr/loopnest.cpp @@ -2830,7 +2830,6 @@ LoopNest::AccessResult LoopNest::cacheAccesses( if (reduceOp) { throw std::runtime_error( "can only cache accesses used by at most a single reduceOp"); - return {nullptr, nullptr}; } reduceOp = ro; @@ -2842,7 +2841,6 @@ LoopNest::AccessResult LoopNest::cacheAccesses( auto bounds_it = consumer_bounds_info.find(producer); if (bounds_it == consumer_bounds_info.end()) { throw std::runtime_error("consumer does not use the Tensor produced"); - return {nullptr, nullptr}; } TORCH_INTERNAL_ASSERT( diff --git a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp index 2dd46335e09f8..3b0b21f94699c 100644 --- a/torch/csrc/jit/tensorexpr/operators/conv2d.cpp +++ b/torch/csrc/jit/tensorexpr/operators/conv2d.cpp @@ -341,8 +341,9 @@ bool mkldnnPrepackedConvIsSupported( input.dims[0] * input.dims[1] * input.dims[2] * input.dims[3] > 20480; GRAPH_DEBUG("mkldnnPrepackedConvIsSupported: ", use_mkldnn); return use_mkldnn; -#endif +#else return false; +#endif } Tensor computeConv2d( diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index f633923723535..266199bd7216e 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -152,8 +152,6 @@ ExprHandle demoteOutput( default: throw unsupported_dtype(); } - - return e; } std::optional getTensorInfo(const BufHandle& b) { diff --git a/torch/csrc/jit/tensorexpr/registerizer.cpp b/torch/csrc/jit/tensorexpr/registerizer.cpp index 9ad44e31a3873..acf216c51b530 100644 --- a/torch/csrc/jit/tensorexpr/registerizer.cpp +++ b/torch/csrc/jit/tensorexpr/registerizer.cpp @@ -161,17 +161,12 @@ AccessHashMap& Scope::getAccessMapByBuf(const BufPtr& b) { } void Scope::filterClosed() { - closedAccesses_.erase( - std::remove_if( - closedAccesses_.begin(), - closedAccesses_.end(), - [](auto info) { - return info->store_cost()->isConstant() && - immediateAs(info->store_cost()) <= 1 && - info->load_cost()->isConstant() && - immediateAs(info->load_cost()) <= 1; - }), - closedAccesses_.end()); + std::erase_if(closedAccesses_, [](auto info) { + return info->store_cost()->isConstant() && + immediateAs(info->store_cost()) <= 1 && + info->load_cost()->isConstant() && + immediateAs(info->load_cost()) <= 1; + }); } // RegisterizerAnalysis diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index 57f6c1c9ec342..1873035014fc0 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -115,7 +115,6 @@ std::string Dtype::ToCppString() const { default: throw unsupported_dtype(); } - return "invalid"; } } // namespace torch::jit::tensorexpr diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index 0e792934472d1..27ef7557837e1 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -144,8 +144,6 @@ size_t assertFindRegex( extra_msg(ss); } throw std::runtime_error(ss.str()); - - return std::string::npos; } return pos; } diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 4807aa6a4c7d1..07281f7e444c7 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -272,8 +272,6 @@ void initLazyBindings(PyObject* module) { #else TORCH_CHECK( false, "TorchScript backend not yet supported in FBCODE builds"); - return std::make_pair( - std::vector(), std::vector()); #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) }); // TODO(shunting) revisit this part for XLA @@ -303,13 +301,13 @@ void initLazyBindings(PyObject* module) { for (torch::jit::IValue elem : stack) { result.push_back(elem.toTensor()); } + return result; #else TORCH_CHECK( false, "TorchScript backend not yet supported in FBCODE builds"); #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) - return result; }); - lazy_ts_backend.def("_get_latest_computation_graph", []() { + lazy_ts_backend.def("_get_latest_computation_graph", []() -> std::string { #if !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) auto computation = LazyGraphExecutor::Get() ->GetComputationCache() @@ -321,7 +319,6 @@ void initLazyBindings(PyObject* module) { #else TORCH_CHECK( false, "TorchScript backend not yet supported in FBCODE builds"); - return ""; #endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE)) }); diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp index 9c33a9988d7f2..20a700d9f3122 100644 --- a/torch/csrc/mps/Module.cpp +++ b/torch/csrc/mps/Module.cpp @@ -495,6 +495,26 @@ void initModule(PyObject* module) { .def_property_readonly( "static_thread_group_memory_length", &MetalKernelFunction::getStaticThreadGroupMemoryLength); + py::class_< + PrecompiledMetalShaderLibrary, + std::shared_ptr>( + m, "_mps_PrecompiledShaderLibrary") + .def( + "__getattr__", + [](PrecompiledMetalShaderLibrary& self, const std::string& name) { + return self.getKernelFunction(name); + }) + .def("__dir__", [](PrecompiledMetalShaderLibrary& self) { + return self.getFunctionNames(); + }); + m.def("_mps_loadMetalllib", [](const py::bytes& data) { + auto sv = static_cast(data); + std::vector bytes(sv.begin(), sv.end()); + return std::make_shared(std::move(bytes)); + }); + m.def("_mps_loadMetallibFromPath", [](const std::string& path) { + return std::make_shared(path); + }); m.def("_mps_compileShader", [](const std::string& source) { return std::make_shared(source); }); diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index 617316617fc67..1be2d80310910 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -10,6 +10,7 @@ #include #include +#include #ifdef USE_KINETO #include @@ -1096,6 +1097,18 @@ class TransferEvents { for (const auto* activity : trace_activities_) { auto e = toResult(activity); if (e) { + // Flow data for Kineto events is already set during + // resultFromActivity(). TorchOp events need it copied here because + // their Result is created during RecordFunction callbacks, before + // flow data exists on the GenericTraceActivity. + e->visit(c10::overloaded( + [&](ExtraFields& i) { + i.flow = { + /*id=*/static_cast(activity->flowId()), + /*type=*/static_cast(activity->flowType()), + /*start=*/activity->flowStart()}; + }, + [](auto&) {})); if (config_.experimental_config.expose_kineto_event_metadata) { e->visit(c10::overloaded( [&](ExtraFields& i) { @@ -1105,6 +1118,27 @@ class TransferEvents { i.metadata_json_ = activity->metadataJson(); }, [](auto&) { return; })); + // Parse metadataJson() into extra_meta_ so events() exposes + // Kineto metadata as typed fields without export_chrome_trace(). + // Python schemas (profiler_util.py) are the single SOT for + // which keys to expose and how to type-convert them. + e->visit(c10::overloaded( + [&](ExtraFields& i) { + auto json_str = activity->metadataJson(); + if (!json_str.empty()) { + auto j = nlohmann::json::parse( + "{" + json_str + "}", nullptr, false); + if (!j.is_discarded()) { + for (auto& [key, val] : j.items()) { + i.extra_meta_.emplace( + key, + val.is_string() ? val.get() + : val.dump()); + } + } + } + }, + [](auto&) {})); } const auto* linked_activity = activity->linkedActivity(); if (linked_activity) { diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index d66eb630a47d9..bd46814f8be9b 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -132,6 +132,14 @@ using extra_args_t = std::unordered_map; using extra_meta_t = std::unordered_map; using kwinputs_t = std::unordered_map; +// Mirrors `libkineto::GenericTraceActivity::Flow`. Used during post processing +// to embed Kineto events into the broader profiler tree structure. +struct Flow { + uint32_t id{0}; + uint32_t type{0}; + uint32_t start{0}; +}; + struct FallbackPair { ProfilerVoidEventStub device_event_start_ = nullptr; ProfilerVoidEventStub device_event_end_ = nullptr; @@ -179,6 +187,7 @@ struct ExtraFields : TorchOpBasicFields { bool allow_tf32_cublas_; std::unique_ptr perf_event_counters_; std::string metadata_json_; + Flow flow; }; template <> @@ -354,16 +363,6 @@ struct ExtraFields : public PyExtraFieldsBase { template <> struct ExtraFields { - // Mirrors `libkineto::GenericTraceActivity::Flow`. This information is used - // during post processing to properly embed Kineto events into the broader - // profiler tree structure. End users are not generally expected to use these - // fields directly, but they are available for debugging. - struct Flow { - uint32_t id{0}; - uint32_t type{0}; - uint32_t start{0}; - }; - std::string name_; int64_t duration_ns_{0}; uint64_t correlation_id_{0}; @@ -371,6 +370,7 @@ struct ExtraFields { Flow flow; std::weak_ptr linked_activity_; std::string metadata_json_; + extra_meta_t extra_meta_; }; struct TORCH_API Result : public std::enable_shared_from_this { diff --git a/torch/csrc/profiler/data_flow.cpp b/torch/csrc/profiler/data_flow.cpp index a9f98930f8c66..e18d88960fbba 100644 --- a/torch/csrc/profiler/data_flow.cpp +++ b/torch/csrc/profiler/data_flow.cpp @@ -131,15 +131,10 @@ void calculateUniqueTensorIDs( tensor_set.insert(t.allocation_id_ref_.get().value()); } } - tensors.erase( - std::remove_if( - tensors.begin(), - tensors.end(), - [&tensor_set](const auto& i) { - auto it = tensor_set.find(i.allocation_id_ref_.get().value()); - return it == tensor_set.end(); - }), - tensors.end()); + std::erase_if(tensors, [&tensor_set](const auto& i) { + auto it = tensor_set.find(i.allocation_id_ref_.get().value()); + return it == tensor_set.end(); + }); } // Handle the case that the storage of a TensorImpl changed. diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 0d1e7e0604222..fa232e1a01016 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -61,6 +61,7 @@ const ActivityTypeMap kMtiaTypes{ {libkineto::ActivityType::MTIA_CCP_EVENTS, "MTIA_CCP_EVENTS"}, {libkineto::ActivityType::MTIA_RUNTIME, "MTIA_RUNTIME"}, {libkineto::ActivityType::MTIA_INSIGHT, "MTIA_INSIGHT"}, + {libkineto::ActivityType::MTIA_COUNTERS, "MTIA_COUNTERS"}, }; const ActivityTypeMap kHpuTypes{ @@ -356,6 +357,12 @@ void prepareTrace( } else { LOG(INFO) << "Disabling MTIA insight events"; } + if (config.custom_profiler_config.find("disable_counter_events") == + std::string::npos) { + k_activities.insert(libkineto::ActivityType::MTIA_COUNTERS); + } else { + LOG(INFO) << "Disabling MTIA counter events"; + } } } if (activities.count(torch::autograd::profiler::ActivityType::HPU)) { @@ -490,6 +497,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { // TODO: T151322015 case libkineto::ActivityType::MTIA_CCP_EVENTS: case libkineto::ActivityType::MTIA_INSIGHT: + case libkineto::ActivityType::MTIA_COUNTERS: return device_type_privateuse1_or(c10::DeviceType::MTIA); case libkineto::ActivityType::HPU_OP: return c10::DeviceType::HPU; diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index 2f3d1fce740ce..379c065b22b07 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -486,7 +486,10 @@ convertIValue( itemsize, device_str); return std::make_tuple( - tensor_shape, tensor_stride, tensor_type, tensor_value); + std::move(tensor_shape), + std::move(tensor_stride), + std::move(tensor_type), + std::move(tensor_value)); } else if (val.isTuple()) { const auto& val_tuple = val.toTupleRef().elements(); size_t tuple_size = val_tuple.size(); @@ -494,6 +497,10 @@ convertIValue( std::vector stride_array; std::vector type_array; std::vector value_array; + shape_array.reserve(tuple_size); + stride_array.reserve(tuple_size); + type_array.reserve(tuple_size); + value_array.reserve(tuple_size); for (const auto j : c10::irange(tuple_size)) { auto tuple = convertIValue( ob, @@ -505,17 +512,17 @@ convertIValue( val_tuple[j], false, maxArrayLen); - shape_array.push_back(std::get<0>(tuple)); - stride_array.push_back(std::get<1>(tuple)); - type_array.push_back(std::get<2>(tuple)); - value_array.push_back(std::get<3>(tuple)); + shape_array.push_back(std::move(std::get<0>(tuple))); + stride_array.push_back(std::move(std::get<1>(tuple))); + type_array.push_back(std::move(std::get<2>(tuple))); + value_array.push_back(std::move(std::get<3>(tuple))); } type = type + vectorToString(type_array); std::string tensor_type = baseType ? fmt::format("\"{}\"", type) : type; return std::make_tuple( vectorToString(shape_array), vectorToString(stride_array), - tensor_type, + std::move(tensor_type), vectorToString(value_array)); } else if (val.isList()) { const auto& val_list = val.toList(); @@ -524,6 +531,11 @@ convertIValue( std::vector stride_array; std::vector type_array; std::vector value_array; + const size_t effective_list_size = std::min(list_size, maxArrayLen + 1); + shape_array.reserve(effective_list_size); + stride_array.reserve(effective_list_size); + type_array.reserve(effective_list_size); + value_array.reserve(effective_list_size); for (const auto j : c10::irange(list_size)) { auto tuple = convertIValue( ob, @@ -535,10 +547,10 @@ convertIValue( val_list.get(j), false, maxArrayLen); - shape_array.push_back(std::get<0>(tuple)); - stride_array.push_back(std::get<1>(tuple)); - type_array.push_back(std::get<2>(tuple)); - value_array.push_back(std::get<3>(tuple)); + shape_array.push_back(std::move(std::get<0>(tuple))); + stride_array.push_back(std::move(std::get<1>(tuple))); + type_array.push_back(std::move(std::get<2>(tuple))); + value_array.push_back(std::move(std::get<3>(tuple))); if (j >= maxArrayLen) { LOG(WARNING) << "list size=" << val_list.size() << " exceeded maxArrayLen=" << maxArrayLen; @@ -550,7 +562,7 @@ convertIValue( return std::make_tuple( vectorToString(shape_array), vectorToString(stride_array), - tensor_type, + std::move(tensor_type), vectorToString(value_array)); } else { std::string tensor_shape = "[]"; @@ -559,7 +571,10 @@ convertIValue( std::string tensor_value = getScalarValue(val); return std::make_tuple( - tensor_shape, tensor_stride, tensor_type, tensor_value); + std::move(tensor_shape), + std::move(tensor_stride), + std::move(tensor_type), + std::move(tensor_value)); } } diff --git a/torch/csrc/profiler/standalone/privateuse1_profiler.h b/torch/csrc/profiler/standalone/privateuse1_profiler.h index 63e2d9bf982c2..0cf5d4b2357fa 100644 --- a/torch/csrc/profiler/standalone/privateuse1_profiler.h +++ b/torch/csrc/profiler/standalone/privateuse1_profiler.h @@ -63,11 +63,6 @@ class TORCH_API PrivateUse1ProfilerRegistry { // Useful for testing to verify the registration logic. bool isRegisteredWithKineto() const; - // Register the factory with Kineto's activity profiler. - // This is called internally when Kineto is ready. - // Safe to call multiple times - will only register once. - void registerWithKineto(); - // Mark that Kineto has been initialized. // If a factory was registered before Kineto init, it will be forwarded. void onKinetoInit(); @@ -75,6 +70,10 @@ class TORCH_API PrivateUse1ProfilerRegistry { private: PrivateUse1ProfilerRegistry() = default; + // Register the factory with Kineto's activity profiler. + // Caller must hold mutex_. + void registerWithKineto(); + mutable std::mutex mutex_; PrivateUse1ProfilerFactory factory_; bool registered_with_kineto_ = false; diff --git a/torch/csrc/profiler/unwind/action.h b/torch/csrc/profiler/unwind/action.h index 5a982cfd046a0..7dc6c6c0b8166 100644 --- a/torch/csrc/profiler/unwind/action.h +++ b/torch/csrc/profiler/unwind/action.h @@ -11,7 +11,8 @@ enum { A_REG_PLUS_DATA_DEREF = 0x3 // exp = *(REG[reg] + data0) }; -// register numbers in dwarf info +// DWARF register numbers — architecture-specific +#if defined(__x86_64__) enum { D_UNDEFINED = -1, D_RBP = 6, @@ -19,6 +20,28 @@ enum { D_RIP = 16, D_REG_SIZE = 17, }; +static constexpr int D_FRAME_PTR = D_RBP; +static constexpr int D_STACK_PTR = D_RSP; +static constexpr int D_RET_ADDR = D_RIP; +static constexpr int D_EXPECTED_RA_REG = 16; +#elif defined(__aarch64__) +enum { + D_UNDEFINED = -1, + D_FP = 29, + D_LR = 30, + D_SP = 31, + D_REG_SIZE = 32, +}; +static constexpr int D_FRAME_PTR = D_FP; +static constexpr int D_STACK_PTR = D_SP; +static constexpr int D_RET_ADDR = D_LR; +static constexpr int D_EXPECTED_RA_REG = 30; +#else +enum { + D_UNDEFINED = -1, + D_REG_SIZE = 1, +}; +#endif struct Action { uint8_t kind = A_UNDEFINED; diff --git a/torch/csrc/profiler/unwind/dwarf_enums.h b/torch/csrc/profiler/unwind/dwarf_enums.h index 91af24b34e1f9..3effefd1f0083 100644 --- a/torch/csrc/profiler/unwind/dwarf_enums.h +++ b/torch/csrc/profiler/unwind/dwarf_enums.h @@ -30,6 +30,7 @@ enum { DW_CFA_advance_loc1 = 0x02, DW_CFA_advance_loc2 = 0x03, DW_CFA_advance_loc4 = 0x04, + DW_CFA_offset_extended = 0x05, DW_CFA_restore_extended = 0x06, DW_CFA_undefined = 0x07, DW_CFA_register = 0x09, diff --git a/torch/csrc/profiler/unwind/fde.h b/torch/csrc/profiler/unwind/fde.h index ffb06b5ab1f46..70b829a0f8023 100644 --- a/torch/csrc/profiler/unwind/fde.h +++ b/torch/csrc/profiler/unwind/fde.h @@ -63,8 +63,10 @@ struct FDE { } else { ra_register_ = static_cast(LC.readULEB128()); } - // we assume this in the state - TORCH_INTERNAL_ASSERT(ra_register_ == 16, "unexpected number of registers"); + TORCH_INTERNAL_ASSERT( + ra_register_ == D_EXPECTED_RA_REG, + "unexpected ra register: ", + ra_register_); if (augmentation_string_ && *augmentation_string_ == 'z') { augmentation_length_ = static_cast(LC.readULEB128()); Lexer A(LC.loc()); @@ -271,6 +273,11 @@ struct FDE { auto delta = L.read(); return advance_loc(delta); } + case DW_CFA_offset_extended: { + auto reg = L.readULEB128(); + auto off = L.readULEB128(); + return offset(reg, off); + } case DW_CFA_restore_extended: { auto reg = L.readULEB128(); return restore(reg); diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index bdb610627c897..b1696e34f24ce 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -3,32 +3,32 @@ #include #include -#if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \ - !__has_include("ext/stdio_filebuf.h") +#if !defined(__linux__) || !(defined(__x86_64__) || defined(__aarch64__)) || \ + !defined(__has_include) || !__has_include("ext/stdio_filebuf.h") namespace torch::unwind { std::vector unwind() { TORCH_WARN_ONCE( - "record_context_cpp is not support on non-linux non-x86_64 platforms"); + "record_context_cpp is not supported on this platform (requires linux x86_64 or aarch64)"); return {}; } std::optional> libraryFor(void* addr) { TORCH_WARN_ONCE( - "record_context_cpp is not support on non-linux non-x86_64 platforms"); + "record_context_cpp is not supported on this platform (requires linux x86_64 or aarch64)"); return {}; } #ifndef FBCODE_CAFFE2 std::vector symbolize(const std::vector& frames, Mode mode) { TORCH_WARN_ONCE( - "record_context_cpp is not support on non-linux non-x86_64 platforms"); + "record_context_cpp is not supported on this platform (requires linux x86_64 or aarch64)"); return {}; } #endif Stats stats() { TORCH_WARN_ONCE( - "record_context_cpp is not support on non-linux non-x86_64 platforms"); + "record_context_cpp is not supported on this platform (requires linux x86_64 or aarch64)"); return {}; } @@ -41,8 +41,10 @@ Stats stats() { #include #include #include +#include #include #include +#include #include #include @@ -53,7 +55,14 @@ Stats stats() { #include #include +#if defined(__aarch64__) +extern "C" void unwind_c( + std::vector* result, + uintptr_t fp, + uintptr_t lr); +#else extern "C" void unwind_c(std::vector* result, int64_t rsp, int64_t rbp); +#endif extern "C" void unwind_entry(std::vector* result); namespace torch::unwind { @@ -98,7 +107,8 @@ struct LibraryInfo { void* fde_data = eh_frame_hdr_.entryForAddr(addr); FDE fde(fde_data, name().c_str(), load_bias()); TableState state = fde.readUpTo(addr); - return Unwinder(state.cfa, state.registers[D_RIP], state.registers[D_RBP]); + return Unwinder( + state.cfa, state.registers[D_RET_ADDR], state.registers[D_FRAME_PTR]); } const std::string& name() const { return name_; @@ -501,6 +511,111 @@ Stats stats() { } // namespace torch::unwind +#if defined(__aarch64__) +// aarch64 uses frame-pointer chain walking instead of DWARF unwinding. +// Each frame has: *(FP) = caller's FP, *(FP+8) = saved LR (return address). +// This is simpler and avoids issues with tail calls producing stale x30 +// values in DWARF-based unwinding. GCC/Clang on aarch64 emit frame +// pointers by default even at -O2. +// +// External libraries (CPython, libc, CUDA runtime) may be built without +// frame pointers, making x29 an arbitrary callee-saved value. We obtain +// the current thread's stack bounds and reject any fp outside that range +// to avoid dereferencing garbage pointers. +// +// No cache_mutex_ needed: frame-pointer walking reads only the stack, +// unlike the x86 path which queries the DWARF FDE cache. + +// Stack bounds are essentially immutable over a thread's lifetime (base and +// size are fixed at pthread_create for spawned threads). On glibc aarch64 +// the main thread is the pathological case: pthread_getattr_np() parses +// /proc/self/maps on every call because the main-thread stack isn't recorded +// in TLS at pthread_create. Cache once per thread to avoid that parse on +// every unwind. This mirrors the process-global caches the x86 unwinder +// uses for /proc-derived data (library list, exe path); here the data is +// per-thread, so thread_local is the right scope. +// +// Edge case: main thread's stack can grow up to RLIMIT_STACK, and this cache +// freezes the bounds at first observation. If the stack later grows past +// cached lo, unwind_c will terminate early on those frames rather than +// follow them - truncated backtrace, not incorrect. In practice main-thread +// stack reaches steady-state depth during init, long before heavy unwinding. +namespace { +struct StackBounds { + uintptr_t lo = 0; + uintptr_t hi = 0; + bool initialized = false; +}; +thread_local StackBounds tls_stack_bounds; +} // namespace + +static bool get_stack_bounds(uintptr_t& lo, uintptr_t& hi) { + auto& b = tls_stack_bounds; + if (!b.initialized) { + pthread_attr_t attr; + if (pthread_getattr_np(pthread_self(), &attr) != 0) { + return false; + } + void* base = nullptr; + size_t size = 0; + int rc = pthread_attr_getstack(&attr, &base, &size); + pthread_attr_destroy(&attr); + if (rc != 0) { + return false; + } + b.lo = reinterpret_cast(base); + b.hi = b.lo + size; + b.initialized = true; + } + lo = b.lo; + hi = b.hi; + return true; +} + +extern "C" C10_USED void unwind_c( + std::vector* result, + uintptr_t fp, + uintptr_t lr) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + result->push_back((void*)lr); + + uintptr_t stack_lo = 0, stack_hi = 0; + if (!get_stack_bounds(stack_lo, stack_hi)) { + return; + } + + constexpr int kMaxFrames = 4096; + int depth = 0; + while (fp != 0 && (fp & 0xF) == 0 && depth++ < kMaxFrames) { + if (fp < stack_lo || fp + 16 > stack_hi) { + break; + } + uintptr_t saved_lr; + std::memcpy( + &saved_lr, reinterpret_cast(fp + 8), sizeof(saved_lr)); + if (saved_lr == 0) { + break; + } + // NOLINTNEXTLINE(performance-no-int-to-ptr) + result->push_back((void*)saved_lr); + uintptr_t next_fp; + std::memcpy(&next_fp, reinterpret_cast(fp), sizeof(next_fp)); + if (next_fp <= fp) { + break; + } + fp = next_fp; + } +} + +// x0 already holds the result pointer. +// Pass FP (x29) and LR (x30), then tail-call unwind_c. +__asm__( + ".global unwind_entry\n" + "unwind_entry:\n" + "mov x1, x29\n" + "mov x2, x30\n" + "b unwind_c\n"); +#else extern "C" C10_USED void unwind_c( std::vector* result, int64_t rsp, @@ -508,17 +623,17 @@ extern "C" C10_USED void unwind_c( std::shared_lock lock(torch::unwind::cache_mutex_); torch::unwind::UnwindState state{}; // NOLINTNEXTLINE(performance-no-int-to-ptr) - state.rip = *(int64_t*)rsp; + state.pc = *(int64_t*)rsp; // +8 because we saved rsp after the return address was already pushed // to the stack - state.rsp = rsp + 8; - state.rbp = rbp; + state.sp = rsp + 8; + state.fp = rbp; torch::unwind::unwind_cache.checkRefresh(lock); - while (true) { // unwind for _start sets rip as being undefined + while (true) { // NOLINTNEXTLINE(performance-no-int-to-ptr) - result->push_back((void*)state.rip); + result->push_back((void*)state.pc); const torch::unwind::Unwinder& uw = - torch::unwind::unwind_cache.unwinderFor(state.rip, lock); + torch::unwind::unwind_cache.unwinderFor(state.pc, lock); if (uw.terminator()) { if (uw.isUnknown()) { result->push_back(nullptr); @@ -529,8 +644,7 @@ extern "C" C10_USED void unwind_c( } } -// calling convention puts the first three pointer/int64_t arguments in -// rdi rsi rdx (all caller-saved) +// x86-64 calling convention: rdi rsi rdx (all caller-saved) // rdi already holds the pointer to the result vector // we add arguments for current rsp and rbp and then tail call // into unwind_c @@ -540,5 +654,6 @@ __asm__( "mov %rsp, %rsi;\n" "mov %rbp, %rdx;\n" "jmp unwind_c;\n"); +#endif #endif diff --git a/torch/csrc/profiler/unwind/unwinder.h b/torch/csrc/profiler/unwind/unwinder.h index d673f47af8db2..3836203688530 100644 --- a/torch/csrc/profiler/unwind/unwinder.h +++ b/torch/csrc/profiler/unwind/unwinder.h @@ -6,30 +6,32 @@ namespace torch::unwind { +// Architecture-neutral names: pc (program counter / return address), +// fp (frame pointer: x86 RBP, aarch64 x29), sp (stack pointer). struct UnwindState { - int64_t rip, rbp, rsp; + int64_t pc, fp, sp; }; struct Unwinder { - Unwinder(Action rsp, Action rip, Action rbp) - : kind_(rip.kind == A_UNDEFINED ? END : STANDARD), - reg_(rsp.reg), - off_(rsp.data), - rip_off_(rip.data), - rbp_off_( - rbp.kind == A_UNDEFINED ? std::numeric_limits::max() - : rbp.data), - deref_(rsp.kind == A_REG_PLUS_DATA_DEREF) { - check(rsp.reg == D_RSP || rsp.reg == D_RBP); - check(rip.kind == A_UNDEFINED || rip.kind == A_LOAD_CFA_OFFSET); - if (rsp.kind == A_REG_PLUS_DATA) { - check(rbp.kind == A_LOAD_CFA_OFFSET || rbp.kind == A_UNDEFINED); - } else if (rsp.kind == A_REG_PLUS_DATA_DEREF) { - if (rbp.kind == A_REG_PLUS_DATA_DEREF) { - check(rbp.reg == rsp.reg); - rbp_off_ -= rsp.data; + Unwinder(Action cfa, Action ret, Action fp) + : kind_(ret.kind == A_UNDEFINED ? END : STANDARD), + reg_(cfa.reg), + off_(cfa.data), + ret_off_(ret.data), + fp_off_( + fp.kind == A_UNDEFINED ? std::numeric_limits::max() + : fp.data), + deref_(cfa.kind == A_REG_PLUS_DATA_DEREF) { + check(cfa.reg == D_STACK_PTR || cfa.reg == D_FRAME_PTR); + check(ret.kind == A_UNDEFINED || ret.kind == A_LOAD_CFA_OFFSET); + if (cfa.kind == A_REG_PLUS_DATA) { + check(fp.kind == A_LOAD_CFA_OFFSET || fp.kind == A_UNDEFINED); + } else if (cfa.kind == A_REG_PLUS_DATA_DEREF) { + if (fp.kind == A_REG_PLUS_DATA_DEREF) { + check(fp.reg == cfa.reg); + fp_off_ -= cfa.data; } else { - check(rbp.kind == A_UNDEFINED); + check(fp.kind == A_UNDEFINED); } } else { check(false); @@ -53,28 +55,28 @@ struct Unwinder { } UnwindState run(const UnwindState& cur) const { UnwindState r = cur; - r.rsp = (reg_ == D_RSP ? cur.rsp : cur.rbp) + off_; - r.rbp = rbp_off_ == std::numeric_limits::max() - ? cur.rbp + r.sp = (reg_ == D_STACK_PTR ? cur.sp : cur.fp) + off_; + r.fp = fp_off_ == std::numeric_limits::max() + ? cur.fp // NOLINTNEXTLINE(performance-no-int-to-ptr) - : *(int64_t*)(r.rsp + rbp_off_); + : *(int64_t*)(r.sp + fp_off_); if (deref_) { // NOLINTNEXTLINE(performance-no-int-to-ptr) - r.rsp = *(int64_t*)r.rsp; + r.sp = *(int64_t*)r.sp; } // NOLINTNEXTLINE(performance-no-int-to-ptr) - r.rip = *(int64_t*)(r.rsp + rip_off_); + r.pc = *(int64_t*)(r.sp + ret_off_); return r; } private: - Unwinder() : kind_(UNKNOWN), reg_(0), off_(0), rip_off_(0), rbp_off_(0) {} + Unwinder() : kind_(UNKNOWN), reg_(0), off_(0), ret_off_(0), fp_off_(0) {} enum Kind { STANDARD, END, UNKNOWN } kind_; uint32_t reg_; int64_t off_; - int64_t rip_off_; - int64_t rbp_off_; + int64_t ret_off_; + int64_t fp_off_; bool deref_{false}; }; diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp index 541b6b79c56ce..e9a77287121f4 100644 --- a/torch/csrc/profiler/util.cpp +++ b/torch/csrc/profiler/util.cpp @@ -10,6 +10,7 @@ #include #endif #ifdef USE_DISTRIBUTED +#include #include #endif // USE_DISTRIBUTED @@ -513,6 +514,15 @@ std::unordered_map saveNcclMeta( auto seqNum = debugInfo->getSequenceNumber(); if (seqNum >= 0) { map.emplace(kSeqNum, std::to_string(seqNum)); + + size_t comms_id = c10::get_hash( + debugInfo->getProcessGroupName(), + seqNum, + debugInfo->getIsP2P(), + globalRankStart, + globalRankStride, + debugInfo->getWorldSize()); + map.emplace(kCommsId, std::to_string(comms_id)); } } diff --git a/torch/csrc/profiler/util.h b/torch/csrc/profiler/util.h index eed931cbf5a9c..a6b8248132800 100644 --- a/torch/csrc/profiler/util.h +++ b/torch/csrc/profiler/util.h @@ -205,6 +205,7 @@ constexpr auto kSeqNum = "Seq"; constexpr auto kInTensorsStart = "Input Tensors start"; constexpr auto kOutTensorsStart = "Output Tensors start"; constexpr auto kIsAsynchronizedOp = "Is asynchronized op"; +constexpr auto kCommsId = "Comms Id"; #endif // USE_DISTRIBUTED } // namespace torch::profiler::impl diff --git a/torch/csrc/python_dimname.cpp b/torch/csrc/python_dimname.cpp index 07f604600b22b..a2de89811f120 100644 --- a/torch/csrc/python_dimname.cpp +++ b/torch/csrc/python_dimname.cpp @@ -57,7 +57,7 @@ void InternedStringsTable::addMapping(PyObject* obj, at::Dimname dimname) { } // namespace torch bool THPUtils_checkDimname(PyObject* obj) { - return obj == Py_None || THPUtils_checkString(obj); + return Py_IsNone(obj) || THPUtils_checkString(obj); } // To avoid ambiguity with IntArrayRef, we parse obj as a DimnameList if @@ -78,7 +78,7 @@ bool THPUtils_checkDimnameList(PyObject* obj) { } at::Dimname THPDimname_parse(PyObject* obj) { - if (obj == Py_None) { + if (Py_IsNone(obj)) { return at::Dimname::wildcard(); } diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index b1338fc43e913..fc5c7a47b8b82 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -12,24 +12,21 @@ namespace torch::utils { namespace { -std::array is_initialized{}; -std::array is_in_bad_fork{}; -std::array at_fork_registered{}; +std::array, at::COMPILE_TIME_MAX_DEVICE_TYPES> + is_initialized{}; +std::array, at::COMPILE_TIME_MAX_DEVICE_TYPES> + is_in_bad_fork{}; +std::array, at::COMPILE_TIME_MAX_DEVICE_TYPES> + at_fork_registered{}; c10::once_flag at_fork_register_once{}; } // anonymous namespace bool is_device_initialized(at::DeviceType device_type) { - pybind11::gil_scoped_acquire g; return is_initialized[static_cast(device_type)]; } void device_lazy_init(at::DeviceType device_type) { - pybind11::gil_scoped_acquire g; - // Protected by the GIL. We don't use call_once because under ASAN it - // has a buggy implementation that deadlocks if an instance throws an - // exception. In any case, call_once isn't necessary, because we - // have taken a lock. if (is_device_initialized(device_type)) { return; } @@ -40,6 +37,10 @@ void device_lazy_init(at::DeviceType device_type) { return; } + // Don't use call_once because under ASAN it has a buggy implementation that + // deadlocks if an instance throws an exception and Python _lazy_init() + // functions are idempotent. + pybind11::gil_scoped_acquire g; std::string module_name = "torch." + at::DeviceTypeName(device_type, true); auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str())); if (!module) { @@ -50,6 +51,7 @@ void device_lazy_init(at::DeviceType device_type) { auto has_lazy_init_method = PyObject_HasAttrString(module.get(), "_lazy_init") == 1; if (!has_lazy_init_method) { + is_initialized[static_cast(device_type)] = true; return; } } diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index d75c0351fb6c4..0394cfce0b424 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -323,6 +323,15 @@ auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool { !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) && torch::torch_function_enabled() && has_torch_function_attr(obj)); } + +bool has_torch_function(c10::ArrayRef args) { + for (const auto obj : args) { + if (has_torch_function(obj)) { + return true; + } + } + return false; +} } // namespace torch inline static bool sequence_has_torch_function(PyObject* args) { diff --git a/torch/csrc/utils/disable_torch_function.h b/torch/csrc/utils/disable_torch_function.h index b52173c252a88..c735fbbc560ed 100644 --- a/torch/csrc/utils/disable_torch_function.h +++ b/torch/csrc/utils/disable_torch_function.h @@ -1,6 +1,7 @@ #pragma once #include #include +#include #include namespace torch { @@ -18,6 +19,11 @@ void set_disabled_torch_dispatch_impl(PyObject* value); // overloaded list even if they don't actually have __torch_function__ bool check_has_torch_function(PyObject* obj, bool ignore_mode = false); +inline bool has_torch_function(PyObject* obj) { + return check_has_torch_function(obj); +} +bool has_torch_function(c10::ArrayRef args); + struct DisableTorchDispatch { DisableTorchDispatch() : guard_(c10::DispatchKeySet( diff --git a/torch/csrc/utils/generated_serialization_types.h b/torch/csrc/utils/generated_serialization_types.h index f4614ac0c191a..995e2405cb4da 100644 --- a/torch/csrc/utils/generated_serialization_types.h +++ b/torch/csrc/utils/generated_serialization_types.h @@ -1,5 +1,5 @@ // @generated by update_schema.py -// checksum<> +// checksum<> // clang-format off #pragma once @@ -289,6 +289,8 @@ enum class ScalarType { FLOAT8E4M3FNUZ = 31, FLOAT8E5M2FNUZ = 32, FLOAT8E8M0FNU = 33, + UINT32 = 34, + UINT64 = 35, }; inline std::string_view printEnum(const ScalarType& e) { @@ -313,6 +315,8 @@ inline std::string_view printEnum(const ScalarType& e) { case ScalarType::FLOAT8E4M3FNUZ: return "FLOAT8E4M3FNUZ"; case ScalarType::FLOAT8E5M2FNUZ: return "FLOAT8E5M2FNUZ"; case ScalarType::FLOAT8E8M0FNU: return "FLOAT8E8M0FNU"; + case ScalarType::UINT32: return "UINT32"; + case ScalarType::UINT64: return "UINT64"; default: throw std::runtime_error("Unknown enum value"); } @@ -339,6 +343,8 @@ inline void parseEnum(std::string_view s, ScalarType& t) { if (s == "FLOAT8E4M3FNUZ") { t = ScalarType::FLOAT8E4M3FNUZ; return; } if (s == "FLOAT8E5M2FNUZ") { t = ScalarType::FLOAT8E5M2FNUZ; return; } if (s == "FLOAT8E8M0FNU") { t = ScalarType::FLOAT8E8M0FNU; return; } + if (s == "UINT32") { t = ScalarType::UINT32; return; } + if (s == "UINT64") { t = ScalarType::UINT64; return; } throw std::runtime_error("Unknown enum value: " + std::string{s}); } @@ -1235,11 +1241,11 @@ class Argument { public: enum class Tag { - AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX, AS_NESTED_TENSORS, AS_INT_LISTS, AS_STRING_TO_ARGUMENT + AS_NONE, AS_TENSOR, AS_TENSORS, AS_INT, AS_INTS, AS_FLOAT, AS_FLOATS, AS_STRING, AS_STRINGS, AS_SYM_INT, AS_SYM_INTS, AS_SCALAR_TYPE, AS_MEMORY_FORMAT, AS_LAYOUT, AS_DEVICE, AS_BOOL, AS_BOOLS, AS_SYM_BOOL, AS_SYM_BOOLS, AS_GRAPH, AS_OPTIONAL_TENSORS, AS_CUSTOM_OBJ, AS_OPERATOR, AS_SYM_FLOAT, AS_SYM_FLOATS, AS_OPTIONAL_TENSOR, AS_COMPLEX, AS_NESTED_TENSORS, AS_INT_LISTS, AS_STRING_TO_ARGUMENT, AS_FLOAT_LISTS }; private: - std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue, std::vector>, std::vector>, std::unordered_map>> variant_; + std::variant, int64_t, std::vector, F64, std::vector, std::string, std::vector, SymIntArgument, std::vector, ScalarType, MemoryFormat, Layout, Device, bool, std::vector, SymBoolArgument, std::vector, GraphArgument, std::vector, CustomObjArgument, std::string, SymFloatArgument, std::vector, OptionalTensorArgument, ComplexValue, std::vector>, std::vector>, std::unordered_map>, std::vector>> variant_; Tag tag_; public: @@ -1517,6 +1523,15 @@ class Argument { tag_ = Tag::AS_STRING_TO_ARGUMENT; } + const std::vector>& get_as_float_lists() const { + return std::get<31>(variant_); + } + + void set_as_float_lists(std::vector> def) { + variant_.emplace<31>(std::move(def)); + tag_ = Tag::AS_FLOAT_LISTS; + } + friend void to_json(nlohmann::json& nlohmann_json_j, const Argument& nlohmann_json_t) { if (nlohmann_json_t.tag_ == Tag::AS_NONE) { @@ -1639,6 +1654,10 @@ class Argument { nlohmann_json_j["as_string_to_argument"] = nlohmann_json_t.get_as_string_to_argument(); return; } + if (nlohmann_json_t.tag_ == Tag::AS_FLOAT_LISTS) { + nlohmann_json_j["as_float_lists"] = nlohmann_json_t.get_as_float_lists(); + return; + } } friend void from_json(const nlohmann::json& nlohmann_json_j, Argument& nlohmann_json_t) { @@ -1793,6 +1812,11 @@ class Argument { nlohmann_json_t.tag_ = Tag::AS_STRING_TO_ARGUMENT; return; } + if (nlohmann_json_j.contains("as_float_lists")) { + nlohmann_json_t.variant_.emplace<31>(nlohmann_json_j.at("as_float_lists").template get>>()); + nlohmann_json_t.tag_ = Tag::AS_FLOAT_LISTS; + return; + } } }; @@ -1828,6 +1852,7 @@ inline std::string_view printEnum(const Argument::Tag& e) { case Argument::Tag::AS_NESTED_TENSORS: return "AS_NESTED_TENSORS"; case Argument::Tag::AS_INT_LISTS: return "AS_INT_LISTS"; case Argument::Tag::AS_STRING_TO_ARGUMENT: return "AS_STRING_TO_ARGUMENT"; + case Argument::Tag::AS_FLOAT_LISTS: return "AS_FLOAT_LISTS"; default: throw std::runtime_error("Unknown enum value"); } @@ -1864,6 +1889,7 @@ inline void parseEnum(std::string_view s, Argument::Tag& t) { if (s == "AS_NESTED_TENSORS") { t = Argument::Tag::AS_NESTED_TENSORS; return; } if (s == "AS_INT_LISTS") { t = Argument::Tag::AS_INT_LISTS; return; } if (s == "AS_STRING_TO_ARGUMENT") { t = Argument::Tag::AS_STRING_TO_ARGUMENT; return; } + if (s == "AS_FLOAT_LISTS") { t = Argument::Tag::AS_FLOAT_LISTS; return; } throw std::runtime_error("Unknown enum value: " + std::string{s}); } diff --git a/torch/csrc/utils/invalid_arguments.cpp b/torch/csrc/utils/invalid_arguments.cpp index a0a85956442d5..fa7c5494851f4 100644 --- a/torch/csrc/utils/invalid_arguments.cpp +++ b/torch/csrc/utils/invalid_arguments.cpp @@ -52,7 +52,7 @@ struct NullableType : public Type { NullableType(std::unique_ptr type) : type(std::move(type)) {} bool is_matching(PyObject* object) override { - return object == Py_None || type->is_matching(object); + return Py_IsNone(object) || type->is_matching(object); } std::unique_ptr type; diff --git a/torch/csrc/utils/pyobject_preservation.cpp b/torch/csrc/utils/pyobject_preservation.cpp index a652cbdb7aefd..e61b835630174 100644 --- a/torch/csrc/utils/pyobject_preservation.cpp +++ b/torch/csrc/utils/pyobject_preservation.cpp @@ -1,67 +1 @@ #include - -#include -#include - -namespace torch::utils { - -using c10::intrusive_ptr_target; -using c10::impl::PyObjectSlot; - -void PyObjectPreservation::init_fresh_nonatomic( - intrusive_ptr_target* target, - PyObjectSlot* slot, - PyObject* pyobj) { - TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr); - TORCH_INTERNAL_ASSERT( - target->combined_refcount_.load(std::memory_order_relaxed) == - c10::detail::kUniqueRef); - - slot->pyobj_.store(pyobj, std::memory_order_relaxed); - slot->pyobj_interpreter_.store( - c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed); - target->combined_refcount_.store( - c10::detail::kHasPyObject | c10::detail::kUniqueRef, - std::memory_order_relaxed); -} - -PyObject* PyObjectPreservation::init_once( - intrusive_ptr_target* target, - PyObjectSlot* slot, - PyObject* pyobj) { - PyObject* expected = nullptr; - if (!slot->pyobj_.compare_exchange_strong( - expected, pyobj, std::memory_order_acq_rel)) { - TORCH_INTERNAL_ASSERT(expected != nullptr); - return expected; - } - - slot->pyobj_interpreter_.store( - c10::impl::getGlobalPyInterpreter(), std::memory_order_release); - - bool increfed = false; - auto combined = target->combined_refcount_.load(std::memory_order_relaxed); - do { - TORCH_INTERNAL_ASSERT(!c10::detail::has_pyobject(combined)); - if (c10::detail::refcount(combined) > 1 && !increfed) { - // We need to incref the object to preserve the invariant that - // if refcount > 1, the c10 object holds a reference to the PyObject. - // This must happen before we set the kHasPyObject bit. - Py_INCREF(pyobj); - increfed = true; - } - } while (!target->combined_refcount_.compare_exchange_weak( - combined, - combined | c10::detail::kHasPyObject, - std::memory_order_acq_rel, - std::memory_order_relaxed)); - - if (increfed && c10::detail::refcount(combined) == 1) { - // Fix up if refcount if we did the incref in a failed compare-exchange - Py_DECREF(pyobj); - } - - return pyobj; -} - -} // namespace torch::utils diff --git a/torch/csrc/utils/pyobject_preservation.h b/torch/csrc/utils/pyobject_preservation.h index b060bc034b2c3..fdbc48b8d25a6 100644 --- a/torch/csrc/utils/pyobject_preservation.h +++ b/torch/csrc/utils/pyobject_preservation.h @@ -1,15 +1,12 @@ #pragma once #include +#include -// This file contains utilities used for handling PyObject preservation +#include +#include -namespace c10 { -class intrusive_ptr_target; -namespace impl { -struct PyObjectSlot; -} // namespace impl -} // namespace c10 +// This file contains utilities used for handling PyObject preservation namespace torch::utils { @@ -17,15 +14,98 @@ class PyObjectPreservation { public: // Store a PyObject wrapper on a fresh c10 wrapper. The caller must hold // a unique reference to `target`. - static void init_fresh_nonatomic( - c10::intrusive_ptr_target* target, - c10::impl::PyObjectSlot* slot, - PyObject* pyobj); - - static PyObject* init_once( - c10::intrusive_ptr_target* target, - c10::impl::PyObjectSlot* slot, - PyObject* pyobj); + template + requires requires(T& t) { + t.pyobj_slot(); + } + static void init_fresh_nonatomic(T& target, PyObject* pyobj) { + auto* slot = target.pyobj_slot(); + TORCH_INTERNAL_ASSERT(slot->load_pyobj() == nullptr); + TORCH_INTERNAL_ASSERT( + target.combined_refcount_.load(std::memory_order_relaxed) == + c10::detail::kUniqueRef); + + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(pyobj); + + slot->pyobj_.store(pyobj, std::memory_order_relaxed); + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed); + target.combined_refcount_.store( + c10::detail::kHasPyObject | c10::detail::kUniqueRef, + std::memory_order_relaxed); + } + + // Thread-safe get-or-create for the PyObject wrapper. Returns a new + // reference. The factory is called at most once if no wrapper exists yet; + // if another thread races and wins, the factory's result is destroyed and + // the winner's wrapper is returned instead. + template + requires requires(T& t) { + t.pyobj_slot(); + } + static PyObject* get_or_init(T& target, Factory&& pyobj_factory) { + auto* slot = target.pyobj_slot(); + PyObject* obj = slot->load_pyobj(); + if (obj) { + return Py_NewRef(obj); + } + + obj = pyobj_factory(); + + // Ensure that PyUnstable_TryIncref calls don't fail spuriously in + // free-threaded Python. + PyUnstable_EnableTryIncRef(obj); + + // Fast path: if we're the only owner, no other thread can see this + // object, so we can skip the atomic CAS. + auto combined = target.combined_refcount_.load(std::memory_order_relaxed); + if (combined == c10::detail::kUniqueRef) { + slot->pyobj_.store(obj, std::memory_order_relaxed); + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_relaxed); + target.combined_refcount_.store( + c10::detail::kHasPyObject | c10::detail::kUniqueRef, + std::memory_order_relaxed); + return obj; + } + + // Slow path: atomically store our new wrapper into the slot. + slot->pyobj_interpreter_.store( + c10::impl::getGlobalPyInterpreter(), std::memory_order_release); + PyObject* expected = nullptr; + if (!slot->pyobj_.compare_exchange_strong( + expected, obj, std::memory_order_acq_rel)) { + // Another thread won the race — discard ours, use theirs. + Py_DECREF(obj); + return Py_NewRef(expected); + } + + // We won. Set the kHasPyObject bit in the combined refcount. + bool increfed = false; + do { + if (c10::detail::refcount(combined) > 1 && !increfed) { + // Preserve the invariant that if refcount > 1, the c10 object + // holds a reference to the PyObject. This must happen before we + // set the kHasPyObject bit. + Py_INCREF(obj); + increfed = true; + } + } while (!target.combined_refcount_.compare_exchange_weak( + combined, + combined | c10::detail::kHasPyObject, + std::memory_order_acq_rel, + std::memory_order_relaxed)); + + if (increfed && c10::detail::refcount(combined) == 1) { + // We incref'd because refcount was > 1 during an earlier CAS attempt, + // but by the time we succeeded, refcount had dropped to 1. Undo. + Py_DECREF(obj); + } + + return obj; + } }; } // namespace torch::utils diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 999bd00b3bcd6..131b4d58f71a8 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -302,8 +302,22 @@ static py::object maybe_get_registered_torch_dispatch_rule( static bool is_dtensor(PyObject* obj) { #ifdef USE_DISTRIBUTED const py::handle dtensor = get_dtensor_class(); - return (PyObject*)Py_TYPE(obj) == dtensor.ptr() || - py::isinstance(py::handle(obj), dtensor); + if ((PyObject*)Py_TYPE(obj) == dtensor.ptr()) { + return true; + } + if (!py::isinstance(py::handle(obj), dtensor)) { + return false; + } + // DTensor subclass: only use the C++ fast path if it does not override + // __torch_dispatch__. Subclasses with a custom override should fall + // through to the normal Python dispatch path. + // Compare via __func__ because @classmethod descriptors create new bound + // method objects on each attr access, making direct identity checks fail. + static py::object base_td = + dtensor.attr("__torch_dispatch__").attr("__func__"); + py::object sub_td = + py::type::handle_of(obj).attr("__torch_dispatch__").attr("__func__"); + return sub_td.is(base_td); #else return false; #endif @@ -755,9 +769,12 @@ auto handle_torch_function_indexing( } py::object func = PyObject_FastGetAttrString(THPVariableClass, (char*)func_name); - py::object args = (val == nullptr) - ? py::make_tuple(py::handle(self), py::handle(index)) - : py::make_tuple(py::handle(self), py::handle(index), py::handle(val)); + py::tuple args; + if (val == nullptr) { + args = py::make_tuple(py::handle(self), py::handle(index)); + } else { + args = py::make_tuple(py::handle(self), py::handle(index), py::handle(val)); + } return handle_torch_function_no_python_arg_parser( overridable_args, args.ptr(), @@ -878,7 +895,7 @@ bool is_tensor_and_append_overloaded( static bool is_scalar_list( PyObject* obj, std::vector* overloaded_args = nullptr) { - auto tuple = six::isTuple(obj); + auto tuple = PyTuple_Check(obj); if (!(tuple || PyList_Check(obj))) { return false; } @@ -909,7 +926,7 @@ bool is_tensor_list_and_append_overloaded( std::vector* overloaded_args, size_t argnum, bool throw_error) { - auto tuple = six::isTuple(obj); + auto tuple = PyTuple_Check(obj); if (!(tuple || PyList_Check(obj))) { return false; } @@ -950,7 +967,7 @@ static bool is_float_or_symfloat(PyObject* obj) { static bool is_float_or_complex_list( PyObject* obj, std::vector* overloaded_args = nullptr) { - auto tuple = six::isTuple(obj); + auto tuple = PyTuple_Check(obj); if (!(tuple || PyList_Check(obj))) { return false; } @@ -1722,7 +1739,7 @@ bool FunctionSignature::parse( int64_t failed_idx = -1; bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd; - if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) { + if ((!obj && param.optional) || (Py_IsNone(obj) && param.allow_none)) { dst[i++] = nullptr; } else if (!obj) { if (raise_exception) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index 12b023c764749..1e040c0c5a9b5 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -65,7 +65,7 @@ #include #include #include -#include +#include #include #include @@ -420,14 +420,14 @@ inline at::Scalar PythonArgs::scalar(int i) { inline std::vector PythonArgs::scalarlist(int i) { if (!args[i]) return std::vector(); - auto tuple = six::isTuple(args[i]); - THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto tuple = PyTuple_Check(args[i]); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tuple || PyList_Check(args[i])); // NOLINTNEXTLINE(bugprone-branch-clone) - auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + auto size = tuple ? PyTuple_GET_SIZE(args[i]) : PyList_GET_SIZE(args[i]); std::vector res(size); for (const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) - : PyList_GET_ITEM(arg.get(), idx); + PyObject* obj = + tuple ? PyTuple_GET_ITEM(args[i], idx) : PyList_GET_ITEM(args[i], idx); res[idx] = scalar_slow(obj); } return res; @@ -450,14 +450,14 @@ inline std::optional PythonArgs::scalarOptional(int i) { inline std::vector PythonArgs::tensorlist(int i) { if (!args[i]) return std::vector(); - auto tuple = six::isTuple(args[i]); - THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto tuple = PyTuple_Check(args[i]); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tuple || PyList_Check(args[i])); // NOLINTNEXTLINE(bugprone-branch-clone) - auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + auto size = tuple ? PyTuple_GET_SIZE(args[i]) : PyList_GET_SIZE(args[i]); std::vector res(size); for (const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) - : PyList_GET_ITEM(arg.get(), idx); + PyObject* obj = + tuple ? PyTuple_GET_ITEM(args[i], idx) : PyList_GET_ITEM(args[i], idx); // This is checked by the argument parser so it's safe to cast without // checking if this is a tensor first res[idx] = THPVariable_Unpack(obj); @@ -469,15 +469,15 @@ inline torch::List> PythonArgs:: list_of_optional_tensors(int i) { if (!args[i]) return torch::List>(); - auto tuple = six::isTuple(args[i]); - THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto tuple = PyTuple_Check(args[i]); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tuple || PyList_Check(args[i])); // NOLINTNEXTLINE(bugprone-branch-clone) - auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + auto size = tuple ? PyTuple_GET_SIZE(args[i]) : PyList_GET_SIZE(args[i]); torch::List> res; res.reserve(size); for (const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) - : PyList_GET_ITEM(arg.get(), idx); + PyObject* obj = + tuple ? PyTuple_GET_ITEM(args[i], idx) : PyList_GET_ITEM(args[i], idx); // This is checked by the argument parser so it's safe to cast without // checking if this is a tensor first res.push_back(THPVariable_Unpack(obj)); @@ -490,18 +490,18 @@ inline std::array PythonArgs::tensorlist_n(int i) { auto res = std::array(); if (!args[i]) return res; - auto tuple = six::isTuple(args[i]); - THPObjectPtr arg = six::maybeAsTuple(args[i]); + auto tuple = PyTuple_Check(args[i]); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(tuple || PyList_Check(args[i])); // NOLINTNEXTLINE(bugprone-branch-clone) - auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get()); + auto size = tuple ? PyTuple_GET_SIZE(args[i]) : PyList_GET_SIZE(args[i]); if (size != N) { TORCH_CHECK_TYPE( false, fmt::format("expected tuple of {} elements but got {}", N, size)); } for (const auto idx : c10::irange(size)) { - PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) - : PyList_GET_ITEM(arg.get(), idx); + PyObject* obj = + tuple ? PyTuple_GET_ITEM(args[i], idx) : PyList_GET_ITEM(args[i], idx); // This is checked by the argument parser so it's safe to cast without // checking if this is a tensor first res[idx] = THPVariable_Unpack(obj); @@ -1097,10 +1097,10 @@ inline bool PythonArgs::toBool(int i) { if (!args[i]) { return signature.params[i].default_bool; } - if (args[i] == Py_True) { + if (Py_IsTrue(args[i])) { return true; } - if (args[i] == Py_False) { + if (Py_IsFalse(args[i])) { return false; } if (torch::is_symbool(py::handle(args[i]))) { diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index 0be7c4a3970af..c141aa928307a 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -30,15 +31,18 @@ namespace py = pybind11; namespace torch::impl::dispatch { // Global storage for leaked Python filenames to ensure they remain valid -// for the lifetime of Library objects -static std::vector leaked_python_filenames_; +// for the lifetime of Library objects. We use unique_ptr rather than +// plain string so that c_str() pointers handed to Library objects remain valid +// when the vector reallocates. +static c10::Synchronized>> + leaked_python_filenames_; // NB: I'd like to index this on OperatorHandle, but I can't, as I can't // guarantee that the main interpreter has finish doing all registrations before // the other interpreters start banging on it -static ska::flat_hash_map< +static c10::Synchronized>> + ska::flat_hash_map>>> python_registrations_; static torch::Library::Kind parseKind(const std::string& k) { @@ -424,10 +428,12 @@ void initDispatchBindings(PyObject* module) { std::make_unique( func, dispatch, with_keyset))), register_or_verify()); - python_registrations_[lib._resolve(name)].insert_or_assign( - dispatch, - std::make_shared( - func.release().ptr(), getPyInterpreter())); + python_registrations_.withLock([&](auto& regs) { + regs[lib._resolve(name)].insert_or_assign( + dispatch, + std::make_shared( + func.release().ptr(), getPyInterpreter())); + }); } END_HANDLE_TH_ERRORS_PYBIND }, @@ -512,8 +518,11 @@ void initDispatchBindings(PyObject* module) { HANDLE_TH_ERRORS // Store the file string in global storage to ensure it remains valid // for the lifetime of the Library object - leaked_python_filenames_.emplace_back(file); - const char* leaked_file = leaked_python_filenames_.back().c_str(); + const char* leaked_file = + leaked_python_filenames_.withLock([&](auto& filenames) { + filenames.push_back(std::make_unique(file)); + return filenames.back()->c_str(); + }); return std::make_unique( parseKind(kind), @@ -534,7 +543,7 @@ void initDispatchBindings(PyObject* module) { m.def( "_dispatch_clear_leaked_python_filenames", - []() { leaked_python_filenames_.clear(); }, + []() { leaked_python_filenames_.withLock([](auto& f) { f.clear(); }); }, "Clear the global storage of leaked Python filenames. " "WARNING: Only call this if you're sure no Library objects are still using the filenames."); @@ -1077,7 +1086,8 @@ void python_op_registration_trampoline_impl( auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); - const auto& func = python_registrations_[op.operator_name()][key]; + auto func = python_registrations_.withLock( + [&](auto& regs) { return regs[op.operator_name()][key]; }); TORCH_INTERNAL_ASSERT(func != nullptr); auto* pyobj = func->ptr(getPyInterpreter()); TORCH_INTERNAL_ASSERT(pyobj != nullptr); diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index c7e6cc29bf783..48d7b12b16c4a 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -115,9 +115,9 @@ inline int64_t THPUtils_unpackIndex(PyObject* obj) { } inline bool THPUtils_unpackBool(PyObject* obj) { - if (obj == Py_True) { + if (Py_IsTrue(obj)) { return true; - } else if (obj == Py_False) { + } else if (Py_IsFalse(obj)) { return false; } else { TORCH_CHECK(false, "couldn't convert python object to boolean"); diff --git a/torch/csrc/utils/six.h b/torch/csrc/utils/six.h deleted file mode 100644 index 9671e5156b9d4..0000000000000 --- a/torch/csrc/utils/six.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace six { - -// Usually instances of PyStructSequence is also an instance of tuple -// but in some py2 environment it is not, so we have to manually check -// the name of the type to determine if it is a namedtupled returned -// by a pytorch operator. - -inline bool isStructSeq(pybind11::handle input) { - return pybind11::cast(pybind11::type::handle_of(input).attr( - "__module__")) == "torch.return_types"; -} - -inline bool isStructSeq(PyObject* obj) { - return isStructSeq(pybind11::handle(obj)); -} - -inline bool isTuple(pybind11::handle input) { - if (PyTuple_Check(input.ptr())) { - return true; - } - return false; -} - -inline bool isTuple(PyObject* obj) { - return isTuple(pybind11::handle(obj)); -} - -// maybeAsTuple: if the input is a structseq, then convert it to a tuple -// -// On Python 3, structseq is a subtype of tuple, so these APIs could be used -// directly. But on Python 2, structseq is not a subtype of tuple, so we need to -// manually create a new tuple object from structseq. -inline THPObjectPtr maybeAsTuple(PyStructSequence* obj) { - Py_INCREF(obj); - return THPObjectPtr((PyObject*)obj); -} - -inline THPObjectPtr maybeAsTuple(PyObject* obj) { - if (isStructSeq(obj)) - return maybeAsTuple((PyStructSequence*)obj); - Py_INCREF(obj); - return THPObjectPtr(obj); -} - -} // namespace six diff --git a/torch/csrc/utils/structseq.cpp b/torch/csrc/utils/structseq.cpp index 2e804aa44bad9..ded49128cf284 100644 --- a/torch/csrc/utils/structseq.cpp +++ b/torch/csrc/utils/structseq.cpp @@ -12,7 +12,7 @@ * https://github.com/python/cpython#copyright-and-license-information */ -#include +#include #include #include @@ -20,19 +20,12 @@ namespace torch::utils { -// NOTE: The built-in repr method from PyStructSequence was updated in -// https://github.com/python/cpython/commit/c70ab02df2894c34da2223fc3798c0404b41fd79 -// so this function might not be required in Python 3.8+. PyObject* returned_structseq_repr(PyStructSequence* obj) { PyTypeObject* typ = Py_TYPE(obj); - THPObjectPtr tup = six::maybeAsTuple(obj); - if (tup == nullptr) { - return nullptr; - } + Py_ssize_t num_elements = PyTuple_GET_SIZE(obj); std::stringstream ss; ss << typ->tp_name << "(\n"; - Py_ssize_t num_elements = Py_SIZE(obj); for (Py_ssize_t i = 0; i < num_elements; i++) { const char* cname = typ->tp_members[i].name; @@ -46,11 +39,7 @@ PyObject* returned_structseq_repr(PyStructSequence* obj) { return nullptr; } - PyObject* val = PyTuple_GetItem(tup.get(), i); - if (val == nullptr) { - return nullptr; - } - + PyObject* val = PyTuple_GET_ITEM(obj, i); auto repr = THPObjectPtr(PyObject_Repr(val)); if (repr == nullptr) { return nullptr; diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp index 39df9be68868a..938c7bc7558d0 100644 --- a/torch/csrc/utils/tensor_dtypes.cpp +++ b/torch/csrc/utils/tensor_dtypes.cpp @@ -20,17 +20,15 @@ void initializeDtypes() { for (at::ScalarType scalarType : all_scalar_types) { auto [primary_name, legacy_name] = c10::getDtypeNames(scalarType); - PyObject* dtype = THPDtype_New(scalarType, primary_name); - torch::registerDtypeObject((THPDtype*)dtype, scalarType); - Py_INCREF(dtype); - if (PyModule_AddObject(torch_module.get(), primary_name.c_str(), dtype) != - 0) { + THPObjectPtr dtype(THPDtype_New(scalarType, primary_name)); + torch::registerDtypeObject((THPDtype*)dtype.get(), scalarType); + if (PyModule_AddObjectRef( + torch_module.get(), primary_name.c_str(), dtype.get()) != 0) { throw python_error(); } if (!legacy_name.empty()) { - Py_INCREF(dtype); - if (PyModule_AddObject(torch_module.get(), legacy_name.c_str(), dtype) != - 0) { + if (PyModule_AddObjectRef( + torch_module.get(), legacy_name.c_str(), dtype.get()) != 0) { throw python_error(); } } diff --git a/torch/csrc/utils/tensor_layouts.cpp b/torch/csrc/utils/tensor_layouts.cpp index d0bccbcf9106f..3b1e04cb08a30 100644 --- a/torch/csrc/utils/tensor_layouts.cpp +++ b/torch/csrc/utils/tensor_layouts.cpp @@ -6,49 +6,36 @@ namespace torch::utils { -#define REGISTER_LAYOUT(layout, LAYOUT) \ - PyObject* layout##_layout = \ - THPLayout_New(at::Layout::LAYOUT, "torch." #layout); \ - Py_INCREF(layout##_layout); \ - if (PyModule_AddObject(torch_module, "" #layout, layout##_layout) != 0) { \ - throw python_error(); \ - } \ - registerLayoutObject((THPLayout*)layout##_layout, at::Layout::LAYOUT); +static void registerLayout( + PyObject* torch_module, + at::Layout layout, + const char* name, + const char* qualified_name) { + THPObjectPtr obj(THPLayout_New(layout, qualified_name)); + if (PyModule_AddObjectRef(torch_module, name, obj.get()) != 0) { + throw python_error(); + } + registerLayoutObject((THPLayout*)obj.get(), layout); +} void initializeLayouts() { auto torch_module = THPObjectPtr(PyImport_ImportModule("torch")); if (!torch_module) throw python_error(); - PyObject* strided_layout = - THPLayout_New(at::Layout::Strided, "torch.strided"); - Py_INCREF(strided_layout); - if (PyModule_AddObject(torch_module, "strided", strided_layout) != 0) { - throw python_error(); - } - registerLayoutObject((THPLayout*)strided_layout, at::Layout::Strided); - - PyObject* sparse_coo_layout = - THPLayout_New(at::Layout::Sparse, "torch.sparse_coo"); - Py_INCREF(sparse_coo_layout); - if (PyModule_AddObject(torch_module, "sparse_coo", sparse_coo_layout) != 0) { - throw python_error(); - } - registerLayoutObject((THPLayout*)sparse_coo_layout, at::Layout::Sparse); - - REGISTER_LAYOUT(sparse_csr, SparseCsr) - REGISTER_LAYOUT(sparse_csc, SparseCsc) - REGISTER_LAYOUT(sparse_bsr, SparseBsr) - REGISTER_LAYOUT(sparse_bsc, SparseBsc) - - PyObject* mkldnn_layout = THPLayout_New(at::Layout::Mkldnn, "torch._mkldnn"); - Py_INCREF(mkldnn_layout); - if (PyModule_AddObject(torch_module, "_mkldnn", mkldnn_layout) != 0) { - throw python_error(); - } - registerLayoutObject((THPLayout*)mkldnn_layout, at::Layout::Mkldnn); - - REGISTER_LAYOUT(jagged, Jagged); + registerLayout(torch_module, at::Layout::Strided, "strided", "torch.strided"); + registerLayout( + torch_module, at::Layout::Sparse, "sparse_coo", "torch.sparse_coo"); + registerLayout( + torch_module, at::Layout::SparseCsr, "sparse_csr", "torch.sparse_csr"); + registerLayout( + torch_module, at::Layout::SparseCsc, "sparse_csc", "torch.sparse_csc"); + registerLayout( + torch_module, at::Layout::SparseBsr, "sparse_bsr", "torch.sparse_bsr"); + registerLayout( + torch_module, at::Layout::SparseBsc, "sparse_bsc", "torch.sparse_bsc"); + registerLayout(torch_module, at::Layout::Mkldnn, "_mkldnn", "torch._mkldnn"); + registerLayout(torch_module, at::Layout::Jagged, "jagged", "torch.jagged"); } } // namespace torch::utils diff --git a/torch/csrc/utils/tensor_memoryformats.cpp b/torch/csrc/utils/tensor_memoryformats.cpp index c1a3ff326493a..0ba5c8bd842c5 100644 --- a/torch/csrc/utils/tensor_memoryformats.cpp +++ b/torch/csrc/utils/tensor_memoryformats.cpp @@ -31,14 +31,12 @@ void initializeMemoryFormats() { auto add_memory_format = [&](at::MemoryFormat format, const char* name) { std::string module_name = "torch."; - PyObject* memory_format = THPMemoryFormat_New(format, module_name + name); - Py_INCREF(memory_format); - if (PyModule_AddObject(torch_module, name, memory_format) != 0) { - Py_DECREF(memory_format); + THPObjectPtr memory_format(THPMemoryFormat_New(format, module_name + name)); + if (PyModule_AddObjectRef(torch_module, name, memory_format.get()) != 0) { throw python_error(); } - Py_INCREF(memory_format); - memory_format_registry[static_cast(format)] = memory_format; + memory_format_registry[static_cast(format)] = + memory_format.release(); }; add_memory_format(at::MemoryFormat::Preserve, "preserve_format"); diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 8b758df7224a9..97ae280dfe332 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -1382,26 +1382,16 @@ static void _validate_sparse_compressed_tensor_args_template( ARG_SIZE, ARGS_COUNT }; - static std::string sig; - switch (required_layout) { - case c10::Layout::SparseCsr: - sig = - "_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseCsc: - sig = - "_validate_sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseBsr: - sig = - "_validate_sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; - break; - case c10::Layout::SparseBsc: - sig = - "_validate_sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; - break; - default:; - } + constexpr const char* sig = [] { + if constexpr (required_layout == c10::Layout::SparseCsr) + return "_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; + else if constexpr (required_layout == c10::Layout::SparseCsc) + return "_validate_sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; + else if constexpr (required_layout == c10::Layout::SparseBsr) + return "_validate_sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)"; + else if constexpr (required_layout == c10::Layout::SparseBsc) + return "_validate_sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)"; + }(); static PythonArgParser parser({sig}); ParsedArgs parsed_args; @@ -1744,7 +1734,7 @@ Tensor asarray( std::optional dtype, std::optional device, std::optional copy, - bool requires_grad) { + std::optional requires_grad) { Tensor tensor; bool force_copy = copy.value_or(false); @@ -1761,9 +1751,17 @@ Tensor asarray( if (THPVariable_Check(obj)) { tensor = THPVariable_Unpack(obj); } + bool return_requires_grad = + requires_grad.value_or(tensor.defined() ? tensor.requires_grad() : false); + if (return_requires_grad && !requires_grad) { + TORCH_WARN_ONCE( + "torch.asarray: unspecified requires_grad now defaults to obj.requires_grad " + "instead of False. Pass requires_grad=False explicitly to get the old behavior " + "and silence this warning.") + } #ifdef USE_NUMPY - if (is_numpy_available()) { + if (!tensor.defined() && is_numpy_available()) { // Check whether 'obj' is a NumPy Array or Scalar. bool is_numpy_array = PyArray_Check(obj); bool is_numpy_scalar = PyArray_CheckScalar(obj); @@ -1809,7 +1807,8 @@ Tensor asarray( // Check whether 'obj' implements the buffer protocol if (!tensor.defined() && PyObject_CheckBuffer(obj) != 0) { - tensor = tensor_frombuffer(obj, dtype_unwrapped, -1, 0, requires_grad); + tensor = + tensor_frombuffer(obj, dtype_unwrapped, -1, 0, return_requires_grad); } if (tensor.defined()) { @@ -1857,10 +1856,10 @@ Tensor asarray( // Setting 'requires_grad' when the tensor is not a leaf does not work. // Whenever that happens, we have to use 'detach'. - if (!tensor.is_leaf() && !requires_grad) { + if (!tensor.is_leaf() && !return_requires_grad) { tensor = tensor.detach(); } else { - tensor.set_requires_grad(requires_grad); + tensor.set_requires_grad(return_requires_grad); } } else { // Undefined tensor means it does not implement neither DLPack nor @@ -1881,7 +1880,7 @@ Tensor asarray( /* copy_variables = */ false, /* copy_numpy = */ false, /* type_inference = */ !dtype.has_value()); - tensor.set_requires_grad(requires_grad); + tensor.set_requires_grad(return_requires_grad); } return tensor; diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 8ae71fcde4cfb..903645c40dc1d 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -132,5 +132,5 @@ at::Tensor asarray( std::optional dtype, std::optional device, std::optional copy, - bool requires_grad); + std::optional requires_grad); } // namespace torch::utils diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 1813d623af5e9..d4cdbb8d9125d 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -469,7 +469,7 @@ at::Tensor tensor_from_cuda_array_interface( if (PyDict_GetItemStringRef(cuda_dict, "strides", &py_strides) < 0) { throw python_error(); } - if (py_strides != nullptr && py_strides != Py_None) { + if (py_strides != nullptr && !Py_IsNone(py_strides)) { if (PySequence_Length(py_strides) == -1 || static_cast(PySequence_Length(py_strides)) != sizes.size()) { TORCH_CHECK_TYPE( diff --git a/torch/csrc/utils/tensor_qschemes.cpp b/torch/csrc/utils/tensor_qschemes.cpp index f85d091bd57a0..6c86be1ff7f8a 100644 --- a/torch/csrc/utils/tensor_qschemes.cpp +++ b/torch/csrc/utils/tensor_qschemes.cpp @@ -20,13 +20,12 @@ void initializeQSchemes() { for (const auto i : c10::irange(at::COMPILE_TIME_NUM_QSCHEMES)) { auto qscheme = static_cast(i); - PyObject* qscheme_obj = THPQScheme_New(qscheme, toString(qscheme)); - thp_qscheme_array[static_cast(qscheme)] = qscheme_obj; - Py_INCREF(qscheme_obj); - if (PyModule_AddObject( - torch_module, toString(qscheme).c_str(), qscheme_obj) != 0) { + THPObjectPtr qscheme_obj(THPQScheme_New(qscheme, toString(qscheme))); + if (PyModule_AddObjectRef( + torch_module, toString(qscheme).c_str(), qscheme_obj.get()) != 0) { throw python_error(); } + thp_qscheme_array[static_cast(qscheme)] = qscheme_obj.release(); } } diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 620086f9ad50d..02414c8494c58 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -6,7 +6,6 @@ #include #include -#include #include #include @@ -76,20 +75,21 @@ std::string type_to_string(const at::DeprecatedTypeProperties& type) { return ss.str(); } +using TypeMap = std::unordered_map; + +static TypeMap build_type_map( + const std::vector& types) { + TypeMap m; + m.reserve(types.size()); + for (auto type : types) + m.emplace(type_to_string(*type), type); + return m; +} + at::TensorOptions options_from_string(const std::string& str) { - static std::string cuda_prefix("torch.cuda."); - static std::string xpu_prefix("torch.xpu."); - static std::string privateUser_prefix( - std::string(parse_privateuseone_backend()) + "."); - static std::unordered_map cpu_map; - static std::unordered_map xpu_map; - static std::unordered_map - cuda_map; - static std::unordered_map - privateUser1_map; - - const std::unordered_map* map = - nullptr; + static const std::string privateUser_prefix = + std::string(parse_privateuseone_backend()) + "."; + const TypeMap* map = nullptr; if (str == "torch.Tensor") { auto backend = @@ -98,46 +98,21 @@ at::TensorOptions options_from_string(const std::string& str) { return getDeprecatedTypeProperties(backend, scalar_type).options(); } - if (std::mismatch(cuda_prefix.begin(), cuda_prefix.end(), str.begin()) - .first == cuda_prefix.end()) { - // torch.cuda. is prefix of str - static bool cuda_once [[maybe_unused]] = []() { - for (auto type : autograd::VariableType::allCUDATypes()) { - cuda_map.emplace(type_to_string(*type), type); - } - return true; - }(); + if (str.starts_with("torch.cuda.")) { + static const auto cuda_map = + build_type_map(autograd::VariableType::allCUDATypes()); map = &cuda_map; - } else if ( - std::mismatch(xpu_prefix.begin(), xpu_prefix.end(), str.begin()).first == - xpu_prefix.end()) { - // torch.xpu. is prefix of str - static bool xpu_once [[maybe_unused]] = []() { - for (auto type : autograd::VariableType::allXPUTypes()) { - xpu_map.emplace(type_to_string(*type), type); - } - return true; - }(); + } else if (str.starts_with("torch.xpu.")) { + static const auto xpu_map = + build_type_map(autograd::VariableType::allXPUTypes()); map = &xpu_map; - } else if ( - std::mismatch( - privateUser_prefix.begin(), privateUser_prefix.end(), str.begin()) - .first == privateUser_prefix.end()) { - // torch.foo. foo is privateUser1 name - static bool privateUser1_once [[maybe_unused]] = []() { - for (auto type : autograd::VariableType::allPrivateUser1Types()) { - privateUser1_map.emplace(type_to_string(*type), type); - } - return true; - }(); + } else if (str.starts_with(privateUser_prefix)) { + static const auto privateUser1_map = + build_type_map(autograd::VariableType::allPrivateUser1Types()); map = &privateUser1_map; } else { - static bool cpu_once [[maybe_unused]] = []() { - for (auto type : autograd::VariableType::allCPUTypes()) { - cpu_map.emplace(type_to_string(*type), type); - } - return true; - }(); + static const auto cpu_map = + build_type_map(autograd::VariableType::allCPUTypes()); map = &cpu_map; } diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 0834b7d3e4f7c..759b82cc06fcb 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -336,6 +336,8 @@ static void registerXpuDeviceProperties(PyObject* module) { ._(gpu_eu_count) \ ._(max_work_group_size) \ ._(max_num_sub_groups) \ + ._(memory_clock_rate) \ + ._(memory_bus_width) \ ._(sub_group_sizes) \ ._(local_mem_size) \ ._(has_fp16) \ @@ -371,7 +373,9 @@ static void registerXpuDeviceProperties(PyObject* module) { << prop.global_mem_size / (1024ull * 1024) << "MB, local_mem_size=" << prop.local_mem_size / 1024ull << "KB, max_compute_units=" << prop.max_compute_units - << ", gpu_eu_count=" << prop.gpu_eu_count + << ", memory_clock_rate=" << prop.memory_clock_rate + << "MHz, memory_bus_width=" << prop.memory_bus_width + << "-bit, gpu_eu_count=" << prop.gpu_eu_count << ", gpu_subslice_count=" << gpu_subslice_count(prop) << ", max_work_group_size=" << prop.max_work_group_size << ", max_num_sub_groups=" << prop.max_num_sub_groups @@ -391,6 +395,14 @@ static void registerXpuPluggableAllocator(PyObject* module) { std::shared_ptr>( m, "_xpu_XPUAllocator"); + // Register concrete XPUPluggableAllocator type with inheritance + py::class_< + torch::xpu::XPUPluggableAllocator::XPUPluggableAllocator, + c10::xpu::XPUCachingAllocator::XPUAllocator, + std::shared_ptr< + torch::xpu::XPUPluggableAllocator::XPUPluggableAllocator>>( + m, "_XPUPluggableAllocator"); + m.def("_xpu_getAllocator", []() { return py::cast(torch::xpu::XPUPluggableAllocator::getCurrentAllocator()); }); diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 08ac01b4093ec..9b7a6be38941f 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -52,7 +52,9 @@ _is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) _HAS_PYNVML = False +_HAS_AMDSMI = False _PYNVML_ERR = None +_AMDSMI_ERR = None try: from torch import version as _version @@ -104,8 +106,13 @@ def __enter__(self) -> None: def __exit__(self, type: Any, value: Any, traceback: Any) -> None: ctypes.CDLL = self.original_CDLL # type: ignore[misc] - with _amdsmi_cdll_hook(): - import amdsmi # type: ignore[import] + try: + with _amdsmi_cdll_hook(): + import amdsmi # type: ignore[import] + _HAS_AMDSMI = True + except ModuleNotFoundError as err: + _AMDSMI_ERR = err + raise _HAS_PYNVML = True except ModuleNotFoundError: @@ -883,7 +890,7 @@ def parse_list_with_prefix(lst: str, prefix: str) -> list[str]: def _raw_device_count_amdsmi() -> int: - if not _HAS_PYNVML: # If amdsmi is not available + if not _HAS_AMDSMI: return -1 try: amdsmi.amdsmi_init() @@ -917,7 +924,7 @@ def _raw_device_count_nvml() -> int: def _raw_device_uuid_amdsmi() -> list[str] | None: from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer - if not _HAS_PYNVML: # If amdsmi is not available + if not _HAS_AMDSMI: return None try: amdsmi.amdsmi_init() @@ -1306,11 +1313,10 @@ def _get_pynvml_handler(device: Device = None): def _get_amdsmi_handler(device: Device = None): - if not _HAS_PYNVML: + if not _HAS_AMDSMI: raise ModuleNotFoundError( "amdsmi does not seem to be installed or it can't be imported." - # pyrefly: ignore [invalid-inheritance] - ) from _PYNVML_ERR + ) from _AMDSMI_ERR try: amdsmi.amdsmi_init() except amdsmi.AmdSmiException as e: @@ -1322,11 +1328,49 @@ def _get_amdsmi_handler(device: Device = None): return handle +_cached_hip_to_amdsmi: dict[int, int] | None = None + + +def _get_amdsmi_device_index_from_hip_index(device: int) -> int: + r"""Return amdsmi index from HIP device index. They are not always the same. + + Assume amdsmi_init() already completes successfully.""" + global _cached_hip_to_amdsmi + if _cached_hip_to_amdsmi is None: + amdsmi_handles = amdsmi.amdsmi_get_processor_handles() + + def gen(): + for amdsmi_idx, handle in enumerate(amdsmi_handles): + info = amdsmi.amdsmi_get_gpu_enumeration_info(handle) + if "hip_id" in info: + yield info["hip_id"], amdsmi_idx + + _cached_hip_to_amdsmi = dict(gen()) + if not _cached_hip_to_amdsmi and len(amdsmi_handles) > 1: + warnings.warn( + "Cannot translate HIP ID to AMD SMI ID due to" + " lack of translation information prior to ROCm 6.4." + " Functions that rely on amdsmi" + " (e.g. temperature()) may operate on wrong devices." + ) + if device not in _cached_hip_to_amdsmi: + warnings.warn( + f"Cannot translate HIP ID {device} to AMD SMI ID due to" + " undetected HIP ID from amdsmi." + " amdsmi_get_gpu_enumeration_info() only report these HIP IDs" + f" {list(_cached_hip_to_amdsmi.keys())}." + " Functions that rely on amdsmi" + " (e.g. temperature()) may operate on wrong devices." + ) + return _cached_hip_to_amdsmi.get(device, device) + + def _get_amdsmi_device_index(device: Device) -> int: r"""Return the amdsmi index of the device, taking visible_devices into account.""" idx = _get_device_index(device, optional=True) visible_devices = _parse_visible_devices() - if type(visible_devices[0]) is str: + visible_device_is_str = type(visible_devices[0]) is str + if visible_device_is_str: uuids = _raw_device_uuid_amdsmi() if uuids is None: raise RuntimeError("Can't get device UUIDs") @@ -1339,7 +1383,10 @@ def _get_amdsmi_device_index(device: Device) -> int: raise RuntimeError( f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" ) - return idx_map[idx] + if visible_device_is_str: + return idx_map[idx] + else: + return _get_amdsmi_device_index_from_hip_index(idx_map[idx]) def _get_amdsmi_device_memory_used(device: Device = None) -> int: @@ -1937,6 +1984,7 @@ def _compile_kernel( "amp", "caching_allocator_alloc", "caching_allocator_delete", + "caching_allocator_disabled", "caching_allocator_enable", "can_device_access_peer", "check_error", diff --git a/torch/cuda/_annotate_cuda_graph_trace.py b/torch/cuda/_annotate_cuda_graph_trace.py new file mode 100644 index 0000000000000..92a0c665a027c --- /dev/null +++ b/torch/cuda/_annotate_cuda_graph_trace.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +"""Post-process a profiler trace to add CUDA graph kernel annotations. + +Reads a profiler trace (gzipped or plain JSON) and a kernel annotations +pickle, matches kernel events by their graph node id, and writes an +annotated trace with the annotation fields added to each kernel event's +args (displayed alongside grid/block size in trace viewers). + +The annotations pickle is auto-discovered from the trace file's parent +directory (one level up, matching the rank from the trace filename). + +Usage: + python -m torch.cuda._annotate_cuda_graph_trace [-a ] [-o ] + +Examples: + # Auto-discover annotations pickle from trace location + python -m torch.cuda._annotate_cuda_graph_trace \\ + traces/step_000000000014/000000.*.pt.trace.json.gz + + # Explicit annotations pickle + python -m torch.cuda._annotate_cuda_graph_trace trace.json.gz -a annotations.pkl +""" + +import argparse +import gzip +import json +import pickle +import re +import sys +from collections import defaultdict +from pathlib import Path +from typing import Any + + +_WORK_CATEGORIES = {"kernel", "gpu_memcpy", "gpu_memset"} + + +def _move_overlapping_to_stream( + trace: dict, default_stream: int = 7, overlap_stream: int = 8 +) -> int: + """Move graphed kernels that overlap with their predecessor to a separate stream. + + Perfetto cannot display overlapping (non-nested) events on the same + stream -- they get hidden. This pass detects graphed kernel events on + *default_stream* whose start timestamp falls before the previous + kernel's end, and moves them to *overlap_stream* so they're visible. + + Returns the number of events moved. + """ + graphed_on_default = [ + e + for e in trace["traceEvents"] + if e.get("cat") == "kernel" + and e.get("tid") == default_stream + and e.get("args", {}).get("graph node id", 0) != 0 + ] + graphed_on_default.sort(key=lambda e: e["ts"]) + + moved = 0 + prev_end = 0.0 + for event in graphed_on_default: + ts = event["ts"] + dur = event.get("dur", 0) + if ts < prev_end: + event["tid"] = overlap_stream + event.get("args", {})["stream"] = overlap_stream + moved += 1 + else: + prev_end = ts + dur + + return moved + + +def _fix_overlapping_timestamps(trace: dict, max_adjust_us: float = 1.0) -> int: + """Clamp graphed kernel/memcpy timestamps so they don't overlap on the same stream. + + CUPTI can produce slightly overlapping timestamps for consecutive graphed + events, causing Perfetto to hide events that sit entirely "under" their + neighbours. This pass sorts graphed work events per stream and ensures + each event starts at or after the previous event's end. + + Overlaps larger than *max_adjust_us* are flagged as warnings and left + unchanged, since they likely indicate a real issue rather than CUPTI + timestamp jitter. + + Returns the number of events adjusted. + """ + per_stream: dict[int, list[dict]] = defaultdict(list) + for event in trace["traceEvents"]: + if ( + event.get("cat") in _WORK_CATEGORIES + and event.get("args", {}).get("graph node id", 0) != 0 + ): + per_stream[event.get("tid")].append(event) + + adjusted = 0 + for tid, events in per_stream.items(): + events.sort(key=lambda e: e["ts"]) + prev_end = 0.0 + for event in events: + ts = event["ts"] + dur = event.get("dur", 0) + if ts < prev_end: + overlap = prev_end - ts + if overlap > max_adjust_us: + print( + f"WARNING: large overlap {overlap:.3f}us on stream {tid} " + f"for {event.get('name', '?')[:60]}, skipping adjustment" + ) + else: + event["ts"] = prev_end + adjusted += 1 + prev_end = event["ts"] + dur + + return adjusted + + +def annotate_trace( + trace: dict, + annotations: dict[int, list[Any]], + default_stream: int = 7, +) -> int: + """Add annotation fields to kernel events matching the annotations dict. + + Each annotation entry is a list (from nested ``mark_kernels`` scopes). + Fields from all annotations are merged into the event args; if multiple + annotations define the same key, later entries in the list win. + + For graphed events (graph_node_id != 0), reassigns ``tid`` and + ``args["stream"]`` to the stream recorded in annotations, or to + *default_stream* if there is no annotation. Also moves the + corresponding ``ac2g`` flow-finish events to the new tid so that + CPU-to-GPU correlation arrows are preserved. + + Removes ``gpu_user_annotation`` events and orphaned ``ac2g`` events + from streams that have no kernel or memcpy events after reassignment, + since CUPTI replicates these onto every stream during graph replay. + + Returns the number of events annotated. + """ + # Build an index of ac2g 'f' events keyed by (tid, ts) so we can + # move them together with the kernel events they correspond to. + ac2g_f_index: dict[tuple, list] = {} + for event in trace["traceEvents"]: + if event.get("cat") == "ac2g" and event.get("ph") == "f": + key = (event.get("tid"), event.get("ts")) + ac2g_f_index.setdefault(key, []).append(event) + + annotated = 0 + for event in trace.get("traceEvents", []): + args = event.get("args", {}) + graph_node_id = args.get("graph node id") + if graph_node_id is None or graph_node_id == 0: + continue + stream_id = None + if graph_node_id in annotations: + for ann in annotations[graph_node_id]: + if isinstance(ann, dict): + for key, value in ann.items(): + args[key] = str(value) + if "stream" in ann: + stream_id = int(ann["stream"]) + else: + args["annotation"] = str(ann) + annotated += 1 + + # Reassign stream: use annotated stream if available, else default + if stream_id is None: + stream_id = default_stream + old_key = (event.get("tid"), event.get("ts")) + event["tid"] = stream_id + args["stream"] = stream_id + + # Move the matching ac2g 'f' event(s) to the same new tid + for ac2g_event in ac2g_f_index.get(old_key, ()): + ac2g_event["tid"] = stream_id + + # Remove gpu_user_annotation events and ac2g flow-finish events from + # streams that have no real kernel/memcpy/memset work -- these are + # noise replicated by CUPTI onto every stream during graph replay. + tids_with_work = set() + for event in trace["traceEvents"]: + if event.get("cat") in _WORK_CATEGORIES: + tids_with_work.add(event.get("tid")) + + def _is_noise(event: dict) -> bool: + cat = event.get("cat") + if cat == "gpu_user_annotation": + return event.get("tid") not in tids_with_work + if cat == "ac2g" and event.get("ph") == "f": + return event.get("tid") not in tids_with_work + return False + + original_count = len(trace["traceEvents"]) + trace["traceEvents"] = [ + event for event in trace["traceEvents"] if not _is_noise(event) + ] + removed = original_count - len(trace["traceEvents"]) + if removed: + print(f"Removed {removed} noise events from empty streams") + + # Clean up metadata: remove thread_name / thread_sort_index entries + # for noise streams that have no real (non-metadata) events, and add + # thread_name entries for our new annotation streams. + all_tids_in_trace = { + e.get("tid") for e in trace["traceEvents"] if e.get("ph") != "M" + } + # Find the GPU process pid from existing thread_name metadata + gpu_pid = 0 + for event in trace["traceEvents"]: + if ( + event.get("ph") == "M" + and event.get("name") == "thread_name" + and str(event.get("args", {}).get("name", "")).startswith("stream ") + ): + gpu_pid = event.get("pid", 0) + break + + # Remove metadata entries for tids with no non-metadata events + trace["traceEvents"] = [ + event + for event in trace["traceEvents"] + if event.get("ph") != "M" or event.get("tid") in all_tids_in_trace + ] + + # Add thread_name metadata for new annotation tids that lack one + existing_thread_names = { + e.get("tid") + for e in trace["traceEvents"] + if e.get("ph") == "M" and e.get("name") == "thread_name" + } + for tid in sorted(tids_with_work - existing_thread_names): + trace["traceEvents"].append( + { + "ph": "M", + "pid": gpu_pid, + "tid": tid, + "name": "thread_name", + "args": {"name": f"stream {tid}"}, + } + ) + + return annotated + + +def load_trace(path: Path) -> dict: + if path.suffix == ".gz" or path.name.endswith(".json.gz"): + with gzip.open(path, "rt") as f: + return json.load(f) + else: + with open(path) as f: + return json.load(f) + + +def save_trace(trace: dict, path: Path) -> None: + if path.suffix == ".gz" or path.name.endswith(".json.gz"): + with gzip.open(path, "wt") as f: + json.dump(trace, f) + else: + with open(path, "w") as f: + json.dump(trace, f) + + +def _find_annotations_pkl(trace_file: Path) -> Path | None: + """Auto-discover the annotations pickle from the trace file location. + + Trace files live in e.g. ``traces/step_000000000014/000000..pt.trace.json.gz`` + where the leading digits are the rank. The pickle lives one level up: + ``traces/kernel_annotations_rank0_*.pkl``. + """ + match = re.match(r"^(\d+)", trace_file.name) + if not match: + return None + rank = int(match.group(1)) + + traces_dir = trace_file.parent.parent + candidates = sorted(traces_dir.glob(f"kernel_annotations_rank{rank}_*.pkl")) + if candidates: + return candidates[0] + return None + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Annotate a profiler trace with CUDA graph kernel annotations." + ) + parser.add_argument( + "trace_file", type=Path, help="Input trace file (.json or .json.gz)" + ) + parser.add_argument( + "-a", + "--annotations", + type=Path, + default=None, + help="Kernel annotations pickle file. Auto-discovered if omitted.", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Output file path. Defaults to .annotated.", + ) + parser.add_argument( + "--default-stream", + type=int, + default=7, + help="Stream ID to assign to unannotated graphed events (default: 7).", + ) + args = parser.parse_args() + + annotations_pkl = args.annotations + if annotations_pkl is None: + annotations_pkl = _find_annotations_pkl(args.trace_file) + if annotations_pkl is None: + print( + f"Could not auto-discover annotations pickle for {args.trace_file}. " + f"Use -a to specify it explicitly.", + file=sys.stderr, + ) + sys.exit(1) + print(f"Auto-discovered annotations: {annotations_pkl}") + + with open(annotations_pkl, "rb") as f: + annotations = pickle.load(f) + print(f"Loaded {len(annotations)} kernel annotations") + + trace = load_trace(args.trace_file) + total_events = len(trace.get("traceEvents", [])) + print(f"Loaded trace with {total_events} events") + + count = annotate_trace(trace, annotations, default_stream=args.default_stream) + print(f"Annotated {count} kernel events") + + overlap_moved = _move_overlapping_to_stream( + trace, default_stream=args.default_stream + ) + if overlap_moved: + print(f"Moved {overlap_moved} overlapping events to stream 8") + + ts_fixed = _fix_overlapping_timestamps(trace) + if ts_fixed: + print(f"Fixed {ts_fixed} overlapping graphed event timestamps") + + output = args.output + if output is None: + name = args.trace_file.name + if name.endswith(".json.gz"): + output = args.trace_file.with_name( + name.replace(".json.gz", ".annotated.json.gz") + ) + elif name.endswith(".json"): + output = args.trace_file.with_suffix(".annotated.json") + else: + output = args.trace_file.with_suffix(args.trace_file.suffix + ".annotated") + + save_trace(trace, output) + print(f"Saved annotated trace to {output}") + + +if __name__ == "__main__": + main() diff --git a/torch/cuda/_graph_annotations.py b/torch/cuda/_graph_annotations.py new file mode 100644 index 0000000000000..16a895686cf41 --- /dev/null +++ b/torch/cuda/_graph_annotations.py @@ -0,0 +1,434 @@ +"""Annotate CUDA graph kernel nodes during capture. + +During CUDA graph capture, ``mark_kernels`` uses ``cudaGraphGetNodes`` +to count nodes before and after the wrapped region. Nodes at indices +``[before, after)`` are the ones added within the scope. Each kernel +or memcpy node found is annotated by its ``toolsId`` so it can later +be matched to profiler trace events. + +The annotations can be pickled and later merged into a Chrome profiler +trace using ``torch.cuda._annotate_cuda_graph_trace``. + +Requires ``cuda.bindings`` package and a CUDA driver that supports +``cudaGraphNodeGetToolsId`` (CUDA >= 13.1 or appropriate cuda-compat). +When unavailable, ``mark_kernels`` silently becomes a no-op. + +Usage during capture:: + + from torch.cuda._graph_annotations import ( + enable_annotations, + mark_kernels, + resolve_pending_annotations, + remap_to_exec_graph, + ) + + enable_annotations() + + with torch.cuda.graph(graph): + with mark_kernels("phase_A"): + y = workload_a(x) + with mark_kernels("phase_B"): + z = workload_b(y) + resolve_pending_annotations() + + remap_to_exec_graph(graph) +""" + +from collections import defaultdict +from contextlib import contextmanager +from logging import getLogger +from typing import Any + +import torch +from torch.cuda._utils import _check_cuda_bindings, _HAS_CUDA_BINDINGS + + +try: + from cuda.bindings import ( # pyrefly: ignore[missing-import] + runtime as _cuda_runtime, + ) +except ImportError: + _cuda_runtime = None # type: ignore[assignment] + + +logger = getLogger(__name__) + + +# Tri-state: None = not probed, True = available, False = unavailable. +# Deferred to first use to avoid premature CUDA initialization. +_tools_id_available: bool | None = None + +# Global kill switch. When False, mark_kernels and mark_stream are no-ops. +_annotations_enabled: bool = False + + +def enable_annotations() -> None: + """Enable kernel annotation recording.""" + global _annotations_enabled + _annotations_enabled = True + + +def disable_annotations() -> None: + """Disable kernel annotation recording.""" + global _annotations_enabled + _annotations_enabled = False + + +def _is_tools_id_unavailable() -> bool: + """Return True if we already know cudaGraphNodeGetToolsId is missing.""" + if not _HAS_CUDA_BINDINGS: + return True + if _tools_id_available is False: + return True + if not hasattr(_cuda_runtime, "cudaGraphNodeGetToolsId"): + return True + return False + + +def _get_tools_id(node: Any) -> int | None: + """Return the toolsId for a graph node, or None if unavailable.""" + global _tools_id_available + if _tools_id_available is None: + try: + tools_id = _check_cuda_bindings( + _cuda_runtime.cudaGraphNodeGetToolsId( # pyrefly: ignore[missing-attribute] + node + ) + ) + except Exception: + _tools_id_available = False + logger.info( + "cudaGraphNodeGetToolsId not available; " + "CUDA graph kernel annotations will be disabled" + ) + return None + _tools_id_available = True + return tools_id + return _check_cuda_bindings( + _cuda_runtime.cudaGraphNodeGetToolsId( # pyrefly: ignore[missing-attribute] + node + ) + ) + + +def _get_capture_graph(stream: Any) -> Any: + """Return the graph handle for the active capture, or None.""" + status, _id, graph, _deps, _edge_data, _num_deps = _check_cuda_bindings( + _cuda_runtime.cudaStreamGetCaptureInfo( # pyrefly: ignore[missing-attribute] + stream + ) + ) + if ( + status + != _cuda_runtime.cudaStreamCaptureStatus.cudaStreamCaptureStatusActive # pyrefly: ignore[missing-attribute] + ): + return None + return graph + + +def _get_node_count(graph: Any) -> int: + """Return the number of nodes currently in the graph.""" + _, num = _check_cuda_bindings( + _cuda_runtime.cudaGraphGetNodes( # pyrefly: ignore[missing-attribute] + graph, numNodes=0 + ) + ) + return num + + +# toolsId -> list of annotation objects. +_kernel_annotations: defaultdict[int, list[Any]] = defaultdict(list) + +# Node types we annotate. Initialized lazily to avoid touching cuda.bindings +# at import time. +_ANNOTATABLE_TYPES: set[Any] | None = None + + +def _get_annotatable_types() -> set[Any]: + global _ANNOTATABLE_TYPES + if _ANNOTATABLE_TYPES is None: + _ANNOTATABLE_TYPES = { + _cuda_runtime.cudaGraphNodeType.cudaGraphNodeTypeKernel, # pyrefly: ignore[missing-attribute] + _cuda_runtime.cudaGraphNodeType.cudaGraphNodeTypeMemcpy, # pyrefly: ignore[missing-attribute] + } + return _ANNOTATABLE_TYPES + + +# Pending scopes: (annotation, start_node_index, end_node_index). +_pending_scopes: list[tuple[Any, int, int]] = [] + +# Graph handle saved during capture for post-capture resolution. +_capture_graph: Any = None + +# Capture graph ID saved by resolve_pending_annotations for remap_to_exec_graph. +_last_capture_graph_id: int | None = None + + +@contextmanager # type: ignore[arg-type] +def mark_kernels(annotation: str | dict[str, Any]): + """Context manager that records node index ranges for later annotation. + + During capture, calls ``cudaGraphGetNodes`` to count graph nodes before + and after the scope. Nodes at indices ``[before, after)`` were added + inside the scope. After capture, ``resolve_pending_annotations`` + enumerates all nodes and annotates kernel/memcpy nodes in those ranges. + + Must be called inside an active ``torch.cuda.graph()`` capture. If the + current stream is not capturing, or if ``cudaGraphNodeGetToolsId`` is not + available, the context manager is a no-op. + + Args: + annotation: Arbitrary object appended to the annotation list for + every kernel/memcpy node whose index falls within this scope. + """ + if not _annotations_enabled or _is_tools_id_unavailable(): + yield + return + + if isinstance(annotation, str): + annotation = {"str": annotation} + + stream = _cuda_runtime.cudaStream_t( # pyrefly: ignore[missing-attribute] + init_value=torch.cuda.current_stream().cuda_stream + ) + graph = _get_capture_graph(stream) + if graph is None: + yield + return + + global _capture_graph + _capture_graph = graph + + start_count = _get_node_count(graph) + + yield + + end_count = _get_node_count(graph) + + if end_count > start_count: + _pending_scopes.append((annotation, start_count, end_count)) + + +def resolve_pending_annotations() -> None: + """Resolve pending scope index ranges into kernel annotations. + + Enumerates all graph nodes and annotates kernel/memcpy nodes whose + indices fall within recorded scope ranges. Must be called while still + inside the ``torch.cuda.graph()`` capture context. + """ + global _capture_graph + if not _pending_scopes: + _capture_graph = None + return + + # Get a fresh graph handle from the active capture. + stream = _cuda_runtime.cudaStream_t( # pyrefly: ignore[missing-attribute] + init_value=torch.cuda.current_stream().cuda_stream + ) + graph = _get_capture_graph(stream) + if graph is None: + graph = _capture_graph + if graph is None: + logger.warning("resolve_pending_annotations: no graph handle available") + _pending_scopes.clear() + return + + try: + num = _get_node_count(graph) + if num == 0: + _pending_scopes.clear() + _capture_graph = None + return + + nodes, num = _check_cuda_bindings( + _cuda_runtime.cudaGraphGetNodes( # pyrefly: ignore[missing-attribute] + graph, numNodes=num + ) + ) + + # Save capture graph ID for remap_to_exec_graph. + global _last_capture_graph_id + if num > 0: + first_tid = _get_tools_id(nodes[0]) + _last_capture_graph_id = (first_tid >> 32) if first_tid else None + + annotatable = _get_annotatable_types() + + # Sort by (start, -end, -append_index). The append index encodes + # nesting depth: inner context managers exit first, so they are + # appended to _pending_scopes first (smaller index). Using + # -append_index as tiebreaker ensures that for same-range scopes + # the outer scope (larger index) sorts first and is pushed onto + # the stack first, leaving the inner scope on top. + sorted_scopes = sorted( + ( + (ann, start, end, i) + for i, (ann, start, end) in enumerate(_pending_scopes) + ), + key=lambda s: (s[1], -s[2], -s[3]), + ) + scope_ptr = 0 + active_stack: list[tuple[int, Any]] = [] # (end_idx, annotation) + + for i in range(num): + # Pop scopes whose range ended. + while active_stack and active_stack[-1][0] <= i: + active_stack.pop() + + # Push scopes that start at or before this index. + while scope_ptr < len(sorted_scopes) and sorted_scopes[scope_ptr][1] <= i: + ann, _start_idx, end_idx, _idx = sorted_scopes[scope_ptr] + if end_idx > i: + active_stack.append((end_idx, ann)) + scope_ptr += 1 + + if not active_stack: + continue + + node = nodes[i] + node_type = _check_cuda_bindings( + _cuda_runtime.cudaGraphNodeGetType( # pyrefly: ignore[missing-attribute] + node + ) + ) + if node_type not in annotatable: + continue + + tools_id = _get_tools_id(node) + if tools_id is None: + logger.warning( + "resolve_pending_annotations: toolsId unavailable, aborting" + ) + _pending_scopes.clear() + _capture_graph = None + return + + if len(active_stack) == 1: + _kernel_annotations[tools_id].append(active_stack[0][1]) + else: + # Merge all active scopes into one dict. Inner scopes sit + # on top of the stack. Iterating reversed (inner first) + # with setdefault lets the inner scope's values win for + # overlapping keys (e.g. name, stream) while outer scopes + # fill in any missing keys. + merged: dict[str, Any] = {} + for _, ann in reversed(active_stack): + if isinstance(ann, dict): + for ak, av in ann.items(): + merged.setdefault(ak, av) + else: + merged.setdefault("name", ann) + _kernel_annotations[tools_id].append(merged) + except Exception: + logger.exception("resolve_pending_annotations failed") + finally: + _pending_scopes.clear() + _capture_graph = None + + +def remap_to_exec_graph(torch_cuda_graph: torch.cuda.CUDAGraph) -> None: + """Remap annotation keys from capture graph ID to exec graph ID. + + During capture, toolsId encodes the capture graph's ID in the upper + 32 bits. After instantiation, the profiler uses the exec graph's ID. + This function rewrites the keys so annotations match the trace. + + Must be called after the ``torch.cuda.graph()`` context exits. + """ + if not _kernel_annotations: + return + + exec_handle = _cuda_runtime.cudaGraphExec_t( # pyrefly: ignore[missing-attribute] + init_value=torch_cuda_graph.raw_cuda_graph_exec() + ) + exec_graph_id = _check_cuda_bindings( + _cuda_runtime.cudaGraphExecGetId( # pyrefly: ignore[missing-attribute] + exec_handle + ) + ) + + # Only remap annotations from the most recent capture graph. + # Previously remapped annotations (from earlier captures) keep their + # correct exec graph IDs. + capture_graph_id = _last_capture_graph_id + remapped: dict[int, list[Any]] = {} + for tools_id, ann_list in _kernel_annotations.items(): + graph_id = tools_id >> 32 + if capture_graph_id is not None and graph_id != capture_graph_id: + # Belongs to a different graph — keep as-is. + remapped[tools_id] = ann_list + continue + node_id = tools_id & 0xFFFFFFFF + new_tools_id = (exec_graph_id << 32) | node_id + if new_tools_id in remapped: + remapped[new_tools_id].extend(ann_list) + else: + remapped[new_tools_id] = list(ann_list) + + _kernel_annotations.clear() + _kernel_annotations.update(remapped) + + +def get_kernel_annotations() -> dict[int, list[Any]]: + """Return the current kernel annotation map (toolsId -> annotations).""" + return _kernel_annotations + + +def clear_kernel_annotations() -> None: + """Clear all recorded kernel annotations and pending scopes.""" + global _capture_graph + _kernel_annotations.clear() + _pending_scopes.clear() + _capture_graph = None + + +# Counter-based stream ID registry. IDs start at 60 (above the highest +# observed non-graphed CUDA stream ID) so every assigned lane is visually +# distinct in Perfetto and doesn't collide with real streams. +_stream_id_counter: int = 60 +_stream_id_map: dict[int, int] = {} + + +def _get_stream_id(stream: torch.cuda.Stream) -> int: + """Return a small, stable stream ID for the given CUDA stream.""" + global _stream_id_counter + key = stream.cuda_stream + if key not in _stream_id_map: + _stream_id_map[key] = _stream_id_counter + _stream_id_counter += 1 + return _stream_id_map[key] + + +def get_stream_for_pg(pg_key: str) -> int: + """Return a unique stream ID for the given process group key.""" + global _stream_id_counter + if pg_key not in _stream_id_map: + _stream_id_map[pg_key] = _stream_id_counter # type: ignore[assignment] + _stream_id_counter += 1 + return _stream_id_map[pg_key] # type: ignore[return-value] + + +@contextmanager # type: ignore[arg-type] +def mark_stream(stream: torch.cuda.Stream, annotation: str | dict[str, Any]): + """Switch to stream, inject its ID into annotation, and mark kernels. + + If *stream* is already the current stream, no stream switch or stream ID + injection happens — the kernels stay on whatever stream is active (which + keeps the trace faithful when e.g. FSDP uses the current stream for + copy-in instead of a separate one). + """ + if not _annotations_enabled: + with torch.cuda.stream(stream): + yield + return + if stream.cuda_stream == torch.cuda.current_stream().cuda_stream: + with mark_kernels(annotation): + yield + else: + if isinstance(annotation, str): + annotation = {"str": annotation} + if isinstance(annotation, dict): + annotation["stream"] = _get_stream_id(stream) + with torch.cuda.stream(stream): + with mark_kernels(annotation): + yield diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index f58dd988851bd..10dc9efa11422 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -105,7 +105,7 @@ def format_flamegraph(flamegraph_lines, flamegraph_script=None): try: os.chmod(f.name, 0o755) os.rename(f.name, flamegraph_script) - except OSError: # noqa: B001,E722 + except OSError: # Ok to skip, the file will be removed by tempfile pass args = [flamegraph_script, "--countname", "bytes"] diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index d842e8b56ef41..8d9501afcbff3 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -4,6 +4,17 @@ import torch + +try: + from cuda.bindings import ( # pyrefly: ignore[missing-import] + runtime as _cuda_bindings_runtime, + ) + + _HAS_CUDA_BINDINGS = True +except ImportError: + _cuda_bindings_runtime = None # type: ignore[assignment] + _HAS_CUDA_BINDINGS = False + # The _get_device_index has been moved to torch.utils._get_device_index from torch._utils import _get_device_index as _torch_get_device_index @@ -59,6 +70,35 @@ def _check_cuda(result: int) -> None: raise RuntimeError(f"CUDA error: {error_message}") +def _check_cuda_bindings(result: Any) -> Any: + """Check a cuda.bindings (cuda-python) call result for errors. + + All cuda.bindings runtime calls return ``(error, *outputs)``. This + helper unpacks the tuple, raises on non-success, and returns the + outputs (``None`` for zero outputs, scalar for one, tuple otherwise). + """ + if not _HAS_CUDA_BINDINGS: + raise RuntimeError("cuda.bindings is not available") + err, *out = result + if ( + err + != _cuda_bindings_runtime.cudaError_t.cudaSuccess # pyrefly: ignore[missing-attribute] + ): + _, err_str = ( + _cuda_bindings_runtime.cudaGetErrorString( # pyrefly: ignore[missing-attribute] + err + ) + ) + if isinstance(err_str, bytes): + err_str = err_str.decode() + raise RuntimeError(f"CUDA error: {err} ({err_str})") + if len(out) == 0: + return None + if len(out) == 1: + return out[0] + return out + + def _get_hiprtc_library() -> ctypes.CDLL: try: # pyrefly: ignore [import-error, missing-import] diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 43d559ddc2440..ae896ef0aa51f 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -111,7 +111,7 @@ def capture_begin( may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for actions in the current thread, and "relaxed" will not error on these actions. Do NOT change this setting unless you're familiar with `cudaStreamCaptureMode `_ - """ # noqa: B950 + """ super().capture_begin(pool=pool, capture_error_mode=capture_error_mode) def capture_end(self) -> None: @@ -168,14 +168,14 @@ def raw_cuda_graph(self) -> int: r"""Returns the underlying cudaGraph_t. ``keep_graph`` must be True. See the following for APIs for how to manipulate this object: `Graph Managmement `_ and `cuda-python Graph Management bindings `_ - """ # noqa: B950 + """ return super().raw_cuda_graph() def raw_cuda_graph_exec(self) -> int: r"""Returns the underlying cudaGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_cuda_graph_exec()``, the previously returned cudaGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction. See the following for APIs for how to manipulate this object: `Graph Execution `_ and `cuda-python Graph Execution bindings `_ - """ # noqa: B950 + """ return super().raw_cuda_graph_exec() @@ -197,6 +197,12 @@ class graph: may be unsafe. "global" will error on actions in other threads, "thread_local" will only error for actions in the current thread, and "relaxed" will not error on actions. Do NOT change this setting unless you're familiar with `cudaStreamCaptureMode `_ + enable_annotations (bool, optional): If ``True``, enables kernel annotation + recording on entry and automatically calls + :func:`~torch.cuda._graph_annotations.resolve_pending_annotations` before + the capture ends. Annotations are **not** cleared on exit so that multiple + graphs in the same workload can accumulate annotations. + Requires ``cuda.bindings`` package and cuda-compat >= 13.1 or CUDA driver >= 13.1. .. note:: For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture @@ -207,7 +213,7 @@ class graph: .. _cudaStreamCaptureMode: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 - """ # noqa: B950 + """ default_capture_stream: torch.cuda.Stream | None = None @@ -217,6 +223,7 @@ def __init__( pool: _POOL_HANDLE | None = None, stream: torch.cuda.Stream | None = None, capture_error_mode: str = "global", + enable_annotations: bool = False, ): # Lazy-init of default_capture_stream helps avoid circular-import errors. # Not thread safe, but graphs already have the general (explicitly documented) @@ -233,6 +240,7 @@ def __init__( self.stream_ctx = torch.cuda.stream(self.capture_stream) self.cuda_graph = cuda_graph self.capture_error_mode = capture_error_mode + self._enable_annotations = enable_annotations def __enter__(self) -> None: # Free as much memory as we can for the graph @@ -250,6 +258,11 @@ def __enter__(self) -> None: # pyrefly: ignore [missing-attribute] torch._C._host_emptyCache() + if self._enable_annotations: + from torch.cuda._graph_annotations import enable_annotations as _enable_ann + + _enable_ann() + # Stackoverflow seems comfortable with this pattern # https://stackoverflow.com/questions/26635684/calling-enter-and-exit-manually#39172487 self.stream_ctx.__enter__() @@ -262,8 +275,18 @@ def __enter__(self) -> None: ) def __exit__(self, *args: object) -> None: + if self._enable_annotations: + from torch.cuda._graph_annotations import resolve_pending_annotations + + resolve_pending_annotations() + self.cuda_graph.capture_end() self.stream_ctx.__exit__(*args) + + if self._enable_annotations: + from torch.cuda._graph_annotations import remap_to_exec_graph + + remap_to_exec_graph(self.cuda_graph) # returning None should propagate exceptions from either capture_end or stream_ctx.__exit__() diff --git a/torch/cuda/green_contexts.py b/torch/cuda/green_contexts.py index c08d2c061e404..c7aab1b749db1 100644 --- a/torch/cuda/green_contexts.py +++ b/torch/cuda/green_contexts.py @@ -1,6 +1,10 @@ import torch +__all__ = [ + "GreenContext", +] + _GreenContext = object SUPPORTED = False @@ -19,18 +23,59 @@ class GreenContext(_GreenContext): """ @staticmethod - def create(num_sms: int, device_id: int = 0) -> _GreenContext: + def create( + *, + num_sms: int | None = None, + workqueue_scope: str | None = None, + workqueue_concurrency_limit: int | None = None, + device_id: int | None = None, + ) -> _GreenContext: r"""Create a CUDA green context. + At least one of ``num_sms`` or ``workqueue_scope`` must be specified. + Both can be combined to partition SMs and configure workqueues in the + same green context. + Arguments: - num_sms (int): The number of SMs to use in the green context. + num_sms (int, optional): The number of SMs to use in the green + context. When ``None``, SMs are not partitioned. + workqueue_scope (str, optional): Workqueue sharing scope. One of + ``"device_ctx"`` (shared across all contexts, default driver + behaviour) or ``"balanced"`` (non-overlapping workqueues with + other balanced green contexts). When ``None``, no workqueue + configuration is applied. + workqueue_concurrency_limit (int, optional): Maximum number of + concurrent stream-ordered workloads for the workqueue. Requires + ``workqueue_scope`` to be set. device_id (int, optional): The device index of green context. + When ``None``, the current device is used. + """ + if not SUPPORTED: + raise RuntimeError("PyTorch was not built with Green Context support!") + return _GreenContext.create( # type: ignore[attr-defined] + device_id=device_id, + num_sms=num_sms, + workqueue_scope=workqueue_scope, + workqueue_concurrency_limit=workqueue_concurrency_limit, + ) + + @staticmethod + def max_workqueue_concurrency(device_id: int | None = None) -> int: + r"""Return the maximum workqueue concurrency limit for the device. + + This queries the device for the default number of concurrent + stream-ordered workloads supported by workqueue configuration + resources. + + Arguments: + device_id (int, optional): The device index to query. When + ``None``, the current device is used. """ if not SUPPORTED: raise RuntimeError("PyTorch was not built with Green Context support!") - return _GreenContext.create(num_sms, device_id) # type: ignore[attr-defined] + return _GreenContext.max_workqueue_concurrency(device_id=device_id) # type: ignore[attr-defined] - # Note that these functions are bypassed by we define them here + # Note that these functions are bypassed but we define them here # for Sphinx documentation purposes def set_context(self) -> None: # pylint: disable=useless-parent-delegation r"""Make the green context the current context.""" @@ -42,6 +87,6 @@ def pop_context(self) -> None: # pylint: disable=useless-parent-delegation """ return super().pop_context() # type: ignore[misc] - def Stream(self) -> torch.Stream: + def Stream(self) -> "torch.cuda.Stream": r"""Return the CUDA Stream used by the green context.""" return super().Stream() diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 77915b04b209d..46e9412338d01 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -32,6 +32,7 @@ __all__ = [ "caching_allocator_alloc", "caching_allocator_delete", + "caching_allocator_disabled", "caching_allocator_enable", "get_per_process_memory_fraction", "set_per_process_memory_fraction", @@ -98,15 +99,6 @@ def _host_allocator(): return torch._C._cuda_cudaHostAllocator() -@contextlib.contextmanager -def _free_mutex(): - torch._C._cuda_lock_mutex() - try: - yield - finally: - torch._C._cuda_unlock_mutex() - - def caching_allocator_alloc(size, device: "Device" = None, stream=None): r"""Perform a memory allocation using the CUDA memory allocator. @@ -166,6 +158,18 @@ def caching_allocator_enable(value: bool = True) -> None: torch._C._cuda_cudaCachingAllocator_enable(value) +@contextlib.contextmanager +def caching_allocator_disabled(): + r"""Context manager that temporarily disables the CUDA caching allocator.""" + # pyrefly: ignore [missing-attribute] + prev = torch._C._cuda_cudaCachingAllocator_is_enabled() + caching_allocator_enable(False) + try: + yield + finally: + caching_allocator_enable(prev) + + def set_per_process_memory_fraction(fraction, device: "Device" = None) -> None: r"""Set memory fraction for a process. @@ -277,6 +281,8 @@ def memory_stats(device: "Device" = None) -> dict[str, Any]: cuMemMap and cudaMalloc. - ``"num_device_free"``: number of CUDA free calls. This includes both cuMemUnmap and cudaFree. + - ``"num_oom_rejections"``: number of allocations preemptively rejected by the + throw_on_cudamalloc_oom + per_process_memory_fraction policy. The caching allocator can be configured via ENV to not split blocks larger than a defined size (see Memory Management section of the Cuda Semantics documentation). @@ -1027,6 +1033,9 @@ class Segment(TypedDict): total_size: int # cudaMalloc'd size of segment stream: int segment_type: Literal["small", "large"] # 'large' (>1MB) + segment_pool_id: Tuple[ + int, int + ] # id of the memory pool owning this segment allocated_size: int # size of memory in use active_size: int # size of memory in use or in active_awaiting_free state blocks: List[Block] @@ -1085,6 +1094,7 @@ class TraceEntry(TypedDict): stream: int device_free: int # only present for OOM, the amount of # memory cuda still reports to be free + pool_id: Tuple[int, int] # id of the memory pool for this entry Args: device: Device to capture snapshot for. If None, captures for current device. diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 9c5a1af4a2c23..77981e465257a 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -178,7 +178,7 @@ class Event(torch._C._CudaEventBase): .. _CUDA Event Documentation: https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__EVENT.html - """ # noqa: B950 + """ def __new__( cls, enable_timing=False, blocking=False, interprocess=False, external=False diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index 24add6a4d8aeb..4f24b45494e6e 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -169,7 +169,8 @@ def breakpoint(rank: int = 0, skip: int = 0, timeout_s=3600): # stubs as necessary. # We cannot define stubs directly because they confuse pyre - class _ProcessGroupStub: + class _Stub: pass - sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined] + sys.modules["torch.distributed"].GroupName = _Stub # type: ignore[attr-defined] + sys.modules["torch.distributed"].ProcessGroup = _Stub # type: ignore[attr-defined] diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index ba8a54452058f..e0c9125fca52d 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -15,9 +15,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_init import ( _apply_to_module, _get_device_from_mesh, + _get_mesh_info, _get_modules_and_states, _init_default_mesh, _init_param_group, + _validate_mesh as _validate_mesh_common, _validate_module as _validate_module_common, ) from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState, FSDPStateContext @@ -30,6 +32,7 @@ if TYPE_CHECKING: + from torch.distributed.fsdp._fully_shard._fsdp_api import DataParallelMeshDims from torch.distributed.tensor import DeviceMesh @@ -85,6 +88,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> ReplicateModule: ... @@ -97,6 +101,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> list[ReplicateModule]: ... @@ -108,6 +113,7 @@ def replicate( mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: set[nn.Parameter] | None = None, + dp_mesh_dims: DataParallelMeshDims | None = None, ): r"""Replicates a module @@ -122,8 +128,17 @@ def replicate( torch._C._log_api_usage_once("torch.distributed._composable.replicate_with_fsdp") _validate_module(module) mesh = mesh or _init_default_mesh(mesh_dim_names=("replicate",)) - _validate_mesh(mesh) - mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) + if dp_mesh_dims is not None: + _validate_mesh_common(mesh, dp_mesh_dims) + mesh_info = _get_mesh_info(mesh, dp_mesh_dims) + if not isinstance(mesh_info, DDPMeshInfo): + raise ValueError( + "replicate() with dp_mesh_dims requires replicate-only " + "dims (no shard dims). Use fully_shard() for sharding." + ) + else: + _validate_mesh(mesh) + mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0) device = _get_device_from_mesh(mesh) # managed_modules (3rd return) and buffers (5th return) are unused: # - managed_modules: FSDP uses this to set Dynamo-specific attributes diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 71b6cd2c2b1db..6748163c0533e 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -101,7 +101,7 @@ def is_torchdynamo_compiling(): # type: ignore[misc] ) -from torch._utils import _chunk_or_narrow_cat # noqa: F401 +from torch._utils import _chunk_or_narrow_cat """ @@ -151,8 +151,10 @@ def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""): group (ProcessGroup or List[int]): The process group to work on. tag (str, optional): A unique identifier for the collective. Default: empty string """ - group_name = _resolve_group_name(group, tag) - tensor = torch.ops._c10d_functional.broadcast(self, src, group_name) + group = _resolve_group(group, tag) + tensor = torch.ops._c10d_functional.broadcast( + self, src, _group_or_group_name(group) + ) return _maybe_wrap_tensor(tensor) @@ -173,8 +175,10 @@ def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - group_name = _resolve_group_name(group, tag) - tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) + group = _resolve_group(group, tag) + tensor = torch.ops._c10d_functional.all_reduce( + self, reduceOp.lower(), _group_or_group_name(group) + ) return _maybe_wrap_tensor(tensor) @@ -200,12 +204,10 @@ def all_gather_tensor( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - if not self.is_contiguous(): - raise AssertionError("Tensor must be contiguous for all_gather_tensor") - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) tensor = torch.ops._c10d_functional.all_gather_into_tensor( - self, group_size, group_name + self, group_size, _group_or_group_name(group) ) res = _maybe_wrap_tensor(tensor) if gather_dim != 0: @@ -238,11 +240,11 @@ def all_gather_tensor_autograd( See all_gather_tensor for more details on usage. """ - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor( - self, group_size, group_name + self, group_size, _group_or_group_name(group) ) res = _FromTorchTensor.apply(tensor) if gather_dim != 0: @@ -281,8 +283,8 @@ def reduce_scatter_tensor( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) if self.size(scatter_dim) % group_size != 0: raise AssertionError( @@ -295,7 +297,7 @@ def reduce_scatter_tensor( self, reduceOp.lower(), group_size, - group_name, # type: ignore[possibly-undefined] + _group_or_group_name(group), ) res = _maybe_wrap_tensor(tensor) return res @@ -320,8 +322,8 @@ def reduce_scatter_tensor_autograd( See reduce_scatter_tensor for more details on usage. """ - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) if self.size(scatter_dim) % group_size != 0: raise AssertionError( @@ -334,7 +336,7 @@ def reduce_scatter_tensor_autograd( self, reduceOp.lower(), group_size, - group_name, # type: ignore[possibly-undefined] + _group_or_group_name(group), ) res = _FromTorchTensor.apply(tensor) return res @@ -359,11 +361,11 @@ def all_reduce_coalesced( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - group_name = _resolve_group_name(group, tag) + group = _resolve_group(group, tag) tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined] self, reduceOp.lower(), - group_name, + _group_or_group_name(group), ) return list(map(_maybe_wrap_tensor, tensor_list)) @@ -387,12 +389,12 @@ def all_gather_into_tensor_coalesced( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined] self, group_size, - group_name, + _group_or_group_name(group), ) return list(map(_maybe_wrap_tensor, tensor_list)) @@ -419,8 +421,8 @@ def reduce_scatter_tensor_coalesced( :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover that information and perform collective algebraic optimization. Use other forms of input for that. """ - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) if len(scatter_dim) != len(inputs): raise AssertionError( @@ -439,7 +441,7 @@ def reduce_scatter_tensor_coalesced( inputs, reduceOp.lower(), group_size, - group_name, # type: ignore[possibly-undefined] + _group_or_group_name(group), ) return list(map(_maybe_wrap_tensor, tensor_list)) @@ -497,8 +499,8 @@ def all_to_all_single( raise AssertionError( f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" ) - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) if output_split_sizes is None or input_split_sizes is None: if not (output_split_sizes is None and input_split_sizes is None): raise AssertionError( @@ -511,7 +513,7 @@ def all_to_all_single( self, output_split_sizes, input_split_sizes, - group_name, + _group_or_group_name(group), ) return _maybe_wrap_tensor(tensor) @@ -539,8 +541,8 @@ def all_to_all_single_autograd( f"All input_split_sizes must be int or SymInt, got {input_split_sizes}" ) - group_name = _resolve_group_name(group, tag) - group_size = c10d._get_group_size_by_name(group_name) + group = _resolve_group(group, tag) + group_size = c10d._get_group_size_by_name(group) if output_split_sizes is None or input_split_sizes is None: if not (output_split_sizes is None and input_split_sizes is None): raise AssertionError( @@ -553,7 +555,7 @@ def all_to_all_single_autograd( self, output_split_sizes, input_split_sizes, - group_name, + _group_or_group_name(group), ) return _FromTorchTensor.apply(tensor) @@ -1187,26 +1189,31 @@ def cast_listint(x): return (tag, rankset, group_size) -def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName: +def _resolve_group( + group: RANK_TYPES, tag: str = "" +) -> dist.ProcessGroup | c10d.GroupName: """ - Given group in RANK_TYPES, return the group name. + Given group in RANK_TYPES, return a ProcessGroup or group name. """ # `tag` will be deprecated. See details in: # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 if isinstance(group, dist.ProcessGroup): - return group.group_name + return group elif isinstance(group, str): # In some cases Dynamo doesn't like tracing through NewType constructors # - so use a cast instead (the actual newtype representation is # literally the underlying type so this is fine). I haven't been able to # reproduce it in isolation (see T247631668). # pyrefly: ignore [redundant-cast] - return cast(c10d.GroupName, group) # c10d.GroupName(group) + group_name = cast(c10d.GroupName, group) # c10d.GroupName(group) + return group_name elif isinstance(group, DeviceMesh): if group.ndim != 1: raise AssertionError( "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D" ) + if dist.config.compile_on_one_rank: + return torch.ops._dtensor.mesh_get_process_group(group, 0) return group._dim_group_names[0] elif isinstance(group, tuple): if ( @@ -1216,6 +1223,8 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName: ): dmesh = group[0] dim = group[1] + if dist.config.compile_on_one_rank: + return torch.ops._dtensor.mesh_get_process_group(dmesh, dim) return dmesh._dim_group_names[dim] else: raise ValueError( @@ -1230,11 +1239,26 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName: FutureWarning, stacklevel=3, ) - return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag) + return c10d._resolve_group_name_by_ranks_and_tag( + # pyrefly: ignore [redundant-cast] + cast(list[int], group), + tag, + ) else: raise ValueError(f"Unsupported group type: {type(group)}, {group}") +def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName: + """ + Given group in RANK_TYPES, return the group name. + """ + group = _resolve_group(group, tag) + if isinstance(group, str): + return c10d.GroupName(group) + else: + return group.group_name + + class _FromTorchTensor(torch.autograd.Function): """ _FromTorchTensor allows autograd to propagate from a normal Tensor to an @@ -1570,7 +1594,7 @@ def _reduce_scatter_tensor_coalesced_native_meta( "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]", "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor", "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]", - "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950 + "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", ] my_module = sys.modules[__name__] @@ -1792,7 +1816,8 @@ def batch_p2p_ops_inplace( raise AssertionError("torch.distributed must be initialized") if group_name is None or group_name == "": group_name = c10d._get_default_group() - group_name = _resolve_group_name(group_name) + resolved = _resolve_group(group_name) + group_name = resolved if isinstance(resolved, str) else resolved.group_name tensors = torch.ops._c10d_functional.batch_p2p_ops( op_list, peer_list, tag_list, tensors, group_name ) @@ -1804,6 +1829,17 @@ def batch_p2p_ops_inplace( return list(map(_maybe_wrap_tensor, tensors)) +def _group_or_group_name( + group: dist.ProcessGroup | c10d.GroupName, +) -> dist.ProcessGroup | c10d.GroupName: + if isinstance(group, str): + return group + elif dist.config.compile_on_one_rank: + return group + else: + return group.group_name + + from torch.distributed.distributed_c10d import ( # pyrefly: ignore # deprecated; pyrefly: ignore [deprecated] _all_gather_base as legacy_all_gather_base, _reduce_scatter_base as legacy_reduce_scatter_base, diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index a960829ac8882..7326ed539348b 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -1891,7 +1891,7 @@ def __init__( def run(self): try: super().run() - except BaseException as e: # noqa: B036 + except BaseException as e: self.exception = e def join(self, timeout=None): diff --git a/torch/distributed/_ops/device_mesh.py b/torch/distributed/_ops/device_mesh.py index fb431cad79efc..61d4312c3bb04 100644 --- a/torch/distributed/_ops/device_mesh.py +++ b/torch/distributed/_ops/device_mesh.py @@ -62,3 +62,44 @@ def _runtime_compute_coordinate_on_dim_impl(full_mesh: torch.Tensor, index: int) if mesh_coords is None: raise AssertionError return mesh_coords[index] + + +def _get_flattened_submesh_impl(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + from torch.distributed.tensor._redistribute import ( + _get_flattened_mesh_by_layout_impl, + ) + + result = _get_flattened_mesh_by_layout_impl(mesh, tuple(mesh_dims)) + if result is None: + raise ValueError(f"No flattened mesh found for mesh_dims={mesh_dims} on {mesh}") + return result + + +@torch.library.custom_op("device_mesh::_get_flattened_submesh", mutates_args=()) +def _get_flattened_submesh(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + return _get_flattened_submesh_impl(mesh, mesh_dims) + + +@_get_flattened_submesh.register_fake +def _get_flattened_submesh_fake(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + return _get_flattened_submesh_impl(mesh, mesh_dims) + + +def _get_submesh_impl(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + all_dim_names = mesh._mesh_dim_names + if all_dim_names is None: + raise ValueError(f"Cannot slice mesh without dim names: {mesh}") + dim_names = tuple(all_dim_names[i] for i in mesh_dims) + if len(dim_names) == 1: + return mesh[dim_names[0]] + return mesh[dim_names] + + +@torch.library.custom_op("device_mesh::_get_submesh", mutates_args=()) +def _get_submesh(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + return _get_submesh_impl(mesh, mesh_dims) + + +@_get_submesh.register_fake +def _get_submesh_fake(mesh: DeviceMesh, mesh_dims: list[int]) -> DeviceMesh: + return _get_submesh_impl(mesh, mesh_dims) diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py index 3d3af3ed35953..b89c4826957e2 100644 --- a/torch/distributed/_shard/sharded_tensor/__init__.py +++ b/torch/distributed/_shard/sharded_tensor/__init__.py @@ -14,7 +14,7 @@ ShardedTensorMetadata, TensorProperties, ) -from .metadata import ShardMetadata # noqa: F401 +from .metadata import ShardMetadata if TYPE_CHECKING: diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 7052519061130..1a6c19b85385e 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -2008,7 +2008,9 @@ def is_nvshmem_available() -> bool: r""" is_nvshmem_available() -> bool - Check if NVSHMEM is available in current build and on current system. + Check if NVSHMEM (CUDA) or rocSHMEM (ROCm) is available in the current + build and usable at runtime. On ROCm, rocSHMEM ``VERSION`` must be at + least 3.3.0 (see ``rocshmem/rocshmem.hpp``). """ try: from torch._C._distributed_c10d import _is_nvshmem_available @@ -2195,8 +2197,100 @@ def wait_signal(hdl: _SymmetricMemory, peer: int) -> None: raise ValueError(f"wait_signal: unsupported backend: {backend}") +def reduce_scatter_offset( + input: torch.Tensor, + out: list[torch.Tensor], + group: str, + *, + dim: int, + offsets: list[int] | None = None, + dst_ranks: list[int] | None = None, + red_op: str = "sum", +) -> None: + r""" + reduce_scatter_offset(input, out, group, *, dim, offsets, dst_ranks, red_op='sum') -> None + + Simultaneously reduce N blocks of a 2-D ``input`` tensor from a symmetric + memory buffer, routing each block to a specific destination rank. Only + ``dst_ranks[i]`` writes the reduced result for block ``i``; the result is + written to a contiguous output tensor, with the same shape as block ``i``. + + The ``dim`` argument controls which dimension is sharded: + + - ``dim=0`` (row sharding): block ``i`` spans + ``input[offsets[i-1] : offsets[i], :]``. Each ``out[j]`` has shape + ``(size_j, input.size(1))``. + - ``dim=1`` (column sharding): block ``i`` spans + ``input[:, offsets[i-1] : offsets[i]]``. Each ``out[j]`` has shape + ``(input.size(0), size_j)``. + + Blocks are described by ``offsets``, an inclusive prefix-sum of block sizes + along ``dim`` (first block starts at index 0 by convention). Block offsets + can be even or uneven; when uneven, the following condition must be met: for + each ``j``, the ``j``-th owned block must have the same size across all + ranks (so that ``out[j]`` has a uniform shape); different ``j``'s may + differ. + + Args: + input (Tensor): 2-D tensor allocated via symmetric memory (innermost + dimension must be contiguous). + out (list[Tensor]): Output tensors for this rank's owned blocks. Must + have length equal to the number of blocks owned by this rank (i.e. + the count of ``i`` where ``dst_ranks[i] == my_rank``). Each + ``out[j]`` must be contiguous with the same dtype as ``input``. + group (str): The name of the ``ProcessGroup`` to perform the operation on. + dim (int): Dimension along which blocks are defined (0 or 1). + offsets (list[int] | None): Inclusive prefix-sum of block sizes along + ``dim``, length N. If not provided, ``input.size(dim)`` is divided + into equal-size blocks based on the size of the ``group``. + dst_ranks (list[int] | None): Destination rank for each block. If not + provided, blocks are distributed round-robin across ranks. + red_op (str): Reduction operation; currently only ``'sum'`` is supported. + + Example:: + + >>> # doctest: +SKIP + >>> # Each rank holds a Grouped GEMM gradient buffer in symmetric memory. + >>> # The buffer has W experts laid out as equal column blocks; each expert + >>> # is reduced to a specific rank (dst_ranks[i] == i % world_size). + >>> buf = symm_mem.empty(H, W * C, dtype=torch.bfloat16, device="cuda") + >>> symm_mem.rendezvous(buf, group=group_name) + >>> offsets = [i * C for i in range(1, W + 1)] # inclusive prefix-sum + >>> dst_ranks = [i % world_size for i in range(W)] + >>> n_owned = sum(r == rank for r in dst_ranks) + >>> out = [torch.empty(H, C, dtype=torch.bfloat16, device="cuda") for _ in range(n_owned)] + >>> symm_mem.reduce_scatter_offset(buf, out, group_name, dim=1, offsets=offsets, dst_ranks=dst_ranks) + """ + backend = get_backend(input.device) + if backend == "NCCL": + torch.ops.symm_mem.nccl_reduce_scatter_offset( + input, out, group, dim, offsets, dst_ranks, red_op + ) + else: + raise NotImplementedError( + f"reduce_scatter_offset: unsupported backend: {backend}" + ) + + +def is_symm_mem_tensor(tensor: torch.Tensor) -> bool: + r""" + is_symm_mem_tensor(tensor) -> bool + + Returns ``True`` if ``tensor`` was allocated via symmetric memory + (i.e. via :func:`torch.distributed._symmetric_memory.empty` or + :meth:`_SymmetricMemory.empty_strided_p2p`). + + This is a non-collective, O(1) check. + + Args: + tensor (:class:`torch.Tensor`): the tensor to check. + """ + return _SymmetricMemory.is_symm_mem_tensor(tensor) + + __all__ = [ "empty", + "is_symm_mem_tensor", "rendezvous", "is_nvshmem_available", "set_backend", @@ -2204,4 +2298,5 @@ def wait_signal(hdl: _SymmetricMemory, peer: int) -> None: "set_signal_pad_size", "get_signal_pad_size", "get_mem_pool", + "reduce_scatter_offset", ] diff --git a/torch/distributed/_tensor/__init__.py b/torch/distributed/_tensor/__init__.py index c5559cc10fabd..272a696d69fdb 100644 --- a/torch/distributed/_tensor/__init__.py +++ b/torch/distributed/_tensor/__init__.py @@ -26,7 +26,7 @@ full_module_name ) -from torch.distributed.tensor import ( # noqa: F401 +from torch.distributed.tensor import ( DeviceMesh, distribute_module, distribute_tensor, diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py index 9e5742156a86c..621c700e77ff9 100644 --- a/torch/distributed/_tensor/api.py +++ b/torch/distributed/_tensor/api.py @@ -6,4 +6,4 @@ TODO: throw warnings when this module imported """ -from torch.distributed.tensor._api import * # noqa: F401, F403 +from torch.distributed.tensor._api import * # noqa: F403 diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py index 6a4e70dbba455..e7dac1b76b52b 100644 --- a/torch/distributed/_tensor/placement_types.py +++ b/torch/distributed/_tensor/placement_types.py @@ -6,5 +6,5 @@ TODO: throw warnings when this module imported """ -from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 -from torch.distributed.tensor.placement_types import * # noqa: F401, F403 +from torch.distributed.tensor._dtensor_spec import * # noqa: F403 +from torch.distributed.tensor.placement_types import * # noqa: F403 diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index 83d07940a266a..1a10223f03ae4 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -257,6 +257,7 @@ def inner( state = _FSDPModState.AFT_PRE_FW mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) self._fsdp_state = _FSDPState.FW + # pyrefly: ignore [bad-return] return args, kwargs return inner diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index 806bdc6911c67..cd9effbab558d 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -87,7 +87,7 @@ def __init__(self) -> None: self.mod_bw_post_order: list[str] = [] self.total_runtime: float = 0.0 - # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # NB: returns fake tensors @classmethod def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] @@ -217,7 +217,7 @@ def _benchmark_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: res = func(*args, **kwargs or {}) return (res, mean_op_time) - # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 @classmethod def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] """ diff --git a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py index 2e3514edef063..5c5d4cd63be50 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py @@ -94,7 +94,7 @@ def _should_compress( uncompressed_el_count is the uncompressed element count, i.e. ``num_rows`` * ``num_cols``; and, compress_el_count is the element count after compression, i.e. (``num_rows`` + ``num_cols``) * ``matrix_approximation_rank``. - """ # noqa: B950 + """ uncompressed_size = num_rows * num_cols compressed_size = (num_rows + num_cols) * matrix_approximation_rank return ( @@ -150,7 +150,7 @@ class PowerSGDState: If error feedback or warm-up is enabled, the minimum value of ``start_powerSGD_iter`` allowed in DDP is 2. This is because there is another internal optimization that rebuilds buckets at iteration 1 in DDP, and this can conflict with any tensor memorized before the rebuild process. - """ # noqa: B950 + """ __slots__ = [ "process_group", @@ -324,7 +324,7 @@ def compression_stats(self): numel_before_compression is the total number of elements before compression was applied; and, numel_after_compression is the total number of elements after compression was applied. - """ # noqa: B950 + """ compress_rate = ( self.total_numel_before_compression / self.total_numel_after_compression if self.total_numel_after_compression > 0 @@ -397,7 +397,7 @@ def powerSGD_hook( >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5) >>> ddp_model.register_comm_hook(state, powerSGD_hook) - """ # noqa: B950 + """ process_group = state.process_group group_to_use = ( process_group if process_group is not None else not_none(dist.group.WORLD) @@ -710,7 +710,7 @@ def batched_powerSGD_hook( >>> # xdoctest: +SKIP >>> state = PowerSGDState(process_group=process_group, matrix_approximation_rank=1) >>> ddp_model.register_comm_hook(state, batched_powerSGD_hook) - """ # noqa: B950 + """ process_group = state.process_group group_to_use = ( process_group if process_group is not None else not_none(dist.group.WORLD) diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py index 110ea622ca4de..fd90399d8a5f1 100644 --- a/torch/distributed/checkpoint/_async_process_executor.py +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -275,7 +275,7 @@ def _checkpointing_subprocess( # GC can optionally be called manually after each checkpoint gc.disable() logger.info("Disabled automatic garbage collection") - except BaseException as e: # noqa: B036 + except BaseException as e: logger.error( f"Checkpoint background process failed during initialization: {e}" # noqa: G004 ) @@ -337,7 +337,7 @@ def _checkpointing_subprocess( ) gc.freeze() first_request = False - except BaseException as e: # noqa: B036 + except BaseException as e: logger.error( f"Checkpoint save failed for checkpoint_id={obj.checkpoint_request_id.checkpoint_id}: {e}" # noqa: G004 ) diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py index e9bf297618299..2014a97375b90 100644 --- a/torch/distributed/checkpoint/_fsspec_filesystem.py +++ b/torch/distributed/checkpoint/_fsspec_filesystem.py @@ -47,11 +47,11 @@ def create_stream( with self.fs.open(path, mode) as stream: try: yield stream - except: # noqa: B001,E722 + except: if any(ch in mode for ch in "w+a"): # cleanup file if not read-only try: self.rm_file(path) - except: # noqa: B001,E722 + except: # noqa: E722 pass raise diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index 8b5807fcca8b7..cf0c00a48dda9 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -3,6 +3,7 @@ # flake8: noqa: F821 from collections.abc import Callable, Collection, Mapping, MutableMapping from typing import cast, TypeVar +from typing_extensions import TypeIs import torch from torch.distributed._shard.sharded_tensor.api import ShardedTensor @@ -20,7 +21,7 @@ __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"] -def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool: +def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> TypeIs[torch.Tensor]: return isinstance(value, torch.Tensor) diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py index 4aa4854db2358..4d9817f235ca4 100644 --- a/torch/distributed/checkpoint/api.py +++ b/torch/distributed/checkpoint/api.py @@ -8,7 +8,15 @@ def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION: - return (exc, tb.extract_tb(exc.__traceback__)) + summary = tb.extract_tb(exc.__traceback__) + # Python 3.13+ stores bytecode objects in FrameSummary._code, + # which cannot be pickled. Clear them so gather_object succeeds + # and the real exception is reported instead of a misleading + # "cannot pickle code objects" TypeError. + for frame in summary: + if hasattr(frame, "_code"): + object.__setattr__(frame, "_code", None) + return (exc, summary) def _is_wrapped_exception(obj: Any) -> bool: diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py index d693790356a9d..831fa16394563 100644 --- a/torch/distributed/checkpoint/filesystem.py +++ b/torch/distributed/checkpoint/filesystem.py @@ -756,7 +756,7 @@ def _write_data( return fut def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: - metadata = dataclasses.replace(metadata, version=CURRENT_DCP_VERSION) + metadata.version = CURRENT_DCP_VERSION storage_md = {} for wr_list in results: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index ed3d4622eace4..7d4428afa25db 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -844,6 +844,7 @@ def _reconstruct_nested_dict( continue # Reconstruct state for this parameter + # pyrefly: ignore [unsupported-operation] state[fqn] = {} for state_name in optim.state[param]: flattened_state_key = f"{_STATE}.{fqn}.{state_name}" @@ -853,11 +854,13 @@ def _reconstruct_nested_dict( reconstructed_value = _reconstruct_nested_dict( flattened_state_key, state_dict ) + # pyrefly: ignore [bad-index] cast(DictValueType, state[fqn])[state_name] = ( reconstructed_value ) else: # Existing keys mean no nesting, directly use the value. + # pyrefly: ignore [bad-index] cast(DictValueType, state[fqn])[state_name] = state_dict[ flattened_state_key ] diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py index 62acd1bfef359..a3e6586ecceb3 100644 --- a/torch/distributed/checkpoint/state_dict_loader.py +++ b/torch/distributed/checkpoint/state_dict_loader.py @@ -228,11 +228,15 @@ def local_step(): nonlocal metadata # Use global metadata if available, otherwise fallback to rank local metadata + global_metadata_exc: Exception | None = None + rank_metadata_exc: Exception | None = None try: metadata = storage_reader.read_metadata() - except Exception: - logger.info( - "Global metadata is not found. Falling back to rank local metadata." + except Exception as e: + global_metadata_exc = e + logger.warning( + "Global metadata is not found. Falling back to rank local metadata.", + exc_info=True, ) if ( @@ -240,15 +244,25 @@ def local_step(): and "kwargs" in inspect.signature(storage_reader.read_metadata).parameters ): try: - metadata = storage_reader.read_metadata(rank=distW.rank) # noqa: F841 + metadata = storage_reader.read_metadata(rank=distW.rank) use_collectives = False - except Exception: - logger.info("Rank local metadata is not found.") + except Exception as e: + rank_metadata_exc = e + logger.warning("Rank local metadata is not found.", exc_info=True) if planner is None: raise AssertionError("planner is None") if metadata is None: - raise AssertionError("metadata is None") + error_parts = ["metadata is None"] + if global_metadata_exc is not None: + error_parts.append( + f"global metadata read failed: {global_metadata_exc}" + ) + if rank_metadata_exc is not None: + error_parts.append( + f"rank local metadata read failed: {rank_metadata_exc}" + ) + raise AssertionError("; ".join(error_parts)) planner.set_up_planner(state_dict, metadata, distW.is_coordinator) if ( diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index d5e19eee111a0..8ac7097bf7501 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -191,7 +191,7 @@ def reduce_scatter( local_data: WRAPPED_EXCEPTION | T try: local_data = map_fun() - except BaseException as e: # noqa: B036 + except BaseException as e: local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -208,7 +208,7 @@ def reduce_scatter( list[R | CheckpointException], reduce_fun(cast(list[T], all_data)), ) - except BaseException as e: # noqa: B036 + except BaseException as e: node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -239,7 +239,7 @@ def all_reduce( local_data: T | WRAPPED_EXCEPTION try: local_data = map_fun() - except BaseException as e: # noqa: B036 + except BaseException as e: local_data = _wrap_exception(e) all_data = self.gather_object(local_data) @@ -251,7 +251,7 @@ def all_reduce( if len(node_failures) == 0: try: result = reduce_fun(cast(list[T], all_data)) - except BaseException as e: # noqa: B036 + except BaseException as e: node_failures[self.rank] = _wrap_exception(e) if len(node_failures) > 0: @@ -261,6 +261,7 @@ def all_reduce( final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): raise final_result + # pyrefly: ignore [redundant-cast] return cast(R, final_result) def all_gather( @@ -278,7 +279,7 @@ def all_gather( result: T | WRAPPED_EXCEPTION try: result = map_fun() - except BaseException as e: # noqa: B036 + except BaseException as e: result = _wrap_exception(e) all_results = self.all_gather_object(result) @@ -304,12 +305,13 @@ def broadcast( if self.is_coordinator: try: result = map_fun() - except BaseException as e: # noqa: B036 + except BaseException as e: result = CheckpointException(step, {self.rank: _wrap_exception(e)}) # pyrefly: ignore [bad-argument-type] final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): raise final_result + # pyrefly: ignore [redundant-cast] return cast(T, final_result) def barrier(self) -> None: diff --git a/torch/distributed/config.py b/torch/distributed/config.py index 50587c4162880..e92fd0d9577c6 100644 --- a/torch/distributed/config.py +++ b/torch/distributed/config.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: - from torch.utils._config_typing import * # noqa: F401, F403 + from torch.utils._config_typing import * # noqa: F403 # adds patch, save_config, invalid config checks, etc diff --git a/torch/distributed/debug/__init__.py b/torch/distributed/debug/__init__.py index a5a66db1ef90a..dc8e051b3f652 100644 --- a/torch/distributed/debug/__init__.py +++ b/torch/distributed/debug/__init__.py @@ -4,7 +4,7 @@ from typing import Literal, TYPE_CHECKING # import for registration side effect -import torch.distributed.debug._handlers # noqa: F401 +import torch.distributed.debug._handlers from torch._C._distributed_c10d import _WorkerServer from torch.distributed.debug._store import get_rank, tcpstore_client diff --git a/torch/distributed/debug/_debug_handlers.py b/torch/distributed/debug/_debug_handlers.py index f63c0848d7e15..902f78e2e863f 100644 --- a/torch/distributed/debug/_debug_handlers.py +++ b/torch/distributed/debug/_debug_handlers.py @@ -228,9 +228,10 @@ def templates(self) -> dict[str, str]: return {"pyspy_dump.html": PYSPY_DUMP_TEMPLATE} def _handle(self, req: HTTPRequestHandler) -> bytes: - addrs, resps = fetch_all( - "pyspy_dump", req.get_raw_query(), timeout=self.fetch_timeout - ) + query = req.get_raw_query() + if "nonblocking" not in query: + query = f"nonblocking=1&{query}" if query else "nonblocking=1" + addrs, resps = fetch_all("pyspy_dump", query, timeout=self.fetch_timeout) return req.frontend.render_template( "pyspy_dump.html", addrs=addrs, @@ -238,7 +239,9 @@ def _handle(self, req: HTTPRequestHandler) -> bytes: ) def dump(self) -> str | None: - addrs, resps = fetch_all("pyspy_dump", timeout=self.fetch_timeout) + addrs, resps = fetch_all( + "pyspy_dump", "nonblocking=1", timeout=self.fetch_timeout + ) parts: list[str] = [] summary = format_fetch_summary(addrs, resps) if summary: diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 895ba0e17bb31..99b41eb2ad529 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -4,14 +4,16 @@ import os import threading import warnings -from collections.abc import Iterator +from collections.abc import Callable, Iterator from itertools import zip_longest -from typing import TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch +from torch._opaque_base import OpaqueBase from torch.distributed import is_available from torch.distributed._mesh_layout import _MeshLayout from torch.distributed._pycute import IntTuple, is_int, suffix_product +from torch.types import IntLikeType from torch.utils._typing_utils import not_none @@ -148,7 +150,7 @@ def _get_device_handle(device_type: str = "cuda"): """ return getattr(torch, device_type, None) - class DeviceMesh(torch._opaque_base.OpaqueBase): + class DeviceMesh(OpaqueBase): """ DeviceMesh represents a mesh of devices, where layout of devices could be represented as a n-d dimension array, and each value of the n-d dimensional @@ -202,6 +204,7 @@ class DeviceMesh(torch._opaque_base.OpaqueBase): _mesh_dim_names: tuple[str, ...] | None _layout: _MeshLayout _root_mesh: "DeviceMesh | None" = None + _thread_id: int | None # Record flatten mesh name to its flattened mesh in root mesh. _flatten_mapping: dict[str, "DeviceMesh"] # Registry mapping group names to ProcessGroup objects (to avoid C++ lookup) @@ -546,10 +549,12 @@ def _init_one_process_group( getattr(default_group, "bound_device_id", None) is not None or dist_config.use_torchcomms ) - and torch.cuda.is_available() + and torch.accelerator.is_available() and ( backend is None - or default_group._get_backend(torch.device("cuda")).name() + or default_group._get_backend( + torch.accelerator.current_accelerator() # pyrefly: ignore[bad-argument-type] + ).name() == backend ) ): @@ -643,19 +648,21 @@ def __repr__(self) -> str: device_mesh_repr += f", Mesh: {self.mesh.tolist()}" return f"{device_mesh_repr})" + def _hash_key(self) -> tuple[Any, ...]: + """Return the tuple used for hashing. Used by both __hash__ and _stable_hash.""" + return ( + self._flatten_rank_map, + self._layout, + self._device_type, + self._mesh_dim_names, + self._thread_id, + ) + def __hash__(self): # lazily compute hash self._hash = getattr(self, "_hash", None) if not self._hash: - self._hash = hash( - ( - self._flatten_rank_map, - self._layout, - self._device_type, - self._mesh_dim_names, - self._thread_id, - ) - ) + self._hash = hash(self._hash_key()) return self._hash def __eq__(self, other: object) -> bool: @@ -671,6 +678,17 @@ def __eq__(self, other: object) -> bool: and self._thread_id == other._thread_id ) + def _stable_hash(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + import hashlib + + return hashlib.blake2b( + repr(self._hash_key()).encode(), digest_size=16 + ).hexdigest() + def __getitem__(self, mesh_dim_names: str | tuple[str, ...]) -> "DeviceMesh": """ Slice the current DeviceMesh based on the mesh_dim_names given to create a submesh. @@ -1220,11 +1238,15 @@ def get_coordinate(self) -> tuple[int, ...] | None: """ return self._coordinate_on_dim - def _sym_get_coordinate(self, index: int) -> int: + def _sym_get_coordinate(self, index: int) -> IntLikeType: import torch.distributed.config as config from torch._guards import detect_fake_mode - if not detect_fake_mode() or not config.compile_on_one_rank: + if ( + not config.compile_on_one_rank + or not (fake_mode := detect_fake_mode()) + or not fake_mode.shape_env + ): # This is only valid when the current rank is part of the mesh. if self._coordinate_on_dim is None: raise AssertionError @@ -1574,12 +1596,89 @@ def init_device_mesh( return device_mesh +_distributed_opaque_types_registered = False + + +def _device_mesh_reconstruct_fn( + mesh: "OpaqueBase", + get_tracked_proxy: Callable[["OpaqueBase"], "torch.fx.Proxy | None"], + tracer: Any, +) -> "torch.fx.Proxy | None": + """Reconstruct a DeviceMesh submesh from a tracked ancestor mesh. + + Called by PythonKeyTracer when make_fx encounters a DeviceMesh that isn't + tracked (e.g. a submesh captured by a backward closure). Looks for any + tracked mesh that shares the same root and contains the target dim names, + then emits a call_function node that derives the submesh via _get_submesh. + """ + if not isinstance(mesh, DeviceMesh): + raise AssertionError("DeviceMesh expected") + + root_mesh = mesh._get_root_mesh() + + # Only submeshes can be reconstructed; root meshes must already be tracked. + if mesh is root_mesh: + return None + + dim_names = mesh._mesh_dim_names + if dim_names is None: + return None + + # Ensure the custom ops are registered + from torch.distributed._ops import device_mesh as _dm_ops # noqa: F401 + + # Try the root mesh first (original path). + ancestor_proxy = get_tracked_proxy(root_mesh) + ancestor_dim_names = root_mesh._mesh_dim_names + + # If root isn't tracked, search for any tracked DeviceMesh that shares + # the same root AND contains all our dim names. This handles the case + # where e.g. a concatenated (fsdp, tp) mesh is a graph input (from + # DTensor.__tensor_flatten__) but neither root nor the individual + # submeshes are tracked directly. + if ancestor_proxy is None: + from torch._library.fake_class_registry import FakeScriptObject + + for tracked_obj, proxy in tracer.opaque_tracker.items(): + real_obj = ( + tracked_obj.real_obj + if isinstance(tracked_obj, FakeScriptObject) + else tracked_obj + ) + if not isinstance(real_obj, DeviceMesh) or real_obj is mesh: + continue + if real_obj._get_root_mesh() is not root_mesh: + continue + tracked_dim_names = real_obj._mesh_dim_names + if tracked_dim_names is None: + continue + if all(n in tracked_dim_names for n in dim_names): + ancestor_proxy = proxy + ancestor_dim_names = tracked_dim_names + break + + if ancestor_proxy is None or ancestor_dim_names is None: + return None + + # Convert our dim names to indices into the ancestor mesh's dim names + mesh_dims = [ancestor_dim_names.index(n) for n in dim_names] + + # Dispatch through the custom op with proxy mode active so that + # meta["val"] is set and the result is tracked in opaque_tracker. + return torch.ops.device_mesh._get_submesh(ancestor_proxy, mesh_dims) + + def _register_distributed_opaque_types(): """ Register DeviceMesh as an opaque type for torch.compile. This must happen before any custom ops that use DeviceMesh in their schema. Called lazily to avoid circular import issues. """ + global _distributed_opaque_types_registered + if _distributed_opaque_types_registered: + return + _distributed_opaque_types_registered = True + from torch._library.opaque_object import MemberType, register_opaque_type register_opaque_type( @@ -1590,13 +1689,16 @@ def _register_distributed_opaque_types(): "rank": MemberType.USE_REAL, "_get_backend_name": MemberType.USE_REAL, "group_name": MemberType.USE_REAL, + "group_desc": MemberType.USE_REAL, "__eq__": MemberType.USE_REAL, + "__ne__": MemberType.USE_REAL, }, ) register_opaque_type( DeviceMesh, typ="reference", + reconstruct_fn=_device_mesh_reconstruct_fn, guard_fn=lambda obj: [ obj._flatten_rank_map, obj._layout, @@ -1617,6 +1719,7 @@ def _register_distributed_opaque_types(): "get_coordinate": MemberType.USE_REAL, "get_local_rank": MemberType.USE_REAL, "__eq__": MemberType.USE_REAL, + "__ne__": MemberType.USE_REAL, "ndim": MemberType.USE_REAL, "shape": MemberType.USE_REAL, "mesh_dim_names": MemberType.USE_REAL, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 75569c0c912db..0e536d434cea4 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -134,6 +134,7 @@ "get_node_local_rank", "split_group", "shrink_group", + "record_comm", ] _MPI_AVAILABLE = True @@ -143,11 +144,15 @@ _XCCL_AVAILABLE = True try: - # pyrefly: ignore [missing-import] - from torchcomms._backend_wrapper import _BackendWrapper + try: + # pyrefly: ignore [missing-import] + from torchcomms._comms import _BackendWrapper + except ImportError: + # pyrefly: ignore [missing-import] + from torchcomms._backend_wrapper import _BackendWrapper # pyrefly: ignore [missing-import] - from torchcomms._comms import new_comm + from torchcomms import new_comm # pyrefly: ignore [missing-import] from torchcomms.hooks import FlightRecorderHook @@ -1016,7 +1021,7 @@ def _store_based_barrier( except RuntimeError as e: worker_count = store.add(store_key, 0) # Print status periodically to keep track. - logger.debug( # noqa: G200 + logger.debug( "Waiting in store based barrier to initialize process group for %s seconds" "rank: %s, key: %s (world_size=%s, num_workers_joined=%s, timeout=%s error=%s)", time.time() - start, @@ -1147,9 +1152,11 @@ def _get_group_size(group: ProcessGroup | None) -> int: return group.size() -def _get_group_size_by_name(group_name: GroupName) -> int: - group = _resolve_process_group(group_name) - return group.size() +def _get_group_size_by_name(group_name: GroupName | ProcessGroup) -> int: + if isinstance(group_name, str): + # pyrefly: ignore[bad-argument-type] # pyrefly bug + group_name = _resolve_process_group(group_name) + return group_name.size() def _resolve_group_name_by_ranks_and_tag(ranks: list[int], tag: str) -> GroupName: @@ -5159,6 +5166,7 @@ def barrier( group: ProcessGroup | None = GroupMember.WORLD, async_op: bool = False, device_ids=None, + timeout: timedelta | None = None, ): """ Synchronize all processes. @@ -5171,6 +5179,8 @@ def barrier( the default process group will be used. async_op (bool, optional): Whether this op should be an async op device_ids ([int], optional): List of device/GPU ids. Only one id is expected. + timeout (datetime.timedelta, optional): Timeout for barrier. + If ``None``, the default process group timeout will be used. Returns: Async work handle, if async_op is set to True. @@ -5191,6 +5201,8 @@ def barrier( opts = BarrierOptions() opts.asyncOp = async_op + if timeout is not None: + opts.timeout = timeout # Detect the accelerator on the machine. If no accelerator is available, it # returns CPU. device = torch._C._get_accelerator() @@ -5451,10 +5463,18 @@ def split_group( ) parent_group_rank = parent_global_to_group_ranks[global_rank] - parent_backend = parent_pg._get_backend(torch.device("cuda")) + + if torch.accelerator.is_available(): + parent_backend = parent_pg._get_backend( + torch.accelerator.current_accelerator() # pyrefly: ignore[bad-argument-type] + ) + else: + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) # if the parent backend does not support splitting, raise error - # currently this API only support NCCL backend + # currently this API only support NCCL and XCCL backend if ( not parent_backend or not parent_backend.supports_splitting ) and not _use_torchcomms_enabled(): @@ -5520,7 +5540,16 @@ def split_group( global_ranks_in_my_group = [parent_group_to_global_ranks[rank] for rank in my_group] split_pg.bound_device_id = device_id # type: ignore[union-attr] - split_backend_class = split_pg._get_backend(torch.device("cuda")) + + if torch.accelerator.is_available(): + split_backend_class = split_pg._get_backend( + torch.accelerator.current_accelerator() # pyrefly: ignore[bad-argument-type] + ) + else: + raise RuntimeError( + "No backend for the parent process group or its backend does not support splitting" + ) + if not _use_torchcomms_enabled(): split_backend_class._set_sequence_number_for_group() if split_pg.group_name != group_name: @@ -6553,3 +6582,28 @@ def _update_process_group_global_state( # Standard process group tag _world.tags_to_pg.setdefault(pg_tag, []).append(pg) _world.pg_to_tag[pg] = pg_tag + + +@contextlib.contextmanager +def record_comm(name: str): + """Context manager to set a custom profiling name for communication collectives. + + When active, all c10d collectives issued within this context will use ``name`` + as their profiling title in the Work base class, overriding the default + backend-specific name (e.g. ``nccl:all_reduce``). This works across all + backends without per-backend or per-collective changes. + + Args: + name (str): The profiling name to associate with collectives. + + Example:: + >>> # xdoctest: +SKIP("undefined vars") + >>> with dist.record_comm("FSDP::all_gather (layer1)"): + ... dist.all_gather_into_tensor(output, input, group=pg) + """ + prev = torch._C._distributed_c10d._get_comm_profiling_name() + torch._C._distributed_c10d._set_comm_profiling_name(name) + try: + yield + finally: + torch._C._distributed_c10d._set_comm_profiling_name(prev) diff --git a/torch/distributed/elastic/agent/server/__init__.py b/torch/distributed/elastic/agent/server/__init__.py index 7c0d76131fe40..7758a8b0ff59e 100644 --- a/torch/distributed/elastic/agent/server/__init__.py +++ b/torch/distributed/elastic/agent/server/__init__.py @@ -29,7 +29,7 @@ in the same job) to make a collective decision. """ -from .api import ( # noqa: F401 +from .api import ( ElasticAgent, RunResult, SimpleElasticAgent, diff --git a/torch/distributed/elastic/agent/server/api.py b/torch/distributed/elastic/agent/server/api.py index dd0905eb3f761..396b62c661266 100644 --- a/torch/distributed/elastic/agent/server/api.py +++ b/torch/distributed/elastic/agent/server/api.py @@ -471,6 +471,7 @@ def __init__( self._exit_barrier_timeout = exit_barrier_timeout self._shutdown_timeout = shutdown_timeout self._total_execution_time = 0 + self._in_exit_barrier: bool = False def get_worker_group(self, role: str = DEFAULT_ROLE) -> WorkerGroup: return self._worker_group @@ -747,7 +748,7 @@ def run(self, role: str = DEFAULT_ROLE) -> RunResult: self._record_worker_events(result) return result except RendezvousGracefulExitError as e: - logger.info("Rendezvous gracefully exited: %s", e) # noqa: G200 + logger.info("Rendezvous gracefully exited: %s", e) except SignalException as e: logger.warning("Received %s death signal, shutting down workers", e.sigval) self._shutdown(e.sigval, timeout=self._shutdown_timeout) @@ -902,19 +903,9 @@ def _record_flakiness_metric(self, is_failed: bool = False): put_metric(f"workers.{spec.role}.flakiness", int(flakiness)) - def _pre_invoke_run(self) -> None: - """Hook called before the worker lifecycle loop in ``_invoke_run``. - - Subclasses can override this to perform setup that must happen - before rendezvous and worker initialization (e.g. starting a - health check server). The default implementation is a no-op. - """ - def _invoke_run(self, role: str = DEFAULT_ROLE) -> RunResult: # NOTE: currently only works for a single role - self._pre_invoke_run() - spec = self._worker_group.spec role = spec.role @@ -998,6 +989,7 @@ def _exit_barrier(self): self._exit_barrier_timeout, ) start = time.time() + self._in_exit_barrier = True try: store_util.barrier( store=self._store, @@ -1017,3 +1009,5 @@ def _exit_barrier(self): "Error waiting on exit barrier. Elapsed: %s seconds", time.time() - start, ) + finally: + self._in_exit_barrier = False diff --git a/torch/distributed/elastic/agent/server/health_check_server.py b/torch/distributed/elastic/agent/server/health_check_server.py index 4815d86aa289c..99528e0981c8b 100644 --- a/torch/distributed/elastic/agent/server/health_check_server.py +++ b/torch/distributed/elastic/agent/server/health_check_server.py @@ -53,6 +53,10 @@ def stop(self) -> None: """ log.info("Stopping noop health check server.") + @property + def alive_callback(self) -> Callable[[], int]: + return self._alive_callback + def create_healthcheck_server( alive_callback: Callable[[], int], diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index b9d2ed64cd0cf..fda632bb3aaba 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -7,6 +7,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import json import os @@ -43,6 +44,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from torch.distributed.elastic.events.api import EventMetadataValue logger = get_logger(__name__) @@ -59,6 +62,29 @@ TORCHELASTIC_TIMER_FILE = "TORCHELASTIC_TIMER_FILE" +class _AliveCallbackProxy: + """Mutable callback wrapper for the health check server. + + The C++ pybind ``HealthCheckThriftServer`` binds its ``alive_callback`` + at construction time and cannot update it afterward. This proxy is + created *before* the health check server so it can be passed as the + callback. Initially it returns ``time.time()`` (signalling "alive"). + After the agent is constructed, :meth:`set_delegate` wires it to + ``agent._get_alive_time`` for real liveness tracking. + """ + + def __init__(self) -> None: + self._delegate: Callable[[], int] | None = None + + def __call__(self) -> int: + if self._delegate is not None: + return self._delegate() + return int(time.time()) + + def set_delegate(self, delegate: Callable[[], int]) -> None: + self._delegate = delegate + + class LocalElasticAgent(SimpleElasticAgent): """An implementation of :py:class:`torchelastic.agent.server.ElasticAgent` that handles host-local workers. @@ -156,6 +182,7 @@ def __init__( exit_barrier_timeout: float = 300, log_line_prefix_template: str | None = None, shutdown_timeout: int = 30, + health_check_server: HealthCheckServer | None = None, ): super().__init__(spec, exit_barrier_timeout, shutdown_timeout) self._start_method = start_method @@ -164,7 +191,7 @@ def __init__( self._log_line_prefix_template = log_line_prefix_template self._worker_watchdog: timer.FileTimerServer | None = None self._logs_specs = logs_specs - self._health_check_server: HealthCheckServer | None = None + self._health_check_server = health_check_server def _setup_local_watchdog(self, envs: dict[int, dict[str, str]]) -> None: enable_watchdog_env_name = TORCHELASTIC_ENABLE_FILE_TIMER @@ -217,7 +244,13 @@ def _get_alive_time(self) -> int: Once workers are running and the watchdog is active, we delegate to the watchdog's ``get_last_progress_time`` for real liveness tracking. + + During the exit barrier wait, workers have finished and the watchdog + progress time is stale. We return the current time to prevent TW + from killing the task while agents coordinate shutdown. """ + if self._in_exit_barrier: + return int(time.time()) if self._worker_watchdog is not None: return self._worker_watchdog.get_last_progress_time() return int(time.time()) @@ -440,22 +473,6 @@ def _set_local_rank_env( if "CUDA_VISIBLE_DEVICES" in os.environ: worker_env["CUDA_VISIBLE_DEVICES"] = os.environ["CUDA_VISIBLE_DEVICES"] - def _pre_invoke_run(self) -> None: - # Start the health check server immediately, before rendezvous, - # so that TW sees a healthy thrift port while the agent is still - # initializing (package fetch, rendezvous, model load, etc.). - # The _get_alive_time callback dynamically checks self._worker_watchdog: - # returns current time during init, real watchdog time once workers run. - if justknobs_check( - "ai_infra/pytorch_distributed:torchelastic_enable_healthcheck_before_rendezvous", - default=False, - ): - logger.info( - "Starting health check server before rendezvous " - "(torchelastic_enable_healthcheck_before_rendezvous=True)" - ) - self._setup_healthcheck() - def _shutdown( self, death_sig: signal.Signals = signal.SIGTERM, timeout: int = 30 ) -> None: diff --git a/torch/distributed/elastic/events/__init__.py b/torch/distributed/elastic/events/__init__.py index deea40f3899ae..0b8437b2efec5 100644 --- a/torch/distributed/elastic/events/__init__.py +++ b/torch/distributed/elastic/events/__init__.py @@ -31,13 +31,7 @@ from torch.distributed.elastic.events.handlers import get_logging_handler -from .api import ( # noqa: F401 - Event, - EventMetadataValue, - EventSource, - NodeState, - RdzvEvent, -) +from .api import Event, EventMetadataValue, EventSource, NodeState, RdzvEvent _events_loggers: dict[str, logging.Logger] = {} diff --git a/torch/distributed/elastic/metrics/__init__.py b/torch/distributed/elastic/metrics/__init__.py index b2c2330924879..75edb816e2a1f 100644 --- a/torch/distributed/elastic/metrics/__init__.py +++ b/torch/distributed/elastic/metrics/__init__.py @@ -142,7 +142,7 @@ def emit(self, metric_data): from typing import Optional -from .api import ( # noqa: F401 +from .api import ( configure, ConsoleMetricHandler, get_elapsed_time_ms, @@ -163,6 +163,6 @@ def initialize_metrics(cfg: MetricsConfig | None = None): try: - from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F401 F403 + from torch.distributed.elastic.metrics.static_init import * # type: ignore[import] # noqa: F403 except ModuleNotFoundError: pass diff --git a/torch/distributed/elastic/multiprocessing/__init__.py b/torch/distributed/elastic/multiprocessing/__init__.py index 60b7cd32fd253..a387ba40a4955 100644 --- a/torch/distributed/elastic/multiprocessing/__init__.py +++ b/torch/distributed/elastic/multiprocessing/__init__.py @@ -66,7 +66,7 @@ def trainer(a, b, c): from collections.abc import Callable from typing import Optional, Union -from torch.distributed.elastic.multiprocessing.api import ( # noqa: F401 +from torch.distributed.elastic.multiprocessing.api import ( _validate_full_rank, DefaultLogsSpecs, LogsDest, diff --git a/torch/distributed/elastic/multiprocessing/api.py b/torch/distributed/elastic/multiprocessing/api.py index 1ce0d78ed25d7..aee0e9cd8d860 100644 --- a/torch/distributed/elastic/multiprocessing/api.py +++ b/torch/distributed/elastic/multiprocessing/api.py @@ -89,7 +89,7 @@ def _terminate_process_handler(signum: int, frame: FrameType | None) -> None: def _get_kill_signal() -> signal.Signals: """Get the kill signal. SIGKILL for unix, CTRL_C_EVENT for windows.""" if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + return signal.CTRL_C_EVENT # type: ignore[attr-defined] else: return signal.SIGKILL @@ -97,7 +97,7 @@ def _get_kill_signal() -> signal.Signals: def _get_default_signal() -> signal.Signals: """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + return signal.CTRL_C_EVENT # type: ignore[attr-defined] else: return signal.SIGTERM diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 25a60fce6b9c7..3361300483a3c 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -64,8 +64,8 @@ from torch.distributed.elastic.utils.logging import get_logger -from .error_handler import ErrorHandler # noqa: F401 -from .handlers import get_error_handler # noqa: F401 +from .error_handler import ErrorHandler +from .handlers import get_error_handler __all__ = [ diff --git a/torch/distributed/elastic/multiprocessing/redirects.py b/torch/distributed/elastic/multiprocessing/redirects.py index 057013fbb9e5b..cecdec12ffee2 100644 --- a/torch/distributed/elastic/multiprocessing/redirects.py +++ b/torch/distributed/elastic/multiprocessing/redirects.py @@ -24,12 +24,28 @@ logger = logging.getLogger(__name__) +_WIN32_STD_HANDLE = { + "stdout": -11, # STD_OUTPUT_HANDLE + "stderr": -12, # STD_ERROR_HANDLE +} + + def get_libc(): - if IS_WINDOWS or IS_MACOS: - logger.warning( - "NOTE: Redirects are currently not supported in Windows or MacOs." - ) + if IS_MACOS: + logger.warning("NOTE: Redirects are currently not supported in MacOs.") return None + elif IS_WINDOWS: + for lib_name in ("ucrtbase", "msvcrt", "msvcr110", "msvcr100"): + try: + lib = ctypes.CDLL(lib_name) + logger.debug("Loaded Windows C runtime: %s", lib_name) + return lib + except OSError: + continue + raise RuntimeError( + "Could not load a C runtime DLL on Windows (tried: ucrtbase, msvcrt, " + "msvcr110, msvcr100). Redirects cannot function without a CRT." + ) else: return ctypes.CDLL("libc.so.6") @@ -38,6 +54,24 @@ def get_libc(): def _c_std(stream: str): + if IS_WINDOWS: + stream_index = 2 if stream == "stderr" else 1 + try: + iob_func = libc.__acrt_iob_func + iob_func.restype = ctypes.POINTER(ctypes.c_void_p) + iob_func.argtypes = [ctypes.c_uint] + return iob_func(stream_index) + except AttributeError: + pass + try: + legacy_index = 2 if stream == "stderr" else 1 + iob = (ctypes.POINTER(ctypes.c_void_p) * 3).in_dll(libc, "_iob") + return iob[legacy_index] + except (AttributeError, OSError) as err: + raise RuntimeError( + f"Could not resolve C-runtime FILE* for '{stream}'. " + "Neither __acrt_iob_func nor _iob are available in the loaded CRT." + ) from err return ctypes.c_void_p.in_dll(libc, stream) @@ -48,56 +82,147 @@ def _python_std(stream: str): _VALID_STD = {"stdout", "stderr"} -@contextmanager -def redirect(std: str, to_file: str): - """ - Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. - - This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). - See usage for details. - - Directory of ``dst_filename`` is assumed to exist and the destination file - is overwritten if it already exists. - - .. note:: Due to buffering cross source writes are not guaranteed to - appear in wall-clock order. For instance in the example below - it is possible for the C-outputs to appear before the python - outputs in the log file. - - Usage: - - :: - - # syntactic-sugar for redirect("stdout", "tmp/stdout.log") - with redirect_stdout("/tmp/stdout.log"): - print("python stdouts are redirected") - libc = ctypes.CDLL("libc.so.6") - libc.printf(b"c stdouts are also redirected" - os.system("echo system stdouts are also redirected") - - print("stdout restored") - - """ - if std not in _VALID_STD: - raise ValueError( - f"unknown standard stream <{std}>, must be one of {_VALID_STD}" - ) - - c_std = _c_std(std) - python_std = _python_std(std) - std_fd = python_std.fileno() - - def _redirect(dst): - libc.fflush(c_std) - python_std.flush() - os.dup2(dst.fileno(), std_fd) - - with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: - _redirect(dst) - try: - yield - finally: - _redirect(orig_std) +if IS_WINDOWS: # libc is None on macOS; all of the below is Windows-only + import io as _io + import msvcrt as _msvcrt + + _kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + + _crt_dup = libc._dup + _crt_dup2 = libc._dup2 + _crt_dup.restype = ctypes.c_int + _crt_dup.argtypes = [ctypes.c_int] + _crt_dup2.restype = ctypes.c_int + _crt_dup2.argtypes = [ctypes.c_int, ctypes.c_int] + + @contextmanager + def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file at ``to_file``. + + On Windows this performs a four-layer redirect: + + 1. ``sys.stdout``/``sys.stderr`` -- rewired to a new TextIOWrapper so + Python's ``print()`` writes to the destination file. + 2. CRT fd (``_dup2``) -- captures C ``printf`` and UCRT ``FILE*`` writers. + 3. Win32 ``SetStdHandle`` -- captures native code using ``WriteFile``/ + ``WriteConsole`` directly, including HIP/ROCm. + 4. ``fflush`` before each switch -- prevents lost output from CRT buffering. + + .. note:: If ROCm/HIP caches the Win32 HANDLE before this redirect runs + (e.g. at ``import torch`` time), set up the redirect *before* + importing torch/ROCm to capture all output. + + Directory of ``to_file`` is assumed to exist. The destination file is + overwritten if it already exists. + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + std_fd = 1 if std == "stdout" else 2 + win32_handle_id = _WIN32_STD_HANDLE[std] + orig_sys_std = getattr(sys, std) + orig_fd_dup = _crt_dup(std_fd) + if orig_fd_dup == -1: + raise OSError(f"CRT _dup failed for {std} (fd={std_fd})") + orig_win32_handle = _kernel32.GetStdHandle(win32_handle_id) + + with open(to_file, mode="w+b") as dst: + dst_fd = dst.fileno() + + try: + libc.fflush(_c_std(std)) + except Exception: + pass + try: + orig_sys_std.flush() + except Exception: + pass + + _kernel32.SetStdHandle( + win32_handle_id, + _msvcrt.get_osfhandle(dst_fd), # pyrefly: ignore [missing-attribute] + ) + + if _crt_dup2(dst_fd, std_fd) == -1: + raise OSError(f"CRT _dup2 failed redirecting {std}") + + new_sys_std = _io.TextIOWrapper( + open(dst_fd, mode="wb", closefd=False), # noqa: SIM115 + encoding=orig_sys_std.encoding or "utf-8", + errors="replace", + line_buffering=True, + ) + setattr(sys, std, new_sys_std) + + try: + yield + finally: + try: + new_sys_std.flush() + except Exception: + pass + try: + libc.fflush(_c_std(std)) + except Exception: + pass + + setattr(sys, std, orig_sys_std) + _crt_dup2(orig_fd_dup, std_fd) + os.close(orig_fd_dup) + _kernel32.SetStdHandle(win32_handle_id, orig_win32_handle) + +else: + + @contextmanager + def redirect(std: str, to_file: str): + """ + Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. + + This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). + See usage for details. + + Directory of ``dst_filename`` is assumed to exist and the destination file + is overwritten if it already exists. + + .. note:: Due to buffering cross source writes are not guaranteed to + appear in wall-clock order. For instance in the example below + it is possible for the C-outputs to appear before the python + outputs in the log file. + + Usage:: + + # syntactic-sugar for redirect("stdout", "tmp/stdout.log") + with redirect_stdout("/tmp/stdout.log"): + print("python stdouts are redirected") + libc = ctypes.CDLL("libc.so.6") + libc.printf(b"c stdouts are also redirected") + os.system("echo system stdouts are also redirected") + + print("stdout restored") + """ + if std not in _VALID_STD: + raise ValueError( + f"unknown standard stream <{std}>, must be one of {_VALID_STD}" + ) + + c_std = _c_std(std) + python_std = _python_std(std) + std_fd = python_std.fileno() + + def _redirect(dst): + libc.fflush(c_std) + python_std.flush() + os.dup2(dst.fileno(), std_fd) + + with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: + _redirect(dst) + try: + yield + finally: + _redirect(orig_std) redirect_stdout = partial(redirect, "stdout") diff --git a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py index 5bc0db1fd1afb..5dd79c65fbcfd 100644 --- a/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py +++ b/torch/distributed/elastic/multiprocessing/subprocess_handler/subprocess_handler.py @@ -22,7 +22,7 @@ def _get_default_signal() -> signal.Signals: """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" if IS_WINDOWS: - return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 + return signal.CTRL_C_EVENT # type: ignore[attr-defined] else: return signal.SIGTERM diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py index 96d82e03f9cf0..9c0e50e7eba65 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous.py @@ -207,7 +207,7 @@ def shutdown(self) -> bool: try: self.set_closed() return True - except BaseException: # noqa: B036 + except BaseException: logger.warning("Shutdown failed", exc_info=True) return False @@ -332,7 +332,7 @@ def rendezvous_barrier(self): # to avoid spamming etcd # FIXME: there are a few things that fall under this like # etcd.EtcdKeyNotFound, etc, which could be handled more explicitly. - logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) # noqa: G200 + logger.info("Rendezvous attempt failed, will retry. Reason: %s", e) time.sleep(1) def init_phase(self): diff --git a/torch/distributed/elastic/rendezvous/etcd_server.py b/torch/distributed/elastic/rendezvous/etcd_server.py index 347e7339d9a46..37f30fa160ec8 100644 --- a/torch/distributed/elastic/rendezvous/etcd_server.py +++ b/torch/distributed/elastic/rendezvous/etcd_server.py @@ -176,8 +176,8 @@ def start( except Exception as e: curr_retries += 1 stop_etcd(self._etcd_proc) - logger.warning( # noqa: G200 - "Failed to start etcd server, got error: %s, retrying", str(e) + logger.warning( + "Failed to start etcd server, got error: %s, retrying", e ) if curr_retries >= num_retries: shutil.rmtree(self._base_data_dir, ignore_errors=True) diff --git a/torch/distributed/elastic/rendezvous/etcd_store.py b/torch/distributed/elastic/rendezvous/etcd_store.py index faaf77587bc9d..09e588ff5b709 100644 --- a/torch/distributed/elastic/rendezvous/etcd_store.py +++ b/torch/distributed/elastic/rendezvous/etcd_store.py @@ -120,6 +120,7 @@ def add(self, key, num: int) -> int: except etcd.EtcdCompareFailed: cas_delay() + # pyrefly: ignore [bad-override] def wait(self, keys, override_timeout: datetime.timedelta | None = None): """ Wait until all of the keys are published, or until timeout. diff --git a/torch/distributed/elastic/timer/__init__.py b/torch/distributed/elastic/timer/__init__.py index b9c2ea349cc67..1b8eb72b520bb 100644 --- a/torch/distributed/elastic/timer/__init__.py +++ b/torch/distributed/elastic/timer/__init__.py @@ -39,16 +39,6 @@ def trainer_func(message_queue): complete, then the worker process is killed and the agent retries the worker group. """ -from .api import ( # noqa: F401 - configure, - expires, - TimerClient, - TimerRequest, - TimerServer, -) -from .file_based_local_timer import ( # noqa: F401 - FileTimerClient, - FileTimerRequest, - FileTimerServer, -) -from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 +from .api import configure, expires, TimerClient, TimerRequest, TimerServer +from .file_based_local_timer import FileTimerClient, FileTimerRequest, FileTimerServer +from .local_timer import LocalTimerClient, LocalTimerServer diff --git a/torch/distributed/elastic/utils/__init__.py b/torch/distributed/elastic/utils/__init__.py index ce2bbf5bbe234..8fe3ff97ad4ff 100644 --- a/torch/distributed/elastic/utils/__init__.py +++ b/torch/distributed/elastic/utils/__init__.py @@ -6,4 +6,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 +from .api import get_env_variable_or_raise, get_socket_with_port, macros diff --git a/torch/distributed/elastic/utils/data/__init__.py b/torch/distributed/elastic/utils/data/__init__.py index 6c39bca6f3c8a..cc2a7a0978d9a 100644 --- a/torch/distributed/elastic/utils/data/__init__.py +++ b/torch/distributed/elastic/utils/data/__init__.py @@ -6,5 +6,5 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .cycling_iterator import CyclingIterator # noqa: F401 -from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401 +from .cycling_iterator import CyclingIterator +from .elastic_distributed_sampler import ElasticDistributedSampler diff --git a/torch/distributed/flight_recorder/components/builder.py b/torch/distributed/flight_recorder/components/builder.py index 87c52d6d4b090..663363e718257 100644 --- a/torch/distributed/flight_recorder/components/builder.py +++ b/torch/distributed/flight_recorder/components/builder.py @@ -311,7 +311,12 @@ def build_collectives( # This extra cleanup is needed because we need to pop all collectives within a coalesced collective. for i, k in idx_map.items(): for _ in range(1, num_coalesced_entries): - all_entries[i].pop(k) + try: + all_entries[i].pop(k) + except IndexError: + # In the case of a missing rank symptom that a rank didn't schedule the coalesced collective, + # we should not fail the analysis script here. + pass else: # Iterate through all the ranks and check if there is a mismatch for the current entry. check_current_entry_match( diff --git a/torch/distributed/flight_recorder/components/types.py b/torch/distributed/flight_recorder/components/types.py index d0b8f067acfb4..46ebda50dd506 100644 --- a/torch/distributed/flight_recorder/components/types.py +++ b/torch/distributed/flight_recorder/components/types.py @@ -209,16 +209,25 @@ class Database(NamedTuple): "reduce", "_reduce_oop", "all_gather", + "all_gather_single", + "all_gather_v", "all_reduce", "_all_gather_base", "all_gather_into_tensor_coalesced", "reduce_scatter", + "reduce_scatter_single", + "reduce_scatter_v", "reduce_scatter_tensor_coalesced", "_reduce_scatter_base", "gather", "scatter", "all_to_all", + "all_to_all_single", + "all_to_all_v_single", "all_reduce_barrier", + "barrier", + "split", + "new_window", "allreduce_coalesced", "ALLGATHER_coalesced", "REDUCE_SCATTER_coalesced", @@ -283,7 +292,7 @@ def log( logger.info("input sizes: %s", self.input_sizes) logger.info("output sizes: %s", self.output_sizes) logger.info("world size: %d", len(self.expected_ranks)) - logger.info("expected ranks: %s", str(self.expected_ranks)) + logger.info("expected ranks: %s", self.expected_ranks) logger.info("collective state: %s", self.collective_state) if errors: self.errors = errors @@ -409,9 +418,9 @@ def __init__( ): self.profiling_name = event["profiling_name"] comm_lib_backend, name = self.profiling_name.split(":") - if comm_lib_backend not in ["nccl", "xccl"]: + if comm_lib_backend not in ["nccl", "ncclx", "gloo", "xccl"]: raise AssertionError( - f"name formatting error? {comm_lib_backend} != 'nccl' or 'xccl'" + f"name formatting error? {comm_lib_backend} not in supported backends" ) parts = name.split(" ") type = parts[0] diff --git a/torch/distributed/fsdp/__init__.py b/torch/distributed/fsdp/__init__.py index 1e4219250c39d..7c0d55e37d688 100644 --- a/torch/distributed/fsdp/__init__.py +++ b/torch/distributed/fsdp/__init__.py @@ -1,6 +1,7 @@ from ._flat_param import FlatParameter as FlatParameter from ._fully_shard import ( CPUOffloadPolicy, + DataParallelMeshDims, FSDPModule, fully_shard, MixedPrecisionPolicy, @@ -49,6 +50,7 @@ "StateDictType", # FSDP2 "CPUOffloadPolicy", + "DataParallelMeshDims", "FSDPModule", "fully_shard", "MixedPrecisionPolicy", @@ -60,6 +62,7 @@ # Set namespace for exposed private names CPUOffloadPolicy.__module__ = "torch.distributed.fsdp" +DataParallelMeshDims.__module__ = "torch.distributed.fsdp" FSDPModule.__module__ = "torch.distributed.fsdp" fully_shard.__module__ = "torch.distributed.fsdp" MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp" diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index 190e2a0a88a34..c7e226c09457f 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -29,6 +29,7 @@ ) from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined] from torch.testing._internal.distributed.fake_pg import FakeProcessGroup +from torch.utils._typing_utils import not_none from ._fsdp_extensions import ( _ext_post_unflatten_transform, @@ -600,6 +601,7 @@ def __init__( self._needs_pre_backward_unshard = False # Was the handle prefetched? Set on successful _prefetch_handle and unshard self._prefetched = False + self._compute_stream: torch.Stream | None = None # Optimistically assume a valid input `params` and set dtype attributes # before `_init_flat_param()`, which performs the actual validation self._orig_param_dtype = params[0].dtype @@ -1322,6 +1324,11 @@ def pre_unshard(self) -> bool: self._use_sharded_views() ret = False if self._use_orig_params and not self._skip_writeback_check: + # Wait for the compute stream since _writeback_orig_params reads + # original parameters that may still be in use during prefetch. + self._device_handle.current_stream().wait_stream( + not_none(self._compute_stream) + ) ret = self._writeback_orig_params() if ( self.uses_sharded_strategy diff --git a/torch/distributed/fsdp/_fully_shard/__init__.py b/torch/distributed/fsdp/_fully_shard/__init__.py index d4d0b341a3f82..f6aab00fc32b7 100644 --- a/torch/distributed/fsdp/_fully_shard/__init__.py +++ b/torch/distributed/fsdp/_fully_shard/__init__.py @@ -1,4 +1,9 @@ -from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_api import ( + CPUOffloadPolicy, + DataParallelMeshDims, + MixedPrecisionPolicy, + OffloadPolicy, +) from ._fully_shard import ( FSDPModule, fully_shard, @@ -10,6 +15,7 @@ __all__ = [ "CPUOffloadPolicy", + "DataParallelMeshDims", "FSDPModule", "fully_shard", "MixedPrecisionPolicy", diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py index d495bb953cac3..3f03d5707a9ba 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_api.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -126,6 +126,49 @@ def __call__( ) -> dist.Work | None: ... +@dataclass +class DataParallelMeshDims: + """ + Specifies which dimensions of a full SPMD :class:`DeviceMesh` correspond to + data parallelism when using :func:`fully_shard` whose parameters are already + DTensors on that mesh. + + Attributes: + shard (Optional[Union[str, tuple[str, ...]]]): Mesh dimension name(s) + that FSDP shards parameters on. If a tuple of names, those dims + are flattened into a single shard dimension. At least one of + ``shard`` and ``replicate`` must be set. + replicate (Optional[Union[str, tuple[str, ...]]]): Mesh dimension + name(s) for HSDP or DDP replication. If a tuple of names, those + dims are flattened into a single replicate dimension. + """ + + shard: str | tuple[str, ...] | None = None + replicate: str | tuple[str, ...] | None = None + + def __post_init__(self): + if self.shard is None and self.replicate is None: + raise ValueError( + "At least one of shard or replicate must be set in DataParallelMeshDims" + ) + + @property + def shard_names(self) -> tuple[str, ...]: + if self.shard is None: + return () + if isinstance(self.shard, str): + return (self.shard,) + return tuple(self.shard) + + @property + def replicate_names(self) -> tuple[str, ...]: + if self.replicate is None: + return () + if isinstance(self.replicate, str): + return (self.replicate,) + return tuple(self.replicate) + + @dataclass class OffloadPolicy: """ diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py index 3c8aa312c7187..0f15723ae6942 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -20,6 +20,12 @@ from ._fsdp_param import FSDPParam, ShardedState +def _label_with_suffix(label: str, suffix: str) -> str: + if suffix: + return f"{label} {suffix}" + return label + + class AllGatherResult(NamedTuple): all_gather_output: torch.Tensor all_gather_event: torch.Event | None @@ -330,43 +336,49 @@ def foreach_all_gather( all_gather_stream: torch.Stream, device: torch.device, all_gather_comm: AllGather, + label_suffix: str = "", ) -> AllGatherResult | None: world_size, rank = group.size(), group.rank() device_handle = _get_device_handle(device.type) + with device_handle.stream(all_gather_copy_in_stream): - param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) - ( - param_all_gather_input_dtypes, - param_all_gather_input_numels, - dtype, - ) = _get_all_gather_input_metadatas(param_all_gather_inputs) - if dtype == torch.uint8: - all_gather_inputs = [ - t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts - ] - else: - all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] - inp_split_sizes = [t.numel() for t in all_gather_inputs] - all_gather_input_numel = sum(inp_split_sizes) - all_gather_output = all_gather_comm.allocate( - (all_gather_input_numel * world_size,), dtype=dtype, device=device - ) - all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( - all_gather_inputs, - all_gather_output, - inp_split_sizes, - all_gather_input_numel, - rank, - ) - del param_all_gather_inputs + with torch.profiler.record_function( + _label_with_suffix("FSDP::all_gather_copy_in", label_suffix) + ): + param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) + ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + dtype, + ) = _get_all_gather_input_metadatas(param_all_gather_inputs) + if dtype == torch.uint8: + all_gather_inputs = [ + t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts + ] + else: + all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] + inp_split_sizes = [t.numel() for t in all_gather_inputs] + all_gather_input_numel = sum(inp_split_sizes) + all_gather_output = all_gather_comm.allocate( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + all_gather_output, + inp_split_sizes, + all_gather_input_numel, + rank, + ) + del param_all_gather_inputs all_gather_stream.wait_stream(all_gather_copy_in_stream) with device_handle.stream(all_gather_stream): - all_gather_work = all_gather_comm( - output_tensor=all_gather_output, - input_tensor=all_gather_input, - group=group, - async_op=async_op, - ) + with dist.record_comm(_label_with_suffix("FSDP::all_gather", label_suffix)): + all_gather_work = all_gather_comm( + output_tensor=all_gather_output, + input_tensor=all_gather_input, + group=group, + async_op=async_op, + ) all_gather_event = all_gather_stream.record_event() return AllGatherResult( all_gather_output, @@ -537,6 +549,7 @@ def foreach_reduce( partial_reduce_output: torch.Tensor | None, # only used for HSDP all_reduce_hook: Callable[[torch.Tensor], None] | None, force_sum_reduction_for_comms: bool = False, + label_suffix: str = "", ) -> tuple[ torch.Tensor, torch.Event, @@ -617,12 +630,15 @@ def foreach_reduce( ) _div_if_needed(reduce_scatter_input, predivide_factor) if world_size > 1: - reduce_scatter_comm( - output_tensor=reduce_output, - input_tensor=reduce_scatter_input, - group=reduce_scatter_group, - op=reduce_scatter_op, - ) + with dist.record_comm( + _label_with_suffix("FSDP::reduce_scatter", label_suffix) + ): + reduce_scatter_comm( + output_tensor=reduce_output, + input_tensor=reduce_scatter_input, + group=reduce_scatter_group, + op=reduce_scatter_op, + ) else: # For single GPU, just copy the input to output (no actual reduce-scatter needed), and # account for a possible gradient_divide_factor. @@ -655,11 +671,22 @@ def foreach_reduce( else: all_reduce_stream.wait_stream(current_stream) with device_handle.stream(all_reduce_stream): - dist.all_reduce( - reduce_output, - group=all_reduce_group, - op=all_reduce_op, - ) + with dist.record_comm( + _label_with_suffix("FSDP::all_reduce", label_suffix) + ): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=all_reduce_op, + ) + # Keep refs to the reduce-dtype AR buffer + completion + # event so FSDPParamGroup._all_reduce_state can hold them + # across layers. This keeps the buffer off the caching + # allocator's free list; otherwise the next layer's + # reduce-scatter can reuse the same physical block while + # this layer's AR is still in flight, causing cross-layer + # gradient aliasing under slow AR. See PR #140044, + # regression test PR #180900. all_reduce_input = reduce_output all_reduce_event = all_reduce_stream.record_event() # -- END: ops in reduce_scatter stream @@ -676,6 +703,15 @@ def foreach_reduce( with device_handle.stream(post_reduce_stream): _div_if_needed(reduce_output, postdivide_factor) + # Rebinds to a new orig_dtype tensor when reduce_dtype != + # orig_dtype. Do NOT rely on this stream-scoped rebind to manage + # the old reduce-dtype buffer's lifetime: the rebind orders the + # cast before the free-event on AR stream, but the freed block + # lands on the caching allocator's free list and the next layer's + # RS on RS stream can reuse it without waiting for this layer's + # AR to finish. The reduce-dtype buffer is held across layers by + # FSDPParamGroup._all_reduce_state (captured above) to prevent + # this. See PR #140044, regression test PR #180900. reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) # View out and accumulate sharded gradients flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] @@ -692,10 +728,28 @@ def foreach_reduce( ) to_accumulate_grad = fsdp_param.sharded_param.grad is not None if fsdp_param.offload_to_cpu: - # Only overlap the D2H copy (copying to pinned memory) if not - # accumulating gradients since the CPU add kernel depends on - # the copy result and we cannot run the add as a callback - non_blocking = fsdp_param.pin_memory and not to_accumulate_grad + # Only overlap the D2H copy (copying to pinned memory) when no + # in-backward CPU consumer of the grad exists. Two such + # consumers suppress the overlap: + # - Accumulating grads: the CPU add kernel depends on the + # copy result and we cannot run the add as a callback. + # - Post-accumulate-grad hooks: user code (e.g. + # optimizer-in-backward) reads ``param.grad`` on CPU + # synchronously. With ``non_blocking=True`` the hook would + # observe in-flight pinned memory — silently wrong + # optimizer updates. + has_post_acc_grad_hook = bool( + getattr( + fsdp_param.sharded_param, + "_post_accumulate_grad_hooks", + None, + ) + ) + non_blocking = ( + fsdp_param.pin_memory + and not to_accumulate_grad + and not has_post_acc_grad_hook + ) # Since the GPU sharded gradient is allocated in the RS stream, # we can free it here by not keeping a ref without waiting for # the D2H copy since future RS-stream ops run after the copy @@ -811,7 +865,7 @@ def _get_gradient_divide_factors( if reduce_scatter_group is not None and factor == reduce_scatter_group.size(): reduce_scatter_op = ReduceOp.AVG else: - reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) + reduce_scatter_op = ReduceOp.PREMUL_SUM(1 / factor) return None, None, reduce_scatter_op, ReduceOp.SUM if factor is None: diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index 2f76336332e85..f96f634b86f51 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -2,7 +2,7 @@ import functools import math import traceback -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import auto, Enum from typing import Any @@ -13,6 +13,8 @@ from torch.distributed.tensor import DeviceMesh, DTensor, Shard from torch.distributed.tensor._dtensor_spec import DTensorSpec +from ._fsdp_api import DataParallelMeshDims + def _dynamo_disable(func): """Disable dynamo tracing for FSDP hooks.""" @@ -31,12 +33,19 @@ class DataParallelMeshInfo: mesh: DeviceMesh shard_mesh_dim: int | None = None replicate_mesh_dim: int | None = None + dp_mesh_dims: DataParallelMeshDims | None = None + # The full SPMD mesh (excluding PP dims) that params are distributed on. + # Must include all non-PP SPMD dims (e.g. DP + TP); passing a submesh + # that omits dims like TP will lead to incorrect behavior. + spmd_mesh: DeviceMesh | None = field(default=None, repr=False) + is_spmd_mesh: bool = field(default=False, init=False, repr=False) def __post_init__(self): if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: raise AssertionError( "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" ) + self.is_spmd_mesh = self.dp_mesh_dims is not None @dataclass diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py index 0b1d652422852..42f5293705cda 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -13,6 +13,7 @@ from ._fsdp_common import ( _is_composable_with_fsdp, DataParallelMeshInfo, + DDPMeshInfo, FSDPMeshInfo, HSDPMeshInfo, ) @@ -23,7 +24,7 @@ from collections.abc import Callable from typing import Any - from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy + from ._fsdp_api import DataParallelMeshDims, MixedPrecisionPolicy, OffloadPolicy from ._fsdp_common import ShardPlacementFnResult from ._fsdp_state import FSDPState @@ -46,13 +47,35 @@ def _validate_module(module: nn.Module, func_name: str) -> None: ) -def _validate_mesh(mesh: "DeviceMesh") -> None: +def _validate_mesh( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims | None" = None, +) -> None: """ Validate that the mesh can be used with fully_shard. - Raises ValueError if the mesh is not 1D or 2D. - Raises AssertionError if the mesh is 2D but mesh_dim_names is not specified. + When ``dp_mesh_dims`` is provided, validates that the named dims + exist in the mesh and at least one of shard/replicate is set. + Otherwise raises ValueError if the mesh is not 1D or 2D. """ + if dp_mesh_dims is not None: + if dp_mesh_dims.shard is None and dp_mesh_dims.replicate is None: + raise ValueError( + "At least one of shard or replicate must be set in dp_mesh_dims" + ) + if mesh.mesh_dim_names is None: + raise ValueError( + "mesh must have mesh_dim_names when dp_mesh_dims is provided" + ) + names_to_check: list[str] = list(dp_mesh_dims.shard_names) + names_to_check.extend(dp_mesh_dims.replicate_names) + for name in names_to_check: + if name not in mesh.mesh_dim_names: + raise ValueError( + f"Mesh dim name '{name}' not found in mesh.mesh_dim_names " + f"{mesh.mesh_dim_names}" + ) + return if mesh.ndim not in (1, 2): raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") if mesh.ndim == 2 and mesh.mesh_dim_names is None: @@ -61,18 +84,71 @@ def _validate_mesh(mesh: "DeviceMesh") -> None: ) -def _get_mesh_info(mesh: "DeviceMesh") -> "FSDPMeshInfo": +def _get_mesh_info( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims | None" = None, +) -> "DataParallelMeshInfo": """ Get the appropriate mesh info for the given mesh. + When ``dp_mesh_dims`` is provided, extracts the DP submesh from the + full SPMD mesh and returns FSDPMeshInfo, HSDPMeshInfo, or DDPMeshInfo + with ``dp_mesh_dims`` set and ``is_spmd_mesh`` as True. + Returns FSDPMeshInfo for 1D mesh, HSDPMeshInfo for 2D mesh. """ + if dp_mesh_dims is not None: + return _get_mesh_info_from_named_dims(mesh, dp_mesh_dims) if mesh.ndim == 1: return FSDPMeshInfo(mesh, shard_mesh_dim=0) else: return HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) +def _get_mesh_info_from_named_dims( + mesh: "DeviceMesh", + dp_mesh_dims: "DataParallelMeshDims", +) -> "DataParallelMeshInfo": + shard_names = dp_mesh_dims.shard_names + replicate_names = dp_mesh_dims.replicate_names + + def _get_submesh(names: tuple[str, ...]) -> "DeviceMesh": + if len(names) == 1: + return mesh[names[0]] + # Flatten multi-dim submesh into a single dim so FSDP's internal + # logic (which expects one shard and/or one replicate dim) works + # unchanged. This creates a new 1D DeviceMesh and ProcessGroup. + return mesh[names]._flatten("_".join(names)) + + if len(shard_names) == 0: # DDP + dp_mesh = _get_submesh(replicate_names) + return DDPMeshInfo( + dp_mesh, + replicate_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + if len(replicate_names) == 0: # FSDP + dp_mesh = _get_submesh(shard_names) + return FSDPMeshInfo( + dp_mesh, + shard_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + # HSDP + shard_mesh = _get_submesh(shard_names) + replicate_mesh = _get_submesh(replicate_names) + dp_mesh = DeviceMesh._concatenate([replicate_mesh, shard_mesh]) + return HSDPMeshInfo( + dp_mesh, + shard_mesh_dim=1, + replicate_mesh_dim=0, + dp_mesh_dims=dp_mesh_dims, + spmd_mesh=mesh, + ) + + def _get_post_forward_mesh_info( reshard_after_forward: bool | int, mesh_info: FSDPMeshInfo ) -> FSDPMeshInfo | None: diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index b39e8c8f1f3cb..f09e7b8dcd00a 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -31,6 +31,16 @@ ) +_orig_param_uid_counter = itertools.count() + + +def _get_orig_param_uid(param: nn.Parameter) -> int: + if not hasattr(param, "_fsdp_orig_uid"): + uid = next(_orig_param_uid_counter) + param._fsdp_orig_uid = uid # pyrefly: ignore[missing-attribute] + return param._fsdp_orig_uid # pyrefly: ignore[missing-attribute] + + """ [Note: FSDP tensors] FSDP considers the following tensors: @@ -166,12 +176,14 @@ class FSDPParam: _unsharded_param: nn.Parameter # ND unsharded_accumulated_grad: torch.Tensor | None # ND _sharding_spec: DTensorSpec - # DTensor attributes (only defined for DTensor `param`): - _tp_spec: DTensorSpec + _unsharded_dtensor_spec: ( + DTensorSpec | None + ) # set for DTensor params (SPMD or TP/EP) all_gather_outputs: list[torch.Tensor] # 1D # All-gather extension attributes _extensions_data: ExtensionsData _unsharded_inner_tensors: list[torch.Tensor] + _orig_param_uid: int def __init__( self, @@ -226,6 +238,7 @@ def _init_sharded_param( else: self.mesh_info = mesh_info # pyrefly: ignore[bad-assignment] fsdp_placement = None + self._shard_mesh = self._init_shard_mesh() if param.device != device and param.device.type != "meta": raise AssertionError( f"Expects the parameter to already be moved to device {device} but got {param.device}" @@ -249,69 +262,8 @@ def _init_sharded_param( # TODO: Simplify the following sharded parameter padding logic after # https://github.com/pytorch/pytorch/issues/113045 self.is_dtensor = isinstance(param, DTensor) - if self.is_dtensor: - self._tp_spec = cast(DTensor, param)._spec - dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) - if dp_mesh is None or tp_mesh is None: - raise AssertionError( - "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n" - f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}" - ) - self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh]) - if len(self._tp_spec.placements) > 2: - raise NotImplementedError( - f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._tp_spec.placements}" - ) - split_factor = self._tp_spec.num_shards_map[shard_dim] - if not (2 <= self._spmd_mesh.ndim <= 4): - raise AssertionError( - "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " - f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." - ) - self._spmd_placements: tuple[Placement, ...] - if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP - dp_shard_tp_placement = ( - ( - _StridedShard(shard_dim, split_factor=split_factor) - if split_factor > 1 - else fsdp_placement - ), - *self._tp_spec.placements, - ) - else: # DDP - dp_shard_tp_placement = ( - (Replicate()), - *self._tp_spec.placements, - ) - if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP - if self.mesh_info.replicate_mesh_dim != 0: - raise AssertionError( - f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" - ) - self._spmd_placements = (Replicate(),) + dp_shard_tp_placement - else: # FSDP or DDP - self._spmd_placements = dp_shard_tp_placement - - self._sharding_spec = DTensorSpec( - self._spmd_mesh, - self._spmd_placements, - tensor_meta=self._tp_spec.tensor_meta, - ) - param_data = cast(DTensor, param)._local_tensor - else: - self._spmd_mesh = self.mesh_info.mesh - if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP - self._spmd_placements = (Replicate(), fsdp_placement) - elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP - self._spmd_placements = (fsdp_placement,) - elif isinstance(self.mesh_info, DDPMeshInfo): # DDP - self._spmd_placements = (Replicate(),) - self._sharding_spec = DTensorSpec( - self._spmd_mesh, - self._spmd_placements, - tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), - ) - param_data = param + self._orig_param_uid = _get_orig_param_uid(param) + param_data = self._init_sharding_spec(param, fsdp_placement, shard_dim) if not param_data.is_contiguous(): raise AssertionError( f"Expected contiguous tensor, got {param_data.shape=} {param_data.stride()=}" @@ -366,13 +318,189 @@ def _init_sharded_param( raise AssertionError( f"Expected contiguous tensor with {self.fsdp_placement=}" ) - self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) - self.sharded_param.requires_grad_(param.requires_grad) + self.sharded_param = nn.Parameter( + self.to_sharded_dtensor(sharded_param), + requires_grad=param.requires_grad, + ) # Let `param_data` be freed normally when its ref count reaches 0 when # the `fully_shard` call returns to allow provided parameters to alias self._setattr_on_modules(self.sharded_param) self.sharded_state = ShardedState.SHARDED + def _init_sharding_spec( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """ + Build ``_sharding_spec``, ``_spmd_mesh``, and ``_spmd_placements`` and + return the local tensor data to be sharded. + """ + self._unsharded_dtensor_spec = None + if self.mesh_info.is_spmd_mesh and not self.is_dtensor: + raise ValueError( + "When dp_mesh_dims is provided, all parameters must be " + "DTensors on the full SPMD mesh (e.g. via distribute_module). " + f"Got plain tensor for parameter '{self._module_info.param_name}'." + ) + if self.is_dtensor and self.mesh_info.is_spmd_mesh: + return self._init_sharding_spec_spmd(param, fsdp_placement, shard_dim) + if self.is_dtensor: + return self._init_sharding_spec_tp(param, fsdp_placement, shard_dim) + return self._init_sharding_spec_plain(param, fsdp_placement) + + def _init_sharding_spec_spmd( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """SPMD path: param is a DTensor on the full SPMD mesh.""" + self._unsharded_dtensor_spec = cast(DTensor, param)._spec + spmd_mesh = self._unsharded_dtensor_spec.mesh + dp_dim_names = self.mesh_info.dp_mesh_dims + if dp_dim_names is None: + raise AssertionError("dp_dim_names must not be None for SPMD mesh") + if spmd_mesh.mesh_dim_names is None: + raise AssertionError("spmd_mesh.mesh_dim_names must not be None") + if ( + self.mesh_info.spmd_mesh is not None + and spmd_mesh is not self.mesh_info.spmd_mesh + ): + raise ValueError( + "Expected param's DTensor mesh to be the same mesh passed " + "to fully_shard, but got different mesh objects" + ) + + dp_shard_indices = [ + spmd_mesh.mesh_dim_names.index(n) for n in dp_dim_names.shard_names + ] + + orig_placements = self._unsharded_dtensor_spec.placements + for idx in dp_shard_indices: + if not isinstance(orig_placements[idx], Replicate): + raise ValueError( + f"Expected Replicate() on DP shard dim " + f"'{spmd_mesh.mesh_dim_names[idx]}' (index {idx}) " + f"but got {orig_placements[idx]}" + ) + dp_replicate_indices = [] + for rep_name in dp_dim_names.replicate_names: + rep_idx = spmd_mesh.mesh_dim_names.index(rep_name) + dp_replicate_indices.append(rep_idx) + if not isinstance(orig_placements[rep_idx], Replicate): + raise ValueError( + f"Expected Replicate() on DP replicate dim " + f"'{spmd_mesh.mesh_dim_names[rep_idx]}' (index {rep_idx}) " + f"but got {orig_placements[rep_idx]}" + ) + + # Cache DP dim indices so _get_grad_inner_tensor can skip + # redistribution on DP dims and let FSDP's reduce-scatter handle them. + self._dp_dim_indices: frozenset[int] = frozenset( + dp_shard_indices + dp_replicate_indices + ) + + new_placements = list(orig_placements) + for dp_idx in dp_shard_indices: + # split_factor = number of non-DP shards on shard_dim from + # mesh dims with higher index (the "right-side" dims that + # _StridedShard needs to interleave with) + sf = 1 + for j in range(dp_idx + 1, spmd_mesh.ndim): + p = orig_placements[j] + if isinstance(p, (Shard, _StridedShard)) and p.dim == shard_dim: + sf *= spmd_mesh.size(j) + new_placements[dp_idx] = ( + _StridedShard(shard_dim, split_factor=sf) if sf > 1 else fsdp_placement + ) + + self._spmd_mesh = spmd_mesh + self._spmd_placements: tuple[Placement, ...] = tuple(new_placements) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._unsharded_dtensor_spec.tensor_meta, + ) + return cast(DTensor, param)._local_tensor + + def _init_sharding_spec_tp( + self, + param: nn.Parameter, + fsdp_placement: Shard, + shard_dim: int, + ) -> torch.Tensor: + """TP/EP path: param is a DTensor, DP mesh is separate from TP mesh.""" + self._unsharded_dtensor_spec = cast(DTensor, param)._spec + dp_mesh, tp_mesh = (self.mesh_info.mesh, self._unsharded_dtensor_spec.mesh) + if dp_mesh is None or tp_mesh is None: + raise AssertionError( + "FSDP requires the DP and model parallel TP/EP mesh to be not None but got: \n" + f"DP's mesh: {dp_mesh}\nTP/EP's mesh: {tp_mesh}" + ) + self._spmd_mesh = DeviceMesh._concatenate([dp_mesh, tp_mesh]) + if len(self._unsharded_dtensor_spec.placements) > 2: + raise NotImplementedError( + f"FSDP only supports 1D TP/EP or 2D EP+TP, not {self._unsharded_dtensor_spec.placements}" + ) + split_factor = self._unsharded_dtensor_spec.num_shards_map[shard_dim] + if not (2 <= self._spmd_mesh.ndim <= 4): + raise AssertionError( + "_spmd_mesh.ndim can only be 2 (FSDP+TP/EP), 3 (FSDP+EP+TP, HSDP+TP/EP), " + f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}." + ) + if isinstance(self.mesh_info, FSDPMeshInfo): + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + *self._unsharded_dtensor_spec.placements, + ) + else: # DDP + dp_shard_tp_placement = ( + Replicate(), + *self._unsharded_dtensor_spec.placements, + ) + self._spmd_placements: tuple[Placement, ...] + if isinstance(self.mesh_info, HSDPMeshInfo): + if self.mesh_info.replicate_mesh_dim != 0: + raise AssertionError( + f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}" + ) + self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + else: + self._spmd_placements = dp_shard_tp_placement + + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._unsharded_dtensor_spec.tensor_meta, + ) + return cast(DTensor, param)._local_tensor + + def _init_sharding_spec_plain( + self, + param: nn.Parameter, + fsdp_placement: Shard, + ) -> torch.Tensor: + """Plain tensor path: param is not a DTensor.""" + self._spmd_mesh = self.mesh_info.mesh + if isinstance(self.mesh_info, HSDPMeshInfo): + self._spmd_placements = (Replicate(), fsdp_placement) + elif isinstance(self.mesh_info, FSDPMeshInfo): + self._spmd_placements = (fsdp_placement,) + elif isinstance(self.mesh_info, DDPMeshInfo): + self._spmd_placements = (Replicate(),) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), + ) + return param + def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: mesh_info = self.post_forward_mesh_info if mesh_info is None: @@ -404,8 +532,10 @@ def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): # then we do not need extra casting if reduce_dtype == param_dtype: reduce_dtype = None - # Clamp `param_dtype` to `None` if no casting is required - if param_dtype == self.orig_dtype: + # Clamp `param_dtype` to `None` if no casting is required or if the + # parameter is non-floating-point (mixed precision is only meaningful + # for floating-point parameters) + if param_dtype == self.orig_dtype or not self.orig_dtype.is_floating_point: param_dtype = None self.param_dtype = param_dtype self.reduce_dtype = reduce_dtype @@ -481,8 +611,10 @@ def init_unsharded_param(self): self._contiguous_orig_stride, storage_offset=0, ) - if self.is_dtensor: - unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if self._unsharded_dtensor_spec is not None: + unsharded_param = _from_local_no_grad( + unsharded_param, self._unsharded_dtensor_spec + ) self._unsharded_param = nn.Parameter( unsharded_param, requires_grad=self.sharded_param.requires_grad ) @@ -503,7 +635,7 @@ def to_sharded(self) -> None: def to_sharded_post_forward(self) -> None: if self.is_dtensor: raise NotImplementedError( - "Resharding to smaller mesh with TP is not supported yet" + "Resharding to smaller mesh is not supported for DTensor parameters yet" ) self._assert_in_states(ShardedState.UNSHARDED) if self.post_forward_mesh_info is None: @@ -532,7 +664,8 @@ def to_sharded_post_forward(self) -> None: storage_offset=0, ) self._sharded_post_forward_param = nn.Parameter( - self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) + self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor), + requires_grad=self.sharded_param.requires_grad, ) self._setattr_on_modules(self._sharded_post_forward_param) self.free_unsharded_param() @@ -726,13 +859,32 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: grad = grad.wait() if not isinstance(grad, DTensor): raise AssertionError(f"Expected DTensor, got {type(grad)}") - placements = self._tp_spec.placements - if placements != grad.placements: - if len(self._tp_spec.placements) != len(grad.placements): - raise AssertionError( - f"Expected same placement length: {self._tp_spec=} {grad.placements=}" - ) - grad = grad.redistribute(placements=placements) + if self._unsharded_dtensor_spec is None: + raise AssertionError( + "Expected _unsharded_dtensor_spec for DTensor param" + ) + placements = self._unsharded_dtensor_spec.placements + if self.mesh_info.is_spmd_mesh: + # Only redistribute non-DP dims; keep Partial on DP dims + # so FSDP's reduce-scatter handles them directly, avoiding + # a redundant all-reduce on the DP dimensions. + target_placements = tuple( + grad.placements[i] if i in self._dp_dim_indices else placements[i] + for i in range(len(placements)) + ) + if target_placements != grad.placements: + if len(placements) != len(grad.placements): + raise AssertionError( + f"Expected same placement length: {placements=} {grad.placements=}" + ) + grad = grad.redistribute(placements=target_placements) + else: + if placements != grad.placements: + if len(placements) != len(grad.placements): + raise AssertionError( + f"Expected same placement length: {placements=} {grad.placements=}" + ) + grad = grad.redistribute(placements=placements) grad = grad._local_tensor return grad @@ -740,8 +892,7 @@ def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: def _sharded_local_tensor(self) -> torch.Tensor: return cast(DTensor, self.sharded_param)._local_tensor - @property - def shard_mesh(self): + def _init_shard_mesh(self) -> DeviceMesh: mesh = self.mesh_info.mesh if mesh.ndim == 1: return mesh @@ -749,6 +900,10 @@ def shard_mesh(self): raise AssertionError("Expected mesh_dim_names to not be None") return mesh[mesh.mesh_dim_names[-1]] + @property + def shard_mesh(self): + return self._shard_mesh + @property def shard_mesh_from_root(self): return self.shard_mesh diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 0b71c63720e05..5015f7acd2ec9 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -95,7 +95,7 @@ def lazy_init(self, device: torch.device): # tensors produced in one stream and used in another and accompanying # CUDA events for synchronization self.all_gather_state: AllGatherState | None = None - self.reduce_scatter_state: ReduceScatterState | None = None + self.reduce_scatter_states: list[ReduceScatterState] = [] # Post-forward order for explicit backward prefetching self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles @@ -124,6 +124,11 @@ class ReduceScatterState(NamedTuple): class AllReduceState(NamedTuple): + # Holding all_reduce_input (the reduce-dtype AR buffer) keeps the + # caching allocator from reusing the block across layers. This is a + # structural invariant, not bookkeeping: without it, the next layer's + # RS can reuse the same physical block before this layer's AR finishes + # under slow AR, causing gradient aliasing. See PR #140044, PR #180900. all_reduce_input: torch.Tensor event: torch.Event | None # all-reduce event @@ -189,6 +194,8 @@ def __init__( # - Communication and communication/computation overlap self.comm_ctx = FSDPCommContext() + self._param_group_index: int = 0 + self._num_param_groups: int = 1 # Group's indices in the shared post-forward order self._post_forward_indices: list[int] = [] # Whether to reduce gradients at all (whether for FSDP or HSDP) @@ -228,10 +235,13 @@ def __init__( # Only for HSDP, if accumulating gradients without all-reduce, save the # partial reduce output (only reduce-scattered but not all-reduced) self._partial_reduce_output: torch.Tensor | None = None - # Holds the all-reduce input and all-reduce event to keep it alive - # until the end of backward (critical when doing bf16 reduction with - # fp32 parameters since the all-reduce input is allocated in the RS - # stream and will have no refs to it after being upcast to fp32) + # Holds the reduce-dtype AR buffer + completion event across + # layers in HSDP+AR with reduce_dtype != orig_dtype (e.g., bf16 + # reduce + fp32 params). Structural invariant: the live Python + # ref keeps the buffer off the caching allocator's free list, + # preventing the next layer's RS from reusing the same physical + # block while this layer's AR is still in flight. See + # AllReduceState docstring and regression test PR #180900. self._all_reduce_state: AllReduceState | None = None # Initialization # @@ -374,6 +384,7 @@ def unshard(self, async_op: bool = False): *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), self.device, self._all_gather_comm, + self._label_suffix, ) def wait_for_unshard(self): @@ -546,17 +557,20 @@ def post_backward(self, *unused: Any): fsdp_param.unsharded_param.grad = None if self.reshard_after_backward: self.reshard() + # Wait on prior module's RS states (assumes backward fires groups + # N-1 first; if not, overlap degrades but correctness is preserved). + if ( + self._param_group_index == self._num_param_groups - 1 + and self.comm_ctx.reduce_scatter_states + ): + with record_function(f"FSDP::post_backward_rs_wait ({self._module_fqn})"): + for rs_state in self.comm_ctx.reduce_scatter_states: + if rs_state.event is not None: + self.device_handle.current_stream().wait_event(rs_state.event) + self.comm_ctx.reduce_scatter_states.clear() if len(fsdp_params_with_grad) == 0: return with record_function(self._with_fqn("FSDP::post_backward_reduce")): - if ( - self.comm_ctx.reduce_scatter_state is not None - and self.comm_ctx.reduce_scatter_state.event is not None - ): - self.device_handle.current_stream().wait_event( - self.comm_ctx.reduce_scatter_state.event - ) - self.comm_ctx.reduce_scatter_state = None all_reduce_pg = ( self._all_reduce_process_group if isinstance(self.mesh_info, DDPMeshInfo) @@ -607,9 +621,10 @@ def post_backward(self, *unused: Any): self._partial_reduce_output, self._all_reduce_hook, self.force_sum_reduction_for_comms, + self._label_suffix, ) - self.comm_ctx.reduce_scatter_state = ReduceScatterState( - reduce_scatter_input, reduce_scatter_event + self.comm_ctx.reduce_scatter_states.append( + ReduceScatterState(reduce_scatter_input, reduce_scatter_event) ) if all_reduce_input is not None: if self.device.type != "cpu": @@ -655,14 +670,36 @@ def _backward_prefetch(self) -> None: # Can be cleared if running multiple `backward`s return curr_index = self._post_forward_indices.pop() - if (target_index := curr_index - 1) < 0: - return - # Prefetch naively using the reverse post-forward order, which may - # have mistargeted prefetches if not all modules used in forward - # are used in this backward - # pyrefly: ignore [unbound-name] - target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] - self._prefetch_unshard(target_fsdp_param_group, "backward") + if self._num_param_groups > 1: + # Backward fires groups in reverse forward order: + # N-1, N-2, ..., 1, 0. Index 1 is always the + # penultimate group regardless of N. Prefetching here + # lets the next module's AG overlap with group 0's RS + # without holding unsharded params too long (as would + # happen if we prefetched from N-1). + if self._param_group_index != 1: + return + # E.g. fully_shard(block, shard_placement_fn=...) creates two + # param groups per block (dense + moe), giving + # post_forward_order = [block0, block0.moe, block1, block1.moe]. + # block1.moe walks back past block1 to prefetch block0.moe then block0. + curr_modules = self.modules + target_modules: tuple[nn.Module, ...] | None = None + for step in range(1, curr_index + 1): + target = self.comm_ctx.post_forward_order[curr_index - step] + if target.modules is curr_modules: + continue + if target_modules is None: + target_modules = target.modules + elif target.modules is not target_modules: + break + # Prefetch all groups of the target module in + # reverse forward order (highest index first), + # matching the explicit path in _pre_backward. + self._prefetch_unshard(target, "backward") + elif curr_index > 0: + target = self.comm_ctx.post_forward_order[curr_index - 1] + self._prefetch_unshard(target, "backward") @staticmethod def _prefetch_unshard( @@ -810,9 +847,17 @@ def _all_reduce_process_group(self) -> dist.ProcessGroup: ) return self.mesh_info.replicate_process_group + @property + def _label_suffix(self) -> str: + suffix = f"({self._module_fqn})" if self._module_fqn else "" + if self._num_param_groups > 1 and isinstance(self.mesh_info, FSDPMeshInfo): + suffix = f"{suffix} [pg={self.mesh_info.shard_mesh_size}]".lstrip() + return suffix + def _with_fqn(self, label: str) -> str: - if self._module_fqn: - return f"{label} ({self._module_fqn})" + suffix = self._label_suffix + if suffix: + return f"{label} {suffix}" return label def __repr__(self): diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 7481efc688663..a1bbcbf29c79f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -200,19 +200,38 @@ def _lazy_init(self) -> None: fsdp_param_group.post_forward_mesh_info = None self._init_fqns() self._init_shared_state() + self._validate_no_duplicate_params() # Run parameter group lazy inits after initializing FQNs for improved # error messages for state in self._state_ctx.all_states: for fsdp_param_group in state._fsdp_param_groups: fsdp_param_group.lazy_init() + def _validate_no_duplicate_params(self) -> None: + seen: set[int] = set() + for state in self._state_ctx.all_states: + for fsdp_param_group in state._fsdp_param_groups: + for fsdp_param in fsdp_param_group.fsdp_params: + if fsdp_param._orig_param_uid in seen: + raise ValueError( + f"Parameter '{fsdp_param._param_fqn}' is shared with a " + f"parameter already managed by another FSDP group. " + f"For shared/tied parameters, use " + f"fully_shard([module_a, module_b]) to place them in " + f"the same FSDP group." + ) + seen.add(fsdp_param._orig_param_uid) + def _init_shared_state(self) -> None: self._comm_ctx.lazy_init(self._device) for state in self._state_ctx.all_states: state._state_ctx = self._state_ctx state._comm_ctx = self._comm_ctx - for fsdp_param_group in state._fsdp_param_groups: + num_groups = len(state._fsdp_param_groups) + for i, fsdp_param_group in enumerate(state._fsdp_param_groups): fsdp_param_group.comm_ctx = self._comm_ctx + fsdp_param_group._param_group_index = i + fsdp_param_group._num_param_groups = num_groups def _init_fqns(self) -> None: """Sets module and parameter FQN attributes for debugging.""" @@ -277,6 +296,8 @@ def _pre_forward( for fsdp_param_group in self._fsdp_param_groups: args, kwargs = fsdp_param_group.pre_forward(module, args, kwargs) for fsdp_state in self._states_to_forward_prefetch: + # Forward order (not reversed) to match forward execution order; + # contrast with reversed() in _pre_backward for backward order. for target_param_group in fsdp_state._fsdp_param_groups: FSDPParamGroup._prefetch_unshard(target_param_group, "forward") return args, kwargs @@ -317,7 +338,9 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: for fsdp_param_group in self._fsdp_param_groups: fsdp_param_group.pre_backward(default_prefetch) for fsdp_state in self._states_to_backward_prefetch: - for target_param_group in fsdp_state._fsdp_param_groups: + # Reverse so higher-indexed groups are prefetched first, + # matching backward execution order (reverse of forward). + for target_param_group in reversed(fsdp_state._fsdp_param_groups): FSDPParamGroup._prefetch_unshard(target_param_group, "backward") return grad @@ -326,7 +349,11 @@ def _root_post_backward_final_callback(self) -> None: logger.debug("FSDP::root_post_backward") with torch.profiler.record_function("FSDP::root_post_backward_callback"): for state in self._state_ctx.all_states: - for fsdp_param_group in state._fsdp_param_groups: + # Reverse so that the last param group (which gates the + # reduce-scatter wait/clear) fires first, matching the + # autograd backward order and preserving RS overlap for + # per-param-mesh modules whose inputs lack gradients. + for fsdp_param_group in reversed(state._fsdp_param_groups): if fsdp_param_group._training_state != TrainingState.POST_BACKWARD: # Run post-backward in case forward inputs did not require # gradient so the autograd backward did not run @@ -337,11 +364,12 @@ def _root_post_backward_final_callback(self) -> None: state._finalize_backward() if self._state_ctx.is_last_backward: self._comm_ctx.post_forward_order.clear() - if self._comm_ctx.reduce_scatter_state is not None: - self._device_handle.current_stream().wait_event( - self._comm_ctx.reduce_scatter_state.event - ) - self._comm_ctx.reduce_scatter_state = None + # Catch the last module's RS states that no subsequent + # module's group N-1 wait will clear. + for rs_state in self._comm_ctx.reduce_scatter_states: + if rs_state.event is not None: + self._device_handle.current_stream().wait_event(rs_state.event) + self._comm_ctx.reduce_scatter_states.clear() self._state_ctx.post_backward_final_callback_queued = False def _finalize_backward(self) -> None: diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index d0c80371731e2..cc48a5f9d90d3 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -12,7 +12,13 @@ import torch.nn as nn from torch.distributed._composable import contract -from ._fsdp_api import AllGather, MixedPrecisionPolicy, OffloadPolicy, ReduceScatter +from ._fsdp_api import ( + AllGather, + DataParallelMeshDims, + MixedPrecisionPolicy, + OffloadPolicy, + ReduceScatter, +) from ._fsdp_common import FSDPMeshInfo, ShardPlacementFnResult from ._fsdp_init import ( _apply_to_module, @@ -64,6 +70,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> FSDPModule: ... @@ -78,6 +85,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = ..., offload_policy: OffloadPolicy = ..., ignored_params: set[nn.Parameter] | None = ..., + dp_mesh_dims: DataParallelMeshDims | None = ..., ) -> list[FSDPModule]: ... @@ -96,6 +104,7 @@ def fully_shard( mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), ignored_params: set[nn.Parameter] | None = None, + dp_mesh_dims: DataParallelMeshDims | None = None, ): """ Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP @@ -194,6 +203,12 @@ def fully_shard( ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be ignored by FSDP. They will not be sharded, nor moved to the device during init, nor have their gradients reduced in backward. + dp_mesh_dims (Optional[DataParallelMeshDims]): When provided, + ``mesh`` is treated as the full SPMD mesh, and parameters should be + DTensors on this mesh with ``Replicate()`` on all DP dimensions. + The ``shard`` field names which dim(s) FSDP shards on (multiple + dims are flattened). The ``replicate`` field names the HSDP + replication dim(s) (multiple dims are flattened). Returns: FSDPModule: The module with FSDP applied (in-place). @@ -201,16 +216,29 @@ def fully_shard( torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") _validate_module(module, "fully_shard") mesh = mesh or _init_default_mesh() - _validate_mesh(mesh) - mesh_info = _get_mesh_info(mesh) + _validate_mesh(mesh, dp_mesh_dims) + mesh_info = _get_mesh_info(mesh, dp_mesh_dims) device = _get_device_from_mesh(mesh) auto_reshard_after_forward = reshard_after_forward is None # If the user does not provide ``reshard_after_forward``, we set it to True. # During lazy_init, we identify which module is the root and override its value to False - post_forward_mesh_info = _get_post_forward_mesh_info( - reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] - mesh_info, - ) + if isinstance(mesh_info, FSDPMeshInfo): + if ( + mesh_info.is_spmd_mesh + and not isinstance(reshard_after_forward, bool) + and isinstance(reshard_after_forward, int) + ): + raise NotImplementedError( + "reshard_after_forward as int is not yet supported with " + "SPMD mesh (dp_mesh_dims)" + ) + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] + mesh_info, + ) + else: + # DDPMeshInfo: no sharding, so no post-forward resharding needed + post_forward_mesh_info = None arg_module, modules, managed_modules, params, buffers = _get_modules_and_states( module, device, ignored_params ) diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index 633e743e7d5c8..349ccfd767207 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -617,13 +617,26 @@ def _init_param_handle_from_module( managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params)) _verify_managed_params(fully_sharded_module, managed_params) if sync_module_states: - _sync_module_params_and_buffers( - fully_sharded_module, managed_params, state.process_group - ) if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + # Broadcast inter-node first, then intra-node. The inter-node + # broadcast propagates rank 0's states to each node's local + # rank 0, so the subsequent intra-node broadcast has the + # correct source values on every node. Reversing this order + # causes local rank 0 on non-source nodes to broadcast + # uninitialized states (e.g. from meta-device materialization). _sync_module_params_and_buffers( fully_sharded_module, managed_params, state._inter_node_pg ) + # _sync_module_params_and_buffers marks each buffer with + # FSDP_SYNCED=True to avoid redundant syncs in nested + # wrapping. Reset the flag here so the intra-node broadcast + # below also includes buffers. + for buffer in fully_sharded_module.buffers(): + if hasattr(buffer, FSDP_SYNCED): + setattr(buffer, FSDP_SYNCED, False) + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state.process_group + ) _init_param_handle_from_params(state, managed_params, fully_sharded_module) return state diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py index 31c23b06705ac..4a5bd5f47211f 100644 --- a/torch/distributed/fsdp/_runtime_utils.py +++ b/torch/distributed/fsdp/_runtime_utils.py @@ -290,6 +290,7 @@ def _unshard( """ if not handle: return + handle._compute_stream = state._device_handle.current_stream() with state._device_handle.stream(pre_unshard_stream): ran_pre_unshard = handle.pre_unshard() if ran_pre_unshard: diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index 9d9d5fec572f5..fb7b25ee570f7 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -331,7 +331,7 @@ def param_hook( try: state_dict[fqn] = state_dict[fqn].detach().clone() state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] - except BaseException as e: # noqa: B036 + except BaseException as e: warnings.warn( f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " "This may mean that this state_dict entry could point to invalid " diff --git a/torch/distributed/launcher/__init__.py b/torch/distributed/launcher/__init__.py index fb744a2b93615..966c0385c5080 100644 --- a/torch/distributed/launcher/__init__.py +++ b/torch/distributed/launcher/__init__.py @@ -7,8 +7,4 @@ # LICENSE file in the root directory of this source tree. -from torch.distributed.launcher.api import ( # noqa: F401 - elastic_launch, - launch_agent, - LaunchConfig, -) +from torch.distributed.launcher.api import elastic_launch, launch_agent, LaunchConfig diff --git a/torch/distributed/launcher/api.py b/torch/distributed/launcher/api.py index 71ce57f62bce3..e51f4cb56c977 100644 --- a/torch/distributed/launcher/api.py +++ b/torch/distributed/launcher/api.py @@ -15,10 +15,18 @@ import torch import torch.distributed.elastic.rendezvous.registry as rdzv_registry -from torch._utils_internal import get_default_numa_options +from torch._utils_internal import get_default_numa_options, justknobs_check from torch.distributed.elastic import events, metrics from torch.distributed.elastic.agent.server.api import WorkerSpec -from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from torch.distributed.elastic.agent.server.health_check_server import ( + create_healthcheck_server, + HealthCheckServer, +) +from torch.distributed.elastic.agent.server.local_elastic_agent import ( + _AliveCallbackProxy, + LocalElasticAgent, + TORCHELASTIC_HEALTH_CHECK_PORT, +) from torch.distributed.elastic.multiprocessing import ( DefaultLogsSpecs, LogsSpecs, @@ -176,12 +184,19 @@ def __init__( self, config: LaunchConfig, entrypoint: Callable | str | None, + health_check_server: HealthCheckServer | None = None, ): self._config = config self._entrypoint = entrypoint + self._health_check_server = health_check_server def __call__(self, *args): - return launch_agent(self._config, self._entrypoint, list(args)) + return launch_agent( + self._config, + self._entrypoint, + list(args), + health_check_server=self._health_check_server, + ) def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str: @@ -227,6 +242,7 @@ def launch_agent( config: LaunchConfig, entrypoint: Callable | str | None, args: list[Any], + health_check_server: HealthCheckServer | None = None, ) -> dict[int, Any]: if not config.run_id: run_id = str(uuid.uuid4().int) @@ -290,6 +306,35 @@ def launch_agent( # Set the signals to handle in the environment variable os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle + # Start health check server before rendezvous so TW sees a healthy + # thrift port during the potentially long MAST rendezvous store barrier + # (10-22+ min for large jobs). The _AliveCallbackProxy returns + # time.time() until wired to the agent after construction. + # Skip if a server was already provided by the caller (e.g. started + # before remote_pre_launch in the APF executor). + if health_check_server is None: + healthcheck_port = os.getenv(TORCHELASTIC_HEALTH_CHECK_PORT) + if healthcheck_port is not None and justknobs_check( + "ai_infra/pytorch_distributed:torchelastic_enable_healthcheck_before_rendezvous", + default=False, + ): + try: + health_check_server = create_healthcheck_server( + alive_callback=_AliveCallbackProxy(), + port=int(healthcheck_port), + timeout=60, + ) + health_check_server.start() + logger.info( + "Started early health check server on port %s before rendezvous", + healthcheck_port, + ) + except Exception: + logger.warning( + "Failed to start early health check server", exc_info=True + ) + health_check_server = None + spec = WorkerSpec( role=config.role, local_world_size=config.nproc_per_node, @@ -314,8 +359,14 @@ def launch_agent( start_method=config.start_method, log_line_prefix_template=config.log_line_prefix_template, shutdown_timeout=config.shutdown_timeout, # type: ignore[arg-type] + health_check_server=health_check_server, ) + if health_check_server is not None: + cb = health_check_server.alive_callback + if isinstance(cb, _AliveCallbackProxy): + cb.set_delegate(agent._get_alive_time) + shutdown_rdzv = True try: metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index bf8c4613b0849..a5cd964770e58 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -326,7 +326,7 @@ def _pipe_split(): @torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] -def _pipe_split(): # noqa: F811 +def _pipe_split(): return None diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index c4e6256389aa3..7e8f418a02d4a 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -2,7 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import collections import logging -from collections.abc import Iterator +from collections.abc import Iterator, Sequence from typing import Any import torch @@ -140,6 +140,40 @@ def get_param_groups( return unique_param_groups +def _autograd_grad_for_inputs( + outputs: Sequence[torch.Tensor], + inputs: Sequence[torch.Tensor], + grad_outputs: Sequence[torch.Tensor | None] | None = None, + retain_graph: bool = False, + allow_unused: bool = False, +) -> tuple[torch.Tensor | None, ...]: + """Compute input gradients, returning ``None`` for non-grad inputs.""" + # Some inputs may not be used or may not require gradients, so we filter them out + # before calling autograd.grad and place None for those positions in the result. + grad_indices: list[int] = [] + inputs_requiring_grad: list[torch.Tensor] = [] + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor) and inp.requires_grad: + grad_indices.append(i) + inputs_requiring_grad.append(inp) + + if not inputs_requiring_grad: + return tuple(None for _ in inputs) + + grads = torch.autograd.grad( + outputs=outputs, + inputs=inputs_requiring_grad, + grad_outputs=grad_outputs, + retain_graph=retain_graph, + allow_unused=allow_unused, + ) + + result: list[torch.Tensor | None] = [None] * len(inputs) + for idx, g in zip(grad_indices, grads, strict=True): + result[idx] = g + return tuple(result) + + def stage_backward_input( stage_outputs_or_loss: list[torch.Tensor], output_grads: list[torch.Tensor] | None, @@ -196,20 +230,20 @@ def hook(grad_inputs): torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss ] - # Some inputs may not be used or may not require gradients, so we filter them out - input_values = [inp for inp in input_values if inp.requires_grad] - dinputs = torch.autograd.grad( + dinputs = _autograd_grad_for_inputs( stage_outputs_or_loss, - inputs=input_values, - grad_outputs=output_grads, + input_values, + output_grads, retain_graph=True, ) - # Update the gradients for inputs + + # Accumulate into .grad for inp, dinput in zip(input_values, dinputs): - if inp.grad is None: - inp.grad = dinput - else: - inp.grad += dinput + if isinstance(inp, torch.Tensor) and dinput is not None: + if inp.grad is None: + inp.grad = dinput + else: + inp.grad += dinput # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py index 79b74be406814..3da67258ecbd9 100644 --- a/torch/distributed/pipelining/_utils.py +++ b/torch/distributed/pipelining/_utils.py @@ -1,91 +1,506 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +from __future__ import annotations + import logging -from dataclasses import dataclass +import warnings +from dataclasses import dataclass, field +from enum import Enum +from typing import cast, Literal, overload, Protocol, TYPE_CHECKING, TypeAlias import torch from torch import fx +from torch.distributed._mesh_layout import _MeshLayout +from torch.distributed.tensor import DTensor +from torch.utils._pytree import tree_flatten, tree_unflatten + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.tensor.placement_types import Placement logger = logging.getLogger(__name__) -def flatten_args_detach(args): +class GetMeshCallback(Protocol): + """Callback to create/retrieve a DeviceMesh from its cache key components.""" + + def __call__( + self, + mesh_dim_names: tuple[str, ...], + mesh_layout: _MeshLayout | None, + ) -> DeviceMesh: ... + + +# Key for mesh cache: (mesh_dim_names, mesh_layout) +# mesh_layout is the _MeshLayout object containing shape and stride (not actual ranks). +# This uniquely identifies a mesh within the same "universe" where all stages share +# the same rank tensor. +MeshCacheKey: TypeAlias = tuple[tuple[str, ...], _MeshLayout | None] + + +class PipeliningMetadataError(RuntimeError): + """Raised on metadata mismatches during pipeline communication.""" + + +@dataclass(frozen=True, slots=True) +class _TensorMeta: + """Tensor metadata for recv buffer allocation and validation. + + For plain tensors, these are the tensor's actual attributes. + For DTensors, these are LOCAL shard attributes; global attributes + are stored in :class:`_DTensorMeta`. """ - Flatten the args into a list form and detach the tensors from computational graph. + + shape: torch.Size + stride: tuple[int, ...] + dtype: torch.dtype + requires_grad: bool + + @staticmethod + def from_tensor(tensor: torch.Tensor) -> _TensorMeta: + """Create metadata from a plain tensor. + + Args: + tensor: A plain ``torch.Tensor`` (not DTensor). + + Returns: + Metadata capturing shape, stride, dtype, and requires_grad. + + Raises: + TypeError: If ``tensor`` is a DTensor. + """ + if isinstance(tensor, DTensor): + raise PipeliningMetadataError( + "Expected plain tensor, got DTensor. Use _DTensorMeta.from_dtensor instead." + ) + return _TensorMeta( + shape=tensor.shape, + stride=tensor.stride(), + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + ) + + def to_tensor(self, device: torch.device | str) -> torch.Tensor: + """Reconstruct a tensor on ``device`` from this metadata. + + Args: + device: Target device for the tensor. + + Returns: + An empty strided tensor on ``device``. + """ + t = _make_tensor_from_meta(self, device) + t.requires_grad_(self.requires_grad) + return t + + def get_diff(self, other: _TensorMeta) -> list[str]: + """Return field-by-field differences with ``other``. + + Args: + other: Metadata to compare against. + + Returns: + List of human-readable difference strings (empty if equal). + """ + if self == other: + return [] + + diffs = [] + if self.shape != other.shape: + diffs.append(f"shape mismatch: {self.shape} vs {other.shape}") + if self.stride != other.stride: + diffs.append(f"stride mismatch: {self.stride} vs {other.stride}") + if self.dtype != other.dtype: + diffs.append(f"dtype mismatch: {self.dtype} vs {other.dtype}") + # requires_grad is intentionally excluded: it is a runtime concern + # determined by has_backward and grad context, not a metadata invariant. + return diffs + + +@dataclass(frozen=True, slots=True) +class _DTensorMeta(_TensorMeta): + """DTensor metadata extending :class:`_TensorMeta` with distribution info. + + Inherited fields (shape, stride, etc.) are LOCAL shard attributes. + Additional fields capture global shape and placement information + needed to reconstruct a :class:`DTensor` via ``DTensor.from_local()``. + + The :class:`DeviceMesh` is **not** stored (not serializable for P2P); + it is looked up from :class:`_MeshCache` using + ``(mesh_dim_names, mesh_layout)`` as the key. """ - flat_detached_args = [] - def extract_tensor_args(a): - nonlocal flat_detached_args - if isinstance(a, torch.Tensor): - val = a.detach().requires_grad_(a.requires_grad) - flat_detached_args.append(val) - return val + # Global DTensor properties (for reconstruction) + global_shape: torch.Size = field(default_factory=lambda: torch.Size([])) + global_stride: tuple[int, ...] = field(default=()) + + # DTensor distribution properties + placements: tuple[Placement, ...] = field( + default=() + ) # e.g., (Shard(0), Replicate()) + + # Mesh identification - used to look up the correct DeviceMesh from cache + mesh_dim_names: tuple[str, ...] = field(default=()) # e.g., ("tp",) or ("dp", "tp") + mesh_layout: _MeshLayout | None = field( + default=None + ) # _MeshLayout with shape/stride - uniquely identifies mesh within the same universe + + @staticmethod + def from_dtensor(dtensor: DTensor) -> _DTensorMeta: + """Create metadata from a DTensor. + + Args: + dtensor: The DTensor to extract metadata from. + + Returns: + Metadata capturing both local and global attributes. + """ + device_mesh = dtensor.device_mesh + + return _DTensorMeta( + # Local tensor attributes (for recv buffer allocation) + shape=dtensor._local_tensor.shape, + stride=dtensor._local_tensor.stride(), + dtype=dtensor.dtype, + requires_grad=dtensor.requires_grad, + # Global DTensor attributes (for reconstruction) + global_shape=dtensor.shape, + global_stride=dtensor.stride(), + # Distribution info + placements=dtensor._spec.placements, + mesh_dim_names=( + tuple(device_mesh.mesh_dim_names) if device_mesh.mesh_dim_names else () + ), + mesh_layout=device_mesh._layout, + ) + + @property + def mesh_cache_key(self) -> MeshCacheKey: + """Cache key ``(mesh_dim_names, mesh_layout)`` for mesh lookup.""" + return (self.mesh_dim_names, self.mesh_layout) + + def to_dtensor(self, device: torch.device | str, mesh: DeviceMesh) -> DTensor: + """Reconstruct a DTensor on ``device`` with placements. + + Args: + device: Target device for the local tensor. + mesh: The ``DeviceMesh`` to attach. + + Returns: + A DTensor on ``device``. + """ + local_tensor = _make_tensor_from_meta(self, device) + # Set requires_grad after from_local() so that the from_local + # operation itself is not recorded in the autograd graph. + return cast( + DTensor, + DTensor.from_local( + local_tensor, + device_mesh=mesh, + placements=self.placements, + shape=self.global_shape, + stride=self.global_stride, + run_check=False, + ).requires_grad_(self.requires_grad), + ) + + def get_diff(self, other: _TensorMeta) -> list[str]: + """Return field-by-field differences, including DTensor-specific fields. + + Args: + other: Metadata to compare against. + + Returns: + List of human-readable difference strings (empty if equal). + """ + if self == other: + return [] + + # Get base class differences (compares local shape/stride/dtype/requires_grad) + # NOTE: Use explicit class call instead of super() because + # @dataclass(slots=True) on both parent and child can break super(). + diffs = _TensorMeta.get_diff(self, other) + + # Add DTensor-specific comparisons if other is also _DTensorMeta + if isinstance(other, _DTensorMeta): + if self.global_shape != other.global_shape: + diffs.append( + f"global_shape mismatch: {self.global_shape} vs {other.global_shape}" + ) + if self.global_stride != other.global_stride: + diffs.append( + f"global_stride mismatch: {self.global_stride} vs {other.global_stride}" + ) + if self.placements != other.placements: + diffs.append( + f"placements mismatch: {self.placements} vs {other.placements}" + ) + if self.mesh_dim_names != other.mesh_dim_names: + diffs.append( + f"mesh_dim_names mismatch: {self.mesh_dim_names} vs {other.mesh_dim_names}" + ) + if self.mesh_layout != other.mesh_layout: + diffs.append( + f"mesh_layout mismatch: {self.mesh_layout} vs {other.mesh_layout}" + ) else: - flat_detached_args.append(a) - return a + diffs.append("type: _DTensorMeta vs _TensorMeta") - new_args = fx.node.map_aggregate( - args, - extract_tensor_args, - ) + return diffs - return new_args, flat_detached_args +# Type alias for union of tensor metadata types +TensorMeta: TypeAlias = _TensorMeta | _DTensorMeta + + +# Not frozen: fields are populated incrementally during forward and +# backward metadata inference or from user provided static metadata +@dataclass(slots=True) +class _StageMeta: + """Consolidated tensor metadata for a pipeline stage's forward and backward passes.""" + + inputs: tuple[TensorMeta, ...] | None = None + outputs: tuple[TensorMeta, ...] | None = None + input_grads: tuple[TensorMeta | None, ...] | None = None + output_grads: tuple[TensorMeta | None, ...] | None = None + + def has_any(self) -> bool: + """Check if any metadata field is populated.""" + return any( + v is not None + for v in [self.inputs, self.outputs, self.input_grads, self.output_grads] + ) -def flatten_args(args): + def has_dtensors(self) -> bool: + """Check if any input/output metadata is DTensor type.""" + for metas in [self.inputs, self.outputs]: + if metas and any(isinstance(m, _DTensorMeta) for m in metas if m): + return True + return False + + def is_complete_for_forward(self) -> bool: + """Check if forward metadata is fully populated.""" + return self.inputs is not None and self.outputs is not None + + +@dataclass(frozen=True, slots=True) +class _StageForwardMeta: + """Forward metadata transmitted from stage *i* to stage *i+1* during inference.""" + + forward_metas: tuple[TensorMeta, ...] # Stage i's outputs → Stage i+1's inputs + + +@dataclass(frozen=True, slots=True) +class _StageBackwardMeta: + """Backward metadata transmitted from stage *i* to stage *i-1* during inference. + + Gradient placements may differ from forward activations + (e.g., ``Replicate`` → ``Partial``). """ - Flatten the args into a list form. + + backward_metas: tuple[ + TensorMeta | None, ... + ] # Stage i's input_grads → Stage i-1's output_grads + + +def _make_tensor_from_meta( + meta: _TensorMeta, + device: torch.device | str, +) -> torch.Tensor: + """Create a tensor from metadata. + + Args: + meta: Metadata with shape, stride, and dtype. + device: Target device for the tensor. + + Returns: + Empty tensor preserving the exact memory layout. """ - flat_args = [] + return torch.empty_strided( + size=meta.shape, + stride=meta.stride, + dtype=meta.dtype, + device=device, + ) + - def extract_tensor_args(a): - nonlocal flat_args - flat_args.append(a) - return a +def _derive_grad_metas( + tensor_metas: tuple[TensorMeta, ...], +) -> tuple[_TensorMeta | None, ...]: + """Derive gradient metadata from tensor metadata. - fx.node.map_aggregate( - args, - extract_tensor_args, + Returns metadata with the same shape/stride/dtype but ``requires_grad=False``. + Entries where the source has ``requires_grad=False`` become ``None``. + """ + return tuple( + _TensorMeta(shape=m.shape, stride=m.stride, dtype=m.dtype, requires_grad=False) + if m.requires_grad + else None + for m in tensor_metas ) - return flat_args +class _MeshCache: + """Cache for :class:`DeviceMesh` objects keyed by ``(mesh_dim_names, mesh_layout)``. -class PipeliningShapeError(RuntimeError): - """Shape mismatch between configured and runtime values.""" + Assumes all pipeline stages share the same rank tensor (true for + TorchTitan-style frameworks where meshes derive from a common world). + """ + def __init__(self, get_mesh_cb: GetMeshCallback | None = None) -> None: + self._cache: dict[MeshCacheKey, DeviceMesh] = {} + self._get_mesh_cb = get_mesh_cb -def validate_tensor_metadata(desc, expected, given): - if not expected.shape == given.shape: - raise PipeliningShapeError( - f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" - ) - if not expected.dtype == given.dtype: - raise PipeliningShapeError( - f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" - ) - if not expected.stride() == given.stride(): - raise PipeliningShapeError( - f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" - ) + def get_mesh(self, key: MeshCacheKey) -> DeviceMesh: + """Return a cached mesh, or create one via the callback. + Args: + key: Cache key ``(mesh_dim_names, mesh_layout)``. -def validate_tensors_metadata( - desc, - expected_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], - actual_tensors: list[torch.Tensor] | tuple[torch.Tensor, ...], -): - if len(expected_tensors) != len(actual_tensors): - raise PipeliningShapeError( - f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" - ) - for i in range(len(expected_tensors)): - validate_tensor_metadata( - f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] - ) + Returns: + The ``DeviceMesh``. + + Raises: + PipeliningMetadataError: If not cached and no callback provided. + """ + if key in self._cache: + return self._cache[key] + + mesh_dim_names, mesh_layout = key + + if self._get_mesh_cb is None: + raise PipeliningMetadataError( + f"Mesh not found in cache for mesh_dim_names={mesh_dim_names}, " + f"mesh_layout={mesh_layout}, and no get_mesh callback provided. " + f"Provide a get_mesh callback or use DTensors in static mode." + ) + + mesh = self._get_mesh_cb(mesh_dim_names, mesh_layout) + if mesh is None: + raise PipeliningMetadataError( + f"Mesh lookup failed: callback returned None for " + f"mesh_dim_names={mesh_dim_names}, mesh_layout={mesh_layout}. " + f"Ensure all stages use meshes from the same universe." + ) + self._cache[key] = mesh + return mesh + + def put(self, key: MeshCacheKey, mesh: DeviceMesh) -> None: + """Add a mesh to the cache.""" + self._cache[key] = mesh + + def update_from_tensors(self, tensors: tuple[torch.Tensor | None, ...]) -> None: + """Extract and cache meshes from any :class:`DTensor` instances in *tensors*.""" + for tensor in tensors: + if isinstance(tensor, DTensor): + mesh = tensor.device_mesh + dim_names = tuple(mesh.mesh_dim_names) if mesh.mesh_dim_names else () + mesh_layout = mesh._layout + key = (dim_names, mesh_layout) + if key not in self._cache: + self._cache[key] = mesh + + def __contains__(self, key: MeshCacheKey) -> bool: + return key in self._cache + + def __len__(self) -> int: + return len(self._cache) + + +# ============================================================================ +# Inference mode enum +# ============================================================================ + + +class InferenceMode(Enum): + """Pipeline-level metadata inference mode, determined collectively across all PP ranks. + + The mode is set by the schedule (not individual stages) because + ``has_backward`` is only known at schedule creation time and all + stages must agree to avoid P2P hangs. + + .. attribute:: STATIC + + All stages have sufficient metadata; runtime inference is skipped. + + .. attribute:: DYNAMIC + + At least one stage requires runtime metadata inference. + """ + + STATIC = "static" + DYNAMIC = "dynamic" + + @classmethod + def needs_dynamic(cls, meta: _StageMeta, stage_has_backward: bool) -> bool: + """Determine whether dynamic metadata inference is needed for a stage. + + Args: + meta: Stage metadata from user-provided args. + stage_has_backward: Whether a backward pass will be performed. + + Returns: + ``True`` if dynamic inference is needed. + """ + # Case 1: Forward metadata incomplete → needs DYNAMIC + if not meta.is_complete_for_forward(): + return True + + # Case 2: No DTensors → STATIC is fine (bwd metadata derivable from fwd metadata) + if not meta.has_dtensors(): + return False + + # Case 3: No backward needed → STATIC is fine (don't need grad metadata) + if not stage_has_backward: + return False + + # Case 4: DTensors with backward but missing ANY grad metadata → needs DYNAMIC + # Both input_grads AND output_grads are required for static mode with DTensors + if meta.input_grads is None or meta.output_grads is None: + return True + + # Case 5: DTensors with complete grads → STATIC is fine + return False + + +# ============================================================================ +# Utility functions +# ============================================================================ + + +def flatten_args(args, *, detach: bool = False): + """Flatten ``args`` into a list, optionally detaching tensors. + + Args: + args: Nested arguments to flatten. + detach: If ``True``, detach tensors while preserving ``requires_grad``. + + Returns: + ``(new_args, flat_detached_args)`` when ``detach=True``; + ``flat_args`` list otherwise. + """ + flat_args, treespec = tree_flatten(args) + + if detach: + flat_detached = [ + a.detach().requires_grad_(a.requires_grad) + if isinstance(a, torch.Tensor) + else a + for a in flat_args + ] + new_args = tree_unflatten(flat_detached, treespec) + return new_args, flat_detached + + return flat_args + + +# Backward compatibility alias +def flatten_args_detach(args): + """Flatten and detach. Deprecated: use ``flatten_args(args, detach=True)``.""" + return flatten_args(args, detach=True) def generate_stage_to_rank_mapping( @@ -148,7 +563,7 @@ def generate_rank_to_stage_mapping( return rank_to_stages -@dataclass +@dataclass(slots=True) class PipeInfo: """ Captures information for a pipeline (`Pipe` object). @@ -157,3 +572,367 @@ class PipeInfo: graph: fx.Graph num_stages: int has_loss_and_backward: bool + + +# ============================================================================ +# Metadata extraction helpers +# ============================================================================ + + +def extract_tensor_meta(tensor: torch.Tensor) -> TensorMeta: + """Extract metadata from a tensor. + + Handles both plain Tensor and DTensor correctly: DTensors are + dispatched to ``_DTensorMeta.from_dtensor`` which captures local + shard attributes plus global shape/placement info, while plain + tensors use ``_TensorMeta.from_tensor``. + + Args: + tensor: A plain tensor or DTensor. + + Returns: + ``_TensorMeta`` for plain tensors, ``_DTensorMeta`` for DTensors. + """ + if isinstance(tensor, DTensor): + return _DTensorMeta.from_dtensor(tensor) + else: + return _TensorMeta.from_tensor(tensor) + + +@overload +def extract_tensor_metas( + tensors: tuple[torch.Tensor, ...] | None, + *, + allow_none: Literal[False] = ..., +) -> tuple[TensorMeta, ...] | None: ... + + +@overload +def extract_tensor_metas( + tensors: tuple[torch.Tensor | None, ...] | None, + *, + allow_none: Literal[True], +) -> tuple[TensorMeta | None, ...] | None: ... + + +def extract_tensor_metas( + tensors: tuple[torch.Tensor | None, ...] | tuple[torch.Tensor, ...] | None, + *, + allow_none: bool = False, +) -> tuple[TensorMeta | None, ...] | None: + """Extract metadata from a tuple of tensors. + + Args: + tensors: Tuple of tensors (may include ``None`` when ``allow_none=True``). + allow_none: If ``True``, preserve ``None`` elements (for gradients). + + Returns: + Tuple of ``TensorMeta``, or ``None`` if ``tensors`` is ``None``. + + Raises: + PipeliningMetadataError: If ``None`` found and ``allow_none=False``. + """ + if tensors is None: + return None + + metas_with_none: list[TensorMeta | None] = [] + has_none = False + for t in tensors: + if isinstance(t, torch.Tensor): + metas_with_none.append(extract_tensor_meta(t)) + else: + has_none = True + metas_with_none.append(None) + if not allow_none and has_none: + raise PipeliningMetadataError( + "None values are not allowed in tensor metadata tuples. " + "Use allow_none=True for optional values." + ) + return tuple(metas_with_none) + + +def to_local_if_dtensor(tensor: torch.Tensor, detach: bool = False) -> torch.Tensor: + """Convert a DTensor to its local shard, or return a plain tensor as-is. + + When ``detach=True``, the tensor is detached before conversion — + this applies to both DTensors and plain tensors. + + Args: + tensor: A tensor that may be a DTensor. + detach: If ``True``, detach before ``to_local()`` to avoid + redistribution during backward. + + Returns: + The local tensor component. + """ + maybe_detached_tensor = tensor.detach() if detach else tensor + if isinstance(maybe_detached_tensor, DTensor): + return maybe_detached_tensor.to_local() + return maybe_detached_tensor + + +@overload +def validate_and_normalize_to_tuple( + args: torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | None, + allow_none: Literal[False] = ..., +) -> tuple[torch.Tensor, ...] | None: ... + + +@overload +def validate_and_normalize_to_tuple( + args: torch.Tensor + | tuple[torch.Tensor | None, ...] + | list[torch.Tensor | None] + | None, + allow_none: Literal[True] = ..., +) -> tuple[torch.Tensor | None, ...] | None: ... + + +def validate_and_normalize_to_tuple( + args: torch.Tensor + | tuple[torch.Tensor, ...] + | tuple[torch.Tensor | None, ...] + | list[torch.Tensor] + | list[torch.Tensor | None] + | None, + allow_none: bool = False, +) -> tuple[torch.Tensor | None, ...] | tuple[torch.Tensor, ...] | None: + """Normalize ``args`` to a tuple and validate that all elements are tensors. + + Args: + args: A single tensor, tuple/list of tensors, or ``None``. + allow_none: If ``True``, permit ``None`` elements (for gradients). + + Returns: + Tuple of tensors, or ``None`` if ``args`` is ``None``. + + Raises: + PipeliningMetadataError: On non-tensor values + (or ``None`` when ``allow_none=False``). + """ + if args is None: + return None + elif isinstance(args, torch.Tensor): + return (args,) + elif isinstance(args, (tuple, list)): + for i, arg in enumerate(args): + if arg is None: + if not allow_none: + raise PipeliningMetadataError( + f"Stage arg[{i}] is None. " + f"Stage args must be tensors. Use kwargs for optional values." + ) + continue + if not isinstance(arg, torch.Tensor): + raise PipeliningMetadataError( + f"Stage arg[{i}] has type {type(arg).__name__}. " + f"All stage args must be tensors. Use kwargs for non-tensor inputs." + ) + # Normalize list to tuple + return tuple(args) if isinstance(args, list) else args + else: + raise PipeliningMetadataError( + f"Stage args must be a tensor, tuple, or list of tensors, got {type(args).__name__}." + ) + + +# ============================================================================ +# Validation functions +# ============================================================================ + + +def validate_metadata( + desc: str, + expected: TensorMeta, + actual: torch.Tensor | TensorMeta, + *, + raise_on_mismatch: bool = False, + warn_on_mismatch: bool = False, +) -> list[str]: + """ + Compare expected metadata against actual tensor or metadata. + + This is the unified validation/comparison function that uses get_diff() from + metadata classes. Works with both plain tensors and DTensors. + + For plain tensors: compares shape/stride/dtype/requires_grad. + For DTensors: compares all properties including global shape and placements. + + Args: + desc: Description for error/warning messages. + expected: Expected tensor metadata (_TensorMeta or _DTensorMeta). + actual: Actual tensor or metadata to compare against. + raise_on_mismatch: If True, raise PipeliningMetadataError on mismatch. + warn_on_mismatch: If True, issue a warning on mismatch. + + Returns: + List of differences (empty if metadata matches). + + Raises: + PipeliningMetadataError: If raise_on_mismatch=True and differences exist. + """ + # Extract metadata if actual is a tensor + if isinstance(actual, torch.Tensor): + actual_meta = extract_tensor_meta(actual) + else: + actual_meta = actual + + # Type check: ensure both are same type for meaningful comparison + if type(expected) is not type(actual_meta): + type_diff = [ + f"type: expected {type(expected).__name__}, got {type(actual_meta).__name__}" + ] + if raise_on_mismatch: + raise PipeliningMetadataError(f"{desc}: {type_diff[0]}") + if warn_on_mismatch: + warnings.warn( + f"{desc}: Metadata type mismatch. {type_diff[0]}. " + f"Using dynamically inferred metadata instead.", + UserWarning, + stacklevel=2, + ) + return type_diff + + # Use get_diff() from the metadata class + diffs = expected.get_diff(actual_meta) + + if diffs: + if raise_on_mismatch: + raise PipeliningMetadataError(f"{desc}: {'; '.join(diffs)}") + if warn_on_mismatch: + warnings.warn( + f"{desc}: Metadata mismatch. {'; '.join(diffs)}. " + f"Using dynamically inferred metadata instead.", + UserWarning, + stacklevel=2, + ) + + return diffs + + +def validate_tensors_metadata( + desc: str, + expected: tuple[TensorMeta | None, ...], + actual: tuple[torch.Tensor | TensorMeta | None, ...], + *, + raise_on_mismatch: bool = True, + warn_on_mismatch: bool = False, +) -> list[str]: + """Validate metadata for a tuple of tensors element-wise. + + Args: + desc: Description prefix for error/warning messages. + expected: Tuple of expected metadata (may include ``None`` for grads). + actual: Tuple of actual tensors or metadata to compare against. + raise_on_mismatch: If ``True``, raise on the first mismatch. + warn_on_mismatch: If ``True``, issue warnings for mismatches. + + Returns: + Aggregated list of difference strings. + + Raises: + PipeliningMetadataError: If lengths differ or on mismatch. + """ + if len(expected) != len(actual): + msg = f"{desc}: expected {len(expected)} tensors, got {len(actual)}" + if raise_on_mismatch: + raise PipeliningMetadataError(msg) + if warn_on_mismatch: + warnings.warn(msg, UserWarning, stacklevel=2) + return [msg] + + all_diffs: list[str] = [] + for i, (exp, act) in enumerate(zip(expected, actual, strict=True)): + if exp is None and act is None: + continue + if exp is None or act is None: + msg = ( + f"{desc}[{i}]: expected {'None' if exp is None else 'metadata'}, " + f"got {'None' if act is None else 'metadata'}" + ) + if raise_on_mismatch: + raise PipeliningMetadataError(msg) + if warn_on_mismatch: + warnings.warn(msg, UserWarning, stacklevel=2) + all_diffs.append(msg) + continue + diffs = validate_metadata( + f"{desc}[{i}]", + exp, + act, + raise_on_mismatch=raise_on_mismatch, + warn_on_mismatch=warn_on_mismatch, + ) + all_diffs.extend(diffs) + return all_diffs + + +def validate_static_arg_grad_correspondence( + stage_index: int, + args: tuple[torch.Tensor, ...], + grads: tuple[torch.Tensor | None, ...], + is_input: bool, +) -> None: + """ + Validate the args↔grads contract for static mode. + + Enforces four rules for each (arg, grad) pair: + 1. len(args) must equal len(grads). + 2. If arg.requires_grad is False, grad must be None. + 3. If arg.requires_grad is True and grad is None, emit a warning + (this is legal at pipeline boundaries but may indicate a bug). + 4. If arg is a DTensor with requires_grad=True and grad is not None, + grad must also be a DTensor. + + Args: + stage_index: The stage index for error messages. + args: Tuple of forward tensors. + grads: Tuple of gradient tensors (can include None). + is_input: True for input_args/input_grads, False for output_args/output_grads. + + Raises: + PipeliningMetadataError: If any hard rule (1, 2, or 4) is violated. + """ + kind = "input" if is_input else "output" + args_name = f"{kind}_args" + grads_name = f"{kind}_grads" + + # Rule 1: lengths must match + if len(args) != len(grads): + raise PipeliningMetadataError( + f"Stage {stage_index}: {grads_name} length ({len(grads)}) does not match " + f"{args_name} length ({len(args)}). Each forward tensor must have a " + f"corresponding gradient entry (use None for tensors that don't require grad)." + ) + + for i, (arg, grad) in enumerate(zip(args, grads, strict=True)): + # Rule 2: no grad for a non-differentiable arg + if not arg.requires_grad and grad is not None: + raise PipeliningMetadataError( + f"Stage {stage_index}: {args_name}[{i}] has requires_grad=False, " + f"but {grads_name}[{i}] is not None ({type(grad).__name__}). " + f"Non-differentiable tensors must have None as their gradient entry." + ) + + # Rule 3: missing grad for a differentiable arg (warn, don't raise) + if arg.requires_grad and grad is None: + warnings.warn( + f"Stage {stage_index}: {args_name}[{i}] has requires_grad=True, " + f"but {grads_name}[{i}] is None. This is legal at pipeline boundaries " + f"but may indicate a missing gradient.", + UserWarning, + stacklevel=2, + ) + + # Rule 4: DTensor arg must have DTensor grad + if ( + isinstance(arg, DTensor) + and arg.requires_grad + and grad is not None + and not isinstance(grad, DTensor) + ): + raise PipeliningMetadataError( + f"Stage {stage_index}: {args_name}[{i}] is a DTensor with requires_grad=True, " + f"but {grads_name}[{i}] is {type(grad).__name__}, expected DTensor or None. " + f"DTensor gradients may have different placements than forward tensors." + ) diff --git a/torch/distributed/pipelining/microbatch.py b/torch/distributed/pipelining/microbatch.py index c1c353b191bb5..5b28c1942a59b 100644 --- a/torch/distributed/pipelining/microbatch.py +++ b/torch/distributed/pipelining/microbatch.py @@ -6,6 +6,8 @@ from typing import Any import torch +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.experimental import local_map from torch.fx.node import map_aggregate from torch.nn.attention.flex_attention import BlockMask from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -202,26 +204,60 @@ def _split_tensor( raise AssertionError( f"Tensor size {tensor.size(spec.split_dim)} is smaller than num_chunks" ) - chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim) - if not _debug_mask_minibatches: - return chunk_tensors - - expanded_chunks = [] - split_dim_idx = 0 - for chunk_tensor in chunk_tensors: - new_val = torch.zeros_like(tensor) - upper_idx = split_dim_idx + chunk_tensor.size(spec.split_dim) - - slice_indices = [slice(None, None, None)] * new_val.ndim - slice_indices[spec.split_dim] = slice(split_dim_idx, upper_idx) - new_val[slice_indices] = chunk_tensor + _is_dtensor = isinstance(tensor, DTensor) + + if _is_dtensor: + # Use local_map to split locally and preserve placements. + # Going through DTensor dispatch would convert Shard(split_dim) to + # Replicate() via an implicit all-gather, which is both wasteful and + # semantically wrong for PP microbatch splitting. + placements = tensor.placements + split_fn = local_map( + lambda t: torch.tensor_split(t, num_chunks, spec.split_dim), + out_placements=(placements,) * num_chunks, + in_placements=(placements,), + ) + chunk_tensors: Sequence[torch.Tensor] = split_fn(tensor) # type: ignore[assignment] + else: + chunk_tensors = torch.tensor_split(tensor, num_chunks, spec.split_dim) - expanded_chunks.append(new_val) + # tensor_split on a leaf tensor produces non-leaf views that won't + # accumulate .grad during torch.autograd.backward(). Call retain_grad() + # on those views so that stage_backward() can read .grad from them. + if tensor.requires_grad and tensor.is_leaf: + for chunk in chunk_tensors: + chunk.retain_grad() - split_dim_idx += chunk_tensor.size(spec.split_dim) + if not _debug_mask_minibatches: + return chunk_tensors - return expanded_chunks + def _expand_chunks( + orig: torch.Tensor, *chunks: torch.Tensor + ) -> tuple[torch.Tensor, ...]: + expanded = [] + idx = 0 + for chunk in chunks: + new_val = torch.zeros_like(orig) + upper = idx + chunk.size(spec.split_dim) + slices: list[slice] = [slice(None)] * new_val.ndim + slices[spec.split_dim] = slice(idx, upper) + new_val[slices] = chunk + expanded.append(new_val) + idx += chunk.size(spec.split_dim) + return tuple(expanded) + + if _is_dtensor: + placements = tensor.placements + n = len(chunk_tensors) + expand_fn = local_map( + _expand_chunks, + out_placements=(placements,) * n, + in_placements=(placements,) + (placements,) * n, + ) + return list(expand_fn(tensor, *chunk_tensors)) # type: ignore[arg-type] + else: + return list(_expand_chunks(tensor, *chunk_tensors)) def _shard_dict_of_args( @@ -538,7 +574,31 @@ def merge_chunks( else: values_to_cat = partial_values - args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) + # Validate DTensor consistency: either all values are DTensors + # or none are. A mix indicates a bug in the pipeline stage. + dtensor_flags = [isinstance(v, DTensor) for v in values_to_cat] + if any(dtensor_flags): + if not all(dtensor_flags): + raise AssertionError( + "merge_chunks: expected all values to be DTensors or " + "none to be DTensors, got a mix" + ) + # All DTensors must have matching placements. + placements = values_to_cat[0].placements + for i, v in enumerate(values_to_cat[1:], 1): + if v.placements != placements: + raise AssertionError( + f"merge_chunks: placement mismatch at chunk {i}: " + f"expected {placements}, got {v.placements}" + ) + cat_fn = local_map( + lambda *chunks: torch.cat(chunks, dim=arg.split_dim), + out_placements=(placements,), + in_placements=tuple(placements for _ in range(len(values_to_cat))), + ) + args_flattened.append(cat_fn(*values_to_cat)) + else: + args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) elif isinstance(arg, _CustomReducer): reduced_val = arg.init_value diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 27789f74ba7a7..9ce05ffb7eb48 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -21,9 +21,13 @@ from torch.nn.modules.loss import _Loss from torch.profiler import record_function -from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping +from ._utils import ( + generate_rank_to_stage_mapping, + generate_stage_to_rank_mapping, + InferenceMode, +) from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec -from .stage import _PipelineStageBase +from .stage import _PipelineStageBase, PipelineStage __all__ = [ @@ -320,6 +324,135 @@ def _update_losses(self, stages, losses): self._internal_losses.clear() + def _warmup_p2p( + self, + stages: list[_PipelineStageBase], + has_backward: bool, + p2p_done: bool, + ) -> None: + """Run the P2P warm-up protocol for the given stages. + + For ``PipelineStage`` instances this executes the forward/backward vote + protocol (which warms up 2-rank sub-communicators) and sets each + stage's ``_inference_mode``. For other stage types it falls back to + the legacy ``_get_init_p2p_neighbors_ops`` + ``_batch_p2p`` path. + + Args: + stages: The pipeline stages owned by this rank. + has_backward: Whether the schedule includes a backward pass. + p2p_done: ``True`` if P2P neighbours have already been initialised + (avoids redundant init on eval↔train mode switches). + """ + if all(isinstance(stage, PipelineStage) for stage in stages): + acc: torch.Tensor | None = None + for stage in cast(list[PipelineStage], stages): + acc = stage._warmup_forward_vote(has_backward, received_acc=acc) + result: torch.Tensor | None = acc + determined_mode: InferenceMode | None = None + for stage in reversed(cast(list[PipelineStage], stages)): + result = stage._warmup_backward_result(received_result=result) + if result is None: + raise RuntimeError("P2P warm-up voting failed") + determined_mode = ( + InferenceMode.STATIC + if result.item() == 1 + else InferenceMode.DYNAMIC + ) + stage._inference_mode = determined_mode + logger.debug( + "Rank determined inference_mode=%s for %d stage(s)", + determined_mode.value if determined_mode else "None", + len(stages), + ) + elif not p2p_done: + all_ops: list[dist.P2POp] = [] + for stage in stages: + all_ops.extend(stage._get_init_p2p_neighbors_ops()) + _wait_batch_p2p(_batch_p2p(all_ops)) + + # TODO: STATIC mode group communicator warm-up gap + # The vote protocol above warms up 2-rank sub-communicators + # (used by `_batch_p2p` homogeneous fast-path). In DYNAMIC mode, + # `_send_meta`/`_recv_meta` (called during `_prepare_forward_infra` → + # `_forward_metadata_inference`) also warm up the *group* communicator + # (used by `_batch_p2p` mixed-op path). In STATIC mode, metadata + # inference is skipped, so the group communicator is NOT warmed up — + # it will be lazily created on the first mixed `_batch_p2p` call + # (e.g., 1F1B steady-state with both sends and recvs). + # Fix: run `_get_init_p2p_neighbors_ops` + `_batch_p2p` after the + # vote, gated by `not p2p_done`. + + def _initialize_pp_stages( + self, + stages: list[_PipelineStageBase], + args: tuple[Any, ...] | Any, + kwargs: dict[str, Any] | None, + target: Any, + fwd_initialized: bool, + bwd_initialized: bool, + ) -> tuple[bool, bool]: + """Common stage initialization shared by Single and Multi schedules. + + Handles mode-change detection (eval↔train), P2P warm-up, RNG forking, + forward / backward metadata inference, and FSDP cleanup. + + Returns the updated ``(fwd_initialized, bwd_initialized)`` flags. + """ + # Detect eval↔train mode switch: if has_backward changed since last + # init, re-initialize both fwd (recv buffers need different + # requires_grad) and bwd. p2p_done avoids redundant P2P warm-up. + p2p_done = fwd_initialized + if fwd_initialized and (self._has_backward != bwd_initialized): + fwd_initialized = False + bwd_initialized = False + + needs_fwd = not fwd_initialized + needs_bwd = self._has_backward and not bwd_initialized + + if not needs_fwd and not needs_bwd: + return fwd_initialized, bwd_initialized + + if needs_fwd: + self._warmup_p2p(stages, self._has_backward, p2p_done) + + # Fork RNG so metadata inference doesn't perturb training RNG. + devices = list( + { + torch.device(stage.device) + for stage in stages + if torch.device(stage.device).type != "cpu" + } + ) + with torch.random.fork_rng(devices=devices): + if needs_fwd: + next_stage_args: Any = None + for stage in stages: + stage_args = args if stage.is_first else next_stage_args + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, + stage_args, + kwargs, + has_backward=self._has_backward, + ) + fwd_initialized = True + + if needs_bwd: + prev_stage_grad_meta: Any = None + for stage in reversed(stages): + prev_stage_grad_meta = stage._prepare_backward_infra( + self._n_microbatches, + loss_fn=self._loss_fn, + target=target, + received_grad_meta=prev_stage_grad_meta, + ) + bwd_initialized = True + + for stage in stages: + if isinstance(stage, PipelineStage): + stage._post_metadata_inference_cleanup() + + return fwd_initialized, bwd_initialized + @abstractmethod def _step_microbatches( self, @@ -560,21 +693,18 @@ def __init__( self._get_pipeline_order() ) - def _initialize_stage(self, args, kwargs): - if not self._stage_forward_initialized: - # Prepare the communication needed for the pipeline schedule execution - # This is needed because during execution we always perform a series of batch P2P ops - # The first call of the batched P2P needs to involve the global group - all_ops: list[dist.P2POp] = [] - all_ops.extend(self._stage._get_init_p2p_neighbors_ops()) - _wait_batch_p2p(_batch_p2p(all_ops)) - - self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) - self._stage_forward_initialized = True - - if self._has_backward and not self._stage_backward_initialized: - self._stage._prepare_backward_infra(self._n_microbatches) - self._stage_backward_initialized = True + def _initialize_stage(self, args, kwargs, target=None): + ( + self._stage_forward_initialized, + self._stage_backward_initialized, + ) = self._initialize_pp_stages( + [self._stage], + args, + kwargs, + target, + self._stage_forward_initialized, + self._stage_backward_initialized, + ) def step( self, @@ -670,7 +800,8 @@ def _step_microbatches( ) arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + maybe_first_target = target_mbs[0] if target_mbs is not None else None + self._initialize_stage(arg_mbs[0], kwarg_mbs[0], maybe_first_target) # Delay send waits fwd_sends_to_wait: list[list[dist.Work]] = [] @@ -721,7 +852,8 @@ def _step_microbatches( return_outputs: whether to return the outputs from the last stage. """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + maybe_first_target = target_mbs[0] if target_mbs is not None else None + self._initialize_stage(arg_mbs[0], kwarg_mbs[0], maybe_first_target) # Delay send waits fwd_sends_to_wait: list[list[dist.Work]] = [] @@ -865,7 +997,8 @@ def _step_microbatches( return_outputs: whether to return the outputs from the last stage. """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + maybe_first_target = target_mbs[0] if target_mbs is not None else None + self._initialize_stage(arg_mbs[0], kwarg_mbs[0], maybe_first_target) # Last stage has 1 warmup, second-to-last 2 warmups, ... # first stage `num_stages` warmups @@ -1532,34 +1665,18 @@ def __init__( "Simply stop passing it, and everything should still work fine." ) - def _initialize_stages(self, args: tuple[Any, ...], kwargs): - if not self._stages_forward_initialized: - # Prepare the communication needed for the pipeline schedule execution - # This is needed because during execution we always perform a series of batch P2P ops - # The first call of the batched P2P needs to involve the global group - all_ops: list[dist.P2POp] = [] - for stage in self._stages: - all_ops.extend(stage._get_init_p2p_neighbors_ops()) - _wait_batch_p2p(_batch_p2p(all_ops)) - - # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) - # or real value (if this stage and next stage are on the same device) - next_stage_args: tuple[Any, ...] = tuple() - for stage in self._stages: - if stage.is_first: - next_stage_args = stage._prepare_forward_infra( - self._n_microbatches, args, kwargs - ) - else: - next_stage_args = stage._prepare_forward_infra( - self._n_microbatches, next_stage_args, kwargs - ) - self._stages_forward_initialized = True - - if self._has_backward and not self._stages_backward_initialized: - for stage in self._stages: - stage._prepare_backward_infra(self._n_microbatches) - self._stages_backward_initialized = True + def _initialize_stages(self, args: tuple[Any, ...], kwargs, target=None): + ( + self._stages_forward_initialized, + self._stages_backward_initialized, + ) = self._initialize_pp_stages( + self._stages, + args, + kwargs, + target, + self._stages_forward_initialized, + self._stages_backward_initialized, + ) def _validate_and_set_stage_mapping( self, actions: dict[int, list[_Action | None]] @@ -1674,8 +1791,8 @@ def _step_microbatches( not support models with skip connections. """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - - self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + maybe_first_target = target_mbs[0] if target_mbs is not None else None + self._initialize_stages(arg_mbs[0], kwarg_mbs[0], maybe_first_target) # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order @@ -1834,12 +1951,12 @@ def _step_microbatches( # do the communication _wait_batch_p2p(_batch_p2p(ops)) except Exception as e: - logger.error( # noqa: G200 + logger.error( "[Rank %s] pipeline schedule %s caught the following exception '%s' \ at time_step %s when running action %s", self.rank, self.__class__.__name__, - str(e), + e, time_step, action, ) @@ -2060,7 +2177,8 @@ def _step_microbatches( not support models with skip connections. """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) - self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + maybe_first_target = target_mbs[0] if target_mbs is not None else None + self._initialize_stages(arg_mbs[0], kwarg_mbs[0], maybe_first_target) # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index cd56efbbf4420..0fbe780714869 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -2,6 +2,7 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import logging import operator +import warnings from abc import ABC, abstractmethod from collections.abc import Callable from typing import Any, cast @@ -13,13 +14,39 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.distributed._composable.replicate_with_fsdp import replicate, ReplicateModule from torch.distributed.fsdp import FSDPModule, fully_shard +from torch.distributed.pipelining._utils import ( + _derive_grad_metas, + _DTensorMeta, + _make_tensor_from_meta, + _MeshCache, + _StageBackwardMeta, + _StageForwardMeta, + _StageMeta, + _TensorMeta, + extract_tensor_meta, + extract_tensor_metas, + flatten_args, + GetMeshCallback, + InferenceMode, + PipeInfo, + PipeliningMetadataError, + TensorMeta, + to_local_if_dtensor, + validate_and_normalize_to_tuple, + validate_static_arg_grad_correspondence, + validate_tensors_metadata, +) +from torch.distributed.tensor import DTensor from torch.fx.node import Argument, map_aggregate from torch.nn.parallel import DistributedDataParallel -from torch.utils._pytree import tree_map_only -from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._backward import ( + _autograd_grad_for_inputs, + stage_backward, + stage_backward_input, + stage_backward_weight, +) from ._debug import map_debug_info -from ._utils import flatten_args, PipeInfo, validate_tensors_metadata __all__ = [ @@ -63,61 +90,47 @@ def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: return output_tuple -class _RootArgPlaceholder: - """ - Placeholder for model-level inputs. - """ - - def __init__(self, tensor): - self.meta = tensor.to("meta") - - class _RecvInfo: - """ - Represents a stage input. + """Input tensor descriptor for a pipeline stage. + + Handles both received activations from a previous stage + (``is_root_arg=False``) and root-level model inputs provided + by the user (``is_root_arg=True``). """ def __init__( self, input_name: str, - source: int, - buffer: torch.Tensor, + source: int | None, + buffer: torch.Tensor | None, + tensor_meta: TensorMeta | None, + *, + is_root_arg: bool = False, ): # Name of this input self.input_name = input_name - # Stage index of the source of this input + # Stage index of the source of this input (None for root args) self.source = source - # Buffer to receive the input into. + # Buffer to receive the input into (None for root args) self.buffer = buffer + # Tensor metadata for validation and DTensor reconstruction + self.tensor_meta = tensor_meta + # Whether this is a root-level model input (no recv needed) + self.is_root_arg = is_root_arg def __repr__(self): - return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" - - -# An input can be either a received activation or a model input -InputInfo = _RecvInfo | _RootArgPlaceholder - - -def _make_tensor_from_meta( - example: torch.Tensor | FakeTensor, - device: torch.device, -) -> torch.Tensor: - """ - Create a real tensor from a tensor. - """ - return torch.empty( - example.size(), - dtype=example.dtype, - layout=example.layout, - device=device, - ) + if self.is_root_arg: + return f"_RecvInfo(input={self.input_name}, root_arg=True)" + meta_type = type(self.tensor_meta).__name__ if self.tensor_meta else "None" + buffer_shape = self.buffer.size() if self.buffer is not None else "None" + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={buffer_shape}, meta={meta_type})" class _PipelineStageBase(ABC): - """ - Base class for pipeline stages. - Defines or implements common methods used by the `_PipelineStage` used by - the tracing frontend and `PipelineStage` used by manual frontend. + """Base class for pipeline stages. + + Defines common methods used by ``_PipelineStage`` (tracing frontend) + and ``PipelineStage`` (manual frontend). """ def __init__( @@ -131,20 +144,16 @@ def __init__( ): """ Args: - submodule (torch.nn.Module): The module to be executed in this stage. - stage_index (int): The index of this stage. - num_stages (int): The total number of stages in this pipeline. - device (torch.device): The device to run this stage on. - group (Optional[dist.ProcessGroup]): The process group to use for communication. - If `None`, the default process group will be used. - Default: `None`. - dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder is a builder function - that will build a new dw_runner function that will run parts of module backward that were intentionally - skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs - model backwards, and stage should save the latest dw_runner to run during weight pas (W). - If not provided, a dw_runner will be generated automatically by traversing the autograd graph. - When used with schedules that only have F and B steps, the fresh dw_runner function will be called as - part of I (input backwards). When used with F,I,W schedules, the dw_runner function implements 'W'. + submodule: The module to be executed in this stage. + stage_index: The index of this stage. + num_stages: The total number of stages in this pipeline. + device: The device to run this stage on. + group: Process group for communication. Defaults to the + default process group if ``None``. + dw_builder: Builder function that produces a ``dw_runner`` + for deferred weight updates in F/I/W zero-bubble + schedules. If ``None``, a runner is generated + automatically via autograd graph traversal. """ super().__init__() if stage_index >= num_stages: @@ -175,7 +184,6 @@ def __init__( ) # Run time states - self._outputs_meta: tuple[torch.Tensor, ...] | None = None # map microbatch ID to list of forward tensor args self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} # map microbatch ID to list of backward grad tensor args @@ -190,7 +198,7 @@ def __init__( self.log_prefix = f"[Stage {self.stage_index}]" # Forward infra - self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.args_recv_info: dict[int, tuple[_RecvInfo, ...]] = {} self.act_send_info: dict[int, list] = {} # Backward infra will created lazily @@ -203,6 +211,17 @@ def __init__( i: i % self.group_size for i in range(self.num_stages) } + # DTensor support: mesh cache for looking up DeviceMesh by (dim_names, layout) + self._mesh_cache = _MeshCache() + + # Per-chunk runtime validation is expensive; only enable under + # TORCH_DISTRIBUTED_DEBUG=DETAIL for debugging shape/dtype mismatches. + self._runtime_validate = dist.get_debug_level() == dist.DebugLevel.DETAIL + + # DTensor support: consolidated stage metadata container + # Contains inputs, outputs, input_grads, output_grads metadata + self._stage_meta = _StageMeta() + @property def has_backward(self) -> bool: """ @@ -228,6 +247,21 @@ def is_last(self): """ return self.stage_index == self.num_stages - 1 + def _validate_stage_tensors( + self, + desc: str, + expected: tuple[TensorMeta | None, ...] | None, + actual: tuple[torch.Tensor | None, ...], + ) -> None: + """Validate actual tensors against expected metadata. + + Raises: + PipeliningMetadataError: If metadata is missing or mismatched. + """ + if expected is None: + raise PipeliningMetadataError(f"{desc}: no metadata available") + validate_tensors_metadata(desc, expected, actual) + def _check_chunk_id(self, chunk_id: int): if self.chunks is None: raise RuntimeError( @@ -238,27 +272,6 @@ def _check_chunk_id(self, chunk_id: int): f"Chunk id {chunk_id} is out of range [0, {self.chunks})" ) - def _configure_outputs_meta(self, outputs_meta: tuple[torch.Tensor, ...]): - """ - Track the output shapes/dtype of this stage since they determine the send operation(s) which must match - recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial - configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches - which could show up as hangs, silent corruption, or other errors. - """ - if self._outputs_meta is not None: - raise AssertionError( - "Attempting to reconfigure output_meta, which is not supported" - ) - self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] - - def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: - """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" - if self._outputs_meta is None: - raise AssertionError( - "Attempted to get_outputs_meta() without configuring output meta" - ) - return self._outputs_meta - def _create_grad_send_info( self, args_recv_info: tuple, @@ -272,12 +285,13 @@ def map_recv_to_send(a): # Note: we send gradients back to previous stage as long as in # forward it is a received input, regardless of whether it requires # grad. It is up to the previous stage to discard this gradient. - if isinstance(a, _RecvInfo): - grad_send_info.append(a.source) - return a.source - else: + if a.is_root_arg: + # Root args don't have a source stage to send gradients to grad_send_info.append(None) return None + else: + grad_send_info.append(a.source) + return a.source map_aggregate(args_recv_info, map_recv_to_send) @@ -288,17 +302,30 @@ def map_recv_to_send(a): def _prepare_forward_infra( self, num_microbatches: int, - args: tuple[Any, ...], + args: tuple[Any, ...] | _StageForwardMeta | None, kwargs: dict[str, Any] | None = None, - ) -> tuple[Any, ...]: + has_backward: bool = False, + ) -> _StageForwardMeta | None: + raise NotImplementedError + + @abstractmethod + def _prepare_backward_infra( + self, + num_microbatches: int, + loss_fn: Callable[..., torch.Tensor] | None = None, + target: torch.Tensor | None = None, + received_grad_meta: _StageBackwardMeta | None = None, + ) -> _StageBackwardMeta | None: raise NotImplementedError - def _prepare_backward_infra(self, num_microbatches: int): + def _setup_backward_recv_info(self, num_microbatches: int): # TODO: this is needed for backward_maybe_with_nosync self.chunks = num_microbatches + # IMPORTANT: _create_grad_recv_info reads self._stage_meta.output_grads + # to attach DTensor metadata to _RecvInfo objects. The clear below MUST + # happen after all _create_grad_recv_info calls complete. for mb_index in range(num_microbatches): - # `grad_recv_info` is a mirror of `act_send_info` self.grad_recv_info[mb_index] = self._create_grad_recv_info( self.act_send_info ) @@ -310,9 +337,17 @@ def _create_grad_recv_info( ) -> tuple[_RecvInfo, ...]: raise NotImplementedError + def _resolve_peer_global_rank(self, stage_idx: int) -> int: + """Map a pipeline stage index to the corresponding global rank for P2P communication.""" + peer_rank = self.stage_index_to_group_rank[stage_idx] + return dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + peer_rank, + ) + def _get_recv_ops( self, - recv_infos: tuple[InputInfo, ...], + recv_infos: tuple[_RecvInfo, ...], ) -> list[dist.P2POp]: """ Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. @@ -320,15 +355,16 @@ def _get_recv_ops( """ ops: list[dist.P2POp] = [] for info in recv_infos: - if not isinstance(info, _RecvInfo): + if info.is_root_arg: + # Root args don't need recv operations continue - - peer_rank = self.stage_index_to_group_rank[info.source] - peer_global_rank = ( - peer_rank - if self.group is None - else dist.get_global_rank(self.group, peer_rank) - ) + # Skip entries with None buffer (None gradients) + if info.buffer is None: + assert info.tensor_meta is None # noqa: S101 + continue + # At this point, source and buffer are guaranteed non-None + assert info.source is not None # noqa: S101 + peer_global_rank = self._resolve_peer_global_rank(info.source) ops.append( dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) ) @@ -337,46 +373,42 @@ def _get_recv_ops( """[Note: V-schedule special case] - V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + V-Schedules have a special case where 2 stages with adjacent stage_id + are on the same rank. - ex: 2 ranks, 4 stages forms a simple V: - rank0: stage 0 stage 3 - rank1: stage 1 stage 2 + Example: 2 ranks, 4 stages forms a simple V:: - stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to - use communication ops. Instead, they should pass tensor data directly via function call. + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 - set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and - should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + Stages 0/1 and 2/3 communicate via send/recv, but stages 1/2 pass + tensors directly via function call, avoiding communication ops. """ def set_local_fwd_input(self, prev_stage_outputs: Any, mb_index: int) -> None: + """Pass outputs from a same-rank stage as forward inputs (V-schedule). + + Detaches tensors and sets ``requires_grad`` so they serve as autograd + leaves. Handles DTensor activations transparently. """ - Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids - copying tensor data or using send/recv op. Detaches original tensor and sets requires_grad so the - tensor can serve as a leaf for autograd and gradients can be collected from it during backward. - """ - recv_infos: tuple[InputInfo, ...] = self.args_recv_info[mb_index] + recv_infos: tuple[_RecvInfo, ...] = self.args_recv_info[mb_index] # See [Note: pipeline model output type] prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) - for info, tensor in zip(recv_infos, prev_stage_outputs): + for info, tensor in zip(recv_infos, prev_stage_outputs, strict=True): if not isinstance(tensor, torch.Tensor): raise AssertionError( f"expected tensor values as outputs from prev stage, got {type(tensor)}" ) - if not isinstance(info, _RecvInfo): + if info.is_root_arg: raise AssertionError( - "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + "set_local_fwd_input should only be called on non-first stage, which should always have non-root RecvInfo" ) - # We don't need to do a data copy here, since we can directly pass the activation tensor reference from - # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve - # as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph. - # TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does - # detach have any affect on that? - info.buffer = tensor.detach().requires_grad_(True) + # Pass the activation tensor directly (same rank for local execution). + # Detach to create a new autograd leaf for the fresh autograd graph. + info.buffer = to_local_if_dtensor(tensor, detach=True) def get_local_bwd_output(self, mb_index): """ @@ -398,6 +430,7 @@ def set_local_bwd_input( """ Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. Does not detach or set '_requires_grad'. + Handles DTensor gradients for V-schedule local passing. """ if not isinstance(next_stage_bwd_outputs, tuple): raise AssertionError(f"Expected tuple, got {type(next_stage_bwd_outputs)}") @@ -409,21 +442,27 @@ def set_local_bwd_input( if self.is_last: raise AssertionError("can't set bwd input if this stage is last") recv_infos = self.grad_recv_info[mb_index] - for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + for info, tensor in zip(recv_infos, next_stage_bwd_outputs, strict=True): + if tensor is None: + continue if not isinstance(tensor, torch.Tensor): raise AssertionError( f"expected tensor values as outputs from prev stage, got {type(tensor)}" ) - if not isinstance(info, _RecvInfo): - raise AssertionError(f"Expected a recv info, got {type(info)}") - info.buffer = tensor + if info.is_root_arg: + raise AssertionError( + "set_local_bwd_input should only be called with non-root RecvInfo" + ) + + # Extract local tensor for the buffer (handles DTensor or plain tensor) + info.buffer = to_local_if_dtensor(tensor) def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: """ Returns a list of ops that are needed to receive the input arguments for this stage. """ - recv_infos: tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + recv_infos: tuple[_RecvInfo, ...] = self.args_recv_info[fwd_chunk_id] return self._get_recv_ops(recv_infos) @@ -441,6 +480,7 @@ def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: """ Get the activation send ops for current stage's forward. + Handles DTensor outputs by extracting local tensors. """ output_tuple, _ = self.fwd_cache[fwd_chunk_id] @@ -451,25 +491,25 @@ def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: for dst in dst_stages: if dst is None: continue + # Extract local tensor if DTensor + send_tensor = to_local_if_dtensor(out, detach=True) logger.debug( "%s Sending tensor to Stage %s: %s", self.log_prefix, dst, - out.size(), + send_tensor.size(), ) - peer_rank = self.stage_index_to_group_rank[dst] - peer_global_rank = ( - peer_rank - if self.group is None - else dist.get_global_rank(self.group, peer_rank) + peer_global_rank = self._resolve_peer_global_rank(dst) + ops.append( + dist.P2POp(dist.isend, send_tensor, peer_global_rank, self.group) ) - ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) return ops def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: """ Get the gradient send ops for current stage's backward. + Handles DTensor gradients by extracting local tensors. """ if not self.has_backward or self.is_first: return [] @@ -485,24 +525,24 @@ def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: ops: list[dist.P2POp] = [] grads_input = self.bwd_cache.pop(bwd_chunk_id) - for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info, strict=True): if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: + # Extract local tensor if DTensor + send_tensor = to_local_if_dtensor(grad) logger.debug( "%s Sending gradient to Stage %s: %s", self.log_prefix, grad_recv_stage, - grad.size(), + send_tensor.size(), ) - peer_rank = self.stage_index_to_group_rank[grad_recv_stage] - peer_global_rank = ( - peer_rank - if self.group is None - else dist.get_global_rank(self.group, peer_rank) + peer_global_rank = self._resolve_peer_global_rank(grad_recv_stage) + ops.append( + dist.P2POp(dist.isend, send_tensor, peer_global_rank, self.group) ) - ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) else: if grad is not None or grad_recv_stage is not None: - raise RuntimeError( + raise PipeliningMetadataError( f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " f"and is expecting to send gradients to stage {grad_recv_stage}" ) @@ -523,34 +563,90 @@ def clear_runtime_states(self) -> None: # don't want such accumulation. for recv_tuple in self.args_recv_info.values(): # iterate over all chunks for a in recv_tuple: # iterate over all input args - if isinstance(a, _RecvInfo): + if not a.is_root_arg and a.buffer is not None: # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. # See https://github.com/pytorch/pytorch/pull/92731 a.buffer.grad = None def _map_tensor_from_recv_info( self, - recv_infos: tuple[InputInfo, ...], + recv_infos: tuple[_RecvInfo, ...], ): """ Map tensors from recv infos to a list. """ def get_recv_tensor(info): - if isinstance(info, _RecvInfo): - return info.buffer - else: - raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + if info.is_root_arg: + raise PipeliningMetadataError("Cannot get recv tensor from root arg") + return info.buffer return map_aggregate(cast(Argument, recv_infos), get_recv_tensor) - def _retrieve_recv_activations(self, fwd_chunk_id: int): + def _retrieve_recv_activations( + self, + fwd_chunk_id: int, + ): """ Retrieve the activations received for the current stage during forward. + Reconstructs DTensors if the inputs were DTensors. + Also validates DTensor metadata against expected values. """ recv_infos = self.args_recv_info[fwd_chunk_id] - activations = self._map_tensor_from_recv_info(recv_infos) - return activations + + activations = [] + for i, info in enumerate(recv_infos): + if not info.is_root_arg: + # Non-root args have valid buffer and tensor_meta + if info.buffer is None or info.tensor_meta is None: + raise PipeliningMetadataError( + f"Non-root arg '{info.input_name}' has None buffer or tensor_meta" + ) + # Effective requires_grad: metadata captures what the model + # produced, but the runtime context (has_backward, grad mode) + # determines whether we actually need gradients. + effective_requires_grad = ( + info.tensor_meta.requires_grad + and self.has_backward + and torch.is_grad_enabled() + ) + if isinstance(info.tensor_meta, _DTensorMeta): + # Buffer must not require grad so from_local stays out + # of the autograd graph (no grad_placements needed). + if info.buffer.requires_grad: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: recv buffer " + f"'{info.input_name}' unexpectedly requires grad " + f"before DTensor reconstruction" + ) + mesh = self._mesh_cache.get_mesh(info.tensor_meta.mesh_cache_key) + activation = DTensor.from_local( + info.buffer, + device_mesh=mesh, + placements=info.tensor_meta.placements, + shape=info.tensor_meta.global_shape, + stride=info.tensor_meta.global_stride, + run_check=False, + ).requires_grad_(effective_requires_grad) + else: + activation = info.buffer.requires_grad_(effective_requires_grad) + # Activation must be a leaf so backward terminates here. + if effective_requires_grad and not activation.is_leaf: + warnings.warn( + f"Stage {self.stage_index}: activation " + f"'{info.input_name}' is not a leaf " + f"(grad_fn={activation.grad_fn}); using " + f"retain_grad() as fallback", + stacklevel=2, + ) + activation.retain_grad() + activations.append(activation) + else: + raise PipeliningMetadataError( + f"_retrieve_recv_activations expected non-root _RecvInfo but got root arg at index {i}" + ) + + return tuple(activations) def _retrieve_recv_grads( self, @@ -558,10 +654,50 @@ def _retrieve_recv_grads( ): """ Retrieve the gradients received for the current stage during backward. + + Handles None gradients gracefully (for inputs that don't require grad). """ recv_infos = self.grad_recv_info[bwd_chunk_id] - grads = self._map_tensor_from_recv_info(recv_infos) - return grads + + grads: list[torch.Tensor | None] = [] + for i, info in enumerate(recv_infos): + if not isinstance(info, _RecvInfo): + raise PipeliningMetadataError( + f"Expected _RecvInfo but got {type(info)}" + ) + if not info.is_root_arg: + # Gradients can be None for non-differentiable outputs + if info.buffer is None: + if info.tensor_meta is not None: + raise PipeliningMetadataError( + f"Grad recv '{info.input_name}': buffer is None but tensor_meta is not None" + ) + grads.append(None) + continue + if info.tensor_meta is None: + raise PipeliningMetadataError( + f"Grad recv '{info.input_name}': buffer is not None but tensor_meta is None" + ) + if isinstance(info.tensor_meta, _DTensorMeta): + # Reconstruct DTensor gradient from local tensor + metadata + mesh = self._mesh_cache.get_mesh(info.tensor_meta.mesh_cache_key) + grad = DTensor.from_local( + info.buffer, + device_mesh=mesh, + placements=info.tensor_meta.placements, + shape=info.tensor_meta.global_shape, + stride=info.tensor_meta.global_stride, + run_check=False, + ) + else: + grad = info.buffer + grads.append(grad) + else: + raise PipeliningMetadataError( + f"grad_recv_info should not contain root args, but found one at index {i}" + ) + + return tuple(grads) def forward_maybe_with_nosync(self, *args, **kwargs): # If submod is wrapped with DDP, we use the `no_sync` context manager to @@ -691,7 +827,12 @@ def forward_one_chunk( composite_kwargs = kwargs or {} - self._validate_fwd_input(args, kwargs) + if self._runtime_validate: + self._validate_stage_tensors( + f"Stage {self.stage_index} forward inputs", + self._stage_meta.inputs, + composite_args, + ) # Compute forward try: @@ -727,7 +868,14 @@ def forward_one_chunk( fwd_chunk_id, map_debug_info(output), ) - self._validate_fwd_outputs(output_tuple) + # Validate outputs before P2P send; skipped for last stage (outputs + # go to loss/user, not via send/recv). + if self._runtime_validate and not self.is_last: + self._validate_stage_tensors( + f"Stage {self.stage_index} forward outputs", + self._stage_meta.outputs, + output_tuple, + ) # We return the original user-provided output, not normalized to tuple. # See [Note: pipeline model output type] @@ -776,6 +924,13 @@ def backward_one_chunk( else: # Otherwise, receive gradients from next stage grads_output = self._retrieve_recv_grads(bwd_chunk_id) + if self._runtime_validate: + # Validate backward input (output gradients) for DTensor metadata + self._validate_stage_tensors( + f"Stage {self.stage_index} backward input (output_grads)", + self._stage_meta.output_grads, + grads_output, + ) # If an input to the pipeline requires gradient, # `torch.autograd.backward` will accumulate the gradient into the # `.grad` field of such input @@ -828,8 +983,18 @@ def backward_one_chunk( ) # Save a placeholder for the dw_runner self.dw_runner[bwd_chunk_id] = lambda: None - - self.bwd_cache[bwd_chunk_id] = grads_input + # Note: grads_input may contain gradients for both args and kwargs (from fwd_cache), + # Kwargs are local to each stage and don't need gradient transmission. + # Validate backward output (input gradients) for DTensor metadata + assert self._stage_meta.inputs is not None # noqa: S101 + num_fwd_args = len(self._stage_meta.inputs) + if self._runtime_validate and not self.is_first: + self._validate_stage_tensors( + f"Stage {self.stage_index} backward output (input_grads)", + self._stage_meta.input_grads, + grads_input[:num_fwd_args], + ) + self.bwd_cache[bwd_chunk_id] = grads_input[:num_fwd_args] if self.is_last and not self.is_first: # Autograd dependencies: @@ -887,48 +1052,6 @@ def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): "full", bwd_kwargs, last_backward=last_backward ) - def _validate_fwd_input(self, args, kwargs): - """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" - - if self.is_first: - # TODO why is there a separate recv_info for each pipeline chunk? - # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we - # check all chunks against args_recv_info[0] - expected_args = self.args_recv_info[0] - else: - # We don't check inputs for non-0 stages assuming they don't accept - # user inputs in canonical pipeline scenarios - return - - if len(kwargs): - # TODO- need a mapping of kwarg to position in self.args_recv_info - # Without it, we are not 100% sure how to match the args and - # expected_args. - return - - # TODO- need a mapping of kwarg to position in self.args_recv_info - # maybe it's impossible to tell whether the len mismatches because - # (a) the user passed an extra arg or missed an arg - # (b) the user did not pass a kwarg, which has a default value baked into expected_args - expected_tensors_meta = [ - e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer - for e in expected_args - ] - validate_tensors_metadata( - f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args - ) - - def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): - """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. - Most likely, this could be cause either by incorrect user specification of output shapes, or because - shape inference was done on the original model but then at runtime the model is wrapped with something like - mixed precision which changes output dtype. - """ - expected_tensors_meta = self.get_outputs_meta() - validate_tensors_metadata( - f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs - ) - def _get_init_p2p_neighbors_ops(self) -> list[dist.P2POp]: """ Get the operations to initialize the p2p communicators between previous and next stages. @@ -1054,7 +1177,7 @@ def __init__( node for node in pipe_info.graph.nodes if node.op == "call_module" ] if len(submod_nodes) != self.num_stages: - raise AssertionError( + raise PipeliningMetadataError( f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" ) @@ -1092,21 +1215,85 @@ def _move_submod_to_device(self): def _prepare_forward_infra( self, num_microbatches: int, - args: tuple[Any, ...], + args: tuple[Any, ...] | _StageForwardMeta | None, kwargs: dict[str, Any] | None = None, - ) -> tuple[Any, ...]: + has_backward: bool = False, + ) -> _StageForwardMeta | None: """ - Create send/recv infrastructures for activations (during forward) + Prepare forward infrastructure for traced pipeline. + + Metadata is created directly from graph placeholders with correct + ``requires_grad`` — received activations get ``requires_grad=True`` + when ``has_backward`` is set, fixing the fact that ``torch.export`` + traces under ``no_grad()``. + + ``_stage_meta.inputs`` is derived from recv infos and aligned with + ``forward_one_chunk``'s ``composite_args``: positional root inputs + on the first stage, received activations only on subsequent stages. """ - # TODO(whc) - # this method should be deleted once lazy buffer allocation is implemented - # for now, it ignores args/kwargs because it should not need to do shape inference + # Step 1: Create recv info for each microbatch. + # _create_act_recv_info is self-contained: it creates _TensorMeta + # directly from graph placeholder values with correct requires_grad. for chunk in range(num_microbatches): self.args_recv_info[chunk] = self._create_act_recv_info() - # Send info during forward for each activation + # Step 2: Derive _stage_meta.inputs from recv infos. + # forward_one_chunk builds composite_args as: + # - First stage: args (positional root inputs, excludes kwargs) + # - Non-first stages: received activations only (no root kwargs) + # _stage_meta.inputs must match composite_args for validation. + recv_infos = self.args_recv_info[0] + if self.is_first: + # All placeholders are root args. Only the first len(args) + # correspond to positional inputs (composite_args); the rest + # are kwargs passed separately via composite_kwargs. + # First stage always receives real tensor args, never _StageForwardMeta. + if not isinstance(args, tuple): + raise AssertionError("First stage requires real tensor args") + n_positional = len(args) + self._stage_meta.inputs = tuple( + info.tensor_meta # type: ignore[misc] + for info in recv_infos[:n_positional] + ) + else: + self._stage_meta.inputs = tuple( + info.tensor_meta # type: ignore[misc] + for info in recv_infos + if not info.is_root_arg + ) + + # Step 3: Create send info and output metadata. self.act_send_info = self._create_act_send_info() - return tuple() + + return None + + def _prepare_backward_infra( + self, + num_microbatches: int, + loss_fn: Callable[..., torch.Tensor] | None = None, + target: torch.Tensor | None = None, + received_grad_meta: _StageBackwardMeta | None = None, + ) -> _StageBackwardMeta | None: + """ + Prepare backward infrastructure for traced pipeline. + Derives input_grads metadata from inputs (plain tensors only). + + Note: DTensors are NOT supported in the traced frontend. + """ + # Derive input_grads from inputs (for plain tensors, grad shape == input shape) + if self._stage_meta.inputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: inputs metadata required for backward inference." + ) + + self._stage_meta.input_grads = _derive_grad_metas(self._stage_meta.inputs) + + # Setup backward recv info (calls _create_grad_recv_info which sets output_grads). + # Note: grad_send_info is created lazily in get_bwd_send_ops() since + # it mirrors args_recv_info (already populated during forward). + self._setup_backward_recv_info(num_microbatches) + + return None def get_stage_index_of_submod( self, @@ -1116,7 +1303,7 @@ def get_stage_index_of_submod( Given a submodule name, return the stage index of the submodule. """ if submod_name not in self.submod_to_stage_index: - raise AssertionError(f"Stage id of {submod_name} not found") + raise PipeliningMetadataError(f"Stage id of {submod_name} not found") return self.submod_to_stage_index[submod_name] @@ -1125,38 +1312,65 @@ def _create_act_recv_info( ): """ Create a tuple of `_RecvInfo` for inputs to the stage. + + Self-contained: creates ``_TensorMeta`` directly from graph + placeholder values with correct ``requires_grad``. + ``torch.export`` traces under ``no_grad()`` so traced metadata + always has ``requires_grad=False``; for received activations we + set ``requires_grad=True`` when ``has_backward`` is set. + + Note: DTensors are NOT supported in the traced frontend. """ def create_recv_tensor(placeholder, arg_node): - """ - Create a receive buffer for a placeholder. - """ example_value = placeholder.meta["val"] + + # Reject DTensors in traced frontend + if isinstance(example_value, DTensor): + raise PipeliningMetadataError( + f"{self.log_prefix} DTensor detected in traced pipeline input " + f"'{placeholder.name}'. DTensor metadata propagation is NOT " + f"supported for the traced frontend (_PipelineStage). " + f"Use the manual PipelineStage frontend for full DTensor support." + ) + if arg_node.op == "placeholder": - # This is a root level placeholder, thus an input argument to the entire model. - # We are likely at stage 0, hence no need to create a receive buffer. - return _RootArgPlaceholder(example_value) + # Root-level placeholder: an input argument to the entire + # model. Keep original metadata from the trace. + return _RecvInfo( + input_name=f"root_input_{placeholder.name}", + source=None, + buffer=None, + tensor_meta=_TensorMeta.from_tensor(example_value), + is_root_arg=True, + ) - # Figure out the source stage of this input + # Received activation from a previous stage. while arg_node.target is operator.getitem: - # If the input is a getitem, we need to go deeper arg_node = arg_node.args[0] if arg_node.op != "call_module": - raise AssertionError(f"Expecting call_module, got {arg_node.op}") + raise PipeliningMetadataError( + f"Expecting call_module, got {arg_node.op}" + ) src_stage = self.get_stage_index_of_submod(arg_node.name) - # Create a receive buffer for this placeholder + # Create metadata directly with correct requires_grad. + tensor_meta = _TensorMeta( + shape=example_value.shape, + stride=example_value.stride(), + dtype=example_value.dtype, + requires_grad=self.has_backward, + ) + logger.debug( "%s Creating recv buffer for input '%s' : %s, %s", self.log_prefix, placeholder.name, - example_value.shape, - example_value.dtype, + tensor_meta.shape, + tensor_meta.dtype, ) - buffer = _make_tensor_from_meta(example_value, self.device) - # In case there is backward pass, set requires_grad for receive buffers - # before first forward + buffer = _make_tensor_from_meta(tensor_meta, self.device) if self.has_backward: buffer.requires_grad_(True) @@ -1164,10 +1378,10 @@ def create_recv_tensor(placeholder, arg_node): arg_node.name, src_stage, buffer, + tensor_meta, ) - args_recv_info: list[InputInfo] = [] - # Filter out placeholder nodes from `self.submod` (a GraphModule) + args_recv_info: list[_RecvInfo] = [] placeholders = filter( # type: ignore[var-annotated] lambda node: node.op == "placeholder", # type: ignore[arg-type] self.submod.graph.nodes, # type: ignore[arg-type,union-attr] @@ -1175,15 +1389,12 @@ def create_recv_tensor(placeholder, arg_node): # `placeholders` are nodes internal to submod. # `self.node.args` are dependency nodes in the outer graph. # The two are 1:1. - for placeholder, arg_node in zip(placeholders, self.node.args): - # Create a receive buffer for this placeholder - recv_info = create_recv_tensor(placeholder, arg_node) - args_recv_info.append(recv_info) + for placeholder, arg_node in zip(placeholders, self.node.args, strict=True): + args_recv_info.append(create_recv_tensor(placeholder, arg_node)) logger.debug( "%s Activation recv / args info: %s", self.log_prefix, args_recv_info ) - # `args` is a Tuple, hence we will return a Tuple[InputInfo] return tuple(args_recv_info) def find_dst_rank( @@ -1207,14 +1418,12 @@ def find_dst_rank( def _create_act_send_info(self): """ - Create a dict of send info for activations. - The dict is of the form: - { - output_index: [dst_rank_0, dst_rank_1, ...], - ... - } - where the list of `dst_rank`s covers the case where an output value may - be consumed by multiple stages. + Create a dict of send info for activations and output metadata. + + Output metadata is created directly with correct ``requires_grad`` + (``torch.export`` traces under ``no_grad()``, so traced values + always have ``requires_grad=False``; at runtime, stage outputs + carry ``requires_grad=True`` when training). """ # Output index: List of receiver ranks act_send_info: dict[int, list] = {} @@ -1222,16 +1431,13 @@ def _create_act_send_info(self): for user in self.node.users: if user.target is operator.getitem: - # Recursively find the real destination gi_dsts = act_send_info.setdefault(out_idx, []) for gi_user in user.users: dst_rank = self.find_dst_rank(gi_user) if dst_rank is not None: gi_dsts.append(dst_rank) - # Next `getitem` will point to the next output index out_idx += 1 else: - # In case of single output value, `out_idx` will not increase dsts = act_send_info.setdefault(out_idx, []) dst_rank = self.find_dst_rank(user) if dst_rank is not None: @@ -1241,7 +1447,25 @@ def _create_act_send_info(self): output_vals: tuple[torch.Tensor] = tuple( v.meta["val"] for v in flatten_args(output_node.args) ) - self._configure_outputs_meta(output_vals) + # Reject DTensors and create output metadata directly with + # correct requires_grad. + output_metas: list[_TensorMeta] = [] + for i, val in enumerate(output_vals): + if isinstance(val, DTensor): + raise PipeliningMetadataError( + f"{self.log_prefix} DTensor detected in traced pipeline output index {i}. " + f"DTensor metadata propagation is NOT supported for the traced frontend " + f"(_PipelineStage). Use the manual PipelineStage frontend for full DTensor support." + ) + output_metas.append( + _TensorMeta( + shape=val.shape, + stride=val.stride(), + dtype=val.dtype, + requires_grad=self.has_backward, + ) + ) + self._stage_meta.outputs = tuple(output_metas) logger.debug("%s Send info: %s", self.log_prefix, act_send_info) return act_send_info @@ -1249,50 +1473,79 @@ def _create_act_send_info(self): def _get_output_node(self): output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] # type: ignore[union-attr] if len(output_nodes) != 1: - raise AssertionError(f"Expected 1 output node, got {len(output_nodes)}") + raise PipeliningMetadataError( + f"Expected 1 output node, got {len(output_nodes)}" + ) output_node = output_nodes[0] return output_node - def _create_grad_recv_info( - self, - act_send_info: dict, - ) -> tuple[_RecvInfo, ...]: + def _create_grad_recv_info(self, act_send_info: dict) -> tuple[_RecvInfo, ...]: """ Create a tuple of `_RecvInfo` for gradients. + Reuses output metadata from _stage_meta.outputs (populated by _create_act_send_info). """ - # Dict[output_index, _RecvInfo] - grad_recv_info: dict[int, _RecvInfo] = {} - output_node = self._get_output_node() + if self._stage_meta.outputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: outputs metadata required for grad recv info. " + f"Ensure _create_act_send_info is called first." + ) - # The output node may take multiple args, meaning the submod having multiple output values. - output_vals = flatten_args(output_node.args) + outputs_meta = self._stage_meta.outputs + output_grads_metas: list[TensorMeta | None] = [] + grad_recv_infos: list[_RecvInfo] = [] - for out_idx, dst_list in act_send_info.items(): - if not dst_list: - # No actual receiver for activation so no grad coming back - continue + for out_idx, out_meta in enumerate(outputs_meta): + dst_list = act_send_info.get(out_idx, []) - output = output_vals[out_idx] - example_value = output.meta["val"] - logger.debug( - f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 - f": {example_value.shape}, {example_value.dtype}" - ) + # Determine the source stage for gradients + grad_src = dst_list[0] if dst_list else self.stage_index + 1 - # TODO: otherwise needs grad accumulation - if len(dst_list) != 1: - raise AssertionError("Backward of skip connections not supported yet") - grad_src = dst_list[0] - grad_recv_info[out_idx] = _RecvInfo( - f"{grad_src}", # noqa: G004 - grad_src, - _make_tensor_from_meta(example_value, self.device), - ) + # Check if this output needs gradients + if not dst_list or not out_meta.requires_grad: + output_grads_metas.append(None) + grad_recv_infos.append( + _RecvInfo( + input_name=f"recv_grad_for_{self.stage_index}_none_{out_idx}", + source=grad_src, + buffer=None, + tensor_meta=None, + ) + ) + else: + # Derive grad metadata from output metadata (same shape, requires_grad=False) + grad_meta = _TensorMeta( + shape=out_meta.shape, + stride=out_meta.stride, + dtype=out_meta.dtype, + requires_grad=False, + ) + output_grads_metas.append(grad_meta) - # Convert to tuple for convenience in get_ops and retrieve tensor - grad_recv_info_tuple = tuple(grad_recv_info.values()) - logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) - return grad_recv_info_tuple + if len(dst_list) != 1: + raise PipeliningMetadataError( + "Backward of skip connections not supported yet" + ) + + logger.debug( + "%s Creating grad recv buffer for output %s : %s, %s", + self.log_prefix, + out_idx, + grad_meta.shape, + grad_meta.dtype, + ) + + grad_recv_infos.append( + _RecvInfo( + input_name=f"recv_grad_for_{self.stage_index}_from_{grad_src}", + source=grad_src, + buffer=_make_tensor_from_meta(grad_meta, self.device), + tensor_meta=grad_meta, + ) + ) + + self._stage_meta.output_grads = tuple(output_grads_metas) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_infos) + return tuple(grad_recv_infos) # A helper function to create a pipeline stage based on traced pipeline information @@ -1327,26 +1580,33 @@ def build_stage( class PipelineStage(_PipelineStageBase): - """ - A class representing a pipeline stage in a pipeline parallelism setup. + """A pipeline stage for pipeline parallelism with sequential model partitioning. + + Supports both **static** and **dynamic** metadata inference: - PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from - one chunk feed into inputs of the next chunk, with no skip connections. + Static mode: + All of ``input_args``, ``output_args`` (and ``input_grads``/``output_grads`` + when DTensors are present) are provided at construction time. - PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to - stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each - PipelineStage instance. + Dynamic mode: + Metadata is inferred from the first microbatch at runtime; any + statically provided args are used for validation only. Args: - submodule (nn.Module): The PyTorch module wrapped by this stage. - stage_index (int): The ID of this stage. - num_stages (int): The total number of stages. - device (torch.device): The device where this stage is located. - input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. - output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. - group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. - dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder will build a new dw_runner function - that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules. + submodule: The ``nn.Module`` wrapped by this stage. + stage_index: Zero-based stage ID. + num_stages: Total number of stages in the pipeline. + device: Device this stage runs on. + input_args: Example input tensors (single tensor or tuple). Optional. + output_args: Example output tensors. Optional. + output_grads: Example output gradients (received from next stage). Optional. + input_grads: Example input gradients (sent to previous stage). Optional. + group: Process group for P2P communication. Defaults to the + world process group. + dw_builder: Builder for deferred weight-update runners used by + zero-bubble (F/I/W) schedules. + get_mesh: `GetMeshCallback` used during + dynamic DTensor inference. Ignored in fully static DTensor mode. """ def __init__( @@ -1357,247 +1617,633 @@ def __init__( device: torch.device, input_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, output_args: torch.Tensor | tuple[torch.Tensor, ...] | None = None, + output_grads: torch.Tensor | tuple[torch.Tensor | None, ...] | None = None, + input_grads: torch.Tensor | tuple[torch.Tensor | None, ...] | None = None, group: dist.ProcessGroup | None = None, dw_builder: Callable[[], Callable[..., None]] | None = None, + get_mesh: GetMeshCallback | None = None, ): super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) - self.inputs: list[torch.Tensor] | None = None - self.inputs_meta: tuple[torch.Tensor, ...] | None = None - # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it - # might be breaking for existing users. - if input_args is None: - if output_args is not None: - raise AssertionError( - "If specifying output_args, input_args must also be specified. " - "Otherwise, shape inference will be performed at runtime" - ) - else: - self.inputs_meta = ( - (input_args,) if isinstance(input_args, torch.Tensor) else input_args - ) - if output_args is None: - logger.warning( - "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " - "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " - "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " - "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " + + self._mesh_cache = _MeshCache(get_mesh_cb=get_mesh) + self._inference_mode: InferenceMode | None = None + self._fwd_outputs_for_bwd_meta: tuple[torch.Tensor, ...] | None = None + self._fwd_inputs_for_bwd_meta: tuple[torch.Tensor, ...] | None = None + self._fwd_kwargs_tensors_for_bwd_meta: tuple[torch.Tensor, ...] | None = None + + # Validate and normalize args to tuples + inputs = validate_and_normalize_to_tuple(input_args) + outputs = validate_and_normalize_to_tuple(output_args) + in_grads = validate_and_normalize_to_tuple(input_grads, allow_none=True) + out_grads = validate_and_normalize_to_tuple(output_grads, allow_none=True) + + self._user_meta = _StageMeta( + inputs=extract_tensor_metas(inputs), + outputs=extract_tensor_metas(outputs), + input_grads=extract_tensor_metas(in_grads, allow_none=True), + output_grads=extract_tensor_metas(out_grads, allow_none=True), + ) + + # Cache meshes from user-provided DTensors + for args in (inputs, outputs, in_grads, out_grads): + if args is not None: + self._mesh_cache.update_from_tensors(args) + + # Validate DTensor↔grad correspondence independently for inputs and outputs + if self._user_meta.has_dtensors(): + if inputs and in_grads: + validate_static_arg_grad_correspondence( + self.stage_index, inputs, in_grads, is_input=True ) - try: - with torch.no_grad(): - output_args = submodule(*self.inputs_meta) - output_args = tree_map_only( - torch.Tensor, lambda x: x.to("meta"), output_args - ) - except Exception as e: - raise RuntimeError( - "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" - ) from e - if output_args is None: - raise AssertionError( - "If passing input_args, also pass output_args to override shape inference" + if outputs and out_grads: + validate_static_arg_grad_correspondence( + self.stage_index, outputs, out_grads, is_input=False ) - self._configure_outputs_meta( - (output_args,) if isinstance(output_args, torch.Tensor) else output_args + + def _recv_meta(self, src_stage: int) -> Any: + """Receive metadata object from a stage on a different rank via P2P.""" + objects: list[Any] = [None] + dist.recv_object_list( + objects, + src=self._resolve_peer_global_rank(src_stage), + group=self.group, + device=self.device, + use_batch=True, + ) + if len(objects) != 1: + raise PipeliningMetadataError( + f"Expected exactly one object to be received but got: {len(objects)}" ) + return objects[0] + + def _send_meta(self, meta: Any, dst_stage: int) -> None: + """Send metadata object to a stage on a different rank via P2P.""" + dist.send_object_list( + [meta], + dst=self._resolve_peer_global_rank(dst_stage), + group=self.group, + device=self.device, + use_batch=True, + ) + + def _is_same_rank(self, other_stage: int) -> bool: + """Check if another stage is on the same rank as this stage.""" + return self.stage_index_to_group_rank[other_stage] == self.group_rank - # these are the buffers used in backwards send/recv, they are allocated later - self.outputs_grad: list[torch.Tensor] = [] + def _warmup_forward_vote( + self, has_backward: bool, received_acc: torch.Tensor | None = None + ) -> torch.Tensor: + """Forward phase of the warm-up vote protocol (stage 0 → N−1). + + Each stage computes a vote (1 = STATIC, 0 = DYNAMIC) based on + ``InferenceMode.needs_dynamic``, multiplies it with the accumulated + product from the previous stage, and forwards the result to the next + stage. The final product at stage N−1 is 1 iff *every* stage voted + STATIC. + + Args: + has_backward: Whether the schedule includes a backward pass. + received_acc: Accumulated product tensor from the previous + same-rank stage (V-schedule), or ``None`` for the first + stage / cross-rank. + + Returns: + The accumulated product tensor after this stage's vote. + """ + my_vote = 0 if InferenceMode.needs_dynamic(self._user_meta, has_backward) else 1 + + my_vote_t = torch.tensor([my_vote], dtype=torch.int32, device=self.device) + + if self.is_first: + acc = my_vote_t + elif self._is_same_rank(self.stage_index - 1): + assert received_acc is not None # noqa: S101 + acc = received_acc * my_vote_t + else: + peer_global = self._resolve_peer_global_rank(self.stage_index - 1) + acc = torch.zeros(1, dtype=torch.int32, device=self.device) + dist.recv(acc, src=peer_global, group=self.group) + acc = acc * my_vote_t - dbg_str = ( - f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 - f"{self.is_last=}, {self.num_stages=}, " + if not self.is_last and not self._is_same_rank(self.stage_index + 1): + peer_global = self._resolve_peer_global_rank(self.stage_index + 1) + dist.send(acc, dst=peer_global, group=self.group) + + return acc + + def _warmup_backward_result( + self, received_result: torch.Tensor | None = None + ) -> torch.Tensor: + """Backward phase of the warm-up vote protocol (stage N−1 → 0). + + Propagates the final accumulated product (computed in the forward + phase) back through the pipeline so every stage learns the global + inference mode. + + Args: + received_result: Result tensor from the next same-rank stage + (V-schedule), or ``None`` for the last stage / cross-rank. + + Returns: + The global vote result tensor for this stage. + """ + if self.is_last or self._is_same_rank(self.stage_index + 1): + assert received_result is not None # noqa: S101 + result = received_result + else: + peer_global = self._resolve_peer_global_rank(self.stage_index + 1) + result = torch.zeros(1, dtype=torch.int32, device=self.device) + dist.recv(result, src=peer_global, group=self.group) + + if not self.is_first and not self._is_same_rank(self.stage_index - 1): + peer_global = self._resolve_peer_global_rank(self.stage_index - 1) + dist.send(result, dst=peer_global, group=self.group) + + return result + + def _compute_outputs( + self, + *args: torch.Tensor, + module: torch.nn.Module, + **kwargs: Any, + ) -> torch.Tensor | tuple[torch.Tensor, ...] | list[torch.Tensor] | None: + """Compute outputs of the submodule.""" + return module(*args, **kwargs) + + def _compute_input_grads( + self, + outputs: list[torch.Tensor], + all_fwd_inputs: list[torch.Tensor], + grad_outputs: list[torch.Tensor | None] | None = None, + ) -> tuple[torch.Tensor | None, ...]: + """Compute input gradients via :func:`_autograd_grad_for_inputs`.""" + return _autograd_grad_for_inputs( + outputs, + all_fwd_inputs, + grad_outputs, ) - if self.inputs_meta is not None: - dbg_str += ( - f"inputs: {[inp.shape for inp in self.inputs_meta]}, " - f"output: {[output.shape for output in self.get_outputs_meta()]}" - ) + + def _to_tensor(self, arg: torch.Tensor | TensorMeta) -> torch.Tensor: + """Convert a tensor or metadata to a real tensor on ``self.device``. + + Real tensors are detached and re-set requires_grad to create a fresh + autograd leaf, isolating metadata inference from the user's graph. + TensorMeta is materialized as an empty tensor (or DTensor via mesh cache). + """ + if isinstance(arg, torch.Tensor): + return arg.detach().requires_grad_(arg.requires_grad) + elif isinstance(arg, TensorMeta): + if isinstance(arg, _DTensorMeta): + mesh = self._mesh_cache.get_mesh(arg.mesh_cache_key) + return arg.to_dtensor(self.device, mesh) + else: + return arg.to_tensor(self.device) else: - dbg_str += " running shape-inference at runtime" + raise PipeliningMetadataError( + f"Unsupported type {type(arg)} for _to_tensor: {arg}" + ) - logger.debug(dbg_str) + def _ones_from_metadata(self, meta: TensorMeta) -> torch.Tensor: + """Create a ones tensor from metadata for backward inference grad_outputs.""" + local_ones = torch.ones( + meta.shape, + dtype=meta.dtype, + device=self.device, + ) + if isinstance(meta, _DTensorMeta): + mesh = self._mesh_cache.get_mesh(meta.mesh_cache_key) + return DTensor.from_local( + local_ones, + device_mesh=mesh, + placements=meta.placements, + shape=meta.global_shape, + stride=meta.global_stride, + run_check=False, + ) + return local_ones - def _shape_inference( + def _forward_metadata_inference( self, - args: tuple[Any, ...], + args: tuple[torch.Tensor, ...] | _StageForwardMeta | None, kwargs: dict[str, Any] | None = None, - ): - if kwargs is None: - kwargs = {} - if args is None: - raise AssertionError("Args may be an empty tuple but not None") - - # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank - # and can pass its output shapes in as args instead of using send/recv. - if ( - self.is_first - # if not first stage, then check if prev stage is on the same rank - or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank - ): - logger.debug( - "Shape inference: stage %s skipping recv, because shape info passed in via `args`", - self.stage_index, - ) - args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) + has_backward: bool = False, + ) -> _StageForwardMeta | None: + """Run forward metadata inference (Stage 0 → N). + + Args: + args: Real tensors (first stage), ``_StageForwardMeta`` + (same-rank), or ``None`` (cross-rank P2P). + kwargs: Keyword arguments forwarded to the submodule. + has_backward: Whether backward inference follows. + + Returns: + ``_StageForwardMeta`` for the next stage, or ``None`` if sent via P2P. + """ + kwargs = kwargs or {} + + # === RECEIVE: Get input metadata and create meta tensors === + if self.is_first: + # First stage: extract metadata from real tensors + if args is None or isinstance(args, _StageForwardMeta): + raise PipeliningMetadataError( + f"Stage {self.stage_index}: First stage requires real tensors, " + f"got {type(args).__name__}." + ) + tensor_args = validate_and_normalize_to_tuple(args) + assert tensor_args is not None # noqa: S101 + self._stage_meta.inputs = extract_tensor_metas(tensor_args) + inference_args = tuple(self._to_tensor(a) for a in tensor_args) + elif self._is_same_rank(self.stage_index - 1): + # Same-rank: _StageForwardMeta passed via argument + if not isinstance(args, _StageForwardMeta): + raise PipeliningMetadataError( + f"Stage {self.stage_index}: Expected _StageForwardMeta from same-rank " + f"previous stage, got {type(args).__name__}." + ) + self._stage_meta.inputs = args.forward_metas + inference_args = tuple(self._to_tensor(m) for m in args.forward_metas) else: - if len(args) != 0: - raise AssertionError( - "Can't supply input args for shape inference on non-first stage" + # Cross-rank: receive _StageForwardMeta via P2P + recv_meta = self._recv_meta(self.stage_index - 1) + if not isinstance(recv_meta, _StageForwardMeta): + raise PipeliningMetadataError( + f"Stage {self.stage_index}: Expected _StageForwardMeta from P2P, " + f"got {type(recv_meta).__name__}." ) - objects = [None] - logger.debug( - "Shape inference: stage %s receiving from stage %s", - self.stage_index, - self.stage_index - 1, - ) - dist.recv_object_list( - objects, - src=dist.get_global_rank( - self.group or dist.distributed_c10d._get_default_group(), - self.stage_index_to_group_rank[self.stage_index - 1], - ), - group=self.group, - device=self.device, - use_batch=True, + self._stage_meta.inputs = recv_meta.forward_metas + inference_args = tuple(self._to_tensor(m) for m in recv_meta.forward_metas) + + inference_kwargs = { + k: self._to_tensor(v) if isinstance(v, torch.Tensor) else v + for k, v in kwargs.items() + } + + # Isolate metadata inference from user's grad context. + # has_backward → enable_grad() so backward tracing sees grad_fn; + # no backward → no_grad() for cross-rank consistency. + ctx = torch.enable_grad() if has_backward else torch.no_grad() + with ctx: + outputs = self._compute_outputs( + *inference_args, module=self.submod, **inference_kwargs ) - recv_args = objects[0] - if not isinstance(recv_args, tuple): - raise AssertionError(f"Expected tuple, got {type(recv_args)}") - args = recv_args - - # cache input shapes for use during recv buffer allocation - self.inputs_meta = args - args = tree_map_only( - torch.Tensor, lambda x: torch.zeros_like(x, device=self.device), args - ) - # set attributes needed for forward - with torch.no_grad(): - outputs = self.submod(*args, **kwargs) + # Normalize outputs to tuple + outputs = validate_and_normalize_to_tuple(outputs) - # if single tensor, convert so it is always a list - if isinstance(outputs, torch.Tensor): - outputs = [outputs] + self._stage_meta.outputs = extract_tensor_metas(outputs) - # communicate meta outputs not real outputs for two reasons - # 1 - its faster (esp. since obj coll pickles tensor data!) - # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! - outputs_meta = tuple( - tree_map_only(torch.Tensor, lambda x: x.to("meta"), outputs) - ) - logger.debug( - "Shape inference: stage %s inputs %s, outputs %s", - self.stage_index, - self.inputs_meta, - outputs_meta, + # Store for backward metadata inference (always, even during eval) + fwd_kwargs_tensors = tuple( + v for v in flatten_args(inference_kwargs) if isinstance(v, torch.Tensor) ) - self._configure_outputs_meta(outputs_meta) - - # Passing outputs to the next stage: - # two cases- - # 1. Usually: use send/recv communication to pass the output - # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) - # pass their shape info via return value and function args rather than send/recv. - if ( - self.is_last - # if not last stage, then check if next stage is on the same rank - or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank - ): - # Case (2) above: pass shape info via return value and caller passes it as args to next stage's - # _shape_inference call - logger.debug( - "Shape inference: stage %s skipping send to next stage", - self.stage_index, + self._fwd_outputs_for_bwd_meta = outputs + self._fwd_inputs_for_bwd_meta = inference_args + self._fwd_kwargs_tensors_for_bwd_meta = fwd_kwargs_tensors + + # === SEND: Pass output metadata to next stage === + if self._stage_meta.outputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output metadata is required for forward inference." ) + fwd_meta = _StageForwardMeta(forward_metas=self._stage_meta.outputs) + if self.is_last or self._is_same_rank(self.stage_index + 1): + # Same-rank or last: return for caller to pass + return fwd_meta else: - # Case (1): send shapes via send operation, and ensure not to return it to the caller - logger.debug( - "Shape inference: stage %s sending to stage %s", - self.stage_index, - self.stage_index + 1, + # Cross-rank: send via P2P + self._send_meta(fwd_meta, self.stage_index + 1) + return None + + def _backward_metadata_inference( + self, + loss_fn: Callable[..., torch.Tensor] | None = None, + target: torch.Tensor | None = None, + received_grad_meta: _StageBackwardMeta | None = None, + ) -> _StageBackwardMeta | None: + """Run backward metadata inference (Stage N → 0). + + Args: + loss_fn: Loss function (required for the last stage). + target: Target tensor (required for the last stage). + received_grad_meta: Grad metadata from next same-rank stage + (V-schedule only). + + Returns: + ``_StageBackwardMeta`` for the previous stage, or ``None`` if sent via P2P. + """ + fwd_outputs = self._fwd_outputs_for_bwd_meta + fwd_inputs = self._fwd_inputs_for_bwd_meta + if fwd_outputs is None or fwd_inputs is None: + raise PipeliningMetadataError( + "Backward metadata inference requires forward metadata inference to run first" ) - dist.send_object_list( - [outputs_meta], - dst=dist.get_global_rank( - self.group or dist.distributed_c10d._get_default_group(), - self.stage_index_to_group_rank[self.stage_index + 1], - ), - group=self.group, - device=self.device, - use_batch=True, + kwargs_tensors = self._fwd_kwargs_tensors_for_bwd_meta or () + all_fwd_inputs = list(fwd_inputs) + list(kwargs_tensors) + # Clear temporary storage early — local refs are sufficient from here + self._fwd_outputs_for_bwd_meta = None + self._fwd_inputs_for_bwd_meta = None + self._fwd_kwargs_tensors_for_bwd_meta = None + # === RECEIVE: Get output grad metadata (except last stage) === + if self.is_last: + if loss_fn is None or target is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: loss_fn and target required for last stage" + ) + inference_target = self._to_tensor(target) + loss = loss_fn( + fwd_outputs[0] if len(fwd_outputs) == 1 else fwd_outputs, + inference_target, + ) + self._stage_meta.output_grads = None + all_input_grads = self._compute_input_grads( + [loss], + all_fwd_inputs, ) - outputs_meta = tuple() + else: + # Non-last stage: receive grad metadata from next stage + if self._is_same_rank(self.stage_index + 1): + # Same-rank: _StageBackwardMeta passed via argument + if not isinstance(received_grad_meta, _StageBackwardMeta): + raise PipeliningMetadataError( + f"Stage {self.stage_index}: Expected _StageBackwardMeta from same-rank " + f"next stage, got {type(received_grad_meta).__name__}." + ) + self._stage_meta.output_grads = received_grad_meta.backward_metas + else: + # Cross-rank: receive _StageBackwardMeta via P2P + recv_meta = self._recv_meta(self.stage_index + 1) + if not isinstance(recv_meta, _StageBackwardMeta): + raise PipeliningMetadataError( + f"Stage {self.stage_index}: Expected _StageBackwardMeta from P2P, " + f"got {type(recv_meta).__name__}." + ) + self._stage_meta.output_grads = recv_meta.backward_metas + + # === COMPUTE: Build grad_outputs and compute input grads === + # Extract output tensors and corresponding grad_outputs from metadata + # Must iterate together to maintain alignment + if self._stage_meta.output_grads is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output_grads metadata is required for backward inference." + ) + stage_output_grad_metas = self._stage_meta.output_grads + + filtered_fwd_outputs: list[torch.Tensor] = [] + filtered_output_grads: list[torch.Tensor | None] = [] + + for idx, (fwd_out, grad_meta) in enumerate( + zip(fwd_outputs, stage_output_grad_metas, strict=True) + ): + # Match _backward.py behavior: skip if output doesn't require grad AND has no grad_fn + if not fwd_out.requires_grad: + if grad_meta is not None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output {idx} requires_grad=False, " + f"but output_grads metadata is provided: {grad_meta}." + ) + continue + filtered_fwd_outputs.append(fwd_out) + # For outputs that require grad, include them even if grad_meta is None + # (runtime passes None grad_outputs to autograd.backward in this case) + filtered_output_grads.append( + self._ones_from_metadata(grad_meta) if grad_meta else None + ) - return outputs_meta + if filtered_fwd_outputs: + all_input_grads = self._compute_input_grads( + filtered_fwd_outputs, all_fwd_inputs, filtered_output_grads + ) + # Free intermediate references early + filtered_fwd_outputs.clear() + filtered_output_grads.clear() + all_fwd_inputs.clear() + # Only positional input grads flow to previous stage + else: + all_input_grads = tuple(None for _ in range(len(all_fwd_inputs))) + + input_grads = all_input_grads[: len(fwd_inputs)] + self._stage_meta.input_grads = tuple( + extract_tensor_meta(g) if isinstance(g, torch.Tensor) else None + for g in input_grads + ) + + # === SEND: Pass input grad metadata to previous stage === + bwd_meta = _StageBackwardMeta(backward_metas=self._stage_meta.input_grads) + + if self.is_first or self._is_same_rank(self.stage_index - 1): + # First rank or Same-rank: return for caller to pass + return bwd_meta + else: + # Cross-rank: send via P2P + self._send_meta(bwd_meta, self.stage_index - 1) + return None + + def _post_metadata_inference_cleanup(self) -> None: + """Clean up FSDP side effects (unsharded params, stale grads, stored + tensors) after metadata inference with real tensors. + """ + # Clear stored inference tensors (frees autograd graph + activations) + self._fwd_outputs_for_bwd_meta = None + self._fwd_inputs_for_bwd_meta = None + self._fwd_kwargs_tensors_for_bwd_meta = None + + # Metadata inference runs real fwd/bwd, which unshards FSDP params and + # accumulates grads. Reshard to free memory and clear stale grads. + for module in self.submod.modules(): + if isinstance(module, FSDPModule): + module.reshard() + for param in module.parameters(): + param.grad = None + + def _prepare_backward_infra( + self, + num_microbatches: int, + loss_fn: Callable[..., torch.Tensor] | None = None, + target: torch.Tensor | None = None, + received_grad_meta: "_StageBackwardMeta | None" = None, + ) -> "_StageBackwardMeta | None": + """Run backward metadata inference and prepare backward infrastructure. + + Returns: + ``_StageBackwardMeta`` for the previous same-rank stage, or ``None``. + """ + grad_meta_result: _StageBackwardMeta | None = None + if self._inference_mode == InferenceMode.DYNAMIC: + # DYNAMIC mode: run backward metadata inference + # received_grad_meta is used for same-rank V-schedule stages + grad_meta_result = self._backward_metadata_inference( + loss_fn=loss_fn, + target=target, + received_grad_meta=received_grad_meta, + ) + # Validate dynamically inferred metadata against user-provided metadata + self._validate_inferred_metadata() + else: + # STATIC mode: metadata comes from user inputs, no validation needed + self._stage_meta.input_grads = self._user_meta.input_grads + self._stage_meta.output_grads = self._user_meta.output_grads + # For STATIC mode with plain tensors, if output_grads is not set but + # we have outputs, derive output_grads from outputs. + # (gradient shape == output shape, but requires_grad=False for gradients) + if self._stage_meta.output_grads is None: + if self._stage_meta.outputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output metadata is required for backward inference." + ) + self._stage_meta.output_grads = _derive_grad_metas( + self._stage_meta.outputs + ) + # Similarly, derive input_grads from inputs if not provided + if self._stage_meta.input_grads is None: + if self._stage_meta.inputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: input metadata is required for backward inference." + ) + self._stage_meta.input_grads = _derive_grad_metas( + self._stage_meta.inputs + ) + + # Note: grad_send_info is created lazily in get_bwd_send_ops() since + # it mirrors args_recv_info (already populated during forward). + self._setup_backward_recv_info(num_microbatches) + return grad_meta_result + + def _validate_inferred_metadata(self) -> None: + """Validate dynamically inferred metadata against user-provided metadata.""" + pairs = [ + (self._user_meta.inputs, self._stage_meta.inputs, "input"), + (self._user_meta.outputs, self._stage_meta.outputs, "output"), + (self._user_meta.input_grads, self._stage_meta.input_grads, "input_grad"), + ( + self._user_meta.output_grads, + self._stage_meta.output_grads, + "output_grad", + ), + ] + for user_val, stage_val, label in pairs: + if user_val and stage_val: + validate_tensors_metadata( + f"Stage {self.stage_index} {label}", + user_val, + stage_val, + warn_on_mismatch=True, + ) def _prepare_forward_infra( self, num_microbatches: int, - args: tuple[Any, ...], + args: tuple[Any, ...] | _StageForwardMeta | None, kwargs: dict[str, Any] | None = None, - ) -> tuple[Any, ...]: - # TODO move self.device to an argument from step API (from its input tensors)? - if num_microbatches is None: - raise AssertionError("TODO fix num_microbatches") - - outputs: tuple[Any, ...] = tuple() - if self.inputs_meta is None: - outputs = self._shape_inference(args, kwargs) - - if self.inputs_meta is None: - raise AssertionError("Expected inputs_meta to be set after shape inference") - # Receive info during forward - # TODO: create args_recv_info lazily? (same needed for PipelineStage) + has_backward: bool = False, + ) -> _StageForwardMeta | None: + """Prepare the stage infrastructure for forward pass. + + Returns: + ``_StageForwardMeta`` for next stage (same-rank), or ``None`` if sent via P2P. + """ + if self._inference_mode is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: inference mode not set. " + f"Run warmup vote protocol first." + ) + + fwd_meta_output: _StageForwardMeta | None = None + + if self._inference_mode == InferenceMode.DYNAMIC: + # DYNAMIC mode: run forward metadata inference + # args may be _StageForwardMeta for same-rank V-schedule stages + fwd_meta_output = self._forward_metadata_inference( + args, kwargs, has_backward + ) + # Validate dynamically inferred metadata against user-provided metadata + self._validate_inferred_metadata() + # STATIC mode: metadata comes from user inputs, no validation needed + else: + self._stage_meta.inputs = self._user_meta.inputs + self._stage_meta.outputs = self._user_meta.outputs + + # Setup recv and send info + self._setup_forward_recv_info(num_microbatches, has_backward) + self._setup_forward_send_info() + + return fwd_meta_output + + def _setup_forward_recv_info( + self, num_microbatches: int, has_backward: bool + ) -> None: + """Setup receive info for forward pass.""" + if self._stage_meta.inputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: inputs metadata required for recv info." + ) for chunk_id in range(num_microbatches): - if not self.is_first: - # We assume that we always receive from stage - 1 - recv_infos = tuple( + if self.is_first: + # First stage: all inputs are root arguments (no recv needed) + self.args_recv_info[chunk_id] = tuple( _RecvInfo( - f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", - self.stage_index - 1, - _make_tensor_from_meta(inp, self.device), + input_name=f"root_input_{idx}", + source=None, + buffer=None, + tensor_meta=meta, + is_root_arg=True, ) - for inp in self.inputs_meta + for idx, meta in enumerate(self._stage_meta.inputs) ) - # In case there is backward pass, set requires_grad for receive buffers - if self.has_backward: - for r in recv_infos: - r.buffer.requires_grad_(True) - - self.args_recv_info[chunk_id] = recv_infos else: + # Non-first stages: receive from previous stage self.args_recv_info[chunk_id] = tuple( - _RootArgPlaceholder(i) for i in self.inputs_meta + _RecvInfo( + input_name=f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + source=self.stage_index - 1, + buffer=_make_tensor_from_meta(meta, self.device), + tensor_meta=meta, + ) + for meta in self._stage_meta.inputs ) - # Send info during forward for each activation - # only need the rank that is being sent to + def _setup_forward_send_info(self) -> None: + """Setup send info for forward pass.""" self.act_send_info: dict[int, list] = {} - - for idx in range(len(self.get_outputs_meta())): - # We assume we always send to stage + 1 - if not self.is_last: - self.act_send_info[idx] = [self.stage_index + 1] - else: - self.act_send_info[idx] = [] - - return outputs + if self._stage_meta.outputs is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: outputs metadata required for recv info." + ) + for idx in range(len(self._stage_meta.outputs)): + self.act_send_info[idx] = [self.stage_index + 1] if not self.is_last else [] def _create_grad_recv_info( self, act_send_info: dict, ) -> tuple[_RecvInfo, ...]: - grad_recv_info: tuple[_RecvInfo, ...] = () + grad_recv_infos: list[_RecvInfo] = [] if not self.is_last: + # Ensure output_grads metadata is available + if self._stage_meta.output_grads is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output_grads metadata is required for " + f"creating grad recv info. Ensure backward metadata is populated." + ) + # Receiving gradients from multiple sources is not supported # hence we only take the first destination - grad_recv_info = tuple( - _RecvInfo( - f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", - dst_list[0], - _make_tensor_from_meta(self.get_outputs_meta()[idx], self.device), + # Use a helper function to safely extract the metadata + output_grads = self._stage_meta.output_grads + for idx, dst_list in act_send_info.items(): + if dst_list is None: + raise PipeliningMetadataError( + f"Stage {self.stage_index}: output {idx} is not sent to any stage." + ) + src = dst_list[0] + grad_meta = output_grads[idx] + grad_recv_infos.append( + _RecvInfo( + input_name=f"recv_grad_for_{self.stage_index}_from_{src}", + source=src, + buffer=_make_tensor_from_meta(grad_meta, self.device) + if grad_meta + else None, + tensor_meta=grad_meta, + ) ) - for idx, dst_list in act_send_info.items() - ) - return grad_recv_info + return tuple(grad_recv_infos) diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py index adf901d6b6e3e..2b075f81b7d50 100644 --- a/torch/distributed/rpc/__init__.py +++ b/torch/distributed/rpc/__init__.py @@ -38,7 +38,7 @@ def is_available() -> bool: import torch.distributed.autograd as dist_autograd from torch._C._distributed_c10d import Store - from torch._C._distributed_rpc import ( # noqa: F401 + from torch._C._distributed_rpc import ( _cleanup_python_rpc_handler, _DEFAULT_INIT_METHOD, _DEFAULT_RPC_TIMEOUT_SEC, @@ -72,22 +72,23 @@ def is_available() -> bool: ) if _is_tensorpipe_available: - from torch._C._distributed_rpc import ( # noqa: F401 + from torch._C._distributed_rpc import ( _DEFAULT_NUM_WORKER_THREADS, _TensorPipeRpcBackendOptionsBase, TensorPipeAgent, ) from . import api, backend_registry, functions - from .api import * # noqa: F401,F403 + from .api import * # noqa: F403 from .backend_registry import BackendType - from .options import TensorPipeRpcBackendOptions # noqa: F401 + from .options import TensorPipeRpcBackendOptions from .server_process_global_profiler import _server_process_global_profile rendezvous_iterator: Generator[tuple[Store, int, int], None, None] __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] - __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 + # pyrefly: ignore [unresolvable-dunder-all] + __all__ = __all__ + api.__all__ + backend_registry.__all__ def init_rpc( name, diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py index 0cde066b2466b..1b5666ec6b005 100644 --- a/torch/distributed/rpc/internal.py +++ b/torch/distributed/rpc/internal.py @@ -226,8 +226,8 @@ def _handle_exception(result): exc = None try: exc = result.exception_type(exception_msg) - except BaseException as e: # noqa: B036 - raise RuntimeError( # noqa: B904 + except BaseException as e: + raise RuntimeError( f"Failed to create original exception type. Error msg was {str(e)}" f" Original exception on remote side was {exception_msg}" ) from e diff --git a/torch/distributed/rpc/rref_proxy.py b/torch/distributed/rpc/rref_proxy.py index 46eecf19e22c9..81b1d43b07dcf 100644 --- a/torch/distributed/rpc/rref_proxy.py +++ b/torch/distributed/rpc/rref_proxy.py @@ -53,13 +53,13 @@ def _rref_type_cont(rref_fut): def _wrap_rref_type_cont(fut): try: _rref_type_cont(fut).then(_complete_op) - except BaseException as ex: # noqa: B036 + except BaseException as ex: result.set_exception(ex) def _complete_op(fut): try: result.set_result(fut.value()) - except BaseException as ex: # noqa: B036 + except BaseException as ex: result.set_exception(ex) rref_fut.then(_wrap_rref_type_cont) diff --git a/torch/distributed/run.py b/torch/distributed/run.py index 16ef02495e205..d8efbe30f8490 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -380,7 +380,7 @@ def main(): if __name__ == "__main__": main() -""" # noqa: E501 +""" import os import sys @@ -747,7 +747,7 @@ def determine_local_world_size(nproc_per_node: str): return int(nproc_per_node) except ValueError as e: if nproc_per_node == "cpu": - num_proc = os.cpu_count() + num_proc = torch._utils.cpu_count() device_type = "cpu" elif nproc_per_node == "gpu": if not torch.cuda.is_available(): @@ -769,7 +769,7 @@ def determine_local_world_size(nproc_per_node: str): num_proc = torch.accelerator.device_count() device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr] else: - num_proc = os.cpu_count() + num_proc = torch._utils.cpu_count() device_type = "cpu" else: raise ValueError( @@ -828,7 +828,7 @@ def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]: ) logger.info( - "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls) + "Using logs_spec '%s' mapped to %s", logs_specs_name, logs_specs_cls ) else: logs_specs_cls = DefaultLogsSpecs diff --git a/torch/distributed/tensor/__init__.py b/torch/distributed/tensor/__init__.py index 8e1be826be58f..32ba1aa8810cd 100644 --- a/torch/distributed/tensor/__init__.py +++ b/torch/distributed/tensor/__init__.py @@ -2,7 +2,7 @@ import torch import torch.distributed.tensor._ops # force import all built-in dtensor ops -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor._api import ( distribute_module, distribute_tensor, @@ -15,6 +15,7 @@ zeros, ) from torch.distributed.tensor.placement_types import ( + _StridedShard, Partial, Placement, Replicate, @@ -63,6 +64,7 @@ Partial, Replicate, Shard, + _StridedShard, ] ) diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 370bccb9a0bcd..00b984422926d 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -2,6 +2,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates import copy +import hashlib import inspect import warnings from collections.abc import Callable, Sequence @@ -149,13 +150,12 @@ def backward(ctx, grad_output: torch.Tensor | None): # type: ignore[override] ), ) return ( - # pyrefly: ignore [bad-argument-type] - DTensor( - # pyrefly: ignore [bad-argument-count] + DTensor.from_local( grad_output, - grad_spec, - # pyrefly: ignore [unexpected-keyword] - requires_grad=grad_output.requires_grad, + grad_spec.device_mesh, + grad_spec.placements, + shape=grad_spec.shape, + stride=grad_spec.stride, ), None, ) @@ -360,7 +360,12 @@ def __tensor_flatten__(self): protocol to inform how to flatten a DTensor to local tensor for PT2 tracing """ - return ["_local_tensor"], (self._spec, self.requires_grad) + return ["_local_tensor", "device_mesh"], ( + self._spec.placements, + self._spec.tensor_meta, + self._spec.shard_order, + self.requires_grad, + ) @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): @@ -369,16 +374,18 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): "Expecting spec to be not None from `__tensor_flatten__` return value!" ) local_tensor = inner_tensors["_local_tensor"] - spec, requires_grad = flatten_spec + mesh = inner_tensors["device_mesh"] + placements, old_tensor_meta, shard_order, requires_grad = flatten_spec unflatten_tensor_meta = TensorMeta( shape=outer_size, stride=outer_stride, - dtype=spec.tensor_meta.dtype, + dtype=old_tensor_meta.dtype, ) unflatten_spec = DTensorSpec( - spec.mesh, - spec.placements, + mesh, + placements, tensor_meta=unflatten_tensor_meta, + shard_order=shard_order, ) # pyrefly: ignore [bad-argument-type] return DTensor( @@ -389,6 +396,15 @@ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): requires_grad=requires_grad, ) + def _stable_hash_for_caching(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + # Combine spec's stable hash with requires_grad + cache_data = self._spec._stable_hash() + str(self.requires_grad) + return hashlib.blake2b(cache_data.encode(), digest_size=16).hexdigest() + def __coerce_tangent_metadata__(self): if not any(isinstance(p, Partial) for p in self.placements): return self @@ -401,22 +417,34 @@ def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): if expected_type is not None: return None - (spec, _) = flatten_spec # Result of tensor_flatten() + (placements, _, _, _) = flatten_spec return self.redistribute( device_mesh=self.device_mesh, - placements=spec.placements, + placements=placements, ) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] - # We just need to have an implementation here; the __torch_dispatch__ machinery - # calls into a specific C++ fast path that doesn't call here. - # See #167051 for details - # python_arg_parser.cpp: dispatch_on_subclass() - # -> python_variable.cpp: dispatchDTensorOp() - raise NotImplementedError( - "DTensor.__torch_dispatch__ should not actually get called" - ) + # Base DTensor is normally dispatched via a C++ fast path (see #167051) + # and never reaches here. This implementation exists so that DTensor + # subclasses can delegate back via super().__torch_dispatch__(). + # It unwraps subclass instances to base DTensor and re-calls the op, + # which re-enters dispatch and hits the C++ fast path. + def unwrap(t): + if isinstance(t, DTensor) and type(t) is not DTensor: + # pyrefly: ignore [bad-argument-type] + return DTensor( + # pyrefly: ignore [bad-argument-count] + t._local_tensor, + t._spec, + # pyrefly: ignore [unexpected-keyword] + requires_grad=t.requires_grad, + ) + return t + + args = torch.utils._pytree.tree_map(unwrap, args) + kwargs = torch.utils._pytree.tree_map(unwrap, kwargs or {}) + return func(*args, **kwargs) @staticmethod def from_local( @@ -632,8 +660,9 @@ def redistribute( Returns: A :class:`DTensor` object - .. note:: ``redistribute`` is differentiable, which means user do not need to worry about - the backward formula of the redistribute operation. + .. note:: ``redistribute`` is twice-differentiable, which means user do not need to worry about + the backward formula of the redistribute operation, or its compatibility with autograd for + second-order gradients. Higher-order differentiation has not been tested (but may work). .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, Please file an issue if you need to redistribute DTensor to different DeviceMesh. @@ -777,19 +806,6 @@ def __get_tensor_shard__(self, index): else: raise RuntimeError("Unsupported tensor type!") - @classmethod - def __metadata_guard__( - cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool] - ) -> bool: - # TODO - delete this - This is now unused after the PR - - # https://github.com/pytorch/pytorch/pull/165824 - orig_spec, orig_requires_grad = orig - other_spec, other_requires_grad = other - return ( - orig_spec._check_equals(other_spec, skip_shapes=True) - and orig_requires_grad == other_requires_grad - ) - def distribute_tensor( tensor: torch.Tensor, diff --git a/torch/distributed/tensor/_collective_utils.py b/torch/distributed/tensor/_collective_utils.py index a779190f25cfa..e26d7d7abc041 100644 --- a/torch/distributed/tensor/_collective_utils.py +++ b/torch/distributed/tensor/_collective_utils.py @@ -10,31 +10,67 @@ import torch.distributed.tensor._dtensor_spec as dtensor_spec from torch._C._distributed_c10d import _resolve_process_group from torch._logging import warning_once +from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed._local_tensor import ( local_tensor_mode, maybe_run_for_local_tensor, ) from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.distributed_c10d import ( - _get_group_size_by_name, broadcast, get_group_rank, get_rank, + GroupName, ProcessGroup, scatter, Work, ) +from torch.fx.experimental.symbolic_shapes import guard_or_false +from torch.types import IntLikeType logger = logging.getLogger(__name__) +# Opaque types must be registered before defining schemas that reference them, +# so the schema parser recognizes the type names and uses PyObjectType (which +# wraps Python objects as ConcretePyObjectHolder) instead of AnyType (which +# calls toTypeInferredIValue and fails for Python-only opaque types). +from torch.distributed.device_mesh import _register_distributed_opaque_types + + +_register_distributed_opaque_types() + +_dtensor_lib = torch.library.Library("_dtensor", "FRAGMENT") +_dtensor_lib.define( + "mesh_get_process_group(" + "torch.distributed.device_mesh.DeviceMesh mesh, int dim" + ") -> torch.distributed.distributed_c10d.ProcessGroup" +) + + +@torch.library.impl("_dtensor::mesh_get_process_group", "CompositeExplicitAutograd") +def _mesh_get_process_group_impl(mesh, dim): + return mesh.get_group(dim) + + +@torch.library.register_fake("_dtensor::mesh_get_process_group") +def _mesh_get_process_group_fake(mesh, dim): + from torch._library.fake_class_registry import maybe_unwrap_fake_script_object + + real_mesh = maybe_unwrap_fake_script_object(mesh) + return real_mesh.get_group(dim) + @torch.library.register_fake("_dtensor::shard_dim_alltoall") -def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): - group_size = _get_group_size_by_name(group_name) +def _shard_dim_alltoall_meta( + input, gather_dim, shard_dim, group_name: GroupName | ProcessGroup +): + if isinstance(group_name, str): + # pyrefly: ignore[bad-argument-type] # pyrefly bug + group_name = _resolve_process_group(group_name) + group_size = group_name.size() stacked_list = [torch.empty_like(input) for _ in range(group_size)] - group = _resolve_process_group(group_name) - group_rank = get_group_rank(group, get_rank()) + group_rank = get_group_rank(group_name, get_rank()) cat_tensor = torch.cat(stacked_list, dim=gather_dim) # pyrefly: ignore [unsupported-operation] @@ -61,10 +97,10 @@ def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): ] return out.contiguous() - group_name = funcol._resolve_group_name((mesh, mesh_dim)) + group = funcol._resolve_group((mesh, mesh_dim)) # TODO: enable async op for shard_dim_alltoall return torch.ops._dtensor.shard_dim_alltoall( - input, gather_dim, shard_dim, group_name + input, gather_dim, shard_dim, funcol._group_or_group_name(group) ) @@ -176,21 +212,33 @@ def mesh_broadcast( @maybe_run_for_local_tensor -def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if guard_or_false(pad_size == 0): +def pad_tensor( + tensor: torch.Tensor, pad_dim: int, pad_size: IntLikeType +) -> torch.Tensor: + # During tracing, always emit the pad op even when pad_size=0 so all + # ranks produce identical FX graph structure (SPMD). + # In eager with concrete pad_size=0, guard_or_false returns True and we + # skip the no-op pad. Check _are_we_tracing() first to avoid + # guard_or_false creating a guard that concretizes symbolic pad sizes + # during make_fx tracing. + if not _are_we_tracing() and guard_or_false(pad_size == 0): return tensor pad = [0, 0] * (tensor.ndim - pad_dim) - pad[-1] = pad_size + pad[-1] = pad_size # pyrefly: ignore[unsupported-operation] return torch.nn.functional.pad(tensor, pad) @maybe_run_for_local_tensor -def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: - from torch.fx.experimental.symbolic_shapes import guard_or_false - - if guard_or_false(pad_size == 0): +def unpad_tensor( + tensor: torch.Tensor, pad_dim: int, pad_size: IntLikeType +) -> torch.Tensor: + # During tracing, always emit the narrow op even when pad_size=0 so all + # ranks produce identical FX graph structure (SPMD). + # In eager with concrete pad_size=0, guard_or_false returns True and we + # skip the no-op narrow. Check _are_we_tracing() first to avoid + # guard_or_false creating a guard that concretizes symbolic pad sizes + # during make_fx tracing. + if not _are_we_tracing() and guard_or_false(pad_size == 0): return tensor return tensor.narrow( pad_dim, @@ -349,6 +397,8 @@ def _compute_placement_transition_cost( num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + # NOTE: is_shard() does not match _StridedShard; see _is_shard_like(). + # Safe today: redistribute_cost bails with inf when shard_order is None. if current_placement.is_shard() and target_placement.is_replicate(): # allgather gives larger comm bytes comm_bytes_gb *= num_devices_on_mesh_dim diff --git a/torch/distributed/tensor/_decompositions.py b/torch/distributed/tensor/_decompositions.py index b1f28c82a7359..eefffa6df5a93 100644 --- a/torch/distributed/tensor/_decompositions.py +++ b/torch/distributed/tensor/_decompositions.py @@ -236,9 +236,7 @@ def propagate_strategy( n_outputs = len(output_placements) strategy_schema = self.sharding_prop._wrap_with_op_strategy(op_schema) # Import here to avoid circular import at module load time - from torch.distributed.tensor._ops.utils import ( # noqa: F811 - expand_to_full_mesh_op_strategy, - ) + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy return expand_to_full_mesh_op_strategy( mesh, strategy_schema, single_dim_strategies, input_index=n_outputs @@ -326,5 +324,5 @@ def _get_candidate_placements( options |= {Shard(i) for i in range(spec.ndim)} candidates.append(list(options)) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return list(itertools.product(*candidates)) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 27272d2238b40..3dc53af1c5b29 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -7,7 +7,6 @@ import torch import torch.distributed as dist -import torch.distributed.config as dist_config import torch.distributed.tensor._api as dtensor import torch.distributed.tensor._random as random from torch._library.utils import fill_defaults @@ -372,10 +371,6 @@ def _dispatch_get_local_results_slow_path( ) else: # CUDA device without user generator, use HOP for traceability - if dist_config.compile_on_one_rank: - raise NotImplementedError( - "run_dtensor_rng_op is not yet compatible with compile_on_one_rank" - ) if not isinstance( random._rng_tracker, random.OffsetBasedRNGTracker ): @@ -395,7 +390,17 @@ def _dispatch_get_local_results_slow_path( local_results = op_call(*local_tensor_args, **op_info.local_kwargs) else: # normal case, run local sharded op computation - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + if ( + output_sharding.needs_redistribute + and output_sharding.redistribute_schema is not None + and output_sharding.redistribute_schema.op != op_call + ): + # Op was rewritten (e.g., squeeze.default → squeeze.dims) + local_results = output_sharding.redistribute_schema.op( + *local_tensor_args, **op_info.local_kwargs + ) + else: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) else: # For a non-participating device (happens on rank that does not belong to @@ -486,31 +491,17 @@ def _dispatch_fast_path_python_tail( if not isinstance(args[0], dtensor.DTensor): raise AssertionError - # NOTE: aten.squeeze_.dim is an inplace op but it also may change - # the inplace argument's tensor meta. Here we choose to special case - # this op because as far as I know this is the only inplace op that - # has such as behavior. We can extend this special case if necessary. - if op_call == aten.squeeze_.dim: - # update the spec to handle tensor meta changes - args[0]._spec = output_spec - # use return_and_correct_aliasing to match the outer and the inner - # aliasing. See https://github.com/pytorch/pytorch/pull/158954 - return return_and_correct_aliasing(op_call, args, kwargs, args[0]) - else: - # For all other inplace ops, check if placement changes are required - # Inplace operations that change placement are not supported because - # they would require redistribution, which breaks aliasing semantics. - # If there are views into the tensor, the views would not be updated. - if args[0]._spec.placements != output_spec.placements: - raise RuntimeError( - f"{op_call}: in-place operations that require placement changes " - f"are not supported. The operation would change placement from " - f"{args[0]._spec.placements} to {output_spec.placements}, " - f"which requires redistribution and breaks aliasing semantics. " - f"Please use the out-of-place version of this operation instead." - ) - # Most inplace ops don't change tensor meta, so no spec update needed + # Fast path: placements unchanged (common case: add_, mul_, etc.) + if args[0]._spec.placements == output_spec.placements: return args[0] + + # Placement reindexed (e.g. squeeze_ removing a non-sharded + # dim: Shard(1) → Shard(0)). No redistribution — the local + # tensor data is unchanged, only dim indices shift. + # strict_view=True in sharding prop prevents the illegal + # case (squeezing a sharded dim) from reaching here. + args[0]._spec = output_spec + return return_and_correct_aliasing(op_call, args, kwargs, args[0]) else: return None elif is_out_variant_op: @@ -598,6 +589,13 @@ def redistribute_local_args( else: new_local_args.append(arg_spec) + # Append extra non-tensor args from rewritten schema (e.g., dims tuple). + if use_val_from_redistribute_schema: + for i in range( + len(op_info.flat_args_schema), len(flatten_args_schema_to_reshard) + ): + new_local_args.append(flatten_args_schema_to_reshard[i]) + op_info.local_args = tuple(new_local_args) def unwrap_to_op_info( diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index d1f7125bebb25..d2456be1c7862 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -1,3 +1,4 @@ +import hashlib import itertools import math from collections import defaultdict @@ -7,6 +8,7 @@ import torch from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( + _is_shard_like, _MaskPartial, _StridedShard, Partial, @@ -18,6 +20,13 @@ from torch.utils._dtype_abbrs import dtype_abbrs +# Defined here (not in placement_types.py) because decoding split_factor into +# a shard order is a DTensorSpec concern — placement_types doesn't know about +# shard orders. +class _StridedShardNotDecodableError(ValueError): + """Raised when _StridedShard split_factor cannot be decoded into a shard order.""" + + class ShardOrderEntry(NamedTuple): """ Represents how a single tensor dimension is sharded across mesh dimensions. @@ -129,10 +138,9 @@ def _normalize_placements_into_shard_order( placements, mesh ) if shard_order is None: - raise ValueError( - "use_strided_shard_as_shard_order is True, but placements: " - f"{placements} is unable to be interpreted into a corresponding " - "shard_order" + raise _StridedShardNotDecodableError( + f"_StridedShard placements {placements} cannot be decoded " + "into a corresponding shard_order" ) normalized_placements = tuple( [ @@ -162,7 +170,7 @@ def compute_default_shard_order( tensor_dim_to_mesh_dims: defaultdict[int, list[int]] = defaultdict(list) mesh_ndim = len(placements) for mesh_dim in range(mesh_ndim): - if isinstance(placements[mesh_dim], Shard | _StridedShard): + if _is_shard_like(placements[mesh_dim]): placement = placements[mesh_dim] shard_dim = placement.dim # pyrefly: ignore [missing-attribute] if shard_dim < 0: @@ -313,9 +321,7 @@ def _maybe_convert_StridedShard_to_shard_order( """ if not any(isinstance(p, _StridedShard) for p in placements): return DTensorSpec.compute_default_shard_order(placements) - max_tensor_dim = ( - max([i.dim for i in placements if isinstance(i, Shard | _StridedShard)]) + 1 - ) + max_tensor_dim = max([i.dim for i in placements if _is_shard_like(i)]) + 1 shard_order = [] tensor_dim_to_mesh_dims_order: list[list[int]] = [ @@ -323,8 +329,7 @@ def _maybe_convert_StridedShard_to_shard_order( ] for mesh_dim in reversed(range(len(placements))): cur_placement = placements[mesh_dim] - # _StridedShard may not be a subclass of Shard in the future, so write in this way: - if isinstance(cur_placement, Shard | _StridedShard): + if _is_shard_like(cur_placement): tensor_dim = cur_placement.dim mesh_dims_order = tensor_dim_to_mesh_dims_order[tensor_dim] cur_sf = 1 @@ -362,6 +367,8 @@ def _verify_shard_order(self, shard_order: ShardOrder) -> None: """Verify that the shard_order is valid and matches the placements.""" total_shard = 0 if any(isinstance(p, _StridedShard) for p in self.placements): + # _StridedShard shard_order validation not yet supported; + # the Shard-only checks below (line 390, 394) would fail. return prev_tensor_dim = -1 for entry in shard_order: @@ -415,24 +422,26 @@ def __setattr__(self, attr: str, value: Any) -> None: if not isinstance(value, TensorMeta | TensorMetadata): raise AssertionError(repr(value)) + def _hash_key(self) -> tuple[Any, ...]: + """Return the tuple used for hashing. Used by both __hash__ and _stable_hash.""" + if self.tensor_meta is not None: + return ( + self.mesh, + self.placements, + self.shard_order, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + return (self.mesh, self.placements, self.shard_order) + def _hash_impl(self) -> int: # hashing and equality check for DTensorSpec are used to cache the sharding # propagation results. We only need to consider the mesh, placements, shape # dtype and stride. # Caveat: we need to keep this in mind and sync hash and eq if we add more # fields to them. - if self.tensor_meta is not None: - return hash( - ( - self.mesh, - self.placements, - self.shard_order, - self.tensor_meta.shape, - self.tensor_meta.stride, - self.tensor_meta.dtype, - ) - ) - return hash((self.mesh, self.placements, self.shard_order)) + return hash(self._hash_key()) def __hash__(self) -> int: # We lazily cache the spec to avoid recomputing the hash upon each @@ -443,6 +452,17 @@ def __hash__(self) -> int: self._hash = self._hash_impl() return self._hash + def _stable_hash(self) -> str: + """ + Return a stable hash for AOT autograd caching. + [See note: Tensor subclass stable hashing for AOT autograd cache] + """ + # Get hash key, but replace mesh with its stable hash + key = self._hash_key() + # First element is mesh, replace with its stable hash + stable_key = (self.mesh._stable_hash(),) + key[1:] + return hashlib.blake2b(repr(stable_key).encode(), digest_size=16).hexdigest() + def _check_equals(self, other: object, skip_shapes: bool = False) -> bool: if not ( isinstance(other, DTensorSpec) @@ -544,7 +564,7 @@ def format_shard_order_str( # native dtensor-style sharding representation: map from mesh # dim to tensor dim for mesh_dim, placement in enumerate(placements): - if isinstance(placement, (Shard, _StridedShard)): + if _is_shard_like(placement): if shard_order is not None: for entry in shard_order: tensor_dim = entry.tensor_dim @@ -588,7 +608,7 @@ def ndim(self) -> int: def num_shards(self) -> int: num_shards = 1 for i, placement in enumerate(self.placements): - if placement.is_shard(): + if _is_shard_like(placement): num_shards *= self.mesh.size(i) return num_shards @@ -624,8 +644,8 @@ def dim_map(self) -> list[int]: # and int >=0 represent shard on that device mesh dim r = [-1] * self.ndim for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim + if _is_shard_like(placement): + shard_dim = placement.dim if r[shard_dim] > -1: raise ValueError( f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," @@ -652,9 +672,8 @@ def num_shards_map(self) -> list[int]: """ r = [1] * self.ndim for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - r[shard_dim] *= self.mesh.size(i) + if _is_shard_like(placement): + r[placement.dim] *= self.mesh.size(i) return r @@ -703,7 +722,7 @@ def from_dim_map( for i, m in enumerate(dim_map): if m >= 0: placement = placements[m] - if placement.is_shard(): + if placement.is_shard(): # dim_map only produces Shard placements placement = cast(Shard, placement) raise RuntimeError( f"DeviceMesh dimension can't be mapped to two dimension of the same tensor: {i} and {placement.dim}" @@ -724,9 +743,9 @@ def is_replicated(self) -> bool: def is_sharded(self) -> bool: """ - return True if the current DTensorSpec uses Shard() placement on any mesh dims (devices) + return True if the current DTensorSpec uses Shard() or _StridedShard() placement on any mesh dims (devices) """ - return any(placement.is_shard() for placement in self.placements) + return any(_is_shard_like(placement) for placement in self.placements) def shallow_copy_with_tensor_meta( self, tensor_meta: TensorMeta | None @@ -740,4 +759,5 @@ def shallow_copy_with_tensor_meta( self.mesh, self.placements, tensor_meta=tensor_meta, + use_strided_shard_as_shard_order=self.use_strided_shard_as_shard_order, ) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index b1aca77f41323..d413c67169171 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -436,6 +436,9 @@ def convert_to_meta(item): return item.tensor_meta elif isinstance(item, TupleStrategy): return tuple(convert_to_meta(child) for child in item.children) + elif isinstance(item, (list, tuple)): + converted = [convert_to_meta(child) for child in item] + return type(item)(converted) else: return item @@ -450,6 +453,9 @@ def convert_to_meta(item): return item.tensor_meta elif isinstance(item, TupleStrategy): return tuple(convert_to_meta(child) for child in item.children) + elif isinstance(item, (list, tuple)): + converted = [convert_to_meta(child) for child in item] + return type(item)(converted) else: return item diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index b0bce583b4ecd..7e47bd552313d 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -1,10 +1,21 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # implement matrix related ops for distributed tensor +from typing import Any + import torch from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._op_schema import ( + OpSchema, + OutputSharding, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops.single_dim_strategy import ( + _ShardingPlaceholder, + register_single_dim_strategy, +) from torch.distributed.tensor._ops.utils import register_prop_rule +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate aten = torch.ops.aten @@ -138,3 +149,54 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: # for a certain output Tensor. This also applies to the conv handler # in torch/distributed/tensor/_tp_conv.py return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) + + +# Single-dim strategies for autoparallel optimizer support. +# These coexist with the prop_rules above — strategies take precedence +# in the propagation path, while the prop_rules + custom handlers in +# _tp_conv.py continue to handle runtime dispatch. + + +@register_single_dim_strategy( + [aten.convolution.default], + schema_info=RuntimeSchemaInfo(2), +) +def convolution_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + bias_meta = args_schema[2] + # [output, input, weight, (bias)] + rule: list[Placement | _ShardingPlaceholder] = [ + _ShardingPlaceholder(0), # output + _ShardingPlaceholder(0), # input + Replicate(), # weight + ] + if bias_meta is not None: + rule.append(Replicate()) # bias + return [rule] + + +@register_single_dim_strategy( + [aten.convolution_backward.default], + schema_info=RuntimeSchemaInfo(3), +) +def convolution_backward_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder | None]]: + bias_sizes = args_schema[3] + has_bias = bias_sizes is not None + # outputs: [grad_input, grad_weight, grad_bias] + # inputs: [grad_output, input, weight] + rule: list[Placement | _ShardingPlaceholder | None] = [ + _ShardingPlaceholder(0), # grad_input + Partial("sum"), # grad_weight + Partial("sum") if has_bias else None, # grad_bias + _ShardingPlaceholder(0), # grad_output + _ShardingPlaceholder(0), # input + Replicate(), # weight + ] + return [rule] diff --git a/torch/distributed/tensor/_ops/_embedding_ops.py b/torch/distributed/tensor/_ops/_embedding_ops.py index 6e701caf95235..a0051c0b77a3a 100644 --- a/torch/distributed/tensor/_ops/_embedding_ops.py +++ b/torch/distributed/tensor/_ops/_embedding_ops.py @@ -1,23 +1,15 @@ -# mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -# implement matrix related ops for distributed tensor -from typing import cast - import torch -from torch.distributed.tensor._op_schema import ( - OpSchema, - OpStrategy, - PlacementList, - RuntimeSchemaInfo, - StrategyType, -) -from torch.distributed.tensor._ops.utils import ( - expand_to_full_mesh_op_strategy, - register_op_strategy, +from torch._ops import OpOverload +from torch.distributed.tensor._dtensor_spec import TensorMeta +from torch.distributed.tensor._op_schema import ArgsType, KwargsType, RuntimeSchemaInfo +from torch.distributed.tensor._ops.single_dim_strategy import ( + register_single_dim_strategy, ) from torch.distributed.tensor.placement_types import ( _MaskPartial, Partial, + Placement, Replicate, Shard, ) @@ -26,92 +18,80 @@ aten = torch.ops.aten -@register_op_strategy(aten.embedding.default) -def embedding_strategy(op_schema: OpSchema) -> StrategyType: - """ - This strategy handles embedding op. We have two possible embedding shardings: - rowwise and colwise +@register_single_dim_strategy(aten.embedding.default) +def embedding_strategy( + op: OpOverload, + args_schema: ArgsType, + kwargs_schema: KwargsType, +) -> list[list[Placement]]: + """Single-dim strategy for embedding: rowwise, colwise, and batch-dim sharding. + + Placement order: [output, weight, indices] """ - weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) - indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) - mesh = op_schema.get_mesh_from_args() - - weight_shape = weight_strategy.shape - indices_shape = indices_strategy.shape - output_emd_dim = len(indices_shape) - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate - colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) - - # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial - embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0) - - # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates - # from the input indices and use it for output reduction - rowwise_sharding: PlacementList = [ - embedding_partial_placement, - Shard(0), - embedding_partial_placement, - ] - single_mesh_dim_strategies.append(rowwise_sharding) - - # batch dim sharding, weight replicated, input can shard on any dim, output follows input - for input_dim in range(len(indices_shape)): - batch_sharding: PlacementList = [ - Shard(input_dim), - Replicate(), - Shard(input_dim), - ] - single_mesh_dim_strategies.append(batch_sharding) - - return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) - - -@register_op_strategy( + weight_meta = args_schema[0] + indices_meta = args_schema[1] + if not isinstance(weight_meta, TensorMeta) or not isinstance( + indices_meta, TensorMeta + ): + raise AssertionError + + # _MaskPartial hashes offset_shape, but torch.Size with SymInt (from + # dynamo tracing) is unhashable. Concretize to int, which adds a + # standard dynamo guard (recompiles if the shape changes at runtime). + weight_shape = torch.Size(int(s) for s in weight_meta.shape) + indices_shape = indices_meta.shape + output_emb_dim = len(indices_shape) + + strategies: list[list[Placement]] = [] + + # colwise: output shard on last dim, weight shard on dim 1, indices replicate + strategies.append([Shard(output_emb_dim), Shard(1), Replicate()]) + + # rowwise: output is MaskPartial, weight shard on dim 0, indices MaskPartial + # NOTE: same object for output & indices so the mask buffer is shared + embedding_partial = _MaskPartial(offset_shape=weight_shape, offset_dim=0) + strategies.append([embedding_partial, Shard(0), embedding_partial]) + + # batch dim sharding: weight replicated, indices shard on any dim, output follows + for i in range(len(indices_shape)): + strategies.append([Shard(i), Replicate(), Shard(i)]) + + return strategies + + +@register_single_dim_strategy( aten.embedding_dense_backward.default, schema_info=RuntimeSchemaInfo(static_argnum=2), ) -def embedding_dense_backward_strategy(op_schema: OpSchema) -> StrategyType: - """ - This strategy handles embedding op. We have two possible embedding shardings: - rowwise and colwise +def embedding_dense_backward_strategy( + op: OpOverload, + args_schema: ArgsType, + kwargs_schema: KwargsType, +) -> list[list[Placement]]: + """Single-dim strategy for embedding backward. + + Placement order: [output(weight_grad), grad_out, indices] """ - grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) - indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) - mesh = op_schema.get_mesh_from_args() - - grad_out_shape = grad_out_strategy.shape - indices_shape = indices_strategy.shape - grad_out_ndim = len(grad_out_shape) + grad_out_meta = args_schema[0] + indices_meta = args_schema[1] + if not isinstance(grad_out_meta, TensorMeta) or not isinstance( + indices_meta, TensorMeta + ): + raise AssertionError - single_mesh_dim_strategies = [] + grad_out_ndim = len(grad_out_meta.shape) + indices_shape = indices_meta.shape - # placement list stores placements of [output, weight, input_indices] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) + strategies: list[list[Placement]] = [] - # colwise sharding backward, grad_out shard on last dim, input replicate, - # weight grad shard colwise - colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] - single_mesh_dim_strategies.append(colwise_sharding) + # colwise backward: weight grad shard on dim 1, grad_out shard on last dim, indices replicate + strategies.append([Shard(1), Shard(grad_out_ndim - 1), Replicate()]) - # batch dim sharding, weight replicated, grad_out/input have same sharding - # that can shard on any dim, weight grad partial - for input_dim in range(len(indices_shape)): - batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] - single_mesh_dim_strategies.append(batch_sharding) + # batch dim sharding: weight grad partial, grad_out/indices shard on same dim + for i in range(len(indices_shape)): + strategies.append([Partial(), Shard(i), Shard(i)]) - # grad_out partial, input replicate, weight grad keep partial - partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] - single_mesh_dim_strategies.append(partial_sharding) + # grad_out partial, indices replicate, weight grad partial + strategies.append([Partial(), Partial(), Replicate()]) - return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) + return strategies diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 020961b9d74f6..d6250cf89fef5 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -31,8 +31,12 @@ normalize_dims, register_op_strategy, ) -from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + normalize_to_torch_size, +) from torch.distributed.tensor.placement_types import ( + _is_shard_like, _StridedShard, Partial, Placement, @@ -159,7 +163,7 @@ def _replicate_dims_start_at( ) -> tuple[Placement, ...]: new_placements: list[Placement] = [] for p in placements: - if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + if p.is_partial() or (_is_shard_like(p) and p.dim >= start_dim): new_placements.append(Replicate()) # make it replicate else: new_placements.append(p) # keep the placement @@ -167,12 +171,16 @@ def _replicate_dims_start_at( # return new_placements which align with placements but skip the skipped_dim +# Precondition: no shard-like placement on skipped_dim (callers must +# replicate it first via replicate_reduction_dims). def _skip_dim( placements: tuple[Placement, ...], skipped_dim: int ) -> tuple[Placement, ...]: new_placements: list[Placement] = [] for p in placements: - if isinstance(p, Shard) and p.dim >= skipped_dim: + if isinstance(p, _StridedShard) and p.dim >= skipped_dim: + new_placements.append(_StridedShard(p.dim - 1, split_factor=p.split_factor)) + elif isinstance(p, Shard) and p.dim >= skipped_dim: new_placements.append(Shard(p.dim - 1)) else: new_placements.append(p) @@ -188,7 +196,7 @@ def replicate_reduction_dims( for p in placements: if p.is_partial(): new_placements.append(Replicate()) - elif isinstance(p, Shard) and p.dim in reduction_dims: + elif _is_shard_like(p) and p.dim in reduction_dims: new_placements.append(Replicate()) else: new_placements.append(p) @@ -210,7 +218,7 @@ def map_placements_after_reduction( if isinstance(placement, (Replicate, Partial)): new_placements.append(placement) else: - if not isinstance(placement, Shard | _StridedShard): + if not _is_shard_like(placement): raise AssertionError( f"Expected Shard/_StridedShard, got {type(placement)}" ) @@ -221,14 +229,14 @@ def map_placements_after_reduction( # (i.e. for the case where keepdims=True), we generate partial new_placements.append(get_placement_from_reduction_op(reduction_op)) else: - if isinstance(placement, Shard): - new_placements.append(Shard(new_shard_dim)) - else: + if isinstance(placement, _StridedShard): new_placements.append( _StridedShard( new_shard_dim, split_factor=placement.split_factor ) ) + elif isinstance(placement, Shard): + new_placements.append(Shard(new_shard_dim)) return tuple(new_placements) @@ -735,12 +743,6 @@ def foreach_max_strategy(op_schema: OpSchema) -> TupleStrategy: aten.tril.default, aten.triu.default, aten._linalg_eigh.default, - aten.upsample_bicubic2d.default, - aten.upsample_bilinear2d.default, - aten.upsample_linear1d.default, - aten.upsample_nearest2d.default, - aten.upsample_trilinear3d.default, - # TODO: support the full F.interpolate set of options. ], schema_info=RuntimeSchemaInfo(1), ) @@ -1160,13 +1162,17 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: generate_redistribute_costs(weight_strategy, weight_expected_spec) ) - # total_weight should always be replicated + # total_weight is only used by the backward kernel for reduction='mean'. + # For reduction='sum' or 'none', it is unused, so no redistribution needed. total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec - total_weight_expected_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(total_weight_src_spec.placements), - tensor_meta=total_weight_src_spec.tensor_meta, - ) + if reduction == Reduction.MEAN.value: + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(total_weight_src_spec.placements), + tensor_meta=total_weight_src_spec.tensor_meta, + ) + else: + total_weight_expected_spec = total_weight_src_spec op_args_target_specs.append(total_weight_expected_spec) redistribute_costs.append( generate_redistribute_costs( @@ -1186,393 +1192,145 @@ def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: return grad_in_strategy -def _common_norm_forward_strategy( - op_schema: OpSchema, - rms_norm: bool = False, -) -> OpStrategy: - """Common forward strategy logic for layer_norm and rms_norm.""" - mesh = op_schema.get_mesh_from_args() - - if not rms_norm: - # layer_norm args: input, normalized_shape, weight, bias, eps - # for None weight and bias, their corresponding objects will - # be None as well. layer_norm_strategy returns one OpStrategy - # for the triple return values (out, mean, rstd). - if not len(op_schema.args_schema) == 5: - raise AssertionError(f"Expected 5 args, got {len(op_schema.args_schema)}") - ( - input_strategy, - normalized_shape, - weight_strategy, - bias_strategy, - _, - ) = op_schema.args_schema - else: - # rms_norm args: input, normalized_shape, weight, eps - if not len(op_schema.args_schema) == 4: - raise AssertionError(f"Expected 4 args, got {len(op_schema.args_schema)}") - ( - input_strategy, - normalized_shape, - weight_strategy, - _, - ) = op_schema.args_schema - bias_strategy = None - - # the current norm implementation requires that all - # input DTensor's sharding must be in form of OpStrategy - if not isinstance(input_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") - if not isinstance(normalized_shape, (int, Sequence, torch.Size)): - raise AssertionError( - f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}" - ) - normalized_size = normalize_to_torch_size(normalized_shape) - - input_ndim = input_strategy.ndim - axis = input_ndim - len(normalized_size) - - # we use OpStrategy because the output values (out, mean, rstd) - # should have the same placements - output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - op_args_target_specs = [] - redistribute_costs = [] - input_src_spec = input_placement_strategy.output_spec - - # for the input tensor, we replicate it on the inner dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - input_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - op_args_target_specs.append(input_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_target_spec) - ) - - if weight_strategy is not None: - if not isinstance(weight_strategy, OpStrategy): - raise AssertionError( - f"Expected OpStrategy, got {type(weight_strategy)}" - ) - weight_src_spec = weight_strategy.strategies[idx].output_spec - - # for the weight tensor, we replicate it on all dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - weight_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(weight_src_spec.placements), - tensor_meta=weight_src_spec.tensor_meta, - ) - op_args_target_specs.append(weight_target_spec) - redistribute_costs.append( - generate_redistribute_costs(weight_strategy, weight_target_spec) - ) - - if bias_strategy is not None: - if not isinstance(bias_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(bias_strategy)}") - bias_src_spec = bias_strategy.strategies[idx].output_spec - - # for the bias tensor, we replicate it on all dims if necessary - # TODO: we can avoid forcing the redistribution once we figure out - # how to decompose layer norm - bias_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(bias_src_spec.placements), - tensor_meta=bias_src_spec.tensor_meta, - ) - op_args_target_specs.append(bias_target_spec) - redistribute_costs.append( - generate_redistribute_costs(bias_strategy, bias_target_spec) - ) - - # Build per-output specs with correct tensor_meta. - # out: same shape as input, contiguous strides - # mean/rstd: shape = input_shape[:axis], contiguous strides - input_tm = input_src_spec.tensor_meta - if input_tm is None: - raise AssertionError("input_src_spec.tensor_meta is None") - input_shape = input_tm.shape - out_placements = input_target_spec.placements - - out_strides = torch._prims_common.make_contiguous_strides_for(input_shape) - out_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=TensorMeta( - shape=input_shape, - stride=out_strides, - dtype=input_tm.dtype, - ), - ) - - stat_shape = torch.Size(input_shape[:axis]) - stat_strides = torch._prims_common.make_contiguous_strides_for(stat_shape) - stat_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=TensorMeta( - shape=stat_shape, - stride=stat_strides, - dtype=input_tm.dtype, - ), - ) - - if rms_norm: - output_specs = (out_spec, stat_spec) - else: - output_specs = (out_spec, stat_spec, stat_spec) - - output_strategy.strategies.append( - OpSpec( - output_specs=output_specs, - input_specs=op_args_target_specs, - redistribute_cost=redistribute_costs, - ) - ) - - return output_strategy - - -@register_op_strategy( +@register_single_dim_strategy( [aten.native_layer_norm.default], schema_info=RuntimeSchemaInfo(1), ) -def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: - return _common_norm_forward_strategy(op_schema) +def layer_norm_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = args_schema[0] + normalized_shape = args_schema[1] + weight_meta = args_schema[2] + bias_meta = args_schema[3] + axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape)) -@register_op_strategy( + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(axis): + # [out, mean, rstd, input, weight?, bias?] + rule: list[Placement | _ShardingPlaceholder] = [ + _ShardingPlaceholder(dim), # out + _ShardingPlaceholder(dim), # mean + _ShardingPlaceholder(dim), # rstd + _ShardingPlaceholder(dim), # input + ] + if weight_meta is not None: + rule.append(Replicate()) + if bias_meta is not None: + rule.append(Replicate()) + strategies.append(rule) + return strategies + + +@register_single_dim_strategy( [aten._fused_rms_norm.default], schema_info=RuntimeSchemaInfo(1), ) -def fused_rms_norm_strategy(op_schema: OpSchema) -> OpStrategy: - return _common_norm_forward_strategy(op_schema, rms_norm=True) - - -def _common_norm_backward_strategy( - op_schema: OpSchema, - rms_norm: bool = False, -) -> OpStrategy: - """Common backward strategy logic for layer_norm and rms_norm.""" - # backward op does not need to validate the mesh since forward op has already done it - mesh = op_schema.get_mesh_from_args(validate=False) - - if not rms_norm: - # layer_norm args: grad_out, input, normalized_shape, mean, rstd, - # weight, bias, output_mask. For None weight and bias, their - # corresponding objects will be None as well. - if not len(op_schema.args_schema) == 8: - raise AssertionError(f"Expected 8 args, got {len(op_schema.args_schema)}") - ( - grad_out_strategy, - input_strategy, - normalized_shape, - mean_strategy, - rstd_strategy, - weight_strategy, - bias_strategy, - output_mask, - ) = op_schema.args_schema - else: - # rms_norm args: grad_out, input, normalized_shape, rstd, - if not len(op_schema.args_schema) == 6: - raise AssertionError(f"Expected 6 args, got {len(op_schema.args_schema)}") - ( - grad_out_strategy, - input_strategy, - normalized_shape, - rstd_strategy, - weight_strategy, - output_mask, - ) = op_schema.args_schema - mean_strategy = None - bias_strategy = None - - if not isinstance(grad_out_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(grad_out_strategy)}") - if not isinstance(input_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(input_strategy)}") - if not isinstance(rstd_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(rstd_strategy)}") - if mean_strategy is not None: - if not isinstance(mean_strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(mean_strategy)}") - - if not isinstance(normalized_shape, (int, Sequence, torch.Size)): - raise AssertionError( - f"Expected int, Sequence, or torch.Size, got {type(normalized_shape)}" - ) - normalized_size = normalize_to_torch_size(normalized_shape) - input_ndim = input_strategy.ndim - axis = input_ndim - len(normalized_size) - outer_dims = list(range(axis)) - - if not rms_norm: - if not (isinstance(output_mask, list) and len(output_mask) == 3): - raise AssertionError( - f"Expected output_mask to be list of length 3, got {type(output_mask)} " - f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}" - ) - else: - if not (isinstance(output_mask, list) and len(output_mask) == 2): - raise AssertionError( - f"Expected output_mask to be list of length 2, got {type(output_mask)} " - f"of length {len(output_mask) if isinstance(output_mask, list) else 'N/A'}" - ) - - # output tuple: (d_input, d_weight[, d_bias]) - out_tuple_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): - # args for OpSpec - output_specs_list: list[DTensorSpec | None] = [] - input_specs_list: list[DTensorSpec] = [] - redistribute_costs = [] - - input_src_spec = input_placement_strategy.output_spec - # arg: grad_out - # TODO: change the strategy to the following rule. - # d_input is basically a product of element-wise mul of - # grad_out, rstd, and normalized input, among which rstd - # and normalized input (x_hat) should have the same sharding - # placements, and grad_out's sharding is determined by the - # pointwise result of x_hat and weight/bias. - # TODO: now grad_out spec follows input spec. we may need - # to change it to apply a pointwise rule over grad_out, - # input, and weight. - grad_out_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - input_specs_list.append(grad_out_target_spec) - redistribute_costs.append( - generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) - ) - output_specs_list.append(grad_out_target_spec if output_mask[0] else None) - - # arg: input - input_target_spec = DTensorSpec( - mesh=mesh, - placements=_replicate_dims_start_at(input_src_spec.placements, axis), - tensor_meta=input_src_spec.tensor_meta, - ) - input_specs_list.append(input_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_strategy, input_target_spec) - ) - - # arg: mean - if not rms_norm: - if mean_strategy is None: - raise AssertionError("Expected mean_strategy to not be None") - mean_src_spec = mean_strategy.strategies[idx].output_spec - input_specs_list.append(mean_src_spec) - redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) - - # arg: rstd - rstd_src_spec = rstd_strategy.strategies[idx].output_spec - input_specs_list.append(rstd_src_spec) - redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) - - def _add_target_input_spec(strategy) -> DTensorSpec: - # shared logic for setting the weight and bias target input specs - if not isinstance(strategy, OpStrategy): - raise AssertionError(f"Expected OpStrategy, got {type(strategy)}") - src_spec = strategy.strategies[idx].output_spec - # no need to redistribute since they should be replicated in forward pass - input_specs_list.append(src_spec) - redistribute_costs.append([0.0 for _ in strategy.strategies]) - return src_spec - - # arg: weight - # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) - # For RMS norm, mean is 0, so it's just: sum(grad_out * input / rstd, outer_dim, keepdim=False) - if weight_strategy is not None: - weight_src_spec = _add_target_input_spec(weight_strategy) - # TODO: now d_weight spec follows input spec w/ a reduction. - # we may need to change to a pointwise rule over grad_out and - # input, then apply a reduction. - inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, input_src_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - weight_out_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=weight_src_spec.tensor_meta, - ) - output_specs_list.append(weight_out_spec if output_mask[1] else None) - else: - if not rms_norm: - error_msg = "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." - else: - error_msg = "output_mask[1] should not be `True` while weight argument is `None` in _fused_rms_norm_backward." - if output_mask[1] is not False: - raise AssertionError(error_msg) - output_specs_list.append(None) - - # arg: bias - # d_bias = sum(grad_out, outer_dim, keepdim=False) - if not rms_norm: - if bias_strategy is not None: - bias_src_spec = _add_target_input_spec(bias_strategy) - # d_bias spec follows a reduction over grad_out - inp_placements = _replicate_dims_start_at( - grad_out_target_spec.placements, axis - ) - reduce_dims_map = _infer_reduce_dims_map( - outer_dims, grad_out_target_spec.ndim, False - ) - out_placements = map_placements_after_reduction( - inp_placements, outer_dims, reduce_dims_map, "sum" - ) - bias_out_spec = DTensorSpec( - mesh=mesh, - placements=out_placements, - tensor_meta=bias_src_spec.tensor_meta, - ) - output_specs_list.append(bias_out_spec if output_mask[2] else None) - else: - if output_mask[2] is not False: - raise AssertionError( - "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." - ) - output_specs_list.append(None) +def rms_norm_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = args_schema[0] + normalized_shape = args_schema[1] + weight_meta = args_schema[2] - out_tuple_strategy.strategies.append( - OpSpec( - output_specs=tuple(output_specs_list), - input_specs=input_specs_list, - redistribute_cost=redistribute_costs, - ) - ) + axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape)) - return out_tuple_strategy + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(axis): + # [out, rrms, input, weight?] + rule: list[Placement | _ShardingPlaceholder] = [ + _ShardingPlaceholder(dim), # out + _ShardingPlaceholder(dim), # rrms + _ShardingPlaceholder(dim), # input + ] + if weight_meta is not None: + rule.append(Replicate()) + strategies.append(rule) + return strategies -@register_op_strategy( +@register_single_dim_strategy( [aten.native_layer_norm_backward.default], schema_info=RuntimeSchemaInfo(2), ) -def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: - return _common_norm_backward_strategy(op_schema) +def layer_norm_bwd_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder | None]]: + input_meta = args_schema[1] + normalized_shape = args_schema[2] + # mean = args_schema[3], rstd = args_schema[4] + weight_meta = args_schema[5] + bias_meta = args_schema[6] + + axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape)) + + strategies: list[list[Placement | _ShardingPlaceholder | None]] = [] + for dim in range(axis): + # outputs: [d_input, d_weight, d_bias] — always 3 per schema + # d_weight/d_bias use None when weight/bias are None + rule: list[Placement | _ShardingPlaceholder | None] = [ + _ShardingPlaceholder(dim), # d_input + Partial("sum") if weight_meta is not None else None, # d_weight + Partial("sum") if bias_meta is not None else None, # d_bias + ] + # inputs: [grad_out, input, mean, rstd, weight?, bias?] + rule.extend( + [ + _ShardingPlaceholder(dim), # grad_out + _ShardingPlaceholder(dim), # input + _ShardingPlaceholder(dim), # mean + _ShardingPlaceholder(dim), # rstd + ] + ) + if weight_meta is not None: + rule.append(Replicate()) + if bias_meta is not None: + rule.append(Replicate()) + strategies.append(rule) + return strategies -@register_op_strategy( + +@register_single_dim_strategy( [aten._fused_rms_norm_backward.default], schema_info=RuntimeSchemaInfo(2), ) -def fused_rms_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: - return _common_norm_backward_strategy(op_schema, rms_norm=True) +def rms_norm_bwd_single_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder | None]]: + input_meta = args_schema[1] + normalized_shape = args_schema[2] + # rstd = args_schema[3] + weight_meta = args_schema[4] + + axis = len(input_meta.shape) - len(normalize_to_torch_size(normalized_shape)) + + strategies: list[list[Placement | _ShardingPlaceholder | None]] = [] + for dim in range(axis): + # outputs: [d_input, d_weight] — always 2 per schema + # d_weight uses None when weight is None + # inputs: [grad_out, input, rstd, weight?] + rule: list[Placement | _ShardingPlaceholder | None] = [ + _ShardingPlaceholder(dim), # d_input + Partial("sum") if weight_meta is not None else None, # d_weight + _ShardingPlaceholder(dim), # grad_out + _ShardingPlaceholder(dim), # input + _ShardingPlaceholder(dim), # rstd + ] + if weight_meta is not None: + rule.append(Replicate()) + strategies.append(rule) + + return strategies def sort_strategy(op_schema: OpSchema, sort_dim: int) -> OpStrategy: @@ -1696,3 +1454,424 @@ def logsumexp_strategy(op_schema: OpSchema) -> OpStrategy: keep_dim=keep_dim, reduction_linear=False, ) + + +_LINALG_NUM_PLACEMENTS = { + # 1 in 1 out + aten.cholesky.default: 2, + aten.cholesky_inverse.default: 2, + aten.linalg_matrix_exp.default: 2, + # 2 in 1 out + aten.cholesky_solve.default: 3, + aten.linalg_householder_product.default: 3, + aten.linalg_solve_triangular.default: 3, + # 3 in 1 out + aten.linalg_ldl_solve.default: 4, + aten.linalg_lu_solve.default: 4, + aten.ormqr.default: 4, + # 1 in 2 out + aten.geqrf.default: 3, + aten.linalg_cholesky_ex.default: 3, + aten.linalg_eig.default: 3, + aten.linalg_inv_ex.default: 3, + # 2 in 2 out + aten.triangular_solve.default: 4, + # 1 in 3 out + aten._linalg_det.default: 4, + aten.linalg_ldl_factor_ex.default: 4, + aten.linalg_lu.default: 4, + aten.linalg_lu_factor_ex.default: 4, + # 2 in 3 out + aten.lu_unpack.default: 5, + # 1 in 4 out + aten._linalg_slogdet.default: 5, + # 2 in 4 out + aten._linalg_solve_ex.default: 6, + # 1 in + aten._linalg_check_errors.default: 1, +} + + +def _linalg_batch_dim_strategies( + ndim: int, n_placements: int +) -> list[list[Placement | _ShardingPlaceholder]]: + """Build single-dim strategies for linalg ops that operate on the last 1-2 dims. + + Returns sharding on each batch dim (all dims except the last 2), with all + outputs and inputs sharded on the same dim. + """ + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(ndim - 2): + strategies.append([_ShardingPlaceholder(dim)] * n_placements) + return strategies + + +def _get_ndim(tensor_meta: Any) -> int: + if not isinstance(tensor_meta, TensorMeta): + raise AssertionError(f"Expected TensorMeta, got {type(tensor_meta)}") + return len(tensor_meta.shape) + + +@register_single_dim_strategy( + [ + aten.cholesky.default, + aten.cholesky_inverse.default, + aten.linalg_matrix_exp.default, + aten.cholesky_solve.default, + aten.linalg_householder_product.default, + aten.linalg_solve_triangular.default, + aten.linalg_ldl_solve.default, + aten.linalg_lu_solve.default, + aten.ormqr.default, + aten.geqrf.default, + aten.linalg_cholesky_ex.default, + aten.linalg_eig.default, + aten.linalg_inv_ex.default, + aten.triangular_solve.default, + aten._linalg_det.default, + aten.linalg_ldl_factor_ex.default, + aten.linalg_lu.default, + aten.linalg_lu_factor_ex.default, + aten.lu_unpack.default, + aten._linalg_slogdet.default, + aten._linalg_solve_ex.default, + aten._linalg_check_errors.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_batch_dim_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + ndim = _get_ndim(args_schema[0]) + if op not in _LINALG_NUM_PLACEMENTS: + raise AssertionError(f"Expected op in _LINALG_NUM_PLACEMENTS, got {op}") + + n_placements = _LINALG_NUM_PLACEMENTS[op] + strategies = _linalg_batch_dim_strategies(ndim, n_placements=n_placements) + + if op == aten.linalg_solve_triangular.default: + # solve_triangular(A, B) -> result: linear in B + strategies.append([Partial(), Replicate(), Partial()]) + strategies.append([Partial("avg"), Replicate(), Partial("avg")]) + # A replicated, B sharded on batch dims (B may have more batch dims than A) + ndim_b = _get_ndim(args_schema[1]) + for dim in range(ndim_b - 2): + strategies.append( + [_ShardingPlaceholder(dim), Replicate(), _ShardingPlaceholder(dim)] + ) + elif op == aten.cholesky_solve.default: + # cholesky_solve(B, A) -> result (B is arg0) + strategies.append([Partial(), Partial(), Replicate()]) + elif op == aten.linalg_lu_solve.default: + # linalg_lu_solve(LU, pivots, B) -> result + strategies.append([Partial(), Replicate(), Replicate(), Partial()]) + elif op == aten.linalg_ldl_solve.default: + # linalg_ldl_solve(LD, pivots, B) -> result + strategies.append([Partial(), Replicate(), Replicate(), Partial()]) + elif op == aten.ormqr.default: + # ormqr(a, tau, C) -> result (linear in C) + strategies.append([Partial(), Replicate(), Replicate(), Partial()]) + elif op == aten._linalg_solve_ex.default: + # _linalg_solve_ex(A, B) -> (result, LU, pivots, info) + strategies.append( + [Partial(), Replicate(), Replicate(), Replicate(), Replicate(), Partial()] + ) + + return strategies + + +# linalg_pinv has optional tensor kwargs atol, rtol (scalar tensors when present). +# Schema: (Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian) -> Tensor +# When atol/rtol are None, num_inputs=1; when present, they add to num_inputs. +@register_single_dim_strategy( + [aten.linalg_pinv.atol_rtol_tensor], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_pinv_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + ndim = _get_ndim(args_schema[0]) + # Count optional tensor kwargs that are actually present + extra_tensors = sum( + isinstance(kwargs_schema.get(k), TensorMeta) for k in ("atol", "rtol") + ) + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(ndim - 2): + s: list[Placement | _ShardingPlaceholder] = [ + _ShardingPlaceholder(dim), + _ShardingPlaceholder(dim), + ] + # atol, rtol are scalar tensors — always Replicate + s.extend([Replicate()] * extra_tensors) + strategies.append(s) + return strategies + + +# linalg_cross is pointwise on every dim except the cross-product dim (which +# must be size 3). Shard on any other dim. +@register_single_dim_strategy( + [aten.linalg_cross.default], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_cross_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + ndim = _get_ndim(args_schema[0]) + cross_dim = kwargs_schema.get("dim", -1) % ndim + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(ndim): + if dim == cross_dim: + continue + strategies.append([_ShardingPlaceholder(dim)] * 3) + return strategies + + +# --------------------------------------------------------------------------- +# Interpolation / upsample / pooling ops +# +# These ops operate on spatial dims and are safely shardable on batch (dim 0) +# and channel (dim 1). grid_sampler is batch-only because the grid tensor has +# no channel dimension. +# --------------------------------------------------------------------------- + + +@register_single_dim_strategy( + [ + # Forward ops + aten.upsample_nearest1d.default, + aten.upsample_nearest2d.default, + aten.upsample_nearest3d.default, + aten._upsample_nearest_exact1d.default, + aten._upsample_nearest_exact2d.default, + aten._upsample_nearest_exact3d.default, + aten._upsample_bilinear2d_aa.default, + aten.upsample_bicubic2d.default, + aten.upsample_bilinear2d.default, + aten.upsample_linear1d.default, + aten.upsample_trilinear3d.default, + # Backward ops + aten.upsample_nearest1d_backward.default, + aten.upsample_nearest2d_backward.default, + aten.upsample_nearest3d_backward.default, + aten._upsample_nearest_exact1d_backward.default, + aten._upsample_nearest_exact2d_backward.default, + aten._upsample_nearest_exact3d_backward.default, + aten._upsample_bilinear2d_aa_backward.default, + aten.upsample_bicubic2d_backward.default, + aten.upsample_bilinear2d_backward.default, + aten.upsample_linear1d_backward.default, + aten.upsample_trilinear3d_backward.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def interp_upsample_1out_1in_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # 1 output + 1 input = 2 placements; shard on batch (0) and channel (1) + # Upsample is a linear transformation so Partial(sum/avg) is valid. + return [ + [_ShardingPlaceholder(0)] * 2, + [_ShardingPlaceholder(1)] * 2, + [Partial("sum"), Partial("sum")], + [Partial("avg"), Partial("avg")], + ] + + +@register_single_dim_strategy( + [ + aten.max_unpool2d.default, + aten.max_unpool3d.default, + aten._adaptive_avg_pool2d_backward.default, + ], + schema_info=RuntimeSchemaInfo(1), +) +def interp_pool_1out_2in_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # 1 output + 2 inputs = 3 placements; shard on batch (0) and channel (1) + return [ + [_ShardingPlaceholder(0)] * 3, + [_ShardingPlaceholder(1)] * 3, + ] + + +@register_single_dim_strategy( + [aten.max_pool2d_with_indices_backward.default], + schema_info=RuntimeSchemaInfo(1), +) +def pool_backward_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # max_pool2d_with_indices_backward(grad_output, self, ..., indices) -> grad_input + # 1 output + 3 tensor inputs = 4 placements + # Order: [output, grad_output, self, indices] + input_meta = cast(TensorMeta, args_schema[0]) + strategies: list[list[Placement | _ShardingPlaceholder]] = [ + [_ShardingPlaceholder(0)] * 4, + ] + if len(input_meta.shape) >= 4: # batched: (N, C, H, W) + strategies.append([_ShardingPlaceholder(1)] * 4) + # The backward is linear in grad_output, so P(sum/avg) pass through. + # indices must be replicated (integer positions, not reducible). + # self is only used for shape, so replicate it too. + r = Replicate() + for reduce_op in ("sum", "avg"): + p = Partial(reduce_op) + strategies.append([p, p, r, r]) + return strategies + + +@register_single_dim_strategy( + [aten.grid_sampler_2d.default, aten.grid_sampler_3d.default], + schema_info=RuntimeSchemaInfo(1), +) +def grid_sampler_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # grid_sampler_{2,3}d(input[N,C,...], grid[N,...,{2,3}]) -> output[N,C,...] + # grid has no channel dim, so only batch sharding applies to both inputs. + # Linear in input: P(sum/avg) on input with replicated grid is valid. + return [ + [_ShardingPlaceholder(0)] * 3, + [Partial("sum"), Partial("sum"), Replicate()], + [Partial("avg"), Partial("avg"), Replicate()], + ] + + +@register_single_dim_strategy( + [aten.grid_sampler_2d_backward.default, aten.grid_sampler_3d_backward.default], + schema_info=RuntimeSchemaInfo(1), +) +def grid_sampler_backward_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # grid_sampler_{2,3}d_backward: 2 outputs (grad_input, grad_grid) + 3 inputs = 5 placements, batch-only + return [[_ShardingPlaceholder(0)] * 5] + + +def _adjust_group_norm_scalars( + input_specs: list[DTensorSpec], schema: OpSchema +) -> OpSchema: + """Adjust N, C, HxW scalar args in native_group_norm to local values. + + native_group_norm(input, weight?, bias?, N, C, HxW, group, eps) + The scalar args are derived from the global input shape by the Python frontend. + When the input is sharded, we recompute them from the local input shape. + """ + input_spec = input_specs[0] + if input_spec.tensor_meta is None: + raise AssertionError("input_spec must have tensor_meta") + local_shape, _ = compute_local_shape_and_global_offset( + input_spec.tensor_meta.shape, + input_spec.mesh, + input_spec.placements, + skip_offset=True, + ) + # N = local_shape[0], C = local_shape[1], HxW = product of remaining dims + n_local = local_shape[0] + c_local = local_shape[1] + hxw_local = 1 + for d in local_shape[2:]: + hxw_local *= d + args = list(schema.args_schema) + # Find scalar arg positions: first 1-3 args are tensors (input, weight?, bias?), + # then N, C, HxW, group, eps. Count tensor args to find the offset. + num_tensor_args = sum(isinstance(a, DTensorSpec) for a in args) + args[num_tensor_args] = n_local + args[num_tensor_args + 1] = c_local + args[num_tensor_args + 2] = hxw_local + return OpSchema(schema.op, tuple(args), schema.kwargs_schema) + + +# --------------------------------------------------------------------------- +# Normalization ops +# +# Batch norm reduces over batch (dim 0) + spatial dims (2+), keeping only +# channel (dim 1). Neither batch nor channel sharding is safe, so we fall +# back to replicate-only. +# +# Group norm reduces over (C/groups, spatial) within each group per sample. +# Batch dim (0) is safe to shard — each sample is independent. +# --------------------------------------------------------------------------- + +BATCH_NORM_3OUT_OPS = [ + aten.native_batch_norm.default, + aten._native_batch_norm_legit.default, + aten._native_batch_norm_legit.no_stats, + aten._native_batch_norm_legit_no_training.default, +] + +BATCH_NORM_4OUT_OPS = [ + aten._batch_norm_with_update.default, +] + + +@register_single_dim_strategy( + BATCH_NORM_3OUT_OPS + BATCH_NORM_4OUT_OPS, + schema_info=RuntimeSchemaInfo(1), +) +def batch_norm_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # Batch norm normalizes per-channel (reduces over batch + spatial dims), + # so channel-dim sharding is valid: each shard processes independent channels. + # Unlike group_norm, batch_norm infers shapes from tensors (no scalar N/C/HxW). + num_outputs = 4 if op in BATCH_NORM_4OUT_OPS else 3 + num_tensor_inputs = sum(isinstance(a, TensorMeta) for a in args_schema) + # output [N,C,*] shards on dim 1; save_mean, save_invstd [C] shard on dim 0 + rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(1)] + rule.extend([_ShardingPlaceholder(0)] * 2) # save_mean, save_invstd + if num_outputs == 4: + rule.append(Replicate()) # reserve: opaque cuDNN workspace + # input [N,C,*] shards on dim 1; weight, bias, running_mean, running_var [C] on dim 0 + rule.append(_ShardingPlaceholder(1)) # input + rule.extend([_ShardingPlaceholder(0)] * (num_tensor_inputs - 1)) + return [rule] + + +@register_single_dim_strategy( + [aten.native_group_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def group_norm_strategy( + op: torch._ops.OpOverload, + args_schema: tuple[Any, ...], + kwargs_schema: dict[str, Any], +) -> list[list[Placement | _ShardingPlaceholder]]: + # native_group_norm(input, weight?, bias?, N, C, HxW, group, eps) -> (out, mean, rstd) + # Batch dim (0) is independent. The scalar N/C/HxW args are adjusted to local + # values by _adjust_group_norm_scalars in the sharding propagation layer. + num_tensor_inputs = sum(isinstance(a, TensorMeta) for a in args_schema) + # 3 outputs + input all shard on batch dim + placements: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(0)] * 4 + # weight and bias (if present) must be Replicate + placements.extend([Replicate()] * (num_tensor_inputs - 1)) + return [placements] + + +# Register scalar shape adjuster for group_norm so the sharding propagator +# rewrites the N/C/HxW args to local values when the input is sharded. +from torch.distributed.tensor._api import DTensor + + +DTensor._op_dispatcher.sharding_propagator.op_to_scalar_shape_adjuster[ + aten.native_group_norm.default +] = _adjust_group_norm_scalars diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index b4ce70938ec94..b1bda0d15fdc2 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -36,6 +36,7 @@ compute_local_stride, ) from torch.distributed.tensor.placement_types import ( + _StridedShard, Partial, Placement, Replicate, @@ -62,10 +63,16 @@ def transpose_strategy(op_schema: OpSchema) -> OpStrategy: if ndim <= 1: output_placements = list(input_spec.placements) else: - output_placements = [ - Shard(1 - p.dim) if isinstance(p, Shard) else p - for p in input_spec.placements - ] + output_placements: list[Placement] = [] + for p in input_spec.placements: + if isinstance(p, _StridedShard): + output_placements.append( + _StridedShard(1 - p.dim, split_factor=p.split_factor) + ) + elif isinstance(p, Shard): + output_placements.append(Shard(1 - p.dim)) + else: + output_placements.append(p) transpose_strategy = OpSpec( output_specs=DTensorSpec( mesh=input_strategy.mesh, @@ -205,6 +212,7 @@ def _scaled_mm_scale_placement( if data_placement.dim == contracting_dim: return None return _ShardingPlaceholder(0) + # NOTE: isinstance(_, Shard) does not match _StridedShard; see _is_shard_like(). elif isinstance(data_placement, Shard): if data_placement.dim == contracting_dim: return None diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 1e8c5b28a56e0..dfc253eff8e86 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -1,61 +1,20 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import functools -from collections.abc import Callable, Sequence -from typing import cast +from collections.abc import Callable import torch from torch._ops import OpOverload -from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta -from torch.distributed.tensor._op_schema import ( - ArgsType, - KwargsType, - OpSchema, - OpSpec, - OpStrategy, - RuntimeSchemaInfo, - StrategyType, - TupleStrategy, -) -from torch.distributed.tensor._ops._math_ops import _NormPartial +from torch.distributed.tensor._dtensor_spec import TensorMeta +from torch.distributed.tensor._op_schema import ArgsType, KwargsType, RuntimeSchemaInfo from torch.distributed.tensor._ops.single_dim_strategy import ( _ShardingPlaceholder, register_single_dim_strategy, ) -from torch.distributed.tensor._ops.utils import ( - generate_redistribute_costs, - infer_broadcast_dims_map, - map_placements_after_broadcast, - normalize_dim, - register_op_strategy, -) -from torch.distributed.tensor.placement_types import ( - _StridedShard, - Partial, - Placement, - Replicate, - Shard, -) -from torch.types import _Number -from torch.utils._typing_utils import not_none +from torch.distributed.tensor._ops.utils import infer_broadcast_dims_map +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate aten = torch.ops.aten prims = torch.ops.prims -# leave the remaining pointwise_ops list here for convenience, -# Below ops are some pointwise ops that are yet to be supported, -# they might not be a complete list. -# pointwise_ops = [ -# "fake_quantize_per_channel_affine", -# "fake_quantize_per_tensor_affine", -# "floor_divide", # floor_divide is deprecated -# "frexp", # multiple output pointwise op, need to add support -# "gradient", # need investigation on this op -# "imag", # complex data type only -# "quantized_batch_norm", -# "quantized_max_pool1d", -# "quantized_max_pool2d", -# "real", # complex data type only -# ] # Linear pointwise ops, split by linearity type. unary_linear_ops = [aten.to.dtype] @@ -66,7 +25,12 @@ def _common_pointwise_single_dim_strategy( ) -> Callable[ [OpOverload, ArgsType, KwargsType], list[list[Placement | _ShardingPlaceholder]] ]: - """Factory for single-dim strategies that add partial placement rules.""" + """Factory for single-dim strategies that add partial placement rules. + + Returns strategies shaped [output, *args] only. Tensor kwarg placements + (e.g. ``out``, ``lr``) are appended by the wrapper in + ``_register_single_dim_pointwise``. + """ def strategy( op: OpOverload, @@ -116,46 +80,67 @@ def strategy( return strategy +def _is_list_op(op: OpOverload) -> bool: + """Returns True if op is a foreach, amp_foreach, or fused op.""" + name = op.name() + return name.startswith(("aten::_foreach_", "aten::_amp_foreach_", "aten::_fused_")) + + +# The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on +# the compute_mesh of an op across all parameter groups, even when not all parameter groups +# are on the same device mesh. This idx will help avoid hitting exceptions or unnecessary +# redistribute during sharding propagation. +_FUSED_OP_SCALAR_IDX = 5 + +# Ops registered with extra Partial rules; populated by _register_single_dim_pointwise +# when partial_extra_rules is not None, to avoid double-registration from tag discovery. +_specially_registered_ops: set[OpOverload] = set() + + def _register_single_dim_pointwise( op: OpOverload, partial_extra_rules: list[list[Placement]] | None = None, static_argnum: int = 0, ) -> None: - strategy_fn = _common_pointwise_single_dim_strategy( + if partial_extra_rules is not None: + _specially_registered_ops.add(op) + inner_fn = _common_pointwise_single_dim_strategy( partial_extra_rules=partial_extra_rules # pyrefly: ignore[bad-argument-type] ) - # For .out ops, append output placement as the out kwarg placement. - # Strategy functions author [output, *args] without kwargs. The out tensor - # must match the output placement, so we duplicate strategy[0] (output). - # This makes strategies [output, *args, out_kwarg] so _get_num_tensor_inputs - # (which counts the out kwarg) computes num_outputs correctly. - if "out" in op._schema.overload_name: - inner_fn = strategy_fn - - def _out_wrapper( - op: OpOverload, - args: ArgsType, - kwargs: KwargsType, - _fn: Callable = inner_fn, - ) -> list[list[Placement | _ShardingPlaceholder]]: - strategies = _fn(op, args, kwargs) - n_tensor_args = sum(1 for a in args if isinstance(a, TensorMeta)) - n_outputs = sum(1 for r in op._schema.returns if "Tensor" in str(r.type)) - for s in strategies: - if len(s) != n_outputs + n_tensor_args: - raise AssertionError( - f"Strategy length {len(s)} != expected {n_outputs + n_tensor_args} " - f"({n_outputs} output(s) + {n_tensor_args} args) for {op}. " - f"out kwarg will be appended by infra." - ) - return [s + [s[0]] for s in strategies] - strategy_fn = _out_wrapper + # Wrap to append tensor kwarg placements in schema declaration order. + # out = output placement (s[0]); everything else (e.g. lr) = Replicate. + # TODO: move kwargs handling upstream if this works + def strategy_fn( + op: OpOverload, + args: ArgsType, + kwargs: KwargsType, + _fn: Callable = inner_fn, + ) -> list[list[Placement | _ShardingPlaceholder]]: + strategies = _fn(op, args, kwargs) + kw_names = [k for k, v in kwargs.items() if isinstance(v, TensorMeta)] + if not kw_names: + return strategies + return [ + s + [s[0] if name == "out" else Replicate() for name in kw_names] + for s in strategies + ] + + if _is_list_op(op): + schema_info = RuntimeSchemaInfo(needs_pytree=True) + else: + schema_info = RuntimeSchemaInfo(static_argnum, static_kwargkey=["out"]) + # Fused ops (e.g. _fused_adam_) have state_steps on a potentially different + # mesh; see the note in expand_to_full_mesh_op_strategy for details. + different_mesh_args: list[int] | None = None + if op.name().startswith("aten::_fused_"): + different_mesh_args = [_FUSED_OP_SCALAR_IDX] register_single_dim_strategy( op, - schema_info=RuntimeSchemaInfo(static_argnum, static_kwargkey=["out"]), + schema_info=schema_info, allow_uneven_sharding=True, allow_unbacked_sharding=True, + different_mesh_args=different_mesh_args, )(strategy_fn) @@ -171,6 +156,11 @@ def _out_wrapper( aten.sub.Tensor, aten.sub_.Tensor, aten.sub.out, + # foreach variants + aten._foreach_add.List, + aten._foreach_add_.List, + aten._foreach_sub.List, + aten._foreach_sub_.List, ] _BINARY_ADDITIVE_RULES: list[list[Placement]] = [ @@ -191,8 +181,26 @@ def _out_wrapper( _register_single_dim_pointwise(op, _BINARY_ADDITIVE_RULES) # mul: partials propagate through either arg. div: only through numerator. -binary_mul_ops = [aten.mul.Tensor, aten.mul_.Tensor, aten.mul.out] -binary_div_ops = [aten.div.Tensor, aten.div_.Tensor, aten.div.out] +binary_mul_ops = [ + aten.mul.Tensor, + aten.mul_.Tensor, + aten.mul.out, + # foreach variants + aten._foreach_mul.List, + aten._foreach_mul_.List, + aten._foreach_mul.Tensor, + aten._foreach_mul_.Tensor, +] +binary_div_ops = [ + aten.div.Tensor, + aten.div_.Tensor, + aten.div.out, + # foreach variants + aten._foreach_div.List, + aten._foreach_div_.List, + aten._foreach_div.Tensor, + aten._foreach_div_.Tensor, +] # _UNARY_LINEAR_RULES handles the scalar promotion case: Python's __mul__/__truediv__ # promote scalars to 0-dim tensors, so aten.mul.Scalar dispatches as aten.mul.Tensor @@ -220,12 +228,20 @@ def _out_wrapper( aten.div_.Scalar, aten.mul.Scalar, aten.mul_.Scalar, + # foreach variants + aten._foreach_div.Scalar, + aten._foreach_div_.Scalar, + aten._foreach_mul.Scalar, + aten._foreach_mul_.Scalar, + aten._foreach_div.ScalarList, + aten._foreach_div_.ScalarList, + aten._foreach_mul.ScalarList, + aten._foreach_mul_.ScalarList, ] for op in scalar_linear_ops: _register_single_dim_pointwise(op, _UNARY_LINEAR_RULES, static_argnum=1) -neg_ops = [aten.neg.default, aten.neg_.default] # Non-decreasing unary ops: f(max(a,b)) = max(f(a),f(b)). # Only ops that are non-decreasing on their ENTIRE domain belong here. @@ -293,6 +309,18 @@ def _out_wrapper( aten.nan_to_num.default, aten.nan_to_num_.default, aten.nan_to_num.out, + # hardshrink: x if |x|>lambd else 0. Non-decreasing on entire domain. + aten.hardshrink.default, + # I1(x) is monotonically non-decreasing for all real x. + aten.special_modified_bessel_i1.default, + # threshold(x, t, v): x if x > t else v. Non-decreasing for v <= t (the + # common case, including the default v=0, t=0). + aten.threshold.default, + # foreach variants + aten._foreach_exp.default, + aten._foreach_exp_.default, + aten._foreach_clamp_max_.Scalar, + aten._foreach_clamp_min_.Scalar, ] _NON_DECREASING_RULES: list[list[Placement]] = [ @@ -321,14 +349,47 @@ def _out_wrapper( for op in non_increasing_unary_ops: _register_single_dim_pointwise(op, _NON_INCREASING_RULES) +# Bessel K functions are strictly decreasing for x > 0 but undefined at x <= 0. +# Only P(min)->P(max) is safe: P(min) offsets add positive values to the +# non-holding rank, keeping all inputs positive. P(max) offsets subtract, +# which can push inputs to x <= 0 producing NaN. +_POSITIVE_DOMAIN_NON_INCREASING_RULES: list[list[Placement]] = [ + [Partial("max"), Partial("min")], +] + +for op in [ + aten.special_modified_bessel_k0.default, + aten.special_modified_bessel_k1.default, + aten.special_scaled_modified_bessel_k0.default, + aten.special_scaled_modified_bessel_k1.default, +]: + _register_single_dim_pointwise(op, _POSITIVE_DOMAIN_NON_INCREASING_RULES) + # neg is linear: -(A1 + A2) = -A1 + -A2 -neg_ops = [aten.neg.default, aten.neg_.default, aten.neg.out] +neg_ops = [ + aten.neg.default, + aten.neg_.default, + aten.neg.out, + # foreach variants + aten._foreach_neg.default, + aten._foreach_neg_.default, +] _NEG_RULES: list[list[Placement]] = _UNARY_LINEAR_RULES + _NON_INCREASING_RULES for op in neg_ops: _register_single_dim_pointwise(op, _NEG_RULES) +# xlog1py(x, y) = x * log1p(y). Linear in x with y replicated: +# (a+b)*log1p(y) = a*log1p(y) + b*log1p(y). +_XLOG1PY_RULES: list[list[Placement]] = [ + [Partial("sum"), Partial("sum"), Replicate()], + [Partial("avg"), Partial("avg"), Replicate()], +] + +for op in [aten.special_xlog1py.default, aten.special_xlog1py.other_scalar]: + _register_single_dim_pointwise(op, _XLOG1PY_RULES) + # All-partial-preserving unary ops: P(x)->P(x) for all x. # TODO: positive should be removed once CIA (Copy Is All) optimizes it away. @@ -383,6 +444,8 @@ def _out_wrapper( aten.maximum.default, aten.maximum.out, prims.fmax.default, + # foreach variants + aten._foreach_maximum_.List, ] _MONOTONE_MAX_PRESERVING_BINARY_BASE_RULES: list[list[Placement]] = [ @@ -412,302 +475,29 @@ def _out_wrapper( _register_single_dim_pointwise(op, _MONOTONE_MIN_PRESERVING_BINARY_BASE_RULES) -# The linear pointwise ops map, key is op, value is the type of linearity. -# Reconstructed from category lists for the existing registration path. -linear_pointwise_ops: dict[OpOverload, int] = { - aten.to.dtype: 0, - **dict.fromkeys(binary_additive_ops, 1), - **dict.fromkeys(binary_mul_ops, 2), - **dict.fromkeys(binary_div_ops, 2), - **dict.fromkeys(scalar_linear_ops, 0), - **dict.fromkeys(neg_ops, 0), -} - -pointwise_ops = [ - # please keep the entries below alphabetically sorted - aten.__ilshift__.Scalar, - aten.__ilshift__.Tensor, +# Ops that are pointwise for DTensor purposes but lack torch.Tag.pointwise. +# TODO(pianpwk): add torch.Tag.pointwise to these ops in native_functions.yaml +# so this list can be removed. +_extra_pointwise_ops: list[OpOverload] = [ aten.__irshift__.Scalar, aten.__irshift__.Tensor, - aten.__lshift__.Scalar, - aten.__lshift__.Tensor, - aten.__rshift__.Scalar, - aten.__rshift__.Tensor, aten._conj.default, - aten.abs.default, - aten.abs.out, aten.abs_.default, - aten.acos.default, - aten.acos.out, - aten.acos_.default, - aten.acosh.default, - aten.acosh.out, - aten.acosh_.default, - aten.add.Scalar, - aten.add_.Scalar, - aten.addcdiv.default, - aten.addcdiv.out, - aten.addcdiv_.default, - aten.addcmul.default, - aten.addcmul.out, - aten.addcmul_.default, - aten.angle.default, - aten.angle.out, - aten.asin.default, - aten.asin.out, - aten.asin_.default, - aten.atan2.default, - aten.atan2.out, - aten.atan2_.default, - aten.atanh.default, - aten.atanh.out, - aten.atanh_.default, - aten.bitwise_and.Scalar, - aten.bitwise_and.Scalar_Tensor, - aten.bitwise_and.Scalar_out, - aten.bitwise_and.Tensor, - aten.bitwise_and.Tensor_out, - aten.bitwise_and_.Scalar, - aten.bitwise_and_.Tensor, - aten.bitwise_left_shift.Scalar_Tensor, - aten.bitwise_left_shift.Tensor, - aten.bitwise_left_shift.Tensor_Scalar, - aten.bitwise_left_shift.Tensor_Scalar_out, - aten.bitwise_left_shift.Tensor_out, - aten.bitwise_left_shift_.Tensor, - aten.bitwise_left_shift_.Tensor_Scalar, - aten.bitwise_not.default, - aten.bitwise_not.out, - aten.bitwise_not_.default, - aten.bitwise_or.Scalar, - aten.bitwise_or.Scalar_Tensor, - aten.bitwise_or.Scalar_out, - aten.bitwise_or.Tensor, - aten.bitwise_or.Tensor_out, - aten.bitwise_or_.Scalar, - aten.bitwise_or_.Tensor, - aten.bitwise_right_shift.Scalar_Tensor, - aten.bitwise_right_shift.Tensor, - aten.bitwise_right_shift.Tensor_Scalar, - aten.bitwise_right_shift.Tensor_Scalar_out, - aten.bitwise_right_shift.Tensor_out, - aten.bitwise_right_shift_.Tensor, - aten.bitwise_right_shift_.Tensor_Scalar, - aten.bitwise_xor.Scalar, - aten.bitwise_xor.Scalar_Tensor, - aten.bitwise_xor.Scalar_out, - aten.bitwise_xor.Tensor, - aten.bitwise_xor.Tensor_out, - aten.bitwise_xor_.Scalar, - aten.bitwise_xor_.Tensor, - aten.clamp.default, - aten.clamp.Tensor, - aten.clamp.out, - aten.clamp_.default, - aten.clamp_.Tensor, - aten.clamp_min.default, - aten.clamp_max.default, - aten.clip.default, - aten.clip.out, - aten.clip_.default, - aten.conj_physical.default, - aten.conj_physical.out, - aten.conj_physical_.default, - aten.copysign.Scalar, - aten.copysign.Scalar_out, - aten.copysign.Tensor, - aten.copysign.out, aten.copysign_.Scalar, aten.copysign_.Tensor, - aten.cos.default, - aten.cos.out, - aten.cos_.default, - aten.cosh.default, - aten.cosh.out, - aten.cosh_.default, - aten.digamma.default, - aten.digamma.out, - aten.digamma_.default, - aten.div.Tensor_mode, - aten.div.out_mode, - aten.div_.Tensor_mode, - aten.eq.Tensor, - aten.eq.Tensor_out, - aten.eq.Scalar, - aten.eq.Scalar_out, - aten.erfinv.default, - aten.erfinv.out, - aten.erfinv_.default, + aten.exponential_.default, aten.float_power.Scalar, aten.float_power.Scalar_out, aten.float_power.Tensor_Scalar, aten.float_power.Tensor_Scalar_out, aten.float_power.Tensor_Tensor, aten.float_power.Tensor_Tensor_out, - aten.float_power_.Scalar, - aten.float_power_.Tensor, - aten.fmod.Scalar, - aten.fmod.Scalar_out, - aten.fmod.Tensor, - aten.fmod.Tensor_out, - aten.fmod_.Scalar, - aten.fmod_.Tensor, - aten.frac.default, - aten.frac.out, - aten.frac_.default, - aten.gcd.default, - aten.gcd.out, - aten.ge.Scalar, - aten.ge.Tensor, - aten.gelu.default, - aten.gt.Tensor, - aten.gt.Tensor_out, - aten.gt.Scalar, - aten.gt.Scalar_out, - aten.gt.Scalar, - aten.gt.Tensor, - aten.heaviside.default, - aten.heaviside.out, - aten.hypot.default, - aten.hypot.out, - aten.hypot_.default, - aten.i0.default, - aten.i0.out, - aten.i0_.default, - aten.igamma.default, - aten.igamma.out, - aten.igamma_.default, - aten.igammac.default, - aten.igammac.out, - aten.igammac_.default, - aten.isinf.default, - aten.isnan.default, - aten.isneginf.default, - aten.isneginf.out, - aten.isposinf.default, - aten.isposinf.out, - aten.ldexp.Tensor, - aten.ldexp.out, - aten.ldexp_.default, - aten.lt.Tensor, - aten.lt.Tensor_out, - aten.lt.Scalar, - aten.lt.Scalar_out, - aten.le.Scalar, - aten.le.Tensor, - aten.lerp.Scalar, - aten.lerp.Scalar_out, - aten.lerp.Tensor, - aten.lerp.Tensor_out, - aten.lerp_.Scalar, - aten.lerp_.Tensor, - aten.lgamma.default, - aten.lgamma.out, - aten.lgamma_.default, - aten.log.default, - aten.log.out, - aten.log10.default, - aten.log10.out, - aten.log10_.default, - aten.log1p.default, - aten.log1p.out, - aten.log1p_.default, - aten.log2.default, - aten.log2.out, - aten.log2_.default, - aten.log_.default, - aten.logical_and.default, - aten.logical_and.out, - aten.logical_and_.default, - aten.logical_not.default, - aten.logical_not.out, - aten.logical_not_.default, - aten.logical_or.default, - aten.logical_or.out, - aten.logical_or_.default, - aten.logical_xor.default, - aten.logical_xor.out, - aten.logical_xor_.default, - aten.logit.default, - aten.logit.out, - aten.logit_.default, - aten.masked_fill.Scalar, aten.masked_fill_.Scalar, - aten.mvlgamma.default, - aten.mvlgamma.out, - aten.mvlgamma_.default, - aten.native_dropout_backward.default, aten.native_dropout_backward.out, - aten.ne.Scalar, - aten.nextafter.default, - aten.nextafter.out, - aten.nextafter_.default, - aten.polygamma.default, - aten.polygamma.out, aten.polygamma_.default, - aten.pow.Scalar, - aten.pow.Scalar_out, - aten.pow.Tensor_Scalar, - aten.pow.Tensor_Scalar_out, - aten.pow.Tensor_Tensor, - aten.pow.Tensor_Tensor_out, - aten.pow_.Scalar, - aten.pow_.Tensor, - aten.reciprocal.default, - aten.reciprocal.out, - aten.reciprocal_.default, - aten.remainder.Scalar, - aten.remainder.Scalar_Tensor, - aten.remainder.Scalar_out, - aten.remainder.Tensor, - aten.remainder.Tensor_out, - aten.remainder_.Scalar, - aten.remainder_.Tensor, - aten.rsqrt.default, - aten.rsqrt.out, - aten.rsqrt_.default, - aten.rsub.Scalar, - aten.signbit.default, - aten.signbit.out, - aten.silu.default, - aten.silu.out, - aten.sin.default, - aten.sin.out, - aten.sin_.default, - aten.sinc.default, - aten.sinc.out, - aten.sinc_.default, - aten.sqrt.default, - aten.sqrt.out, - aten.sqrt_.default, - aten.square.default, - aten.square.out, - aten.square_.default, - aten.sub.Scalar, - aten.sub_.Scalar, - aten.tan.default, - aten.tan.out, - aten.tan_.default, - aten.true_divide.Tensor, - aten.where.self, + aten.rrelu_with_noise.default, aten.where.self_out, - aten.xlogy.OutScalar_Self, - aten.xlogy.OutScalar_Other, - aten.xlogy.OutTensor, - aten.xlogy.Scalar_Other, - aten.xlogy.Scalar_Self, - aten.xlogy.Tensor, aten.xlogy_.Scalar_Other, - aten.xlogy_.Tensor, - # backward point-wise ops - # please keep the entries below alphabetically sorted - aten.gelu_backward.default, - aten.sigmoid_backward.default, - aten.silu_backward.default, - aten.tanh_backward.default, - aten.threshold_backward.default, - # prims ops - # please keep the entries below alphabetically sorted prims.bessel_i0e.default, prims.bessel_i1.default, prims.bessel_i1e.default, @@ -715,379 +505,13 @@ def _out_wrapper( prims.bessel_j1.default, prims.div.default, prims.erfcx.default, - prims.gcd.default, prims.frexp.default, + prims.gcd.default, prims.ndtri.default, prims.ne.default, prims.spherical_bessel_j0.default, prims.zeta.default, -] - - -# Reconstruct the original linear_pointwise_ops dict for the existing registration path. -linear_pointwise_ops: dict[OpOverload, int] = { - **dict.fromkeys(unary_linear_ops, 0), - **dict.fromkeys(binary_additive_ops, 1), - **dict.fromkeys(binary_mul_ops, 2), - **dict.fromkeys(binary_div_ops, 2), - **dict.fromkeys(scalar_linear_ops, 0), - **dict.fromkeys(neg_ops, 0), -} - - -def pointwise_strategy( - op_schema: OpSchema, - linearity: int = -1, - preserve_partial: str | None = None, -) -> OpStrategy: - """Strategy for pointwise ops on the old registration path.""" - followed_strategy_index = -1 - max_shards = -1 - max_ndim = -1 - - if op_schema.is_inplace_op(): - # inplace op should follow the first arg strategy - followed_strategy = op_schema.args_schema[0] - followed_strategy_index = 0 - elif op_schema.is_out_variant_op(): - # out variant op should follow the out kwarg strategy - followed_strategy = op_schema.kwargs_schema["out"] - # out variant is technically a kwarg for the strategy to follow so it does not - # have an "index", we set it to a reasonably large number just to indicate it's - # not a valid index - followed_strategy_index = 100 - else: - # normal pointwise op, we choose to follow the arg with - # the max shards in case operands needs reshard - # in case of multiple operands with max shard, we take - # the one with the max number of dimensions - for idx, arg_strategy in enumerate(op_schema.args_schema): - if not isinstance(arg_strategy, OpStrategy): - continue - arg_max_shards = arg_strategy.max_num_shards() - arg_max_ndim = arg_strategy.ndim - if (arg_max_shards > max_shards) or ( - arg_max_shards == max_shards and arg_max_ndim > max_ndim - ): - followed_strategy_index = idx - max_shards = arg_max_shards - max_ndim = arg_max_ndim - followed_strategy = op_schema.args_schema[followed_strategy_index] - - if not isinstance(followed_strategy, OpStrategy): - raise AssertionError(f"no strategy to follow for {op_schema}!") - return common_pointwise_strategy( - op_schema.op, - op_schema.args_schema, - followed_strategy, - followed_strategy_index, - linearity, - preserve_partial=preserve_partial, - ) - - -def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: - """ - Linear pointwise operators can propagate pending reductions. - For example, c = add(a, b); if a is pending sum, then c will be - pending sum as well without any communication overhead. - - Note that: - 1. Only unary and binary operations are supported, out variant - ops are not supported. - 2. There're multiple types of linearity, refer to the doc of - common_pointwise_strategy for more details. - """ - linearity_type = linear_pointwise_ops.get(op_schema.op, -1) - return pointwise_strategy(op_schema, linearity=linearity_type) - - -def single_mesh_dim_pointwise_strategy( - op: OpOverload, - args_schema: ArgsType, - kwargs_schema: KwargsType, - linearity: int = -1, -) -> list[list[Placement | _ShardingPlaceholder]]: - return single_mesh_dim_common_pointwise_strategy(args_schema, linearity) - - -def single_mesh_dim_linear_pointwise_strategy( - linearity: int = -1, -) -> Callable[ - [OpOverload, ArgsType, KwargsType], list[list[Placement | _ShardingPlaceholder]] -]: - return functools.partial(single_mesh_dim_pointwise_strategy, linearity=linearity) - - -def single_mesh_dim_common_pointwise_strategy( - args_schema: ArgsType, - linearity: int = -1, - scalar_tensor_idx: int | None = None, -) -> list[list[Placement | _ShardingPlaceholder]]: - # TODO rename - tensor_arg_strategies: list[TensorMeta] = [ - arg for arg in args_schema if isinstance(arg, TensorMeta) - ] - common_shape = torch.broadcast_shapes( - *[arg.shape for arg in args_schema if isinstance(arg, TensorMeta)] - ) - placements_list: list[list[Placement | _ShardingPlaceholder]] = [] - for i in range(len(common_shape)): - # Shard output dim i, and then shard the corresponding arguments if they have a corresponding (non broadcast) dim - shard_placements: list[Placement | _ShardingPlaceholder] = [ - _ShardingPlaceholder(i) - ] - for arg in tensor_arg_strategies: - common_dim_to_arg_dim = infer_broadcast_dims_map(common_shape, arg.shape) - if common_dim_to_arg_dim[i] >= 0: - shard_placements.append(_ShardingPlaceholder(common_dim_to_arg_dim[i])) - else: - shard_placements.append(Replicate()) - - placements_list.append(shard_placements) - - if linearity == 0: - # unary op (e.g. to_copy), and also binary ops like mul.scalar - # input, output can be partial - if len(tensor_arg_strategies) != 1: - raise AssertionError("expected single tensor input for linearity==0 op") - placements_list.append([Partial("sum"), Partial("sum")]) - # TODO: do i need to check scalar_tensor_index and assign a replicate to that one, or do i omit a placement for it - # TODO: can mul.scalar work with avg or only sum? i think only sum works. common_pointwise_strategy seems - # to support both. - # TODO: also, i'll be replacing 'Partial(sum)' here with some kind of 'PartialPlaceholder', not yet designed - placements_list.append([Partial("avg"), Partial("avg")]) - - elif linearity == 1: - # binary add ops - # (A1 + B1) + (A2 + B2) == (A1 + A2) + (B1 + B2) - if len(tensor_arg_strategies) != 2: - raise AssertionError("expected two tensor inputs for linearity==1 op") - placements_list.append([Partial("sum"), Partial("sum"), Partial("sum")]) - elif linearity == 2: - # binary mul ops (2 tensor inputs) - # (A * B1) + (A * B2) == A * (B1 + B2) - if len(tensor_arg_strategies) != 2: - raise AssertionError("expected two tensor inputs for linearity==2 op") - placements_list.append([Partial("sum"), Partial("sum"), Replicate()]) - placements_list.append([Partial("sum"), Replicate(), Partial("sum")]) - - # TODO: handle scalar_tensor_idx - return placements_list - - -def copy_strategy(op_schema: OpSchema) -> StrategyType: - """ - Strategy for copy_ that preserves any Partial placement. - - copy_ simply copies data and should preserve whatever Partial placement - the destination has, regardless of the reduce_op type (sum, avg, max, min, etc.). - """ - return pointwise_strategy(op_schema, preserve_partial="all") - - -def common_pointwise_strategy( - op, - args_schema: Sequence[object], - followed_strategy: OpStrategy, - followed_strategy_index: int, - linearity: int = -1, - scalar_tensor_idx: int | None = None, - preserve_partial: str | None = None, -) -> OpStrategy: - """ - Common strategy for pointwise operations. - - Args: - args_schema: Input arguments schema - followed_strategy: Strategy to follow for output placement - followed_strategy_index: Index of the strategy being followed - linearity: depending on the operator, we support different types of linearity - -1: the operation does not support linearity - 0: the unary operation that supports linearity, output propagates partial. - 1: the binary operation supports add linearity, where it requires every operand - to be partial, output propagates partial. - 2: the binary operation supports multiplicative linearity, where it requires - the primary operand to be partial, and the other operands to be replicate, - output propagates partial. - scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh - to be different from the mesh of followed_strategy - preserve_partial: If set, Partial placements with this reduce_op will be preserved - through the operation (e.g., "max" for torch.maximum, "min" for torch.minimum). - """ - # handle broadcasting - common_shape = torch.broadcast_shapes( - *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] - ) - pointwise_strategy = OpStrategy([]) - - for op_spec in followed_strategy.strategies: - spec_to_follow = op_spec.output_spec - - out_placements: list[Placement] = [] - for placement in spec_to_follow.placements: - if isinstance(placement, Shard | _StridedShard): - shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) - common_ndim = len(common_shape) - new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim - if isinstance(placement, _StridedShard): - out_placements.append( - _StridedShard( - new_shard_dim, split_factor=placement.split_factor - ) - ) - else: - out_placements.append(Shard(new_shard_dim)) - elif isinstance(placement, Partial): - is_scalar_arg = any(isinstance(arg, _Number) for arg in args_schema) - propagate_partial = False - - # ordering matters here since NormPartial is a subclass of Partial - if isinstance(placement, _NormPartial): - # explanation for args_schema[1] >= 0 can be found in summary - # https://github.com/pytorch/pytorch/pull/170035 - propagate_partial = ( - op in norm_partial_avoidable_redistribute_ops - and args_schema[1] >= 0 # pyre-ignore[unsupported-operation] - ) - - elif isinstance(placement, Partial): - propagate_partial = not ( - op in p_sum_scalar_redistribute_ops and is_scalar_arg - ) - - # Check if this partial type should be preserved - # preserve_partial="all" preserves any Partial type (used for copy_) - if preserve_partial == "all": - out_placements.append(placement) - elif preserve_partial is not None and placement.is_partial( - preserve_partial - ): - out_placements.append(placement) - # note that only partial-sum and partial-avg are supported for linearity - elif ( - linearity >= 0 - and (placement.is_partial("sum") or placement.is_partial("avg")) - and propagate_partial - ): - # propagate the partial placement - out_placements.append(placement) - else: - # clear the partial placement if op does not support linearity - # by default we just replicate the partial, need to see if this - # is optimal for all cases - out_placements.append(Replicate()) - else: - out_placements.append(placement) - - input_specs: list[DTensorSpec] = [] - redistribute_costs: list[list[float]] = [] - for input_idx, input_arg in enumerate(args_schema): - if isinstance(input_arg, OpStrategy): - input_arg_spec = input_arg.strategies[0].output_spec - - # sanity check that all args that follow the same strategy - # are on the same DeviceMesh - if input_arg.mesh != followed_strategy.mesh: - # For the scalar tensor arg in fused ops, do not follow followed_strategy; - # instead, let the input mesh and the Replicate placements propagate through. - if input_idx == scalar_tensor_idx: - if not all(p == Replicate() for p in input_arg_spec.placements): - raise AssertionError - input_arg_target_spec = DTensorSpec( - mesh=input_arg.mesh, - placements=input_arg_spec.placements, - tensor_meta=input_arg_spec.tensor_meta, - ) - input_specs.append(input_arg_target_spec) - redistribute_costs.append( - generate_redistribute_costs( - input_arg, input_arg_target_spec - ) - ) - continue - else: - raise ValueError( - f"Could not run pointwise computation across different mesh: " - f"Found {input_arg.mesh} and {followed_strategy.mesh}!" - ) - - # every arg follow the out_placements, but need to handle broadcasting - input_arg_dims_map = infer_broadcast_dims_map( - common_shape, input_arg_spec.shape - ) - - # Determine if this input should convert Partial to Replicate based on linearity - should_convert_partial = ( - linearity == 2 - and input_idx - != followed_strategy_index # Don't convert the "followed" strategy - ) - - # For preserve_partial ops, check if non-followed input has incompatible - # Partial type. If so, it must be redistributed to Replicate first. - if ( - preserve_partial is not None - and input_idx != followed_strategy_index - ): - for out_p, in_p in zip(out_placements, input_arg_spec.placements): - if ( - isinstance(out_p, Partial) - and isinstance(in_p, Partial) - and out_p != in_p - ): - should_convert_partial = True - break - - input_target_placements = map_placements_after_broadcast( - tuple(out_placements), - common_shape, - input_arg_dims_map, - partial_to_replicate=should_convert_partial, - ) - - input_arg_target_spec = DTensorSpec( - mesh=followed_strategy.mesh, - placements=input_target_placements, - tensor_meta=input_arg_spec.tensor_meta, - ) - input_specs.append(input_arg_target_spec) - redistribute_costs.append( - generate_redistribute_costs(input_arg, input_arg_target_spec) - ) - - pointwise_strategy.strategies.append( - OpSpec( - output_specs=DTensorSpec( - mesh=followed_strategy.mesh, - placements=tuple(out_placements), - ), - input_specs=input_specs, - redistribute_cost=redistribute_costs, - ) - ) - return pointwise_strategy - - -p_sum_scalar_redistribute_ops = { - aten.add.Tensor, - aten.add_.Tensor, - aten.sub.Tensor, - aten.sub_.Tensor, -} - -norm_partial_avoidable_redistribute_ops = { - aten.div.Scalar, - aten.div_.Scalar, - aten.mul.Scalar, - aten.mul_.Scalar, -} - -for op in pointwise_ops: - _register_single_dim_pointwise(op) - -# TODO: add all for_each ops -for_each_ops = [ + # foreach variants aten._foreach_abs.default, aten._foreach_abs_.default, aten._foreach_addcdiv_.Scalar, @@ -1097,147 +521,27 @@ def common_pointwise_strategy( aten._foreach_addcmul_.Scalar, aten._foreach_addcmul_.ScalarList, aten._foreach_addcmul_.Tensor, - aten._foreach_clamp_max_.Scalar, - aten._foreach_clamp_min_.Scalar, - aten._foreach_div_.List, - aten._foreach_div_.Scalar, - aten._foreach_div_.ScalarList, - aten._foreach_div_.Tensor, - aten._foreach_div.List, - aten._foreach_div.Scalar, - aten._foreach_div.ScalarList, - aten._foreach_div.Tensor, aten._foreach_lerp_.Scalar, - aten._foreach_maximum_.List, - aten._foreach_mul.Scalar, - aten._foreach_mul.ScalarList, - aten._foreach_mul.Tensor, - aten._foreach_mul.List, - aten._foreach_mul_.Scalar, - aten._foreach_mul_.ScalarList, - aten._foreach_mul_.Tensor, - aten._foreach_mul_.List, aten._foreach_pow.List, aten._foreach_pow.ScalarList, - aten._foreach_neg.default, - aten._foreach_neg_.default, aten._foreach_reciprocal_.default, aten._foreach_sub.Scalar, aten._foreach_sub_.Scalar, - aten._foreach_sub.List, - aten._foreach_sub_.List, aten._foreach_sub.ScalarList, aten._foreach_sub_.ScalarList, aten._foreach_sqrt.default, aten._foreach_sqrt_.default, aten._foreach_zero_.default, - aten._foreach_exp.default, - aten._foreach_exp_.default, aten._foreach_cos.default, aten._foreach_cos_.default, aten._foreach_log.default, aten._foreach_log_.default, aten._amp_foreach_non_finite_check_and_unscale_.default, -] - -for_each_linearity_ops = [ + # foreach linearity variants aten._foreach_add.Scalar, aten._foreach_add_.Scalar, aten._foreach_add_.ScalarList, - aten._foreach_add.List, - aten._foreach_add_.List, -] - - -def list_pointwise_strategy( - op_schema: OpSchema, linearity: bool = False -) -> StrategyType: - """ - Apply the pointwise strategy to the zipped arguments. For example, if we - run a foreach add of two lists l1 and l2, then we apply the pointwise - strategy on each pair (l1[i], l2[i]). If the first argument is a list but - the second (or later) one is a tensor, then we broadcast the tensor by - replicating it into a list with the length of the first argument. - - Args: - mesh (DeviceMesh): device mesh for pointwise ops - op_schema (OpSchema): schema of the operator to generate strategy for - linearity (bool): specify whether op(a) + op(b) = op(a + b) - - Returns: - OpStrategy: generated strategy - """ - - def args_tuple_strategies( - args_schema: tuple[object, ...], - ) -> list[TupleStrategy | None]: - first_arg = args_schema[0] - if not isinstance(first_arg, TupleStrategy): - raise AssertionError - strategy_len = len(first_arg.children) - tuple_strategies: list[TupleStrategy | None] = [] - for arg_idx, arg in enumerate(args_schema): - if isinstance(arg, TupleStrategy): - # every tuple strategy should have the same length - if len(arg.children) != strategy_len: - raise AssertionError - tuple_strategies.append(arg) - elif isinstance(arg, OpStrategy): - if arg_idx > 0: # implicitly broadcast - tuple_strategies.append( - TupleStrategy([arg for _ in range(strategy_len)]) - ) - else: - raise RuntimeError( - f"list op only supports tuple strategy! {op_schema}" - ) - else: - # insert None as placeholder so that the idx of arg is kept - tuple_strategies.append(None) - return tuple_strategies - - args_strategies = args_tuple_strategies(op_schema.args_schema) - follow_strategy: TupleStrategy = not_none(args_strategies[0]) - list_strategy: list[OpStrategy] = [] - - for child_idx, child_strtgy in enumerate(follow_strategy.children): - if not isinstance(child_strtgy, OpStrategy): - raise AssertionError - args_schema: list[OpStrategy | None] = [ - cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None - for arg_strategy in args_strategies - ] - pointwise_strategy: OpStrategy = common_pointwise_strategy( - op_schema.op, - args_schema, - child_strtgy, - linearity, - scalar_tensor_idx=( - _FUSED_OP_SCALAR_IDX if op_schema.op in fused_ops else None - ), - ) - list_strategy.append(pointwise_strategy) - return TupleStrategy(list_strategy) - - -def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: - """ - for each list op stratgy that supports linearity - """ - return list_pointwise_strategy(op_schema, linearity=True) - - -for op in for_each_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_pointwise_strategy - ) - -for op in for_each_linearity_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_linear_pointwise_strategy - ) - -fused_ops = [ + # fused optimizer ops aten._fused_adam_.default, aten._fused_adam.default, aten._fused_adam.tensor_lr, @@ -1249,16 +553,31 @@ def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: ] -# The state_steps arg of fused adam / adamw is a Replicate scalar tensor, which will be put on -# the compute_mesh of an op across all parameter groups, even when not all parameter groups -# are on the same device mesh. This idx will help avoid hitting exceptions or unnecessary -# redistribute during sharding propagation. -_FUSED_OP_SCALAR_IDX = 5 +def _get_pointwise_ops_from_tag() -> list[OpOverload]: + """ + Auto-discover pointwise ops via torch.Tag.pointwise, from ops.aten, ops.prims. + """ + ops = [] + for ns in [torch.ops.aten, torch.ops.prims]: + for attr_name in dir(ns): + attr = getattr(ns, attr_name) + if isinstance(attr, torch._ops.OpOverloadPacket): + for overload_name in attr.overloads(): + op = getattr(attr, overload_name) + if torch.Tag.pointwise in op.tags: + ops.append(op) + return ops -for op in fused_ops: - register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( - list_pointwise_strategy - ) + +pointwise_ops = [ + op + for op in _get_pointwise_ops_from_tag() + _extra_pointwise_ops + if op not in _specially_registered_ops +] + + +for op in pointwise_ops: + _register_single_dim_pointwise(op) def register_inductor_prims() -> None: diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index 6da1e17db361b..a15cf411a44f4 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates -from collections.abc import Sequence, Sized +from collections.abc import Callable, Sequence, Sized from typing import cast import torch @@ -13,14 +13,12 @@ OpSchema, OpSpec, OpStrategy, - OutputSharding, PlacementList, RuntimeSchemaInfo, StrategyType, TensorMeta, TupleStrategy, ) -from torch.distributed.tensor._ops._common_rules import pointwise_rule from torch.distributed.tensor._ops.single_dim_strategy import ( _ShardingPlaceholder, register_single_dim_strategy, @@ -32,11 +30,11 @@ is_tensor_partial, normalize_dim, register_op_strategy, - register_prop_rule, shift_shard_dims_after_insert, shift_shard_dims_after_remove, ) from torch.distributed.tensor.placement_types import ( + _is_shard_like, _MaskPartial, Partial, Placement, @@ -102,9 +100,41 @@ def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType: )(propagate_single_input_strategy) -register_op_strategy( - aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) -)(propagate_single_input_strategy) +def _partial_needs_reduce_for_dtype_cast( + reduce_op: str, + src_dtype: torch.dtype, + target_dtype: torch.dtype | None, +) -> bool: + """Return True when reduce_op does not commute with the dtype cast.""" + if target_dtype is None or src_dtype == target_dtype: + return False + if target_dtype == torch.bool: + return True + if reduce_op in ("max", "min"): + return False + return src_dtype.is_floating_point and not target_dtype.is_floating_point + + +@register_single_dim_strategy( + aten._to_copy.default, + schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]), + allow_unbacked_sharding=True, + allow_uneven_sharding=True, +) +def _to_copy_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = cast(TensorMeta, args_schema[0]) + src_dtype = input_meta.dtype + target_dtype = cast(torch.dtype | None, kwargs_schema.get("dtype", None)) + + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for dim in range(len(input_meta.shape)): + strategies.append([_ShardingPlaceholder(dim), _ShardingPlaceholder(dim)]) + for reduce_op in Partial.ALL_REDUCE_OPS: + if not _partial_needs_reduce_for_dtype_cast(reduce_op, src_dtype, target_dtype): + strategies.append([Partial(reduce_op), Partial(reduce_op)]) + return strategies @register_op_strategy( @@ -252,7 +282,18 @@ def new_factory_strategy(op_schema: OpSchema) -> StrategyType: ) ) - if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded(): + # Sharded inputs always propagate. Uninitialized factories (new_empty*) + # also propagate Partial — the memory is about to be overwritten, so the + # placement just needs to match the source of the subsequent write + # (e.g., autograd's clone_obey_contract: new_empty_strided + copy_). + # Initialized factories (new_zeros/ones/full) keep Replicate to avoid + # incorrect values after Partial reduction (e.g. ones * world_size). + is_uninitialized_factory = op_schema.op in ( + aten.new_empty.default, + aten.new_empty_strided.default, + ) + can_propagate_placement = input_spec.is_sharded() or is_uninitialized_factory + if tuple(input_shape) == tuple(output_shape) and can_propagate_placement: new_factory_strategy.strategies.append( OpSpec( output_specs=input_spec, @@ -469,7 +510,7 @@ def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: new_placements: list[Placement] = [] for placement in output_spec.placements: # Redistribute to replicate only if the dim is sharded and matches the slice dim - if isinstance(placement, Shard) and placement.dim == dim: + if _is_shard_like(placement) and placement.dim == dim: new_placements.append(Replicate()) else: new_placements.append(placement) @@ -487,7 +528,7 @@ def unshard_tensor_dim( ) -> tuple[Placement, ...]: """Disallow the given tensor dimension to be sharded.""" return tuple( - p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() + p if (not _is_shard_like(p) or p.dim != dim) else Replicate() for p in placements ) @@ -496,10 +537,8 @@ def replicate_tensor_dim( placements: Sequence[Placement], dim: int ) -> tuple[Placement, ...]: """Force the given tensor dimension to be replicated.""" - # Not using p.is_shard() to avoid mypy complain about Placement not having - # attribute dim. return tuple( - Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p + Replicate() if p.is_partial() or (_is_shard_like(p) and p.dim == dim) else p for p in placements ) @@ -570,6 +609,70 @@ def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: return slice_scatter_strategy +@register_single_dim_strategy( + [aten.select_scatter.default], + schema_info=RuntimeSchemaInfo(1), +) +def select_scatter_single_dim_strategy( + op: OpOverload, + args_schema: ArgsType, + kwargs_schema: KwargsType, +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = args_schema[0] + if not isinstance(input_meta, TensorMeta): + raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}") + ndim = len(input_meta.shape) + dim = normalize_dim(cast(int, args_schema[2]), ndim) + # [output, self, src] — src has the select dim removed + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for d in range(ndim): + if d == dim: + continue + strategies.append( + [ + _ShardingPlaceholder(d), + _ShardingPlaceholder(d), + _ShardingPlaceholder(d if d < dim else d - 1), + ] + ) + return strategies + + +@register_single_dim_strategy( + [aten.diagonal_scatter.default], + schema_info=RuntimeSchemaInfo(1), +) +def diagonal_scatter_single_dim_strategy( + op: OpOverload, + args_schema: ArgsType, + kwargs_schema: KwargsType, +) -> list[list[Placement | _ShardingPlaceholder]]: + input_meta = args_schema[0] + if not isinstance(input_meta, TensorMeta): + raise AssertionError(f"Expected TensorMeta, got {type(input_meta)}") + ndim = len(input_meta.shape) + # schema: (self, src, offset=0, dim1=0, dim2=1) + dim1 = cast(int, args_schema[3]) if len(args_schema) > 3 else 0 + dim2 = cast(int, args_schema[4]) if len(args_schema) > 4 else 1 + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + min_d, max_d = min(dim1, dim2), max(dim1, dim2) + # [output, self, src] — src has dim1/dim2 removed and diagonal appended + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for d in range(ndim): + if d in (dim1, dim2): + continue + removed = (1 if d > min_d else 0) + (1 if d > max_d else 0) + strategies.append( + [ + _ShardingPlaceholder(d), + _ShardingPlaceholder(d), + _ShardingPlaceholder(d - removed), + ] + ) + return strategies + + @register_op_strategy(aten._local_scalar_dense.default) def replica_only_strategy(op_schema: OpSchema) -> StrategyType: """Only allow replication on the input/output.""" @@ -711,11 +814,13 @@ def merge_placement( # check each placement for the current arg placement # to see if we want to merge/adjust the placement to follow # the priority: Partial -> Shard -> Replicate + # _StridedShard.__eq__ compares both dim and split_factor, + # so two _StridedShard with different split_factor won't match here. if cur_placement == new_placement: return cur_placement if cur_placement.is_partial(): - if new_placement.is_shard(): + if _is_shard_like(new_placement): # follow new placement return new_placement elif new_placement.is_partial(): @@ -724,8 +829,8 @@ def merge_placement( else: # follow partial return cur_placement - elif cur_placement.is_shard(): - if new_placement.is_shard(): + elif _is_shard_like(cur_placement): + if _is_shard_like(new_placement): # cur/new placement are different sharding (i.e. different shard dim) # currently fallback to replicate all args return Replicate() @@ -963,7 +1068,88 @@ def index_select_single_dim_strategy( @register_single_dim_strategy( - [aten.index_put.default, aten._index_put_impl_.default], + aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True) +) +def index_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + values_meta, multi_indices_meta = args_schema + if not isinstance(values_meta, TensorMeta): + raise AssertionError(f"Expected TensorMeta, got {type(values_meta)}") + if not isinstance(multi_indices_meta, (list, tuple)): + raise AssertionError(f"Expected list or tuple, got {type(multi_indices_meta)}") + + indexed_dims = [i for i, idx in enumerate(multi_indices_meta) if idx is not None] + non_indexed_dims = [ + i for i in range(len(values_meta.shape)) if i not in set(indexed_dims) + ] + + index_metas = [idx for idx in multi_indices_meta if idx is not None] + if not all(isinstance(m, TensorMeta) for m in index_metas): + raise AssertionError("Expected all index metas to be TensorMeta") + broadcast_ndim = max(len(m.shape) for m in index_metas) + num_indices = len(indexed_dims) + + # Determine where index output dims are inserted in the result + all_consecutive = all( + indexed_dims[i + 1] - indexed_dims[i] == 1 for i in range(len(indexed_dims) - 1) + ) + insert_dim = indexed_dims[0] if all_consecutive else 0 + + def values_dim_to_output_dim(d: int) -> int: + if d < insert_dim: + return d + return d + broadcast_ndim - sum(1 for idx_dim in indexed_dims if d > idx_dim) + + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + + # Shard values on a non-indexed dim, all indices replicated + for d in non_indexed_dims: + out_dim = values_dim_to_output_dim(d) + rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)] + rule.append(_ShardingPlaceholder(d)) + rule.extend([Replicate()] * num_indices) + strategies.append(rule) + + # Shard indices on the same broadcast dim. Each index tensor may + # have a different ndim, so we map broadcast dim → tensor dim via + # left-padding. Tensors with size 1 on that dim are replicated + # (broadcast semantics). + for bd in range(broadcast_ndim): + per_tensor: list[tuple[int, int]] = [] # (tensor_dim, size) + for m in index_metas: + offset = broadcast_ndim - len(m.shape) + if bd < offset: + per_tensor.append((-1, 1)) # implicit broadcast + else: + td = bd - offset + per_tensor.append((td, m.shape[td])) + if all(s == 1 for _, s in per_tensor): + continue # all broadcast-only, skip + out_dim = bd + insert_dim + rule: list[Placement | _ShardingPlaceholder] = [_ShardingPlaceholder(out_dim)] + rule.append(Replicate()) + for td, s in per_tensor: + if s > 1: + rule.append(_ShardingPlaceholder(td)) + else: + rule.append(Replicate()) + strategies.append(rule) + + # Partial passthrough from values + for reduce_op in Partial.LINEAR_REDUCE_OPS: + rule: list[Placement | _ShardingPlaceholder] = [ + Partial(reduce_op), + Partial(reduce_op), + ] + rule.extend([Replicate()] * num_indices) + strategies.append(rule) + + return strategies + + +@register_single_dim_strategy( + [aten.index_put.default, aten.index_put_.default, aten._index_put_impl_.default], schema_info=RuntimeSchemaInfo(needs_pytree=True), ) def index_put_single_dim_strategy( @@ -986,9 +1172,10 @@ def index_put_single_dim_strategy( serves as an indexing coordinate into self. Each coordinate selects a tensor element, or a slice (if non-indexed dims exist). - values is a tensor that's broadcastable to (*broadcasted_shape, *non_indexed_dim_sizes). - Each position in broadcasted_shape selects a slice of values, - and writes it into the corresponding slice of self. + values is a tensor broadcastable to the indexing output shape. + When indexed dims are consecutive starting at dim k, this shape is + (*self[:k], *broadcast_shape, *self[k+n_indexed:]). When indexed + dims are non-consecutive, it is (*broadcast_shape, *non_indexed_dims). Sharding rules (possibly conservative and incomplete): - Index tensors: always Replicate (every rank needs all coordinates). @@ -997,10 +1184,6 @@ def index_put_single_dim_strategy( The exception is broadcasted value dimensions (size 1) - we require Replicate, but can shard self. - Additionally, we allow the full Partial rule on non-indexing tensors. - TODO(pianpwk): support non-contiguous indexed dims (None gaps in indices tuple, - e.g. (idx, None, idx)). Currently blocked by a single_dim_strategy infra bug: - _get_num_tensor_inputs counts None TupleStrategy children but args_strategy - drops them, causing a length mismatch in expand_to_full_mesh_op_strategy. """ self_meta = cast(TensorMeta, args[0]) indices_meta = cast(tuple[TensorMeta | None, ...], args[1]) @@ -1013,25 +1196,55 @@ def index_put_single_dim_strategy( values_ndim = len(values_meta.shape) # Explicitly compute the broadcast shape of the index tensors. - # We could probably derive it in a smarter way, but this is more explicit. index_shapes = [idx.shape for idx in indices_meta if idx is not None] broadcast_ndim = len(torch.broadcast_shapes(*index_shapes)) if index_shapes else 0 - # values shape = (*broadcast_shape, *non_indexed_dim_sizes) # Strategy format: [output, input, *indices, value] # The infra flattens the indices list and drops None entries, so only # non-None index tensors get a placement slot (all Replicate). + # + # Values dim mapping depends on whether indexed dims are contiguous: + # Contiguous (e.g., (None, idx0, idx1)): broadcast replaces indexed block in-place. + # values shape = (*non_indexed_before, *broadcast_shape, *non_indexed_after) + # Non-contiguous (e.g., (idx0, None, idx1)): broadcast goes to front. + # values shape = (*broadcast_shape, *non_indexed_dim_sizes) + indexed_dims_sorted = sorted(indexed_dims) + contiguous_indexed = len(indexed_dims_sorted) <= 1 or ( + indexed_dims_sorted[-1] - indexed_dims_sorted[0] + 1 == len(indexed_dims_sorted) + ) + strategies: list[list[Placement | _ShardingPlaceholder]] = [] - for values_dim in range(broadcast_ndim, values_ndim): - self_dim = non_indexed_dims[values_dim - broadcast_ndim] + for i, self_dim in enumerate(non_indexed_dims): + if contiguous_indexed and indexed_dims_sorted: + # Broadcast replaces the indexed block in-place. + first_indexed = indexed_dims_sorted[0] + if self_dim < first_indexed: + values_dim = self_dim + else: + values_dim = self_dim - n_indexed + broadcast_ndim + else: + # Broadcast goes to front (non-contiguous or no indexed dims). + values_dim = broadcast_ndim + i + + # values_dim is the position in the result tensor, but values may + # have fewer dims (right-aligned broadcasting). Convert to the + # actual values tensor dimension. + result_ndim = broadcast_ndim + len(non_indexed_dims) + values_tensor_dim = values_dim - (result_ndim - values_ndim) + + if values_tensor_dim < 0: + values_placement: Placement | _ShardingPlaceholder = Replicate() + elif values_meta.shape[values_tensor_dim] == 1: + values_placement = Replicate() + else: + values_placement = _ShardingPlaceholder(values_tensor_dim) + strategies.append( [ _ShardingPlaceholder(self_dim), _ShardingPlaceholder(self_dim), *([Replicate()] * n_indexed), - Replicate() - if values_meta.shape[values_dim] == 1 - else _ShardingPlaceholder(values_dim), + values_placement, ] ) @@ -1047,133 +1260,96 @@ def index_put_single_dim_strategy( return strategies -@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) -def prop_index(op_schema: OpSchema) -> OutputSharding: - """ - Expect replicated on the first input; _mostly_ pointwise on the second input. +def _index_dim_strategy( + args_schema: ArgsType, + shard_row: Callable[[int], list[Placement | _ShardingPlaceholder]], + partial_rules: list[list[Placement | _ShardingPlaceholder]] | None = None, +) -> list[list[Placement | _ShardingPlaceholder]]: + """Common strategy for index ops that shard on all dims except the indexed dim. - TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. + Args: + shard_row: given a dim d, returns the strategy row for sharding on that dim. + partial_rules: additional Partial passthrough strategies. """ - # Current sharding constraints: - # For values: - # 1. We currently require that the dimension of values_spec be replicated or partial - # if they are being indexed on. - # 2. Other dimensions of values_spec can remain sharded if they are so. - # For indices: - # Indices can be either sharded or replicated. All index tensors need to be sharded - # in a compatible way, following the pointwise rule (including resolving Partial - # into either sharded or replicated) - - values_spec, multi_indices_spec = op_schema.args_schema - if not isinstance(values_spec, DTensorSpec): - raise AssertionError(f"Expected DTensorSpec, got {type(values_spec)}") - if not isinstance(multi_indices_spec, list): - raise AssertionError(f"Expected list, got {type(multi_indices_spec)}") - multi_indices_spec = cast(list[DTensorSpec | None], multi_indices_spec) - valid_indices_spec: list[tuple[int, DTensorSpec]] = [ - (i, a) for i, a in enumerate(multi_indices_spec) if a is not None - ] + self_meta = cast(TensorMeta, args_schema[0]) + ndim = len(self_meta.shape) + dim = normalize_dim(cast(int, args_schema[1]), ndim) + strategies: list[list[Placement | _ShardingPlaceholder]] = [] + for d in range(ndim): + if d != dim: + strategies.append(shard_row(d)) + if partial_rules: + strategies.extend(partial_rules) + return strategies - # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. - # Here, we piggyback on the pointwise sharding rule for indices. - indices_out = pointwise_rule( - OpSchema( - op=op_schema.op, - args_schema=tuple(v[1] for v in valid_indices_spec), - kwargs_schema={}, - ) - ) - need_reshard_on_indices = indices_out.output_spec is None - if not need_reshard_on_indices: - # this means that our inputs are already sharded properly and we will use that as our indices_spec - if not isinstance(indices_out.output_spec, DTensorSpec): - raise AssertionError( - f"Expected DTensorSpec, got {type(indices_out.output_spec)}" - ) - indices_spec: DTensorSpec = indices_out.output_spec - else: - if indices_out.redistribute_schema is None: - raise AssertionError("redistribute_schema should not be None") - valid_indices_suggestion = indices_out.redistribute_schema - for i, v in enumerate(valid_indices_suggestion.args_spec): - multi_indices_spec[valid_indices_spec[i][0]] = v - # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then - # use that to compute our ideal values_spec - indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec - if not isinstance(indices_output_spec, DTensorSpec): - raise AssertionError( - f"Expected DTensorSpec, got {type(indices_output_spec)}" - ) - indices_spec = indices_output_spec +@register_single_dim_strategy( + [aten.index_fill.int_Scalar, aten.index_fill_.int_Scalar], + schema_info=RuntimeSchemaInfo(1), +) +def index_fill_scalar_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + # index_fill(self, dim, index, value) — fills self[..., index, ...] with scalar value. + # Partial rules: each rank fills with the same scalar v, then reduces. + # Only idempotent reduces work: avg(v,v,...,v)=v, max(v,v,...,v)=v, min(v,v,...,v)=v. + # sum and product fail: sum(v,v,...,v)=nv, product(v,v,...,v)=v^n. + return _index_dim_strategy( + args_schema, + lambda d: [ + _ShardingPlaceholder(d), # result + _ShardingPlaceholder(d), # self + Replicate(), # value (scalar, same on all ranks) + ], + [[Partial(op), Partial(op), Replicate()] for op in ("avg", "max", "min")], + ) - lookup_dims = {v[0] for v in valid_indices_spec} - need_reshard_on_values = tuple( - (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) - for vp, ip in zip(values_spec.placements, indices_spec.placements) +@register_single_dim_strategy( + [aten.index_fill.int_Tensor, aten.index_fill_.int_Tensor], + schema_info=RuntimeSchemaInfo(1), +) +def index_fill_tensor_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + # index_fill(self, dim, index, value) — fills self[..., index, ...] with 0-d tensor value. + # Partial rules: each rank fills with its partial value v_i, then reduces. + # All reduce ops work because reduce(v_0, ..., v_{n-1}) = V (the global value) + # regardless of op, since fill is a pure replacement (no mixing with self). + return _index_dim_strategy( + args_schema, + lambda d: [ + _ShardingPlaceholder(d), # result + _ShardingPlaceholder(d), # self + Replicate(), # index + Replicate(), # value + ], + [ + [Partial(op), Partial(op), Replicate(), Partial(op)] + for op in Partial.ALL_REDUCE_OPS + ], ) - if not need_reshard_on_indices and not any(need_reshard_on_values): - value_placements = values_spec.placements - all_dims_consecutive = all( - b[0] - a[0] == 1 - for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) - ) - if all_dims_consecutive: - # if all index vectors are consecutives, insert at the dimension of the first index - insert_dim: int = valid_indices_spec[0][0] - else: - # else, insert on the first dimension - insert_dim = 0 - - def place(vp: Placement, ip: Placement) -> Placement: - if isinstance(vp, Shard): - return Shard( - vp.dim - if vp.dim < insert_dim - # accounts for the offset in output dimensions - else vp.dim - + indices_spec.ndim - - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) - ) - if isinstance(ip, Shard): - return Shard(ip.dim + insert_dim) - # Partial or Replicated - return vp - - value_placements = tuple( - place(vp, ip) - for vp, ip in zip(values_spec.placements, indices_spec.placements) - ) - result = OutputSharding( - output_spec=DTensorSpec( - mesh=values_spec.mesh, - placements=value_placements, - ) - ) - return result - else: - result = OutputSharding( - output_spec=None, - redistribute_schema=OpSchema( - op=op_schema.op, - args_schema=( - DTensorSpec( - mesh=values_spec.mesh, - placements=tuple( - Replicate() if need_reshard_on_values[i] else v - for i, v in enumerate(values_spec.placements) - ), - tensor_meta=values_spec.tensor_meta, - ), - multi_indices_spec, - ), - kwargs_schema=op_schema.kwargs_schema, - ), - ) - return result +@register_single_dim_strategy( + [aten.index_reduce.default, aten.index_reduce_.default], + schema_info=RuntimeSchemaInfo(1), +) +def index_reduce_single_dim_strategy( + op: OpOverload, args_schema: ArgsType, kwargs_schema: KwargsType +) -> list[list[Placement | _ShardingPlaceholder]]: + # index_reduce(self, dim, index, source, reduce) — reduces source into self at index positions. + # No partial rules: reduce ops are "mean"/"amax"/"amin"/"prod", which don't match + # any Partial reduce op names ("avg"/"max"/"min"/"product"/"sum"). + return _index_dim_strategy( + args_schema, + lambda d: [ + _ShardingPlaceholder(d), # result + _ShardingPlaceholder(d), # self + Replicate(), # index + _ShardingPlaceholder(d), # source + ], + ) @register_op_strategy( diff --git a/torch/distributed/tensor/_ops/_view_ops.py b/torch/distributed/tensor/_ops/_view_ops.py index e6e1478308e42..34adf22386878 100644 --- a/torch/distributed/tensor/_ops/_view_ops.py +++ b/torch/distributed/tensor/_ops/_view_ops.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +import math from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import cast +from typing import cast, NamedTuple import torch from torch import Tensor @@ -40,6 +41,13 @@ Shape = tuple[int, ...] +class ClaimedDim(NamedTuple): + """An (input_dim, output_dim) pair claimed by a mesh dim's _StridedShard rewrite.""" + + input_dim: int + output_dim: int + + @dataclass class DimSpec: """Specifies how an output dimension maps to an input dimension.""" @@ -57,12 +65,27 @@ class Singleton(DimSpec): """Output dimension is a singleton.""" -@dataclass +@dataclass(eq=False) class InputDim(DimSpec): """Output dimension maps directly to an input dimension.""" input_dim: int + def __eq__(self, other: object) -> bool: + """Raises TypeError for non-DimSpec comparisons to catch accidental + ``shard.dim == input_dim`` bugs where ``.input_dim`` was intended.""" + if isinstance(other, InputDim): + return self.input_dim == other.input_dim + if not isinstance(other, DimSpec): + raise TypeError( + f"Cannot compare InputDim with {type(other).__name__}. " + f"Did you mean to use .input_dim?" + ) + return NotImplemented + + def __hash__(self) -> int: + return hash((InputDim, self.input_dim)) + @dataclass class Broadcast(DimSpec): @@ -465,18 +488,22 @@ def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: return tuple(dimmap) -def dim_squeeze(shape: Shape, dim: int | None = None) -> DimMap: - # FIXME: this is wrong when dim=None and one of the dimensions - # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could - # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to - # removal of a dimension that is not actually a singleton. +def dim_squeeze(shape: Shape, dim: DimsType | None = None) -> DimMap: + # Operates on local shape; sharding_prop rewrites squeeze ops to squeeze.dims + # with only globally-singleton dims before this is called. from torch.fx.experimental.symbolic_shapes import guard_or_true + ndim = len(shape) + if dim is None: + target_dims = set(range(ndim)) + elif isinstance(dim, int): + target_dims = {normalize_dim(dim, ndim)} + else: + target_dims = set(normalize_dims(dim, ndim)) return tuple( InputDim(i) for i, s in enumerate(shape) - if guard_or_true(s > 1) - or (dim is not None and i != normalize_dim(dim, len(shape))) + if guard_or_true(s > 1) or i not in target_dims ) @@ -554,186 +581,673 @@ def propagate_shape_and_sharding( Sharding propagation follows mapped dimensions: - An output dimension that maps directly to an input dimension is sharded equally - - An output dimension that is a flattened set of input dimensions can only be - sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a flattened set of input dimensions can be sharded: + the first sharded dim stays as Shard, non-first sharded dims become _StridedShard - An output dimension that is a split of the input dimension can only be sharded if the leftmost split size is divisible by the mesh dimension """ - if not len(input_src_placements) == len(mesh_sizes): - raise AssertionError(f"{input_src_placements} != {mesh_sizes}") - # for each input dim, for each mesh dim, provides a list of possible shardable dimensions - mesh_ndim = len(mesh_sizes) - shardable_dims: dict[int, list[bool]] = {} - - # in case an input dimension disappears (e.g. collapsing, reduction) - # we cannot shard in that dimension (we need a replication fall-back rule) - seen_input_dims: set[int] = set() + propagator = _ViewShardingPropagator( + input_src_placements, global_input_shape, rule, mesh_sizes, strict_view + ) + input_tgt_placements, input_to_output_tensor_dims = propagator.analyze() + output_placements = propagator.rewrite_output_placements( + input_tgt_placements, input_to_output_tensor_dims + ) + return input_tgt_placements, output_placements - def collect_used_inputs(cmd: DimSpec) -> None: - if isinstance(cmd, InputDim): - seen_input_dims.add(cmd.input_dim) - for inp in cmd.inputs(): - collect_used_inputs(inp) - for cmd in rule: - collect_used_inputs(cmd) - for dim in range(len(global_input_shape)): - shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim +class _ViewShardingPropagator: + """Two-phase sharding propagator for view ops. + + Phase 1 — ``analyze()``: + Walks the DimMap rule and returns: + - ``input_tgt_placements``: input placements with unshardable dims + demoted to Replicate. + - ``input_to_output_tensor_dims``: maps each input tensor dim to its + output dim(s). Cardinality encodes the op type: 1→1 for InputDim, + N→1 for Flatten, 1→N for Split/unflatten. + + Phase 2 — ``rewrite_output_placements()``: + Consumes both Phase 1 outputs. Iterates mesh dims 0..n-1, maintaining: + - ``strided_shard_claimed_dims``: (input_dim, output_dim) pairs already assigned + to a mesh dim by _StridedShard rewriting. + - ``local_tensor_shapes``: global shape progressively divided by each + mesh dim's shard size. + For each surviving Shard/_StridedShard, looks up the output dim(s) and + produces the final output placement. + """ - def maybe_get_shard_mesh_dim_and_placement( - input_dim: InputDim, - ) -> tuple[int | None, Shard | _StridedShard | None]: - # if input_dim is sharded, return the mesh_dim and shard placement - for i, placement in enumerate(input_src_placements): + def __init__( + self, + input_src_placements: Sequence[Placement], + global_input_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, + strict_view: bool, + ) -> None: + self.input_src_placements = input_src_placements + self.global_input_shape = global_input_shape + self.rule = rule + self.mesh_sizes = mesh_sizes + self.strict_view = strict_view + self.mesh_ndim = len(mesh_sizes) + + # shard_allowed[input_dim][mesh_dim]: whether input_dim can stay + # sharded on mesh_dim. Populated by _analyze_dim and its helpers. + self.shard_allowed: dict[int, list[bool]] = {} + # Mesh dims whose _StridedShard has already been matched to an output dim. + # Populated by _analyze_split. + self.matched_strided_mesh_dims: set[int] = set() + + # ------------------------------------------------------------------ + # Public API: analyze → rewrite_output_placements + # ------------------------------------------------------------------ + + def analyze( + self, + ) -> tuple[Sequence[Placement], dict[int, list[int]]]: + """Phase 1: walk the DimMap rule, return (input_tgt_placements, input_to_output_tensor_dims).""" + input_dims_in_rule = self._input_dims_in_rule(self.rule) + + # Default: shardable if the dim appears in the rule. Refined by _analyze_*. + for dim in range(len(self.global_input_shape)): + self.shard_allowed[dim] = [dim in input_dims_in_rule] * self.mesh_ndim + + # Walk the rule to refine shard_allowed and build input_to_output_tensor_dims. + # + # Flatten example: view([2, 3, 4], [6, 4]) + # rule = (Flatten(InputDim(0), InputDim(1)), InputDim(2)) + # output_dim=0 (Flatten): hits the isinstance(cmd, Flatten) branch. + # Maps input dims 0 and 1 to output dim 0. Result: {0: [0], 1: [0]} + # output_dim=1 (InputDim(2)): hits the len(in_dims) > 0 branch. + # Maps input dim 2 to output dim 1. Result: {0: [0], 1: [0], 2: [1]} + # + # Split example: view([6], [2, 3]) + # rule = (Split(InputDim(0), (2,3), 0), Split(InputDim(0), (2,3), 1)) + # output_dim=0 (split_id=0): hits the len(in_dims) > 0 branch. + # Maps input dim 0 to output dim 0. Result: {0: [0]} + # output_dim=1 (split_id=1): hits the isinstance(cmd, Split) branch + # because _analyze_split returns [] for split_id>0. Chases root + # InputDim(0) and appends output dim 1. Result: {0: [0, 1]} + input_to_output_tensor_dims: dict[int, list[int]] = {} + for output_dim, cmd in enumerate(self.rule): + in_dims = self._analyze_dim(cmd) + if isinstance(cmd, Flatten): + for in_dim in in_dims: + if in_dim.input_dim in input_to_output_tensor_dims: + raise AssertionError( + f"Input dim {in_dim.input_dim} already mapped to output dims " + f"{input_to_output_tensor_dims[in_dim.input_dim]}" + ) + input_to_output_tensor_dims[in_dim.input_dim] = [output_dim] + elif len(in_dims) > 0: + # InputDim (identity) or Split(split_id=0). + in_dim = in_dims[0] + if in_dim.input_dim not in input_to_output_tensor_dims: + input_to_output_tensor_dims[in_dim.input_dim] = [output_dim] + else: + input_to_output_tensor_dims[in_dim.input_dim].append(output_dim) + elif isinstance(cmd, Split): + # Split(split_id>0): _analyze_split returned [], so chase the + # root input dim and append this output dim to its existing entry. + # + # Flatten+Split example: view([2, 3], [3, 2]) + # rule = (Split(Flatten(InputDim(0), InputDim(1)), (3,2), 0), + # Split(Flatten(InputDim(0), InputDim(1)), (3,2), 1)) + # output_dim=0 (split_id=0): same as Split example above. + # Result: {0: [0]} + # output_dim=1 (split_id=1): same as Split example, but + # the chase unwraps the inner Flatten to find InputDim(0). + # Result: {0: [0, 1]} + root_spec = cmd.input_dim + while isinstance(root_spec, (Flatten, Split)): + if isinstance(root_spec, Flatten): + # _analyze_flatten always returns input_dims[0] as + # the first element (either as the only shardable dim + # in non-strict mode, or as the fallback when nothing + # is sharded), so split_id=0 uses it as the key in + # input_to_output_tensor_dims. Use [0] here to match. + root_spec = root_spec.input_dims[0] + else: + root_spec = root_spec.input_dim + root = root_spec if isinstance(root_spec, InputDim) else None + if root is not None and root.input_dim in input_to_output_tensor_dims: + input_to_output_tensor_dims[root.input_dim].append(output_dim) + + input_tgt_placements: list[Placement] = [] + for mesh_dim, p in enumerate(self.input_src_placements): if ( - isinstance(placement, Shard | _StridedShard) - and placement.dim == input_dim.input_dim + isinstance(p, Shard | _StridedShard) + and not self.shard_allowed[p.dim][mesh_dim] ): - return i, placement - return None, None + if self.strict_view: + raise RuntimeError( + f"This operation would remove or reshape sharded " + f"dimension {p.dim}, which requires redistribution. " + f"Please redistribute the input first." + ) + input_tgt_placements.append(Replicate()) + else: + input_tgt_placements.append(p) + return input_tgt_placements, input_to_output_tensor_dims + + def rewrite_output_placements( + self, + input_tgt_placements: Sequence[Placement], + input_to_output_tensor_dims: dict[int, list[int]], + ) -> list[Placement]: + """Phase 2: consume analyze() outputs, return final output placements.""" + # (input_dim, output_dim) pairs claimed by earlier mesh dims + # (via _rewrite_strided_shard), to avoid double-assignment. + strided_shard_claimed_dims: set[ClaimedDim] = set() + # Starts as global_input_shape; each mesh dim divides its sharded dim. + local_tensor_shapes: list[int] = list(self.global_input_shape) + + output_placements: list[Placement] = [] + # Process mesh dims in order; _rewrite_*_shard relies on this for + # truncating division safety in local_tensor_shapes. + for mesh_dim, p in enumerate(input_tgt_placements): + if isinstance(p, Shard): + placement, local_tensor_shapes = self._rewrite_plain_shard( + p, + mesh_dim, + input_tgt_placements, + strided_shard_claimed_dims, + local_tensor_shapes, + input_to_output_tensor_dims, + ) + output_placements.append(placement) + elif isinstance(p, _StridedShard): + placement, local_tensor_shapes = self._rewrite_strided_shard( + p, + mesh_dim, + input_tgt_placements, + strided_shard_claimed_dims, + local_tensor_shapes, + input_to_output_tensor_dims, + ) + output_placements.append(placement) + else: + output_placements.append(p) + return output_placements + + # ------------------------------------------------------------------ + # Analysis phase helpers + # ------------------------------------------------------------------ + + @staticmethod + def _input_dims_in_rule(rule: DimMap) -> set[int]: + """Walk the DimMap rule tree and return all input dim indices that appear in it.""" + seen: set[int] = set() + + def _walk(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen.add(cmd.input_dim) + for inp in cmd.inputs(): + _walk(inp) + + for cmd in rule: + _walk(cmd) + return seen + + def _find_plain_shard( + self, input_dim: InputDim + ) -> tuple[int | None, Shard | _StridedShard | None]: + """Find the mesh dim with a plain Shard on ``input_dim``. - # NOTE: This function has three responsibilities: - # 1. determine "theoretically" if an output dimension can be sharded, i.e. fill the shardable_dims map - # 2. determine "theoretically" the corresponding input dimension to shard on, via return value - # 3. throw an error when strict_view is enabled and we cannot shard an output dimension - # 1 and 2 doesn't require the info of whether current input is sharded. - # 3 requires that info, to decide whether we can error out. Maybe we can refactor - # to make this function purely "theoretical". - def get_in_dim_to_shard(cmd: DimSpec) -> InputDim | None: - from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true + Only matches Shard, not _StridedShard. Used by both _analyze_flatten + and _analyze_split. _find_shard_for_split is the counterpart that + also matches _StridedShard with split_factor validation. + """ + for mesh_dim, placement in enumerate(self.input_src_placements): + if isinstance(placement, Shard) and placement.dim == input_dim.input_dim: + return mesh_dim, placement + return None, None - if isinstance(cmd, InputDim): - return cmd - elif isinstance(cmd, Flatten): - for i, dim in enumerate(cmd.input_dims): - # so far all Flatten is always composed of InputDims; revisit this if needed - if not isinstance(dim, InputDim): - raise AssertionError(f"Expected InputDim, got {type(dim)}") - can_shard_dim = True - shard_mesh_dim, shard_placement = ( - maybe_get_shard_mesh_dim_and_placement(dim) + def _find_shard_for_split( + self, + current_dim: int, + cmd: Split, + placements: Sequence[Placement], + ) -> tuple[int | None, Shard | _StridedShard | None]: + """Find the mesh dim and placement for an input dim in Split ops. + + Matches both Shard and _StridedShard: + - Shard: plain unflatten, e.g. [6] Shard(0) → [2, 3]. + - _StridedShard: unflatten after a prior flatten that produced + _StridedShard, e.g. [2,3,4] Shard(1) → flatten → [6,4] + _StridedShard(0,sf=2) → unflatten → [2,3,4]. Validates that + the split_factor matches the expected value for this split_id. + """ + for mesh_dim, placement in enumerate(placements): + if not isinstance(placement, Shard | _StridedShard): + continue + if placement.dim != current_dim: + continue + if mesh_dim in self.matched_strided_mesh_dims: + continue + + if isinstance(placement, _StridedShard): + expected_sf = self._expected_split_factor( + cmd, current_dim, mesh_dim, placements ) - input_sharded = shard_mesh_dim is not None - if i > 0: - can_shard_dim = False - if strict_view and input_sharded: - raise RuntimeError( - f"Attempted to flatten multiple dimensions, with dimension {dim.input_dim} being sharded. ", - "It cannot be performed without redistribution, which is disallowed by the current operator.", - ) - elif input_sharded: - if not (shard_placement is not None and shard_mesh_dim is not None): - raise AssertionError( - "Expected shard_placement and shard_mesh_dim to be not None" - ) - tensor_dim_size = global_input_shape[shard_placement.dim] - mesh_dim_size = mesh_sizes[shard_mesh_dim] + if expected_sf == placement.split_factor: + return mesh_dim, placement + else: + return mesh_dim, placement + return None, None + + def _analyze_flatten(self, cmd: Flatten) -> list[InputDim]: + """Fill self.shard_allowed for Flatten; return sharded input dims.""" + from torch.fx.experimental.symbolic_shapes import guard_or_true + + sharded_dims: list[InputDim] = [] + num_input_dims = len(cmd.input_dims) + for i, dim in enumerate(cmd.input_dims): + if not isinstance(dim, InputDim): + raise AssertionError(f"Expected InputDim, got {type(dim)}") + shard_mesh_dim, shard_placement = self._find_plain_shard(dim) + if shard_mesh_dim is None or shard_placement is None: + continue # default from analyze() already covers this + tensor_dim_size = self.global_input_shape[shard_placement.dim] + mesh_dim_size = self.mesh_sizes[shard_mesh_dim] + can_shard_dim = True + if self.strict_view: + is_last_input_dim = i == num_input_dims - 1 + if not is_last_input_dim and guard_or_true( + tensor_dim_size % mesh_dim_size != 0 + ): + raise RuntimeError( + f"Cannot flatten unevenly sharded tensor: " + f"dimension {dim.input_dim} (size {tensor_dim_size}) " + f"is not evenly divisible by mesh dimension " + f"{shard_mesh_dim} (size {mesh_dim_size}). " + f"Please redistribute the tensor before this operation." + ) + sharded_dims.append(dim) + else: + # TODO: non-strict (reshape) should allow can_shard_dim = True + # for non-first flatten dims, since strict_view already does. + # Currently forces redistribution because the rewrite phase + # wasn't originally implemented for this case. + if i == 0: + sharded_dims.append(dim) if guard_or_true(tensor_dim_size % mesh_dim_size != 0): can_shard_dim = False - if strict_view: - raise RuntimeError( - f"Attempted to flatten unevenly sharded dimension {i}, " - "which would require resharding the input. " - "Please explicitly redistribute the tensor instead." - ) - shardable_dims[dim.input_dim] = [can_shard_dim] * mesh_ndim - - if not isinstance(cmd.input_dims[0], InputDim): - raise AssertionError( - f"Expected InputDim, got {type(cmd.input_dims[0])}" - ) - return cmd.input_dims[0] - elif isinstance(cmd, Split): - in_dim = get_in_dim_to_shard(cmd.input_dim) - out_size = cmd.group_shape[cmd.split_id] - if cmd.split_id == 0 and in_dim is not None: - # we need to check that the input dimension is divisible - # by the size of the submesh we're sharding it on - # NOTE: it would be possible to shard the same input dimension - # on more than one mesh dimension. In that case, the dimension - # needs to be divisible by the product of mesh sizes. - # In order to keep the problem more tractable, we will not consider - # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) - # but we will allow it if that's the input and it's compatible - - # 1. is this dimension shardable on each individual mesh dim? - shardable_dims[in_dim.input_dim] = [ - guard_or_false(out_size % mesh_dim_size == 0) - for mesh_dim_size in mesh_sizes - ] - - shard_mesh_dim, _ = maybe_get_shard_mesh_dim_and_placement(in_dim) - if strict_view and shard_mesh_dim is not None: - if not shardable_dims[in_dim.input_dim][shard_mesh_dim]: - raise RuntimeError( - f"Attempted to split the sharded dimension {in_dim.input_dim} into multiple subdimensions. ", - "It cannot be performed without redistribution, which is disallowed by the current operator.", - ) + else: + can_shard_dim = False + self.shard_allowed[dim.input_dim] = [can_shard_dim] * self.mesh_ndim + + if len(sharded_dims) > 0: + return sharded_dims + # No sharded dims: e.g. Flatten([InputDim(0), InputDim(1)]) where + # neither dim is sharded. Return the first input dim so that + # input_to_output_tensor_dims is populated for identity rewrites. + if not isinstance(cmd.input_dims[0], InputDim): + raise AssertionError(f"Expected InputDim, got {type(cmd.input_dims[0])}") + return [cmd.input_dims[0]] + + def _analyze_split(self, cmd: Split) -> list[InputDim]: + """Fill self.shard_allowed for Split; return shardable input dims.""" + from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true - # 2. here we special case things like [Shard(0), Shard(0)] - submesh_size = 1 - for size, shard in zip(mesh_sizes, input_src_placements): - if isinstance(shard, Shard | _StridedShard) and shard.dim == in_dim: - submesh_size *= size - if guard_or_true(out_size % submesh_size != 0): - raise AssertionError( - f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + in_dims = self._analyze_dim(cmd.input_dim) + if len(in_dims) == 0: + return [] + in_dim = in_dims[0] + out_size = cmd.group_shape[cmd.split_id] + shard_mesh_dim, input_src_placement = self._find_shard_for_split( + in_dim.input_dim, cmd, self.input_src_placements + ) + # split_id == 0 sets the base shard_allowed for this input dim. + # Later split_ids (processed in subsequent rule iterations) refine + # individual mesh_dim entries via the _StridedShard branch below. + if cmd.split_id == 0: + self.shard_allowed[in_dim.input_dim] = [ + guard_or_false(out_size % mesh_dim_size == 0) + for mesh_dim_size in self.mesh_sizes + ] + plain_mesh_dim, _ = self._find_plain_shard(in_dim) + # Non-strict silently redistributes via shard_allowed=False above; + # strict raises so the user knows to redistribute before view(). + if self.strict_view and plain_mesh_dim is not None: + if not self.shard_allowed[in_dim.input_dim][plain_mesh_dim]: + raise RuntimeError( + f"Cannot unflatten unevenly sharded tensor: " + f"output dimension {cmd.split_id} (size {out_size}) " + f"is not evenly divisible by mesh dimension " + f"{plain_mesh_dim} (size {self.mesh_sizes[plain_mesh_dim]}). " + f"Please redistribute the tensor before this operation." ) + if shard_mesh_dim is not None and isinstance( + input_src_placement, _StridedShard + ): + # The last split dim doesn't require even divisibility because + # its local size is inferred: local_last = local_flat / product + # of earlier dims, and DTensor handles uneven local sizes. + # Non-last dims must be evenly divisible because they appear as + # fixed sizes in the local reshape — uneven division would make + # the stride pattern inconsistent across devices. + # E.g. [12] → [3, 4], _StridedShard targeting dim 1 (last), + # mesh=3: 4%3≠0, but local shapes [3,2],[3,1],[3,1] are valid. + is_last_split_dim = cmd.split_id == len(cmd.group_shape) - 1 + if ( + self.strict_view + and not is_last_split_dim + and guard_or_true(out_size % self.mesh_sizes[shard_mesh_dim] != 0) + ): + raise RuntimeError( + f"Cannot unflatten unevenly sharded tensor: " + f"output dimension {cmd.split_id} (size {out_size}) " + f"is not evenly divisible by mesh dimension {shard_mesh_dim} " + f"(size {self.mesh_sizes[shard_mesh_dim]}). " + f"Please redistribute the tensor before this operation." + ) + # Prevents _find_shard_for_split from matching this mesh dim + # again for a later split_id of the same Split group. + self.matched_strided_mesh_dims.add(shard_mesh_dim) + if in_dim.input_dim in self.shard_allowed: + self.shard_allowed[in_dim.input_dim][shard_mesh_dim] = ( + guard_or_false(out_size % self.mesh_sizes[shard_mesh_dim] == 0) + or is_last_split_dim + ) + # Only split_id==0 returns the input dim for input_to_output_tensor_dims. + # Later split_ids refine shard_allowed above but return [] — their + # output dims are linked via the root-input-dim chase in analyze(). + return [in_dim] if cmd.split_id == 0 else [] - # we will only shard our first component of the split - return in_dim if cmd.split_id == 0 else None + def _analyze_dim(self, cmd: DimSpec) -> list[InputDim]: + """Dispatch one DimSpec: update self.shard_allowed, return input dim(s) to shard on.""" + if isinstance(cmd, InputDim): + return [cmd] + elif isinstance(cmd, Flatten): + return self._analyze_flatten(cmd) + elif isinstance(cmd, Split): + return self._analyze_split(cmd) elif isinstance(cmd, Repeat): - in_dim = get_in_dim_to_shard(cmd.input_dim) - if in_dim is not None: - shardable_dims[in_dim.input_dim] = [False] * mesh_ndim - return None + in_dims = self._analyze_dim(cmd.input_dim) + for d in in_dims: + self.shard_allowed[d.input_dim] = [False] * self.mesh_ndim + return [] else: - return None - - # for each output dim, find the corresponding input dim in terms of sharding prop - shard_dim_map = {} - for dim, cmd in enumerate(rule): - in_dim = get_in_dim_to_shard(cmd) - if in_dim is not None: - shard_dim_map[in_dim.input_dim] = dim - - input_tgt_placements = [ - ( - Replicate() - if isinstance(p, Shard | _StridedShard) - and not shardable_dims[p.dim][mesh_dim] - else p + return [] + + # ------------------------------------------------------------------ + # Rewrite phase helpers + # ------------------------------------------------------------------ + + @staticmethod + def _is_last_shard_in_flatten_range( + mesh_dim: int, + placements: Sequence[Placement], + flatten_start: int, + flatten_end: int, + ) -> bool: + """Check if no later mesh dim shards a dim within the flatten range at or above this one. + + Uneven sharding on dim d breaks stride computation for all earlier dims + that flatten together with d. Only dims within [flatten_start, flatten_end) + matter; shards on dims outside the flatten range are independent. + + Requires: placements[mesh_dim] must be Shard or _StridedShard. + """ + p = placements[mesh_dim] + if not isinstance(p, (Shard, _StridedShard)): + raise AssertionError( + f"Expected Shard or _StridedShard at mesh_dim {mesh_dim}, got {type(p)}" + ) + tensor_dim = p.dim + return not any( + isinstance(other_p, (Shard, _StridedShard)) + and flatten_start <= other_p.dim < flatten_end + and other_p.dim >= tensor_dim + for other_p in placements[mesh_dim + 1 :] ) - for mesh_dim, p in enumerate(input_src_placements) - ] - def _rewrite_shard_dim(p: Shard | _StridedShard): + def _expected_split_factor( + self, + cmd: Split, + sharded_dim: int, + mesh_dim: int, + placements: Sequence[Placement], + ) -> int | None: + """Compute the residual split factor for ``cmd`` after earlier mesh dims. + + Starts from ``math.prod(cmd.group_shape[:cmd.split_id])`` and divides + out each earlier mesh dim that shards the same input dim. Returns + ``None`` if any earlier mesh size doesn't divide evenly. + """ + sf = math.prod(cmd.group_shape[: cmd.split_id]) + for m in range(mesh_dim): + other_p = placements[m] + if ( + isinstance(other_p, (_StridedShard, Shard)) + and other_p.dim == sharded_dim + ): + if sf % self.mesh_sizes[m] != 0: + return None + sf //= self.mesh_sizes[m] + return sf + + def _find_keep_ss_dim( + self, + tgt_shard_dims: list[int], + p: _StridedShard, + mesh_dim: int, + ) -> int | None: + """Find an output dim where SS stays as SS. + + Returns the first output dim whose Split can accommodate the combined + sharding (mesh_size * split_factor), or ``None`` if no dim fits. """ - Rewrite the shard dim to the corresponding tensor dim in output. - For ``_StridedShard``, we can safely keep the placement type and - ``split_factor`` unchanged and only rewrite the ``dim`` because: - 1. ``_StridedShard`` has no impact on sharding (i.e. how - tensor is partitioned) compared to ``Shard``. It only changes - how shards permute across the devices. - 2. ``view()`` op on DTensor strictly forbids shard redistribution - which means if ``view()`` may cause shard permutation across - devices, it should be rejected. This is enforced in today's - sharding prop for ``view()``. - 3. Since DTensor ``view()`` won't introduce any redistribution, - it's certain that ``placements`` won't change except the - inner ``dim`` attribute of ``Shard`` or ``_StridedShard``. + total_shard = self.mesh_sizes[mesh_dim] * p.split_factor + if self.global_input_shape[p.dim] % total_shard != 0: + return None + shard_size = self.global_input_shape[p.dim] // total_shard + for candidate_dim in tgt_shard_dims: + cmd = self.rule[candidate_dim] + if isinstance(cmd, Split): + inner_size = math.prod(cmd.group_shape[cmd.split_id + 1 :]) + # When a Split wraps a Flatten, the per-shard chunk covers + # the sharded dim plus trailing dims flattened together. + trailing_size = 1 + if isinstance(cmd.input_dim, Flatten): + found = False + for flat_dim in cmd.input_dim.input_dims: + if not isinstance(flat_dim, InputDim): + raise AssertionError( + f"Expected InputDim, got {type(flat_dim)}" + ) + if flat_dim.input_dim == p.dim: + found = True + elif found: + trailing_size *= self.global_input_shape[flat_dim.input_dim] + flattened_shard_size = shard_size * trailing_size + if ( + flattened_shard_size >= inner_size + and flattened_shard_size % inner_size == 0 + ): + return candidate_dim + return None + + def _rewrite_plain_shard( + self, + p: Shard, + mesh_dim: int, + placements: Sequence[Placement], + strided_shard_claimed_dims: set[ClaimedDim], + local_tensor_shapes: list[int], + input_to_output_tensor_dims: dict[int, list[int]], + ) -> tuple[Placement, list[int]]: + """Given a plain Shard(dim=X) input placement on a specific mesh dim, + determine what output placement it maps to after the view op. + + For identity and unflatten, produces Shard on the output dim. For + flatten, Shard on the first flattened dim stays Shard, while Shard on + a non-first dim produces _StridedShard (consumed later by + _rewrite_strided_shard). + + Returns the output placement and a new local_tensor_shapes with this + mesh dim's division applied. """ - if isinstance(p, _StridedShard): - return _StridedShard(shard_dim_map[p.dim], split_factor=p.split_factor) + # Output dims that input dim p.dim maps to, filtering out any + # already claimed by _StridedShard rewriting on earlier mesh dims. + tgt_shard_dims = [ + d + for d in input_to_output_tensor_dims[p.dim] + if ClaimedDim(p.dim, d) not in strided_shard_claimed_dims + ] + if len(tgt_shard_dims) == 0: + raise AssertionError( + f"No output dim available for Shard(dim={p.dim}) on mesh dim " + f"{mesh_dim}. All output dims already claimed by earlier mesh dims." + ) + if len(tgt_shard_dims) == 1: + tgt_shard_dim = tgt_shard_dims[0] else: - return Shard(shard_dim_map[p.dim]) - - output_placements = [ - _rewrite_shard_dim(p) if isinstance(p, Shard | _StridedShard) else p - for p in input_tgt_placements - ] - - return input_tgt_placements, output_placements + # Unflatten: one input dim maps to multiple output dims + # (e.g. (24,) → (2, 3, 4) gives 3 splits). Plain Shard + # always targets the split_id=0 output dim. + tgt_shard_dim = next( + ( + d + for d in tgt_shard_dims + if isinstance(self.rule[d], Split) + and cast(Split, self.rule[d]).split_id == 0 + ), + None, + ) + if tgt_shard_dim is None: + raise AssertionError( + f"No Split(split_id=0) found among unclaimed output dims " + f"{tgt_shard_dims} for Shard(dim={p.dim}) on mesh dim {mesh_dim}." + ) + cmd = self.rule[tgt_shard_dim] + if isinstance(cmd, Split) and isinstance(cmd.input_dim, Flatten): + first_dim = cmd.input_dim.input_dims[0] + if isinstance(first_dim, InputDim) and p.dim != first_dim.input_dim: + raise RuntimeError( + f"Shard(dim={p.dim}) through Split(Flatten(...), {cmd.group_shape}) " + f"is not supported yet for non-first flatten dims." + ) + if isinstance(cmd, (Split, InputDim)): + # Split/InputDim: 1:1 dim mapping, sharding transfers directly. + # Flatten needs stride computation below (multiple dims merge). + new_shapes = list(local_tensor_shapes) + new_shapes[p.dim] //= self.mesh_sizes[mesh_dim] + return Shard(tgt_shard_dim), new_shapes + if not isinstance(cmd, Flatten): + raise AssertionError(f"Expected Flatten, got {type(cmd)}") + first_dim = cmd.input_dims[0] + last_dim = cmd.input_dims[-1] + if not isinstance(first_dim, InputDim): + raise AssertionError(f"Expected InputDim, got {type(first_dim)}") + if not isinstance(last_dim, InputDim): + raise AssertionError(f"Expected InputDim, got {type(last_dim)}") + input_start_idx = first_dim.input_dim + if p.dim == input_start_idx: + output_placement: Placement = Shard(tgt_shard_dim) + else: + split_factor = math.prod(local_tensor_shapes[input_start_idx : p.dim]) + output_placement = _StridedShard(tgt_shard_dim, split_factor=split_factor) + # Uneven sharding on a non-last flatten dim breaks _StridedShard: + # split_factor (number of groups) must be the same on all devices, + # but uneven division of a non-last dim makes group count vary. + # E.g. [3,4]→[12] Shard(0) mesh=2: device 0 has 2 groups of 4, + # device 1 has 1 group of 4 — no consistent split_factor. + # The last dim is exempt: only group *size* varies, not count. + flatten_end = last_dim.input_dim + 1 + if local_tensor_shapes[p.dim] % self.mesh_sizes[ + mesh_dim + ] != 0 and not self._is_last_shard_in_flatten_range( + mesh_dim, placements, input_start_idx, flatten_end + ): + raise RuntimeError( + f"Cannot shard unevenly distributed tensor: " + f"dimension {p.dim} (size {local_tensor_shapes[p.dim]}) " + f"is not evenly divisible by mesh dimension " + f"{mesh_dim} (size {self.mesh_sizes[mesh_dim]}). " + f"Please redistribute the tensor before this operation." + ) + new_shapes = list(local_tensor_shapes) + new_shapes[p.dim] //= self.mesh_sizes[mesh_dim] + return output_placement, new_shapes + + def _rewrite_strided_shard( + self, + p: _StridedShard, + mesh_dim: int, + placements: Sequence[Placement], + strided_shard_claimed_dims: set[ClaimedDim], + local_tensor_shapes: list[int], + input_to_output_tensor_dims: dict[int, list[int]], + ) -> tuple[Placement, list[int]]: + """Rewrite _StridedShard placement to target the correct output dim. + + _StridedShard inputs arise from a prior flatten on a non-first dim + (produced by _rewrite_plain_shard above). The interesting case is + unflatten (Split rule): the split_factor may resolve to contiguous + sharding (producing Shard) or stay as _StridedShard. For + identity/flatten rules, falls through to the fallback and keeps the + placement as-is. + + Returns the output placement and a new local_tensor_shapes with this + mesh dim's division applied. + """ + tgt_shard_dims = [ + d + for d in input_to_output_tensor_dims[p.dim] + if ClaimedDim(p.dim, d) not in strided_shard_claimed_dims + ] + # Phase 1: resolve SS → Shard. If an output dim's Split has a + # group_shape prefix matching the split_factor, the strided pattern + # is fully captured by the Split, so SS simplifies to Shard. + # E.g. unflatten (6, 4) → (2, 3, 4) with SS(0, sf=2) on mesh (3): + # sf=2 means 2 groups of contiguous data in dim 0. Split into + # (2, 3, 4) gives group_shape=(2, 3); prod(group_shape[:1])=2==sf, + # so the strided pattern lands exactly on output dim 1 → Shard(1). + for candidate_dim in tgt_shard_dims: + cmd = self.rule[candidate_dim] + if isinstance(cmd, Split): + expected_sf = self._expected_split_factor( + cmd, p.dim, mesh_dim, placements + ) + if expected_sf != p.split_factor: + continue + strided_shard_claimed_dims.add(ClaimedDim(p.dim, candidate_dim)) + new_shapes = list(local_tensor_shapes) + new_shapes[p.dim] //= self.mesh_sizes[mesh_dim] + return Shard(candidate_dim), new_shapes + + # Phase 2: keep SS as SS. Phase 1 is tried first because we prefer + # resolving to the simpler Shard when possible. + tgt_shard_dim = self._find_keep_ss_dim(tgt_shard_dims, p, mesh_dim) + + if tgt_shard_dim is None: + if self.strict_view and any( + isinstance(self.rule[d], Split) for d in tgt_shard_dims + ): + raise RuntimeError( + f"Cannot unflatten tensor with _StridedShard placement: " + f"split_factor={p.split_factor} does not match any output " + f"dimension. This typically means the _StridedShard placement " + f"was constructed with a split_factor that is incompatible " + f"with the unflatten shape. Please redistribute the tensor " + f"before this operation." + ) + if len(tgt_shard_dims) == 0: + raise AssertionError( + f"No unclaimed output dims for _StridedShard(dim={p.dim}) " + f"on mesh dim {mesh_dim}." + ) + # Fallback for identity/flatten: tgt_shard_dims has exactly one + # element, so [0] is correct. For Split rules this is unreachable + # in practice — the analysis phase rejects mismatched split_factors + # via shard_allowed, forcing redistribution before we get here. + tgt_shard_dim = tgt_shard_dims[0] + new_shapes = list(local_tensor_shapes) + new_shapes[p.dim] //= self.mesh_sizes[mesh_dim] + return _StridedShard(tgt_shard_dim, split_factor=p.split_factor), new_shapes def register_op_strategy_map( @@ -783,12 +1297,17 @@ def reshape_strategy(op_schema: OpSchema) -> StrategyType: placements=tuple(input_tgt_placements), mesh=mesh, tensor_meta=input_src_spec.tensor_meta, + use_strided_shard_as_shard_order=False, ) redistribute_costs: list[list[float]] = [ generate_redistribute_costs(input_strategy, input_tgt_spec) ] - output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_spec = DTensorSpec( + mesh=mesh, + placements=tuple(output_placements), + use_strided_shard_as_shard_order=False, + ) output_strategy.strategies.append( OpSpec( output_specs=output_spec, @@ -800,18 +1319,31 @@ def reshape_strategy(op_schema: OpSchema) -> StrategyType: return output_strategy -register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map(aten.squeeze.default, torch.squeeze, strict_view=True) +register_op_strategy_map(aten.squeeze_.default, torch.squeeze, strict_view=True) register_op_strategy_map( - aten.squeeze_.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) + aten.squeeze_.dim, + torch.squeeze, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, ) register_op_strategy_map( - aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) + aten.squeeze.dim, + torch.squeeze, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, ) register_op_strategy_map( - aten.squeeze.dims, torch.squeeze, schema_info=RuntimeSchemaInfo(1) + aten.squeeze.dims, + torch.squeeze, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, ) register_op_strategy_map( - aten.squeeze_.dims, torch.squeeze, schema_info=RuntimeSchemaInfo(1) + aten.squeeze_.dims, + torch.squeeze, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, ) register_op_strategy_map( aten.view.default, diff --git a/torch/distributed/tensor/_ops/single_dim_strategy.py b/torch/distributed/tensor/_ops/single_dim_strategy.py index e2127c29950c4..722cb8c24a8a2 100644 --- a/torch/distributed/tensor/_ops/single_dim_strategy.py +++ b/torch/distributed/tensor/_ops/single_dim_strategy.py @@ -23,7 +23,6 @@ OpSchema, OpSpec, OpStrategy, - PlacementList, RuntimeSchemaInfo, StrategyType, TupleStrategy, @@ -76,6 +75,10 @@ class _SingleDimStrategyInfo: func: _SingleDimStrategyFunc allow_unbacked_sharding: bool | None = field(default=None) allow_uneven_sharding: bool = field(default=False) + # Positions (in args_schema) of args that may live on a different mesh + # than the op's compute mesh. These args must be Replicate. + # See Note [Multi-mesh args] in expand_to_full_mesh_op_strategy. + different_mesh_args: list[int] | None = field(default=None) # Delegate to func so this can be used interchangeably with a raw # _SingleDimStrategyFunc (e.g. in tests that call strategy functions directly). @@ -85,29 +88,39 @@ def __call__(self, *args, **kwargs): def _insert_single_dim_replication_strategy( single_dim_strategies_with_placeholders: list[ - list[Placement | _ShardingPlaceholder] + list[Placement | _ShardingPlaceholder | None] ], num_outputs: int, num_input_tensors: int, -) -> list[list[Placement | _ShardingPlaceholder]]: + output_tensor_meta: TensorMeta | Sequence[TensorMeta | None] | None = None, +) -> list[list[Placement | _ShardingPlaceholder | None]]: """ Inserts the [Replicate(), Replicate(), ...] strategy after asserting that such strategy does not yet exist. + For ops with masked-off outputs (e.g. backward ops with output_mask), output positions + where output_tensor_meta is None are set to None in the all-Replicate rule. """ for strategy in single_dim_strategies_with_placeholders: - if all(isinstance(p, Replicate) for p in strategy): - raise AssertionError - single_dim_strategies_with_placeholders.insert( - 0, [Replicate()] * (num_outputs + num_input_tensors) - ) + if all(isinstance(p, Replicate) or p is None for p in strategy): + return single_dim_strategies_with_placeholders + total_len = num_outputs + num_input_tensors + replicate_rule: list[Placement | _ShardingPlaceholder | None] = [ + Replicate() + ] * total_len + # Set None for masked-off output positions based on output_tensor_meta + if isinstance(output_tensor_meta, Sequence): + for i, meta in enumerate(output_tensor_meta): + if meta is None and i < num_outputs: + replicate_rule[i] = None + single_dim_strategies_with_placeholders.insert(0, replicate_rule) return single_dim_strategies_with_placeholders def _fill_single_dim_strategy_placeholders( unique_input_placements: set[Placement], single_dim_strategies_with_placeholders: list[ - list[Placement | _ShardingPlaceholder] + list[Placement | _ShardingPlaceholder | None] ], -) -> list[list[Placement]]: +) -> list[list[Placement | None]]: """ Replace any _ShardingPlaceholder with the specific Sharding types used by the inputs in op_schema. Supports implicit replication. @@ -137,25 +150,29 @@ def _fill_single_dim_strategy_placeholders( # if any of the placements is a placeholder, we need to expand the strategy # to all possible combinations of placements - expanded_strategies_over_one_mesh_dim: list[list[Placement]] = [] + expanded_strategies_over_one_mesh_dim: list[list[Placement | None]] = [] for s in single_dim_strategies_with_placeholders: if any(isinstance(p, _ShardingPlaceholder) for p in s): for shard_builder in shard_builders.values(): - expanded_strategy: list[Placement] = [] + expanded_strategy: list[Placement | None] = [] for maybe_placeholder in s: if isinstance(maybe_placeholder, _ShardingPlaceholder): # we combine the tensor dim to shard from the placeholder # with other metadata (e.g. split_factor) from the sharding class expanded_strategy.append(shard_builder(maybe_placeholder.dim)) + elif maybe_placeholder is None: + expanded_strategy.append(None) else: if not isinstance(maybe_placeholder, Placement): raise AssertionError expanded_strategy.append(maybe_placeholder) expanded_strategies_over_one_mesh_dim.append(expanded_strategy) else: - if not all(isinstance(p, Placement) for p in s): + if not all(isinstance(p, Placement) or p is None for p in s): raise AssertionError - expanded_strategies_over_one_mesh_dim.append(cast(list[Placement], (s))) + expanded_strategies_over_one_mesh_dim.append( + cast(list[Placement | None], (s)) + ) return expanded_strategies_over_one_mesh_dim @@ -173,6 +190,9 @@ def _update_placements(obj: Any): elif isinstance(obj, TupleStrategy): for child in obj.children: _update_placements(child) + elif isinstance(obj, (list, tuple)): + for child in obj: + _update_placements(child) for obj in op_schema.args_schema: _update_placements(obj) @@ -185,30 +205,35 @@ def _update_placements(obj: Any): def _get_num_tensor_inputs(op_schema: OpSchema) -> int: num_inputs = 0 - for obj in op_schema.args_schema: + + def _count(obj: Any) -> int: if isinstance(obj, OpStrategy): - num_inputs += 1 + return 1 elif isinstance(obj, TupleStrategy): - num_inputs += len(obj.children) + return sum(1 for child in obj.children if child is not None) + elif isinstance(obj, (list, tuple)): + return sum(_count(child) for child in obj) + return 0 + + for obj in op_schema.args_schema: + num_inputs += _count(obj) # Also count tensor kwargs (e.g., "out" for out-variant ops) for obj in op_schema.kwargs_schema.values(): - if isinstance(obj, OpStrategy): - num_inputs += 1 - elif isinstance(obj, TupleStrategy): - num_inputs += len(obj.children) + num_inputs += _count(obj) return num_inputs def _build_output_specs( mesh: DeviceMesh, - per_mesh_dim_placements: list[tuple[Placement, ...]], + per_mesh_dim_placements: list[tuple[Placement | None, ...]], num_outputs: int, output_metas: tuple[TensorMeta | None, ...], ) -> DTensorSpec | tuple[DTensorSpec | None, ...]: """Build output spec(s) by transposing per-mesh-dim placements to per-output. per_mesh_dim_placements is indexed [mesh_dim][output_idx]. output_metas must - have exactly num_outputs elements. + have exactly num_outputs elements. Outputs where output_metas[i] is None + (masked-off outputs) produce None specs. """ if num_outputs <= 0: raise AssertionError(f"Expected num_outputs > 0, got {num_outputs}") @@ -217,16 +242,21 @@ def _build_output_specs( f"Expected {num_outputs} output_metas, got {len(output_metas)}" ) - def _placements_for_output(out_idx: int) -> tuple[Placement, ...]: - return tuple(out[out_idx] for out in per_mesh_dim_placements) + def _spec_for_output(out_idx: int) -> DTensorSpec | None: + if output_metas[out_idx] is None: + return None + placements = tuple( + cast(Placement, out[out_idx]) for out in per_mesh_dim_placements + ) + return DTensorSpec(mesh, placements, tensor_meta=output_metas[out_idx]) if num_outputs > 1: - return tuple( - DTensorSpec(mesh, _placements_for_output(i), tensor_meta=output_metas[i]) - for i in range(num_outputs) - ) + return tuple(_spec_for_output(i) for i in range(num_outputs)) else: - return DTensorSpec(mesh, _placements_for_output(0), tensor_meta=output_metas[0]) + spec = _spec_for_output(0) + if spec is None: + raise AssertionError("Single-output op cannot have None output meta") + return spec class _PreparedSingleDimStrategy: @@ -242,8 +272,8 @@ class _PreparedSingleDimStrategy: allowed_partial_per_input for graph search neighbor generation. """ - strategy_lookup: dict[tuple[Placement, ...], tuple[Placement, ...]] - expanded_strategies: list[list[Placement]] + strategy_lookup: dict[tuple[Placement | None, ...], tuple[Placement | None, ...]] + expanded_strategies: list[list[Placement | None]] num_outputs: int num_inputs: int output_metas: tuple[TensorMeta | None, ...] @@ -275,18 +305,63 @@ def __init__( if isinstance(strategy_fn, _SingleDimStrategyInfo): self.allow_unbacked_sharding = strategy_fn.allow_unbacked_sharding self.allow_uneven_sharding = strategy_fn.allow_uneven_sharding + different_mesh_args = strategy_fn.different_mesh_args func = strategy_fn.func else: self.allow_unbacked_sharding = None self.allow_uneven_sharding = False + different_mesh_args = None func = strategy_fn + # Determine element_mesh from the first OpStrategy arg. For foreach + # per-element schemas the element's inputs may live on a smaller + # sub-mesh than the global compute_mesh. + self.element_mesh: DeviceMesh | None = None + for arg in op_schema.args_schema: + if isinstance(arg, OpStrategy): + self.element_mesh = arg.strategies[0].output_spec.mesh + break + + # Validate that all inputs are on the same mesh (except + # different_mesh_args which are explicitly allowed to differ). + if self.element_mesh is not None: + allowed = set(different_mesh_args or []) + for i, arg in enumerate(op_schema.args_schema): + if isinstance(arg, OpStrategy) and i not in allowed: + arg_mesh = arg.strategies[0].output_spec.mesh + if arg_mesh != self.element_mesh: + raise ValueError( + f"Cannot run {op_schema.op} on inputs with different " + f"meshes: got {self.element_mesh} and {arg_mesh}" + ) + + # Remap different_mesh_args from args_schema positions to + # OpStrategy-only positions. Non-OpStrategy args (e.g. empty lists) + # are filtered out by expand_to_full_mesh_op_strategy, shifting later + # indices. + self.remapped_different_mesh_args: list[int] | None = None + if different_mesh_args is not None: + schema_to_strategy: dict[int, int] = {} + strategy_pos = 0 + for schema_pos, arg in enumerate(op_schema.args_schema): + if isinstance(arg, OpStrategy): + schema_to_strategy[schema_pos] = strategy_pos + strategy_pos += 1 + self.remapped_different_mesh_args = [ + schema_to_strategy[i] + for i in different_mesh_args + if i in schema_to_strategy + ] + if num_inputs is None: num_inputs = _get_num_tensor_inputs(op_schema) self.num_inputs = num_inputs - strategies_with_placeholders = func( - op_schema.op, op_schema.args_meta, op_schema.kwargs_meta + # Strategy functions may return None in output positions for masked-off + # outputs (e.g. backward ops with output_mask). Widen the type here. + strategies_with_placeholders = cast( + list[list[Placement | _ShardingPlaceholder | None]], + func(op_schema.op, op_schema.args_meta, op_schema.kwargs_meta), ) # Validate strategy length against the op schema. The schema is the @@ -320,7 +395,10 @@ def __init__( self.num_outputs = num_outputs strategies_with_placeholders = _insert_single_dim_replication_strategy( - strategies_with_placeholders, num_outputs, num_inputs + strategies_with_placeholders, + num_outputs, + num_inputs, + output_tensor_meta, ) unique_input_placements = _get_unique_placements(op_schema) @@ -343,6 +421,8 @@ def __init__( for strategy in self.expanded_strategies: for input_idx in range(num_inputs): p = strategy[num_outputs + input_idx] + if p is None: + continue if _is_sharding(p): self.allowed_sharding_per_input[input_idx].add(p) elif isinstance(p, Partial): @@ -371,7 +451,7 @@ def try_propagate( """ from torch.distributed.tensor._ops.utils import is_tensor_shardable - selected_output_placements: list[tuple[Placement, ...]] = [] + selected_output_placements: list[tuple[Placement | None, ...]] = [] for mesh_dim in range(mesh.ndim): input_placements_for_dim = tuple( placements[mesh_dim] for placements in input_placements @@ -457,15 +537,18 @@ def expanded_strategy( base_name = op_name.split("::")[1].split(".")[0] is_inplace = base_name.endswith("_") + element_mesh = prepared_strategy.element_mesh or mesh + return expand_to_full_mesh_op_strategy( - mesh, + element_mesh, op_schema, - cast(list[PlacementList], prepared_strategy.expanded_strategies), + prepared_strategy.expanded_strategies, output_tensor_meta=output_tensor_meta, inplace_op=is_inplace, input_index=prepared_strategy.num_outputs, allow_unbacked_sharding=prepared_strategy.allow_unbacked_sharding, allow_uneven_sharding=prepared_strategy.allow_uneven_sharding, + different_mesh_args=prepared_strategy.remapped_different_mesh_args, ) return expanded_strategy @@ -487,12 +570,14 @@ def _create_expanded_strategy( # Unhashable types (SymInts), skip caching return _create_expanded_strategy_impl(op_schema, output_tensor_meta) - def _translate_foreach_op_schema( - op_schema: OpSchema, output_tensor_meta: Sequence[TensorMeta], index: int - ) -> tuple[OpSchema, TensorMeta]: - """Translate foreach op to per-element version of schema.""" + def _translate_list_op_schema( + op_schema: OpSchema, + output_tensor_meta: Sequence[TensorMeta] | None, + index: int, + ) -> tuple[OpSchema, TensorMeta | None]: + """Translate foreach/fused op to per-element version of schema.""" op_parts = str(op_schema.op).split(".") - base_op_name = op_parts[-2].replace("_foreach_", "") + op_name = op_parts[-2] foreach_variant = op_parts[-1] # select per-element inputs, outputs @@ -502,7 +587,30 @@ def _translate_foreach_op_schema( (op_schema.args_schema, op_schema.kwargs_schema), is_leaf=lambda x: isinstance(x, TupleStrategy), ) - target_output_meta = output_tensor_meta[index] + # For inplace ops, output_tensor_meta is None + target_output_meta = ( + output_tensor_meta[index] if output_tensor_meta is not None else None + ) + + # Strip the prefix to get the base op name and find the per-element op. + # Fused ops (e.g. _fused_adam) have no per-element ATen equivalent, + # so we keep the original op unchanged. + if op_name.startswith("_foreach_"): + base_op_name = op_name.replace("_foreach_", "", 1) + elif op_name.startswith("_amp_foreach_"): + base_op_name = op_name.replace("_amp_foreach_", "", 1) + else: + # Fused ops or unknown: keep original op, no translation + target_op = op_schema.op + op_schema = OpSchema( + target_op, # type: ignore[arg-type] + args_schema=tuple(target_args), + kwargs_schema=op_schema.kwargs_schema, + ) + return op_schema, target_output_meta + + # Strip trailing underscore for inplace ops + base_op_name = base_op_name.removesuffix("_") # figure out target op variant variant_map = { @@ -551,7 +659,7 @@ def expanded_foreach_strategy( child_strategies: list[StrategyType] = [] for tensorlist_i in range(tensorlist_len): - per_index_schema, per_index_output_meta = _translate_foreach_op_schema( + per_index_schema, per_index_output_meta = _translate_list_op_schema( op_schema, output_tensor_meta, # type: ignore[arg-type] tensorlist_i, @@ -568,7 +676,16 @@ def expanded_foreach_strategy( return TupleStrategy(children=child_strategies) # TODO maybe this could be helped by adding a new 'tag' to the OpOverload? - if op_schema.op.name().startswith("aten::_foreach_"): + # Only use the foreach path if the op has TupleStrategy inputs (i.e., actual + # list-of-tensor args). The name prefix alone is insufficient because ops like + # _fused_rms_norm share the "_fused_" prefix but are not foreach/fused-optimizer ops. + op_name = op_schema.op.name() + has_tuple_strategy = any( + isinstance(arg, TupleStrategy) for arg in op_schema.args_schema + ) + if has_tuple_strategy and op_name.startswith( + ("aten::_foreach_", "aten::_amp_foreach_", "aten::_fused_") + ): return expanded_foreach_strategy return _create_expanded_strategy(op_schema, output_tensor_meta) @@ -579,6 +696,7 @@ def register_single_dim_strategy( schema_info: RuntimeSchemaInfo | None = None, allow_unbacked_sharding: bool | None = None, allow_uneven_sharding: bool = False, + different_mesh_args: list[int] | None = None, ) -> Callable[[_SingleDimStrategyFunc], _SingleDimStrategyFunc]: """ Registers a single_dim_strategy function for the given op. @@ -627,6 +745,7 @@ def wrapper(impl): func=impl, allow_unbacked_sharding=allow_unbacked_sharding, allow_uneven_sharding=allow_uneven_sharding, + different_mesh_args=different_mesh_args, ) registration_wrapper(info) return impl @@ -810,6 +929,10 @@ def _dijkstra_expand_single_dim_strategy_to_mesh( total_bytes = spec.tensor_meta.dtype.itemsize * math.prod( spec.tensor_meta.shape ) + # TODO: is_shard() misses _StridedShard, use spec.num_shards instead. + # Not fixing yet: the overestimate biases Dijkstra toward redistributing + # away from _StridedShard, which is the safer default until _StridedShard + # is fully validated. num_shards = 1 for i, p in enumerate(spec.placements): if p.is_shard(): diff --git a/torch/distributed/tensor/_ops/strategy_validation.py b/torch/distributed/tensor/_ops/strategy_validation.py index c48f425706e47..3d94de2d1455d 100644 --- a/torch/distributed/tensor/_ops/strategy_validation.py +++ b/torch/distributed/tensor/_ops/strategy_validation.py @@ -62,6 +62,7 @@ SKIP_OPS: dict[str, str] = { "bernoulli": "non-deterministic (random sampling)", "empty_like": "uninitialized memory", + "exponential": "non-deterministic (random sampling)", "new_empty": "uninitialized memory", "new_empty_strided": "uninitialized memory", "nn.functional.dropout": "non-deterministic (random masking)", @@ -158,6 +159,7 @@ def is_fully_replicated(placements: tuple[Placement, ...]) -> bool: def is_trivial_shard(p: Placement, tensor_shape: tuple[int, ...]) -> bool: """Check if placement is a Shard on a size-1 dimension.""" + # NOTE: isinstance(_, Shard) does not match _StridedShard; see _is_shard_like(). return ( isinstance(p, Shard) and p.dim < len(tensor_shape) and tensor_shape[p.dim] == 1 ) @@ -429,6 +431,121 @@ def _create_partial_input( return LocalTensor(local_tensors) +def _shard_tensors( + tensors: list[tuple[str, torch.Tensor]], + input_placements: tuple[Placement, ...], + world_size: int, + mesh: DeviceMesh, + mask_shift: int = 0, +) -> list[LocalTensor | torch.Tensor]: + """Create sharded LocalTensors from tensors according to placements.""" + local_tensors: list[LocalTensor | torch.Tensor] = [] + for tensor_idx, ((name, tensor), placement) in enumerate( + zip(tensors, input_placements) + ): + if isinstance(placement, Partial): + local_tensor = _create_partial_input( + tensor, placement, world_size, tensor_idx, mask_shift + ) + elif isinstance(placement, Replicate): + _tmp = {r: tensor.clone() for r in range(world_size)} + # pyrefly: ignore [bad-argument-type, bad-argument-count] + local_tensor = LocalTensor(_tmp) + elif isinstance(placement, Shard): + shard_dim = placement.dim + chunks = tensor.tensor_split(world_size, dim=shard_dim) + _tmp = { + r: chunks[r].clone(memory_format=torch.contiguous_format) + for r in range(world_size) + } + # pyrefly: ignore [bad-argument-type, bad-argument-count] + local_tensor = LocalTensor(_tmp) + else: + dt = distribute_tensor(tensor.clone(), mesh, (placement,)) + local_tensor = dt.to_local() + local_tensors.append(local_tensor) + return local_tensors + + +def _compare_outputs( + local_output: Any, + ground_truth: torch.Tensor | list[torch.Tensor], + output_placements: tuple[Placement, ...], + mesh: DeviceMesh, + world_size: int, +) -> tuple[bool, str]: + """Compare op output (wrapped as DTensor) against ground truth.""" + if isinstance(local_output, (list, tuple)): + local_outputs = list(local_output) + else: + local_outputs = [local_output] + + if isinstance(ground_truth, list): + ground_truths = ground_truth + else: + ground_truths = [ground_truth] + + if len(local_outputs) != len(ground_truths): + return ( + False, + f"Output count mismatch: got {len(local_outputs)}, " + f"expected {len(ground_truths)}", + ) + + if len(local_outputs) != len(output_placements): + return ( + False, + f"Output count mismatch with placements: " + f"got {len(local_outputs)}, expected {len(output_placements)}", + ) + + for i, (local_out, gt, out_plc) in enumerate( + zip(local_outputs, ground_truths, output_placements) + ): + if not isinstance(local_out, torch.Tensor): + return False, f"Local output[{i}] is not a tensor: {type(local_out)}" + + if not isinstance(local_out, LocalTensor): + return False, f"LocalTensor inputs produced non-LocalTensor output[{i}]" + + output_dt = DTensor.from_local( + local_out, + mesh, + (out_plc,), + shape=gt.shape, + stride=gt.stride(), + ) + + if isinstance(out_plc, Replicate): + local_values = [local_out._local_tensors[r] for r in range(world_size)] + all_same = all( + torch.allclose(local_values[0], lv, atol=1e-5, rtol=1e-5) + for lv in local_values[1:] + ) + if not all_same: + return ( + False, + f"Replicate output[{i}] but local values differ across ranks", + ) + + full_output = output_dt.redistribute(mesh, (Replicate(),)).to_local() + + if isinstance(full_output, LocalTensor): + full_output = full_output._local_tensors[0] + + if gt.shape != full_output.shape: + return ( + False, + f"Shape mismatch[{i}]: expected {gt.shape}, got {full_output.shape}", + ) + + if not torch.allclose(gt, full_output, atol=1e-5, rtol=1e-5, equal_nan=True): + max_diff = (gt - full_output).abs().max().item() + return False, f"Value mismatch[{i}]: max_diff={max_diff:.6f}" + + return True, "" + + def validate_combination( op: Callable[..., Any], sample_input: SampleInput, @@ -438,10 +555,13 @@ def validate_combination( world_size: int = 2, mesh: DeviceMesh | None = None, mask_shift: int = 0, -) -> tuple[bool, str]: +) -> tuple[bool | None, str]: """ Validate a single placement combination against ground truth. + Returns (True, "") if valid, (False, error_msg) if invalid, or + (None, reason) if the combination cannot be tested (e.g. uneven shards). + The validation logic: 1. Shard inputs according to input placements to get local tensors 2. Run the raw op on local tensors (bypassing DTensor dispatch) @@ -467,33 +587,17 @@ def validate_combination( device = tensors[0][1].device.type if tensors else "cpu" mesh = init_device_mesh(device, (world_size,)) - local_tensors = [] - for tensor_idx, ((name, tensor), placement) in enumerate( - zip(tensors, combination[0]) - ): - if isinstance(placement, Partial): - local_tensor = _create_partial_input( - tensor, placement, world_size, tensor_idx, mask_shift - ) - elif isinstance(placement, Replicate): - _tmp = {r: tensor.clone() for r in range(world_size)} - # pyrefly: ignore [bad-argument-type, bad-argument-count] - local_tensor = LocalTensor(_tmp) - elif isinstance(placement, Shard): - # Create sharded LocalTensor directly to work in LocalTensorMode - shard_dim = placement.dim - chunks = tensor.tensor_split(world_size, dim=shard_dim) - _tmp = { - r: chunks[r].clone(memory_format=torch.contiguous_format) - for r in range(world_size) - } - # pyrefly: ignore [bad-argument-type, bad-argument-count] - local_tensor = LocalTensor(_tmp) - else: - # Fallback for other placement types - dt = distribute_tensor(tensor.clone(), mesh, (placement,)) - local_tensor = dt.to_local() - local_tensors.append(local_tensor) + # Uneven shards produce SymInt in LocalTensor's wrapper shape, + # which breaks C++ overload resolution before __torch_dispatch__ + # can intercept. Return None to signal "untestable". + for (name, tensor), placement in zip(tensors, combination[0]): + if isinstance(placement, Shard): + if tensor.size(placement.dim) % world_size != 0: + return None, "uneven shard" + + local_tensors = _shard_tensors( + tensors, combination[0], world_size, mesh, mask_shift + ) local_idx = 0 @@ -515,84 +619,95 @@ def _replace_with_local(a): local_output = op(local_input, *local_args, **local_kwargs) - # Normalize to list for uniform handling of single/multi-output ops - if isinstance(local_output, (list, tuple)): - local_outputs = list(local_output) - else: - local_outputs = [local_output] + return _compare_outputs( + local_output, ground_truth, combination[1], mesh, world_size + ) - if isinstance(ground_truth, list): - ground_truths = ground_truth - else: - ground_truths = [ground_truth] + except Exception as e: + # TODO: This is too broad. Consider: (1) explicit checks for shard dim + # validity and shape compatibility before calling tensor_split/from_local, + # (2) scoped try/except around op() and redistribute() that raise specific + # exceptions (e.g., UnsupportedRedistribute, OpError), and (3) only + # catching those here, letting real bugs propagate. + return False, f"Exception: {type(e).__name__}: {e}" - if len(local_outputs) != len(ground_truths): - return ( - False, - f"Output count mismatch: got {len(local_outputs)}, expected {len(ground_truths)}", - ) - if len(local_outputs) != len(combination[1]): - return ( - False, - f"Output count mismatch with placements: " - f"got {len(local_outputs)}, expected {len(combination[1])}", - ) +def extract_tensors_from_args( + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> list[tuple[str, torch.Tensor]]: + """Extract tensor arguments from captured aten args/kwargs. - for i, (local_out, gt, out_plc) in enumerate( - zip(local_outputs, ground_truths, combination[1]) - ): - if not isinstance(local_out, torch.Tensor): - return False, f"Local output[{i}] is not a tensor: {type(local_out)}" + Unlike extract_tensors_from_sample which walks SampleInput pytrees, + this walks the flat aten-level args and kwargs directly. + """ + tensors: list[tuple[str, torch.Tensor]] = [] + idx = 0 - if not isinstance(local_out, LocalTensor): - return False, f"LocalTensor inputs produced non-LocalTensor output[{i}]" + def _collect(x: Any) -> Any: + nonlocal idx + if isinstance(x, torch.Tensor): + tensors.append((f"tensor_{idx}", x)) + idx += 1 + return x - output_dt = DTensor.from_local( - local_out, - mesh, - (out_plc,), - shape=gt.shape, - stride=gt.stride(), - ) + pytree.tree_map(_collect, args) + pytree.tree_map(_collect, kwargs) + return tensors - if isinstance(out_plc, Replicate): - local_values = [local_out._local_tensors[r] for r in range(world_size)] - all_same = all( - torch.allclose(local_values[0], lv, atol=1e-5, rtol=1e-5) - for lv in local_values[1:] - ) - if not all_same: - return ( - False, - f"Replicate output[{i}] but local values differ across ranks", - ) - full_output = output_dt.redistribute(mesh, (Replicate(),)).to_local() +def validate_aten_combination( + aten_op: OpOverload, + captured_args: tuple[Any, ...], + captured_kwargs: dict[str, Any], + ground_truth: torch.Tensor | list[torch.Tensor], + combination: PlacementCombination, + world_size: int, + mesh: DeviceMesh, + mask_shift: int = 0, +) -> tuple[bool | None, str]: + """Validate a placement combination using aten-level captured args. + + Works directly with aten op args/kwargs instead of SampleInput pytrees. + Replaces tensors in the flat args/kwargs with sharded LocalTensors, + calls the aten op, and compares output. - if isinstance(full_output, LocalTensor): - full_output = full_output._local_tensors[0] + Returns (True, ""), (False, error_msg), or (None, reason) if untestable. + """ + try: + tensors = extract_tensors_from_args(captured_args, captured_kwargs) + if not tensors: + return False, "No tensor args in captured aten call" - if gt.shape != full_output.shape: - return ( - False, - f"Shape mismatch[{i}]: expected {gt.shape}, got {full_output.shape}", - ) + for (name, tensor), placement in zip(tensors, combination[0]): + if isinstance(placement, Shard): + if tensor.size(placement.dim) % world_size != 0: + return None, "uneven shard" + + local_tensors = _shard_tensors( + tensors, combination[0], world_size, mesh, mask_shift + ) + + local_idx = 0 + + def _replace_with_local(a: Any) -> Any: + nonlocal local_idx + if isinstance(a, torch.Tensor): + local = local_tensors[local_idx] + local_idx += 1 + return local + return a - if not torch.allclose( - gt, full_output, atol=1e-5, rtol=1e-5, equal_nan=True - ): - max_diff = (gt - full_output).abs().max().item() - return False, f"Value mismatch[{i}]: max_diff={max_diff:.6f}" + local_args = pytree.tree_map(_replace_with_local, captured_args) + local_kwargs = pytree.tree_map(_replace_with_local, captured_kwargs) - return True, "" + local_output = aten_op(*local_args, **local_kwargs) + + return _compare_outputs( + local_output, ground_truth, combination[1], mesh, world_size + ) except Exception as e: - # TODO: This is too broad. Consider: (1) explicit checks for shard dim - # validity and shape compatibility before calling tensor_split/from_local, - # (2) scoped try/except around op() and redistribute() that raise specific - # exceptions (e.g., UnsupportedRedistribute, OpError), and (3) only - # catching those here, letting real bugs propagate. return False, f"Exception: {type(e).__name__}: {e}" @@ -695,34 +810,41 @@ def _extract_rules_from_op_strategy( class _CaptureAtenOp(torch.utils._python_dispatch.TorchDispatchMode): - """Dispatch mode that captures aten ops called and their args.""" + """Dispatch mode that captures aten ops called, their args, and return values.""" def __init__(self, target_op_name: str = ""): self.target_op_name = target_op_name.lower() - self.all_ops: list[tuple[OpOverload, tuple[Any, ...], dict[str, Any]]] = [] + self.all_ops: list[tuple[OpOverload, tuple[Any, ...], dict[str, Any], Any]] = [] self.best_match: OpOverload | None = None self.best_match_args: tuple[Any, ...] | None = None self.best_match_kwargs: dict[str, Any] | None = None + self.best_match_result: Any = None def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + result = func(*args, **kwargs) if func.namespace == "aten": - self.all_ops.append((func, args, kwargs)) + self.all_ops.append((func, args, kwargs, result)) op_name = func.name().split("::")[1].split(".")[0].lower() if self.target_op_name and self.target_op_name in op_name: if self.best_match is None: self.best_match = func self.best_match_args = args self.best_match_kwargs = kwargs - return func(*args, **kwargs) + self.best_match_result = result + return result def get_aten_op_for_sample( op: Callable[..., Any], sample: SampleInput, op_name: str = "" -) -> tuple[OpOverload | None, tuple[Any, ...], dict[str, Any]]: +) -> _CaptureAtenOp: """ - Determine the actual aten op that will be dispatched for a given sample. + Capture aten ops dispatched for a given sample. + + Returns the _CaptureAtenOp object containing all captured ops with their + args, kwargs, and return values. Use best_match for the primary op or + all_ops for exhaustive iteration. """ with _CaptureAtenOp(op_name) as capture: try: @@ -733,16 +855,15 @@ def get_aten_op_for_sample( except Exception: pass - if capture.best_match is not None: - captured_op = capture.best_match - captured_args = capture.best_match_args - captured_kwargs = capture.best_match_kwargs - elif capture.all_ops: - captured_op, captured_args, captured_kwargs = capture.all_ops[0] - else: - return None, (), {} + # Populate best_match from first op if target match wasn't found + if capture.best_match is None and capture.all_ops: + first_op, first_args, first_kwargs, first_result = capture.all_ops[0] + capture.best_match = first_op + capture.best_match_args = first_args + capture.best_match_kwargs = first_kwargs + capture.best_match_result = first_result - return captured_op, captured_args, captured_kwargs + return capture def query_single_dim_strategy( @@ -1091,12 +1212,17 @@ def _validate_with_mitigations( world_size: int, mesh: DeviceMesh, mitigations: _FalsePositiveMitigations, -) -> bool: - """Validate a combination, including false positive mitigation re-checks.""" +) -> bool | None: + """Validate a combination, including false positive mitigation re-checks. + + Returns True (valid), False (invalid), or None (untestable). + """ combo: PlacementCombination = (input_placements, output_placements) is_valid, _ = validate_combination( op, sample, tensors, combo, ground_truth, world_size, mesh ) + if is_valid is None: + return None # Flipped-mask mitigation: the checkerboard mask that controls offset # signs (for P(sum)/P(avg)) or rank ownership (for P(min)/P(max)) is @@ -1174,6 +1300,115 @@ def _validate_with_mitigations( return is_valid +@dataclass +class _AtenFalsePositiveMitigations: + """Bundle of negated variants for aten-level false positive detection. + + Unlike _FalsePositiveMitigations, this works with captured aten args/kwargs + directly, without SampleInput. The rounding_mode mitigation is skipped + since it's an OpInfo-level concept that doesn't appear in captured aten kwargs. + """ + + negated_args: tuple[Any, ...] | None = None + negated_kwargs: dict[str, Any] | None = None + negated_ground_truth: torch.Tensor | list[torch.Tensor] | None = None + + +def _negate_tensors_in_tree(tree: Any) -> Any: + """Negate all tensors in a pytree structure.""" + + def _negate(x: Any) -> Any: + if isinstance(x, torch.Tensor): + return -x + return x + + return pytree.tree_map(_negate, tree) + + +def _prepare_aten_mitigations( + aten_op: OpOverload, + captured_args: tuple[Any, ...], + captured_kwargs: dict[str, Any], +) -> _AtenFalsePositiveMitigations: + """Create negated variants for aten-level false positive detection.""" + m = _AtenFalsePositiveMitigations() + try: + m.negated_args = _negate_tensors_in_tree(captured_args) + m.negated_kwargs = _negate_tensors_in_tree(captured_kwargs) + result = aten_op(*m.negated_args, **m.negated_kwargs) + if _is_tensor_output(result): + m.negated_ground_truth = _to_ground_truth(result) + else: + m.negated_args = None + m.negated_kwargs = None + except Exception: + m.negated_args = None + m.negated_kwargs = None + return m + + +def _validate_aten_with_mitigations( + aten_op: OpOverload, + captured_args: tuple[Any, ...], + captured_kwargs: dict[str, Any], + input_placements: tuple[Placement, ...], + output_placements: tuple[Placement, ...], + ground_truth: torch.Tensor | list[torch.Tensor], + world_size: int, + mesh: DeviceMesh, + mitigations: _AtenFalsePositiveMitigations, +) -> bool | None: + """Validate an aten-level combination with false positive mitigations. + + Returns True (valid), False (invalid), or None (untestable). + """ + combo: PlacementCombination = (input_placements, output_placements) + is_valid, _ = validate_aten_combination( + aten_op, + captured_args, + captured_kwargs, + ground_truth, + combo, + world_size, + mesh, + ) + if is_valid is None: + return None + + if is_valid and has_any_partial(input_placements, output_placements): + is_valid, _ = validate_aten_combination( + aten_op, + captured_args, + captured_kwargs, + ground_truth, + combo, + world_size, + mesh, + mask_shift=1, + ) + + if ( + is_valid + and mitigations.negated_args is not None + and has_pmin_pmax(input_placements, output_placements) + ): + if mitigations.negated_kwargs is None: + raise AssertionError("negated_kwargs must not be None") + if mitigations.negated_ground_truth is None: + raise AssertionError("negated_ground_truth must not be None") + is_valid, _ = validate_aten_combination( + aten_op, + mitigations.negated_args, + mitigations.negated_kwargs, + mitigations.negated_ground_truth, + combo, + world_size, + mesh, + ) + + return is_valid + + def _assert_keys_normalized( keys: set[ComboKey], input_shapes: tuple[tuple[int, ...], ...], @@ -1200,10 +1435,13 @@ def _compare_rules( variant: str, stats: ComparisonStats, sample: SampleInput | None = None, + untestable: set[ComboKey] | None = None, ) -> None: """Compare ground truth valid rules against DTensor claimed rules, updating stats.""" if not dtensor_rules: return + if untestable is None: + untestable = set() _assert_keys_normalized(ground_truth_valid, input_shapes, output_shapes) _assert_keys_normalized(dtensor_rules, input_shapes, output_shapes) @@ -1215,7 +1453,7 @@ def _compare_rules( stats.true_positives_by_op[op_str] = ( stats.true_positives_by_op.get(op_str, 0) + 1 ) - else: + elif combo_key not in untestable: stats.false_negatives.append( Discrepancy( input_placements=combo_key[0], @@ -1232,7 +1470,7 @@ def _compare_rules( ) for combo_key in dtensor_rules: - if combo_key not in ground_truth_valid: + if combo_key not in ground_truth_valid and combo_key not in untestable: stats.false_positives.append( Discrepancy( input_placements=combo_key[0], @@ -1374,12 +1612,201 @@ def _discover_aten_op( tensors = extract_tensors_from_sample(sample) if not tensors or any(0 in t.shape for _, t in tensors): continue - aten_op, _, _ = get_aten_op_for_sample(opinfo.op, sample, opinfo.name) + capture = get_aten_op_for_sample(opinfo.op, sample, opinfo.name) + aten_op = capture.best_match if aten_op is not None: return aten_op return None +def _check_ground_truth( + result: Any, +) -> torch.Tensor | list[torch.Tensor] | None: + """Validate an op result is suitable as ground truth. + + Returns the ground truth tensor(s) or None if the result should be skipped. + """ + if isinstance(result, (list, tuple)): + if not all(isinstance(t, torch.Tensor) for t in result): + return None + gt = list(result) + elif isinstance(result, torch.Tensor): + gt = result + else: + return None + + first_gt = gt[0] if isinstance(gt, list) else gt + if first_gt.numel() == 0: + return None + if (first_gt == 0).all(): + return None + if first_gt.isnan().all(): + return None + return gt + + +def _validate_aten_op_for_sample( + aten_op: OpOverload, + captured_args: tuple[Any, ...], + captured_kwargs: dict[str, Any], + ground_truth: torch.Tensor | list[torch.Tensor], + world_size: int, + incorrect_only: bool, + verbose: bool, + sample_idx: int, + variant: str, + stats: ComparisonStats, + sample: SampleInput | None = None, +) -> tuple[int, int]: + """Validate a single aten op with captured args against ground truth. + + Shared logic used by both default (1:1) and allow_composite modes in + compare_operator. Returns (samples_counted, combinations_counted). + """ + tensors = extract_tensors_from_args(captured_args, captured_kwargs) + if not tensors: + return 0, 0 + if any(0 in t.shape for _, t in tensors): + return 0, 0 + + input_shapes = tuple(t.shape for _, t in tensors) + gt_list = ground_truth if isinstance(ground_truth, list) else [ground_truth] + output_shapes = tuple(tuple(gt.shape) for gt in gt_list) + n_outputs = len(gt_list) + first_gt = gt_list[0] + + scalar_args = tuple(a for a in captured_args if not isinstance(a, torch.Tensor)) + scalar_kwargs = { + k: v for k, v in captured_kwargs.items() if not isinstance(v, torch.Tensor) + } + + mitigations = _prepare_aten_mitigations(aten_op, captured_args, captured_kwargs) + + input_placement_options = [ + get_1d_input_placements_for_tensor(t, include_partial=True) for _, t in tensors + ] + output_placement_options = get_1d_output_placements_for_tensor(first_gt) + + dtensor_rules = _query_dtensor_rules( + aten_op, + tensors, + captured_args, + captured_kwargs, + input_shapes, + output_shapes, + world_size, + verbose, + ) + + ground_truth_valid: set[ComboKey] = set() + total_combinations = 0 + + tensor_device = tensors[0][1].device.type if tensors else "cpu" + with LocalTensorMode(frozenset(range(world_size))): + mesh = init_device_mesh(tensor_device, (world_size,)) + + if incorrect_only: + combinations_to_test = [] + for combo_key in dtensor_rules: + input_plc_strs, output_plc_strs = combo_key + input_plcs_list: list[Placement] = [] + all_valid = True + for s in input_plc_strs: + p = parse_placement(s) + if p is None: + all_valid = False + break + input_plcs_list.append(p) + output_plcs_list: list[Placement] = [] + for s in output_plc_strs: + p = parse_placement(s) + if p is None: + all_valid = False + break + output_plcs_list.append(p) + if not all_valid: + continue + combinations_to_test.append( + ( + tuple(input_plcs_list), + tuple(output_plcs_list), + combo_key, + ) + ) + else: + combinations_to_test = [] + for input_placements in itertools.product(*input_placement_options): + if is_fully_replicated(input_placements): + continue + for output_placement in output_placement_options: + output_placements = tuple( + output_placement for _ in range(n_outputs) + ) + combo_key = ( + tuple(str(p) for p in input_placements), + tuple(str(p) for p in output_placements), + ) + combinations_to_test.append( + (input_placements, output_placements, combo_key) + ) + + untestable: set[ComboKey] = set() + + for ( + input_placements, + output_placements, + combo_key, + ) in combinations_to_test: + total_combinations += 1 + is_valid = _validate_aten_with_mitigations( + aten_op, + captured_args, + captured_kwargs, + input_placements, + output_placements, + ground_truth, + world_size, + mesh, + mitigations, + ) + + if is_valid is None: + normalized_key = normalize_combo_key( + combo_key, input_shapes, output_shapes + ) + untestable.add(normalized_key) + elif is_valid: + normalized_key = normalize_combo_key( + combo_key, input_shapes, output_shapes + ) + if not is_fully_replicated( + tuple(parse_placement(p) or Replicate() for p in normalized_key[0]) + ): + ground_truth_valid.add(normalized_key) + + _compare_rules( + ground_truth_valid, + dtensor_rules, + input_shapes, + output_shapes, + sample_idx, + scalar_args, + scalar_kwargs, + aten_op, + variant, + stats, + sample, + untestable, + ) + + if verbose: + print(f" Sample {sample_idx} [{aten_op}]: shapes={input_shapes}") + print(f" Ground truth valid: {len(ground_truth_valid)}") + print(f" DTensor rules: {len(dtensor_rules)}") + + return 1, total_combinations + + def compare_operator( op_name: str, device: str = "cpu", @@ -1388,6 +1815,7 @@ def compare_operator( max_samples: int | None = None, verbose: bool = False, incorrect_only: bool = False, + allow_composite: bool = False, ) -> ComparisonStats: """ Compare DTensor's sharding rules against ground truth for an operator. @@ -1400,7 +1828,10 @@ def compare_operator( max_samples: Maximum number of samples to test per OpInfo verbose: Print detailed output incorrect_only: If True, only test DTensor's claimed rules for correctness. - Skips exhaustive search for missing rules (much faster). + Skips search for missing rules (much faster). + allow_composite: If True, validate each supported aten op individually for + samples that decompose into multiple aten calls. Default (False) + skips samples where the OpInfo doesn't map 1:1 to a single aten op. """ if op_name in SKIP_OPS: return ComparisonStats() @@ -1409,10 +1840,15 @@ def compare_operator( stats = ComparisonStats() - aten_op = _discover_aten_op(opinfos, device, dtype) - if aten_op is None or not _has_dtensor_support(aten_op): - stats.no_dtensor_support = True - return stats + if not allow_composite: + aten_op = _discover_aten_op(opinfos, device, dtype) + if aten_op is None or not _has_dtensor_support(aten_op): + if verbose: + print(f" ATEN_OP_MAP: {op_name} -> {aten_op} [no_support]") + stats.no_dtensor_support = True + return stats + if verbose: + print(f" ATEN_OP_MAP: {op_name} -> {aten_op} [supported]") total_samples = 0 total_combinations = 0 @@ -1420,7 +1856,7 @@ def compare_operator( for opinfo in opinfos: variant = opinfo.variant_test_name - if variant: + if variant and verbose: print(f"\n OpInfo variant: {variant}") op = opinfo.op @@ -1428,205 +1864,99 @@ def compare_operator( try: samples = list(opinfo.sample_inputs(device, dtype)) except Exception as e: - print(f" Error generating samples: {e}") + if verbose: + print(f" Error generating samples: {e}") continue if max_samples: samples = samples[:max_samples] for sample_idx, sample in enumerate(samples): - tensors = extract_tensors_from_sample(sample) - - if len(tensors) == 0: + # Check that SampleInput has tensor inputs and no zero-sized tensors + sample_tensors = extract_tensors_from_sample(sample) + if len(sample_tensors) == 0: skip_reasons["no tensor inputs"] += 1 continue - - if any(0 in t.shape for _, t in tensors): + if any(0 in t.shape for _, t in sample_tensors): skip_reasons["zero-sized tensor"] += 1 continue - total_samples += 1 - - try: - ground_truth_raw = _run_op_on_sample(op, sample) - if isinstance(ground_truth_raw, (list, tuple)): - if not all(isinstance(t, torch.Tensor) for t in ground_truth_raw): - total_samples -= 1 - skip_reasons["non-tensor output"] += 1 - continue - ground_truth = list(ground_truth_raw) - elif isinstance(ground_truth_raw, torch.Tensor): - ground_truth = ground_truth_raw - else: - total_samples -= 1 - skip_reasons["non-tensor output"] += 1 - continue - - # For skip checks, use the first tensor (or the only tensor) - first_gt = ( - ground_truth[0] if isinstance(ground_truth, list) else ground_truth - ) - if first_gt.numel() == 0: - total_samples -= 1 - skip_reasons["zero-element output"] += 1 - continue - if (first_gt == 0).all(): - total_samples -= 1 - skip_reasons["all-zero output"] += 1 - continue - if first_gt.isnan().all(): - total_samples -= 1 - skip_reasons["all-NaN output"] += 1 - continue - except Exception: - skip_reasons["op raised exception"] += 1 + # Capture all aten ops dispatched for this sample + capture = get_aten_op_for_sample(op, sample, opinfo.name) + if capture.best_match is None: + skip_reasons["no aten op captured"] += 1 continue - input_shapes = tuple(t.shape for _, t in tensors) - gt_list = ground_truth if isinstance(ground_truth, list) else [ground_truth] - output_shapes = tuple(tuple(gt.shape) for gt in gt_list) - n_outputs = len(gt_list) - - scalar_args = tuple( - a for a in sample.args if not isinstance(a, torch.Tensor) - ) - scalar_kwargs = { - k: v - for k, v in sample.kwargs.items() - if not isinstance(v, torch.Tensor) - } - - mitigations = _prepare_false_positive_mitigations(op, sample, tensors) - - input_placement_options = [ - get_1d_input_placements_for_tensor(t, include_partial=True) - for _, t in tensors + # Count supported aten ops in the capture + supported_ops = [ + (func, args, kwargs, result) + for func, args, kwargs, result in capture.all_ops + if _has_dtensor_support(func) ] - # Use first output for enumerating placement options (DTensor applies - # the same placement to all outputs of multi-output ops) - output_placement_options = get_1d_output_placements_for_tensor(first_gt) - - aten_op, captured_args, captured_kwargs = get_aten_op_for_sample( - op, sample, opinfo.name - ) + num_supported = len(supported_ops) - dtensor_rules = _query_dtensor_rules( - aten_op, - tensors, - captured_args, - captured_kwargs, - input_shapes, - output_shapes, - world_size, - verbose, - ) + if allow_composite: + # Validate each supported aten op individually + if num_supported == 0: + skip_reasons["no supported aten ops"] += 1 + continue - ground_truth_valid: set[ComboKey] = set() - - tensor_device = tensors[0][1].device.type if tensors else "cpu" - with LocalTensorMode(frozenset(range(world_size))): - mesh = init_device_mesh(tensor_device, (world_size,)) - - if incorrect_only: - combinations_to_test = [] - for combo_key in dtensor_rules: - input_plc_strs, output_plc_strs = combo_key - input_plcs_list: list[Placement] = [] - all_valid = True - for s in input_plc_strs: - p = parse_placement(s) - if p is None: - all_valid = False - break - input_plcs_list.append(p) - output_plcs_list: list[Placement] = [] - for s in output_plc_strs: - p = parse_placement(s) - if p is None: - all_valid = False - break - output_plcs_list.append(p) - if not all_valid: - continue - combinations_to_test.append( - ( - tuple(input_plcs_list), - tuple(output_plcs_list), - combo_key, - ) - ) - else: - combinations_to_test = [] - for input_placements in itertools.product(*input_placement_options): - if is_fully_replicated(input_placements): - continue - for output_placement in output_placement_options: - # Apply same placement to all outputs (matches - # DTensor propagator behavior for multi-output ops) - output_placements = tuple( - output_placement for _ in range(n_outputs) - ) - combo_key = ( - tuple(str(p) for p in input_placements), - tuple(str(p) for p in output_placements), - ) - combinations_to_test.append( - (input_placements, output_placements, combo_key) - ) - - for ( - input_placements, - output_placements, - combo_key, - ) in combinations_to_test: - total_combinations += 1 - is_valid = _validate_with_mitigations( - op, - sample, - tensors, - input_placements, - output_placements, - ground_truth, + for func, args, kwargs, result in supported_ops: + gt = _check_ground_truth(result) + if gt is None: + skip_reasons["non-tensor/degenerate aten output"] += 1 + continue + n_samples, n_combos = _validate_aten_op_for_sample( + func, + args, + kwargs, + gt, world_size, - mesh, - mitigations, + incorrect_only, + verbose, + sample_idx, + variant, + stats, + sample, ) + total_samples += n_samples + total_combinations += n_combos + else: + # Default: only validate samples with a single supported aten op + if num_supported != 1: + skip_reasons["non-1:1 aten mapping"] += 1 + continue - if is_valid: - normalized_key = normalize_combo_key( - combo_key, input_shapes, output_shapes - ) - if not is_fully_replicated( - tuple( - parse_placement(p) or Replicate() - for p in normalized_key[0] - ) - ): - ground_truth_valid.add(normalized_key) - - _compare_rules( - ground_truth_valid, - dtensor_rules, - input_shapes, - output_shapes, - sample_idx, - scalar_args, - scalar_kwargs, - aten_op, - variant, - stats, - sample, - ) + func, args, kwargs, result = supported_ops[0] + gt = _check_ground_truth(result) + if gt is None: + skip_reasons["non-tensor/degenerate aten output"] += 1 + continue - if verbose: - print(f" Sample {sample_idx}: shapes={input_shapes}") - print(f" Ground truth valid: {len(ground_truth_valid)}") - print(f" DTensor rules: {len(dtensor_rules)}") + n_samples, n_combos = _validate_aten_op_for_sample( + func, + args, + kwargs, + gt, + world_size, + incorrect_only, + verbose, + sample_idx, + variant, + stats, + sample, + ) + total_samples += n_samples + total_combinations += n_combos stats.total_samples = total_samples stats.total_combinations = total_combinations stats.skip_reasons = dict(skip_reasons) + # In allow_composite mode, check DTensor support after processing + if allow_composite and total_samples == 0 and not skip_reasons: + stats.no_dtensor_support = True + return stats @@ -1772,6 +2102,12 @@ def _print_ops(label: str, ops: list) -> None: action="store_true", help="Only test DTensor's claimed rules (faster, skips missing detection)", ) + parser.add_argument( + "--allow-composite", + action="store_true", + help="Validate each supported aten op individually for decomposed ops " + "(default skips non-1:1 aten mappings)", + ) parser.add_argument("--device", default="cuda", help="Device to use") parser.add_argument("--dtype", default="float32", help="Dtype to use") parser.add_argument( @@ -1830,6 +2166,10 @@ def _print_ops(label: str, ops: list) -> None: # Preamble display_names = [_format_op_name(n) for n in op_names] print(f"Testing ops: {', '.join(display_names)}") + if args.allow_composite: + print( + "Mode: allow-composite (validates each aten op in decomposed samples)" + ) if args.incorrect_only: print("Mode: incorrect-only (fast)") print(f"Device: {args.device}, Dtype: {dtype}, World size: {args.world_size}") @@ -1854,7 +2194,9 @@ def _print_ops(label: str, ops: list) -> None: dtype, args.world_size, args.max_samples, + verbose=True, incorrect_only=args.incorrect_only, + allow_composite=args.allow_composite, ) elapsed = time.time() - op_start diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 6bb6309feb1a6..1381a9abba6f8 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -4,7 +4,7 @@ import itertools import operator from collections.abc import Callable, Iterable, Sequence -from typing import cast, TypeAlias, TypeVar +from typing import TypeAlias, TypeVar import torch from torch._prims_common import DimsSequenceType, DimsType @@ -22,6 +22,7 @@ ) from torch.distributed.tensor.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( + _is_shard_like, _StridedShard, Partial, Placement, @@ -205,7 +206,7 @@ def is_tensor_shardable( # number of shards in each tensor dimension num_shards = [1] * len(shape) for i, placement in enumerate(spec.placements): - if isinstance(placement, Shard | _StridedShard): + if _is_shard_like(placement): shard_dim = placement.dim if shard_dim >= len(shape): return False @@ -229,7 +230,7 @@ def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: # number of shards in each tensor dimension num_shards = [1] * len(shape) for i, placement in enumerate(spec.placements): - if isinstance(placement, Shard | _StridedShard): + if _is_shard_like(placement): shard_dim = placement.dim if shard_dim >= len(shape): return False @@ -254,17 +255,24 @@ def is_tensor_evenly_shardable_on_dim( num_shards = 1 for i, placement in enumerate(spec.placements): - if placement.is_shard(): - shard_dim = cast(Shard, placement).dim - if shard_dim == dim: - num_shards *= spec.mesh.size(i) + if _is_shard_like(placement) and placement.dim == dim: + num_shards *= spec.mesh.size(i) + if isinstance(placement, _StridedShard): + # _StridedShard._split_tensor first chunks into split_factor + # groups, then into num_shards within each group, so the dim + # must be divisible by the product of both. This is stricter + # than the final num_shards check and implies it. Note: + # num_shards already includes spec.mesh.size(i) from this + # iteration, so the check covers the full shard count. + if shape[dim] % (placement.split_factor * num_shards) != 0: + return False return shape[dim] % num_shards == 0 def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: """Return True if tensor dim is sharded.""" - return any(p.is_shard(dim) for p in spec.placements) + return any(_is_shard_like(p) and p.dim == dim for p in spec.placements) def is_tensor_partial(spec: DTensorSpec) -> bool: @@ -309,7 +317,7 @@ def map_placements_after_broadcast( elif isinstance(placement, Replicate): new_placements.append(placement) else: - if not isinstance(placement, Shard | _StridedShard): + if not _is_shard_like(placement): raise AssertionError shard_dim = normalize_dim(placement.dim, len(shape)) new_shard_dim = broadcast_dims_map[shard_dim] @@ -367,6 +375,7 @@ def expand_to_full_mesh_op_strategy( [list[DTensorSpec], DTensorSpec | tuple[DTensorSpec | None, ...]], bool ] | None = None, + different_mesh_args: list[int] | None = None, ) -> OpStrategy: """ Convenience function to allow writing a sharding strategy considering only a single mesh dimension, @@ -403,6 +412,18 @@ def expand_to_full_mesh_op_strategy( args_strategy = op_schema.args_strategy kwargs_strategy = op_schema.kwargs_strategy input_args_strategy = args_strategy + kwargs_strategy + + # Propagate use_strided_shard_as_shard_order from inputs so that + # strategy specs with _StridedShard get the correct flag (and thus + # correct shard_order) at construction time, avoiding shard_order + # mismatches in redistribute_cost computation. + _input_use_strided: bool | None = None + for input_strat in input_args_strategy: + input_spec = input_strat.strategies[0].output_spec + if any(isinstance(p, _StridedShard) for p in input_spec.placements): + _input_use_strided = input_spec.use_strided_shard_as_shard_order + break + all_strategies = [] # Track input placements if we skip strategies due to inplace placement mismatch blocking_inplace_input_placements: tuple[Placement, ...] | None = None @@ -441,7 +462,20 @@ def expand_to_full_mesh_op_strategy( input_strategy_counter += 1 # pyrefly: ignore [bad-argument-type] - spec_list.append(DTensorSpec(mesh, specs, tensor_meta=tensor_meta)) + use_strided = ( + _input_use_strided + if _input_use_strided is not None + and any(isinstance(p, _StridedShard) for p in specs) + else None + ) + spec_list.append( + DTensorSpec( + mesh, + specs, + tensor_meta=tensor_meta, + use_strided_shard_as_shard_order=use_strided, + ) + ) else: spec_list.append(None) @@ -476,6 +510,52 @@ def expand_to_full_mesh_op_strategy( f"input_specs({len(input_specs)}) != strategies({len(input_args_strategy)}: " f"{len(args_strategy)} args + {len(kwargs_strategy)} kwargs)" ) + + # Note [Multi-mesh args] + # + # Some ops accept args whose DTensor lives on a different DeviceMesh + # than the op's primary compute mesh. We call these "multi-mesh + # args". They arise in fused optimizer ops (e.g. _fused_adam_) + # where *state_steps* is a per-rank scalar counter allocated on a + # smaller sub-mesh (e.g. 1-D DP) while params and grads live on a + # larger mesh (e.g. 2-D DP × TP). + # + # Why must these args be Replicate? + # Sharding implies a specific partitioning of a tensor's data + # across the ranks of a mesh. If a tensor doesn't even *exist* + # on the compute mesh, there is no meaningful way to interpret a + # Shard placement for it. Replicate, on the other hand, is + # mesh-agnostic: every rank already holds the full data, so the + # op can simply read the value regardless of which mesh owns it. + # + # What we do here: + # We preserve the original mesh and Replicate placement for these + # args so the propagator does not try to redistribute them onto + # the compute mesh (which would fail or produce wrong results). + # + # This is distinct from the *element_mesh* handling in + # single_dim_strategy.py, which deals with foreach ops where + # different *elements* in a tensor list may live on different + # sub-meshes (e.g. param group A on 2-D mesh, param group B on + # 1-D mesh). + # TODO: refactor fused_ops handling so that there are no longer + # args on different meshes + if different_mesh_args is not None: + for idx in different_mesh_args: + if idx < len(input_args_strategy): + cross_mesh_input = input_args_strategy[idx] + original_spec = cross_mesh_input.strategies[0].output_spec + if original_spec.mesh != mesh: + if not all(p == Replicate() for p in original_spec.placements): + raise RuntimeError( + f"Cross-mesh input at index {idx} must be Replicate, " + f"but got {original_spec.placements}" + ) + input_specs[idx] = DTensorSpec( + mesh=original_spec.mesh, + placements=original_spec.placements, + tensor_meta=original_spec.tensor_meta, + ) self_spec = input_args_strategy[0].strategies[0].output_spec redistribute_input = self_spec.placements != input_specs[0].placements @@ -564,7 +644,11 @@ def shift_shard_dims_after_insert( ) -> Sequence[Placement]: normalized_placements: list[Placement] = [] for placement in placements: - if isinstance(placement, Shard) and placement.dim >= insert_dim: + if isinstance(placement, _StridedShard) and placement.dim >= insert_dim: + normalized_placements.append( + _StridedShard(placement.dim + 1, split_factor=placement.split_factor) + ) + elif isinstance(placement, Shard) and placement.dim >= insert_dim: normalized_placements.append(Shard(placement.dim + 1)) else: normalized_placements.append(placement) @@ -576,7 +660,11 @@ def shift_shard_dims_after_remove( ) -> Sequence[Placement]: normalized_placements: list[Placement] = [] for placement in placements: - if isinstance(placement, Shard) and placement.dim > remove_dim: + if isinstance(placement, _StridedShard) and placement.dim > remove_dim: + normalized_placements.append( + _StridedShard(placement.dim - 1, split_factor=placement.split_factor) + ) + elif isinstance(placement, Shard) and placement.dim > remove_dim: normalized_placements.append(Shard(placement.dim - 1)) else: normalized_placements.append(placement) diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 4d016e299197d..ebbdd7a8d97ad 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -11,6 +11,7 @@ from torch.distributed.device_mesh import _get_device_handle, DeviceMesh from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor.placement_types import _StridedShard, Shard +from torch.types import IntLikeType logger = getLogger(__name__) @@ -372,9 +373,7 @@ def _compute_rng_offsets(self, spec: DTensorSpec) -> tuple[int, int]: from torch.distributed.tensor._ops.utils import prod mesh = spec.mesh - mesh_coordinate = mesh.get_coordinate() - if mesh_coordinate is None: - raise AssertionError + mesh_coordinate = [mesh._sym_get_coordinate(i) for i in range(mesh.ndim)] shard_idx_by_dim, total_num_shards_by_dim = _calc_shard_info( mesh_coordinate, spec @@ -391,8 +390,8 @@ def _compute_rng_offsets(self, spec: DTensorSpec) -> tuple[int, int]: return start_offset_incr, end_offset_incr def _calc_shard_linear_idx( - self, shard_coord: list[int], shard_size: list[int] - ) -> int: + self, shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType] + ) -> IntLikeType: return _calc_shard_linear_idx(shard_coord, shard_size) @@ -411,8 +410,8 @@ def _calc_first_shard_size(spec: DTensorSpec) -> list[int]: def _calc_shard_info( - mesh_coordinate: Sequence[int], spec: DTensorSpec -) -> tuple[list[int], list[int]]: + mesh_coordinate: Sequence[IntLikeType], spec: DTensorSpec +) -> tuple[list[IntLikeType], list[IntLikeType]]: mesh = spec.mesh # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP # case. Replace the custom logic with dim_map once we support it. @@ -432,14 +431,14 @@ def _calc_shard_info( # The coordinate on each tensor dim is a tuple (idx, range) # If a DTensor is partitioned on its dim i into n shards, and the current rank # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i - if mesh_coordinate is None: - raise AssertionError mesh_size = mesh.shape shard_idx_by_dim = [] - total_num_shards_by_dim = [] # total number of shards on each tensor dim + total_num_shards_by_dim: list[ + IntLikeType + ] = [] # total number of shards on each tensor dim for mesh_dim in dim_map: - shard_idx = 0 - total_num_shards = 1 + shard_idx: IntLikeType = 0 + total_num_shards: IntLikeType = 1 # the tensor dim is sharded on more than 1 mesh dim if isinstance(mesh_dim, list): rank_coord = [mesh_coordinate[d] for d in mesh_dim] @@ -454,10 +453,12 @@ def _calc_shard_info( return shard_idx_by_dim, total_num_shards_by_dim -def _calc_shard_linear_idx(shard_coord: list[int], shard_size: list[int]) -> int: +def _calc_shard_linear_idx( + shard_coord: Sequence[IntLikeType], shard_size: Sequence[IntLikeType] +) -> IntLikeType: # compute shard linear index - shard_linear_idx = 0 - shard_coord_stride = 1 + shard_linear_idx: IntLikeType = 0 + shard_coord_stride: IntLikeType = 1 for idx, size in zip(reversed(shard_coord), reversed(shard_size)): shard_linear_idx += idx * shard_coord_stride shard_coord_stride *= size diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 1c7cb2a485c78..3608e3edcbc2c 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -17,6 +17,7 @@ from torch.distributed._functional_collectives import _are_we_tracing from torch.distributed.tensor._collective_utils import one_step_redistribute_cost from torch.distributed.tensor._dtensor_spec import ( + _StridedShardNotDecodableError, DTensorSpec, ShardOrder, ShardOrderEntry, @@ -25,12 +26,14 @@ from torch.distributed.tensor._utils import assert_no_mixed_partial_types from torch.distributed.tensor.device_mesh import DeviceMesh from torch.distributed.tensor.placement_types import ( + _is_shard_like, _StridedShard, Partial, Placement, Replicate, Shard, ) +from torch.types import IntLikeType from torch.utils._debug_mode import get_active_debug_mode @@ -144,7 +147,7 @@ class _TransformInfo: mesh_dim: int src_dst_placements: tuple[Placement, Placement] # logical_shape on this mesh dimension - logical_shape: list[int] + logical_shape: Sequence[IntLikeType] def __post_init__(self): if self.mesh_dim < 0: @@ -164,11 +167,11 @@ def _comm_type_key(self) -> str | None: src, dst = self.src_dst_placements if src.is_partial() and dst.is_replicate(): return "all_reduce" - elif src.is_partial() and dst.is_shard(): + elif src.is_partial() and _is_shard_like(dst): return "reduce_scatter" - elif src.is_shard() and dst.is_replicate(): + elif _is_shard_like(src) and dst.is_replicate(): return "all_gather" - elif src.is_shard() and dst.is_shard(): + elif _is_shard_like(src) and _is_shard_like(dst): return "all_to_all" else: # Local ops (Replicate->Shard, Replicate->Partial, noop, etc.) @@ -251,18 +254,14 @@ def _update_shard_order_and_placements( current_placements[mesh_dim] = dst_placement -def _get_flattened_mesh_by_layout( +def _get_flattened_mesh_by_layout_impl( mesh: DeviceMesh, mesh_dims: tuple[int, ...] ) -> DeviceMesh | None: """ Query for an explicitly created flattened mesh using layout comparison. - Args: - mesh: The DeviceMesh to query - mesh_dims: Tuple of mesh dimension indices to look for - - Returns: - The flattened DeviceMesh if it was explicitly created, None otherwise. + Searches root_mesh._flatten_mapping for a mesh whose layout matches + the expected flattened layout for the given dims. Pure Python layout math. """ root_mesh = mesh._get_root_mesh() mesh_dim_names = mesh.mesh_dim_names @@ -288,6 +287,29 @@ def _get_flattened_mesh_by_layout( return None +def _get_flattened_mesh_by_layout( + mesh: DeviceMesh, mesh_dims: tuple[int, ...] +) -> DeviceMesh | None: + """ + Query for an explicitly created flattened mesh using layout comparison. + + When tracing with compile_on_one_rank, delegates to a custom op so the + flattened mesh appears as a call_function node derived from mesh (a graph + input) rather than as a get_attr constant holding an unpicklable + ProcessGroup. + """ + if _are_we_tracing() and torch.distributed.config.compile_on_one_rank: + # Pre-check: the custom op can't return None (torch.library doesn't + # support Optional opaque return types), so guard here first. + if _get_flattened_mesh_by_layout_impl(mesh, mesh_dims) is None: + return None + from torch.distributed._ops import device_mesh as _ # noqa: F401 + + return torch.ops.device_mesh._get_flattened_submesh(mesh, list(mesh_dims)) + + return _get_flattened_mesh_by_layout_impl(mesh, mesh_dims) + + # Track (mesh_hash, mesh_dims, reason) we've already warned about to avoid repeated warnings _warned_flatten_issues: set[tuple[int, tuple[int, ...], str]] = set() @@ -1176,8 +1198,8 @@ def get_logical_shape( src_state: "DTensorRedistributePlanner.DistState", mesh_dim: int, full_tensor_shape: tuple[int, ...], - ) -> list[int]: - new_logical_shape = list(full_tensor_shape) + ) -> list[IntLikeType]: + new_logical_shape: list[IntLikeType] = list(full_tensor_shape) for entry in src_state.tensor_dim_to_mesh_dim: tensor_dim = entry.tensor_dim mesh_dims = entry.mesh_dims @@ -1351,6 +1373,11 @@ def generate_greedy_transform_infos( target = target_placements[mesh_dim] # If target is not Shard, we can directly redistribute since we # are traversing from inner to outer placements here + # TODO: extend nested sharding detection to _StridedShard + # (isinstance check and is_shard() below miss it). + # Safe today: strategies convert _StridedShard to Replicate + # on ALL mesh dims for a given reduction dim, so misaligned + # nested _StridedShard targets can't arise. if isinstance(target, Shard): # If target is Shard, check for nested sharding on the # tensor dim BEFORE the current mesh_dim @@ -1437,9 +1464,22 @@ def _gen_transform_infos_non_cached( src_spec.tensor_meta, ) if use_graph_based_transform: - transform_infos = drp.generate_graph_based_transform_infos( - src_spec, dst_spec, src_spec.shape - ) + # TODO(zpcore): Temporary workaround for the case where _StridedShard + # cannot be decoded into shard order. This happens when + # use_strided_shard_as_shard_order defaults to True (e.g. in + # Redistribute.forward where the target DTensorSpec is constructed from + # raw placements without the flag), but the split_factor doesn't + # correspond to any valid product of mesh dimension sizes (e.g. sf=2 + # on a 1D mesh). A proper fix is to either pass + # use_strided_shard_as_shard_order through the Redistribute API, or + # migrate to explicit shard_order so _StridedShard is no longer + # overloaded for two purposes. + try: + transform_infos = drp.generate_graph_based_transform_infos( + src_spec, dst_spec, src_spec.shape + ) + except _StridedShardNotDecodableError: + transform_infos = drp.generate_greedy_transform_infos(src_spec, dst_spec) else: transform_infos = drp.generate_greedy_transform_infos(src_spec, dst_spec) return transform_infos @@ -1541,6 +1581,15 @@ def redistribute_local_tensor( mesh_to_use = device_mesh i = transform_info.mesh_dim current, target = transform_info.src_dst_placements + + # _StridedShard methods use device_mesh directly, not mesh_to_use. + # This is safe because _StridedShard.is_shard() returns False, so + # _comm_type_key() returns None and flattening is never attempted. + if isinstance(current, _StridedShard) or isinstance(target, _StridedShard): + assert mesh_to_use is device_mesh, ( # noqa: S101 + "_StridedShard redistribute assumes no flattened transforms" + ) + num_chunks = mesh_to_use.size(mesh_dim=i) if current == target: @@ -1611,8 +1660,15 @@ def redistribute_local_tensor( target_placement.dim, ) elif isinstance(current, _StridedShard): - raise NotImplementedError( - "Redistribute from _StridedShard to Shard is not implemented yet" + # _StridedShard -> Shard: go via Replicate as intermediate + replicated = current._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + new_local_tensor = target_placement._replicate_to_shard( + replicated, + mesh_to_use, + i, + mesh_to_use._sym_get_coordinate(i), ) else: raise ValueError( @@ -1624,7 +1680,7 @@ def redistribute_local_tensor( new_local_tensor = partial_spec._partition_value( local_tensor, mesh_to_use, i ) - elif current.is_shard() or isinstance(current, _StridedShard): + elif _is_shard_like(current): raise RuntimeError( f"redistribute from {current} to {target} not supported yet" ) @@ -1638,8 +1694,13 @@ def redistribute_local_tensor( elif isinstance(target, _StridedShard): # Case 4: target is _StridedShard if current.is_partial(): - raise NotImplementedError( - "Redistribute from Partial to _StridedShard is not implemented yet" + # Partial -> _StridedShard: reduce to Replicate, then strided shard + partial_spec = cast(Partial, current) + replicated = partial_spec._reduce_value( + local_tensor, mesh_to_use, i + ) + new_local_tensor = target._replicate_to_strided_shard( + replicated, device_mesh, i, device_mesh._sym_get_coordinate(i) ) elif current.is_replicate(): # split the tensor and return the corresponding local strided shard @@ -1647,9 +1708,13 @@ def redistribute_local_tensor( local_tensor, device_mesh, i, device_mesh._sym_get_coordinate(i) ) elif current.is_shard(): - # Shard -> _StridedShard on potentially different dimensions - raise NotImplementedError( - "Redistribute from Shard to _StridedShard is not implemented yet" + # Shard -> _StridedShard: all-gather to Replicate, then strided shard + current_placement = cast(Shard, current) + replicated = current_placement._to_replicate_tensor( + local_tensor, mesh_to_use, i, transform_info.logical_shape + ) + new_local_tensor = target._replicate_to_strided_shard( + replicated, device_mesh, i, device_mesh._sym_get_coordinate(i) ) elif isinstance(current, _StridedShard): # _StridedShard -> _StridedShard: go through Replicate @@ -1673,6 +1738,99 @@ def redistribute_local_tensor( return new_local_tensor +def _redistribute_backward( + grad_output: "dtensor.DTensor", + previous_spec: DTensorSpec, + original_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + async_op: bool = False, +): + """ + Common function for redistributing a distributed tensor during backward + and twice-backward backpropagation steps. + + Args: + grad_output: The output gradient tensor. + previous_spec: DTensorSpec prior to redistribution. + original_dtype: Original output tensor dtype from forward pass (for type checking) + backward_dtype: Desired data type for backwards output. + async_op: whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + + Returns: + A :class:`torch.Tensor` object. + A :class:`DTensorSpec` object. + """ + if backward_dtype is not None and backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + # pyrefly: ignore [bad-argument-type] + dtype=backward_dtype, + ), + use_strided_shard_as_shard_order=grad_output._spec.use_strided_shard_as_shard_order, + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + use_strided_shard_as_shard_order=previous_spec.use_strided_shard_as_shard_order, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! + + # for backward shard -> partial, we just do shard -> replicate + # for backward replicate -> partial, we skip the transformation + # NOTE: _is_shard_like covers _StridedShard defensively; currently + # unreachable because Partial -> _StridedShard is not implemented. + normalized_placements: list[Placement] = [] + for current, target in zip(current_spec.placements, previous_spec.placements): + if (_is_shard_like(current) or current.is_replicate()) and target.is_partial(): + normalized_placements.append(Replicate()) + else: + normalized_placements.append(target) + + previous_spec = DTensorSpec( + previous_spec.device_mesh, + placements=tuple(normalized_placements), + tensor_meta=previous_spec.tensor_meta, + use_strided_shard_as_shard_order=previous_spec.use_strided_shard_as_shard_order, + ) + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + ) + + if output.dtype != original_dtype: + output = output.to(original_dtype) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + use_strided_shard_as_shard_order=previous_spec.use_strided_shard_as_shard_order, + ) + return output, spec + + class Redistribute(torch.autograd.Function): @staticmethod def forward( # type: ignore[override] @@ -1699,6 +1857,7 @@ def forward( # type: ignore[override] stride=input.stride(), dtype=forward_dtype, ), + use_strided_shard_as_shard_order=input._spec.use_strided_shard_as_shard_order, ) else: local_tensor = input._local_tensor @@ -1735,72 +1894,60 @@ def forward( # type: ignore[override] @staticmethod def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] previous_spec = ctx.current_spec - async_op = ctx.async_op - backward_dtype = ctx.backward_dtype or ctx.original_dtype + output_dtensor = NestedRedistribute.apply( + grad_output, + previous_spec, + ctx.async_op, + ctx.backward_dtype, + ctx.original_dtype, + ) + return ( + output_dtensor, + None, + None, + None, + None, + None, + ) - if backward_dtype != grad_output._local_tensor.dtype: - local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) - current_spec = DTensorSpec( - mesh=grad_output._spec.device_mesh, - placements=grad_output._spec.placements, - tensor_meta=TensorMeta( - shape=grad_output.shape, - stride=grad_output.stride(), - dtype=backward_dtype, - ), - ) - previous_spec = DTensorSpec( - mesh=previous_spec.device_mesh, - placements=previous_spec.placements, - tensor_meta=current_spec.tensor_meta, - ) - else: - local_tensor = grad_output._local_tensor - current_spec = grad_output._spec - # skip the replicate to partial transformation when we are in backward pass - # In this case we keep the grad as replicate, this is because we don't - # want to convert the replicated gradients back to partial, although - # that's logically conform with the same layout, converting the gradients - # back to partial is actually useless as you would have to do reduce later - # which would be more expensive than keeping it replicate! - - # for backward shard -> partial, we just do shard -> replicate - # for backward replicate -> partial, we skip the transformation - normalized_placements: list[Placement] = [] - for current, target in zip(current_spec.placements, previous_spec.placements): - if (current.is_shard() or current.is_replicate()) and target.is_partial(): - normalized_placements.append(Replicate()) - else: - normalized_placements.append(target) - previous_spec = DTensorSpec( - previous_spec.device_mesh, - placements=tuple(normalized_placements), - tensor_meta=previous_spec.tensor_meta, - ) +class NestedRedistribute(torch.autograd.Function): + """ + This class is used to make the redistribution of a DTensor twice-differentiable. + This is called during the `Redistribute.forward`. + Therefore, `NestedRedistribute.forward` is called during the first backward pass, + and `NestedRedistribute.backward` is called during the second backward pass. - output = redistribute_local_tensor( - local_tensor, - current_spec, + Note: `NestedRedistribute.backward` is not differentiable, and therefore triple + backward is not yet supported. + """ + + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + grad_output: "dtensor.DTensor", + previous_spec: DTensorSpec, + async_op: bool = False, + forward_dtype: torch.dtype | None = None, + backward_dtype: torch.dtype | None = None, + ): + ctx.async_op = async_op + ctx.original_dtype = grad_output._local_tensor.dtype + ctx.backward_dtype = backward_dtype or ctx.original_dtype + + output, spec = _redistribute_backward( + grad_output, previous_spec, - async_op=async_op, - is_explicit=True, + ctx.backward_dtype, + backward_dtype, + async_op, ) - if output.dtype != ctx.original_dtype: - output = output.to(ctx.original_dtype) + ctx.current_spec = spec - spec = DTensorSpec( - previous_spec.device_mesh, - tuple(normalized_placements), - tensor_meta=TensorMeta( - shape=grad_output.shape, - stride=grad_output.stride(), - dtype=output.dtype, - ), - ) # pyrefly: ignore [bad-argument-type] - output_dtensor = dtensor.DTensor( + return dtensor.DTensor( # pyrefly: ignore [bad-argument-count] output, spec, @@ -1808,6 +1955,20 @@ def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] requires_grad=grad_output.requires_grad, ) + @staticmethod + def backward(ctx, grad2_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + async_op = ctx.async_op + backward_dtype = ctx.backward_dtype or ctx.original_dtype + + output_dtensor = NestedRedistribute.apply( + grad2_output, + previous_spec, + async_op, + backward_dtype, + ctx.original_dtype, + ) + return ( output_dtensor, None, diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index de0b7ba08b717..66584473c952c 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -45,6 +45,56 @@ log = logging.getLogger(__name__) +def _propagate_use_strided_shard_flag( + op_strategy: OpStrategy, + op_schema: OpSchema, +) -> None: + """Propagate use_strided_shard_as_shard_order from input specs to output specs. + + When inputs carry _StridedShard with an explicit flag, all output (and input) + DTensorSpecs in the strategy that also contain _StridedShard must agree. + Strategy functions may forget to propagate the flag; this function fixes + them up centrally after the strategy is produced. + """ + _use_strided: bool | None = None + for spec in op_schema.args_spec: + if any(isinstance(p, _StridedShard) for p in spec.placements): + val = spec.use_strided_shard_as_shard_order + if _use_strided is not None and _use_strided != val: + raise ValueError( + "Conflicting use_strided_shard_as_shard_order across " + f"input specs: got both {_use_strided} and {val}" + ) + _use_strided = val + + if _use_strided is None: + return + + def _fixup(spec: DTensorSpec) -> None: + if not any(isinstance(p, _StridedShard) for p in spec.placements): + return + if spec.use_strided_shard_as_shard_order == _use_strided: + return + spec.use_strided_shard_as_shard_order = _use_strided + if _use_strided: + spec.shard_order = None # pyrefly: ignore[bad-assignment] + else: + spec.shard_order = DTensorSpec.compute_default_shard_order(spec.placements) + + for op_spec in op_strategy.strategies: + out = op_spec.output_specs + if out is not None: + if isinstance(out, DTensorSpec): + _fixup(out) + else: + for s in out: + if s is not None: + _fixup(s) + if op_spec.input_specs is not None: + for s in op_spec.input_specs: + _fixup(s) + + def _length(obj) -> int: if obj is None: return 0 @@ -53,14 +103,15 @@ def _length(obj) -> int: return len(obj) -def _get_expected_num_tensor_outputs(op: OpOverload) -> int: +def _get_expected_num_tensor_outputs(op: OpOverload) -> int | None: """ Get the expected number of tensor outputs for an operator based on its schema. Returns: The number of tensor outputs expected. Returns 0 for ops that don't return tensors (e.g., _linalg_check_errors). Returns 1 for single tensor return, and >1 for - tuple returns where each element is a tensor. + tuple returns where each element is a tensor. Returns None for List[Tensor] + returns where the length is unknown at schema time. """ return_types = op._schema.returns if len(return_types) == 0: @@ -71,8 +122,8 @@ def _get_expected_num_tensor_outputs(op: OpOverload) -> int: # Could be single tensor or tuple of tensors return len(return_types) elif isinstance(first_return.type, torch.ListType): - # List[Tensor] - we don't know the length at schema time, treat as 1 - return 1 + # List[Tensor] - we don't know the length at schema time + return None else: # Not a tensor return type return 0 @@ -101,6 +152,16 @@ def _validate_tensor_meta_count( else: actual_outputs = len(tensor_meta) + if expected_outputs is None: + # List[Tensor] return type: length unknown at schema time, but + # tensor_meta must be a list of TensorMeta. + if not isinstance(tensor_meta, list): + raise AssertionError( + f"Tensor meta for {op_schema.op} should be a list[TensorMeta] " + f"(op returns List[Tensor]), but got {type(tensor_meta).__name__}" + ) + return + if actual_outputs != expected_outputs: raise AssertionError( f"Tensor meta count mismatch for {op_schema.op}: " @@ -346,6 +407,22 @@ def __init__(self) -> None: aten.select_backward.default: 1, aten.slice_backward.default: 1, } + # ops with individual scalar shape args that need local adjustment + # maps op -> callable(input_specs, schema) -> adjusted schema + # populated by op modules (e.g. _math_ops.py) at registration time + self.op_to_scalar_shape_adjuster: dict[ + OpOverload, + Callable[[list[DTensorSpec], OpSchema], OpSchema], + ] = {} + # squeeze ops that need dim arg rewritten to only globally-singleton dims + self.squeeze_op_to_dims_variant: dict[OpOverload, OpOverload] = { + aten.squeeze.default: aten.squeeze.dims, + aten.squeeze.dim: aten.squeeze.dims, + aten.squeeze.dims: aten.squeeze.dims, + aten.squeeze_.default: aten.squeeze_.dims, + aten.squeeze_.dim: aten.squeeze_.dims, + aten.squeeze_.dims: aten.squeeze_.dims, + } def register_sharding_prop_rule( self, @@ -541,21 +618,11 @@ def _create_output_spec_with_new_tensor_meta( if isinstance(spec, DTensorSpec): output_tensor_meta_i = output_tensor_meta[i] if not isinstance(output_tensor_meta_i, TensorMeta): - # NOTE: aten.convolution_backward.default is an exception and it - # needs extra handling because any Tensor in the output tuple - # can be `None` depending on the output_mask parameter. This can - # occur during double backpropagation or when certain gradients - # are not needed (e.g., grad_input when input has requires_grad=False, - # grad_weight/grad_bias when weight/bias have requires_grad=False, - # or grad_bias when bias is None). We explicitly allow the - # corresponding TensorMeta to be `None`. - if ( - op == aten.convolution_backward.default - and i in (0, 1, 2) - and output_tensor_meta_i is None - ): - if not isinstance(output_specs, list): - raise AssertionError + # Some ops (e.g. convolution_backward, native_layer_norm_backward, + # _fused_rms_norm_backward) have an output_mask parameter that + # controls which outputs are computed. When output_mask[i] is + # False, the output at position i is None and has no TensorMeta. + if output_tensor_meta_i is None: new_specs.append(None) continue else: @@ -586,17 +653,23 @@ def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: def spec_to_strategy(spec: object) -> object: if isinstance(spec, DTensorSpec): return OpStrategy([OpSpec(spec)]) - elif ( - isinstance(spec, (list, tuple)) - and len(spec) > 0 - and isinstance(spec[0], DTensorSpec) - ): - # tensor list create tuple strategy - tuple_strategy = [spec_to_strategy(s) for s in spec] - tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) - return TupleStrategy( - tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy - ) + elif isinstance(spec, (list, tuple)) and len(spec) > 0: + if all(isinstance(s, DTensorSpec) for s in spec): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) + if isinstance(spec, tuple) + else tuple_strategy + ) + elif any(isinstance(s, DTensorSpec) for s in spec): + # mixed list (e.g. [DTensorSpec, None, DTensorSpec]) for + # ops like aten.index.Tensor; keep as list so pytree + # flattening can extract OpStrategy items + return [spec_to_strategy(s) for s in spec] + else: + return spec else: return spec @@ -707,6 +780,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin if op_strategy is not None: if isinstance(op_strategy, OpStrategy): + _propagate_use_strided_shard_flag(op_strategy, op_schema) # single Op strategy output_strategy = _select_min_cost_strategy(op_strategy, op_schema) @@ -764,6 +838,28 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin needs_redistribute = True use_val_from_redistribute_schema = True + # adjust individual scalar shape args (e.g. N, C, HxW in group_norm) + if op_schema.op in self.op_to_scalar_shape_adjuster: + if any( + isinstance(p, Shard | _StridedShard) + for spec in expected_input_specs + for p in spec.placements + ): + schema = suggestion_schema or op_schema + adjuster = self.op_to_scalar_shape_adjuster[op_schema.op] + suggestion_schema = adjuster(expected_input_specs, schema) + needs_redistribute = True + use_val_from_redistribute_schema = True + + # rewrite squeeze to use only globally-singleton dims + if op_schema.op in self.squeeze_op_to_dims_variant: + schema = suggestion_schema or op_schema + adjusted = self._adjust_squeeze_to_global_singletons(schema) + if adjusted is not None: + suggestion_schema = adjusted + needs_redistribute = True + use_val_from_redistribute_schema = True + # construct output spec for the op if op_schema.return_type_tuple_tensor_like(): # for ops that return multiple tensors and the output_specs is not @@ -778,6 +874,7 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin mesh=output_specs.mesh, placements=output_specs.placements, tensor_meta=output_specs.tensor_meta, + use_strided_shard_as_shard_order=output_specs.use_strided_shard_as_shard_order, ) for _ in range(len(op_schema.op._schema.returns)) ) @@ -803,9 +900,11 @@ def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputShardin for strategy in op_strategy.children: if not isinstance(strategy, OpStrategy): raise AssertionError + _propagate_use_strided_shard_flag(strategy, op_schema) selected_strategy = _select_min_cost_strategy(strategy) selected_strategies.append(selected_strategy) - out_spec_list.append(selected_strategy.output_spec) + if selected_strategy.output_specs is not None: + out_spec_list.append(selected_strategy.output_spec) needs_redistribute = False suggestion_args: list[object] = [] @@ -951,3 +1050,47 @@ def _adjust_shape_and_stride_args( ) return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) + + def _adjust_squeeze_to_global_singletons(self, schema: OpSchema) -> OpSchema | None: + """ + Rewrite squeeze ops to squeeze.dims with only globally-singleton dims. + Fixes bug where sharded dims with local size 1 get incorrectly squeezed. + Returns None if no rewrite is needed (already squeeze.dims with correct args). + """ + from torch.fx.experimental.symbolic_shapes import guard_or_false + + input_spec = cast(DTensorSpec, schema.args_schema[0]) + tensor_meta = input_spec.tensor_meta + if tensor_meta is None: + raise RuntimeError("squeeze requires tensor metadata") + global_shape = tensor_meta.shape + ndim = len(global_shape) + + def normalize(d: int) -> int: + return d if d >= 0 else d + ndim + + def is_singleton(d: int) -> bool: + nd = normalize(d) + return 0 <= nd < ndim and guard_or_false(global_shape[nd] == 1) + + # guard_or_false: conservatively keep dims when size is symbolic/unknown + if schema.op in (aten.squeeze.default, aten.squeeze_.default): + target_dims = tuple( + i for i, s in enumerate(global_shape) if guard_or_false(s == 1) + ) + elif schema.op in (aten.squeeze.dim, aten.squeeze_.dim): + dim = normalize(schema.args_schema[1]) # type: ignore[arg-type] + target_dims = (dim,) if is_singleton(dim) else () + else: + dims = cast(Sequence[int], schema.args_schema[1]) + target_dims = tuple( # type: ignore[union-attr] + normalize(d) for d in dims if is_singleton(d) + ) + + dims_variant = self.squeeze_op_to_dims_variant[schema.op] + # Skip rewrite if already targeting the right op with the same dims + if schema.op == dims_variant and len(schema.args_schema) > 1: + existing_dims = schema.args_schema[1] + if existing_dims == target_dims: + return None + return OpSchema(dims_variant, (input_spec, target_dims), {}) diff --git a/torch/distributed/tensor/_utils.py b/torch/distributed/tensor/_utils.py index e326cd827974e..89aede42e53f5 100644 --- a/torch/distributed/tensor/_utils.py +++ b/torch/distributed/tensor/_utils.py @@ -1,7 +1,7 @@ import logging import threading from collections.abc import Callable, Sequence -from typing import Any, cast +from typing import Any import torch import torch.distributed._functional_collectives as funcol @@ -15,6 +15,7 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import OpSchema from torch.distributed.tensor.placement_types import ( + _is_shard_like, _StridedShard, Partial, Placement, @@ -314,14 +315,13 @@ def compute_local_tensor_info( for idx, placement in enumerate(placements): mesh_dim_size = mesh.size(idx) - if placement.is_shard(): - shard_placement = cast(Shard, placement) - if shard_placement.dim < 0: + if _is_shard_like(placement): + if placement.dim < 0: raise AssertionError( "Shard placements should have negative dims normalized in " - f"the user-facing APIs: {shard_placement}" + f"the user-facing APIs: {placement}" ) - shard_dim = shard_placement.dim + shard_dim = placement.dim if shard_dim >= len(local_shape): raise AssertionError( f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)} " @@ -385,6 +385,7 @@ def compute_global_tensor_shape( if isinstance(placements[0], Replicate): return shape + # NOTE: isinstance(_, Shard) does not match _StridedShard; see _is_shard_like(). elif isinstance(placements[0], Shard): @maybe_run_for_local_tensor diff --git a/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py index fb01a8c20d1da..f5d982d5a6561 100644 --- a/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py +++ b/torch/distributed/tensor/experimental/_context_parallel/_load_balancer.py @@ -148,27 +148,16 @@ def _generate_indices(self, restore: bool = False) -> Tensor: if seq_length % (world_size * 2) != 0: raise AssertionError chunk_size = seq_length // (world_size * 2) - all_indices = [] - for rank in range(world_size): - # Generate indices for first chunk of the cp rank - first_chunk_start = rank * chunk_size - first_chunk_indices = list( - range(first_chunk_start, first_chunk_start + chunk_size) - ) - - # Second chunk: positions from the complementary chunk - second_chunk_idx = world_size * 2 - rank - 1 - second_chunk_start = second_chunk_idx * chunk_size - second_chunk_indices = list( - range(second_chunk_start, second_chunk_start + chunk_size) - ) - # combine the indices for this rank - all_indices.extend(first_chunk_indices + second_chunk_indices) + # Split sequence into 2*world_size chunks, then pair chunk r with + # chunk (2*world_size - 1 - r) for each rank. + indices = torch.arange(seq_length, dtype=torch.int, device=self.device) + chunks = indices.view(world_size * 2, chunk_size) + head_idx = torch.arange(world_size, device=self.device) + tail_idx = 2 * world_size - 1 - head_idx + paired = torch.stack([chunks[head_idx], chunks[tail_idx]], dim=1) + all_indices_tensor = paired.reshape(-1) - all_indices_tensor = torch.tensor( - all_indices, dtype=torch.int, device=self.device - ) if restore: all_indices_tensor = torch.argsort(all_indices_tensor) diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index e78222845339c..adde891c4b68c 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -430,6 +430,7 @@ def _partition_val(val: Any, spec: DTensorSpec) -> Any: return local_shard for idx, placement in enumerate(spec.placements): + # NOTE: is_shard() does not match _StridedShard; see _is_shard_like(). if placement.is_shard(): placement = cast(Shard, placement) num_chunks = spec.mesh.size(mesh_dim=idx) diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index 99e7113afefb6..5f6facf74542a 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -37,6 +37,7 @@ def _get_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: offsets = [0] * len(tensor.size()) num_chunks = device_mesh.size(mesh_dim=0) + # NOTE: is_shard() does not match _StridedShard; see _is_shard_like(). if tensor.placements[0].is_shard(): shard_dim = cast(DShard, placement).dim chunk_size = tensor.size(shard_dim) // num_chunks @@ -81,6 +82,7 @@ def _create_sharded_tensor_md_from_dt( my_rank = dist.get_rank(dt_pg) scapegoat_rank = 0 if my_rank > 0 else 1 + # NOTE: is_shard() does not match _StridedShard; see _is_shard_like(). if dt.placements[0].is_shard(): shard_count = dt_pg.size() else: diff --git a/torch/distributed/tensor/parallel/input_reshard.py b/torch/distributed/tensor/parallel/input_reshard.py index 81e25621e040a..b3c47ddae8460 100644 --- a/torch/distributed/tensor/parallel/input_reshard.py +++ b/torch/distributed/tensor/parallel/input_reshard.py @@ -4,6 +4,7 @@ import torch from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard +from torch.distributed.tensor.placement_types import _is_shard_like __all__ = [ @@ -64,7 +65,7 @@ def input_reshard_backward_hook( return module -def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 +def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: """Hook function called after FWD to shard input.""" if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) @@ -82,12 +83,12 @@ def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> return x -def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 +def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: """Hook function called before activation recomputing in BWD to restore input.""" if ( isinstance(x, DTensor) and len(x._spec.placements) == 1 - and x._spec.placements[0].is_shard() + and _is_shard_like(x._spec.placements[0]) ): return x.redistribute(device_mesh=mesh, placements=[Replicate()]) elif ( diff --git a/torch/distributed/tensor/placement_types.py b/torch/distributed/tensor/placement_types.py index 1e717569986e9..6009ff046b7eb 100644 --- a/torch/distributed/tensor/placement_types.py +++ b/torch/distributed/tensor/placement_types.py @@ -2,8 +2,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates import functools +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import cast, TypeVar +from typing import cast, TypeGuard, TypeVar import torch import torch._C @@ -21,6 +22,7 @@ unpad_tensor, ) from torch.distributed.tensor._ops._mask_buffer import MaskBuffer +from torch.types import IntLikeType __all__ = ["Placement", "Shard", "Replicate", "Partial"] @@ -60,6 +62,10 @@ class Shard(torch._C._distributed.Shard): .. warning:: sharding on a tensor dimension where the tensor dimension size is not evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + + .. note:: When checking whether a placement is shard-like, use + :func:`_is_shard_like` instead of ``isinstance(p, Shard)`` to also + match :class:`_StridedShard`. """ def _split_tensor( @@ -211,7 +217,7 @@ def _custom_chunk( @staticmethod @maybe_run_for_local_tensor def local_shard_size_and_offset( - curr_local_size: int, + curr_local_size: IntLikeType, num_chunks: int, rank: _RankTypeT, ) -> tuple[_RankTypeT, _RankTypeT]: @@ -253,7 +259,7 @@ def _local_shard_size_and_offset( num_chunks: int, rank: RankType, ) -> tuple[int, RankType]: - # pyrefly: ignore[bad-argument-type] # pyrefly bug + # pyrefly: ignore [bad-argument-type, bad-return] return Shard.local_shard_size_and_offset(curr_local_size, num_chunks, rank) @staticmethod @@ -392,7 +398,7 @@ def _reduce_shard_tensor( def _maybe_pad_tensor( self, local_tensor: torch.Tensor, - logical_dim_size: int, + logical_dim_size: IntLikeType, num_chunks: int, ) -> torch.Tensor: from torch.fx.experimental.symbolic_shapes import guard_or_true @@ -414,10 +420,10 @@ def _maybe_pad_tensor( def _maybe_unpad_tensor( self, local_tensor: torch.Tensor, - logical_dim_size: int, + logical_dim_size: IntLikeType, num_chunks: int, ) -> torch.Tensor: - from torch.fx.experimental.symbolic_shapes import guard_or_true + from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true # Assume padding (uneven sharding) as general case for unbacked sizes. is_padded = guard_or_true(logical_dim_size % num_chunks != 0) @@ -427,6 +433,22 @@ def _maybe_unpad_tensor( unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] local_tensor = unpad_tensor(local_tensor, self.dim, unpad_size) + # Bind derived symbolic sizes (e.g. 2*(s//2)) back to the original + # symbol - needed for correct shape propagation and dynamo generation + if local_tensor.size(self.dim) is not logical_dim_size: + orig_size = local_tensor.size(self.dim) + torch._check(orig_size >= logical_dim_size) + local_tensor = local_tensor.narrow(self.dim, 0, logical_dim_size) + + # Safety check: the narrow should never change the concrete size. + # Use guard_or_false so we don't trigger data-dependent guards + # on unbacked symints. + if guard_or_false(local_tensor.size(self.dim) != orig_size): + raise RuntimeError( + f"narrow unexpectedly changed concrete size on dim {self.dim}: " + f"{orig_size} -> {local_tensor.size(self.dim)}" + ) + return local_tensor def _to_replicate_tensor( @@ -434,7 +456,7 @@ def _to_replicate_tensor( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], ) -> torch.Tensor: """ This function all_gather all shards and return a tensor that @@ -462,7 +484,7 @@ def _replicate_to_shard( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - shard_index: int, + shard_index: IntLikeType, ) -> torch.Tensor: """ transform from replicated tensor to a sharded tensor on @@ -489,11 +511,11 @@ def _get_shard_pad_size( @staticmethod def _compute_padding_info( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], num_chunks: int, old_shard_dim: int, new_shard_dim: int, - ) -> tuple[bool, int, int, bool, int, int]: + ) -> tuple[bool, IntLikeType, int, bool, IntLikeType, int]: from torch.fx.experimental.symbolic_shapes import guard_or_true results = [] @@ -508,7 +530,7 @@ def _compute_padding_info( @staticmethod @maybe_run_for_local_tensor def _pad_for_new_shard_dim( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], local_tensor: torch.Tensor, num_chunks: int, old_shard_dim: int, @@ -543,7 +565,7 @@ def _pad_for_new_shard_dim( @staticmethod @maybe_run_for_local_tensor def _unpad_for_new_shard_dim( - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], local_tensor: torch.Tensor, num_chunks: int, old_shard_dim: int, @@ -582,7 +604,7 @@ def _to_new_shard_dim( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], new_shard_dim: int, ) -> torch.Tensor: """ @@ -724,7 +746,7 @@ def __fx_repr__(self): Needed for passing this type as an opaque object input to a custom op. """ return ( - f"torch.distributed.tensor.placement_types._StridedShard(dim={self.dim}, sf={self.split_factor})", # noqa: B950 + f"torch.distributed.tensor.placement_types._StridedShard(dim={self.dim}, sf={self.split_factor})", {}, ) @@ -857,7 +879,7 @@ def _select_split_tensor( self, tensor: torch.Tensor, num_chunks: int, - index: int, + index: IntLikeType, *, with_padding: bool = True, contiguous: bool = True, @@ -891,7 +913,7 @@ def _to_replicate_tensor( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - current_logical_shape: list[int], + current_logical_shape: Sequence[IntLikeType], ) -> torch.Tensor: """ Replay the replicate-to-shard process to understand how to stitch shards back. @@ -986,6 +1008,7 @@ def _to_replicate_tensor( # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [logical_dim_size] + # pyrefly: ignore [no-matching-overload] indices_tensor = torch.arange( logical_dim_size, device=local_tensor.device ).view(shape) @@ -1050,7 +1073,7 @@ def _replicate_to_strided_shard( local_tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int, - shard_index: int, + shard_index: IntLikeType, ) -> torch.Tensor: """ Transform from replicated tensor to a strided-sharded tensor on the current rank. @@ -1097,7 +1120,7 @@ def _local_shard_size_and_offset( @maybe_run_for_local_tensor def local_shard_size_and_offset( self, - curr_local_size: int, + curr_local_size: IntLikeType, num_chunks: int, rank: RankType, return_first_offset: bool = True, @@ -1127,6 +1150,7 @@ def local_shard_size_and_offset( # indices_tensor is 1D torch.arange(logical_dim_size) unsqueezed # so that we can reuse self._split_tensor which splits on self.dim shape = [1] * self.dim + [curr_local_size] + # pyrefly: ignore [no-matching-overload] indices_tensor = torch.arange( curr_local_size, ).view(shape) @@ -1154,6 +1178,17 @@ def local_shard_size_and_offset( return local_shard_size, offsets +def _is_shard_like(p: "Placement") -> TypeGuard[Shard | _StridedShard]: + """Check if a placement is Shard or _StridedShard. + + Use this instead of ``isinstance(p, Shard)`` to avoid silently missing + ``_StridedShard``. When ``_StridedShard`` is unified with ``Shard`` + (see TODO on the class), this helper can be collapsed to a single + ``isinstance`` check. + """ + return isinstance(p, Shard | _StridedShard) + + class Replicate(torch._C._distributed.Replicate): """ The ``Replicate()`` placement describes the DTensor replicating on a corresponding @@ -1384,7 +1419,9 @@ def __init__( @staticmethod @maybe_run_for_local_tensor def _mask_tensor( - tensor: torch.Tensor, local_offset_on_dim: int, local_shard_size: int + tensor: torch.Tensor, + local_offset_on_dim: IntLikeType, + local_shard_size: IntLikeType, ) -> tuple[torch.Tensor, torch.Tensor]: # Build the input mask and save it for the current partial placement # this is so that the output of embedding op can reuse the same partial @@ -1393,8 +1430,10 @@ def _mask_tensor( tensor >= local_offset_on_dim + local_shard_size ) # mask the input tensor + # pyrefly: ignore [unsupported-operation] masked_tensor = tensor.clone() - local_offset_on_dim masked_tensor[mask] = 0 + # pyrefly: ignore [bad-return] return mask, masked_tensor def _partition_value( @@ -1497,7 +1536,7 @@ def __fx_repr__(self): Needed for passing this type as an input to a custom op. """ return ( - f"torch.distributed.tensor.placement_types.MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})", # noqa: B950 + f"torch.distributed.tensor.placement_types.MaskPartial(reduce_op={self.reduce_op}, offset_shape={self.offset_shape}, offset_dim={self.offset_dim})", {}, ) diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index 28a18428c673d..d5651926b102a 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -134,16 +134,16 @@ def to_map(obj): from torch.nn.parallel.scatter_gather import _is_namedtuple if _is_namedtuple(obj): - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return [type(obj)(*args) for args in zip(*map(to_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return list(zip(*map(to_map, obj))) if isinstance(obj, list) and len(obj) > 0: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return [list(i) for i in zip(*map(to_map, obj))] if isinstance(obj, dict) and len(obj) > 0: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] return [obj] diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 85932828d21af..43991b0ac265a 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -970,4 +970,5 @@ def _add_kl_info(): ) kl_info = "\n\t".join(rows) if kl_divergence.__doc__: + # pyrefly: ignore [missing-attribute] kl_divergence.__doc__ += kl_info diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 6e79c7287dcf6..2e29d418ecc87 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -101,7 +101,7 @@ def print(self, str_to_filename: dict[int, str]) -> str: torch.ops.{op} is missing a fake kernel implementation. Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation. -""" # noqa: B950 +""" elif self.failure_type == FailureType.GUARD_ADDED: locals_info = ( @@ -140,7 +140,7 @@ def print(self, str_to_filename: dict[int, str]) -> str: Please add `torch._check(...)` to the original code to assert this data-dependent assumption. Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details. -""" # noqa: B950 +""" elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL: op = self.data["op"] @@ -150,7 +150,7 @@ def print(self, str_to_filename: dict[int, str]) -> str: The reason for the mismatch is: {reason}. Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation. -""" # noqa: B950 +""" else: raise ValueError(f"Unknown failure type: {self.failure_type}") diff --git a/torch/export/_leakage_detection_utils.py b/torch/export/_leakage_detection_utils.py index fe211e1dc079c..722a756431ea1 100644 --- a/torch/export/_leakage_detection_utils.py +++ b/torch/export/_leakage_detection_utils.py @@ -2,8 +2,9 @@ import types import typing import weakref +from typing_extensions import TypeIs -import torch +from torch.fx.experimental.symbolic_shapes import TrackedFake """ @@ -37,8 +38,8 @@ def _is_globals_or_locals(obj: typing.Any) -> bool: return obj is globals() or obj is locals() -def _is_tracked_fake(obj: typing.Any) -> bool: - return isinstance(obj, torch.fx.experimental.symbolic_shapes.TrackedFake) +def _is_tracked_fake(obj: typing.Any) -> TypeIs[TrackedFake]: + return isinstance(obj, TrackedFake) def _is_gm_meta_like_dict(d: dict, o: typing.Any) -> bool: diff --git a/torch/export/_swap.py b/torch/export/_swap.py index bdf3c441d69be..778a4ba68d3a8 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -58,7 +58,7 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: Currently this optimization only works for the case where all of the outputs of `foo` go directly into `bar`, and `bar` has no other inputs. - """ # noqa: B950 + """ log.debug("Trying to remove pytrees for module call %s", curr_module_node) @@ -328,7 +328,7 @@ def _swap_module_helper( The `call_module` node should now reference the swapped torch.nn.Module. The `tree_flatten_spec` call will deconstruct the eager outputs of the swapped module into tensors. - """ # noqa: B950 + """ submod_name = name.replace(".", "_") sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index e8cbaacb5222b..d89c7c2272c49 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -74,7 +74,7 @@ def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list: ) if not eq_spec(received_spec, in_spec): - raise ValueError( # noqa: B904 + raise ValueError( "Trying to flatten user inputs with exported input tree spec: \n" f"{in_spec}\n" "but actually got inputs with tree spec of: \n" @@ -205,7 +205,11 @@ def _convert_guards_code_to_fn( code_str += " return\n" # populate namespace with sympy globals, materialize function (named `_`) - namespace = {**SYMPY_INTERP} + namespace = { + **SYMPY_INTERP, + "math": math, + "inf": float("inf"), + } exec(code_str, namespace) # create and return a module whose forward is the materialized function @@ -675,7 +679,9 @@ def handle_symint(expr, src): if isinstance(meta, int): new_guards_code.append(f"{src} == {meta}") if isinstance(meta, float): - if meta == math.inf: + if math.isnan(meta): + new_guards_code.append(f"math.isnan({src})") + elif meta == math.inf: new_guards_code.append(f"{src} == math.inf") elif meta == -math.inf: new_guards_code.append(f"{src} == -math.inf") diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 270379d12b668..0109e3f685cc4 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1118,7 +1118,7 @@ def root_value(): if solution is not None: return int(solution[1]) else: - raise UserError( # noqa: B904 + raise UserError( UserErrorType.CONSTRAINT_VIOLATION, f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " f"of the form {expr}, where {symbol} is an integer", diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 15da1b9f02c26..6a930f537353c 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -325,7 +325,7 @@ def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] return ep.module()(*args, **kwargs) if isinstance(fn, torch.nn.Module): - _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 + _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] fn, lambda _: _exporter_context, # type: ignore[arg-type] ) diff --git a/torch/export/experimental/_utils.py b/torch/export/experimental/_utils.py index 9684e3530f89f..45c26f7de517d 100644 --- a/torch/export/experimental/_utils.py +++ b/torch/export/experimental/_utils.py @@ -196,7 +196,7 @@ def _get_make_file(package_name: str, model_names: list[str], device_type: str) "cmake_minimum_required(VERSION 3.10)", "project(TestProject)", "", - "set(CMAKE_CXX_STANDARD 17)", + "set(CMAKE_CXX_STANDARD 20)", "", ] ) diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 8d7cb8972cec6..2de31f40b20cb 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -7,6 +7,7 @@ import zipfile from dataclasses import dataclass from typing import Any, IO, TYPE_CHECKING, TypeAlias +from typing_extensions import TypeIs import torch import torch.utils._pytree as pytree @@ -331,7 +332,7 @@ def _package_aoti_files( logger.debug(weights_config) -def _is_fake_tensor(t: torch.Tensor) -> bool: +def _is_fake_tensor(t: torch.Tensor) -> TypeIs[FakeTensor]: return isinstance(t, FakeTensor) @@ -602,6 +603,8 @@ def _package_exported_programs( ep, opset_version, pickle_protocol, + serialize_state_dict=False, + serialize_constants=False, ) archive_writer.write_bytes( diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index c76c52c37e152..526010abb0ac5 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -1283,7 +1283,7 @@ def add_placeholder(self, x): if x.graph is not self.flat_graph: raise AssertionError( "expected x.graph to be flat_graph, got different graph" - ) # noqa: F541 + ) # x is not in subgraph, create a new placeholder for subgraph with self.graph.inserting_before(None): placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) @@ -1310,7 +1310,7 @@ def remap_input(self, x): if x.graph is not self.flat_graph: raise AssertionError( "expected x.graph to be flat_graph, got different graph" - ) # noqa: F541 + ) if x in self.node_map: return self.node_map[x] self.print(f"remap_input({x})") @@ -1650,7 +1650,10 @@ def _reorder_submodules( fqn = prefix + name _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".") delattr(parent, name) - children.append((fqn_order[fqn], name, child)) + base_fqn = fqn.split("@")[0] + children.append( + (fqn_order.get(fqn, fqn_order.get(base_fqn, len(fqn_order))), name, child) + ) children.sort(key=operator.itemgetter(0)) for _, name, child in children: parent.register_module(name, child) diff --git a/torch/func/_random.py b/torch/func/_random.py new file mode 100644 index 0000000000000..a0b6fe6b5ab1f --- /dev/null +++ b/torch/func/_random.py @@ -0,0 +1,263 @@ +"""Stateless PRNG APIs. + +These are experimental and subject to change without notice. +Access via ``torch.func._random``. +""" + +from collections.abc import Sequence + +import torch + + +def key( + seed: int, impl: str = "philox4x32-10", device: torch.device | None = None +) -> torch.Tensor: + r"""Create a PRNG key from a seed. + + A key is a tensor that encodes the state needed to deterministically + produce random values. Keys are consumed by generation functions to produce + reproducible random tensors without any global state. The internal + representation of the key depends on the chosen PRNG algorithm. + + Args: + seed (int): The seed value for the PRNG. + impl (str): PRNG algorithm. Currently only ``"philox4x32-10"`` is + supported. + device (:class:`torch.device`, optional): The desired device for the + returned key. Default: ``cpu``. + + Returns: + A tensor representing the PRNG key. + + .. note:: + + For the ``"philox4x32-10"`` algorithm, the key is a uint64 tensor of + shape ``(2,)`` encoding a ``(seed, offset)`` pair. The offset determines + the starting position in the Philox output stream. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + """ + if impl != "philox4x32-10": + raise NotImplementedError(f"key() does not support PRNG impl '{impl}'") + + # (seed, offset) + return torch.tensor([seed, 0], dtype=torch.uint64, device=device) + + +def split(key: torch.Tensor, num: int = 2) -> torch.Tensor: + r"""Split a PRNG key into ``num`` new independent keys. + + Each returned key produces a different, deterministic random sequence. + This is the primary mechanism for deriving multiple independent keys from + a single parent key without mutating any state. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, each key in the + batch is split independently and the result has shape ``(num, *batch, K)``. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + num (int): Number of keys to produce. Default: ``2``. + + Returns: + A tensor of shape ``(num, *key.shape)`` containing the derived keys. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> k1, k2 = torch.func._random.split(key) # doctest: +SKIP + """ + return torch.ops.aten._philox_key_split(key, num) + + +def fold_in(key: torch.Tensor, data: int) -> torch.Tensor: + r"""Deterministically derive a new key by folding in an integer. + + Equivalent to ``split(key, data + 1)[data]``, but more efficient when + only a single derived key is needed. Useful for associating a key with + a loop iteration, layer index, or other integer identifier. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, each key in + the batch is folded independently. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + data (int): An integer to fold into the key, interpreted as uint64. + + Returns: + A new key tensor with the same shape as ``key``. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> k0 = torch.func._random.fold_in(key, 0) # doctest: +SKIP + >>> k1 = torch.func._random.fold_in(key, 1) # doctest: +SKIP + >>> # Equivalent to split: + >>> keys = torch.func._random.split(key, 2) # doctest: +SKIP + >>> assert torch.equal(k0, keys[0]) # doctest: +SKIP + >>> assert torch.equal(k1, keys[1]) # doctest: +SKIP + """ + return torch.ops.aten._philox_key_fold_in(key, data) + + +def normal_( + key: torch.Tensor, + result: torch.Tensor, + *, + mean: float = 0.0, + std: float = 1.0, +) -> torch.Tensor: + r"""Fill ``result`` in-place with normal random values from a PRNG key. + + The values are drawn from a normal distribution with the specified ``mean`` + and ``std``. The output is fully determined by the key, so calling with the + same key always produces the same result. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, the leading + dimensions of ``result`` must be broadcastable with ``*batch`` and each key + independently generates its slice of the output. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + result (Tensor): The output tensor to fill in-place. + mean (float): Mean of the normal distribution. Default: ``0.0``. + std (float): Standard deviation of the normal distribution. Default: ``1.0``. + + Returns: + ``result``, filled with normal random values. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> result = torch.empty(1000, device="cuda") # doctest: +SKIP + >>> torch.func._random.normal_(key, result) # doctest: +SKIP + """ + return torch.ops.aten._philox_normal_(result, key, mean, std) + + +def normal( + key: torch.Tensor, + *shape: tuple[int, ...], + mean: float = 0.0, + std: float = 1.0, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + r"""Generate normally distributed random values from a PRNG key. + + Produces a tensor of the given shape filled with values drawn from a normal + distribution with the specified ``mean`` and ``std``. The output is fully + determined by the key, so calling with the same key always returns the same + result. The output is placed on the same device as ``key``. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, the leading + dimensions of ``shape`` must be broadcastable with ``*batch`` and each key + independently generates its slice of the output. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + *shape (int): The desired output shape. + mean (float): Mean of the normal distribution. Default: ``0.0``. + std (float): Standard deviation of the normal distribution. Default: ``1.0``. + dtype (:class:`torch.dtype`, optional): The desired dtype. Default: ``torch.float32``. + + Returns: + A tensor of the given shape filled with normal random values. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> torch.func._random.normal(key, (1000,)) # doctest: +SKIP + """ + if len(shape) == 1 and isinstance(shape[0], Sequence): + # pyrefly: ignore [bad-argument-type] + shape = tuple(shape[0]) + if dtype is None: + dtype = torch.float32 + # pyrefly: ignore [no-matching-overload] + result = torch.empty(shape, dtype=dtype, device=key.device) + return normal_(key, result, mean=mean, std=std) + + +def uniform_( + key: torch.Tensor, + result: torch.Tensor, + *, + low: float = 0.0, + high: float = 1.0, +) -> torch.Tensor: + r"""Fill ``result`` in-place with uniform random values from a PRNG key. + + The values are drawn uniformly from the interval ``[low, high)``. The output + is fully determined by the key, so calling with the same key always produces + the same result. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, the leading + dimensions of ``result`` must be broadcastable with ``*batch`` and each key + independently generates its slice of the output. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + result (Tensor): The output tensor to fill in-place. + low (float): Lower bound (inclusive) of the uniform distribution. Default: ``0.0``. + high (float): Upper bound (exclusive) of the uniform distribution. Default: ``1.0``. + + Returns: + ``result``, filled with uniform random values. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> result = torch.empty(1000, device="cuda") # doctest: +SKIP + >>> torch.func._random.uniform_(key, result) # doctest: +SKIP + """ + return torch.ops.aten._philox_uniform_(result, key, low, high) + + +def uniform( + key: torch.Tensor, + *shape: tuple[int, ...], + low: float = 0.0, + high: float = 1.0, + dtype: torch.dtype | None = None, +) -> torch.Tensor: + r"""Generate uniformly distributed random values from a PRNG key. + + Produces a tensor of the given shape filled with values drawn uniformly + from the interval ``[low, high)``. The output is fully determined by the + key, so calling with the same key always returns the same result. The output + is placed on the same device as ``key``. + + Supports batched keys: if ``key`` has shape ``(*batch, K)``, the leading + dimensions of ``shape`` must be broadcastable with ``*batch`` and each key + independently generates its slice of the output. + + Args: + key (Tensor): A PRNG key returned by :func:`key`, :func:`split`, or + :func:`fold_in`. + *shape (int): The desired output shape. + low (float): Lower bound (inclusive) of the uniform distribution. Default: ``0.0``. + high (float): Upper bound (exclusive) of the uniform distribution. Default: ``1.0``. + dtype (:class:`torch.dtype`, optional): The desired dtype. Default: ``torch.float32``. + + Returns: + A tensor of the given shape filled with uniform random values. + + Example:: + + >>> key = torch.func._random.key(42, device="cuda") # doctest: +SKIP + >>> torch.func._random.uniform(key, (1000,)) # doctest: +SKIP + """ + if len(shape) == 1 and isinstance(shape[0], Sequence): + # pyrefly: ignore [bad-argument-type] + shape = tuple(shape[0]) + if dtype is None: + dtype = torch.float32 + # pyrefly: ignore [no-matching-overload] + result = torch.empty(shape, dtype=dtype, device=key.device) + return uniform_(key, result, low=low, high=high) diff --git a/torch/functional.py b/torch/functional.py index 1594a7b3a289f..c71878055b9a1 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1254,7 +1254,7 @@ def tensordot( # noqa: F811 pass -def tensordot( # noqa: F811 +def tensordot( a, b, dims=2, @@ -1682,7 +1682,7 @@ def norm( # noqa: F811 pass -def norm( # noqa: F811 +def norm( input, p: float | str | None = "fro", dim=None, diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index c048b4fdd8f89..6d2ce95d8a3fd 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -85,14 +85,14 @@ def forward(self, x): ''' from torch.fx import immutable_collections -from torch.fx._symbolic_trace import ( # noqa: F401 +from torch.fx._symbolic_trace import ( PH, ProxyableClassMeta, symbolic_trace, Tracer, wrap, ) -from torch.fx.graph import CodeGen, Graph # noqa: F401 +from torch.fx.graph import CodeGen, Graph from torch.fx.graph_module import GraphModule from torch.fx.interpreter import Interpreter, Transformer from torch.fx.node import has_side_effect, map_arg, Node diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index c07dd1b51bc05..265369cb84c21 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -16,6 +16,7 @@ def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: def mark_back_compat(fn: _T) -> _T: docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ + .. note:: Backwards-compatibility for this API is guaranteed. """ @@ -30,6 +31,7 @@ def mark_back_compat(fn: _T) -> _T: def mark_not_back_compat(fn: _T) -> _T: docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ + .. warning:: This API is experimental and is *NOT* backward-compatible. """ diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index b91bc33b06b7f..6c21f84f59a82 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -4,8 +4,9 @@ import io import itertools import pickle +import weakref from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Generator from typing import Any, NewType, TypeVar from typing_extensions import override, Self @@ -14,12 +15,13 @@ dill = import_dill() if dill is not None: - pickle = dill # noqa: F811 + pickle = dill import torch import torch.utils._pytree as pytree from torch._guards import TracingContext from torch._inductor.standalone_compile import AOTCompiledArtifact +from torch._library.fake_class_registry import FakeScriptObject from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor from torch._subclasses.meta_utils import ( MetaConverter, @@ -64,10 +66,25 @@ class Options: node_metadata_key_filter: Callable[[str], bool] | None = ( _node_metadata_key_filter_safe ) + # If True, raw torch.fx.Node objects encountered during pickling will be + # silently replaced with None instead of raising an AssertionError. + ignore_raw_node: bool = False + + +def _unpickle_as_none() -> None: + return None + + +def _unpickle_as_weakref(referent: object) -> weakref.ref[object]: + return weakref.ref(referent) + + +def _unpickle_as_dead_weakref() -> Callable[[], None]: + return lambda: None @contextlib.contextmanager -def patch_pytree_map_over_slice(): +def patch_pytree_map_over_slice() -> Generator[None]: if slice in pytree.SUPPORTED_NODES: yield return @@ -140,9 +157,25 @@ def reducer_override( return _SymNodePickleData.reduce_helper(self, obj) elif isinstance(obj, torch._guards.TracingContext): return _TracingContextPickleData.reduce_helper(self, obj) + elif isinstance(obj, FakeScriptObject): + # FakeScriptObjects wrap opaque traced objects (e.g. DeviceMesh, + # ProcessGroup) that can't be default-pickled. Reduce to None + # since they aren't meaningful after deserialization. + return (_unpickle_as_none, ()) + elif isinstance(obj, weakref.ref): + # Serialize weakrefs properly: if the referent is alive, + # serialize it and reconstruct the weakref on unpickle. + # If the referent is dead, unpickle as a dead-weakref-like callable. + referent = obj() + if referent is not None: + return (_unpickle_as_weakref, (referent,)) + else: + return (_unpickle_as_dead_weakref, ()) else: # We should never get a raw Node! if isinstance(obj, torch.fx.Node): + if self.options.ignore_raw_node: + return (_unpickle_as_none, ()) raise AssertionError("Unexpected raw Node during pickling") if reduce := _TorchNumpyPickleData.reduce_helper(self, obj): return reduce diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 7808dec3cee13..20a7837a66618 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -1,5 +1,6 @@ -# mypy: allow-untyped-defs +from collections.abc import Iterator from contextlib import contextmanager +from typing import Any, TYPE_CHECKING from torch.fx.graph_module import ( _format_import_block, @@ -12,13 +13,17 @@ from ._compatibility import compatibility +if TYPE_CHECKING: + from torch.fx.graph import PythonCode + + _use_lazy_graph_module_flag = False _force_skip_lazy_graph_module_flag = False @compatibility(is_backward_compatible=False) @contextmanager -def _force_skip_lazy_graph_module(): +def _force_skip_lazy_graph_module() -> Iterator[None]: """ Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. Use to skip _LazyGraphModule when testing inductor torchscript related backend. @@ -37,7 +42,7 @@ def _force_skip_lazy_graph_module(): @compatibility(is_backward_compatible=False) @contextmanager -def _use_lazy_graph_module(should_use: bool): +def _use_lazy_graph_module(should_use: bool) -> Iterator[None]: try: global _use_lazy_graph_module_flag prior = _use_lazy_graph_module_flag @@ -50,11 +55,13 @@ def _use_lazy_graph_module(should_use: bool): @compatibility(is_backward_compatible=False) -def _get_graph_module_cls(): +def _get_graph_module_cls() -> type[GraphModule]: return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule -def _make_graph_module(*args, graph_module_cls=None, **kwargs): +def _make_graph_module( + *args: Any, graph_module_cls: type[GraphModule] | None = None, **kwargs: Any +) -> GraphModule: if graph_module_cls is None: graph_module_cls = _get_graph_module_cls() @@ -88,14 +95,14 @@ class _LazyGraphModule(GraphModule): """ @classmethod - def from_graphmodule(cls, gm: GraphModule): + def from_graphmodule(cls, gm: GraphModule) -> GraphModule: if isinstance(gm, _LazyGraphModule): return gm else: return _LazyGraphModule(gm, gm.graph) @staticmethod - def force_recompile(gm): + def force_recompile(gm: GraphModule) -> None: """ Sometimes we need force a recompile as a workaround - we want to do the real recompilation before symbolic_trace to avoid error: @@ -104,15 +111,15 @@ def force_recompile(gm): if isinstance(gm, _LazyGraphModule): gm.real_recompile() - def real_recompile(self): + def real_recompile(self) -> None: if self._needs_recompile(): self._real_recompile() @classmethod - def _needs_recompile(cls): + def _needs_recompile(cls) -> bool: return cls.forward is cls._lazy_forward - def _lazy_forward(self, *args, **kwargs): + def _lazy_forward(self, *args: Any, **kwargs: Any) -> Any: # Call self.real_recompile() rather than self._real_recompile() here. # The _lazy_forward method may be saved and call repeatedly. # Calling self.real_recompile can make sure we skip recompilation if @@ -128,7 +135,9 @@ def _lazy_forward(self, *args, **kwargs): forward = _lazy_forward - def __reduce_package__(self, exporter: PackageExporter): + def __reduce_package__( + self, exporter: PackageExporter + ) -> tuple[Any, tuple[Any, str]]: """ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather than 'self.recompile' since for a _LazyGraphModule, self.recompile just @@ -148,7 +157,7 @@ def __reduce_package__(self, exporter: PackageExporter): (dict_without_graph, generated_module_name), ) - def __reduce__(self): + def __reduce__(self) -> tuple[Any, tuple[Any, str]]: """ Follow GraphModule.__reduce__ but call 'self._real_recompile' rather than 'self.recompile' since for a _LazyGraphModule, self.recompile just @@ -160,11 +169,11 @@ def __reduce__(self): del dict_without_graph["_graph"] return (reduce_graph_module, (dict_without_graph, import_block)) - def _real_recompile(self): + def _real_recompile(self) -> "PythonCode": return super().recompile() @classmethod - def recompile(cls): + def recompile(cls) -> None: # pyrefly: ignore[bad-override] cls.forward = cls._lazy_forward @property diff --git a/torch/fx/_pytree.py b/torch/fx/_pytree.py index 5b18c815ca312..4d923d528e710 100644 --- a/torch/fx/_pytree.py +++ b/torch/fx/_pytree.py @@ -7,7 +7,7 @@ from torch.utils._pytree import PyTree, tree_flatten, TreeSpec -FlattenFnSpec = Callable[[PyTree, TreeSpec], list] +FlattenFnSpec = Callable[[PyTree, TreeSpec], list[Any]] FlattenFnExactMatchSpec = Callable[[PyTree, TreeSpec], bool] # Keep deprecated alias for backward compatibility @@ -52,7 +52,7 @@ def tree_flatten_spec( if spec.type in SUPPORTED_NODES: flatten_fn_spec = SUPPORTED_NODES[spec.type] child_pytrees = flatten_fn_spec(pytree, spec) - result = [] + result: list[Any] = [] for child, child_spec in zip(child_pytrees, spec.children()): flat = tree_flatten_spec(child, child_spec) result += flat diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 893225b3836ef..3a5541811cb6e 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import builtins import collections import contextlib @@ -9,10 +8,10 @@ import math import os import warnings -from collections.abc import Callable +from collections.abc import Callable, Iterable, Iterator from itertools import chain -from types import CodeType, FunctionType, ModuleType -from typing import Any, get_args, NamedTuple, TypeAlias +from types import CodeType, FunctionType, ModuleType, TracebackType +from typing import Any, get_args, NamedTuple, overload, ParamSpec, TypeAlias, TypeVar import torch import torch.utils._pytree as pytree @@ -32,9 +31,13 @@ HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS +_F = TypeVar("_F", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_T = TypeVar("_T") + # These need to run in global scope to handle nested calls correctly -_orig_module_call: Callable = torch.nn.Module.__call__ -_orig_module_getattr: Callable = torch.nn.Module.__getattr__ +_orig_module_call: Callable[..., Any] = torch.nn.Module.__call__ +_orig_module_getattr: Callable[..., Any] = torch.nn.Module.__getattr__ _proxyable_classes: dict[type, None] = {} @@ -49,7 +52,7 @@ # We only want to print this once to avoid flooding logs @functools.lru_cache -def is_fx_tracing_warning(): +def is_fx_tracing_warning() -> None: log.warning( "is_fx_tracing will return true for both fx.symbolic_trace and " "torch.export. Please use " @@ -58,12 +61,12 @@ def is_fx_tracing_warning(): ) -def is_fx_tracing(): +def is_fx_tracing() -> bool: is_fx_tracing_warning() return _is_fx_tracing_flag -def is_fx_symbolic_tracing(): +def is_fx_symbolic_tracing() -> bool: return _is_fx_tracing_flag and not torch.compiler.is_compiling() @@ -116,20 +119,22 @@ def forward(self, x : __main___TensorPair, y : torch.Tensor): tracing. """ - def __init__(cls, name, bases, attrs): + def __init__( + cls, name: str, bases: tuple[type, ...], attrs: dict[str, Any] + ) -> None: _proxyable_classes.setdefault(cls) super().__init__(name, bases, attrs) - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: instance = cls.__new__(cls) # type: ignore[call-overload] if not is_fx_tracing(): cls.__init__(instance, *args, **kwargs) # type: ignore[misc] return instance - found_proxies = [] + found_proxies: list[Proxy] = [] - def check_proxy(a): + def check_proxy(a: object) -> None: if isinstance(a, Proxy): found_proxies.append(a) @@ -147,7 +152,7 @@ def check_proxy(a): def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: co = fn.__code__ co_flags = co.co_flags & ~HAS_VARSTUFF - co_args: tuple + co_args: tuple[Any, ...] if hasattr(co, "co_qualname"): # Python-3.11+ code signature co_args = ( @@ -223,7 +228,7 @@ class PHBase: Object representing an input placeholder to `concrete_args` """ - def __repr__(self): + def __repr__(self) -> str: return "PH" @@ -236,14 +241,14 @@ class PHWithMeta(PHBase): Object representing an input placeholder to `concrete_args` """ - def __init__(self, ph_key: str | None = None): + def __init__(self, ph_key: str | None = None) -> None: super().__init__() # Provide a hey for user to identify placeholder node during analysis self.ph_key = ph_key -def _transfer_attrs(fr, to): +def _transfer_attrs(fr: object, to: object) -> None: for attr_name in dir(fr): attr_val = getattr(fr, attr_name) if ( @@ -279,7 +284,7 @@ class Tracer(TracerBase): def __init__( self, autowrap_modules: tuple[ModuleType] = (math,), - autowrap_functions: tuple[Callable, ...] = (), + autowrap_functions: tuple[Callable[..., Any], ...] = (), param_shapes_constant: bool = False, ) -> None: # This method's signature is overridden by the first line of this class' @@ -581,7 +586,9 @@ def call_module( return ret_val @compatibility(is_backward_compatible=False) - def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]): + def getattr( + self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Proxy] + ) -> Any: """ Method that specifies the behavior of this ``Tracer`` when we call getattr on a call to an ``nn.Module`` instance. @@ -605,8 +612,10 @@ def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any """ def maybe_get_proxy_for_attr( - attr_val, collection_to_search, parameter_proxy_cache - ): + attr_val: Any, + collection_to_search: Iterable[tuple[str, Any]], + parameter_proxy_cache: dict[str, Proxy], + ) -> Proxy | None: for n, p in collection_to_search: if attr_val is p: if n not in parameter_proxy_cache: @@ -646,7 +655,12 @@ def maybe_get_proxy_for_attr( # This method will be refactored @compatibility(is_backward_compatible=False) - def create_args_for_root(self, root_fn, is_module, concrete_args=None): + def create_args_for_root( + self, + root_fn: Callable[..., Any], + is_module: bool, + concrete_args: dict[str, Any] | tuple[Any, ...] | None = None, + ) -> tuple[Any, list[Any]]: """ Create ``placeholder`` nodes corresponding to the signature of the ``root`` Module. This method introspects root's signature and emits those @@ -708,7 +722,7 @@ def create_args_for_root(self, root_fn, is_module, concrete_args=None): ) concrete_args = dict(zip(arg_names, concrete_args)) - def proxy_placeholder(name): + def proxy_placeholder(name: str) -> Any: return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) args.extend(proxy_placeholder(names) for names in arg_names) @@ -730,7 +744,9 @@ def proxy_placeholder(name): _PyTreeInfo(orig_args[:total_args], in_spec, None) ) - def flatten_fn(*args): + # TODO: annotate return type. inspect.get_annotations(flatten_fn) + # leaks the return annotation into the generated forward() code. + def flatten_fn(*args: Any): # pyrefly: ignore[unannotated-parameter] tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) @@ -787,7 +803,9 @@ def trace( # without this. from torch.fx._lazy_graph_module import _LazyGraphModule - _LazyGraphModule.force_recompile(root) + _LazyGraphModule.force_recompile( + root # pyrefly: ignore[bad-argument-type] + ) self.root = root @@ -823,7 +841,9 @@ def trace( str, ] = {} - def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]): + def collect_tensor_attrs( + m: torch.nn.Module, prefix_atoms: list[str] + ) -> None: for k, v in m.__dict__.items(): if isinstance(v, _constant_attribute_types): self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) @@ -847,13 +867,15 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]): # Method dispatch on parameters is not recorded unless it's directly used. # Thus, we need to insert a proxy when __getattr__ requests a parameter. @functools.wraps(_orig_module_getattr) - def module_getattr_wrapper(mod, attr): + def module_getattr_wrapper(mod: torch.nn.Module, attr: str) -> Any: attr_val = _orig_module_getattr(mod, attr) return self.getattr(attr, attr_val, parameter_proxy_cache) @functools.wraps(_orig_module_call) - def module_call_wrapper(mod, *args, **kwargs): - def forward(*args, **kwargs): + def module_call_wrapper( + mod: torch.nn.Module, *args: Any, **kwargs: Any + ) -> Any: + def forward(*args: Any, **kwargs: Any) -> Any: return _orig_module_call(mod, *args, **kwargs) _autowrap_check( @@ -907,7 +929,7 @@ def forward(*args, **kwargs): _is_fx_tracing_flag = old_is_fx_tracing_flag return self.graph - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any]) -> "Tracer": # _autowrap_search contains modules, which cannot be deepcopied. new_tracer = Tracer.__new__(Tracer) @@ -921,11 +943,17 @@ def __deepcopy__(self, memo): return new_tracer - def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + def _proxy_placeholder( + self, + name: str, + concrete_args: dict[str, Any] | None, + sig: inspect.Signature, + fn_for_analysis: Callable[..., Any], + ) -> Any: if concrete_args is not None and name in concrete_args: cnt = 0 - def replace_ph(x): + def replace_ph(x: object) -> object: nonlocal cnt cnt += 1 param = sig.parameters[name] @@ -993,7 +1021,7 @@ def replace_ph(x): # the purposes of the wrap() API. # We key by the globals dict id and function name to ensure we're wrapping a given # function only once. -_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} +_wrapped_fns_to_patch: dict[tuple[int, str], dict[str, Any]] = {} # List of methods on classes to wrap (class type, function name) # this currently only works for Tensor.* methods that aren't traced properly @@ -1008,25 +1036,26 @@ def replace_ph(x): _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) -def _find_proxy(*objects_to_search): +def _find_proxy(*objects_to_search: object) -> Proxy | None: """ Recursively search a data structure for a Proxy() and return it, return None if not found. """ proxy = None - def find_proxy(x): + def find_proxy(x: object) -> None: nonlocal proxy if isinstance(x, Proxy): proxy = x + # pyrefly: ignore[bad-specialization] map_aggregate(objects_to_search, find_proxy) return proxy -def _create_wrapped_func(orig_fn): +def _create_wrapped_func(orig_fn: Callable[_P, _T]) -> Callable[_P, _T]: @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): + def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> Any: """ Given an closed-over ``orig_function`` to invoke, search the args and kwargs for a Proxy object. If there is one, emit a ``call_function`` node to preserve the @@ -1045,11 +1074,11 @@ def wrapped(*args, **kwargs): return wrapped -def _create_wrapped_method(cls, name): +def _create_wrapped_method(cls: type, name: str) -> Callable[..., Any]: orig_fn = getattr(cls, name) @functools.wraps(orig_fn) - def wrapped(*args, **kwargs): + def wrapped(*args: Any, **kwargs: Any) -> Any: """ Search the args and kwargs for a Proxy object. If there is one, emit a ``call_method`` node to preserve the call to this method @@ -1065,39 +1094,39 @@ def wrapped(*args, **kwargs): class _PatchedFn(NamedTuple): - frame_dict: Any + frame_dict: Any # dict[str, Any] for SetItem/Del, type for SetAttr fn_name: str - orig_fn: Any - new_fn: Any + orig_fn: Callable[..., Any] | None + new_fn: Callable[..., Any] - def revert(self): + def revert(self) -> None: raise NotImplementedError - def patch(self): + def patch(self) -> None: raise NotImplementedError class _PatchedFnSetItem(_PatchedFn): - def revert(self): + def revert(self) -> None: self.frame_dict[self.fn_name] = self.orig_fn - def patch(self): + def patch(self) -> None: self.frame_dict[self.fn_name] = self.new_fn class _PatchedFnDel(_PatchedFn): - def revert(self): + def revert(self) -> None: del self.frame_dict[self.fn_name] - def patch(self): + def patch(self) -> None: self.frame_dict[self.fn_name] = self.new_fn class _PatchedFnSetAttr(_PatchedFn): - def revert(self): + def revert(self) -> None: setattr(self.frame_dict, self.fn_name, self.orig_fn) - def patch(self): + def patch(self) -> None: setattr(self.frame_dict, self.fn_name, self.new_fn) @@ -1111,9 +1140,9 @@ def patch( self, frame_dict: dict[str, Any], name: str, - new_fn: Callable, + new_fn: Callable[..., Any], deduplicate: bool = True, - ): + ) -> None: """ Replace frame_dict[name] with new_fn until we exit the context manager. """ @@ -1130,8 +1159,8 @@ def patch( self.patches_made[-1].patch() def patch_method( - self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True - ): + self, cls: type, name: str, new_fn: Callable[..., Any], deduplicate: bool = True + ) -> None: """ Replace object_or_dict.name with new_fn until we exit the context manager. """ @@ -1142,7 +1171,7 @@ def patch_method( self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) self.patches_made[-1].patch() - def visit_once(self, thing: Any): + def visit_once(self, thing: object) -> bool: """Return True on the first call to with thing, otherwise false""" idx = id(thing) if idx in self.visited: @@ -1150,7 +1179,7 @@ def visit_once(self, thing: Any): self.visited.add(idx) return True - def revert_all_patches(self): + def revert_all_patches(self) -> list[_PatchedFn]: """ Remove all the stored patcheds. It doesn't modify patches_made. """ @@ -1158,7 +1187,7 @@ def revert_all_patches(self): patch.revert() return self.patches_made - def reapply_all_patches(self): + def reapply_all_patches(self) -> list[_PatchedFn]: """ Patch all the stored patcheds. It doesn't modify patches_made. """ @@ -1166,10 +1195,15 @@ def reapply_all_patches(self): patch.patch() return self.patches_made - def __enter__(self): + def __enter__(self) -> "_Patcher": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """ Undo all the changes made via self.patch() and self.patch_method() """ @@ -1183,7 +1217,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): @contextlib.contextmanager -def _new_patcher(): +def _new_patcher() -> Iterator[_Patcher]: global CURRENT_PATCHER prior_patcher = CURRENT_PATCHER try: @@ -1198,7 +1232,7 @@ def _new_patcher(): @contextlib.contextmanager -def _maybe_revert_all_patches(): +def _maybe_revert_all_patches() -> Iterator[None]: current_patcher = CURRENT_PATCHER patches_made = None patches_removed = None @@ -1215,7 +1249,7 @@ def _maybe_revert_all_patches(): ) -def _patch_wrapped_functions(patcher: _Patcher): +def _patch_wrapped_functions(patcher: _Patcher) -> None: """ Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap the listed global functions in the `_create_wrapped_func` wrapper. @@ -1233,7 +1267,7 @@ def _patch_wrapped_functions(patcher: _Patcher): def _autowrap_check( patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int] -): +) -> None: """ Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. This method searches a scope for them and patches them if found. @@ -1248,8 +1282,12 @@ def _autowrap_check( patcher.patch(frame_dict, name, _create_wrapped_func(value)) +@overload +def wrap(fn_or_name: _F) -> _F: ... +@overload +def wrap(fn_or_name: str) -> str: ... @compatibility(is_backward_compatible=True) -def wrap(fn_or_name: str | Callable): +def wrap(fn_or_name: str | Callable[..., Any]) -> str | Callable[..., Any]: """ This function can be called at module-level scope to register fn_or_name as a "leaf function". A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being @@ -1382,6 +1420,6 @@ def f(x): @wrap -def _assert_is_none(value, msg): +def _assert_is_none(value: object, msg: str) -> None: if value is not None: raise AssertionError(msg) diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index c27d7d113ffd8..c58c36502b68d 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -1,16 +1,18 @@ -# mypy: allow-untyped-defs import sys +from typing import Any import torch from torch._logging import LazyString -def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): +def lazy_format_graph_code( + name: str, gm: torch.fx.GraphModule, maybe_id: int | None = None, **kwargs: Any +) -> LazyString: """ Returns a LazyString that formats the graph code. """ - def format_name(): + def format_name() -> str: if maybe_id is not None: return f"{name} {maybe_id}" else: @@ -35,14 +37,14 @@ def format_name(): ) -def _format_graph_code(name, filename, graph_str): +def _format_graph_code(name: str, filename: str, graph_str: str) -> str: """ Returns a string that formats the graph code. """ return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" -def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> dict | None: +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> dict[str, Any] | None: """ Returns the nn_module_stack of the first call_function node. """ @@ -52,14 +54,15 @@ def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> dict | None: return None -def get_node_context(node, num_nodes=2) -> str: +def get_node_context(node: torch.fx.Node, num_nodes: int = 2) -> str: """ Returns a string of the last num_nodes nodes in the graph. """ node_contexts = [] cur = node for _ in range(num_nodes): - node_contexts.append(cur.format_node()) + # cast to str to handle None return value + node_contexts.append(str(cur.format_node())) if cur.op == "root": break cur = cur.prev diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index b3c5056066251..fd90605d8e3b5 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,11 +1,15 @@ -# mypy: allow-untyped-defs +from typing import Any + from torch.fx.proxy import Proxy from ._compatibility import compatibility +__all__ = ["annotate"] + + @compatibility(is_backward_compatible=False) -def annotate(val, type): +def annotate(val: Any, type: type) -> Any: """ Annotates a Proxy object with a given type. diff --git a/torch/fx/experimental/_constant_symnode.py b/torch/fx/experimental/_constant_symnode.py index b3b40bda324c8..ae467419f91c3 100644 --- a/torch/fx/experimental/_constant_symnode.py +++ b/torch/fx/experimental/_constant_symnode.py @@ -5,7 +5,7 @@ # This needs to exist because the Python version of nested int is not compatible # with the C++ version of constant symnode. class ConstantIntNode: - def __init__(self, val: int): + def __init__(self, val: int) -> None: self.val = val def is_constant(self) -> bool: diff --git a/torch/fx/experimental/_dynamism.py b/torch/fx/experimental/_dynamism.py index 186e8de3898c0..be09162e2deb4 100644 --- a/torch/fx/experimental/_dynamism.py +++ b/torch/fx/experimental/_dynamism.py @@ -26,8 +26,10 @@ def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]: """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys.""" self_dict: dict[str, Any] = {} - self_dict["_parameters"] = {} - self_dict["_modules"] = {} + parameters: dict[str, torch.Tensor] = {} + modules: dict[str, dict[str, Any]] = {} + self_dict["_parameters"] = parameters + self_dict["_modules"] = modules for attr_name in dir(module): try: diff --git a/torch/fx/experimental/_size_hinting.py b/torch/fx/experimental/_size_hinting.py new file mode 100644 index 0000000000000..89e55dbedbe3a --- /dev/null +++ b/torch/fx/experimental/_size_hinting.py @@ -0,0 +1,444 @@ +""" +Size hinting utilities for symbolic shape expressions. + +This module contains the core logic for resolving symbolic expressions to +concrete integer hints. Two strategies are provided: + +- _guarding_hint_or_throw_base: strict, only uses backed symbol hints, throws on + unbacked symbols. Use for correctness-critical guarding decisions. +- _optimization_hint_base: permissive, uses heuristics and fallbacks for unbacked + symbols. Use for performance optimization decisions. +""" + +from __future__ import annotations + +import logging +import sys +from typing import Any, TYPE_CHECKING + +import sympy + +from torch.utils._sympy.numbers import int_oo + + +log = logging.getLogger(__name__) + +# Maximum number of free symbols in an expression before we skip +# sympy.factor() in optimization_hint process for unbacked. +# Factoring polynomials with many variables is expensive. +SYMPY_FACTOR_MAX_FREE_SYMBOLS = 50 + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _sympy_subs(expr: sympy.Basic, replacements: dict[sympy.Expr, Any]) -> sympy.Basic: + """ + When the passed replacement symbol v is a string, it is converted to a symbol with name v that + have the same replaced expression integer and nonnegative properties. + """ + + def to_symbol(replaced: sympy.Expr, replacement: sympy.Expr | str) -> sympy.Symbol: + if not isinstance(replaced, sympy.Expr): + raise AssertionError( + f"Expected sympy.Expr key, got {type(replaced)}: {replaced}" + ) + if isinstance(replacement, str): + return sympy.Symbol( + replacement, + integer=replaced.is_integer, # type: ignore[attr-defined] + nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] + ) + else: + return replacement + + # xreplace is faster than subs, but is way more picky + return sympy.sympify(expr).xreplace( + {k: to_symbol(k, v) for k, v in replacements.items()} + ) + + +def _maybe_realize_expr( + expr: sympy.Basic, nan_fallback: int | None +) -> int | bool | None: + """ + Handle special sympy values in hinting APIs. + + Returns: + - True/False for sympy.true/sympy.false (preserves bool type) + - Raises ValueError for complex numbers + - sys.maxsize for positive infinity + - -sys.maxsize for negative infinity + - fallback for NaN + - None if no special handling needed + """ + if expr is sympy.true: + return True + if expr is sympy.false: + return False + + try: + return int(expr) + except (TypeError, ValueError): + pass + + if isinstance(expr, sympy.Expr): + if expr.has(sympy.I): + raise ValueError( + f"_maybe_realize_expr received a complex expression: {expr}. " + "Tensor dimensions cannot be complex numbers." + ) + if expr in (int_oo, sympy.oo): + return sys.maxsize + if expr in (-int_oo, -sympy.oo): + return -sys.maxsize + if nan_fallback is not None and (expr is sympy.nan or expr.has(sympy.nan)): + return nan_fallback + + return None + + +def _guarding_hint_or_throw_base( + shape_env: ShapeEnv, + expr: sympy.Expr | sympy.Basic | int | bool, + precomputed_replacements: dict[sympy.Expr, sympy.Symbol], +) -> int | bool: + """ + Return a concrete integer hint for an expression that is safe to use for guarding. + + This function evaluates the expression using only backed-symbols hints. Unlike + _optimization_hint_base(), this function does NOT use heuristics or fallback values + for unbacked symbols. + + Use this when you need a hint value that will be used for a guarding decision. + + Args: + shape_env: The ShapeEnv instance. + expr: A sympy expression or integer to evaluate. + precomputed_replacements: Precomputed replacements for PRECOMPUTED_SIZE symbols. + + Returns: + The concrete integer value of the expression based on backed symbol hints. + + Raises: + GuardOnDataDependentSymNode: If the expression contains unbacked symbols + (data-dependent values) that cannot be resolved to concrete values. + + See Also: + _optimization_hint_base: For cases where fallback/heuristic values are acceptable + for unbacked symbols. + """ + from torch.fx.experimental.symbolic_shapes import ( + has_free_unbacked_symbols, + symbol_is_type, + SymT, + ) + + # sympy.expand() doesn't work with boolean expressions like Or/And + if isinstance(expr, sympy.Expr): + expr = sympy.expand(expr).xreplace(shape_env.replacements) + else: + expr = sympy.sympify(expr).xreplace(shape_env.replacements) + + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + + result = _maybe_realize_expr(expr, None) + if result is not None: + return result + + if not isinstance(expr, sympy.Basic): + raise RuntimeError("isinstance(expr, sympy.Basic)", expr, type(expr)) + + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + expr = _sympy_subs(expr, precomputed_replacements) + + # TODO do we need sympy_subs, or just xreplace + expr = _sympy_subs(expr, shape_env.backed_var_to_val) + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + + if has_free_unbacked_symbols(expr): + # Note: we could do better here and call + # _maybe_evaluate_static(orig_expr, compute_hint=True) + # but is it worth the overhead? probably not. + raise shape_env._make_data_dependent_error(expr, expr) + + result = _maybe_realize_expr(expr, None) + if result is None: + raise RuntimeError("unexpected None!", expr) + return result + + +def _get_unbacked_replacements(shape_env: ShapeEnv) -> dict[sympy.Expr, sympy.Expr]: + """Builds a mapping from unbacked expressions to canonical equivalents + using a union-find algorithm over deferred runtime asserts. + Used by optimization_hint to resolve unbacked symbols to consistent values.""" + from collections import defaultdict + + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + from torch.utils._ordered_set import OrderedSet + + if shape_env._unbacked_replacements is not None: + return shape_env._unbacked_replacements + + class CanonicalExprFinder: + """ + A disjoint-set/union-find data structure that can return the + "canonical" expression for a group of equivalent expressions. + - The canonical expression must come from the input eq_graph. + - The heuristics used to choose a leader determines which + expression becomes the canonical expression. + """ + + def __init__(self, eq_graph: dict[sympy.Expr, OrderedSet[sympy.Expr]]) -> None: + self.eq_graph = eq_graph + self.expressions = list(eq_graph.keys()) + self.reverse_expressions = { + expr: i for i, expr in enumerate(self.expressions) + } + self.leader = list(range(len(self.expressions))) + self.size = [1] * len(self.expressions) + self._build_canonical_expr_mapping() + + def _build_canonical_expr_mapping(self) -> None: + for expr, edges in self.eq_graph.items(): + for adj in edges: + self.union_expr(expr, adj) + + def union_expr(self, a: sympy.Expr, b: sympy.Expr) -> bool: + return self.union(self.reverse_expressions[a], self.reverse_expressions[b]) + + def union(self, a: int, b: int) -> bool: + rootA = self.find(a) + rootB = self.find(b) + if rootA == rootB: + return False + leader, other = self.choose_leader(rootA, rootB) + self.leader[other] = leader + self.size[leader] += self.size[other] + return True + + def find_expr(self, expr: sympy.Expr) -> sympy.Expr: + parent = self.find(self.reverse_expressions[expr]) + return self.expressions[parent] + + def find(self, x: int) -> int: + if self.leader[x] != x: + self.leader[x] = self.find(self.leader[x]) + return self.leader[x] + + def choose_leader(self, a: int, b: int) -> tuple[int, int]: + """ + The leader will become the canonical expression. + Returns a (leader, follower) tuple. + + Heuristics: + 1. Backed expression or constants preferred over unbacked expr + 2. Simpler sub-expr when one contains the other + 3. Higher frequency across equalities from deferred runtime assertions + 4. Size of the set + 5. Fallback to sympy.Basic.compare + """ + + def _choose(x: int, y: int) -> bool: + lhs, rhs = self.expressions[x], self.expressions[y] + + any_unbacked_lhs = has_free_unbacked_symbols(lhs) + any_unbacked_rhs = has_free_unbacked_symbols(rhs) + if any_unbacked_lhs != any_unbacked_rhs: + return bool(any_unbacked_rhs) + + if lhs.has(rhs): + return False + elif rhs.has(lhs): + return True + + degrees_lhs = len(self.eq_graph[lhs]) + degrees_rhs = len(self.eq_graph[rhs]) + if degrees_lhs != degrees_rhs: + return degrees_lhs > degrees_rhs + + if self.size[x] != self.size[y]: + return self.size[x] > self.size[y] + + return lhs.compare(rhs) == -1 + + if _choose(a, b): + return a, b + return b, a + + # Build an undirected graph using ShapeEnv's deferred runtime assertions. + shape_env._equality_graph = defaultdict(OrderedSet) + for assertions in shape_env.deferred_runtime_asserts.values(): + for assertion in assertions: + if not isinstance(assertion.expr, sympy.Equality): + continue + lhs = sympy.sympify(assertion.expr.lhs) + rhs = sympy.sympify(assertion.expr.rhs) + shape_env._equality_graph[lhs].add(rhs) + shape_env._equality_graph[rhs].add(lhs) + + uf = CanonicalExprFinder(shape_env._equality_graph) + + shape_env._unbacked_replacements = {} + for expr in shape_env._equality_graph: + canonical_expr = uf.find_expr(expr) + if expr != canonical_expr: + shape_env._unbacked_replacements[expr] = canonical_expr + + return shape_env._unbacked_replacements + + +def _sub_unbacked_exprs(shape_env: ShapeEnv, expr: sympy.Expr) -> sympy.Expr: + """Substitute unbacked expressions with canonical equivalents. + Used by optimization_hint to maximize consistency when hinting unbacked symbols.""" + replacements = _get_unbacked_replacements(shape_env) + + # consider making this threshold configurable + sub_cnt_limit = 30 + sub_cnt = 0 + while sub_cnt < sub_cnt_limit: + new_expr = expr.subs(replacements) + if new_expr == expr: + break + if len(new_expr.free_symbols) <= SYMPY_FACTOR_MAX_FREE_SYMBOLS: + expr = sympy.factor(new_expr) + else: + expr = new_expr + sub_cnt += 1 + else: + log.warning("Substitution limit (%d) reached w/ %s", sub_cnt_limit, expr) + + expr = _sympy_subs(expr, shape_env.backed_var_to_val) + expr = _sympy_subs(expr, shape_env.var_to_hint_override) + return expr + + +def _optimization_hint_base( + shape_env: ShapeEnv, + expr: sympy.Expr | int, + precomputed_replacements: dict[sympy.Expr, sympy.Symbol], + fallback: int | None = None, +) -> int: + """ + Return a concrete integer hint for an expression using heuristics. + + This function should be used for non-guarding based optimizations. + It will hint unbacked symbols using user provided optimization hints. + If not provided, fallback will be used along with some heuristics + that try to maximize consistency with the shape environment. + + Args: + shape_env: The ShapeEnv instance. + expr: A sympy expression or integer to evaluate. + precomputed_replacements: Precomputed replacements for PRECOMPUTED_SIZE symbols. + fallback: Fallback value for unbacked symbols. If None, reads from config. + + Returns: + A concrete integer hint for the expression. + """ + from torch.fx.experimental.symbolic_shapes import ( + has_free_unbacked_symbols, + symbol_is_type, + SymT, + ) + + # Read config at call time to respect runtime patches (e.g., in tests) + if fallback is None: + from torch._inductor.config import unbacked_symint_fallback + + fallback = unbacked_symint_fallback + + # to have expanded (Identity free) expr stored in original + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + + original = expr + # sympy.expand() doesn't work with boolean expressions like Or/And + if isinstance(expr, sympy.Expr): + expr = expr.xreplace(shape_env.replacements) + else: + expr = sympy.sympify(expr).xreplace(shape_env.replacements) + + result = _maybe_realize_expr(expr, fallback) + if result is not None: + return result + + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + + # Replace backed symbols with their hints, leaving unbacked symbols alone. + result = _maybe_realize_expr(expr, None) + if result is not None: + return result + + if not isinstance(expr, sympy.Expr): + raise RuntimeError("isinstance(expr, sympy.Expr)", expr) + + if any(symbol_is_type(s, SymT.PRECOMPUTED_SIZE) for s in expr.free_symbols): # type: ignore[attr-defined] + expr = _sympy_subs(expr, precomputed_replacements) + + expr = _sympy_subs(expr, shape_env.backed_var_to_val) + if isinstance(expr, sympy.Expr): + expr = expr.expand(identity=True) + + result = _maybe_realize_expr(expr, fallback) + if result is not None: + return result + + expr = _sympy_subs(expr, shape_env.var_to_hint_override) + + result = _maybe_realize_expr(expr, fallback) + if result is not None: + return result + + # If unbacked symbols remain, try to substitute them using heuristics + # that maximize consistency with the shape environment. + if has_free_unbacked_symbols(expr): + # Make sure to substitute with the factored version + # e.g. 10*(s0 + u0) instead of 10*s0 + 10*u0 + if ( + isinstance(original, sympy.Expr) + and len(original.free_symbols) <= SYMPY_FACTOR_MAX_FREE_SYMBOLS + ): + original = sympy.factor(original) + expr = _sub_unbacked_exprs(shape_env, original) + + # For multiple expressions that depend on an unbacked symint, + # we want to compute them consistently for a size hint we have chosen. + # So, recursively compute expressions via size hints of contained symbols. + # For example: u1 * u2 - 10 ==> fallback * fallback - 10 + + if not isinstance(expr, sympy.Expr): + raise RuntimeError(f"Expected sympy Expr, got {type(expr)}: {expr}") + free_symbols = expr.free_symbols + + # Constrain fallback per-symbol based on var_to_range bounds + size_dict = {} + for s in free_symbols: + sym_fallback = fallback + vr = shape_env.var_to_range.get(s, None) + if vr is not None: + if isinstance(vr.lower, (int, sympy.Integer)): + sym_fallback = max(sym_fallback, int(vr.lower)) + if isinstance(vr.upper, (int, sympy.Integer)): + sym_fallback = min(sym_fallback, int(vr.upper)) + size_dict[s] = sym_fallback + + try: + final_result = expr.subs(size_dict) + except ZeroDivisionError: + # Expressions like ModularIndexing(x, u1, 4) crash during subs() + # when u1 is substituted with 0, because sympy eagerly evaluates + # (x // 0) % 4. This can happen when an unbacked symbol with + # var_to_range lower=0 is used as a divisor (e.g. from + # _dynamic_reshape_indexer) and the fallback also maps to 0. + # Return fallback in that case. + return fallback if fallback is not None else 0 + + final_result = _maybe_realize_expr(final_result, fallback) + if final_result is None: + raise RuntimeError(f"Failed to realize expression to int: {expr}") + + return final_result diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 09fc60be54b80..91cab81e9c65c 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import operator from collections import deque from typing import NamedTuple @@ -73,7 +72,7 @@ class PartitionResult(NamedTuple): """Followings are some helper functions for partition manipulation""" -def reset_partition_device(partitions): +def reset_partition_device(partitions: list[Partition]) -> None: for partition in partitions: partition.logical_device_ids = [] @@ -214,14 +213,14 @@ def get_device_partition_stats( def get_device_to_partitions_mapping( partitions: list[Partition], devices: list[Device] -): +) -> bool: """Given a list of partitions and a list of devices, map each partition into a device. """ def calculate_extra_mem_bytes_needed_for( partition: Partition, partitions: list[Partition] - ): + ) -> int: all_nodes: set[Node] = set() for p in partitions: all_nodes = all_nodes.union(p.nodes) @@ -233,7 +232,7 @@ def calculate_extra_mem_bytes_needed_for( extra_size_needed += get_extra_size_of(node, all_nodes) return extra_size_needed - def find_device_for(partition: Partition): + def find_device_for(partition: Partition) -> bool: """Given a partition, find a logical device for the partition The algorithm is to put the partition on the device that has just enough mem left for that partition. @@ -269,7 +268,7 @@ def find_device_for(partition: Partition): return found_device -def check_dependency(partition): +def check_dependency(partition: Partition) -> bool: """Given a partition,check if there is a circular dependency on this partition using bfs """ @@ -385,7 +384,7 @@ def partition_graph( return ret def find_single_partition( - self, total_size_of_graph, logical_device_id: int = 0 + self, total_size_of_graph: int, logical_device_id: int = 0 ) -> None: """Fit the whole fx module into one device""" partition_0 = self.create_partition() @@ -416,7 +415,7 @@ def size_based_partition(self) -> None: and then try to map those partitions into logical devices with enough mem left. """ - def find_device_based_on_size(node) -> Device: + def find_device_based_on_size(node: Node) -> Device: """Given a node, this function is to find a logical device that could fit the node. """ @@ -610,7 +609,7 @@ def create_partition(self) -> Partition: self.partitions.append(partition) return partition - def create_single_node_partition(self, node): + def create_single_node_partition(self, node: Node) -> None: """Create a partition for a single node""" partition = self.create_partition() partition.add_node(node) @@ -664,7 +663,7 @@ def combine_partitions_based_on_size( ) return - def calculate_mem_bytes_needed(p1, p2): + def calculate_mem_bytes_needed(p1: Partition, p2: Partition) -> int: """Given two partitions, calculate how many mem bytes are needed if two partitions are combined """ @@ -695,20 +694,23 @@ def find_partition_to_combine_based_on_size( break return find_combination, partitions - def reset_partition_in_sparse_nn(partition, new_partition=True): - """If crossing the boundary between non-embedding nodes and - embedding nodes, create a new partition - """ + def reset_partition_in_sparse_nn(partition: Partition) -> Partition: + """Finalize current partition and create a new one.""" + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + partition = self.create_partition() + # pyrefly: ignore [missing-attribute] + partition.left_mem_bytes = available_mem_bytes + return partition + + def finalize_partition(partition: Partition) -> None: + """Finalize current partition without creating a new one.""" if in_embedding_region: embedding_partitions.append(partition) else: non_embedding_partitions.append(partition) - if new_partition: - partition = self.create_partition() - # pyrefly: ignore [missing-attribute] - partition.left_mem_bytes = available_mem_bytes - return partition - return None def is_embedding_node(node: Node) -> bool: """Check if a node is an embedding node""" @@ -752,7 +754,7 @@ def is_embedding_node(node: Node) -> bool: node.target + "is too large to fit into a device" ) partition.add_node(node) - reset_partition_in_sparse_nn(partition, new_partition=False) + finalize_partition(partition) # Set parents and children for partitions set_parents_and_children(self.partitions) # Combining non-embedding partitions @@ -816,7 +818,9 @@ def cost_aware_partition( #3. Repeat #2 until the cost cannot be reduced. """ - def try_combining_partitions(p0_index, p1_index, partitions) -> float: + def try_combining_partitions( + p0_index: int, p1_index: int, partitions: list[Partition] + ) -> float: """Given two partitions and a list of partitions, combine these two partitions and see what is the cost of the modified partition list """ @@ -855,7 +859,8 @@ def try_combining_partitions(p0_index, p1_index, partitions) -> float: return float("inf") def search_combination( - transfer_rate_bytes_per_sec, node_to_latency_mapping + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], ) -> bool: """Given transfer rate between partitions and each node's latency, find two partitions to combine so the cost of the partitions can @@ -940,7 +945,9 @@ def kl_based_partition( are tried. """ - def swap_nodes(n0, n1, p0, p1): + def swap_nodes( + n0: Node | None, n1: Node | None, p0: Partition, p1: Partition + ) -> None: # Either n0 or n1 could be None # That means we simply move the node # to another partition @@ -952,8 +959,13 @@ def swap_nodes(n0, n1, p0, p1): p1.remove_node(n1) def try_swap_nodes( - n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): + n0: Node | None, + n1: Node | None, + p0: Partition, + p1: Partition, + node_to_latency_mapping: dict[Node, NodeLatency], + transfer_rate_per_sec: float, + ) -> float: cost = float("inf") swap_nodes(n0, n1, p0, p1) # Reorganize partitions after swapping @@ -984,8 +996,12 @@ def try_swap_nodes( return cost def swap_node_to_partition( - node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec - ): + node: Node, + p0: Partition, + p1: Partition, + node_to_latency_mapping: dict[Node, NodeLatency], + transfer_rate_per_sec: float, + ) -> tuple[float, list[Node]]: """This function helps to swap one node from partition p0 with all the nodes in another partition p1 """ @@ -1060,8 +1076,10 @@ def swap_node_to_partition( return def aot_based_partition( - self, node_to_partition_mapping, partition_to_logical_device_mapping - ): + self, + node_to_partition_mapping: dict[Node, int], + partition_to_logical_device_mapping: dict[int, list[int]], + ) -> None: """This function helps to rebuild the partitions given the nodes and its corresponding partition id """ diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index f44293e2242a7..a9ceb336376e8 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs import re from collections.abc import Callable +from typing import Any import torch.fx from torch.fx.node import map_arg @@ -31,7 +31,7 @@ def __init__( const_subgraph: torch.fx.Graph | None = None, fx_const_folded_attrs_name: str | None = None, device_for_folded_attrs: str = "cuda", - ): + ) -> None: super().__init__(root, graph) self.const_subgraph_module = ( None @@ -42,12 +42,12 @@ def __init__( self.fx_const_folded_attrs_name = fx_const_folded_attrs_name self.device_for_folded_attrs = device_for_folded_attrs - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> Any: if not self.has_folding_been_run: self.run_folding() return super().__call__(*args) - def run_folding(self): + def run_folding(self) -> None: # If there's no const subgraph module or attr output names to use, return # early as there is no const folding to perform. if ( @@ -65,7 +65,7 @@ def run_folding(self): # Tuple[Tensor,]. folded_attrs = self.const_subgraph_module() - def _create_param(i): + def _create_param(i: torch.Tensor | int) -> torch.nn.Parameter: return torch.nn.Parameter( i.detach().clone() if not isinstance(i, int) @@ -110,7 +110,7 @@ def _inline_module( replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {} ph_count = 0 - def replacement_fn(node): + def replacement_fn(node: torch.fx.Node) -> torch.fx.Node: new_node = replacement_mapping[node] new_node.meta = node.meta.copy() return new_node @@ -281,21 +281,34 @@ def _subgraph_has_impure_ops(module: torch.fx.GraphModule) -> bool: # Partition the module into two: submod_0 for constant folding subgraph, and # submod_1 for the rest. - def mod_partition(node: torch.fx.Node): + def mod_partition(node: torch.fx.Node) -> int: return 0 if node in const_nodes else 1 split = split_module(mod_traced, module, mod_partition) const_mod_name, non_const_mod_name = "submod_0", "submod_1" # Safely get submod_1 in case there are no non-const nodes - const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) + const_gm = getattr(split, const_mod_name) + if not isinstance(const_gm, torch.fx.GraphModule): + raise AssertionError( + f"Expected GraphModule for {const_mod_name}, got {type(const_gm)}" + ) + non_const_mod = getattr(split, non_const_mod_name, None) + non_const_gm: torch.fx.GraphModule | None = None + if non_const_mod is not None: + if not isinstance(non_const_mod, torch.fx.GraphModule): + raise AssertionError( + f"Expected GraphModule for {non_const_mod_name}, got {type(non_const_mod)}" + ) + non_const_gm = non_const_mod # The module that a call_module node refers to gets copied to submodules during split. # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to # attach inlined modules to `split` as it's the owning module now. - for node in non_const_gm.graph.nodes if non_const_gm else []: - if node.op == "call_module": - setattr(split, node.target, getattr(non_const_gm, node.target)) + if non_const_gm is not None: + for node in non_const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) for node in const_gm.graph.nodes: if node.op == "call_module": setattr(split, node.target, getattr(const_gm, node.target)) diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index b100910b3cc55..9ef96c78537e7 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -1,9 +1,8 @@ -# mypy: allow-untyped-defs import itertools import operator from collections.abc import Callable from functools import reduce -from typing import TypeVar +from typing import Any, TypeVar from typing_extensions import ParamSpec import sympy @@ -20,9 +19,9 @@ _T = TypeVar("_T") _P = ParamSpec("_P") -_INFERENCE_RULES: dict[Target, Callable] = {} -_REFINEMENT_RULES: dict[Target, Callable] = {} -_RULES: dict[Target, Callable] = {} +_INFERENCE_RULES: dict[Target, Callable[..., Any]] = {} +_REFINEMENT_RULES: dict[Target, Callable[..., Any]] = {} +_RULES: dict[Target, Callable[..., Any]] = {} __all__ = [ "GraphTypeChecker", @@ -60,7 +59,8 @@ ] -def expand_to_tensor_dim(t, n): +# TODO: narrow t to TensorType | _DynType once Node.type is narrowed +def expand_to_tensor_dim(t: Any, n: int) -> TensorType: """ Expand a type to the desired tensor dimension if possible Raise an error otherwise. @@ -80,7 +80,7 @@ def expand_to_tensor_dim(t, n): raise TypeError(f"Cannot match the type {t}") -def broadcast_types(t1, t2): +def broadcast_types(t1: Any, t2: Any) -> tuple[Any, Any]: """ Applies broadcasting to both given types such that they become consistent with each other and returns two new @@ -163,7 +163,7 @@ def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: @register_inference_rule(torch.add) @register_inference_rule(operator.add) -def add_inference_rule(n: Node): +def add_inference_rule(n: Node) -> Any: """ Apply the addition inference rule. This includes: - scalar addition @@ -228,7 +228,7 @@ def add_inference_rule(n: Node): @register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, traced): +def get_attr_inference_rule(n: Node, traced: Any) -> Any: """ The current getattr rule only handles the shape attribute Can be extended to other attributes @@ -247,7 +247,7 @@ def get_attr_inference_rule(n: Node, traced): @register_inference_rule(torch.transpose) -def transpose_inference_rule(n: Node): +def transpose_inference_rule(n: Node) -> Any: """ We check that dimensions for the transpose operations are within range of the tensor type of the node @@ -285,7 +285,7 @@ def transpose_inference_rule(n: Node): @register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node): +def reshape_inference_rule(n: Node) -> TensorType: """ Without dynamism, the rule checks that the product of the elements of the argument tensor @@ -328,7 +328,7 @@ def reshape_inference_rule(n: Node): @register_inference_rule(BatchNorm2d) -def bn2d_inference_rule(n: Node, module_instance): +def bn2d_inference_rule(n: Node, module_instance: Any) -> Any: """ Given a BatchNorm2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) @@ -364,7 +364,7 @@ def bn2d_inference_rule(n: Node, module_instance): ) -def calculate_out_dimension(d_in, module_instance, index): +def calculate_out_dimension(d_in: Any, module_instance: Any, index: int) -> Any: """ For calculating h_in and w_out according to the conv2D documentation """ @@ -405,7 +405,8 @@ def calculate_out_dimension(d_in, module_instance, index): ) -def get_greatest_upper_bound(type1, type2): +# TODO: narrow params/return to TensorType | _DynType once Node.type is narrowed +def get_greatest_upper_bound(type1: Any, type2: Any) -> Any: """ Get the most precise type that's consistent with the given types """ @@ -424,7 +425,7 @@ def get_greatest_upper_bound(type1, type2): @register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance): +def conv2d_inference_rule(n: Node, module_instance: Any) -> Any: """ Given a Conv2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) @@ -457,7 +458,7 @@ def conv2d_inference_rule(n: Node, module_instance): @register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance): +def relu_inference_rule(n: Node, module_instance: Any) -> Any: """ Input and output shapes should be equal. """ @@ -472,7 +473,7 @@ def relu_inference_rule(n: Node, module_instance): return n.type -def maxpool2d_check(typ, module_instance): +def maxpool2d_check(typ: Any, module_instance: Any) -> TensorType: """ Applies the maxpool2d shape information to the input this affects the last two dimensions @@ -494,7 +495,7 @@ def maxpool2d_check(typ, module_instance): @register_inference_rule(torch.nn.MaxPool2d) -def maxpool2d_inference_rule(n: Node, module_instance): +def maxpool2d_inference_rule(n: Node, module_instance: Any) -> Any: """ Given a MaxPool2D instance and a node check the following conditions: - Input size matches size 3 or 4 @@ -515,7 +516,7 @@ def maxpool2d_inference_rule(n: Node, module_instance): return n.type -def linear_check(tensor_type, module_instance): +def linear_check(tensor_type: Any, module_instance: Any) -> TensorType: """ Checks that an input tensor type satisfies the conditions for linear operation and returns the output type based on in and out features given by module_instance @@ -534,7 +535,7 @@ def linear_check(tensor_type, module_instance): @register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance): +def linear_inference_rule(n: Node, module_instance: Any) -> Any: """ Applies the shape information to the input then gets the greatest upper bound of the resulting type and the existing type @@ -549,7 +550,7 @@ def linear_inference_rule(n: Node, module_instance): return n.type -def adaptiveavgpool2d_check(tensor_type, module_instance): +def adaptiveavgpool2d_check(tensor_type: Any, module_instance: Any) -> TensorType: output_size = module_instance.output_size if isinstance(output_size, int): output_size = [output_size, output_size] @@ -573,7 +574,7 @@ def adaptiveavgpool2d_check(tensor_type, module_instance): @register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptiveavgpool2d_inference_rule(n: Node, module_instance): +def adaptiveavgpool2d_inference_rule(n: Node, module_instance: Any) -> Any: """ The input and output sizes should be the same except for the last two dimensions taken from the input, which represent width and height @@ -588,7 +589,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance): return n.type -def flatten_check(tensor_type, start_dim, end_dim): +def flatten_check(tensor_type: Any, start_dim: int, end_dim: int) -> TensorType: l = len(tensor_type.__args__) start_dim = l if start_dim == -1 else abs(start_dim) @@ -612,7 +613,7 @@ def flatten_check(tensor_type, start_dim, end_dim): @register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node): +def flatten_inference_rule(n: Node) -> Any: """ Applies the flatten shape information to the input then gets the greatest upper bound of the resulting type and the existing type @@ -645,11 +646,11 @@ def flatten_inference_rule(n: Node): class GraphTypeChecker: - def __init__(self, env, traced): + def __init__(self, env: dict[str, Any], traced: torch.fx.GraphModule) -> None: self.env = env self.traced = traced - def type_check(self): + def type_check(self) -> bool: """ A gradual type checker for graphs Effect: every node's field type will be @@ -663,7 +664,7 @@ def type_check(self): self.type_check_node(n) return True - def type_check_node(self, n: Node): + def type_check_node(self, n: Node) -> Any: """ Type check a given fx node. Current operations: @@ -704,6 +705,7 @@ def type_check_node(self, n: Node): ) elif n.op == "call_module": + # pyrefly: ignore[bad-argument-type] module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: return _INFERENCE_RULES[type(module_instance)](n, module_instance) @@ -714,7 +716,7 @@ def type_check_node(self, n: Node): elif n.op == "output": - def get_node_type(a): + def get_node_type(a: Any) -> Any: return a.type n.type = torch.fx.node.map_arg(n.args[0], get_node_type) @@ -725,12 +727,12 @@ def get_node_type(a): @register_refinement_rule(Conv2d) -def conv_refinement_rule(n: Node): +def conv_refinement_rule(n: Node) -> list[Any] | None: """ The equality constraints are between the first dimension of the input and output """ - res = [] + res: list[Any] = [] if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") arg_type = n.args[0].type @@ -740,12 +742,12 @@ def conv_refinement_rule(n: Node): @register_refinement_rule(torch.nn.Linear) -def linear_refinement_rule(n: Node): +def linear_refinement_rule(n: Node) -> list[Any]: """ The equality constraints are between the first dimension of the input and output """ - res = [] + res: list[Any] = [] if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") arg_type = n.args[0].type @@ -756,11 +758,11 @@ def linear_refinement_rule(n: Node): @register_refinement_rule(BatchNorm2d) @register_refinement_rule(torch.nn.ReLU) -def all_eq(n: Node): +def all_eq(n: Node) -> list[Any]: """ For operations where the input shape is equal to the output shape """ - res = [] + res: list[Any] = [] if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") arg_type = n.args[0].type @@ -773,12 +775,12 @@ def all_eq(n: Node): @register_refinement_rule(torch.nn.AdaptiveAvgPool2d) @register_refinement_rule(torch.nn.MaxPool2d) -def first_two_eq(n: Node): +def first_two_eq(n: Node) -> list[Any]: """ For operations where the first two dimensions of the input and output shape are equal """ - res = [] + res: list[Any] = [] if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") arg_type = n.args[0].type @@ -791,7 +793,7 @@ def first_two_eq(n: Node): @register_refinement_rule(torch.add) @register_refinement_rule(operator.add) -def element_wise_eq(n: Node): +def element_wise_eq(n: Node) -> list[Any]: """ For element-wise operations and handles broadcasting. Note that after applying broadcasting to the arguments @@ -806,7 +808,7 @@ def element_wise_eq(n: Node): including unification) and another iteration to establish equality between the operands and the resulting type, requiring another round of constraint generation and unificaiton. """ - res = [] + res: list[Any] = [] if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type @@ -832,7 +834,7 @@ def element_wise_eq(n: Node): @register_refinement_rule(torch.flatten) -def flatten_refinement_rule(n: Node): +def flatten_refinement_rule(n: Node) -> list[Any]: """ Generates equality constraints between the dimensions of the input and output that will not be involved in the flatten operation @@ -870,7 +872,7 @@ def flatten_refinement_rule(n: Node): @register_algebraic_expressions_inference_rule(Conv2d) -def conv_rule(n: Node, module_instance): +def conv_rule(n: Node, module_instance: Any) -> TensorType | None: """ Represents the output in terms of an algrbraic expression w.r.t the input when possible @@ -895,12 +897,12 @@ class Refine: Currently all constraints are equality constraints. """ - def __init__(self, traced): + def __init__(self, traced: Any) -> None: self.constraints = [] self.traced = traced self.symbol_iter = itertools.count(start=0, step=1) - def refine(self): + def refine(self) -> bool: """ Generates constraints for every node in the graph based on @@ -911,7 +913,7 @@ def refine(self): self.refine_node(n) return True - def symbolic_relations(self): + def symbolic_relations(self) -> bool: """ Infers algebraic relations """ @@ -920,7 +922,7 @@ def symbolic_relations(self): self.infer_symbolic_relations(n) return True - def replace_dyn_with_fresh_var(self, typ): + def replace_dyn_with_fresh_var(self, typ: Any) -> Any: """ Replace all unknown types with fresh type variables. """ @@ -937,7 +939,7 @@ def replace_dyn_with_fresh_var(self, typ): else: return typ - def convert_to_sympy_symbols(self, typ): + def convert_to_sympy_symbols(self, typ: Any) -> Any: """ Replace all unknown types with fresh type variables. """ @@ -953,7 +955,7 @@ def convert_to_sympy_symbols(self, typ): else: return typ - def refine_node(self, n: Node): + def refine_node(self, n: Node) -> Any: """ Returns a list of equality constraints for call_module and call_function nodes. @@ -977,13 +979,13 @@ def refine_node(self, n: Node): if n.op == "output": - def get_node_type(a): + def get_node_type(a: Any) -> Any: return a.type n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type - def infer_symbolic_relations(self, n: Node): + def infer_symbolic_relations(self, n: Node) -> Any: n.type = self.convert_to_sympy_symbols(n.type) if n.op == "call_function": if n.target in _RULES: @@ -996,14 +998,14 @@ def infer_symbolic_relations(self, n: Node): if n.op == "output": - def get_node_type(a): + def get_node_type(a: Any) -> Any: return a.type n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type -def get_parameter(traced, target: str): +def get_parameter(traced: Any, target: str) -> torch.nn.Parameter: """ Returns the parameter given by ``target`` if it exists, otherwise throws an error. diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index 3c200a2a206f9..94c63a917783c 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import itertools import operator @@ -33,7 +32,7 @@ def split_result_tensors( return torch.split(result, splits) -def may_depend_on(a: Node, b: Node, search_depth: int = 6): +def may_depend_on(a: Node, b: Node, search_depth: int = 6) -> bool: """ Determine if one node depends on another in a torch.fx.Graph. @@ -70,7 +69,7 @@ def may_depend_on(a: Node, b: Node, search_depth: int = 6): return False -def are_nodes_independent(nodes: list[Node]): +def are_nodes_independent(nodes: list[Node]) -> bool: """ Check if all of the given nodes are pairwise-data independent. @@ -88,7 +87,7 @@ def are_nodes_independent(nodes: list[Node]): return True -def merge_matmul(in_mod: torch.nn.Module): +def merge_matmul(in_mod: torch.nn.Module) -> torch.fx.GraphModule: """ A graph transformation that merges matrix multiplication operations that share the same right-hand side operand into one large matrix multiplication. diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index 472ed0860b3de..42672357eb6c5 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -1,31 +1,55 @@ -# mypy: allow-untyped-defs import builtins import functools import warnings from collections.abc import Callable -from typing import Any +from typing import Any, TypeVar import torch import torch.fx - - -def embedding_override(self, input): +from torch.fx.node import Node +from torch.fx.proxy import Proxy + + +_C = TypeVar("_C", bound=Callable[..., Any]) + +__all__ = [ + "embedding_override", + "functional_relu_override", + "gen_constructor_wrapper", + "manual_meta_overrides", + "MetaAttribute", + "MetaDeviceAttribute", + "MetaProxy", + "MetaTracer", + "nn_layernorm_override", + "proxys_to_metas", + "symbolic_trace", + "torch_abs_override", + "torch_nn_relu_override", + "torch_relu_override", + "torch_where_override", +] + + +def embedding_override(self: torch.nn.Embedding, input: torch.Tensor) -> torch.Tensor: return torch.empty(*input.shape, self.weight.shape[-1], device="meta") -def nn_layernorm_override(self, input): +def nn_layernorm_override( + self: torch.nn.LayerNorm, input: torch.Tensor +) -> torch.Tensor: return input -def torch_relu_override(x): +def torch_relu_override(x: torch.Tensor) -> torch.Tensor: return x -def torch_nn_relu_override(self, x): +def torch_nn_relu_override(self: torch.nn.ReLU, x: torch.Tensor) -> torch.Tensor: return x -def functional_relu_override(x, inplace=False): +def functional_relu_override(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: if inplace: raise AssertionError( "dont support inplace functional.relu for metatensor analysis" @@ -33,19 +57,23 @@ def functional_relu_override(x, inplace=False): return x -def torch_where_override(condition, x, y): +def torch_where_override( + condition: torch.Tensor, x: torch.Tensor, y: torch.Tensor +) -> torch.Tensor: # torch.where returns the broadcasted tensor of condition, x, and y, # so hack it by using addition return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") -def torch_abs_override(input, *, out=None): +def torch_abs_override( + input: torch.Tensor, *, out: torch.Tensor | None = None +) -> torch.Tensor: if out is not None: raise AssertionError("Dont support in-place abs for MetaTensor analysis") return input -manual_meta_overrides: dict[Callable, Callable] = { +manual_meta_overrides: dict[Callable[..., Any], Callable[..., Any]] = { torch.nn.Embedding: embedding_override, torch.nn.LayerNorm: nn_layernorm_override, torch.relu: torch_relu_override, @@ -56,12 +84,14 @@ def torch_abs_override(input, *, out=None): } -def gen_constructor_wrapper(target): +def gen_constructor_wrapper( + target: _C, +) -> tuple[Callable[..., Any], _C]: @functools.wraps(target) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: proxy = None - def check_has_proxy(v): + def check_has_proxy(v: Any) -> None: if isinstance(v, torch.fx.Proxy): nonlocal proxy proxy = v @@ -78,23 +108,23 @@ def check_has_proxy(v): class MetaProxy(torch.fx.Proxy): - def install_tensor_meta(self, tensor_meta): + def install_tensor_meta(self, tensor_meta: torch.Tensor) -> None: self._tensor_meta = tensor_meta - def size(self, dim=None): + def size(self, dim: int | None = None) -> Any: if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.size(*[dim] if dim else []) return self.tracer.create_proxy( "call_method", "size", (self, dim) if dim else (self,), {} ) - def dim(self): + def dim(self) -> Any: if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dim() return self.tracer.create_proxy("call_method", "dim", (self,), {}) @property - def shape(self): + def shape(self) -> Any: if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.shape return self.tracer.create_proxy( @@ -102,7 +132,7 @@ def shape(self): ) @property - def dtype(self): + def dtype(self) -> Any: if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dtype return self.tracer.create_proxy( @@ -110,12 +140,12 @@ def dtype(self): ) @property - def device(self): + def device(self) -> "MetaDeviceAttribute": # Hack so we can track when devices are used. During meta-tensor propagation, # replace these values with a constant 'meta' return MetaDeviceAttribute(self, "device") - def __getattr__(self, k): + def __getattr__(self, k: str) -> Any: if k == "_tensor_meta": return self.__getattribute__(k) # note: not added to the graph yet, if this is a method call @@ -124,7 +154,7 @@ def __getattr__(self, k): class MetaAttribute(MetaProxy): - def __init__(self, root, attr: str): + def __init__(self, root: MetaProxy, attr: str) -> None: self.root = root self.attr = attr self.tracer = root.tracer @@ -140,7 +170,7 @@ def node(self): # type: ignore[override] ).node return self._node - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.tracer.create_proxy( "call_method", self.attr, (self.root,) + args, kwargs ) @@ -150,7 +180,7 @@ class MetaDeviceAttribute(MetaAttribute): pass -def proxys_to_metas(v): +def proxys_to_metas(v: Any) -> Any: if isinstance(v, MetaDeviceAttribute): return "meta" if isinstance(v, torch.fx.Proxy): @@ -169,14 +199,14 @@ class MetaTracer(torch.fx.Tracer): def create_proxy( self, - kind, - target, - args, - kwargs, - name=None, - type_expr=None, - proxy_factory_fn=None, - ): + kind: str, + target: torch.fx.node.Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: str | None = None, + type_expr: Any = None, + proxy_factory_fn: Callable[[Node], Proxy] | None = None, + ) -> MetaProxy: rv = super().create_proxy( kind, target, @@ -190,7 +220,7 @@ def create_proxy( if kind == "placeholder" and target in self.meta_args: rv.install_tensor_meta(self.meta_args[target]) - return rv + return rv # pyrefly: ignore [bad-return] if target in self.orig_fns: # NOTE: tensor constructors in PyTorch define the `device` argument as @@ -206,6 +236,7 @@ def create_proxy( kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) if kind == "call_function": + # pyrefly: ignore [no-matching-overload] meta_target = manual_meta_overrides.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) @@ -217,6 +248,7 @@ def create_proxy( raise AssertionError("orig_forward not set for call_module") self._disable_module_getattr = True try: + # pyrefly: ignore [bad-argument-type] mod = self.root.get_submodule(target) mod_type = type(mod) if mod_type in manual_meta_overrides: @@ -231,7 +263,7 @@ def create_proxy( self._disable_module_getattr = True try: attr_itr = self.root - atoms = target.split(".") + atoms = target.split(".") # pyrefly: ignore [missing-attribute] for atom in atoms: attr_itr = getattr(attr_itr, atom) if not isinstance(attr_itr, torch.Tensor): @@ -240,7 +272,7 @@ def create_proxy( finally: self._disable_module_getattr = False else: - return rv + return rv # pyrefly: ignore [bad-return] # TODO if not isinstance(rv, torch.fx.Proxy): @@ -249,15 +281,23 @@ def create_proxy( except Exception as e: warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") - return rv + return rv # pyrefly: ignore [bad-return] - def getattr(self, attr, attr_val, parameter_proxy_cache): + def getattr( + self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Proxy] + ) -> Any: if getattr(self, "_disable_module_getattr", False): return attr_val else: return super().getattr(attr, attr_val, parameter_proxy_cache) - def call_module(self, m, forward, args, kwargs): + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: self.orig_forward = forward return super().call_module(m, forward, args, kwargs) @@ -289,7 +329,7 @@ def path_of_module(self, mod: torch.nn.Module) -> str: return path raise - def proxy(self, node): + def proxy(self, node: torch.fx.Node) -> MetaProxy: return MetaProxy(node, self) def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index ce724d39f228f..2a296203c52c0 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -1,4 +1,40 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeAlias + + +if TYPE_CHECKING: + from collections.abc import Sequence + +__all__ = [ + "ApplyBroadcasting", + "BinConstraintD", + "BinConstraintT", + "BinaryConstraint", + "BVar", + "CalcConv", + "CalcMaxPool", + "CalcProduct", + "CanReshape", + "Conj", + "Constraint", + "DGreatestUpperBound", + "Disj", + "DVar", + "F", + "GetItem", + "GetItemTensor", + "IndexSelect", + "Prod", + "T", + "TGreatestUpperBound", + "Transpose", + "TVar", + "is_algebraic_expression", + "is_bool_expr", + "is_dim", +] + from torch.fx.experimental.migrate_gradual_types.operation import ( op_add, op_div, @@ -10,7 +46,7 @@ op_neq, op_sub, ) -from torch.fx.tensor_type import Dyn, TensorType +from torch.fx.tensor_type import _DynType, Dyn, TensorType class Constraint: @@ -18,55 +54,53 @@ class Constraint: class Conj(Constraint): - def __init__(self, conjuncts): + def __init__(self, conjuncts: Sequence[Constraint]) -> None: """ :param conjuncts: Conjunction of constraints """ - self.conjucts = conjuncts + self.conjucts = list(conjuncts) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Conj): - return self.conjucts == other.conjucts and self.conjucts == other.conjucts + return self.conjucts == other.conjucts else: return False - def __repr__(self): + def __repr__(self) -> str: return f"And({self.conjucts})" class Disj(Constraint): - def __init__(self, disjuncts): + def __init__(self, disjuncts: Sequence[Constraint]) -> None: """ :param disjuncts: Disjunction of constraints """ - self.disjuncts = disjuncts + self.disjuncts = list(disjuncts) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Disj): - return ( - self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts - ) + return self.disjuncts == other.disjuncts else: return False - def __repr__(self): + def __repr__(self) -> str: return f"Or({self.disjuncts})" class Prod(Constraint): - def __init__(self, products): + def __init__(self, products: Sequence[DVar | int | _DynType]) -> None: """ :param products: lists of dimensions to multiply """ - self.products = products + self.products = list(products) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Prod): - return self.products == other.products and self.products == other.products + return self.products == other.products else: return False - def __repr__(self): + def __repr__(self) -> str: return f"Product({self.products})" @@ -78,10 +112,10 @@ class T(Constraint): def __init__(self) -> None: pass - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, T) - def __repr__(self): + def __repr__(self) -> str: return "True" @@ -93,10 +127,10 @@ class F(Constraint): def __init__(self) -> None: pass - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, F) - def __repr__(self): + def __repr__(self) -> str: return "False" @@ -105,7 +139,7 @@ class BinaryConstraint(Constraint): Represents all binary operations """ - def __init__(self, lhs, rhs, op): + def __init__(self, lhs: _Operand, rhs: _Operand, op: str | None) -> None: """ :param lhs: lhs of the constraint :param rhs: rhs of the constraint @@ -115,7 +149,7 @@ def __init__(self, lhs, rhs, op): self.rhs = rhs self.op = op - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, BinaryConstraint): return ( self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op @@ -123,7 +157,7 @@ def __eq__(self, other): else: return False - def __repr__(self): + def __repr__(self) -> str: return f"({self.lhs} {self.op} {self.rhs})" @@ -132,7 +166,7 @@ class BinConstraintT(BinaryConstraint): Binary constraints about tensors """ - def __init__(self, lhs, rhs, op): + def __init__(self, lhs: _Operand, rhs: _Operand, op: str | None) -> None: if not ( (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) @@ -146,7 +180,7 @@ class BinConstraintD(BinaryConstraint): Binary constraints about dimensions """ - def __init__(self, lhs, rhs, op): + def __init__(self, lhs: _Operand, rhs: _Operand, op: str | None) -> None: if not (is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs)): raise AssertionError(f"Invalid lhs type: {type(lhs)}") if not (is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs)): @@ -160,7 +194,7 @@ class TGreatestUpperBound(Constraint): Greatest Upper bound for tensors with dynamic type """ - def __init__(self, res, rhs1, rhs2): + def __init__(self, res: TVar, rhs1: TVar, rhs2: TVar) -> None: """ :param res: tensor variable that stores the result of the output :param rhs1: tensor or tensor variable @@ -170,10 +204,10 @@ def __init__(self, res, rhs1, rhs2): self.rhs1 = rhs1 self.rhs2 = rhs2 - def __repr__(self): + def __repr__(self) -> str: return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, TGreatestUpperBound): return ( self.res == other.res @@ -189,7 +223,12 @@ class DGreatestUpperBound(Constraint): Greatest Upper bound for dimensions """ - def __init__(self, res, rhs1, rhs2): + def __init__( + self, + res: DVar | int | _DynType, + rhs1: DVar | int | _DynType, + rhs2: DVar | int | _DynType, + ) -> None: """ :param res: Dimension variable to store the result :param rhs1: dimension variable 1 @@ -206,10 +245,10 @@ def __init__(self, res, rhs1, rhs2): self.rhs1 = rhs1 self.rhs2 = rhs2 - def __repr__(self): + def __repr__(self) -> str: return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, DGreatestUpperBound): return ( self.res == other.res @@ -225,7 +264,7 @@ class CanReshape(Constraint): can_reshape constraint """ - def __init__(self, src, target): + def __init__(self, src: TVar, target: TensorType) -> None: """ :param src: tensor variable :param target: tensor @@ -233,10 +272,10 @@ def __init__(self, src, target): self.src = src self.target = target - def __repr__(self): + def __repr__(self) -> str: return f"can-reshape({self.src}, {self.target})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, CanReshape): return self.src == other.src and self.target == other.target else: @@ -244,7 +283,14 @@ def __eq__(self, other): class IndexSelect(Constraint): - def __init__(self, tensor_size, input_var, dim_replace, index, output): + def __init__( + self, + tensor_size: int, + input_var: TVar, + dim_replace: DVar | _DynType, + index: int, + output: TVar, + ) -> None: """ Args: input_var: input to index_select @@ -268,7 +314,7 @@ def __init__(self, tensor_size, input_var, dim_replace, index, output): self.index = index self.output = output - def __repr__(self): + def __repr__(self) -> str: return ( f" {self.output} = " f"IndexSelect({self.input_var}, " @@ -277,7 +323,7 @@ def __repr__(self): f"{self.index})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, IndexSelect): return ( self.tensor_size == other.tensor_size @@ -291,7 +337,9 @@ def __eq__(self, other): class Transpose(Constraint): - def __init__(self, tensor_size, input_var, index1, index2, output): + def __init__( + self, tensor_size: int, input_var: TVar, index1: int, index2: int, output: TVar + ) -> None: """ Args: tensor_size: current tensor size @@ -315,7 +363,7 @@ def __init__(self, tensor_size, input_var, index1, index2, output): self.index2 = index2 self.output = output - def __repr__(self): + def __repr__(self) -> str: return ( f" {self.output} = " f"Transpose({self.input_var}, " @@ -324,7 +372,7 @@ def __repr__(self): f"{self.index2})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Transpose): return ( self.tensor_size == other.tensor_size @@ -338,7 +386,9 @@ def __eq__(self, other): class GetItem(Constraint): - def __init__(self, tensor_size, index, res, input_var): + def __init__( + self, tensor_size: int, index: int, res: DVar, input_var: TVar + ) -> None: """ Constraint for getting item given a tensor size :param tensor_size: actual number @@ -354,10 +404,10 @@ def __init__(self, tensor_size, index, res, input_var): self.index = index self.input_var = input_var - def __repr__(self): + def __repr__(self) -> str: return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, GetItem): return ( self.res == other.res @@ -370,7 +420,13 @@ def __eq__(self, other): class GetItemTensor(Constraint): - def __init__(self, tensor_size, index_tuple, res, input_var): + def __init__( + self, + tensor_size: int, + index_tuple: tuple[None | slice, ...], + res: TVar, + input_var: TVar, + ) -> None: """ Constraint for getting item given a tensor size However, when the argument is a tuple, we will @@ -388,10 +444,10 @@ def __init__(self, tensor_size, index_tuple, res, input_var): self.index_tuple = index_tuple self.input_var = input_var - def __repr__(self): + def __repr__(self) -> str: return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, GetItemTensor): return ( self.res == other.res @@ -406,15 +462,15 @@ def __eq__(self, other): class CalcConv(Constraint): def __init__( self, - conv_result, - input_var, - c_out, - kernel, - padding, - stride, - dilation, - matching_constraint_vars, - ): + conv_result: TVar, + input_var: TVar, + c_out: int, + kernel: int | tuple[int, int], + padding: int | tuple[int, int], + stride: int | tuple[int, int], + dilation: int | tuple[int, int], + matching_constraint_vars: list[DVar], + ) -> None: """ :param conv_result: the convolution result :param input_var: input to convolution @@ -430,7 +486,7 @@ def __init__( self.dilation = dilation self.matching_constraint = matching_constraint_vars - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.conv_result} =" f" calc-conv({self.input_var}," @@ -439,7 +495,7 @@ def __repr__(self): f" {self.dilation})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, CalcConv): return ( self.conv_result == other.conv_result @@ -458,14 +514,14 @@ def __eq__(self, other): class CalcMaxPool(Constraint): def __init__( self, - maxpool_result, - input_var, - kernel, - padding, - stride, - dilation, - matching_constraint_vars, - ): + maxpool_result: TVar, + input_var: TVar, + kernel: int | tuple[int, int], + padding: int | tuple[int, int], + stride: int | tuple[int, int], + dilation: int | tuple[int, int], + matching_constraint_vars: list[DVar], + ) -> None: """ :param maxpool_result: the result of maxpool :param input_var: input to convolution @@ -479,7 +535,7 @@ def __init__( self.dilation = dilation self.matching_constraint = matching_constraint_vars - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.maxpool_result} =" f" calc-maxpool({self.input_var}," @@ -488,7 +544,7 @@ def __repr__(self): f" {self.dilation})" ) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, CalcMaxPool): return ( self.maxpool_result == other.maxpool_result @@ -504,7 +560,7 @@ def __eq__(self, other): class ApplyBroadcasting(Constraint): - def __init__(self, res1, res2, input1, input2): + def __init__(self, res1: TVar, res2: TVar, input1: TVar, input2: TVar) -> None: """ :param res1: resulting tensor 1 :param res2: resulting tensor 2 @@ -516,7 +572,7 @@ def __init__(self, res1, res2, input1, input2): self.input1 = input1 self.input2 = input2 - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, ApplyBroadcasting): return ( self.res1 == other.res1 @@ -527,7 +583,7 @@ def __eq__(self, other): else: return False - def __repr__(self): + def __repr__(self) -> str: return ( f"{self.res1}, {self.res2} =" f" apply-broadcasting({self.input1}," @@ -540,7 +596,9 @@ class CalcProduct(Constraint): Given correct dimensions, calculate the product for flatten accounting for Dyn """ - def __init__(self, start, end, flattened, dims_to_flatten): + def __init__( + self, start: int, end: int, flattened: TVar, dims_to_flatten: list[DVar] + ) -> None: """ :param start: start index :param end: end index @@ -561,7 +619,7 @@ def __init__(self, start, end, flattened, dims_to_flatten): self.dims_to_flatten = dims_to_flatten self.flattened = flattened - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, CalcProduct): return ( self.start == other.start @@ -573,7 +631,7 @@ def __eq__(self, other): else: return False - def __repr__(self): + def __repr__(self) -> str: return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" @@ -582,16 +640,16 @@ class TVar: Tensor variable with no tensor constructor """ - def __init__(self, tvar): + def __init__(self, tvar: int) -> None: """ :param tvar: tensor variable """ self.tvar = tvar - def __repr__(self): + def __repr__(self) -> str: return f"TV({self.tvar})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, TVar): return self.tvar == other.tvar else: @@ -603,16 +661,16 @@ class DVar: Dimension variable """ - def __init__(self, c): + def __init__(self, c: int) -> None: """ :param c: character or number """ self.c = c - def __repr__(self): + def __repr__(self) -> str: return f"DV({self.c})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, DVar): return self.c == other.c else: @@ -624,35 +682,52 @@ class BVar: Boolean variable """ - def __init__(self, c): + def __init__(self, c: int) -> None: """ :param c: character or number """ self.c = c - def __repr__(self): + def __repr__(self) -> str: return f"BV({self.c})" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, BVar): return self.c == other.c else: return False -def is_algebraic_expression(constraint): +_Operand: TypeAlias = ( + TVar + | TensorType + | DVar + | int + | float + | bool + | _DynType + | BinConstraintD + | Prod + | BVar + | Conj + | Disj + | None +) + + +def is_algebraic_expression(constraint: object) -> bool: if isinstance(constraint, BinConstraintD): return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] else: return isinstance(constraint, Prod) -def is_bool_expr(constraint): +def is_bool_expr(constraint: object) -> bool: if isinstance(constraint, BinConstraintD): return constraint.op in [op_gt, op_lt, op_neq, op_eq] else: return isinstance(constraint, (BVar, Conj, Disj)) -def is_dim(d): +def is_dim(d: object) -> bool: return isinstance(d, (DVar, int)) or d == Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 9f5a7e8064a2b..2bc3e128ea83c 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,8 +1,7 @@ -# mypy: allow-untyped-defs import operator import warnings -from collections.abc import Callable, Iterable -from typing import TypeVar +from collections.abc import Callable, Iterable, Sequence +from typing import TypeAlias, TypeVar from typing_extensions import ParamSpec import torch @@ -11,11 +10,13 @@ ApplyBroadcasting, BinConstraintD, BinConstraintT, + BVar, CalcConv, CalcMaxPool, CalcProduct, CanReshape, Conj, + Constraint, DGreatestUpperBound, Disj, DVar, @@ -49,6 +50,7 @@ gen_tensor_dims, gen_tvar, ) +from torch.fx.graph import Graph from torch.fx.node import Node, Target from torch.fx.tensor_type import Dyn, TensorType from torch.nn.modules.batchnorm import BatchNorm2d @@ -58,7 +60,9 @@ _T = TypeVar("_T") _P = ParamSpec("_P") -_INFERENCE_RULES: dict[Target, Callable] = {} +_SymbolDict: TypeAlias = dict[Node, TVar | DVar | BVar] + +_INFERENCE_RULES: dict[Target, Callable[..., tuple[list[Constraint], int]]] = {} MAX_TENSOR_RANK = 4 @@ -117,13 +121,15 @@ def register_inference_rule( def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: if call_target in _INFERENCE_RULES: raise RuntimeError(f"Inference rule already registered for {call_target}!") - _INFERENCE_RULES[call_target] = fn + _INFERENCE_RULES[call_target] = fn # pyrefly: ignore[unsupported-operation] return fn return register -def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): +def generate_flatten_constraints( + start_dim: int, end_dim: int, input: TVar, flattened: TVar, n: int, counter: int +) -> tuple[Conj, int]: d, counter = gen_tensor_dims(n, counter) c1 = BinConstraintT(input, TensorType(d), op_eq) start_dim = n if start_dim == -1 else abs(start_dim) @@ -134,7 +140,9 @@ def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counte @register_inference_rule(getattr) -def get_attr_inference_rule(n: Node, symbols, constraints, counter): +def get_attr_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ If the attribute is "device" then the tensor shape is preserved """ @@ -155,7 +163,9 @@ def get_attr_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch.bmm) -def bmm_inference_rule(n: Node, symbols, constraints, counter): +def bmm_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules @@ -227,7 +237,9 @@ def bmm_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("index_select") -def index_select_inference_rule(n: Node, symbols, constraints, counter): +def index_select_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ We constrain the second argument to a vector or Dyn. The output replaces the input with the shape of the vector @@ -256,7 +268,13 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): Disj( [ IndexSelect( - i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + i + 1, + symbols[ # pyrefly: ignore[bad-argument-type, bad-index] + n.args[0] + ], + dims[0], + n.args[1], + index_select, ) for i in range(MAX_TENSOR_RANK) ] @@ -268,7 +286,15 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): is_dyn, Disj( [ - IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + IndexSelect( + i + 1, + symbols[ # pyrefly: ignore[bad-argument-type, bad-index] + n.args[0] + ], + Dyn, + n.args[1], + index_select, + ) for i in range(MAX_TENSOR_RANK) ] ), @@ -279,7 +305,9 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("expand") -def expand_inference_rule(n: Node, symbols, constraints, counter): +def expand_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only @@ -308,13 +336,22 @@ def expand_inference_rule(n: Node, symbols, constraints, counter): e2_constraint = BinConstraintT( e2, TensorType( - [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + [ + arg + if isinstance(arg, int) + else symbols[arg] # pyrefly: ignore[bad-index] + for arg in n.args[1:] + ] ), op_eq, ) constraints, counter = gen_broadcasting_constraints( - e1, e2, symbols, counter, expand + e1, # pyrefly: ignore[bad-argument-type] + e2, + symbols, + counter, + expand, ) # constraint the output size @@ -341,7 +378,9 @@ def expand_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("contiguous") @register_inference_rule(torch.ones) @register_inference_rule(torch.zeros) -def equality_inference_rule(n: Node, symbols, constraints, counter): +def equality_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ We generate the constraint: input = output """ @@ -356,23 +395,27 @@ def equality_inference_rule(n: Node, symbols, constraints, counter): # then we have dimension variables else: for arg in n.args: - if not isinstance(symbols[arg], DVar): - raise AssertionError(f"Expected DVar, got {type(symbols[arg])}") - my_size = [symbols[arg] for arg in n.args] + if not isinstance(symbols[arg], DVar): # pyrefly: ignore[bad-index] + raise AssertionError( + f"Expected DVar, got {type(symbols[arg])}" # pyrefly: ignore[bad-index] + ) + my_size = [symbols[arg] for arg in n.args] # pyrefly: ignore[bad-index] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter elif isinstance(n.args[0], tuple): # then the tuple is the size if len(n.args[0]) > 4: raise AssertionError(f"Expected len <= 4, got {len(n.args[0])}") - my_size = [symbols[arg] for arg in n.args[0]] + my_size = [symbols[arg] for arg in n.args[0]] # pyrefly: ignore[bad-index] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter else: raise NotImplementedError("Method not yet implemented") @register_inference_rule("transpose") -def transpose_inference_rule(n: Node, symbols, constraints, counter): +def transpose_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Can be considered as a sequence of two index selects, so we generate constraints accordingly """ @@ -407,7 +450,9 @@ def transpose_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("type_as") -def type_inference_rule(n: Node, symbols, constraints, counter): +def type_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ We generate the constraint: input = output """ @@ -434,7 +479,9 @@ def type_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("masked_fill_") -def masked_fill_inference_rule(n: Node, symbols, constraints, counter): +def masked_fill_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Similar to addition. For now we implement the constraints when the argument is a boolean tensor. There is also a case for when @@ -463,11 +510,13 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch.nn.functional.embedding) -def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): +def embedding_inference_rule_functional( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") - embedding_dim_weights = symbols[n.args[1]] + embedding_dim_weights = symbols[n.args[1]] # pyrefly: ignore[bad-index] # will treat this as a static shape. So we will not use matching. weight_dims, counter = gen_tensor_dims(2, counter) @@ -480,7 +529,13 @@ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): @register_inference_rule(torch.nn.modules.sparse.Embedding) -def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def embedding_inference_rule( + n: Node, + module_instance: torch.nn.Embedding, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: """ The output shape differs from the input shape in the last dimension """ @@ -489,10 +544,12 @@ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, cou return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) -def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): +def gen_embedding_rules( + n: Node, symbols: _SymbolDict, embedding_dim: int | DVar, counter: int +) -> tuple[list[Constraint], int]: embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output - embedding_input = symbols[n.args[0]] + embedding_input = symbols[n.args[0]] # pyrefly: ignore[bad-index] input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) @@ -520,18 +577,22 @@ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): @register_inference_rule(torch.tensor) -def tensor_inference_rule(n: Node, symbols, constraints, counter): +def tensor_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ If the tensor is a scalar, we will skip it since we do not support scalars yet. We will add support in the future if it's needed. For our examples so far, scalars are not needed. """ - return [], counter + return [], counter # pyrefly: ignore[implicit-any] @register_inference_rule("reshape") @register_inference_rule("view") -def view_inference_rule(n: Node, symbols, constraints, counter): +def view_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Similar to reshape but with an extra condition on the strides """ @@ -564,7 +625,7 @@ def view_inference_rule(n: Node, symbols, constraints, counter): t2_type = TensorType(t2_type) # type: ignore[assignment] c1 = BinConstraintT(my_view, t2_type, op_eq) - c2 = CanReshape(src_var, t2_type) + c2 = CanReshape(src_var, t2_type) # pyrefly: ignore[bad-argument-type] # TODO: add the extra check mentioned here: # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view @@ -573,7 +634,9 @@ def view_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule("size") -def size_inference_rule(n: Node, symbols, constraints, counter): +def size_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ The constraint is just lhs = rhs. Ex: size = input_ids.size() @@ -583,7 +646,7 @@ def size_inference_rule(n: Node, symbols, constraints, counter): # generate the new variable size, counter = gen_tvar(counter) symbols[n] = size - input = symbols[n.args[0]] + input = symbols[n.args[0]] # pyrefly: ignore[bad-index] c = BinConstraintT(input, size, op_eq) return [c], counter @@ -593,9 +656,14 @@ def size_inference_rule(n: Node, symbols, constraints, counter): # generate the new variable size_index, counter = gen_dvar(counter) symbols[n] = size_index - input = symbols[n.args[0]] + input = symbols[n.args[0]] # pyrefly: ignore[bad-index] c2 = [ - GetItem(i + 1, n.args[1], size_index, input) + GetItem( + i + 1, + n.args[1], + size_index, + input, # pyrefly: ignore[bad-argument-type] + ) for i in range(MAX_TENSOR_RANK) ] c3 = BinConstraintD(0, size_index, op_leq) @@ -613,7 +681,7 @@ def size_inference_rule(n: Node, symbols, constraints, counter): raise NotImplementedError -def range_check(i, n): +def range_check(i: int, n: int) -> T | F: """ Checks if an index i is within range of a size n list Args: @@ -629,7 +697,9 @@ def range_check(i, n): @register_inference_rule(torch.cumsum) -def cumsum_inference_rule(n: Node, symbols, constraints, counter): +def cumsum_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Input and output shapes should be equal We should verify that the index is valid @@ -668,14 +738,18 @@ def cumsum_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(_assert_is_none) -def assert_inference_rule(n: Node, symbols, constraints, counter): +def assert_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if len(n.users) != 0: raise AssertionError(f"Expected no users, got {len(n.users)}") - return [], counter + return [], counter # pyrefly: ignore[implicit-any] @register_inference_rule(operator.getitem) -def getitem_inference_rule(n: Node, symbols, constraints, counter): +def getitem_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -732,7 +806,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): ] else: # TODO: we should figure out why there is a key-error here. - return [], counter + return [], counter # pyrefly: ignore[implicit-any] return [Disj([c1, *c2])], counter @@ -741,7 +815,9 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(operator.gt) -def gt_inference_rule(n: Node, symbols, constraints, counter): +def gt_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], (Node, int)): raise AssertionError(f"Expected Node or int, got {type(n.args[0])}") if not isinstance(n.args[1], (Node, int)): @@ -804,7 +880,9 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(operator.eq) -def eq_inference_rule(n: Node, symbols, constraints, counter): +def eq_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], (Node, int)): raise AssertionError(f"Expected Node or int, got {type(n.args[0])}") if not isinstance(n.args[1], (Node, int)): @@ -845,7 +923,9 @@ def eq_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(operator.ne) -def neq_inference_rule(n: Node, symbols, constraints, counter): +def neq_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ Translates to inconsistent in gradual types. To prove inequality, we should prove that @@ -973,7 +1053,9 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(operator.lt) -def lt_inference_rule(n: Node, symbols, constraints, counter): +def lt_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], (Node, int)): raise AssertionError(f"Expected Node or int, got {type(n.args[0])}") if not isinstance(n.args[1], (Node, int)): @@ -1018,7 +1100,9 @@ def lt_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch.full) -def full_inference_rule(n: Node, symbols, constraints, counter): +def full_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: full, counter = gen_tvar(counter) symbols[n] = full res = [] @@ -1026,7 +1110,9 @@ def full_inference_rule(n: Node, symbols, constraints, counter): if not isinstance(n.args[0], Iterable): raise AssertionError(f"Expected Iterable, got {type(n.args[0])}") for arg in n.args[0]: - dim = arg if isinstance(arg, int) else symbols[arg] + dim = ( + arg if isinstance(arg, int) else symbols[arg] # pyrefly: ignore[bad-index] + ) res.append(dim) c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] return [c], counter @@ -1034,12 +1120,14 @@ def full_inference_rule(n: Node, symbols, constraints, counter): # TODO normalize index @register_inference_rule(torch.arange) -def arange_inference_rule(n: Node, symbols, constraints, counter): +def arange_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: start = 0 step = 1 if len(n.args) == 1: - end = symbols[n.args[0]] + end = symbols[n.args[0]] # pyrefly: ignore[bad-index] else: raise NotImplementedError("Not yet implemented") @@ -1078,7 +1166,9 @@ def arange_inference_rule(n: Node, symbols, constraints, counter): ], counter -def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): +def gen_broadcasting_constraints( + e1: TVar, e2: TVar, symbols: _SymbolDict, counter: int, output_var: TVar +) -> tuple[list[Constraint], int]: # additional vars that don't correspond to expressions e11, counter = gen_tvar(counter) e22, counter = gen_tvar(counter) @@ -1095,7 +1185,9 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): @register_inference_rule("ne") @register_inference_rule(torch.add) @register_inference_rule(operator.add) -def broadcasting_inference_rule(n: Node, symbols, constraints, counter): +def broadcasting_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: # pyrefly: ignore[bad-return] op_code = None if n.target is operator.add or n.target is torch.add: op_code = op_add @@ -1111,7 +1203,13 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e1 = symbols[n.args[0]] e2 = symbols[n.args[1]] - return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + return gen_broadcasting_constraints( + e1, # pyrefly: ignore[bad-argument-type] + e2, # pyrefly: ignore[bad-argument-type] + symbols, + counter, + my_output, + ) else: raise NotImplementedError("Method not yet implemented") @@ -1168,7 +1266,9 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch.flatten) -def flatten_inference_rule(n: Node, symbols, constraints, counter): +def flatten_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -1199,7 +1299,12 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter): const = [] for i in range(1, MAX_TENSOR_RANK + 1): c, counter = generate_flatten_constraints( - start_dim, end_dim, input, flattened, i, counter + start_dim, + end_dim, + input, # pyrefly: ignore[bad-argument-type] + flattened, + i, + counter, ) const.append(c) @@ -1207,17 +1312,30 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch.nn.functional.layer_norm) -def layer_norm_functional(n: Node, symbols, constraints, counter): +def layer_norm_functional( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: """ We generate the constraint: input = output """ if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") - return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + return gen_layer_norm_constraints( + n, + n.args[1], # pyrefly: ignore[bad-argument-type] + symbols, + counter, + ) @register_inference_rule(torch.nn.LayerNorm) -def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def layer_norm_inference_rule( + n: Node, + module_instance: torch.nn.LayerNorm, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: """ Input and output shapes should be equal. Input should be consistent with the normalized_shape @@ -1229,10 +1347,12 @@ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, co ) -def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): +def gen_layer_norm_constraints( + n: Node, normalized_shape: Sequence[int], symbols: _SymbolDict, counter: int +) -> tuple[list[Constraint], int]: output, counter = gen_tvar(counter) symbols[n] = output - input = symbols[n.args[0]] + input = symbols[n.args[0]] # pyrefly: ignore[bad-index] input_dyn = BinConstraintT(input, Dyn, op_eq) output_dyn = BinConstraintT(output, Dyn, op_eq) @@ -1258,7 +1378,13 @@ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): @register_inference_rule(torch.nn.Dropout) @register_inference_rule(torch.nn.ReLU) -def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def relu_inference_rule( + n: Node, + module_instance: torch.nn.Module, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: """ Input and output shapes should be equal. """ @@ -1273,7 +1399,13 @@ def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter) @register_inference_rule(torch.nn.Linear) -def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def linear_inference_rule( + n: Node, + module_instance: torch.nn.Linear, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: """ Input and output sizes should be the same except for the last dimension If the input is Dyn, then so should the output @@ -1286,7 +1418,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte @register_inference_rule("dim") -def torch_dim_inference_rule(n: Node, symbols, constraints, counter): +def torch_dim_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") my_dim, counter = gen_dvar(counter) @@ -1313,12 +1447,16 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter): @register_inference_rule(torch._C._nn.linear) -def torch_linear_inference_rule(n: Node, symbols, constraints, counter): +def torch_linear_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") weight_dims, counter = gen_tensor_dims(2, counter) equality_constraint = BinConstraintT( - symbols[n.args[1]], TensorType(weight_dims), op_eq + symbols[n.args[1]], # pyrefly: ignore[bad-index] + TensorType(weight_dims), + op_eq, ) constraints, counter = linear_constraints( n, weight_dims[1], weight_dims[0], symbols, counter @@ -1326,10 +1464,16 @@ def torch_linear_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint] + constraints, counter -def linear_constraints(n: Node, in_features, out_features, symbols, counter): +def linear_constraints( + n: Node, + in_features: int | DVar, + out_features: int | DVar, + symbols: _SymbolDict, + counter: int, +) -> tuple[list[Constraint], int]: linear_output, counter = gen_tvar(counter) symbols[n] = linear_output - linear_input = symbols[n.args[0]] + linear_input = symbols[n.args[0]] # pyrefly: ignore[bad-index] input_dyn = BinConstraintT(linear_input, Dyn, op_eq) output_dyn = BinConstraintT(linear_output, Dyn, op_eq) @@ -1357,7 +1501,9 @@ def linear_constraints(n: Node, in_features, out_features, symbols, counter): return [Disj([c1, Disj(c2)])], counter -def add_layer_norm_constraints(input_dim, normalized_dim): +def add_layer_norm_constraints( + input_dim: list[DVar], normalized_dim: list[int] +) -> list[Constraint]: """ The constraints say that the type has te form: [*, 1024, 1024] while the normalized_dim have the form [1024, 1024] @@ -1372,16 +1518,21 @@ def add_layer_norm_constraints(input_dim, normalized_dim): return [F()] else: - constraints = [] + constraints: list[Constraint] = [] for i, n in zip(reversed(input_dim), reversed(normalized_dim)): constraints.append(BinConstraintD(i, n, op_consistency)) return constraints -def add_linear_constraints(dims1, dims2, in_features, out_features): +def add_linear_constraints( + dims1: list[DVar], + dims2: list[DVar], + in_features: int | DVar, + out_features: int | DVar, +) -> list[Constraint]: if len(dims1) != len(dims2): raise AssertionError(f"Expected same length, got {len(dims1)} vs {len(dims2)}") - constraints = [] + constraints: list[Constraint] = [] for i in range(len(dims1)): if i == len(dims1) - 1: constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) @@ -1393,7 +1544,9 @@ def add_linear_constraints(dims1, dims2, in_features, out_features): @register_inference_rule(torch.reshape) -def reshape_inference_rule(n: Node, symbols, constraints, counter): +def reshape_inference_rule( + n: Node, symbols: _SymbolDict, constraints: list[Constraint], counter: int +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -1405,13 +1558,19 @@ def reshape_inference_rule(n: Node, symbols, constraints, counter): t2 = n.args[1] t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] - c2 = CanReshape(src_var, t2_type) + c2 = CanReshape(src_var, t2_type) # pyrefly: ignore[bad-argument-type] return [c1, c2], counter @register_inference_rule(BatchNorm2d) -def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def batchnorm_inference_rule( + n: Node, + module_instance: BatchNorm2d, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -1434,7 +1593,13 @@ def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, cou @register_inference_rule(torch.nn.AdaptiveAvgPool2d) -def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def adaptive_inference_rule( + n: Node, + module_instance: torch.nn.AdaptiveAvgPool2d, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -1453,7 +1618,16 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun c2 = BinConstraintT( avg_pool, TensorType( - [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + [ + d1, + d2, + module_instance.output_size[ # pyrefly: ignore[bad-index, unsupported-operation] + 0 + ], + module_instance.output_size[ # pyrefly: ignore[bad-index, unsupported-operation] + 1 + ], + ] ), op_eq, ) @@ -1462,7 +1636,13 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun @register_inference_rule(Conv2d) -def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def conv2d_inference_rule( + n: Node, + module_instance: Conv2d, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") @@ -1481,12 +1661,12 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte c3 = CalcConv( my_conv, - input_var, + input_var, # pyrefly: ignore[bad-argument-type] module_instance.out_channels, - module_instance.kernel_size, - module_instance.padding, - module_instance.stride, - module_instance.dilation, + module_instance.kernel_size, # pyrefly: ignore[bad-argument-type] + module_instance.padding, # pyrefly: ignore[bad-argument-type] + module_instance.stride, # pyrefly: ignore[bad-argument-type] + module_instance.dilation, # pyrefly: ignore[bad-argument-type] [d1, d2, d3, d4], ) @@ -1496,7 +1676,13 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte @register_inference_rule(torch.nn.MaxPool2d) -def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): +def maxpool_inference_rule( + n: Node, + module_instance: torch.nn.MaxPool2d, + symbols: _SymbolDict, + constraints: list[Constraint], + counter: int, +) -> tuple[list[Constraint], int]: if not isinstance(n.args[0], Node): raise AssertionError(f"Expected Node, got {type(n.args[0])}") maxpool, counter = gen_tvar(counter) @@ -1510,7 +1696,7 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count c2 = CalcMaxPool( maxpool, - input_var, + input_var, # pyrefly: ignore[bad-argument-type] module_instance.kernel_size, module_instance.padding, module_instance.stride, @@ -1524,21 +1710,21 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count class ConstraintGenerator: - def __init__(self, traced, graph=None): + def __init__(self, traced: torch.nn.Module, graph: Graph | None = None) -> None: self.traced = traced # traced or tracer.root self.traced_params = dict(self.traced.named_parameters()) self.constraints = [] self.symbol_dict = {} self.graph = traced.graph if hasattr(traced, "graph") else graph - def generate_constraints(self, counter=0): + def generate_constraints(self, counter: int = 0) -> tuple[Conj, int]: """ Iterate through every node and generate constraints Effect: self.constraints will be populated with the final constraints """ graph = self.graph - all_constraints = [] + all_constraints: list[Constraint] = [] # pyrefly: ignore [missing-attribute] for n in graph.nodes: @@ -1547,7 +1733,9 @@ def generate_constraints(self, counter=0): return Conj(all_constraints), counter - def generate_constraints_node(self, n: Node, counter): + def generate_constraints_node( + self, n: Node, counter: int + ) -> tuple[list[Constraint], int]: """ Generate constraints the given node: Currently supported operations: @@ -1586,7 +1774,9 @@ def generate_constraints_node(self, n: Node, counter): ) elif n.op == "call_module": - module_instance = self.traced.get_submodule(n.target) + module_instance = self.traced.get_submodule( + n.target # pyrefly: ignore[bad-argument-type] + ) if type(module_instance) in _INFERENCE_RULES: return _INFERENCE_RULES[type(module_instance)]( n, module_instance, self.symbol_dict, self.constraints, counter @@ -1607,7 +1797,9 @@ def generate_constraints_node(self, n: Node, counter): ) elif n.op == "get_attr": - t = self.traced_params.get(n.target, None) + t = self.traced_params.get( # pyrefly: ignore[no-matching-overload] + n.target, None + ) if isinstance(t, torch.Tensor): if len(t.shape) > 0: @@ -1618,12 +1810,12 @@ def generate_constraints_node(self, n: Node, counter): return [BinConstraintT(output, attr_type, op_eq)], counter else: # scalar? - return [], counter + return [], counter # pyrefly: ignore[implicit-any] else: - return [], counter + return [], counter # pyrefly: ignore[implicit-any] elif n.op == "output": - return [], counter + return [], counter # pyrefly: ignore[implicit-any] else: raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index 9ec1006e7fedf..20060b786c0d4 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -1,7 +1,6 @@ -# mypy: ignore-errors import copy import itertools -from collections.abc import Callable +from collections.abc import Callable, Sequence from torch.fx.experimental.migrate_gradual_types.constraint import ( ApplyBroadcasting, @@ -47,14 +46,54 @@ gen_nat_constraints, gen_tensor_dims, ) -from torch.fx.tensor_type import Dyn, TensorType - - -_TRANSFORMATION_RULES: dict[Constraint, Callable] = {} - - -def register_transformation_rule(call_target): - def register(fn): +from torch.fx.tensor_type import _DynType, Dyn, TensorType + + +__all__ = [ + "apply_padding", + "broadcast_dim", + "calc_last_two_dims", + "create_equality_constraints_for_broadcasting", + "gen_all_reshape_possibilities", + "gen_broadcasting_constraints", + "gen_consistency_constraints", + "gen_greatest_upper_bound", + "gen_lists_of_dims", + "generate_all_broadcasting_possibilities_no_padding", + "generate_all_int_dyn_dim_possibilities", + "generate_binconstraint_d", + "generate_binconstraint_t", + "generate_broadcasting", + "generate_calc_conv", + "generate_calc_maxpool", + "generate_calc_product", + "generate_conj", + "generate_d_gub", + "generate_disj", + "generate_gub", + "generate_reshape", + "is_dim_div_by_target", + "is_target_div_by_dim", + "no_broadcast_dim_with_index", + "register_transformation_rule", + "transform_constraint", + "transform_get_item", + "transform_get_item_tensor", + "transform_index_select", + "transform_transpose", + "valid_index", + "valid_index_tensor", +] + + +_TransformFn = Callable[[Constraint, int], tuple[Constraint, int]] +_TRANSFORMATION_RULES: dict[type, _TransformFn] = {} + + +def register_transformation_rule( + call_target: type[Constraint], +) -> Callable[[_TransformFn], _TransformFn]: + def register(fn: _TransformFn) -> _TransformFn: if call_target in _TRANSFORMATION_RULES: raise RuntimeError( f"Transformation rule already registered for {call_target}!" @@ -65,7 +104,7 @@ def register(fn): return register -def valid_index(index, dims): +def valid_index(index: int, dims: list[DVar]) -> Constraint: """ Given a list of dimensions, checks if an index is valid in the list """ @@ -77,10 +116,12 @@ def valid_index(index, dims): @register_transformation_rule(Transpose) -def transform_transpose(constraint, counter): +def transform_transpose(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Similar to a sequence of two index-selects """ + if not isinstance(constraint, Transpose): + raise TypeError(type(constraint)) dims, counter = gen_tensor_dims(constraint.tensor_size, counter) is_valid_index1 = valid_index(constraint.index1, dims) is_valid_index2 = valid_index(constraint.index2, dims) @@ -104,21 +145,25 @@ def transform_transpose(constraint, counter): @register_transformation_rule(IndexSelect) -def transform_index_select(constraint, counter): +def transform_index_select( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ The constraints consider the given tensor size, checks if the index is valid and if so, generates a constraint for replacing the input dimension with the required dimension """ + if not isinstance(constraint, IndexSelect): + raise TypeError(type(constraint)) dims, counter = gen_tensor_dims(constraint.tensor_size, counter) is_valid_index = valid_index(constraint.index, dims) nat_constraints = gen_nat_constraints(dims) # if the index is valid then replace the input dimension with the new dimension # otherwise the dimension will not be replaced and the clause will contain False + new_dims = copy.deepcopy(dims) if is_valid_index == T(): - new_dims = copy.deepcopy(dims) - new_dims[constraint.index] = constraint.dim_replace + new_dims[constraint.index] = constraint.dim_replace # type: ignore[unsupported-operation] transformed_constraint = Conj( [ @@ -134,7 +179,7 @@ def transform_index_select(constraint, counter): @register_transformation_rule(GetItem) -def transform_get_item(constraint, counter): +def transform_get_item(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ generate an equality of the form: t = [a1, ..., an] @@ -149,6 +194,8 @@ def transform_get_item(constraint, counter): Returns: simplified constraints for GetItem """ + if not isinstance(constraint, GetItem): + raise TypeError(type(constraint)) dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) @@ -170,7 +217,7 @@ def transform_get_item(constraint, counter): return Conj(all_constraints), counter -def valid_index_tensor(index, dims): +def valid_index_tensor(index: tuple[None | slice, ...], dims: list[DVar]) -> Constraint: """ if the slice instances exceed the length of the dimensions then this is a type error so we return False @@ -186,7 +233,9 @@ def valid_index_tensor(index, dims): @register_transformation_rule(GetItemTensor) -def transform_get_item_tensor(constraint, counter): +def transform_get_item_tensor( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models @@ -200,6 +249,8 @@ def transform_get_item_tensor(constraint, counter): slice with default arguments does not change the rank """ + if not isinstance(constraint, GetItemTensor): + raise TypeError(type(constraint)) if not isinstance(constraint.index_tuple, tuple): raise AssertionError( f"Expected tuple for index_tuple, got {type(constraint.index_tuple)}" @@ -212,7 +263,8 @@ def transform_get_item_tensor(constraint, counter): # generate a place-holder list of the right rank # where "slice" does not contribute to the rank and "None" does none_c = constraint.index_tuple.count(None) - resulting_tensor_dims = (none_c + len(dims)) * [None] + # list invariance: [None] * n types as list[None], but elements are reassigned to int/DVar + resulting_tensor_dims: list[int | DVar | None] = [None] * (none_c + len(dims)) # type: ignore[assignment] dim_index = 0 for i in range(len(constraint.index_tuple)): @@ -251,10 +303,14 @@ def transform_get_item_tensor(constraint, counter): @register_transformation_rule(BinConstraintT) -def generate_binconstraint_t(constraint, counter): +def generate_binconstraint_t( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transform binary constraints for tensors """ + if not isinstance(constraint, BinConstraintT): + raise TypeError(type(constraint)) # precision constraints if constraint.op == op_precision: @@ -280,6 +336,8 @@ def generate_binconstraint_t(constraint, counter): + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] ) return Conj(new_dim_constraints), counter + else: + return constraint, counter # matching elif constraint.op == op_matching: @@ -342,15 +400,21 @@ def generate_binconstraint_t(constraint, counter): @register_transformation_rule(BinConstraintD) -def generate_binconstraint_d(constraint, counter): +def generate_binconstraint_d( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transform binary constraints for dimensions """ + if not isinstance(constraint, BinConstraintD): + raise TypeError(type(constraint)) if constraint.op == op_precision: if isinstance(constraint.lhs, int): return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter elif constraint.lhs == Dyn: return T(), counter + else: + return constraint, counter elif constraint.op == op_consistency: return ( @@ -369,10 +433,12 @@ def generate_binconstraint_d(constraint, counter): @register_transformation_rule(Conj) -def generate_conj(constraint, counter): +def generate_conj(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Transform conjunctions """ + if not isinstance(constraint, Conj): + raise TypeError(type(constraint)) new = [] for c in constraint.conjucts: new_c, counter = transform_constraint(c, counter) @@ -381,10 +447,12 @@ def generate_conj(constraint, counter): @register_transformation_rule(Disj) -def generate_disj(constraint, counter): +def generate_disj(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Transform disjunctions """ + if not isinstance(constraint, Disj): + raise TypeError(type(constraint)) new = [] for c in constraint.disjuncts: new_c, counter = transform_constraint(c, counter) @@ -393,11 +461,13 @@ def generate_disj(constraint, counter): @register_transformation_rule(TGreatestUpperBound) -def generate_gub(constraint, counter): +def generate_gub(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound on dimensions """ + if not isinstance(constraint, TGreatestUpperBound): + raise TypeError(type(constraint)) c1 = Conj( [ Disj( @@ -416,10 +486,12 @@ def generate_gub(constraint, counter): @register_transformation_rule(DGreatestUpperBound) -def generate_d_gub(constraint, counter): +def generate_d_gub(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Transform greatest upper bound for dimensions into equality constraints """ + if not isinstance(constraint, DGreatestUpperBound): + raise TypeError(type(constraint)) c1 = Conj( [ BinConstraintD(constraint.rhs1, Dyn, op_eq), @@ -442,7 +514,9 @@ def generate_d_gub(constraint, counter): @register_transformation_rule(CalcConv) -def generate_calc_conv(constraint, counter): +def generate_calc_conv(constraint: Constraint, counter: int) -> tuple[Constraint, int]: + if not isinstance(constraint, CalcConv): + raise TypeError(type(constraint)) d, counter = gen_tensor_dims(4, counter) conv_result = TensorType([d[0], d[1], d[2], d[3]]) @@ -475,10 +549,14 @@ def generate_calc_conv(constraint, counter): @register_transformation_rule(CalcMaxPool) -def generate_calc_maxpool(constraint, counter): +def generate_calc_maxpool( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transform maxpool constraints """ + if not isinstance(constraint, CalcMaxPool): + raise TypeError(type(constraint)) d, counter = gen_tensor_dims(4, counter) maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) @@ -503,10 +581,14 @@ def generate_calc_maxpool(constraint, counter): @register_transformation_rule(CalcProduct) -def generate_calc_product(constraint, counter): +def generate_calc_product( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transform flatten constraints """ + if not isinstance(constraint, CalcProduct): + raise TypeError(type(constraint)) start = constraint.start end = constraint.end dims = constraint.dims_to_flatten @@ -524,7 +606,7 @@ def generate_calc_product(constraint, counter): all_possibilities = generate_all_int_dyn_dim_possibilities(mid) - all_constraints = [] + all_constraints: list[Constraint] = [] for p in all_possibilities: p = list(p) @@ -575,10 +657,12 @@ def generate_calc_product(constraint, counter): @register_transformation_rule(CanReshape) -def generate_reshape(constraint, counter): +def generate_reshape(constraint: Constraint, counter: int) -> tuple[Constraint, int]: """ Transform reshape constraints """ + if not isinstance(constraint, CanReshape): + raise TypeError(type(constraint)) d, counter = gen_tensor_dims(4, counter) d1 = d[0] @@ -710,10 +794,14 @@ def generate_reshape(constraint, counter): @register_transformation_rule(ApplyBroadcasting) -def generate_broadcasting(constraint, counter): +def generate_broadcasting( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transform broadcasting constraints """ + if not isinstance(constraint, ApplyBroadcasting): + raise TypeError(type(constraint)) e11, e12 = constraint.res1, constraint.res2 e1, e2 = constraint.input1, constraint.input2 @@ -784,7 +872,9 @@ def generate_broadcasting(constraint, counter): ) -def transform_constraint(constraint: Constraint, counter: int): +def transform_constraint( + constraint: Constraint, counter: int +) -> tuple[Constraint, int]: """ Transforms a constraint into a simpler constraint. Ex: precision and consistency are transformed to equality @@ -802,7 +892,9 @@ def transform_constraint(constraint: Constraint, counter: int): return constraint, counter -def calc_last_two_dims(constraint, d: list[DVar]): +def calc_last_two_dims( + constraint: CalcConv | CalcMaxPool, d: list[DVar] +) -> tuple[Constraint, Constraint]: """ Generates constraints for the last two dimensions of a convolution or a maxpool output Args: @@ -874,7 +966,9 @@ def calc_last_two_dims(constraint, d: list[DVar]): return c4, c5 -def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]): +def generate_all_int_dyn_dim_possibilities( + my_list: list[DVar], +) -> list[tuple[BinConstraintD, ...]]: """ Generate all possibilities of being equal or not equal to dyn for my_list Args: @@ -896,7 +990,9 @@ def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]): return all_possibilities -def is_target_div_by_dim(target: list[int], dim: list[DVar]): +def is_target_div_by_dim( + target: Sequence[DVar | int | _DynType], dim: DVar | Prod +) -> BinConstraintD: """ Generate constraints to check if the target dimensions are divisible by the input dimensions Args: @@ -909,7 +1005,9 @@ def is_target_div_by_dim(target: list[int], dim: list[DVar]): return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) -def is_dim_div_by_target(target: list[int], dim: list[DVar]): +def is_dim_div_by_target( + target: Sequence[DVar | int | _DynType], dim: DVar | Prod +) -> BinConstraintD: """ Generate constraints to check if the input dimensions is divisible by the target dimensions Args: @@ -922,7 +1020,9 @@ def is_dim_div_by_target(target: list[int], dim: list[DVar]): return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) -def gen_all_reshape_possibilities(list_of_dims, target): +def gen_all_reshape_possibilities( + list_of_dims: list[DVar], target: Sequence[DVar | int | _DynType] +) -> Constraint: """ Consider all possibilities what the input dimensions could be (number or dynamic) Then generate the appropriate constraints using multiplication or mod depending on the possibility @@ -942,7 +1042,7 @@ def gen_all_reshape_possibilities(list_of_dims, target): all_constraints = [] for p in all_possibilities: - to_multiply = [] + to_multiply: list[DVar] = [] p = list(p) @@ -950,7 +1050,7 @@ def gen_all_reshape_possibilities(list_of_dims, target): if not isinstance(constraint, BinConstraintD): raise AssertionError(f"Expected BinConstraintD, got {type(constraint)}") if constraint.op == op_neq: - to_multiply.append(constraint.lhs) + to_multiply.append(constraint.lhs) # type: ignore[arg-type] if not to_multiply: all_constraints.append(Conj(p)) @@ -967,7 +1067,14 @@ def gen_all_reshape_possibilities(list_of_dims, target): return Disj(all_constraints) -def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): +def broadcast_dim( + tensor_input1: Sequence[DVar | None], + tensor_input2: Sequence[DVar], + res1: Sequence[DVar], + res2: Sequence[DVar], + index: int, + padding: bool = False, +) -> Constraint: """ Apply broadcasting to the 'index' dimension of tensor_input1. Args: @@ -989,7 +1096,7 @@ def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False # then the inputs are the same length so they all have dimensions at "index" return Conj( [ - BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(tensor_input1[index], 1, op_eq), # type: ignore[arg-type] BinConstraintD(res1[index], res2[index], op_eq), BinConstraintD(res2[index], tensor_input2[index], op_eq), ] @@ -1014,7 +1121,7 @@ def apply_padding( d11: list[DVar], d12: list[DVar], counter: int, -): +) -> tuple[Constraint, int]: """ We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results @@ -1080,7 +1187,7 @@ def apply_padding( def no_broadcast_dim_with_index( d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int -): +) -> Constraint: """ Args: d1: input 1 @@ -1115,7 +1222,9 @@ def no_broadcast_dim_with_index( ) -def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): +def gen_lists_of_dims( + num_tensors: int, dim_size: int, counter: int +) -> tuple[list[list[DVar]], int]: """ Generate lists of DVar to represent tensor dimensions Args: @@ -1144,7 +1253,7 @@ def create_equality_constraints_for_broadcasting( d2: list[DVar], d11: list[DVar], d12: list[DVar], -): +) -> list[BinConstraintT]: """ Create equality constraints for when no broadcasting occurs Args: @@ -1168,7 +1277,9 @@ def create_equality_constraints_for_broadcasting( return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] -def gen_consistency_constraints(constraint: Constraint, counter: int): +def gen_consistency_constraints( + constraint: BinConstraintT, counter: int +) -> tuple[list[Constraint], int]: """ Args: constraint: Consistency constraint on tensors @@ -1178,7 +1289,7 @@ def gen_consistency_constraints(constraint: Constraint, counter: int): """ - all_constraints = [] + all_constraints: list[Constraint] = [] for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) @@ -1203,7 +1314,9 @@ def gen_consistency_constraints(constraint: Constraint, counter: int): return all_constraints, counter -def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): +def gen_greatest_upper_bound( + constraint: TGreatestUpperBound, counter: int +) -> tuple[list[Constraint], int]: """ Args: constraint: Greatest upper bound on tensors @@ -1213,10 +1326,10 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): """ - all_constraints = [] + all_constraints: list[Constraint] = [] for i in range(1, MAX_TENSOR_RANK + 1): - c = [] + c: list[Constraint] = [] dims1, counter = gen_tensor_dims(i, counter) c1tensor = TensorType(dims1) @@ -1249,7 +1362,7 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): def generate_all_broadcasting_possibilities_no_padding( d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar] -): +) -> Constraint: """ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. We look at all combinations for all dimensions in d1 and d2 @@ -1279,7 +1392,7 @@ def generate_all_broadcasting_possibilities_no_padding( def gen_broadcasting_constraints( e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int -): +) -> tuple[Constraint, Constraint, Constraint, list[BinConstraintD], int]: """ Simulates broadcasting on e1 and e2 and returns the results respectively in e11 and e12. Because of gradual types, diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index 6feccb3483fd7..73414d628bb77 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -1,9 +1,29 @@ -# mypy: allow-untyped-defs +from typing import Any, TypeAlias + +import torch + + +__all__ = [ + "evaluate_conditional_with_constraints", + "iterate_till_fixed_point", + "transform_algebraic_expression", + "transform_all_constraints", + "transform_all_constraints_trace_time", + "transform_dimension", + "transform_to_z3", + "transform_var", +] + + +# z3 is an optional dependency with no type stubs, so we use aliases for its types. +_Z3Expr: TypeAlias = Any +_Z3Result: TypeAlias = Any from torch.fx.experimental.migrate_gradual_types.constraint import ( BinConstraintD, BinConstraintT, BVar, Conj, + Constraint, Disj, DVar, F, @@ -32,7 +52,9 @@ op_neq, op_sub, ) -from torch.fx.tensor_type import Dyn, TensorType +from torch.fx.graph import Graph +from torch.fx.node import Node +from torch.fx.tensor_type import _DynType, Dyn, TensorType try: @@ -46,7 +68,9 @@ HAS_Z3 = True - def transform_to_z3(constraint, counter, dimension_dict): + def transform_to_z3( + constraint: Constraint, counter: int, dimension_dict: dict[int, int] + ) -> tuple[_Z3Expr, int]: if isinstance(constraint, Conj): conjuncts = [] for c in constraint.conjucts: @@ -69,8 +93,16 @@ def transform_to_z3(constraint, counter, dimension_dict): elif isinstance(constraint, BinConstraintT): if constraint.op == op_eq: - lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_var( + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, + ) + rhs, counter = transform_var( + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, + ) return (lhs == rhs), counter else: @@ -80,7 +112,9 @@ def transform_to_z3(constraint, counter, dimension_dict): if constraint.op == op_eq: if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): transformed_rhs, counter = transform_to_z3( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) transformed_lhs = z3.Bool(constraint.lhs.c) return transformed_lhs == transformed_rhs, counter @@ -88,10 +122,14 @@ def transform_to_z3(constraint, counter, dimension_dict): elif is_dim(constraint.lhs) and is_dim(constraint.rhs): # with dimension transformations we consider the encoding lhs, counter = transform_dimension( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_dimension( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) return lhs == rhs, counter @@ -99,10 +137,14 @@ def transform_to_z3(constraint, counter, dimension_dict): # then we have an algebraic expression which means that we disregard the # first element of the encoding lhs, counter = transform_algebraic_expression( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_algebraic_expression( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) return lhs == rhs, counter @@ -113,15 +155,19 @@ def transform_to_z3(constraint, counter, dimension_dict): if not is_dim(constraint.rhs): raise AssertionError("Expected rhs to be a dimension") lhs, counter = transform_dimension( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_dimension( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) if constraint.rhs == Dyn or constraint.lhs == Dyn: if constraint.rhs == Dyn: return lhs.arg(0) == 1, counter - elif constraint.lhs == Dyn: + else: return rhs.arg(0) == 1, counter # if one of the instances is a number @@ -137,7 +183,7 @@ def transform_to_z3(constraint, counter, dimension_dict): counter, ) - elif isinstance(constraint.rhs, int): + else: return ( z3.Or( [ @@ -173,10 +219,14 @@ def transform_to_z3(constraint, counter, dimension_dict): if not (is_dim(constraint.lhs) and is_dim(constraint.rhs)): raise AssertionError("Expected both lhs and rhs to be dimensions") lhs, counter = transform_algebraic_expression( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_algebraic_expression( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) return lhs <= rhs, counter @@ -184,10 +234,14 @@ def transform_to_z3(constraint, counter, dimension_dict): if not (is_dim(constraint.lhs) and is_dim(constraint.rhs)): raise AssertionError("Expected both lhs and rhs to be dimensions") lhs, counter = transform_algebraic_expression( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_algebraic_expression( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) return lhs > rhs, counter @@ -195,10 +249,14 @@ def transform_to_z3(constraint, counter, dimension_dict): if not (is_dim(constraint.lhs) and is_dim(constraint.rhs)): raise AssertionError("Expected both lhs and rhs to be dimensions") lhs, counter = transform_algebraic_expression( - constraint.lhs, counter, dimension_dict + constraint.lhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) rhs, counter = transform_algebraic_expression( - constraint.rhs, counter, dimension_dict + constraint.rhs, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, ) return lhs < rhs, counter @@ -208,7 +266,11 @@ def transform_to_z3(constraint, counter, dimension_dict): else: raise NotImplementedError("Operation not yet implemented") - def transform_var(tensor, counter, dimension_dict): + def transform_var( + tensor: TVar | TensorType | _DynType, + counter: int, + dimension_dict: dict[int, int], + ) -> tuple[_Z3Expr, int]: """ Transforms tensor variables to a format understood by z3 Args: @@ -217,7 +279,7 @@ def transform_var(tensor, counter, dimension_dict): """ if isinstance(tensor, TensorType): - res = [] + res: list[_Z3Expr] = [] for t in tensor.__args__: transformed, counter = transform_dimension(t, counter, dimension_dict) res.append(transformed) @@ -232,6 +294,10 @@ def transform_var(tensor, counter, dimension_dict): return tensor_type.tensor3(res[0], res[1], res[2]), counter elif len(tensor.__args__) == 4: return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + else: + raise AssertionError( + f"Unexpected tensor args length: {len(tensor.__args__)}" + ) elif tensor == Dyn: return z3_dyn, counter @@ -239,7 +305,12 @@ def transform_var(tensor, counter, dimension_dict): elif isinstance(tensor, TVar): return z3.Const(tensor.tvar, tensor_type), counter - def transform_dimension(dimension, counter, dimension_dict): + else: + raise NotImplementedError(f"Unsupported tensor type: {type(tensor)}") + + def transform_dimension( + dimension: DVar | int | _DynType, counter: int, dimension_dict: dict[int, int] + ) -> tuple[_Z3Expr, int]: """ Takes a dimension variable or a number and transforms it to a tuple according to our scheme @@ -266,7 +337,14 @@ def transform_dimension(dimension, counter, dimension_dict): dimension_dict[dimension.c] = counter return D(z3.Int(counter), z3.Int(dimension.c)), counter - def transform_algebraic_expression(expr, counter, dimension_dict): + else: + raise NotImplementedError(f"Unsupported dimension type: {type(dimension)}") + + def transform_algebraic_expression( + expr: DVar | int | _DynType | Prod | BinConstraintD, + counter: int, + dimension_dict: dict[int, int], + ) -> tuple[_Z3Expr, int]: """ Transforms an algebraic expression to z3 format Args: @@ -280,7 +358,11 @@ def transform_algebraic_expression(expr, counter, dimension_dict): raise AssertionError("Expected algebraic expression or dimension") if is_dim(expr): - transformed, counter = transform_dimension(expr, counter, dimension_dict) + transformed, counter = transform_dimension( + expr, # pyrefly: ignore[bad-argument-type] + counter, + dimension_dict, + ) return transformed.arg(1), counter elif isinstance(expr, Prod): @@ -294,13 +376,17 @@ def transform_algebraic_expression(expr, counter, dimension_dict): elif is_algebraic_expression(expr): lhs, counter = transform_algebraic_expression( - expr.lhs, counter, dimension_dict + expr.lhs, # pyrefly: ignore[missing-attribute] + counter, + dimension_dict, ) rhs, counter = transform_algebraic_expression( - expr.rhs, counter, dimension_dict + expr.rhs, # pyrefly: ignore[missing-attribute] + counter, + dimension_dict, ) - if expr.op == op_sub: + if expr.op == op_sub: # pyrefly: ignore[missing-attribute] c = lhs - rhs elif expr.op == op_add: @@ -323,31 +409,25 @@ def transform_algebraic_expression(expr, counter, dimension_dict): else: raise RuntimeError - def transform_all_constraints(traced, counter=0): + def transform_all_constraints(traced: torch.nn.Module, counter: int = 0) -> _Z3Expr: """ Given a trace, generates constraints and transforms them to z3 format """ - dimension_dict = {} # type: ignore[var-annotated] + dimension_dict: dict[int, int] = {} generator = ConstraintGenerator(traced) new_constraints, counter = generator.generate_constraints(counter) - # print(new_constraints.conjucts[0]) - # print(*new_constraints.conjucts, sep='\n') - - # transform precision, matching, consistency till obtaining a fixed point new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - # print(new_constraints) - # print(new_constraints.conjucts) - # new_constraints.conjucts = new_constraints.conjucts[:-1] - # print(*new_constraints.conjucts, sep='\n') transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) # print(transformed) return transformed - def iterate_till_fixed_point(constraints, counter): + def iterate_till_fixed_point( + constraints: Constraint, counter: int + ) -> tuple[Constraint, int]: """ Transform constraints till reaching a fixed point """ @@ -357,7 +437,9 @@ def iterate_till_fixed_point(constraints, counter): constraints, counter = transform_constraint(constraints, counter) return constraints, counter - def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + def transform_all_constraints_trace_time( + tracer_root: torch.nn.Module, graph: Graph, node: Node, counter: int = 0 + ) -> tuple[_Z3Expr, _Z3Expr]: """ Takes a node and a graph and generates two sets of constraints. One set constraints the node's constraints and another set @@ -373,7 +455,7 @@ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): its negation. """ - dimension_dict = {} # type: ignore[var-annotated] + dimension_dict: dict[int, int] = {} generator = ConstraintGenerator(tracer_root, graph) new_constraints, counter = generator.generate_constraints(counter) @@ -394,10 +476,14 @@ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): # we make sure the constraint is of the form: # c = b where b is a boolean expression # and we consider b (constraint.rhs) for transformation + if not isinstance(condition_constraint, BinConstraintD): + raise TypeError(type(condition_constraint)) if not isinstance(condition_constraint.lhs, BVar): raise AssertionError(f"Expected BVar, got {type(condition_constraint.lhs)}") if not is_bool_expr(condition_constraint.rhs): raise AssertionError("Expected bool expression for rhs") + if not isinstance(condition_constraint.rhs, Constraint): + raise TypeError(type(condition_constraint.rhs)) condition_constraint_rhs = condition_constraint.rhs # transform the condition constraint @@ -420,8 +506,12 @@ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): ) def evaluate_conditional_with_constraints( - tracer_root, graph, node, counter=0, user_constraints=None - ): + tracer_root: torch.nn.Module, + graph: Graph, + node: Node, + counter: int = 0, + user_constraints: _Z3Expr | None = None, + ) -> tuple[_Z3Result, _Z3Result]: """ Given an IR and a node representing a conditional, evaluate the conditional and its negation diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index e490919b3eca1..c7fd16b0443fb 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import operator from collections.abc import Callable from typing import Any @@ -37,7 +36,7 @@ class NormalizeArgs(Transformer): def __init__( self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True - ): + ) -> None: super().__init__(module) self.node_map: dict[Proxy, Node] = {} self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs @@ -45,7 +44,7 @@ def __init__( def run_node(self, n: Node) -> Any: args, kwargs = self.fetch_args_kwargs_from_env(n) - def get_type(arg): + def get_type(arg: object) -> Any: if isinstance(arg, fx.Node): return n.meta.get("type") return type(arg) @@ -72,7 +71,7 @@ def call_function( kwargs: dict[str, Any], arg_types: tuple[Any, ...] | None = None, kwarg_types: dict[str, Any] | None = None, - ): + ) -> Proxy: if not callable(target): raise AssertionError(f"Expected callable target, got {type(target)}") new_args_and_kwargs = normalize_function( @@ -93,7 +92,7 @@ def call_function( def call_module( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] - ): + ) -> Proxy: if not isinstance(target, str): raise AssertionError(f"Expected str target, got {type(target)}") new_args_and_kwargs = normalize_module( @@ -147,9 +146,9 @@ class NormalizeOperators(AnnotateTypesWithSchema): def call_function( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] - ): + ) -> Proxy: # Normalize operators according to the magic methods implemented on tensors here: - # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 if not callable(target): raise AssertionError(f"Expected callable target, got {type(target)}") diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 105528da26639..b5a68616c9c49 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -1,10 +1,9 @@ -# mypy: allow-untyped-defs import copy import logging import operator import time from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from enum import Enum from typing import Any, cast @@ -45,11 +44,11 @@ def _parent_name(target: str) -> tuple[str, str]: # Works for length 2 patterns with 2 modules def matches_module_pattern( - pattern: Iterable[type], node: fx.Node, modules: dict[str, Any] -): + pattern: Iterable[type], node: fx.Node, modules: dict[str, torch.nn.Module] +) -> bool: if len(node.args) == 0: return False - nodes: tuple[Any, fx.Node] = (node.args[0], node) + nodes: tuple[Argument, fx.Node] = (node.args[0], node) for expected_type, current_node in zip(pattern, nodes): if not isinstance(current_node, fx.Node): return False @@ -65,8 +64,8 @@ def matches_module_pattern( def replace_node_module( - node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module -): + node: fx.Node, modules: dict[str, torch.nn.Module], new_module: torch.nn.Module +) -> None: if not isinstance(node.target, str): raise AssertionError(f"Expected str target, got {type(node.target)}") parent_name, name = _parent_name(node.target) @@ -74,7 +73,9 @@ def replace_node_module( setattr(modules[parent_name], name, new_module) -def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: +def fuse( + model: torch.nn.Module, inplace: bool = False, no_trace: bool = False +) -> torch.nn.Module: """ Fuses convolution/BN and linear/BN layers for inference purposes. Will deepcopy your model by default, but can modify the model inplace as well. @@ -139,7 +140,7 @@ def extract_subgraph( nodes: list[fx.Node], inputs: list[fx.Node], outputs: list[fx.Node], -): +) -> fx.GraphModule: """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ @@ -183,7 +184,9 @@ def extract_subgraph( } -def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): +def modules_to_mkldnn( + nodes: list[fx.Node], modules: dict[str, nn.Module] +) -> dict[nn.Module, nn.Module]: """ For each node, if it's a module that can be preconverted into MKLDNN, then we do so and create a mapping to allow us to convert from the MKLDNN @@ -206,10 +209,10 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): def reset_modules( - nodes: list[fx.Node], + nodes: Iterable[fx.Node], modules: dict[str, nn.Module], old_modules: dict[nn.Module, nn.Module], -): +) -> None: """ Maps each module that's been changed with `modules_to_mkldnn` back to its original. @@ -224,14 +227,16 @@ def reset_modules( class MklSubgraph: - def __init__(self, fx_graph: fx.Graph): + def __init__(self, fx_graph: fx.Graph) -> None: self.fx_graph = fx_graph self.nodes: list[fx.Node] = [] self.start_nodes: list[fx.Node] = [] self.end_nodes: list[fx.Node] = [] -def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): +def gen_mkl_autotuner( + example_inputs: list[torch.Tensor], iters: int = 10, warmup: int = 1 +) -> Callable[[MklSubgraph], bool]: """ This generates a heuristic that can be passed into `optimize_for_inference` that determines whether a subgraph should be run in MKL by running it with the example_inputs. @@ -254,7 +259,7 @@ def use_mkl_heuristic(graph: MklSubgraph) -> bool: output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes]) submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) - def benchmark(f): + def benchmark(f: Callable[[], object]) -> float: for _ in range(warmup): f() begin = time.time() @@ -271,7 +276,7 @@ def benchmark(f): reset_modules( submodule.graph.nodes, dict(submodule.named_modules()), - # pyrefly: ignore [bad-argument-type] + # pyrefly: ignore [bad-argument-type] # old_modules is set before this point old_modules, ) no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) @@ -290,11 +295,11 @@ def use_mkl_length(graph: MklSubgraph) -> bool: class UnionFind: - def __init__(self, n): + def __init__(self, n: int) -> None: self.parent: list[int | None] = [None] * n self.size: list[int] = [0] * n - def make_set(self, v: int): + def make_set(self, v: int) -> None: self.parent[v] = v self.size[v] = 1 @@ -307,7 +312,7 @@ def find(self, v: int) -> int: self.parent[v] = self.find(par) return cast(int, self.parent[v]) - def join(self, a: int, b: int): + def join(self, a: int, b: int) -> int | None: a, b = self.find(a), self.find(b) if a == b: return a @@ -425,7 +430,7 @@ class MklSupport(Enum): num_nodes = len(fx_graph.nodes) uf = UnionFind(num_nodes) - def get_color(n): + def get_color(n: fx.Node) -> int | None: if hasattr(n, "color"): # Current node is part of a MKL subgraph return uf.find(n.color) if hasattr(n, "start_color"): # Current node is input to MKL subgraph @@ -463,10 +468,10 @@ def get_color(n): continue if any(i is None for i in cur_colors): raise AssertionError("Found None in cur_colors") - cur_colors = sorted(cur_colors) - node.color = cur_colors[0] - for other_color in cur_colors[1:]: - uf.join(cur_colors[0], other_color) + sorted_colors: list[int] = sorted(cur_colors) # type: ignore[arg-type] + node.color = sorted_colors[0] + for other_color in sorted_colors[1:]: + uf.join(sorted_colors[0], other_color) mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) for node in fx_graph.nodes: diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index 3658dd1a9ce96..729938828b9da 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs from enum import Enum from typing import NamedTuple @@ -19,15 +18,15 @@ def __init__(self, partition_id: int) -> None: self.used_mem_bytes: int = 0 self.logical_device_ids: list[int] = [] - def __str__(self): + def __str__(self) -> str: return str(self.partition_id) - def recalculate_mem_size(self): + def recalculate_mem_size(self) -> None: self.used_mem_bytes = 0 for node in self.nodes: self.used_mem_bytes += get_extra_size_of(node, self.nodes) - def add_node(self, node): + def add_node(self, node: Node) -> None: input_nodes: dict[Node, None] = {} map_arg(node.args, input_nodes.setdefault) map_arg(node.kwargs, input_nodes.setdefault) @@ -38,7 +37,7 @@ def add_node(self, node): self.nodes.add(node) self.recalculate_mem_size() - def remove_node(self, node): + def remove_node(self, node: Node) -> None: # Remove a node only if the node is in the partition if node in self.nodes: self.nodes.remove(node) @@ -151,7 +150,7 @@ def get_top_nodes(partition: Partition) -> list[Node]: top_nodes.append(node) return top_nodes - def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + def dfs_helper(node: Node, partition_latency: PartitionLatency) -> PartitionLatency: """Given a top node of a partition, this function returns the latency of the critical path in the partition """ @@ -235,7 +234,7 @@ def get_comm_latency_between( parent_partition: Partition, child_partition: Partition, transfer_rate_bytes_per_sec: float, -): +) -> float: """Given two partitions (parent and child), calculate the communication latency between the two. """ @@ -271,7 +270,7 @@ def get_latency_of_partitioned_graph( partitions: list[Partition], partition_to_latency_mapping: dict[Partition, PartitionLatency], transfer_rate_bytes_per_sec: float, -): +) -> float: """Given all partitions in a graph, find the critical path among all partitions and return its latency as the latency of the whole graph """ diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index b74e82a79e5e8..2da681922bbf5 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-decorators # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # @@ -7,6 +6,7 @@ from __future__ import annotations +import contextvars import functools import inspect import logging @@ -14,7 +14,6 @@ import threading import typing import typing_extensions -import weakref from collections import defaultdict, OrderedDict from collections.abc import Callable, Generator, Mapping, Sequence from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext @@ -22,6 +21,7 @@ from typing import ( Any, Concatenate, + Literal, overload, Protocol, TYPE_CHECKING, @@ -40,10 +40,16 @@ from torch import SymBool, SymInt, Tensor from torch._dispatch.python import enable_python_dispatcher from torch._library.fake_class_registry import FakeScriptObject -from torch._library.opaque_object import is_opaque_value, OpaqueType +from torch._library.opaque_object import ( + get_reconstruct_fn, + is_opaque_reference_type, + is_opaque_value, + is_opaque_value_type, + OpaqueBase, + should_hoist, +) from torch._logging import trace_structured -from torch._opaque_base import OpaqueBase -from torch._ops import HigherOrderOperator +from torch._ops import HigherOrderOperator, OpOverload from torch._subclasses.fake_impls import fast_detach from torch._subclasses.fake_tensor import ( FakeTensor, @@ -70,6 +76,7 @@ _unset_infra_mode, autograd_would_have_decomposed, TorchDispatchMode, + TraceableWrapperSubclass, ) from torch.utils._stats import count from torch.utils._thunk import Thunk @@ -86,8 +93,6 @@ import sympy from torch._higher_order_ops.utils import FunctionalizeCtxWrapper - from torch._ops import OpOverload - from torch.fx._symbolic_trace import PHBase from torch.types import BoolLikeType, FloatLikeType, IntLikeType __all__ = [ @@ -105,6 +110,7 @@ ] _ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] +_TracingMode: TypeAlias = Literal["real", "fake", "symbolic"] _AnyScriptObject = (torch.ScriptObject, FakeScriptObject) _AnyScriptObjectType = torch.ScriptObject | FakeScriptObject @@ -115,7 +121,9 @@ log = logging.getLogger(__name__) not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") -CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} +CURRENT_DECOMPOSITION_TABLE: contextvars.ContextVar[ + Mapping[OpOverload, Callable[..., Any]] +] = contextvars.ContextVar("CURRENT_DECOMPOSITION_TABLE") CONSTANT_NUMEL_LIMIT = 1 @@ -125,7 +133,6 @@ R = TypeVar("R") _Ts = TypeVarTuple("_Ts") -null_ctx_type = type(nullcontext) # We currently convert all SymInt to proxies before we use them. # This could plausibly be handled at the Dynamo level. pytree.register_pytree_node( @@ -156,15 +163,14 @@ def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: @contextmanager def decompose( - decomposition_table: Mapping[OpOverload, Callable] | None, -) -> Generator[Mapping[OpOverload, Callable], None, None]: - global CURRENT_DECOMPOSITION_TABLE - old_decomposition_table = CURRENT_DECOMPOSITION_TABLE - CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None, +) -> Generator[Mapping[OpOverload, Callable[..., Any]], None, None]: + table = decomposition_table or {} + token = CURRENT_DECOMPOSITION_TABLE.set(table) try: - yield CURRENT_DECOMPOSITION_TABLE + yield table finally: - CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + CURRENT_DECOMPOSITION_TABLE.reset(token) # ensure we cannot collide with other properties @@ -258,7 +264,7 @@ def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> No @overload def set_proxy_slot( - obj: _AnyScriptObjectType | OpaqueType, tracer: _ProxyTracer, proxy: Proxy + obj: _AnyScriptObjectType | OpaqueBase, tracer: _ProxyTracer, proxy: Proxy ) -> None: ... @@ -269,7 +275,7 @@ def set_proxy_slot( def set_proxy_slot( - obj: PySymType | _AnyScriptObjectType | Tensor | OpaqueType, + obj: PySymType | _AnyScriptObjectType | Tensor | OpaqueBase, tracer: _ProxyTracer, proxy: object, ) -> None: @@ -286,7 +292,6 @@ def set_proxy_slot( ): tracer.tensor_tracker[obj] = proxy elif isinstance(obj, (_AnyScriptObject)) or is_opaque_value(obj): - # We DO want to clobber proxies, with a similar rationale as for tensors. if not isinstance(proxy, Proxy): raise AssertionError(f"Expected Proxy, got {type(proxy)}") # ScriptObject (actual C++ torchbind) uses _WeakHashRef-keyed tracker @@ -296,7 +301,30 @@ def set_proxy_slot( if isinstance(obj, torch.ScriptObject): tracer.script_object_tracker[obj] = proxy else: - tracer.opaque_tracker[obj] = proxy + # NB: Never clobber a pre-existing proxy for the same + # underlying real object. Multiple FakeScriptObject wrappers + # can share the same real_obj (e.g. primal vs tangent + # placeholders during joint graph tracing). We always keep the + # first proxy registered, with the same rationale as the + # symnode_tracker first-one-wins policy below: primals are + # registered first, so this avoids spurious tangent dependencies + # in forward outputs (which would break the partitioner). + real_obj = None + if isinstance(obj, FakeScriptObject): + try: + real_obj = object.__getattribute__(obj, "real_obj") + except AttributeError: + pass + + if real_obj is not None: + existing = tracer._opaque_real_obj_proxy.get(id(real_obj)) + if existing is not None: + tracer.opaque_tracker[obj] = existing + else: + tracer.opaque_tracker[obj] = proxy + tracer._opaque_real_obj_proxy[id(real_obj)] = proxy + else: + tracer.opaque_tracker[obj] = proxy else: # NB: Never clobber pre-existing proxy. Although the proxies # are in principle equivalent, when we do graph partitioning @@ -338,7 +366,7 @@ def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: _PySymProxyType: TypeAlias = Thunk[Proxy] -_OpaqueObjectProxyType: TypeAlias = Thunk[Proxy] +_OpaqueObjectProxyType: TypeAlias = Proxy @overload @@ -406,7 +434,7 @@ def get_proxy_slot( @overload def get_proxy_slot( - obj: OpaqueType, + obj: OpaqueBase, tracer: _ProxyTracer, default: T, ) -> T | _OpaqueObjectProxyType: ... @@ -425,10 +453,10 @@ def get_proxy_slot( # the transform argument is handy if you need to extract a subfield from # the successfully looked up result (but NOT the default.) def get_proxy_slot( - obj: Tensor | _AnyScriptObjectType | PySymType | OpaqueType, + obj: Tensor | _AnyScriptObjectType | PySymType | OpaqueBase, tracer: _ProxyTracer, default: object = no_default, - transform: Callable = lambda x: x, + transform: Callable[..., Any] = lambda x: x, ) -> object: tracker: Any if isinstance(obj, Tensor): @@ -457,9 +485,16 @@ def get_proxy_slot( else: # Attempt to build it from first principles. _build_proxy_for_sym_expr(tracer, obj.node.expr, obj) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] value = tracker.get(obj) + if value is None and isinstance(obj, FakeScriptObject): + # A new FakeScriptObject wrapping the same real_obj may have been + # created (e.g. output flattening in unwrap_tensor_subclasses calls + # maybe_to_fake_obj which always mints a fresh wrapper). Fall back + # to the real-object dedup map that set_proxy_slot maintains. + value = tracer._opaque_real_obj_proxy.get(id(obj.real_obj)) + if value is None: # We don't know this value - return the default. if isinstance(default, _NoDefault): @@ -472,11 +507,11 @@ def get_proxy_slot( return res -# Recursively traverses tensor subclasses, +# Recursively traverses traceable wrapper subclasses, # returnining an (unordered) list of Proxy objects that are tracked # for all inner tensors, given the current extant proxy mode. # Returns an empty list if no proxy mode is active. -def _get_proxies(t: torch.Tensor) -> list[Proxy]: +def _get_proxies(t: torch.Tensor | TraceableWrapperSubclass) -> list[Proxy]: proxies = [] mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) if mode is None: @@ -634,7 +669,7 @@ def snapshot_fake(val: Tensor, include_real: bool = False) -> Tensor | None: _ExtractValType: TypeAlias = ( None | PySymType - | OpaqueType + | OpaqueBase | _AnyScriptObjectType | BackwardState | list["_ExtractValType"] @@ -652,7 +687,7 @@ def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractVal return snapshot_fake(val, include_real=include_real) elif isinstance(val, py_sym_types): return val - elif isinstance(val, _AnyScriptObject): + elif isinstance(val, (_AnyScriptObject, OpaqueBase)): return val elif isinstance(val, BackwardState): return val @@ -900,6 +935,15 @@ def wrap_with_proxy( elif isinstance(e, _AnyScriptObject) or is_opaque_value(e): if not isinstance(proxy, Proxy): raise AssertionError(f"Expected Proxy, got {type(proxy)}") + # Non-hoisted opaque value types should be baked as constants + # in the graph, not tracked as proxy references. This matches + # dynamo's behavior where non-hoisted values are not graph inputs. + if ( + is_opaque_value_type(type(e)) # pyrefly: ignore[bad-argument-type] + and not should_hoist(type(e)) + ): + set_meta(proxy, e) + return set_proxy_slot(e, tracer, proxy) set_meta(proxy, e) elif isinstance(e, (tuple, list)): @@ -1000,13 +1044,13 @@ def fetch_object_proxy( @overload def fetch_object_proxy( - tracer: _ProxyTracer, t: OpaqueType + tracer: _ProxyTracer, t: OpaqueBase ) -> _OpaqueObjectProxyType | PySymType: ... def fetch_object_proxy( tracer: _ProxyTracer, - t: Tensor | _AnyScriptObjectType | PySymType | OpaqueType, + t: Tensor | _AnyScriptObjectType | PySymType | OpaqueBase, ) -> object: return get_proxy_slot(t, tracer, t) @@ -1286,7 +1330,13 @@ def tensor_numel_in_limit(t: Tensor) -> bool: else: constant = None - track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + track_tensor_tree( + out, + proxy_out, + # pyrefly: ignore[bad-argument-type] + constant=constant, + tracer=tracer, + ) _maybe_record_pointwise_barrier(func, proxy_mode) return out @@ -1328,35 +1378,42 @@ class _SympyExprTrackerValue: value: PySymType +def _init_proxy_trackers(tracer: PythonKeyTracer | _GraphAppendingTracerEx) -> None: + """Initialize the tracker dictionaries shared by PythonKeyTracer and _GraphAppendingTracerEx.""" + tracer.tensor_tracker = WeakTensorKeyDictionary() + tracer.symnode_tracker = _SymNodeDict() + tracer.script_object_tracker = WeakIdKeyDictionary(dict=None, ref_type=_WeakHashRef) + tracer.opaque_tracker = WeakIdKeyDictionary() + tracer._opaque_real_obj_proxy = {} + tracer.sympy_expr_tracker = {} + # Stores the torch function that was called during tracing + tracer.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + tracer.torch_fn_counts = {} + tracer.enable_thunkify = False + + class PythonKeyTracer(Tracer): # ScriptObject uses _WeakHashRef because the same C++ IValue can produce # different Python wrapper objects, so Python id() won't match. script_object_tracker: MutableMapping[torch.ScriptObject, Proxy] # FakeScriptObject/OpaqueBase uses WeakIdRef because distinct objects that # are value-equal (e.g. primal vs tangent opaques) must be tracked separately. - opaque_tracker: MutableMapping[FakeScriptObject | OpaqueBase | OpaqueType, Proxy] + opaque_tracker: MutableMapping[FakeScriptObject | OpaqueBase, Proxy] + # Maps id(real_obj) -> proxy for opaque FSOs, so that multiple FSO wrappers + # of the same real object (e.g. primal vs tangent) resolve to one proxy. + _opaque_real_obj_proxy: dict[int, Proxy] symnode_tracker: _SymNodeDict sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue] tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + torch_fn_metadata: OpOverload | None torch_fn_counts: dict[OpOverload, int] enable_thunkify: bool = False def __init__(self) -> None: super().__init__(autowrap_modules=()) # type: ignore[arg-type] - self.tensor_tracker = WeakTensorKeyDictionary() - self.symnode_tracker = _SymNodeDict() - self.script_object_tracker = WeakIdKeyDictionary( - dict=None, ref_type=_WeakHashRef - ) - self.opaque_tracker = WeakIdKeyDictionary() - self.sympy_expr_tracker = {} - - # Stores the torch function that was called during tracing - self.torch_fn_metadata = None - # Stores the counts for every torch function called. This is to help - # distinguish between different calls to the same torch function. - self.torch_fn_counts = {} - self.enable_thunkify = False + _init_proxy_trackers(self) # In general, we don't want to make modules leaves. In principle, users of # this tracer might want to override this in order to turn a couple specific @@ -1390,8 +1447,75 @@ def create_arg(self, a: object) -> fx.node.Node: if a.node.constant is None: raise AssertionError("a.node.constant should not be None") return a.node.constant + + # Try reconstructing untracked opaque reference types from existing + # graph inputs (e.g. derive a DeviceMesh submesh from its root mesh). + if isinstance(a, (FakeScriptObject, OpaqueBase)): + node = self._try_reconstruct_opaque(a) + if node is not None: + return node + return super().create_arg(a) # type: ignore[return-value] + def _try_reconstruct_opaque( + self, a: FakeScriptObject | OpaqueBase + ) -> fx.node.Node | None: + """Try to reconstruct an opaque object from existing graph inputs. + + When make_fx encounters an untracked opaque reference type (e.g. a + DeviceMesh submesh captured by a backward closure), this method checks + if the type has a registered reconstruct_fn that can derive the object + from inputs already in the graph. Returns an FX Node on success, + None on failure (falls back to get_attr constant). + """ + real_obj: OpaqueBase = a.real_obj if isinstance(a, FakeScriptObject) else a + + if not is_opaque_reference_type(type(real_obj)): + return None + + reconstruct_fn = get_reconstruct_fn(type(real_obj)) + if reconstruct_fn is None: + return None + + def get_tracked_proxy(obj: OpaqueBase) -> Proxy | None: + proxy = self._opaque_real_obj_proxy.get(id(obj)) + if proxy is not None: + return proxy + proxy = get_proxy_slot(obj, self, None) + if proxy is not None: + return proxy + # The object may be identity-different but equal to a tracked + # FSO's real_obj. This happens because maybe_to_fake_obj creates + # a new DeviceMesh wrapper, so the FSO's real_obj is a distinct + # object from the one held by submesh._root_mesh. Equality + # comparison is needed, not identity. + for tracked_obj, p in self.opaque_tracker.items(): + if not isinstance(tracked_obj, FakeScriptObject): + continue + if tracked_obj.real_obj == obj: + return p + return None + + result = reconstruct_fn(real_obj, get_tracked_proxy, self) + if result is None: + return None + + # The reconstruct_fn dispatches through a custom op with a Proxy + # argument, which goes through Proxy.__torch_function__. That path + # creates a graph node but does NOT populate meta["val"] (unlike the + # ProxyTorchDispatchMode.__torch_dispatch__ path which calls + # track_tensor_tree). Set it here so downstream consumers (e.g. the + # min-cut partitioner) can classify the node correctly. + if "val" not in result.node.meta: + set_meta(result, a) + + # Also register for the *input* object so dedup works for re-encounters. + set_proxy_slot(a, self, result) + if id(real_obj) not in self._opaque_real_obj_proxy: + self._opaque_real_obj_proxy[id(real_obj)] = result + + return result.node + @overload def unwrap_proxy(self, e: Tensor) -> Proxy | Tensor: ... @@ -1543,9 +1667,9 @@ def context_manager_fn() -> Generator[TorchFunctionMode | None, None, None]: @torch._disable_dynamo def dispatch_trace( - root: Module | Callable, + root: Module | Callable[..., Any], tracer: Tracer, - concrete_args: tuple[Any, ...] | None = None, + concrete_args: tuple[object, ...] | None = None, ) -> GraphModule: graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] @@ -1636,21 +1760,22 @@ def get_sym_proxy_slot(t: PySymType) -> Proxy: # TODO: Make downstream users of this work with OperatorBase -ORIGINAL_ATEN: object | None = None +ORIGINAL_ATEN: contextvars.ContextVar[object | None] = contextvars.ContextVar( + "ORIGINAL_ATEN", default=None +) @contextmanager def set_original_aten_op( func: OpOverload | torch._ops.HigherOrderOperator, ) -> Generator[None, None, None]: - global ORIGINAL_ATEN - if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): - ORIGINAL_ATEN = func + if ORIGINAL_ATEN.get() is None and fx_traceback.has_preserved_node_meta(): + token = ORIGINAL_ATEN.set(func) fx_traceback.current_meta["original_aten"] = func try: yield finally: - ORIGINAL_ATEN = None + ORIGINAL_ATEN.reset(token) fx_traceback.current_meta["original_aten"] = None else: yield @@ -1692,7 +1817,7 @@ def __init__(self, tracer: _ProxyTracer) -> None: def __torch_function__( self, - func: OpOverload | Callable, + func: OpOverload | Callable[..., Any], types: tuple[torch._C._TensorMeta, ...], args: tuple[object, ...] = (), kwargs: dict[str, object] | None = None, @@ -1731,16 +1856,32 @@ def __torch_function__( torch._functorch.predispatch._vmap_increment_nesting, torch._functorch.predispatch._vmap_decrement_nesting, torch._functorch.vmap.lazy_load_decompositions, + torch._functorch.predispatch._make_dual, + torch._functorch.predispatch._unpack_dual, + torch._functorch.predispatch._jvp_increment_nesting, + torch._functorch.predispatch._jvp_decrement_nesting, + torch._functorch.predispatch._unwrap_for_grad, + torch._functorch.predispatch._enter_dual_level, + torch._functorch.predispatch._exit_dual_level, ]: _, proxies, _ = _fetch_proxies_and_all_constant_flag(args, self.tracer) out_proxy = self.tracer.create_proxy( "call_function", func, proxies, - {}, + kwargs, ) res = func(*args, **kwargs) - track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer) + # When JVP transforms are active, snapshot_fake calls detach + # which goes through C++ functorch dispatch keys, potentially + # re-wrapping the result as a grad tracking tensor. + # Temporarily disable functorch transforms during tracking to + # prevent this corruption of meta["val"]. + if func in _jvp_predispatch_functions: + with torch._C._DisableFuncTorch(): + track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer) + else: + track_tensor_tree(res, out_proxy, constant=None, tracer=self.tracer) return res return func(*args, **kwargs) @@ -1750,6 +1891,37 @@ def __torch_function__( ) +# JVP predispatch functions need special handling during tracing: +# _DisableFuncTorch prevents grad tracking tensor corruption in snapshot_fake. +_jvp_predispatch_functions = frozenset( + { + torch._functorch.predispatch._make_dual, + torch._functorch.predispatch._unpack_dual, + torch._functorch.predispatch._jvp_increment_nesting, + torch._functorch.predispatch._jvp_decrement_nesting, + torch._functorch.predispatch._unwrap_for_grad, + torch._functorch.predispatch._enter_dual_level, + torch._functorch.predispatch._exit_dual_level, + } +) + +# These JVP predispatch wrappers manage transform nesting/level state +# and must survive FX dead code elimination even with zero output users. +# Registered here rather than in predispatch.py to avoid circular imports +# (predispatch.py loads before torch.fx during torch.autograd init). +from torch.fx.node import _side_effectful_functions + + +_side_effectful_functions.update( + { + torch._functorch.predispatch._jvp_increment_nesting, + torch._functorch.predispatch._jvp_decrement_nesting, + torch._functorch.predispatch._enter_dual_level, + torch._functorch.predispatch._exit_dual_level, + } +) + + class ProxyTorchDispatchMode(TorchDispatchMode): # Ensure this is read-only; this exists only for legacy reasons @property @@ -1759,7 +1931,7 @@ def enable_tracing(self) -> bool: def __init__( self, tracer: _ProxyTracer, - tracing_mode: str, + tracing_mode: _TracingMode, pre_dispatch: bool = False, _allow_fake_constant: bool = False, _error_on_data_dependent_ops: bool = True, @@ -1904,8 +2076,11 @@ def _compute_proxy( class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): script_object_tracker: MutableMapping[torch.ScriptObject, Proxy] - opaque_tracker: MutableMapping[FakeScriptObject | OpaqueBase | OpaqueType, Proxy] - symnode_tracker: MutableMapping[PySymType, _PySymProxyType] + opaque_tracker: MutableMapping[FakeScriptObject | OpaqueBase, Proxy] + # Maps id(real_obj) -> proxy for opaque FSOs, so that multiple FSO wrappers + # of the same real object (e.g. primal vs tangent) resolve to one proxy. + _opaque_real_obj_proxy: dict[int, Proxy] + symnode_tracker: _SymNodeDict tensor_tracker: MutableMapping[Tensor, _ProxyTensor] sympy_expr_tracker: dict[sympy.Symbol, _SympyExprTrackerValue] torch_fn_metadata: OpOverload | None @@ -1914,18 +2089,7 @@ class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): def __init__(self, graph: fx.graph.Graph) -> None: super().__init__(graph) - self.symnode_tracker = weakref.WeakKeyDictionary() - self.tensor_tracker = WeakTensorKeyDictionary() - self.sympy_expr_tracker = {} - self.script_object_tracker = WeakIdKeyDictionary( - dict=None, ref_type=_WeakHashRef - ) - self.opaque_tracker = WeakIdKeyDictionary() - # Stores the torch function that was called during tracing - self.torch_fn_metadata = None - # Stores the counts for every torch function called. This is to help - # distinguish between different calls to the same torch function. - self.torch_fn_counts = {} + _init_proxy_trackers(self) # TODO: I'm not sure what the point of this class is; you can just @@ -1935,7 +2099,7 @@ def __init__( self, module: fx.GraphModule, new_graph: fx.Graph, - decomposition_table: Mapping[OpOverload, Callable] | None = None, + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None = None, **kwargs: object, ) -> None: super().__init__(module, **kwargs) # type: ignore[arg-type] @@ -2002,7 +2166,7 @@ def __init__( self, module: fx.GraphModule, should_decompose: Callable[[fx.Node], bool], - decomposition_table: Mapping[OpOverload, Callable], + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None, **kwargs: object, ) -> None: """ @@ -2017,7 +2181,7 @@ def __init__( def recursive_wrap( gm: fx.GraphModule, should_decompose: Callable[[fx.Node], bool], - decomposition_table: Mapping[OpOverload, Callable], + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None, **kwargs: object, ) -> _SelectiveDecomposeInterpreter: """ @@ -2048,7 +2212,7 @@ def recursive_wrap( gm, should_decompose, decomposition_table, **kwargs ) - def run_node(self, n): + def run_node(self, n: fx.Node) -> Any: if self.should_decompose(n): with decompose(self.decomposition_table): result = super().run_node(n) @@ -2059,9 +2223,9 @@ def run_node(self, n): def selective_decompose( joint_gm: fx.GraphModule, - *args, - decomposition, - should_decompose, + *args: object, + decomposition: Mapping[OpOverload, Callable[..., Any]] | None, + should_decompose: Callable[..., bool], trace_joint_graph: bool, ) -> fx.GraphModule: """Retrace a joint graph module and selectively apply decomposition.""" @@ -2070,13 +2234,13 @@ def selective_decompose( # the arg name, primals and tangents, are important. # make_fx keeps the name in the traced graph and partitioner later relies # on the name to partition joint graph correctly. - def wrap_fn(primals: list[Any], tangents: list[Any]): + def wrap_fn(primals: list[Any], tangents: list[Any]) -> Any: return _SelectiveDecomposeInterpreter.recursive_wrap( joint_gm, should_decompose, decomposition ).run(*args) else: - def wrap_fn(*args): + def wrap_fn(*args: Any) -> Any: return _SelectiveDecomposeInterpreter.recursive_wrap( joint_gm, should_decompose, decomposition ).run(*args) @@ -2281,7 +2445,7 @@ def getattr( return self.attr_proxy_map[attr_val] def trace( # type: ignore[override] - self, root: Module | Callable, concrete_args: dict[str, object] | None + self, root: Module | Callable[..., Any], concrete_args: dict[str, object] | None ) -> fx.Graph: res = super().trace(root, concrete_args) @@ -2352,7 +2516,7 @@ def _delete_proxy_attr(obj: Module, target: str) -> bool: def call_module( self, m: Module, - forward: Callable, + forward: Callable[..., Any], args: tuple[object, ...], kwargs: dict[str, object], ) -> None: @@ -2379,7 +2543,7 @@ def call_module( "This might be because the module was not properly registered " "as a submodule, which is not good practice. We will trace " "through the module without recording stack information.", - str(m), + m, ) return forward(*args, **kwargs) @@ -2411,12 +2575,14 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: # torch_fn if ( node.op == "call_function" - and self.torch_fn_metadata is not None + and (torch_fn := self.torch_fn_metadata) is not None and "torch_fn" not in node.meta ): node.meta["torch_fn"] = ( - f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", - f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", + # pyrefly: ignore[missing-attribute,bad-index] + f"{torch_fn.__name__}_{self.torch_fn_counts[torch_fn]}", + # pyrefly: ignore[missing-attribute] + f"{torch_fn.__class__.__name__}.{torch_fn.__name__}", ) return node @@ -2425,8 +2591,8 @@ def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: class _MakefxTracer: def __init__( self, - decomposition_table: Mapping[OpOverload, Callable] | None, - tracing_mode: str, + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None, + tracing_mode: _TracingMode, _allow_non_fake_inputs: bool, pre_dispatch: bool, record_module_stack: bool, @@ -2439,13 +2605,13 @@ def __init__( ) -> None: # Configurations that are used to initialize the context managers and their states. # Should not modify them during tracing. - self.decomposition_table: dict[OpOverload, Callable] = dict( + self.decomposition_table: dict[OpOverload, Callable[..., Any]] = dict( decomposition_table or {} ) self.decomposition_table.setdefault( torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel ) - self.tracing_mode: str = tracing_mode + self.tracing_mode: _TracingMode = tracing_mode self._allow_non_fake_inputs: bool = _allow_non_fake_inputs self.pre_dispatch: bool = pre_dispatch self.record_module_stack: bool = record_module_stack @@ -2457,13 +2623,13 @@ def __init__( # Remember to specify how to initialize it from user inputs and from parent tracer whenever # adding new modes in _MakefxTracer. self.fake_tensor_mode: FakeTensorMode | None = None - self.proxy_mode: nullcontext | ProxyTorchDispatchMode = nullcontext() - self.proxy_function_mode: nullcontext | PreDispatchTorchFunctionMode = ( + self.proxy_mode: nullcontext[None] | ProxyTorchDispatchMode = nullcontext() + self.proxy_function_mode: nullcontext[None] | PreDispatchTorchFunctionMode = ( nullcontext() ) self.fx_tracer: PythonKeyTracer | None = None - self.python_dispatcher_mode: nullcontext | Any = nullcontext() - self.torch_fn_metadata_mode: nullcontext | TorchFunctionMetadataMode = ( + self.python_dispatcher_mode: nullcontext[None] | Any = nullcontext() + self.torch_fn_metadata_mode: nullcontext[None] | TorchFunctionMetadataMode = ( nullcontext() ) self.record_stack_traces = record_stack_traces @@ -2484,11 +2650,11 @@ def _checkpoint_modes(self) -> list[Any]: def _restore_modes( self, prev_fake_tensor_mode: FakeTensorMode | None, - prev_proxy_mode: nullcontext | ProxyTorchDispatchMode, - prev_proxy_function_mode: nullcontext | PreDispatchTorchFunctionMode, + prev_proxy_mode: nullcontext[None] | ProxyTorchDispatchMode, + prev_proxy_function_mode: nullcontext[None] | PreDispatchTorchFunctionMode, prev_fx_tracer: PythonKeyTracer | None, - prev_python_dispatcher_mode: nullcontext | Any, - prev_torch_fn_metadata_mode: nullcontext | TorchFunctionMetadataMode, + prev_python_dispatcher_mode: nullcontext[None] | Any, + prev_torch_fn_metadata_mode: nullcontext[None] | TorchFunctionMetadataMode, ) -> None: self.fake_tensor_mode = prev_fake_tensor_mode self.proxy_mode = prev_proxy_mode @@ -2499,7 +2665,7 @@ def _restore_modes( @contextmanager def _init_modes_from_inputs( - self, f: Callable, args: tuple[object, ...] + self, f: Callable[..., Any], args: tuple[object, ...] ) -> Generator[None, None, None]: prev_modes = self._checkpoint_modes() try: @@ -2614,7 +2780,57 @@ def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: finally: self._restore_modes(*prev_modes) - def _trace_inner(self, f: Callable, *args: object) -> GraphModule: + def _convert_args_to_fake(self, args: T) -> T: + if self.tracing_mode == "real": + return args + + arg_count = 0 + + def inner_wrap_fake(x: object) -> object: + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + + if self.fake_tensor_mode is None: + raise AssertionError("fake_tensor_mode should not be None") + source = ConstantSource(f"input{arg_count}") + if isinstance(x, Tensor): + arg_count += 1 + return self.fake_tensor_mode.from_tensor(x, source=source) + # NB: don't match on bools + elif type(x) is int and self.tracing_mode == "symbolic": + if self.fake_tensor_mode.shape_env is None: + raise AssertionError( + "shape_env should be set if tracing with 'symbolic'" + ) + return self.fake_tensor_mode.shape_env.create_symintnode( + self.fake_tensor_mode.shape_env.create_symbol( + x, source, positive=None + ), + hint=x, + source=source, + ) + elif isinstance(x, torch.ScriptObject) or is_opaque_value(x): + if is_opaque_value_type( + type(x) # pyrefly: ignore[bad-argument-type] + ): + return x + return torch._library.fake_class_registry.maybe_to_fake_obj( + self.fake_tensor_mode, x + ) + + if isinstance(x, FakeScriptObject): + raise AssertionError( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) + return x + + return pytree.tree_map(inner_wrap_fake, args) + + def _trace_inner(self, f: Callable[..., Any], *args: object) -> GraphModule: # TODO: We need to explicitly import torch._dynamo before calling dispatch_trace, # because dispatch_trace will introduce the lazy import of torch._dynamo, # and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo, @@ -2624,66 +2840,16 @@ def _trace_inner(self, f: Callable, *args: object) -> GraphModule: phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) - def _wrap_fake(args: T) -> T: - arg_count = 0 - - def inner_wrap_fake(x: object) -> object: - nonlocal arg_count - # TODO: it would be nice to line these up with the names - # FX will choose for the placeholders, but we don't - # actually know what the names will be at this point yet - # NB: the Source here is actually meaningless - from torch._dynamo.source import ConstantSource - - if self.fake_tensor_mode is None: - raise AssertionError("fake_tensor_mode should not be None") - source = ConstantSource(f"input{arg_count}") - if isinstance(x, Tensor): - arg_count += 1 - return self.fake_tensor_mode.from_tensor(x, source=source) - # NB: don't match on bools - elif type(x) is int and self.tracing_mode == "symbolic": - if self.fake_tensor_mode.shape_env is None: - raise AssertionError( - "shape_env should be set if tracing with 'symbolic'" - ) - return self.fake_tensor_mode.shape_env.create_symintnode( - self.fake_tensor_mode.shape_env.create_symbol( - x, source, positive=None - ), - hint=x, - source=source, - ) - elif isinstance(x, torch.ScriptObject) or is_opaque_value(x): - return torch._library.fake_class_registry.maybe_to_fake_obj( - self.fake_tensor_mode, x - ) + args = self._convert_args_to_fake(args) - if isinstance(x, FakeScriptObject): - raise AssertionError( - f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." - ) - return x - - wrap_fn_map = { - "real": lambda x: x, - "fake": inner_wrap_fake, - "symbolic": inner_wrap_fake, - } - return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) - - def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: - if ( - not hasattr(inspect.unwrap(f), "__code__") - or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS - ): - # FX doesn't support varargs, so we gotta fake up a wrapper - # TODO: Would be nice to fix this at the source... - return fake_signature(f, len(phs)) - return f - - args = _wrap_fake(args) - func = _wrap_func(f, phs) + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + func: Callable[..., Any] = f + if ( + not hasattr(inspect.unwrap(f), "__code__") + or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS + ): + func = fake_signature(f, len(phs)) # We disable the autocast cache as the autocast cache causes type conversions on parameters to # check a cache, which introduces untracked tensors into the graph # @@ -2744,18 +2910,21 @@ def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: t.shape_env = self.fake_tensor_mode.shape_env # type: ignore[assignment] return t - def trace(self, f: Callable, *args: object) -> fx.GraphModule: + def trace(self, f: Callable[..., Any], *args: object) -> fx.GraphModule: with self._init_modes_from_inputs(f, args): return self._trace_inner(f, *args) def is_hop_subgraph_tracer(self) -> bool: return self.parent_tracer is not None - def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: - # Create a new tracer based on parent's config - sub_tracer = _MakefxTracer( - self.decomposition_table, - "real", + def _make_sub_tracer( + self, + decomp_table: Mapping[OpOverload, Callable[..., Any]] | None = None, + tracing_mode: _TracingMode = "real", + ) -> _MakefxTracer: + return _MakefxTracer( + decomp_table if decomp_table is not None else self.decomposition_table, + tracing_mode, self._allow_non_fake_inputs, self.pre_dispatch, self.record_module_stack, @@ -2763,25 +2932,21 @@ def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: self._error_on_data_dependent_ops, parent_tracer=self, ) + + def trace_subgraph(self, f: Callable[..., Any], *args: object) -> GraphModule: + sub_tracer = self._make_sub_tracer() with sub_tracer._init_modes_from_parent(self): return sub_tracer._trace_inner(f, *args) def trace_subgraph_custom_decomp( - self, f: Callable, decomp_table: Mapping[OpOverload, Callable], *args + self, + f: Callable[..., Any], + decomp_table: Mapping[OpOverload, Callable[..., Any]], + *args: object, ) -> GraphModule: if not isinstance(decomp_table, Mapping): raise AssertionError(f"Expected Mapping, got {type(decomp_table)}") - # Create a new tracer based on parent's config, but use a different decomposition table - sub_tracer = _MakefxTracer( - decomp_table, - "real", - self._allow_non_fake_inputs, - self.pre_dispatch, - self.record_module_stack, - self._allow_fake_constant, - self._error_on_data_dependent_ops, - parent_tracer=self, - ) + sub_tracer = self._make_sub_tracer(decomp_table=decomp_table) with sub_tracer._init_modes_from_parent(self): return sub_tracer._trace_inner(f, *args) @@ -2801,9 +2966,9 @@ def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: def make_fx( - f: Callable, - decomposition_table: Mapping[OpOverload, Callable] | None = None, - tracing_mode: str = "real", + f: Callable[..., Any], + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None = None, + tracing_mode: _TracingMode = "real", _allow_non_fake_inputs: bool = False, *, pre_dispatch: bool = False, @@ -2909,7 +3074,8 @@ def maybe_handle_decomp( ) -> object: from torch._inductor.compiler_bisector import CompilerBisector - if op in CURRENT_DECOMPOSITION_TABLE: + decomp_table = CURRENT_DECOMPOSITION_TABLE.get({}) + if op in decomp_table: if CompilerBisector.disable_subsystem( "aot_eager_decomp_partition", "decomposition", lambda: repr(op) ): @@ -2917,7 +3083,7 @@ def maybe_handle_decomp( with proxy_mode: proxy_mode.decomp_layers += 1 - out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + out = decomp_table[op](*args, **kwargs) proxy_mode.decomp_layers -= 1 return out @@ -2925,11 +3091,11 @@ def maybe_handle_decomp( def get_isolated_graphmodule( - func: Callable, + func: Callable[..., Any], args: tuple[object, ...], kwargs: dict[str, object], - tracing_mode: str = "real", - decomposition_table: Mapping[OpOverload, Callable] | None = None, + tracing_mode: _TracingMode = "real", + decomposition_table: Mapping[OpOverload, Callable[..., Any]] | None = None, ) -> GraphModule: """A helper function used to get the GraphModule for the given func. diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 28ff27ab3f2d1..5cc859a25b30c 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -1,16 +1,26 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import functools import inspect import itertools import logging -from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, ParamSpec, TYPE_CHECKING, TypeVar + + +_P = ParamSpec("_P") +_R = TypeVar("_R") import torch import torch.utils._pytree as pytree +if TYPE_CHECKING: + from collections.abc import Callable + + from torch.fx.experimental.symbolic_shapes import ShapeEnv, TrackedFake + + log = logging.getLogger(__name__) trace_shape_events_log = torch._logging.getArtifactLogger( __name__, "trace_shape_events" @@ -81,21 +91,21 @@ @dataclass class ShapeEnvEvent: # ShapeEnv method. - f: Callable + f: Callable[..., Any] # Arguments and keyword arguments called with. - args: list[Any] | None = None + args: list[object] | None = None kwargs: dict[str, Any] | None = None # List of tracked_fakes at the time the method was called. - tracked_fakes: list[Any] | None = None + tracked_fakes: list[TrackedFake] | None = None # Name of the captured event. # Used for special handling of particular methods. name: str | None = None # Replay itself, but using shape_env as self. - def run(self, shape_env=None) -> Any: + def run(self, shape_env: ShapeEnv | None = None) -> Any: from torch.fx.experimental.symbolic_shapes import ( is_symbolic, ShapeEnv, @@ -147,7 +157,7 @@ def maybe_convert_node(x: Any) -> Any: return name_to_node[x.name] # Replaces the value of an specific argument by the result of fn. - def replacearg(index: int, key: str, fn: Callable): + def replacearg(index: int, key: str, fn: Callable[..., Any]) -> None: if index < len(args): args[index] = fn(args[index]) if key in kwargs: @@ -196,7 +206,9 @@ def is_defer_runtime_assert(self) -> bool: # 2. SymInt, SymFloat, or SymBool arguments # If we find more than one object of any of the above types, we # also check that the ShapeEnv instance is the same for all of them. -def _extract_shape_env_and_assert_equal(args, kwargs): +def _extract_shape_env_and_assert_equal( + args: tuple[object, ...] | list[object], kwargs: dict[str, object] +) -> ShapeEnv | None: from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes def assert_equal(old: ShapeEnv | None, new: ShapeEnv) -> ShapeEnv: @@ -241,8 +253,8 @@ def assert_equal(old: ShapeEnv | None, new: ShapeEnv) -> ShapeEnv: # - ShapeEnv.guard_or_defer_runtime_assert def record_shapeenv_event( *, save_tracked_fakes: bool = False, name: str | None = None -) -> Callable: - def decorator(fn: Callable) -> Callable: +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: + def decorator(fn: Callable[_P, _R]) -> Callable[_P, _R]: if not callable(fn): raise AssertionError(f"Expected callable, got {type(fn)}") args = inspect.getfullargspec(fn).args @@ -256,7 +268,7 @@ def decorator(fn: Callable) -> Callable: name = fn.__name__ @functools.wraps(fn) - def wrapper(*args, **kwargs): + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: from torch.fx.experimental.symbolic_shapes import ShapeEnv if not isinstance(args[0], ShapeEnv): @@ -269,7 +281,7 @@ def wrapper(*args, **kwargs): ) NEST += 1 - def retlog(r): + def retlog(r: _R) -> _R: trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) return r @@ -325,7 +337,7 @@ def retlog(r): if not shape_env.should_record_events or shape_env.is_recording: # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging. raise - log.error( # noqa: G201 + log.error( "failed while running %s(*%s, **%s)", name, args[1:], @@ -346,7 +358,7 @@ def retlog(r): # It assumes the first event is the constructor call. # # fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. -def replay_shape_env_events(events): +def replay_shape_env_events(events: list[ShapeEnvEvent]) -> ShapeEnv: from torch.fx.experimental.symbolic_shapes import ShapeEnv constructor_event = events[0] @@ -394,7 +406,7 @@ def dim(self) -> int: return len(self.tensor_size) @staticmethod - def from_fake(fake) -> "FakeTensorMeta": + def from_fake(fake: torch.Tensor) -> FakeTensorMeta: return FakeTensorMeta( fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested ) @@ -446,7 +458,12 @@ def from_fake(fake) -> "FakeTensorMeta": # Checks whether the state of two ShapeEnv are equal w.r.t. the guards # returned by ShapeEnv.produce_guards. -def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): +def shape_env_check_state_equal( + env1: ShapeEnv, + env2: ShapeEnv, + non_state_variable_names: tuple[str, ...], + map_value: Callable[[str, object], object], +) -> None: # Collect and remove variables that don't necessarily represent the state # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the # instance itself. @@ -476,7 +493,7 @@ def value_to_str(value: Any) -> str: # Here, we allow the value of each field to be mapped, so that we appropriately # compare the two values. def compare_vars( - map_value: Callable[[str, Any], Any], + map_value: Callable[[str, object], object], ) -> list[tuple[str, str, str]]: env1_set, env2_set = set(env1_vars), set(env2_vars) diff --git a/torch/fx/experimental/refinement_types.py b/torch/fx/experimental/refinement_types.py index 8e92163a2139c..f2f5a57db6e34 100644 --- a/torch/fx/experimental/refinement_types.py +++ b/torch/fx/experimental/refinement_types.py @@ -1,5 +1,5 @@ class Equality: - def __init__(self, lhs: object, rhs: object): + def __init__(self, lhs: object, rhs: object) -> None: self.lhs = lhs self.rhs = rhs diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 73e9abaa8e30f..6505bb82aa1aa 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs import ast import copy import functools @@ -31,7 +29,7 @@ class AST_Rewriter(ast.NodeTransformer): # a disable here. This function is an optimization pass and not really # suitable for dynamo tracing anyways. @torch._dynamo.disable - def rewrite(self, fn: FunctionType): + def rewrite(self, fn: FunctionType) -> FunctionType: # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) sourcelines = normalize_source_lines(sourcelines) @@ -53,7 +51,9 @@ def rewrite(self, fn: FunctionType): fn_compiled = globals_dict[new_keys[0]] # return the compiled function with the original globals - def change_func_globals(f, globals): + def change_func_globals( + f: FunctionType, globals: dict[str, object] + ) -> FunctionType: """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" # __globals__ is a private member of the function class # so we have to copy the function, f, all of its member, except f.__globals__ @@ -66,12 +66,12 @@ def change_func_globals(f, globals): ) g = functools.update_wrapper(g, f) g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] - return g + return g # pyrefly: ignore [bad-return] # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) - def visit_Assert(self, node): + def visit_Assert(self, node: ast.Assert) -> ast.Expr: """ Swap out the Assert node (Python's `assert`) with a callsite to the symbolically-traceable torch._assert function @@ -93,7 +93,7 @@ def visit_Assert(self, node): # a replacement for the original _assert node return ast.copy_location(expr_wrapper, node) - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.Assign: """ Swap out Python's AnnAssign with an Assign node where the annotation function is called. Example: @@ -106,6 +106,7 @@ def visit_AnnAssign(self, node): targets=[node.target], value=ast.Call( func=ast.Name(id="annotate", ctx=ast.Load()), + # pyrefly: ignore [bad-argument-type] args=[node.value, node.annotation], keywords=[], ), @@ -115,20 +116,22 @@ def visit_AnnAssign(self, node): class RewritingTracer(Tracer): def trace( self, - root: torch.nn.Module | Callable, + root: torch.nn.Module | Callable[..., Any], concrete_args: dict[str, Any] | None = None, ) -> Graph: return super().trace(_rewrite(root), concrete_args) -def _rewrite(fn: torch.nn.Module | Callable) -> torch.nn.Module | Callable: +def _rewrite( + fn: torch.nn.Module | Callable[..., Any], +) -> torch.nn.Module | Callable[..., Any]: if isinstance(fn, torch.nn.Module): # Rewrite this module's `forward` as well as the `forward`s of # all of this module's recursive descendents. Return the new, # rewritten module hierarchy. - def rewrite_module(m: torch.nn.Module): + def rewrite_module(m: torch.nn.Module) -> torch.nn.Module: class RewrittenModule(torch.nn.Module): - def __init__(self, orig): + def __init__(self, orig: torch.nn.Module) -> None: super().__init__() for k, v in orig.__dict__.items(): if isinstance(v, torch.nn.Module): diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index b507c17822116..ac31752402612 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import inspect from typing import Any @@ -6,8 +5,13 @@ import torch.fx from torch._jit_internal import boolean_dispatched from torch.fx import Transformer +from torch.fx.graph_module import GraphModule from torch.fx.node import Argument, Target from torch.fx.operator_schemas import _torchscript_type_to_python_type +from torch.fx.proxy import Proxy + + +__all__ = ["AnnotateTypesWithSchema"] class AnnotateTypesWithSchema(Transformer): @@ -31,11 +35,11 @@ class AnnotateTypesWithSchema(Transformer): def __init__( self, - module: torch.nn.Module, + module: GraphModule, annotate_functionals: bool = True, annotate_modules: bool = True, annotate_get_attrs: bool = True, - ): + ) -> None: super().__init__(module) self.annotate_functionals = annotate_functionals self.annotate_modules = annotate_modules @@ -43,7 +47,7 @@ def __init__( def call_function( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] - ): + ) -> Proxy: python_ret_type = None if self.annotate_functionals and target.__module__ == "torch.nn.functional": target_for_analysis = target @@ -75,7 +79,7 @@ def call_function( def call_module( self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] - ): + ) -> Proxy: python_ret_type = None if not isinstance(target, str): raise AssertionError(f"Expected str target, got {type(target)}") @@ -92,10 +96,10 @@ def call_module( def get_attr( self, - target: torch.fx.node.Target, + target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any], - ): + ) -> Proxy: attr_proxy = super().get_attr(target, args, kwargs) if self.annotate_get_attrs: diff --git a/torch/fx/experimental/shape_inference/infer_shape.py b/torch/fx/experimental/shape_inference/infer_shape.py index 10f5d53712aeb..24965bff4053b 100644 --- a/torch/fx/experimental/shape_inference/infer_shape.py +++ b/torch/fx/experimental/shape_inference/infer_shape.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import copy from collections import defaultdict @@ -10,6 +9,7 @@ infer_symbol_values, ) from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from torch.types import IntLikeType from torch.utils import _pytree @@ -18,7 +18,11 @@ """ -def infer_shape(gm, input_tensors): +def infer_shape( + gm: torch.fx.GraphModule, input_tensors: list[torch.Tensor] +) -> ( + tuple[torch.fx.GraphModule, list[torch.Tensor], FakeTensorMode, IntLikeType] | None +): # Prepare environments shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env, allow_non_fake_inputs=True) @@ -29,11 +33,11 @@ def infer_shape(gm, input_tensors): dim_count += input_tensor.dim() - 1 sample = {f"s{i}": 2 for i in range(dim_count)} - init_symints = [ + init_symints: list[IntLikeType] = [ mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k, v in sample.items() ] - symints = copy.deepcopy(init_symints) + symints: list[IntLikeType] = copy.deepcopy(init_symints) symbol_to_idx_dict = {f"s{i}": i for i in range(dim_count)} padding_constraints = defaultdict(list) # type: ignore[var-annotated] @@ -87,7 +91,9 @@ def infer_shape(gm, input_tensors): allowed_try_times -= 1 -def mksym(shape_env, value, source, dynamic_dim): +def mksym( + shape_env: ShapeEnv, value: int, source: LocalSource, dynamic_dim: DimDynamic +) -> IntLikeType: return shape_env.create_symintnode( shape_env.create_symbol( value, diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 5e52592b314a3..5867ec0dbc8c3 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - from __future__ import annotations @@ -23,13 +21,13 @@ import operator import sys from functools import lru_cache, update_wrapper -from typing import TYPE_CHECKING +from typing import Any, overload, TYPE_CHECKING import torch import torch._logging.structured as structured # NB: The sym_* functions are used via getattr() and must be imported here. -from torch import ( # noqa: F401 +from torch import ( sym_float, sym_ite, sym_max, @@ -43,6 +41,11 @@ if TYPE_CHECKING: + from collections.abc import Callable + from typing import Self + + import sympy + from torch.fx.experimental.symbolic_shapes import ShapeEnv log = logging.getLogger(__name__) @@ -62,7 +65,15 @@ from torch.types import py_sym_types as SymTypes -def _to_symtype(t): +@overload +def _to_symtype(t: type[bool]) -> type[SymBool]: ... +@overload +def _to_symtype(t: type[int]) -> type[SymInt]: ... +@overload +def _to_symtype(t: type[float]) -> type[SymFloat]: ... +@overload +def _to_symtype(t: type) -> type: ... +def _to_symtype(t: type) -> type: if t is bool: return SymBool if t is int: @@ -92,14 +103,14 @@ class SymNode: def __init__( self, - expr, - shape_env, - pytype, - hint: int | float | bool | None, - constant=None, - fx_node=None, - optimized_summation=False, - ): + expr: object, + shape_env: ShapeEnv | None, + pytype: type, + hint: HintType | object, + constant: int | float | bool | None = None, + fx_node: object = None, + optimized_summation: bool = False, + ) -> None: self._expr = expr self.shape_env = shape_env self.pytype = pytype @@ -134,7 +145,7 @@ def __init__( # potential refinements to unbacked symints this got harder to keep # in sync, so we've deleted it for now.) - def compute_hint(): + def compute_hint() -> HintType | SymInt | SymFloat | SymBool: from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols # This occasionally gets exercised by, e.g., @@ -144,6 +155,8 @@ def compute_hint(): # expensive. if has_free_unbacked_symbols(self.expr): return None + if self.shape_env is None: + raise RuntimeError("shape_env is required to compute hint") hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) if hint is not None: hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint @@ -197,9 +210,14 @@ def _value_hash(self) -> int: return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) @property - def expr(self): - if isinstance(self._expr, int) or self._expr.is_number: + def expr(self) -> sympy.Basic: + if ( + isinstance(self._expr, int) + or self._expr.is_number # pyrefly: ignore[missing-attribute] + ): return self._expr + if self.shape_env is None: + raise AssertionError("shape_env is required to access expr") ver = self.shape_env._replacements_version_counter if ver == 0: return self._expr @@ -211,46 +229,20 @@ def expr(self): return result @property - def hint(self): + def hint(self) -> HintType | SymInt | SymFloat | SymBool: return self._hint - def has_hint(self): + def has_hint(self) -> bool: return self._hint is not None - def require_hint(self, fallback=None): - from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols - - if self._hint is None: - if fallback is not None: - # Say we have some expr like 2*u0 + s0 - # The hint will be None, since the expr contains at least 1 unbacked. - # We will: - # - replace every backed free symbol with its corresponding hint - # - replace every unbacked free symbol with the fallback - # - regenerate the expression with those symbol replacements - # Note: this is not really complete either, since right now - # this logic does not take into account any value ranges - # for the unbacked symints, we may need to beef it up at some point. - unbacked_symbols = free_unbacked_symbols(self.expr) - replacements = { - s: fallback - if s in unbacked_symbols - else self.shape_env.backed_var_to_val[s] - for s in self.expr.free_symbols - } - return int(self.expr.xreplace(replacements)) - # NB: we expect this to raise - return self.shape_env.size_hint(self.expr) - return self._hint - - def maybe_as_int(self): + def maybe_as_int(self) -> int | None: if self.expr.is_number: return int(self.expr) else: return None # NB: This does conversions, not sure if this is good or not - def maybe_as_float(self): + def maybe_as_float(self) -> float | None: import sympy if isinstance(self.expr, sympy.Float): @@ -258,7 +250,7 @@ def maybe_as_float(self): else: return None - def maybe_as_bool(self): + def maybe_as_bool(self) -> bool | None: import sympy if self.expr is sympy.true: @@ -268,16 +260,16 @@ def maybe_as_bool(self): else: return None - def is_int(self): + def is_int(self) -> bool: return self.pytype is int - def is_float(self): + def is_float(self) -> bool: return self.pytype is float - def is_bool(self): + def is_bool(self) -> bool: return self.pytype is bool - def is_nested_int(self): + def is_nested_int(self) -> bool: # Unbacked SymInts cannot be nested int today return ( self._hint is not None @@ -285,7 +277,7 @@ def is_nested_int(self): and self._hint.node.is_nested_int() ) - def wrap_int(self, num): + def wrap_int(self, num: int) -> SymNode: if type(num) is not int: raise AssertionError(f"Expected int, got {type(num)}") import sympy @@ -294,7 +286,7 @@ def wrap_int(self, num): sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num ) - def wrap_float(self, num): + def wrap_float(self, num: float) -> SymNode: if type(num) is not float: raise AssertionError(f"Expected float, got {type(num)}") import sympy @@ -303,7 +295,7 @@ def wrap_float(self, num): sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num ) - def wrap_bool(self, num): + def wrap_bool(self, num: bool) -> SymNode: if type(num) is not bool: raise AssertionError(f"Expected bool, got {type(num)}") import sympy @@ -317,16 +309,16 @@ def wrap_bool(self, num): fx_node=num, ) - def clone(self): + def clone(self) -> SymNode: return self - def str(self): + def str(self) -> builtins.str: return f"{self.expr}" - def __str__(self): + def __str__(self) -> builtins.str: return self.str() - def __repr__(self): + def __repr__(self) -> builtins.str: rep = [ f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", ] @@ -350,70 +342,70 @@ def abs(self) -> SymNode: def pos(self) -> SymNode: return self._pos() # type: ignore[attr-defined] - def round(self, ndigits=None) -> SymNode: + def round(self, ndigits: int | None = None) -> SymNode: return self._round(ndigits) # type: ignore[attr-defined] def trunc(self) -> SymNode: return self._trunc() # type: ignore[attr-defined] - def add(self, other) -> SymNode: + def add(self, other: SymNode) -> SymNode: return self._add(other) # type: ignore[attr-defined] - def sub(self, other) -> SymNode: + def sub(self, other: SymNode) -> SymNode: return self._sub(other) # type: ignore[attr-defined] - def mul(self, other) -> SymNode: + def mul(self, other: SymNode) -> SymNode: return self._mul(other) # type: ignore[attr-defined] - def mod(self, other) -> SymNode: + def mod(self, other: SymNode) -> SymNode: return self._mod(other) # type: ignore[attr-defined] - def float_pow(self, other) -> SymNode: + def float_pow(self, other: SymNode) -> SymNode: return self._float_pow(other) # type: ignore[attr-defined] - def pow_by_natural(self, other) -> SymNode: + def pow_by_natural(self, other: SymNode) -> SymNode: return self._pow_by_natural(other) # type: ignore[attr-defined] - def and_(self, other) -> SymNode: + def and_(self, other: SymNode) -> SymNode: return self._and_(other) # type: ignore[attr-defined] - def or_(self, other) -> SymNode: + def or_(self, other: SymNode) -> SymNode: return self._or_(other) # type: ignore[attr-defined] - def float_truediv(self, other) -> SymNode: + def float_truediv(self, other: SymNode) -> SymNode: return self._float_truediv(other) # type: ignore[attr-defined] - def int_truediv(self, other) -> SymNode: + def int_truediv(self, other: SymNode) -> SymNode: return self._int_truediv(other) # type: ignore[attr-defined] - def int_floordiv(self, other) -> SymNode: + def int_floordiv(self, other: SymNode) -> SymNode: return self._int_floordiv(other) # type: ignore[attr-defined] - def lshift(self, other) -> SymNode: + def lshift(self, other: SymNode) -> SymNode: return self._lshift(other) # type: ignore[attr-defined] - def rshift(self, other) -> SymNode: + def rshift(self, other: SymNode) -> SymNode: return self._rshift(other) # type: ignore[attr-defined] - def sym_not(self) -> SymNode: # noqa: F811 + def sym_not(self) -> SymNode: return self._sym_not() # type: ignore[attr-defined] - def eq(self, other) -> SymNode: + def eq(self, other: SymNode) -> SymNode: return self._eq(other) # type: ignore[attr-defined] - def ne(self, other) -> SymNode: + def ne(self, other: SymNode) -> SymNode: return self._ne(other) # type: ignore[attr-defined] - def gt(self, other) -> SymNode: + def gt(self, other: SymNode) -> SymNode: return self._gt(other) # type: ignore[attr-defined] - def lt(self, other) -> SymNode: + def lt(self, other: SymNode) -> SymNode: return self._lt(other) # type: ignore[attr-defined] - def le(self, other) -> SymNode: + def le(self, other: SymNode) -> SymNode: return self._le(other) # type: ignore[attr-defined] - def ge(self, other) -> SymNode: + def ge(self, other: SymNode) -> SymNode: return self._ge(other) # type: ignore[attr-defined] def floor(self) -> SymNode: @@ -422,7 +414,7 @@ def floor(self) -> SymNode: def is_integer(self) -> SymNode: return self._is_integer() # type: ignore[attr-defined] - def sym_float(self) -> SymNode: # noqa: F811 + def sym_float(self) -> SymNode: return self._sym_float() # type: ignore[attr-defined] def sym_int(self) -> SymNode: @@ -434,74 +426,86 @@ def ceil(self) -> SymNode: def neg(self) -> SymNode: return self._neg() # type: ignore[attr-defined] - def sym_min(self, other) -> SymNode: # noqa: F811 + def sym_min(self, other: SymNode) -> SymNode: return self._sym_min(other) # type: ignore[attr-defined] - def sym_max(self, other) -> SymNode: # noqa: F811 + def sym_max(self, other: SymNode) -> SymNode: return self._sym_max(other) # type: ignore[attr-defined] - def sym_ite(self, then_val, else_val) -> SymNode: + def sym_ite(self, then_val: SymNode, else_val: SymNode) -> SymNode: return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] - def is_contiguous(self, sizes, strides) -> SymNode: + def is_contiguous(self, sizes: list[SymNode], strides: list[SymNode]) -> SymNode: return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] - def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode: + def is_channels_last_contiguous_2d( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] - def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode: + def is_channels_last_contiguous_3d( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] - def is_channels_last_strides_2d(self, sizes, strides) -> SymNode: + def is_channels_last_strides_2d( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] - def is_channels_last_strides_3d(self, sizes, strides) -> SymNode: + def is_channels_last_strides_3d( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] - def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode: + def is_non_overlapping_and_dense_indicator( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] # Make C++ happy - def sym_or(self, other): + def sym_or(self, other: SymNode) -> SymNode: return self.or_(other) - def sym_and(self, other): + def sym_and(self, other: SymNode) -> SymNode: return self.and_(other) # Integer bitwise ops - def bitwise_and(self, other): + def bitwise_and(self, other: SymNode) -> SymNode: return self._bitwise_and(other) # type: ignore[attr-defined] - def bitwise_or(self, other): + def bitwise_or(self, other: SymNode) -> SymNode: return self._bitwise_or(other) # type: ignore[attr-defined] - def bitwise_xor(self, other): + def bitwise_xor(self, other: SymNode) -> SymNode: return self._bitwise_xor(other) # type: ignore[attr-defined] # There is no int_truediv available from C++ - def truediv(self, other): + def truediv(self, other: SymNode) -> SymNode: return self.float_truediv(other) - def floordiv(self, other) -> SymNode: + def floordiv(self, other: SymNode) -> SymNode: return self.int_floordiv(other) # We didn't bind integer pow in C++ - def pow(self, other): + def pow(self, other: SymNode) -> SymNode: return self.float_pow(other) - def is_non_overlapping_and_dense(self, sizes, strides): + def is_non_overlapping_and_dense( + self, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( to_node(self, 1) ) # type: ignore[attr-defined] - def int_(self): + def int_(self) -> int: return self.guard_int("", 0) # NB: uses Python backtrace # This one is currently done by hand, but if we add other variadic # functions consider factoring it out to be metaprogrammed too. Note that # some load bearing logic is directly in torch.sym_sum - def sym_sum(self, args) -> SymNode: + def sym_sum(self, args: list[SymNode]) -> SymNode: import sympy # Inner impl @@ -523,14 +527,16 @@ def sym_sum(self, args) -> SymNode: out = sympy.Add(*exprs) size_hints = [] - out_hint = None + out_hint: object = _NO_HINT for a in args: if a.hint is None: break size_hints.append(a.hint) else: - out_hint = sum(size_hints) + out_hint = sum(size_hints) # pyrefly: ignore[no-matching-overload] + if self.shape_env is None: + raise RuntimeError("shape_env is required for sym_sum") fx_node, _ = self.shape_env._create_fx_call_function( torch.sym_sum, (tuple(a.fx_node for a in args),) ) @@ -538,11 +544,13 @@ def sym_sum(self, args) -> SymNode: # NB: Only for integers! return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node) - def evaluate(self, size_oblivious=False): + def evaluate(self, size_oblivious: bool = False) -> bool | int | float: + if self.shape_env is None: + raise RuntimeError("shape_env is required to evaluate") return self.shape_env.evaluate_sym_node(self, size_oblivious) # You can manually trigger a guard with this function - def guard_int(self, file, line): + def guard_int(self, file: builtins.str, line: int) -> int: # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.evaluate() @@ -552,7 +560,7 @@ def guard_int(self, file, line): log.warning("Failed to convert to int: %s", r) raise - def guard_float(self, file, line): + def guard_float(self, file: builtins.str, line: int) -> float: # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.evaluate() @@ -562,7 +570,7 @@ def guard_float(self, file, line): log.warning("Failed to convert to float: %s", r) raise - def guard_bool(self, file, line): + def guard_bool(self, file: builtins.str, line: int) -> bool: # TODO: use the file/line for some useful diagnostic on why a # guard occurred r = self.evaluate() @@ -572,9 +580,11 @@ def guard_bool(self, file, line): log.warning("Failed to convert to bool: %s", r) raise - def expect_true(self, file, line): + def expect_true(self, file: builtins.str, line: int) -> bool: from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + if self.shape_env is None: + raise RuntimeError("shape_env is required for expect_true") if ( self.has_hint() and not free_unbacked_symbols(self.expr) @@ -590,14 +600,14 @@ def expect_true(self, file, line): self.expr, f"{file}:{line}", fx_node=self.fx_node ) - def statically_known_true(self, file, line): + def statically_known_true(self, file: builtins.str, line: int) -> bool: from torch.fx.experimental.symbolic_shapes import statically_known_true if not self.is_bool(): raise AssertionError("Expected bool type") return statically_known_true(SymBool(self)) - def guard_size_oblivious(self, file, line): + def guard_size_oblivious(self, file: builtins.str, line: int) -> bool: """ Like guard_bool, but if we encounter unbacked symbols, if those symbols are size-like, we will treat them as >= 2 for the purposes of the analysis. @@ -617,35 +627,35 @@ def guard_size_oblivious(self, file, line): log.warning("Failed to convert to bool: %s", r) raise - def guard_or_false(self, file, line): + def guard_or_false(self, file: builtins.str, line: int) -> bool: from torch.fx.experimental.symbolic_shapes import guard_or_false if not self.is_bool(): raise AssertionError("Expected bool type") return guard_or_false(SymBool(self)) - def guard_or_true(self, file, line): + def guard_or_true(self, file: builtins.str, line: int) -> bool: from torch.fx.experimental.symbolic_shapes import guard_or_true if not self.is_bool(): raise AssertionError("Expected bool type") return guard_or_true(SymBool(self)) - def bool_(self): + def bool_(self) -> bool: return self.guard_bool("", 0) - def is_symbolic(self): + def is_symbolic(self) -> bool: return True - def nested_int(self): + def nested_int(self) -> None: return None - def is_constant(self): + def is_constant(self) -> bool: return False class _DynamicScalar: - def __new__(cls, *args): + def __new__(cls, *args: object) -> Self: if cls is _DynamicScalar: raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.") return super().__new__(cls, *args) @@ -663,21 +673,44 @@ class DynamicInt(_DynamicScalar, int): fn(x) # compiles x as a dynamic integer input; returns f(4) """ - def __new__(cls, val): + def __new__(cls, val: int) -> Self: if not isinstance(val, int): raise AssertionError(f"Expected int, got {type(val)}") obj = super().__new__(cls, int(val)) return obj - def __repr__(self): + def __repr__(self) -> str: return f"DynamicInt({self.real})" - def __floordiv__(self, other): # // was casting to int without these overrides? + def __floordiv__( + self, other: int + ) -> DynamicInt: # // was casting to int without these overrides? return DynamicInt(self.real // other) - def __rfloordiv__(self, other): + def __rfloordiv__(self, other: int) -> DynamicInt: return DynamicInt(other // self.real) + def __pow__(self, other, modulo=None): + if modulo is not None: + result = pow(self.real, other, modulo) + else: + result = self.real**other + # Only create DynamicInt if result is int, otherwise return plain value + # (e.g., negative exponent produces float) + if isinstance(result, int): + return DynamicInt(result) + return result + + def __rpow__(self, other, modulo=None): + if modulo is not None: + result = pow(other, self.real, modulo) + else: + result = other**self.real + # Only create DynamicInt if result is int, otherwise return plain value + if isinstance(result, int): + return DynamicInt(result) + return result + # TODO: this probably needs the sizes-strides eval functions METHOD_TO_OPERATOR = { @@ -732,8 +765,8 @@ def __rfloordiv__(self, other): # Adding math ops: sqrt, cos, sin, ... -def _get_sym_node_fn(name): - def fn(self): +def _get_sym_node_fn(name: str) -> Callable[[SymNode], SymNode]: + def fn(self: SymNode) -> SymNode: return getattr(self, f"_sym_{name}")() return fn @@ -758,6 +791,7 @@ def fn(self): setattr(SymNode, sym_name, _get_sym_node_fn(name)) METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) unary_magic_methods.add(sym_name) + # pyrefly: ignore [unresolvable-dunder-all] __all__.append(sym_name) @@ -811,25 +845,25 @@ def fn(self): # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_float_truediv(a, b): +def _sympy_float_truediv(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import FloatTrueDiv return FloatTrueDiv(a, b) -def _sympy_int_truediv(a, b): +def _sympy_int_truediv(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import IntTrueDiv return IntTrueDiv(a, b) -def _sympy_floordiv(a, b): +def _sympy_floordiv(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import FloorDiv return FloorDiv(a, b) -def _sympy_mod(a, b): +def _sympy_mod(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import Mod, PythonMod if a.is_nonnegative and b.is_nonnegative: @@ -838,43 +872,45 @@ def _sympy_mod(a, b): return PythonMod(a, b) -def _sympy_pow_by_natural(a, b): +def _sympy_pow_by_natural(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import PowByNatural return PowByNatural(a, b) -def _sympy_float_pow(a, b): +def _sympy_float_pow(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import FloatPow return FloatPow(a, b) -def _sympy_and(a, b): +def _sympy_and(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.And(a, b) -def _sympy_or(a, b): +def _sympy_or(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Or(a, b) -def _sympy_lshift(a, b): +def _sympy_lshift(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import LShift return LShift(a, b) -def _sympy_rshift(a, b): +def _sympy_rshift(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import RShift return RShift(a, b) -def _binary_search_insert_arg(ordered_args, new_arg): +def _binary_search_insert_arg( + ordered_args: list[sympy.Basic], new_arg: sympy.Basic +) -> list[sympy.Basic] | None: """ If new_arg is found in ordered_args None is returned, else the new ordered_args with new_arg inserted @@ -909,8 +945,11 @@ def _binary_search_insert_arg(ordered_args, new_arg): def _optimized_add( - lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False -): + lhs: sympy.Basic, + rhs: sympy.Basic, + lhs_is_optimized_summation: bool = False, + rhs_is_optimized_summation: bool = False, +) -> tuple[bool, sympy.Basic]: """ Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols, @@ -924,7 +963,7 @@ def _optimized_add( import sympy from sympy.core.basic import _args_sortkey as sortkey - def make_optimized(ordered_args): + def make_optimized(ordered_args: list[sympy.Basic]) -> tuple[bool, sympy.Basic]: if ordered_args is None: raise AssertionError("ordered_args is None") # Use _from_args directly to bypass _exec_constructor_postprocessors @@ -976,19 +1015,19 @@ def make_optimized(ordered_args): return (_is_symbols_binary_summation(result), result) -def _bitwise_and(a, b): +def _bitwise_and(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import BitwiseFn_bitwise_and return BitwiseFn_bitwise_and(a, b) -def _bitwise_or(a, b): +def _bitwise_or(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import BitwiseFn_bitwise_or return BitwiseFn_bitwise_or(a, b) -def _bitwise_xor(a, b): +def _bitwise_xor(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import BitwiseFn_bitwise_xor return BitwiseFn_bitwise_xor(a, b) @@ -1014,7 +1053,7 @@ def _bitwise_xor(a, b): } -def _floor_ceil_helper(a, fn): +def _floor_ceil_helper(a: sympy.Basic, fn: Callable[..., sympy.Basic]) -> sympy.Basic: import sympy if isinstance(a, sympy.Mul): @@ -1032,7 +1071,7 @@ def _floor_ceil_helper(a, fn): return fn(a) -def _sympy_floor(a): +def _sympy_floor(a: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import FloorToInt return FloorToInt(a) @@ -1040,67 +1079,67 @@ def _sympy_floor(a): # NB: this is Python trunc semantics which returns an int. Do NOT use this to # represent torch.trunc (which is float to float) -def _sympy_trunc(a): +def _sympy_trunc(a: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import TruncToInt return TruncToInt(a) -def _sympy_ceil(a): +def _sympy_ceil(a: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import CeilToInt return CeilToInt(a) -def _sympy_eq(a, b): +def _sympy_eq(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Eq(a, b) -def _sympy_ne(a, b): +def _sympy_ne(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Ne(a, b) -def _sympy_gt(a, b): +def _sympy_gt(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Gt(a, b) -def _sympy_lt(a, b): +def _sympy_lt(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Lt(a, b) -def _sympy_le(a, b): +def _sympy_le(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Le(a, b) -def _sympy_ge(a, b): +def _sympy_ge(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: import sympy return sympy.Ge(a, b) -def _sympy_min(a, b): +def _sympy_min(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import Min return Min(a, b) -def _sympy_max(a, b): +def _sympy_max(a: sympy.Basic, b: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import Max return Max(a, b) -def _sympy_ite(a, t, f): +def _sympy_ite(a: sympy.Basic, t: sympy.Basic, f: sympy.Basic) -> sympy.Basic: import sympy return sympy.Piecewise((t, a), (f, True)) @@ -1109,8 +1148,8 @@ def _sympy_ite(a, t, f): current_module = sys.modules[__name__] -def _get_sym_math_fn(name): - def fn(a): +def _get_sym_math_fn(name: str) -> Callable[[sympy.Basic], sympy.Basic]: + def fn(a: sympy.Basic) -> sympy.Basic: import torch.utils._sympy.functions return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) @@ -1127,13 +1166,15 @@ def fn(a): del fn, name, priv_sympy_name # type: ignore[possibly-undefined] -def _sympy_abs(a): +def _sympy_abs(a: sympy.Basic) -> sympy.Basic: import sympy return sympy.Abs(a) -def _sympy_round(number, ndigits=None): +def _sympy_round( + number: sympy.Basic, ndigits: sympy.Basic | None = None +) -> sympy.Basic: from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: @@ -1142,7 +1183,7 @@ def _sympy_round(number, ndigits=None): return RoundDecimal(number, ndigits) -def _sympy_sym_float(a): +def _sympy_sym_float(a: sympy.Basic) -> sympy.Basic: from torch.utils._sympy.functions import ToFloat # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly @@ -1150,7 +1191,7 @@ def _sympy_sym_float(a): return ToFloat(a) -def _sympy_is_integer(a): +def _sympy_is_integer(a: sympy.Basic) -> sympy.Basic: import sympy from torch.utils._sympy.functions import ToFloat @@ -1189,12 +1230,16 @@ def _sympy_is_integer(a): del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] -def sympy_is_contiguous(sizes, strides): +def sympy_is_contiguous( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: dim = len(sizes) return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) -def sympy_is_contiguous_generic(sizes, strides, dim_order): +def sympy_is_contiguous_generic( + sizes: list[sympy.Basic], strides: list[sympy.Basic], dim_order: list[int] +) -> sympy.Basic: import sympy dim = len(sizes) @@ -1218,15 +1263,21 @@ def sympy_is_contiguous_generic(sizes, strides, dim_order): # happens you will need to refactor this -def sympy_is_channels_last_contiguous_2d(sizes, strides): +def sympy_is_channels_last_contiguous_2d( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) -def sympy_is_channels_last_contiguous_3d(sizes, strides): +def sympy_is_channels_last_contiguous_3d( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) -def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): +def sympy_is_channels_last_strides_generic( + sizes: list[sympy.Basic], strides: list[sympy.Basic], dim_order: list[int] +) -> sympy.Basic: import sympy from torch.utils._sympy.functions import Max @@ -1266,15 +1317,21 @@ def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): return r -def sympy_is_channels_last_strides_2d(sizes, strides): +def sympy_is_channels_last_strides_2d( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) -def sympy_is_channels_last_strides_3d(sizes, strides): +def sympy_is_channels_last_strides_3d( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) -def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): +def _sympy_is_non_overlapping_and_dense_indicator( + sizes: list[sympy.Basic], strides: list[sympy.Basic] +) -> sympy.Basic: from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator return IsNonOverlappingAndDenseIndicator(*sizes, *strides) @@ -1292,7 +1349,7 @@ def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): } -def to_node(self, num): +def to_node(self: SymNode, num: object) -> SymNode: if isinstance(num, SymTypes): return num.node elif type(num) is bool: @@ -1307,7 +1364,7 @@ def to_node(self, num): return NotImplemented -def wrap_node(x): +def wrap_node(x: SymNode) -> SymInt | SymFloat | SymBool | int | float | bool: # TODO: let C++ also take advantage of this if isinstance(x, SymNode) and x.constant is not None: return x.constant @@ -1321,11 +1378,11 @@ def wrap_node(x): raise AssertionError(f"unrecognized return type {x}") -def method_to_operator(method): +def method_to_operator(method: str) -> Callable[..., object]: return METHOD_TO_OPERATOR[method] -def _make_node_magic(method, func): +def _make_node_magic(method: str, func: Callable[..., sympy.Basic]) -> None: func = lru_cache(256)(func) if method in magic_methods_on_operator_with_trailing_underscore: @@ -1350,9 +1407,9 @@ def uninteresting_files() -> set[str]: | {""} ) - def capture_provenance(fn): + def capture_provenance(fn: Callable[..., SymNode]) -> Callable[..., SymNode]: @functools.wraps(fn) - def wrapper(self, other=None): + def wrapper(self: SymNode, other: SymNode | None = None) -> SymNode: if other is None: result = fn(self) else: @@ -1363,7 +1420,7 @@ def wrapper(self, other=None): else: arguments = [self] - def get_id(sym_node) -> int | None: + def get_id(sym_node: SymNode) -> int | None: # We don't want to return an ID if the input is a constant import sympy @@ -1397,7 +1454,7 @@ def get_id(sym_node) -> int | None: return wrapper @capture_provenance - def binary_magic_impl(self, other): + def binary_magic_impl(self: SymNode, other: SymNode) -> SymNode: from torch.fx.experimental.proxy_tensor import ( get_proxy_mode, handle_sym_dispatch, @@ -1423,6 +1480,8 @@ def binary_magic_impl(self, other): # Special handling for mod that requires access to the value # ranges shape_env = self.shape_env + if shape_env is None: + raise AssertionError("shape_env is required for mod") if ( self.expr.is_nonnegative or shape_env.bound_sympy(self.expr).lower >= 0 @@ -1506,6 +1565,8 @@ def binary_magic_impl(self, other): # Create a FX node that corresponds to the operation being applied to # this node. + if self.shape_env is None: + raise RuntimeError("shape_env is required for binary op") fx_node, _ = self.shape_env._create_fx_call_function( op, (self.fx_node, other.fx_node) ) @@ -1521,7 +1582,7 @@ def binary_magic_impl(self, other): return result @capture_provenance - def unary_magic_impl(self): + def unary_magic_impl(self: SymNode) -> SymNode: from torch.fx.experimental.proxy_tensor import ( get_proxy_mode, handle_sym_dispatch, @@ -1532,6 +1593,8 @@ def unary_magic_impl(self): return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) # TODO: consider constant prop here expr = self.expr + if self.shape_env is None: + raise RuntimeError("shape_env is required for unary op") if method == "floor" or method == "ceiling": expr = self.shape_env._simplify_floor_div(expr) @@ -1561,13 +1624,20 @@ def unary_magic_impl(self): setattr(SymNode, f"_{method_attr}", unary_magic_impl) elif method == "sym_ite": - def sym_ite_impl(pred_node, then_node, else_node): + def sym_ite_impl( + pred_node: SymNode, then_node: SymNode, else_node: SymNode + ) -> SymNode: from torch.fx.experimental.proxy_tensor import ( get_proxy_mode, handle_sym_dispatch, ) - out_hint = then_node.hint if pred_node.hint else else_node.hint + if pred_node.hint is None: + out_hint = None + elif pred_node.hint: + out_hint = then_node.hint + else: + out_hint = else_node.hint if get_proxy_mode(): return to_node( pred_node, @@ -1594,6 +1664,8 @@ def sym_ite_impl(pred_node, then_node, else_node): ) raise + if pred_node.shape_env is None: + raise RuntimeError("shape_env is required for sym_ite") fx_node, _ = pred_node.shape_env._create_fx_call_function( sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) ) @@ -1604,7 +1676,7 @@ def sym_ite_impl(pred_node, then_node, else_node): setattr(SymNode, f"_{method_attr}", sym_ite_impl) elif method == "round": - def round_impl(self, ndigits=None): + def round_impl(self: SymNode, ndigits: int | None = None) -> SymNode: from torch.fx.experimental.proxy_tensor import ( get_proxy_mode, handle_sym_dispatch, @@ -1630,7 +1702,9 @@ def round_impl(self, ndigits=None): out_hint = None if self.hint is not None: - out_hint = op(self.hint, ndigits) + out_hint = op( # pyrefly: ignore[no-matching-overload] + self.hint, ndigits + ) # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here @@ -1642,6 +1716,8 @@ def round_impl(self, ndigits=None): args = [self.fx_node] if ndigits is not None: args.append(ndigits) + if self.shape_env is None: + raise RuntimeError("shape_env is required for round") fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) @@ -1650,10 +1726,12 @@ def round_impl(self, ndigits=None): setattr(SymNode, f"_{method_attr}", binary_magic_impl) -def _make_node_sizes_strides(method, func): +def _make_node_sizes_strides(method: str, func: Callable[..., sympy.Basic]) -> None: # NB: don't LRU cache, lots of arguments - def sizes_strides_impl(self, sizes, strides): + def sizes_strides_impl( + self: SymNode, sizes: list[SymNode], strides: list[SymNode] + ) -> SymNode: from torch.fx.experimental.proxy_tensor import ( get_proxy_mode, handle_sym_dispatch, @@ -1706,7 +1784,9 @@ def sizes_strides_impl(self, sizes, strides): # TODO: This is technically hotpath, but in the ideal end state # guards on this will resolve at a higher level so you never # spend time in this code - def sizes_strides_user(sizes, strides): + def sizes_strides_user( + sizes: list[object], strides: list[object] + ) -> SymInt | SymFloat | SymBool | int | float | bool: import sympy from torch.fx.experimental.symbolic_shapes import ( @@ -1722,7 +1802,10 @@ def sizes_strides_user(sizes, strides): ) ) if method == "is_non_overlapping_and_dense_indicator": - return eval_is_non_overlapping_and_dense(sizes, strides) + return eval_is_non_overlapping_and_dense( + sizes, # pyrefly: ignore[bad-argument-type] + strides, # pyrefly: ignore[bad-argument-type] + ) else: # TODO: this is an awful implementation return bool( @@ -1744,7 +1827,7 @@ def sizes_strides_user(sizes, strides): _make_node_sizes_strides(method, func) -def _make_user_magic(method, user_type): +def _make_user_magic(method: str, user_type: type) -> None: # User magic takes care of wrapping the other operand into a node, # so that our internal logic can assume everything is nodes if method in magic_methods_on_operator_with_trailing_underscore: @@ -1752,7 +1835,9 @@ def _make_user_magic(method, user_type): else: method_attr = method - def get_constant(x: SymInt | int | SymFloat | float | SymBool | bool): + def get_constant( + x: SymInt | int | SymFloat | float | SymBool | bool, + ) -> int | float | bool: if isinstance(x, (int, float, bool)): return x if isinstance(x, SymInt): @@ -1761,7 +1846,7 @@ def get_constant(x: SymInt | int | SymFloat | float | SymBool | bool): return x.node.guard_bool("", 0) raise AssertionError("expect to be called with constant SymBools") - def is_constant(x): + def is_constant(x: SymInt | int | SymFloat | float | SymBool | bool) -> bool: if isinstance(x, (int, float, bool)): return True if isinstance(x, (SymInt, SymFloat, SymBool)): @@ -1796,7 +1881,7 @@ def is_constant(x): if method in bool_becomes_int_magic_methods: - def promote(x): + def promote(x: object) -> Any: """Implements True+True=2, which works in python but not sympy""" if isinstance(x, SymBool): return SymInt(x.node.wrap_int(int(x))) @@ -1804,10 +1889,10 @@ def promote(x): else: - def promote(x): + def promote(x: object) -> Any: return x - def promote2(self, other): + def promote2(self: object, other: object) -> tuple[Any, Any]: # TODO: Remove eq and other relations from this list. # CPython has fancy implementations for these to get as much precision # as possible instead of just promoting to float64 and praying, so we @@ -1849,13 +1934,13 @@ def promote2(self, other): # Alternatively, we could also rewrap into constant Symbool (i.e. by # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that # today for no particular reason. - def unary_magic_impl(self): + def unary_magic_impl(self: object) -> Any: self = promote(self) if is_constant(self): return (method_to_operator(method))(get_constant(self)) return wrap_node(getattr(self.node, method_attr)()) - def binary_magic_impl(self, other): + def binary_magic_impl(self: object, other: object) -> Any: if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) @@ -1872,7 +1957,7 @@ def binary_magic_impl(self, other): ret = wrap_node(getattr(self.node, method_attr)(other_node)) return get_constant(ret) if is_constant(ret) else ret - def rbinary_magic_impl(self, other): + def rbinary_magic_impl(self: object, other: object) -> Any: if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): return NotImplemented self = promote(self) @@ -1888,7 +1973,7 @@ def rbinary_magic_impl(self, other): ret = wrap_node(getattr(other_node, method_attr)(self.node)) return get_constant(ret) if is_constant(ret) else ret - def setattrs(user_type, attr, symnode_impl): + def setattrs(user_type: type, attr: str, symnode_impl: object) -> None: """ Registers the SymNode magic method on SymInt/Float/Bool, and optionally registers a corresponding wrapped method on DynamicInt. @@ -1898,8 +1983,10 @@ def setattrs(user_type, attr, symnode_impl): setattr(user_type, attr, symnode_impl) # DynamicInt impl - def dynamic_int_impl(*args): - args = [x.real if isinstance(x, DynamicInt) else x for x in args] + def dynamic_int_impl(*args: object) -> Any: + args = [ # pyrefly: ignore[bad-assignment] + x.real if isinstance(x, DynamicInt) else x for x in args + ] out = getattr(int, attr)(*args) if isinstance(out, int) and not isinstance(out, bool): return DynamicInt(out) @@ -1915,7 +2002,9 @@ def dynamic_int_impl(*args): setattrs(user_type, method, update_wrapper(unary_magic_impl, orig)) elif method == "sym_ite": - def sym_ite_magic_impl(pred, then_val, else_val): + def sym_ite_magic_impl( + pred: SymBool, then_val: object, else_val: object + ) -> Any: pred_node = pred.node then_node = to_node(pred_node, then_val) else_node = to_node(pred_node, else_val) @@ -1930,12 +2019,17 @@ def sym_ite_magic_impl(pred, then_val, else_val): "then_node and else_node must be SymNodes with same pytype" ) ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) - return get_constant(ret) if ret.node.is_constant() else ret + return ( + get_constant(ret) + # pyrefly: ignore[missing-attribute] + if ret.node.is_constant() + else ret + ) setattrs(user_type, f"__{method}__", sym_ite_magic_impl) elif method == "round": - def round_magic_impl(self, ndigits=None): + def round_magic_impl(self: SymFloat, ndigits: int | None = None) -> Any: if is_constant(self): return builtins.round(get_constant(self), ndigits) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3cca936f6172b..6fc6ed685b3a3 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -56,7 +56,7 @@ # NB: The sym_* functions are used via getattr() and must be imported here. from torch import SymBool, SymFloat, SymInt -from torch._C._functorch import get_unwrapped, is_batchedtensor +from torch._C._functorch import get_unwrapped, is_batchedtensor, is_gradtrackingtensor from torch._guards import ShapeGuard, SLoc, Source, TracingContext from torch._library.fake_class_registry import FakeScriptObject from torch._library.opaque_object import is_opaque_value @@ -119,6 +119,54 @@ log = logging.getLogger(__name__) +from torch.fx.experimental._size_hinting import ( + _guarding_hint_or_throw_base, + _optimization_hint_base, +) + + +def guarding_hint_or_throw( + a: torch.SymInt | torch.SymBool | int | bool | SymNode, +) -> int | bool: + """ + Return a concrete hint for a symbolic value, for use in guarding decisions. + + Returns Python bool (True/False) for boolean inputs (SymBool, bool), + and Python int for integer inputs (SymInt, int). + """ + if isinstance(a, SymNode): + if a._hint is not None: + return a._hint # pyrefly: ignore[bad-return] + if a.shape_env is None: + raise AssertionError("shape_env is required for guarding_hint_or_throw") + hint = a.shape_env.guarding_hint_or_throw(a.expr) + a._hint = hint + return hint + if isinstance(a, (torch.SymInt, torch.SymBool)): + return guarding_hint_or_throw(a.node) + if isinstance(a, bool): + return a + if type(a) is not int: + raise AssertionError(f"Expected int, got {type(a)}") + return a + + +def optimization_hint(a: torch.SymInt | int, fallback: int | None = None) -> int: + """ + Return a concrete hint for a symbolic integer, for use in optimization decisions. + + Unlike guarding_hint_or_throw, this function does not add guards and is intended + for optimization purposes only (e.g., memory estimation). + """ + if isinstance(a, torch.SymInt): + if a.node._hint is not None: + return a.node._hint + return a.node.shape_env.optimization_hint(a.node.expr, fallback=fallback) + if type(a) is not int: + raise AssertionError(f"Expected int, got {type(a)}") + return a + + class GuardOnDataDependentSymNode(RuntimeError): cond: sympy.Basic @@ -138,7 +186,8 @@ class _ShapeEnvGuardError(RuntimeError): aten = torch._ops.ops.aten # type: ignore[has-type] __all__ = [ - "size_hint", + "optimization_hint", + "guarding_hint_or_throw", "guard_or_false", "guard_or_true", "has_symbolic_sizes_strides", @@ -152,7 +201,6 @@ class _ShapeEnvGuardError(RuntimeError): "guard_float", "guard_scalar", "canonicalize_bool_expr", - "hint_int", "SYMPY_INTERP", "free_symbols", "is_symbol_binding_fx_node", @@ -370,29 +418,16 @@ def create_contiguous(shape: Sequence[Int]) -> list[Int]: return list(reversed(strides)) -@deprecated("used size_hint instead of hint_int", category=FutureWarning) -def hint_int(a: torch.SymInt | int, fallback: int | None = None) -> int: - return size_hint(a, fallback) +Scalar: TypeAlias = torch.SymInt | torch.SymFloat | torch.SymBool | int | float | bool -def size_hint(a: torch.SymInt | int, fallback: int | None = None) -> int: - """ - Retrieve the hint for an int (based on the underlying real values as observed - at runtime). If no hint is available (e.g., because data dependent shapes), - if fallback is not None, use that instead to hint each unbacked symbol individually - (otherwise raise an error). +def has_guarding_hint(a: Scalar) -> bool: """ - if isinstance(a, torch.SymInt): - return a.node.require_hint(fallback) - if type(a) is not int: - raise AssertionError(f"Expected int, got {type(a)}") - return a - - -Scalar: TypeAlias = torch.SymInt | torch.SymFloat | torch.SymBool | int | float | bool + Check if a symbolic value has a hint available for guarding. - -def has_hint(a: Scalar) -> bool: + Returns True if the value is concrete or if the symbolic node has a hint, + False otherwise. + """ if isinstance(a, SymTypes): return a.node.has_hint() return True @@ -910,7 +945,7 @@ def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr: return expr -def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]: +def is_nested_int(s: IntLikeType | FloatLikeType) -> TypeGuard[SymInt]: return isinstance(s, torch.SymInt) and s.node.is_nested_int() @@ -951,11 +986,14 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: yield val.expr elif isinstance(val, sympy.Basic): yield val - elif isinstance(val, (int, float, bool)): + elif isinstance(val, (int, float, bool, str)): pass elif isinstance(val, (tuple, list)): for s in val: yield from _iterate_exprs(s) + elif isinstance(val, dict): + for s in itertools.chain(val.keys(), val.values()): + yield from _iterate_exprs(s) elif is_sparse_any(val): yield from _iterate_exprs(val.size()) elif isinstance(val, torch.Tensor): @@ -1109,7 +1147,7 @@ def find_symbol_binding_fx_nodes( Returns: A dictionary mapping from sympy Symbols to their binding FX nodes """ - r = {} + r: dict[sympy.Symbol, torch.fx.Node] = {} # NB: Prefer first occurrence of symbol for node in graph.nodes: if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: @@ -1128,7 +1166,7 @@ class Specialization: """ source: TensorPropertySource - check_fn: Callable + check_fn: Callable[[int], bool] # Analogous to ConvertIntSource @@ -1223,7 +1261,7 @@ def expr(s: SymInt | SymFloat | SymBool) -> sympy.Expr: pending = set() r = {} - def match_tensor(a: torch.Tensor, real_tensor: torch.Tensor | None = None): + def match_tensor(a: torch.Tensor, real_tensor: torch.Tensor | None = None) -> None: r.update( go( a.size(), @@ -1277,10 +1315,16 @@ def match_tensor(a: torch.Tensor, real_tensor: torch.Tensor | None = None): a, torch.distributed.tensor.DTensor ): match_tensor(a) - elif isinstance(a, torch.Tensor) and is_batchedtensor(a): + elif isinstance(a, torch.Tensor) and ( + is_batchedtensor(a) or is_gradtrackingtensor(a) + ): unwrapped_tensor = get_unwrapped(a) r.update(go(unwrapped_tensor, path)) - elif isinstance(a, torch.Tensor) and not is_batchedtensor(a): + elif ( + isinstance(a, torch.Tensor) + and not is_batchedtensor(a) + and not is_gradtrackingtensor(a) + ): from torch._subclasses.fake_tensor import FakeTensor if not isinstance(a, FakeTensor): @@ -1700,6 +1744,7 @@ def _advise_is_size(a: SymInt) -> None: isinstance(a, SymInt) and isinstance(a.node, SymNode) and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env is not None and a.node.shape_env.is_unbacked_symint(a.node.expr) ): _constrain_range_for_size(a) @@ -1710,6 +1755,7 @@ def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None: isinstance(a, SymInt) and isinstance(a.node, SymNode) and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env is not None and a.node.shape_env.is_unbacked_symint(a.node.expr) and isinstance(upper_bound, int) # TODO: relax ): @@ -1975,7 +2021,7 @@ class StrictMinMaxConstraint(Constraint): for N=0/1 too. """ - vr: ValueRanges + vr: ValueRanges[sympy.Expr] def render(self, source: Source) -> str: """Format the constrain equation""" @@ -2258,7 +2304,7 @@ def __post_init__(self) -> None: # is used. # TODO(voz): Shape env validation @dataclass(frozen=True, slots=True, kw_only=True) -class StatefulSymbolicContext(StatelessSymbolicContext): +class StatefulSymbolicContext(StatelessSymbolicContext[..., Any]): """ Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via a symbolic_context determination as given by a cache of Source:Symbol. A cache hit @@ -2285,6 +2331,7 @@ class StatefulSymbolicContext(StatelessSymbolicContext): shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = field( default_factory=dict ) + excluded_sizes: tuple[int | None, ...] | None = None @dataclass(frozen=True, slots=True) @@ -2346,7 +2393,8 @@ def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]: - A boolean indicating whether expansion occurred (True if multiple additive expressions were present or if there was at least one additive and one other expression) """ - adds, other = [], [] + adds: list[sympy.Expr] = [] + other: list[sympy.Expr] = [] for arg in args: if arg.is_Add: adds.append(arg) @@ -2440,7 +2488,7 @@ def safe_expand(r: _SympyT) -> _SympyT: class _SymbolInfo(NamedTuple): k: sympy.Symbol - vr: ValueRanges | None + vr: ValueRanges[sympy.Expr] | None val: sympy.Integer | None is_size_like: bool @@ -2625,12 +2673,21 @@ def cast_symbool_to_symint_guardless( return 1 if symbool else 0 int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) return symbool.node.shape_env.create_symintnode( - int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None + int_sym, + hint=guarding_hint_or_throw(symbool) if has_guarding_hint(symbool) else None, ) +def _eval_is_non_overlapping_and_dense_flat(*args: int) -> int: + # Guard code strings print IsNonOverlappingAndDenseIndicator with flat args + # (s0, s1, ..., stride0, stride1, ...) but eval_is_non_overlapping_and_dense + # expects two sequences (sizes, strides). This wrapper bridges the gap. + dim = len(args) // 2 + return eval_is_non_overlapping_and_dense(list(args[:dim]), list(args[dim:])) + + SYMPY_INTERP = { - "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, + "IsNonOverlappingAndDenseIndicator": _eval_is_non_overlapping_and_dense_flat, "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, "math": math, "torch": torch, @@ -2909,7 +2966,7 @@ class _CppShapeGuardsHelper(_ShapeGuardsHelper): class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): - def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): + def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]) -> None: super().__init__(var_to_sources, lambda n: n.name, var_to_sources) @@ -2926,7 +2983,7 @@ def __init__( self, symbol_to_source: dict[sympy.Symbol, list[Source]], source_name_to_debug_name: Mapping[str, str], - ): + ) -> None: super().__init__() self.symbol_to_source = symbol_to_source self.source_name_to_debug_name = source_name_to_debug_name @@ -3167,7 +3224,7 @@ def add_equality(self, source: Source, expr: sympy.Expr) -> None: def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]: reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {} for s, congruences in self._congruences.items(): - remainder_modulus_pairs = [] + remainder_modulus_pairs: list[tuple[sympy.Integer, sympy.Integer]] = [] congruences_to_check = set() for congruence in congruences: base, divisor = congruence.args @@ -3644,8 +3701,8 @@ def relation_with_digit(expr: str, op: str, digit: int) -> None: self._process_derived_dim_roots(results, name_to_dim) - dims = [] - others = [] + dims: list[str] = [] + others: list[str] = [] # order results by source name results2 = { @@ -3716,7 +3773,7 @@ class ValueRangesSLoc: @contextmanager -def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]: +def _suppress_guards(shape_env: ShapeEnv) -> Generator[None, None, None]: shape_env._suppress_guards_enter() try: yield @@ -3879,7 +3936,7 @@ def _init( # are conservative: the int MUST fall in the range, but the # range may contain ints which may not actually appear in # practice - self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} + self.var_to_range: dict[sympy.Symbol, ValueRanges[sympy.Expr]] = {} self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} self.source_name_to_debug_name: dict[str, str] = {} self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} @@ -3902,12 +3959,17 @@ def _init( self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} # Set holds a % b expressions that evaluate to 0. self.divisible: set[sympy.Expr] = set() + # Exclusion constraints from automatic_dynamic transitions. + # Each (symbol, excluded_value) pair represents one dim/scalar that + # transitioned static → dynamic. All pairs are combined into a single + # Or(Ne(...), ...) guard in produce_guards_verbose. + self.exclusion_constraints: list[tuple[sympy.Symbol, int]] = [] # Set that holds "size-like" symbols. When we perform # "size-oblivious" tests, these can be assumed to be >= 2. self.size_like: set[sympy.Symbol] = set() # Duck-shaping says that if two input tensors have the same size, # they get assigned the same symbolic variable - self.val_to_var: dict[int, sympy.Symbol] = {} + self.val_to_var: dict[IntLikeType | FloatLikeType, sympy.Symbol] = {} self.unbacked_symfloat_counter = 0 self.unbacked_symint_counter = 0 # Similar to guards, but these MUST evaluate to true and can @@ -3998,7 +4060,9 @@ def _init( # 2. list of arguments # This drastically reduces the size of the FX graph, avoiding # duplicated nodes. - self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {} + self.fx_node_cache: dict[ + tuple[Callable[..., Any], tuple[Any, ...]], torch.fx.Node + ] = {} self.source_to_symbol: dict[str, sympy.Symbol] = {} # Suppose you want to replace an unbacked symbol with another @@ -4030,6 +4094,11 @@ def _init( self.specialization_stacks: dict[Source, traceback.StackSummary] = {} + # Used by _get_unbacked_replacements / _sub_unbacked_exprs for + # optimization_hint canonicalization of unbacked expressions. + self._equality_graph: dict[sympy.Expr, OrderedSet[sympy.Expr]] | None = None + self._unbacked_replacements: dict[sympy.Expr, sympy.Expr] | None = None + self.trace_asserts = trace_asserts self.specializations: OrderedSet[Specialization] = OrderedSet() @@ -4090,7 +4159,7 @@ def prefer_deferred_runtime_asserts_over_guards(self) -> bool: @contextmanager def patch_source_specialization( self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] - ) -> Iterator[None]: + ) -> Generator[None, None, None]: """ Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph @@ -4163,6 +4232,9 @@ def check_equal(self, other: ShapeEnv) -> None: "_resimplify_floor_div_axioms", "_expr_sym_node_id", "specialization_stacks", + # Cached state for optimization_hint unbacked canonicalization + "_equality_graph", + "_unbacked_replacements", ) # Mapping of the value of each to-be-compared field into the values that @@ -4217,7 +4289,7 @@ def _last_event_index(self) -> int: return len(self.events) - 1 @contextmanager - def _recording(self) -> Iterator[None]: + def _recording(self) -> Generator[None, None, None]: self.is_recording = True try: yield @@ -4229,7 +4301,9 @@ def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr) -> None: self._set_replacement(orig_s, new_s, "eliminate_unbacked") @record_shapeenv_event() - def set_real_tensor_prop_unbacked_vals(self, k: sympy.Symbol, v: int) -> None: + def set_real_tensor_prop_unbacked_vals( + self, k: sympy.Symbol, v: int | float | torch.Tensor + ) -> None: """Used only when propagate_real_tensors; registers a value for an unbacked symbol, which can be used last resort to resolve hints.""" log.info("set_real_tensor_prop_unbacked_vals %s = %s", k, v) @@ -4355,7 +4429,7 @@ def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool: return prev @contextmanager - def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: + def ignore_fresh_unbacked_symbols(self) -> Generator[None, None, None]: """ Indicates that the newly allocated unbacked SymInts are being discarded @@ -4414,8 +4488,8 @@ def _check_translation_validate(self) -> None: @record_shapeenv_event() def _create_fx_call_function( self, - op: Callable, - args: tuple, + op: Callable[..., object], + args: tuple[Any, ...], ) -> tuple[torch.fx.Node | None, bool]: # Cache this tuple in order to avoid duplicated nodes. node_key = (op, args) @@ -4510,7 +4584,7 @@ def suppress_guards(self) -> _GeneratorContextManager[None]: return _suppress_guards(self) @contextmanager - def error_on_new_guards(self) -> Iterator[None]: + def error_on_new_guards(self) -> Generator[None, None, None]: """Context manager that raises _ShapeEnvGuardError if a guard is attempted. Temporarily freezes the ShapeEnv and makes _check_frozen raise @@ -4659,7 +4733,7 @@ def create_symbolic_sizes_strides_storage_offset( # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. # produce_guards will trigger specializations on the outer stuff - # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # Case 2: when the SymInt is unbacked, we will throw a data dependent error in guarding_hint_or_throw(). # # It's probably good for now but it's important to note that this approach has implications for # the original shape_env when checking guards in different order. @@ -4697,7 +4771,7 @@ def _maybe_specialize_sym_int_with_hint( raise AssertionError( "expect the symbol is created from an shape env other than current one." ) - return maybe_sym.node.require_hint() + return guarding_hint_or_throw(maybe_sym.node) return maybe_sym @record_shapeenv_event() @@ -4783,6 +4857,27 @@ def _create_symbolic_sizes_strides_storage_offset( size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( ex_size, source, symbolic_context, hint_overrides=hint_overrides ) + # Record tensor exclusion constraints for stable graph selection. + # The ndim check guards against stale excluded_sizes from graph + # breaks where the resumed tensor may have different dimensionality. + # Skip dims with hint overrides: the overridden hint in + # backed_var_to_val would mismatch the excluded value, causing the + # not-all check in produce_guards_verbose to emit a guard that + # immediately fails. + excluded_sizes = getattr(symbolic_context, "excluded_sizes", None) + if ( + excluded_sizes + and len(excluded_sizes) == dim + and any(v is not None for v in excluded_sizes) + ): + for i in range(dim): + ev = excluded_sizes[i] + if ( + ev is not None + and isinstance(size[i], sympy.Symbol) + and i not in (hint_overrides or {}) + ): + self._record_exclusion_constraint(size[i], ev) stride = self._compute_symbolic_stride( source, size, @@ -4904,7 +4999,7 @@ def create_symintnode( self, sym: sympy.Expr, *, - hint: int | None, + hint: int | float | bool | torch.SymInt | None, source: Source | None = None, ) -> IntLikeType: """Create a SymInt value from a symbolic expression @@ -4950,7 +5045,7 @@ def create_symfloatnode( self, sym: sympy.Expr, *, - hint: int | float | bool | None, + hint: int | float | bool | torch.SymInt | None, source: Source | None = None, ) -> FloatLikeType: """Create a SymFloat value from a symbolic expression""" @@ -4987,17 +5082,28 @@ def create_symfloatnode( out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) return out + @record_shapeenv_event() + def _record_exclusion_constraint(self, sym: sympy.Symbol, val: int) -> None: + self.exclusion_constraints.append((sym, val)) + @record_shapeenv_event() def create_unspecified_symint_and_symbol( - self, value: int, source: Source, dynamic_dim: DimDynamic + self, + value: int, + source: Source, + dynamic_dim: DimDynamic, + excluded_value: int | None = None, ) -> IntLikeType: """Create a SymInt wrapping a new unspecified symbol""" + sym = self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ) + if excluded_value is not None: + self._record_exclusion_constraint(sym, excluded_value) return self.create_symintnode( - self.create_unspecified_symbol( - value, - source=source, - dynamic_dim=dynamic_dim, - ), + sym, hint=value, source=source, ) @@ -5012,7 +5118,7 @@ def _log_create_unbacked_symbol( self, prefix: str, symbol: sympy.Symbol, - vr: ValueRanges, + vr: ValueRanges[sympy.Expr], source: Source | None = None, sym_node: SymNode | None = None, ) -> None: @@ -5136,7 +5242,7 @@ def create_unspecified_symbol( source: Source, dynamic_dim: DimDynamic = DimDynamic.DUCK, constraint_dim: DimConstraint = None, # NB: includes None - symbolic_context: StatelessSymbolicContext | None = None, + symbolic_context: SymbolicContext | None = None, ) -> sympy.Expr: """ Create a symbol with an unspecified value @@ -5162,13 +5268,13 @@ def create_unspecified_symbol( @record_shapeenv_event() def create_symbol( self, - val: int, + val: IntLikeType | FloatLikeType, source: Source, dynamic_dim: DimDynamic = DimDynamic.DUCK, constraint_dim: DimConstraint = None, # NB: includes None positive: bool | None = True, do_not_specialize_zero_one: bool = False, - symbolic_context: StatelessSymbolicContext | None = None, + symbolic_context: SymbolicContext | None = None, ) -> sympy.Expr: """Create a new symbol which is tracked by this ShapeEnv""" # check if constraint_dim is actually static integer @@ -5182,7 +5288,7 @@ def create_symbol( f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " f"for {source.name}" ) - if symbolic_context: + if isinstance(symbolic_context, StatelessSymbolicContext): from torch._dynamo.source import TensorPropertySource if not isinstance(source, TensorPropertySource): @@ -5589,7 +5695,9 @@ def produce_guards_verbose( raise AssertionError(f"len({placeholders}) != len({sources})") Tensorlike = (torch.Tensor, FakeTensorMeta) - def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: + def _create_no_constraints_context( + t: Tensor, + ) -> StatelessSymbolicContext[..., Any]: return StatelessSymbolicContext( # Ignored; only the constraints part is relevant below. dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), @@ -5913,7 +6021,7 @@ def track_symfloat(source: Source, val: FloatLikeType) -> None: s = sympy.Float(val) input_guards.append((source, s)) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] for t, source, context in zip(placeholders, sources, input_contexts): if isinstance(source, str): from torch._dynamo.source import LocalSource @@ -6209,7 +6317,7 @@ def issue_guard(guard: ShapeGuard) -> None: if not sources: raise AssertionError(f"sources must not be empty for symbol {symbol}") - bounds = [] + bounds: list[sympy.Basic] = [] rf = source_ref(sources[0]) verbose_expr = "" if r.lower not in (-sympy.oo, -int_oo): @@ -6279,6 +6387,93 @@ def issue_guard(guard: ShapeGuard) -> None: else: raise NotImplementedError(f"Unimplemented for lang: {lang}") + # Exclusion guard for stable graph selection with automatic dynamic. + # + # When automatic_dynamic promotes a static dim to dynamic, the new + # (more general) graph is inserted *before* the old (specialized) graph + # in the guard cache. Without an exclusion guard, inputs that exactly + # match the old graph's static sizes would be captured by the new + # dynamic graph instead, violating the invariant "once an input is + # served by graph X it is always served by graph X". This condition + # is true iff there is no branching on dynamic shapes. + # + # Soundness argument (cache-flip / LIFO order): + # Graph_new sits before Graph_old in the cache. Graph_old accepts + # only inputs whose sizes match its static constraints exactly. + # Graph_new must therefore reject exactly that set of inputs so they + # fall through to Graph_old. The excluded values are the static + # sizes from Graph_old, so the guard + # Or(Ne(s0, v0), Ne(s1, v1), ...) + # passes iff at least one dim differs from the old sizes — i.e. the + # input does NOT fully match Graph_old. Conversely, when every dim + # matches the old sizes the guard fails and the input falls through + # to Graph_old, which is guaranteed to accept it. + # + # Theorem: For graphs G0, ..., Gn compiled via progressive dynamism + # (one dim per step), each input is accepted by at most one graph. + # + # Setup: G0 is all-static with shape S. Gk is created by making + # dim d_k dynamic, with exclusion guard d_k != S[d_k]. + # + # Proof by induction on n: + # + # Base case (n=0): Only G0, all-static. Trivially unique. + # + # Inductive step: Assume the property holds for G0, ..., G_{n-1}. + # We add Gn with newly-dynamic dim d_n and exclusion d_n != S[d_n]. + # + # For any input X that passes Gn's shape guards, exactly one of: + # + # Case A — exclusion passes (X[d_n] != S[d_n]): + # Dim d_n is static in all G0, ..., G_{n-1} with value S[d_n], + # so X fails all prior graphs on that dim. Only Gn accepts X. + # + # Case B — exclusion rejects (X[d_n] == S[d_n]): + # X matches Gn's shape guards on all other dims, and matches + # the static value for d_n. So X satisfies G_{n-1}'s shape + # guards. By the inductive hypothesis, exactly one of + # G0, ..., G_{n-1} accepts X. Gn rejects X. + # + # Corollary: Evaluation order does not affect correctness. + # + # All exclusion pairs across all tensors and scalars are flattened + # into a single list — each pair is just (symbol, excluded_int), + # and the multi-tensor case is the same logic as multi-dim within + # one tensor. The combined Or rejects only when ALL pairs match + # simultaneously, which is the exact condition for Graph_old to + # accept. If the current concrete values already match every + # excluded value the guard is skipped (it would fail on creation). + import torch._dynamo.config as dynamo_config + + if ( + dynamo_config.automatic_dynamic_exclusion_guard + and not dynamo_config.enable_compiler_collectives + and self.exclusion_constraints + ): + all_pairs = [ + (sym, val) + for sym, val in self.exclusion_constraints + if symbol_to_source.get(sym) + ] + if all_pairs and not all( + self.backed_var_to_val.get(sym) == val for sym, val in all_pairs + ): + if len(all_pairs) == 1: + excl_expr = sympy.Ne( + all_pairs[0][0], all_pairs[0][1], evaluate=False + ) + else: + excl_expr = sympy.Or( + *[sympy.Ne(sym, val, evaluate=False) for sym, val in all_pairs] + ) + for exprs, printer, lang in zip(all_exprs, printers, langs): + guard_expr = printer.doprint(excl_expr) + if lang == "verbose_python": + guard_expr = ( + f"{guard_expr} # exclusion guard for automatic dynamic" + ) + exprs.append(guard_expr) + if constraint_violations: warn_msgs: list[str] = [] error_msgs: list[str] = [] @@ -6513,7 +6708,7 @@ def format_guards(self, verbose: bool = False) -> str: def bound_sympy( self, expr: sympy.Expr, size_oblivious: bool = False - ) -> ValueRanges: + ) -> ValueRanges[sympy.Expr]: """Given a sympy expression, computes a ValueRanges bound for what values it can be""" # TODO: maybe it's guaranteed x in is var_to_range? var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} @@ -6676,7 +6871,7 @@ def _maybe_evaluate_static( compute_hint: bool = False, size_oblivious: bool = False, axioms: tuple[SympyBoolean] | None = None, - var_to_range: tuple[tuple[sympy.Symbol, ValueRanges]] | None = None, + var_to_range: tuple[tuple[sympy.Symbol, ValueRanges[sympy.Expr]]] | None = None, ) -> sympy.Basic | None: """ Tries to evaluate expr without introducing guards @@ -6769,7 +6964,7 @@ def replace(self, expr: _SympyT) -> _SympyT: tracks only replacement changes) to cache calls to this method, so depending on other state would cause stale cache results. """ - replacements = {} + replacements: dict[sympy.Basic, sympy.Basic] = {} # pyrefly: ignore [missing-attribute] for s in expr.free_symbols: r = self._find(s) @@ -6805,7 +7000,7 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: # Simplify max(0/1, x) to x when x >= 0/1. max(1, x) is a commonly introduced # expression when creating contiguous strides. if not size_oblivious: - min_max_replacements = {} + min_max_replacements: dict[sympy.Basic, sympy.Basic] = {} for atom in expr.atoms(Max): # type: ignore[has-type] if len(atom.args) > 2: continue @@ -6821,11 +7016,11 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: expr = expr.xreplace(min_max_replacements) if expr.has(TruncToInt): - trunc_replacements = {} + trunc_replacements: dict[sympy.Basic, sympy.Basic] = {} for atom in expr.atoms(TruncToInt): if isinstance(atom.args[0], IntTrueDiv): base, divisor = atom.args[0].args - if base % divisor == 0: + if Mod(base, divisor) == 0: trunc_replacements[atom] = CleanDiv(base, divisor) else: # TruncToInt(IntTrueDiv(a,b)) == FloorDiv(a, b) @@ -6841,7 +7036,7 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: # for now just do a separate pass to catch common nested case if expr.has(FloorDiv): self._update_divisible() - div_replacements = {} + div_replacements: dict[sympy.Basic, sympy.Basic] = {} for atom in expr.atoms(FloorDiv): base, divisor = atom.args if isinstance(divisor, FloorDiv): @@ -6856,7 +7051,7 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: expr = expr.xreplace(div_replacements) expr = safe_expand(expr) if expr.has(FloorDiv): - div_replacements = {} + div_replacements: dict[sympy.Basic, sympy.Basic] = {} pows = expr.atoms(sympy.Pow) rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) for fd in expr.atoms(FloorDiv): @@ -6876,6 +7071,10 @@ def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: return expr # TODO: overload for allow_none literal + @deprecated( + "use guarding_hint_or_throw or optimization_hint instead", + category=FutureWarning, + ) @lru_cache(256) def size_hint( self, expr: sympy.Basic, *, allow_none: bool = False @@ -6922,13 +7121,46 @@ def size_hint( raise self._make_data_dependent_error(result_expr, expr) return result_expr - # NB: keep in sync with size_hint @lru_cache(256) - def has_hint(self, expr: sympy.Expr) -> bool: - result_expr = safe_expand(expr).xreplace(self.backed_var_to_val) - return ( - result_expr.is_number - or self._maybe_evaluate_static(result_expr) is not None + def guarding_hint_or_throw(self, expr: sympy.Expr | int) -> int | bool: + """ + Return a concrete hint for an expression. + + Returns Python bool (True/False) for boolean expressions (e.g. Eq, Ne), + and Python int for integer expressions. + """ + return _guarding_hint_or_throw_base(self, expr, {}) + + @lru_cache(256) + def has_guarding_hint(self, expr: sympy.Expr) -> bool: + try: + self.guarding_hint_or_throw(expr) + except GuardOnDataDependentSymNode: + return False + return True + + def optimization_hint( + self, expr: sympy.Expr | int, fallback: int | None = None + ) -> int: + """ + Return a concrete integer hint for an expression. + + This function should be used for non-guarding based optimizations. If you + want a hint that you can guard on, use the guarding_hint API instead. + + This function will hint unbacked symbols using user provided optimization + hints. If not provided, fallback will be used along with some heuristics + that try to maximize consistency with the shape environment. + + Special cases: + + - Complex numbers (containing sympy.I): raises an error since tensor + dimensions cannot be complex. + - Infinity (int_oo, sympy.oo): returns sys.maxsize. + - NaN (sympy.nan): returns the fallback value. + """ + return _optimization_hint_base( + self, expr, precomputed_replacements={}, fallback=fallback ) def _make_data_dependent_error( @@ -6993,7 +7225,7 @@ def _make_data_dependent_error( def _update_var_to_range( self, symbol: sympy.Symbol, - vr: ValueRanges, + vr: ValueRanges[sympy.Expr], vr_sloc: ValueRangesSLoc | None = None, *, is_constraint: bool = False, @@ -7381,11 +7613,11 @@ def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool: # See: Note - On 0/1 specialization def _default_value_range( self, do_not_specialize_zero_one: bool = False - ) -> ValueRanges: + ) -> ValueRanges[sympy.Expr]: lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2 return ValueRanges(lower, int_oo) - def _default_unspecified_value_range(self) -> ValueRanges: + def _default_unspecified_value_range(self) -> ValueRanges[sympy.Expr]: return ValueRanges.unknown_int() @_lru_cache @@ -7636,8 +7868,8 @@ def evaluate_sym_node( self._expr_sym_node_id = id(sym_node) return self.evaluate_expr( sym_node.expr, - sym_node.hint, - sym_node.fx_node, + sym_node.hint, # pyrefly: ignore[bad-argument-type] + sym_node.fx_node, # pyrefly: ignore[bad-argument-type] size_oblivious, fallback_value=fallback_value, ) @@ -7807,11 +8039,21 @@ def _evaluate_expr( def compute_concrete_val() -> sympy.Basic: if hint is None: # This is only ever called for expressions WITHOUT unbacked - # symbols - r = self.size_hint(orig_expr) - if r is None: - raise AssertionError("r must not be None") - return r + # symbols. guarding_hint_or_throw returns Python bool for + # boolean expressions and int for integer expressions; + # sympify converts them to the proper sympy types + # (True -> sympy.true, 5 -> Integer(5)). + try: + return sympy.sympify(self.guarding_hint_or_throw(orig_expr)) + except GuardOnDataDependentSymNode: + # guarding_hint_or_throw only does backed-symbol replacement. + # For expressions with unbacked symbols resolvable via axioms + # (e.g. Eq(x, 0) when torch._check(Ne(x, 0)) was previously + # asserted), fall back to static evaluation with compute_hint. + r = self._maybe_evaluate_static(orig_expr, compute_hint=True) + if r is not None: + return r + raise else: return sympy.sympify(hint) @@ -8474,6 +8716,8 @@ def _remove_effect_token_unbacked_bindings( # that we apply is unbacked renaming. def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr: shape_env = sym_node.shape_env + if shape_env is None: + raise AssertionError("shape_env is required for _get_placeholder_expr") result = sym_node._expr if result in shape_env.unbacked_renamings: return shape_env.unbacked_renamings[result] diff --git a/torch/fx/experimental/unification/__init__.py b/torch/fx/experimental/unification/__init__.py index 7db0e29d1d4f7..017bf69352738 100644 --- a/torch/fx/experimental/unification/__init__.py +++ b/torch/fx/experimental/unification/__init__.py @@ -1,4 +1,4 @@ # mypy: disable-error-code=attr-defined -from .core import reify, unify # noqa: F403 -from .more import unifiable # noqa: F403 -from .variable import isvar, Var, var, variables, vars # noqa: F403 +from .core import reify, unify +from .more import unifiable +from .variable import isvar, Var, var, variables, vars diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 3d8071c847ae5..0ec0974da181a 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -1,6 +1,9 @@ -# mypy: allow-untyped-defs -from collections.abc import Iterator # type: ignore[import] +from __future__ import annotations + +from collections.abc import Iterator, Sequence from functools import partial +from typing import TYPE_CHECKING +from typing_extensions import TypeVarTuple, Unpack from .dispatch import dispatch from .unification_tools import assoc # type: ignore[import] @@ -8,6 +11,12 @@ from .variable import isvar +if TYPE_CHECKING: + from .variable import Var + + +_Ts = TypeVarTuple("_Ts") + __all__ = ["reify", "unify"] ############### @@ -16,7 +25,7 @@ @dispatch(Iterator, dict) -def _reify(t, s): +def _reify(t: Iterator[object], s: dict[Var, object]) -> Iterator[object]: return map(partial(reify, s=s), t) # return (reify(arg, s) for arg in t) @@ -25,23 +34,23 @@ def _reify(t, s): @dispatch(tuple, dict) # type: ignore[no-redef] -def _reify(t, s): - return tuple(reify(iter(t), s)) +def _reify(t: tuple[Unpack[_Ts]], s: dict[Var, object]) -> tuple[Unpack[_Ts]]: + return tuple(reify(iter(t), s)) # pyrefly: ignore[bad-argument-type, bad-return] _reify @dispatch(list, dict) # type: ignore[no-redef] -def _reify(t, s): - return list(reify(iter(t), s)) +def _reify(t: list[object], s: dict[Var, object]) -> list[object]: + return list(reify(iter(t), s)) # pyrefly: ignore[bad-argument-type] _reify @dispatch(dict, dict) # type: ignore[no-redef] -def _reify(d, s): +def _reify(d: dict[object, object], s: dict[Var, object]) -> dict[object, object]: return {k: reify(v, s) for k, v in d.items()} @@ -49,11 +58,11 @@ def _reify(d, s): @dispatch(object, dict) # type: ignore[no-redef] -def _reify(o, s): +def _reify(o: object, s: dict[Var, object]) -> object: return o # catch all, just return the object -def reify(e, s): +def reify(e: object, s: dict[Var, object]) -> object: """Replace variables of expression with substitution >>> # xdoctest: +SKIP >>> x, y = var(), var() @@ -78,11 +87,13 @@ def reify(e, s): @dispatch(seq, seq, dict) # type: ignore[arg-type] -def _unify(u, v, s): +def _unify( + u: Sequence[object], v: Sequence[object], s: dict[Var, object] +) -> dict[Var, object] | bool: if len(u) != len(v): return False for uu, vv in zip(u, v): # avoiding recursion - s = unify(uu, vv, s) + s = unify(uu, vv, s) # pyrefly: ignore[bad-assignment] if s is False: return False return s @@ -116,7 +127,9 @@ def _unify(u, v, s): @dispatch(object, object, dict) -def unify(u, v, s): # no check at the moment +def unify( + u: object, v: object, s: dict[Var, object] +) -> dict[Var, object] | bool: # no check at the moment """Find substitution so that u == v while satisfying s >>> x = var("x") >>> unify((1, x), (1, 2), {}) @@ -137,5 +150,5 @@ def unify(u, v, s): # no check at the moment @dispatch(object, object) # type: ignore[no-redef] -def unify(u, v): +def unify(u: object, v: object) -> dict[Var, object] | bool: return unify(u, v, {}) diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index 01861a086f64b..19fc2805b628e 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -1,33 +1,45 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + from .core import reify, unify # type: ignore[attr-defined] + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from .variable import Var + from .unification_tools import first, groupby # type: ignore[import] from .utils import _toposort, freeze from .variable import isvar class Dispatcher: - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - self.funcs = {} - self.ordering = [] + self.funcs: dict[object, Callable[..., object]] = {} + self.ordering: list[object] = [] - def add(self, signature, func): + def add(self, signature: tuple[object, ...], func: Callable[..., object]) -> None: self.funcs[freeze(signature)] = func self.ordering = ordering(self.funcs) - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: func, _ = self.resolve(args) return func(*args, **kwargs) - def resolve(self, args): + def resolve( + self, args: tuple[object, ...] + ) -> tuple[Callable[..., object], dict[Var, object]]: n = len(args) for signature in self.ordering: - if len(signature) != n: + if len(signature) != n: # pyrefly: ignore[bad-argument-type] continue s = unify(freeze(args), signature) if s is not False: result = self.funcs[signature] - return result, s + return result, s # pyrefly: ignore[bad-return] raise NotImplementedError( "No match found. \nKnown matches: " + str(self.ordering) @@ -35,8 +47,8 @@ def resolve(self, args): + str(args) ) - def register(self, *signature): - def _(func): + def register(self, *signature: object) -> Callable[..., object]: + def _(func: Callable[..., object]) -> Dispatcher: self.add(signature, func) return self @@ -60,24 +72,28 @@ class VarDispatcher(Dispatcher): 20 """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: func, s = self.resolve(args) - d = {k.token: v for k, v in s.items()} + d = {k.token: v for k, v in s.items()} # pyrefly: ignore[missing-attribute] return func(**d) -global_namespace = {} # type: ignore[var-annotated] +global_namespace: dict[str, Dispatcher] = {} -def match(*signature, **kwargs): - namespace = kwargs.get("namespace", global_namespace) - dispatcher = kwargs.get("Dispatcher", Dispatcher) +def match(*signature: object, **kwargs: object) -> Callable[..., object]: + namespace: dict[str, Dispatcher] = kwargs.get( # type: ignore[assignment] + "namespace", global_namespace + ) + dispatcher_cls: type[Dispatcher] = kwargs.get( # type: ignore[assignment] + "Dispatcher", Dispatcher + ) - def _(func): + def _(func: Callable[..., object]) -> Dispatcher: name = func.__name__ if name not in namespace: - namespace[name] = dispatcher(name) + namespace[name] = dispatcher_cls(name) d = namespace[name] d.add(signature, func) @@ -87,22 +103,27 @@ def _(func): return _ -def supercedes(a, b): +def supercedes(a: object, b: object) -> bool: """``a`` is a more specific match than ``b``""" if isvar(b) and not isvar(a): return True s = unify(a, b) if s is False: return False - s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} + s = { + k: v + for k, v in s.items() # pyrefly: ignore[missing-attribute] + if not isvar(k) or not isvar(v) + } if reify(a, s) == a: return True if reify(b, s) == b: return False + return False # Taken from multipledispatch -def edge(a, b, tie_breaker=hash): +def edge(a: object, b: object, tie_breaker: Callable[[object], int] = hash) -> bool: """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ @@ -115,15 +136,15 @@ def edge(a, b, tie_breaker=hash): # Taken from multipledispatch -def ordering(signatures): +def ordering(signatures: Iterable[object]) -> list[object]: """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ - signatures = list(map(tuple, signatures)) + signatures = list(map(tuple, signatures)) # pyrefly: ignore[bad-argument-type] edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] edges = groupby(first, edges) for s in signatures: if s not in edges: edges[s] = [] edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] - return _toposort(edges) + return _toposort(edges) # pyrefly: ignore[bad-argument-type] diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index a5d46f3f256e5..41cceef3e2b3a 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -1,4 +1,7 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING + from .core import ( # type: ignore[attr-defined] _reify as core_reify, _unify as core_unify, @@ -8,10 +11,14 @@ from .dispatch import dispatch +if TYPE_CHECKING: + from .variable import Var + + __all__ = ["unifiable", "reify_object", "unify_object"] -def unifiable(cls): +def unifiable(cls: type) -> type: """Register standard unify and reify operations on class This uses the type and __dict__ or __slots__ attributes to define the nature of the term @@ -40,7 +47,7 @@ def unifiable(cls): ######### -def reify_object(o, s): +def reify_object(o: object, s: dict[Var, object]) -> object: """Reify a Python object with a substitution >>> # xdoctest: +SKIP >>> class Foo(object): @@ -63,32 +70,38 @@ def reify_object(o, s): return _reify_object_dict(o, s) -def _reify_object_dict(o, s): +def _reify_object_dict(o: object, s: dict[Var, object]) -> object: obj = object.__new__(type(o)) - d = reify(o.__dict__, s) - if d == o.__dict__: + d = reify(o.__dict__, s) # pyrefly: ignore[missing-attribute] + if d == o.__dict__: # pyrefly: ignore[missing-attribute] return o - obj.__dict__.update(d) + obj.__dict__.update(d) # pyrefly: ignore[missing-attribute, no-matching-overload] return obj -def _reify_object_slots(o, s): - attrs = [getattr(o, attr) for attr in o.__slots__] +def _reify_object_slots(o: object, s: dict[Var, object]) -> object: + attrs = [ + getattr(o, attr) + for attr in o.__slots__ # pyrefly: ignore[missing-attribute] + ] new_attrs = reify(attrs, s) if attrs == new_attrs: return o else: newobj = object.__new__(type(o)) - for slot, attr in zip(o.__slots__, new_attrs): + for slot, attr in zip( + o.__slots__, # pyrefly: ignore[missing-attribute] + new_attrs, # pyrefly: ignore[bad-argument-type] + ): setattr(newobj, slot, attr) return newobj @dispatch(slice, dict) -def _reify(o, s): +def _reify(o: slice, s: dict[Var, object]) -> slice: """Reify a Python ``slice`` object""" - return slice(*reify((o.start, o.stop, o.step), s)) + return slice(*reify((o.start, o.stop, o.step), s)) # pyrefly: ignore[not-iterable] ######### @@ -96,7 +109,9 @@ def _reify(o, s): ######### -def unify_object(u, v, s): +def unify_object( + u: object, v: object, s: dict[Var, object] +) -> dict[Var, object] | bool: """Unify two Python objects Unifies their type and ``__dict__`` attributes >>> # xdoctest: +SKIP @@ -117,15 +132,25 @@ def unify_object(u, v, s): return False if hasattr(u, "__slots__"): return unify( - [getattr(u, slot) for slot in u.__slots__], - [getattr(v, slot) for slot in v.__slots__], + [ + getattr(u, slot) + for slot in u.__slots__ # pyrefly: ignore[missing-attribute] + ], + [ + getattr(v, slot) + for slot in v.__slots__ # pyrefly: ignore[missing-attribute] + ], s, ) else: - return unify(u.__dict__, v.__dict__, s) + return unify( + u.__dict__, # pyrefly: ignore[missing-attribute] + v.__dict__, # pyrefly: ignore[missing-attribute] + s, + ) @dispatch(slice, slice, dict) -def _unify(u, v, s): +def _unify(u: slice, v: slice, s: dict[Var, object]) -> dict[Var, object] | bool: """Unify a Python ``slice`` object""" return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 9b6e33805a4c7..9e249e2d59250 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,5 +1,11 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import operator +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Sequence from .utils import _toposort, groupby from .variadic import isvariadic @@ -21,7 +27,7 @@ class AmbiguityWarning(Warning): pass -def supercedes(a, b): +def supercedes(a: tuple[type, ...], b: tuple[type, ...]) -> bool: """A is consistent and strictly more specific than B""" if len(a) < len(b): # only case is if a is empty and b is variadic @@ -57,7 +63,7 @@ def supercedes(a, b): return p2 == len(b) - 1 and p1 == len(a) -def consistent(a, b): +def consistent(a: tuple[type, ...], b: tuple[type, ...]) -> bool: """It is possible for an argument list to satisfy both A and B""" # Need to check for empty args @@ -94,12 +100,14 @@ def consistent(a, b): ) -def ambiguous(a, b): +def ambiguous(a: tuple[type, ...], b: tuple[type, ...]) -> bool: """A is consistent with B but neither is strictly more specific""" return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) -def ambiguities(signatures): +def ambiguities( + signatures: Iterable[tuple[type, ...]], +) -> set[tuple[tuple[type, ...], tuple[type, ...]]]: """All signature pairs such that A is ambiguous with B""" signatures = list(map(tuple, signatures)) return { @@ -112,7 +120,7 @@ def ambiguities(signatures): } -def super_signature(signatures): +def super_signature(signatures: Sequence[tuple[type, ...]]) -> list[type]: """A signature that would break ambiguities""" n = len(signatures[0]) if not all(len(s) == n for s in signatures): @@ -121,7 +129,11 @@ def super_signature(signatures): return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] -def edge(a, b, tie_breaker=hash): +def edge( + a: tuple[type, ...], + b: tuple[type, ...], + tie_breaker: Callable[[tuple[type, ...]], int] = hash, +) -> bool: """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ @@ -132,7 +144,7 @@ def edge(a, b, tie_breaker=hash): ) -def ordering(signatures): +def ordering(signatures: Iterable[tuple[type, ...]]) -> list[tuple[type, ...]]: """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ @@ -142,5 +154,10 @@ def ordering(signatures): for s in signatures: if s not in edges: edges[s] = [] - edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] - return _toposort(edges) + topo_edges: dict[ + tuple[type, ...], list[tuple[type, ...]] + ] = { # pyrefly: ignore[bad-assignment] + k: [b for a, b in v] + for k, v in edges.items() # type: ignore[attr-defined] + } + return _toposort(topo_edges) diff --git a/torch/fx/experimental/unification/multipledispatch/core.py b/torch/fx/experimental/unification/multipledispatch/core.py index 69b9f3b2b5a2c..d44ef661c3feb 100644 --- a/torch/fx/experimental/unification/multipledispatch/core.py +++ b/torch/fx/experimental/unification/multipledispatch/core.py @@ -1,23 +1,27 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import inspect -from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import TypeVarTuple, Unpack + +if TYPE_CHECKING: + from collections.abc import Callable + from .dispatcher import Dispatcher, MethodDispatcher -global_namespace = {} # type: ignore[var-annotated] +global_namespace: dict[str, Dispatcher] = {} __all__ = ["dispatch", "ismethod"] -T = TypeVar("T") -Ts = TypeVarTuple("Ts") +_T = TypeVar("_T") +_Ts = TypeVarTuple("_Ts") def dispatch( - *types: Unpack[Ts], **kwargs: Any -) -> Callable[[Callable[..., T]], Callable[..., T]]: + *types: Unpack[_Ts], **kwargs: Any +) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """Dispatch function on the types of the inputs Supports dispatch on all non-keyword arguments. Collects implementations based on the function name. Ignores namespaces. @@ -60,7 +64,7 @@ def dispatch( types_tuple: tuple[type, ...] = tuple(types) # type: ignore[arg-type] - def _df(func): + def _df(func: Callable[..., _T]) -> Callable[..., _T]: name = func.__name__ if ismethod(func): @@ -76,10 +80,10 @@ def _df(func): dispatcher.add(types_tuple, func) return dispatcher - return _df + return _df # type: ignore[return-value] -def ismethod(func): +def ismethod(func: Callable[..., object]) -> bool: """Is func a method? Note that this has to work as the method is defined but before the class is defined. At this stage methods look like functions. @@ -89,4 +93,4 @@ def ismethod(func): return signature.parameters.get("self", None) is not None else: spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] - return spec and spec.args and spec.args[0] == "self" + return bool(spec and spec.args and spec.args[0] == "self") diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index 6ecb4502b978e..e5d44cf4248e2 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,14 +1,23 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import inspect import itertools as itl +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import deprecated from warnings import warn + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Iterable, Iterator + from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature from .utils import expand_tuples from .variadic import isvariadic, Variadic +_T = TypeVar("_T") + + __all__ = [ "MDNotImplementedError", "ambiguity_warn", @@ -28,7 +37,9 @@ class MDNotImplementedError(NotImplementedError): """A NotImplementedError for multiple dispatch""" -def ambiguity_warn(dispatcher, ambiguities): +def ambiguity_warn( + dispatcher: Dispatcher, ambiguities: set[tuple[tuple[type, ...], tuple[type, ...]]] +) -> None: """Raise warning when ambiguity is detected. Parameters @@ -50,7 +61,7 @@ def ambiguity_warn(dispatcher, ambiguities): "`halt_ordering` is deprecated, you can safely remove this call.", category=FutureWarning, ) -def halt_ordering(): +def halt_ordering() -> None: """Deprecated interface to temporarily disable ordering.""" @@ -59,11 +70,13 @@ def halt_ordering(): "you should call the `reorder()` method on each dispatcher.", category=FutureWarning, ) -def restart_ordering(on_ambiguity=ambiguity_warn): +def restart_ordering(on_ambiguity: Callable[..., None] = ambiguity_warn) -> None: """Deprecated interface to temporarily resume ordering.""" -def variadic_signature_matches_iter(types, full_signature): +def variadic_signature_matches_iter( + types: tuple[type, ...], full_signature: tuple[type, ...] +) -> Generator[bool, None, None]: """Check if a set of input types matches a variadic signature. Notes @@ -105,7 +118,9 @@ def variadic_signature_matches_iter(types, full_signature): yield False -def variadic_signature_matches(types, full_signature): +def variadic_signature_matches( + types: tuple[type, ...], full_signature: tuple[type, ...] +) -> bool: # No arguments always matches a variadic signature if not full_signature: raise AssertionError("full_signature is empty") @@ -133,14 +148,16 @@ class Dispatcher: __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc" - def __init__(self, name, doc=None): + def __init__(self, name: str, doc: str | None = None) -> None: self.name = self.__name__ = name - self.funcs = {} + self.funcs: dict[tuple[type, ...], Callable[..., object]] = {} self.doc = doc - self._cache = {} + self._cache: dict[tuple[type, ...], Callable[..., object]] = {} - def register(self, *types, **kwargs): + def register( + self, *types: type, **kwargs: object + ) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """register dispatcher with new implementation >>> # xdoctest: +SKIP >>> f = Dispatcher("f") @@ -162,20 +179,24 @@ def register(self, *types, **kwargs): [3, 2, 1] """ - def _df(func): + def _df(func: Callable[..., _T]) -> Callable[..., _T]: self.add(types, func, **kwargs) # type: ignore[call-arg] return func return _df @classmethod - def get_func_params(cls, func): + def get_func_params( + cls, func: Callable[..., object] + ) -> Iterable[inspect.Parameter] | None: if hasattr(inspect, "signature"): sig = inspect.signature(func) return sig.parameters.values() @classmethod - def get_func_annotations(cls, func): + def get_func_annotations( + cls, func: Callable[..., object] + ) -> tuple[type, ...] | None: """get annotations of function positional parameters""" params = cls.get_func_params(func) if params: @@ -193,7 +214,7 @@ def get_func_annotations(cls, func): if all(ann is not Parameter.empty for ann in annotations): return annotations - def add(self, signature, func): + def add(self, signature: tuple[type, ...], func: Callable[..., object]) -> None: """Add new types/method pair to dispatcher >>> # xdoctest: +SKIP >>> D = Dispatcher("add") @@ -248,7 +269,7 @@ def add(self, signature, func): # pyrefly: ignore [bad-specialization] new_signature.append(Variadic[typ[0]]) else: - new_signature.append(typ) + new_signature.append(typ) # pyrefly: ignore[bad-argument-type] self.funcs[tuple(new_signature)] = func self._cache.clear() @@ -259,20 +280,22 @@ def add(self, signature, func): pass @property - def ordering(self): + def ordering(self) -> list[tuple[type, ...]]: try: return self._ordering except AttributeError: return self.reorder() - def reorder(self, on_ambiguity=ambiguity_warn): + def reorder( + self, on_ambiguity: Callable[..., None] = ambiguity_warn + ) -> list[tuple[type, ...]]: self._ordering = od = ordering(self.funcs) amb = ambiguities(self.funcs) if amb: on_ambiguity(self, amb) return od - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: types = tuple(type(arg) for arg in args) try: func = self._cache[types] @@ -300,12 +323,12 @@ def __call__(self, *args, **kwargs): f"{self.name}: <{str_signature(types)}> found, but none completed successfully", ) from e - def __str__(self): + def __str__(self) -> str: return f"" __repr__ = __str__ - def dispatch(self, *types): + def dispatch(self, *types: type) -> Callable[..., object] | None: """Determine appropriate implementation for this type signature This method is internal. Users should call this object as a function. Implementation resolution occurs within the ``__call__`` method. @@ -331,7 +354,9 @@ def dispatch(self, *types): except StopIteration: return None - def dispatch_iter(self, *types): + def dispatch_iter( + self, *types: type + ) -> Generator[Callable[..., object], None, None]: n = len(types) for signature in self.ordering: if len(signature) == n and all(map(issubclass, types, signature)): @@ -345,17 +370,17 @@ def dispatch_iter(self, *types): @deprecated( "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning ) - def resolve(self, types): + def resolve(self, types: tuple[type, ...]) -> Callable[..., object] | None: """Determine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ return self.dispatch(*types) - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: return {"name": self.name, "funcs": self.funcs} - def __setstate__(self, d): + def __setstate__(self, d: dict[str, Any]) -> None: self.name = d["name"] self.funcs = d["funcs"] self._ordering = ordering(self.funcs) @@ -368,7 +393,7 @@ def __doc__(self): # type: ignore[override] if self.doc: docs.append(self.doc) - other = [] + other: list[str] = [] for sig in self.ordering[::-1]: func = self.funcs[sig] if func.__doc__: @@ -384,25 +409,25 @@ def __doc__(self): # type: ignore[override] return "\n\n".join(docs) - def _help(self, *args): + def _help(self, *args: object) -> str | None: return self.dispatch(*map(type, args)).__doc__ - def help(self, *args, **kwargs): + def help(self, *args: object, **kwargs: object) -> None: """Print docstring for the function corresponding to inputs""" print(self._help(*args)) - def _source(self, *args): + def _source(self, *args: object) -> str: func = self.dispatch(*map(type, args)) if not func: raise TypeError("No function found") return source(func) - def source(self, *args, **kwargs): + def source(self, *args: object, **kwargs: object) -> None: """Print source code for the function corresponding to inputs""" print(self._source(*args)) -def source(func): +def source(func: Callable[..., object]) -> str: s = f"File: {inspect.getsourcefile(func)}\n\n" s = s + inspect.getsource(func) return s @@ -417,17 +442,19 @@ class MethodDispatcher(Dispatcher): __slots__ = ("obj", "cls") @classmethod - def get_func_params(cls, func): + def get_func_params( + cls, func: Callable[..., object] + ) -> Iterator[inspect.Parameter] | None: if hasattr(inspect, "signature"): sig = inspect.signature(func) return itl.islice(sig.parameters.values(), 1, None) - def __get__(self, instance, owner): + def __get__(self, instance: object | None, owner: type) -> MethodDispatcher: self.obj = instance self.cls = owner return self - def __call__(self, *args, **kwargs): + def __call__(self, *args: object, **kwargs: object) -> object: types = tuple(type(arg) for arg in args) func = self.dispatch(*types) if not func: @@ -437,7 +464,7 @@ def __call__(self, *args, **kwargs): return func(self.obj, *args, **kwargs) -def str_signature(sig): +def str_signature(sig: Iterable[type]) -> str: """String representation of type signature >>> str_signature((int, float)) 'int, float' @@ -445,7 +472,7 @@ def str_signature(sig): return ", ".join(cls.__name__ for cls in sig) -def warning_text(name, amb): +def warning_text(name: str, amb: set[tuple[tuple[type, ...], tuple[type, ...]]]) -> str: """The text for ambiguity warnings""" text = f"\nAmbiguities exist in dispatched function {name}\n\n" text += "The following signatures may result in ambiguous behavior:\n" diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index f89d31aaef25a..befa822735e90 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -1,11 +1,22 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + from collections import OrderedDict +from typing import TYPE_CHECKING, TypeVar + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Mapping, Sequence + +_T = TypeVar("_T") __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] -def raises(err, lamda): # codespell:ignore lamda +def raises( + err: type[BaseException], + lamda: Callable[[], object], # codespell:ignore lamda +) -> bool: try: lamda() # codespell:ignore lamda return False @@ -13,7 +24,7 @@ def raises(err, lamda): # codespell:ignore lamda return True -def expand_tuples(L): +def expand_tuples(L: Sequence[type | tuple[type, ...]]) -> list[tuple[type, ...]]: """ >>> expand_tuples([1, (2, 3)]) [(1, 2), (1, 3)] @@ -32,7 +43,7 @@ def expand_tuples(L): # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): +def _toposort(edges: Mapping[_T, Iterable[_T]]) -> list[_T]: """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a @@ -64,7 +75,9 @@ def _toposort(edges): return L -def reverse_dict(d): +def reverse_dict( + d: Mapping[_T, Iterable[_T]], +) -> OrderedDict[_T, tuple[_T, ...]]: """Reverses direction of dependence dict. >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} @@ -82,12 +95,14 @@ def reverse_dict(d): for val in d[key]: # pyrefly: ignore [unsupported-operation] result[val] = result.get(val, ()) + (key,) - return result + return result # pyrefly: ignore[bad-return] # Taken from toolz # Avoids licensing issues because this version was authored by Matthew Rocklin -def groupby(func, seq): +def groupby( + func: Callable[[_T], object], seq: Iterable[_T] +) -> OrderedDict[object, list[_T]]: """Group a collection by a key function >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP @@ -108,7 +123,7 @@ def groupby(func, seq): return d -def typename(type): +def typename(type: type | tuple[type, ...]) -> str: """Get the name of `type`. Parameters ---------- @@ -125,7 +140,7 @@ def typename(type): '(int, float)' """ try: - return type.__name__ + return type.__name__ # pyrefly: ignore[missing-attribute] except AttributeError: if len(type) == 1: return typename(*type) diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py index 1b5604a152480..211a18b44c0b5 100644 --- a/torch/fx/experimental/unification/multipledispatch/variadic.py +++ b/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -1,4 +1,5 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + from .utils import typename @@ -7,14 +8,14 @@ class VariadicSignatureType(type): # checking if subclass is a subclass of self - def __subclasscheck__(cls, subclass): + def __subclasscheck__(cls, subclass: type) -> bool: other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) return subclass is cls or all( issubclass(other, cls.variadic_type) # type: ignore[attr-defined] for other in other_type ) - def __eq__(cls, other): + def __eq__(cls, other: object) -> bool: """ Return True if other has the same variadic type Parameters @@ -28,11 +29,11 @@ def __eq__(cls, other): """ return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined] - def __hash__(cls): + def __hash__(cls) -> int: return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] -def isvariadic(obj): +def isvariadic(obj: type) -> bool: """Check whether the type `obj` is variadic. Parameters ---------- @@ -59,7 +60,9 @@ class VariadicSignatureMeta(type): examples of how this behaves. """ - def __getitem__(cls, variadic_type): + def __getitem__( + cls, variadic_type: type | tuple[type, ...] + ) -> VariadicSignatureType: if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): raise ValueError( "Variadic types must be type or tuple of types" diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index a47d900273f5e..10791a96f34b8 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -1,8 +1,20 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import collections import operator from collections.abc import Mapping from functools import reduce +from typing import TYPE_CHECKING, TypeVar + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + +_K = TypeVar("_K") +_V = TypeVar("_V") +_K2 = TypeVar("_K2") +_V2 = TypeVar("_V2") +_T = TypeVar("_T") __all__ = [ @@ -22,8 +34,8 @@ ] -def _get_factory(f, kwargs): - factory = kwargs.pop("factory", dict) +def _get_factory(f: Callable[..., object], kwargs: dict[str, object]) -> type: + factory: type = kwargs.pop("factory", dict) # type: ignore[assignment] if kwargs: raise TypeError( f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" @@ -31,7 +43,7 @@ def _get_factory(f, kwargs): return factory -def merge(*dicts, **kwargs): +def merge(*dicts: Mapping[object, object], **kwargs: object) -> object: """Merge a collection of dictionaries >>> merge({1: "one"}, {2: "two"}) @@ -55,7 +67,9 @@ def merge(*dicts, **kwargs): return rv -def merge_with(func, *dicts, **kwargs): +def merge_with( + func: Callable[..., object], *dicts: Mapping[object, object], **kwargs: object +) -> object: """Merge dictionaries and apply function to combined values A key may occur in more than one dict, and all values mapped from the key @@ -84,7 +98,9 @@ def merge_with(func, *dicts, **kwargs): return valmap(func, result, factory) -def valmap(func, d, factory=dict): +def valmap( + func: Callable[[_V], _V2], d: Mapping[_K, _V], factory: type = dict +) -> dict[_K, _V2]: """Apply function to values of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} @@ -100,7 +116,9 @@ def valmap(func, d, factory=dict): return rv -def keymap(func, d, factory=dict): +def keymap( + func: Callable[[_K], _K2], d: Mapping[_K, _V], factory: type = dict +) -> dict[_K2, _V]: """Apply function to keys of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} @@ -116,7 +134,9 @@ def keymap(func, d, factory=dict): return rv -def itemmap(func, d, factory=dict): +def itemmap( + func: Callable[[tuple[_K, _V]], object], d: Mapping[_K, _V], factory: type = dict +) -> dict[object, object]: """Apply function to items of dictionary >>> accountids = {"Alice": 10, "Bob": 20} @@ -132,7 +152,9 @@ def itemmap(func, d, factory=dict): return rv -def valfilter(predicate, d, factory=dict): +def valfilter( + predicate: Callable[[_V], bool], d: Mapping[_K, _V], factory: type = dict +) -> dict[_K, _V]: """Filter items in dictionary by value >>> iseven = lambda x: x % 2 == 0 @@ -152,7 +174,9 @@ def valfilter(predicate, d, factory=dict): return rv -def keyfilter(predicate, d, factory=dict): +def keyfilter( + predicate: Callable[[_K], bool], d: Mapping[_K, _V], factory: type = dict +) -> dict[_K, _V]: """Filter items in dictionary by key >>> iseven = lambda x: x % 2 == 0 @@ -172,7 +196,9 @@ def keyfilter(predicate, d, factory=dict): return rv -def itemfilter(predicate, d, factory=dict): +def itemfilter( + predicate: Callable[[tuple[_K, _V]], bool], d: Mapping[_K, _V], factory: type = dict +) -> dict[_K, _V]: """Filter items in dictionary by item >>> def isvalid(item): @@ -196,7 +222,9 @@ def itemfilter(predicate, d, factory=dict): return rv -def assoc(d, key, value, factory=dict): +def assoc( + d: Mapping[_K, _V], key: object, value: object, factory: type = dict +) -> dict[_K, _V]: """Return a new dict with new key value pair New dict has d[key] set to value. Does not modify the initial dictionary. @@ -212,7 +240,7 @@ def assoc(d, key, value, factory=dict): return d2 -def dissoc(d, *keys, **kwargs): +def dissoc(d: Mapping[object, object], *keys: object, **kwargs: object) -> object: """Return a new dict with the given key(s) removed. New dict has d[key] deleted for each supplied key. @@ -241,7 +269,12 @@ def dissoc(d, *keys, **kwargs): return d2 -def assoc_in(d, keys, value, factory=dict): +def assoc_in( + d: Mapping[object, object], + keys: Iterable[object], + value: object, + factory: type = dict, +) -> object: """Return a new dict with new, potentially nested, key value pair >>> purchase = { @@ -257,7 +290,13 @@ def assoc_in(d, keys, value, factory=dict): return update_in(d, keys, lambda x: value, value, factory) -def update_in(d, keys, func, default=None, factory=dict): +def update_in( + d: Mapping[object, object], + keys: Iterable[object], + func: Callable[..., object], + default: object = None, + factory: type = dict, +) -> object: """Update value in a (potentially) nested dictionary inputs: @@ -300,7 +339,7 @@ def update_in(d, keys, func, default=None, factory=dict): for key in ks: if k in d: - d = d[k] + d = d[k] # pyrefly: ignore[bad-assignment] dtemp = factory() dtemp.update(d) else: @@ -316,7 +355,12 @@ def update_in(d, keys, func, default=None, factory=dict): return rv -def get_in(keys, coll, default=None, no_default=False): +def get_in( + keys: Iterable[object], + coll: object, + default: object = None, + no_default: bool = False, +) -> object: """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless @@ -349,14 +393,18 @@ def get_in(keys, coll, default=None, no_default=False): operator.getitem """ try: - return reduce(operator.getitem, keys, coll) + return reduce( + operator.getitem, + keys, # pyrefly: ignore[bad-argument-type] + coll, # pyrefly: ignore[bad-argument-type] + ) except (KeyError, IndexError, TypeError): if no_default: raise return default -def getter(index): +def getter(index: object) -> Callable[..., object]: if isinstance(index, list): if len(index) == 1: index = index[0] @@ -369,7 +417,7 @@ def getter(index): return operator.itemgetter(index) -def groupby(key, seq): +def groupby(key: object, seq: Iterable[object]) -> dict[object, list[object]]: """Group a collection by a key function >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] @@ -404,13 +452,13 @@ def groupby(key, seq): d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] for item in seq: d[key(item)](item) - rv = {} + rv: dict[object, list[object]] = {} for k, v in d.items(): - rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] + rv[k] = v.__self__ # type: ignore[attr-defined] return rv -def first(seq): +def first(seq: Iterable[_T]) -> _T: """The first element in a sequence >>> first("ABC") diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 8f38ab7f23d06..9e414c38a40af 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -1,8 +1,20 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar + + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from .variable import Var + +_T = TypeVar("_T") + + __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] -def hashable(x): +def hashable(x: object) -> bool: try: hash(x) return True @@ -10,7 +22,7 @@ def hashable(x): return False -def transitive_get(key, d): +def transitive_get(key: object, d: dict[Var, object]) -> object: """Transitive dict.get >>> d = {1: 2, 2: 3, 3: 4} >>> d.get(1) @@ -23,7 +35,10 @@ def transitive_get(key, d): return key -def raises(err, lamda): # codespell:ignore lamda +def raises( + err: type[BaseException], + lamda: Callable[[], object], # codespell:ignore lamda +) -> bool: try: lamda() # codespell:ignore lamda return False @@ -33,7 +48,7 @@ def raises(err, lamda): # codespell:ignore lamda # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin -def _toposort(edges): +def _toposort(edges: dict[_T, Iterable[_T]]) -> list[_T]: """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a @@ -66,7 +81,7 @@ def _toposort(edges): return L -def reverse_dict(d): +def reverse_dict(d: dict[_T, Iterable[_T]]) -> dict[_T, tuple[_T, ...]]: """Reverses direction of dependence dict. >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} @@ -84,10 +99,10 @@ def reverse_dict(d): for val in d[key]: # pyrefly: ignore [unsupported-operation] result[val] = result.get(val, ()) + (key,) - return result + return result # pyrefly: ignore[bad-return] -def xfail(func): +def xfail(func: Callable[[], object]) -> None: try: func() raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 @@ -95,7 +110,7 @@ def xfail(func): pass -def freeze(d): +def freeze(d: object) -> object: """Freeze container to hashable form >>> freeze(1) 1 diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 1b5b51aaf99a5..9082fc7e2a294 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -1,11 +1,18 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + from contextlib import contextmanager +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable, Generator, Hashable + from typing import Literal from .dispatch import dispatch from .utils import hashable -_global_logic_variables = set() # type: ignore[var-annotated] +_global_logic_variables: set[Hashable] = set() _glv = _global_logic_variables @@ -14,39 +21,39 @@ class Var: _id = 1 - def __new__(cls, *token): + def __new__(cls, *token: Hashable) -> Var: # noqa: PYI034 if len(token) == 0: token = f"_{Var._id}" # type: ignore[assignment] Var._id += 1 elif len(token) == 1: - token = token[0] + token = token[0] # pyrefly: ignore[bad-assignment] obj = object.__new__(cls) obj.token = token # type: ignore[attr-defined] return obj - def __str__(self): + def __str__(self) -> str: return "~" + str(self.token) # type: ignore[attr-defined] __repr__ = __str__ - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return type(self) is type(other) and self.token == other.token # type: ignore[attr-defined] - def __hash__(self): + def __hash__(self) -> int: return hash((type(self), self.token)) # type: ignore[attr-defined] -def var(): +def var() -> Callable[..., Var]: return lambda *args: Var(*args) -def vars(): +def vars() -> Callable[[int], list[Callable[..., Var]]]: return lambda n: [var() for i in range(n)] @dispatch(Var) -def isvar(v): +def isvar(v: Var) -> Literal[True]: return True @@ -54,12 +61,12 @@ def isvar(v): @dispatch(object) # type: ignore[no-redef] -def isvar(o): - return _glv and hashable(o) and o in _glv +def isvar(o: object) -> bool: + return bool(_glv and hashable(o) and o in _glv) @contextmanager -def variables(*variables): +def variables(*variables: Hashable) -> Generator[None, None, None]: """ Context manager for logic variables diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index fa7208f7545cb..b240a64d9729f 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -1,10 +1,25 @@ -# mypy: allow-untyped-defs +from typing import Any + +import torch +import torch.fx from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.refinement_types import Equality from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] from torch.fx.tensor_type import TensorType -def infer_symbolic_types_single_pass(traced): +__all__ = [ + "check_for_type_equality", + "convert_eq", + "infer_symbolic_types", + "infer_symbolic_types_single_pass", + "substitute_all_types", + "substitute_solution_one_type", + "unify_eq", +] + + +def infer_symbolic_types_single_pass(traced: torch.fx.GraphModule) -> None: """ Calls our symbolic inferencer once. """ @@ -14,7 +29,7 @@ def infer_symbolic_types_single_pass(traced): substitute_all_types(traced.graph, mgu) -def infer_symbolic_types(traced): +def infer_symbolic_types(traced: torch.fx.GraphModule) -> None: """ Calls our symbolic inferencer twice. This is useful when one pass is not enough @@ -34,7 +49,7 @@ def infer_symbolic_types(traced): r.symbolic_relations() -def convert_eq(list_of_eq): +def convert_eq(list_of_eq: list[Equality]) -> tuple[tuple[Any, ...], tuple[Any, ...]]: """ Convert equality constraints in the right format to be used by unification library. @@ -47,7 +62,7 @@ def convert_eq(list_of_eq): return tuple(lhs), tuple(rhs) -def unify_eq(list_of_eq): +def unify_eq(list_of_eq: list[Equality]) -> Any: """ Apply unification to a set of equality constraints @@ -56,7 +71,7 @@ def unify_eq(list_of_eq): return unify(lhs, rhs) -def substitute_solution_one_type(mapping, t): +def substitute_solution_one_type(mapping: dict[object, object], t: object) -> Any: """ Apply the most general unifier to a type """ @@ -91,7 +106,7 @@ def substitute_solution_one_type(mapping, t): return t -def substitute_all_types(graph, mapping): +def substitute_all_types(graph: torch.fx.Graph, mapping: dict[object, object]) -> None: """ Apply the most general unifier to all types in a graph till reaching a fixed point. If the input and output graph @@ -112,7 +127,7 @@ def substitute_all_types(graph, mapping): n.type = substitute_solution_one_type(mapping, n.type) -def check_for_type_equality(g1, g2): +def check_for_type_equality(g1: torch.fx.Graph, g2: torch.fx.Graph) -> bool: """ A check equality to be used in fixed points. We do not use graph equality but instead type diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 643e2d2537ef8..b500d7a8f2fd2 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import builtins import functools import logging @@ -6,7 +5,7 @@ import operator from collections.abc import Callable from dataclasses import dataclass -from typing import Any +from typing import Any, overload, TYPE_CHECKING, TypeVar import sympy @@ -19,8 +18,14 @@ from torch.utils._sympy.interp import sympy_interp +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv, TrackedFake + + log = logging.getLogger(__name__) +_R = TypeVar("_R") + try: import z3 # type: ignore[import] @@ -89,7 +94,7 @@ def get_args_str(e: z3.ExprRef) -> list[str]: # Collect the arguments of chains of ADD and MUL. # This is safe, since they are associative. - def collect_str_args(e): + def collect_str_args(e: z3.ExprRef) -> list[str]: if not (z3.is_app(e) and e.decl().kind() == kind): return [z3str(e)] else: @@ -148,9 +153,11 @@ def collect_str_args(e): # We need to convert to/from BitVec in order to use z3 bitwise ops. # We assume that integers are 64 bit. # If all args are boolean, then use the boolean bitwise op implementation instead, if provided. - def _bitwise_op(bitwise_func, bool_func): + def _bitwise_op( + bitwise_func: Callable[..., Any], bool_func: Callable[..., Any] | None + ) -> Callable[..., Any]: @functools.wraps(bitwise_func) - def wrapper(self, *args): + def wrapper(self: "_Z3Ops", *args: z3.ExprRef) -> Any: if bool_func is not None and all( isinstance(arg, z3.BoolRef) for arg in args ): @@ -270,7 +277,9 @@ def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # # 2. Calls an operation that corresponds to 'op', but works with Z3 # inhabitants (left as is if it works as is) - def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + def z3op( + op: Callable[..., Any], validator: "TranslationValidator" + ) -> Callable[..., Any]: # Operations that have booleans as their argument. # This is needed because the argument of some FX nodes were # literal integers, instead of booleans. So, whenever this flag @@ -279,8 +288,8 @@ def z3op(op: Callable, validator: "TranslationValidator") -> Callable: as_bool = op in boolean_ops # Lifts the function into 'z3.ExprRef' domain. - def lift(func): - def wrap(a) -> z3.ExprRef: + def lift(func: Callable[..., _R]) -> Callable[..., _R]: + def wrap(a: object) -> z3.ExprRef: if isinstance(a, (z3.ArithRef, z3.BoolRef)): return a # Convert it into a Z3 value, if it is some of the supported @@ -294,7 +303,7 @@ def wrap(a) -> z3.ExprRef: raise ValueError(f"can't lift type: {type(a)}") @functools.wraps(func) - def wrapper(*args): + def wrapper(*args: object) -> Any: # Lifts the arguments into a list of Z3 inhabitants. if len(args) == 1 and isinstance(args[0], (list, tuple)): wrapped_args = (tuple(wrap(a) for a in args[0]),) @@ -327,7 +336,7 @@ def wrapper(*args): torch.sym_max: lift(ops.max), torch.sym_min: lift(ops.min), torch.sym_sum: lift(ops.sym_sum), - torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch.sym_ite: lift(lambda b, t, f: z3.If(b, t, f)), torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] # Not lifted because we only use this function as a # marker for adding the expression as validator input. @@ -345,7 +354,9 @@ def wrapper(*args): # it adds the Z3 expression corresponding to the argument as validator # input. class PopulateValidator(torch.fx.Interpreter): - def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + def __init__( + self, graph: torch.fx.Graph, validator: "TranslationValidator" + ) -> None: # Reference to the translation validator. self.validator = validator @@ -389,7 +400,7 @@ def __init__( self._validator = validator self._ops = _Z3Ops(self._validator) - def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + def constant(self, value: int | float | bool, dtype: torch.dtype) -> z3.ExprRef: # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) @@ -674,7 +685,7 @@ def translation_validation_timeout() -> int: return config.translation_validation_timeout -def _assert_z3_installed_if_tv_set(): +def _assert_z3_installed_if_tv_set() -> None: if not (_HAS_Z3 or not config.translation_validation): raise AssertionError( "translation validation requires Z3 package. Please, either install " @@ -683,14 +694,16 @@ def _assert_z3_installed_if_tv_set(): class ValidationException(TorchDynamoException): - def __init__(self, model, assertions, target_exprs, failed_source_exprs): + def __init__( + self, model: Any, assertions: Any, target_exprs: Any, failed_source_exprs: Any + ) -> None: if not _HAS_Z3: raise AssertionError("Z3 is required") - def symbolstr(sym) -> str: + def symbolstr(sym: Any) -> str: return f"{sym}: {model[sym]}" - def joinlines(xs) -> str: + def joinlines(xs: Any) -> str: return "\n".join(f" ==> {x}" for x in xs) model_str = joinlines(sorted(map(symbolstr, model))) @@ -712,12 +725,18 @@ def joinlines(xs) -> str: Failed Source Expressions: {failed_source_exprs_str}""" - def __str__(self): + def __str__(self) -> str: return f"{self.msg}\n\n{self.details}" class BisectValidationException(TorchDynamoException): - def __init__(self, validation_exc, expr, failed_action, traced_node): + def __init__( + self, + validation_exc: ValidationException, + expr: sympy.Basic, + failed_action: str, + traced_node: torch.fx.Node, + ) -> None: self.msg = f"translation validation failed when {failed_action}: {expr}" self.details = f"""\ Failure occurred while running node: @@ -725,7 +744,7 @@ def __init__(self, validation_exc, expr, failed_action, traced_node): {validation_exc.details}""" - def __str__(self): + def __str__(self) -> str: return f"{self.msg}\n\n{self.details}" @@ -741,7 +760,7 @@ def __str__(self): # As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors # might be silently happening. This function tries to nail down exactly at which # point things went wrong from a validation perspective. -def bisect(shape_env): +def bisect(shape_env: "ShapeEnv") -> None: from torch.fx.experimental.recording import ( FakeTensorMeta, replay_shape_env_events, @@ -766,7 +785,25 @@ def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: # # This is needed so as not to simplify a symbolic expression using a ShapeEnv # "from the future", where it may have a different set of replacements. - def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + @overload + def new_with_shape_env(shape_env: ShapeEnv, fake: int) -> int: ... + + @overload + def new_with_shape_env(shape_env: ShapeEnv, fake: torch.SymInt) -> torch.SymInt: ... + + @overload + def new_with_shape_env( + shape_env: ShapeEnv, fake: torch.SymFloat + ) -> torch.SymFloat: ... + + @overload + def new_with_shape_env( + shape_env: ShapeEnv, fake: FakeTensorMeta + ) -> FakeTensorMeta: ... + + def new_with_shape_env( + shape_env: ShapeEnv, fake: int | torch.SymInt | torch.SymFloat | FakeTensorMeta + ) -> int | torch.SymInt | torch.SymFloat | FakeTensorMeta: if isinstance(fake, int): return fake if isinstance(fake, torch.SymInt): @@ -784,7 +821,7 @@ def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: # Checks whether the given shape_env fails when produce_guards is called. def check_shapeenv_fails( - shape_env: ShapeEnv, tracked_fakes: list[Any] | None + shape_env: ShapeEnv, tracked_fakes: list["TrackedFake"] | None ) -> ValidationException | None: if tracked_fakes is None: raise AssertionError("tracked_fakes is None") @@ -793,6 +830,7 @@ def check_shapeenv_fails( # don't populate EqualityConstraint list. Reason: we would also have # to save OutputGraph.tracked_fakes_id_to_source. shape_env.produce_guards( + # pyrefly: ignore [no-matching-overload] # TrackedFake.fake includes FakeTensor [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], [a.source for a in tracked_fakes], input_contexts=[a.symbolic_context for a in tracked_fakes], @@ -859,6 +897,7 @@ def check_node_fails(node: torch.fx.Node) -> ValidationException | None: if not (left in exception and isinstance(exception[left], ValidationException)): raise AssertionError("Expected ValidationException at bisect result") + left_exception: ValidationException = exception[left] # type: ignore[assignment] node = assert_nodes[left] event = get_node_event(node) @@ -885,7 +924,7 @@ def check_node_fails(node: torch.fx.Node) -> ValidationException | None: ) raise BisectValidationException( - exception[left], + left_exception, expr=args[1], failed_action=failed_action, traced_node=node.meta[CURRENT_NODE_KEY], diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 38afbb0b4692c..523553acd051c 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,4 +1,5 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import builtins import contextlib import copy @@ -15,10 +16,10 @@ import typing import warnings from collections import defaultdict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Generator, Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Literal, NamedTuple, Optional, TYPE_CHECKING +from typing import Any, Literal, NamedTuple, TYPE_CHECKING import torch import torch.utils._pytree as pytree @@ -37,8 +38,8 @@ __all__ = ["PythonCode", "CodeGen", "Graph"] if TYPE_CHECKING: - from ._symbolic_trace import Tracer # noqa: F401 - from .graph_module import GraphModule # noqa: F401 + from ._symbolic_trace import Tracer + from .graph_module import GraphModule # Mapping of builtins to their `typing` equivalent. @@ -72,7 +73,7 @@ class _CustomBuiltin(NamedTuple): # How to import this object from the standard library. import_str: str # The actual object, produced from that import string. - obj: Any + obj: object # Combined dict of disallowed variable names so we can check with one lookup @@ -82,7 +83,7 @@ class _CustomBuiltin(NamedTuple): _custom_builtins: dict[str, _CustomBuiltin] = {} -def _register_custom_builtin(name: str, import_str: str, obj: Any): +def _register_custom_builtin(name: str, import_str: str, obj: object) -> None: _custom_builtins[name] = _CustomBuiltin(import_str, obj) _illegal_names[name] = obj @@ -131,7 +132,7 @@ def _snake_case(s: str) -> str: ).fullmatch -def _is_from_torch(obj: Any) -> bool: +def _is_from_torch(obj: object) -> bool: module_name = getattr(obj, "__module__", None) if module_name is not None: return _torch_but_not_dynamo(module_name) is not None @@ -155,12 +156,12 @@ class _Namespace: - Names generated do not shadow builtins, unless the object is indeed that builtin. """ - def __init__(self): - self._obj_to_name: dict[Any, str] = {} + def __init__(self) -> None: + self._obj_to_name: dict[object, str] = {} self._used_names: set[str] = set() self._base_count: dict[str, int] = {} - def create_name(self, candidate: str, obj: Any | None) -> str: + def create_name(self, candidate: str, obj: object | None) -> str: """Create a unique name. Arguments: @@ -211,7 +212,7 @@ def create_name(self, candidate: str, obj: Any | None) -> str: self._obj_to_name[obj] = candidate return candidate - def associate_name_with_obj(self, name: str, obj: Any): + def associate_name_with_obj(self, name: str, obj: object) -> None: """Associate a unique name with an object. Neither `name` nor `obj` should be associated already. @@ -220,7 +221,7 @@ def associate_name_with_obj(self, name: str, obj: Any): if maybe_existing is not name: raise AssertionError("obj is already associated") - def _rename_object(self, obj: Any, name: str): + def _rename_object(self, obj: object, name: str) -> None: if obj not in self._obj_to_name: raise AssertionError(f"Object {obj} is not in _obj_to_name") self._obj_to_name[obj] = name @@ -257,19 +258,26 @@ def _format_target(base: str, target: str) -> str: class _InsertPoint: - def __init__(self, graph, new_insert): + def __init__(self, graph: Graph, new_insert: Callable[..., None]) -> None: self.graph = graph self.orig_insert, graph._insert = graph._insert, new_insert - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, type, value, tb): + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + tb: types.TracebackType | None, + ) -> None: self.graph._insert = self.orig_insert class _node_list: - def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next"): + def __init__( + self, graph: Graph, direction: Literal["_prev", "_next"] = "_next" + ) -> None: if direction not in ("_next", "_prev"): raise AssertionError( f"direction must be '_next' or '_prev', got {direction}" @@ -277,14 +285,19 @@ def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next self.graph = graph self.direction = direction - def __len__(self): + def __len__(self) -> int: return self.graph._len - def __iter__(self): + # TODO: These should return Iterator[Node], but doing so causes ~350 + # downstream pyrefly errors because Node.target is typed as + # Callable[..., Any] | str and pyrefly can't narrow it based on + # node.op checks (e.g. `if node.op == "call_module": node.target` + # should be str but pyrefly doesn't support that narrowing). + def __iter__(self) -> Iterator[Any]: return _NodeIter(self.graph._root, self.direction == "_prev") - def __reversed__(self): - return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") + def __reversed__(self) -> Iterator[Any]: + return _NodeIter(self.graph._root, self.direction == "_next") class _PyTreeInfo(NamedTuple): @@ -308,14 +321,14 @@ class _ParsedStackTrace: name: str code: str - def get_summary_str(self): + def get_summary_str(self) -> str: return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" # get File:lineno code from stack_trace def _parse_stack_trace( stack_trace: str, filter_fn: Callable[[str, str, str], bool] | None = None -): +) -> _ParsedStackTrace | None: if stack_trace is None: return None pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") @@ -341,9 +354,9 @@ def _parse_stack_trace( @compatibility(is_backward_compatible=False) class CodeGen: # This is an override hook so we can customize the SymNode printer. - _sym_repr: Callable[["torch.types.PySymType"], str] = lambda x: repr(x) + _sym_repr: Callable[[torch.types.PySymType], str] = lambda x: repr(x) - def __init__(self): + def __init__(self) -> None: self._body_transformer: TransformCodeFunc | None = None self._func_name: str = "forward" @@ -359,30 +372,43 @@ def _format_single_arg(self, arg: str) -> str: else: return f" {arg},\n" - def _get_delimiters(self, container) -> tuple[str, str]: + def _get_delimiters(self, container: Sequence[object]) -> tuple[str, str]: """Helper to get opening and closing delimiters for containers.""" return ("(", ")") if isinstance(container, tuple) else ("[", "]") - def _format_multiline_container(self, items, descs=None, prefix="") -> str: + def _format_multiline_container( + self, + items: Sequence[object], + descs: Sequence[str] | None = None, + prefix: str = "", + repr_fn: Callable[[object], str] | None = None, + ) -> str: """Helper to format containers (lists/tuples) in multiline format.""" ldelim, rdelim = self._get_delimiters(items) desc_trailers = self._get_desc_trailers(items, descs) + if repr_fn is None: + repr_fn = repr return ( f"{prefix}{ldelim}\n" + "".join( - f" {item},{trailer}\n" for item, trailer in zip(items, desc_trailers) + f" {repr_fn(item)},{trailer}\n" + for item, trailer in zip(items, desc_trailers) ) + f"{rdelim}" ) - def _get_desc_trailers(self, items, descs): + def _get_desc_trailers( + self, items: Sequence[object], descs: Sequence[str] | None + ) -> list[str]: """Helper to generate description trailers for items.""" if descs is None: return [""] * len(items) return [f" # {desc}" for desc in descs] - def _call_method_with_signature_check(self, method, *args, **kwargs): + def _call_method_with_signature_check( + self, method: Callable[..., Any], *args: Any, **kwargs: Any + ) -> Any: """Helper to call a method with optional parameters based on signature.""" sig = inspect.signature(method) # Filter kwargs to only include parameters that exist in the method signature @@ -414,16 +440,24 @@ def gen_fn_def( return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" def generate_output( - self, output_args: Argument, *, descs: Any | None = None + self, + output_args: Argument, + *, + descs: Sequence[str] | None = None, + repr_fn: Callable[[object], str] | None = None, ) -> str: """ Given the output arguments, generates the return statement of the FX function. Note: The returned statement should not be indented. """ + if repr_fn is None: + repr_fn = repr if descs is not None and isinstance(output_args, (list, tuple)): - return self._format_multiline_container(output_args, descs, "return ") + return self._format_multiline_container( + output_args, descs, "return ", repr_fn=repr_fn + ) else: - return f"return {repr(output_args)}" + return f"return {repr_fn(output_args)}" def process_inputs(self, *args: Any) -> Any: """ @@ -453,7 +487,7 @@ def additional_globals(self) -> list[tuple[str, Any]]: def _gen_python_code( self, - nodes, + nodes: _node_list, root_module: str, namespace: _Namespace, *, @@ -481,7 +515,7 @@ def _gen_python_code( ) include_meta = os.environ.get("FX_GRAPH_SHOW_META", "0") == "1" - def add_global(name_hint: str, obj: Any): + def add_global(name_hint: str, obj: Any) -> str: """Add an obj to be tracked as a global. We call this for names that reference objects external to the @@ -513,7 +547,7 @@ def add_global(name_hint: str, obj: Any): for name, (_, obj) in _custom_builtins.items(): add_global(name, obj) - def type_repr(o: Any): + def type_repr(o: object) -> str: if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] return "()" @@ -555,7 +589,7 @@ def type_repr(o: Any): dim_blue = _identity blue = _identity - def _get_repr(arg: Any) -> str: + def _get_repr(arg: object) -> str: if isinstance(arg, Node): # first because common return repr(arg) elif isinstance(arg, tuple) and hasattr(arg, "_fields"): @@ -608,7 +642,7 @@ def _format_args( node_to_last_use: dict[Node, Node] = {} user_to_last_uses: dict[Node, list[Node]] = {} - def register_last_uses(n: Node, user: Node): + def register_last_uses(n: Node, user: Node) -> None: if n not in node_to_last_use: node_to_last_use[n] = user user_to_last_uses.setdefault(user, []).append(n) @@ -617,7 +651,7 @@ def register_last_uses(n: Node, user: Node): for input_node in node._input_nodes: register_last_uses(input_node, node) - def delete_unused_values(user: Node): + def delete_unused_values(user: Node) -> None: """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage @@ -646,7 +680,7 @@ def delete_unused_values(user: Node): prev_summary_str = None - def append_stacktrace_summary(node: Node): + def append_stacktrace_summary(node: Node) -> None: """ Append a summary of the stacktrace to the generated code. This is useful for debugging. @@ -656,7 +690,7 @@ def append_stacktrace_summary(node: Node): if node.op not in {"placeholder", "output"}: additional_meta_str = "" if additional_meta: - parts = [] + parts: list[str] = [] for key in additional_meta: if key in node.meta: parts.append(f"{key}: {node.meta[key]}") @@ -703,10 +737,10 @@ def append_stacktrace_summary(node: Node): prev_summary_str = summary_str body.append(summary_str) - def stringify_shape(shape: Iterable) -> str: + def stringify_shape(shape: Iterable[object]) -> str: return f"[{', '.join([str(x) for x in shape])}]" - def emit_node(node: Node): + def emit_node(node: Node) -> None: maybe_type_annotation = ( "" if node.type is None else f" : {type_repr(node.type)}" ) @@ -900,6 +934,7 @@ def _tensor_annotation(t: torch.Tensor) -> str: self.generate_output, node.args[0], descs=desc if expanded_def else None, + repr_fn=_get_repr, ) ) return @@ -1007,8 +1042,12 @@ class _BoxedCodeGen(CodeGen): """ def gen_fn_def( - self, free_vars, maybe_return_annotation, *, expanded_def: bool = False - ): + self, + free_vars: list[str], + maybe_return_annotation: str, + *, + expanded_def: bool = False, + ) -> str: """ Generate function definition for boxed calling convention. @@ -1036,7 +1075,7 @@ def gen_fn_def( class _PyTreeCodeGen(CodeGen): - def __init__(self, pytree_info: _PyTreeInfo): + def __init__(self, pytree_info: _PyTreeInfo) -> None: super().__init__() self.pytree_info: _PyTreeInfo = pytree_info @@ -1067,7 +1106,9 @@ def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str: else: return "\n " + "".join(x + "; " for x in has_annotation) + "\n" - def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + def gen_var_bindings( + self, fn_args: list[str], free_vars: list[str], expanded_def: bool + ) -> str: in_spec = self.pytree_info.in_spec # when kwargs is present, in_spec is tuple(args, kwargs) has_args_kwargs_tuple = ( @@ -1105,8 +1146,12 @@ def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: return bindings def gen_fn_def( - self, free_vars, maybe_return_annotation, *, expanded_def: bool = False - ): + self, + free_vars: list[str], + maybe_return_annotation: str, + *, + expanded_def: bool = False, + ) -> str: # Given a user function/model: # myargs = [myargs0, myargs1] # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} @@ -1139,32 +1184,41 @@ def gen_fn_def( fn_definition += self.gen_var_bindings(fn_args, free_vars, expanded_def) return fn_definition - def generate_output(self, output_args, *, descs: Any | None = None): + def generate_output( + self, + output_args: Argument, + *, + descs: Sequence[str] | None = None, + repr_fn: Callable[[object], str] | None = None, + ) -> str: + if repr_fn is None: + repr_fn = repr if self.pytree_info and self.pytree_info.out_spec: if descs is not None and isinstance(output_args, (list, tuple)): return ( self._format_multiline_container( - output_args, descs, "return pytree.tree_unflatten(" + output_args, + descs, + "return pytree.tree_unflatten(", + repr_fn=repr_fn, ) + ", self._out_spec)" ) else: - return ( - f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" - ) + return f"return pytree.tree_unflatten({repr_fn(output_args)}, self._out_spec)" else: - return super().generate_output(output_args, descs=descs) + return super().generate_output(output_args, descs=descs, repr_fn=repr_fn) class _ExportCodeGen(_PyTreeCodeGen): def __init__( self, pytree_info: _PyTreeInfo, - in_shuffle_graph: "GraphModule", - out_shuffle_graph: "GraphModule", + in_shuffle_graph: GraphModule, + out_shuffle_graph: GraphModule, tree_leaf_names: list[str], root: torch.nn.Module | None, - ): + ) -> None: super().__init__(pytree_info) self.in_shuffle_graph = in_shuffle_graph self.out_shuffle_graph = out_shuffle_graph @@ -1184,11 +1238,13 @@ def process_outputs(self, out: Any) -> Any: ret = super().process_outputs(flat_outs) return ret - def gen_fn_def(self, *args, **kwargs) -> str: + def gen_fn_def(self, *args: Any, **kwargs: Any) -> str: fn_def = super().gen_fn_def(*args, **kwargs) return fn_def - def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: + def gen_var_bindings( + self, fn_args: list[str], free_vars: list[str], expanded_def: bool + ) -> str: without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars] fn_signature: str = f"{', '.join(fn_args)}" if self.root is not None: @@ -1197,7 +1253,11 @@ def gen_var_bindings(self, fn_args, free_vars, expanded_def) -> str: {", ".join(self.tree_leaf_names)}, = pytree.tree_leaves(({fn_signature},)) {", ".join(without_annotation)}, = self._in_shuffle_graph({", ".join(self.tree_leaf_names)})""" - def generate_output(self, output_args, *args, **kwargs) -> str: + def generate_output(self, output_args: Argument, *args: Any, **kwargs: Any) -> str: + if not isinstance(output_args, (list, tuple)): + raise TypeError( + f"Expected list or tuple for output_args, got {type(output_args)}" + ) output = f"self._out_shuffle_graph({', '.join(self.tree_leaf_names)}, {', '.join([str(a) for a in output_args])})" return f"return pytree.tree_unflatten({output}, self._out_spec)" @@ -1207,15 +1267,15 @@ class _FindNodesLookupTable: Side table for the graph for the purpose of doing fast queries """ - def __init__(self): + def __init__(self) -> None: self.table: dict[tuple[str, Target | None], dict[Node, None]] = defaultdict( dict ) - def _key(self, node) -> tuple[str, Target | None]: + def _key(self, node: Node) -> tuple[str, Target | None]: return (node.op, node.target if node.op == "call_function" else None) - def __contains__(self, node) -> bool: + def __contains__(self, node: Node) -> bool: return node in self.table[self._key(node)] def insert(self, node: Node) -> None: @@ -1224,7 +1284,8 @@ def insert(self, node: Node) -> None: def remove(self, node: Node) -> None: self.table[self._key(node)].pop(node) - def find_nodes(self, *, op: str, target: Optional["Target"] = None): + # TODO: should return list[Node], see _node_list.__iter__ comment + def find_nodes(self, *, op: str, target: Target | None = None) -> list[Any]: if op == "call_function": if target is None: raise AssertionError("target must not be None for call_function op") @@ -1289,10 +1350,10 @@ def forward(self, x): @compatibility(is_backward_compatible=True) def __init__( self, - owning_module: Optional["GraphModule"] = None, - tracer_cls: type["Tracer"] | None = None, + owning_module: GraphModule | None = None, + tracer_cls: type[Tracer] | None = None, tracer_extras: dict[str, Any] | None = None, - ): + ) -> None: """ Construct an empty Graph. """ @@ -1309,11 +1370,13 @@ def __init__( self._find_nodes_lookup_table = _FindNodesLookupTable() @property - def owning_module(self): + # TODO: should return GraphModule | None, but causes downstream errors + # where callers pass it to functions expecting non-optional GraphModule + def owning_module(self): # pyrefly: ignore[unannotated-return] return self._owning_module @owning_module.setter - def owning_module(self, mod: Optional["GraphModule"]): + def owning_module(self, mod: GraphModule | None) -> None: self._owning_module = mod @property @@ -1339,9 +1402,10 @@ def output_node(self) -> Node: return output_node @compatibility(is_backward_compatible=False) + # TODO: should return list[Node], see _node_list.__iter__ comment def find_nodes( - self, *, op: str, target: Optional["Target"] = None, sort: bool = True - ): + self, *, op: str, target: Target | None = None, sort: bool = True + ) -> list[Any]: """ Allows for fast query of nodes @@ -1366,8 +1430,8 @@ def find_nodes( @compatibility(is_backward_compatible=True) def graph_copy( - self, g: "Graph", val_map: dict[Node, Node], return_output_node=False - ) -> "Argument | None": + self, g: Graph, val_map: dict[Node, Node], return_output_node: bool = False + ) -> Argument | None: """ Copy all nodes from a given graph into ``self``. @@ -1393,7 +1457,7 @@ def graph_copy( val_map[node] = self.node_copy(node, lambda n: val_map[n]) return None - def __deepcopy__(self, memo=None) -> "Graph": + def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Graph: """ Explicitly implement __deepcopy__ to prevent excessive recursion depth from the default implementation. This uses graph_copy to copy the nodes @@ -1403,7 +1467,11 @@ def __deepcopy__(self, memo=None) -> "Graph": """ memo = memo if memo else {} g = Graph(tracer_cls=self._tracer_cls) - output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + output_vals = g.graph_copy( + self, + val_map=memo, # pyrefly: ignore[bad-argument-type] + return_output_node=True, + ) g._codegen = copy.deepcopy(self._codegen) if output_vals is not None: if not isinstance(output_vals, tuple): @@ -1424,9 +1492,9 @@ def __deepcopy__(self, memo=None) -> "Graph": def create_node( self, op: str, - target: "Target", - args: tuple["Argument", ...] | None = None, - kwargs: dict[str, "Argument"] | None = None, + target: Target, + args: tuple[Argument, ...] | None = None, + kwargs: dict[str, Argument] | None = None, name: str | None = None, type_expr: Any | None = None, ) -> Node: @@ -1486,14 +1554,14 @@ def create_node( return n @compatibility(is_backward_compatible=False) - def process_inputs(self, *args): + def process_inputs(self, *args: Any) -> Any: """ Processes args so that they can be passed to the FX graph. """ return self._codegen.process_inputs(*args) @compatibility(is_backward_compatible=False) - def process_outputs(self, out): + def process_outputs(self, out: Any) -> Any: return self._codegen.process_outputs(out) @compatibility(is_backward_compatible=True) @@ -1538,7 +1606,7 @@ def erase_node(self, to_erase: Node) -> None: ) @compatibility(is_backward_compatible=True) - def inserting_before(self, n: Node | None = None): + def inserting_before(self, n: Node | None = None) -> _InsertPoint: """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and then restore it when the with statement exits:: @@ -1563,7 +1631,7 @@ def inserting_before(self, n: Node | None = None): return _InsertPoint(self, n.prepend) @compatibility(is_backward_compatible=True) - def inserting_after(self, n: Node | None = None): + def inserting_after(self, n: Node | None = None) -> _InsertPoint: """Set the point at which create_node and companion methods will insert into the graph. When used within a 'with' statement, this will temporary set the insert point and then restore it when the with statement exits:: @@ -1671,7 +1739,7 @@ def _get_attr_reference_exists( return True - if self.owning_module and not _get_attr_reference_exists( + if self.owning_module is not None and not _get_attr_reference_exists( self.owning_module, qualified_name ): warnings.warn( @@ -1692,8 +1760,8 @@ def _get_attr_reference_exists( def call_module( self, module_name: str, - args: tuple["Argument", ...] | None = None, - kwargs: dict[str, "Argument"] | None = None, + args: tuple[Argument, ...] | None = None, + kwargs: dict[str, Argument] | None = None, type_expr: Any | None = None, ) -> Node: """ @@ -1726,7 +1794,10 @@ def call_module( The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - if self.owning_module and self.owning_module.get_submodule(module_name) is None: + if ( + self.owning_module is not None + and self.owning_module.get_submodule(module_name) is None + ): warnings.warn( "Attempted to insert a call_module Node with " "no underlying reference in the owning " @@ -1742,8 +1813,8 @@ def call_module( def call_method( self, method_name: str, - args: tuple["Argument", ...] | None = None, - kwargs: dict[str, "Argument"] | None = None, + args: tuple[Argument, ...] | None = None, + kwargs: dict[str, Argument] | None = None, type_expr: Any | None = None, ) -> Node: """ @@ -1781,8 +1852,8 @@ def call_method( def call_function( self, the_function: Callable[..., Any], - args: tuple["Argument", ...] | None = None, - kwargs: dict[str, "Argument"] | None = None, + args: tuple[Argument, ...] | None = None, + kwargs: dict[str, Argument] | None = None, type_expr: Any | None = None, name: str | None = None, ) -> Node: @@ -1821,7 +1892,7 @@ def call_function( @compatibility(is_backward_compatible=True) def node_copy( - self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x + self, node: Node, arg_transform: Callable[[Node], Argument] = lambda x: x ) -> Node: """ Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from @@ -1857,7 +1928,10 @@ def node_copy( return result_node @compatibility(is_backward_compatible=True) - def output(self, result: "Argument", type_expr: Any | None = None): + # TODO: should return Node, see _node_list.__iter__ comment + def output( # pyrefly: ignore[unannotated-return] + self, result: Argument, type_expr: Any | None = None + ): """ Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents a ``return`` statement in Python code. ``result`` is the value that should @@ -1945,11 +2019,11 @@ def python_code( # makes sense to reuse it. This way, it's easy to print something like # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is # implemented cooperatively to allow this. - def node_repr(n: Node): + def node_repr(n: Node) -> str: return namespace.create_name(n.name, n) @contextmanager - def override_node_repr(graph: Graph): + def override_node_repr(graph: Graph) -> Generator[None, None, None]: orig_repr_fns = {} for node in graph.nodes: orig_repr_fns[node] = node._repr_fn @@ -2019,14 +2093,14 @@ def __str__(self) -> str: return s @compatibility(is_backward_compatible=True) - def print_tabular(self): + def print_tabular(self) -> None: """ Prints the intermediate representation of the graph in tabular format. Note that this API requires the ``tabulate`` module to be installed. """ try: - from tabulate import tabulate + from tabulate import tabulate # pyrefly: ignore[missing-import] except ImportError: print( "`print_tabular` relies on the library `tabulate`, " @@ -2041,7 +2115,7 @@ def print_tabular(self): ) @compatibility(is_backward_compatible=True) - def lint(self): + def lint(self) -> None: """ Runs various checks on this Graph to make sure it is well-formed. In particular: @@ -2175,7 +2249,7 @@ def forward(self, x): if torch._guards.TracingContext.try_get(): impure_random = torch._inductor.config.fallback_random - def has_side_effect(node): + def has_side_effect(node: Node) -> bool: if is_impure_node is not None: return is_impure_node(node) return node.is_impure(impure_random) @@ -2210,14 +2284,14 @@ def has_side_effect(node): return changed @compatibility(is_backward_compatible=False) - def set_codegen(self, codegen: CodeGen): + def set_codegen(self, codegen: CodeGen) -> None: self._codegen = codegen @compatibility(is_backward_compatible=False) def on_generate_code( self, make_transformer: Callable[[TransformCodeFunc | None], TransformCodeFunc], - ): + ) -> contextlib.AbstractContextManager[None]: """Register a transformer function when python code is generated Args: @@ -2290,7 +2364,7 @@ def insert_pdb(body): self._codegen._body_transformer = make_transformer(on_gen_code_old) @contextlib.contextmanager - def on_generate_code_context_manager(): + def on_generate_code_context_manager() -> Generator[None, None, None]: try: yield finally: @@ -2306,8 +2380,8 @@ def _clear_nodes(self) -> None: @contextmanager def _override_sym_repr( - override: Callable[["torch.types.PySymType"], str], -) -> Iterator[None]: + override: Callable[[torch.types.PySymType], str], +) -> Generator[None, None, None]: tmp = CodeGen._sym_repr try: CodeGen._sym_repr = override @@ -2316,12 +2390,12 @@ def _override_sym_repr( CodeGen._sym_repr = tmp -def _identity(x): +def _identity(x: str) -> str: return x -def _make_color_fn(code): - def f(s): +def _make_color_fn(code: str) -> Callable[[str], str]: + def f(s: str) -> str: reset = "\033[0m" return f"{code}{s}{reset}" diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 9ca2984fb9b85..4bf58aa9643e0 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,17 +1,24 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + import base64 import contextlib import copy import hashlib import itertools import linecache -import os import sys import traceback import warnings -from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, cast, TYPE_CHECKING + + +if TYPE_CHECKING: + import os + from collections.abc import Callable, Generator + from typing import Self + + from .node import Node import torch import torch.nn as nn @@ -47,11 +54,13 @@ # Using _exec_with_source will add it to our local cache # and then tools like TorchScript will be able to get source info. class _EvalCacheLoader: - def __init__(self): + def __init__(self) -> None: self.eval_cache = {} self.next_id = 0 - def cache(self, src: str, globals: dict[str, Any], co_fields=None): + def cache( + self, src: str, globals: dict[str, Any], co_fields: dict[str, Any] | None = None + ) -> str: """Store the source in a private cache, and add a lazy entry in linecache that allows the source to be retrieved by 'filename'. @@ -87,12 +96,12 @@ def cache(self, src: str, globals: dict[str, Any], co_fields=None): # Part of the loader protocol (PEP 302) # linecache will use this method when trying to find source code - def get_source(self, module_name) -> str | None: + def get_source(self, module_name: str) -> str | None: if module_name in self.eval_cache: return self.eval_cache[module_name] return None - def _get_key(self): + def _get_key(self) -> str: key = f".{self.next_id}" self.next_id += 1 return key @@ -101,20 +110,32 @@ def _get_key(self): _loader = _EvalCacheLoader() -def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None): +def _exec_with_source( + src: str, globals: dict[str, Any], co_fields: dict[str, Any] | None = None +) -> None: key = _loader.cache(src, globals, co_fields) - exec(compile(src, key, "exec"), globals) + # dont_inherit=True prevents this module's `from __future__ import + # annotations` from leaking into the generated code, which would turn + # type annotations into strings and break downstream consumers like + # TorchScript that expect real type objects. + # TODO: Fix TorchScript BC to avoid breakages like these + exec(compile(src, key, "exec", dont_inherit=True), globals) -def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None): +def _forward_from_src( + src: str, globals: dict[str, Any], co_fields: dict[str, Any] | None = None +) -> Callable[..., Any]: return _method_from_src( method_name="forward", src=src, globals=globals, co_fields=co_fields ) def _method_from_src( - method_name: str, src: str, globals: dict[str, Any], co_fields=None -) -> Callable: + method_name: str, + src: str, + globals: dict[str, Any], + co_fields: dict[str, Any] | None = None, +) -> Callable[..., Any]: # avoid mutating the passed in dict globals_copy = globals.copy() _exec_with_source(src, globals_copy, co_fields) @@ -123,7 +144,7 @@ def _method_from_src( return fn -def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: +def _format_import_statement(name: str, obj: object, importer: Importer) -> str: if name in _custom_builtins: return _custom_builtins[name].import_str if _is_from_torch(name): @@ -132,7 +153,7 @@ def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: return f"from {module_name} import {attr_name} as {name}" -def _format_import_block(globals: dict[str, Any], importer: Importer): +def _format_import_block(globals: dict[str, Any], importer: Importer) -> str: import_strs: set[str] = { _format_import_statement(name, obj, importer) for name, obj in globals.items() } @@ -142,7 +163,7 @@ def _format_import_block(globals: dict[str, Any], importer: Importer): @compatibility(is_backward_compatible=True) -def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module: +def reduce_graph_module(body: dict[str, Any], import_block: str) -> torch.nn.Module: # BC: attribute name was changed from `code` to `_code` to facilitate # making `code` into a property and adding a docstring to it fn_src = body.get("_code") or body["code"] @@ -152,7 +173,7 @@ def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Mod @compatibility(is_backward_compatible=True) def reduce_package_graph_module( - importer: PackageImporter, body: dict[Any, Any], generated_module_name: str + importer: PackageImporter, body: dict[str, Any], generated_module_name: str ) -> torch.nn.Module: forward = importer.import_module(generated_module_name).forward return _deserialize_graph_module(forward, body) @@ -162,13 +183,15 @@ def reduce_package_graph_module( # function off of the class, rather than the instance. This class is used # in _deserialize_graph_module() below. class _CodeOnlyModule(torch.nn.Module): - def __init__(self, body): + def __init__(self, body: dict[str, Any]) -> None: super().__init__() self.__dict__ = body def _deserialize_graph_module( - forward, body: dict[Any, Any], graph_module_cls=None + forward: Callable[..., Any], + body: dict[str, Any], + graph_module_cls: type | None = None, ) -> torch.nn.Module: """ Deserialize a GraphModule given the dictionary of the original module, @@ -232,7 +255,9 @@ def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: # copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' # This installs empty Modules where none exist yet if they are subpaths of target -def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): +def _copy_attr( + from_module: torch.nn.Module, to_module: torch.nn.Module, target: str +) -> None: *prefix, field = target.split(".") for item in prefix: f = getattr(from_module, item) @@ -260,7 +285,7 @@ def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: # Assign attribute 'from_obj' to the qualified name 'target' on 'to_module # This installs empty Modules where none exist yet if they are subpaths of target -def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): +def _assign_attr(from_obj: object, to_module: torch.nn.Module, target: str) -> None: *prefix, field = target.split(".") for item in prefix: t = getattr(to_module, item, None) @@ -281,17 +306,17 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): # Recursively look up target from a graph module. -def _get_attr(model: torch.nn.Module, attr_name: str): +def _get_attr(model: torch.nn.Module, attr_name: str) -> Any: return _get_attr_via_attr_list(model, attr_name.split(".")) -def _del_attr(model: torch.nn.Module, attr_name: str): +def _del_attr(model: torch.nn.Module, attr_name: str) -> None: attr_names = attr_name.split(".") t = _get_attr_via_attr_list(model, attr_names[:-1]) return delattr(t, attr_names[-1]) -def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]): +def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]) -> Any: if len(attr_list) == 0: return model *prefix, field = attr_list @@ -304,7 +329,7 @@ def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]): return getattr(t, field) -def _has_attr(model: torch.nn.Module, attr_name: str): +def _has_attr(model: torch.nn.Module, attr_name: str) -> bool: *prefix, field = attr_name.split(".") t = model for item in prefix: @@ -316,15 +341,15 @@ def _has_attr(model: torch.nn.Module, attr_name: str): def _print_readable( - module, - module_name, - print_output=True, - include_stride=False, - include_device=False, - colored=False, - expanded_def=False, - additional_meta=None, -): + module: torch.nn.Module, + module_name: str, + print_output: bool = True, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + expanded_def: bool = False, + additional_meta: list[str] | None = None, +) -> str: graph = module.graph if graph is None or not isinstance(graph, torch.fx.Graph): raise AssertionError("print_readable must be used on a module with a graph") @@ -345,7 +370,7 @@ def _print_readable( submodule_code_list = [""] for submodule_name, submodule in module.named_children(): - if hasattr(submodule, "graph"): + if isinstance(submodule, GraphModule): submodule_code_list.append( _print_readable( submodule, @@ -366,7 +391,7 @@ def _print_readable( return output -def _metadata_hash(code: str, node_metadata: dict) -> str: +def _metadata_hash(code: str, node_metadata: dict[int, dict[str, Any]]) -> str: """ Create a content-addressed hash from code and metadata. @@ -397,7 +422,7 @@ def _metadata_hash(code: str, node_metadata: dict) -> str: class _WrappedCall: - def __init__(self, cls, cls_call): + def __init__(self, cls: type, cls_call: Callable[..., Any] | None) -> None: self.cls = cls self.cls_call = cls_call @@ -439,7 +464,7 @@ def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: # joined message return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) - def __call__(self, obj, *args, **kwargs): + def __call__(self, obj: Any, *args: Any, **kwargs: Any) -> Any: try: if self.cls_call is not None: return self.cls_call(obj, *args, **kwargs) @@ -476,7 +501,7 @@ class GraphModule(torch.nn.Module): code. """ - def __new__(cls: "type[GraphModule]", *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Self: # each instance of a graph module needs its own forward method # so create a new singleton class for each instance. # it is a subclass of the user-defined class, the only difference @@ -501,7 +526,7 @@ def __init__( root: torch.nn.Module | dict[str, Any], graph: Graph, class_name: str = "GraphModule", - ): + ) -> None: """ Construct a GraphModule. @@ -596,11 +621,11 @@ def __init__( # Dictionary to store metadata self.meta: dict[str, Any] = {} - self._replace_hooks: list[Callable] = [] - self._create_node_hooks: list[Callable] = [] - self._erase_node_hooks: list[Callable] = [] + self._replace_hooks: list[Callable[[Node, str, Node], object]] = [] + self._create_node_hooks: list[Callable[[Node], object]] = [] + self._erase_node_hooks: list[Callable[[Node], object]] = [] # Used to remove hooks from deepcopied graph modules within a context manager. - self._deepcopy_hooks: list[Callable] = [] + self._deepcopy_hooks: list[Callable[[GraphModule], object]] = [] self.shape_env = None # optional not always set even when dynamic shapes exist. # TorchScript breaks trying to compile the graph setter because of the @@ -634,7 +659,9 @@ def graph(self, g: Graph) -> None: self.recompile() @compatibility(is_backward_compatible=False) - def to_folder(self, folder: str | os.PathLike, module_name: str = "FxModule"): + def to_folder( + self, folder: str | os.PathLike[str], module_name: str = "FxModule" + ) -> None: """Dumps out module to ``folder`` with ``module_name`` so that it can be imported with ``from import `` @@ -675,7 +702,7 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> str | None: else: return None - blobified_modules = [] + blobified_modules: list[str] = [] for module_name, module in self.named_children(): module_str = _gen_model_repr(module_name, module) if module_str is None: @@ -692,12 +719,12 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> str | None: for buffer_name, buffer in self._buffers.items(): if buffer is None: continue - model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" for param_name, param in self._parameters.items(): if param is None: continue - model_str += f"{tab * 2}setattr(self, '{param_name}', torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype})))\n" # noqa: B950 + model_str += f"{tab * 2}setattr(self, '{param_name}', torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype})))\n" model_str += ( f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" @@ -816,7 +843,7 @@ def delete_all_unused_submodules(self) -> None: used: list[str] = [] for node in self.graph.nodes: - if node.op == "call_module" or node.op == "get_attr": + if node.op in ("call_module", "get_attr") and isinstance(node.target, str): # A list of strings representing the different parts # of the path. For example, `foo.bar.baz` gives us # ["foo", "bar", "baz"] @@ -838,11 +865,12 @@ def join_fn(x: str, y: str) -> str: # as used if node.op == "call_module": try: - submod = self.get_submodule(node.target) + str_target = cast(str, node.target) + submod = self.get_submodule(str_target) for submod_name, _ in submod.named_modules(): if submod_name != "": - used.append(".".join([node.target, submod_name])) + used.append(".".join([str_target, submod_name])) except AttributeError: # Node referenced nonexistent submodule, don't need to # worry about GCing anything @@ -887,7 +915,9 @@ def recompile(self) -> PythonCode: self._prologue_start = python_code._prologue_start cls = type(self) - co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + co_fields: dict[str, Any] = ( + self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + ) if fx_experimental_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -945,7 +975,7 @@ def recompile(self) -> PythonCode: self._recompile_submodules() - def call_wrapped(self, *args, **kwargs): + def call_wrapped(self, *args: Any, **kwargs: Any) -> Any: return self._wrapped_call(self, *args, **kwargs) cls.__call__ = call_wrapped # type: ignore[method-assign] @@ -966,7 +996,7 @@ def _recompile_submodules(self) -> list[tuple[str, PythonCode]]: # Passing Tracer as argument allows subclasses extending fx.GraphModule # define their own Tracer (extending fx.Tracer). - def __reduce_package__(self, exporter: PackageExporter): + def __reduce_package__(self, exporter: PackageExporter) -> tuple[Any, ...]: dict_without_graph = self.__dict__.copy() dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ del dict_without_graph["_graph"] @@ -991,7 +1021,7 @@ def __reduce_package__(self, exporter: PackageExporter): (dict_without_graph, generated_module_name), ) - def __reduce__(self): + def __reduce__(self) -> tuple[Any, ...]: """ Serialization of GraphModule. We serialize only the generated code, not the underlying ``Graph``. This is because ``Graph`` does not have on-disk @@ -1006,13 +1036,13 @@ def __reduce__(self): del dict_without_graph["_graph"] return (reduce_graph_module, (dict_without_graph, import_block)) - def _deepcopy_init(self): + def _deepcopy_init(self) -> Callable[..., None]: return GraphModule.__init__ # because __reduce__ is defined for serialization, # we need to define deepcopy otherwise it will call __reduce__ # and cause symbolic tracing to occur every time we try to copy the object - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any]) -> GraphModule: res = type(self).__new__(type(self)) memo[id(self)] = res fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) @@ -1042,7 +1072,7 @@ def __deepcopy__(self, memo): hook(res) return res - def __copy__(self): + def __copy__(self) -> GraphModule: from ._lazy_graph_module import _make_graph_module res = _make_graph_module(self, self.graph) @@ -1052,17 +1082,17 @@ def __copy__(self): @compatibility(is_backward_compatible=False) def print_readable( self, - print_output=True, - include_stride=False, - include_device=False, - colored=False, + print_output: bool = True, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, *, # If `fast_sympy_print` is True then we use a sympy printer which is faster # but may result in less-readable output. fast_sympy_print: bool = False, expanded_def: bool = False, additional_meta: list[str] | None = None, - ): + ) -> str: """ Return the Python code generated for current GraphModule and its children GraphModules. @@ -1101,13 +1131,15 @@ def __str__(self) -> str: ) return "\n".join([orig_str, self._code, print_readable_reminder]) - def _replicate_for_data_parallel(self): + def _replicate_for_data_parallel(self) -> GraphModule: new_gm = self.__copy__() - new_gm._is_replica = True + object.__setattr__(new_gm, "_is_replica", True) return new_gm @contextlib.contextmanager - def _set_replace_hook(self, f): + def _set_replace_hook( + self, f: Callable[[Node, str, Node], object] + ) -> Generator[None, None, None]: """ Takes a callable which will be called every time when we replace a node to a new node, or change the node's name. Callable takes three arguments: @@ -1122,7 +1154,9 @@ def _set_replace_hook(self, f): finally: self._unregister_replace_node_hook(f) - def _register_replace_node_hook(self, f): + def _register_replace_node_hook( + self, f: Callable[[Node, str, Node], object] + ) -> None: """ Takes a callable which will be called every time when we replace a node to a new node, or change the node's name. Callable takes three arguments: @@ -1133,7 +1167,9 @@ def _register_replace_node_hook(self, f): raise AssertionError("create_node hook must be a callable.") self._replace_hooks.append(f) - def _unregister_replace_node_hook(self, f): + def _unregister_replace_node_hook( + self, f: Callable[[Node, str, Node], object] + ) -> None: """ Takes a callable which was previously registered to be called every time when we replace a node. This function will unregister that callable so it is no longer invoked on node replacement. @@ -1142,7 +1178,7 @@ def _unregister_replace_node_hook(self, f): raise AssertionError("create_node hook must be a callable.") self._replace_hooks.remove(f) - def _register_create_node_hook(self, f): + def _register_create_node_hook(self, f: Callable[[Node], object]) -> None: """ Takes a callable which will be called after we create a new node. The callable takes the newly created node as input and returns None. @@ -1151,7 +1187,7 @@ def _register_create_node_hook(self, f): raise AssertionError("create_node hook must be a callable.") self._create_node_hooks.append(f) - def _unregister_create_node_hook(self, f): + def _unregister_create_node_hook(self, f: Callable[[Node], object]) -> None: """ Takes a callable which was previously registered to be called after we create a node. This function will unregister that callable so it is no longer invoked on node creation. @@ -1160,7 +1196,7 @@ def _unregister_create_node_hook(self, f): raise AssertionError("create_node hook must be a callable.") self._create_node_hooks.remove(f) - def _register_erase_node_hook(self, f): + def _register_erase_node_hook(self, f: Callable[[Node], object]) -> None: """ Takes a callable which will be called after we erase a node. The callable takes the node that is being erased as input and returns None. @@ -1169,7 +1205,7 @@ def _register_erase_node_hook(self, f): raise AssertionError("erase_node hook must be a callable.") self._erase_node_hooks.append(f) - def _unregister_erase_node_hook(self, f): + def _unregister_erase_node_hook(self, f: Callable[[Node], object]) -> None: """ Takes a callable which was previously registered to be called after we erase a node. This function will unregister that callable so it is no longer invoked on node erasure. @@ -1178,7 +1214,7 @@ def _unregister_erase_node_hook(self, f): raise AssertionError("erase_node hook must be a callable.") self._erase_node_hooks.remove(f) - def _register_deepcopy_hook(self, f): + def _register_deepcopy_hook(self, f: Callable[[GraphModule], object]) -> None: """ Takes a callable which will be called when we deepcopy this graph module. The callable takes the resulting deepcopied graph module. @@ -1187,7 +1223,7 @@ def _register_deepcopy_hook(self, f): raise AssertionError("deepcopy hook must be a callable.") self._deepcopy_hooks.append(f) - def _unregister_deepcopy_hook(self, f): + def _unregister_deepcopy_hook(self, f: Callable[[GraphModule], object]) -> None: """ Takes a callable which was previously registered to be called after deepcopy. This function will unregister that callable so it is no longer invoked on deepcopy. @@ -1195,19 +1231,3 @@ def _unregister_deepcopy_hook(self, f): if not callable(f): raise AssertionError("deepcopy hook must be a callable.") self._deepcopy_hooks.remove(f) - - -# workarounds for issues in __torch_function__ - -# WAR for __torch_function__ not handling tensor lists, -# fix is in https://github.com/pytorch/pytorch/pull/34725 -# orig_cat = torch.cat -# def patched_cat(*args, **kwargs): -# tensors = args[0] -# for t in tensors: -# if isinstance(t, Proxy): -# return t.__torch_function__(patched_cat, (), args, kwargs) -# return orig_cat(*args, **kwargs) -# patched_cat.__module__ = 'torch' -# patched_cat.__name__ = 'cat' -# torch.cat = patched_cat diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a69e47e8d511c..7cd4ba1776567 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,8 +1,8 @@ -# mypy: allow-untyped-defs import inspect import logging +from collections.abc import Iterator from contextlib import contextmanager -from typing import Any, TYPE_CHECKING +from typing import Any import torch import torch.fx.traceback as fx_traceback @@ -19,15 +19,12 @@ from .proxy import Proxy -if TYPE_CHECKING: - from collections.abc import Iterator - log = logging.getLogger(__name__) __all__ = ["Interpreter", "Transformer"] -def _format_fx_node(n): +def _format_fx_node(n: Node) -> str: """ Format a torch.fx.Node into a human-readable string for debug logging. @@ -118,7 +115,7 @@ def __init__( module: torch.nn.Module, garbage_collect_values: bool = True, graph: Graph | None = None, - ): + ) -> None: self.module = module self.submodules = dict(self.module.named_modules()) if graph is not None: @@ -138,7 +135,7 @@ def __init__( node_to_last_use: dict[Node, Node] = {} self.user_to_last_uses: dict[Node, list[Node]] = {} - def register_last_uses(n: Node, user: Node): + def register_last_uses(n: Node, user: Node) -> None: if n not in node_to_last_use: node_to_last_use[n] = user self.user_to_last_uses.setdefault(user, []).append(n) @@ -150,7 +147,7 @@ def register_last_uses(n: Node, user: Node): @compatibility(is_backward_compatible=True) def run( self, - *args, + *args: Any, initial_env: dict[Node, Any] | None = None, enable_io_processing: bool = True, ) -> Any: @@ -240,7 +237,7 @@ def run( ) @compatibility(is_backward_compatible=True) - def boxed_run(self, args_list): + def boxed_run(self, args_list: list[Any]) -> Any: """ Run `module` via interpretation and return the result. This uses the "boxed" calling convention, where you pass a list of arguments, which will be cleared @@ -267,7 +264,7 @@ def boxed_run(self, args_list): return self.run(initial_env=env) @contextmanager - def _set_current_node(self, node): + def _set_current_node(self, node: Node) -> Iterator[None]: with fx_traceback.set_current_meta( node, f"Interpreter_{self.__class__.__name__}" ): @@ -452,7 +449,7 @@ def output( # Helper methods @compatibility(is_backward_compatible=True) - def fetch_attr(self, target: str): + def fetch_attr(self, target: str) -> Any: """ Fetch an attribute from the ``Module`` hierarchy of ``self.module``. @@ -473,7 +470,9 @@ def fetch_attr(self, target: str): return attr_itr @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]: + def fetch_args_kwargs_from_env( + self, n: Node + ) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` from the current execution environment. @@ -568,18 +567,18 @@ def fn(x): """ @compatibility(is_backward_compatible=True) - def __init__(self, module): + def __init__(self, module: GraphModule) -> None: super().__init__(module) self.new_graph = Graph() self.new_graph.set_codegen(module.graph._codegen) class TransformerTracer(Tracer): - def __init__(self, graph: Graph): + def __init__(self, graph: Graph) -> None: super().__init__() self.graph = graph self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment] - def is_leaf_module(self, _, __) -> bool: + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: return True self.tracer = TransformerTracer(self.new_graph) diff --git a/torch/fx/node.py b/torch/fx/node.py index 98ff8892243f2..001618fff6f02 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -49,7 +49,7 @@ Target: TypeAlias = Callable[..., Any] | str -Argument = Optional[ # noqa: UP007, UP045 +Argument = Optional[ # noqa: UP045 Union[ tuple["Argument", ...], Sequence["Argument"], @@ -292,6 +292,7 @@ class Node(_NodeBase): # generated function return type. (Note this is a special case. ``return`` # does not produce a value, it's more of a notation. Thus, this value # describes the type of args[0] in the ``return`` node. + # TODO: narrow this to TensorType | _DynType | None type: Any | None _sort_key: Any # If set, use this fn to print this node @@ -655,7 +656,7 @@ def format_node( return f"return {self.args[0]}" else: - def stringify_shape(shape: Iterable) -> str: + def stringify_shape(shape: Iterable[Any]) -> str: return f"[{', '.join([str(x) for x in shape])}]" meta_val = self.meta.get( diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index e2c76b6310260..532c8543ba888 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import enum import inspect import numbers @@ -6,11 +5,12 @@ import typing import warnings from collections.abc import Callable -from typing import Any, cast, NamedTuple, TYPE_CHECKING +from typing import Any, cast, Literal, NamedTuple, overload, TYPE_CHECKING import torch from torch._jit_internal import boolean_dispatched from torch._ops import OpOverload, OpOverloadPacket +from torch.utils._inspect import _fast_bind from ._compatibility import compatibility @@ -39,18 +39,18 @@ class ArgsKwargsPair(NamedTuple): kwargs: dict[str, Any] -_manual_overrides: dict[Callable, list[inspect.Signature]] = {} +_manual_overrides: dict[Callable[..., Any], list[inspect.Signature]] = {} -def _nonzero_schemas(): +def _nonzero_schemas() -> list[inspect.Signature]: signatures = [] - def nonzero(self): + def nonzero(self: torch.Tensor) -> None: pass signatures.append(inspect.signature(nonzero)) - def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] + def nonzero(self: torch.Tensor, *, as_tuple: bool) -> None: # type: ignore[no-redef] pass signatures.append(inspect.signature(nonzero)) @@ -62,7 +62,7 @@ def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] class _FakeGlobalNamespace: - def __getattr__(self, name): + def __getattr__(self, name: str) -> types.ModuleType: if name == "torch": return torch raise RuntimeError("Expected a torch namespace lookup") @@ -171,24 +171,26 @@ def _torchscript_schema_to_signature( @compatibility(is_backward_compatible=False) def check_for_mutable_operation( - target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"] -): + target: Callable[..., Any], + args: tuple["Argument", ...], + kwargs: dict[str, "Argument"], +) -> None: signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) if signatures and schemas: - matched_schemas = [] + matched_schemas: list[tuple[inspect.Signature, torch._C.FunctionSchema]] = [] # Iterate through all of the schema until we find one that matches # If one matches, populate `new_args_and_kwargs` with the new args/kwargs # values. If none matches, `new_args_and_kwargs` will be None for candidate_signature, schema in zip(signatures, schemas): try: - candidate_signature.bind(*args, **kwargs) + _fast_bind(candidate_signature, *args, **kwargs) matched_schemas.append((candidate_signature, schema)) except TypeError: continue - def throw_if_mutable(schema): + def throw_if_mutable(schema: torch._C.FunctionSchema) -> None: if schema.is_mutable: raise RuntimeError( f"Tried to trace mutable operation {schema}. FX only supports functional " @@ -209,8 +211,26 @@ def throw_if_mutable(schema): pass +@overload +def get_signature_for_torch_op( + op: Callable[..., Any], return_schemas: Literal[True] +) -> tuple[list[inspect.Signature] | None, list[torch._C.FunctionSchema] | None]: ... + + +@overload +def get_signature_for_torch_op( + op: Callable[..., Any], return_schemas: Literal[False] = ... +) -> list[inspect.Signature] | None: ... + + @compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): +def get_signature_for_torch_op( + op: Callable[..., Any], return_schemas: bool = False +) -> ( + list[inspect.Signature] + | tuple[list[inspect.Signature] | None, list[torch._C.FunctionSchema] | None] + | None +): """ Given an operator on the `torch` namespace, return a list of `inspect.Signature` objects corresponding to the overloads of that op.. May return `None` if a signature @@ -245,7 +265,7 @@ def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): @compatibility(is_backward_compatible=False) -def create_type_hint(x): +def create_type_hint(x: object) -> object: """ Produces a type hint for the given argument. @@ -262,12 +282,12 @@ def create_type_hint(x): # todo(chilli): Figure out the right way for mypy to handle this if isinstance(x, list): - def ret_type(x): + def ret_type(x: Any) -> Any: return list[x] # type: ignore[valid-type] else: - def ret_type(x): + def ret_type(x: Any) -> Any: return tuple[x, ...] # type: ignore[valid-type] if len(x) == 0: @@ -290,7 +310,7 @@ def ret_type(x): @compatibility(is_backward_compatible=False) -def type_matches(signature_type: Any, argument_type: Any): +def type_matches(signature_type: Any, argument_type: Any) -> bool: sig_origin_type = getattr(signature_type, "__origin__", signature_type) if signature_type is argument_type: @@ -317,11 +337,11 @@ def type_matches(signature_type: Any, argument_type: Any): if getattr(argument_type, "__origin__", None) is list: return issubclass(argument_type.__args__[0], sig_el_type) - def is_homogeneous_tuple(t): - if getattr(t, "__origin__", None) is not tuple: + def is_homogeneous_tuple(t: object) -> bool: + if typing.get_origin(t) is not tuple: return False - contained = t.__args__ - if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + contained = typing.get_args(t) + if contained == ((),): # Tuple[()].__args__ == ((),) for some reason return True return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) @@ -342,7 +362,7 @@ def is_homogeneous_tuple(t): @compatibility(is_backward_compatible=False) def _normalize_function_or_error( - target: Callable, + target: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any] | None = None, arg_types: tuple[Any] | None = None, @@ -366,7 +386,7 @@ def _normalize_function_or_error( @compatibility(is_backward_compatible=False) def normalize_function( - target: Callable, + target: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any] | None = None, arg_types: tuple[Any] | None = None, @@ -442,14 +462,14 @@ def normalize_function( if not callable(target): raise AssertionError(f"target must be callable, got {type(target)}") torch_op_schemas = get_signature_for_torch_op(target) - matched_schemas = [] + matched_schemas: list[inspect.Signature] = [] if torch_op_schemas: # Iterate through all of the schema until we find one that matches # If one matches, populate `new_args_and_kwargs` with the new args/kwargs # values. If none matches, `new_args_and_kwargs` will be None for candidate_signature in torch_op_schemas: try: - candidate_signature.bind(*args, **kwargs) + _fast_bind(candidate_signature, *args, **kwargs) matched_schemas.append(candidate_signature) except TypeError: continue @@ -469,8 +489,8 @@ def normalize_function( for candidate_signature in torch_op_schemas: sig_matches = True try: - bound_types = candidate_signature.bind( - *arg_types, **kwarg_types + bound_types = _fast_bind( + candidate_signature, *arg_types, **kwarg_types ) for arg_name, arg_type in bound_types.arguments.items(): param = candidate_signature.parameters[arg_name] @@ -589,7 +609,7 @@ def _args_kwargs_to_normalized_args_kwargs( if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]: return None - bound_args = sig.bind(*args, **kwargs) + bound_args = _fast_bind(sig, *args, **kwargs) bound_args.apply_defaults() new_kwargs: dict[str, Any] = {} diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py index 61c87548fa146..43dfcda045fc6 100644 --- a/torch/fx/passes/_tensorify_python_scalars.py +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -209,7 +209,8 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: and node.op == "call_function" and node.target is torch.ops.aten._local_scalar_dense.default ): - dtype = node.args[0].meta["val"].dtype + source_tensor = node.args[0].meta["val"] + dtype = source_tensor.dtype if not isinstance(node.args[0], fx.Node): raise AssertionError(f"Expected fx.Node, got {node.args[0]}") @@ -227,6 +228,15 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: expr_to_tensor_proxy[s] = MetaProxy( node.args[0], tracer=tracer, fake_mode=fake_mode ) + if len(source_tensor.shape) != 0: + # .item() always produces a scalar value, even when it is + # called on a size-1 tensor with rank > 0. Preserve that 0-d + # semantics before tensorifying the scalar expression so + # later tensor math and autograd tangents do not keep an + # accidental length-1 dimension. + expr_to_tensor_proxy[s] = torch.ops.aten.reshape.default( + expr_to_tensor_proxy[s], [] + ) # Upcast the float tensor to torch.float64 to avoid precision problem expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default( expr_to_tensor_proxy[s], torch.float64 @@ -246,7 +256,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: # Specialize all dimensions that contain symfloats. Here's # an example test that requires this: - # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 + # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 val = node.meta.get("val") if isinstance(val, FakeTensor): @@ -337,7 +347,7 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy: ): failed_tensorify_ops.update(str(node.target)) - log.info("Failed to tensorify %s", str(node.target)) + log.info("Failed to tensorify %s", node.target) # Now do one more pass that specializes all symfloats we didn't manage # to tensorify away. diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index 97496fbc9b2a2..ec9cb02c43ff9 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -1,7 +1,11 @@ -# mypy: allow-untyped-defs import operator +from collections.abc import Mapping, Sequence +from typing import Any import torch + + +__all__ = ["CudaGraphsSupport", "partition_cudagraphs"] from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport @@ -11,7 +15,9 @@ class CudaGraphsSupport(OperatorSupport): # TODO: why is submodules passed here - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + def is_node_supported( + self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: if node.op not in CALLABLE_NODE_OPS: return False @@ -23,10 +29,10 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: found_not_cuda = False - def meta_fk(meta): + def meta_fk(meta: dict[str, Any]) -> torch.Tensor: return meta["val"] if "val" in meta else meta["fake_result"] - def find_not_cuda(t): + def find_not_cuda(t: object) -> None: nonlocal found_not_cuda if isinstance(t, torch.Tensor) and t.device.type != "cuda": found_not_cuda = True @@ -42,7 +48,9 @@ def find_not_cuda(t): return not found_not_cuda -def partition_cudagraphs(gm, inputs): +def partition_cudagraphs( + gm: torch.fx.GraphModule, inputs: Sequence[object] +) -> torch.fx.GraphModule: """ Partition an FX graph into sub-GraphModules that can be validly run under CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index e5889375bb07a..5526c8be3e2d3 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,8 +1,8 @@ -# mypy: allow-untyped-defs from typing import Any import torch from torch.fx import Graph, GraphModule, Node +from torch.fx.node import Target from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.utils._pytree import tree_flatten @@ -27,7 +27,7 @@ aten.randint, aten.randn, aten.randperm, -} # noqa: E501,B950 +} inplace_ops = { aten.add_, @@ -39,17 +39,19 @@ aten.relu_, aten.sigmoid_, aten.tanh_, -} # noqa: E501 +} @torch.fx._compatibility.compatibility(is_backward_compatible=False) -def get_CSE_banned_ops(): +def get_CSE_banned_ops() -> set[torch._ops.OpOverloadPacket]: return rand_ops.union(inplace_ops) @torch.fx._compatibility.compatibility(is_backward_compatible=False) class CSEPass(PassBase): - def __init__(self, banned_ops=None): + def __init__( + self, banned_ops: set[torch._ops.OpOverloadPacket] | None = None + ) -> None: """ This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. @@ -83,7 +85,7 @@ def f(a): print(result.graph_module) """ - def get_aten_target(node): + def get_aten_target(node: Node) -> Target: if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target @@ -94,10 +96,10 @@ def get_aten_target(node): Node, Node ] = {} # map from node in the old graph to node in the new graph hash_env: dict[ - tuple[torch._ops.OpOverload, int], Node + tuple[Target, int], Node ] = {} # map from hash to a node in the new graph token_map: dict[ - tuple[torch._ops.OpOverload, int], dict[str, Any] + tuple[Target, int], dict[str, Any] ] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change @@ -113,7 +115,7 @@ def get_aten_target(node): else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' # substitute args and kwargs members to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries - def substitute(arg_list): + def substitute(arg_list: Any) -> tuple[tuple[Any, ...], Any]: arg_list, spec = tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 45507632228ef..3b4a35e7f246e 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -1,4 +1,4 @@ -# mypy: allow-untyped-defs +from typing import Any import torch.fx from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -29,7 +29,7 @@ class FakeTensorProp(torch.fx.Interpreter): def __init__( self, module: torch.fx.GraphModule, mode: FakeTensorMode | None = None - ): + ) -> None: super().__init__(module) if mode is None: mode = FakeTensorMode() @@ -38,7 +38,7 @@ def __init__( mode.reset_nt_tensor_id_counter() self.seen_subgraphs: OrderedSet[str] = OrderedSet() - def run_node(self, n: Node): + def run_node(self, n: Node) -> Any: from torch.fx.experimental.symbolic_shapes import ( compute_unbacked_bindings, rebind_unbacked, @@ -79,7 +79,7 @@ def run_node(self, n: Node): result = super().run_node(n) rebind_unbacked(self._mode.shape_env, n, result) - def extract_val(obj): + def extract_val(obj: Any) -> Any: if isinstance(obj, FakeTensor): return snapshot_fake(obj) elif isinstance(obj, torch.Tensor): @@ -101,13 +101,13 @@ def extract_val(obj): return result - def propagate(self, *args): + def propagate(self, *args: object) -> Any: fake_args = [ self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args ] return self.propagate_dont_convert_inputs(*fake_args) - def propagate_dont_convert_inputs(self, *args): + def propagate_dont_convert_inputs(self, *args: object) -> Any: with self._mode: return super().run(*args) diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index f94d4cd174f53..7671b98fba6d9 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - import hashlib from itertools import chain from types import ModuleType @@ -85,7 +83,7 @@ def __init__( parse_stack_trace: bool = False, dot_graph_shape: str | None = None, normalize_args: bool = False, - ): + ) -> None: self._name = name self.dot_graph_shape = ( dot_graph_shape if dot_graph_shape is not None else "record" @@ -122,7 +120,7 @@ def __init__( parse_stack_trace, ) - def get_dot_graph(self, submod_name=None) -> pydot.Dot: + def get_dot_graph(self, submod_name: str | None = None) -> pydot.Dot: """ Visualize a torch.fx.Graph with graphviz Example: @@ -154,7 +152,7 @@ def get_dot_graph(self, submod_name=None) -> pydot.Dot: def get_main_dot_graph(self) -> pydot.Dot: return self._dot_graphs[self._name] - def get_submod_dot_graph(self, submod_name) -> pydot.Dot: + def get_submod_dot_graph(self, submod_name: str) -> pydot.Dot: return self._dot_graphs[f"{self._name}_{submod_name}"] def get_all_dot_graphs(self) -> dict[str, pydot.Dot]: @@ -198,7 +196,7 @@ def _get_leaf_node( py_obj = getattr(py_obj, atom) return py_obj - def _typename(self, target: Any) -> str: + def _typename(self, target: torch.fx.node.Target | torch.nn.Module) -> str: if isinstance(target, torch.nn.Module): ret = torch.typename(target) elif isinstance(target, str): @@ -218,7 +216,7 @@ def _shorten_file_name( self, full_file_name: str, truncate_to_last_n: int = 2, - ): + ) -> str: splits = full_file_name.split("/") if len(splits) >= truncate_to_last_n: return "/".join(splits[-truncate_to_last_n:]) @@ -231,7 +229,7 @@ def _get_node_label( skip_node_names_in_args: bool, parse_stack_trace: bool, ) -> str: - def _get_str_for_args_kwargs(arg): + def _get_str_for_args_kwargs(arg: tuple[Any, ...] | dict[str, Any]) -> str: if isinstance(arg, tuple): prefix, suffix = r"|args=(\l", r",\n)\l" arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] @@ -303,15 +301,16 @@ def _get_str_for_args_kwargs(arg): # print file:lineno code if parse_stack_trace and node.stack_trace is not None: parsed_stack_trace = _parse_stack_trace(node.stack_trace) - fname = self._shorten_file_name(parsed_stack_trace.file) - label += ( - f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" - + r"\n" - ) + if parsed_stack_trace is not None: + fname = self._shorten_file_name(parsed_stack_trace.file) + label += ( + f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + + r"\n" + ) return label + "}" - def _tensor_meta_to_label(self, tm) -> str: + def _tensor_meta_to_label(self, tm: object) -> str: if tm is None: return "" elif isinstance(tm, TensorMetadata): @@ -414,7 +413,7 @@ def _to_dot( # "TB" means top-to-bottom rank direction in layout dot_graph = pydot.Dot(name, rankdir="TB") - buf_name_to_subgraph = {} + buf_name_to_subgraph: dict[str, pydot.Cluster] = {} for node in graph_module.graph.nodes: if ignore_getattr and node.op == "get_attr": @@ -443,7 +442,7 @@ def _to_dot( # pyrefly: ignore [missing-attribute] current_graph.add_node(dot_node) - def get_module_params_or_buffers(): + def get_module_params_or_buffers() -> None: for pname, ptensor in chain( leaf_module.named_parameters(), # pyrefly: ignore [bad-argument-type] diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index 64e1481915e00..222245718b9c4 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -1,12 +1,11 @@ -# mypy: allow-untyped-defs -from typing import Any, NamedTuple +from typing import NamedTuple import torch from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule from torch.fx.node import map_arg, Node, Target -from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.shape_prop import ShapeProp, TensorMetadata __all__ = [ @@ -25,9 +24,11 @@ def replace_target_nodes_with( old_target: Target, new_op: str, new_target: Target, -): - """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, - and updates them to match the new op code and target""" +) -> None: + """ + Modifies all nodes in fx_module.graph.nodes which match the specified op code + and target, and updates them to match the new op code and target. + """ new_graph = Graph() val_map: dict[Node, Node] = {} for node in fx_module.graph.nodes: @@ -71,7 +72,7 @@ def get_size_of_all_nodes( @compatibility(is_backward_compatible=False) -def get_tensor_meta(node: Node) -> Any: +def get_tensor_meta(node: Node) -> TensorMetadata: tensor_meta = node.meta.get("tensor_meta") if not tensor_meta: diff --git a/torch/fx/passes/graph_transform_observer.py b/torch/fx/passes/graph_transform_observer.py index f91877b00a057..e5e5e5c3aaf5b 100644 --- a/torch/fx/passes/graph_transform_observer.py +++ b/torch/fx/passes/graph_transform_observer.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs import os from collections.abc import Callable +from types import TracebackType from typing import TypeVar from torch.fx import Graph, Node @@ -28,7 +28,7 @@ def __init__( passname: str, subsystem: str | None = None, log_url: str | None = None, - ): + ) -> None: """ log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified """ @@ -73,7 +73,7 @@ def __init__( ).get_dot_graph() @classmethod - def get_current_pass_count(cls): + def get_current_pass_count(cls) -> int: return cls.__pass_count def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> T | None: @@ -102,7 +102,7 @@ def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> T | None: ): return pass_fn(self.gm.graph) - def _check_disable_pass(self): + def _check_disable_pass(self) -> bool: from torch._inductor import config as inductor_config if self.passname.upper() in inductor_config.disabled_passes.upper(): @@ -118,7 +118,7 @@ def _check_disable_pass(self): "inductor", self.subsystem, debug_info ) - def __enter__(self): + def __enter__(self) -> "GraphTransformObserver": if not self.active: return self self.gm._register_create_node_hook(self._node_creation_hook) @@ -136,7 +136,12 @@ def __enter__(self): return self - def __exit__(self, type, value, tb): + def __exit__( + self, + type: type[BaseException] | None, + value: BaseException | None, + tb: TracebackType | None, + ) -> None: if not self.active: return for gm in self.copied_gms + [self.gm]: @@ -181,10 +186,10 @@ def __exit__(self, type, value, tb): ) ) - def get_node_creation_hook(self): + def get_node_creation_hook(self) -> Callable[[Node], None]: # We have to return a function instead of using a class method directly # to avoid max recursion issue when deepcopy a graph module within the context manager. - def on_node_creation(node): + def on_node_creation(node: Node) -> None: self.created_nodes.add(node.name) self.name_to_node[node.name] = node source = NodeSource(None, self.passname, NodeSourceAction.CREATE) @@ -195,15 +200,15 @@ def on_node_creation(node): return on_node_creation - def get_node_erase_hook(self): - def on_node_erase(node): + def get_node_erase_hook(self) -> Callable[[Node], None]: + def on_node_erase(node: Node) -> None: self.erased_nodes.add(node.name) self.name_to_node.pop(node.name, None) return on_node_erase - def get_node_replace_hook(self): - def on_node_replace(old: Node, new: str, user: Node): + def get_node_replace_hook(self) -> Callable[[Node, str, Node], None]: + def on_node_replace(old: Node, new: str, user: Node) -> None: # Update node meta when replacing old node with new node new_node = self.name_to_node.get(new, None) @@ -223,7 +228,7 @@ def on_node_replace(old: Node, new: str, user: Node): if new_node.name in self.created_nodes: action.append(NodeSourceAction.CREATE) - def created_this_pass(source): + def created_this_pass(source: NodeSource) -> bool: return source.pass_name == self.passname and source.action == [ NodeSourceAction.CREATE ] @@ -241,8 +246,8 @@ def created_this_pass(source): return on_node_replace - def get_deepcopy_hook(self): - def on_deepcopy(gm): + def get_deepcopy_hook(self) -> Callable[[GraphModule], None]: + def on_deepcopy(gm: GraphModule) -> None: self.copied_gms.append(gm) return on_deepcopy diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 550c6d57be865..be8b4575a4c4e 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import collections import itertools import logging @@ -21,7 +20,7 @@ def __init__( id: int | None = None, nodes: Iterable[Node] | None = None, node_orders: Iterable[int] | None = None, - ): + ) -> None: self.id = id self.nodes: dict[Node, int | None] = {} if nodes is not None: @@ -39,18 +38,18 @@ def __init__( def __repr__(self) -> str: return str(self.nodes) - def add_node(self, node: Node, node_order: int | None = None): + def add_node(self, node: Node, node_order: int | None = None) -> None: self.nodes.update({node: node_order}) - def remove_node(self, node: Node): + def remove_node(self, node: Node) -> None: del self.nodes[node] - def size(self): + def size(self) -> int: return len(self.nodes) class _DependencyViewer: - def __init__(self, graph_module: GraphModule): + def __init__(self, graph_module: GraphModule) -> None: self.downstreams = collections.defaultdict(set) for node in reversed(graph_module.graph.nodes): @@ -92,7 +91,7 @@ def propose_partitions(self) -> list[Partition]: # partition_map is a mapping from partition id to a set of partition id's. # The value set contains all the partition ids that can be reached by doing a # DFS starting from the partition id in the key. - partition_map: dict[int, set] = collections.defaultdict(set) + partition_map: dict[int, set[int]] = collections.defaultdict(set) # assumptions: nodes in candidate list is sorted in topological order assignment: dict[Node, int] = {} # mapping from node to partition_id @@ -106,19 +105,19 @@ def propose_partitions(self) -> list[Partition]: int, int ] = {} # mapping from partition_id to minimum topo order of nodes in partition partition_users: dict[ - int, set + int, set[Node] ] = {} # mapping from partition_id to partition users new_partition_id = itertools.count() # try to merge partition other_id into partition self_id # merge only happens if the end graph doesn't contain cyclic dependency # returns `True` when merge happens, `False` otherwise. - def maybe_merge_partition(self_id: int, other_id: int): + def maybe_merge_partition(self_id: int, other_id: int) -> tuple[int, bool]: # merged_nodes is the union of nodes in two partition to-be-merged self_nodes = partitions_by_id[self_id].nodes other_nodes = partitions_by_id[other_id].nodes - def dfs_iter_find_cycle(all_user_nodes: set[Node]): + def dfs_iter_find_cycle(all_user_nodes: set[Node]) -> bool: for user_node in all_user_nodes: visited_partition_ids = set() @@ -185,8 +184,10 @@ def dfs_iter_find_cycle(all_user_nodes: set[Node]): return merge_id, True - def merge_single_node(node: Node, node_order: int | None, id: int | None): - def _update_partition_map(node: Node, id: int): + def merge_single_node( + node: Node, node_order: int | None, id: int | None + ) -> None: + def _update_partition_map(node: Node, id: int) -> None: # Iterate through all the users of this node and update the partition map to indicate # that there is a path from the partition id of this node to the target partition id. for user_node in node.users: @@ -331,10 +332,10 @@ def fuse_partitions( ) # remove non-compute-ops that sits at the boundary of a partition. - def remove_bookend_non_compute_ops(self, partitions: list[Partition]): + def remove_bookend_non_compute_ops(self, partitions: list[Partition]) -> None: non_compute_ops = set(self.non_compute_ops) - def is_non_compute_node(node: Node): + def is_non_compute_node(node: Node) -> bool: return ( node.op == "call_function" and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] @@ -346,7 +347,7 @@ def is_non_compute_node(node: Node): def is_transparent_input_node( node: Node, partition: set[Node], removed_nodes: set[Node] - ): + ) -> bool: if ( node.op == "placeholder" or (node not in partition) @@ -367,7 +368,7 @@ def is_transparent_input_node( def is_transparent_output_node( node: Node, partition: set[Node], removed_nodes: set[Node] - ): + ) -> bool: if ( node.op == "placeholder" or (node not in partition) diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index 109f62f1488bb..24571d82cd4a1 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -1,7 +1,7 @@ -# mypy: allow-untyped-defs import abc from collections import namedtuple +import torch.nn as nn from torch.fx._compatibility import compatibility from torch.fx.graph_module import GraphModule @@ -20,7 +20,7 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): __slots__ = () - def __new__(cls, graph_module, modified): + def __new__(cls, graph_module: nn.Module, modified: bool) -> "PassResult": return super().__new__(cls, graph_module, modified) diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 87fb6e70037f9..a474d692627c0 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -1,9 +1,9 @@ -# mypy: allow-untyped-defs import inspect import logging from collections.abc import Callable from functools import wraps from queue import Queue +from typing import Any import torch.nn as nn from torch.fx._compatibility import compatibility @@ -18,7 +18,7 @@ @compatibility(is_backward_compatible=False) -def pass_result_wrapper(fn: Callable) -> Callable: +def pass_result_wrapper(fn: Callable[..., Any]) -> Callable[..., PassResult | None]: """ Wrapper for passes which currently do not return a PassResult. This wrapper makes them return a PassResult containing the modified object @@ -35,7 +35,7 @@ def pass_result_wrapper(fn: Callable) -> Callable: return None @wraps(fn) - def wrapped_fn(gm): + def wrapped_fn(gm: nn.Module) -> PassResult | None: res = fn(gm) if res is None: return PassResult(gm, True) @@ -51,7 +51,8 @@ def wrapped_fn(gm): def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: list[Callable] + constraint: Callable[[Callable[..., Any], Callable[..., Any]], bool], + passes: list[Callable[..., Any]], ) -> None: for i, a in enumerate(passes): for j, b in enumerate(passes[i + 1 :]): @@ -65,8 +66,8 @@ def _validate_pass_schedule_constraint( def _topological_sort_passes( - passes: list[Callable], constraints: list[Callable] -) -> list[Callable]: + passes: list[Callable[..., Any]], constraints: list[Callable[..., Any]] +) -> list[Callable[..., Any]]: """ Args passes: Passes that we are ordering @@ -80,9 +81,9 @@ def _topological_sort_passes( return passes # Construct a graph mapping nodes to a list of their users - graph: dict[Callable, list[Callable]] = {p: [] for p in passes} - indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) - candidates: Queue = Queue() + graph: dict[Callable[..., Any], list[Callable[..., Any]]] = {p: [] for p in passes} + indegree_map: dict[Callable[..., Any], int] = dict.fromkeys(passes, 0) + candidates: Queue[Callable[..., Any]] = Queue() for a in passes: for b in passes: if a == b: @@ -96,8 +97,8 @@ def _topological_sort_passes( if indegree_map[a] == 0: candidates.put(a) - visited: dict[Callable, bool] = dict.fromkeys(passes, False) - sorted_passes: list[Callable] = [] + visited: dict[Callable[..., Any], bool] = dict.fromkeys(passes, False) + sorted_passes: list[Callable[..., Any]] = [] while not candidates.empty(): p = candidates.get() @@ -122,27 +123,28 @@ def _topological_sort_passes( @compatibility(is_backward_compatible=False) -def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: +def this_before_that_pass_constraint( + this: Callable[..., Any], that: Callable[..., Any] +) -> Callable[[Callable[..., Any], Callable[..., Any]], bool]: """ - Defines a partial order ('depends on' function) where `this` must occur - before `that`. + Defines a partial order ('depends on' function) where ``this`` must occur + before ``that``. - For example, the following pass list and constraint list would be invalid. - ``` - passes = [pass_b, pass_a] + For example, the following pass list and constraint list would be invalid:: - constraints = [this_before_that_pass_constraint(pass_a, pass_b)] - ``` + passes = [pass_b, pass_a] + + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] Args: this (Callable): pass which should occur first that (Callable): pass which should occur later Returns: - depends_on (Callable[[Object, Object], bool] + depends_on (Callable[[Object, Object], bool]) """ - def depends_on(a: Callable, b: Callable): + def depends_on(a: Callable[..., Any], b: Callable[..., Any]) -> bool: return a != that or b != this return depends_on @@ -170,19 +172,20 @@ class PassManager: checks """ - passes: list[Callable[[nn.Module], PassResult]] - constraints: list[Callable[[Callable, Callable], bool]] + passes: list[Callable[..., PassResult | None]] + constraints: list[Callable[[Callable[..., Any], Callable[..., Any]], bool]] _validated: bool = False steps: int = 1 def __init__( self, - passes=None, - constraints=None, - steps=None, + passes: list[Callable[..., PassResult | None]] | None = None, + constraints: list[Callable[[Callable[..., Any], Callable[..., Any]], bool]] + | None = None, + steps: int | None = None, run_checks_after_each_pass: bool = False, suppress_check_failures: bool = False, - ): + ) -> None: self.passes = passes or [] self.constraints = constraints or [] if steps: @@ -191,21 +194,23 @@ def __init__( self.run_checks_after_each_pass = run_checks_after_each_pass self.suppress_check_failures = suppress_check_failures - def add_pass(self, _pass: Callable): + def add_pass(self, _pass: Callable[..., PassResult | None]) -> None: """ Adds a pass into the current list of passes. """ self.passes.append(_pass) self._validated = False - def add_constraint(self, constraint: Callable): + def add_constraint( + self, constraint: Callable[[Callable[..., Any], Callable[..., Any]], bool] + ) -> None: """ Adds a constraint into the current list of constraints. """ self.constraints.append(constraint) self._validated = False - def validate_constraints(self): + def validate_constraints(self) -> None: """ Validates that current pass schedule defined by `self.passes` is valid according to all constraints in `self.constraints` @@ -216,7 +221,7 @@ def validate_constraints(self): _validate_pass_schedule_constraint(constraint, self.passes) self._validated = True - def solve_constraints(self): + def solve_constraints(self) -> None: """ Finds a valid traversal order based on the given constraints and orders the passes based on this order. @@ -228,7 +233,7 @@ def solve_constraints(self): self.passes = _topological_sort_passes(self.passes, self.constraints) self._validated = True - def add_checks(self, check: Callable) -> None: + def add_checks(self, check: Callable[[nn.Module], None]) -> None: """ Adds a function which takes runs various checks on a given graph module. This function is run before and after each pass if the @@ -283,7 +288,9 @@ def __call__(self, module: nn.Module) -> PassResult: f"The result of the pass {fn_name} should be type PassResult." + "Please wrap it with pass_result_wrapper()" ) + # pyrefly: ignore[missing-attribute] module = res.graph_module + # pyrefly: ignore[missing-attribute] modified = modified or res.modified if isinstance(module, GraphModule): diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 34636c6f3664a..7335497a670ef 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import logging from collections.abc import Callable from dataclasses import dataclass @@ -78,7 +77,7 @@ class _MinimizerSettingBase: return_intermediate: bool = False all_outputs: bool = False - def __str__(self): + def __str__(self) -> str: settings_str = "FX Minimizer Settings:\n" for k, v in vars(self).items(): @@ -113,7 +112,7 @@ def __init__( module_exporter: Callable[[Tensors, torch.fx.GraphModule, str], None] | None = None, exclusion_fn: Callable[[NodeList, int, int], None] | None = None, - ): + ) -> None: if not isinstance(module, torch.fx.GraphModule): raise AssertionError(f"Expected GraphModule, got {type(module)}") @@ -190,7 +189,7 @@ def _store_outputs( a_result: TensorOrTensors, b_result: TensorOrTensors, submodule: torch.fx.GraphModule, - ): + ) -> None: """ Store the outputs of self.run_a() and self.run_b() into self.a_outputs and self.b_outputs, so that we can use them when execute preceding nodes that @@ -251,7 +250,7 @@ def _get_submod_inputs( if self.settings.accumulate_error: print(f"Can't find previous stored outputs named {placeholders}!") - def get_inputs(self: torch.nn.Module, inputs: Any): + def get_inputs(self: torch.nn.Module, inputs: tuple[Any, ...]) -> None: nonlocal a_input a_input = inputs @@ -267,7 +266,7 @@ def get_inputs(self: torch.nn.Module, inputs: Any): return a_input, b_input - def _tag_nodes(self, selected_nodes: NodeSet): + def _tag_nodes(self, selected_nodes: NodeSet) -> None: """ Tag selected nodes with tag "minimize". Nodes with the same tags will be split to the same submodule afterwards. @@ -336,7 +335,7 @@ def _run_and_compare( submod_name: str, output_names: Names, report_idx: int = -1, - ): + ) -> None: """ Run the submodule in `split_module` that has name `submod_name` using `self.run_a` and `self.run_b` and compare their results. @@ -837,7 +836,7 @@ def _skip_traverse_impl( self.print_report(report) return set() - def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet: + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list[str]) -> NodeSet: """ Skip certain nodes in graph based on settings """ @@ -880,7 +879,7 @@ def _collect_nodes(self, start: str | None, end: str | None) -> NodeList: return nodes - def run_nodes(self, start: str | None = None, end: str | None = None): + def run_nodes(self, start: str | None = None, end: str | None = None) -> None: """ Run part of the model from `start` node to `end` node. If `start` is None then we start from the beginning of the model. If `end` is None then we @@ -901,7 +900,7 @@ def run_nodes(self, start: str | None = None, end: str | None = None): if node in self.fusions: cur_nodes.update(self.fusions[node]) - output_names = [] + output_names: list[str] = [] if self.settings.return_intermediate: output_names = [node.name for node in nodes] @@ -914,14 +913,14 @@ def run_nodes(self, start: str | None = None, end: str | None = None): ) as e: print(e) - def print_report(self, report: list[str]): + def print_report(self, report: list[str]) -> None: for i in range(len(report)): if i > 0: print(" . " + report[i]) else: print(report[i]) - def print_reports(self): + def print_reports(self) -> None: for report in self.reports: self.print_report(report) @@ -929,7 +928,7 @@ def minimize( self, start: str | None = None, end: str | None = None, - skip_nodes: list | None = None, + skip_nodes: list[str] | None = None, find_last_node: bool | None = None, ) -> NodeSet: """ diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index 813e22090ba7a..64d40737dc275 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import abc import typing as t @@ -67,7 +66,7 @@ class OperatorSupport(OperatorSupportBase): _support_dict: SupportDict - def __init__(self, support_dict: SupportDict | None = None): + def __init__(self, support_dict: SupportDict | None = None) -> None: self._support_dict = support_dict or {} def is_node_supported( @@ -163,7 +162,7 @@ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: any of it reports False. """ - def _chain(submods, node) -> bool: + def _chain(submods: t.Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool: return all(x.is_node_supported(submods, node) for x in op_support) return create_op_support(_chain) @@ -176,7 +175,9 @@ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: any of it reports True. """ - def _any_chain(submods, node) -> bool: + def _any_chain( + submods: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: return any(x.is_node_supported(submods, node) for x in op_support) return create_op_support(_any_chain) @@ -219,7 +220,7 @@ def _decline_if_node_in_names( return create_op_support(_decline_if_node_in_names) -def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: +def _get_arg_dtype(arg: torch.fx.Node) -> torch.dtype: if not isinstance(arg, torch.fx.Node): raise AssertionError(f"Expected torch.fx.Node, got {type(arg)}") tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 918f628d73eff..b413a83744beb 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,8 +1,13 @@ -# mypy: allow-untyped-defs import logging from collections.abc import Callable from functools import wraps from inspect import unwrap +from typing import Any, Concatenate, ParamSpec, TypeVar + + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_T = TypeVar("_T") logger = logging.getLogger(__name__) @@ -19,7 +24,9 @@ # for callables which modify object inplace and return something other than # the object on which they act -def inplace_wrapper(fn: Callable) -> Callable: +def inplace_wrapper( + fn: Callable[Concatenate[_T, _P], Any], +) -> Callable[Concatenate[_T, _P], _T]: """ Convenience wrapper for passes which modify an object inplace. This wrapper makes them return the modified object instead. @@ -32,33 +39,32 @@ def inplace_wrapper(fn: Callable) -> Callable: """ @wraps(fn) - def wrapped_fn(gm): - fn(gm) + def wrapped_fn(gm: _T, *args: _P.args, **kwargs: _P.kwargs) -> _T: + fn(gm, *args, **kwargs) return gm return wrapped_fn -def log_hook(fn: Callable, level=logging.INFO) -> Callable: +def log_hook(fn: Callable[_P, _R], level: int = logging.INFO) -> Callable[_P, _R]: """ Logs callable output. - This is useful for logging output of passes. Note inplace_wrapper replaces + This is useful for logging output of passes. Note ``inplace_wrapper`` replaces the pass output with the modified object. If we want to log the original - output, apply this wrapper before inplace_wrapper. + output, apply this wrapper before ``inplace_wrapper``. + Example:: - ``` - def my_pass(d: Dict) -> bool: - changed = False - if "foo" in d: - d["foo"] = "bar" - changed = True - return changed + def my_pass(d: Dict) -> bool: + changed = False + if "foo" in d: + d["foo"] = "bar" + changed = True + return changed - pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) - ``` + pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) Args: fn (Callable[Type1, Type2]) @@ -69,8 +75,8 @@ def my_pass(d: Dict) -> bool: """ @wraps(fn) - def wrapped_fn(gm): - val = fn(gm) + def wrapped_fn(*args: _P.args, **kwargs: _P.kwargs) -> _R: + val = fn(*args, **kwargs) logger.log(level, "Ran pass %s\t Return value: %s", fn, val) return val @@ -78,10 +84,10 @@ def wrapped_fn(gm): def loop_pass( - base_pass: Callable, + base_pass: Callable[[_T], _T], n_iter: int | None = None, - predicate: Callable | None = None, -): + predicate: Callable[[_T], bool] | None = None, +) -> Callable[[_T], _T]: """ Convenience wrapper for passes which need to be applied multiple times. @@ -97,7 +103,7 @@ def loop_pass( raise AssertionError("Exactly one of `n_iter`or `predicate` must be specified.") @wraps(base_pass) - def new_pass(source): + def new_pass(source: _T) -> _T: output = source if n_iter is not None and n_iter > 0: for _ in range(n_iter): @@ -120,8 +126,9 @@ def new_pass(source): # Implemented as 'depends on' operators. A constraint is satisfied iff a list # has a valid partial ordering according to this comparison operator. def _validate_pass_schedule_constraint( - constraint: Callable[[Callable, Callable], bool], passes: list[Callable] -): + constraint: Callable[[Callable[..., Any], Callable[..., Any]], bool], + passes: list[Callable[..., Any]], +) -> None: for i, a in enumerate(passes): for j, b in enumerate(passes[i + 1 :]): if constraint(a, b): @@ -133,42 +140,45 @@ def _validate_pass_schedule_constraint( ) -def this_before_that_pass_constraint(this: Callable, that: Callable): +def this_before_that_pass_constraint( + this: Callable[..., Any], that: Callable[..., Any] +) -> Callable[[Callable[..., Any], Callable[..., Any]], bool]: """ Defines a partial order ('depends on' function) where `this` must occur before `that`. """ - def depends_on(a: Callable, b: Callable): + def depends_on(a: Callable[..., Any], b: Callable[..., Any]) -> bool: return a != that or b != this return depends_on -def these_before_those_pass_constraint(these: Callable, those: Callable): +def these_before_those_pass_constraint( + these: Callable[..., Any], those: Callable[..., Any] +) -> Callable[[Callable[..., Any], Callable[..., Any]], bool]: """ - Defines a partial order ('depends on' function) where `these` must occur - before `those`. Where the inputs are 'unwrapped' before comparison. + Defines a partial order ('depends on' function) where ``these`` must occur + before ``those``. Where the inputs are 'unwrapped' before comparison. + + For example, the following pass list and constraint list would be invalid:: - For example, the following pass list and constraint list would be invalid. - ``` - passes = [ - loop_pass(pass_b, 3), - loop_pass(pass_a, 5), - ] + passes = [ + loop_pass(pass_b, 3), + loop_pass(pass_a, 5), + ] - constraints = [these_before_those_pass_constraint(pass_a, pass_b)] - ``` + constraints = [these_before_those_pass_constraint(pass_a, pass_b)] Args: these (Callable): pass which should occur first those (Callable): pass which should occur later Returns: - depends_on (Callable[[Object, Object], bool] + depends_on (Callable[[Object, Object], bool]) """ - def depends_on(a: Callable, b: Callable): + def depends_on(a: Callable[..., Any], b: Callable[..., Any]) -> bool: return unwrap(a) != those or unwrap(b) != these return depends_on @@ -190,40 +200,42 @@ class PassManager: `this_before_that_pass_constraint` for example. """ - passes: list[Callable] - constraints: list[Callable] + passes: list[Callable[..., Any]] + constraints: list[Callable[..., Any]] _validated: bool = False def __init__( self, - passes=None, - constraints=None, - ): + passes: list[Callable[..., Any]] | None = None, + constraints: list[Callable[..., Any]] | None = None, + ) -> None: self.passes = passes or [] self.constraints = constraints or [] @classmethod - def build_from_passlist(cls, passes): + def build_from_passlist(cls, passes: list[Callable[..., Any]]) -> "PassManager": pm = PassManager(passes) # TODO(alexbeloi): add constraint management/validation return pm - def add_pass(self, _pass: Callable): + def add_pass(self, _pass: Callable[..., Any]) -> None: self.passes.append(_pass) self._validated = False - def add_constraint(self, constraint): + def add_constraint(self, constraint: Callable[..., Any]) -> None: self.constraints.append(constraint) self._validated = False - def remove_pass(self, _passes: list[str]): + def remove_pass(self, _passes: list[str]) -> None: if _passes is None: return passes_left = [ps for ps in self.passes if ps.__name__ not in _passes] self.passes = passes_left self._validated = False - def replace_pass(self, _target, _replacement): + def replace_pass( + self, _target: Callable[..., Any], _replacement: Callable[..., Any] + ) -> None: passes_left = [] for ps in self.passes: if ps.__name__ == _target.__name__: @@ -233,7 +245,7 @@ def replace_pass(self, _target, _replacement): self.passes = passes_left self._validated = False - def validate(self): + def validate(self) -> None: """ Validates that current pass schedule defined by `self.passes` is valid according to all constraints in `self.constraints` @@ -244,7 +256,7 @@ def validate(self): _validate_pass_schedule_constraint(constraint, self.passes) self._validated = True - def __call__(self, source): + def __call__(self, source: Any) -> Any: self.validate() out = source for _pass in self.passes: diff --git a/torch/fx/passes/regional_inductor.py b/torch/fx/passes/regional_inductor.py index aa830c705ef72..72561dd5cea1d 100644 --- a/torch/fx/passes/regional_inductor.py +++ b/torch/fx/passes/regional_inductor.py @@ -1,7 +1,12 @@ -# mypy: allow-untyped-defs - +import contextlib import functools import logging +from collections.abc import Callable, Iterator, Mapping +from typing import Any, ParamSpec, TypeVar + + +_P = ParamSpec("_P") +_R = TypeVar("_R") import torch from torch.fx._compatibility import compatibility @@ -15,34 +20,26 @@ # standalone_inductor returns a callable class object - this does not sit well # with Fx graph node op call_function which expects a function. So this is just # a wrapper function to make Fx graph codegen happy. -def _dummy_wrapper(fn): +def _dummy_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: @functools.wraps(fn) - def inner(*args, **kwargs): + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: return fn(*args, **kwargs) return inner -def _partition_by_supported_nodes(gm, supported_ops, prefix): - from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner - from torch.fx.passes.utils.fuser_utils import fuse_by_partitions - - partitioner = CapabilityBasedPartitioner( - gm, supported_ops, allows_single_node_partition=True - ) - - candidate_partitions = partitioner.propose_partitions() - partitioned_gm = fuse_by_partitions( - partitioner.graph_module, - [partition.nodes for partition in candidate_partitions], - prefix=prefix, - always_return_tuple=True, - ) - - return partitioned_gm +@contextlib.contextmanager +def _disable_remat_for_regional_subcompile() -> Iterator[None]: + # In torch.compile, regional_inductor subcompiles run after the enclosing + # non-strict full graph has already been partitioned, so any graph-SAC + # remat pass has already run before we reach this nested compile. + # Rerunning remat here can see stage-2-reordered backward nodes that + # violate remat's contiguous-backward-region assumption. + with torch._functorch.config.patch(remat_using_tags_for_fwd_loss_bwd_graph=False): + yield -def _compile_submod(gm, prefix): +def _compile_submod(gm: torch.fx.GraphModule, prefix: str) -> torch.fx.GraphModule: from torch._inductor.standalone_compile import AOTCompiledArtifact for node in gm.graph.nodes: @@ -61,7 +58,7 @@ def _compile_submod(gm, prefix): # Get inductor configs from annotation # TODO we should change partition when there are multiple differently # annotated regions. - inductor_options = {} + inductor_options: dict[str, Any] = {} for sub_node in submod.graph.nodes: if hasattr(sub_node, "meta") and sub_node.meta.get("custom", None): custom = sub_node.meta["custom"] @@ -92,7 +89,10 @@ def _compile_submod(gm, prefix): f"Available config keys can be found in torch._inductor.config" ) - with inductor_config.patch(inductor_options): + with ( + inductor_config.patch(inductor_options), + _disable_remat_for_regional_subcompile(), + ): compiled_fn = torch._inductor.standalone_compile( submod, fake_inputs, @@ -118,8 +118,8 @@ def _compile_submod(gm, prefix): return gm -def _needs_inductor_compile(node: torch.fx.Node): - return ( +def _needs_inductor_compile(node: torch.fx.Node) -> bool: + return bool( node.op not in ("placeholder", "output") and hasattr(node, "meta") and node.meta.get("custom", None) @@ -133,30 +133,66 @@ class _RegionScooper: """ @staticmethod - def scoop_regions(gm): - from torch.fx.passes.operator_support import OperatorSupport - - found_marked_node = False + def scoop_regions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torch.fx.passes.operator_support import create_op_support + from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + + # Group tagged nodes by region ID. The region ID comes from the + # optional "inductor_region" key inside the compile_with_inductor + # annotation. When absent, all tagged nodes share a single default region + _DEFAULT_REGION = object() + regions: dict[object, set[torch.fx.Node]] = {} for node in gm.graph.nodes: if _needs_inductor_compile(node): - found_marked_node = True - break + compile_value = node.meta["custom"]["compile_with_inductor"] + if ( + isinstance(compile_value, dict) + and "inductor_region" in compile_value + ): + rid = compile_value["inductor_region"] + else: + rid = _DEFAULT_REGION + regions.setdefault(rid, set()).add(node) - if not found_marked_node: + if not regions: logger.info("No inductor marked nodes found") return gm - class InductorMarkedNodes(OperatorSupport): - def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: - return _needs_inductor_compile(node) - - marked_nodes = InductorMarkedNodes() - return _partition_by_supported_nodes( - gm, marked_nodes, "__marked_inductor_submod" + # Run CapabilityBasedPartitioner per region to get cycle-safe partitions + # without merging across region boundaries. + def _is_in_region( + region_nodes: set[torch.fx.Node], + ) -> Callable[[Mapping[str, torch.nn.Module], torch.fx.Node], bool]: + def is_node_supported( + _submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return node in region_nodes + + return is_node_supported + + all_partitions: list[dict[torch.fx.Node, int | None]] = [] + for region_nodes in regions.values(): + support = create_op_support(_is_in_region(region_nodes)) + partitioner = CapabilityBasedPartitioner( + gm, support, allows_single_node_partition=True + ) + for partition in partitioner.propose_partitions(): + all_partitions.append(partition.nodes) + + return fuse_by_partitions( + gm, + all_partitions, + prefix="__marked_inductor_submod", + always_return_tuple=True, ) @staticmethod - def recursively_scoop_regions(gm): + def recursively_scoop_regions( + gm: torch.fx.GraphModule, _processed: set[int] | None = None + ) -> torch.fx.GraphModule: + if _processed is None: + _processed = set() for node in gm.graph.find_nodes(op="get_attr"): if _needs_inductor_compile(node): # If the get_attr itself is marked for compile, the outer graph will @@ -164,12 +200,17 @@ def recursively_scoop_regions(gm): # regional inductor compiles that do not work well. continue submod = getattr(gm, node.target) - if isinstance(submod, torch.fx.GraphModule): - _RegionScooper.recursively_scoop_regions(submod) + # Track by id: multiple get_attr nodes may reference the same GraphModule + if ( + isinstance(submod, torch.fx.GraphModule) + and id(submod) not in _processed + ): + _processed.add(id(submod)) + _RegionScooper.recursively_scoop_regions(submod, _processed) return _RegionScooper.scoop_regions(gm) - def __call__(self, gm): + def __call__(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: with torch.fx.traceback.preserve_node_meta(enable=False): return _RegionScooper.recursively_scoop_regions(gm) @@ -180,7 +221,7 @@ class _RegionCompiler: """ @staticmethod - def compile_region(gm): + def compile_region(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: from torch.fx.graph import _BoxedCodeGen gm = _compile_submod(gm, "__marked_inductor_submod") @@ -189,7 +230,7 @@ def compile_region(gm): return gm @staticmethod - def recursively_compile_regions(gm): + def recursively_compile_regions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # Find if the graph module has a scooped out region found_region = False for node in gm.graph.find_nodes(op="call_module"): @@ -208,35 +249,41 @@ def recursively_compile_regions(gm): return _RegionCompiler.compile_region(gm) return gm - def __call__(self, gm): + def __call__(self, gm: torch.fx.GraphModule) -> torch.fx.GraphModule: with torch.fx.traceback.preserve_node_meta(enable=False): return _RegionCompiler.recursively_compile_regions(gm) -def _create_inductor_marked_regions(gm): +def _create_inductor_marked_regions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: with torch.fx.traceback.preserve_node_meta(enable=False): return _RegionScooper()(gm) -def _compile_inductor_marked_regions(gm): +def _compile_inductor_marked_regions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: with torch.fx.traceback.preserve_node_meta(enable=False): return _RegionCompiler()(gm) @compatibility(is_backward_compatible=False) -def regional_inductor(gm, *example_args): +def regional_inductor( + gm: torch.fx.GraphModule, *example_args: object +) -> torch.fx.GraphModule: """ Scoops out inductor marked regions and compiles them with inductor. - Inductor options should be provided via the annotation API: - with fx_traceback.annotate({ - "compile_with_inductor": { - "inductor_configs": { - "max_autotune": True, - "triton.cudagraphs": False + Inductor options should be provided via the annotation API:: + + with fx_traceback.annotate( + { + "compile_with_inductor": { + "inductor_configs": { + "max_autotune": True, + "triton.cudagraphs": False, + } + } } - } - }): + ): + ... """ # fuser utils create new nodes using create_proxy which retains the seq_nr @@ -248,5 +295,5 @@ def regional_inductor(gm, *example_args): if torch._functorch.config.force_autograd_cache: from torch._inductor.output_code import RegionalOutputCode - gm = RegionalOutputCode(gm) + return RegionalOutputCode(gm) # type: ignore[return-value] return gm diff --git a/torch/fx/passes/regional_inductor_invoke_subgraph.py b/torch/fx/passes/regional_inductor_invoke_subgraph.py index b2a6e5c3a9b28..28d614d8a8425 100644 --- a/torch/fx/passes/regional_inductor_invoke_subgraph.py +++ b/torch/fx/passes/regional_inductor_invoke_subgraph.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - import copy import logging from collections import defaultdict @@ -8,7 +6,10 @@ from torch._inductor.standalone_compile import AOTCompiledArtifact from torch.compiler._cache import CacheArtifactManager from torch.fx._compatibility import compatibility -from torch.fx.passes.regional_inductor import _dummy_wrapper +from torch.fx.passes.regional_inductor import ( + _disable_remat_for_regional_subcompile, + _dummy_wrapper, +) logger = logging.getLogger(__name__) @@ -18,7 +19,7 @@ def _compile_submod( gm: torch.fx.GraphModule, subgraph: str, subgraph_users: list[torch.fx.Node] -): +) -> torch.fx.GraphModule: """ Compiles subgraph submodule in gm. subgraph is used by subgraph_users. subgraph_users must all be torch.ops.higher_order.invoke_subgraph HOP. @@ -56,7 +57,7 @@ def _compile_submod( compile_config, ) - def get_compiled_fn(): + def get_compiled_fn() -> AOTCompiledArtifact: context = torch._guards.TracingContext.get() if context.fake_mode is None: raise AssertionError("context.fake_mode is None") @@ -67,6 +68,7 @@ def get_compiled_fn(): torch._guards.tracing(context), CacheArtifactManager.with_fresh_cache(), torch._functorch.config.patch("bundled_autograd_cache", True), + _disable_remat_for_regional_subcompile(), ): # compile_fx can mutate gm gm = copy.deepcopy(submod) @@ -96,28 +98,33 @@ def get_compiled_fn(): return gm -def _needs_inductor_compile(node: torch.fx.Node): +def _needs_inductor_compile(node: torch.fx.Node) -> bool: # TODO: maybe we could change to check # node.meta.get("partitioner_tag") != "is_forward" # if the tag is relibable - return ( - node.op not in ("placeholder", "output") - and hasattr(node, "meta") - and node.meta.get("custom", None) - and node.meta["custom"].get("nested_region_config", None) - and node.meta["custom"]["nested_region_config"].fw_compiler - and node.meta.get("partitioner_tag") != "is_backward" - ) or ( - node.op not in ("placeholder", "output") - and hasattr(node, "meta") - and node.meta.get("custom", None) - and node.meta["custom"].get("nested_region_config", None) - and node.meta["custom"]["nested_region_config"].bw_compiler - and node.meta.get("partitioner_tag") == "is_backward" + return bool( + ( + node.op not in ("placeholder", "output") + and hasattr(node, "meta") + and node.meta.get("custom", None) + and node.meta["custom"].get("nested_region_config", None) + and node.meta["custom"]["nested_region_config"].fw_compiler + and node.meta.get("partitioner_tag") != "is_backward" + ) + or ( + node.op not in ("placeholder", "output") + and hasattr(node, "meta") + and node.meta.get("custom", None) + and node.meta["custom"].get("nested_region_config", None) + and node.meta["custom"]["nested_region_config"].bw_compiler + and node.meta.get("partitioner_tag") == "is_backward" + ) ) -def _compile_invoke_subgraph_nodes_with_inductor(gm): +def _compile_invoke_subgraph_nodes_with_inductor( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: map_subgraph_to_nodes = defaultdict(list) subgraphs: set[str] = set() @@ -140,7 +147,9 @@ def _compile_invoke_subgraph_nodes_with_inductor(gm): return gm -def _recursive_compile_invoke_subgraph_nodes(gm): +def _recursive_compile_invoke_subgraph_nodes( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: for node in gm.graph.find_nodes(op="get_attr"): if _needs_inductor_compile(node): # If the get_attr itself is marked for compile, the outer graph will @@ -155,7 +164,9 @@ def _recursive_compile_invoke_subgraph_nodes(gm): @compatibility(is_backward_compatible=False) -def regional_inductor_invoke_subgraph(gm, *example_args): +def regional_inductor_invoke_subgraph( + gm: torch.fx.GraphModule, *example_args: object +) -> torch.fx.GraphModule: """ Compile invoke_subgraph nodes if they have custom compiler specified in node.meta["nested_region_config"].bw_compiler or fw_compiler @@ -165,6 +176,7 @@ def regional_inductor_invoke_subgraph(gm, *example_args): with torch.fx.traceback.preserve_node_meta(enable=False): compiled_gm = _recursive_compile_invoke_subgraph_nodes(gm) # TODO: might not need this boxed_nop after we switch to _RegionCompiler + # pyrefly: ignore [bad-return] return torch._dynamo.backends.debugging.boxed_nop( compiled_gm, example_inputs=[] ) diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 17dde9256f5fa..04e98ef3ed549 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -1,8 +1,6 @@ -# mypy: allow-untyped-defs import _operator import itertools from collections import defaultdict -from collections.abc import Callable from enum import Enum from typing import Any @@ -24,7 +22,7 @@ class _ViewType(Enum): MultiOutputView = 2 -def _is_view_op(tgt): +def _is_view_op(tgt: object) -> bool | None: if tgt is not None and isinstance(tgt, torch._ops.OpOverload): schema = tgt._schema if len(schema.arguments) > 0: @@ -35,7 +33,7 @@ def _is_view_op(tgt): ) -def _get_view_type(tgt) -> _ViewType: +def _get_view_type(tgt: object) -> _ViewType: if tgt is not None and isinstance(tgt, torch._ops.OpOverload): schema = tgt._schema if len(schema.arguments) > 0: @@ -61,7 +59,7 @@ def _get_view_type(tgt) -> _ViewType: # to sanity check that our aliasing information is correct. @compatibility(is_backward_compatible=False) class _FunctionalizationMetadataProp(torch.fx.Interpreter): - def run_node(self, node: Node): + def run_node(self, node: Node) -> Any: self.node_counter += 1 result = super().run_node(node) node.meta["fake_result"] = result @@ -122,7 +120,7 @@ def run_node(self, node: Node): raise AssertionError("view_storage != base_storage") return result - def propagate(self, *args): + def propagate(self, *args: object) -> Any: self.multi_output_view_nodes = {} self.node_counter = -1 @@ -133,7 +131,9 @@ def propagate(self, *args): return super().run(*fake_args) -def _schemas_match(functional_schema, inplace_schema): +def _schemas_match( + functional_schema: torch._C.FunctionSchema, inplace_schema: torch._C.FunctionSchema +) -> bool: names_match = ( inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name @@ -160,7 +160,7 @@ def _schemas_match(functional_schema, inplace_schema): # - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) # - out= ops (e.g. angle -> angle.out) # TODO: we should also figure this info out using torchgen. -def _maybe_get_inplace_op(op): +def _maybe_get_inplace_op(op: object) -> torch._ops.OpOverload | None: # __module__ seems broken; it returns torch._ops.aten which doesn't exist if not isinstance(op, torch._ops.OpOverload): return None @@ -202,7 +202,7 @@ def _maybe_get_inplace_op(op): return inplace_op -_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = { +_VIEW_INVERSE_MAP: dict[torch._ops.OpOverload, torch._ops.OpOverload] = { torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, @@ -213,8 +213,8 @@ def _maybe_get_inplace_op(op): # This function, given a set of set of (aliased) tensor nodes, # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index # in the node ordering. -def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int): - def _add_if_tensor(x, set_): +def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int) -> set[Node]: + def _add_if_tensor(x: object, set_: set[StorageWeakRef]) -> None: if isinstance(x, FakeTensor): set_.add(StorageWeakRef(x._typed_storage())) @@ -249,7 +249,7 @@ def _add_if_tensor(x, set_): def _get_view_inverse_node_usages( later_node_usages: set[Node], self_aliases: set[Node] ) -> set[Node]: - def matching_view_metadata(a, b): + def matching_view_metadata(a: FakeTensor, b: FakeTensor) -> bool: return ( a.size() == b.size() and a.stride() == b.stride() @@ -308,105 +308,111 @@ def matching_view_metadata(a, b): @compatibility(is_backward_compatible=True) -def reinplace(gm, *sample_args): - """ +def reinplace( + gm, *sample_args +): # pyrefly: ignore[unannotated-parameter, unannotated-return] + r""" Given an fx.GraphModule, modifies it to perform "reinplacing", mutating the nodes of the graph. - We look for out-of-place op call sites like `b = a.add(...)`, - and convert them to be inplace (`b = a.add_(...)`), + We look for out-of-place op call sites like ``b = a.add(...)``, + and convert them to be inplace (``b = a.add_(...)``), as long as the input to the current operator ("a") isn't reused anywhere later in the graph. This pass currently expects to operate on a **functional, ATen** graph. - This can be obtained by running `make_fx(functionalize(f))`. + This can be obtained by running ``make_fx(functionalize(f))``. Sample inputs are needed to determine aliasing relationships of the inputs. - In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + In general, we can't reinplace node ``b = a.add(...)`` if "a" aliases any of the inputs to the program. - Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: + Given a node ``b = foo(a, args...)`` the algorithm for re-inplacing is as follows: + + **(1)** Perform some initial checks on the metadata of "a" and "args..." + that can disqualify them from being reinplaced. + + - **(1a)** Check that the self argument we're attempting to reinplace + has acceptable dtype/size metadata to reinplace with. + + For example, if we have:: + + a = torch.ones(1) + b = torch.ones(10) + out = torch.add(a, b) + + We can't turn that into ``a.add_(b)`` because that would require resizing "a". - (1) Perform some initial checks on the metadata of "a" and "args..." - that can disqualify them from being reinplaced. + Similarly, we can't convert ``torch.ge(a, b)`` into ``a.ge_(b)``, + because that would require changing a's dtype (from e.g. float32 to bool). + Note that in this specific example, we could technically do better.. - (1a) Check that the self argument we're attempting to reinplace - has acceptable dtype/size metadata to reinplace with. + If we see the pattern:: - For example, if we have: - a = torch.ones(1) - b = torch.ones(10) - out = torch.add(a, b) - We can't turn that into - a.add_(b) - Because that would require resizing "a". + a_1 = a.ge(b) + a_2 = aten._to_copy(a_1, a.dtype) - Similarly, we can't convert torch.ge(a, b) into a.ge_(b), - because that would require changing a's dtype (from e.g. float32 to bool). - Note that in this specific example, we could technically do better.. + Then this should be valid to completely re-inplace + (this is exactly what functionalization will emit when it sees ``a.ge_(b)``). - If we see the pattern: - a_1 = a.ge(b) - a_2 = aten._to_copy(a_1, a.dtype) - Then we this should be valid to completely re-inplace - (this is exactly what functionalization will emit when it sees a.ge_(b)). + This optimization is only really important for user programs + that directly use inplace comparison ops though. - This optimization is only really important for user programs - that directly use inplace comparison ops though. + We also cannot re-inplace on tensors that have overlapping memory, + e.g. ``torch.ones(1).expand(4, 4).add_(1)``. - We also cannot re-inplace on tensors that have overlapping memory, - e.g. torch.ones(1).expand(4, 4).add_(1) + - **(1b)** Check if "a" is an alias of any of the program inputs. - (1b) Check if "a" is an alias of any of the program inputs. + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. - If it is, skip and move to the next node. - Inplace'ing an op that would cause it to mutate a program is not sound, - because that would be a side effect visible to the user. + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like ``a.copy_(...)``, + then re-inplacing is ok to do - we are temporarily reusing a's buffer, + which will later be overwritten by the ``copy_()`` call. - NOTE: there's a future optimization that we should make: - if "a" is a (alias of a) program input, but later in the program - there is a node that looks like "a.copy_(...)", - Then re-inplacing is ok to do - we are temporarily reusing a's buffer, - which will later be overwritten by the copy_() call. + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. - This will be an important optimization to have for programs that mutate - their inputs. It currently isn't implemented though. + - **(1c)** Check if "a" and "args..." alias. - (1c) Check if "a" and "args..." alias + For example, re-inplacing to create code like the below + isn't guaranteed to be sound:: - For example, re-inplacing to create code like the below - isn't guaranteed to be sound: + aten.mul_(a, a) - aten.mul_(a, a) + **(2)** Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to ``b = foo_(a)``. - (2) Check that "a" and all of its outstanding aliases are not used anywhere - later in the graph. If this is the case, then it's safe to re-inplace - to "b = foo_(a)". + There are a few caveats to this, explained in more detail below: - There are a few caveats to this, explained in more detail below: - (a) If "a" is used later as an argument to a view op, that is okay. - It's only a problem if "a" (or that view) is later passed - into a normal operator, or if it is returned as the program output. - (b) If "a" is a repeat argument in `foo()`, then don't reinplace. - Most ATen kernels don't make any guarantees that this is sound, - e.g. if you do aten.mul_(a, a). - So we'll just ban re-inplacing in this case. - It's only a problem if "a" (or that view) is later passed - (c) If "a" is used as an input into a view "inverse" / "scatter" - operator, it is potentially fine to re-inplace - (and remove that scatter operator from the graph). - See below for a more detailed example. + - (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + - (b) If "a" is a repeat argument in ``foo()``, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do ``aten.mul_(a, a)``. + So we'll just ban re-inplacing in this case. + - (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. - NOTE: there is an optimization in this step that is crucial - to fully recovering performance from functionalization. + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program:: - Given this program: def f(x): a = torch.ops.aten.add(x, x) b = torch.ops.aten.diagonal(a) torch.ops.aten.fill_(b, 0) return d - Functionalization will emit the following: + Functionalization will emit the following:: + def f(x): a = torch.ops.aten.add(x, x) b = torch.ops.aten.diagonal(a, 0, 1) @@ -414,99 +420,115 @@ def f(x): a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) return a_updated - Ordinarily, we would not be able to reinplace the fill, - because "b" aliases with "a" which is used by the diagonal_scatter call. - - "re-inplacing" is on the hook for figuring out that it is ok to - completely, the expensive diagonal_scatter call, if we re-inplace the add(). - - So, for every `alias in alias_set(a)`, instead of checking - that "alias" is not used anywhere later in the graph, - we check that - EITHER: - (a) alias is not used anywhere later in the graph - OR: - (b) alias is used exactly once later on in the graph, - in the following op: - - out = foo_scatter(alias, x, args...) - - where the following must hold: - (i) "foo_scatter" is the "inverse" operator for foo. - This only applies to "foo" ops that are view operators, - which view into a subset of the original tensor's memory. - In practice, there are ~4 operators where this applies: - diagonal -> diagonal_scatter - slice -> slice_scatter - select -> select_scatter - as_strided -> as_strided_scatter - (ii) "args..." are the same between the foo() and foo_scatter() calls. - - (3) Perform the actual re-inplacing on foo! - - (3b) is the common case, but special care is needed for {view}_scatter (3a) - - (3a) {view}_scatter ops. - - Consider this program: - a = torch.zeros(2, 2) - b = torch.ones(2) - a[0] = b - - Post functionalization, that will look like: - a = torch.zeros(2) - b = torch.ones(1) - a_updated = torch.select_scatter(a, b, 0, 0) - - In this case though, there is no "functional" op to re-inplace! - Instead, we'd like to directly remove toe select_scatter call. - We already know from (3) that this is valid, - because "a" has no later usages in the graph. - - We perform the re-inplacing on the {view}_scatter op like so - Before: - a_updated = torch.select_scatter(a, b, args...) - After: - a_slice = a.select(a, args...) - a_slice.copy_(b) - - (3b) Otherwise, replace the functional op with its inplace variant. - Before: - b = foo(a, args...) - After: - a.foo_(args...) - - (4) Finally, after converting either: - Before: - b = foo(a) - After: - foo_(a) - or - Before: - b = {slice}_scatter(a, mutated_slice, args...) - After: - slice = {slice}(a, args...) - slice.copy_(mutated_slice) - - We now need to find all later nodes that use "b" as an argument - and update them to take in "a" instead. - - Note that for the majority of inplace ops, this isn't actually necessary - (because most inplace ops return "self" as their output). - This isn't generally true for all mutable ops though, which is why - we need to actually replace all of the arguments. - - We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], - That maps a given tensor storage to the set of all nodes that take in that storage - as an input. - Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused - together. - - (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" - during step (3) get manually deleted from the graph. - Their outputs are no longer used, so technically standard DCE would be able - to do this, but we can no longer run FX's DCE pass now that we have mutable - ops in the graph. + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely remove the expensive diagonal_scatter call, if we re-inplace + the add(). + + So, for every ``alias in alias_set(a)``, instead of checking + that "alias" is not used anywhere later in the graph, + we check that EITHER: + + - (a) alias is not used anywhere later in the graph, OR + - (b) alias is used exactly once later on in the graph, + in the following op:: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + + - (i) ``foo_scatter`` is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies:: + + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + + - (ii) "args..." are the same between the ``foo()`` and + ``foo_scatter()`` calls. + + **(3)** Perform the actual re-inplacing on foo! + + (3b) is the common case, but special care is needed for + ``{view}_scatter`` (3a). + + - **(3a)** ``{view}_scatter`` ops. + + Consider this program:: + + a = torch.zeros(2, 2) + b = torch.ones(2) + a[0] = b + + Post functionalization, that will look like:: + + a = torch.zeros(2) + b = torch.ones(1) + a_updated = torch.select_scatter(a, b, 0, 0) + + In this case though, there is no "functional" op to re-inplace! + Instead, we'd like to directly remove the select_scatter call. + We already know from (3) that this is valid, + because "a" has no later usages in the graph. + + We perform the re-inplacing on the ``{view}_scatter`` op like so. + + Before:: + + a_updated = torch.select_scatter(a, b, args...) + + After:: + + a_slice = a.select(a, args...) + a_slice.copy_(b) + + - **(3b)** Otherwise, replace the functional op with its inplace variant. + + Before:: + + b = foo(a, args...) + + After:: + + a.foo_(args...) + + **(4)** Finally, after converting either:: + + # Before: # After: + b = foo(a) foo_(a) + + or:: + + # Before: + b = {slice}_scatter(a, mutated_slice, args...) + # After: + slice = {slice}(a, args...) + slice.copy_(mutated_slice) + + We now need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of ``Dict[StorageWeakRef, Set[Node]]``, + that maps a given tensor storage to the set of all nodes that take in that + storage as an input. + Specifically, re-inplacing ``b = foo(a)`` causes "a" and "b"'s sets to get + fused together. + + **(5)** Any ``view_inverse/scatter`` nodes that were identified as + "it's ok to ignore them" during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. """ _FunctionalizationMetadataProp(gm).propagate(*sample_args) @@ -543,7 +565,7 @@ def f(x): for n in gm.graph.nodes: if "fake_result" in n.meta: # Tree-mapping because some ops can return lists of tensors. - def _add_to_map(x): + def _add_to_map(x: object) -> None: if isinstance(x, FakeTensor): storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) @@ -701,12 +723,14 @@ def _add_to_map(x): # We need to replace any later usages of "b" with "a" for old in itertools.chain([node], later_view_inverse_node_usages): new = old.args[0] + if not isinstance(new, Node): + raise AssertionError(f"Expected Node, got {type(new)}") nodes_to_update = [ n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] ] for node_to_update in nodes_to_update: - def replace_arg(a): + def replace_arg(a: Node) -> Node: if a == old: return new return a @@ -751,6 +775,7 @@ def replace_arg(a): and len(node_res_storage) == 1 and old_res_storage == node_res_storage ): + # pyrefly: ignore [missing-attribute] new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) new_res_storage = { StorageWeakRef(x._typed_storage()) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 99665d407e65d..9a6e3881e8afa 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -1,8 +1,8 @@ -# mypy: allow-untyped-defs import functools import logging import operator import sys +from collections.abc import Callable from typing import Any, Optional, TYPE_CHECKING @@ -98,12 +98,12 @@ def insert_deferred_runtime_asserts( _get_placeholder_expr, _has_uninterpretable_sympy_function, CallMethodKey, - cast_symbool_to_symint_guardless, ConvertIntKey, DivideByKey, free_symbols, InnerTensorKey, resolve_unbacked_bindings, + RuntimeAssert, ) from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.reference import ( @@ -207,7 +207,9 @@ def _node_metadata_hook( Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis - def _sympy_interp(expr_to_proxy, expr): + def _sympy_interp( + expr_to_proxy: dict[sympy.Expr, fx.Proxy], expr: sympy.Expr + ) -> fx.Proxy: # sympy_interp() with hash consing from sympy import Integer, Number, Symbol from sympy.logic.boolalg import BooleanAtom @@ -240,7 +242,7 @@ def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool: isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number) ) - def add_runtime_asserts(ras): + def add_runtime_asserts(ras: list[RuntimeAssert]) -> None: for ra in ras: if ( # redundant @@ -310,7 +312,7 @@ def add_runtime_asserts(ras): and (example_value := _get_example_value(node)) is not None ): - def match_symbol(symint, cb): + def match_symbol(symint: object, cb: Callable[[], fx.Node]) -> None: if ( isinstance(symint, torch.SymInt) and isinstance(symint.node, SymNode) @@ -391,7 +393,7 @@ def match_symbol(symint, cb): and (sym_expr := _get_sym_val(node)) is not None ): # this guards against deleting calls like item() that produce new untracked symbols - def has_new_untracked_symbols(): + def has_new_untracked_symbols() -> bool: # pyrefly: ignore [missing-attribute] for symbol in sym_expr.free_symbols: if symbol not in expr_to_proxy: @@ -405,7 +407,7 @@ def has_new_untracked_symbols(): shape_env, node.meta.get("unbacked_bindings", {}) ) - def has_new_unbacked_bindings(): + def has_new_unbacked_bindings() -> bool: if resolved_unbacked_bindings is None: raise AssertionError("resolved_unbacked_bindings is None") for key in resolved_unbacked_bindings: @@ -482,7 +484,7 @@ def has_new_unbacked_bindings(): # TODO: some CSE when generating these nodes can probably # help reduce graph size and improve compile time - def go(node, keypath): + def go(node: fx.Node, keypath: tuple[object, ...]) -> fx.Node: if keypath == (): return node if ( @@ -535,9 +537,7 @@ def go(node, keypath): ) elif isinstance(keypath[0], ConvertIntKey): return go( - graph.call_function( - cast_symbool_to_symint_guardless, (node,) - ), + graph.call_function(torch.sym_ite, (node, 1, 0)), keypath[1:], ) elif isinstance(keypath[0], DivideByKey): @@ -629,7 +629,7 @@ def go(node, keypath): # assert and also explicitly refine the range # (refinement should not be necessary once runtime # asserts cause refinement, but that's NYI) - def convert(s): + def convert(s: Any) -> int | None: if s in (int_oo, -int_oo): return None try: @@ -637,44 +637,37 @@ def convert(s): except TypeError: return None - if ( - expr_to_proxy[i0].node.target - is not cast_symbool_to_symint_guardless + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + # nodes added in `apply_runtime_assertion_pass` will have the same annotation + # as the input node to the assertion + custom=node.meta.get("custom"), + ), ): - # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts - # raises AOTAutograd errors on cast_symbool_to_symint_guardless - - with _set_node_metadata_hook( - gm, - functools.partial( - _node_metadata_hook, - stack_trace=node.meta.get("stack_trace"), - nn_module_stack=node.meta.get("nn_module_stack"), - # nodes added in `apply_runtime_assertion_pass` will have the same annotation - # as the input node to the assertion - custom=node.meta.get("custom"), - ), - ): - if (min_val := convert(vr.lower)) is not None: - ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - ge, - f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", - ), - ) - added_asserts.add(i0 >= min_val) - if (max_val := convert(vr.upper)) is not None: - le = _sympy_interp(expr_to_proxy, i0 <= max_val).node - graph.call_function( - torch.ops.aten._assert_scalar.default, - ( - le, - f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", - ), - ) - added_asserts.add(i0 <= max_val) + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) constrained_unbacked_symbols.add(i0) add_runtime_asserts(ras) diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index 63e35cae175cd..f6887e2d67c5c 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - import traceback from typing import Any, NamedTuple diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 099184f96408b..924f310e62812 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import inspect import logging from collections import OrderedDict @@ -7,6 +6,7 @@ import torch from torch.fx._compatibility import compatibility +from torch.fx._lazy_graph_module import _make_graph_module from torch.fx._utils import lazy_format_graph_code from torch.fx.graph_module import GraphModule from torch.fx.node import Node @@ -18,7 +18,7 @@ @compatibility(is_backward_compatible=True) class Partition: - def __init__(self, name: str): + def __init__(self, name: str) -> None: self.name: str = name self.submod_name = f"submod_{name}" self.node_names: list[str] = [] @@ -62,7 +62,8 @@ def split_module( keep_original_input_name: bool = True, *, partition_affix: str | None = None, -): + tuple_return: bool = False, +) -> GraphModule: """ Creates subgraphs out of main graph @@ -86,6 +87,9 @@ def split_module( have the same input names as the original graph. partition_affix: Optional[str]: If specified, the submodules' names will contain the affix, e.g. "submod__". + tuple_return: bool: If True, submodule outputs are always wrapped in a tuple, + even when there is only a single output value. This makes all subgraphs + conform to the convention expected by ``torch._inductor.compile_fx``. Returns: GraphModule: the module after split. @@ -172,7 +176,7 @@ def construct_graph( node: Node, base_mod_env: dict[str, Node], base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule], - ): + ) -> tuple[dict[str, Node], dict[str, torch.fx.graph_module.GraphModule]]: if node.op == "placeholder": default_value = ( node.args[0] if len(node.args) > 0 else inspect.Signature.empty @@ -209,7 +213,7 @@ def construct_graph( orig_nodes: dict[str, Node] = {} symbol_to_node: dict[sympy.Symbol, Node] = {} - def record_cross_partition_use(def_node: Node, use_node: Node | None): + def record_cross_partition_use(def_node: Node, use_node: Node | None) -> None: from torch.fx.experimental.symbolic_shapes import free_symbols defined = getattr(def_node, "_fx_partition", None) @@ -258,7 +262,7 @@ def record_cross_partition_use(def_node: Node, use_node: Node | None): if defined is not None: use_partition.dependencies.setdefault(defined) - def instantiate_node_partition_mapping(node): + def instantiate_node_partition_mapping(node: Node) -> None: partition_idx = split_callback(node) partition_name = str(partition_idx) if partition_affix is not None: @@ -406,7 +410,7 @@ def instantiate_node_partition_mapping(node): ) torch.fx.graph.map_arg( node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) - ) # noqa: B950 + ) original_partition_order = list(partitions.keys()) # find partitions with no dependencies @@ -462,7 +466,7 @@ def instantiate_node_partition_mapping(node): # We don't pass in get_attr nodes as inputs to the partition, but # instead set them as targets and use getattr within the module - def add_placeholder(): + def add_placeholder() -> Node: if keep_original_input_name: name = inp else: @@ -603,15 +607,10 @@ def add_placeholder(): partition.environment[orig_nodes[name]] for name in partition.outputs ) - # skip output node generation if there are no output values - num_output_vals = len(output_vals) - if num_output_vals == 1: + if len(output_vals) == 1 and not tuple_return: partition.graph.output(output_vals[0]) - elif num_output_vals > 1: - partition.graph.output(output_vals) else: - # Invariant - Graph should always have an output node. - partition.graph.output(()) + partition.graph.output(output_vals) if keep_original_order: # first get the attr nodes required by this partition @@ -638,9 +637,9 @@ def add_placeholder(): ) already_constructed_attr_nodes.add(node) - base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( + base_mod_attrs[partition.submod_name] = _make_graph_module( partition.targets, partition.graph - ) # noqa: B950 + ) # Emit call in base graph to this submodule output_val = base_mod_graph.call_module( @@ -649,8 +648,8 @@ def add_placeholder(): ) num_outputs = len(partition.outputs) - if num_outputs > 1: - # Unpack multiple return values from submodule + if num_outputs > 1 or (num_outputs == 1 and tuple_return): + # Unpack return values from submodule output_val_proxy = torch.fx.proxy.Proxy(output_val) for i, output_name in enumerate(partition.outputs): base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] @@ -672,9 +671,9 @@ def add_placeholder(): if node.op == "output": base_mod_graph.output( torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) - ) # noqa: B950 + ) - ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + ret = _make_graph_module(base_mod_attrs, base_mod_graph) log.debug( "%s", lazy_format_graph_code("post split_module", ret, colored=True), diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 08a4c5a685ec7..e11ae3cb2a95f 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -1,6 +1,6 @@ -# mypy: allow-untyped-defs import copy from dataclasses import dataclass, field +from typing import Any, TYPE_CHECKING import torch.fx from torch.fx._compatibility import compatibility @@ -10,6 +10,10 @@ from .tools_common import CALLABLE_NODE_OPS, is_node_output_tensor, NodeList +if TYPE_CHECKING: + from .splitter_base import Subgraph + + __all__ = [ "getattr_recursive", "setattr_recursive", @@ -20,7 +24,7 @@ @compatibility(is_backward_compatible=False) -def getattr_recursive(obj, name): +def getattr_recursive(obj: object, name: str) -> Any: for layer in name.split("."): if isinstance(obj, torch.nn.ModuleList): if hasattr(obj, "_modules") and layer in obj._modules: @@ -35,7 +39,7 @@ def getattr_recursive(obj, name): @compatibility(is_backward_compatible=False) -def setattr_recursive(obj, attr, value): +def setattr_recursive(obj: object, attr: str, value: object) -> None: if "." not in attr: setattr(obj, attr, value) else: @@ -55,13 +59,13 @@ class Component: name: str # Stores the placeholder nodes in `graph`. - input_placeholders: list = field(default_factory=list) + input_placeholders: list[torch.fx.Node] = field(default_factory=list) # Store the nodes in original graph that are placeholder in `graph`. - orig_inputs: list = field(default_factory=list) + orig_inputs: list[torch.fx.Node] = field(default_factory=list) # Store the nodes in original graph that are outputs in `graph`. - orig_outputs: list = field(default_factory=list) + orig_outputs: list[torch.fx.Node] = field(default_factory=list) # Mapping from get_attr node in original graph to get_attr node in `graph`. getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) @@ -217,10 +221,14 @@ def flatten(x: torch.fx.node.Argument) -> NodeList: ) # Map a input of `node` to nodes in the component's graph. - def remap_func(x): + def remap_func(x: torch.fx.Node) -> torch.fx.Node: # If input is a get_attr node, copy it to current component's graph. # Returns the get_attr node in current component's graph. if x.op == "get_attr": + if not isinstance(x.target, str): + raise RuntimeError( + f"Expected get_attr node target to be a str, got {type(x.target)}" + ) if x not in comp.getattr_maps: comp.getattr_maps[x] = comp.graph.get_attr( x.target, type_expr=x.type @@ -321,7 +329,7 @@ def remap_func(x): @compatibility(is_backward_compatible=False) -def move_non_tensor_nodes_on_boundary(subgraphs) -> None: +def move_non_tensor_nodes_on_boundary(subgraphs: list["Subgraph"]) -> None: """ Move non-tensor nodes on the boundary between subgraphs. @@ -391,7 +399,7 @@ def can_move_node_and_dependencies( visited = set() can_move = True - def dfs(current_node): + def dfs(current_node: torch.fx.Node) -> None: nonlocal can_move, nodes_to_move if current_node in visited: diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 3de6753228a82..0536c67ca9f98 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,12 +1,11 @@ -# mypy: allow-untyped-defs import argparse import copy import json import os from collections import defaultdict -from collections.abc import Iterable, Sequence +from collections.abc import Callable, Iterable, Sequence from dataclasses import dataclass -from typing import Any, Literal, NamedTuple +from typing import Any, IO, Literal, NamedTuple import torch from torch._logging import trace_structured @@ -76,12 +75,12 @@ class _SplitterSettingBase: def __init__( self, - min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, - skip_fusion=DEFAULT_SKIP_FUSION, - allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + min_acc_module_size: int = DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion: bool = DEFAULT_SKIP_FUSION, + allow_non_tensor: bool = DEFAULT_ALLOW_NON_TENSOR, max_acc_splits: int = -1, move_non_tensor_nodes_on_boundary: bool = False, - ): + ) -> None: parser = argparse.ArgumentParser() parser.add_argument( "--min-acc-module-size", @@ -159,12 +158,12 @@ class NodeEvent: def __init__( self, source: torch.fx.Node, desc: str, dep: torch.fx.Node | None = None - ): + ) -> None: self.source = source self.desc = desc self.dep = dep - def to_str(self): + def to_str(self) -> str: # source: The name of the subject of the event. # desc: description of the event, in the format of | # dep: The name of the cause of this event, which is another node, or # @@ -178,7 +177,7 @@ class NodeEventTracker: Tracks node events during the splitter execution. """ - def __init__(self, tracker_mode, dump_prefix): + def __init__(self, tracker_mode: int, dump_prefix: str) -> None: self.tracker_mode = tracker_mode self.dump_prefix = dump_prefix # list of events @@ -187,7 +186,9 @@ def __init__(self, tracker_mode, dump_prefix): self.node_events = {} self.writer = print - def add(self, node: torch.fx.Node, desc: str, dep: torch.fx.Node | None = None): + def add( + self, node: torch.fx.Node, desc: str, dep: torch.fx.Node | None = None + ) -> None: """ Add a new event to the tracker. """ @@ -197,7 +198,13 @@ def add(self, node: torch.fx.Node, desc: str, dep: torch.fx.Node | None = None): self.node_events[node.name] = [] self.node_events[node.name].append(len(self.events) - 1) - def print_node(self, node_name, recursive=False, tab="", writer=None): + def print_node( + self, + node_name: str, + recursive: bool = False, + tab: str = "", + writer: Callable[[str], object] | None = None, + ) -> None: """ Print a node and its events. @param recursive: if True, print nodes that caused the events on this current node. @@ -214,7 +221,7 @@ def print_node(self, node_name, recursive=False, tab="", writer=None): event.dep.name, recursive=True, tab="| " + tab, writer=writer ) - def to_dict(self): + def to_dict(self) -> dict[str, list[str]]: """ Create dict dump on all events. """ @@ -226,7 +233,7 @@ def to_dict(self): ret[name].append(event.to_str()) return ret - def print_all(self, writer=None): + def print_all(self, writer: Callable[[str], object] | None = None) -> None: """ Print all nodes in a list. @param writer: function to write to file. If None, use print. @@ -237,7 +244,7 @@ def print_all(self, writer=None): writer(f"Node: {name}:") self.print_node(name, recursive=False, tab=" ", writer=writer) - def dump(self): + def dump(self) -> None: """ Function to be invoked at the end of the finder execution to printout tracked events specified by the mode. """ @@ -251,8 +258,8 @@ def dump(self): payload_fn=lambda: json.dumps(self.to_dict()), ) - def writeln(f): - def fn(x): + def make_writer(f: IO[str]) -> Callable[[str], int]: + def fn(x: str) -> int: return f.write(x + "\n") return fn @@ -261,16 +268,15 @@ def fn(x): # Mode >=1: Dump all events to file if self.tracker_mode >= 1: with open(self.dump_prefix + ALL_SUFFIX, "w") as f: - self.print_all(writeln(f)) + self.print_all(make_writer(f)) - def dump_selected_nodes(nodes): + def dump_selected_nodes(nodes: list[str]) -> None: with open(self.dump_prefix + NODES_SUFFIX, "w") as f: + writer = make_writer(f) for node_name in nodes: - writeln(f"===== Tracking node {node_name} =====") - self.print_node( - node_name, recursive=True, tab="|-", writer=writeln(f) - ) - writeln(f"===== End of tracking node {node_name} =====") + writer(f"===== Tracking node {node_name} =====") + self.print_node(node_name, recursive=True, tab="|-", writer=writer) + writer(f"===== End of tracking node {node_name} =====") # Mode 2: Dump specific nodes in recursive manner. # Mode 3: Dump all nodes with more than 1 event in recursive manner. @@ -307,7 +313,7 @@ def __init__( module: torch.fx.GraphModule, operator_support: OperatorSupportBase, allow_non_tensor: bool, - ): + ) -> None: self.module = module self.operator_support = operator_support self.allow_non_tensor = allow_non_tensor @@ -315,7 +321,7 @@ def __init__( self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX) - def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList) -> None: """ Transitively excludes nodes from ACC supported set. For every node in the worklist: @@ -334,7 +340,7 @@ def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): self.tracker.add(user, "new_cpu_node|non_tensor_output") cpu_worklist.append(user) - def reduce_acc_nodes_non_tensor_input(self): + def reduce_acc_nodes_non_tensor_input(self) -> None: """ Excludes nodes from ACC supported set that have direct upstream CPU nodes that produce non-tensor outputs. @@ -353,7 +359,7 @@ def reduce_acc_nodes_non_tensor_input(self): self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) - def reduce_acc_nodes_non_tensor_output(self): + def reduce_acc_nodes_non_tensor_output(self) -> None: """ Excludes nodes from ACC supported set that produce non-tensor outputs and have downstream CPU nodes. @@ -452,10 +458,10 @@ def generate_inputs_for_submodules( """ handles = [] - results = {} + results: dict[str, Any] = {} submodule_to_names = {mod: name for name, mod in model.named_modules()} - def pre_forward(module, module_inputs): + def pre_forward(module: torch.nn.Module, module_inputs: tuple[Any, ...]) -> None: results[submodule_to_names[module]] = ( copy.deepcopy(module_inputs) if deepcopy else module_inputs ) @@ -465,7 +471,7 @@ def pre_forward(module, module_inputs): if not isinstance(mod, torch.jit.ScriptModule): handles.append(mod.register_forward_pre_hook(pre_forward)) - def clean_up_handles(): + def clean_up_handles() -> None: for h in handles: h.remove() @@ -535,7 +541,7 @@ def __init__( non_acc_submodule_name: str = "_run_on_cpu_", return_tuple: bool = False, nodes_finder: FxNetAccNodesFinder | None = None, - ): + ) -> None: """ Preprocesses graph before splitting: - finds nodes supported by ACC, @@ -606,7 +612,7 @@ def find_deps(self) -> dict[torch.fx.Node, NodeSet]: deps[user].add(node) return deps - def update_deps_for_fusions(self): + def update_deps_for_fusions(self) -> None: """ Updates graph of dependencies so that: - nodes from the same fusion depend on the same set of outer nodes, @@ -621,7 +627,7 @@ def update_deps_for_fusions(self): if user not in fusion: self.deps[user].add(node) - def _merge_overlapping_fusions(self): + def _merge_overlapping_fusions(self) -> None: """ Merge fusion groups that share nodes. @@ -709,7 +715,7 @@ def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: def _draw_graph_based_on_node_support( self, mod: torch.fx.GraphModule, supported_nodes: NodeList - ): + ) -> None: color_map = { "default": "AliceBlue", "supported": "chartreuse1", @@ -717,7 +723,7 @@ def _draw_graph_based_on_node_support( } class CustomDrawer(FxGraphDrawer): - def _get_node_style(self, node): + def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]: template = super()._get_node_style(node) if node in supported_nodes: template["fillcolor"] = color_map["supported"] @@ -733,14 +739,14 @@ def _get_node_style(self, node): # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined] - def node_support_preview(self, dump_graph: bool = False): + def node_support_preview(self, dump_graph: bool = False) -> str: submodules = dict(self.module.named_modules()) supported_nodes: NodeList = [] supported_node_types = defaultdict(set) unsupported_node_types = defaultdict(set) - def get_dtype(arg): + def get_dtype(arg: torch.fx.Node) -> torch.dtype | None: tensor_meta = arg.meta.get("tensor_meta") return getattr(tensor_meta, "dtype", None) @@ -800,7 +806,7 @@ def get_dtype(arg): # Return reports for testing purpose return reports - def split_preview(self, dump_graph: bool = False): + def split_preview(self, dump_graph: bool = False) -> str: reports = "" subgraphs = self.put_nodes_into_subgraphs() acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) @@ -842,16 +848,24 @@ def split_preview(self, dump_graph: bool = False): submod = getattr(split_mod, node.target) - def get_submod_inputs(main_mod, submod, example_inputs): - sub_inputs = None + def get_submod_inputs( + main_mod: torch.fx.GraphModule, + submod: torch.fx.GraphModule, + example_inputs: Sequence[Any], + ) -> tuple[Any, ...]: + sub_inputs: tuple[Any, ...] | None = None - def get_inputs(self, inputs): + def get_inputs( + self: torch.nn.Module, inputs: tuple[Any, ...] + ) -> None: nonlocal sub_inputs sub_inputs = inputs handle = submod.register_forward_pre_hook(get_inputs) main_mod(*example_inputs) handle.remove() + if sub_inputs is None: + raise AssertionError("Forward pre-hook did not capture inputs") return sub_inputs submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) @@ -872,7 +886,7 @@ def get_inputs(self, inputs): reports += "Checking outputs...\n" - def get_bytes(node: torch.fx.Node): + def get_bytes(node: torch.fx.Node) -> None: nonlocal total_output_bytes nonlocal reports if not is_node_output_tensor(node): @@ -937,7 +951,9 @@ def find_reverse_deps( return result - def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): + def update_reverse_deps_for_fusions( + self, deps: dict[torch.fx.Node, NodeSet] + ) -> None: processed_node = set() for node, fusion in self.fusions.items(): @@ -981,7 +997,7 @@ def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: return parent_nodes - def extend_acc_subgraph(self, tag: str): + def extend_acc_subgraph(self, tag: str) -> None: """ Extend the acc subgraph with `tag` going the reversed topological direction. """ @@ -1145,7 +1161,7 @@ def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph result.append(subgraph) return result - def tag(self, subgraphs: list[Subgraph]): + def tag(self, subgraphs: list[Subgraph]) -> None: self.tags = [] for subgraph in subgraphs: tag = ( @@ -1186,7 +1202,7 @@ def __call__(self) -> torch.fx.GraphModule: def generate_split_results(self) -> SplitResult: split_module = self() - submodule_names = [] + submodule_names: list[str] = [] for name, _mod in split_module.named_children(): submodule_names.append(name) if ( diff --git a/torch/fx/passes/tests/_test_split_utils.py b/torch/fx/passes/tests/_test_split_utils.py index d841685668624..e1751d7f88686 100644 --- a/torch/fx/passes/tests/_test_split_utils.py +++ b/torch/fx/passes/tests/_test_split_utils.py @@ -42,10 +42,12 @@ def _create_mock_node( node.meta = {"type": int} # Non-tensor type # Mock users dict (Node.users is dict[Node, None]) - node.users = {} + users: dict[torch.fx.Node, None] = {} + node.users = users # Initialize the _input_nodes dict (Node._input_nodes is dict[Node, None]) - node._input_nodes = {} + input_nodes: dict[torch.fx.Node, None] = {} + node._input_nodes = input_nodes return node diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index 75f76a6d0da18..751a615698cbd 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import collections import heapq import operator @@ -30,7 +29,7 @@ @compatibility(is_backward_compatible=False) -def get_acc_ops_name(k): +def get_acc_ops_name(k: str | type) -> str: if isinstance(k, str): return k elif k.__module__ and "acc_ops" in k.__module__: @@ -105,7 +104,7 @@ class FxNetAccFusionsFinder: Such groups are called fusion groups. """ - def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): + def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet) -> None: self.module = module self.nodes = list(module.graph.nodes) self.acc_nodes = acc_nodes @@ -125,7 +124,7 @@ class FusionGroup: # Nodes that in the fusion group that haven't been processed yet. nodes_need_process: NodeSet - def add_node(self, node): + def add_node(self, node: torch.fx.Node) -> None: """ Add a node to fusion group. """ @@ -148,7 +147,7 @@ def recursive_add_node( fusion_group: "FxNetAccFusionsFinder.FusionGroup", inputs: NodeSet | NodeList, visited: NodeSet | None = None, - ): + ) -> bool: """ Start from inputs and going reverse topological order. If any upstream node is in the fusion group, add all the nodes in this path to fusion group. @@ -302,7 +301,7 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: for node in gm.graph.nodes: for user in node.users: indeg[user] += 1 - queue: collections.deque = collections.deque() + queue: collections.deque[torch.fx.Node] = collections.deque() # Add all nodes with no dependencies to the queue for node in gm.graph.nodes: if indeg[node] == 0: diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index a0a375b96dac2..87f9fa4326845 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -1,5 +1,3 @@ -# mypy: allow-untyped-defs - from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule @@ -17,7 +15,7 @@ class HolderModule(Module): that uses the attributes """ - def __init__(self, d): + def __init__(self, d: dict[str, Module | None]) -> None: super().__init__() for k, v in d.items(): self.add_module(k, v) diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index e0ebe687e95f5..c5c59e40e3007 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import copy import logging import os @@ -15,7 +14,7 @@ # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs -def _init_logger(): +def _init_logger() -> logging.Logger: logger = logging.getLogger(__name__) level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() @@ -51,7 +50,7 @@ class InternalMatch: # only available if the matcher is `SubgraphMatcherWithNameNodesMap` name_node_map: dict[str, Node] = field(default_factory=dict) - def __copy__(self): + def __copy__(self) -> "InternalMatch": return InternalMatch( anchors=self.anchors, nodes_map=self.nodes_map.copy(), @@ -251,7 +250,9 @@ def _match_nodes( # match for `gn` match_found = True - def _match_args(args1: list | tuple, args2: list | tuple) -> bool: + def _match_args( + args1: list[Any] | tuple[Any, ...], args2: list[Any] | tuple[Any, ...] + ) -> bool: if len(args1) != len(args2): return False @@ -271,7 +272,8 @@ def _match_args(args1: list | tuple, args2: list | tuple) -> bool: return True # Flatten all args/kwargs into 1 list of args - pn_args, gn_args = None, None + pn_args: list[Any] | None = None + gn_args: list[Any] | None = None if ( ( len(pn.args) != len(gn.args) @@ -282,7 +284,9 @@ def _match_args(args1: list | tuple, args2: list | tuple) -> bool: ): args_schema = pn.target._schema.arguments - def get_all_arguments(orig_args, orig_kwargs): + def get_all_arguments( + orig_args: tuple[Any, ...], orig_kwargs: dict[str, Any] + ) -> list[Any]: all_args = [] for i, schema in enumerate(args_schema): if schema.name in orig_kwargs: @@ -370,7 +374,7 @@ def match(self, graph: Graph, node_name_match: str = "") -> list[InternalMatch]: matches: list[InternalMatch] = [] - def backtracking(anchor_index, match): + def backtracking(anchor_index: int, match: InternalMatch) -> None: if anchor_index == len(match_candidates_list): match.placeholder_nodes = [ match.nodes_map[pn] for pn in self.pattern_placeholder_nodes @@ -418,7 +422,7 @@ def backtracking(anchor_index, match): ) # filter out the matches that form a cycle if the subgraph is fused - valid_matches = [] + valid_matches: list[InternalMatch] = [] for match in matches: matched_compute_nodes = [ gn diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 8b2f0f2e89e22..43726f79d1ed7 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -13,7 +13,7 @@ def _split_to_graph_and_name_node_map( from torch.fx.graph import _PyTreeInfo from torch.utils._pytree import tree_flatten, tree_unflatten - name_node_map = {} + name_node_map: dict[str, Node] = {} for n in gm.graph.nodes: if n.op == "output": if gm._out_spec is None: diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 0d0e13707c9d1..811f58e5317d3 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - import collections import copy import dis @@ -9,10 +7,12 @@ import operator import sys import traceback +import types from collections import OrderedDict from collections.abc import Callable, Iterator from dataclasses import fields, is_dataclass -from typing import Any, TypeVar +from typing import Any, cast, TypeVar +from typing_extensions import Never import torch import torch.fx.traceback as fx_traceback @@ -117,7 +117,7 @@ def forward(self, x): """ - def __init__(self, module_path: str, module_type: Any): + def __init__(self, module_path: str, module_type: Any) -> None: super().__init__() self.module_path = module_path self.module_type = module_type @@ -134,7 +134,7 @@ def __init__( self, scope: Scope, current_scope: Scope, - ): + ) -> None: super().__init__() # Keep a copy of prev scope to restore on exit self._prev_scope = copy.copy(scope) @@ -144,10 +144,10 @@ def __init__( # Save a reference so we can restore it self._scope = scope - def __enter__(self): + def __enter__(self) -> Scope: return self._scope - def __exit__(self, *args): + def __exit__(self, *args: object) -> None: self._scope.module_path = self._prev_scope.module_path self._scope.module_type = self._prev_scope.module_type return @@ -217,7 +217,8 @@ def create_node( """ if kind == "call_function" and self.check_mutable_operations: - check_for_mutable_operation(target, args, kwargs) + target_fn = cast(Callable[..., Any], target) + check_for_mutable_operation(target_fn, args, kwargs) node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) # TODO node_name_to_scope will be depreciated in favor of @@ -237,6 +238,7 @@ def create_node( if fx_traceback.GRADIENT_ACC_SPECIAL_STACK in stack_trace: node.meta["is_gradient_acc"] = True + node.meta["autograd_backward"] = True # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta # If other meta fields are needed, they can be added here @@ -254,7 +256,7 @@ def create_node( # See Note [Functionalization View Replay Annotation] # Overriding some node meta with the original node meta of the # regenerated node. - replay_node: Node = fx_traceback.get_current_replay_node() + replay_node: Node | None = fx_traceback.get_current_replay_node() if replay_node is not None: node.meta["is_functional_regenerated"] = True if "custom" in replay_node.meta: @@ -262,6 +264,9 @@ def create_node( if "stack_trace" in replay_node.meta: node.stack_trace = replay_node.meta.get("stack_trace") + if current_meta.get("autograd_backward", False): + node.meta["autograd_backward"] = True + elif self.module_stack: node.meta["nn_module_stack"] = copy.copy(self.module_stack) @@ -280,7 +285,7 @@ def _filter_traceback_frames( ) -> traceback.StackSummary: # This method can be overridden to customize the frame filtering logic # for the recorded stack trace - user_frames = [] + user_frames: list[traceback.FrameSummary] = [] if self._record_forward_stack_traces_only: user_frames = [ frame @@ -301,7 +306,7 @@ def _filter_traceback_frames( # Not having a "forward" call in the stacktrace implies the # stacktrace will probably be irrelevant if first_forward == -1: - user_frames = [] + user_frames: list[traceback.FrameSummary] = [] from torch.fx.experimental.symbolic_shapes import uninteresting_files @@ -326,9 +331,8 @@ def create_proxy( kwargs: dict[str, Any], name: str | None = None, type_expr: Any | None = None, - # fix noqa when updating bc tests - proxy_factory_fn: Callable[[Node], "Proxy"] = None, # noqa: RUF013 - ): + proxy_factory_fn: Callable[[Node], "Proxy"] | None = None, + ) -> "Proxy": """ Create a Node from the given arguments, then return the Node wrapped in a Proxy object. @@ -355,7 +359,7 @@ def create_proxy( return proxy - def _find_user_frame(self): + def _find_user_frame(self) -> types.FrameType | None: """ Find the Python stack frame executing the user code during symbolic tracing. @@ -440,14 +444,14 @@ def create_arg(self, a: Any) -> Argument: ) elif isinstance(a, range): - return range( - self.create_arg(a.start), - self.create_arg(a.stop), - self.create_arg(a.step), + return range( # pyrefly: ignore[no-matching-overload] + self.create_arg(a.start), # pyrefly: ignore[bad-argument-type] + self.create_arg(a.stop), # pyrefly: ignore[bad-argument-type] + self.create_arg(a.step), # pyrefly: ignore[bad-argument-type] ) elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): - return a + return a # pyrefly: ignore[bad-return] elif is_opaque_value_type(type(a)): return a @@ -460,7 +464,7 @@ def create_arg(self, a: Any) -> Argument: return self.create_node("call_function", a.__class__, (), kwargs) elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: - return a + return a # pyrefly: ignore[bad-return] raise NotImplementedError(f"argument of type: {type(a)}") @@ -494,7 +498,7 @@ def iter(self, obj: "Proxy") -> Iterator: ) @compatibility(is_backward_compatible=True) - def keys(self, obj: "Proxy") -> Any: + def keys(self, obj: "Proxy") -> "Proxy": """Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an iterator it ** is suppose to work in your custom tracer. @@ -502,7 +506,7 @@ def keys(self, obj: "Proxy") -> Any: return Attribute(obj, "keys")() -def _get_seq_nr(node_name: str = ""): +def _get_seq_nr(node_name: str = "") -> int | None: """ Returns the seq_nr node meta for the current proxy node that we're creating. The seq_nr number in node meta is related to but not the same as the "sequence number" @@ -547,7 +551,7 @@ def _get_seq_nr(node_name: str = ""): # See Note [Functionalization View Replay Annotation] # Overriding some node meta with the original node meta of the # regenerated node. - replay_node: Node = fx_traceback.get_current_replay_node() + replay_node: Node | None = fx_traceback.get_current_replay_node() if replay_node is not None: if "seq_nr" in replay_node.meta: annotation_log.debug("%s: seq_nr from replay_node", node_name) @@ -559,16 +563,16 @@ def _get_seq_nr(node_name: str = ""): # used in Proxy object when just appending to the graph while not tracing. @compatibility(is_backward_compatible=True) class GraphAppendingTracer(TracerBase): - def __init__(self, graph: Graph): + def __init__(self, graph: Graph) -> None: super().__init__() self.graph = graph self.scope = Scope("", None) - self.module_stack = collections.OrderedDict() - self.node_name_to_scope = {} + self.module_stack: OrderedDict[str, tuple[str, Any]] = collections.OrderedDict() + self.node_name_to_scope: dict[str, tuple[str, type]] = {} @compatibility(is_backward_compatible=False) -def assert_fn(x): +def assert_fn(x: object) -> None: if not x: raise AssertionError("Assertion failed") @@ -609,7 +613,7 @@ class Proxy: """ @compatibility(is_backward_compatible=True) - def __init__(self, node: Node, tracer: "TracerBase | None" = None): + def __init__(self, node: Node, tracer: "TracerBase | None" = None) -> None: if tracer is None: # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) @@ -619,21 +623,21 @@ def __init__(self, node: Node, tracer: "TracerBase | None" = None): def __repr__(self) -> str: return f"Proxy({self.node.name})" - def __getattr__(self, k) -> "Attribute": + def __getattr__(self, k: str) -> "Attribute": # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return Attribute(self, k) - def __getstate__(self) -> dict: + def __getstate__(self) -> dict[str, Any]: return self.__dict__ - def __deepcopy__(self, memo) -> dict: + def __deepcopy__(self, memo: dict[int, Any]) -> "Proxy": # We have to explicitly override this method, because otherwise deepcopy # will go to __getattr__(self, "__deepcopy__") and return a # Attribute(__deepcopy__), and may go into an infinite loop in some cases. import copy - new_dict = {} + new_dict: dict[str, Any] = {} for k, v in self.__dict__.items(): try: new_obj = copy.deepcopy(v, memo) @@ -655,11 +659,11 @@ def __deepcopy__(self, memo) -> dict: new_proxy.__dict__[k] = v return new_proxy - def __setstate__(self, d): + def __setstate__(self, d: dict[str, Any]) -> None: # This is called when being unpickled/loaded. self.__dict__ = d - def __call__(self, *args, **kwargs) -> "Proxy": + def __call__(self, *args: Any, **kwargs: Any) -> "Proxy": return self.tracer.create_proxy( "call_method", "__call__", (self,) + args, kwargs ) @@ -686,7 +690,7 @@ def __iter__(self) -> Iterator["Proxy"]: return self.tracer.iter(self) - def __abs__(self): + def __abs__(self) -> "Proxy": return self.tracer.create_proxy("call_function", operator.abs, (self,), {}) def __bool__(self) -> bool: @@ -725,10 +729,10 @@ def __bool__(self) -> bool: return self.tracer.to_bool(self) @compatibility(is_backward_compatible=True) - def keys(self): + def keys(self) -> "Proxy": return self.tracer.keys(self) - def __len__(self): + def __len__(self) -> int: raise RuntimeError( "'len' is not supported in symbolic tracing by default. If you want " "this call to be recorded, please call torch.fx.wrap('len') at " @@ -736,13 +740,19 @@ def __len__(self): ) @classmethod - def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + def __torch_function__( + cls, + orig_method: Callable[..., Any], + types: tuple[type, ...], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + ) -> "Proxy": args = args if args else () kwargs = kwargs if kwargs else {} - tracers: dict[Any, None] = {} + tracers: dict[TracerBase, None] = {} - def find_tracer(a): + def find_tracer(a: Any) -> None: if isinstance(a, cls): tracers[a.tracer] = None @@ -788,7 +798,9 @@ class MetaProxy(Proxy): A Proxy subclass that propagates metadata (meta['val']) during graph tracing. """ - def __init__(self, node: Node, tracer: "TracerBase | None" = None, fake_mode=None): + def __init__( + self, node: Node, tracer: "TracerBase | None" = None, fake_mode: Any = None + ) -> None: super().__init__(node, tracer) self.fake_mode = fake_mode @@ -796,7 +808,13 @@ def __repr__(self) -> str: return f"MetaProxy({self.node.name})" @classmethod - def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + def __torch_function__( + cls, + orig_method: Callable[..., Any], + types: tuple[type, ...], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + ) -> "MetaProxy": args = args if args else () kwargs = kwargs if kwargs else {} @@ -823,14 +841,14 @@ def __torch_function__(cls, orig_method, types, args=None, kwargs=None): @compatibility(is_backward_compatible=True) class Attribute(Proxy): @compatibility(is_backward_compatible=True) - def __init__(self, root: Proxy, attr: str): + def __init__(self, root: Proxy, attr: str) -> None: self.root = root self.attr = attr self.tracer = root.tracer self._node: Node | None = None @property - def node(self): + def node(self) -> Node: # pyrefly: ignore[bad-override] # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: @@ -839,7 +857,7 @@ def node(self): ).node return self._node - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> "Proxy": return self.tracer.create_proxy( "call_method", self.attr, (self.root,) + args, kwargs ) @@ -853,7 +871,9 @@ class ParameterProxy(Proxy): so that conditional tests on these attributes will not throw exception during tracing """ - def __init__(self, tracer: TracerBase, node: Node, name, param): + def __init__( + self, tracer: TracerBase, node: Node, name: str, param: torch.nn.Parameter + ) -> None: super().__init__(node, tracer) if not isinstance(param, torch.nn.Parameter): raise AssertionError(f"Expected Parameter, got {type(param)}") @@ -864,30 +884,30 @@ def __repr__(self) -> str: return f"ParameterProxy({self.name})" @property - def shape(self): + def shape(self) -> torch.Size: return self.param.shape - def size(self): + def size(self) -> torch.Size: return self.param.size() - def dim(self): + def dim(self) -> int: return self.param.dim() @property - def ndim(self): + def ndim(self) -> int: return self.param.ndim - def numel(self): + def numel(self) -> int: return self.param.numel() - def nelement(self): + def nelement(self) -> int: return self.param.nelement() for method in magic_methods: - def _scope(method): - def impl(*args, **kwargs): + def _scope(method: str) -> None: + def impl(*args: Any, **kwargs: Any) -> "Proxy": tracer = args[0].tracer target = getattr(operator, method) return tracer.create_proxy("call_function", target, args, kwargs) @@ -899,10 +919,10 @@ def impl(*args, **kwargs): _scope(method) -def _define_reflectable(orig_method_name): +def _define_reflectable(orig_method_name: str) -> None: method_name = f"__r{orig_method_name.strip('_')}__" - def impl(self, rhs): + def impl(self: "Proxy", rhs: Any) -> "Proxy": target = getattr(operator, orig_method_name) return self.tracer.create_proxy("call_function", target, (rhs, self), {}) @@ -915,15 +935,15 @@ def impl(self, rhs): _define_reflectable(orig_method_name) -def _no_nodes_error(arg): +def _no_nodes_error(arg: Argument) -> Never: raise RuntimeError( "Keys for dictionaries used as an argument cannot contain a " f"Node. Got key: {arg}" ) -def _create_arg_dict(self, a): - r = {} +def _create_arg_dict(self: TracerBase, a: dict[Any, Any]) -> dict[Any, Argument]: + r: dict[Any, Argument] = {} for k, v in a.items(): if not isinstance(k, str): # Check for invalid dict keys. We do not want a Proxy to appear diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index 564c0aeadf8d6..71d884cc86e70 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -95,8 +95,8 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Any | None: @compatibility(is_backward_compatible=True) def replace_pattern( gm: GraphModule, - pattern: Callable | GraphModule, - replacement: Callable | GraphModule, + pattern: Callable[..., Any] | GraphModule, + replacement: Callable[..., Any] | GraphModule, ) -> list[Match]: """ Matches all possible non-overlapping sets of operators and their @@ -225,8 +225,8 @@ def forward(self, x, w1, w2): @compatibility(is_backward_compatible=False) def replace_pattern_with_filters( gm: GraphModule, - pattern: Callable | Graph | GraphModule, - replacement: Callable | Graph | GraphModule | None = None, + pattern: Callable[..., Any] | Graph | GraphModule, + replacement: Callable[..., Any] | Graph | GraphModule | None = None, match_filters: list[Callable[["InternalMatch", Graph, Graph], bool]] | None = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility @@ -261,8 +261,8 @@ def replace_pattern_with_filters( def _replace_pattern( gm: GraphModule, - pattern: Callable | Graph | GraphModule, - replacement: Callable | Graph | GraphModule | None = None, + pattern: Callable[..., Any] | Graph | GraphModule, + replacement: Callable[..., Any] | Graph | GraphModule | None = None, match_filters: list[Callable[["InternalMatch", Graph, Graph], bool]] | None = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility @@ -417,7 +417,9 @@ def _replace_pattern( f"{len(copied_returning_nodes)}" # pyrefly: ignore [bad-argument-type] ) for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] + # pyrefly: ignore [bad-argument-type] gn.replace_all_uses_with(copied_node) + # pyrefly: ignore [unsupported-operation] match_changed_node[gn] = copied_node # Remove the original nodes for node in reversed(pattern_graph.nodes): diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index 4f375e461ef28..6e04c7b1cc858 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -1,9 +1,21 @@ -# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Any, TYPE_CHECKING + from torch.fx.experimental.unification import Var # type: ignore[attr-defined] from ._compatibility import compatibility +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.fx.experimental.migrate_gradual_types.constraint import DVar + + +__all__ = ["Dyn", "TensorType", "is_consistent", "is_more_precise"] + + @compatibility(is_backward_compatible=False) class TensorType: """ @@ -14,21 +26,23 @@ def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): return torch.add(x, y) """ - def __init__(self, dim): + __args__: Sequence[DVar | int | _DynType] + + def __init__(self, dim: Sequence[Any]) -> None: self.__origin__ = TensorType self.__args__ = dim - def __repr__(self): + def __repr__(self) -> str: return f"TensorType[{self.__args__}]" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): return list(self.__args__) == list(other.__args__) else: return False @staticmethod - def __class_getitem__(*args): + def __class_getitem__(*args: object) -> TensorType: if len(args) == 1 and isinstance(args[0], tuple): args = args[0] return TensorType(tuple(args)) @@ -42,13 +56,13 @@ class _DynType: def __init__(self) -> None: self.__name__ = "_DynType" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) - def __str__(self): + def __str__(self) -> str: return "Dyn" - def __repr__(self): + def __repr__(self) -> str: return "Dyn" @@ -56,7 +70,7 @@ def __repr__(self): @compatibility(is_backward_compatible=False) -def is_consistent(t1, t2): +def is_consistent(t1: object, t2: object) -> bool: """ A binary relation denoted by ~ that determines if t1 is consistent with t2. The relation is reflexive, symmetric but not transitive. @@ -84,7 +98,7 @@ def is_consistent(t1, t2): @compatibility(is_backward_compatible=False) -def is_more_precise(t1, t2): +def is_more_precise(t1: object, t2: object) -> bool: """ A binary relation denoted by <= that determines if t1 is more precise than t2. The relation is reflexive and transitive. diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index f13c2837b4344..485d2fbfa6a00 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,11 +1,11 @@ -# mypy: allow-untyped-defs import copy import logging import traceback from collections import defaultdict +from collections.abc import Callable, Iterator from contextlib import contextmanager from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional, ParamSpec, TypeVar, Union from torch._utils_internal import signpost_event @@ -15,6 +15,9 @@ from .node import Node +_P = ParamSpec("_P") +_R = TypeVar("_R") + log = logging.getLogger(__name__) __all__ = [ @@ -82,7 +85,7 @@ class NodeSource: """ class NodeInfo: - def __init__(self, name: str, target: str, graph_id: int): + def __init__(self, name: str, target: str, graph_id: int) -> None: self.name = name self.target = target self.graph_id = graph_id @@ -99,7 +102,7 @@ def __init__( node: Node | None, pass_name: str = "", action: Union["NodeSourceAction", list["NodeSourceAction"]] | None = None, - ): + ) -> None: self.pass_name = pass_name if action is None: @@ -139,15 +142,15 @@ def target(self) -> str: def graph_id(self) -> int: return self.node_info.graph_id if self.node_info else -1 - def __repr__(self): + def __repr__(self) -> str: return self.print_readable() - def _get_action_string(self): + def _get_action_string(self) -> str: if self._action_string is None: self._action_string = "+".join([a.name.lower() for a in self.action]) return self._action_string - def print_readable(self, indent=0): + def print_readable(self, indent: int = 0) -> str: if indent > 9: return "" result = "" @@ -160,7 +163,7 @@ def print_readable(self, indent=0): result += item.print_readable(indent + 1) return result - def to_dict(self) -> dict: + def to_dict(self) -> dict[str, Any]: if self._dict is None: # Convert the object to a dictionary action_string = self._get_action_string() @@ -177,15 +180,15 @@ def to_dict(self) -> dict: raise AssertionError("_dict is None after initialization") return self._dict - def __eq__(self, other: object): + def __eq__(self, other: object) -> bool: if not isinstance(other, NodeSource): return False return self.to_dict() == other.to_dict() - def __hash__(self): + def __hash__(self) -> int: # Create a hash based on the dictionary representation # We need to convert the dict to a hashable form - def _make_hashable(obj): + def _make_hashable(obj: Any) -> Any: if isinstance(obj, dict): return tuple(sorted((k, _make_hashable(v)) for k, v in obj.items())) elif isinstance(obj, list): @@ -196,7 +199,7 @@ def _make_hashable(obj): return hash(_make_hashable(self.to_dict())) @classmethod - def _from_dict(cls, d: dict | None) -> Optional["NodeSource"]: + def _from_dict(cls, d: dict[str, Any] | None) -> Optional["NodeSource"]: """ Recursively deserialize from_node metadata from dictionary data. It is used to deserialize the from_node field from serialized metadata. @@ -253,7 +256,7 @@ def _from_dict(cls, d: dict | None) -> Optional["NodeSource"]: @compatibility(is_backward_compatible=False) @contextmanager -def preserve_node_meta(enable=True): +def preserve_node_meta(enable: bool = True) -> Iterator[None]: global should_preserve_node_meta global current_meta saved_should_preserve_node_meta = should_preserve_node_meta @@ -268,7 +271,7 @@ def preserve_node_meta(enable=True): @contextmanager -def _preserve_node_seq_nr(preserve_seq_nr=True): +def _preserve_node_seq_nr(preserve_seq_nr: bool = True) -> Iterator[None]: """ Temporarily enables or disables the preservation of node.meta["seq_nr"] in the tracing context. @@ -284,7 +287,7 @@ def _preserve_node_seq_nr(preserve_seq_nr=True): @compatibility(is_backward_compatible=False) -def set_stack_trace(stack: list[str]): +def set_stack_trace(stack: list[str]) -> None: global current_meta if should_preserve_node_meta: @@ -298,7 +301,7 @@ def set_stack_trace(stack: list[str]): @compatibility(is_backward_compatible=False) @contextmanager -def annotate(annotation_dict: dict): +def annotate(annotation_dict: dict[str, Any]) -> Iterator[None]: """ Temporarily adds custom annotations to the current tracing context. The fx_node produced from this tracing context will have the @@ -338,7 +341,7 @@ def annotate(annotation_dict: dict): try: if not has_custom: - current_meta["custom"] = {} + current_meta["custom"] = dict[str, Any]() # Update with all key-value pairs from the input dict current_meta["custom"].update(annotation_dict) @@ -352,7 +355,9 @@ def annotate(annotation_dict: dict): @compatibility(is_backward_compatible=False) -def annotate_fn(annotation_dict: dict): +def annotate_fn( + annotation_dict: dict[str, Any], +) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]: """ A decorator that wraps a function with the annotate context manager. Use this when you want to annotate an entire function instead of a specific code block. @@ -377,19 +382,53 @@ def annotate_fn(annotation_dict: dict): """ from functools import wraps - def decorator(func): + def decorator(func: Callable[_P, _R]) -> Callable[_P, _R]: @wraps(func) - def wrapper(*args, **kwargs): + # NB: Do not annotate with _P.args/_P.kwargs here. Dynamo guards on + # the identity of ParamSpec annotation objects, causing guard failures. + def wrapper(*args: Any, **kwargs: Any) -> Any: with annotate(annotation_dict): return func(*args, **kwargs) - return wrapper + return wrapper # type: ignore[return-value] return decorator +@contextmanager +def _set_autograd_backward(enable: bool = True) -> Iterator[None]: + global current_meta + + had_autograd_backward = "autograd_backward" in current_meta + old_autograd_backward = current_meta.get("autograd_backward", False) + + if enable: + _mark_autograd_backward() + try: + yield + finally: + if had_autograd_backward: + current_meta["autograd_backward"] = old_autograd_backward + else: + _reset_autograd_backward() + + @compatibility(is_backward_compatible=False) -def set_grad_fn_seq_nr(seq_nr): +def _mark_autograd_backward() -> None: + global current_meta + + current_meta["autograd_backward"] = True + + +@compatibility(is_backward_compatible=False) +def _reset_autograd_backward() -> None: + global current_meta + + current_meta.pop("autograd_backward", None) + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr: int) -> None: global current_meta if should_preserve_node_meta: @@ -401,7 +440,7 @@ def set_grad_fn_seq_nr(seq_nr): @compatibility(is_backward_compatible=False) -def reset_grad_fn_seq_nr(): +def reset_grad_fn_seq_nr() -> None: # NB: reset state properly, this would be helpful towards supporting # reentrant autograd if we actually wanted to do that. global current_meta @@ -437,7 +476,7 @@ def _is_preserving_node_seq_nr() -> bool: @compatibility(is_backward_compatible=False) @contextmanager -def set_current_meta(node, pass_name=""): +def set_current_meta(node: Node, pass_name: str = "") -> Iterator[None]: global current_meta if should_preserve_node_meta and node.meta: saved_meta = current_meta @@ -465,7 +504,7 @@ def get_current_meta() -> dict[str, Any]: @compatibility(is_backward_compatible=False) @contextmanager -def set_current_replay_node(node): +def set_current_replay_node(node: Node | None) -> Iterator[None]: """ Set the currently replay node. If `current_replay_node` is not None, then we're re-generating the `current_replay_node` in FunctionalTensorMode. @@ -481,7 +520,7 @@ def set_current_replay_node(node): @compatibility(is_backward_compatible=False) -def get_current_replay_node(): +def get_current_replay_node() -> Node | None: """ Get the currently replay node """ @@ -522,7 +561,7 @@ def _get_custom_metadata(gm: GraphModule) -> str: if not isinstance(gm, GraphModule): raise AssertionError(f"Expected GraphModule, got {type(gm)}") - def helper(gm: GraphModule): + def helper(gm: GraphModule) -> list[Any]: custom_metadata = [] for node in gm.graph.nodes: if hasattr(node, "meta") and node.meta.get("custom", None): @@ -530,7 +569,10 @@ def helper(gm: GraphModule): if node.op == "get_attr" and isinstance( getattr(gm, node.target), GraphModule ): - custom_metadata.append(helper(getattr(gm, node.target))) + custom_metadata.append( + # pyrefly: ignore[bad-argument-type] + helper(getattr(gm, node.target)) + ) return custom_metadata return "\n".join(str(x) for x in helper(gm)) diff --git a/torch/headeronly/CMakeLists.txt b/torch/headeronly/CMakeLists.txt index 93d2d7802b528..e6065f96f12e5 100644 --- a/torch/headeronly/CMakeLists.txt +++ b/torch/headeronly/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.27 FATAL_ERROR) project(headeronly CXX) -set(CMAKE_CXX_STANDARD 17 CACHE STRING "The C++ standard whose features are requested to build this target.") +set(CMAKE_CXX_STANDARD 20 CACHE STRING "The C++ standard whose features are requested to build this target.") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Main build file for torch/headeronly, except there's no build cuz this lib is header-only! diff --git a/torch/headeronly/core/Dispatch_v2.h b/torch/headeronly/core/Dispatch_v2.h index 47ca18ce79c54..aa1bc1d9e9f72 100644 --- a/torch/headeronly/core/Dispatch_v2.h +++ b/torch/headeronly/core/Dispatch_v2.h @@ -54,6 +54,11 @@ torch::headeronly::ScalarType::UInt16, \ torch::headeronly::ScalarType::UInt32, \ torch::headeronly::ScalarType::UInt64 +#define AT_OPAQUE_TYPES \ + torch::headeronly::ScalarType::Byte, torch::headeronly::ScalarType::UInt16, \ + torch::headeronly::ScalarType::UInt32, \ + torch::headeronly::ScalarType::UInt64, \ + torch::headeronly::ScalarType::ComplexDouble #define AT_INTEGRAL_TYPES_V2 \ AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) #define AT_COMPLEX_TYPES \ diff --git a/torch/headeronly/macros/Macros.h b/torch/headeronly/macros/Macros.h index 880e741abf62b..cef99df3f566f 100644 --- a/torch/headeronly/macros/Macros.h +++ b/torch/headeronly/macros/Macros.h @@ -325,41 +325,88 @@ constexpr uint32_t CUDA_THREADS_PER_BLOCK_FALLBACK = 256; #define C10_HIP_HOST_DEVICE #endif -#if defined(USE_ROCM) // C10_WARP_SIZE is only allowed for device code. -// Host code _must_ use at::cuda::warp_size() +// Host code dynamically-sized launch configs _must_ use at::cuda::warp_size(). +// Host or device statically-sized arrays _must_ use either +// C10_WARP_SIZE_UPPER_BOUND or C10_WARP_SIZE_LOWER_BOUND, as needed. +// // HIP header used to define warpSize as a constexpr that was either 32 or 64 // depending on the target device, and then always set it to 64 for host code. -// Host pass of HIP compiler needs C10_WARP_SIZE defined to _something_ so we -// set it to something unreasonable to trigger obvious host code errors. - +// For a time, that allowed C10_WARP_SIZE to be defined like so: +// +// #ifdef USE_ROCM +// #define C10_WARP_SIZE warpSize +// #else +// #define C10_WARP_SIZE 32 +// #endif +// +// In ROCm 7, warpSize is no longer constexpr, matching CUDA behavior. +// We can now only use warpSize for C10_WARP_SIZE in device code and this is +// enforced by using __device__ in its definition. In host code where +// C10_WARP_SIZE was previously used as a compile-time constant, this will now +// cause a compile-time error. +// +// If an array was previously expected to be sized at compile-time using +// C10_WARP_SIZE, users must now use either C10_WARP_SIZE_UPPER_BOUND or +// C10_WARP_SIZE_LOWER_BOUND depending on the situation. +// +// If C10_WARP_SIZE was previously used to determine kernel launch sizes, users +// must now use at::cuda::warp_size() for the dynamic runtime query. +// +// Unfortunately, C10_WARP_SIZE has been public and available for both host and +// device since approximately 2019, so forcing it to be device-only would break +// existing code in the wild. +#if defined(USE_ROCM) namespace at::cuda { TORCH_CUDA_CPP_API int warp_size(); } -#ifdef __HIPCC__ -static inline int __host__ C10_WARP_SIZE_INTERNAL() { +#if defined(__HIPCC__) +static __host__ inline int C10_WARP_SIZE_INTERNAL() { return at::cuda::warp_size(); } - -static inline constexpr int __device__ C10_WARP_SIZE_INTERNAL() { +// NOTE: __device__ C10_WARP_SIZE_INTERNAL +// For __SPIRV__, we must use dynamic warpSize. When not targeting __SPIRV__, +// we can use constexpr. This matches prior behavior. We preserve this for +// backward compatibility instead of forcing old code to use dynamic warpSize +// and losing constexpr. However, compiling for --offload-arch=amdgcnspirv +// could expose where C10_WARP_SIZE was used incorrectly where the dynamic +// warpSize is not allowed. +#if defined(__SPIRV__) +static __device__ inline int C10_WARP_SIZE_INTERNAL() { + return warpSize; +} +#else // __SPIRV__ +static __device__ inline constexpr int C10_WARP_SIZE_INTERNAL() { #if defined(__GFX9__) return 64; #else // __GFX9__ return 32; #endif // __GFX9__ } -#else // __HIPCC__ +#endif // __SPIRV__ +#if defined(__SPIRV__) +#define C10_WARP_SIZE_LOWER_BOUND 32 +#define C10_WARP_SIZE_UPPER_BOUND 64 +#elif defined(__GFX9__) +#define C10_WARP_SIZE_LOWER_BOUND 64 +#define C10_WARP_SIZE_UPPER_BOUND 64 +#else +#define C10_WARP_SIZE_LOWER_BOUND 32 +#define C10_WARP_SIZE_UPPER_BOUND 32 +#endif +#else // !__HIPCC__ static inline int C10_WARP_SIZE_INTERNAL() { return at::cuda::warp_size(); } +#define C10_WARP_SIZE_LOWER_BOUND 32 +#define C10_WARP_SIZE_UPPER_BOUND 64 #endif // __HIPCC__ - #define C10_WARP_SIZE (C10_WARP_SIZE_INTERNAL()) -#define C10_WARP_SIZE_STATIC 64 - -#else // defined(USE_ROCM) +#else // !USE_ROCM #define C10_WARP_SIZE 32 -#endif +#define C10_WARP_SIZE_LOWER_BOUND 32 +#define C10_WARP_SIZE_UPPER_BOUND 32 +#endif // USE_ROCM #if defined(_MSC_VER) && _MSC_VER <= 1900 #define __func__ __FUNCTION__ diff --git a/torch/headeronly/util/BFloat16.h b/torch/headeronly/util/BFloat16.h index 64479ba36f125..9aa08c265bd2c 100644 --- a/torch/headeronly/util/BFloat16.h +++ b/torch/headeronly/util/BFloat16.h @@ -12,7 +12,7 @@ #include #include -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702)) #include #endif @@ -46,7 +46,7 @@ struct alignas(2) BFloat16 { /* implicit */ inline C10_HOST_DEVICE BFloat16(float value); inline C10_HOST_DEVICE operator float() const; -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702)) inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; #endif @@ -124,8 +124,9 @@ C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") /// Constructors inline C10_HOST_DEVICE BFloat16::BFloat16(float value) : -#if defined(__CUDACC__) && !defined(USE_ROCM) && defined(__CUDA_ARCH__) && \ - __CUDA_ARCH__ >= 800 +#if defined(__CUDACC__) && \ + (!defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 || \ + defined(USE_ROCM) && (TORCH_HIP_VERSION >= 702)) x(__bfloat16_as_ushort(__float2bfloat16(value))) #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -139,7 +140,7 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value) /// Implicit conversions inline C10_HOST_DEVICE BFloat16::operator float() const { -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702)) return __bfloat162float(*reinterpret_cast(&x)); #elif defined(__SYCL_DEVICE_ONLY__) && \ defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) @@ -149,7 +150,7 @@ inline C10_HOST_DEVICE BFloat16::operator float() const { #endif } -#if defined(__CUDACC__) && !defined(USE_ROCM) +#if defined(__CUDACC__) && (!defined(USE_ROCM) || (TORCH_HIP_VERSION >= 702)) inline C10_HOST_DEVICE BFloat16::BFloat16(const __nv_bfloat16& value) { x = *reinterpret_cast(&value); } diff --git a/torch/headeronly/util/Float8_e4m3fn.h b/torch/headeronly/util/Float8_e4m3fn.h index b35fb1a7aa85d..1ba2938c367a1 100644 --- a/torch/headeronly/util/Float8_e4m3fn.h +++ b/torch/headeronly/util/Float8_e4m3fn.h @@ -49,6 +49,7 @@ struct alignas(1) Float8_e4m3fn { inline C10_HOST_DEVICE Float8_e4m3fn(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; }; inline std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) { @@ -203,8 +204,13 @@ inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { f_bits ^= sign; if (f_bits >= fp8_max) { - // NaN - all exponent and mantissa bits set to 1 - result = 0x7f; + if (f_bits > UINT32_C(0x7F800000)) { + // NaN input → NaN output + result = 0x7f; + } else { + // Finite overflow or inf → saturate to max finite value + result = 0x7e; + } } else { if (f_bits < (UINT32_C(121) << 23)) { // Input number is smaller than 2^(-6), which is the smallest @@ -224,6 +230,11 @@ inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) { // take the bits! result = static_cast(f_bits >> 20); + + // Rounding may carry into the NaN bit pattern (0x7f); saturate to max + if (result == 0x7f) { + result = 0x7e; + } } } @@ -256,6 +267,11 @@ inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const { return (x & 0b01111111) == 0b01111111; } +inline C10_HOST_DEVICE bool Float8_e4m3fn::isinf() const { + // Note: fp8e4m3fn does not have infinity, so this always returns false. + return false; +} + /// Arithmetic inline C10_HOST_DEVICE Float8_e4m3fn diff --git a/torch/headeronly/util/Float8_e4m3fnuz.h b/torch/headeronly/util/Float8_e4m3fnuz.h index e361a2f92a2a5..4c11b3be05593 100644 --- a/torch/headeronly/util/Float8_e4m3fnuz.h +++ b/torch/headeronly/util/Float8_e4m3fnuz.h @@ -52,6 +52,7 @@ struct alignas(1) Float8_e4m3fnuz { inline C10_HOST_DEVICE Float8_e4m3fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; }; inline std::ostream& operator<<( @@ -160,6 +161,11 @@ inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { return x == 0b10000000; } +inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isinf() const { + // Note: fp8e4m3fnuz does not have infinity, so this always returns false. + return false; +} + /// Arithmetic inline C10_HOST_DEVICE Float8_e4m3fnuz diff --git a/torch/hub.py b/torch/hub.py index cc43e1c2450b4..72694e4b253b2 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -15,7 +15,7 @@ from typing import Any from typing_extensions import deprecated from urllib.error import HTTPError, URLError -from urllib.parse import urlparse # noqa: F401 +from urllib.parse import urlparse from urllib.request import Request, urlopen import torch diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index 34caf60599a94..8590b73a5808a 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -101,6 +101,7 @@ (torch.autograd.grad, "aten::grad"), (torch.autograd.backward, "aten::backward"), (torch._C._infer_size, "aten::_infer_size"), + (torch.broadcast_shapes, "aten::broadcast_shapes"), ( torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined] "aten::_no_grad_embedding_renorm_", diff --git a/torch/jit/_monkeytype_config.py b/torch/jit/_monkeytype_config.py index c4cf1089494ce..cea7bd2f91cf5 100644 --- a/torch/jit/_monkeytype_config.py +++ b/torch/jit/_monkeytype_config.py @@ -171,7 +171,7 @@ class JitTypeTraceConfig: # type: ignore[no-redef] def __init__(self) -> None: pass - monkeytype_trace = None # type: ignore[assignment] # noqa: F811 + monkeytype_trace = None # type: ignore[assignment] def jit_code_filter(code: CodeType) -> bool: diff --git a/torch/jit/_script.py b/torch/jit/_script.py index fb723c62ff658..502fbce14d215 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -284,7 +284,7 @@ def __getitem__(self, k): # parameters are initialized _before_ the script compiler resolve references to # `self.param` or `self.module`. class ScriptMeta(type): - def __init__(cls, name, bases, attrs): # noqa: B902 + def __init__(cls, name, bases, attrs): # Aggregate all the ScriptMethods and constants from superclasses cls._methods: dict[str, Any] = {} cls._constants_set = set(getattr(cls, "__constants__", ())) diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index c09de1971a4c4..cd870b3b2dbcf 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -145,7 +145,7 @@ def is_function_or_method(the_callable): def is_vararg(the_callable): - if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 + if not is_function_or_method(the_callable) and callable(the_callable): # If `the_callable` is a class, de-sugar the call so we can still get # the signature the_callable = the_callable.__call__ @@ -164,7 +164,7 @@ def get_param_names(fn, n_args): not is_function_or_method(fn) and callable(fn) and is_function_or_method(fn.__call__) - ): # noqa: B004 + ): # De-sugar calls to classes fn = fn.__call__ @@ -271,7 +271,7 @@ def get_type_line(source): "The annotation prefix in line " + str(wrong_type_lines[0][0]) + " is probably invalid.\nIt must be '# type:'" - + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 + + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" + "\nfor examples" ) return None diff --git a/torch/jit/supported_ops.py b/torch/jit/supported_ops.py index 8a258280ea352..ee4f74945f5b0 100644 --- a/torch/jit/supported_ops.py +++ b/torch/jit/supported_ops.py @@ -104,7 +104,7 @@ def _get_nn_functional_ops(): scripted = torch.jit.script(attr) scripted_schema = scripted.schema functions.append(_emit_schema(name, elem, scripted_schema)) - except: # noqa: B001,E722 + except: # noqa: E722 # Skip interpolate / boolean dispatched things pass diff --git a/torch/lib/libshm/CMakeLists.txt b/torch/lib/libshm/CMakeLists.txt index c3cd26fea7bf3..688e7e7929147 100644 --- a/torch/lib/libshm/CMakeLists.txt +++ b/torch/lib/libshm/CMakeLists.txt @@ -21,7 +21,7 @@ target_include_directories(shm PUBLIC set_target_properties(shm PROPERTIES PREFIX "lib" IMPORT_PREFIX "lib" - CXX_STANDARD 17) + CXX_STANDARD 20) target_link_libraries(shm PRIVATE ${TORCH_CPU_LIB}) if(UNIX AND NOT APPLE) diff --git a/torch/lib/libshm/manager.cpp b/torch/lib/libshm/manager.cpp index 5d2c318c25142..3084f794154a9 100644 --- a/torch/lib/libshm/manager.cpp +++ b/torch/lib/libshm/manager.cpp @@ -46,12 +46,7 @@ static void register_fd(int fd) { } static void unregister_fd(int fd) { - pollfds.erase( - std::remove_if( - pollfds.begin(), - pollfds.end(), - [fd](const struct pollfd& pfd) { return pfd.fd == fd; }), - pollfds.end()); + std::erase_if(pollfds, [fd](const struct pollfd& pfd) { return pfd.fd == fd; }); client_sessions.erase(fd); } diff --git a/torch/library.py b/torch/library.py index fe02a55ffd72e..fd73bef35bd90 100644 --- a/torch/library.py +++ b/torch/library.py @@ -4,7 +4,7 @@ import inspect import re import sys -import traceback +import warnings import weakref from collections.abc import Callable, Sequence from typing import Any, overload, TYPE_CHECKING, TypeVar, Union @@ -20,7 +20,7 @@ device_types_t, ) from torch._library.effects import EffectType -from torch._library.infer_schema import infer_schema # noqa: F401 +from torch._library.infer_schema import infer_schema from torch._library.triton import triton_op, wrap_triton from torch._ops import OpOverload from torch.types import _dtype @@ -65,6 +65,79 @@ def fallthrough_kernel(): raise NotImplementedError("fallthrough_kernel() should never be called.") +def _validate_out_schema(schema: "str | torch._C.FunctionSchema") -> None: + """Validate that a schema has valid out semantics, i.e., it can be tagged with torch.Tag.out. + + Requirements: + - Must have at least one mutable argument + - All returns must alias the mutable args in declaration order + + torchgen has equivalent checks (torchgen/model.py), but we reimplement them here + because (1) it's simple and (2) torchgen uses a different schema object + (torchgen.model.FunctionSchema vs torch._C.FunctionSchema) so it's difficult to + share the function. + """ + if isinstance(schema, str): + schema = torch._C.parse_schema(schema) + mutable_args = [ + arg + for arg in schema.arguments + if arg.alias_info is not None and arg.alias_info.is_write + ] + if not mutable_args: + raise ValueError( + f"Schema tagged with torch.Tag.out must have at least one mutable argument. " + f"Got: {schema}" + ) + positional_mutable = [arg for arg in mutable_args if not arg.kwarg_only] + if positional_mutable: + names = [a.name for a in positional_mutable] + raise ValueError( + f"Schema tagged with torch.Tag.out requires all mutable arguments to be " + f"keyword-only (after the *). Found mutable positional args: {names}. " + f"Got: {schema}" + ) + unsupported_mutable = [ + arg + for arg in mutable_args + if isinstance(arg.type, (torch.OptionalType, torch.ListType)) + ] + if unsupported_mutable: + names = [a.name for a in unsupported_mutable] + raise ValueError( + f"Schema tagged with torch.Tag.out only supports Tensor mutable arguments. " + f"Found unsupported mutable args: {names}. Got: {schema}" + ) + returns = schema.returns + if len(returns) != len(mutable_args): + raise ValueError( + f"Schema tagged with torch.Tag.out must return all mutable arguments " + f"(got {len(mutable_args)} mutable args but {len(returns)} returns). " + f"Got: {schema}" + ) + for i, (ret, arg) in enumerate(zip(returns, mutable_args, strict=True)): + arg_alias = arg.alias_info + ret_alias = ret.alias_info + if ret_alias is None: + raise ValueError( + f"Return {i} of schema tagged with torch.Tag.out must alias mutable arg '{arg.name}'. " + f"Got: {schema}" + ) + if not ret_alias.is_write: + raise ValueError( + f"Return {i} of schema tagged with torch.Tag.out must be a mutable alias " + f"(e.g., Tensor(a!), not Tensor(a)) of arg '{arg.name}'. " + f"Got: {schema}" + ) + # arg_alias is guaranteed non-None by the mutable_args filter above + if ret_alias.before_set != arg_alias.before_set: # type: ignore[union-attr] + raise ValueError( + f"Return {i} of schema tagged with torch.Tag.out must alias mutable arg '{arg.name}' " + f"(return aliases {ret_alias.before_set} but arg aliases {arg_alias.before_set}). " # type: ignore[union-attr] + f"Got: {schema}" + ) + + class Library: """ A class to create libraries that can be used to register new operators or @@ -95,8 +168,8 @@ def __init__(self, ns, kind, dispatch_key=""): f"{ns} is a reserved namespace. Please try creating a library with another name." ) - frame = traceback.extract_stack(limit=2)[0] - filename, lineno = frame.filename, frame.lineno + f = sys._getframe(1) + filename, lineno = f.f_code.co_filename, f.f_lineno self.m: Any | None = torch._C._dispatch_library( kind, ns, dispatch_key, filename, lineno ) @@ -161,6 +234,9 @@ def define(self, schema, alias_analysis="", *, tags=()): getattr(torch.ops, self.ns), packet_name ) + if torch.Tag.out in tags: + _validate_out_schema(schema) + result = self.m.define(schema, alias_analysis, tuple(tags)) name = schema.split("(")[0] qualname = self.ns + "::" + name @@ -343,6 +419,19 @@ def impl( if "::" not in dispatcher_op_name: dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" + op = torch._library.utils.lookup_op(dispatcher_op_name) + if torch._library.utils.is_out(op) and not torch._library.utils.is_builtin( + op + ): + warnings.warn( + f"Registering a Meta kernel for operator '{dispatcher_op_name}' " + f"which has torch.Tag.out. Operators with Tag.out automatically " + f"get a fake kernel that returns the out= arguments. We " + f"recommend not registering a fake/meta kernel manually " + f"because it is easy to get wrong.", + stacklevel=2, + ) + # Internally, we shouldn't be registering meta kernels for any operators that # have CompositeImplicitAutograd kernels. # Instead, we should be letting those decompositions run, and writing meta kernels diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index c7125f1ef2eec..ab09f17507a08 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -163,6 +163,38 @@ def compile_shader(source: str): return torch._C._mps_compileShader(source) +def load_metallib(source): + r"""Loads a precompiled Metal library (.metallib) and returns a shader + library object that allows invoking kernels defined in it. + + Args: + source: Either raw metallib bytes (``bytes``/``bytearray``) or a + filesystem path (``str``/``os.PathLike``) to a ``.metallib`` file. + + This is useful for loading Metal libraries compiled ahead of time or + generated by external tools (e.g. Triton, MetalASM). + + Example:: + + >>> # xdoctest: +SKIP("requires external .metallib file") + >>> lib = torch.mps.load_metallib("kernels.metallib") + >>> x = torch.ones(16, device="mps") + >>> lib.square(x) + """ + import os + + if isinstance(source, (bytes, bytearray)): + if not hasattr(torch._C, "_mps_loadMetalllib"): + raise RuntimeError("MPS is not available") + return torch._C._mps_loadMetalllib(bytes(source)) + elif isinstance(source, (str, os.PathLike)): + if not hasattr(torch._C, "_mps_loadMetallibFromPath"): + raise RuntimeError("MPS is not available") + return torch._C._mps_loadMetallibFromPath(str(source)) + else: + raise TypeError(f"expected bytes or path, got {type(source).__name__}") + + def is_available() -> bool: return device_count() > 0 @@ -173,6 +205,7 @@ def is_available() -> bool: __all__ = [ "compile_shader", + "load_metallib", "device_count", "get_rng_state", "manual_seed", diff --git a/torch/mtia/_utils.py b/torch/mtia/_utils.py index 16710c244c61c..8720f72b8620b 100644 --- a/torch/mtia/_utils.py +++ b/torch/mtia/_utils.py @@ -21,7 +21,7 @@ def _get_device_index( """ if device is None and optional: - # If device is None (frequent), then we can can short-circuit the logic + # If device is None (frequent), then we can short-circuit the logic return torch._C._mtia_getDevice() if isinstance(device, int): return device diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index d73f3d028820b..25da01502a7e1 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -29,7 +29,7 @@ from multiprocessing import * # noqa: F403 -__all__ += multiprocessing.__all__ # noqa: PLE0605 type: ignore[attr-defined] +__all__ += multiprocessing.__all__ # This call adds a Linux specific prctl(2) wrapper function to this module. diff --git a/torch/nativert/executor/ConstantFolder.cpp b/torch/nativert/executor/ConstantFolder.cpp index bbae8d1d426ac..0cad2841e2391 100644 --- a/torch/nativert/executor/ConstantFolder.cpp +++ b/torch/nativert/executor/ConstantFolder.cpp @@ -102,7 +102,7 @@ void ConstantFolder::unlinkConstants( if (user == output) { continue; } - if (foldable.find(user) == foldable.end()) { + if (!foldable.contains(user)) { constFoldableCandidates.push(user); } } @@ -115,7 +115,7 @@ void ConstantFolder::unlinkConstants( // we only store folded values if there is a non-foldable user if (const auto& users = value->users(); std::any_of(users.begin(), users.end(), [&](const auto* u) { - return foldable.find(u) == foldable.end(); + return !foldable.contains(u); })) { foldedOutputValueIds_.insert(value->id()); } @@ -132,12 +132,7 @@ void ConstantFolder::unlinkConstants( // remove moved (i.e., associated w/ const-folded nodes) kernels // from the input kernel vector - kernels.erase( - std::remove_if( - kernels.begin(), - kernels.end(), - [](const auto& k) { return k == nullptr; }), - kernels.end()); + std::erase(kernels, nullptr); graph_.renumberValues(); graph_.finalize(); diff --git a/torch/nativert/executor/GraphExecutorBase.cpp b/torch/nativert/executor/GraphExecutorBase.cpp index a623d5873ea56..5cec883376e1f 100644 --- a/torch/nativert/executor/GraphExecutorBase.cpp +++ b/torch/nativert/executor/GraphExecutorBase.cpp @@ -79,6 +79,21 @@ ProfileMetrics GraphExecutorBase::benchmarkIndividualNodes( } } + // Capture input element counts per node (frame is populated after warmup). + results.inputElementsPerNode.resize(numNodes, 0); + for (const auto i : c10::irange(numNodes)) { + int64_t total_elements = 0; + for (const auto& input : nodeKernels_[i]->node()->inputs()) { + if (input.value && input.value->type().kind() == Type::Kind::Tensor) { + const auto& iv = executionFrame.getIValue(input.value->id()); + if (iv.isTensor()) { + total_elements += iv.toTensor().numel(); + } + } + } + results.inputElementsPerNode[i] = total_elements; + } + // Execute kernels caffe2::Timer timer; executionFrame.withManagedMemory([&](auto) { diff --git a/torch/nativert/executor/GraphExecutorBase.h b/torch/nativert/executor/GraphExecutorBase.h index dfe020ebae29e..d78d28253e201 100644 --- a/torch/nativert/executor/GraphExecutorBase.h +++ b/torch/nativert/executor/GraphExecutorBase.h @@ -21,6 +21,7 @@ struct ProfileMetrics { std::unordered_map instancesPerNodeType; std::unordered_set staticDispatchNodes; std::unordered_set primNodes; + std::vector inputElementsPerNode; float totalTime{0}; std::string name; }; diff --git a/torch/nativert/executor/PlacementUtils.cpp b/torch/nativert/executor/PlacementUtils.cpp index 7dc6e761a0c94..97a2642323363 100644 --- a/torch/nativert/executor/PlacementUtils.cpp +++ b/torch/nativert/executor/PlacementUtils.cpp @@ -22,6 +22,5 @@ bool isSameDevice(const c10::Device& a, const c10::Device& b) { return b.is_mtia(); } TORCH_CHECK(false, "isSameDevice: Unsupported device type ", a, " and ", b); - return false; } } // namespace torch::nativert diff --git a/torch/nativert/graph/Graph.cpp b/torch/nativert/graph/Graph.cpp index bb524d7d39bcf..3531f12c6a685 100644 --- a/torch/nativert/graph/Graph.cpp +++ b/torch/nativert/graph/Graph.cpp @@ -964,10 +964,7 @@ void Value::addUser(Node* node) { } void Value::eraseUser(Node* node) { - users_.erase( - std::remove_if( - users_.begin(), users_.end(), [&](Node* el) { return el == node; }), - users_.end()); + std::erase(users_, node); } std::vector Value::getListElements() const { @@ -1083,6 +1080,15 @@ std::ostream& operator<<(std::ostream& out, const Constant& constant) { out << fmt::format("{}", fmt::streamed(inner_list)); } out << ']'; + } else if constexpr (is_same_v>>) { + out << '['; + for (const auto& [idx, inner_list] : c10::enumerate(arg)) { + if (idx > 0) { + out << ", "; + } + out << fmt::format("{}", fmt::streamed(inner_list)); + } + out << ']'; } else if constexpr (is_same_v>) { out << fmt::format(""); VLOG(0) << "Subgraph pretty print is not implemented"; diff --git a/torch/nativert/graph/Graph.h b/torch/nativert/graph/Graph.h index ce5dd8e412e3d..6d8c449d15c14 100644 --- a/torch/nativert/graph/Graph.h +++ b/torch/nativert/graph/Graph.h @@ -100,6 +100,7 @@ using Constant = std::variant< std::vector, std::vector, std::vector>, + std::vector>, std::unique_ptr>; c10::IValue constantToIValue(const Constant& constant); diff --git a/torch/nativert/graph/GraphSignature.cpp b/torch/nativert/graph/GraphSignature.cpp index 7adff3b6a895a..84b4c10192564 100644 --- a/torch/nativert/graph/GraphSignature.cpp +++ b/torch/nativert/graph/GraphSignature.cpp @@ -258,6 +258,8 @@ GraphSignature::GraphSignature(const torch::_export::GraphSignature& storage) { userInputArg.tag() == torch::_export::Argument::Tag::AS_SYM_FLOATS || userInputArg.tag() == torch::_export::Argument::Tag::AS_INT_LISTS || + userInputArg.tag() == + torch::_export::Argument::Tag::AS_FLOAT_LISTS || userInputArg.tag() == torch::_export::Argument::Tag::AS_SCALAR_TYPE || userInputArg.tag() == diff --git a/torch/nativert/graph/Serialization.cpp b/torch/nativert/graph/Serialization.cpp index 98bcc34a0cbfd..65444b0d84df6 100644 --- a/torch/nativert/graph/Serialization.cpp +++ b/torch/nativert/graph/Serialization.cpp @@ -354,7 +354,7 @@ std::unique_ptr jsonToSubgraph( } else if (arg.tag() == torch::_export::Argument::Tag::AS_NONE) { node->addInput(NamedArgument{ input.get_name(), - graph->addValue(std::nullopt, Type::Kind::None, node)}); + graph->addValue(std::nullopt, Type::Kind::None, nullptr)}); } else { node->addAttribute(Attribute{ input.get_name(), @@ -731,6 +731,17 @@ Constant constantToValue( } return ret; } + case torch::_export::Argument::Tag::AS_FLOAT_LISTS: { + std::vector> ret; + for (const auto& inner_list : jsonArg.get_as_float_lists()) { + std::vector inner_ret; + for (const auto& val : inner_list) { + inner_ret.push_back(val.get()); + } + ret.push_back(inner_ret); + } + return ret; + } case torch::_export::Argument::Tag::AS_STRING_TO_ARGUMENT: return None(); default: diff --git a/torch/nativert/python/Bindings.cpp b/torch/nativert/python/Bindings.cpp index d83c0bbc6ed43..de86306adea0a 100644 --- a/torch/nativert/python/Bindings.cpp +++ b/torch/nativert/python/Bindings.cpp @@ -3,6 +3,7 @@ #include #include +#ifdef FBCODE_CAFFE2 #include namespace py = pybind11; @@ -81,3 +82,22 @@ void initModelRunnerPybind(py::module& m) { } } // namespace torch::nativert + +#else // !FBCODE_CAFFE2 + +namespace py = pybind11; + +namespace torch::nativert { + +class StubModelRunner {}; + +// PyModelRunner is referenced from +// https://github.com/pytorch/benchmark/blob/b8d35ba51a3149b7212888b4010ddee97f19947f/userbenchmark/dynamo/dynamobench/common.py#L45 +void initModelRunnerPybind(py::module& m) { + py::class_>( + m, "PyModelRunner"); +} + +} // namespace torch::nativert + +#endif // FBCODE_CAFFE2 diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 2f3c0e56e0b74..2d5d35316cb2c 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -194,4 +194,4 @@ def _get_flash_version() -> str: restore_flash_attention_impl.__module__ = __name__ # Import built-in implementations to trigger self-registration -from . import _fa3, _fa4 # noqa: F401 # noqa: F401 +from . import _fa3, _fa4 diff --git a/torch/nn/attention/_fa4.py b/torch/nn/attention/_fa4.py index f0ea99a463532..c2786d9dceefe 100644 --- a/torch/nn/attention/_fa4.py +++ b/torch/nn/attention/_fa4.py @@ -67,6 +67,11 @@ def _fa4_import_module(module_path: str) -> ModuleType: def _fa4_register_kernels() -> Library: lib = Library("aten", "IMPL", "CUDA") # noqa: TOR901 lib.impl("_flash_attention_forward", _fa4_flash_attention_forward_impl, "CUDA") + lib.impl( + "_flash_attention_forward_no_dropout_inplace", + _fa4_flash_attention_forward_no_dropout_inplace_impl, + "CUDA", + ) lib.impl("_flash_attention_backward", _fa4_flash_attention_backward_impl, "CUDA") lib.impl( "_scaled_dot_product_flash_attention", @@ -116,6 +121,8 @@ def _fa4_forward_support_error( alibi_slopes: torch.Tensor | None, seqused_k: torch.Tensor | None, cum_seq_q: torch.Tensor | None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, ) -> str | None: if dropout_p != 0.0: return "dropout_p must be 0" @@ -128,6 +135,11 @@ def _fa4_forward_support_error( return "seqused_k must be int32" if not seqused_k.is_cuda: return "seqused_k must be CUDA" + major = _get_device_major(query.device) + if block_table is not None and major != 10: + return f"paged KV (block_table) not supported on SM {major}0" + if num_splits is not None and num_splits > 1 and major != 10: + return f"SplitKV (num_splits > 1) not supported on SM {major}0" error = _fa4_common_support_error( query, (query, key, value), @@ -149,13 +161,9 @@ def _fa4_backward_support_error( logsumexp: torch.Tensor, dropout_p: float, cum_seq_q: torch.Tensor | None, - window_size_left: int | None, - window_size_right: int | None, ) -> str | None: if dropout_p != 0.0: return "dropout_p must be 0" - if window_size_left is not None or window_size_right is not None: - return "windowed attention not supported" error = _fa4_common_support_error( query, (grad_out, query, key, value, out, logsumexp), @@ -167,6 +175,11 @@ def _fa4_backward_support_error( return None +def _aten_to_fa4_window_size(val: int | None) -> int | None: + """need to convert -1 to None for FA4""" + return None if val == -1 else val + + Ts = TypeVarTuple("Ts") @@ -180,12 +193,16 @@ def _fa4_run_forward( value: torch.Tensor, cu_seq_q: torch.Tensor | None, cu_seq_k: torch.Tensor | None, + max_q: int | None, + max_k: int | None, scale: float | None, is_causal: bool, window_size_left: int | None, window_size_right: int | None, seqused_k: torch.Tensor | None, out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if _FA4_MODULE_PATH is None: raise RuntimeError("FA4 not registered") @@ -194,15 +211,18 @@ def _fa4_run_forward( kwargs: dict[str, Any] = { "softmax_scale": scale, "causal": is_causal, - "window_size_left": window_size_left, - "window_size_right": window_size_right, + "window_size_left": _aten_to_fa4_window_size(window_size_left), + "window_size_right": _aten_to_fa4_window_size(window_size_right), "return_lse": True, "cu_seqlens_q": cu_seq_q, "cu_seqlens_k": cu_seq_k, + "max_seqlen_q": max_q, + "max_seqlen_k": max_k, "seqused_k": seqused_k.contiguous() if seqused_k is not None else None, + "page_table": block_table, + "num_splits": num_splits or 1, + "out": out, } - if out is not None: - kwargs["out"] = out out, lse = module._flash_attn_fwd(query, key, value, **kwargs) return out, lse.contiguous() @@ -218,6 +238,8 @@ def _fa4_run_backward( cu_seq_k: torch.Tensor | None, scale: float | None, is_causal: bool, + window_size_left: int | None, + window_size_right: int | None, deterministic: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if _FA4_MODULE_PATH is None: @@ -232,6 +254,8 @@ def _fa4_run_backward( logsumexp.contiguous(), softmax_scale=scale, causal=is_causal, + window_size_left=_aten_to_fa4_window_size(window_size_left), + window_size_right=_aten_to_fa4_window_size(window_size_right), cu_seqlens_q=cu_seq_q, cu_seqlens_k=cu_seq_k, deterministic=deterministic, @@ -257,6 +281,9 @@ def _fa4_flash_attention_forward_impl( seqused_k: torch.Tensor | None = None, alibi_slopes: torch.Tensor | None = None, out: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + compute_auxiliary: bool = True, + num_splits: int | None = None, ): error = _fa4_forward_support_error( query, @@ -267,6 +294,8 @@ def _fa4_flash_attention_forward_impl( alibi_slopes, seqused_k, cum_seq_q, + block_table, + num_splits, ) if error is not None: raise RuntimeError(f"FA4 flash_attention forward unsupported: {error}") @@ -276,19 +305,73 @@ def _fa4_flash_attention_forward_impl( value, cum_seq_q, cum_seq_k, + max_q, + max_k, scale, is_causal, window_size_left, window_size_right, seqused_k, out, + block_table, + num_splits, ) - rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device) - philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device) - debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + if compute_auxiliary: + rng_state = torch.zeros((2,), dtype=torch.uint64, device=query.device) + philox_offset = torch.zeros((), dtype=torch.uint64, device=query.device) + debug_mask = torch.empty(0, dtype=query.dtype, device=query.device) + else: + rng_state = None + philox_offset = None + debug_mask = None return out, lse, rng_state, philox_offset, debug_mask +def _fa4_flash_attention_forward_no_dropout_inplace_impl( + out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cum_seq_q: torch.Tensor | None, + cum_seq_k: torch.Tensor | None, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + return_debug_mask: bool, + *, + scale: float | None = None, + window_size_left: int | None = None, + window_size_right: int | None = None, + seqused_k: torch.Tensor | None = None, + alibi_slopes: torch.Tensor | None = None, + block_table: torch.Tensor | None = None, + num_splits: int | None = None, +): + _, lse, _, _, _ = _fa4_flash_attention_forward_impl( + query, + key, + value, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + return_debug_mask, + scale=scale, + window_size_left=window_size_left, + window_size_right=window_size_right, + seqused_k=seqused_k, + alibi_slopes=alibi_slopes, + out=out, + block_table=block_table, + compute_auxiliary=False, + num_splits=num_splits, + ) + return lse + + def _fa4_flash_attention_backward_impl( grad_out: torch.Tensor, query: torch.Tensor, @@ -318,8 +401,6 @@ def _fa4_flash_attention_backward_impl( logsumexp, dropout_p, cum_seq_q, - window_size_left, - window_size_right, ) if error is not None: raise RuntimeError(f"FA4 flash_attention backward unsupported: {error}") @@ -335,6 +416,8 @@ def _fa4_flash_attention_backward_impl( cum_seq_k, scale, is_causal, + window_size_left, + window_size_right, deterministic, ) return dq, dk, dv @@ -428,8 +511,6 @@ def _fa4_scaled_dot_product_flash_attention_backward_impl( logsumexp, dropout_p, None, - None, - None, ) if error is not None: raise RuntimeError(f"FA4 SDPA backward unsupported: {error}") diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index 7365e726b055c..2e41ad96fa931 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -1,5 +1,4 @@ # mypy: allow-untyped-defs -# flake8: noqa: B950 """This module implements the user facing API for flex_attention in PyTorch.""" from __future__ import annotations @@ -9,6 +8,7 @@ import itertools import math import operator +import types import typing import warnings from collections.abc import Callable @@ -21,11 +21,18 @@ from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch._higher_order_ops.utils import setup_compilation_env from torch.nn.attention._utils import _validate_sdpa_input -from torch.utils._pytree import GetAttrKey, tree_map_only +from torch.utils._pytree import ( + GetAttrKey, + tree_flatten, + tree_map_only, + tree_unflatten, + TreeSpec, +) if typing.TYPE_CHECKING: from torch._prims_common import DeviceLikeType + from torch.fx.node import BaseArgumentTypes # Private debug flag to disable internal compilation wrapping for debugging purposes. @@ -442,49 +449,228 @@ def _adjust_num_blocks_and_indices( return num_blocks, indices -def _closure_contents(fn: object) -> tuple[object, ...]: - """Extract closure cell contents for comparison.""" - closure = getattr(fn, "__closure__", None) - if closure is None: - return () - return tuple(cell.cell_contents for cell in closure) +# TreeSpec for an empty tuple — used as the sentinel when there are no closure leaves. +_EMPTY_CLOSURE_SPEC = tree_flatten(())[1] + + +class _ExtractedLeaf: + """Sentinel in _StrippedClosure.leaf_entries marking a position that is + filled from the extracted pytree leaves list during reconstruction.""" + + __slots__ = () + + def __repr__(self) -> str: + return "_EXTRACTED_LEAF" + + +_EXTRACTED_LEAF = _ExtractedLeaf() + + +class _FunctionLeaf(typing.NamedTuple): + """Entry in _StrippedClosure.leaf_entries for a recursively processed + function. Stores enough information to reconstruct the function from + the extracted leaves during unflattening.""" + + stripped: _StrippedClosure | Callable[..., Any] + closure_spec: TreeSpec + n_extracted: int # number of extracted leaves this function contributes + + +class _StrippedClosure(typing.NamedTuple): + """Data container holding the parts of a function needed for reconstruction. + + Created by _extract_closure_pytree when closure tensors are lifted into + pytree leaves. Unlike a FunctionType with None-filled cells, this is not + callable — it is pure data stored in the pytree context. + """ + + code: types.CodeType + globals_dict: dict[str, Any] + name: str + qualname: str + defaults: tuple[Any, ...] | None + kwdefaults: dict[str, Any] | None + extra_dict: dict[str, Any] + # Per-position info for the closure's flattened leaves. + # _EXTRACTED_LEAF → position filled from the extracted leaves list. + # _FunctionLeaf → recursively processed function (not a valid pytree leaf). + leaf_entries: tuple[_ExtractedLeaf | _FunctionLeaf, ...] + + +def _extract_closure_pytree( + fn, _seen: set[int] | None = None +) -> tuple[ + tuple[BaseArgumentTypes, ...], TreeSpec, _StrippedClosure | Callable[..., Any] +]: + """Extract closure contents as a flattened sub-pytree. + + Returns (extracted_leaves, closure_spec, fn_or_stripped) where: + - extracted_leaves: flattened non-function contents from the closure, + plus any tensors/scalars recursively extracted from nested function + closures + - closure_spec: TreeSpec describing how to reconstruct the closure contents + - fn_or_stripped: either the original fn (no extraction) or a + _StrippedClosure carrying the function parts needed for reconstruction + + Functions found among the closure leaves are recursively processed: their + own closure tensors are extracted into the leaves list, and their skeleton + is stored in _StrippedClosure.leaf_entries as a _FunctionLeaf. All other + values (tensors, scalars, None, etc.) remain as extracted leaves. + + If fn is not a plain function, has no closure, or has empty cells, returns + the original function unchanged with no closure leaves. + + Skipped under Dynamo tracing (torch.compiler.is_compiling) because Dynamo + can't trace through closure cell introspection and handles freevars via its + own lifting mechanism. + """ + if not inspect.isfunction(fn) or torch.compiler.is_compiling(): + return (), _EMPTY_CLOSURE_SPEC, fn + + # Cycle detection for self-referencing closures. + if _seen is None: + _seen = set() + if id(fn) in _seen: + return (), _EMPTY_CLOSURE_SPEC, fn + _seen.add(id(fn)) + + closure = fn.__closure__ + if not closure: + return (), _EMPTY_CLOSURE_SPEC, fn + + try: + contents = tuple(cell.cell_contents for cell in closure) + except ValueError: + # Empty cell (created but not yet assigned) — can't extract + return (), _EMPTY_CLOSURE_SPEC, fn + + closure_leaves, closure_spec = tree_flatten(contents) + + extracted: list[BaseArgumentTypes] = [] + leaf_entries: list[_ExtractedLeaf | _FunctionLeaf] = [] + for leaf in closure_leaves: + if inspect.isfunction(leaf): + child_extracted, child_spec, child_stripped = _extract_closure_pytree( + leaf, _seen + ) + extracted.extend(child_extracted) + leaf_entries.append( + _FunctionLeaf(child_stripped, child_spec, len(child_extracted)) + ) + else: + extracted.append(leaf) + leaf_entries.append(_EXTRACTED_LEAF) + + stripped = _StrippedClosure( + code=fn.__code__, + globals_dict=fn.__globals__, + name=fn.__name__, + qualname=fn.__qualname__, + defaults=fn.__defaults__, + kwdefaults=fn.__kwdefaults__, + extra_dict=dict(fn.__dict__) if fn.__dict__ else {}, + leaf_entries=tuple(leaf_entries), + ) + + return tuple(extracted), closure_spec, stripped + + +def _reconstruct_closure_fn(stripped, extracted_leaves, closure_spec): + """Rebuild a function from a _StrippedClosure and flattened extracted leaves.""" + if not isinstance(stripped, _StrippedClosure): + return stripped + + all_leaves: list[BaseArgumentTypes | Callable[..., Any]] = [] + idx = 0 + for entry in stripped.leaf_entries: + if isinstance(entry, _FunctionLeaf): + child_fn = _reconstruct_closure_fn( + entry.stripped, + extracted_leaves[idx : idx + entry.n_extracted], + entry.closure_spec, + ) + all_leaves.append(child_fn) + idx += entry.n_extracted + else: + # _EXTRACTED_LEAF — take from extracted leaves + all_leaves.append(extracted_leaves[idx]) + idx += 1 + + contents = tree_unflatten(all_leaves, closure_spec) + new_cells = tuple(types.CellType(v) for v in contents) + + restored = types.FunctionType( + stripped.code, + stripped.globals_dict, + stripped.name, + stripped.defaults, + new_cells, + ) + restored.__qualname__ = stripped.qualname + if stripped.kwdefaults: + restored.__kwdefaults__ = stripped.kwdefaults + if stripped.extra_dict: + restored.__dict__.update(stripped.extra_dict) + + return restored class _MaskModWrapper: - """Wraps a mask_mod function with value-based equality. + """Wraps a mask_mod or _StrippedClosure with value-based equality. BlockMask stores an arbitrary callable (mask_mod) in its pytree context. The default __eq__ for functions uses identity comparison, which is too strict when the same closure is recreated (e.g., defined inside forward()). - This wrapper compares functions by their code object and closure contents. + + When closure tensors have been extracted (by _extract_closure_pytree), fn + is a _StrippedClosure (pure data, not callable). Equality compares the + code objects + closure_spec — no tensor dispatch is triggered. + + When extraction is skipped (e.g., under Dynamo), fn is the original + callable and equality compares code objects + closure contents (for plain + functions) or delegates to __eq__ (for callable objects). """ - __slots__ = ("fn",) + __slots__ = ("fn", "closure_spec") - def __init__(self, fn: _mask_mod_signature) -> None: + def __init__(self, fn, closure_spec=None) -> None: self.fn = fn + self.closure_spec = closure_spec def __call__(self, b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor: + if isinstance(self.fn, _StrippedClosure): + raise RuntimeError( + "_MaskModWrapper with _StrippedClosure is not callable — " + "use _reconstruct_closure_fn to rebuild the function first" + ) return self.fn(b, h, q_idx, kv_idx) def __eq__(self, other: object) -> bool: if not isinstance(other, _MaskModWrapper): return False - if self.fn is other.fn: + if self.fn is other.fn and self.closure_spec is other.closure_spec: return True - if ( - inspect.isfunction(self.fn) - and inspect.isfunction(other.fn) - and self.fn.__code__ == other.fn.__code__ - and _closure_contents(self.fn) == _closure_contents(other.fn) + # Extracted case: _StrippedClosure — compare code + closure_spec + if isinstance(self.fn, _StrippedClosure) and isinstance( + other.fn, _StrippedClosure + ): + return ( + self.fn.code == other.fn.code + and self.closure_spec == other.closure_spec + ) + # Non-extracted plain functions: compare code + closure contents + if inspect.isfunction(self.fn) and inspect.isfunction(other.fn): + return self.fn.__code__ == other.fn.__code__ + # Callable objects: delegate to their __eq__ + if not isinstance(self.fn, _StrippedClosure) and not isinstance( + other.fn, _StrippedClosure ): - return True - # For callable objects (not plain functions), delegate to their __eq__ - if not inspect.isfunction(self.fn) and not inspect.isfunction(other.fn): return self.fn == other.fn return False def __hash__(self) -> int: + if isinstance(self.fn, _StrippedClosure): + return hash(self.fn.code) if inspect.isfunction(self.fn): return hash(self.fn.__code__) return hash(self.fn) @@ -1043,30 +1229,43 @@ def _unwrap_context_value(attr: str, value: Any) -> Any: def _flatten( self, - ) -> tuple[tuple[Tensor | None, ...], tuple[Any, ...]]: + ) -> tuple[tuple[BaseArgumentTypes | None, ...], tuple[Any, ...]]: """Flatten BlockMask into a list of tensors and context. - Wraps mask_mod in _MaskModWrapper for value-based comparison in TreeSpec. + Closure tensors from mask_mod are extracted into the leaves via + _extract_closure_pytree so they are visible to the tracing + infrastructure (instead of being hidden in the pytree context). """ tensors = tuple(getattr(self, attr) for attr in self._TENSOR_ATTRS) + closure_leaves, closure_spec, stripped = _extract_closure_pytree(self.mask_mod) + all_leaves = tensors + closure_leaves context = tuple( self._wrap_context_value(attr, getattr(self, attr)) + if attr != "mask_mod" + else _MaskModWrapper(stripped, closure_spec) for attr in self._CONTEXT_ATTRS ) - return tensors, context + return all_leaves, context @classmethod def _unflatten( cls, - tensors: tuple[Tensor | None, ...], + leaves: tuple[Any, ...], context: tuple[Any, ...], ) -> Self: - """Unflatten tensors and context back into a BlockMask.""" - kwargs = { - attr: cls._unwrap_context_value(attr, val) - for attr, val in zip(cls._CONTEXT_ATTRS, context) - } - kwargs.update(zip(cls._TENSOR_ATTRS, tensors)) + """Unflatten leaves and context back into a BlockMask.""" + n_regular = len(cls._TENSOR_ATTRS) + regular_leaves = leaves[:n_regular] + closure_leaves = leaves[n_regular:] + kwargs = {} + for attr, val in zip(cls._CONTEXT_ATTRS, context): + if attr == "mask_mod" and isinstance(val, _MaskModWrapper): + kwargs[attr] = _reconstruct_closure_fn( + val.fn, closure_leaves, val.closure_spec + ) + else: + kwargs[attr] = cls._unwrap_context_value(attr, val) + kwargs.update(zip(cls._TENSOR_ATTRS, regular_leaves)) return cls(**kwargs) def _flatten_with_keys( @@ -1074,16 +1273,25 @@ def _flatten_with_keys( ) -> tuple[tuple[tuple[GetAttrKey, Any], ...], tuple[tuple[GetAttrKey, Any], ...]]: """Flatten BlockMask with keys for better tracing. - Wraps mask_mod in _MaskModWrapper for value-based comparison in TreeSpec. + Closure tensors from mask_mod are extracted into the leaves via + _extract_closure_pytree so they are visible to the tracing + infrastructure (instead of being hidden in the pytree context). """ tensors = tuple( (GetAttrKey(attr), getattr(self, attr)) for attr in self._TENSOR_ATTRS ) + closure_leaves, closure_spec, stripped = _extract_closure_pytree(self.mask_mod) + closure_with_keys = tuple( + (GetAttrKey(f"_closure_{i}"), leaf) for i, leaf in enumerate(closure_leaves) + ) + all_leaves = tensors + closure_with_keys context = tuple( (GetAttrKey(attr), self._wrap_context_value(attr, getattr(self, attr))) + if attr != "mask_mod" + else (GetAttrKey(attr), _MaskModWrapper(stripped, closure_spec)) for attr in self._CONTEXT_ATTRS ) - return tensors, context + return all_leaves, context def _broadcast_to_dim(x: Tensor, dim: int) -> Tensor: diff --git a/torch/nn/attention/varlen.py b/torch/nn/attention/varlen.py index 4401754e9395b..96cfc332b8699 100644 --- a/torch/nn/attention/varlen.py +++ b/torch/nn/attention/varlen.py @@ -54,6 +54,7 @@ def _varlen_attn( is_causal: bool = False, scale: float | None = None, window_size: list[int] | None = None, + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -70,6 +71,9 @@ def _varlen_attn( if use_cudnn: log.info("Using cuDNN backend for varlen_attn") + if enable_gqa: + # TODO: check this + raise RuntimeError("GQA is not supported with the cuDNN backend.") if num_splits is not None: # TODO: check this raise RuntimeError("num_splits is not supported with the cuDNN backend.") @@ -140,6 +144,7 @@ def _varlen_attn_fake( is_causal: bool = False, scale: float | None = None, window_size: list[int] | None = None, + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -159,16 +164,18 @@ def _varlen_attn_fake( # For varlen path: logsumexp shape is (num_heads, total_q) total_q = query.size(0) num_heads = query.size(1) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) + if torch.version.hip: - # ROCm uses batched format: [batch_size, num_heads, max_q] - batch_size = cu_seq_q.size(0) - 1 - logsumexp = torch.empty( - (batch_size, num_heads, max_q), dtype=torch.float, device=query.device - ) - else: - logsumexp = torch.empty( - (num_heads, total_q), dtype=torch.float, device=query.device - ) + preferred = torch._C._get_rocm_fa_preferred_backend() + if preferred == torch._C._ROCmFABackend.AOTriton: + # AOTriton ROCm path uses batched 3D + batch_size = cu_seq_q.size(0) - 1 + logsumexp = torch.empty( + (batch_size, num_heads, max_q), dtype=torch.float, device=query.device + ) rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device) @@ -187,6 +194,7 @@ def varlen_attn( return_aux: AuxRequest | None = None, scale: float | None = None, window_size: tuple[int, int] = (-1, -1), + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -197,11 +205,11 @@ def varlen_attn( variable-length sequences using cumulative sequence position tensors. Args: - query (Tensor): Query tensor; shape :math:`(T_q, H, D)` - key (Tensor): Key tensor; shape :math:`(T_k, H, D)`, or - :math:`(\text{total\_pages}, \text{page\_size}, H, D)` when ``block_table`` is provided. - value (Tensor): Value tensor; shape :math:`(T_k, H, D)`, or - :math:`(\text{total\_pages}, \text{page\_size}, H, D)` when ``block_table`` is provided. + query (Tensor): Query tensor; shape :math:`(T_q, H_q, D)` + key (Tensor): Key tensor; shape :math:`(T_k, H_{kv}, D)`, or + :math:`(\text{total\_pages}, \text{page\_size}, H_{kv}, D)` when ``block_table`` is provided. + value (Tensor): Value tensor; shape :math:`(T_k, H_{kv}, D)`, or + :math:`(\text{total\_pages}, \text{page\_size}, H_{kv}, D)` when ``block_table`` is provided. cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)` cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)` max_q (int): Maximum query sequence length in the batch. @@ -211,6 +219,11 @@ def varlen_attn( window_size (tuple[int, int], optional): Window size for sliding window attention as (left, right). Use (-1, -1) for full attention (default), (-1, 0) for causal attention, or (W, 0) for causal attention with sliding window of size W. + enable_gqa (bool): If set to True, enables Grouped Query Attention (GQA) + and allows key/value to have fewer heads than query. + Each KV head is shared by a group of :math:`H_q / H_{kv}` query heads, + so :math:`H_q` must be divisible by :math:`H_{kv}`. + Default is False. seqused_k (Tensor, optional): Number of valid KV tokens per batch element; shape :math:`(N,)`. When set, only the first ``seqused_k[i]`` tokens in the key/value sequence for batch element *i* participate in attention. Useful for KV-cache decoding where the cache slot @@ -239,16 +252,17 @@ def varlen_attn( ``None`` (default), the kernel chooses automatically. Returns: - output (Tensor): Output tensor from attention computation; shape :math:`(T_q, H, D)`. + output (Tensor): Output tensor from attention computation; shape :math:`(T_q, H_q, D)`. If ``return_aux`` is not None and ``return_aux.lse`` is True: - lse (Tensor): Log-sum-exp of attention scores; shape :math:`(T_q, H)`. + lse (Tensor): Log-sum-exp of attention scores; shape :math:`(T_q, H_q)`. Shape legend: - :math:`N`: Batch size - :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths) - :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths) - - :math:`H`: Number of attention heads + - :math:`H_q`: Number of query attention heads + - :math:`H_{kv}`: Number of key/value attention heads (equal to :math:`H_q` unless GQA is enabled) - :math:`D`: Head dimension Example:: @@ -285,6 +299,20 @@ def varlen_attn( ... ) """ + num_heads_q = query.size(1) + num_heads_k = key.size(2) if block_table is not None else key.size(1) + if not enable_gqa and num_heads_q != num_heads_k: + raise ValueError( + f"Expect query and key/value to have the same number of heads " + f"but got Hq={num_heads_q} and Hkv={num_heads_k}. " + f"Try setting enable_gqa=True for GQA." + ) + if enable_gqa and num_heads_q % num_heads_k != 0: + raise ValueError( + f"Expect number of query heads to be a multiple of kv heads for GQA " + f"but got Hq={num_heads_q} and Hkv={num_heads_k}." + ) + is_causal = window_size == (-1, 0) out, lse, _ = torch.ops.torch_attn._varlen_attn( query, @@ -297,6 +325,7 @@ def varlen_attn( is_causal, scale, list(window_size), + enable_gqa, seqused_k, block_table, num_splits, @@ -319,6 +348,7 @@ def _varlen_attn_out( is_causal: bool = False, scale: float | None = None, window_size: list[int] | None = None, + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -372,6 +402,7 @@ def _varlen_attn_out_fake( is_causal: bool = False, scale: float | None = None, window_size: list[int] | None = None, + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -409,6 +440,7 @@ def varlen_attn_out( return_aux: AuxRequest | None = None, scale: float | None = None, window_size: tuple[int, int] = (-1, -1), + enable_gqa: bool = False, seqused_k: torch.Tensor | None = None, block_table: torch.Tensor | None = None, num_splits: int | None = None, @@ -419,6 +451,20 @@ def varlen_attn_out( instead of allocating a new one. """ + num_heads_q = query.size(1) + num_heads_k = key.size(2) if block_table is not None else key.size(1) + if not enable_gqa and num_heads_q != num_heads_k: + raise ValueError( + f"Expect query and key/value to have the same number of heads " + f"but got Hq={num_heads_q} and Hkv={num_heads_k}. " + f"Try setting enable_gqa=True for GQA." + ) + if enable_gqa and num_heads_q % num_heads_k != 0: + raise ValueError( + f"Expect number of query heads to be a multiple of kv heads for GQA " + f"but got Hq={num_heads_q} and Hkv={num_heads_k}." + ) + is_causal = window_size == (-1, 0) lse = torch.ops.torch_attn._varlen_attn_out( out, @@ -432,6 +478,7 @@ def varlen_attn_out( is_causal, scale, list(window_size), + enable_gqa, seqused_k, block_table, num_splits, @@ -453,6 +500,7 @@ def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None: is_causal, scale, window_size, + enable_gqa, seqused_k, block_table, num_splits, @@ -598,7 +646,9 @@ def _backward( scale, window_size, ) - num_params = 10 # cu_seq_q, cu_seq_k, max_q, max_k, is_causal, scale, window_size, seqused_k, block_table, num_splits + # cu_seq_q, cu_seq_k, max_q, max_k, is_causal, scale, window_size, \ + # enable_gqa, seqused_k, block_table, num_splits + num_params = 11 return (dq, dk, dv, *((None,) * num_params)) @@ -607,3 +657,15 @@ def _backward( torch._dynamo.disallow_in_graph( torch.ops.aten._flash_attention_forward_no_dropout_inplace ) + +from torch.utils.flop_counter import ( + _varlen_attn_backward_flop, + _varlen_attn_forward_flop, + _varlen_attn_out_flop, + flop_registry, +) + + +flop_registry[torch.ops.torch_attn._varlen_attn] = _varlen_attn_forward_flop +flop_registry[torch.ops.torch_attn._varlen_attn_out] = _varlen_attn_out_flop +flop_registry[torch.ops.torch_attn._varlen_attn_backward] = _varlen_attn_backward_flop diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 3f4c51890524a..3db2df066619b 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -10,7 +10,6 @@ from torch import _VF, sym_int as _sym_int, Tensor from torch._C import ( _add_docstr, - _infer_size, _ScalingType as ScalingType, # pyrefly: ignore [missing-module-attribute] _SwizzleType as SwizzleType, # pyrefly: ignore [missing-module-attribute] ) @@ -146,7 +145,7 @@ >>> inputs = torch.randn(1, 4, 5, 5) >>> F.conv2d(inputs, filters, padding=1) """, -) # noqa: E501 +) conv3d = _add_docstr( torch.conv3d, @@ -196,7 +195,7 @@ >>> inputs = torch.randn(20, 16, 50, 10, 20) >>> F.conv3d(inputs, filters) """, -) # noqa: E501 +) conv_transpose1d = _add_docstr( torch.conv_transpose1d, @@ -280,7 +279,7 @@ >>> weights = torch.randn(4, 8, 3, 3) >>> F.conv_transpose2d(inputs, weights, padding=1) """, -) # noqa: E501 +) conv_transpose3d = _add_docstr( torch.conv_transpose3d, @@ -322,7 +321,7 @@ >>> weights = torch.randn(16, 33, 3, 3, 3) >>> F.conv_transpose3d(inputs, weights) """, -) # noqa: E501 +) conv_tbc = _add_docstr( torch.conv_tbc, @@ -443,7 +442,7 @@ def fractional_max_pool2d_with_indices( output_ratio: Optional[BroadcastingList2[float]] = None, # noqa: UP045 return_indices: bool = False, _random_samples: Tensor | None = None, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" fractional_max_pool2d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) @@ -556,7 +555,7 @@ def fractional_max_pool3d_with_indices( output_ratio: Optional[BroadcastingList3[float]] = None, # noqa: UP045 return_indices: bool = False, _random_samples: Tensor | None = None, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" fractional_max_pool3d(input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None) @@ -674,7 +673,7 @@ def max_pool1d_with_indices( dilation: BroadcastingList1[int] = 1, ceil_mode: bool = False, return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -764,7 +763,7 @@ def max_pool2d_with_indices( dilation: BroadcastingList2[int] = 1, ceil_mode: bool = False, return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -854,7 +853,7 @@ def max_pool3d_with_indices( dilation: BroadcastingList3[int] = 1, ceil_mode: bool = False, return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False) @@ -1201,7 +1200,7 @@ def adaptive_max_pool1d_with_indices( input: Tensor, output_size: BroadcastingList1[int], return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" adaptive_max_pool1d(input, output_size, return_indices=False) @@ -1256,7 +1255,7 @@ def adaptive_max_pool2d_with_indices( input: Tensor, output_size: BroadcastingList2[int], return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r"""adaptive_max_pool2d(input, output_size, return_indices=False) Applies a 2D adaptive max pooling over an input signal composed of @@ -1313,7 +1312,7 @@ def adaptive_max_pool3d_with_indices( input: Tensor, output_size: BroadcastingList3[int], return_indices: bool = False, -) -> tuple[Tensor, Tensor]: # noqa: D400 +) -> tuple[Tensor, Tensor]: r""" adaptive_max_pool3d(input, output_size, return_indices=False) @@ -1709,7 +1708,7 @@ def _threshold( ) -def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 +def relu(input: Tensor, inplace: bool = False) -> Tensor: r"""relu(input, inplace=False) -> Tensor Applies the rectified linear unit function element-wise. See @@ -1734,7 +1733,7 @@ def relu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 ) -def glu(input: Tensor, dim: int = -1) -> Tensor: # noqa: D400,D402 +def glu(input: Tensor, dim: int = -1) -> Tensor: r""" glu(input, dim=-1) -> Tensor @@ -1766,7 +1765,7 @@ def hardtanh( min_val: float = -1.0, max_val: float = 1.0, inplace: bool = False, -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r""" hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor @@ -1796,7 +1795,7 @@ def hardtanh( ) -def relu6(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 +def relu6(input: Tensor, inplace: bool = False) -> Tensor: r"""relu6(input, inplace=False) -> Tensor Applies the element-wise function :math:`\text{ReLU6}(x) = \min(\max(0,x), 6)`. @@ -1836,7 +1835,7 @@ def elu(input: Tensor, alpha: float = 1.0, inplace: bool = False) -> Tensor: ) -def selu(input: Tensor, inplace: bool = False) -> Tensor: # noqa: D400,D402 +def selu(input: Tensor, inplace: bool = False) -> Tensor: r"""selu(input, inplace=False) -> Tensor Applies element-wise, @@ -1869,7 +1868,7 @@ def celu( input: Tensor, alpha: float = 1.0, inplace: bool = False, -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""celu(input, alpha=1., inplace=False) -> Tensor Applies element-wise, @@ -1902,7 +1901,7 @@ def leaky_relu( input: Tensor, negative_slope: float = 0.01, inplace: bool = False, -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r""" leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor @@ -1959,7 +1958,7 @@ def rrelu( upper: float = 1.0 / 3, training: bool = False, inplace: bool = False, -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""rrelu(input, lower=1./8, upper=1./3, training=False, inplace=False) -> Tensor Randomized leaky ReLU. @@ -2034,7 +2033,7 @@ def rrelu( ) -def tanhshrink(input): # noqa: D400,D402 +def tanhshrink(input): r"""tanhshrink(input) -> Tensor Applies element-wise, :math:`\text{Tanhshrink}(x) = x - \text{Tanh}(x)` @@ -2046,7 +2045,7 @@ def tanhshrink(input): # noqa: D400,D402 return input - input.tanh() -def softsign(input): # noqa: D400,D402 +def softsign(input): r"""softsign(input) -> Tensor Applies element-wise, the function :math:`\text{SoftSign}(x) = \frac{x}{1 + |x|}` @@ -2282,7 +2281,7 @@ def log_softmax( ) -def tanh(input): # noqa: D400,D402 +def tanh(input): r"""tanh(input) -> Tensor Applies element-wise, @@ -2293,7 +2292,7 @@ def tanh(input): # noqa: D400,D402 return input.tanh() -def sigmoid(input): # noqa: D400,D402 +def sigmoid(input): r"""sigmoid(input) -> Tensor Applies the element-wise function :math:`\text{Sigmoid}(x) = \frac{1}{1 + \exp(-x)}` @@ -3570,7 +3569,7 @@ def binary_cross_entropy( ) if weight is not None: - new_size = _infer_size(target.size(), weight.size()) + new_size = torch.broadcast_shapes(target.size(), weight.size()) weight = weight.expand(new_size) # pyrefly: ignore [bad-argument-type] @@ -3813,7 +3812,7 @@ def l1_loss( reduce: bool | None = None, reduction: str = "mean", weight: Tensor | None = None, -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the L1 loss, with optional weighting. Function that takes the mean element-wise absolute value difference. @@ -3968,7 +3967,7 @@ def margin_ranking_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the margin ranking loss. See :class:`~torch.nn.MarginRankingLoss` for details. @@ -4018,7 +4017,7 @@ def hinge_embedding_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the hinge embedding loss. See :class:`~torch.nn.HingeEmbeddingLoss` for details. @@ -4061,7 +4060,7 @@ def multilabel_margin_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the multilabel margin loss. See :class:`~torch.nn.MultiLabelMarginLoss` for details. @@ -4103,7 +4102,7 @@ def soft_margin_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the soft margin loss. See :class:`~torch.nn.SoftMarginLoss` for details. @@ -4146,7 +4145,7 @@ def multilabel_soft_margin_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the multilabel soft margin loss. See :class:`~torch.nn.MultiLabelSoftMarginLoss` for details. @@ -4207,7 +4206,7 @@ def cosine_embedding_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the cosine embedding loss. See :class:`~torch.nn.CosineEmbeddingLoss` for details. @@ -4255,7 +4254,7 @@ def multi_margin_loss( size_average: bool | None = None, reduce: bool | None = None, reduction: str = "mean", -) -> Tensor: # noqa: D400,D402 +) -> Tensor: r"""Compute the multi margin loss, with optional weighting. See :class:`~torch.nn.MultiMarginLoss` for details. @@ -4446,14 +4445,14 @@ def multi_margin_loss( @_overload -def upsample( # noqa: F811 +def upsample( input: Tensor, size: int | None = None, scale_factor: float | None = None, mode: str = "nearest", align_corners: bool | None = None, # pyrefly: ignore [bad-return] -) -> Tensor: # noqa: B950 +) -> Tensor: pass @@ -4465,7 +4464,7 @@ def upsample( # noqa: F811 mode: str = "nearest", align_corners: bool | None = None, # pyrefly: ignore [bad-return] -) -> Tensor: # noqa: B950 +) -> Tensor: pass @@ -4497,7 +4496,7 @@ def upsample( # noqa: F811 `mini-batch x channels x [optional depth] x [optional height] x width`. The modes available for upsampling are: `nearest`, `linear` (3D-only), - `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only) + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `lanczos` (4D-only) Args: input (Tensor): the input tensor @@ -4506,7 +4505,7 @@ def upsample( # noqa: F811 scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple. mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'``. Default: ``'nearest'`` + ``'trilinear'`` | ``'lanczos'``. Default: ``'nearest'`` align_corners (bool, optional): Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input and output tensors are aligned by the @@ -4519,7 +4518,7 @@ def upsample( # noqa: F811 Default: ``False`` .. note:: - With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + With ``mode='bicubic'`` or ``mode='lanczos'``, it's possible to cause overshoot, in other words it can produce negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot when displaying the image. @@ -4559,7 +4558,7 @@ def _is_integer(x) -> bool: @_overload -def interpolate( # noqa: F811 +def interpolate( input: Tensor, size: int | None = None, scale_factor: list[float] | None = None, @@ -4568,7 +4567,7 @@ def interpolate( # noqa: F811 recompute_scale_factor: bool | None = None, antialias: bool = False, # pyrefly: ignore [bad-return] -) -> Tensor: # noqa: B950 +) -> Tensor: pass @@ -4582,7 +4581,7 @@ def interpolate( # noqa: F811 recompute_scale_factor: bool | None = None, antialias: bool = False, # pyrefly: ignore [bad-return] -) -> Tensor: # noqa: B950 +) -> Tensor: pass @@ -4596,7 +4595,7 @@ def interpolate( # noqa: F811 recompute_scale_factor: bool | None = None, antialias: bool = False, # pyrefly: ignore [bad-return] -) -> Tensor: # noqa: B950 +) -> Tensor: pass @@ -4622,7 +4621,7 @@ def interpolate( # noqa: F811 align_corners: bool | None = None, recompute_scale_factor: bool | None = None, antialias: bool = False, -) -> Tensor: # noqa: B950 +) -> Tensor: r"""Down/up samples the input. Tensor interpolated to either the given :attr:`size` or the given @@ -4637,7 +4636,7 @@ def interpolate( # noqa: F811 `mini-batch x channels x [optional depth] x [optional height] x width`. The modes available for resizing are: `nearest`, `linear` (3D-only), - `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area`, `nearest-exact` + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `lanczos` (4D-only, CPU only), `area`, `nearest-exact` Args: input (Tensor): the input tensor @@ -4647,7 +4646,7 @@ def interpolate( # noqa: F811 its length has to match the number of spatial dimensions; `input.dim() - 2`. mode (str): algorithm used for upsampling: ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | - ``'trilinear'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` + ``'trilinear'`` | ``'lanczos'`` | ``'area'`` | ``'nearest-exact'``. Default: ``'nearest'`` align_corners (bool, optional): Geometrically, we consider the pixels of the input and output as squares rather than points. If set to ``True``, the input and output tensors are aligned by the @@ -4669,14 +4668,19 @@ def interpolate( # noqa: F811 be used directly for interpolation. Default: ``None``. antialias (bool, optional): flag to apply anti-aliasing. Default: ``False``. Using anti-alias option together with ``align_corners=False``, interpolation result would match Pillow - result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``. + result for downsampling operation. Supported modes: ``'bilinear'``, ``'bicubic'``, ``'lanczos'``. .. note:: - With ``mode='bicubic'``, it's possible to cause overshoot. For some dtypes, it can produce + With ``mode='bicubic'`` or ``mode='lanczos'``, it's possible to cause overshoot. For some dtypes, it can produce negative values or values greater than 255 for images. Explicitly call ``result.clamp(min=0,max=255)`` if you want to reduce the overshoot when displaying the image. For ``uint8`` inputs, it already performs saturating cast operation. So, no manual `clamp` operation is needed. + .. note:: + Mode ``mode='lanczos'`` uses a Lanczos-3 windowed sinc filter (6 taps) and requires + ``antialias=True``. It only supports 4-D input (i.e. 2D spatial) and CPU. With ``antialias=True`` + and ``align_corners=False``, the result matches PIL's ``Image.LANCZOS`` resampling filter. + .. note:: Mode ``mode='nearest-exact'`` matches Scikit-Image and PIL nearest neighbours interpolation algorithms and fixes known issues with ``mode='nearest'``. This mode is introduced to keep @@ -4807,9 +4811,11 @@ def interpolate( # noqa: F811 ] scale_factors = None - if antialias and not (mode in ("bilinear", "bicubic") and input.ndim == 4): + if antialias and not ( + mode in ("bilinear", "bicubic", "lanczos") and input.ndim == 4 + ): raise ValueError( - "Anti-alias option is restricted to bilinear and bicubic modes and requires a 4-D tensor as input" + "Anti-alias option is restricted to bilinear, bicubic, and lanczos modes and requires a 4-D tensor as input" ) if input.dim() == 3 and mode == "nearest": @@ -4923,6 +4929,21 @@ def interpolate( # noqa: F811 scale_factors, ) + if input.dim() == 4 and mode == "lanczos": + if align_corners is None: + raise AssertionError("align_corners is unexpectedly None") + if align_corners: + raise ValueError("Lanczos mode does not support align_corners=True") + if not antialias: + raise ValueError("Lanczos mode requires antialias=True") + return torch._C._nn._upsample_lanczos2d_aa( + input, + # pyrefly: ignore [bad-argument-type] + output_size, + align_corners, + scale_factors, + ) + if input.dim() == 3 and mode == "bilinear": raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") if input.dim() == 3 and mode == "trilinear": @@ -4938,7 +4959,7 @@ def interpolate( # noqa: F811 raise NotImplementedError( "Input Error: Only 3D, 4D and 5D input Tensors supported" - f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact" + f" (got {input.dim()}D) for the modes: nearest | linear | bilinear | bicubic | trilinear | lanczos | area | nearest-exact" f" (got {mode})" ) @@ -4948,7 +4969,7 @@ def interpolate( # noqa: F811 @_overload -def upsample_nearest( # noqa: F811 +def upsample_nearest( input: Tensor, size: int | None = None, scale_factor: float | None = None, @@ -5000,7 +5021,7 @@ def upsample_nearest(input, size=None, scale_factor=None): # noqa: F811 @_overload -def upsample_bilinear( # noqa: F811 +def upsample_bilinear( input: Tensor, size: int | None = None, scale_factor: float | None = None, @@ -5425,6 +5446,13 @@ def pad( ): if mode == "replicate": # Use slow decomp whose backward will be in terms of index_put. + if torch.compiler.is_compiling(): + # nonstrict_trace makes Dynamo skip the function body + # (which contains Dynamo-untraceable code) while + # AOTAutograd still traces into it for the backward. + return torch._dynamo.decorators.nonstrict_trace( + torch._decomp.decompositions._replication_pad + )(input, pad) # importlib is required because the import cannot be top level # (cycle) and cannot be nested (TS doesn't support) return importlib.import_module( @@ -5967,7 +5995,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) if is_causal: assert attn_mask is None - temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) if attn_mask is not None: @@ -6071,7 +6099,7 @@ def forward(self, ...): key (Tensor): Key tensor; shape :math:`(N, ..., H, S, E)`. value (Tensor): Value tensor; shape :math:`(N, ..., H, S, Ev)`. attn_mask (optional Tensor): Attention mask; shape must be broadcastable to the shape of attention weights, - which is :math:`(N,..., L, S)`. Two types of masks are supported. + which is :math:`(N,..., Hq, L, S)`. Two types of masks are supported. A boolean mask where a value of True indicates that the element *should* take part in attention. A float mask of the same type as query, key, value that is added to the attention score. dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied @@ -6553,10 +6581,10 @@ def multi_head_attention_forward( # # reshape q, k, v for multihead attention and make them batch first # - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) if static_k is None: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6570,7 +6598,7 @@ def multi_head_attention_forward( ) k = static_k if static_v is None: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) else: # TODO finish disentangling control flow so we don't do in-projections when statics are passed @@ -6643,11 +6671,15 @@ def multi_head_attention_forward( ) else: attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1)) + if not torch.jit.is_scripting(): + del q_scaled, k attn_output_weights = softmax(attn_output_weights, dim=-1) if dropout_p > 0.0: attn_output_weights = dropout(attn_output_weights, p=dropout_p) attn_output = torch.bmm(attn_output_weights, v) + if not torch.jit.is_scripting(): + del v attn_output = ( attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim) @@ -6675,15 +6707,22 @@ def multi_head_attention_forward( else: attn_mask = attn_mask.view(bsz, num_heads, -1, src_len) + # pyrefly: ignore [bad-argument-type] q = q.view(bsz, num_heads, tgt_len, head_dim) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] k = k.view(bsz, num_heads, src_len, head_dim) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] v = v.view(bsz, num_heads, src_len, head_dim) attn_output = scaled_dot_product_attention( q, k, v, attn_mask, dropout_p, is_causal ) + # Free q, k, v and their backing projection storage before the + # .contiguous() call below allocates. In self-attention the three + # tensors are views of a single packed projection, so releasing all + # references here lets the allocator reclaim that memory immediately. + if not torch.jit.is_scripting(): + del q, k, v attn_output = ( attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim) ) diff --git a/torch/nn/intrinsic/__init__.py b/torch/nn/intrinsic/__init__.py index fe9a09aa31464..d3423c5e3a4e7 100644 --- a/torch/nn/intrinsic/__init__.py +++ b/torch/nn/intrinsic/__init__.py @@ -13,10 +13,10 @@ LinearBn1d, LinearReLU, ) -from torch.ao.nn.intrinsic.modules.fused import _FusedModule # noqa: F401 +from torch.ao.nn.intrinsic.modules.fused import _FusedModule # Include the subpackages in case user imports from it directly -from torch.nn.intrinsic import modules, qat, quantized # noqa: F401 +from torch.nn.intrinsic import modules, qat, quantized __all__ = [ diff --git a/torch/nn/intrinsic/quantized/__init__.py b/torch/nn/intrinsic/quantized/__init__.py index 4c09fddf6e75f..d8a4290e763e7 100644 --- a/torch/nn/intrinsic/quantized/__init__.py +++ b/torch/nn/intrinsic/quantized/__init__.py @@ -1,6 +1,6 @@ # to ensure customers can use the module below # without importing it directly -from torch.nn.intrinsic.quantized import dynamic, modules # noqa: F401 +from torch.nn.intrinsic.quantized import dynamic, modules from torch.nn.intrinsic.quantized.modules import * # noqa: F403 diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 408e6ef42f128..59352264b3216 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -124,6 +124,7 @@ def forward( return torch.empty_like(input) @staticmethod + # pyrefly: ignore [bad-override] def backward(self, grad_output): if not ( grad_output.is_contiguous(memory_format=torch.channels_last) diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py index 658dd0a4d4fa3..4df4ff38dbecf 100644 --- a/torch/nn/modules/activation.py +++ b/torch/nn/modules/activation.py @@ -1322,7 +1322,7 @@ def forward( .. note:: `batch_first` argument is ignored for unbatched inputs. - """ # noqa: B950 + """ why_not_fast_path = "" if ( (attn_mask is not None and torch.is_floating_point(attn_mask)) @@ -1394,6 +1394,7 @@ def forward( elif torch.is_autocast_enabled(): why_not_fast_path = "autocast is enabled" + fast_path_blocked_by_tracing = False if not why_not_fast_path: tensor_args = ( query, @@ -1408,8 +1409,6 @@ def forward( # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_fast_path = "some Tensor argument has_torch_function" - elif _is_make_fx_tracing(): - why_not_fast_path = "we are running make_fx tracing" elif not all(_check_arg_device(x) for x in tensor_args): why_not_fast_path = ( "some Tensor argument's device is neither one of " @@ -1422,6 +1421,9 @@ def forward( "grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad" ) + elif _is_make_fx_tracing(): + why_not_fast_path = "we are running make_fx tracing" + fast_path_blocked_by_tracing = True if not why_not_fast_path: merged_mask, mask_type = self.merge_masks( attn_mask, key_padding_mask, query @@ -1511,7 +1513,11 @@ def forward( is_causal=is_causal, ) if self.batch_first and is_batched: - return attn_output.transpose(1, 0), attn_output_weights + attn_output = attn_output.transpose(1, 0) + if fast_path_blocked_by_tracing: + # Keep the traced slowpath layout aligned with eager fastpath. + attn_output = attn_output.contiguous() + return attn_output, attn_output_weights else: return attn_output, attn_output_weights diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 91e8de79855a3..e56a747e14190 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -44,6 +44,8 @@ def __init__( track_running_stats: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -54,7 +56,10 @@ def __init__( self.track_running_stats = track_running_stats if self.affine: self.weight = Parameter(torch.empty(num_features, **factory_kwargs)) - self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(num_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -95,7 +100,8 @@ def reset_parameters(self) -> None: self.reset_running_stats() if self.affine: init.ones_(self.weight) - init.zeros_(self.bias) + if self.bias is not None: + init.zeros_(self.bias) def _check_input_dim(self, input): raise NotImplementedError @@ -103,7 +109,9 @@ def _check_input_dim(self, input): def extra_repr(self): return ( "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " - "track_running_stats={track_running_stats}".format(**self.__dict__) + "bias={use_bias}, track_running_stats={track_running_stats}".format( + **self.__dict__, use_bias=self.bias is not None + ) ) def _load_from_state_dict( @@ -151,10 +159,18 @@ def __init__( track_running_stats: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) def forward(self, input: Tensor) -> Tensor: @@ -220,11 +236,13 @@ def __init__( track_running_stats=True, device=None, dtype=None, + *, + bias=True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} # pyrefly: ignore [bad-argument-type] super().__init__( - # affine and track_running_stats are hardcoded to False to + # affine, bias and track_running_stats are hardcoded to False to # avoid creating tensors that will soon be overwritten. 0, eps, @@ -232,14 +250,16 @@ def __init__( False, False, **factory_kwargs, + bias=False, ) self.affine = affine self.track_running_stats = track_running_stats if self.affine: # pyrefly: ignore [unexpected-keyword] self.weight = UninitializedParameter(**factory_kwargs) - # pyrefly: ignore [unexpected-keyword] - self.bias = UninitializedParameter(**factory_kwargs) + if bias: + # pyrefly: ignore # bad-argument-type + self.bias = UninitializedParameter(**factory_kwargs) if self.track_running_stats: # pyrefly: ignore [unexpected-keyword] self.running_mean = UninitializedBuffer(**factory_kwargs) @@ -266,10 +286,13 @@ def initialize_parameters(self, input) -> None: # type: ignore[override] raise AssertionError( "self.weight must be an UninitializedParameter" ) - if not isinstance(self.bias, UninitializedParameter): - raise AssertionError("self.bias must be an UninitializedParameter") self.weight.materialize((self.num_features,)) - self.bias.materialize((self.num_features,)) + if self.bias is not None: + if not isinstance(self.bias, UninitializedParameter): + raise AssertionError( + "self.bias must be an UninitializedParameter" + ) + self.bias.materialize((self.num_features,)) if self.track_running_stats: self.running_mean.materialize( # type:ignore[union-attr] (self.num_features,) @@ -335,6 +358,8 @@ class BatchNorm1d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size, @@ -381,6 +406,8 @@ class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm1d # type: ignore[assignment] @@ -447,6 +474,8 @@ class BatchNorm2d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` @@ -492,6 +521,8 @@ class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm2d # type: ignore[assignment] @@ -558,6 +589,8 @@ class BatchNorm3d(_BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` @@ -603,6 +636,8 @@ class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): buffers :attr:`running_mean` and :attr:`running_var` as ``None``. When these buffers are ``None``, this module always uses batch statistics. in both training and eval modes. Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` """ cls_to_become = BatchNorm3d # type: ignore[assignment] @@ -677,6 +712,8 @@ class SyncBatchNorm(_BatchNorm): process_group: synchronization of stats happen within each process group individually. Default behavior is synchronization across the whole world + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, +)` @@ -725,10 +762,18 @@ def __init__( process_group: Any | None = None, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) self.process_group = process_group @@ -886,6 +931,7 @@ def convert_sync_batchnorm(cls, module, process_group=None): module.affine, module.track_running_stats, process_group, + bias=module.bias is not None, ) if module.affine: with torch.no_grad(): diff --git a/torch/nn/modules/conv.py b/torch/nn/modules/conv.py index 88b23084442a2..41f8ae4cb3d46 100644 --- a/torch/nn/modules/conv.py +++ b/torch/nn/modules/conv.py @@ -49,7 +49,7 @@ In other words, for an input of size :math:`(N, C_{in}, L_{in})`, a depthwise convolution with a depthwise multiplier `K` can be performed with the arguments :math:`(C_\text{in}=C_\text{in}, C_\text{out}=C_\text{in} \times \text{K}, ..., \text{groups}=C_\text{in})`.""", -} # noqa: B950 +} class _ConvNd(Module): @@ -400,9 +400,11 @@ class Conv2d(_ConvNd): where :math:`\star` is the valid 2D `cross-correlation`_ operator, - :math:`N` is a batch size, :math:`C` denotes a number of channels, - :math:`H` is a height of input planes in pixels, and :math:`W` is - width in pixels. + :math:`N` is a batch size, :math:`C_{\text{in}}` and :math:`C_{\text{out}}` correspond to + :attr:`in_channels` and :attr:`out_channels` respectively, + :math:`H` and :math:`W` are the input height and width in pixels. + See the Shape section below for how :math:`H_{\text{out}}` and :math:`W_{\text{out}}` + are derived from :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, and :attr:`dilation`. """ + r""" diff --git a/torch/nn/modules/instancenorm.py b/torch/nn/modules/instancenorm.py index 058ffb3ed9aa9..76a343afbd0dd 100644 --- a/torch/nn/modules/instancenorm.py +++ b/torch/nn/modules/instancenorm.py @@ -28,10 +28,18 @@ def __init__( track_running_stats: bool = False, device=None, dtype=None, + *, + bias: bool = True, # for backward compatibility ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__( - num_features, eps, momentum, affine, track_running_stats, **factory_kwargs + num_features, + eps, + momentum, + affine, + track_running_stats, + **factory_kwargs, + bias=bias, ) def _check_input_dim(self, input): @@ -174,11 +182,13 @@ class InstanceNorm1d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, L)` or :math:`(C, L)` @@ -218,11 +228,13 @@ class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, L)` or :math:`(C, L)` @@ -290,11 +302,13 @@ class InstanceNorm2d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` @@ -335,11 +349,13 @@ class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)` @@ -406,11 +422,13 @@ class InstanceNorm3d(_InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` @@ -451,11 +469,13 @@ class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm): momentum: the value used for the running_mean and running_var computation. Default: 0.1 affine: a boolean value that when set to ``True``, this module has learnable affine parameters, initialized the same way as done for batch normalization. - Default: ``False``. + Default: ``False`` track_running_stats: a boolean value that when set to ``True``, this module tracks the running mean and variance, and when set to ``False``, this module does not track such statistics and always uses batch statistics in both training and eval modes. Default: ``False`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index e004ef8d14f13..965ca30f24c0a 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -72,7 +72,7 @@ def _addindent(s_, numSpaces): class _WrappedHook: - def __init__(self, hook: Callable, module: Optional["Module"] = None) -> None: # noqa: UP045 + def __init__(self, hook: Callable, module: Optional["Module"] = None) -> None: self.hook: Callable = hook functools.update_wrapper(self, hook) @@ -475,7 +475,7 @@ def forward(self, x): _load_state_dict_pre_hooks: dict[int, Callable] _state_dict_pre_hooks: dict[int, Callable] _load_state_dict_post_hooks: dict[int, Callable] - _modules: dict[str, Optional["Module"]] # noqa: UP045 + _modules: dict[str, Optional["Module"]] call_super_init: bool = False _compiled_call_impl: Callable | None = None @@ -639,7 +639,7 @@ def register_parameter(self, name: str, param: Parameter | None) -> None: param = output self._parameters[name] = param - def add_module(self, name: str, module: Optional["Module"]) -> None: # noqa: UP045 + def add_module(self, name: str, module: Optional["Module"]) -> None: r"""Add a child module to the current module. The module can be accessed as an attribute using the given name. @@ -667,7 +667,7 @@ def add_module(self, name: str, module: Optional["Module"]) -> None: # noqa: UP module = output self._modules[name] = module - def register_module(self, name: str, module: Optional["Module"]) -> None: # noqa: UP045 + def register_module(self, name: str, module: Optional["Module"]) -> None: r"""Alias for :func:`add_module`.""" self.add_module(name, module) @@ -2835,7 +2835,7 @@ def modules(self, remove_duplicate: bool = True) -> Iterator["Module"]: def named_modules( self, - memo: set["Module"] | None = None, # noqa: UP007 + memo: set["Module"] | None = None, prefix: str = "", remove_duplicate: bool = True, ): diff --git a/torch/nn/modules/normalization.py b/torch/nn/modules/normalization.py index c32178af0b82e..c4df21528cae2 100644 --- a/torch/nn/modules/normalization.py +++ b/torch/nn/modules/normalization.py @@ -142,9 +142,9 @@ class LayerNorm(Module): eps: a value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine: a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. + and zeros (for biases). Default: ``True`` bias: If set to ``False``, the layer will not learn an additive bias (only relevant if - :attr:`elementwise_affine` is ``True``). Default: ``True``. + :attr:`elementwise_affine` is ``True``). Default: ``True`` Attributes: weight: the learnable weights of the module of shape @@ -231,8 +231,8 @@ def forward(self, input: Tensor) -> Tensor: def extra_repr(self) -> str: return ( - "{normalized_shape}, eps={eps}, " - "elementwise_affine={elementwise_affine}".format(**self.__dict__) + "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}, " + "bias={use_bias}".format(**self.__dict__, use_bias=self.bias is not None) ) @@ -263,7 +263,9 @@ class GroupNorm(Module): eps: a value added to the denominator for numerical stability. Default: 1e-5 affine: a boolean value that when set to ``True``, this module has learnable per-channel affine parameters initialized to ones (for weights) - and zeros (for biases). Default: ``True``. + and zeros (for biases). Default: ``True`` + bias: If set to ``False``, the layer will not learn an additive bias (only relevant if + :attr:`affine` is ``True``). Default: ``True`` Shape: - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` @@ -296,6 +298,8 @@ def __init__( affine: bool = True, device=None, dtype=None, + *, + bias: bool = True, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -310,7 +314,10 @@ def __init__( self.affine = affine if self.affine: self.weight = Parameter(torch.empty(num_channels, **factory_kwargs)) - self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + if bias: + self.bias = Parameter(torch.empty(num_channels, **factory_kwargs)) + else: + self.register_parameter("bias", None) else: self.register_parameter("weight", None) self.register_parameter("bias", None) @@ -320,14 +327,16 @@ def __init__( def reset_parameters(self) -> None: if self.affine: init.ones_(self.weight) - init.zeros_(self.bias) + if self.bias is not None: + init.zeros_(self.bias) def forward(self, input: Tensor) -> Tensor: return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps) def extra_repr(self) -> str: - return "{num_groups}, {num_channels}, eps={eps}, affine={affine}".format( - **self.__dict__ + return ( + "{num_groups}, {num_channels}, eps={eps}, affine={affine}, " + "bias={use_bias}".format(**self.__dict__, use_bias=self.bias is not None) ) @@ -356,10 +365,10 @@ class RMSNorm(Module): If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. - eps: a value added to the denominator for numerical stability. If not specified, - uses the machine epsilon of the computation (opmath) type: fp16/bf16 and - fp32 inputs use ``torch.finfo(torch.float32).eps``, while fp64 inputs use - ``torch.finfo(torch.float64).eps``. + eps (float, optional): a value added to the denominator for numerical stability. + If not specified, uses the machine epsilon of the computation (opmath) type: + fp16/bf16 and fp32 inputs use ``torch.finfo(torch.float32).eps``, while fp64 + inputs use ``torch.finfo(torch.float64).eps``. Default: ``None`` elementwise_affine: a boolean value that when set to ``True``, this module has learnable per-element affine parameters initialized to ones (for weights). Default: ``True``. diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index 6f6e3b6154725..a313c5e9c1ef5 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -654,7 +654,7 @@ def __init__(self, *args, **kwargs): super().__init__(mode, *args, **kwargs) @overload - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: Tensor, @@ -663,7 +663,7 @@ def forward( pass @overload - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: PackedSequence, @@ -671,7 +671,7 @@ def forward( ) -> tuple[PackedSequence, Tensor]: pass - def forward(self, input, hx=None): # noqa: F811 + def forward(self, input, hx=None): """ Runs the forward pass. """ @@ -1058,25 +1058,25 @@ def permute_hidden( # type: ignore[override] # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload # type: ignore[override] - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None, - ) -> tuple[Tensor, tuple[Tensor, Tensor]]: # noqa: F811 + ) -> tuple[Tensor, tuple[Tensor, Tensor]]: pass # Same as above, see torch/nn/modules/module.py::_forward_unimplemented @overload - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: PackedSequence, hx: tuple[Tensor, Tensor] | None = None, - ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: # noqa: F811 + ) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: pass - def forward(self, input, hx=None): # noqa: F811 + def forward(self, input, hx=None): self._update_flat_weights() orig_input = input @@ -1356,24 +1356,24 @@ def __init__(self, *args, **kwargs): super().__init__("GRU", *args, **kwargs) @overload # type: ignore[override] - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: Tensor, hx: Tensor | None = None, - ) -> tuple[Tensor, Tensor]: # noqa: F811 + ) -> tuple[Tensor, Tensor]: pass @overload - @torch._jit_internal._overload_method # noqa: F811 + @torch._jit_internal._overload_method def forward( self, input: PackedSequence, hx: Tensor | None = None, - ) -> tuple[PackedSequence, Tensor]: # noqa: F811 + ) -> tuple[PackedSequence, Tensor]: pass - def forward(self, input, hx=None): # noqa: F811 + def forward(self, input, hx=None): self._update_flat_weights() orig_input = input diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 65e78dfe3180a..16537624663b1 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -259,7 +259,7 @@ def data_parallel( device_ids = [_get_device_index(x, True) for x in device_ids] output_device = _get_device_index(output_device, True) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] src_device_obj = torch.device(device_type, device_ids[0]) # pyrefly: ignore [bad-argument-type] diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 84c9313d2ed07..50f230b995fba 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -654,7 +654,7 @@ class DistributedDataParallel(Module, Joinable): If you plan on using this module with a ``nccl`` backend or a ``gloo`` backend (that uses Infiniband), together with a DataLoader that uses multiple workers, please change the multiprocessing start method to - ``forkserver`` (Python 3 only) or ``spawn``. Unfortunately + ``forkserver`` or ``spawn``. Unfortunately Gloo (that uses Infiniband) and NCCL2 are not fork safe, and you will likely experience deadlocks if you don't change this setting. @@ -773,6 +773,17 @@ class DistributedDataParallel(Module, Joinable): This requires that unused parameters remain the same across all ranks throughout the entire training process. If this condition is not met, it may cause desynchronization and result in training hang. + batched_grad_copy (bool): When set to ``True``, individual per-parameter + gradient-to-bucket copy and division operations are deferred + and flushed as a single ``_foreach_copy_`` plus one flat + ``div_`` when a bucket becomes ready. This reduces per-parameter + kernel launches down to 2 kernels per bucket, which can improve + throughput for models with many small parameters. The + optimization is most effective with + ``optimizer.zero_grad(set_to_none=True)`` (the default), where + ``gradient_as_bucket_view`` alone cannot avoid copies because + the bucket view alias is destroyed every iteration. + (default: ``False``) Attributes: @@ -808,6 +819,7 @@ def __init__( device_mesh=None, skip_all_reduce_unused_params=False, bucket_cap_mb_list: list[int] | None = None, + batched_grad_copy=False, ): super().__init__() Joinable.__init__(self) @@ -936,6 +948,7 @@ def __init__( self.require_backward_grad_sync = True self.require_forward_param_sync = True self.gradient_as_bucket_view = gradient_as_bucket_view + self.batched_grad_copy = batched_grad_copy self.mixed_precision = mixed_precision if self.mixed_precision is not None: logger.warning("Received mixed precision config %s", self.mixed_precision) @@ -1390,6 +1403,7 @@ def _ddp_init_helper( self.skip_all_reduce_unused_params, self._use_python_reducer, bucket_size_limits_for_rebuilding, + self.batched_grad_copy, ) self.logger = dist.Logger(self.reducer) @@ -1623,11 +1637,17 @@ def _get_active_ddp_module(cls): @contextmanager @torch._disable_dynamo(recursive=False) def _inside_ddp_forward(self): + # Save and restore the previous _active_ddp_module to handle nested + # DDP correctly (e.g., TorchRec wraps embeddings in an inner DDP inside + # an outer DDP). Without this, the inner DDP's exit would clear the + # flag to None, causing DDPOptimizer to miss compiled regions that run + # after the inner forward. + old = DistributedDataParallel._active_ddp_module DistributedDataParallel._active_ddp_module = self try: yield finally: - DistributedDataParallel._active_ddp_module = None + DistributedDataParallel._active_ddp_module = old def _run_ddp_forward(self, *inputs, **kwargs): if self._use_python_reducer: diff --git a/torch/nn/parallel/scatter_gather.py b/torch/nn/parallel/scatter_gather.py index 8954302f2a44e..bcd975a4fe2f7 100644 --- a/torch/nn/parallel/scatter_gather.py +++ b/torch/nn/parallel/scatter_gather.py @@ -58,19 +58,19 @@ def scatter_map(obj): if _is_namedtuple(obj): return [ type(obj)(*args) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] for args in zip(*map(scatter_map, obj), strict=False) ] if isinstance(obj, tuple) and len(obj) > 0: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return list(zip(*map(scatter_map, obj), strict=False)) if isinstance(obj, list) and len(obj) > 0: - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] return [list(i) for i in zip(*map(scatter_map, obj), strict=False)] if isinstance(obj, dict) and len(obj) > 0: return [ type(obj)(i) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] for i in zip(*map(scatter_map, obj.items()), strict=False) ] return [obj for _ in target_gpus] @@ -136,9 +136,9 @@ def gather_map(outputs): # pyrefly: ignore [not-callable] return type(out)((k, gather_map([d[k] for d in outputs])) for k in out) if _is_namedtuple(out): - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] return type(out)._make(map(gather_map, zip(*outputs, strict=True))) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type] return type(out)(map(gather_map, zip(*outputs, strict=True))) # Recursive function calls like this create reference cycles. diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index 824df5c8d317c..f2b0896c6b509 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -101,6 +101,7 @@ def __reduce_ex__(self, proto): (self.data, self.requires_grad, hooks, state), ) + # pyrefly: ignore [bad-override] __torch_function__ = _disabled_torch_function_impl @@ -273,13 +274,14 @@ def __new__(cls, data=None, *, persistent=True): t._is_buffer = True return t + # pyrefly: ignore [bad-override] __torch_function__ = _disabled_torch_function_impl class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor): r"""A buffer that is not initialized. - Uninitialized Buffer is a a special case of :class:`torch.Tensor` + Uninitialized Buffer is a special case of :class:`torch.Tensor` where the shape of the data is still unknown. Unlike a :class:`torch.Tensor`, uninitialized parameters diff --git a/torch/nn/qat/__init__.py b/torch/nn/qat/__init__.py index 766b09382aa78..665800e803871 100644 --- a/torch/nn/qat/__init__.py +++ b/torch/nn/qat/__init__.py @@ -1,11 +1,10 @@ -# flake8: noqa: F401 r"""QAT Dynamic Modules. This package is in the process of being deprecated. Please, use `torch.ao.nn.qat.dynamic` instead. """ -from torch.nn.qat import dynamic, modules # noqa: F403 +from torch.nn.qat import dynamic, modules from torch.nn.qat.modules import * # noqa: F403 diff --git a/torch/nn/qat/dynamic/__init__.py b/torch/nn/qat/dynamic/__init__.py index 56838a1cfcae7..ebe97a04a8751 100644 --- a/torch/nn/qat/dynamic/__init__.py +++ b/torch/nn/qat/dynamic/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""QAT Dynamic Modules. This package is in the process of being deprecated. diff --git a/torch/nn/qat/modules/__init__.py b/torch/nn/qat/modules/__init__.py index f7f55fbdf789a..3ef4652322fa9 100644 --- a/torch/nn/qat/modules/__init__.py +++ b/torch/nn/qat/modules/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""QAT Modules. This package is in the process of being deprecated. diff --git a/torch/nn/qat/modules/embedding_ops.py b/torch/nn/qat/modules/embedding_ops.py index 9a0964739f9e6..5ac3f9867314b 100644 --- a/torch/nn/qat/modules/embedding_ops.py +++ b/torch/nn/qat/modules/embedding_ops.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""QAT Modules. This file is in the process of migration to `torch/ao/nn/qat`, and diff --git a/torch/nn/quantized/__init__.py b/torch/nn/quantized/__init__.py index 5e2bbbc13202d..278c7218c1a51 100644 --- a/torch/nn/quantized/__init__.py +++ b/torch/nn/quantized/__init__.py @@ -1,4 +1,4 @@ -from torch.nn.quantized import dynamic, functional, modules # noqa: F403 +from torch.nn.quantized import dynamic, functional, modules from torch.nn.quantized.modules import * # noqa: F403 from torch.nn.quantized.modules import MaxPool2d diff --git a/torch/nn/quantized/_reference/modules/__init__.py b/torch/nn/quantized/_reference/modules/__init__.py index c9caa8e58f193..d736318ec5e89 100644 --- a/torch/nn/quantized/_reference/modules/__init__.py +++ b/torch/nn/quantized/_reference/modules/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Reference Modules. This module is in the process of migration to diff --git a/torch/nn/quantized/dynamic/modules/__init__.py b/torch/nn/quantized/dynamic/modules/__init__.py index 2ae09e82c3bb8..e3c68eed0c514 100644 --- a/torch/nn/quantized/dynamic/modules/__init__.py +++ b/torch/nn/quantized/dynamic/modules/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Dynamic Modules. This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, diff --git a/torch/nn/quantized/dynamic/modules/conv.py b/torch/nn/quantized/dynamic/modules/conv.py index b5b81a68a8891..df5817451f518 100644 --- a/torch/nn/quantized/dynamic/modules/conv.py +++ b/torch/nn/quantized/dynamic/modules/conv.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Dynamic Modules. This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, diff --git a/torch/nn/quantized/dynamic/modules/rnn.py b/torch/nn/quantized/dynamic/modules/rnn.py index d5ca396a2d440..2195fc73d444d 100644 --- a/torch/nn/quantized/dynamic/modules/rnn.py +++ b/torch/nn/quantized/dynamic/modules/rnn.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Dynamic Modules. This file is in the process of migration to `torch/ao/nn/quantized/dynamic`, diff --git a/torch/nn/quantized/functional.py b/torch/nn/quantized/functional.py index d763e171fdb43..24ad921b8d46a 100644 --- a/torch/nn/quantized/functional.py +++ b/torch/nn/quantized/functional.py @@ -7,4 +7,4 @@ Please, use the `torch.ao.nn.quantized.functional` instead. """ -from torch.ao.nn.quantized.functional import * # noqa: F401,F403 +from torch.ao.nn.quantized.functional import * # noqa: F403 diff --git a/torch/nn/quantized/modules/dropout.py b/torch/nn/quantized/modules/dropout.py index 32a7a22d55867..536b31ddbe81e 100644 --- a/torch/nn/quantized/modules/dropout.py +++ b/torch/nn/quantized/modules/dropout.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Modules. This file is in the process of migration to `torch/ao/nn/quantized`, and diff --git a/torch/nn/quantized/modules/embedding_ops.py b/torch/nn/quantized/modules/embedding_ops.py index d25f8bea7e378..93f017daeccd3 100644 --- a/torch/nn/quantized/modules/embedding_ops.py +++ b/torch/nn/quantized/modules/embedding_ops.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Modules. This file is in the process of migration to `torch/ao/nn/quantized`, and diff --git a/torch/nn/quantized/modules/functional_modules.py b/torch/nn/quantized/modules/functional_modules.py index efe1b38ce3ea4..e7ae21f35a466 100644 --- a/torch/nn/quantized/modules/functional_modules.py +++ b/torch/nn/quantized/modules/functional_modules.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Modules. This file is in the process of migration to `torch/ao/nn/quantized`, and diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index e9ba5a5c12f82..a70c47071395c 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Modules. This file is in the process of migration to `torch/ao/nn/quantized`, and diff --git a/torch/nn/quantized/modules/normalization.py b/torch/nn/quantized/modules/normalization.py index 85462cc365344..11a51f087dff7 100644 --- a/torch/nn/quantized/modules/normalization.py +++ b/torch/nn/quantized/modules/normalization.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r"""Quantized Modules. This file is in the process of migration to `torch/ao/nn/quantized`, and diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index b785e080b79e6..79fea5a50b793 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -185,8 +185,6 @@ def apply( return fn -# This is a top level class because Py2 pickle doesn't like inner class nor an -# instancemethod. class SpectralNormLoadStateDictPreHook: # See docstring of SpectralNorm._version on the changes to spectral_norm. def __init__(self, fn) -> None: @@ -244,8 +242,6 @@ def __call__( state_dict[weight_key + "_v"] = v -# This is a top level class because Py2 pickle doesn't like inner class nor an -# instancemethod. class SpectralNormStateDictHook: # See docstring of SpectralNorm._version on the changes to spectral_norm. def __init__(self, fn) -> None: diff --git a/torch/onnx/_internal/exporter/_analysis.py b/torch/onnx/_internal/exporter/_analysis.py index 53860413526ee..e6c6155b410a7 100644 --- a/torch/onnx/_internal/exporter/_analysis.py +++ b/torch/onnx/_internal/exporter/_analysis.py @@ -1,7 +1,7 @@ """Compatibility analyzer for PyTorch models.""" # mypy: allow-untyped-defs -# flake8: noqa: B950 We do not need flake8 as it complains line length + from __future__ import annotations import dataclasses diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 3b1e1dfc21620..1489ad2814f1b 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -# flake8: noqa: B950 We do not need flake8 as it complains line length + from __future__ import annotations import ctypes diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index c47ad25b31bf9..5d7484d85bed5 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -157,7 +157,7 @@ def _to_ort_value(input: torch.Tensor | int | float | str | bool) -> ort.OrtValu int: np.int64, float: np.float32, } - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] dtype = dtype_mapping.get(type(input)) return ort.OrtValue.ortvalue_from_numpy(np.array(input, dtype=dtype)) diff --git a/torch/onnx/_internal/exporter/_registration.py b/torch/onnx/_internal/exporter/_registration.py index d22186944e1fe..c8fb10a134e29 100644 --- a/torch/onnx/_internal/exporter/_registration.py +++ b/torch/onnx/_internal/exporter/_registration.py @@ -80,7 +80,7 @@ def __post_init__(self) -> None: # When the function is targeting an HOP, for example, it will accept # functions as arguments and fail to generate an ONNX signature. # In this case we set signature to None and dispatch to this function always. - logger.warning( # noqa: G200 + logger.warning( "Failed to infer the signature for function '%s' because '%s'" "All nodes targeting `%s` will be dispatched to this function", self.onnx_function, diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/core.py b/torch/onnx/_internal/exporter/_torchlib/ops/core.py index 36d53b113edc2..19107bb4b61f0 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/core.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/core.py @@ -1,7 +1,7 @@ """torch.ops.aten operators under the `core` module.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" # pyrefly: ignore-errors -# ruff: noqa: TCH001,TCH002 +# ruff: noqa: TCH001 from __future__ import annotations diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py index 14c3ad6d2a3ab..7208e09ec33f5 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/nn.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/nn.py @@ -1,8 +1,7 @@ """torch.ops.aten operators under the `core` module.""" # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" # pyrefly: ignore-errors -# ruff: noqa: TC001,TC002 -# flake8: noqa: B950 +# ruff: noqa: TC001 from __future__ import annotations diff --git a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py index cdaee46802768..3d2ae6b2fe024 100644 --- a/torch/onnx/_internal/exporter/_torchlib/ops/symops.py +++ b/torch/onnx/_internal/exporter/_torchlib/ops/symops.py @@ -2,7 +2,7 @@ # mypy: disable-error-code="misc,arg-type,type-arg,valid-type,assignment,return-value,type-var,operator,no-untyped-def,index" # pyrefly: ignore-errors -# ruff: noqa: TCH001,TCH002,TC003 +# ruff: noqa: TC003 from __future__ import annotations diff --git a/torch/onnx/_internal/fx/type_utils.py b/torch/onnx/_internal/fx/type_utils.py index bd32799bc4e2b..ca434fa213f79 100644 --- a/torch/onnx/_internal/fx/type_utils.py +++ b/torch/onnx/_internal/fx/type_utils.py @@ -4,11 +4,14 @@ from __future__ import annotations from typing import Any +from typing_extensions import TypeIs import torch -def is_torch_symbolic_type(value: Any) -> bool: +def is_torch_symbolic_type( + value: Any, +) -> TypeIs[torch.SymBool | torch.SymInt | torch.SymFloat]: return isinstance(value, (torch.SymBool, torch.SymInt, torch.SymFloat)) diff --git a/torch/onnx/_internal/torchscript_exporter/_type_utils.py b/torch/onnx/_internal/torchscript_exporter/_type_utils.py index acd13f914ae2d..c0ffa56a38e3b 100644 --- a/torch/onnx/_internal/torchscript_exporter/_type_utils.py +++ b/torch/onnx/_internal/torchscript_exporter/_type_utils.py @@ -173,6 +173,7 @@ def from_onnx_type( """ if onnx_type not in _ONNX_TO_SCALAR_TYPE: raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}") + # pyrefly: ignore [redundant-cast] return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)] @classmethod diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py index d52e2a6ee9249..0e00a7a5cd263 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_helper.py @@ -110,7 +110,7 @@ NoReturn, TypeVar as _TypeVar, ) -from typing_extensions import ParamSpec as _ParamSpec +from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs import torch import torch._C._onnx as _C_onnx @@ -561,7 +561,7 @@ def _is_none(x: Any) -> bool: return x is None or (x.node().mustBeNone() if isinstance(x, _C.Value) else False) -def _is_value(x: Any) -> bool: +def _is_value(x: Any) -> _TypeIs[_C.Value]: return isinstance(x, _C.Value) diff --git a/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py index 2fcbdad947082..7beb3de13f272 100644 --- a/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py +++ b/torch/onnx/_internal/torchscript_exporter/symbolic_opset14.py @@ -167,7 +167,7 @@ def scaled_dot_product_attention( # NOTE: onnx-script has different logic here, because the attribute perms in # transpose needs list of ints key_shape_builtin = symbolic_helper._get_tensor_rank(key) - # pyrefly: ignore [no-matching-overload] + # pyrefly: ignore [bad-argument-type, no-matching-overload] key_transposed_axes = list(range(key_shape_builtin)) key_transposed_axes[-1], key_transposed_axes[-2] = ( key_transposed_axes[-2], diff --git a/torch/onnx/ops/__init__.py b/torch/onnx/ops/__init__.py index 8da3fc8e58723..6a75faed2a385 100644 --- a/torch/onnx/ops/__init__.py +++ b/torch/onnx/ops/__init__.py @@ -4,7 +4,6 @@ which are exportable to ONNX. """ -# flake8: noqa: B950 from __future__ import annotations diff --git a/torch/onnx/ops/_impl.py b/torch/onnx/ops/_impl.py index be50d953b619b..18e1275942833 100644 --- a/torch/onnx/ops/_impl.py +++ b/torch/onnx/ops/_impl.py @@ -5,7 +5,6 @@ for more details on how to create fake kernels. """ -# flake8: noqa: B950 import math from collections.abc import Callable from typing import TypeVar diff --git a/torch/onnx/ops/_symbolic_impl.py b/torch/onnx/ops/_symbolic_impl.py index 85963dd85b1d2..e7fa44355d1e4 100644 --- a/torch/onnx/ops/_symbolic_impl.py +++ b/torch/onnx/ops/_symbolic_impl.py @@ -11,7 +11,6 @@ or less the same thing but is required by the `torch.library.custom_op` interface. """ -# flake8: noqa: B950 import dataclasses from collections.abc import Sequence diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 76b50a8eb3f77..d996d93358bff 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_helper import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_helper import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset10.py b/torch/onnx/symbolic_opset10.py index 9bda69b81ab60..cf3fdee3b86c5 100644 --- a/torch/onnx/symbolic_opset10.py +++ b/torch/onnx/symbolic_opset10.py @@ -5,7 +5,7 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import * # noqa: F403 from torch.onnx._internal.torchscript_exporter.symbolic_opset10 import ( # noqa: F401 _slice, ) diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 276ef7209bf69..d0d2d634f3465 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset11 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 63e137734e8a7..54230c0fa6968 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset12 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index 18aff9295be8c..b701224d4148e 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset13 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset14.py b/torch/onnx/symbolic_opset14.py index 367aa9eb0832a..6d56f1c6e114c 100644 --- a/torch/onnx/symbolic_opset14.py +++ b/torch/onnx/symbolic_opset14.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset15.py b/torch/onnx/symbolic_opset15.py index e04e3b0452127..08d7c4636c15b 100644 --- a/torch/onnx/symbolic_opset15.py +++ b/torch/onnx/symbolic_opset15.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset15 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset16.py b/torch/onnx/symbolic_opset16.py index 9a248bb0f26c5..6d11387b90c62 100644 --- a/torch/onnx/symbolic_opset16.py +++ b/torch/onnx/symbolic_opset16.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset16 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset17.py b/torch/onnx/symbolic_opset17.py index 800acd446b5dc..ef42344d066ee 100644 --- a/torch/onnx/symbolic_opset17.py +++ b/torch/onnx/symbolic_opset17.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset17 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset18.py b/torch/onnx/symbolic_opset18.py index cc07a60f018d8..6c7f88e77c059 100644 --- a/torch/onnx/symbolic_opset18.py +++ b/torch/onnx/symbolic_opset18.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset18 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset19.py b/torch/onnx/symbolic_opset19.py index 4f7a54fc1dd38..7bde244c7cc4d 100644 --- a/torch/onnx/symbolic_opset19.py +++ b/torch/onnx/symbolic_opset19.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset19 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset20.py b/torch/onnx/symbolic_opset20.py index 56635a7811611..2df2767bd26f9 100644 --- a/torch/onnx/symbolic_opset20.py +++ b/torch/onnx/symbolic_opset20.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset20 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset7.py b/torch/onnx/symbolic_opset7.py index c11e769677ec4..6e818bdcac07b 100644 --- a/torch/onnx/symbolic_opset7.py +++ b/torch/onnx/symbolic_opset7.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset7 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset8.py b/torch/onnx/symbolic_opset8.py index 0e4411649f3e0..85b7876f98d6b 100644 --- a/torch/onnx/symbolic_opset8.py +++ b/torch/onnx/symbolic_opset8.py @@ -5,4 +5,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset8 import * # noqa: F403 diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index bd0f4795340ae..d0ec459757beb 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -5,7 +5,7 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import * # noqa: F403 from torch.onnx._internal.torchscript_exporter.symbolic_opset9 import ( # noqa: F401 _prepare_onnx_paddings, _reshape_from_tensor, diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index d387069b0b0c9..1d55cd5a76169 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -6,4 +6,4 @@ __all__: list[str] = [] -from torch.onnx._internal.torchscript_exporter.utils import * # noqa: F401,F403 +from torch.onnx._internal.torchscript_exporter.utils import * # noqa: F403 diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py index b6818e5a50f3b..560ce3135f325 100644 --- a/torch/optim/_multi_tensor/__init__.py +++ b/torch/optim/_multi_tensor/__init__.py @@ -11,7 +11,7 @@ from torch import optim -def partialclass(cls, *args, **kwargs): # noqa: D103 +def partialclass(cls, *args, **kwargs): class NewCls(cls): __init__ = partialmethod(cls.__init__, *args, **kwargs) diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index e441c8b911b2f..2e45e07c4a596 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -273,6 +273,29 @@ def step(self, closure=None): adjust_lr_fn (str, optional): function to adjust learning rate. One of "original" and "match_rms_adamw". If not specified, we will default to use "original". (default: None) + Example: + >>> # xdoctest: +SKIP + >>> # Muon only supports 2D params; use a standard optimizer + >>> # such as AdamW for biases, embeddings, and other non-2D + >>> # parameters. + >>> muon_params = [ + ... p for p in model.parameters() if p.ndim == 2 + ... ] + >>> other_params = [ + ... p for p in model.parameters() if p.ndim != 2 + ... ] + >>> optim_muon = torch.optim.Muon( + ... muon_params, lr=0.02, momentum=0.95 + ... ) + >>> optim_adamw = torch.optim.AdamW( + ... other_params, lr=3e-4, weight_decay=0.01 + ... ) + >>> optim_muon.zero_grad() + >>> optim_adamw.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optim_muon.step() + >>> optim_adamw.step() + .. _Muon\: An optimizer for hidden layers in neural networks: https://kellerjordan.github.io/posts/muon/ .. _Muon is Scalable for LLM Training: diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index a6a57fb61b8ba..d18e32fe17a98 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -16,6 +16,7 @@ _to_scalar, _use_grad_for_differentiable, _view_as_real, + DeviceDict, Optimizer, ParamsT, ) @@ -73,6 +74,7 @@ def __init__( if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") self._need_device_dtype_check_for_fused = True + self._step_supports_amp_scaling = True for group in self.param_groups: for p in group["params"]: @@ -106,6 +108,20 @@ def __setstate__(self, state): group.setdefault("differentiable", False) fused = group.setdefault("fused", None) + for p in group["params"]: + p_state = self.state.get(p, []) + if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): + step_val = float(p_state["step"]) + p_state["step"] = ( + torch.tensor( + step_val, + dtype=_get_scalar_dtype(is_fused=fused), + device=p.device, + ) + if group["fused"] + else torch.tensor(step_val, dtype=_get_scalar_dtype()) + ) + state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor( state_values[0]["step"] @@ -132,13 +148,38 @@ def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): "_need_device_dtype_check_for_fused", True, ): - _device_dtype_check_for_fused(p, cuda_unsupported=True) + _device_dtype_check_for_fused(p) self._need_device_dtype_check_for_fused = False has_sparse_grad |= p.grad.is_sparse has_complex |= torch.is_complex(p) params_with_grad.append(p) grads.append(p.grad) state = self.state[p] + if len(state) == 0: + if group["fused"]: + _device_dtype_check_for_fused(p) + + state["step"] = ( + torch.zeros( + (), + dtype=_get_scalar_dtype(is_fused=group["fused"]), + device=p.device, + ) + if group["fused"] + else torch.tensor(0.0, dtype=_get_scalar_dtype()) + ) + + initial_accumulator_value = self.defaults[ + "initial_accumulator_value" + ] + init_value = ( + complex(initial_accumulator_value, initial_accumulator_value) + if torch.is_complex(p) + else initial_accumulator_value + ) + state["sum"] = torch.full_like( + p, init_value, memory_format=torch.preserve_format + ) state_sums.append(state["sum"]) state_steps.append(state["step"]) @@ -230,7 +271,7 @@ def step(self, closure=None): {_foreach_doc} {_maximize_doc} {_differentiable_doc} - fused (bool, optional): whether the fused implementation (CPU only) is used. + fused (bool, optional): whether the fused implementation (CPU and CUDA only) is used. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` are supported. (default: None). Please note that the fused implementations does not support sparse or complex gradients. @@ -507,7 +548,7 @@ def _fused_adagrad( grad_scale: Tensor | None, found_inf: Tensor | None, *, - lr: float, + lr: float | Tensor, weight_decay: float, lr_decay: float, eps: float, @@ -526,12 +567,15 @@ def _fused_adagrad( "adagrad with fused=True does not support differentiable=True" ) - lr = _to_scalar(lr) - - grad_scale_dict = ( - {grad_scale.device: grad_scale} if grad_scale is not None else None + grad_scale_dict: DeviceDict = ( + {grad_scale.device: grad_scale} if grad_scale is not None else {} + ) + found_inf_dict: DeviceDict = ( + {found_inf.device: found_inf} if found_inf is not None else {} + ) + lr_dict: DeviceDict | None = ( + {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None ) - found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( [params, grads, state_sums, state_steps] # type: ignore[list-item] @@ -551,14 +595,17 @@ def _fused_adagrad( device_state_steps = cast(list[Tensor], device_state_steps_) device_grad_scale, device_found_inf = None, None - if grad_scale is not None and grad_scale_dict is not None: - if device not in grad_scale_dict: - grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] - device_grad_scale = grad_scale_dict[device] # type: ignore[index] - if found_inf is not None and found_inf_dict is not None: - if found_inf not in found_inf_dict: - found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] - device_found_inf = found_inf_dict[device] # type: ignore[index] + if grad_scale is not None: + device_grad_scale = grad_scale_dict.setdefault( + device, grad_scale.to(device, non_blocking=True) + ) + if found_inf is not None: + device_found_inf = found_inf_dict.setdefault( + device, found_inf.to(device, non_blocking=True) + ) + if lr_dict is not None and device not in lr_dict: + lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr] + lr = lr_dict[device] torch._foreach_add_(device_state_steps, 1) torch._fused_adagrad_( device_params, diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index b2049c73fc472..b67fd81a1099c 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -120,7 +120,7 @@ def __init__( self, optimizer: Optimizer, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -380,7 +380,7 @@ def __init__( optimizer: Optimizer, lr_lambda: Callable[[int], float] | list[Callable[[int], float]], last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.optimizer = optimizer self.lr_lambdas: list[Callable[[int], float]] @@ -495,7 +495,7 @@ def __init__( optimizer: Optimizer, lr_lambda: Callable[[int], float] | list[Callable[[int], float]], last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.optimizer = optimizer self.lr_lambdas: list[Callable[[int], float]] @@ -624,7 +624,7 @@ def __init__( step_size: int, gamma: float = 0.1, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.step_size = step_size self.gamma = gamma super().__init__(optimizer, last_epoch) @@ -710,7 +710,7 @@ def __init__( milestones: Iterable[int], gamma: float = 0.1, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.milestones = Counter(milestones) self.gamma = gamma super().__init__(optimizer, last_epoch) @@ -809,7 +809,7 @@ def __init__( factor: float = 1.0 / 3, total_iters: int = 5, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: if factor > 1.0 or factor < 0: raise ValueError( "Constant multiplicative factor expected to be between 0 and 1." @@ -917,7 +917,7 @@ def __init__( end_factor: float = 1.0, total_iters: int = 5, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: if start_factor > 1.0 or start_factor <= 0: raise ValueError( "Starting multiplicative factor expected to be greater than 0 and less or equal to 1." @@ -1030,7 +1030,7 @@ def __init__( optimizer: Optimizer, gamma: float, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.gamma = gamma super().__init__(optimizer, last_epoch) @@ -1122,7 +1122,7 @@ def __init__( schedulers: list[LRScheduler], milestones: list[int], last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: if len(schedulers) < 1: raise ValueError( f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler." @@ -1266,7 +1266,7 @@ def __init__( total_iters: int = 5, power: float = 1.0, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.total_iters = total_iters self.power = power super().__init__(optimizer, last_epoch) @@ -1390,7 +1390,7 @@ def __init__( T_max: int, eta_min: float = 0.0, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: self.T_max = T_max self.eta_min = eta_min super().__init__(optimizer, last_epoch) @@ -1507,7 +1507,7 @@ class ChainedScheduler(LRScheduler): def __init__( self, schedulers: Sequence[LRScheduler], optimizer: Optimizer | None = None - ) -> None: # noqa: D107 + ) -> None: if len(schedulers) < 1: raise ValueError( f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler." @@ -1648,7 +1648,7 @@ def __init__( cooldown: int = 0, min_lr: list[float] | float = 0, eps: float = 1e-8, - ) -> None: # noqa: D107 + ) -> None: if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor @@ -1734,10 +1734,10 @@ def _reduce_lr(self, epoch) -> None: _update_param_group_val(param_group, "lr", new_lr) @property - def in_cooldown(self): # noqa: D102 + def in_cooldown(self): return self.cooldown_counter > 0 - def _is_better(self, a, best): # noqa: D102 + def _is_better(self, a, best): if self.mode == "min" and self.threshold_mode == "rel": rel_epsilon = 1.0 - self.threshold return a < best * rel_epsilon @@ -1891,7 +1891,7 @@ def __init__( base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -2061,7 +2061,7 @@ def get_lr(self) -> list[float | Tensor]: return lrs @override - def state_dict(self) -> dict[str, Any]: # noqa: D102 + def state_dict(self) -> dict[str, Any]: """Return the state of the scheduler as a :class:`dict`. It contains an entry for every variable in ``self.__dict__`` which @@ -2142,7 +2142,7 @@ def __init__( T_mult: int = 1, eta_min: float = 0.0, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): @@ -2394,7 +2394,7 @@ def __init__( final_div_factor: float = 1e4, three_phase: bool = False, last_epoch: int = -1, - ) -> None: # noqa: D107 + ) -> None: # Validate optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index d338333d60934..0a3593a567dfb 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -29,7 +29,7 @@ __all__ = ["NAdam", "nadam"] -class NAdam(Optimizer): # noqa: D101 +class NAdam(Optimizer): def __init__( self, params: ParamsT, @@ -44,7 +44,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ) -> None: # noqa: D107 + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -73,7 +73,7 @@ def __init__( } super().__init__(params, defaults) - def __setstate__(self, state): # noqa: D105 + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("maximize", False) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index cf9a3c53bf8e8..a529605c1ec45 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -135,12 +135,12 @@ def maybe_fallback(*args: _P.args, **kwargs: _P.kwargs): and has_state_steps and (arg := args[state_steps_ind]) and isinstance(arg, Sequence) - and arg[0].is_cuda + and arg[0].device.type in {"cuda", "xpu"} or ( "state_steps" in kwargs and (kwarg := kwargs["state_steps"]) and isinstance(kwarg, Sequence) - and kwarg[0].is_cuda + and kwarg[0].device.type in {"cuda", "xpu"} ) ): return disabled_func(*args, **kwargs) @@ -374,7 +374,7 @@ class Optimizer: 'OrderedDict[int, Callable[["Optimizer"], None]]' ) - def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107 + def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: torch._C._log_api_usage_once("python.optimizer") self.defaults = defaults self._optimizer_step_pre_hooks = OrderedDict() @@ -409,14 +409,14 @@ def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: # https://github.com/pytorch/pytorch/issues/72948 self._warned_capturable_if_run_uncaptured = True - def __getstate__(self) -> dict[str, Any]: # noqa: D105 + def __getstate__(self) -> dict[str, Any]: return { "defaults": self.defaults, "state": self.state, "param_groups": self.param_groups, } - def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105 + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) if "_optimizer_step_pre_hooks" not in self.__dict__: self._optimizer_step_pre_hooks = OrderedDict() @@ -433,7 +433,7 @@ def __setstate__(self, state: dict[str, Any]) -> None: # noqa: D105 self._patch_step_function() # To support multiprocessing pickle/unpickle self.defaults.setdefault("differentiable", False) - def __repr__(self) -> str: # noqa: D105 + def __repr__(self) -> str: format_string = self.__class__.__name__ + " (" for i, group in enumerate(self.param_groups): format_string += "\n" @@ -460,21 +460,16 @@ def _accelerator_graph_capture_health_check(self) -> None: return # Determine available accelerator device - accelerator = None - if torch.cuda.is_available(): - accelerator = (torch.cuda, "CUDA") - elif torch.xpu.is_available(): - accelerator = (torch.xpu, "XPU") + accelerator = torch.accelerator.current_accelerator(check_available=True) - if accelerator: - device_module, device_name = accelerator - capturing = device_module.is_current_stream_capturing() + if accelerator and accelerator.type in {"cuda", "xpu"}: + capturing = torch.accelerator.current_stream().is_capturing() if capturing and not all( group["capturable"] for group in self.param_groups ): raise RuntimeError( - f"Attempting {device_name} graph capture of step() for an instance of " + f"Attempting {accelerator.type.upper()} graph capture of step() for an instance of " + self.__class__.__name__ + " but param_groups' capturable is False." ) @@ -486,7 +481,7 @@ def _accelerator_graph_capture_health_check(self) -> None: ): warnings.warn( "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " - f"but step() is running without {device_name} graph capture. If you never intend to graph-capture this " + f"but step() is running without {accelerator.type.upper()} graph capture. If you never intend to graph-capture this " "instance, capturable=True can impair performance, and you should set capturable=False.", stacklevel=2, ) @@ -510,7 +505,7 @@ def _optimizer_step_code(self) -> None: """ @staticmethod - def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 + def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: @functools.wraps(func) def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: self, *_ = args @@ -618,7 +613,7 @@ def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: def register_state_dict_pre_hook( self, hook: Callable[["Optimizer"], None], prepend: bool = False - ) -> RemovableHandle: # noqa: D101 + ) -> RemovableHandle: r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. It should have the following signature:: @@ -811,7 +806,7 @@ def register_load_state_dict_pre_hook( self, hook: Callable[["Optimizer", StateDict], StateDict | None], prepend: bool = False, - ) -> RemovableHandle: # noqa: D205 D400 + ) -> RemovableHandle: r"""Register a load_state_dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the following signature:: @@ -848,7 +843,7 @@ def register_load_state_dict_pre_hook( def register_load_state_dict_post_hook( self, hook: Callable[["Optimizer"], None], prepend: bool = False - ) -> RemovableHandle: # noqa: D205 D400 + ) -> RemovableHandle: r"""Register a load_state_dict post-hook which will be called after :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the following signature:: diff --git a/torch/optim/radam.py b/torch/optim/radam.py index d414b9c7edbee..26c164b2b603a 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -28,7 +28,7 @@ __all__ = ["RAdam", "radam"] -class RAdam(Optimizer): # noqa: D101 +class RAdam(Optimizer): def __init__( self, params: ParamsT, @@ -42,7 +42,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ) -> None: # noqa: D107 + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -69,7 +69,7 @@ def __init__( } super().__init__(params, defaults) - def __setstate__(self, state): # noqa: D105 + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index eefce02fac6fe..771f72d01621b 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -27,7 +27,7 @@ __all__ = ["RMSprop", "rmsprop"] -class RMSprop(Optimizer): # noqa: D101 +class RMSprop(Optimizer): def __init__( self, params: ParamsT, @@ -41,7 +41,7 @@ def __init__( foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, - ) -> None: # noqa: D107 + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -69,7 +69,7 @@ def __init__( } super().__init__(params, defaults) - def __setstate__(self, state): # noqa: D105 + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("momentum", 0) diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index ea9da85141286..f65d2717c0547 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -27,7 +27,7 @@ __all__ = ["Rprop", "rprop"] -class Rprop(Optimizer): # noqa: D101 +class Rprop(Optimizer): def __init__( self, params: ParamsT, @@ -39,7 +39,7 @@ def __init__( foreach: bool | None = None, maximize: bool = False, differentiable: bool = False, - ) -> None: # noqa: D107 + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -58,7 +58,7 @@ def __init__( } super().__init__(params, defaults) - def __setstate__(self, state): # noqa: D105 + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("foreach", None) diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 7c8783fd23182..cf7a8a708365a 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -25,7 +25,7 @@ __all__ = ["SGD", "sgd"] -class SGD(Optimizer): # noqa: D101 +class SGD(Optimizer): def __init__( self, params: ParamsT, @@ -39,7 +39,7 @@ def __init__( foreach: bool | None = None, differentiable: bool = False, fused: bool | None = None, - ) -> None: # noqa: D107 + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if lr < 0.0: @@ -72,7 +72,7 @@ def __init__( if foreach: raise RuntimeError("`fused` and `foreach` cannot be `True` together.") - def __setstate__(self, state): # noqa: D105 + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index ec60e81aff9b0..dc49ba7fc010a 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -271,7 +271,7 @@ def __init__( multi_avg_fn: Callable[[PARAM_LIST, PARAM_LIST, Tensor | int], None] | None = None, use_buffers=False, - ) -> None: # noqa: D107 + ) -> None: super().__init__() if avg_fn is not None and multi_avg_fn is not None: raise AssertionError( @@ -478,7 +478,7 @@ def __init__( anneal_epochs=10, anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, - ) -> None: # noqa: D107 + ) -> None: swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True): group["swa_lr"] = swa_lr diff --git a/torch/overrides.py b/torch/overrides.py index db945dd73322d..a2823f902a8b9 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -20,7 +20,7 @@ instructions in the ``README.md`` in that directory. """ -import __future__ # noqa: F404 +import __future__ import collections import contextlib @@ -316,9 +316,6 @@ def get_ignored_functions() -> set[Callable]: torch.unify_type_list, torch.is_warn_always_enabled, torch.set_warn_always, - torch.vitals_enabled, - torch.set_vital, - torch.read_vitals, torch.vmap, torch.cond, torch.frombuffer, @@ -385,6 +382,8 @@ def get_ignored_functions() -> set[Callable]: Tensor._addmm_activation, Tensor.to_padded_tensor, Tensor._use_count, + Tensor._philox_normal_, + Tensor._philox_uniform_, } if sys.version_info >= (3, 14): @@ -605,10 +604,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1, torch.einsum: lambda equation, *operands: -1, torch.embedding: ( - lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 + lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 ), torch.embedding_bag: ( - lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950 + lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 ), torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch.eq: lambda input, other, out=None: -1, @@ -622,11 +621,11 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1, torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1, torch.fused_moving_avg_obs_fake_quant: ( - lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950 + lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 ), torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias, output: -1, torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias, output: -1, - torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950 + torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, torch.fbgemm_linear_int8_weight_fp32_activation: ( lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1 ), @@ -667,7 +666,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.fmod: lambda input, other, out=None: -1, torch.frac: lambda input, out=None: -1, torch.frexp: lambda input, out=None: -1, - torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950 + torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, torch._functional_assert_async: lambda input, msg, dep_token: -1, torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1, torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, @@ -692,7 +691,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.hardshrink: lambda input, lambd=0.5: -1, torch.hash_tensor: lambda input, dim=None, keepdim=False, mode=0, out=None: -1, torch.heaviside: lambda input, values, out=None: -1, - torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 + torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1, torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1, torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1, @@ -735,7 +734,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1, torch.isnan: lambda input: -1, torch.istft: ( - lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950 + lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 ), torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, torch.kron: lambda input, other: -1, @@ -750,7 +749,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.less_equal: lambda input, other, out=None: -1, torch.lerp: lambda input, end, weight, out=None: -1, torch.lgamma: lambda input, out=None: -1, - torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950 + torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, torch.log: lambda input, out=None: -1, torch.log_softmax: lambda input, dim, dtype=None: -1, torch.log10: lambda input, out=None: -1, @@ -772,7 +771,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.less: lambda input, other, out=None: -1, torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1, torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1, - torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950 + torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] torch.masked_fill: lambda input, mask, value: -1, torch.masked_scatter: lambda input, mask, source: -1, torch.masked_select: lambda input, mask, out=None: -1, @@ -808,7 +807,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.miopen_batch_norm: ( lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1 ), - torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950 + torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1, torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1, torch.miopen_convolution_transpose: ( @@ -818,7 +817,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1 ), torch.miopen_rnn: ( - lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950 + lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 ), torch.mm: lambda input, mat2, out_dtype=None, out=None: -1, torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1, @@ -856,10 +855,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1, torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, torch.nn.functional.avg_pool2d: ( - lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 + lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 ), torch.nn.functional.avg_pool3d: ( - lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950 + lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 ), torch.nn.functional.batch_norm: ( lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1 @@ -876,7 +875,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1 ), torch.nn.functional.cross_entropy: ( - lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950 + lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 ), torch.nn.functional.ctc_loss: ( lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1 @@ -887,29 +886,29 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1, torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1, torch.nn.functional.embedding: ( - lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950 + lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 ), torch.nn.functional.embedding_bag: ( - lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950 + lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 ), torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1, torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1, torch.nn.functional.fractional_max_pool2d: ( - lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 ), torch.nn.functional.fractional_max_pool2d_with_indices: ( - lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 ), torch.nn.functional.fractional_max_pool3d: ( - lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 ), torch.nn.functional.fractional_max_pool3d_with_indices: ( - lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950 + lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 ), torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1, torch.nn.functional.gelu: lambda input, approximate="none": -1, torch.nn.functional.glu: lambda input, dim=-1: -1, - torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950 + torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1, torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1, torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1, @@ -918,12 +917,12 @@ def get_testing_overrides() -> dict[Callable, Callable]: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1 ), torch.nn.functional.instance_norm: ( - lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950 + lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 ), torch.nn.functional.interpolate: ( - lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950 + lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 ), - torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950 + torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1, torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1, torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1, @@ -955,12 +954,12 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.functional.max_pool3d_with_indices: ( lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1 ), - torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 - torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 - torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950 + torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, + torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, + torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", weight=None: -1, torch.nn.functional.multi_head_attention_forward: ( - lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950 + lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 ), torch.nn.functional.multi_margin_loss: ( lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1 @@ -979,20 +978,20 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1, torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1, torch.nn.functional.poisson_nll_loss: ( - lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950 + lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 ), torch.nn.functional.prelu: lambda input, weight: -1, torch.nn.functional.relu: lambda input, inplace=False: -1, torch.nn.functional.relu6: lambda input, inplace=False: -1, torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, - torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950 + torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, torch.nn.functional.selu: lambda input, inplace=False: -1, torch.nn.functional.silu: lambda input, inplace=False: -1, torch.nn.functional.mish: lambda input, inplace=False: -1, torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1, - torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950 + torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0, weight=None: -1, - torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950 + torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1, torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1, torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1, @@ -1001,7 +1000,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.functional.tanhshrink: lambda input: -1, torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1, torch.nn.functional.triplet_margin_loss: ( - lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 + lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 ), torch.nn.functional.triplet_margin_with_distance_loss: ( lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1 @@ -1010,7 +1009,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1, torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1, torch.nn.init.constant_: lambda tensor, val: -1, - torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950 + torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, torch.nonzero: lambda input, as_tuple=False: -1, torch.nonzero_static: lambda input, *, size, fill_value=-1: -1, torch.argwhere: lambda input: -1, @@ -1057,10 +1056,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1, torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1, torch.quantized_gru_cell: ( - lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 ), torch.quantized_lstm_cell: ( - lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 ), torch.quantized_max_pool1d: ( lambda input, kernel_size, stride=(), padding=(0,), dilation=( @@ -1081,10 +1080,10 @@ def get_testing_overrides() -> dict[Callable, Callable]: ), ceil_mode=False: -1 ), torch.quantized_rnn_relu_cell: ( - lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 ), torch.quantized_rnn_tanh_cell: ( - lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950 + lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 ), torch.rad2deg: lambda input, out=None: -1, torch.ravel: lambda input: -1, @@ -1100,9 +1099,9 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.repeat_interleave: lambda input, dim=None: -1, torch.reshape: lambda input, shape: -1, torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1, - torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 + torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, - torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950 + torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1, torch.roll: lambda input, shifts, dims=None: -1, torch.rot90: lambda input, k=1, dims=(0, 1): -1, @@ -1117,7 +1116,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.scatter_add: lambda input, dim, index, src: -1, torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1, torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1, - torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950 + torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, torch.select: lambda input, dim, index: -1, torch.select_scatter: lambda input, src, dim, index: -1, torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1, @@ -1148,7 +1147,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.std: lambda input, dim=None: -1, torch.std_mean: lambda input, dim=None: -1, torch.stft: ( - lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 # noqa: B950 + lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None, align_to_window=None: -1 ), torch.sub: lambda input, other, out=None: -1, torch.subtract: lambda input, other, out=None: -1, @@ -1253,7 +1252,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1, torch.tril: lambda input, diagonal=0, out=None: -1, torch.triplet_margin_loss: ( - lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950 + lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 ), torch.triu: lambda input, diagonal=0, out=None: -1, torch.true_divide: lambda input, other: -1, @@ -1275,7 +1274,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: torch.where: lambda condition, x=None, y=None: -1, torch._wrapped_linear_prepack: lambda weight, weight_scale, weight_zero_point, bias : -1, torch._wrapped_quantized_linear_prepacked: ( - lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 # noqa: B950 + lambda input, input_scale, input_zero_point, prepacked, out_scale, out_zero_point, out_channel : -1 ), torch.zeros_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1, torch._fw_primal_copy: lambda self, level: -1, @@ -1433,6 +1432,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: Tensor.mtia: lambda self, memory_format=torch.preserve_format: -1, Tensor.xpu: lambda self, memory_format=torch.preserve_format: -1, Tensor.ipu: lambda self, memory_format=torch.preserve_format: -1, + Tensor.const_data_ptr: lambda self: -1, Tensor.data_ptr: lambda self: -1, Tensor.dense_dim: lambda self: -1, Tensor.diagonal_scatter: lambda self, src, offset=0, dim1=0, dim2=1: -1, @@ -1588,7 +1588,7 @@ def get_testing_overrides() -> dict[Callable, Callable]: dist.scatter: lambda tensor, scatter_list=None, src=None, group=None, async_op=False, group_src=None: -1, dist.reduce_scatter: lambda output, input_list, op=None, group=None, async_op=False: -1, dist.reduce_scatter_tensor: lambda output, input, op=None, group=None, async_op=False: -1, - dist.all_to_all_single: lambda output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False: -1, # noqa: B950 + dist.all_to_all_single: lambda output, input, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False: -1, dist.all_to_all: lambda output_tensor_list, input_tensor_list, group=None, async_op=False: -1, dist.isend: lambda tensor, dst=None, group=None, tag=0, group_dst=None: -1, dist.irecv: lambda tensor, src=None, group=None, tag=0, group_src=None: -1, diff --git a/torch/package/_importlib.py b/torch/package/_importlib.py index 609efd294c4c9..b24031bce8b8e 100644 --- a/torch/package/_importlib.py +++ b/torch/package/_importlib.py @@ -62,8 +62,8 @@ def _calc___package__(globals): spec = globals.get("__spec__") if package is not None: if spec is not None and package != spec.parent: - _warnings.warn( # noqa: G010 - f"__package__ != __spec__.parent ({package!r} != {spec.parent!r})", # noqa: G004 + _warnings.warn( + f"__package__ != __spec__.parent ({package!r} != {spec.parent!r})", ImportWarning, stacklevel=3, ) @@ -71,7 +71,7 @@ def _calc___package__(globals): elif spec is not None: return spec.parent else: - _warnings.warn( # noqa: G010 + _warnings.warn( "can't resolve package from __spec__ or __package__, " "falling back on __name__ and __path__", ImportWarning, diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index 1cb473e0aed68..a42af10794fa5 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -154,6 +154,7 @@ def __init__( # used for torch.serialization._load self.Unpickler = lambda *args, **kwargs: PackageUnpickler(self, *args, **kwargs) + # pyrefly: ignore [bad-override] def import_module(self, name: str, package=None): """Load a module from the package if it hasn't already been loaded, and then return the module. Modules are loaded locally diff --git a/torch/profiler/_trace_validator.py b/torch/profiler/_trace_validator.py new file mode 100644 index 0000000000000..6fcd8f032d402 --- /dev/null +++ b/torch/profiler/_trace_validator.py @@ -0,0 +1,333 @@ +# mypy: allow-untyped-defs +""" +Validates Chrome traces emitted by ``torch.profiler`` against rules derived +from production issues. + +Usage:: + + from torch.profiler._trace_validator import validate_trace + + passed, violations = validate_trace("trace.pt.trace.json") + for v in violations: + print(v) +""" + +from __future__ import annotations + +import dataclasses +import gzip +import json +from collections import defaultdict +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Callable + + +@dataclasses.dataclass +class Violation: + """A single rule violation found in a trace.""" + + rule_name: str + message: str + + def __str__(self) -> str: + return f"{self.rule_name}: {self.message}" + + +def _load_events(path: str) -> list[dict]: + opener = gzip.open if path.endswith(".gz") else open + with opener(path, "rt", encoding="utf-8") as fh: + data = json.load(fh) + events = data if isinstance(data, list) else data.get("traceEvents", []) + return [e for e in events if isinstance(e, dict)] + + +def _check_gpu_kernel_causality(events: list[dict]) -> list[Violation]: + """For each (cudaLaunchKernel, GPU kernel) pair matched by External id, + the GPU kernel must start at or after its cudaLaunchKernel.""" + cpu_launches: dict[int, dict] = {} + gpu_kernels: dict[int, dict] = {} + + for ev in events: + if ev.get("ph") != "X": + continue + args = ev.get("args", {}) + ext_id = args.get("External id") + if ext_id is None: + continue + ext_id = int(ext_id) + cat, name, ts = ev.get("cat", ""), ev.get("name", ""), float(ev.get("ts", 0)) + corr = args.get("correlation") + + if cat == "cuda_runtime" and name == "cudaLaunchKernel": + if ext_id not in cpu_launches or ts < cpu_launches[ext_id]["ts"]: + cpu_launches[ext_id] = {"ts": ts, "name": name, "corr": corr} + elif cat == "kernel": + if ext_id not in gpu_kernels or ts < gpu_kernels[ext_id]["ts"]: + gpu_kernels[ext_id] = {"ts": ts, "name": name, "corr": corr} + + violations = [] + for ext_id, gpu in gpu_kernels.items(): + launch = cpu_launches.get(ext_id) + if launch is None: + continue + if gpu["ts"] < launch["ts"]: + skew = launch["ts"] - gpu["ts"] + violations.append( + Violation( + rule_name="_check_gpu_kernel_causality", + message=( + f"GPU kernel '{gpu['name']}' (External id={ext_id}, " + f"correlation={gpu['corr']}) starts {skew:.1f}us before " + f"its cudaLaunchKernel (External id={ext_id}, " + f"correlation={launch['corr']}), " + f"gpu_ts={gpu['ts']:.1f}, cpu_ts={launch['ts']:.1f}" + ), + ) + ) + return violations + + +def _check_stream_wait_corr_id_populated(events: list[dict]) -> list[Violation]: + """Stream Wait Events and Event Synchronize must have + wait_on_cuda_event_record_corr_id >= 0.""" + TARGET_KINDS = {"Stream Wait Event", "Event Sync"} + + violations = [] + for ev in events: + if ev.get("ph") != "X": + continue + args = ev.get("args", {}) + sync_kind = args.get("cuda_sync_kind") + if sync_kind not in TARGET_KINDS: + continue + raw_corr = args.get("wait_on_cuda_event_record_corr_id") + if raw_corr is None or int(raw_corr) < 0: + ts = float(ev.get("ts", 0)) + violations.append( + Violation( + rule_name="_check_stream_wait_corr_id_populated", + message=( + f"'{sync_kind}' event at ts={ts:.1f}us on " + f"device={args.get('device')} stream={args.get('stream')} " + f"has invalid wait_on_cuda_event_record_corr_id={raw_corr!r}" + ), + ) + ) + return violations + + +def _check_stream_sync_overlap(events: list[dict]) -> list[Violation]: + """For each Stream Synchronize on (device, stream), no kernel on that + stream should still be running when the sync starts.""" + stream_syncs = [] + for ev in events: + if ( + ev.get("ph") == "X" + and ev.get("cat") == "cuda_sync" + and ev.get("args", {}).get("cuda_sync_kind") == "Stream Sync" + ): + args = ev.get("args", {}) + stream_syncs.append( + { + "ts": float(ev.get("ts", 0)), + "dur": float(ev.get("dur", 0)), + "stream": args.get("stream"), + "device": args.get("device"), + } + ) + if not stream_syncs: + return [] + + kernels_by_stream: dict[tuple, list[dict]] = defaultdict(list) + for ev in events: + if ev.get("ph") == "X" and ev.get("cat") == "kernel": + args = ev.get("args", {}) + key = (args.get("device"), args.get("stream")) + ts = float(ev.get("ts", 0)) + kernels_by_stream[key].append( + { + "ts": ts, + "end": ts + float(ev.get("dur", 0)), + "name": ev.get("name", ""), + } + ) + + violations = [] + for sync in stream_syncs: + key = (sync["device"], sync["stream"]) + for k in kernels_by_stream.get(key, []): + if k["ts"] < sync["ts"] < k["end"]: + overlap = k["end"] - sync["ts"] + violations.append( + Violation( + rule_name="_check_stream_sync_overlap", + message=( + f"StreamSynchronize on device={sync['device']} " + f"stream={sync['stream']} at ts={sync['ts']:.1f}us " + f"but kernel '{k['name']}' (ends {k['end']:.1f}us) " + f"is still running ({overlap:.1f}us overlap)" + ), + ) + ) + return violations + + +_CUDA_EVENT_RECORD_NAMES = { + "cudaEventRecord", + "cudaEventRecord_ptsz", + "cudaEventRecordWithFlags", + "cudaEventRecordWithFlags_ptsz", +} + + +def _check_stream_wait_corr_id_in_past(events: list[dict]) -> list[Violation]: + """wait_on_cuda_event_record_corr_id must point to a cudaEventRecord + with cudaEventRecord.ts <= stream_wait.ts.""" + event_record_ts: dict[int, float] = {} + for ev in events: + if ( + ev.get("ph") == "X" + and ev.get("cat") in ("cuda_runtime", "cuda_driver") + and ev.get("name") in _CUDA_EVENT_RECORD_NAMES + ): + args = ev.get("args", {}) + ts = float(ev.get("ts", 0)) + for field in ("External id", "correlation"): + cid = args.get(field) + if cid is not None: + cid = int(cid) + if cid not in event_record_ts or ts < event_record_ts[cid]: + event_record_ts[cid] = ts + + violations = [] + for ev in events: + if ev.get("ph") != "X": + continue + args = ev.get("args", {}) + if args.get("cuda_sync_kind") != "Stream Wait Event": + continue + ref = args.get("wait_on_cuda_event_record_corr_id") + if ref is None or int(ref) < 0: + continue + ref = int(ref) + sw_ts = float(ev.get("ts", 0)) + record_ts = event_record_ts.get(ref) + + if record_ts is None: + violations.append( + Violation( + rule_name="_check_stream_wait_corr_id_in_past", + message=( + f"Stream Wait Event at ts={sw_ts:.1f}us references " + f"corr_id={ref} but no matching cudaEventRecord in trace" + ), + ) + ) + elif record_ts > sw_ts: + lag = record_ts - sw_ts + violations.append( + Violation( + rule_name="_check_stream_wait_corr_id_in_past", + message=( + f"Stream Wait Event at ts={sw_ts:.1f}us references " + f"cudaEventRecord (corr_id={ref}) {lag:.1f}us in the future " + f"(event_record_ts={record_ts:.1f})" + ), + ) + ) + return violations + + +_NCCL_REQUIRED_FIELDS = { + "Collective name", + "dtype", + "In msg nelems", + "Out msg nelems", + "Group size", +} + + +def _check_nccl_metadata(events: list[dict]) -> list[Violation]: + """record_param_comms events must carry: Collective name, dtype, + In msg nelems, Out msg nelems, Group size.""" + violations = [] + for ev in events: + if ev.get("ph") != "X" or ev.get("name") != "record_param_comms": + continue + args = ev.get("args", {}) + missing = _NCCL_REQUIRED_FIELDS - set(args.keys()) + if missing: + violations.append( + Violation( + rule_name="_check_nccl_metadata", + message=( + f"'record_param_comms' at ts={float(ev.get('ts', 0)):.1f}us " + f"missing metadata: {sorted(missing)}" + ), + ) + ) + return violations + + +def _check_backward_seq_id_uniqueness(events: list[dict]) -> list[Violation]: + """Per Sequence number, at most one distinct backward op name.""" + seq_to_ops: dict[int, list[str]] = defaultdict(list) + for ev in events: + if ev.get("ph") != "X": + continue + name = ev.get("name", "") + if "autograd::engine::evaluate_function:" not in name: + continue + args = ev.get("args", {}) + seq = args.get("Sequence number") or args.get("seq_num") + if seq is None: + continue + seq = int(seq) + op = name.split(":", 1)[-1].strip() if ":" in name else name + if op not in seq_to_ops[seq]: + seq_to_ops[seq].append(op) + + violations = [] + for seq, ops in seq_to_ops.items(): + if len(ops) > 1: + violations.append( + Violation( + rule_name="_check_backward_seq_id_uniqueness", + message=( + f"Sequence number {seq} shared by {len(ops)} backward " + f"ops: {ops}" + ), + ) + ) + return violations + + +_RULES: list[Callable[[list[dict]], list[Violation]]] = [ + _check_gpu_kernel_causality, + _check_stream_wait_corr_id_populated, + _check_stream_sync_overlap, + _check_stream_wait_corr_id_in_past, + _check_nccl_metadata, + _check_backward_seq_id_uniqueness, +] + + +def validate_trace(trace_path: str) -> tuple[bool, list[Violation]]: + """ + Run all validation rules against a Chrome trace JSON file. + + Args: + trace_path: Path to ``.pt.trace.json`` or ``.pt.trace.json.gz``. + + Returns: + A ``(passed, violations)`` tuple. ``passed`` is ``True`` when no + violations were found. + """ + events = _load_events(trace_path) + all_violations: list[Violation] = [] + for rule in _RULES: + all_violations.extend(rule(events)) + return len(all_violations) == 0, all_violations diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 2dc5a63cb33f4..6cda46034a218 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -401,6 +401,8 @@ def key_averages( """Averages events, grouping them by operator name and (optionally) input shapes, stack and overload name. + Returns an :class:`~torch.autograd.profiler_util.EventList` of the aggregated events. + .. note:: To use shape/stack functionality make sure to set record_shapes/with_stack when creating profiler context manager. @@ -415,8 +417,8 @@ def key_averages( def events(self): """ - Returns the list of unaggregated profiler events, - to be used in the trace callback or after the profiling is finished + Return the list of unaggregated :class:`~torch.autograd.profiler_util.FunctionEvent` + objects, for use in the trace callback or after profiling has finished. """ if self.profiler is None: raise AssertionError("Profiler must be initialized before accessing events") @@ -665,8 +667,10 @@ class profile(_KinetoProfile): The same activity group must not appear more than once. schedule (Callable): callable that takes step (int) as a single parameter and returns ``ProfilerAction`` value that specifies the profiler action to perform at each step. - on_trace_ready (Callable): callable that is called at each step when ``schedule`` - returns ``ProfilerAction.RECORD_AND_SAVE`` during the profiling. + on_trace_ready (Callable): callable invoked at the end of each profiling cycle + (when ``schedule`` returns ``ProfilerAction.RECORD_AND_SAVE``). Receives the + :class:`profile` instance as its only argument, typically used to export the + trace (e.g. via :meth:`export_chrome_trace`) or print a summary. record_shapes (bool): save information about operator's input shapes. profile_memory (bool): track tensor memory allocation/deallocation. with_stack (bool): record source information (file and line number) for the ops. @@ -689,6 +693,9 @@ class profile(_KinetoProfile): post_processing_timeout_s (float): Optional timeout in seconds for post-processing profiler results. If specified, event parsing will stop after this duration and return partial results. Useful for handling large traces that may take too long to process. + custom_trace_id_callback (Callable[[], str], optional): User-supplied trace ID generator, + invoked once per profiling cycle. Defaults to a random UUID; retrieve via + :meth:`get_trace_id`. use_cuda (bool): .. deprecated:: 1.8.1 use ``activities`` instead. @@ -970,7 +977,8 @@ def step(self) -> None: def set_custom_trace_id_callback(self, callback) -> None: """ - Sets a callback to be called when a new trace ID is generated. + Set the trace ID generator. Called at the start of each cycle, so updating + it between cycles yields distinct IDs per cycle. """ self.custom_trace_id_callback = callback diff --git a/torch/quantization/fx/__init__.py b/torch/quantization/fx/__init__.py index c01cbd457374c..d7ca02d00ee17 100644 --- a/torch/quantization/fx/__init__.py +++ b/torch/quantization/fx/__init__.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r""" This file is in the process of migration to `torch/ao/quantization`, and is kept here for compatibility while the migration process is ongoing. diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 89f8d4406e912..cf82b18776a09 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -1,4 +1,3 @@ -# flake8: noqa: F401 r""" This file is in the process of migration to `torch/ao/quantization`, and is kept here for compatibility while the migration process is ongoing. diff --git a/torch/return_types.py b/torch/return_types.py index d456742be4b88..4c74fe01537e7 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -39,6 +39,7 @@ def structseq_unflatten(values, context): globals()[name] = _attr if not name.startswith("_"): + # pyrefly: ignore [unresolvable-dunder-all] __all__.append(name) all_return_types.append(_attr) diff --git a/torch/serialization.py b/torch/serialization.py index f6ebe0fce7e4d..32314ae0723e3 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1384,7 +1384,7 @@ def load( second step is a no-op if the final location is CPU. When the ``mmap`` flag is set, instead of copying the tensor storages from disk to CPU memory in the first step, ``f`` is mapped, which means tensor storages will be lazily loaded when their data is accessed. - pickle_load_args: (Python 3 only) optional keyword arguments passed over to + pickle_load_args: optional keyword arguments passed over to :func:`pickle_module.load` and :func:`pickle_module.Unpickler`, only works if :attr:`weights_only=False`, e.g., :attr:`errors=...`. diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py index 1e71d8fc4d5c2..78846265e9669 100644 --- a/torch/sparse/__init__.py +++ b/torch/sparse/__init__.py @@ -642,20 +642,27 @@ def convert_to_strided_representation(args): if obj.layout is torch.sparse_coo: # pyrefly: ignore [no-matching-overload] d.update( - indices=obj._indices(), is_coalesced=obj.is_coalesced() + # pyrefly: ignore [bad-argument-type] + indices=obj._indices(), + # pyrefly: ignore [bad-argument-type] + is_coalesced=obj.is_coalesced(), ) values = obj._values() elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: # pyrefly: ignore [no-matching-overload] d.update( + # pyrefly: ignore [bad-argument-type] compressed_indices=obj.crow_indices(), + # pyrefly: ignore [bad-argument-type] plain_indices=obj.col_indices(), ) values = obj.values() else: # pyrefly: ignore [no-matching-overload] d.update( + # pyrefly: ignore [bad-argument-type] compressed_indices=obj.ccol_indices(), + # pyrefly: ignore [bad-argument-type] plain_indices=obj.row_indices(), ) values = obj.values() diff --git a/torch/sparse/_semi_structured_ops.py b/torch/sparse/_semi_structured_ops.py index 6e12d76fc8abc..0254a33a4bc62 100644 --- a/torch/sparse/_semi_structured_ops.py +++ b/torch/sparse/_semi_structured_ops.py @@ -16,6 +16,8 @@ "semi_sparse_linear", "semi_sparse_scaled_mm", "semi_sparse_clone", + "semi_sparse_to", + "semi_sparse_to_copy", ] @@ -137,20 +139,14 @@ def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor: "`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented" ) if isinstance(A, torch.sparse.SparseSemiStructuredTensor): - row, col = B.shape - B_padded = A._pad_dense_input(B) - res = A._mm(B_padded) - return res[:, :col] + return A._mm(B) else: B_t = B.t() if not isinstance(B_t, torch.sparse.SparseSemiStructuredTensor): raise AssertionError( f"expected SparseSemiStructuredTensor, got {type(B_t).__name__}" ) - row, col = A.shape - A_padded = B._pad_dense_input(A) - res = B_t._mm(A_padded.t()).t() - return res[:row, :] + return B_t._mm(A, should_transpose_dense=True).t() def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: @@ -175,9 +171,7 @@ def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor: f"expected SparseSemiStructuredTensor, got {type(B_t).__name__}" ) row, _col = A.shape - A_padded = B_t._pad_dense_input(A) - result = B_t._mm(A_padded.t(), bias=bias).t() - return result[:row, :] + return B_t._mm(A, bias=bias, should_transpose_dense=True).t() def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: @@ -188,7 +182,6 @@ def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: shape = A.shape A_2d = A.view(-1, shape[-1]) - if bias is None: res = A_2d @ B.t() else: @@ -197,7 +190,6 @@ def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor: types=None, args=[bias, A_2d, B.t()], ) - return res.view(*shape[:-1], -1) @@ -263,3 +255,46 @@ def semi_sparse_clone(func, types, args=(), kwargs=None) -> torch.Tensor: alg_id_cusparselt=self.alg_id_cusparselt, requires_grad=self.requires_grad, ) + + +def semi_sparse_to_copy(func, types, args, kwargs=None) -> torch.Tensor: + self = args[0] + kwargs = kwargs or {} + + device = kwargs.get("device", None) + + if device is not None and torch.device(device).type == "cpu": + dense = self.to_dense() + return func(dense, **kwargs) + + raise NotImplementedError( + f"`_to_copy()` with kwargs={kwargs} is not implemented " + "for SparseSemiStructuredTensor. Only converting to CPU is supported currently." + ) + + +def semi_sparse_to(func, types, args, kwargs=None) -> torch.Tensor: + self = args[0] + remaining_args = args[1:] + kwargs = kwargs or {} + + # Determine the target device from args/kwargs + device = None + if remaining_args: + first_arg = remaining_args[0] + if isinstance(first_arg, (torch.device, str)): + try: + device = torch.device(first_arg) + except RuntimeError: + pass + if "device" in kwargs: + device = torch.device(kwargs["device"]) + + if device is not None and device.type == "cpu": + dense = self.to_dense() + return func(dense, *remaining_args, **kwargs) + + raise NotImplementedError( + f"`to()` with args={remaining_args}, kwargs={kwargs} is not implemented " + "for SparseSemiStructuredTensor. Only `to('cpu')` is supported currently." + ) diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 106306c6e4cde..c67ea08fa3490 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -9,7 +9,7 @@ from torch._dynamo.utils import warn_once from torch.utils._triton import has_triton -from ._triton_ops_meta import get_meta +from ._triton_ops_meta import _get_device_name, get_meta TORCH_SPARSE_BSR_SCATTER_MM_LRU_CACHE_SIZE = int( @@ -31,7 +31,7 @@ def check_bsr_layout(f_name, t): def check_device(f_name, t, device): check( - t.device == device and t.device.type == "cuda", + t.device == device and t.device.type in ("cuda", "xpu"), f"{f_name}(): all inputs are expected to be on the same GPU device.", ) @@ -529,7 +529,7 @@ def scatter_mm_meta( **extra, ): if {TILE_M, TILE_N, SPLIT_N, num_warps, num_stages, GROUP_SIZE} == {None}: - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() meta = get_meta( "scatter_mm", (M, K, N, Ms, Ks), @@ -552,28 +552,28 @@ def scatter_mm_meta( TILE_N = 16 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (32, 32): SPLIT_N = 2 TILE_M = 32 TILE_N = 16 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (64, 64): SPLIT_N = 1 TILE_M = 32 TILE_N = 32 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (128, 128): SPLIT_N = 1 TILE_M = 32 TILE_N = 32 GROUP_SIZE = 2 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (M, K, N) == (512,) * 3: if (Ms, Ks) == (16, 16): SPLIT_N = 8 @@ -581,28 +581,28 @@ def scatter_mm_meta( TILE_N = 64 GROUP_SIZE = 2 num_stages = 1 - num_warps = 2 # noqa: E225,E231,E702 + num_warps = 2 # noqa: E225, E231 elif (Ms, Ks) == (32, 32): SPLIT_N = 8 TILE_M = 32 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 2 # noqa: E225,E231,E702 + num_warps = 2 # noqa: E225, E231 elif (Ms, Ks) == (64, 64): SPLIT_N = 4 TILE_M = 32 TILE_N = 128 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (128, 128): SPLIT_N = 8 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (M, K, N) == (1024,) * 3: if (Ms, Ks) == (16, 16): SPLIT_N = 4 @@ -610,35 +610,35 @@ def scatter_mm_meta( TILE_N = 128 GROUP_SIZE = 2 num_stages = 1 - num_warps = 1 # noqa: E225,E231,E702 + num_warps = 1 # noqa: E225, E231 elif (Ms, Ks) == (32, 32): SPLIT_N = 8 TILE_M = 32 TILE_N = 64 GROUP_SIZE = 2 num_stages = 1 - num_warps = 1 # noqa: E225,E231,E702 + num_warps = 1 # noqa: E225, E231 elif (Ms, Ks) == (64, 64): SPLIT_N = 16 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 2 # noqa: E225,E231,E702 + num_warps = 2 # noqa: E225, E231 elif (Ms, Ks) == (128, 128): SPLIT_N = 16 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (256, 256): SPLIT_N = 16 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 2 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (M, K, N) == (2048,) * 3: if (Ms, Ks) == (16, 16): SPLIT_N = 4 @@ -646,35 +646,35 @@ def scatter_mm_meta( TILE_N = 128 GROUP_SIZE = 8 num_stages = 1 - num_warps = 1 # noqa: E225,E231,E702 + num_warps = 1 # noqa: E225, E231 elif (Ms, Ks) == (32, 32): SPLIT_N = 4 TILE_M = 32 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 1 # noqa: E225,E231,E702 + num_warps = 1 # noqa: E225, E231 elif (Ms, Ks) == (64, 64): SPLIT_N = 4 TILE_M = 64 TILE_N = 128 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (128, 128): SPLIT_N = 8 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 4 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (Ms, Ks) == (256, 256): SPLIT_N = 4 TILE_M = 64 TILE_N = 64 GROUP_SIZE = 2 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 elif (M, K, N) == (4096,) * 3: if (Ms, Ks) == (16, 16): SPLIT_N = 2 @@ -682,21 +682,21 @@ def scatter_mm_meta( TILE_N = 256 GROUP_SIZE = 2 num_stages = 1 - num_warps = 2 # noqa: E225,E231,E702 + num_warps = 2 # noqa: E225, E231 elif (Ms, Ks) == (32, 32): SPLIT_N = 2 TILE_M = 32 TILE_N = 64 GROUP_SIZE = 2 num_stages = 1 - num_warps = 1 # noqa: E225,E231,E702 + num_warps = 1 # noqa: E225, E231 elif (Ms, Ks) == (64, 64): SPLIT_N = 2 TILE_M = 64 TILE_N = 128 GROUP_SIZE = 2 num_stages = 1 - num_warps = 4 # noqa: E225,E231,E702 + num_warps = 4 # noqa: E225, E231 if SPLIT_N is None: # Assume NVIDIA GeForce RTX 2060 SUPER: @@ -785,7 +785,7 @@ def bsr_dense_addmm_meta( if sparsity is None: sparsity = 0.5 if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) if dtype is out_dtype: version_dtype = dtype diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 64f9ca15ca8da..7ced55af3b335 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -111,6 +111,15 @@ from torch.testing import make_tensor +def _get_device_name() -> str: + """Return the current accelerator device name for use as a Triton tuning cache key.""" + if torch.cuda.is_available(): + return torch.cuda.get_device_name() + if torch.xpu.is_available(): + return torch.xpu.get_device_name() + return "" + + def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=False): """Return triton kernel meta parameters of the specified op and its inputs key. @@ -135,7 +144,7 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F mappings that match with the given `key`. """ if device_name is None: - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() op_data = _operation_device_version_data.get((op, device_name, version)) if op_data is None and not exact: @@ -144,7 +153,10 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F # meta parameters have been computed. In the following we'll # assume that there is a set of GPU models that all have # a similar set of optimal meta parameters. - if re.match(r"NVIDIA A100[^\d]", device_name) is not None: + if ( + device_name is not None + and re.match(r"NVIDIA A100[^\d]", device_name) is not None + ): device_name = "NVIDIA A100-SXM4-80GB" else: return @@ -483,7 +495,7 @@ def optimize_scatter_mm( key = (m, k, n, bm, bk) version = (0, dtype, sparsity) - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() reference_meta = dict( GROUP_SIZE=1, @@ -576,7 +588,7 @@ def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=b print(f"{meta=} {speedup=:.1f} % {timing=:.3f} ms") if speedup < 0: return - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() update( "scatter_mm", device_name, version, key, tuple(meta[k] for k in sorted(meta)) @@ -738,7 +750,7 @@ def step_meta_parameter(name, value, direction, meta, M=M, N=N, K=K, BM=BM, BK=B if store and not ( may_skip_update and meta == initial_meta and initial_meta is not reference_meta ): - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() update( opname, device_name, @@ -937,7 +949,7 @@ def test_func(): print(sparsity1, index, key, meta_lst, speeddiff) if index > 0: - device_name = torch.cuda.get_device_name() + device_name = _get_device_name() meta = get_meta( op, key, version=(0, dtype, meta_lst[0][1]), exact=True ) diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 7e9fb868c335f..d56259548f8af 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -19,6 +19,8 @@ semi_sparse_mm, semi_sparse_scaled_mm, semi_sparse_t, + semi_sparse_to, + semi_sparse_to_copy, semi_sparse_values, semi_sparse_view, ) @@ -75,7 +77,7 @@ class SparseSemiStructuredTensor(torch.Tensor): __slots__ = ["packed", "meta", "packed_t", "meta_t", "compressed_swizzled_bitmask"] @staticmethod - def __new__( # noqa: PYI034 + def __new__( cls, shape: torch.Size, packed: torch.Tensor | None, @@ -231,9 +233,10 @@ def _load_dispatch_table(cls, custom_dispatch_table=None) -> None: torch.ops.aten.matmul: semi_sparse_mm, torch.ops.aten.addmm: semi_sparse_addmm, torch.ops.aten.linear: semi_sparse_linear, - torch.ops.aten._to_copy: fallback_dispatcher, + torch.ops.aten._to_copy: semi_sparse_to_copy, torch.ops.aten._scaled_mm: semi_sparse_scaled_mm, torch.ops.aten.clone: semi_sparse_clone, + torch.ops.aten.to: semi_sparse_to, } if custom_dispatch_table is not None: cls.SPARSE_DISPATCH.update(custom_dispatch_table) @@ -281,35 +284,16 @@ def _validate_device_dim_dtype_shape(cls, original_tensor: torch.Tensor) -> None f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})" ) - @classmethod - def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor: - """ - Calculates padding for dense tensor and pads tensor if necessary. - If padding is not required, this function returns the original tensor. - """ - # only 2d matmul - if dense_input.dim() != 2: - raise AssertionError(f"dense_input must be 2D, got {dense_input.dim()}D") - - # check shape - m, n = dense_input.shape - min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows - min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols - - # calculate padding - to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0 - to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0 - if to_pad_m or to_pad_n: - return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m)) - else: - return dense_input - def to_dense(self): # type:ignore[override] col = self.shape[-1] return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device)) @classmethod - def from_dense(cls, original_tensor: torch.Tensor) -> "SparseSemiStructuredTensor": + def from_dense( + cls, + original_tensor: torch.Tensor, + alg_id: int = _DEFAULT_ALG_ID, + ) -> "SparseSemiStructuredTensor": raise NotImplementedError def _mm( @@ -325,6 +309,7 @@ def _mm( def to_sparse_semi_structured( original_tensor: torch.Tensor, transposed: bool = False, + alg_id: int = SparseSemiStructuredTensor._DEFAULT_ALG_ID, ) -> SparseSemiStructuredTensor: """ This function converts a dense tensor into a sparse semi-structured tensor. @@ -338,6 +323,8 @@ def to_sparse_semi_structured( Args: original_tensor (Tensor): the dense tensor to convert transposed (bool, optional): deprecated arg to be removed in another release. Do not use. + alg_id (int, optional): the algorithm id to use for cuSPARSELt matmul. Defaults to 0. + Can be obtained via ``torch._cslt_sparse_mm_search``. Returns: SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor Raises: @@ -387,7 +374,7 @@ def to_sparse_semi_structured( else torch.sparse.SparseSemiStructuredTensorCUSPARSELT ) - return SPARSE_SUBCLASS.from_dense(original_tensor) + return SPARSE_SUBCLASS.from_dense(original_tensor, alg_id=alg_id) class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): @@ -412,7 +399,9 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): @classmethod def from_dense( - cls, original_tensor: torch.Tensor + cls, + original_tensor: torch.Tensor, + alg_id: int = SparseSemiStructuredTensor._DEFAULT_ALG_ID, ) -> "SparseSemiStructuredTensorCUTLASS": cls._validate_device_dim_dtype_shape(original_tensor) ( @@ -519,7 +508,12 @@ def prune_dense_static_sort( ) def _mm( - self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs + self, + B: torch.Tensor, + *, + bias: torch.Tensor | None = None, + should_transpose_dense: bool = False, + **kwargs, ) -> torch.Tensor: if isinstance(B, SparseSemiStructuredTensor): raise ValueError( @@ -535,13 +529,18 @@ def _mm( f"`{cls_name}` matmul: operation is not supported" ) else: - if bias is None: - res = torch._sparse_semi_structured_mm(self.packed, self.meta, B) - else: - res = torch._sparse_semi_structured_addmm( - bias, self.packed, self.meta, B - ) - return res[: self.shape[0]] + _ensure_cutlass_mm_registered() + constraints = self._DTYPE_SHAPE_CONSTRAINTS[B.dtype] + return torch.ops.semi_structured.cutlass_mm( + B, + self.packed, + self.meta, + bias, + self.shape[0], + constraints.dense_min_rows, + constraints.dense_min_cols, + should_transpose_dense, + ) class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): @@ -566,7 +565,9 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor): @classmethod def from_dense( - cls, original_tensor: torch.Tensor + cls, + original_tensor: torch.Tensor, + alg_id: int = SparseSemiStructuredTensor._DEFAULT_ALG_ID, ) -> "SparseSemiStructuredTensorCUSPARSELT": cls._validate_device_dim_dtype_shape(original_tensor) # pyrefly: ignore [no-matching-overload] @@ -578,7 +579,7 @@ def from_dense( meta_t=None, compressed_swizzled_bitmask=None, fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE, - alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID, + alg_id_cusparselt=alg_id, requires_grad=original_tensor.requires_grad, ) @@ -648,7 +649,12 @@ def prune_dense_static_sort( ) def _mm( - self, B: torch.Tensor, *, bias: torch.Tensor | None = None, **kwargs + self, + B: torch.Tensor, + *, + bias: torch.Tensor | None = None, + should_transpose_dense: bool = False, + **kwargs, ) -> torch.Tensor: if isinstance(B, SparseSemiStructuredTensor): raise ValueError( @@ -682,11 +688,153 @@ def _mm( f"`{self.__class__.__name__}` matmul: operation is not supported" ) else: - res = torch._cslt_sparse_mm( - self.packed, + _ensure_cusparselt_mm_registered() + constraints = self._DTYPE_SHAPE_CONSTRAINTS[B.dtype] + return torch.ops.semi_structured.cusparselt_mm( B, - bias=bias, - transpose_result=self.fuse_transpose_cusparselt, - alg_id=self.alg_id_cusparselt, + self.packed, + bias, + self.shape[0], + constraints.dense_min_rows, + constraints.dense_min_cols, + self.fuse_transpose_cusparselt, + self.alg_id_cusparselt, + should_transpose_dense, ) - return res.t() if self.fuse_transpose_cusparselt else res + + +_cutlass_mm_registered = False + + +def _ensure_cutlass_mm_registered(): + """Lazily register the cutlass_mm custom op. + + Registration is deferred to avoid importing torch.library at module load + time, since torch.sparse is imported early during ``import torch``. + """ + global _cutlass_mm_registered + if _cutlass_mm_registered: + return + _cutlass_mm_registered = True + + from torch.library import custom_op + + @custom_op("semi_structured::cutlass_mm", mutates_args=()) + def cutlass_mm( + dense: torch.Tensor, + packed: torch.Tensor, + meta: torch.Tensor, + bias: torch.Tensor | None, + out_features: int, + min_rows: int, + min_cols: int, + should_transpose_dense: bool, + ) -> torch.Tensor: + m, n = dense.shape + to_pad_m = (-m) % min_rows + to_pad_n = (-n) % min_cols + need_pad = to_pad_m != 0 or to_pad_n != 0 + dense_padded = dense + if need_pad: + dense_padded = torch.nn.functional.pad(dense, (0, to_pad_n, 0, to_pad_m)) + mm_input = dense_padded.t() if should_transpose_dense else dense_padded + if bias is None: + res = torch._sparse_semi_structured_mm(packed, meta, mm_input) + else: + res = torch._sparse_semi_structured_addmm(bias, packed, meta, mm_input) + if need_pad: + out_cols = m if should_transpose_dense else n + return ( + res[:out_features] + .narrow(1, 0, out_cols) + .clone(memory_format=torch.contiguous_format) + ) + return res.contiguous() + + @cutlass_mm.register_fake + def _cutlass_mm_fake( + dense: torch.Tensor, + packed: torch.Tensor, + meta: torch.Tensor, + bias: torch.Tensor | None, + out_features: int, + min_rows: int, + min_cols: int, + transpose_dense: bool, + ) -> torch.Tensor: + out_cols = dense.shape[0] if transpose_dense else dense.shape[1] + return torch.empty( + out_features, + out_cols, + dtype=dense.dtype, + device=dense.device, + ) + + +_cusparselt_mm_registered = False + + +def _ensure_cusparselt_mm_registered(): + """Lazily register the cusparselt_mm custom op.""" + global _cusparselt_mm_registered + if _cusparselt_mm_registered: + return + _cusparselt_mm_registered = True + + from torch.library import custom_op + + @custom_op("semi_structured::cusparselt_mm", mutates_args=()) + def cusparselt_mm( + dense: torch.Tensor, + packed: torch.Tensor, + bias: torch.Tensor | None, + out_features: int, + min_rows: int, + min_cols: int, + fuse_transpose: bool, + alg_id: int, + should_transpose_dense: bool = False, + ) -> torch.Tensor: + m, n = dense.shape + to_pad_m = (-m) % min_rows + to_pad_n = (-n) % min_cols + need_pad = to_pad_m != 0 or to_pad_n != 0 + dense_padded = dense + if need_pad: + dense_padded = torch.nn.functional.pad(dense, (0, to_pad_n, 0, to_pad_m)) + mm_input = dense_padded.t() if should_transpose_dense else dense_padded + res = torch._cslt_sparse_mm( + packed, + mm_input, + bias=bias, + transpose_result=fuse_transpose, + alg_id=alg_id, + ) + if fuse_transpose: + res = res.t() + if need_pad: + out_cols = m if should_transpose_dense else n + return res.narrow(1, 0, out_cols).clone( + memory_format=torch.contiguous_format + ) + return res.contiguous() + + @cusparselt_mm.register_fake + def _cusparselt_mm_fake( + dense: torch.Tensor, + packed: torch.Tensor, + bias: torch.Tensor | None, + out_features: int, + min_rows: int, + min_cols: int, + fuse_transpose: bool, + alg_id: int, + should_transpose_dense: bool, + ) -> torch.Tensor: + out_cols = dense.shape[0] if should_transpose_dense else dense.shape[1] + return torch.empty( + out_features, + out_cols, + dtype=dense.dtype, + device=dense.device, + ) diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py index 2605411ed7c9d..80dfdc63347e5 100644 --- a/torch/testing/_comparison.py +++ b/torch/testing/_comparison.py @@ -18,6 +18,37 @@ HAS_NUMPY = False np = None # type: ignore[assignment] +_HAS_DTENSOR = torch.distributed.is_available() + + +def _unwrap_dtensor_for_comparison(actual, expected): + """Handle DTensor inputs for assertEqual/assert_close.""" + if not _HAS_DTENSOR: + return actual, expected + from torch.distributed.tensor import DTensor + + actual_dt = isinstance(actual, DTensor) + expected_dt = isinstance(expected, DTensor) + if actual_dt and expected_dt: + if actual.placements != expected.placements: + raise AssertionError( + f"DTensor placements do not match: " + f"{actual.placements} != {expected.placements}" + ) + if actual.device_mesh != expected.device_mesh: + raise AssertionError( + f"DTensor device meshes do not match: " + f"{actual.device_mesh} != {expected.device_mesh}" + ) + return actual.to_local(), expected.to_local() + elif actual_dt != expected_dt: + raise TypeError( + "Comparing a DTensor to a non-DTensor is ambiguous. " + "Call .full_tensor() to compare the full logical tensor " + "or .to_local() to compare the local shard." + ) + return actual, expected + class ErrorMeta(Exception): """Internal testing exception that makes that carries error metadata.""" @@ -337,7 +368,7 @@ def unravel_flat_index(flat_index: int) -> tuple[int, ...]: ) -class UnsupportedInputs(Exception): # noqa: B903 +class UnsupportedInputs(Exception): """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs.""" @@ -1288,7 +1319,7 @@ def not_close_error_metas( ) except ErrorMeta as error_meta: # Explicitly raising from None to hide the internal traceback - raise error_meta.to_error() from None # noqa: RSE102 + raise error_meta.to_error() from None error_metas: list[ErrorMeta] = [] for pair in pairs: @@ -1573,6 +1604,8 @@ def assert_close( # Hide this function from `pytest`'s traceback __tracebackhide__ = True + actual, expected = _unwrap_dtensor_for_comparison(actual, expected) + error_metas = not_close_error_metas( actual, expected, diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 8575c0a75f77d..2449f76d3f82f 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -5,7 +5,7 @@ import functools import torch import torch.cuda -from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS +from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS, IS_MACOS, TEST_XPU import inspect import contextlib import os @@ -43,6 +43,7 @@ IS_SM89 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 9)) IS_SM90 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)) IS_SM100 = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() == (10, 0)) +IS_SM12X = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 12) @contextlib.contextmanager def blas_library_context(backend): @@ -76,8 +77,16 @@ def evaluate_platform_supports_flash_attention(): return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return not IS_WINDOWS and SM80OrLater + if TEST_XPU: + return True return False +def evaluate_platform_supports_ck_sdpa(): + if TEST_WITH_ROCM: + return torch.backends.cuda.is_ck_sdpa_available() + else: + return False + def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] @@ -86,6 +95,8 @@ def evaluate_platform_supports_efficient_attention(): return evaluate_gfx_arch_within(arch_list) if TEST_CUDA: return True + if TEST_XPU: + return True return False def evaluate_platform_supports_cudnn_attention(): @@ -111,12 +122,16 @@ def evaluate_platform_supports_green_context(): PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM +PLATFORM_SUPPORTS_CK_SDPA: bool = LazyVal(lambda: evaluate_platform_supports_ck_sdpa()) + def evaluate_platform_supports_bf16(): if torch.version.cuda: return SM80OrLater elif torch.version.hip: return True + elif TEST_XPU: + return True return False @@ -140,6 +155,18 @@ def evaluate_platform_supports_half_atomics(): PLATFORM_SUPPORTS_GREEN_CONTEXT: bool = LazyVal(lambda: evaluate_platform_supports_green_context()) +def evaluate_platform_supports_workqueue_config(): + if IS_WINDOWS: + return False + if not _get_torch_cuda_version() >= (13, 1): + return False + driver_version = torch.utils.collect_env.get_nvidia_driver_version(torch.utils.collect_env.run) + if driver_version is None: + return False + return int(driver_version.split('.')[0]) >= 590 + +PLATFORM_SUPPORTS_WORKQUEUE_CONFIG: bool = LazyVal(lambda: evaluate_platform_supports_workqueue_config()) + def evaluate_platform_supports_fp8(): if torch.cuda.is_available(): if torch.version.hip: @@ -151,9 +178,13 @@ def evaluate_platform_supports_fp8(): for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True + return False else: return SM90OrLater or torch.cuda.get_device_capability() == (8, 9) - return False + if torch.xpu.is_available(): + return True + # As CPU supports FP8 and is always available, return True. + return True def evaluate_platform_supports_fp8_grouped_gemm(): if torch.cuda.is_available(): @@ -183,8 +214,21 @@ def evaluate_platform_supports_mxfp8_grouped_gemm(): return built_with_mslk and IS_SM100 return False +def evaluate_platform_supports_fp8_sparse(): + if torch.cuda.is_available(): + if torch.version.hip: + return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName + else: + return ( + (SM90OrLater or torch.cuda.get_device_capability() == (8, 9)) + and torch.backends.cusparselt.is_available() + and torch.backends.cusparselt.version() >= 602 + ) + return False + PLATFORM_SUPPORTS_MX_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mx_gemm()) PLATFORM_SUPPORTS_FP8: bool = LazyVal(lambda: evaluate_platform_supports_fp8()) +PLATFORM_SUPPORTS_FP8_SPARSE: bool = LazyVal(lambda: evaluate_platform_supports_fp8_sparse()) PLATFORM_SUPPORTS_FP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_fp8_grouped_gemm()) PLATFORM_SUPPORTS_MXFP8_GROUPED_GEMM: bool = LazyVal(lambda: evaluate_platform_supports_mxfp8_grouped_gemm()) @@ -192,7 +236,7 @@ def evaluate_platform_supports_mxfp8_grouped_gemm(): try: import numba.cuda TEST_NUMBA_CUDA = numba.cuda.is_available() - except (ImportError, RuntimeError): + except (ImportError, RuntimeError, OSError): TEST_NUMBA_CUDA = False TEST_NUMBA = False else: @@ -442,6 +486,9 @@ def xfailIfSM100OrLater(func): def xfailIfSM120OrLater(func): return func if not SM120OrLater else unittest.expectedFailure(func) +def xfailIfSM12X(func): + return func if not IS_SM12X else unittest.expectedFailure(func) + def xfailIfDistributedNotSupported(func): return func if not (IS_MACOS or IS_JETSON) else unittest.expectedFailure(func) diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index e2762861ccdac..4dc5dc360df86 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -17,6 +17,7 @@ import torch from torch._inductor.utils import GPU_TYPES +from torch._utils import _is_privateuse1_backend_available from torch.testing._internal.common_cuda import ( _get_torch_cuda_version, _get_torch_hipblaslt_version, @@ -33,7 +34,6 @@ get_tracked_input, IS_FBCODE, IS_MACOS, - is_privateuse1_backend_available, IS_REMOTE_GPU, IS_S390X, IS_SANDCASTLE, @@ -318,6 +318,15 @@ def _update_param_kwargs(param_kwargs, name, value): class DeviceTypeTestBase(TestCase): device_type: str = "generic_device_type" + # When True, @onlyOn-based decorators (@onlyCUDA, @onlyMPS, etc.) will not + # skip tests for this device type. This is a pragmatic short-term solution to + # allow PrivateUse1 backends to run tests that are currently gated behind + # device-specific decorators. It is intended to be used together with the + # skip mechanism (see https://github.com/pytorch/pytorch/issues/177253). + # In the longer term, we are incrementally migrating accelerator tests to be + # device-generic and removing @onlyCUDA on tests that should be device-generic. + bypass_device_restrictions: bool = False + # Flag to disable test suite early due to unrecoverable error such as CUDA error. _stop_test_suite = False @@ -670,6 +679,7 @@ class PrivateUse1TestBase(DeviceTypeTestBase): primary_device: ClassVar[str] device_mod = None device_type = "privateuse1" + bypass_device_restrictions = False @classmethod def get_primary_device(cls): @@ -723,7 +733,7 @@ def get_device_type_test_bases(): if torch.cuda.is_available(): test_bases.append(CUDATestBase) - if is_privateuse1_backend_available(): + if _is_privateuse1_backend_available(): test_bases.append(PrivateUse1TestBase) # Disable MPS testing in generic device testing temporarily while we're # ramping up support. @@ -750,7 +760,7 @@ def filter_desired_device_types(device_type_test_bases, except_for=None, only_fo # This handles the case where PrivateUse1TestBase.device_type has been # changed from "privateuse1" to the actual backend name (e.g., "openreg") # by setUpClass being called during previous instantiate_device_type_tests calls - if is_privateuse1_backend_available(): + if _is_privateuse1_backend_available(): privateuse1_backend_name = torch._C._get_privateuse1_backend_name() def func_replace(x: str) -> str: @@ -1168,7 +1178,7 @@ def test_wrapper(*args, **kwargs): except Exception as e: tracked_input = get_tracked_input() if PRINT_REPRO_ON_FAILURE and tracked_input is not None: - e_tracked = Exception( # noqa: TRY002 + e_tracked = Exception( f"{str(e)}\n\nCaused by {tracked_input.type_desc} " f"at index {tracked_input.index}: " f"{_serialize_sample(tracked_input.val)}" @@ -1444,6 +1454,8 @@ def __call__(self, fn): @wraps(fn) def only_fn(slf, *args, **kwargs): if slf.device_type not in self.device_type: + if getattr(slf, "bypass_device_restrictions", False): + return fn(slf, *args, **kwargs) reason = f"Only runs on {self.device_type}" if IS_SANDCASTLE or IS_FBCODE: print( diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 86e36e5b094d0..78b27edaa82d4 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -609,7 +609,10 @@ def create_tcp_store( TIMEOUT_DEFAULT = 500 else: TIMEOUT_DEFAULT = int(os.getenv("DISTRIBUTED_TESTS_DEFAULT_TIMEOUT", "300")) -TIMEOUT_OVERRIDE = {"test_ddp_uneven_inputs": 400} +TIMEOUT_OVERRIDE = { + "test_ddp_uneven_inputs": 400, + "test_DistributedDataParallel": 500, +} # https://github.com/pytorch/pytorch/issues/75665 @@ -737,6 +740,32 @@ def cleanup_temp_dir() -> None: tmp_dir.cleanup() +def retrieve_result_from_completion_queue( + process: torch.multiprocessing.Process, + completion_queue: torch.multiprocessing.Queue, + timeout: int | None = None, +) -> Any: + """Get result from the completion_queue associated with process. + + When the process finished without putting a result or the timeout expired an exception instance will be returned""" + queue_timeout = 120 if timeout is None else max(10, min(120, timeout // 4)) + start_time = time.time() + # Periodically check the process for liveness + while True: + try: + return completion_queue.get(timeout=queue_timeout) + except queue.Empty: + # If the process is no longer alive we cannot get a result from the queue unless it is there right now. + # This can happen if the timeout occurred just before the process put its result and terminated. + # So do a last check for emptiness before considering it as a failure. + if not process.is_alive() and completion_queue.empty(): + return RuntimeError(f"Exited with {process.exitcode}") + if timeout is not None: + elapsed = time.time() - start_time + if elapsed > timeout: + return RuntimeError(f"Process timed out out after {elapsed}s") + + # Most tests operate with this worldsize if TEST_WITH_ROCM: DEFAULT_WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) @@ -941,11 +970,11 @@ def run_test(self, test_name: str, parent_pipe) -> None: try: getattr(self, test_name)() except unittest.SkipTest as se: - logger.info( # noqa: G200 + logger.info( "Process %s skipping test %s for following reason: %s", self.rank, test_name, - str(se), + se, ) sys.exit(TEST_SKIPS["generic"].exit_code) except Exception: @@ -1280,7 +1309,7 @@ def worker(rank, world_pg, store): ) try: callback() - except BaseException as ex: # noqa: B036 + except BaseException as ex: # Exceptions are handled in MultiThreadedTestCase MultiThreadedTestCase.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( @@ -1441,7 +1470,7 @@ def run_test_with_threaded_pg(self, test_name, rank, world_size): try: getattr(self, test_name)() - except BaseException as ex: # noqa: B036 + except BaseException as ex: self.exception_queue.put((rank, sys.exc_info())) ProcessLocalGroup.exception_handle( ex @@ -1496,7 +1525,7 @@ def _check_return_codes(cls, failed_ranks, timeout, fn): "Thread %s skipping test %s for following reason: %s", rank, fn, - str(exc), + exc, ) if skip_code < 0: skip_code = TEST_SKIPS["generic"].exit_code @@ -1812,7 +1841,7 @@ def _worker_loop(cls, rank, world_size, rdvz_file, task_queue, completion_queue) try: cls._run_test_given_id(test_id) completion_queue.put(test_id) - except BaseException as ex: # noqa: B036 + except BaseException as ex: if isinstance(ex, SystemExit): # Get exit code from the process exit_code = getattr(ex, "code", None) @@ -1877,7 +1906,7 @@ def _spawn_processes(cls, world_size) -> None: cls.processes.append(process) cls.task_queues.append(task_queue) cls.completion_queues.append(completion_queue) - logger.debug("Started process %s with pid %s", rank, process.pid) # noqa: UP031 + logger.debug("Started process %s with pid %s", rank, process.pid) @classmethod def _get_world_size(cls, device_type: str) -> int: @@ -2019,8 +2048,12 @@ def wrapper(self): # Drain all completion queues before raising any exception, # so stale results don't desync subsequent tests. deferred_exception = None - for i, completion_queue in enumerate(self.completion_queues): - rv = completion_queue.get() + for i, (p, completion_queue) in enumerate( + zip(self.processes, self.completion_queues) + ): + rv = retrieve_result_from_completion_queue( + p, completion_queue, timeout=get_timeout(self.id()) + ) if deferred_exception is not None: # Already captured an exception; just drain continue @@ -2030,7 +2063,7 @@ def wrapper(self): if isinstance(rv, BaseException): logger.warning( f"Detected failure from Rank {i} in: {self.id()}, " # noqa: G004 - f"skipping rest of tests in Test class: {self.__class__.__name__}" # noqa: G004 + f"skipping rest of tests in Test class: {self.__class__.__name__}" ) self.__class__.poison_pill = True deferred_exception = rv diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py index 20c8d88b46ccf..6e62009b6b601 100644 --- a/torch/testing/_internal/common_dtype.py +++ b/torch/testing/_internal/common_dtype.py @@ -222,6 +222,20 @@ def get_all_qint_dtypes() -> list[torch.dtype]: return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4] +def highest_precision_float(device): + if torch.device(device).type == "mps": + return torch.float32 + else: + return torch.float64 + + +def highest_precision_complex(device): + if torch.device(device).type == "mps": + return torch.complex64 + else: + return torch.complex128 + + float_to_corresponding_complex_type_map = { torch.float16: torch.complex32, torch.float32: torch.complex64, diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py index cebeb1a4c0197..0b18ced69f04d 100644 --- a/torch/testing/_internal/common_fsdp.py +++ b/torch/testing/_internal/common_fsdp.py @@ -115,7 +115,7 @@ class FSDPTestModel(nn.Module, ABC): @abstractmethod def get_input(self, device) -> tuple[torch.Tensor, ...]: - """Returns an input for the model as as tuple.""" + """Returns an input for the model as a tuple.""" ... @abstractmethod diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7e9bb20e3361f..2f33bde89df51 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -23,6 +23,8 @@ _dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types, floating_and_complex_types_and, all_types_and_complex_and, all_types_and, all_types_and_complex, integral_types_and, empty_types, complex_types_and, integral_types, custom_types, all_types_complex_float8_and, float8_types, + highest_precision_complex, + highest_precision_float, ) from torch.testing._internal.common_device_type import ( onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, @@ -39,6 +41,8 @@ ) from torch.testing._internal.common_utils import ( make_fullrank_matrices_with_distinct_singular_values, + IS_ARM64, + IS_CPU_EXT_SVE_SUPPORTED, TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, @@ -46,11 +50,11 @@ ) from torch.testing._utils import wrapper_set_seed -import torch._refs as refs # noqa: F401 +import torch._refs as refs import torch._refs.nn.functional import torch._refs.special import torch._refs.linalg -import torch._prims as prims # noqa: F401 +import torch._prims as prims from torch.utils import _pytree as pytree @@ -1246,7 +1250,7 @@ def sample_inputs_mv(self, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg(S, M), make_arg(M)) def sample_inputs_bmm(self, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, dtype=dtype, device=device, low=-1, high=+1, requires_grad=requires_grad) + make_arg = partial(make_tensor, dtype=dtype, device=device, low=None, high=None, requires_grad=requires_grad) yield SampleInput(make_arg(M, S, M), make_arg(M, M, S)) def sample_inputs_dot_vdot(self, device, dtype, requires_grad, **kwargs): @@ -1613,7 +1617,7 @@ def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs): ((S,), {'dtype': dtype, 'device': device}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), {'dtype': highest_precision_float(device)}), ((S,), {'device': 'cpu'}), ((S,), {'dtype': torch.double, 'device': 'cpu'}), ] @@ -1807,7 +1811,7 @@ def sample_inputs_new_fns(self, device, dtype, requires_grad, *, is_strided=Fals ((S,), (2, 3), (7, 8), {'dtype': dtype, 'device': device}), # Hard-code some dtypes/devices. We want to test cases where the # (dtype, device) is different from the input's (dtype, device) - ((S,), (10,), (S,), {'dtype': torch.double if device != 'mps:0' else torch.float}), + ((S,), (10,), (S,), {'dtype': highest_precision_float(device)}), ((S,), (1, 1, 12), (S, L, M), {'device': 'cpu'}), ((S,), (2, 2, 2), (L, M, S), {'dtype': torch.double, 'device': 'cpu'}), ] @@ -1930,7 +1934,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): def get_val(dtype): return make_tensor([], dtype=dtype, device="cpu").item() - double_dtype = torch.double if device != "mps:0" else torch.float + double_dtype = highest_precision_float(device) inputs = [ ((), get_val(dtype), {}), ((S, S), get_val(dtype), {}), @@ -2529,7 +2533,7 @@ def reference_inputs_cat(op, device, dtype, requires_grad, **kwargs): # Noncontiguous type promoting tensors a = make_arg((3, 4, 2)) - b = make_arg((3, 2, 2), noncontiguous=True, dtype=torch.double) + b = make_arg((3, 2, 2), noncontiguous=True, dtype=highest_precision_float(device)) c = make_arg((3, 3, 2), dtype=torch.float16).permute(1, 0, 2) yield SampleInput((a, b, c), kwargs={'dim': 1}) @@ -2688,7 +2692,7 @@ def error_inputs_gather(op_info, device, **kwargs): # Creates new src & idx since SampleInputs can't share tensors src = torch.tensor(((1, 2), (3, 4)), device=device, dtype=torch.float32) idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long) - out = torch.empty((2, 2), device=device, dtype=torch.float64) + out = torch.empty((2, 2), device=device, dtype=torch.float16 if torch.device(device).type == 'mps' else torch.float64) yield ErrorInput(SampleInput(src, args=(1, idx), kwargs={'out': out}), error_regex="Expected out tensor to have dtype") @@ -2751,7 +2755,7 @@ def error_inputs_scatter_and_scatter_add(op_info, device, **kwargs): # Error when self.dtype != src.dtype (and src is not a scalar) src = make_tensor((2, 5), device=device, dtype=torch.float32) idx = torch.tensor(((0, 1), (1, 2)), device=device, dtype=torch.long) - dst = torch.zeros((3, 5), device=device, dtype=torch.double) + dst = torch.zeros((3, 5), device=device, dtype=torch.float16 if torch.device(device).type == 'mps' else torch.double) yield ErrorInput(SampleInput(dst, args=(0, idx, src)), error_regex="Expected self.dtype to be equal to src.dtype") @@ -2827,7 +2831,8 @@ def error_inputs_t(op_info, device, **kwargs): def error_inputs_multinomial(op_info, device, **kwargs): - x = torch.empty(1, 2, 3, dtype=torch.double, device=device) + dtype = highest_precision_float(device) + x = torch.empty(1, 2, 3, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(2,)), error_regex="prob_dist must be 1 or 2 dim") @@ -2835,24 +2840,24 @@ def error_inputs_multinomial(op_info, device, **kwargs): yield ErrorInput(SampleInput(x, args=(2,)), error_regex="multinomial only supports floating-point dtypes for input") - x = torch.empty(1, 2, dtype=torch.double, device=device) - y = torch.empty(1, 2, dtype=torch.double, device=device) + x = torch.empty(1, 2, dtype=dtype, device=device) + y = torch.empty(1, 2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(2,), kwargs=dict(out=y)), error_regex="multinomial expects Long tensor out") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(0,)), error_regex="cannot sample n_sample <= 0 samples") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(-1,)), error_regex="cannot sample n_sample <= 0 samples") - x = torch.empty(2, dtype=torch.double, device=device) + x = torch.empty(2, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(3, False,)), error_regex="cannot sample n_sample > prob_dist") - x = torch.empty(16777217, dtype=torch.double, device=device) + x = torch.empty(16777217, dtype=dtype, device=device) yield ErrorInput(SampleInput(x, args=(3,)), error_regex="number of categories cannot exceed") @@ -3178,7 +3183,14 @@ def sample_inputs_histc(op_info, device, dtype, requires_grad, **kwargs): sizes = ((), (S,), (S, S), (S, S, S), (S, 1, S), (S, 0, S)) - for size, min, max in product(sizes, [0, -10], [0, 10]): + minima = [0] + maxima = [10] + + if dtype.is_signed: + minima.append(-10) + maxima.append(0) + + for size, min, max in product(sizes, minima, maxima): # construct sample input omitting bins arg yield SampleInput(make_arg(size), min=min, max=max) @@ -4054,8 +4066,8 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs): def error_inputs_conv1d(opinfo, device, **kwargs): - dtype = torch.float64 if device != 'mps:0' else torch.float32 - cdtype = torch.complex128 if device != 'mps:0' else torch.complex64 + dtype = highest_precision_float(device) + cdtype = highest_precision_complex(device) make_arg = partial(make_tensor, device=device, dtype=dtype) make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) make_complex_arg = partial(make_tensor, device=device, dtype=cdtype) @@ -4116,8 +4128,8 @@ def error_inputs_conv1d(opinfo, device, **kwargs): def error_inputs_conv2d(opinfo, device, **kwargs): - dtype = torch.float64 if device != 'mps:0' else torch.float32 - cdtype = torch.complex128 if device != 'mps:0' else torch.complex64 + dtype = highest_precision_float(device) + cdtype = highest_precision_complex(device) make_arg = partial(make_tensor, device=device, dtype=dtype) make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) make_complex_arg = partial(make_tensor, device=device, dtype=cdtype) @@ -4255,8 +4267,8 @@ def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs): def error_inputs_conv3d(opinfo, device, **kwargs): - dtype = torch.float64 if device != 'mps:0' else torch.float32 - cdtype = torch.complex128 if device != 'mps:0' else torch.complex64 + dtype = highest_precision_float(device) + cdtype = highest_precision_complex(device) make_arg = partial(make_tensor, device=device, dtype=dtype) make_int_arg = partial(make_tensor, device=device, dtype=torch.int64) make_complex_arg = partial(make_tensor, device=device, dtype=cdtype) @@ -6996,8 +7008,6 @@ def skips_mvlgamma(skip_redundant=False): DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=(torch.int8,)), - # NotImplementedError: The operator 'aten::mvlgamma.out' is not currently implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', device_type='mps'), ) if skip_redundant: # Redundant tests @@ -9110,10 +9120,10 @@ def sample_inputs_scaled_mm_v2(op_info, device, dtype, requires_grad, **kwargs): mat2_fp4, [scale1, global_scale1], [ScalingType.BlockWise1x16, ScalingType.TensorWise], - [SwizzleType.SWIZZLE_32_4_4, ], + [SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], [scale2, global_scale2], [ScalingType.BlockWise1x16, ScalingType.TensorWise], - [SwizzleType.SWIZZLE_32_4_4, ], + [SwizzleType.SWIZZLE_32_4_4, SwizzleType.NO_SWIZZLE], None, # bias torch.bfloat16, # out_dtype ) @@ -9410,8 +9420,9 @@ def sample_inputs_l1_loss(op_info, device, dtype, requires_grad, **kwargs): # test COMPLEX_TO_FLOAT promotion if dtype.is_complex: make = partial(make_tensor, (), device=device, requires_grad=requires_grad) - yield SampleInput(make(dtype=dtype), args=(make(dtype=torch.double),)) - yield SampleInput(make(dtype=torch.double), args=(make(dtype=dtype),)) + other_dtype = highest_precision_float(device) + yield SampleInput(make(dtype=dtype), args=(make(dtype=other_dtype),)) + yield SampleInput(make(dtype=other_dtype), args=(make(dtype=dtype),)) def error_inputs_l1_loss(op_info, device, **kwargs): make = partial(make_tensor, device=device, dtype=torch.float32) @@ -9775,6 +9786,8 @@ def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dt if self.arity == 1: if "foreach_abs" in opinfo.name and dtype in complex_types(): return True + if "foreach_clone" in opinfo.name: + return False # unary if opinfo.ref in (torch.abs, torch.neg): return False @@ -9953,7 +9966,7 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * for ord in (0, 1, 2, -1, -2, float('inf'), float('-inf')): input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs) disable_fastpath = True - if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + if ord in (0, 1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): disable_fastpath = False yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath) @@ -9968,7 +9981,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for num_tensors, ord, out_dtype, intersperse_empty_tensors in product( num_input_tensors, (0, 1, 2, -1, -2, float('inf'), float('-inf')), - (None,) + (torch.complex128,) if dtype in complex_types() else (torch.float64,), + (None,) + (highest_precision_complex(device),) if dtype in complex_types() else (highest_precision_float(device),), (True, False), ): # inf norm and negative norms on empty tensors is not supported by our reference func vector norm: @@ -9979,7 +9992,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): _foreach_inputs_kwargs["intersperse_empty_tensors"] = intersperse_empty_tensors input = sample_inputs_foreach(None, device, dtype, num_tensors, zero_size=False, **_foreach_inputs_kwargs) disable_fastpath = True - if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + if ord in (0, 1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): disable_fastpath = False yield ForeachSampleInput(input, ord=ord, disable_fastpath=disable_fastpath, dtype=out_dtype) @@ -9996,7 +10009,7 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): for input in nan_inputs: x = torch.tensor(input, device=device) disable_fastpath = True - if ord in (1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): + if ord in (0, 1, 2, float('inf')) and dtype in floating_types_and(torch.half, torch.bfloat16): disable_fastpath = False yield ForeachSampleInput([x], ord=ord, disable_fastpath=disable_fastpath) @@ -11242,6 +11255,13 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): ), ), ), + ForeachFuncInfo( + "clone", + sample_inputs_func=foreach_inputs_sample_func(1, False, False), + supports_forward_ad=True, + supports_autograd=True, + supports_inplace_autograd=True, + ), ] foreach_binary_op_db: list[OpInfo] = [ @@ -13043,6 +13063,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cfloat], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), )), @@ -13544,6 +13567,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=(torch.chalf,), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), )), UnaryUfuncInfo('cosh', ref=np_unary_ufunc_integer_promotion_wrapper(np.cosh), @@ -13576,6 +13602,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cuda', dtypes=(torch.chalf,), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), )), OpInfo('cov', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), @@ -13602,7 +13631,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # undefined value tensor: # File "", line 3 # def the_method(i0): - # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # noqa: B950 + # return torch.cov(i0, correction=0, fweights=None, aweights=tensor([0.0518, 0.4681], dtype=torch.float32, requires_grad=True)) # ~~~~~~ <--- HERE DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), DecorateInfo(toleranceOverride({torch.float16: tol(atol=8e-3, rtol=1.4e-3)}), @@ -13626,21 +13655,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # cumsum does not handle correctly out= dtypes DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # The following dtypes did not work in forward but are listed by the OpInfo: {torch.complex64} + # The following dtypes did not work in forward but are listed by the OpInfo: {torch.bool} DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_out_requires_grad_error', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', - device_type='mps', dtypes=(torch.complex64,) - ), ), sample_inputs_func=sample_inputs_cumulative_ops), OpInfo('cumprod', @@ -13651,21 +13667,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # cumprod does not handle correctly out= dtypes DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # The following dtypes did not work in forward but are listed by the OpInfo: {torch.complex64} + # The following dtypes did not work in forward but are listed by the OpInfo: {torch.bool} DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_out_requires_grad_error', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', - device_type='mps', dtypes=(torch.complex64,) - ), ), # gradgradcheck fails in fast_mode=True: #56275 sample_inputs_func=sample_inputs_cumprod, @@ -13817,6 +13820,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Reference: https://github.com/pytorch/pytorch/issues/48010 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), assert_autodiffed=True, supports_forward_ad=True, @@ -14371,7 +14376,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # see discussion : https://github.com/pytorch/pytorch/issues/56660 # RuntimeError: # Arguments for call are not valid. - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), # noqa: B950 + DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, torch.complex64)), DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'), DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'), ), @@ -15414,18 +15419,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=floating_and_complex_types_and(torch.half, torch.bfloat16), sample_inputs_func=sample_inputs_normalize, supports_forward_ad=True, - skips=( - # Exception: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), decorators=(onlyNativeDeviceTypes,), - supports_autograd=False, + supports_autograd=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, sample_inputs_func=sample_inputs_aminmax, skips=( # Exception: MPS supports tensors with dimensions <= 16, but got 65. @@ -15528,8 +15530,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k sample_inputs_func=sample_inputs_as_strided_scatter, error_inputs_func=error_inputs_as_strided_scatter, skips=( - DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), # noqa: B950 - DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), # noqa: B950 + DecorateInfo(unittest.skip('Works for int64, fails for everything else'), 'TestCommon', 'test_noncontiguous_samples'), + DecorateInfo(unittest.skip('Fails in most cases, passes on LAZY for some reason'), 'TestCommon', 'test_variant_consistency_eager'), DecorateInfo(unittest.skip('Fails on cuda'), 'TestCommon', 'test_complex_half_reference_testing', active_if=not TEST_WITH_ROCM), DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_grad'), @@ -16263,7 +16265,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k aliases=('group_norm',), ref=reference_group_norm, dtypes=floating_types_and(torch.float16, torch.bfloat16), - dtypesIfMPS=floating_types_and(torch.float16, torch.bfloat16, torch.int32, torch.int16), + dtypesIfMPS=floating_types_and(torch.float16, torch.bfloat16, torch.int32, torch.int16, torch.int8, torch.uint8), supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -16278,6 +16280,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_comprehensive", device_type="cpu" ), + # MPS supports int8/uint8 but CPU does not, so consistency test cannot run + DecorateInfo(unittest.expectedFailure, 'TestConsistency', 'test_output_match', + device_type='mps', dtypes=(torch.int8, torch.uint8)), ], sample_inputs_func=sample_inputs_group_norm, reference_inputs_func=reference_inputs_group_norm, @@ -16632,7 +16637,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_out=False, skips=( DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), - DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), # Error: The operator 'aten::_upsample_bilinear2d_aa_backward.grad_input' @@ -17059,9 +17063,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_expanded_weight=True, skips=( # RuntimeError: MPS device does not support linear for non-float inputs - # RuntimeError: mps linear does not support complex types DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64, torch.int64)), + DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.int64,)), ), decorators=( # Strides are not the same! @@ -17618,6 +17621,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestUnaryUfuncs', "test_reference_numerics_extremal", dtypes=(torch.complex64, torch.complex128), device_type='cpu', active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), ), # tan(j * pi/2 * odd_number) is nan which also make tanhshrink nan. reference_numerics_filter=NumericsFilter( @@ -17653,11 +17659,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - skips=( - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), ), OpInfo( "nn.functional.triplet_margin_with_distance_loss", @@ -17680,14 +17681,12 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "TestNormalizeOperators", "test_normalize_operator_exhaustive", ), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ), ), BinaryUfuncInfo('nextafter', dtypes=floating_types_and(torch.bfloat16, torch.half), - supports_autograd=False, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, supports_rhs_python_scalar=False), OpInfo( "to", @@ -18010,30 +18009,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k sample_inputs_func=sample_inputs_mode,), make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_1', domain=(1, None), - skips=skips_mvlgamma() + ( - # NotImplementedError: The operator 'aten::mvlgamma.out' is not currently - # implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), - ), + skips=skips_mvlgamma(), sample_kwargs=lambda device, dtype, input: ({'p': 1}, {'d': 1})), make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_3', domain=(2, None), - skips=skips_mvlgamma() + ( - # NotImplementedError: The operator 'aten::mvlgamma.out' is not currently - # implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), - ), + skips=skips_mvlgamma(), sample_kwargs=lambda device, dtype, input: ({'p': 3}, {'d': 3})), make_mvlgamma_opinfo(variant_test_name='mvlgamma_p_5', domain=(3, None), - skips=skips_mvlgamma() + ( - # NotImplementedError: The operator 'aten::mvlgamma.out' is not currently - # implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), - ), + skips=skips_mvlgamma(), sample_kwargs=lambda device, dtype, input: ({'p': 5}, {'d': 5})), BinaryUfuncInfo('ne', ref=np.not_equal, @@ -18120,18 +18104,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Could not allocate memory to change Tensor SizesAndStrides! check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - skips=( - # norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', - device_type='mps', dtypes=(torch.complex64,) - ), - ), sample_inputs_func=sample_inputs_dist), OpInfo('outer', op=torch.outer, @@ -18452,6 +18424,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=(torch.cfloat, torch.cdouble), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', dtypes=(torch.chalf,), active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), DecorateInfo(toleranceOverride({torch.float16: tol(atol=1e-3, rtol=2e-3)}), @@ -18484,6 +18460,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', dtypes=(torch.chalf,), active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), @@ -18822,6 +18803,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', dtypes=(torch.chalf,), active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), # FIXME: @@ -18871,6 +18857,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', + device_type='xpu', dtypes=(torch.chalf,), active_if=IS_WINDOWS), DecorateInfo(unittest.skip("Skipped! sparse backward not supported"), 'TestSparseUnaryUfuncs', 'test_sparse_fn_grad'), DecorateInfo(toleranceOverride({torch.complex64: tol(atol=3e-5, rtol=7e-6)}), @@ -18955,6 +18946,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # RuntimeError: linalg.solve.triangular(); Only float is supported! DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), + # see https://github.com/pytorch/pytorch/issues/177251 + DecorateInfo( + unittest.expectedFailure, + 'TestOperators', + 'test_jvp', + device_type='cpu', + dtypes=[torch.float32], + active_if=IS_ARM64 and IS_CPU_EXT_SVE_SUPPORTED, + ), )), UnaryUfuncInfo('trunc', aliases=('fix', ), @@ -19109,6 +19109,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # tensor(inf+nanj, device='cuda:0') DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble]), DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_meta_inplace', dtypes=[torch.bool]), DecorateInfo(unittest.expectedFailure, 'TestMeta', 'test_dispatch_meta_inplace', @@ -19239,22 +19241,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestSchemaCheckModeOpInfo', 'test_schema_correctness', dtypes=(torch.complex64, torch.complex128)), - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out', - device_type='mps', dtypes=[torch.float32]), - DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager', - device_type='mps', dtypes=[torch.float32]), - DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', - device_type='mps', dtypes=[torch.float32]), - # The operator 'aten::take' is not currently implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), - # RuntimeError: svd_backward: The singular vectors in the complex - # case are specified up to multiplication by e^{i phi}. The - # specified loss function depends on this phase term, making it - # ill-defined. - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', - device_type='mps', dtypes=(torch.complex64,) - ), )), OpInfo('svd_lowrank', op=lambda *args, **kwargs: wrapper_set_seed( @@ -19585,12 +19571,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_forward_ad=True, supports_fwgrad_bwgrad=True, error_inputs_func=error_inputs_gather, - skips=( - # RuntimeError: gather(): Yet not supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), - ), ), OpInfo('index_fill', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.complex32), @@ -19795,11 +19775,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Compiler issue on ROCm. Regression started in ROCm 6.4. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', dtypes=[torch.bool], active_if=TEST_WITH_ROCM), - # RuntimeError: scatter(): Yet not supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), )), UnaryUfuncInfo( 'bfloat16', @@ -20493,8 +20468,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k error_inputs_func=error_inputs_multinomial, skips=( DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), # Strides are not the same! # This may not be reproducible in CI DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_out'), @@ -20615,11 +20588,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Compiler issue on ROCm. Regression started in ROCm 6.4. DecorateInfo(unittest.skip('Skipped!'), 'TestCommon', 'test_non_standard_bool_values', dtypes=[torch.bool], active_if=TEST_WITH_ROCM), - # RuntimeError: scatter(): Yet not supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_errors', device_type='mps'), )), OpInfo('stack', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -20686,6 +20654,9 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k OpInfo('histc', dtypes=floating_types_and(torch.bfloat16, torch.float16), dtypesIfCUDA=floating_types_and(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64), + dtypesIfMPS=floating_types_and( + torch.bfloat16, torch.float16, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64 + ), sample_inputs_func=sample_inputs_histc, supports_out=True, supports_autograd=False, @@ -20751,8 +20722,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k check_batched_forward_grad=False, assert_autodiffed=True, skips=( - # https://github.com/pytorch/pytorch/issues/89353 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), # RuntimeError: Arguments for call not valid. # Expected a value of type 'List[Tensor]' for argument # 'tensors' but instead found type 'Tensor (inferred)'. @@ -20762,8 +20731,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # see https://github.com/pytorch/pytorch/issues/99806 # RuntimeError: The size of tensor a (25) must match the size of tensor b (0) at non-singleton dimension 0. DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_fn_gradgrad'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', device_type='mps', dtypes=(torch.int64,)), )), OpInfo('unbind', dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -20869,7 +20836,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k device_type='cpu', dtypes=(torch.float16,)), DecorateInfo(toleranceOverride({torch.float16: tol(atol=3e-4, rtol=3e-6)}), "TestConsistency", "test_output_match", device_type="mps"), - # RuntimeError: norm ops are not supported for complex yet + # RuntimeError: Failed to create function state object for: renorm_float2 DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), )), @@ -20885,9 +20852,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k sample_inputs_func=sample_repeat_tile, skips=( DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"), - # Exception: repeat(): Not supported for complex yet! - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), )), OpInfo('squeeze', ref=_squeeze_ref, @@ -21001,9 +20965,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k decorators=( # RuntimeError: view size is not compatible with input tensor's size and stride DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), - # MPS: gather(): Yet not supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), )), ShapeFuncInfo('tile', ref=np.tile, @@ -21013,11 +20974,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_out=False, supports_forward_ad=True, supports_fwgrad_bwgrad=True, - skips=( - # RuntimeError: repeat(): Not supported for complex yet! - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), sample_inputs_func=sample_repeat_tile), OpInfo('trapz', # TODO: in the future, 'trapz' should be made a proper alias of 'trapezoid' dtypes=all_types_and_complex_and(torch.float16, torch.bfloat16), @@ -21059,16 +21015,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k toleranceOverride({torch.float16: tol(atol=4e-3, rtol=4e-3)}), 'TestInductorOpInfo', 'test_comprehensive', ), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), ), sample_inputs_func=sample_cumulative_trapezoid,), OpInfo('unsqueeze', @@ -21369,8 +21315,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_comprehensive", device_type="cuda" ), - # RuntimeError: Failed to create function state object for: logcumsumexp_outer_float2 - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), # AssertionError: UserWarning not triggered : Resized a non-empty tensor but did not warn about it. DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), ), @@ -21502,9 +21446,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_variant_consistency_jit", dtypes=(torch.float32,), ), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ), ), UnaryUfuncInfo('lgamma', @@ -21631,7 +21572,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k sample_inputs_func=sample_inputs_nonzero_static, supports_out=False, supports_autograd=False, - decorators=[onlyCPU], skips=( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'), @@ -21822,15 +21762,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_forward_ad=True, supports_fwgrad_bwgrad=True, skips=( - # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an unsafe cast - # from a result of dtype torch.float32 into an out= with dtype torch.long - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_out', - device_type='mps', dtypes=(torch.float32,) - ), - # Error: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.cfloat, torch.chalf)), # Dispatches in Python to vector_norm. Not sure how to make this test happy # Happens to pass on complex64. Also a mystery DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', @@ -21876,12 +21807,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_output_match', ), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - # AssertionError: RuntimeError not raised : Expected RuntimeError when doing an - # unsafe cast from a result of dtype torch.float32 into an out= with dtype torch.long - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps', dtypes=(torch.float32,)), # Issue with conj and torch dispatch, see https://github.com/pytorch/pytorch/issues/82479 DecorateInfo( unittest.skip("Skipped!"), @@ -21904,18 +21829,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # fast gradcheck produces NaNs gradcheck_fast_mode=False, skips=( - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - # AssertionError: RuntimeError not raised : Expected RuntimeError - # when doing an unsafe cast from a result of dtype torch.float32 - # into an out= with dtype torch.long - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps', dtypes=(torch.float32,)), DecorateInfo( toleranceOverride({torch.float16: tol(atol=2e-3, rtol=1e-3)}), 'TestInductorOpInfo', 'test_comprehensive', device_type='cuda', ), - # Error: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.cfloat, torch.chalf)), # Dispatches in Python to vector_norm. Not sure how to make this test happy # Happens to pass on complex64. Also a mystery DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', @@ -22252,13 +22169,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k supports_out=False, sample_inputs_func=sample_inputs_grid_sample, reference_inputs_func=reference_inputs_grid_sample, - supports_gradgrad=False, - skips=( - # Exception: The operator 'aten::grid_sampler_2d_backward' is not currently implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='mps'), - ), + supports_gradgrad=True, gradcheck_nondet_tol=1e-15), # TODO: delete this OpInfo once we add meta support for grid_sampler_3d OpInfo( @@ -22266,15 +22177,14 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_grid_sampler_2d, - supports_gradgrad=False, + supports_gradgrad=True, gradcheck_nondet_tol=1e-15, skips=( DecorateInfo(slowTest, 'TestDecomp', 'test_comprehensive', dtypes=(torch.float32, torch.float64), active_if=IS_WINDOWS), - # Exception: The operator 'aten::grid_sampler_2d_backward' is not currently implemented for the MPS device - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='mps'), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=2e-5, rtol=2e-6), + torch.float16: tol(atol=5e-3, rtol=4e-3)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), ),), # TODO: Remove grid_sampler_3d tests once `nn.functional.grid_sample` has # MPS support for all cases. @@ -22283,7 +22193,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=floating_types_and(torch.float16, torch.bfloat16), supports_out=False, sample_inputs_func=sample_inputs_grid_sampler_3d, - supports_gradgrad=False, + supports_gradgrad=True, gradcheck_nondet_tol=1e-15, skips=( # NOTE: Only run on MPS @@ -22291,14 +22201,13 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip('Skipped!'), device_type='cuda'), DecorateInfo(unittest.skip('Skipped!'), device_type='xpu'), DecorateInfo(unittest.skip('Skipped!'), device_type='meta'), - # Error: The operator 'aten::grid_sampler_3d_backward' is not currently implemented for the MPS device. - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='mps'), - DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4), - torch.float16: tol(atol=1e-4, rtol=1e-4), - torch.bfloat16: tol(atol=1e-4, rtol=1e-4)}), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), "TestConsistency", "test_output_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-3, rtol=2e-6), + torch.float16: tol(atol=5e-3, rtol=2e-2)}), + "TestConsistency", "test_output_grad_match", device_type="mps"), + DecorateInfo(toleranceOverride({torch.float32: tol(atol=5e-5, rtol=5e-4)}), + "TestCommon", "test_noncontiguous_samples", device_type="mps"), ),), OpInfo( "argwhere", @@ -22705,8 +22614,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k unittest.skip('Skipped!'), 'TestReductions', 'test_ref_small_input', device_type='xpu', dtypes=[torch.complex128, torch.int8, torch.int16, torch.int32, torch.int64]), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ), ), ReductionOpInfo( @@ -22801,6 +22708,8 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'), # Error: The operator 'aten::hash_tensor.out' is not currently implemented for the MPS device DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps'), + # NotImplementedError: aten::hash_tensor.out + DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_default', device_type='mps'), ) ), OpInfo( @@ -23016,11 +22925,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "test_variant_consistency_jit", dtypes=(torch.float32, torch.complex64), ), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref', device_type='mps', dtypes=(torch.int64,)), ), ), OpInfo( @@ -23345,16 +23249,17 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.lerp", torch_opinfo_name="lerp", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype + # Exception: Dtypes torch.float32 and * are not equal! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) ), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps', dtypes=(torch.bool,)), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.bool, torch.float16, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) ), + # RuntimeError: Failed to create function state object for: abs_dense_bool_bool + DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps', dtypes=(torch.bool,)), ), ), PythonRefInfo( @@ -23586,15 +23491,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'), DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_conj_view'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), # RuntimeError: value cannot be converted to type uint8_t without overflow DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', @@ -23678,12 +23574,13 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_executor', dtypes=(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64), device_type="cuda"), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning', device_type='mps'), + # RuntimeError: no _refs support for aten.copy.default DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps'), + # AssertionError: Tensor-likes are not equal! + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', + dtypes=(torch.bool, torch.int16, torch.int32, torch.int64, torch.int8, torch.uint8) + ), ), ), PythonRefInfo( @@ -23769,9 +23666,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), - # MPS: gather(): Yet not supported for complex - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ), ), PythonRefInfo( @@ -23876,6 +23770,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -23900,6 +23798,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cuda', dtypes=[torch.cfloat], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -23952,6 +23854,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cuda', dtypes=(torch.chalf,), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -23986,6 +23892,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cuda', dtypes=(torch.chalf,), active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -24030,6 +23940,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', dtypes=[torch.chalf], + active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -24047,6 +23961,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -24158,33 +24076,13 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.special.multigammaln", torch_opinfo_name="mvlgamma", torch_opinfo_variant_name="mvlgamma_p_3", - skips=skips_mvlgamma() + ( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - ), + skips=skips_mvlgamma(), ), ElementwiseUnaryPythonRefInfo( "_refs.special.multigammaln", torch_opinfo_name="mvlgamma", torch_opinfo_variant_name="mvlgamma_p_5", - skips=skips_mvlgamma() + ( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - ), + skips=skips_mvlgamma(), ), ElementwiseUnaryPythonRefInfo( "_refs.log", @@ -24233,14 +24131,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_name="log_softmax", torch_opinfo_variant_name="with_dtype", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # AssertionError: Tensor-likes are not close! + # RuntimeError: softmax only supported for floating types DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24375,6 +24274,14 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', dtypes=(torch.cfloat, torch.cdouble,), device_type='cpu', active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -24387,15 +24294,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_small', dtypes=[torch.cfloat]), - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64,) - ), ), ), ElementwiseUnaryPythonRefInfo( @@ -24418,6 +24316,14 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.int8]), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), PythonRefInfo( @@ -24425,15 +24331,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_name="softmax", torch_opinfo_variant_name="with_dtype", skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # AssertionError: Tensor-likes are not close! + # RuntimeError: softmax only supported for floating types DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24491,6 +24397,14 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ) ), ElementwiseUnaryPythonRefInfo( @@ -24510,6 +24424,14 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'test_reference_numerics_large', device_type='cpu', dtypes=[torch.cfloat, torch.cdouble], active_if=(IS_MACOS or IS_WINDOWS)), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_large', + device_type='xpu', + dtypes=(torch.chalf,), active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -24524,10 +24446,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # AssertionError: Tensor-likes are not close! + # NotImplementedError: log_softmax for complex is not supported for MPS DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24537,11 +24460,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # Exception: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24551,17 +24474,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k ElementwiseUnaryPythonRefInfo( "_refs.special.logit", torch_opinfo_name="logit", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64, torch.float16) - ), - ), ), # # Elementwise Unary nn.functional OpInfos @@ -24700,12 +24612,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.nn.functional.pairwise_distance", torch_opinfo_name="nn.functional.pairwise_distance", supports_out=True, - skips=( - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), ), PythonRefInfo( "_refs.nn.functional.pdist", @@ -24738,10 +24644,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # NotImplementedError: log_softmax for complex is not supported for MPS + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64, torch.complex32) + dtypes=(torch.float32, torch.complex64, torch.complex32) ), ), ), @@ -24756,17 +24663,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.poisson_nll_loss", torch_opinfo_name="nn.functional.poisson_nll_loss", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.nn.functional.prelu", @@ -24841,11 +24737,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. + # RuntimeError: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.float32, torch.float16, torch.complex64), + dtypes=(torch.float32, torch.complex64), ), ), ), @@ -24855,11 +24751,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k torch_opinfo_variant_name="with_dtype", supports_out=False, skips=( - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype. - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 + # RuntimeError: softmax only supported for floating types + # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.complex64, torch.float16, torch.float32), + dtypes=(torch.complex64, torch.float32), ), ), ), @@ -24887,11 +24783,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.l1_loss", torch_opinfo_name="nn.functional.l1_loss", - skips=( - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - ), ), PythonRefInfo( "_refs.nn.functional.margin_ranking_loss", @@ -24908,13 +24799,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.nn.functional.hinge_embedding_loss", torch_opinfo_name="nn.functional.hinge_embedding_loss", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,), - ), - ), ), PythonRefInfo( "_refs.nn.functional.nll_loss", @@ -24995,15 +24879,10 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=(torch.complex64, torch.complex128), device_type='cpu', active_if=(IS_MACOS or IS_WINDOWS)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), + DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', + 'test_reference_numerics_extremal', + device_type='xpu', dtypes=[torch.cfloat, torch.cdouble], + active_if=IS_WINDOWS), ), ), ElementwiseUnaryPythonRefInfo( @@ -25044,15 +24923,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestBinaryUfuncs', 'test_reference_numerics_extremal_values', dtypes=(torch.complex64, torch.complex128)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.float16,) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25180,15 +25050,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k decorators=( # See https://github.com/pytorch/pytorch/issues/111126 DecorateInfo(unittest.expectedFailure, 'TestBinaryUfuncs', 'test_type_promotion'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.bfloat16, torch.float16) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25394,17 +25255,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k ElementwiseUnaryPythonRefInfo( "_refs.logical_not", torch_opinfo_name="logical_not", - skips=( - # RuntimeError: Undefined type ComplexDouble - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex64,) - ), - ), ), ElementwiseBinaryPythonRefInfo( "_refs.logical_or", @@ -25457,15 +25307,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', dtypes=(torch.complex32,), device_type='cuda' ), - # TypeError: Trying to convert ComplexDouble to the MPS backend but it does not have support for that dtype - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.complex32,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.complex32,) - ), ) ), ElementwiseBinaryPythonRefInfo( @@ -25599,16 +25440,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k 'TestBinaryUfuncs', 'test_reference_numerics_small_values', dtypes=(torch.uint8,)), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - # NotImplementedError: "_local_scalar_dense_mps" not implemented for 'ComplexHalf' - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16, torch.bfloat16) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25644,17 +25475,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k PythonRefInfo( "_refs.addcdiv", torch_opinfo_name="addcdiv", - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.bfloat16, torch.float16,) - ), - ), ), PythonRefInfo( "_refs.addcmul", @@ -25668,18 +25488,17 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k dtypes=(torch.float16,), device_type="cpu"), DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref_torch_fallback', dtypes=(torch.float16,), device_type="cpu"), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 # AssertionError: Tensor-likes are not close! DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.bfloat16, + torch.int16, ) ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.bfloat16, + torch.int16, ) ), ), @@ -25690,11 +25509,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # test error disabled since rhs non-tensor python scalar is supported DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float32, torch.float16,) - ), ), ), ElementwiseBinaryPythonRefInfo( @@ -25703,11 +25517,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # test error disabled since rhs non-tensor python scalar is supported DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float32, torch.float16,) - ), ), ), PythonRefInfo( @@ -25725,9 +25534,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Greatest relative difference: 8.519846983548175e-06 at index (4,) (up to 1.3e-06 allowed) DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_python_ref', dtypes=(torch.uint8,), device_type="cpu"), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), ) ), ElementwiseBinaryPythonRefInfo( @@ -26013,10 +25819,15 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # FIXME: AssertionError: RuntimeError not raised DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps'), + # RuntimeError: Failed to create function state object for: cat_int32_t_* + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', + dtypes=(torch.complex64, torch.complex32) + ), + DecorateInfo( + unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', + dtypes=(torch.complex64, torch.complex32) + ), ), ), PythonRefInfo( @@ -26147,11 +25958,12 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=(torch.int32, torch.int16) + unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', + dtypes=(torch.int32, torch.int16, torch.int8, torch.uint8) ), DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta', device_type='mps', - dtypes=(torch.int32, torch.int16) + dtypes=(torch.int32, torch.int16, torch.int8, torch.uint8) ), ) ), @@ -26199,23 +26011,7 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k "_refs.renorm", torch_opinfo_name="renorm", skips=( - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo(unittest.expectedFailure, 'TestCommon', device_type='mps', dtypes=(torch.complex64,)), - # Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=(torch.float16,)), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', - device_type='mps', dtypes=(torch.float16,) - ), - ), - ), - PythonRefInfo( - "_refs.repeat", - torch_opinfo_name="repeat", - validate_view_consistency=False, - skips=( - # Exception: repeat(): Not supported for complex yet! + # RuntimeError: Failed to create function state object for: renorm_float2 DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', dtypes=(torch.complex64,) @@ -26226,6 +26022,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k ), ), ), + PythonRefInfo( + "_refs.repeat", + torch_opinfo_name="repeat", + validate_view_consistency=False, + ), PythonRefInfo( "_refs.reshape", torch_opinfo_name="reshape", @@ -26577,15 +26378,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # doesn't test out behavior properly for this operator DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.complex64,) - ), ), ), PythonRefInfo( @@ -26595,15 +26387,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k skips=( # doesn't test out behavior properly for this operator DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'), - # RuntimeError: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', device_type='mps', - dtypes=(torch.complex64,) - ), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', - dtypes=(torch.complex64,) - ), ), ), PythonRefInfo( @@ -26696,14 +26479,11 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k decorators=( DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref',), # RuntimeError: MPS device does not support addr for non-float input - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS - # framework doesn't support float64. DecorateInfo( unittest.expectedFailure, 'TestCommon', 'test_python_ref_torch_fallback', device_type='mps', dtypes=( torch.uint8, torch.int8, torch.int64, torch.int32, - torch.int16, torch.float16, torch.complex64, torch.bool, - torch.bfloat16 + torch.int16, torch.complex64, torch.bool, ) ), ), @@ -26719,19 +26499,6 @@ def sample_inputs_abs(op_info, device, dtype, requires_grad, op_kwargs=None, **k # Uses vector_norm inside and vector_norm is affected by # https://github.com/pytorch/pytorch/issues/77216 validate_view_consistency=False, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, 'TestCommon', 'test_python_ref', - device_type='mps', dtypes=(torch.float16,) - ), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_dtypes', device_type='mps'), - DecorateInfo( - unittest.expectedFailure, 'TestCommon', device_type='mps', - dtypes=(torch.complex32, torch.complex64) - ), - ), ), # # Tensor Creation Reference OpInfos diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 10807aa111e33..3e5d4af0d2a58 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -25,7 +25,7 @@ nllloss_reference, nlllossNd_reference, smoothl1loss_reference, softmarginloss_reference, get_reduction) from torch.testing._internal.common_utils import ( freeze_rng_state, skipIfMPS, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS, - skipIfTorchDynamo) + skipIfTorchDynamo, skipIfXpu) from types import ModuleType import operator @@ -913,7 +913,10 @@ def module_inputs_torch_nn_BatchNorm1d(module_info, device, dtype, requires_grad desc='3d_input_not_affine'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 9))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(10, bias=False), + forward_input=FunctionInput(make_input((4, 10))), + desc='affine_not_bias'),] def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -936,7 +939,10 @@ def module_inputs_torch_nn_BatchNorm2d(module_info, device, dtype, requires_grad desc='not_tracking_stats'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 2, 2))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(3, bias=False), + forward_input=FunctionInput(make_input((2, 3, 6, 6))), + desc='affine_not_bias'),] def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -959,7 +965,10 @@ def module_inputs_torch_nn_BatchNorm3d(module_info, device, dtype, requires_grad desc='not_tracking_stats'), ModuleInput(constructor_input=FunctionInput(5, 1e-3, 0.3, False), forward_input=FunctionInput(make_input((0, 5, 2, 2, 2))), - desc='zero_batch')] + desc='zero_batch'), + ModuleInput(constructor_input=FunctionInput(3, bias=False), + forward_input=FunctionInput(make_input((2, 3, 4, 4, 4))), + desc='affine_not_bias'),] def module_error_inputs_torch_nn_BatchNorm1d_2d_3d(module_info, device, dtype, requires_grad, training, **kwargs): @@ -1837,6 +1846,10 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 5))), desc='1d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 6, 1e-3, bias=False), + forward_input=FunctionInput(make_input((4, 6, 5))), + desc='1d_affine_not_bias'), ModuleInput( constructor_input=FunctionInput(3, 12, 1e-3), forward_input=FunctionInput(make_input((4, 12))), @@ -1857,6 +1870,10 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, constructor_input=FunctionInput(3, 6, 1e-3), forward_input=FunctionInput(make_input((4, 6, 2, 3))), desc='2d_affine'), + ModuleInput( + constructor_input=FunctionInput(3, 9, 1e-3, bias=False), + forward_input=FunctionInput(make_input((4, 9, 2, 3))), + desc='2d_affine_not_bias'), ModuleInput( constructor_input=FunctionInput(3, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), @@ -1864,8 +1881,7 @@ def module_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, ModuleInput( constructor_input=FunctionInput(1, 3, 1e-3, False), forward_input=FunctionInput(make_input((4, 3, 2, 3))), - desc='2d_no_affine_LN'), - ] + desc='2d_no_affine_LN'),] def module_error_inputs_torch_nn_GroupNorm(module_info, device, dtype, requires_grad, training, **kwargs): @@ -2054,8 +2070,21 @@ def module_inputs_torch_nn_InstanceNormNd(module_info, device, dtype, requires_g ), forward_input=FunctionInput(make_input(input_no_batch_shape)), reference_fn=no_batch_dim_reference_fn, - desc='no_batch_dim') - ] + desc='no_batch_dim'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine=True) if lazy else + FunctionInput(num_features, eps, momentum, affine=True) + ), + forward_input=FunctionInput(make_input(input_batch_shape)), + desc='affine'), + ModuleInput( + constructor_input=( + FunctionInput(eps, momentum, affine=True, bias=False) if lazy else + FunctionInput(num_features, eps, momentum, affine=True, bias=False) + ), + forward_input=FunctionInput(make_input(input_batch_shape)), + desc='affine_not_bias'),] def module_inputs_torch_nn_LayerNorm(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -2256,6 +2285,47 @@ def module_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, desc='return_indices'), ] + +def module_error_inputs_torch_nn_MaxPool2d(module_info, device, dtype, requires_grad, training, **kwargs): + """ + Error inputs for MaxPool2d that test error messages for invalid inputs. + """ + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + return [ + # Wrong input dimensions: 2D input instead of 3D/4D + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(2), + forward_input=FunctionInput(make_input((3, 4))), # 2D input + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex=r"non-empty 3D or 4D \(batch mode\) tensor expected for input" + ), + # Wrong input dimensions: 5D input + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(2), + forward_input=FunctionInput(make_input((1, 2, 3, 4, 5))), # 5D input + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex=r"non-empty 3D or 4D \(batch mode\) tensor expected for input" + ), + # Invalid padding: padding > kernel_size / 2 + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(3, padding=5), # kernel=3, pad=5 > 3/2 + forward_input=FunctionInput(make_input((1, 1, 10, 10))), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex=r"pad should be at most half of effective kernel size" + ), + ] + + def module_inputs_torch_nn_MaxPool3d(module_info, device, dtype, requires_grad, training, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -2815,6 +2885,57 @@ def module_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, ] +def module_error_inputs_torch_nn_Embedding(module_info, device, dtype, requires_grad, training, **kwargs): + """ + Error inputs for Embedding that test error messages for invalid inputs. + """ + samples = [] + + # Out of range indices: index exceeds num_embeddings + # Only test on CPU - CUDA triggers kernel assertion instead of Python exception + if torch.device(device).type == 'cpu': + samples.append( + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(num_embeddings=10, embedding_dim=3), + forward_input=FunctionInput(torch.tensor([0, 5, 15], device=device, dtype=torch.long)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=IndexError, + error_regex=r"index out of range in self" + ) + ) + + # Float indices: wrong dtype for indices (works on all devices) + samples.append( + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(num_embeddings=10, embedding_dim=3), + forward_input=FunctionInput(torch.tensor([1.5, 2.5], device=device, dtype=torch.float32)), + ), + error_on=ModuleErrorEnum.FORWARD_ERROR, + error_type=RuntimeError, + error_regex=r"Expected tensor for argument.*indices.*to have.*scalar type.*Long.*Int" + ) + ) + + # Negative num_embeddings (construction error, device-independent) + samples.append( + ErrorModuleInput( + ModuleInput( + constructor_input=FunctionInput(num_embeddings=-1, embedding_dim=3), + forward_input=FunctionInput(), + ), + error_on=ModuleErrorEnum.CONSTRUCTION_ERROR, + error_type=RuntimeError, + error_regex=r"Trying to create tensor with negative dimension" + ) + ) + + return samples + + + def module_inputs_torch_nn_MultiheadAttention(module_info, device, dtype, requires_grad, training, **kwargs): # Currently all samples below are for validating the no-batch-dim support. make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -3856,6 +3977,8 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad # Not implemented for chalf on CPU DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', dtypes=(torch.chalf,), device_type='cuda'), + DecorateInfo(skipIfXpu, 'TestModule', 'test_cpu_gpu_parity', + dtypes=(torch.chalf,), device_type='xpu'), ), decorators=( DecorateInfo(precisionOverride({torch.float32: 1e-04}), 'TestModule', 'test_memory_format'), @@ -3875,7 +3998,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad # Not implemented for chalf on CPU DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', dtypes=(torch.chalf,), device_type='cuda'), - DecorateInfo(unittest.expectedFailure, 'TestModule', 'test_cpu_gpu_parity', + DecorateInfo(skipIfXpu, 'TestModule', 'test_cpu_gpu_parity', dtypes=(torch.chalf,), device_type='xpu'), ), decorators=( @@ -4074,6 +4197,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.MaxPool2d, module_inputs_func=module_inputs_torch_nn_MaxPool2d, + module_error_inputs_func=module_error_inputs_torch_nn_MaxPool2d, ), ModuleInfo(torch.nn.MaxPool3d, module_inputs_func=module_inputs_torch_nn_MaxPool3d, @@ -4360,6 +4484,7 @@ def module_error_inputs_torch_nn_Pad3d(module_info, device, dtype, requires_grad ), ModuleInfo(torch.nn.Embedding, module_inputs_func=module_inputs_torch_nn_Embedding, + module_error_inputs_func=module_error_inputs_torch_nn_Embedding, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL, decorators=[ DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}), diff --git a/torch/testing/_internal/common_mps.py b/torch/testing/_internal/common_mps.py index ee8ef700f2cbf..d9ba9e955c91e 100644 --- a/torch/testing/_internal/common_mps.py +++ b/torch/testing/_internal/common_mps.py @@ -55,12 +55,16 @@ def mps_ops_modifier( "cos", "cosh", "cross", + "cumsum", + "cumprod", + "cumulative_trapezoid", "diag", "diag_embed", "diagflat", "diagonal", "diagonal_copy", "diagonal_scatter", + "dist", "divno_rounding_mode", "dsplit", "empty", @@ -72,6 +76,7 @@ def mps_ops_modifier( "expand", "expand_as", "expand_copy", + "gather", "flatten", "fill", "full", @@ -94,13 +99,16 @@ def mps_ops_modifier( "linalg.diagonal", "linalg.householder_product", "linalg.svd", + "linalg.vander", "linalg.vecdot", + "linalg.vector_norm", "log10", "log1p", "log2", "log", "logaddexp", "logaddexp2", + "logcumsumexp", "mH", "mT", "masked_fill", @@ -122,12 +130,22 @@ def mps_ops_modifier( "nn.functional.conv_transpose2d", "nn.functional.conv_transpose3d", "nn.functional.feature_alpha_dropoutwithout_train", + "nn.functional.l1_loss", + "nn.functional.linear", + "nn.functional.normalize", "nn.functional.padcircular", + "nn.functional.pairwise_distance", "nn.functional.softminwith_dtype", "nn.functional.softsign", "nn.functional.tanhshrink", + "nn.functional.triplet_margin_loss", + "nn.functional.triplet_margin_with_distance_loss", "nn.functional.unfold", "nonzero", + "nonzero_static", + "norm", + "normfro", + "norminf", "ones", "ones_like", "outer", @@ -137,6 +155,7 @@ def mps_ops_modifier( "randn", "ravel", "real", + "repeat", "repeat_interleave", "reshape_as", "reshape", @@ -145,6 +164,8 @@ def mps_ops_modifier( "rsqrt", "rsub", "scalar_tensor", + "scatter", + "scatter_add", "select", "sgn", "sigmoid", @@ -169,9 +190,11 @@ def mps_ops_modifier( "svd", "t", "t_copy", + "take_along_dim", "tanh", "tan", "tensor_split", + "tile", "transpose", "transpose_copy", "tril", @@ -265,7 +288,10 @@ def mps_ops_modifier( "logical_xor", "logsumexp", "long", + "masked.cumsum", + "masked.cumprod", "masked.mean", + "masked.normalize", "masked.prod", "masked.std", "masked.sum", @@ -318,7 +344,6 @@ def mps_ops_modifier( "put": None, "frexp": None, "geqrf": None, - "nn.functional.grid_sample": None, # Unsupported Border padding mode "hash_tensor": None, "heaviside": None, # "kthvalue": None, @@ -607,7 +632,6 @@ def mps_ops_modifier( "float_power": None, "linalg.matrix_rankhermitian": None, "linalg.pinvhermitian": None, - "nonzero_static": None, # MPS: input sizes must be divisible by output sizes "nn.functional.adaptive_avg_pool1d": None, "nn.functional.adaptive_avg_pool2d": None, @@ -806,6 +830,10 @@ def mps_ops_modifier( # Unsupported # This doesn't work on M1, but is partially working on M2 with the exception of torch.float16 "nn.functional.conv3d": None, + # MPS uses float32 intermediates (opmath_t) while CPU uses native + # half/bfloat16 precision, causing unbounded divergence. + # Half precision is covered by test_grid_sampler_3d_half_precision. + "nn.functional.grid_sample": [torch.float16, torch.bfloat16], } def addDecorator(op: OpInfo, d: DecorateInfo) -> None: @@ -901,11 +929,8 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "scalar_tensor": [torch.float16, torch.float32], "cdist": None, "masked.scatter": [torch.float16, torch.float32], - "grid_sampler_2d": None, - "grid_sampler_3d": None, "igamma": None, # currently not supported for any device "igammac": None, # currently not supported for any device - "aminmax": [torch.float32, torch.float16], "special.i1": [torch.float16], # "i1_backward" not implemented for 'Half' "special.i1e": [torch.float16], # "i1e_backward" not implemented for 'Half' # Correctness issues @@ -922,8 +947,6 @@ def mps_ops_grad_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: # CPU errors # derivative for zeta is not implemented "special.zeta": None, - # derivative for aten::nextafter is not implemented on CPU - "nextafter": None, # derivative for aten::floor_divide is not implemented on CPU "floor_divide": [torch.float16, torch.float32], # derivative for aten::_histogramdd_from_bin_cts is not implemented on CPU @@ -1008,11 +1031,6 @@ def mps_ops_error_inputs_modifier(ops: Sequence[OpInfo]) -> Sequence[OpInfo]: "clamp_max", "clamp_min", "masked_scatter", - # unsupported float64 dtype - "multinomial", - "gather", - "scatter", - "scatter_add", # MPS does not support tensor dimensions > 16 "amax", "amin", diff --git a/torch/testing/_internal/common_ops_unbacked.py b/torch/testing/_internal/common_ops_unbacked.py index 56a47a263cfc4..766126de04375 100644 --- a/torch/testing/_internal/common_ops_unbacked.py +++ b/torch/testing/_internal/common_ops_unbacked.py @@ -124,7 +124,6 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None): xfail("masked.var"), xfail("max_pool2d_with_indices_backward"), xfail("multinomial"), - xfail("nanquantile"), xfail("nn.functional.adaptive_avg_pool1d"), xfail("nn.functional.adaptive_avg_pool2d"), xfail("nn.functional.adaptive_avg_pool3d"), @@ -153,7 +152,6 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None): xfail("nn.functional.fractional_max_pool2d"), xfail("nn.functional.fractional_max_pool3d"), xfail("nn.functional.gaussian_nll_loss"), - xfail("nn.functional.glu"), xfail("nn.functional.grid_sample"), xfail("nn.functional.group_norm"), xfail("nn.functional.huber_loss"), @@ -196,7 +194,6 @@ def skip(op_name, variant_name="", *, device_type=None, dtypes=None): xfail("ormqr"), xfail("pca_lowrank"), xfail("pinverse"), - xfail("quantile"), xfail("qr"), xfail("rand_like"), xfail("randint_like"), diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index 527bc4e6a7a5b..fb67c4d845021 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -506,6 +506,28 @@ def optim_error_inputs_func_adagrad(device, dtype): error_regex="Invalid lr_decay value: -0.5", ), ] + if _get_device_type(device) == "cuda": + sample_tensor = torch.empty((), device=device, dtype=dtype) + error_inputs += [ + ErrorOptimizerInput( + OptimizerInput( + params=[sample_tensor], + kwargs={"foreach": True, "fused": True}, + desc="`fused` and `foreach` cannot be `True` together", + ), + error_type=RuntimeError, + error_regex="`fused` and `foreach` cannot be `True` together", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[sample_tensor], + kwargs={"fused": True, "differentiable": True}, + desc="`fused` does not support `differentiable`", + ), + error_type=RuntimeError, + error_regex="`fused` does not support `differentiable`", + ), + ] return error_inputs @@ -1535,6 +1557,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "CompiledOptimizerParityTests", "test_correctness", device_type="xpu", + active_if=lambda kwargs: kwargs.get("use_closure", False), ), DecorateInfo( skipIfTorchDynamo("See #133268 regarding dtype being None"), @@ -1598,7 +1621,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "maximize", "capturable", ), - supports_fused_on=("cpu",), + supports_fused_on=("cpu", "cuda"), supports_sparse=True, metadata_for_sparse=( {"lr": 0.1, "weight_decay": 0, "lr_decay": 0}, @@ -1623,6 +1646,16 @@ def _get_optim_inputs_including_global_cliquey_kwargs( "TestOptimRenewed", "test_fused_matches_forloop", ), + DecorateInfo( + toleranceOverride( + { # https://github.com/pytorch/pytorch/issues/116202 + torch.float32: tol(atol=5e-04, rtol=0.015), + } + ), + "TestOptimRenewed", + "test_mixed_device_dtype", + active_if=TEST_WITH_TORCHDYNAMO, + ), ), skips=( DecorateInfo( diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 8b8507433c7d8..ee70df65617d5 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -82,17 +82,19 @@ ) from torch.testing import make_tensor from torch.testing._comparison import ( + _unwrap_dtensor_for_comparison, BooleanPair, NonePair, + not_close_error_metas, NumberPair, Pair, TensorLikePair, ) -from torch.testing._comparison import not_close_error_metas from torch.testing._internal.common_dtype import get_all_dtypes from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree from torch.utils import cpp_extension +from torch._utils import _is_privateuse1_backend_available try: import pytest # type: ignore[import-not-found] has_pytest = True @@ -1065,7 +1067,7 @@ def wait_for_process(p, timeout=None): else: p.kill() raise - except: # noqa: B001,E722, copied from python core library + except: p.kill() raise finally: @@ -1452,12 +1454,6 @@ def TemporaryDirectoryName(suffix=None): yield d -def is_privateuse1_backend_available(): - privateuse1_backend_name = torch._C._get_privateuse1_backend_name() - privateuse1_backend_module = getattr(torch, privateuse1_backend_name, None) - return (is_available := getattr(privateuse1_backend_module, "is_available", None)) and is_available() - - def make_lazy_class(cls): def lazy_init(self, cb): @@ -1511,7 +1507,7 @@ class LazyVal: TEST_ACCELERATOR = LazyVal(lambda: torch.accelerator.is_available()) # type: ignore[call-arg] TEST_MULTIACCELERATOR = LazyVal(lambda: torch.accelerator.device_count() > 1) # type: ignore[call-arg] custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name(), None) -TEST_PRIVATEUSE1 = is_privateuse1_backend_available() +TEST_PRIVATEUSE1 = _is_privateuse1_backend_available() TEST_PRIVATEUSE1_DEVICE_TYPE = torch._C._get_privateuse1_backend_name() TEST_NUMBA = _check_module_exists('numba') TEST_TRANSFORMERS = _check_module_exists('transformers') @@ -1523,6 +1519,44 @@ class LazyVal: TEST_Z3 = _check_module_exists('z3') +# DSL availability (lazy evaluation to avoid import overhead) +class LazyDSLCheck: + """Lazy DSL availability checker to avoid import-time overhead""" + def __init__(self): + self._registry = None + self._import_attempted = False + + def _get_registry(self): + if not self._import_attempted: + self._import_attempted = True + try: + from torch._native.dsl_registry import dsl_registry + self._registry = dsl_registry + except ImportError: + self._registry = None + return self._registry + + def is_available(self, dsl_name: str) -> bool: + """Check if specific DSL is available""" + registry = self._get_registry() + return registry.is_dsl_available(dsl_name) if registry is not None else False + + def list_available(self) -> list[str]: + """Get list of available DSLs""" + registry = self._get_registry() + return list(registry.list_available_dsls()) if registry is not None else [] + + def list_all(self) -> list[str]: + """Get list of all registered DSLs""" + registry = self._get_registry() + return list(registry.list_all_dsls()) if registry is not None else [] + +_dsl_checker = LazyDSLCheck() + +# Lazy constants to avoid import-time overhead +TEST_TRITON_DSL = LazyVal(lambda: _dsl_checker.is_available('triton')) +TEST_CUTEDSL = LazyVal(lambda: _dsl_checker.is_available('cutedsl')) + def split_if_not_empty(x: str): return x.split(",") if len(x) != 0 else [] @@ -1530,6 +1564,34 @@ def split_if_not_empty(x: str): skipIfNoDill = unittest.skipIf(not TEST_DILL, "no dill") +# DSL skip decorators (following existing pattern) +skipIfNoTritonDSL = unittest.skipIf(not TEST_TRITON_DSL, "Triton DSL not available") +skipIfNoCuteDSL = unittest.skipIf(not TEST_CUTEDSL, "CuTeDSL not available") + +def skipIfDSLUnavailable(dsl_name: str, reason: str | None = None): + """Skip test if specific DSL is not available""" + available = _dsl_checker.is_available(dsl_name) + msg = reason or f"{dsl_name} DSL not available" + return unittest.skipIf(not available, msg) + +def skipUnlessDSLAvailable(dsl_name: str, reason: str | None = None): + """Skip test unless specific DSL is available""" + available = _dsl_checker.is_available(dsl_name) + msg = reason or f"{dsl_name} DSL required" + return unittest.skipUnless(available, msg) + +def get_available_dsls() -> list[str]: + """Get list of available DSL names for test parameterization""" + return _dsl_checker.list_available() + +def is_dsl_available(dsl_name: str) -> bool: + """Check if specific DSL is available for conditional testing""" + return _dsl_checker.is_available(dsl_name) + +def get_all_dsls() -> list[str]: + """Get all registered DSL names (available or not) for comprehensive testing""" + return _dsl_checker.list_all() + NO_MULTIPROCESSING_SPAWN: bool = False TEST_WITH_ASAN: bool = TestEnvironment.def_flag( @@ -1551,6 +1613,7 @@ def split_if_not_empty(x: str): TEST_WITH_ROCM: bool = TestEnvironment.def_flag( "TEST_WITH_ROCM", env_var="PYTORCH_TEST_WITH_ROCM", + implied_by_fn=lambda: torch.version.hip is not None, ) TEST_WITH_MTIA: bool = TestEnvironment.def_flag( "TEST_WITH_MTIA", @@ -2002,7 +2065,7 @@ def has_corresponding_torch_dtype(np_dtype): def skipIfNNModuleInlined( msg="test doesn't currently work with nn module inlining", - condition=torch._dynamo.config.inline_inbuilt_nn_modules, + condition=True, ): def decorator(fn): if not isinstance(fn, type): @@ -2108,12 +2171,31 @@ def wrapper(*args, **kwargs): return dec_fn def skipIfMPS(fn): + sig = inspect.signature(fn) + has_device_arg = "device" in sig.parameters + + if not has_device_arg: + warnings.warn( + f"skipIfMPS applied to {fn.__qualname__} which has no 'device' parameter. " + "Consider using device-generic tests with instantiate_device_type_tests instead.", + stacklevel=2, + ) + @wraps(fn) def wrapper(*args, **kwargs): - if TEST_MPS: + if has_device_arg: + # For device-generic tests, only skip when actually running on MPS + slf = args[0] if args else None + if slf is not None: + device_type = getattr(slf, "device_type", None) or getattr( + slf, "device", None + ) + if isinstance(device_type, str) and device_type == "mps": + raise unittest.SkipTest("test doesn't currently work with MPS") + elif TEST_MPS: raise unittest.SkipTest("test doesn't currently work with MPS") - else: - fn(*args, **kwargs) + return fn(*args, **kwargs) + return wrapper @@ -2160,7 +2242,7 @@ def dec_fn(fn): @wraps(fn) def wrapper(*args, **kwargs): - if IS_WINDOWS: # noqa: F821 + if IS_WINDOWS: raise unittest.SkipTest(reason) else: return fn(*args, **kwargs) @@ -2175,7 +2257,7 @@ def dec_fn(fn): @wraps(fn) def wrapper(*args, **kwargs): - if IS_WINDOWS and torch.xpu.is_available(): # noqa: F821 + if IS_WINDOWS and torch.xpu.is_available(): raise unittest.SkipTest(reason) else: return fn(*args, **kwargs) @@ -2227,6 +2309,10 @@ def _fn(*args, **kwargs): fn(*args, **kwargs) finally: torch.backends.cuda.preferred_blas_library(_preferred_backend) + if torch.backends.cuda.is_built(): + torch._C._cuda_resetCublasWorkspaceSize() + torch._C._cuda_resetCublasLtWorkspaceSize() + torch._C._cuda_clearCublasWorkspaces() return _fn @@ -3463,7 +3549,7 @@ def expect_failure(f, file_name): def wrapper(*args, **kwargs): try: f(*args, **kwargs) - except BaseException as e: # noqa: B036 + except BaseException as e: self.skipTest(e) raise RuntimeError(f"Unexpected success, please remove `{file_name}`") return wrapper @@ -3485,7 +3571,7 @@ def ignore_failure(f, file_name): def wrapper(*args, **kwargs): try: f(*args, **kwargs) - except BaseException as e: # noqa: B036 + except BaseException as e: self.skipTest(e) method = getattr(self, self._testMethodName) if getattr(method, "__unittest_expecting_failure__", False): @@ -4315,6 +4401,8 @@ def to_list(input): if isinstance(y, torch.Tensor) and y.is_nested and y.layout == torch.strided: y = y.unbind() + x, y = _unwrap_dtensor_for_comparison(x, y) + error_metas = not_close_error_metas( x, y, @@ -5877,6 +5965,9 @@ def repl_frame(m): if suppress_prefix: s = re.sub(r"Cannot export model.+\n\n", "", s) s = re.sub(r" +$", "", s, flags=re.MULTILINE) + # Normalize caret-only lines by stripping leading whitespace, since + # col_offset in bytecode positions can vary across Python point releases + s = re.sub(r"^[ ]+(\^+)$", r"\1", s, flags=re.MULTILINE) return s @@ -6056,3 +6147,48 @@ def get_gcc_major_version(): return int(out.split(".")[0]) except Exception: return None + + +def run_concurrently(worker_func, num_threads=None, args=(), kwargs=None): + # Adapted from CPython test suite. Runs worker_func in multiple threads + # concurrently to help expose thread-safety issues. Works best in + # combination with ThreadSanitizer (TSan). + from collections.abc import Iterable + + if kwargs is None: + kwargs = {} + if num_threads is None: + num_threads = len(worker_func) + if not isinstance(worker_func, Iterable): + worker_func = [worker_func] * num_threads + + barrier = threading.Barrier(num_threads) + + results = [None] * num_threads + exc_value = None + + def wrapper_func(idx, func, *args, **kwargs): + # Wait for all threads to reach this point before proceeding. + try: + barrier.wait() + res = func(*args, **kwargs) + results[idx] = res + except Exception as e: + nonlocal exc_value + exc_value = e + + workers = [ + threading.Thread(target=wrapper_func, args=(i, func, *args), + kwargs=kwargs, daemon=True) + for i, func in enumerate(worker_func) + ] + for w in workers: + w.start() + for w in workers: + w.join() + + # If a worker thread raises an exception, re-raise it. + if exc_value is not None: + raise exc_value + + return results diff --git a/torch/testing/_internal/common_xpu.py b/torch/testing/_internal/common_xpu.py new file mode 100644 index 0000000000000..ee4f8c921eac5 --- /dev/null +++ b/torch/testing/_internal/common_xpu.py @@ -0,0 +1,88 @@ +import enum +import functools + +import torch +import torch.xpu +from torch.testing._internal.common_utils import IS_WINDOWS, LazyVal, TEST_XPU + + +XPU_ALREADY_INITIALIZED_ON_IMPORT = torch.xpu.is_initialized() + + +class XPUCodename(enum.Enum): + PVC = "PVC" # Intel® Data Center GPU Max Series + BMG = "BMG" # Intel® Arc™ Pro Battlemage Graphics + + +class XPUArch(enum.IntEnum): + Unknown = 0 + Xe = 1 # Xe HPC + Xe2 = 2 + + +# device_id -> GPU codename +# From https://github.com/intel/intel-graphics-compiler/blob/master/inc/common/igfxfmid.h +_DEVICE_ID_TO_CODENAME = { + 0x0BD0: XPUCodename.PVC, + 0x0BD4: XPUCodename.PVC, + 0x0BD5: XPUCodename.PVC, + 0x0BD6: XPUCodename.PVC, + 0x0BD7: XPUCodename.PVC, + 0x0BD8: XPUCodename.PVC, + 0x0BD9: XPUCodename.PVC, + 0x0BDA: XPUCodename.PVC, + 0x0BDB: XPUCodename.PVC, + 0x0B69: XPUCodename.PVC, + 0x0B6E: XPUCodename.PVC, + 0xE202: XPUCodename.BMG, + 0xE20B: XPUCodename.BMG, + 0xE20C: XPUCodename.BMG, + 0xE20D: XPUCodename.BMG, + 0xE210: XPUCodename.BMG, + 0xE212: XPUCodename.BMG, + 0xE215: XPUCodename.BMG, + 0xE216: XPUCodename.BMG, + 0xE220: XPUCodename.BMG, + 0xE221: XPUCodename.BMG, + 0xE222: XPUCodename.BMG, + 0xE223: XPUCodename.BMG, +} + +# GPU codename -> architecture +_CODENAME_TO_ARCH = { + XPUCodename.PVC: XPUArch.Xe, + XPUCodename.BMG: XPUArch.Xe2, +} + + +@functools.lru_cache(1) +def get_xpu_codename() -> XPUCodename | None: + device_id = torch.xpu.get_device_capability()["device_id"] + return _DEVICE_ID_TO_CODENAME.get(device_id) + + +@functools.lru_cache(1) +def get_xpu_arch() -> XPUArch | None: + codename = get_xpu_codename() + return _CODENAME_TO_ARCH.get(codename, XPUArch.Unknown) + + +Xe2_Or_Later = LazyVal( + lambda: torch.xpu.is_available() and get_xpu_arch() >= XPUArch.Xe2 +) + + +def evaluate_platform_supports_flash_attention(): + if TEST_XPU: + return not IS_WINDOWS and Xe2_Or_Later + return False + + +PLATFORM_SUPPORTS_FLASH_ATTENTION_XPU: bool = LazyVal( + lambda: evaluate_platform_supports_flash_attention() +) + +# Importing this module should NOT eagerly initialize XPU +if not XPU_ALREADY_INITIALIZED_ON_IMPORT: + if torch.xpu.is_initialized(): + raise AssertionError("XPU should not be initialized on import") diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index dde1b4d7bc491..68c6baa863e98 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -1461,7 +1461,49 @@ def _to_local_shard(a): dt = DTensor.from_local(local, device_mesh, (plc,)) full = dt.redistribute(device_mesh, (Replicate(),)).to_local() if ref.shape != full.shape or not torch.allclose( - ref, full, atol=1e-5, rtol=1e-5 + ref, full, atol=1e-5, rtol=1e-5, equal_nan=True ): return False return True + + +@contextlib.contextmanager +def op_strategy_context(op_overload, strategy_func, schema_info=None): + """ + Context manager for setting and clearing op strategies. + Args: + op_overload: The operator overload to set or clear the strategy for. + strategy_func: The strategy function to set for the operator overload. + schema_info: Optional schema information for the operator overload. + Yields: + None + """ + from torch.distributed.tensor._ops.utils import register_op_strategy + from torch.distributed.tensor.debug import _clear_sharding_prop_cache + + propagator = DTensor._op_dispatcher.sharding_propagator + _origin_op_strategy_funcs = None + _origin_op_strategy_schema = None + try: + # register the op strategy + if op_overload in propagator.op_strategy_funcs: + _origin_op_strategy_funcs = propagator.op_strategy_funcs[op_overload] + del propagator.op_strategy_funcs[op_overload] + if op_overload in propagator.op_to_schema_info: + _origin_op_strategy_schema = propagator.op_to_schema_info[op_overload] + del propagator.op_to_schema_info[op_overload] + register_op_strategy(op_overload, schema_info=schema_info)(strategy_func) + yield + finally: + # clear this op strategy cache + if _origin_op_strategy_funcs is None: + if op_overload in propagator.op_strategy_funcs: + del propagator.op_strategy_funcs[op_overload] + else: + propagator.op_strategy_funcs[op_overload] = _origin_op_strategy_funcs + if _origin_op_strategy_schema is None: + if op_overload in propagator.op_to_schema_info: + del propagator.op_to_schema_info[op_overload] + else: + propagator.op_to_schema_info[op_overload] = _origin_op_strategy_schema + _clear_sharding_prop_cache() diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index f182a724773ab..18d3e8a11b385 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -198,9 +198,6 @@ class TestNamedTupleInput_1(NamedTuple): BACKEND = os.environ["BACKEND"] INIT_METHOD = os.getenv("INIT_METHOD", "env://") -DEFAULT_TIMEOUT = 300 -CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500} - def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False): event_list = ( @@ -402,14 +399,6 @@ def forward(self, x): return F.relu(self.lin1(x)) -def get_timeout(test_id): - test_name = test_id.split(".")[-1] - if test_name in CUSTOMIZED_TIMEOUT: - return CUSTOMIZED_TIMEOUT[test_name] - else: - return DEFAULT_TIMEOUT - - default_pg_timeout = 60 CUSTOM_PG_TIMEOUT = { @@ -905,6 +894,24 @@ def test_barrier_timeout_full_group(self): if group_id is not None: self._test_barrier_timeout(group_id, timeout) + @skip_but_pass_in_sandcastle_if( + BACKEND != "gloo", "Only gloo backend supports timeouts" + ) + def test_barrier_timeout_arg(self): + """Test that the timeout argument to barrier() overrides PG default. + + Create a PG with a large default timeout, then have only rank 0 + call barrier with a tiny timeout. The barrier should time out using + the per-call timeout (1ms) rather than the PG default (300s). + """ + pg = dist.new_group(timeout=timedelta(seconds=300)) + + if dist.get_rank() == 0: + with self.assertRaisesRegex(RuntimeError, "Timed out waiting 1ms"): + dist.barrier(group=pg, timeout=timedelta(seconds=0.001)) + + dist.destroy_process_group(pg) + @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["subgroup"], f"The {BACKEND} backend does not support creating subgroups on CUDA devices", @@ -5004,7 +5011,7 @@ def __init__(self) -> None: self.register_buffer("buffer", torch.randn(1, 2)) self.p = torch.nn.Parameter(torch.randn(10, 5), requires_grad=False) - def forward(self_, x): # noqa: B902 + def forward(self_, x): params = self_.m.parameters() for p in params: self.assertEqual(mp_config.param_dtype, p.dtype) @@ -7975,11 +7982,11 @@ def dict_validator(x): } class ToyModel(torch.nn.Module): - def __init__(self_): # noqa: B902 + def __init__(self_): super().__init__() self_.lin = nn.Linear(10, 10, bias=False) - def forward(self_, x, expected_type): # noqa: B902 + def forward(self_, x, expected_type): # Similar to scatter, the recursive to in the single-device # case does not move tensors if they are in a custom type. self.assertTrue(isinstance(x, expected_type)) @@ -8038,11 +8045,11 @@ def test_ddp_namedtuple(self): b = torch.rand(batch, dim, device=self.rank) class NamedTupleModule(torch.nn.Module): - def __init__(self_): # noqa: B902 + def __init__(self_): super().__init__() self_.lin = nn.Linear(10, 1) - def forward(self_, input, expected_type): # noqa: B902 + def forward(self_, input, expected_type): # Without NamedTuple support, this would be of type tuple. self.assertTrue( isinstance(input, expected_type), diff --git a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py index fde1fe2355c29..86f7bbbe21097 100644 --- a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py @@ -19,7 +19,7 @@ def local_add(t1, t2): @torch.jit.script -def remote_add(t1, t2, dst: str): # noqa: E999 +def remote_add(t1, t2, dst: str): return rpc_async(dst, local_add, (t1, t2)).wait() diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index 82a5d66e87f38..fc82f987127b4 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -508,7 +508,7 @@ def two_args_two_kwargs( @torch.jit.script def assorted_types_args_kwargs( - tensor_arg: Tensor, # noqa: E999 + tensor_arg: Tensor, str_arg: str, int_arg: int, tensor_kwarg: Tensor = torch.tensor([2, 2]), @@ -684,7 +684,7 @@ def test_less_than_needed_args_are_specified(self): @torch.jit.script def script_rpc_async_call_with_less_args( - dst_worker_name: str, # noqa: E999 + dst_worker_name: str, ): args = (torch.tensor([1, 1]),) kwargs = {} @@ -729,7 +729,7 @@ def test_unexepected_kwarg_is_specified(self): # Notice, kwargs matching happens during execution. @torch.jit.script def script_rpc_async_call_with_unexpected_kwarg( - dst_worker_name: str, # noqa: E999 + dst_worker_name: str, ): args = (torch.tensor([1, 1]), torch.tensor([2, 2])) kwargs = {"third_kwarg": torch.tensor([1, 1])} diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index e3b2452d5eea6..2084e55babf6f 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3549,7 +3549,7 @@ def test_custom_exception_throw_during_reconstruction(self): print(f"Got msg {msg}") self.assertTrue("Original exception on remote side was" in msg) self.assertTrue("CustomException" in msg) - except BaseException as e: # noqa: B036 + except BaseException as e: raise RuntimeError(f"Failure - expected RuntimeError, got {e}") from e finally: self.assertTrue(exc_caught) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 526f340e7222a..80a249b8d909f 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -5,8 +5,16 @@ import torch from functorch.experimental.control_flow import map -from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention +from torch._higher_order_ops.flex_attention import ( + flex_attention as flex_attention_hop, +) +from torch.nn.attention.flex_attention import ( + _create_empty_block_mask, + create_block_mask, + flex_attention, +) from torch.testing import make_tensor +from torch._higher_order_ops.inline_asm_elementwise import inline_asm_elementwise from torch.testing._internal.common_device_type import onlyCUDA from torch.testing._internal.common_dtype import all_types_and, custom_types from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput @@ -195,6 +203,99 @@ def score_mod(score, b, h, m, n): yield SampleInput(q, k, v, score_mod, block_mask) +def sample_inputs_flex_attention_backward( + opinfo, device, dtype, requires_grad, **kwargs +): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=False + ) + + def score_mod(score, b, h, m, n): + return score + + def mask_mod(b, h, m, n): + return m >= n + + q, k, v = (make_arg(2, 2, 128, 16, low=0.1, high=2) for _ in range(3)) + block_mask = create_block_mask(mask_mod, B=2, H=2, Q_LEN=128, KV_LEN=128, device=device) + scale = 1.0 / q.size(-1) ** 0.5 + out, logsumexp, _ = flex_attention_hop( + q, k, v, score_mod, block_mask.as_tuple(), scale, {}, + ) + yield SampleInput( + q, + args=( + k, v, out.detach(), logsumexp.detach(), torch.rand_like(out), None, + score_mod, None, block_mask.as_tuple(), + scale, {}, (), (), + ), + ) + + +def sample_inputs_flex_attention_backward_explicit_buffers( + opinfo, device, dtype, requires_grad, **kwargs +): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=False + ) + mask_offset = torch.full((), 128, device=device, dtype=torch.int32) + + def score_mod(score, b, h, m, n): + return score + + def mask_mod(b, h, m, n): + return m + mask_offset >= n + + q, k, v = (make_arg(2, 2, 128, 16, low=0.1, high=2) for _ in range(3)) + block_mask = create_block_mask(mask_mod, B=2, H=2, Q_LEN=128, KV_LEN=128, device=device) + scale = 1.0 / q.size(-1) ** 0.5 + out, logsumexp, _ = flex_attention_hop( + q, k, v, score_mod, block_mask.as_tuple(), scale, {}, + ) + yield SampleInput( + q, + args=( + k, v, out.detach(), logsumexp.detach(), torch.rand_like(out), None, + score_mod, None, block_mask.as_tuple(), + scale, {}, (), (), + ), + ) + + +def simple_flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, +): + return torch.ops.higher_order.flex_attention_backward( + query, + key, + value, + out, + logsumexp, + grad_out, + grad_logsumexp, + fw_graph, + joint_graph, + block_mask, + scale, + kernel_options, + score_mod_other_buffers, + mask_mod_other_buffers, + ) + + def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=False @@ -292,6 +393,27 @@ def fn(x): return invoke_quant_packed(fn, x)[0] * 2.0 +def sample_inputs_inline_asm(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2)) + + +def simple_inline_asm(x): + if torch.version.hip: + return inline_asm_elementwise( + x, + asm_str="v_mov_b32_e32 $0, $1", + constraints="=v, v", + dtype=torch.float32, + ) + + return inline_asm_elementwise( + x, asm_str="mov.f32 $0, $1;", constraints="=f,f", dtype=torch.float32 + ) + + hop_db = [ OpInfo( name="scan", @@ -476,14 +598,39 @@ def fn(x): OpInfo( name="flex_attention_backward", variant_test_name="simple", - op=flex_attention, - sample_inputs_func=sample_inputs_flex_attention, + op=simple_flex_attention_backward, + sample_inputs_func=sample_inputs_flex_attention_backward, dtypes=custom_types(torch.float16, torch.float32), supports_out=False, check_batched_grad=False, check_batched_gradgrad=False, check_batched_forward_grad=False, check_inplace_batched_forward_grad=False, + supports_autograd=False, + supports_gradgrad=False, + skips=( + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), + DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), + ), + decorators=[onlyCUDA], + ), + OpInfo( + name="flex_attention_backward", + variant_test_name="explicit_buffers", + op=simple_flex_attention_backward, + sample_inputs_func=sample_inputs_flex_attention_backward_explicit_buffers, + dtypes=custom_types(torch.float16, torch.float32), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + supports_gradgrad=False, skips=( DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), DecorateInfo( @@ -520,4 +667,18 @@ def fn(x): ), ], ), + OpInfo( + name="inline_asm_elementwise", + variant_test_name="simple", + op=simple_inline_asm, + sample_inputs_func=sample_inputs_inline_asm, + dtypes=custom_types(torch.float32), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + decorators=[onlyCUDA], + ), ] diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index 83f721b8fad9d..3cc246537b2e7 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -10,7 +10,7 @@ from subprocess import CalledProcessError import torch -import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +import torch._inductor.async_compile import torch._inductor.config as config from torch._inductor.codecache import CppCodeCache from torch._inductor.codegen.common import ( @@ -181,14 +181,19 @@ def skip_windows_ci(name: str, file: str) -> None: requires_helion = functools.partial(unittest.skipIf, not HAS_HELION, "requires helion") -def requires_cuda_with_enough_memory(min_mem_required): +def requires_gpu_with_enough_memory(min_mem_required): def inner(fn): + total_memory = sys.maxsize + if torch.xpu.is_available(): + total_memory = torch.xpu.get_device_properties().total_memory + elif torch.cuda.is_available(): + total_memory = torch.cuda.get_device_properties().total_memory if ( - not torch.cuda.is_available() - or torch.cuda.get_device_properties().total_memory < min_mem_required + not (torch.cuda.is_available() or torch.xpu.is_available()) + or total_memory < min_mem_required ): return unittest.skip( - f"Only if the CUDA device has at least {min_mem_required / 1e9:.3f}GB memory to be safe" + f"Only if the GPU device has at least {min_mem_required / 1e9:.3f}GB memory to be safe" )(fn) else: return fn @@ -388,7 +393,7 @@ def __init__(self, name_to_buffer=None): self.constants = {} self.scheduler = None - def get_dtype(self, buffer_name: str) -> torch.dtype: # noqa: ARG002 + def get_dtype(self, buffer_name: str) -> torch.dtype: """Return default dtype for any buffer (for testing).""" return torch.float32 diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index 714387f49d960..670b34c64ecdf 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -641,16 +641,10 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar DecorateInfo( unittest.skip("Skipped!"), "TestJit", "test_variant_consistency_jit" ), - # Exception: cumulative ops are not yet supported for complex + # The following dtypes worked in forward but are not listed by the OpInfo: {torch.bool}. DecorateInfo( unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), ), # Can reuse the same inputs; dim is required in both sample_inputs_func=sample_inputs_masked_cumops, @@ -688,16 +682,10 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar "test_comprehensive", device_type="cuda", ), - # Exception: cumulative ops are not yet supported for complex + # The following dtypes worked in forward but are not listed by the OpInfo: {torch.bool}. DecorateInfo( unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), ), # Can reuse the same inputs; dim is required in both sample_inputs_func=sample_inputs_masked_cumops, @@ -1315,16 +1303,6 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar DecorateInfo( unittest.expectedFailure, "TestJit", "test_variant_consistency_jit" ), - # Exception: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), ), gradcheck_wrapper=gradcheck_wrapper_masked_operation, # Runs very slowly on slow gradcheck - alternatively reduce input sizes diff --git a/torch/testing/_internal/opinfo/definitions/linalg.py b/torch/testing/_internal/opinfo/definitions/linalg.py index bd03c509a47fb..b96a66af8db95 100644 --- a/torch/testing/_internal/opinfo/definitions/linalg.py +++ b/torch/testing/_internal/opinfo/definitions/linalg.py @@ -33,6 +33,9 @@ ) from torch.testing._internal.common_utils import ( GRADCHECK_NONDET_TOL, + IS_ARM64, + IS_CPU_CAPABILITY_SVE256, + IS_LINUX, make_fullrank_matrices_with_distinct_singular_values, skipIfSlowGradcheckEnv, slowTest, @@ -1599,6 +1602,15 @@ def make_input(): ), # Exception: The operator 'aten::linalg_lstsq.out' is not currently implemented for the MPS device DecorateInfo(unittest.expectedFailure, "TestCommon", device_type="mps"), + # see https://github.com/pytorch/pytorch/issues/177249 + DecorateInfo( + unittest.expectedFailure, + "TestJit", + "test_variant_consistency_jit", + device_type="cpu", + dtypes=[torch.complex64], + active_if=IS_LINUX and IS_ARM64 and not IS_CPU_CAPABILITY_SVE256, + ), ), ), OpInfo( @@ -1789,14 +1801,6 @@ def make_input(): "test_noncontiguous_samples", device_type="mps", ), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_out_requires_grad_error", - device_type="mps", - dtypes=(torch.complex64,), - ), ), ), OpInfo( @@ -1861,14 +1865,6 @@ def make_input(): "test_noncontiguous_samples", device_type="mps", ), - # RuntimeError: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_out_requires_grad_error", - device_type="mps", - dtypes=(torch.complex64,), - ), ), ), OpInfo( @@ -1920,16 +1916,6 @@ def make_input(): "test_variant_consistency_eager", device_type="mps", ), - # Exception: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), ), ), OpInfo( @@ -1983,16 +1969,6 @@ def make_input(): "TestCommon", "test_numpy_ref_mps", ), - # Exception: cumulative ops are not yet supported for complex - DecorateInfo( - unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), ), ), ReductionOpInfo( @@ -2010,18 +1986,6 @@ def make_input(): dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), generate_args_kwargs=sample_kwargs_vector_norm, aten_name="linalg_vector_norm", - skips=( - # Exception: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, "TestCommon", "test_dtypes", device_type="mps" - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - device_type="mps", - dtypes=(torch.complex64,), - ), - ), ), OpInfo( "linalg.lu_factor", @@ -2622,6 +2586,15 @@ def make_input(): "test_variant_consistency_eager", device_type="mps", ), + # see https://github.com/pytorch/pytorch/issues/177264 + DecorateInfo( + unittest.expectedFailure, + "TestEagerFusionOpInfo", + "test_aot_autograd_symbolic_exhaustive", + device_type="cpu", + dtypes=[torch.float32], + active_if=IS_ARM64 and IS_LINUX, + ), ), ), OpInfo( @@ -2641,27 +2614,6 @@ def make_input(): sample_inputs_func=sample_inputs_svd, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off], skips=( - DecorateInfo( - unittest.skip("Skipped!"), - "TestCommon", - "test_out", - device_type="mps", - dtypes=[torch.float32], - ), - DecorateInfo( - unittest.skip("Skipped!"), - "TestCommon", - "test_variant_consistency_eager", - device_type="mps", - dtypes=[torch.float32], - ), - DecorateInfo( - unittest.skip("Skipped!"), - "TestJit", - "test_variant_consistency_jit", - device_type="mps", - dtypes=[torch.float32], - ), DecorateInfo( unittest.skip("Skipped!"), "TestFakeTensor", @@ -2678,24 +2630,6 @@ def make_input(): dtypes=[torch.float32], active_if=TEST_WITH_ROCM, ), - # MPS: AssertionError: The values for attribute 'shape' do not match: torch.Size([0, 0]) != torch.Size([0, 1]). - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_out_warning", - device_type="mps", - ), - # MPS: RuntimeError: svd_backward: The singular vectors in the - # complex case are specified up to multiplication by e^{i phi}. The - # specified loss function depends on this phase term, making it - # ill-defined. - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_noncontiguous_samples", - device_type="mps", - dtypes=(torch.complex64,), - ), ), ), OpInfo( @@ -2818,20 +2752,6 @@ def make_input(): DecorateInfo( unittest.expectedFailure, "TestCommon", "test_python_ref_errors" ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.bfloat16, torch.float16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.bfloat16, torch.float16), - ), ), ), PythonRefInfo( @@ -2844,61 +2764,12 @@ def make_input(): "_refs.linalg.vecdot", torch_opinfo_name="linalg.vecdot", op_db=op_db, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ReductionPythonRefInfo( "_refs.linalg.vector_norm", torch_opinfo_name="linalg.vector_norm", supports_out=True, op_db=op_db, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16,), - ), - # Exception: norm ops are not supported for complex yet - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.complex64,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.complex64,), - ), - ), ), PythonRefInfo( "_refs.linalg.matrix_norm", @@ -2915,29 +2786,14 @@ def make_input(): "TestCommon", "test_python_ref_torch_fallback", device_type="mps", - dtypes=(torch.float32,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float32,), + dtypes=(torch.float32, torch.complex64), ), - # Exception: norm ops are not supported for complex yet DecorateInfo( unittest.expectedFailure, "TestCommon", "test_python_ref", device_type="mps", - dtypes=(torch.complex64,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.complex64,), + dtypes=(torch.float32, torch.complex64), ), ), ), diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py index 7f8781ce87b24..b08954c823e16 100644 --- a/torch/testing/_internal/opinfo/definitions/sparse.py +++ b/torch/testing/_internal/opinfo/definitions/sparse.py @@ -3,8 +3,9 @@ import os import torch -from torch.testing import make_tensor # noqa: F401 -from torch.testing._internal.opinfo.core import ( # noqa: F401 +from torch.testing import make_tensor +from torch.testing._internal.common_dtype import highest_precision_float +from torch.testing._internal.opinfo.core import ( BinaryUfuncInfo, ErrorInput, generate_elementwise_binary_tensors, @@ -769,8 +770,9 @@ def _sample_inputs_sparse_like_fns( tensor, args=(), kwargs=dict(device=device, dtype=dtype, layout=layout) ) - if dtype is not torch.float64: - yield SampleInput(tensor, args=(), kwargs=dict(dtype=torch.float64)) + hpf = highest_precision_float(device) + if dtype is not hpf: + yield SampleInput(tensor, args=(), kwargs=dict(dtype=hpf)) if torch.cuda.is_available(): other_device = "cuda" if tensor.device.type == "cpu" else "cpu" diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index d375ba21358ff..1626ead643244 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -878,23 +878,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): }, ), ), - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.bessel_j1", @@ -908,23 +891,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): }, ), ), - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.entr", @@ -995,21 +961,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): "test_reference_numerics_large", dtypes=(torch.int8,), ), - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16,), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16,), - ), ), ), ElementwiseUnaryPythonRefInfo( @@ -1041,23 +992,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): "_refs.special.ndtr", torch_opinfo_name="special.ndtr", op_db=op_db, - skips=( - # TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64 - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - DecorateInfo( - unittest.expectedFailure, - "TestCommon", - "test_python_ref_torch_fallback", - device_type="mps", - dtypes=(torch.float16, torch.bfloat16), - ), - ), ), ElementwiseUnaryPythonRefInfo( "_refs.special.ndtri", diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index 02ec2623271c9..9d76e67ee16d7 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -96,7 +96,6 @@ def meta_takes_foo_tensor_return(foo, x): def register_fake_classes(): - # noqa: F841 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") class FakeFoo: def __init__(self, x: int, y: int): diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index c6b2e33af17d1..9961330e57bb9 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -13,6 +13,9 @@ requires_cuda_and_triton = unittest.skipUnless( HAS_CUDA_AND_TRITON, "requires cuda and triton" ) +requires_xpu_and_triton = unittest.skipUnless( + HAS_XPU_AND_TRITON, "requires xpu and triton" +) requires_gpu_and_triton = unittest.skipUnless( HAS_XPU_AND_TRITON or HAS_CUDA_AND_TRITON, "requires gpu and triton" ) diff --git a/torch/testing/_utils.py b/torch/testing/_utils.py index b86edfdd67f3c..02eeb1f46f5c8 100644 --- a/torch/testing/_utils.py +++ b/torch/testing/_utils.py @@ -33,8 +33,12 @@ def freeze_rng_state(): # which we need to disable to get and set rng state with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): rng_state = torch.get_rng_state() - if torch.cuda.is_available(): - cuda_rng_state = torch.cuda.get_rng_state() + if torch.accelerator.is_available(): + accelerator = torch.accelerator.current_accelerator(check_available=True) + if accelerator is not None: + accelerator_rng_state = torch.get_device_module( + accelerator.type + ).get_rng_state() try: yield finally: @@ -47,6 +51,12 @@ def freeze_rng_state(): # # NB: Mode disable is to avoid running cross-ref tests on this seeding with torch.utils._mode_utils.no_dispatch(), torch._C._DisableFuncTorch(): - if torch.cuda.is_available(): - torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined] + if torch.accelerator.is_available(): + accelerator = torch.accelerator.current_accelerator( + check_available=True + ) + if accelerator is not None: + torch.get_device_module(accelerator.type).set_rng_state( + accelerator_rng_state # type: ignore[possibly-undefined] + ) torch.set_rng_state(rng_state) diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 2bb21c31c4c86..99e7cf8f9cc77 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -9,11 +9,11 @@ import tokenize import unittest from collections.abc import Callable +from contextvars import ContextVar from dataclasses import dataclass from types import FunctionType, ModuleType from typing import Any, Generic, NoReturn, Optional, TYPE_CHECKING, TypeVar from typing_extensions import deprecated -from unittest import mock from torch._utils_internal import justknobs_check @@ -21,6 +21,10 @@ # Types saved/loaded in configs CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict) +# Immutable scalar types that don't need deepcopy when returned from configs. +# Everything else is defensively copied to prevent accidental mutation. +_IMMUTABLE_CONFIG_TYPES = (int, float, bool, type(None), str, tuple) + # Duplicated, because mypy needs these types statically T = TypeVar("T", bound=int | float | bool | str | list | set | tuple | dict | None) @@ -46,7 +50,7 @@ class _Config(Generic[T]): If multiple env variables are given, the precedence order is from left to right. user_override: If a user sets a value (i.e. foo.bar=True), that - has precedence over everything after this. + has precedence over everything after this. User overrides are thread-local. env_name_default: If set, this environment variable will override everything after this. If multiple env variables are given, the precedence order is from @@ -203,7 +207,7 @@ def visit( annotated_type = type_hints.get(key, None) if isinstance(value, CONFIG_TYPES): config[name] = _ConfigEntry( - _Config(default=value, value_type=annotated_type) + _Config(default=value, value_type=annotated_type), name ) if dest is module: delattr(module, key) @@ -211,7 +215,7 @@ def visit( if annotated_type is not None and value.value_type is None: value.value_type = annotated_type - config[name] = _ConfigEntry(value) + config[name] = _ConfigEntry(value, name) if dest is module: delattr(module, key) @@ -291,7 +295,7 @@ class _ConfigEntry: value_type: type # The value specified by the user when they overrode the configuration # _UNSET_SENTINEL indicates the value is not set. - user_override: Any = _UNSET_SENTINEL + user_override: ContextVar[object] # The justknob to check for this config justknob: str | None = None # environment variables are read at install time @@ -315,7 +319,7 @@ class _ConfigEntry: deprecation_message: str | None = None _deprecation_warned: bool = False - def __init__(self, config: _Config) -> None: + def __init__(self, config: _Config, name: str) -> None: self.default = config.default self.value_type = ( config.value_type if config.value_type is not None else type(self.default) @@ -327,6 +331,7 @@ def __init__(self, config: _Config) -> None: self.deprecation_message = config.deprecation_message self._deprecation_warned = False + self.user_override = ContextVar(name, default=_UNSET_SENTINEL) if config.env_name_default is not None: for val in config.env_name_default: if (env_value := _read_env_variable(val)) is not None: @@ -400,7 +405,7 @@ def __setattr__(self, name: str, value: object) -> None: if config.alias is not None: self._set_alias_val(config, value) else: - config.user_override = value + config.user_override.set(value) self._is_dirty = True config.hide = False @@ -421,8 +426,9 @@ def __getattr__(self, name: str) -> Any: if config.env_value_force is not _UNSET_SENTINEL: return config.env_value_force - if config.user_override is not _UNSET_SENTINEL: - return config.user_override + user_override = config.user_override.get() + if user_override is not _UNSET_SENTINEL: + return user_override if config.env_value_default is not _UNSET_SENTINEL: return config.env_value_default @@ -431,12 +437,11 @@ def __getattr__(self, name: str) -> Any: # JK only supports bools and ints return justknobs_check(name=config.justknob, default=config.default) - # Note that reference types can still be modified, so we - # copy them to user_overrides in case the user overrides - # them - if isinstance(config.default, (list, set, dict)): - config.user_override = copy.deepcopy(config.default) - return config.user_override + # Reference types can still be modified, so copy them to + # user_overrides to prevent accidental mutation of defaults. + if not isinstance(config.default, _IMMUTABLE_CONFIG_TYPES): + config.user_override.set(copy.deepcopy(config.default)) + return config.user_override.get() return config.default except KeyError as e: @@ -447,7 +452,7 @@ def __delattr__(self, name: str) -> None: self._is_dirty = True # must support delete because unittest.mock.patch deletes # then recreate things - self._config[name].user_override = _UNSET_SENTINEL + self._config[name].user_override.set(_UNSET_SENTINEL) self._config[name].hide = True def _get_alias_module_and_name( @@ -498,10 +503,10 @@ def _is_default(self, name: str) -> bool: or config_val.env_value_force == config_val.default ) - unset = config_val.user_override is _UNSET_SENTINEL + unset = config_val.user_override.get() is _UNSET_SENTINEL # Handle reference types specially to avoid spammy warnings - if isinstance(config_val.default, (list, set, dict)): - unset = unset or config_val.user_override == config_val.default + if not isinstance(config_val.default, _IMMUTABLE_CONFIG_TYPES): + unset = unset or config_val.user_override.get() == config_val.default return unset and not_set_env_default and not_set_env_force def _get_dict( @@ -527,7 +532,9 @@ def _get_dict( it skips it. """ config: dict[str, Any] = {} - for key in self._config: + for key, entry in self._config.items(): + if entry.alias is not None: + continue if ignored_keys and key in ignored_keys: continue if ignored_prefixes: @@ -535,14 +542,23 @@ def _get_dict( continue if skip_default and self._is_default(key): continue - if self._config[key].alias is not None: - continue - curr_entry = self._config[key] - has_been_warned = curr_entry._deprecation_warned - curr_entry._deprecation_warned = True - config[key] = copy.deepcopy(getattr(self, key)) - curr_entry._deprecation_warned = has_been_warned + # Read value directly, bypassing __getattr__ overhead + # (deprecation warnings, alias resolution). + user_override = entry.user_override.get() + if entry.env_value_force is not _UNSET_SENTINEL: + val = entry.env_value_force + elif user_override is not _UNSET_SENTINEL: + val = user_override + elif entry.env_value_default is not _UNSET_SENTINEL: + val = entry.env_value_default + elif entry.justknob is not None: + val = justknobs_check(name=entry.justknob, default=entry.default) + else: + val = entry.default + if not isinstance(val, _IMMUTABLE_CONFIG_TYPES): + val = copy.deepcopy(val) + config[key] = val return config @@ -565,7 +581,19 @@ def save_config_portable( if ignore_private_configs: prefixes.append("_") prefixes.extend(getattr(self, "_cache_config_ignore_prefix", [])) - return self._get_dict(ignored_prefixes=prefixes) + config = self._get_dict(ignored_prefixes=prefixes) + for key in getattr(self, "_cache_config_factory_keys", []): + if key in config and config[key] is not None: + instance = config[key]() + if hasattr(instance, "uuid"): + config[key] = instance.uuid() + else: + raise RuntimeError( + f"Config '{key}' is set to {config[key]} which does not " + f"implement uuid(). Implement uuid() for cache key " + f"participation." + ) + return config def codegen_config(self) -> str: """Convert config to Python statements that replicate current config. @@ -688,7 +716,7 @@ def patch( **kwargs: dict[str, Any], ) -> "ContextDecorator": """ - Decorator and/or context manager to make temporary changes to a config. + Decorator and/or context manager to make temporary changes to a config. Note that patched settings are thread-local. As a decorator: @@ -732,28 +760,40 @@ def foo(...): ) if not isinstance(changes, dict): raise AssertionError(f"expected `dict` got {type(changes)}") - prior: dict[str, Any] = {} config = self class ConfigPatch(ContextDecorator): def __init__(self) -> None: self.changes = changes + self._prior: ContextVar[tuple[dict[str, Any], ...]] = ContextVar( + f"{config.__name__}.ConfigPatch[{id(self)}]", + default=(), + ) def __enter__(self) -> None: - if prior: - raise AssertionError( - "prior should be empty when entering ConfigPatch" - ) + prior: dict[str, Any] = {} for key in self.changes: # KeyError on invalid entry prior[key] = config.__getattr__(key) - for k, v in self.changes.items(): - config.__setattr__(k, v) + prior_stack = self._prior.get() + self._prior.set((*prior_stack, prior)) + try: + for k, v in self.changes.items(): + config.__setattr__(k, v) + except Exception: + self._prior.set(prior_stack) + raise def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def] + prior_stack = self._prior.get() + if not prior_stack: + raise AssertionError( + "prior should not be empty when exiting ConfigPatch" + ) + prior = prior_stack[-1] + self._prior.set(prior_stack[:-1]) for k, v in prior.items(): config.__setattr__(k, v) - prior.clear() return ConfigPatch() @@ -778,13 +818,13 @@ def _make_closure_patcher(self, **changes: dict[str, Any]) -> Any: config = self._config def change() -> Callable[[], None]: - prior = {k: config[k].user_override for k in changes} + prior = {k: config[k].user_override.get() for k in changes} for k, v in changes.items(): - self._config[k].user_override = v + self._config[k].user_override.set(v) def revert() -> None: for k, v in prior.items(): - self._config[k].user_override = v + self._config[k].user_override.set(v) return revert @@ -853,15 +893,6 @@ def __delattr__(self, name: str) -> None: return self._config.__delattr__(self._prefix + name) -def patch_object(obj: object, name: str, value: object) -> object: - """ - Workaround `mock.patch.object` issue with ConfigModule - """ - if isinstance(obj, ConfigModule): - return obj.patch(name, value) - return mock.patch.object(obj, name, value) - - def get_tristate_env(name: str, default: Any = None) -> bool | None: value = os.environ.get(name) if value == "1": diff --git a/torch/utils/_config_typing.pyi b/torch/utils/_config_typing.pyi index 9cae7368cfa5e..46bd81be5bc57 100644 --- a/torch/utils/_config_typing.pyi +++ b/torch/utils/_config_typing.pyi @@ -22,7 +22,7 @@ Note that the import should happen before the call to install_config_module(), o """ if not TYPE_CHECKING: # noqa: PYI002 - raise AssertionError("Do not use at runtime") # noqa: W291 + raise AssertionError("Do not use at runtime") def save_config() -> bytes: ... def save_config_portable(*, ignore_private_configs: bool = True) -> dict[str, Any]: ... diff --git a/torch/utils/_contextlib.py b/torch/utils/_contextlib.py index 408cdfe7d7b77..f1af9091e7553 100644 --- a/torch/utils/_contextlib.py +++ b/torch/utils/_contextlib.py @@ -50,7 +50,7 @@ def generator_context(*args, **kwargs): gen.close() raise - except BaseException: # noqa: B036 + except BaseException: # Propagate the exception thrown at us by the caller with ctx_factory(): response = gen.throw(*sys.exc_info()) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 892301070e114..d13cc7eb156f1 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -1,7 +1,9 @@ +# Owner(s): ["module: pytree"] + """ Contains utility functions for working with nested python data structures. -A *pytree* is Python nested data structure. It is a tree in the sense that +A *pytree* is a Python nested data structure. It is a tree in the sense that nodes are Python collections (e.g., list, tuple, dict) and the leaves are Python values. Furthermore, a pytree should not contain reference cycles. @@ -22,13 +24,21 @@ import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion from torch.utils._pytree import ( + Context, + DumpableContext, + FlattenFn, + FlattenWithKeysFn, + FromDumpableContextFn, is_namedtuple, is_namedtuple_class, is_namedtuple_instance, is_structseq, is_structseq_class, is_structseq_instance, - KeyEntry, + KeyPath, + PyTree, + ToDumpableContextFn, + UnflattenFn, ) @@ -101,19 +111,8 @@ U = TypeVar("U") R = TypeVar("R") - TreeSpec: TypeAlias = PyTreeSpec - -Context = Any -PyTree = Any -FlattenFn = Callable[[PyTree], tuple[list[Any], Context]] -UnflattenFn = Callable[[Iterable[Any], Context], PyTree] OpTreeUnflattenFn = Callable[[Context, Iterable[Any]], PyTree] -DumpableContext = Any # Any json dumpable text -ToDumpableContextFn = Callable[[Context], DumpableContext] -FromDumpableContextFn = Callable[[DumpableContext], Context] -KeyPath = tuple[KeyEntry, ...] -FlattenWithKeysFn = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]] # Keep deprecated alias for backward compatibility FlattenFunc = FlattenFn # deprecated @@ -163,7 +162,7 @@ def register_pytree_node( Example:: >>> # xdoctest: +SKIP - >>> # Registry a Python type with lambda functions + >>> # Register a Python type with lambda functions >>> register_pytree_node( ... set, ... lambda s: (sorted(s), None, None), @@ -1022,16 +1021,22 @@ def treespec_loads(serialized: str) -> TreeSpec: return treespec -class _DummyLeaf: +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + def __repr__(self) -> str: - return "*" + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk def treespec_pprint(treespec: TreeSpec) -> str: - dummy_tree = tree_unflatten( - [_DummyLeaf() for _ in range(treespec.num_leaves)], - treespec, - ) + dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec) return repr(dummy_tree) diff --git a/torch/utils/_inspect.py b/torch/utils/_inspect.py new file mode 100644 index 0000000000000..67d29b9bc7069 --- /dev/null +++ b/torch/utils/_inspect.py @@ -0,0 +1,90 @@ +import inspect +from typing import Any + + +def _signature_metadata( + sig: inspect.Signature, +) -> tuple[tuple[inspect.Parameter, ...], bool, int]: + """ + Returns tuple(sig.parameters.values()), if any has VAR_POSITIONAL or VAR_KEYWORD, and the max_positional + """ + params = tuple(sig.parameters.values()) + has_var_args = False + max_positional = 0 + + for p in params: + kind = p.kind + if kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + has_var_args = True + if kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ): + max_positional += 1 + + return params, has_var_args, max_positional + + +def _fast_bind( + sig: inspect.Signature, *args: Any, **kwargs: Any +) -> inspect.BoundArguments: + """ + Fast path for inspect.Signature.bind() for signatures without + VAR_POSITIONAL or VAR_KEYWORD parameters. Falls back to sig.bind() + for signatures that contain *args or **kwargs. + """ + params, has_var_args, max_positional = _signature_metadata(sig) + + # fallback for complex signatures + if has_var_args: + return sig.bind(*args, **kwargs) + + len_args = len(args) + + if len_args > max_positional: + raise TypeError( + f"Too many positional arguments: expected max {max_positional}, got {len_args}" + ) + + arguments: dict[str, Any] = {} + arg_i = 0 + + for p in params: + name = p.name + kind = p.kind + + if kind is inspect.Parameter.POSITIONAL_ONLY: + if name in kwargs: + raise TypeError( + f"Got some positional-only arguments passed as keyword arguments: '{name}'" + ) + if arg_i < len_args: + arguments[name] = args[arg_i] + arg_i += 1 + elif p.default is inspect.Parameter.empty: + raise TypeError(f"Missing required argument '{name}'") + + elif kind is inspect.Parameter.POSITIONAL_OR_KEYWORD: + if arg_i < len_args: + if name in kwargs: + raise TypeError(f"Multiple values for argument '{name}'") + arguments[name] = args[arg_i] + arg_i += 1 + elif name in kwargs: + arguments[name] = kwargs[name] + elif p.default is inspect.Parameter.empty: + raise TypeError(f"Missing required argument '{name}'") + + elif kind is inspect.Parameter.KEYWORD_ONLY: + if name in kwargs: + arguments[name] = kwargs[name] + elif p.default is inspect.Parameter.empty: + raise TypeError(f"Missing required argument '{name}'") + + # disallow extra keyword arguments not in the signature + # cause kwargs have been processed by sig.bind at the beginning + for name in kwargs: + if name not in sig.parameters: + raise TypeError(f"Got an unexpected keyword argument '{name}'") + + return inspect.BoundArguments(sig, arguments) # type: ignore[arg-type] diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py index 84e04d2835e7b..7a1f6fdb9ced6 100644 --- a/torch/utils/_pallas.py +++ b/torch/utils/_pallas.py @@ -85,7 +85,7 @@ def has_jax_tpu_backend() -> bool: def has_torch_tpu() -> bool: """Check if torch_tpu is installed and available.""" try: - import torch_tpu.api # noqa: F401 # type: ignore[import] + import torch_tpu.api # type: ignore[import] # Verify hardware/runtime access torch_tpu.api.tpu_device() diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 9b07d32c950dd..83898df359472 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -23,7 +23,9 @@ if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Mapping, Sequence + + from torch._opaque_base import OpaqueBase # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: @@ -449,19 +451,42 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) -# Subtypes which have __tensor_flatten__ and __tensor_unflatten__. -class TensorWithFlatten(Protocol): +# Python typing cannot express Intersection[torch.Tensor, Protocol], so this +# protocol repeats the Tensor members that call sites commonly use after +# is_traceable_wrapper_subclass() narrows a value to this protocol. +class TraceableWrapperSubclass(Protocol): + """ + Canonical protocol for wrapper tensor subclasses that PT2 can trace through. + + ``__tensor_flatten__`` must return stable attribute names for the inner + values that participate in tracing together with any metadata needed to + rebuild the outer subclass. + + ``__tensor_unflatten__`` must reconstruct an equivalent wrapper subclass + instance from the flattened inner values, metadata, and requested outer + size/stride. Callers may pass tensor attrs and registered reference-type + opaques in ``inner_tensors``. The returned tensor is expected to preserve + the requested outer size/stride. If ``attrs, metadata = + x.__tensor_flatten__()``, then ``type(x).__tensor_unflatten__`` must round- + trip from ``{name: getattr(x, name) for name in attrs}``, ``metadata``, + ``x.size()``, and ``x.stride()`` to an equivalent instance of ``x``. + + ``__tensor_unflatten__`` may be implemented as either a ``@staticmethod`` + or a ``@classmethod``; the runtime check below intentionally uses duck + typing to support both forms, even though static type checkers may only + recognize the ``@staticmethod`` form as conforming to this protocol. + """ + def __tensor_flatten__(self) -> tuple[Sequence[str], object]: ... @staticmethod def __tensor_unflatten__( - inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int + inner_tensors: Mapping[str, torch.Tensor | OpaqueBase], + metadata: object, + outer_size: Sequence[int | torch.SymInt], + outer_stride: Sequence[int | torch.SymInt], ) -> torch.Tensor: ... - # It would be really nice to be able to say that the return of - # is_traceable_wrapper_subclass() is Intersection[torch.Tensor, - # TensorWithFlatten] - but that doesn't exist. - shape: torch._C.Size @overload @@ -512,51 +537,36 @@ def to( ) -> torch.Tensor: ... -def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: +TensorWithFlatten = TraceableWrapperSubclass + + +def _has_traceable_wrapper_subclass_protocol(t: object) -> bool: + return hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__") + + +def is_traceable_wrapper_subclass(t: object) -> TypeIs[TraceableWrapperSubclass]: """ - Returns whether or not a tensor subclass that implements __torch_dispatch__ - is 'traceable' with torch.compile. - In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2, - It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__. - It is also expected to obey some restrictions around traceability and aliasing: - * The subclass's __torch_dispatch__() implementation should desugar into pytorch - dispatcher operations that can be traced into a graph. - * The subclass should use return_and_correct_aliasing(). This is needed today to make - sure that torch.compile does the right thing in a few cases around input mutation - and output aliasing. - - Expected magic method signatures: - attrs, ctx = t.__tensor_flatten__() - attrs: list of attribute name strings for inner tensors - ctx: dict containing any other subclass-specific metadata needed for unflattening - - t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride) - inner_tensors: dict mapping attribute name -> tensor for each inner tensor - ctx: dict with subclass metadata in the form that __tensor_flatten__() produces - outer_size: expected (possibly symbolic) size that the returned subclass - instance should have. Note that this arg is useful for certain subclasses - that require the shape info to be constructed. In most cases, this arg can be - safely ignored. - outer_stride: expected (possibly symbolic) stride that the returned subclass - instance should have. Note that this arg is useful for certain subclasses - that require the stride info to be constructed. In most cases, this arg can be - safely ignored. + Returns whether ``t`` is a tensor subclass that matches the + ``TraceableWrapperSubclass`` protocol at runtime. + + See ``TraceableWrapperSubclass`` for the canonical flatten/unflatten + contract. Matching the protocol is necessary but not sufficient for full + PT2 support: the subclass's ``__torch_dispatch__`` implementation must also + desugar into traceable dispatcher operations and preserve aliasing semantics + with ``return_and_correct_aliasing()`` when needed. """ is_subclass = isinstance(t, torch.Tensor) and type(t) is not torch.Tensor - return ( - is_subclass - and hasattr(t, "__tensor_flatten__") - and hasattr(t, "__tensor_unflatten__") - ) + return is_subclass and _has_traceable_wrapper_subclass_protocol(t) -def is_traceable_wrapper_subclass_type(t: type) -> TypeIs[type[TensorWithFlatten]]: +def is_traceable_wrapper_subclass_type( + t: type, +) -> TypeIs[type[TraceableWrapperSubclass]]: """Same as above, but takes a type argument instead of an instance.""" return ( issubclass(t, torch.Tensor) and t is not torch.Tensor - and hasattr(t, "__tensor_flatten__") - and hasattr(t, "__tensor_unflatten__") + and _has_traceable_wrapper_subclass_protocol(t) ) diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 1a2c38d30274c..de90d98f77951 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -1,7 +1,9 @@ +# Owner(s): ["module: pytree"] + """ Contains utility functions for working with nested python data structures. -A *pytree* is Python nested data structure. It is a tree in the sense that +A *pytree* is a Python nested data structure. It is a tree in the sense that nodes are Python collections (e.g., list, tuple, dict) and the leaves are Python values. Furthermore, a pytree should not contain reference cycles. @@ -20,6 +22,7 @@ import importlib import importlib.metadata import json +import logging import sys import threading import types @@ -45,6 +48,9 @@ from torch.torch_version import TorchVersion as _TorchVersion +log = logging.getLogger(__name__) + + if TYPE_CHECKING: import torch.utils._cxx_pytree as cxx_pytree @@ -615,10 +621,24 @@ def _private_register_pytree_node( from torch._library.opaque_object import is_opaque_type if isinstance(cls, type) and is_opaque_type(cls): - raise ValueError( - f"{cls} cannot be registered as a pytree as it has been " - "registered as an opaque object. Opaque objects must be pytree leaves." - ) + # TODO: remove this allowance once downstream callers stop calling + # register_constant on Enum subclasses. Enums are now natively + # supported as opaque value types and don't need pytree registration. + import enum + + if issubclass(cls, enum.Enum): + log.warning( + "%s is an Enum subclass and is now natively supported by " + "torch.compile as an opaque value type. Calling " + "register_constant() on Enum subclasses is deprecated and " + "will be an error in a future release.", + cls, + ) + else: + raise ValueError( + f"{cls} cannot be registered as a pytree as it has been " + "registered as an opaque object. Opaque objects must be pytree leaves." + ) with _NODE_REGISTRY_LOCK: if cls in SUPPORTED_NODES: @@ -1029,6 +1049,7 @@ def _deque_unflatten(values: Iterable[T], context: Context) -> deque[T]: STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) # pyrefly: ignore [no-matching-overload] BUILTIN_TYPES: frozenset[type] = frozenset( + # pyrefly: ignore [bad-argument-type] { tuple, list, @@ -1105,6 +1126,8 @@ def _is_leaf(tree: PyTree, is_leaf: Callable[[PyTree], bool] | None = None) -> b # is_leaf(): whether the root Node is a leaf @dataclasses.dataclass(init=False, frozen=True, eq=True, repr=False, slots=True) class TreeSpec: + """Representing the structure of the pytree.""" + type: Any _context: Context _children: list[Self] @@ -1188,15 +1211,20 @@ def children_specs(self) -> list[Self]: return self._children def is_leaf(self) -> bool: + """Test whether the treespec represents a leaf.""" return self.num_nodes == 1 and self.num_leaves == 1 def children(self) -> list[Self]: + """Get all the child treespecs.""" return self._children.copy() def child(self, index: int) -> Self: + """Get the child treespec at the given index.""" return self._children[index] def flatten_up_to(self, tree: PyTree) -> list[PyTree]: + """Flatten the subtrees in ``tree`` up to the structure of this treespec and return a list of subtrees.""" + def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): subtrees.append(node) @@ -1280,6 +1308,7 @@ def helper(treespec: TreeSpec, node: PyTree, subtrees: list[PyTree]) -> None: return subtrees def unflatten(self, leaves: Iterable[Any]) -> PyTree: + """Reconstruct a pytree from the leaves.""" if not isinstance(leaves, (list, tuple)): leaves = list(leaves) if len(leaves) != self.num_leaves: @@ -2064,16 +2093,22 @@ def treespec_loads(serialized: str) -> TreeSpec: ) -class _DummyLeaf: +class _Asterisk(str): + __slots__ = () + + def __new__(cls) -> Self: + return super().__new__(cls, "*") + def __repr__(self) -> str: - return "*" + return "*" # no quotes + + +_asterisk = _Asterisk() +del _Asterisk def treespec_pprint(treespec: TreeSpec) -> str: - dummy_tree = tree_unflatten( - [_DummyLeaf() for _ in range(treespec.num_leaves)], - treespec, - ) + dummy_tree = tree_unflatten([_asterisk] * treespec.num_leaves, treespec) return repr(dummy_tree) diff --git a/torch/utils/_runtime_estimation.py b/torch/utils/_runtime_estimation.py index 9efe613cefc85..b2d56dcfc7f79 100644 --- a/torch/utils/_runtime_estimation.py +++ b/torch/utils/_runtime_estimation.py @@ -1,8 +1,7 @@ import torch from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps from torch.fx.experimental.symbolic_shapes import ( - has_hint, - size_hint, + optimization_hint, statically_known_true, ) from torch.utils._ordered_set import OrderedSet @@ -121,12 +120,9 @@ def get_num_bytes(t: torch.Tensor) -> int: """ real_numel = 1 for size, stride in zip(t.shape, t.stride()): - if not has_hint(size) or not has_hint(stride): - return 0 - # For dims with stride=0 (expanded/broadcast), only 1 element accessed if not statically_known_true(stride == 0): - real_numel *= size_hint(size) + real_numel *= optimization_hint(size, fallback=0) return real_numel * t.element_size() diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1f95d5337f6a1..93f1bf9f56dc3 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -808,8 +808,8 @@ def _collapse_arguments(cls, args, **assumptions): if isinstance(a, other): a0 = a.args[0] if ( # noqa: E712 - (a0 > T) if other == Max else (a0 < T) # noqa: E712 - ) == True: # noqa: E712 + (a0 > T) if other == Max else (a0 < T) + ) == True: args[i] = cls.identity # type: ignore[attr-defined] # remove redundant symbolic args @@ -938,13 +938,11 @@ def _find_localzeros(cls, values, **options): _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 - i.is_antihermitian - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_antihermitian for i in s.args + ) _eval_is_commutative = lambda s: _torf( # noqa: E731 - i.is_commutative - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_commutative for i in s.args + ) _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731 _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731 _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731 @@ -957,13 +955,11 @@ def _find_localzeros(cls, values, **options): _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731 _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731 _eval_is_nonnegative = lambda s: _torf( # noqa: E731 - i.is_nonnegative - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_nonnegative for i in s.args + ) _eval_is_nonpositive = lambda s: _torf( # noqa: E731 - i.is_nonpositive - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_nonpositive for i in s.args + ) _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731 _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731 _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731 @@ -972,13 +968,11 @@ def _find_localzeros(cls, values, **options): _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731 _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731 _eval_is_extended_real = lambda s: _torf( # noqa: E731 - i.is_extended_real - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_extended_real for i in s.args + ) _eval_is_transcendental = lambda s: _torf( # noqa: E731 - i.is_transcendental - for i in s.args # noqa: E731 - ) # noqa: E731 + i.is_transcendental for i in s.args + ) _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731 diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 8d42d14968de7..e9ca80812b04d 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -97,6 +97,9 @@ def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: f"_print_NegativeInfinity not implemented for {type(self)}" ) + def _print_NaN(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_NaN not implemented for {type(self)}") + def _print_FloorDiv(self, expr: sympy.Expr) -> str: raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") @@ -176,6 +179,9 @@ def _print_Infinity(self, expr: sympy.Expr) -> str: def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: return "-math.inf" + def _print_NaN(self, expr: sympy.Expr) -> str: + return "math.nan" + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_PythonMod(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) @@ -191,8 +197,17 @@ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) def _helper_sqrt(self, expr: sympy.Expr) -> str: - # pyrefly: ignore [missing-attribute] - return f"math.sqrt({self._print(expr)})" + # NB: We use torch._sym_sqrt here instead of math.sqrt because the + # guard expression may be evaluated with SymInt/SymFloat inputs (e.g. + # during cache hit re-evaluation in evaluate_guards_expression). + # math.sqrt on a SymFloat triggers evaluate_expr which forces + # concretization/specialization of the symbol, creating spurious + # guards that didn't exist in the original program. + # torch._sym_sqrt properly propagates through the symbolic system + # without forcing specialization. + # See https://github.com/pytorch/pytorch/issues/152435 + # pyrefly: ignore [missing-attribute] + return f"torch._sym_sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: return self._helper_sqrt(expr.args[0]) @@ -372,6 +387,9 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " || ", precedence(expr)) + def _print_Piecewise(self, expr: sympy.Expr) -> str: # Convert Piecewise(expr_cond_pairs) to nested ternary operators # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) @@ -643,3 +661,6 @@ def _print_Infinity(self, expr: sympy.Expr) -> str: def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: return f"-{self._print_Infinity(expr)}" + + def _print_NaN(self, expr: sympy.Expr) -> str: + return "std::numeric_limits::quiet_NaN()" diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 3118e89ebfcf0..3879cef698154 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -177,7 +177,7 @@ def sqrt(x): @staticmethod def pow(a, b): - # pyrefly: ignore [bad-argument-type] + # pyrefly: ignore [bad-argument-count, bad-argument-type] return _keep_float(FloatPow)(a, b) @staticmethod @@ -340,6 +340,20 @@ def bitwise_or(a, b): def bitwise_xor(a, b): return a ^ b + @staticmethod + def expr_cond_pair(expr, cond): + return (expr, cond) + + @staticmethod + def piecewise(*pairs): + # Build nested sym_ite from right to left. + # Piecewise((e1, c1), (e2, c2), ..., (en, True)) becomes + # sym_ite(c1, e1, sym_ite(c2, e2, ... en)) + result = pairs[-1][0] + for expr, cond in reversed(pairs[:-1]): + result = torch.sym_ite(cond, expr, result) + return result + # Like PythonReferenceAnalysis, but some export-unfriendly choices of # operators to make things faster diff --git a/torch/utils/_sympy/singleton_int.py b/torch/utils/_sympy/singleton_int.py index 57d5615e55271..208533d5fac4b 100644 --- a/torch/utils/_sympy/singleton_int.py +++ b/torch/utils/_sympy/singleton_int.py @@ -80,14 +80,14 @@ def _eval_is_ge(a, b): @dispatch(SingletonInt, sympy.Integer) # type: ignore[no-redef] -def _eval_is_ge(a, b): # noqa: F811 +def _eval_is_ge(a, b): if b <= 2: return sympy.true raise ValueError("Symbolic SingletonInt: Relation is indeterminate") @dispatch(SingletonInt, SingletonInt) # type: ignore[no-redef] -def _eval_is_ge(a, b): # noqa: F811 +def _eval_is_ge(a, b): if a._val == b._val: if a._coeff >= b._coeff: return sympy.true diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index 0e1eb82f8869e..944c19fe5c28a 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -76,9 +76,25 @@ def rename_privateuse1_backend(backend_name: str) -> None: >>> a = torch.ones(2, device="foo") """ + from torch._C._profiler import ProfilerActivity + _rename_privateuse1_backend(backend_name) global _privateuse1_backend_name _privateuse1_backend_name = backend_name + # Mirror the rename in ProfilerActivity so users can write e.g. + # ProfilerActivity. instead of ProfilerActivity.PrivateUse1. + pu1 = ProfilerActivity.PrivateUse1 + alias = backend_name.upper() + setattr(ProfilerActivity, alias, pu1) + + original_repr = ProfilerActivity.__repr__ + + def custom_repr(self): + if self == pu1: + return f"" + return original_repr(self) + + ProfilerActivity.__repr__ = custom_repr def _check_register_once(module, attr) -> None: @@ -186,9 +202,10 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) - ) def wrap_module_to( + # pyrefly: ignore [invalid-type-var] self: torch.nn.modules.module.T, device: int | torch.device | None = None, - ) -> torch.nn.modules.module.T: + ) -> torch.nn.modules.module.T: # pyrefly: ignore [invalid-type-var] r"""Move all model parameters and buffers to the custom device. This also makes associated parameters and buffers different objects. So diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py index d4f328d19083f..e28957a33be35 100644 --- a/torch/utils/benchmark/utils/common.py +++ b/torch/utils/benchmark/utils/common.py @@ -223,7 +223,7 @@ def __repr__(self) -> str: {'Median: ' if n > 1 else ''}{self._median / time_scale:.2f} {time_unit} {iqr_filter}IQR: {self.iqr / time_scale:.2f} {time_unit} ({self._p25 / time_scale:.2f} to {self._p75 / time_scale:.2f}) {n} measurement{'s' if n > 1 else ''}, {self.number_per_run} runs {'per measurement,' if n > 1 else ','} {self.num_threads} thread{'s' if self.num_threads > 1 else ''} -{newline.join(self._warnings)}""".strip() # noqa: B950 +{newline.join(self._warnings)}""".strip() return "\n".join(l for l in repr_str.splitlines(keepends=False) if skip_line not in l) diff --git a/torch/utils/checkpoint.py b/torch/utils/checkpoint.py index a21ffdb7fef29..d5f2bb039262b 100644 --- a/torch/utils/checkpoint.py +++ b/torch/utils/checkpoint.py @@ -2,6 +2,7 @@ from __future__ import annotations import contextlib +import itertools import platform import uuid import warnings @@ -138,10 +139,11 @@ def get_device_type() -> str: def _infer_device_type(*args): device_types = [] - def add_device_types(arg) -> None: + def add_device_types(arg): nonlocal device_types if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": device_types.append(arg.device.type) + return arg tree_map(add_device_types, args) device_types_set = set(device_types) @@ -174,10 +176,11 @@ def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: # the conditionals short-circuit. fwd_device_ids = [] - def add_device_ids(arg) -> None: + def add_device_ids(arg): nonlocal fwd_device_ids if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: fwd_device_ids.append(arg.get_device()) + return arg tree_map(add_device_ids, args) fwd_device_states = [] @@ -993,8 +996,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint # checkpointing mechanism. error_cb is invoked when an error is detected # during unpack. - # record_context_cpp is not support on non-linux non-x86_64 platforms - cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' + cpp_tb = platform.machine() in ('x86_64', 'aarch64', 'arm64') and platform.system() == 'Linux' class CaptureLogs: def __init__(self) -> None: @@ -1069,6 +1071,10 @@ class _StopRecomputationError(Exception): class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): def __init__(self, target_frame_ref: ReferenceType, gid: GraphExecGroup | int) -> None: + # Dynamo guards on WeakKeyDictionary internals are unstable here + # (dict length/keys change every call), causing recompilation storms. + # with `.compile()` so we disable + @torch._dynamo.disable def pack_hook(x): x = x.detach() if x.requires_grad else x target_frame = target_frame_ref() @@ -1196,8 +1202,10 @@ def unpack_hook_with_error_cb(holder): def _is_compiling(func, args, kwargs): - # Check if we are under AOTAutograd tracing - # Checking that a functional mode is active should always do what we want + # Check if we are under AOTAutograd tracing or export tracing + # Checking that a proxy mode is active should always do what we want + if torch.compiler._is_non_strict_tracing(): + return False return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None @@ -1258,8 +1266,9 @@ class SelectiveCheckpointContext: >>> context_fn=context_fn, >>> ) """ - def __init__(self, *, is_recompute) -> None: + def __init__(self, *, is_recompute, op_output=None) -> None: self.is_recompute = is_recompute + self.op_output = op_output class CheckpointPolicy(enum.Enum): @@ -1316,25 +1325,29 @@ def ignore_compile_internals(cls): return True # Used together with _CachedTorchDispatchMode to implement SAC. - def __init__(self, policy_fn, storage) -> None: + def __init__(self, policy_fn, storage, ac_graph_id=None) -> None: self.policy_fn = policy_fn self.storage = storage + self.ac_graph_id = ac_graph_id def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if func in SAC_IGNORED_OPS: - return func(*args, **kwargs) - kwargs = {} if kwargs is None else kwargs - policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), - func, *args, **kwargs) - if isinstance(policy, bool): - policy = _policy_from_bool(policy) - is_compiling = _is_compiling(func, args, kwargs) if is_compiling: - # Overwrite each node's "recompute" tag to add in the user annotation. - fx_traceback.current_meta["recompute"] = policy + fx_traceback.current_meta["ac_graph_id"] = self.ac_graph_id + fx_traceback.current_meta["recompute"] = CheckpointPolicy.PREFER_RECOMPUTE + + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) + + proxy_mode = None + graph_len_before = 0 + if is_compiling: + from torch.fx.experimental.proxy_tensor import get_proxy_mode + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + graph_len_before = len(proxy_mode.tracer.graph.nodes) out = func(*args, **kwargs) @@ -1346,6 +1359,18 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): else: any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False, op_output=out), + func, *args, **kwargs) + if isinstance(policy, bool): + policy = _policy_from_bool(policy) + + if is_compiling: + if proxy_mode is not None: + graph = proxy_mode.tracer.graph + num_new = len(graph.nodes) - graph_len_before + for node in itertools.islice(reversed(graph.nodes), num_new): + node.meta["recompute"] = policy + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) return out @@ -1572,6 +1597,8 @@ def _checkpoint_without_reentrant_generator( ), contextlib.nullcontext(), ) + error_on_nested_fx_trace = torch._dynamo.config.error_on_nested_fx_trace + is_non_strict_tracing = torch.compiler._is_non_strict_tracing() def recompute_fn(*args) -> None: # This will be called later during recomputation. This wrapping enables @@ -1590,7 +1617,20 @@ def recompute_fn(*args) -> None: device_autocast_ctx = torch.amp.autocast( device_type=device_type, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() - with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context, device_ctx: # type: ignore[attr-defined] + nested_fx_trace_ctx = ( + torch._dynamo.config.patch( + error_on_nested_fx_trace=error_on_nested_fx_trace + ) + if is_non_strict_tracing + else contextlib.nullcontext() + ) + with ( + device_autocast_ctx, + torch.amp.autocast("cpu", **cpu_autocast_kwargs), + recompute_context, + device_ctx, + nested_fx_trace_ctx, + ): # type: ignore[attr-defined] fn(*args, **kwargs) new_frame = _CheckpointFrame( diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index a184eba0ce8a9..3e9d5d3823cbe 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -71,8 +71,8 @@ '12.5': ((6, 0, 0), (14, 0)), '12.6': ((6, 0, 0), (14, 0)), '12.7': ((6, 0, 0), (14, 0)), - '12.8': ((6, 0, 0), (14, 0)), - '12.9': ((6, 0, 0), (14, 0)), + '12.8': ((6, 0, 0), (15, 0)), + '12.9': ((6, 0, 0), (15, 0)), '13.0': ((6, 0, 0), (16, 0)), } @@ -2261,8 +2261,11 @@ def _jit_compile(name, hipified_sources = set() for source in sources: s_abs = os.path.abspath(source) - hipified_sources.add(hipify_result[s_abs].hipified_path if s_abs in hipify_result else s_abs) - + if s_abs in hipify_result and hipify_result[s_abs].hipified_path is not None: + hipified_s_abs = hipify_result[s_abs].hipified_path + else: + hipified_s_abs = s_abs + hipified_sources.add(hipified_s_abs) sources = list(hipified_sources) _write_ninja_file_and_build_library( diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 0b51561ca6e7f..4eaa755980787 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -423,8 +423,6 @@ def __init__( self.check_worker_number_rationality() - torch.set_vital("Dataloader", "enabled", "True") # type: ignore[attr-defined] - def _get_iterator(self) -> _BaseDataLoaderIter: if self.num_workers == 0: return _SingleProcessDataLoaderIter(self) @@ -606,31 +604,10 @@ def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked): return # try to compute a suggested max number of worker based on system's resource - max_num_worker_suggest = None - cpuset_checked = False - if hasattr(os, "sched_getaffinity"): - try: - max_num_worker_suggest = len(os.sched_getaffinity(0)) - cpuset_checked = True - except Exception: - pass - if max_num_worker_suggest is None: - # os.cpu_count() could return Optional[int] - # get cpu count first and check None in order to satisfy mypy check - cpu_count = os.cpu_count() - if cpu_count is not None: - max_num_worker_suggest = cpu_count - - if max_num_worker_suggest is None: - warnings.warn( - _create_warning_msg( - max_num_worker_suggest, self.num_workers, cpuset_checked - ), - stacklevel=2, - ) - return + max_num_worker_suggest = torch._utils.cpu_count() + cpuset_checked = hasattr(os, "sched_getaffinity") - if self.num_workers > max_num_worker_suggest: + if max_num_worker_suggest is None or self.num_workers > max_num_worker_suggest: warnings.warn( _create_warning_msg( max_num_worker_suggest, self.num_workers, cpuset_checked diff --git a/torch/utils/data/graph.py b/torch/utils/data/graph.py index 9820aac095ae3..3f8dbb17ceaea 100644 --- a/torch/utils/data/graph.py +++ b/torch/utils/data/graph.py @@ -11,6 +11,7 @@ __all__ = ["traverse", "traverse_dps"] DataPipe = IterDataPipe | MapDataPipe +# pyrefly: ignore [invalid-type-alias] DataPipeGraph = dict[int, tuple[DataPipe, "DataPipeGraph"]] diff --git a/torch/utils/debug_log.py b/torch/utils/debug_log.py new file mode 100644 index 0000000000000..5ecf80a135d31 --- /dev/null +++ b/torch/utils/debug_log.py @@ -0,0 +1,68 @@ +"""Compile-safe backward gradient logging for multiple tensors. + +``debug_grad_log`` logs gradient norms during backward for one or more tensors. +It is a leaf function with a ``register_multi_grad_hook`` that fires exactly +once when all requires_grad tensor inputs have their gradients computed. + +Example:: + + import torch + from torch.utils.debug_log import debug_grad_log + + x = torch.randn(4, requires_grad=True) + y = torch.randn(4, requires_grad=True) + z = x * 2 + y * 3 + + debug_grad_log(x, y) + + z.sum().backward() + # Logs: [rank 0][bwd] t0_grad_norm=... t1_grad_norm=... +""" + +import logging + +import torch +from torch._dynamo.decorators import leaf_function + + +__all__ = ["debug_grad_log"] + +log = logging.getLogger(__name__) + + +def _get_rank() -> int: + if not torch.distributed.is_available(): + return 0 + import torch.distributed as dist + + return dist.get_rank() if dist.is_initialized() else 0 + + +@leaf_function +def debug_grad_log(*tensors): + """Log gradient norms of multiple tensors during backward. + + This is a no-op in the forward pass. During backward, the hook fires + exactly once when all requires_grad tensor inputs have their gradients + computed, and logs ``[rank R][bwd] t0_grad_norm=... t1_grad_norm=...``. + + Args: + *tensors: One or more tensors to monitor. + + Returns: + None. Call without assignment: ``debug_grad_log(x, y)``. + """ + return None + + +@debug_grad_log.register_fake # pyrefly: ignore[missing-attribute] +def _debug_grad_log_fake(*tensors): + return None + + +@debug_grad_log.register_multi_grad_hook # pyrefly: ignore[missing-attribute] +def _debug_grad_log_hook(*grads): + norms = " ".join( + f"t{i}_grad_norm={g.norm().item():.4f}" for i, g in enumerate(grads) + ) + log.info("[rank %d][bwd] %s", _get_rank(), norms) diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index 76394f88cc32d..5f015a32f9c31 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -608,6 +608,85 @@ def _efficient_attention_backward_flop( ) +def _varlen_attn_forward_flop( + query, + key, + value, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + *args, + out_val=None, + **kwargs, +) -> int: + """Count flops for varlen_attn forward.""" + sizes = _unpack_flash_attention_nested_shapes( + query=query, + key=key, + value=value, + cum_seq_q=cu_seq_q, + cum_seq_k=cu_seq_k if cu_seq_k is not None else cu_seq_q, + max_q=max_q, + max_k=max_k, + ) + return sum( + sdpa_flop_count(query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, _ in sizes + ) + + +def _varlen_attn_out_flop( + out, + query, + key, + value, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + *args, + out_val=None, + **kwargs, +) -> int: + """Count flops for varlen_attn_out forward.""" + return _varlen_attn_forward_flop( + query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, + ) + + +def _varlen_attn_backward_flop( + grad_out, + query, + key, + value, + out, + lse, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + *args, + out_val=None, + **kwargs, +) -> int: + """Count flops for varlen_attn backward.""" + sizes = _unpack_flash_attention_nested_shapes( + query=query, + key=key, + value=value, + grad_out=grad_out, + cum_seq_q=cu_seq_q, + cum_seq_k=cu_seq_k, + max_q=max_q, + max_k=max_k, + ) + return sum( + sdpa_backward_flop_count(grad_out_shape, query_shape, key_shape, value_shape) + for query_shape, key_shape, value_shape, grad_out_shape in sizes + ) + + flop_registry = { aten.mm: mm_flop, aten.addmm: addmm_flop, diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 2ed9ce7c07879..8df2eec8c1140 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -766,6 +766,7 @@ ("CU_MEM_HANDLE_TYPE_WIN32", "hipMemHandleTypeWin32"), ("CU_MEM_HANDLE_TYPE_WIN32_KMT", "hipMemHandleTypeWin32Kmt"), ("CU_MEM_LOCATION_TYPE_DEVICE", "hipMemLocationTypeDevice"), + ("CU_MEM_LOCATION_TYPE_HOST", "hipMemLocationTypeHost"), ("CU_MEM_LOCATION_TYPE_INVALID", "hipMemLocationTypeInvalid"), ("CU_MEM_OPERATION_TYPE_MAP", "hipMemOperationTypeMap"), ("CU_MEM_OPERATION_TYPE_UNMAP", "hipMemOperationTypeUnmap"), @@ -1360,6 +1361,10 @@ ("cudaEventBlockingSync", "hipEventBlockingSync"), ("cudaEventDisableTiming", "hipEventDisableTiming"), ("cudaEventInterprocess", "hipEventInterprocess"), + ("cudaEventRecordDefault", "hipEventRecordDefault"), + ("cudaEventRecordExternal", "hipEventRecordExternal"), + ("cudaEventWaitDefault", "hipEventWaitDefault"), + ("cudaEventWaitExternal", "hipEventWaitExternal"), ("cudaStreamCreate", "hipStreamCreate"), ("cudaStreamCreateWithFlags", "hipStreamCreateWithFlags"), ("cudaStreamCreateWithPriority", "hipStreamCreateWithPriority"), @@ -2676,6 +2681,7 @@ ("cub::RowMajorTid", "hipcub::RowMajorTid"), ("cub::CachingDeviceAllocator", "hipcub::CachingDeviceAllocator"), ("cub::CountingInputIterator", "hipcub::CountingInputIterator"), + ("cub::DeviceHistogram", "hipcub::DeviceHistogram"), ("cub::DeviceRadixSort", "hipcub::DeviceRadixSort"), ("cub::DeviceReduce", "hipcub::DeviceReduce"), ("cub::DeviceRunLengthEncode", "hipcub::DeviceRunLengthEncode"), @@ -3392,6 +3398,23 @@ ("cudnnTensorDescriptor_t ", "miopenTensorDescriptor_t "), ("CUDNN_ENFORCE", "MIOPEN_ENFORCE"), ("CUDNN_CHECK", "MIOPEN_CHECK"), + # NVSHMEM → rocSHMEM mappings (only symbols used in hipified files: + # NVSHMEMSymmetricMemory.cpp and nvshmem_team_manager.hpp). + ("NVSHMEM_TEAM_INVALID", "rocshmem::ROCSHMEM_TEAM_INVALID"), + ("NVSHMEM_TEAM_WORLD", "rocshmem::ROCSHMEM_TEAM_WORLD"), + ("NVSHMEMX_INIT_WITH_UNIQUEID", "rocshmem::ROCSHMEM_INIT_WITH_UNIQUEID"), + + ("nvshmem_malloc", "rocshmem::rocshmem_malloc"), + ("nvshmem_free", "rocshmem::rocshmem_free"), + ("nvshmem_ptr", "rocshmem::rocshmem_ptr"), + ("nvshmem_team_t", "rocshmem::rocshmem_team_t"), + ("nvshmem_team_split_strided", "rocshmem::rocshmem_team_split_strided"), + + ("nvshmemx_uniqueid_t", "rocshmem::rocshmem_uniqueid_t"), + ("nvshmemx_get_uniqueid", "rocshmem::rocshmem_get_uniqueid"), + ("nvshmemx_init_attr", "rocshmem::rocshmem_init_attr"), + ("nvshmemx_init_attr_t", "rocshmem::rocshmem_init_attr_t"), + ("nvshmemx_set_attr_uniqueid_args", "rocshmem::rocshmem_set_attr_uniqueid_args"), ]) C10_MAPPINGS = collections.OrderedDict([ @@ -3415,6 +3438,7 @@ ("c10/cuda/CUDAEvent.h", "c10/hip/HIPEvent.h"), ("c10/cuda/impl/CUDAGuardImpl.h", "c10/hip/impl/HIPGuardImpl.h"), ("c10/cuda/impl/CUDATest.h", "c10/hip/impl/HIPTest.h"), + ("CUDATest.hpp", "HIPTest.hpp"), ("c10/cuda/impl/cuda_cmake_macros.h", "c10/hip/impl/hip_cmake_macros.h"), # TODO: Remove these. They were necessary for Meta-internal builds. ("c10::hip::c10_hip_check_implementation", "c10::cuda::c10_cuda_check_implementation"), diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 557f8c128a905..3919d61bbaa0c 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -647,6 +647,8 @@ def is_pytorch_file(rel_filepath) -> bool: return True if rel_filepath.startswith("tools/autograd/templates/"): return True + if rel_filepath.startswith("test/cpp/c10d/"): + return True return False diff --git a/torch/utils/hooks.py b/torch/utils/hooks.py index 1e3f6fb9ab09d..dd8afceb162eb 100644 --- a/torch/utils/hooks.py +++ b/torch/utils/hooks.py @@ -225,7 +225,7 @@ def hook(_, grad_output): if self.input_tensors_index is None: warnings.warn("Full backward hook is firing when gradients are computed " "with respect to module outputs since no inputs require gradients. See " - "https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " # noqa: B950 + "https://docs.pytorch.org/docs/main/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook " "for more details.", stacklevel=5) grad_inputs = self._pack_with_none([], [], self.n_inputs) diff --git a/torch/utils/viz/MemoryViz.js b/torch/utils/viz/MemoryViz.js index 20b08b801eade..a7a05c6d91f5d 100644 --- a/torch/utils/viz/MemoryViz.js +++ b/torch/utils/viz/MemoryViz.js @@ -1,3 +1,61 @@ +/** + * ================================================================================ + * MemoryViz.js - PyTorch Memory Visualization Tool + * ================================================================================ + * + * OVERVIEW: + * --------- + * This file contains the core visualization logic for PyTorch's memory profiler. + * It renders memory allocation timelines, stack traces, and provides interactive + * exploration of memory snapshots captured during model execution. + * + * KEY FEATURES: + * - Multiple visualization tabs/views for different memory analysis perspectives + * - Interactive stack trace display (supports both click and hover modes) + * - Zoom and brush controls for navigating large memory timelines + * - Support for loading memory snapshot files (.pickle format) + * + * ================================================================================ + * TESTING INSTRUCTIONS FOR ENGINEERS & AGENTS + * ================================================================================ + * + * 1. LOCAL TESTING SETUP: + * - Create a simple HTML file that references this JS file: + * + * + * + * MemoryViz Test + * + * + * + * + * + * - Serve locally using: python3 -m http.server 8888 + * - Open http://localhost:8888 in your browser + * + * 2. WHAT TO TEST: + * - Ensure ALL tabs/views render correctly and switch properly + * - Verify BOTH interaction modes work: + * * Click mode: stack traces appear on click + * * Hover mode: stack traces appear on mouseover + * - Test zoom and brush controls for timeline navigation + * - Verify memory allocation blocks are rendered and interactive + * + * 3. TEST DATA REQUIREMENTS: + * - DO NOT just test with small dummy .pickle files + * - Use realistic, decent-sized .pickle files (10-100+ MB range) + * - Large files stress-test rendering performance and memory handling + * - Test with snapshots from real model training/inference runs + * + * 4. COMMON ISSUES TO WATCH FOR: + * - Performance degradation with large snapshots + * - Stack trace popups not appearing or positioning incorrectly + * - Tab switching not updating the visualization properly + * - Zoom/brush state not persisting across interactions + * + * ================================================================================ + */ + 'use strict'; import * as d3 from "https://cdn.jsdelivr.net/npm/d3@7/+esm"; @@ -5,6 +63,9 @@ import {axisLeft} from "https://cdn.jsdelivr.net/npm/d3-axis@3/+esm"; import {scaleLinear} from "https://cdn.jsdelivr.net/npm/d3-scale@4/+esm"; import {zoom, zoomIdentity} from "https://cdn.jsdelivr.net/npm/d3-zoom@3/+esm"; import {brushX} from "https://cdn.jsdelivr.net/npm/d3-brush@3/+esm"; +import {process_alloc_data, isPrivatePoolId, formatSize, formatAddr, + elideRepeats, frameFilter, format_user_metadata, + format_forward_frames, format_frames} from "./process_alloc_data.js"; // Global configuration for trace interaction mode // 'hover' = show trace on hover (default) @@ -48,8 +109,8 @@ function version_space() { }; } -function Segment(addr, size, stream, frames, version, user_metadata) { - return {addr, size, stream, version, frames, user_metadata}; +function Segment(addr, size, stream, frames, version, user_metadata, segment_pool_id) { + return {addr, size, stream, version, frames, user_metadata, segment_pool_id}; } function Block(addr, size, requested_size, frames, free_requested, version, user_metadata) { @@ -115,25 +176,6 @@ function EventSelector(outer, events, stack_info, memory_view) { return es; } -function formatSize(num, showBytes = true) { - const orig = num; - // https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size - const units = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']; - for (const unit of units) { - if (Math.abs(num) < 1024.0) { - if (showBytes) { - return `${num.toFixed(1)}${unit}B (${orig} bytes)`; - } - return `${num.toFixed(1)}${unit}B`; - } - num /= 1024.0; - } - return `${num.toFixed(1)}YiB`; -} -function formatAddr(event) { - const prefix = event.action.startsWith('segment') ? 's\'' : 'b\''; - return `${prefix}${event.addr.toString(16)}_${event.version}`; -} function formatEvent(event) { const stream = event.stream === null ? '' : `\n (stream ${event.stream})`; @@ -246,6 +288,7 @@ function MemoryView(outer, stack_info, snapshot, device) { seg.frames || [], seg.version, seg.user_metadata, + seg.segment_pool_id, ), ); for (const b of seg.blocks) { @@ -490,13 +533,17 @@ function MemoryView(outer, stack_info, snapshot, device) { const user_metadata_str = format_user_metadata(t.user_metadata); const frames_str = format_frames(t.frames); const forward_frames_str = format_forward_frames(t.forward_frames); + let pool_str = ''; + if (isPrivatePoolId(t.segment_pool_id)) { + pool_str = `, pool_id (${t.segment_pool_id[0]}, ${t.segment_pool_id[1]})`; + } return ( `s${t.addr.toString(16)}_${t.version}: segment ${formatSize( t.size, )} allocated, ` + `${formatSize(free)} free${internal} (stream ${ t.stream - })\n` + + }${pool_str})\n` + (user_metadata_str ? user_metadata_str + '\n' : '') + frames_str + forward_frames_str @@ -564,11 +611,15 @@ function MemoryView(outer, stack_info, snapshot, device) { const user_metadata_str = format_user_metadata(t.user_metadata); const frames_str = format_frames(t.frames); const forward_frames_str = format_forward_frames(t.forward_frames); + let pool_str = ''; + if (isPrivatePoolId(t.segment?.segment_pool_id)) { + pool_str = `, pool_id (${t.segment.segment_pool_id[0]}, ${t.segment.segment_pool_id[1]})`; + } return ( `b${t.addr.toString(16)}_${t.version} ` + `${formatSize(t.requested_size)} allocation${requested} (stream ${ t.segment.stream - })\n` + + }${pool_str})\n` + (user_metadata_str ? user_metadata_str + '\n' : '') + frames_str + forward_frames_str @@ -751,14 +802,6 @@ function annotate_snapshot(snapshot) { } } snapshot.device_traces = new_traces; - // if every event was on the default stream, we elide stream printing - if (next_stream == 1) { - for (const device_trace of snapshot.device_traces) { - for (const t of device_trace) { - t.stream = null; - } - } - } for (const seg of snapshot.segments) { seg.stream = stream_name(seg.stream); @@ -778,6 +821,7 @@ function annotate_snapshot(snapshot) { } } b.version = snapshot.block_version(b.addr, false); + b.segment_pool_id = seg.segment_pool_id; // Note [BigInt and Number Safe Arithmetic] // Device pointer addresses may be represented as either Number or BigInt. // Use explicit conversions to perform arithmetic safely and avoid mixing @@ -794,364 +838,6 @@ function annotate_snapshot(snapshot) { } } -function elideRepeats(frames) { - const result = []; - const length = frames.length; - for (let i = 0; i < length; ) { - let j = i + 1; - const f = frames[i]; - while (j < length && f === frames[j]) { - j++; - } - switch (j - i) { - case 1: - result.push(f); - break; - case 2: - result.push(f, f); - break; - default: - result.push(f, ``); - break; - } - i = j; - } - return result; -} -function frameFilter({name, filename}) { - const omitFunctions = [ - 'unwind::unwind', - 'CapturedTraceback::gather', - 'gather_with_cpp', - '_start', - '__libc_start_main', - 'PyEval_', - 'PyObject_', - 'PyFunction_', - ]; - - const omitFilenames = [ - 'core/boxing', - '/Register', - '/Redispatch', - 'pythonrun.c', - 'Modules/main.c', - 'Objects/call.c', - 'Objects/methodobject.c', - 'pycore_ceval.h', - 'ceval.c', - 'cpython/abstract.h', - ]; - - for (const of of omitFunctions) { - if (name.includes(of)) { - return false; - } - } - - for (const of of omitFilenames) { - if (filename.includes(of)) { - return false; - } - } - - return true; -} - -function format_user_metadata(user_metadata) { - if (!user_metadata) { - return ''; - } - // Handle string metadata - if (typeof user_metadata === 'string') { - return `User Metadata:\n ${user_metadata}`; - } - // Handle object metadata - if (typeof user_metadata === 'object' && Object.keys(user_metadata).length === 0) { - return ''; - } - const metadata_lines = Object.entries(user_metadata) - .map(([key, value]) => ` ${key}: ${value}`); - return 'User Metadata:\n' + metadata_lines.join('\n'); -} - -function format_forward_frames(forward_frames) { - if (!forward_frames || forward_frames.length === 0) { - return ''; - } - // forward_frames is a list of strings (each string is a frame line from the forward pass) - // Each frame string already includes newlines, so we just join them directly - let frames_str = forward_frames.join(''); - // Ensure we don't have a trailing newline that could cause display issues - frames_str = frames_str.trimEnd(); - return `\n\n=== Forward Pass Stack Trace (where this tensor was created) ===\n${frames_str}`; -} - -function format_frames(frames) { - if (frames.length === 0) { - return ( - `This block has no frames. Potential causes:\n` + - `1) This block was allocated before _record_memory_history was enabled.\n` + - `2) The context or stacks passed to _record_memory_history does not include this block. Consider changing context to 'state', 'alloc', or 'all', or changing stacks to 'all'.\n` + - `3) This event occurred during backward, which has no python frames, and memory history did not include C++ frames. Use stacks='all' to record both C++ and python frames.` - ); - } - const frame_strings = frames - .filter(frameFilter) - .map(f => { - let frame_str = `${f.filename}:${f.line}:${f.name}`; - - // Add FX debug information if available - if (f.fx_node_op || f.fx_node_name || f.fx_node_target) { - const fx_parts = []; - if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`); - if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`); - if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`); - frame_str += `\n >> FX: ${fx_parts.join(', ')}`; - } - - if (f.fx_original_trace) { - frame_str += `\n >> Original Model Code:`; - const original_lines = f.fx_original_trace.trim().split('\n'); - // Show all lines of the original trace - for (const line of original_lines) { - frame_str += `\n ${line}`; - } - } - - return frame_str; - }); - return elideRepeats(frame_strings).join('\n'); -} - -function process_alloc_data(snapshot, device, plot_segments, max_entries) { - const elements = []; - const initially_allocated = []; - const actions = []; - const addr_to_alloc = {}; - - const alloc = plot_segments ? 'segment_alloc' : 'alloc'; - const [free, free_completed] = plot_segments - ? ['segment_free', 'segment_free'] - : ['free', 'free_completed']; - for (const e of snapshot.device_traces[device]) { - switch (e.action) { - case alloc: - elements.push(e); - addr_to_alloc[e.addr] = elements.length - 1; - actions.push(elements.length - 1); - break; - case free: - case free_completed: - if (e.addr in addr_to_alloc) { - actions.push(addr_to_alloc[e.addr]); - delete addr_to_alloc[e.addr]; - } else { - elements.push(e); - initially_allocated.push(elements.length - 1); - actions.push(elements.length - 1); - } - break; - default: - break; - } - } - for (const seg of snapshot.segments) { - if (seg.device !== device) { - continue; - } - if (plot_segments) { - if (!(seg.address in addr_to_alloc)) { - const element = { - action: 'alloc', - addr: seg.address, - size: seg.total_size, - frames: [], - stream: seg.stream, - version: seg.version, - }; - elements.push(element); - initially_allocated.push(elements.length - 1); - } - } else { - for (const b of seg.blocks) { - if (b.state === 'active_allocated' && !(b.addr in addr_to_alloc)) { - const element = { - action: 'alloc', - addr: b.addr, - size: b.requested_size, - frames: b.frames, - stream: seg.stream, - version: b.version, - }; - elements.push(element); - initially_allocated.push(elements.length - 1); - } - } - } - } - initially_allocated.reverse(); - // if there are no actions, the graph will be blank, - // but if there are existing allocations we do not want to hide them - // by having just one allocate action it will show a flat graph with all segments - if (actions.length === 0 && initially_allocated.length > 0) { - actions.push(initially_allocated.pop()); - } - - const current = []; - const current_data = []; - const data = []; - let max_size = 0; - - let total_mem = 0; - let total_summarized_mem = 0; - let timestep = 0; - - const max_at_time = []; - - const summarized_mem = { - elem: 'summarized', - timesteps: [], - offsets: [total_mem], - size: [], - color: 0, - }; - const summarized_elems = {}; - - function advance(n) { - summarized_mem.timesteps.push(timestep); - summarized_mem.offsets.push(total_mem); - summarized_mem.size.push(total_summarized_mem); - timestep += n; - for (let i = 0; i < n; i++) { - max_at_time.push(total_mem + total_summarized_mem); - } - } - - const sizes = elements - .map((x, i) => [x.size, i]) - .sort(([x, _xi], [y, _yi]) => y - x); - - const draw_elem = {}; - for (const [_s, e] of sizes.slice(0, max_entries)) { - draw_elem[e] = true; - } - - function add_allocation(elem) { - const element_obj = elements[elem]; - const size = element_obj.size; - current.push(elem); - let color = elem; - if (snapshot.categories.length > 0) { - color = snapshot.categories.indexOf(element_obj.category || 'unknown'); - } - const e = { - elem, - timesteps: [timestep], - offsets: [total_mem], - size, - color, - }; - current_data.push(e); - data.push(e); - total_mem += size; - element_obj.max_allocated_mem = total_mem + total_summarized_mem; - } - - for (const elem of initially_allocated) { - if (elem in draw_elem) { - add_allocation(elem); - } else { - total_summarized_mem += elements[elem].size; - summarized_elems[elem] = true; - } - } - - for (const elem of actions) { - const size = elements[elem].size; - if (!(elem in draw_elem)) { - if (elem in summarized_elems) { - advance(1); - total_summarized_mem -= size; - summarized_elems[elem] = null; - } else { - total_summarized_mem += size; - summarized_elems[elem] = true; - advance(1); - } - continue; - } - const idx = current.findLastIndex(x => x === elem); - // first time we see an action we add it - // second time we remove it - if (idx === -1) { - add_allocation(elem); - advance(1); - } else { - advance(1); - const removed = current_data[idx]; - removed.timesteps.push(timestep); - removed.offsets.push(removed.offsets.at(-1)); - current.splice(idx, 1); - current_data.splice(idx, 1); - - if (idx < current.length) { - for (let j = idx; j < current.length; j++) { - const e = current_data[j]; - e.timesteps.push(timestep); - e.offsets.push(e.offsets.at(-1)); - e.timesteps.push(timestep + 3); - e.offsets.push(e.offsets.at(-1) - size); - } - advance(3); - } - total_mem -= size; - } - max_size = Math.max(total_mem + total_summarized_mem, max_size); - } - - for (const elem of current_data) { - elem.timesteps.push(timestep); - elem.offsets.push(elem.offsets.at(-1)); - } - data.push(summarized_mem); - - return { - max_size, - allocations_over_time: data, - max_at_time, - summarized_mem, - elements_length: elements.length, - context_for_id: id => { - const elem = elements[id]; - let text = `Addr: ${formatAddr(elem)}`; - text = `${text}, Size: ${formatSize(elem.size)} allocation`; - text = `${text}, Total memory used after allocation: ${formatSize( - elem.max_allocated_mem, - )}`; - const context = elem?.compile_context ?? 'None'; - text = `${text}, Compile context: ${context}`; - if (elem.stream !== null) { - text = `${text}, stream ${elem.stream}`; - } - if (elem.timestamp !== null) { - var d = new Date(elem.time_us / 1000); - text = `${text}, timestamp ${d}`; - } - if (!elem.action.includes('alloc')) { - text = `${text}\nalloc not recorded, stack trace for free:`; - } - const user_metadata_str = format_user_metadata(elem.user_metadata); - if (user_metadata_str) { - text = `${text}\n${user_metadata_str}`; - } - text = `${text}\n${format_frames(elem.frames)}`; - text = `${text}${format_forward_frames(elem.forward_frames)}`; - return text; - }, - }; -} - function MemoryPlot( svg, data, @@ -1213,7 +899,11 @@ function MemoryPlot( .enter() .append('polygon') .attr('points', format_points) - .attr('fill', d => colors[d.color % colors.length]); + .attr('fill', d => colors[d.color % colors.length]) + .attr('opacity', d => d.opacity ?? 1) + .attr('stroke', d => typeof d.elem === 'string' && d.elem.startsWith('pool:') ? 'black' : null) + .attr('stroke-width', d => typeof d.elem === 'string' && d.elem.startsWith('pool:') ? 3 : null) + .attr('vector-effect', d => typeof d.elem === 'string' && d.elem.startsWith('pool:') ? 'non-scaling-stroke' : null); const axis = plot_coordinate_space.append('g').call(yaxis); @@ -1272,11 +962,40 @@ function MemoryPlot( function ContextViewer(text, data) { let current_selected = null; + function restore_search_highlight(d) { + if (!d) return; + const addr = d.attr('data-search-match') === 'true'; + const frame = d.attr('data-frame-match') === 'true'; + if (addr && frame) { + d.attr('stroke', '#ff00ff') + .attr('stroke-width', 3) + .attr('stroke-dasharray', '6,3') + .attr('vector-effect', 'non-scaling-stroke'); + } else if (addr) { + d.attr('stroke', 'red') + .attr('stroke-width', 2) + .attr('stroke-dasharray', null) + .attr('vector-effect', 'non-scaling-stroke'); + } else if (frame) { + d.attr('stroke', '#2196F3') + .attr('stroke-width', 2) + .attr('stroke-dasharray', null) + .attr('vector-effect', 'non-scaling-stroke'); + } + } + return { default_selected: null, set_selected: d => { if (current_selected !== null) { - current_selected.attr('stroke', null).attr('stroke-width', null); + const prev = current_selected.datum(); + const is_pool = prev && typeof prev.elem === 'string' && prev.elem.startsWith('pool:'); + current_selected + .attr('stroke', is_pool ? 'black' : null) + .attr('stroke-width', is_pool ? 3 : null) + .attr('stroke-dasharray', null) + .attr('vector-effect', is_pool ? 'non-scaling-stroke' : null); + restore_search_highlight(current_selected); } if (d === null) { text.text(''); @@ -1287,11 +1006,16 @@ function ContextViewer(text, data) { 'Small tensors that were not plotted to cutdown on render time.\n' + 'Use detail slider to see smaller allocations.', ); + } else if (typeof dd.elem === 'string' && dd.elem.startsWith('pool:')) { + const pool_key = dd.elem.slice(5); + const capacity = Array.isArray(dd.size) ? dd.size.at(-1) : dd.size; + text.text(`Private Pool (${pool_key}): capacity ${formatSize(capacity)}`); } else { text.text(`${dd.elem} ${data.context_for_id(dd.elem)}`); } + const is_pool_sel = typeof dd.elem === 'string' && dd.elem.startsWith('pool:'); d.attr('stroke', 'black') - .attr('stroke-width', 1) + .attr('stroke-width', is_pool_sel ? 5 : 1) .attr('vector-effect', 'non-scaling-stroke'); } current_selected = d; @@ -1387,13 +1111,20 @@ function create_trace_view( device, plot_segments = false, max_entries = 15000, + include_private_inactive = false, ) { const left_pad = 70; - const data = process_alloc_data(snapshot, device, plot_segments, max_entries); + const data = process_alloc_data(snapshot, device, plot_segments, max_entries, include_private_inactive); dst.selectAll('svg').remove(); dst.selectAll('div').remove(); max_entries = Math.min(max_entries, data.elements_length); + if (include_private_inactive) { + dst.append('div') + .attr('style', 'padding: 4px 8px; background: #fff3cd; border: 1px solid #ffc107; font-size: 13px; margin-bottom: 4px;') + .text('Note: Private pool memory (the gray bar) is shown as allocated until the pool\'s segment is freed. ' + + 'This view requires that MemPools are not deleted before torch.cuda.memory._snapshot() is called.'); + } const d = dst.append('div'); d.append('input') .attr('type', 'range') @@ -1401,12 +1132,28 @@ function create_trace_view( .attr('max', data.elements_length) .attr('value', max_entries) .on('change', function () { - create_trace_view(dst, snapshot, device, plot_segments, this.value); + create_trace_view(dst, snapshot, device, plot_segments, this.value, include_private_inactive); }); d.append('label').text( `Detail: ${max_entries} of ${data.elements_length} entries`, ); + d.append('span').text(' | '); + const search_input = d.append('input') + .attr('type', 'text') + .attr('placeholder', 'Search address (hex)...') + .attr('style', 'width: 180px; margin-left: 4px; font-family: monospace;'); + const search_label = d.append('label') + .attr('style', 'margin-left: 4px;'); + + d.append('span').text(' | '); + const frame_input = d.append('input') + .attr('type', 'text') + .attr('placeholder', 'Search stack frame...') + .attr('style', 'width: 200px; margin-left: 4px; font-family: monospace;'); + const frame_label = d.append('label') + .attr('style', 'margin-left: 4px;'); + const grid_container = dst .append('div') .attr( @@ -1443,6 +1190,61 @@ function create_trace_view( ); const delegate = ContextViewer(context_div.append('pre').text('none'), data); plot.set_delegate(delegate); + + function apply_search_highlights() { + const addr_query = search_input.node().value.toLowerCase().trim(); + const frame_query = frame_input.node().value.toLowerCase().trim(); + const polygons = plot_svg.selectAll('polygon'); + let addr_matches = 0; + let frame_matches = 0; + polygons.each(function () { + const dd = d3.select(this).datum(); + if (!dd || typeof dd.elem !== 'number') { + d3.select(this) + .attr('data-search-match', null) + .attr('data-frame-match', null); + return; + } + const ctx = data.context_for_id(dd.elem); + const ctx_lower = ctx.toLowerCase(); + const addr_hit = addr_query && ctx_lower.includes(addr_query); + const frame_hit = frame_query && ctx_lower.includes(frame_query); + d3.select(this) + .attr('data-search-match', addr_hit ? 'true' : null) + .attr('data-frame-match', frame_hit ? 'true' : null); + if (addr_hit && frame_hit) { + d3.select(this) + .attr('stroke', '#ff00ff') + .attr('stroke-width', 3) + .attr('stroke-dasharray', '6,3') + .attr('vector-effect', 'non-scaling-stroke'); + } else if (addr_hit) { + d3.select(this) + .attr('stroke', 'red') + .attr('stroke-width', 2) + .attr('stroke-dasharray', null) + .attr('vector-effect', 'non-scaling-stroke'); + } else if (frame_hit) { + d3.select(this) + .attr('stroke', '#2196F3') + .attr('stroke-width', 2) + .attr('stroke-dasharray', null) + .attr('vector-effect', 'non-scaling-stroke'); + } else { + d3.select(this) + .attr('stroke', null) + .attr('stroke-width', null) + .attr('stroke-dasharray', null); + } + if (addr_hit) addr_matches++; + if (frame_hit) frame_matches++; + }); + search_label.text(addr_query ? `${addr_matches} match${addr_matches !== 1 ? 'es' : ''}` : ''); + frame_label.text(frame_query ? `${frame_matches} match${frame_matches !== 1 ? 'es' : ''}` : ''); + } + + search_input.on('input', apply_search_highlights); + frame_input.on('input', apply_search_highlights); } function create_settings_view(dst, snapshot, device) { @@ -1764,6 +1566,8 @@ function decode_base64(input) { const kinds = { 'Active Memory Timeline': create_trace_view, + 'Allocated Memory (incl. Private Pools)': (dst, snapshot, device) => + create_trace_view(dst, snapshot, device, false, 15000, true), 'Allocator State History': create_segment_view, 'Active Cached Segment Timeline': (dst, snapshot, device) => create_trace_view(dst, snapshot, device, true), diff --git a/torch/utils/viz/process_alloc_data.js b/torch/utils/viz/process_alloc_data.js new file mode 100644 index 0000000000000..8beea47677194 --- /dev/null +++ b/torch/utils/viz/process_alloc_data.js @@ -0,0 +1,1119 @@ +// Pure data-processing functions for PyTorch memory visualization. +// Extracted from MemoryViz.js so they can be tested independently (no d3/DOM deps). +// +// This file is the single source of truth for these functions: +// - MemoryViz.js imports them via ESM: import {...} from "./process_alloc_data.js" +// - Node.js tests load this file by stripping the export line and eval-ing +// +// TRACE EVENT ACTIONS (from c10/core/CachingDeviceAllocator.h TraceEntry::Action): +// +// "alloc" - Sub-allocation returned to user from the caching allocator. +// Recorded in alloc_found_block() (CUDACachingAllocator.cpp:1834). +// +// "free_requested" - User code called free (tensor out of scope). The block may not +// be immediately returned to the free pool if it's in use on +// another stream via record_stream. +// Recorded in free() (CUDACachingAllocator.cpp:2123). +// +// "free_completed" - Block actually returned to the allocator's free pool. For simple +// cases this fires immediately after free_requested. For cross-stream +// blocks, deferred until CUDA events confirm all streams are done. +// Recorded in free_block() (CUDACachingAllocator.cpp:3148). +// +// "segment_alloc" - New segment allocated from OS via cudaMalloc (or cuMemCreate for +// expandable segments). +// Recorded in alloc_from_expandable_segment() (CUDACachingAllocator.cpp:3548). +// +// "segment_free" - Segment returned to OS via cudaFree. Happens during empty_cache() +// or defragmentation. Only for non-expandable segments. +// Recorded in release_block() (CUDACachingAllocator.cpp:3686). +// +// "segment_map" - Physical pages mapped into an expandable segment via cuMemMap. +// The segment grows. Only with expandable segments enabled. +// Recorded in alloc_from_expandable_segment() (CUDACachingAllocator.cpp:3092). +// +// "segment_unmap" - Physical pages unmapped from an expandable segment via cuMemUnmap. +// Virtual address range retained, physical memory returned to OS. +// Only with expandable segments. Causes "pool_id unknown" for any +// trace events whose addresses fall in the unmapped range, since +// the segment no longer exists at snapshot time. +// Recorded in unmap_block() (CUDACachingAllocator.cpp:3790). +// +// "snapshot" - A call to torch.cuda.memory._snapshot(). Timestamp marker to +// correlate trace events with snapshot state. addr=0. +// Recorded in snapshot() (CUDACachingAllocator.cpp:2689). +// +// "oom" - Allocator failed to satisfy an allocation after all retries. +// addr=device_free (bytes free on GPU), size=requested allocation. +// Recorded in malloc() (CUDACachingAllocator.cpp:1629). +// +// HOW SEGMENT EVENTS ARE USED IN VISUALIZATION: +// +// The snapshot pickle contains two separate data sources: +// 1. device_traces - Ring buffer of TraceEntry actions (alloc, free, segment_map, etc.) +// 2. segments - Point-in-time dump of all segments/blocks at _snapshot() time +// +// Block-level views ("Active Memory Timeline", "Allocated Memory (incl. Private Pools)"): +// - process_alloc_data matches "alloc" and "free_completed" from device_traces. +// - segment_alloc/segment_free/segment_map/segment_unmap are skipped in the main +// alloc/free switch, but when include_private_inactive=true, segment events for +// private pools are captured separately (pool_segment_events) and used to drive +// pool envelope sizing based on reserved memory rather than active allocations. +// - The segments snapshot is used to resolve pool_id via find_pool_id() and to +// compute initial reserved memory per pool for envelope sizing. +// +// Segment-level view ("Active Cached Segment Timeline"): +// - process_alloc_data is called with plot_segments=true. +// - Matches "segment_alloc" and "segment_free" instead of alloc/free. +// - segment_map/segment_unmap are NOT matched (they don't appear in the switch). +// - Segments from the snapshot that weren't seen in the trace are added as +// initially_allocated (Phase 2). +// +// Allocator State History ("Allocator State History"): +// - EventSelector lists ALL trace events including segment_map/segment_unmap. +// - MemoryView renders the segment/block layout from the segments snapshot. +// - Clicking an event in the list redraws the layout at that point in time. +// +// Ring buffer overflow: +// - All trace event types share the same ring buffer. When it overflows, older +// events are overwritten. The allocator_settings.trace_alloc_overflowed flag +// indicates this happened, and trace_alloc_max_entries gives the buffer size. +// - Segment snapshot data (segments array) is NOT affected by ring buffer overflow. +// - The segment snapshot is always complete regardless of overflow. + +/** + * Returns true if pool_id represents a private (user-created) memory pool, + * as opposed to the default pool [0, 0]. + * + * @param {number[]|null} pool_id - Two-element array [owner_id, pool_id] from + * the CUDA caching allocator. The default pool is [0, 0]; any other non-null + * value is a private pool (e.g. FSDP's MemPool). + * @returns {boolean} + */ +function isPrivatePoolId(pool_id) { + return pool_id && !(pool_id[0] === 0 && pool_id[1] === 0); +} + +/** + * Formats a byte count as a human-readable string (e.g. "1.5GiB (1610612736 bytes)"). + * + * @param {number} num - Size in bytes. + * @param {boolean} [showBytes=true] - Whether to include the raw byte count in parentheses. + * @returns {string} + */ +function formatSize(num, showBytes = true) { + const orig = num; + // https://stackoverflow.com/questions/1094841/get-human-readable-version-of-file-size + const units = ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']; + for (const unit of units) { + if (Math.abs(num) < 1024.0) { + if (showBytes) { + return `${num.toFixed(1)}${unit}B (${orig} bytes)`; + } + return `${num.toFixed(1)}${unit}B`; + } + num /= 1024.0; + } + return `${num.toFixed(1)}YiB`; +} + +/** + * Formats a trace event's address as a display string like "b'7f4c00000_3". + * Segment-level events get an "s'" prefix, block-level events get "b'". + * + * @param {{action: string, addr: number|BigInt, version: number}} event + * @returns {string} + */ +function formatAddr(event) { + const prefix = event.action.startsWith('segment') ? 's\'' : 'b\''; + return `${prefix}${event.addr.toString(16)}_${event.version}`; +} + +/** + * Collapses consecutive duplicate strings in an array. If a string appears + * N > 2 times in a row, it's replaced with [str, ""]. + * Used to compress repetitive stack frames in the display. + * + * @param {string[]} frames + * @returns {string[]} + */ +function elideRepeats(frames) { + const result = []; + const length = frames.length; + for (let i = 0; i < length; ) { + let j = i + 1; + const f = frames[i]; + while (j < length && f === frames[j]) { + j++; + } + switch (j - i) { + case 1: + result.push(f); + break; + case 2: + result.push(f, f); + break; + default: + result.push(f, ``); + break; + } + i = j; + } + return result; +} + +/** + * Returns false for stack frames that are internal runtime noise + * (e.g. Python interpreter internals, C++ dispatch machinery). + * Used as a filter predicate on frame arrays. + * + * @param {{name: string, filename: string}} frame + * @returns {boolean} + */ +function frameFilter({name, filename}) { + const omitFunctions = [ + 'unwind::unwind', + 'CapturedTraceback::gather', + 'gather_with_cpp', + '_start', + '__libc_start_main', + 'PyEval_', + 'PyObject_', + 'PyFunction_', + ]; + + const omitFilenames = [ + 'core/boxing', + '/Register', + '/Redispatch', + 'pythonrun.c', + 'Modules/main.c', + 'Objects/call.c', + 'Objects/methodobject.c', + 'pycore_ceval.h', + 'ceval.c', + 'cpython/abstract.h', + ]; + + for (const of of omitFunctions) { + if (name.includes(of)) { + return false; + } + } + + for (const of of omitFilenames) { + if (filename.includes(of)) { + return false; + } + } + + return true; +} + +/** + * Formats user-attached metadata (from torch.cuda.memory._record_memory_history) + * as a display string. Returns '' if no metadata is present. + * + * @param {string|Object|null|undefined} user_metadata + * @returns {string} + */ +function format_user_metadata(user_metadata) { + if (!user_metadata) { + return ''; + } + if (typeof user_metadata === 'string') { + return `User Metadata:\n ${user_metadata}`; + } + if (typeof user_metadata === 'object' && Object.keys(user_metadata).length === 0) { + return ''; + } + const metadata_lines = Object.entries(user_metadata) + .map(([key, value]) => ` ${key}: ${value}`); + return 'User Metadata:\n' + metadata_lines.join('\n'); +} + +/** + * Formats the forward-pass stack trace (captured via torch.autograd) as a + * display string showing where a tensor was originally created. + * + * @param {string[]|null|undefined} forward_frames + * @returns {string} + */ +function format_forward_frames(forward_frames) { + if (!forward_frames || forward_frames.length === 0) { + return ''; + } + let frames_str = forward_frames.join(''); + frames_str = frames_str.trimEnd(); + return `\n\n=== Forward Pass Stack Trace (where this tensor was created) ===\n${frames_str}`; +} + +/** + * Formats an array of stack frames into a human-readable string. + * Filters out runtime noise via frameFilter, annotates FX graph debug info + * when available, and collapses consecutive duplicate frames. + * + * @param {{filename: string, line: number, name: string, + * fx_node_op?: string, fx_node_name?: string, + * fx_node_target?: string, fx_original_trace?: string}[]} frames + * @returns {string} + */ +function format_frames(frames) { + if (frames.length === 0) { + return ( + `This block has no frames. Potential causes:\n` + + `1) This block was allocated before _record_memory_history was enabled.\n` + + `2) The context or stacks passed to _record_memory_history does not include this block. Consider changing context to 'state', 'alloc', or 'all', or changing stacks to 'all'.\n` + + `3) This event occurred during backward, which has no python frames, and memory history did not include C++ frames. Use stacks='all' to record both C++ and python frames.\n` + + `4) This block was reconstructed from the allocator's segment snapshot (not from a trace event). The snapshot records which blocks exist at the moment _snapshot() is called, but does not carry stack frames. This typically happens for blocks that were allocated before tracing started and never freed, or for inactive blocks in private memory pools.\n` + + `5) The original alloc event was evicted from the trace ring buffer (older entries are overwritten when the buffer is full). Increase the max_entries argument to _record_memory_history to retain more events.` + ); + } + const frame_strings = frames + .filter(frameFilter) + .map(f => { + let frame_str = `${f.filename}:${f.line}:${f.name}`; + + if (f.fx_node_op || f.fx_node_name || f.fx_node_target) { + const fx_parts = []; + if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`); + if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`); + if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`); + frame_str += `\n >> FX: ${fx_parts.join(', ')}`; + } + + if (f.fx_original_trace) { + frame_str += `\n >> Original Model Code:`; + const original_lines = f.fx_original_trace.trim().split('\n'); + for (const line of original_lines) { + frame_str += `\n ${line}`; + } + } + + return frame_str; + }); + return elideRepeats(frame_strings).join('\n'); +} + +/** + * Transforms a memory snapshot into a stacked-area timeline suitable for + * rendering by MemoryPlot. This is the core data-processing function behind + * the "Active Memory Timeline" and "Allocated Memory (incl. Private Pools)" + * visualization tabs. + * + * HIGH-LEVEL ALGORITHM: + * + * 1. TRACE EVENT MATCHING: Scans device_traces to pair alloc events with their + * corresponding free_completed events (by address). Events whose matching + * alloc was lost (e.g. ring buffer wrap) become "initially_allocated" — + * blocks assumed to exist at the start of the trace. + * + * 2. SEGMENT SNAPSHOT: Supplements trace data with the current segment state. + * Blocks marked active_allocated (or inactive in private pools, when + * include_private_inactive=true) that weren't seen in the trace are also + * added as initially_allocated. + * + * 3. DETAIL LIMITING: Only the largest max_entries elements get individual + * rectangles in the plot. Smaller elements are aggregated into a single + * "summarized" band to keep rendering fast. + * + * 4. STACKED AREA CONSTRUCTION: Replays alloc/free events in order, building + * a stacked-area dataset where each element has timesteps, y-offsets, and + * a size. Elements are stacked bottom-to-top; frees remove from the stack + * and shift elements above downward. + * + * 5. PRIVATE POOL ENVELOPES (include_private_inactive=true): Each private pool + * (e.g. FSDP's MemPool) gets a single gray "envelope" rectangle whose + * height is the pool's reserved memory (from segment_map/segment_unmap + * events and the segment snapshot). Active blocks within the pool are + * rendered as colored stripes inside the envelope. The envelope only grows + * (never shrinks), representing the pool's actual GPU memory footprint. + * This correctly handles fragmentation: when a large alloc triggers a + * segment_map because existing free blocks aren't contiguous, the envelope + * grows by the reserved amount, not just the active allocation. + * + * Initially-allocated private pool blocks are PRE-LOADED into pool state + * so that when their free event appears in the trace, they are correctly + * recognized as frees (not misinterpreted as new allocations). + * + * NOTE ON FREE EVENT MATCHING: The C++ allocator emits 'free_requested' and + * 'free_completed' for each deallocation. This function matches against 'free' + * (which no longer appears in modern traces — effectively dead code) and + * 'free_completed'. Only free_completed does the actual matching. Matching + * both 'free_requested' AND 'free_completed' would cause double-processing + * since they share the same address. + * + * @param {Object} snapshot - Memory snapshot from torch.cuda.memory._snapshot(). + * @param {Object[]} snapshot.segments - Current allocator segment state. + * @param {Object[][]} snapshot.device_traces - Per-device arrays of trace events. + * Each event has {action, addr, size, frames, stream, segment_pool_id?, ...}. + * @param {string[]} snapshot.categories - Category names for color-coding. + * @param {number} device - Device index into snapshot.device_traces. + * @param {boolean} plot_segments - If true, plot segment-level (cudaMalloc) + * events instead of sub-allocation events. + * @param {number} max_entries - Maximum number of elements to render individually. + * Elements beyond this limit are aggregated into the "summarized" band. + * @param {boolean} [include_private_inactive=false] - If true, include inactive + * blocks from private pools and render pool envelopes. Used by the + * "Allocated Memory (incl. Private Pools)" tab. + * + * @returns {{ + * max_size: number, + * allocations_over_time: Object[], + * max_at_time: number[], + * summarized_mem: Object, + * elements_length: number, + * context_for_id: function(number): string + * }} + * - max_size: peak total memory observed during the action replay (used for + * y-axis scaling). Note: this is only updated inside the action loop, so + * the initial state from initially_allocated may not be reflected here + * (use max_at_time for the true peak). + * - allocations_over_time: array of stacked-area data objects, each with + * {elem, timesteps[], offsets[], size, color}. + * - max_at_time: total memory at each timestep (for minimap rendering). + * - summarized_mem: the aggregated band for small elements. + * - elements_length: total number of unique allocation elements. + * - context_for_id: function that returns a human-readable description + * string for a given element index (address, size, stack trace, etc.). + */ +function process_alloc_data(snapshot, device, plot_segments, max_entries, include_private_inactive = false) { + const elements = []; + // Contains two types of blocks + // 1. free without alloc in trace + // 2. actively allocated in segments, but no matching alloc in trace + const initially_allocated = []; + const actions = []; + const addr_to_alloc = {}; + + const device_segments = snapshot.segments + .filter(s => s.device === device) + .sort((a, b) => { + if (a.address === b.address) return 0; + return a.address < b.address ? -1 : 1; + }); + + // Binary search to find which segment contains a given address. + function find_pool_id(addr) { + let left = 0; + let right = device_segments.length - 1; + while (left <= right) { + const mid = Math.floor((left + right) / 2); + const seg = device_segments[mid]; + const seg_end = seg.address + (typeof seg.address === "bigint" ? BigInt(seg.total_size) : seg.total_size); + if (addr < seg.address) { + right = mid - 1; + } else if (addr >= seg_end) { + left = mid + 1; + } else { + return seg.segment_pool_id; + } + } + return null; + } + + const alloc = plot_segments ? 'segment_alloc' : 'alloc'; + const [free, free_completed] = plot_segments + ? ['segment_free', 'segment_free'] + : ['free', 'free_completed']; + // pool_segment_events tracks segment_map/segment_unmap events for private + // pools, recording their position relative to the actions list. This lets + // the Phase 3 replay grow pool envelopes based on actual reserved memory + // (segments) rather than just active allocations. + const pool_segment_events = []; + for (const e of snapshot.device_traces[device]) { + switch (e.action) { + case alloc: + elements.push(e); + addr_to_alloc[e.addr] = elements.length - 1; + actions.push(elements.length - 1); + break; + case free: + case free_completed: + if (e.addr in addr_to_alloc) { + // Matched: reuse the element from the alloc event + actions.push(addr_to_alloc[e.addr]); + delete addr_to_alloc[e.addr]; + } else { + // Unmatched free: alloc happened before recording (or was evicted + // from the ring buffer). Create a new element from the free event; + // its stack trace will show the free site, not the alloc site. + elements.push(e); + initially_allocated.push(elements.length - 1); + actions.push(elements.length - 1); + } + break; + default: + break; + } + if (include_private_inactive && + (e.action === 'segment_alloc' || e.action === 'segment_free' || + e.action === 'segment_map' || e.action === 'segment_unmap')) { + const pid = find_pool_id(e.addr); + if (isPrivatePoolId(pid)) { + const is_add = e.action === 'segment_alloc' || e.action === 'segment_map'; + pool_segment_events.push({ + position: actions.length, + delta: is_add ? e.size : -e.size, + pool_key: format_pool_key(pid, e.stream ?? 0), + }); + } + } + } + + // --- Phase 2: Add elements from the snapshot --- + for (const seg of snapshot.segments) { + if (seg.device !== device) { + continue; + } + if (plot_segments) { + if (!(seg.address in addr_to_alloc)) { + const element = { + action: 'alloc', + addr: seg.address, + size: seg.total_size, + frames: [], + stream: seg.stream, + version: seg.version, + }; + elements.push(element); + initially_allocated.push(elements.length - 1); + } + } else { + for (const b of seg.blocks) { + const addr = b.addr ?? b.address; + if (b.state === 'active_allocated' && !(addr in addr_to_alloc)) { + const element = { + action: 'alloc', + addr, + size: b.requested_size, + frames: b.frames, + stream: seg.stream, + version: b.version, + segment_pool_id: seg.segment_pool_id, + ghost: true, + }; + elements.push(element); + initially_allocated.push(elements.length - 1); + } + } + } + } + + // Resolve pool IDs for trace elements by looking up which segment they fall in + for (const elem of elements) { + if (!elem.segment_pool_id) { + elem.segment_pool_id = find_pool_id(elem.addr); + } + } + + initially_allocated.reverse(); + // If there are no trace actions but there are existing allocations, + // show a flat graph with the initial state + if (actions.length === 0 && initially_allocated.length > 0) { + actions.push(initially_allocated.pop()); + } + + // --- Phase 3: Build the stacked-area timeline --- + const current = []; // stack of element indices (bottom to top) + const current_data = []; // parallel array of visualization data objects + const data = []; // all data objects (including completed ones) + let max_size = 0; + + let total_mem = 0; + let total_summarized_mem = 0; + let timestep = 0; + + const max_at_time = []; + + const summarized_mem = { + elem: 'summarized', + timesteps: [], + offsets: [total_mem], + size: [], + color: 0, + }; + const summarized_elems = {}; + + // Record the current memory state and advance time by n steps + function advance(n) { + summarized_mem.timesteps.push(timestep); + summarized_mem.offsets.push(total_mem); + summarized_mem.size.push(total_summarized_mem); + timestep += n; + for (let i = 0; i < n; i++) { + max_at_time.push(total_mem + total_summarized_mem); + } + } + + // Only render the largest max_entries elements individually (across all + // pools). Pools with larger allocations naturally get more of the budget. + // Remaining pool elements go into per-pool summarized stripes; remaining + // non-pool elements go into the global summarized band. + const sizes = elements + .map((x, i) => [x.size, i]) + .sort(([x, _xi], [y, _yi]) => y - x); + + const draw_elem = {}; + for (const [_s, e] of sizes.slice(0, max_entries)) { + draw_elem[e] = true; + } + + // Push an element onto the memory stack + function add_allocation(elem) { + const element_obj = elements[elem]; + const size = element_obj.size; + current.push(elem); + let color = elem; + if (snapshot.categories.length > 0) { + color = snapshot.categories.indexOf(element_obj.category || 'unknown'); + } + const e = { + elem, + timesteps: [timestep], + offsets: [total_mem], + size, + color, + }; + if (element_obj.ghost) e.ghost = true; + current_data.push(e); + data.push(e); + total_mem += size; + element_obj.max_allocated_mem = total_mem + total_summarized_mem; + } + + // --- Pool envelope tracking (only when include_private_inactive=true) --- + // Each private pool gets a gray envelope whose height = high-water mark. + // Active blocks are rendered as colored stripes within the envelope. + const pools = {}; + const pool_active_elems = {}; + + function format_pool_key(pid, stream) { + return `${pid[0]},${pid[1]},s${stream}`; + } + + function get_pool_key(elem_idx) { + const pid = elements[elem_idx].segment_pool_id; + if (!isPrivatePoolId(pid)) return null; + return format_pool_key(pid, elements[elem_idx].stream); + } + + function get_or_create_pool(pool_key) { + if (!(pool_key in pools)) { + pools[pool_key] = { + max: 0, active: 0, reserved: 0, + drawn_active: 0, summarized_active: 0, + envelope_data: null, summarized_data: null, + block_stack: [], // [{elem, size, inner_offset, stripe_data}] + }; + } + return pools[pool_key]; + } + + function elem_color(elem_idx) { + if (snapshot.categories.length > 0) { + return snapshot.categories.indexOf(elements[elem_idx].category || 'unknown'); + } + return elem_idx; + } + + function shift_pool_stripes(pool, delta) { + for (const block of pool.block_stack) { + const s = block.stripe_data; + s.timesteps.push(timestep); + s.offsets.push(s.offsets.at(-1)); + s.timesteps.push(timestep + 3); + s.offsets.push(s.offsets.at(-1) + delta); + } + if (pool.summarized_data) { + const sd = pool.summarized_data; + sd.timesteps.push(timestep); + sd.offsets.push(sd.offsets.at(-1)); + sd.size.push(sd.size.at(-1)); + sd.timesteps.push(timestep + 3); + sd.offsets.push(sd.offsets.at(-1) + delta); + sd.size.push(sd.size.at(-1)); + } + } + + // Update or create the per-pool summarized stripe. Sits on top of drawn + // stripes (offset = envelope base + drawn_active), size = summarized_active. + function update_pool_summary(pool, ts) { + if (!pool.envelope_data) return; + const base = pool.envelope_data.offsets.at(-1) + pool.drawn_active; + if (pool.summarized_data === null) { + pool.summarized_data = { + elem: 'summarized', + timesteps: [ts], + offsets: [base], + size: [pool.summarized_active], + color: 0, + opacity: 0.3, + }; + data.push(pool.summarized_data); + } else { + const sd = pool.summarized_data; + sd.timesteps.push(ts); + sd.offsets.push(base); + sd.size.push(pool.summarized_active); + } + } + + // Animate shifting all elements above idx by delta (used when an element + // is inserted or removed from the middle of the stack) + function shift_elements_above(idx, delta) { + for (let j = idx; j < current.length; j++) { + const e = current_data[j]; + e.timesteps.push(timestep); + e.offsets.push(e.offsets.at(-1)); + e.timesteps.push(timestep + 3); + e.offsets.push(e.offsets.at(-1) + delta); + if (Array.isArray(e.size)) { + e.size.push(e.size.at(-1)); + e.size.push(e.size.at(-1)); + } + const pk = typeof current[j] === 'string' && current[j].startsWith('pool:') + ? current[j].slice(5) : null; + if (pk && pk in pools) { + shift_pool_stripes(pools[pk], delta); + } + } + } + + // Shift all elements stacked above a pool envelope by delta (no animation). + // Used during timestep-0 initialization when there are no transition frames. + function shift_above_pool_no_anim(pool_key, delta) { + const pidx = current.indexOf(`pool:${pool_key}`); + if (pidx >= 0) { + for (let j = pidx + 1; j < current.length; j++) { + const e = current_data[j]; + e.offsets[e.offsets.length - 1] += delta; + } + } + } + + // Grow a pool envelope to accommodate new_size bytes (the larger of active + // allocations and reserved segment memory). The envelope only grows (never + // shrinks) — it represents the pool's actual GPU memory footprint. + function grow_pool_envelope(pool, pool_key, new_size) { + if (new_size <= pool.max) return; + const delta = new_size - pool.max; + pool.max = new_size; + const env = pool.envelope_data; + env.timesteps.push(timestep); + env.offsets.push(env.offsets.at(-1)); + env.size.push(env.size.at(-1)); + env.timesteps.push(timestep + 3); + env.offsets.push(env.offsets.at(-1)); + env.size.push(pool.max); + const pidx = current.indexOf(`pool:${pool_key}`); + if (pidx >= 0) { + shift_elements_above(pidx + 1, delta); + } + total_mem += delta; + advance(3); + } + + // --- Process initially_allocated elements --- + // These are blocks that existed before the trace window started. They come + // from two sources: + // 1. Unmatched free events (free_completed without a prior alloc in trace) + // 2. active_allocated blocks in the segment snapshot with no trace event + // + // For private pool blocks: pre-load into pool state at timestep 0 (no + // animation). This serves two purposes: + // - The envelope starts at the correct initial size + // - When the free event fires during replay, it's recognized as a free + // (not misinterpreted as a new allocation) + // + // For non-pool blocks: added to the global stack (draw_elem) or global + // summarized band. + for (const elem of initially_allocated) { + if (include_private_inactive && get_pool_key(elem)) { + const pk = get_pool_key(elem); + const size = elements[elem].size; + const pool = get_or_create_pool(pk); + // Mark as active so the replay loop recognizes the free event + pool_active_elems[elem] = pk; + + // Create pool envelope on first encounter + if (pool.envelope_data === null) { + const env = { + elem: `pool:${pk}`, + timesteps: [0], + offsets: [total_mem], + size: [0], + color: 9, + }; + pool.envelope_data = env; + // Add to the global stack so elements above it shift when it grows + current.push(`pool:${pk}`); + current_data.push(env); + data.push(env); + } + + pool.active += size; + + // Grow envelope to fit: use max(active, reserved) because active can + // exceed reserved when block sizes are stale (e.g. segment shrank via + // unmap after the block was allocated). + const init_target = Math.max(pool.active, pool.reserved); + if (init_target > pool.max) { + const delta = init_target - pool.max; + pool.max = init_target; + const env = pool.envelope_data; + env.size[env.size.length - 1] = pool.max; + total_mem += delta; + // Shift all elements stacked above this pool's envelope up by delta + shift_above_pool_no_anim(pk, delta); + } + + if (elem in draw_elem) { + const inner_offset = pool.drawn_active; + pool.drawn_active += size; + const stripe = { + elem, + timesteps: [0], + offsets: [pool.envelope_data.offsets.at(-1) + inner_offset], + size, + color: elem_color(elem), + opacity: 0.5, + ghost: elements[elem].ghost || false, + }; + pool.block_stack.push({elem, size, inner_offset, stripe_data: stripe}); + data.push(stripe); + } else { + pool.summarized_active += size; + } + continue; + } + // Non-pool element: render individually or add to global summarized band + if (elem in draw_elem) { + add_allocation(elem); + } else { + total_summarized_mem += elements[elem].size; + summarized_elems[elem] = true; + } + } + + // Fix up pool stripe offsets — stripes are not in current_data so they + // don't get shifted when other pools grow during initially_allocated + // processing. Recompute from the envelope's final offset. + // Also create per-pool summarized data for initial non-drawn elements. + for (const pk in pools) { + const p = pools[pk]; + if (!p.envelope_data) continue; + const env_offset = p.envelope_data.offsets.at(-1); + for (const block of p.block_stack) { + const s = block.stripe_data; + for (let i = 0; i < s.offsets.length; i++) { + s.offsets[i] = env_offset + block.inner_offset; + } + } + if (p.summarized_active > 0) { + p.summarized_data = { + elem: 'summarized', + timesteps: [0], + offsets: [env_offset + p.drawn_active], + size: [p.summarized_active], + color: 0, + opacity: 0.3, + }; + data.push(p.summarized_data); + } + } + + // --- Initialize pool reserved memory from snapshot --- + // The envelope height for each private pool should reflect its reserved + // (segment) memory, not just active allocations. We compute the initial + // reserved value so that replaying segment_map/segment_unmap events from + // the trace arrives at the correct final value (the snapshot total). + // + // Formula: initial = snapshot_total - net_trace_delta + // - snapshot_total: sum of segment total_size for this pool (ground truth + // at snapshot time) + // - net_trace_delta: sum of segment_map sizes minus segment_unmap sizes + // for this pool in the trace + // + // This works regardless of trace truncation (ring buffer overflow): the + // initial value represents the reserved memory at the start of the trace + // window, not at program start. If there are no segment events in the + // trace for a pool, net_trace_delta is 0 and initial = snapshot_total. + if (include_private_inactive) { + const snapshot_reserved = {}; + for (const seg of device_segments) { + const pid = seg.segment_pool_id; + if (isPrivatePoolId(pid)) { + const pk = format_pool_key(pid, seg.stream ?? 0); + snapshot_reserved[pk] = (snapshot_reserved[pk] || 0) + seg.total_size; + } + } + const net_from_trace = {}; + for (const se of pool_segment_events) { + net_from_trace[se.pool_key] = (net_from_trace[se.pool_key] || 0) + se.delta; + } + for (const pk in snapshot_reserved) { + const pool = get_or_create_pool(pk); + pool.reserved = snapshot_reserved[pk] - (net_from_trace[pk] || 0); + // Grow envelope to initial reserved (no animation — pre-existing) + if (pool.reserved > pool.max && pool.envelope_data) { + const delta = pool.reserved - pool.max; + pool.max = pool.reserved; + const env = pool.envelope_data; + env.size[env.size.length - 1] = pool.max; + total_mem += delta; + shift_above_pool_no_anim(pk, delta); + } + } + // Fix up pool stripe offsets again after reserved-based envelope growth + for (const pk in pools) { + const p = pools[pk]; + if (!p.envelope_data) continue; + const env_offset = p.envelope_data.offsets.at(-1); + for (const block of p.block_stack) { + const s = block.stripe_data; + for (let i = 0; i < s.offsets.length; i++) { + s.offsets[i] = env_offset + block.inner_offset; + } + } + } + } + + // --- Replay alloc/free actions to build the timeline --- + let seg_event_idx = 0; + for (let action_i = 0; action_i < actions.length; action_i++) { + // Process segment events that occurred at or before this action position. + // These grow pool envelopes based on actual reserved memory changes. + while (seg_event_idx < pool_segment_events.length && + pool_segment_events[seg_event_idx].position <= action_i) { + const se = pool_segment_events[seg_event_idx]; + const pool = get_or_create_pool(se.pool_key); + pool.reserved += se.delta; + if (pool.reserved > pool.max && pool.envelope_data) { + grow_pool_envelope(pool, se.pool_key, pool.reserved); + } + seg_event_idx++; + } + + const elem = actions[action_i]; + const size = elements[elem].size; + const pool_key = include_private_inactive ? get_pool_key(elem) : null; + + if (pool_key) { + // --- Private pool element --- + if (!(elem in pool_active_elems)) { + // Pool alloc: add to pool, grow envelope if needed + pool_active_elems[elem] = pool_key; + const pool = get_or_create_pool(pool_key); + + if (pool.envelope_data === null) { + const env = { + elem: `pool:${pool_key}`, + timesteps: [timestep], + offsets: [total_mem], + size: [0], + color: 9, + }; + pool.envelope_data = env; + current.push(`pool:${pool_key}`); + current_data.push(env); + data.push(env); + } + + pool.active += size; + + const envelope_target = Math.max(pool.active, pool.reserved); + if (envelope_target > pool.max) { + grow_pool_envelope(pool, pool_key, envelope_target); + } + + if (elem in draw_elem) { + const inner_offset = pool.drawn_active; + pool.drawn_active += size; + const stripe = { + elem, + timesteps: [timestep], + offsets: [pool.envelope_data.offsets.at(-1) + inner_offset], + size, + color: elem_color(elem), + opacity: 0.5, + }; + pool.block_stack.push({elem, size, inner_offset, stripe_data: stripe}); + data.push(stripe); + // Shift summarized stripe up (it sits on top of drawn stripes) + if (pool.summarized_data) { + update_pool_summary(pool, timestep); + } + } else { + pool.summarized_active += size; + update_pool_summary(pool, timestep); + } + advance(1); + elements[elem].max_allocated_mem = total_mem + total_summarized_mem; + } else { + // Pool free: end stripe, shift stripes above down within the pool. + // The envelope stays at its high-water mark (never shrinks). + const pool = pools[pool_key]; + const block_idx = pool.block_stack.findIndex(b => b.elem === elem); + if (block_idx >= 0) { + // Drawn stripe freed + advance(1); + const block = pool.block_stack[block_idx]; + block.stripe_data.timesteps.push(timestep); + block.stripe_data.offsets.push(block.stripe_data.offsets.at(-1)); + + pool.block_stack.splice(block_idx, 1); + pool.active -= size; + pool.drawn_active -= size; + + // Shift drawn stripes above and the summarized stripe down + const need_shift = block_idx < pool.block_stack.length || pool.summarized_data; + if (need_shift) { + for (let j = block_idx; j < pool.block_stack.length; j++) { + const b = pool.block_stack[j]; + b.inner_offset -= size; + const s = b.stripe_data; + s.timesteps.push(timestep); + s.offsets.push(s.offsets.at(-1)); + s.timesteps.push(timestep + 3); + s.offsets.push(pool.envelope_data.offsets.at(-1) + b.inner_offset); + } + if (pool.summarized_data) { + update_pool_summary(pool, timestep); + } + advance(3); + } + } else { + // Non-drawn element freed — summarized stripe shrinks on top + pool.active -= size; + pool.summarized_active -= size; + update_pool_summary(pool, timestep); + advance(1); + } + delete pool_active_elems[elem]; + } + max_size = Math.max(total_mem + total_summarized_mem, max_size); + continue; + } + + // --- Non-pool element --- + if (!(elem in draw_elem)) { + // Too small to render individually — goes into the summarized band + if (elem in summarized_elems) { + advance(1); + total_summarized_mem -= size; + summarized_elems[elem] = null; + } else { + total_summarized_mem += size; + summarized_elems[elem] = true; + advance(1); + } + continue; + } + const idx = current.findLastIndex(x => x === elem); + if (idx === -1) { + // First appearance → alloc + add_allocation(elem); + advance(1); + } else { + // Second appearance → free: remove from stack, shift elements above down + advance(1); + const removed = current_data[idx]; + removed.timesteps.push(timestep); + removed.offsets.push(removed.offsets.at(-1)); + current.splice(idx, 1); + current_data.splice(idx, 1); + + if (idx < current.length) { + shift_elements_above(idx, -size); + advance(3); + } + total_mem -= size; + } + max_size = Math.max(total_mem + total_summarized_mem, max_size); + } + + // Process any remaining segment events after the last action + while (seg_event_idx < pool_segment_events.length) { + const se = pool_segment_events[seg_event_idx]; + const pool = get_or_create_pool(se.pool_key); + pool.reserved += se.delta; + if (pool.reserved > pool.max && pool.envelope_data) { + grow_pool_envelope(pool, se.pool_key, pool.reserved); + } + max_size = Math.max(total_mem + total_summarized_mem, max_size); + seg_event_idx++; + } + + // --- Finalize: close all still-active elements --- + for (const elem of current_data) { + elem.timesteps.push(timestep); + elem.offsets.push(elem.offsets.at(-1)); + if (Array.isArray(elem.size)) { + elem.size.push(elem.size.at(-1)); + } + } + for (const pk in pools) { + for (const block of pools[pk].block_stack) { + const s = block.stripe_data; + s.timesteps.push(timestep); + s.offsets.push(s.offsets.at(-1)); + } + if (pools[pk].summarized_data) { + const sd = pools[pk].summarized_data; + sd.timesteps.push(timestep); + sd.offsets.push(sd.offsets.at(-1)); + sd.size.push(sd.size.at(-1)); + } + } + data.push(summarized_mem); + + return { + max_size, + allocations_over_time: data, + max_at_time, + summarized_mem, + elements_length: elements.length, + context_for_id: id => { + const elem = elements[id]; + let text = `Addr: ${formatAddr(elem)}`; + text = `${text}, Size: ${formatSize(elem.size)} allocation`; + text = `${text}, Total memory used after allocation: ${formatSize( + elem.max_allocated_mem, + )}`; + const context = elem?.compile_context ?? 'None'; + text = `${text}, Compile context: ${context}`; + if (elem.stream !== null) { + text = `${text}, stream ${elem.stream}`; + } + if (elem.segment_pool_id) { + text = `${text}, pool_id (${elem.segment_pool_id[0]}, ${elem.segment_pool_id[1]})`; + } else { + text = `${text}, pool_id unknown`; + } + if (elem.timestamp !== null) { + var d = new Date(elem.time_us / 1000); + text = `${text}, timestamp ${d}`; + } + if (!elem.action.includes('alloc')) { + text = `${text}\nalloc not recorded, stack trace for free:`; + } + if (elem.ghost) { + text = `${text}\n[Ghost block] This block exists in the segment snapshot but has no alloc trace events. ` + + `It was allocated before _record_memory_history() was called, or its alloc event was evicted ` + + `from the trace ring buffer. The block is still active (not freed) at snapshot time.`; + } + const user_metadata_str = format_user_metadata(elem.user_metadata); + if (user_metadata_str) { + text = `${text}\n${user_metadata_str}`; + } + text = `${text}\n${format_frames(elem.frames)}`; + text = `${text}${format_forward_frames(elem.forward_frames)}`; + return text; + }, + }; +} + +export { process_alloc_data, isPrivatePoolId, formatSize, formatAddr, + elideRepeats, frameFilter, format_user_metadata, + format_forward_frames, format_frames }; diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index dd810e1a966b7..373a12a2ad78b 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -9,8 +9,10 @@ from __future__ import annotations +import os import threading import traceback +import warnings from functools import lru_cache from typing import Any, NewType, TYPE_CHECKING @@ -44,6 +46,7 @@ _is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False) _lazy_seed_tracker = _LazySeedTracker() default_generators: tuple[torch._C.Generator] = () # type: ignore[assignment] +_cached_device_count: int | None = None def _is_compiled() -> bool: @@ -66,12 +69,174 @@ def _maybe_exchange_device(device: int) -> int: raise NotImplementedError("PyTorch was compiled without XPU support") -@lru_cache(maxsize=1) +def _parse_visible_devices() -> list[int]: + r"""Parse ``ZE_AFFINITY_MASK`` and return visible device ordinals. + + Returns a list of non-negative device ordinals specified by the mask. + When the mask is unset, returns ``[0, 1, ..., 127]`` (the maximum range + for ``int8_t`` device indices). Returns an empty list for unsupported + COMPOSITE-style masks (e.g. ``"0.0,0.1"``). + """ + var = os.getenv("ZE_AFFINITY_MASK") + if var is None: + # DeviceIndex is stored as int8_t, so valid indices are 0–127 + # (up to 128 devices). Return the full range when no mask is set. + return list(range(128)) + + visible_devices: list[int] = [] + for elem in var.split(","): + try: + x = int(elem.strip()) + except ValueError: + # A non-integer token (e.g. "0.0" in COMPOSITE-mode format) + # means the mask is unsupported here; signal that by returning + # an empty list. + return [] + if x >= 0 and x not in visible_devices: + visible_devices.append(x) + return visible_devices + + +def _raw_device_count_zes(visible_mask: list[int]) -> int: + r"""Return the number of visible XPU devices via Level Zero Sysman. + + Enumerates devices from the first Level Zero Sysman driver and counts those + whose logical index appears in *visible_mask*. Only devices listed in + the visible mask participate in counting. + + Discrete GPUs (dGPUs) take priority: if any visible dGPU is found, only + dGPUs are counted; integrated GPUs (iGPUs) are counted only when no + visible dGPU exists. + + For tiled dGPUs (``numSubdevices > 0``), the counting depends on + ``ZE_FLAT_DEVICE_HIERARCHY``: + + - **FLAT / COMBINED** (default): each sub-device is exposed as a + separate top-level device and counted individually. + - **COMPOSITE**: sub-devices are hidden; the whole physical device + counts as one. + + Returns a negative value on initialization or enumeration failure. + """ + from ctypes import byref, c_uint32 + + try: + import pyzes # type: ignore[import] + except ImportError: + return -1 + + def _zes_check(rc: int, msg: str) -> bool: + """Return True if the call failed (rc != 0) after issuing a warning.""" + if rc != 0: + warnings.warn(msg, stacklevel=3) + return rc != 0 + + if _zes_check(pyzes.zesInit(0), "Can't initialize Level Zero Sysman"): + return -1 + + driver_count = c_uint32(0) + if _zes_check( + pyzes.zesDriverGet(byref(driver_count), None), + "Can't get Level Zero Sysman driver count", + ): + return -1 + if driver_count.value == 0: + return 0 + + drivers = (pyzes.zes_driver_handle_t * driver_count.value)() + if _zes_check( + pyzes.zesDriverGet(byref(driver_count), drivers), + "Can't get Level Zero Sysman driver handles", + ): + return -1 + + device_count = c_uint32(0) + if _zes_check( + pyzes.zesDeviceGet(drivers[0], byref(device_count), None), + "Can't get Level Zero Sysman device count", + ): + return -1 + + devices = (pyzes.zes_device_handle_t * device_count.value)() + if _zes_check( + pyzes.zesDeviceGet(drivers[0], byref(device_count), devices), + "Can't get Level Zero Sysman device handles", + ): + return -1 + + # --- Count visible dGPUs and iGPUs --- + ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = 1 << 0 + hierarchy = os.getenv("ZE_FLAT_DEVICE_HIERARCHY") + expose_sub_devices = hierarchy != "COMPOSITE" + + visible = set(visible_mask) + logical_index = 0 + num_igpu = 0 + num_dgpu = 0 + + for device in devices: + props = pyzes.zes_device_properties_t() + props.stype = pyzes.ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES + if _zes_check( + pyzes.zesDeviceGetProperties(device, byref(props)), + "Can't get Level Zero Sysman device properties", + ): + return -1 + + is_integrated = bool(props.core.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED) + + # Determine how many logical indices this physical device occupies. + # Tiled dGPUs in FLAT/COMBINED mode expose each sub-device separately; + # everything else (iGPU, non-tiled dGPU, COMPOSITE mode) counts as one. + num_slots = ( + props.numSubdevices + if not is_integrated and props.numSubdevices > 0 and expose_sub_devices + else 1 + ) + + for _ in range(num_slots): + if logical_index in visible: + if is_integrated: + num_igpu += 1 + else: + num_dgpu += 1 + logical_index += 1 + + # Prefer dGPU count; fall back to iGPU count only when no dGPU is visible. + return num_dgpu or num_igpu + + +def _device_count_zes() -> int: + r"""Return the number of visible XPU devices, or -1 on failure.""" + visible_devices = _parse_visible_devices() + if not visible_devices: + return -1 + return _raw_device_count_zes(visible_devices) + + def device_count() -> int: - r"""Return the number of XPU device available.""" + r""" + Return the number of XPU device available. + + .. note:: This API will NOT poison fork if Level Zero Sysman discovery succeeds. + See :ref:`multiprocessing-poison-fork-note` for more details. + """ if not _is_compiled(): return 0 - return torch._C._xpu_getDeviceCount() + global _cached_device_count + if _cached_device_count is not None: + return _cached_device_count + if _initialized or hasattr(_tls, "is_initializing"): + count = torch._C._xpu_getDeviceCount() + else: + zes_count = _device_count_zes() + count = torch._C._xpu_getDeviceCount() if zes_count < 0 else zes_count + # Do not cache the device count prior to XPU initialization, because + # the number of devices can change due to changes to ZE_AFFINITY_MASK + # setting prior to XPU initialization. + if _initialized: + _cached_device_count = count + return count def is_available() -> bool: @@ -285,6 +450,8 @@ def get_device_properties( - ``gpu_eu_count`` (int): number of EUs (Execution Unit). - ``max_work_group_size``: (int): maximum number of work-items permitted in a work-group. - ``max_num_sub_groups`` (int): maximum number of sub-groups supported in a work-group. + - ``memory_clock_rate`` (int) maximum clock rate of device's global memory in MHz. + - ``memory_bus_width`` (int) maximum bus width between device and memory in bits. - ``sub_group_sizes``: (list[int]): a list of supported sub-group sizes. - ``local_mem_size`` (int): device local memory capacity that can be allocated per work-group in bytes. - ``has_fp16`` (bool): whether float16 dtype is supported. diff --git a/torch/xpu/graphs.py b/torch/xpu/graphs.py index 51780050f5937..8e89ddedbe263 100644 --- a/torch/xpu/graphs.py +++ b/torch/xpu/graphs.py @@ -136,14 +136,14 @@ def raw_xpu_graph(self) -> int: r"""Returns the underlying xpuGraph_t. ``keep_graph`` must be True. XPU doesn't provide APIs to manipulate this object. - """ # noqa: B950 + """ return super().raw_xpu_graph() def raw_xpu_graph_exec(self) -> int: r"""Returns the underlying xpuGraphExec_t. ``instantiate`` must have been called if ``keep_graph`` is True, or ``capture_end`` must have been called if ``keep_graph`` is False. If you call ``instantiate()`` after ``raw_xpu_graph_exec()``, the previously returned xpuGraphExec_t will be destroyed. It is your responsibility not to use this object after destruction. XPU doesn't provide APIs to manipulate this object. - """ # noqa: B950 + """ return super().raw_xpu_graph_exec() @@ -162,7 +162,7 @@ class graph: For effective memory sharing, if you pass a ``pool`` used by a previous capture and the previous capture used an explicit ``stream`` argument, you should pass the same ``stream`` argument to this capture. - """ # noqa: B950 + """ default_capture_stream: torch.xpu.Stream | None = None diff --git a/torchgen/_autoheuristic/benchmark_runner.py b/torchgen/_autoheuristic/benchmark_runner.py index 117058b3373f3..67e24aec82d36 100644 --- a/torchgen/_autoheuristic/benchmark_runner.py +++ b/torchgen/_autoheuristic/benchmark_runner.py @@ -55,12 +55,15 @@ def add_base_arguments(self) -> None: def run(self) -> None: torch.set_default_device("cuda") args = self.parser.parse_args() + # Set environment variables to control autoheuristic behavior + import os + if args.use_heuristic: - torch._inductor.config.autoheuristic_use = self.name - torch._inductor.config.autoheuristic_collect = "" + os.environ["TORCHINDUCTOR_AUTOHEURISTIC_USE"] = self.name + os.environ["TORCHINDUCTOR_AUTOHEURISTIC_COLLECT"] = "" else: - torch._inductor.config.autoheuristic_use = "" - torch._inductor.config.autoheuristic_collect = self.name + os.environ["TORCHINDUCTOR_AUTOHEURISTIC_USE"] = "" + os.environ["TORCHINDUCTOR_AUTOHEURISTIC_COLLECT"] = self.name torch._inductor.config.autoheuristic_log_path = args.o if args.device is not None: torch.cuda.set_device(args.device) diff --git a/torchgen/_autoheuristic/pad_mm/collect_known_mm_shapes.py b/torchgen/_autoheuristic/pad_mm/collect_known_mm_shapes.py new file mode 100644 index 0000000000000..45a5854044b33 --- /dev/null +++ b/torchgen/_autoheuristic/pad_mm/collect_known_mm_shapes.py @@ -0,0 +1,235 @@ +import argparse +import csv +import sys +from pathlib import Path + + +# Add parent directory to path for imports +sys.path.append(str(Path(__file__).absolute().parents[1])) +sys.path.append( + str( + Path(__file__).absolute().parents[3] + / "benchmarks" + / "dynamo" + / "microbenchmarks" + ) +) + +from operator_inp_utils import ( # type: ignore[import-not-found] + deserialize_args, + OperatorInputsLoader, +) + +import torch +from torch._inductor.fx_passes.pad_mm import ( + get_alignment_size_dtype, # type: ignore[import-not-found] +) +from torch._subclasses.fake_tensor import FakeTensorMode + + +def is_aligned(dim: int, align_size: int) -> bool: + """Check if dimension is aligned to the given alignment size.""" + return dim % align_size == 0 + + +def extract_mm_shapes_from_loader( + loader: OperatorInputsLoader, +) -> list[tuple[int, int, int, torch.dtype, torch.dtype]]: + """Extract matrix multiplication shapes from an OperatorInputsLoader using deserialize_args with FakeTensorMode.""" + shapes = [] + + # Matrix multiplication operators to look for + mm_operators = ["aten.mm.default", "aten.addmm.default", "aten.bmm.default"] + + # Use FakeTensorMode to avoid instantiating actual tensors + with FakeTensorMode(): + for op_name in mm_operators: + if op_name not in loader.operator_db: + continue + + # Count shapes extracted from this operator + shape_count = 0 + + # Access the raw string data directly from operator_db and reuse existing parsing + for input_str in loader.operator_db[op_name]: + try: + # Use deserialize_args to parse inputs - will create fake tensors + args, kwargs = deserialize_args(input_str) + + if op_name == "aten.mm.default": + # mm(input, mat2) -> result + if len(args) >= 2: + a, b = args[0], args[1] + if isinstance(a, torch.Tensor) and isinstance( + b, torch.Tensor + ): + a_shape, a_dtype = tuple(a.shape), a.dtype + b_shape, b_dtype = tuple(b.shape), b.dtype + if len(a_shape) == 2 and len(b_shape) == 2: + m, k = a_shape + k2, n = b_shape + if k == k2: # Valid matrix multiplication + shapes.append((m, k, n, a_dtype, b_dtype)) + shape_count += 1 + + elif op_name == "aten.addmm.default": + # addmm(bias, input, mat2) -> result + if len(args) >= 3: + _, a, b = args[0], args[1], args[2] + if isinstance(a, torch.Tensor) and isinstance( + b, torch.Tensor + ): + a_shape, a_dtype = tuple(a.shape), a.dtype + b_shape, b_dtype = tuple(b.shape), b.dtype + if len(a_shape) == 2 and len(b_shape) == 2: + m, k = a_shape + k2, n = b_shape + if k == k2: # Valid matrix multiplication + shapes.append((m, k, n, a_dtype, b_dtype)) + shape_count += 1 + + elif op_name == "aten.bmm.default": + # bmm(input, mat2) -> result (batch matrix multiplication) + if len(args) >= 2: + a, b = args[0], args[1] + if isinstance(a, torch.Tensor) and isinstance( + b, torch.Tensor + ): + a_shape, a_dtype = tuple(a.shape), a.dtype + b_shape, b_dtype = tuple(b.shape), b.dtype + if len(a_shape) == 3 and len(b_shape) == 3: + batch1, m, k = a_shape + batch2, k2, n = b_shape + if ( + batch1 == batch2 and k == k2 + ): # Valid batch matrix multiplication + shapes.append((m, k, n, a_dtype, b_dtype)) + shape_count += 1 + + except Exception: + # Skip invalid inputs + continue + + print(f" Extracted {shape_count} shapes from {op_name}") + + return shapes + + +def filter_unaligned_shapes( + shapes: list[tuple[int, int, int, torch.dtype, torch.dtype]], +) -> list[tuple[int, int, int, torch.dtype, torch.dtype]]: + """Filter shapes to keep only those that are not completely aligned (so padding is relevant).""" + filtered_shapes = [] + + for m, k, n, dtype1, dtype2 in shapes: + # Use the primary dtype for alignment calculation (assume both dtypes are similar for alignment purposes) + dtype = dtype1 + try: + align_size = get_alignment_size_dtype(dtype) + + # Only keep shapes where not all dimensions are aligned + if not all(is_aligned(dim, align_size) for dim in [m, k, n]): + filtered_shapes.append((m, k, n, dtype1, dtype2)) + + except Exception: + # If we can't get alignment size, skip this shape + continue + + return filtered_shapes + + +def collect_known_mm_shapes() -> list[tuple[int, int, int, torch.dtype, torch.dtype]]: + """ + Collect known matrix multiplication shapes from HuggingFace, TIMM, and TorchBench datasets. + + Returns: + List of tuples containing (m, k, n, dtype1, dtype2) for matrix multiplication shapes + that are not completely aligned (so padding is relevant). + """ + all_shapes = [] + + loaders = [] + + # Try to load each dataset + try: + hf_loader = OperatorInputsLoader.get_huggingface_loader() + loaders.append(("HuggingFace", hf_loader)) + except Exception as e: + print(f"Warning: Could not load HuggingFace dataset: {e}") + + try: + timm_loader = OperatorInputsLoader.get_timm_loader() + loaders.append(("TIMM", timm_loader)) + except Exception as e: + print(f"Warning: Could not load TIMM dataset: {e}") + + try: + torchbench_loader = OperatorInputsLoader.get_torchbench_loader() + loaders.append(("TorchBench", torchbench_loader)) + except Exception as e: + print(f"Warning: Could not load TorchBench dataset: {e}") + + # Extract shapes from each loader + for dataset_name, loader in loaders: + print(f"Extracting shapes from {dataset_name}...") + + shapes = extract_mm_shapes_from_loader(loader) + print(f"Found {len(shapes)} matrix multiplication shapes from {dataset_name}") + all_shapes.extend(shapes) + + # Remove duplicates + unique_shapes = list(set(all_shapes)) + print(f"Total unique shapes before filtering: {len(unique_shapes)}") + + # Filter for unaligned shapes only + filtered_shapes = filter_unaligned_shapes(unique_shapes) + print(f"Shapes after filtering for unaligned: {len(filtered_shapes)}") + + return filtered_shapes + + +def main(output_file="mm_shapes.csv"): + shapes = collect_known_mm_shapes() + + print(f"\nCollected {len(shapes)} real-world matrix multiplication shapes") + + # Convert dtype objects to strings and filter for desired dtypes + dtype_map = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + } + + # Convert to desired format and filter dtypes + csv_rows = [] + for m, k, n, dtype1, dtype2 in shapes: + # Use the first dtype and convert to string + if dtype1 in dtype_map: + dtype_str = dtype_map[dtype1] + csv_rows.append([m, k, n, dtype_str]) + + # Save to CSV file + with open(output_file, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + # Write header + writer.writerow(["M", "K", "N", "dtype"]) + # Write data rows + writer.writerows(csv_rows) + + print(f"Saved matrix multiplication shapes to {output_file}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Collect matrix multiplication shapes from real-world datasets and save to CSV" + ) + parser.add_argument( + "--output", + "-o", + type=str, + default="mm_shapes.csv", + help="Output CSV filename (default: mm_shapes.csv)", + ) + + args = parser.parse_args() + main(args.output) diff --git a/torchgen/_autoheuristic/pad_mm/evaluate_pad_mm_heuristics.py b/torchgen/_autoheuristic/pad_mm/evaluate_pad_mm_heuristics.py new file mode 100644 index 0000000000000..42b37d74c3092 --- /dev/null +++ b/torchgen/_autoheuristic/pad_mm/evaluate_pad_mm_heuristics.py @@ -0,0 +1,426 @@ +#!/usr/bin/env python3 + +import argparse +import csv +import functools + +import torch +from torch._inductor.autoheuristic.autoheuristic_utils import AHContext, AHMetadata +from torch._inductor.fx_passes.pad_mm import get_alignment_size_dtype +from torch._inductor.runtime.benchmarking import benchmarker +from torch._inductor.utils import get_gpu_shared_memory + + +def fits_in_memory(dtype, m: int, k: int, n: int) -> bool: + threshold_memory = torch.cuda.get_device_properties(0).total_memory / 4 + return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory + + +def set_precision(dtype, float32_precision: str = "highest") -> None: + precision = float32_precision if dtype == torch.float32 else "high" + torch.set_float32_matmul_precision(precision) + + +def get_heuristic_decision(m: int, k: int, n: int, dtype: torch.dtype) -> str | None: + from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback + from torch._inductor.fx_passes.pad_mm import ( + get_alignment_size, + get_context, + get_padded_length, + pad_mm_operations, + pad_mm_precondition, + ) + + torch._inductor.config.autoheuristic_use.pad_mm = True + + if not torch._inductor.config.run_autoheuristic("pad_mm"): + return None + + a = torch.randn(m, k, dtype=dtype, device="cuda") + b = torch.randn(k, n, dtype=dtype, device="cuda") + + m_padded_length = get_padded_length(m, get_alignment_size(a)) + k_padded_length = get_padded_length(k, get_alignment_size(a)) + n_padded_length = get_padded_length(n, get_alignment_size(b)) + + context = get_context( + a, + b, + mat1_pre_padded=False, + mat2_pre_padded=False, + m_padded_length=m_padded_length, + k_padded_length=k_padded_length, + n_padded_length=n_padded_length, + ) + + def dummy_feedback(choice: str) -> float: + return 1.0 + + def fallback() -> str: + return "no_decision" + + autoheuristic = AutoHeuristic( + fallback=fallback, + choices=["orig", "pad"], + feedback=LocalFeedback(dummy_feedback), + context=context, + name="pad_mm", + augment_context=pad_mm_operations(), + precondition=pad_mm_precondition, + ) + + choice = autoheuristic.get_choice() + return choice + + +def benchmark_both_choices( + m: int, + k: int, + n: int, + dtype: torch.dtype, + num_reps: int = 3, + float32_precision: str = "highest", +) -> tuple[float, float]: + set_precision(dtype, float32_precision) + a = torch.randn(m, k, dtype=dtype, device="cuda") + b = torch.randn(k, n, dtype=dtype, device="cuda") + + # Use existing benchmarking infrastructure with proper cache management + # benchmarker returns time in milliseconds, so convert to seconds for consistency + orig_time_ms = benchmarker.benchmark( + torch.mm, fn_args=(a, b), rep=num_reps, is_vetted_benchmarking=True + ) + orig_time = orig_time_ms / 1000.0 # Convert ms to seconds + + from torch._inductor.fx_passes.pad_mm import ( + get_alignment_size, + get_padded_length, + pad_mm, + ) + + m_padded_length = get_padded_length(a.shape[0], get_alignment_size(a)) + k_padded_length = get_padded_length(a.shape[1], get_alignment_size(a)) + n_padded_length = get_padded_length(b.shape[1], get_alignment_size(b)) + + if m_padded_length == 0 and k_padded_length == 0 and n_padded_length == 0: + return orig_time, orig_time + + pad_time_ms = benchmarker.benchmark( + pad_mm, + fn_args=(a, b, m_padded_length, k_padded_length, n_padded_length), + rep=num_reps, + is_vetted_benchmarking=True, + ) + pad_time = pad_time_ms / 1000.0 # Convert ms to seconds + + return orig_time, pad_time + + +def load_shapes_from_csv(csv_file: str) -> list: + shapes = [] + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + m, k, n = int(row["M"]), int(row["K"]), int(row["N"]) + dtype_str = row["dtype"] + + if dtype_str == "float16": + dtype = torch.float16 + elif dtype_str == "bfloat16": + dtype = torch.bfloat16 + elif dtype_str == "float32": + dtype = torch.float32 + else: + continue + + shapes.append((m, k, n, dtype)) + + print(f"Loaded {len(shapes)} shapes from {csv_file}") + return shapes + + +@functools.cache +def get_shared_mem_size(): + return get_gpu_shared_memory() + + +def check_shape_passes_precondition(m: int, k: int, n: int, dtype: torch.dtype) -> bool: + """ + Check if a shape passes the same precondition used by the actual pad_mm AutoHeuristics. + + This uses the exact same pad_mm_precondition function that the AutoHeuristic system + uses, avoiding hardcoded magic numbers by delegating to the source of truth. + """ + from torch._inductor.autoheuristic.autoheuristic_utils import pad_mm_precondition + + shared_memory = get_shared_mem_size() + device_capa = torch.cuda.get_device_capability() + + # Create the same metadata and context that AutoHeuristics uses + metadata = AHMetadata( + shared_memory=shared_memory, + device_capa=device_capa, + choices=["orig", "pad"], # Required but not used for precondition check + name="pad_mm", # Required but not used for precondition check + ) + + context = AHContext() + context.add_feature("m", m) + context.add_feature("k", k) + context.add_feature("n", n) + + # Use the actual pad_mm_precondition function - no hardcoded values! + return pad_mm_precondition(metadata, context) + + +def filter_shapes(shapes: list) -> list: + filtered = [] + aligned_count = 0 + precondition_failed_count = 0 + memory_count = 0 + + for m, k, n, dtype in shapes: + # Check if already aligned + align_size = get_alignment_size_dtype(dtype) + is_aligned = all((dim % align_size == 0) for dim in [m, k, n]) + + if is_aligned: + aligned_count += 1 + continue + + # Check if passes the actual precondition used by pad_mm AutoHeuristics + if not check_shape_passes_precondition(m, k, n, dtype): + precondition_failed_count += 1 + continue + + # Check if fits in memory + if not fits_in_memory(dtype, m, k, n): + memory_count += 1 + continue + + # This shape is suitable for evaluation + filtered.append((m, k, n, dtype)) + + print("Filtering results:") + print(f" Already aligned (skipped): {aligned_count}") + print(f" Failed pad_mm_precondition (skipped): {precondition_failed_count}") + print(f" Too large for memory (skipped): {memory_count}") + print(f" Suitable for evaluation: {len(filtered)}") + + return filtered + + +def main(): + parser = argparse.ArgumentParser( + description="Evaluate trained AutoHeuristics for pad_mm optimization" + ) + parser.add_argument("csv_file", help="Path to CSV file with M,K,N,dtype columns") + parser.add_argument( + "--num-reps", type=int, default=3, help="Benchmark repetitions (default: 3)" + ) + parser.add_argument( + "--device", type=int, default=None, help="CUDA device (default: current)" + ) + parser.add_argument( + "--max-shapes", + type=int, + default=10000, + help="Max shapes to test (default: 10000)", + ) + parser.add_argument( + "--float32_matmul_precision", + type=str, + choices=["high", "highest"], + default="highest", + help="Matmul precision for float32 (default: highest). Non-fp32 always uses 'high'.", + ) + + args = parser.parse_args() + + torch.set_default_device("cuda") + if args.device is not None: + torch.cuda.set_device(args.device) + + print(f"Using CUDA device: {torch.cuda.current_device()}") + print() + + shapes = load_shapes_from_csv(args.csv_file) + if not shapes: + print("No shapes found!") + return + + shapes = filter_shapes(shapes) + if not shapes: + print("No suitable shapes found!") + return + + if len(shapes) > args.max_shapes: + shapes = shapes[: args.max_shapes] + print(f"Limited to first {args.max_shapes} shapes") + + print(f"Evaluating {len(shapes)} shapes with {args.num_reps} reps each") + print() + + total_decisions = 0 + correct_decisions = 0 + true_positives = 0 # Chose pad, should pad + true_negatives = 0 # Chose orig, should orig + false_positives = 0 # Chose pad, should orig + false_negatives = 0 # Chose orig, should pad + no_decision_shapes = 0 + + tp_speedups = [] # Speed-up percentages for true positives + fp_slowdowns = [] # Speed-down percentages for false positives + + # Track non-confident decisions and confident decisions by dtype + no_decision_shape_list = [] # List of (M, K, N, dtype) where heuristic chose no_decision + confident_by_dtype = {} # Count of confident decisions by dtype + + for i, (m, k, n, dtype) in enumerate(shapes, 1): + print(f"Shape {i}/{len(shapes)}: M={m}, K={k}, N={n}, dtype={dtype}") + + heuristic_choice = get_heuristic_decision(m, k, n, dtype) + print(f" Heuristic: {heuristic_choice}") + + orig_time, pad_time = benchmark_both_choices( + m, k, n, dtype, args.num_reps, args.float32_matmul_precision + ) + ground_truth = "pad" if pad_time < orig_time else "orig" + + print(f" Times: orig={orig_time:.6f}s, pad={pad_time:.6f}s") + print(f" Ground truth: {ground_truth}") + + if heuristic_choice == "no_decision": + # Heuristic punted to benchmarking - this is correct behavior for small/uncertain shapes + no_decision_shapes += 1 + no_decision_shape_list.append((m, k, n, dtype)) + print(" Heuristic chose to benchmark (conservative)") + else: + # Heuristic made a confident decision - evaluate accuracy + total_decisions += 1 + # Track confident decisions by dtype + dtype_str = str(dtype).replace("torch.", "") + confident_by_dtype[dtype_str] = confident_by_dtype.get(dtype_str, 0) + 1 + if heuristic_choice == ground_truth: + correct_decisions += 1 + print(" ✓ CORRECT") + if heuristic_choice == "pad": + true_positives += 1 # Correctly chose pad + # Calculate speed-up: (orig_time - pad_time) / orig_time * 100 + speedup = (orig_time - pad_time) / orig_time * 100 + tp_speedups.append(speedup) + print(f" Speed-up: {speedup:.1f}%") + else: + true_negatives += 1 # Correctly chose orig + else: + print(" ✗ WRONG") + if heuristic_choice == "pad" and ground_truth == "orig": + false_positives += 1 + # Calculate speed-down: (pad_time - orig_time) / orig_time * 100 + slowdown = (pad_time - orig_time) / orig_time * 100 + fp_slowdowns.append(slowdown) + print(f" Speed-down: {slowdown:.1f}%") + elif heuristic_choice == "orig" and ground_truth == "pad": + false_negatives += 1 + + print(f" Confidence Rate: {total_decisions}/{i}") + if total_decisions > 0: + accuracy = correct_decisions / total_decisions * 100 + tp_rate = true_positives / total_decisions * 100 + tn_rate = true_negatives / total_decisions * 100 + fp_rate = false_positives / total_decisions * 100 + fn_rate = false_negatives / total_decisions * 100 + + # Compute average speedup/slowdown + avg_tp_speedup = sum(tp_speedups) / len(tp_speedups) if tp_speedups else 0 + avg_fp_slowdown = ( + sum(fp_slowdowns) / len(fp_slowdowns) if fp_slowdowns else 0 + ) + + print( + f" Accuracy: {correct_decisions}/{total_decisions} ({accuracy:.1f}%) " + f"| TP: {tp_rate:.1f}% (avg speedup: {avg_tp_speedup:.1f}%) " + f"| TN: {tn_rate:.1f}% " + f"| FP: {fp_rate:.1f}% (avg slowdown: {avg_fp_slowdown:.1f}%)" + f"| FN: {fn_rate:.1f}%" + ) + + print() + + print("=== FINAL RESULTS ===") + print(f"Confident decisions: {total_decisions}") + print(f"#Shapes without confident decisions: {no_decision_shapes}") + + if total_decisions > 0: + accuracy = correct_decisions / total_decisions * 100 + tp_rate = true_positives / total_decisions * 100 + tn_rate = true_negatives / total_decisions * 100 + fp_rate = false_positives / total_decisions * 100 + fn_rate = false_negatives / total_decisions * 100 + + avg_tp_speedup = sum(tp_speedups) / len(tp_speedups) if tp_speedups else 0 + avg_fp_slowdown = sum(fp_slowdowns) / len(fp_slowdowns) if fp_slowdowns else 0 + + print( + f"\nConfident decision accuracy: {accuracy:.1f}% ({correct_decisions}/{total_decisions})" + ) + + if tp_speedups: + print( + f"True Positives (chose pad, should pad): {tp_rate:.1f}% ({true_positives}) " + f"| Avg speed-up: {avg_tp_speedup:.1f}%" + ) + else: + print( + f"True Positives (chose pad, should pad): {tp_rate:.1f}% ({true_positives})" + ) + + print( + f"True Negatives (chose orig, should orig): {tn_rate:.1f}% ({true_negatives})" + ) + + if fp_slowdowns: + print( + f"False Positives (chose pad, should orig): {fp_rate:.1f}% ({false_positives}) " + f"| Avg speed-down: {avg_fp_slowdown:.1f}%" + ) + else: + print( + f"False Positives (chose pad, should orig): {fp_rate:.1f}% ({false_positives})" + ) + + print( + f"False Negatives (chose orig, should pad): {fn_rate:.1f}% ({false_negatives})" + ) + else: + print("No confident decisions made!") + + total_evaluated = total_decisions + no_decision_shapes + if total_evaluated > 0: + print( + f"\nConfidence rate: ({total_decisions}/{total_evaluated} made confident decisions)" + ) + + # Print shapes where AutoHeuristics did not make a confident decision + print(f"\n=== NON-CONFIDENT DECISIONS ({len(no_decision_shape_list)}) ===") + if no_decision_shape_list: + print("Shapes where AutoHeuristics chose 'no_decision' (non-confident):") + for m, k, n, dtype in no_decision_shape_list: + dtype_str = str(dtype).replace("torch.", "") + print(f" M={m}, K={k}, N={n}, dtype={dtype_str}") + else: + print("All shapes had confident decisions!") + + # Print confident decisions by dtype + print("\n=== CONFIDENT DECISIONS BY DTYPE ===") + if confident_by_dtype: + print("Number of confident decisions per dtype:") + for dtype_str, count in sorted(confident_by_dtype.items()): + print(f" {dtype_str}: {count} confident decisions") + print(f"Total confident decisions: {sum(confident_by_dtype.values())}") + else: + print("No confident decisions made!") + + +if __name__ == "__main__": + main() diff --git a/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py b/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py index b476bacfb67db..eba7dd19b53b2 100644 --- a/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py +++ b/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py @@ -1,5 +1,7 @@ +import csv import random import sys +from collections.abc import Generator from pathlib import Path from typing import Any @@ -10,9 +12,13 @@ from benchmark_utils import ( # type: ignore[import-not-found] fits_in_memory, get_mm_tensors, + get_random_between_pow2, set_precision, transpose_tensors, ) +from collect_known_mm_shapes import ( + collect_known_mm_shapes, # type: ignore[import-not-found] +) import torch from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found] @@ -29,10 +35,121 @@ class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimpo def __init__(self) -> None: super().__init__("pad_mm") + # Add CLI argument for additional shape CSV files + self.parser.add_argument( + "--additional-shape-csv", + nargs="*", + default=[], + help="List of CSV files containing additional matrix multiplication shapes (M,K,N,dtype format)", + ) + + # Initialize additional_shape_collections + self.additional_shape_collections: list[ + list[tuple[int, int, int, torch.dtype, torch.dtype]] + ] = [] + + # Initialize the shape generator (will be set up after parsing args) + self.shape_generator = None + + def load_shapes_from_csv( + self, csv_file: str + ) -> list[tuple[int, int, int, torch.dtype, torch.dtype]]: + """Load matrix multiplication shapes from a CSV file in M,K,N,dtype format.""" + shapes = [] + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + + try: + with open(csv_file) as f: + reader = csv.DictReader(f) + for row in reader: + m = int(row["M"]) + k = int(row["K"]) + n = int(row["N"]) + dtype_str = row["dtype"] + + if dtype_str in dtype_map: + dtype = dtype_map[dtype_str] + # Store as (m, k, n, dtype1, dtype2) with same dtype for both + shapes.append((m, k, n, dtype, dtype)) + else: + print( + f"Warning: Unknown dtype '{dtype_str}' in {csv_file}, skipping row" + ) + + print(f"Loaded {len(shapes)} shapes from {csv_file}") + except Exception as e: + print(f"Error loading shapes from {csv_file}: {e}") + + return shapes + + def setup_shape_collections(self, csv_files: list[str]) -> None: + """Setup additional shape collections from CSV files and built-in collection.""" + self.additional_shape_collections = [] + + # Load shapes from provided CSV files first + for csv_file in csv_files: + shapes = self.load_shapes_from_csv(csv_file) + if shapes: + self.additional_shape_collections.append(shapes) + + self.additional_shape_collections.append(collect_known_mm_shapes()) + self.shape_generator = self.generate_mm_shapes() + + def generate_mm_shapes(self) -> Generator[tuple[int, int, int, Any], None, None]: + """Generator that yields (m, k, n, dtype) tuples for matrix multiplication. + + First exhausts all shapes from additional_shape_collections, then generates random shapes. + Only yields unaligned shapes since external CSV shapes may not be pre-filtered. + """ + # Phase 1: Use all shapes from additional shape collections + for collection in self.additional_shape_collections: + for m, k, n, dtype1, _ in collection: + # Filter for unaligned shapes only (external CSVs may not be pre-filtered) + align_size = get_alignment_size_dtype(dtype1) + if not all(self.is_aligned(dim, align_size) for dim in [m, k, n]): + # Check if it fits in memory + if fits_in_memory(dtype1, m, k, n): + yield (m, k, n, dtype1) + + # Phase 2: Generate infinite random shapes + + while True: + # Generate random dtype + dtype_choices = [torch.float16, torch.bfloat16, torch.float32] + dtype = random.choices(dtype_choices)[0] + + # Generate random shape for this dtype + uniform = random.choices([True, False])[0] + align_size = get_alignment_size_dtype(dtype) + + # Keep trying until we get a valid unaligned shape that fits in memory + while True: + if uniform: + m = random.randint(1, 65536) + k = random.randint(1, 65536) + n = random.randint(1, 65536) + else: + m = self.get_random_dim() + k = self.get_random_dim() + n = self.get_random_dim() + + # Skip if all dimensions are aligned (we need unaligned for padding to be relevant) + if all(self.is_aligned(dim, align_size) for dim in [m, k, n]): + continue + + # Check if it fits in memory + if fits_in_memory(dtype, m, k, n): + yield (m, k, n, dtype) + break + def create_input(self) -> tuple[Any, ...]: - dtype = self.get_dtype() + # Get the next shape from the generator + m, k, n, dtype = next(self.shape_generator) set_precision(dtype) - m, k, n = self.get_m_k_n(dtype) (transpose_left, transpose_right) = transpose_tensors() prepadded_left = self.prepadded() @@ -107,40 +224,45 @@ def get_random_dim( return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return] else: # choose a random number between 2^i and 2^(i+1) - return self.get_random_between_pow2(min_power2, max_power2) # type: ignore[no-any-return] + return get_random_between_pow2(min_power2, max_power2) # type: ignore[no-any-return] def is_aligned(self, dim: int, align_size: int) -> bool: return dim % align_size == 0 - def get_m_k_n(self, dtype: Any) -> tuple[int, int, int]: - uniform = random.choices([True, False])[0] - align_size = get_alignment_size_dtype(dtype) + def prepadded(self, p_prepadded: float = 0.2) -> bool: + # p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking + return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0] + + def run(self) -> None: + """Override run to setup shape collections before running.""" + import time - # repeat until tensors fit in memory - while True: - if uniform: - m = random.randint(1, 65536) - k = random.randint(1, 65536) - n = random.randint(1, 65536) - else: - m = self.get_random_dim() - k = self.get_random_dim() - n = self.get_random_dim() + from tqdm import tqdm - if all(self.is_aligned(dim, align_size) for dim in [m, k, n]): - # skip if already aligned - continue + torch.set_default_device("cuda") + args = self.parser.parse_args() - if fits_in_memory(dtype, m, k, n): - return (m, k, n) + # Setup shape collections based on CLI arguments + self.setup_shape_collections(args.additional_shape_csv) - def prepadded(self, p_prepadded: float = 0.2) -> bool: - # p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking - return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0] + # Set up torch configuration (copied from parent run method) + + if args.use_heuristic: + torch._inductor.config.autoheuristic_use.pad_mm = True + torch._inductor.config.autoheuristic_collect.pad_mm = False + else: + torch._inductor.config.autoheuristic_use.pad_mm = False + torch._inductor.config.autoheuristic_collect.pad_mm = True + torch._inductor.config.autoheuristic_log_path = args.o + if args.device is not None: + torch.cuda.set_device(args.device) + random.seed(time.time()) - def get_dtype(self) -> Any: - dtype_choices = [torch.float16, torch.bfloat16, torch.float32] - return random.choices(dtype_choices)[0] + # Run the main benchmarking loop + for _ in tqdm(range(args.num_samples)): + input = self.create_input() + for _ in range(args.num_reps): + self.run_benchmark(*input) if __name__ == "__main__": diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 04af02ce848db..939c9fdc1cdd9 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -152,6 +152,12 @@ "aten.randn.default": {}, "aten.randn.generator": {}, "aten.randperm.default": {}, + "aten.rand_like.default": {}, + "aten.rand_like.generator": {}, + "aten.randint_like.default": {}, + "aten.randint_like.low_dtype": {}, + "aten.randn_like.default": {}, + "aten.randn_like.generator": {}, "aten.repeat_interleave.Tensor": {}, "aten.replication_pad1d_backward.default": {}, "aten.replication_pad2d_backward.default": {}, diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 23a5cd6b1b61e..477be3c528433 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -240,7 +240,7 @@ class DifferentiableInput: # Represents a differentiable `Return`. -# How it it different from the `Return` type? +# How is it different from the `Return` type? # - The name in `Return` is optional. Here it is always populated using the same # `cpp.return_names()` method. # TODO: some cpp naming logic (e.g. resolving name conflict) might be irrelevant? diff --git a/torchgen/api/python.py b/torchgen/api/python.py index ca971e854b234..9039e1c57f2c1 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -288,7 +288,7 @@ def argument_str_pyi( name += "_" # pyi merges the _out and functional variants into the same signature, with an optional out arg - if name == "out" and type_str == "Tensor" and not deprecated: + if name == "out" and not deprecated: type_str = f"{type_str} | None".replace(" | None | None", " | None") # pyi deprecated signatures don't get defaults for their out arg @@ -975,14 +975,17 @@ def argument_type_str_pyi(t: Type) -> str: if str(t.elem) == "int": ret = "_int | _size" if t.size is not None else "_size" elif t.is_tensor_like(): - # TODO: this doesn't seem right... - # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None - # It should probably translate to tuple[Tensor | None, ...] | list[Tensor | None] - add_optional = True + # Tensor?[] translates to tuple[Tensor | None, ...] | list[Tensor | None] | None + # Tensor[] translates to tuple[Tensor, ...] | list[Tensor] + if isinstance(t.elem, OptionalType): + add_optional = True + elem_str = "Tensor | None" + else: + elem_str = "Tensor" ret = ( - "Tensor | tuple[Tensor, ...] | list[Tensor]" + f"Tensor | tuple[{elem_str}, ...] | list[{elem_str}]" if t.size is not None - else "tuple[Tensor, ...] | list[Tensor]" + else f"tuple[{elem_str}, ...] | list[{elem_str}]" ) elif str(t.elem) == "float": ret = "Sequence[_float]" @@ -1473,7 +1476,7 @@ def dispatch_lambda_exprs( inits.extend( [ f"auto __{name} = {arg_parser_expr};", - f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950 + f"::std::optional {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", ] ) lambda_args_exprs[name] = name diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 17d4d4a646a55..251ba64248a3c 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -37,6 +37,7 @@ NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, + OperatorName, Return, SchemaKind, SelfArgument, @@ -76,6 +77,16 @@ ] ) +# Eager cumulative out variants compute in the out dtype when dtype is omitted. +# Functionalization normally lowers mutable ops through their functional variants, +# so these need to thread the out dtype explicitly to preserve eager semantics. +CUMULATIVE_OUT_OPS_PRESERVING_OUT_DTYPE = { + OperatorName.parse("cumsum.out"), + OperatorName.parse("cumprod.out"), + OperatorName.parse("cumsum.dimname_out"), + OperatorName.parse("cumprod.dimname_out"), +} + # This file contains codegen that relates to the functionalization pass. # It includes: # - gen_functionalization_definition @@ -607,6 +618,40 @@ def wrap_propagate_mutations_and_return( {returns_str}""" +def maybe_replace_cumulative_out_dtype_exprs( + f: NativeFunction, + functional_sig: DispatcherSignature, + functional_exprs: list[str], +) -> list[str]: + if ( + f.func.kind() != SchemaKind.out + or f.func.name not in CUMULATIVE_OUT_OPS_PRESERVING_OUT_DTYPE + ): + return functional_exprs + + if len(f.func.arguments.out) != 1: + raise AssertionError( + f"Expected a single out argument for cumulative out op: {f.func.name}" + ) + + dtype_arg_idx = next( + (i for i, arg in enumerate(functional_sig.arguments()) if arg.name == "dtype"), + None, + ) + if dtype_arg_idx is None: + raise AssertionError( + f"Expected dtype argument for cumulative out op: {f.func.name}" + ) + + adjusted_exprs = functional_exprs.copy() + dtype_expr = adjusted_exprs[dtype_arg_idx] + adjusted_exprs[dtype_arg_idx] = ( + f"{dtype_expr}.has_value() ? {dtype_expr} : " + f"std::optional({f.func.arguments.out[0].name}_.scalar_type())" + ) + return adjusted_exprs + + # Generates the Functionalization kernel for: # - mutation ops (inplace and out= ops) @with_native_function_and @@ -678,6 +723,9 @@ def emit_inplace_functionalization_body( e.expr for e in translate(unwrapped_args_ctx, functional_sig.arguments(), method=False) ] + functional_exprs = maybe_replace_cumulative_out_dtype_exprs( + f, functional_sig, functional_exprs + ) meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig) # We don't want to run the inplace meta func for ops like .set_(), because: diff --git a/torchgen/model.py b/torchgen/model.py index 54aeaab9fb9d5..395d5e0b8c363 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -755,6 +755,10 @@ def from_yaml( if namespace == "aten" and "pt2_compliant_tag" in valid_tags: tags_inp.append("pt2_compliant_tag") + # All out= ops receive the "out" tag. + if func.is_out_fn() and "out" in valid_tags: + tags_inp.append("out") + tags: set[str] = set() for t in tags_inp: if len(valid_tags) == 0: diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 107802f60051e..b677e961a0cdb 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -238,9 +238,9 @@ def generate_out_args_from_schema( # Helper function: given a mutable FunctionSchema, generate its corresponding out= variant # Example before: -# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 +# _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # Example after: -# _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950 +# _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema: # Generating an out= schema from a mutable schema. if func.kind() != SchemaKind.mutable: @@ -345,6 +345,8 @@ def generate_function( tags = {"generated"} | set( f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"} ) + if func.is_out_fn(): + tags.add("out") return ( NativeFunction( diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 0521e49f77297..f5c42b4e22cad 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -100,7 +100,7 @@ class ByteCode(Enum): ONE_UPGRADER_IN_VERSION_MAP = CodeTemplate( """Upgrader({${upgrader_min_version}, ${upgrader_max_version}, "${upgrader_name}", ${bytecode_func_index}})""" -) # noqa: E501 +) ONE_OPERATOR_IN_VERSION_MAP = CodeTemplate( """ diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index e15b2514830ac..d8aba4d13bcde 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -270,7 +270,7 @@ def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: # the string, just test the dang thing directly if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type(): # Returns a non-Tensor value. - logger.info("NON-TENSOR RET TYPE: %s", str(func)) + logger.info("NON-TENSOR RET TYPE: %s", func) return False return True diff --git a/version.txt b/version.txt index f5c124eae685e..5bd2aacf8ee24 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -2.12.0a0 +2.13.0a0